xref: /aosp_15_r20/external/libaom/av1/encoder/sparse_linear_solver.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1*77c1e3ccSAndroid Build Coastguard Worker /*
2*77c1e3ccSAndroid Build Coastguard Worker  * Copyright (c) 2021, Alliance for Open Media. All rights reserved.
3*77c1e3ccSAndroid Build Coastguard Worker  *
4*77c1e3ccSAndroid Build Coastguard Worker  * This source code is subject to the terms of the BSD 2 Clause License and
5*77c1e3ccSAndroid Build Coastguard Worker  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6*77c1e3ccSAndroid Build Coastguard Worker  * was not distributed with this source code in the LICENSE file, you can
7*77c1e3ccSAndroid Build Coastguard Worker  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8*77c1e3ccSAndroid Build Coastguard Worker  * Media Patent License 1.0 was not distributed with this source code in the
9*77c1e3ccSAndroid Build Coastguard Worker  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10*77c1e3ccSAndroid Build Coastguard Worker  */
11*77c1e3ccSAndroid Build Coastguard Worker #include "av1/common/av1_common_int.h"
12*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/sparse_linear_solver.h"
13*77c1e3ccSAndroid Build Coastguard Worker #include "config/aom_config.h"
14*77c1e3ccSAndroid Build Coastguard Worker #include "aom_mem/aom_mem.h"
15*77c1e3ccSAndroid Build Coastguard Worker #include "av1/common/alloccommon.h"
16*77c1e3ccSAndroid Build Coastguard Worker 
17*77c1e3ccSAndroid Build Coastguard Worker #if CONFIG_OPTICAL_FLOW_API
18*77c1e3ccSAndroid Build Coastguard Worker /*
19*77c1e3ccSAndroid Build Coastguard Worker  * Input:
20*77c1e3ccSAndroid Build Coastguard Worker  * rows: array of row positions
21*77c1e3ccSAndroid Build Coastguard Worker  * cols: array of column positions
22*77c1e3ccSAndroid Build Coastguard Worker  * values: array of element values
23*77c1e3ccSAndroid Build Coastguard Worker  * num_elem: total number of elements in the matrix
24*77c1e3ccSAndroid Build Coastguard Worker  * num_rows: number of rows in the matrix
25*77c1e3ccSAndroid Build Coastguard Worker  * num_cols: number of columns in the matrix
26*77c1e3ccSAndroid Build Coastguard Worker  *
27*77c1e3ccSAndroid Build Coastguard Worker  * Output:
28*77c1e3ccSAndroid Build Coastguard Worker  * sm: pointer to the sparse matrix to be initialized
29*77c1e3ccSAndroid Build Coastguard Worker  *
30*77c1e3ccSAndroid Build Coastguard Worker  * Return: 0  - success
31*77c1e3ccSAndroid Build Coastguard Worker  *         -1 - failed
32*77c1e3ccSAndroid Build Coastguard Worker  */
av1_init_sparse_mtx(const int * rows,const int * cols,const double * values,int num_elem,int num_rows,int num_cols,SPARSE_MTX * sm)33*77c1e3ccSAndroid Build Coastguard Worker int av1_init_sparse_mtx(const int *rows, const int *cols, const double *values,
34*77c1e3ccSAndroid Build Coastguard Worker                         int num_elem, int num_rows, int num_cols,
35*77c1e3ccSAndroid Build Coastguard Worker                         SPARSE_MTX *sm) {
36*77c1e3ccSAndroid Build Coastguard Worker   sm->n_elem = num_elem;
37*77c1e3ccSAndroid Build Coastguard Worker   sm->n_rows = num_rows;
38*77c1e3ccSAndroid Build Coastguard Worker   sm->n_cols = num_cols;
39*77c1e3ccSAndroid Build Coastguard Worker   if (num_elem == 0) {
40*77c1e3ccSAndroid Build Coastguard Worker     sm->row_pos = NULL;
41*77c1e3ccSAndroid Build Coastguard Worker     sm->col_pos = NULL;
42*77c1e3ccSAndroid Build Coastguard Worker     sm->value = NULL;
43*77c1e3ccSAndroid Build Coastguard Worker     return 0;
44*77c1e3ccSAndroid Build Coastguard Worker   }
45*77c1e3ccSAndroid Build Coastguard Worker   sm->row_pos = aom_calloc(num_elem, sizeof(*sm->row_pos));
46*77c1e3ccSAndroid Build Coastguard Worker   sm->col_pos = aom_calloc(num_elem, sizeof(*sm->col_pos));
47*77c1e3ccSAndroid Build Coastguard Worker   sm->value = aom_calloc(num_elem, sizeof(*sm->value));
48*77c1e3ccSAndroid Build Coastguard Worker 
49*77c1e3ccSAndroid Build Coastguard Worker   if (!sm->row_pos || !sm->col_pos || !sm->value) {
50*77c1e3ccSAndroid Build Coastguard Worker     av1_free_sparse_mtx_elems(sm);
51*77c1e3ccSAndroid Build Coastguard Worker     return -1;
52*77c1e3ccSAndroid Build Coastguard Worker   }
53*77c1e3ccSAndroid Build Coastguard Worker 
54*77c1e3ccSAndroid Build Coastguard Worker   memcpy(sm->row_pos, rows, num_elem * sizeof(*sm->row_pos));
55*77c1e3ccSAndroid Build Coastguard Worker   memcpy(sm->col_pos, cols, num_elem * sizeof(*sm->col_pos));
56*77c1e3ccSAndroid Build Coastguard Worker   memcpy(sm->value, values, num_elem * sizeof(*sm->value));
57*77c1e3ccSAndroid Build Coastguard Worker 
58*77c1e3ccSAndroid Build Coastguard Worker   return 0;
59*77c1e3ccSAndroid Build Coastguard Worker }
60*77c1e3ccSAndroid Build Coastguard Worker 
61*77c1e3ccSAndroid Build Coastguard Worker /*
62*77c1e3ccSAndroid Build Coastguard Worker  * Combines two sparse matrices (allocating new space).
63*77c1e3ccSAndroid Build Coastguard Worker  *
64*77c1e3ccSAndroid Build Coastguard Worker  * Input:
65*77c1e3ccSAndroid Build Coastguard Worker  * sm1, sm2: matrices to be combined
66*77c1e3ccSAndroid Build Coastguard Worker  * row_offset1, row_offset2: row offset of each matrix in the new matrix
67*77c1e3ccSAndroid Build Coastguard Worker  * col_offset1, col_offset2: column offset of each matrix in the new matrix
68*77c1e3ccSAndroid Build Coastguard Worker  * new_n_rows, new_n_cols: number of rows and columns in the new matrix
69*77c1e3ccSAndroid Build Coastguard Worker  *
70*77c1e3ccSAndroid Build Coastguard Worker  * Output:
71*77c1e3ccSAndroid Build Coastguard Worker  * sm: the combined matrix
72*77c1e3ccSAndroid Build Coastguard Worker  *
73*77c1e3ccSAndroid Build Coastguard Worker  * Return: 0  - success
74*77c1e3ccSAndroid Build Coastguard Worker  *         -1 - failed
75*77c1e3ccSAndroid Build Coastguard Worker  */
av1_init_combine_sparse_mtx(const SPARSE_MTX * sm1,const SPARSE_MTX * sm2,SPARSE_MTX * sm,int row_offset1,int col_offset1,int row_offset2,int col_offset2,int new_n_rows,int new_n_cols)76*77c1e3ccSAndroid Build Coastguard Worker int av1_init_combine_sparse_mtx(const SPARSE_MTX *sm1, const SPARSE_MTX *sm2,
77*77c1e3ccSAndroid Build Coastguard Worker                                 SPARSE_MTX *sm, int row_offset1,
78*77c1e3ccSAndroid Build Coastguard Worker                                 int col_offset1, int row_offset2,
79*77c1e3ccSAndroid Build Coastguard Worker                                 int col_offset2, int new_n_rows,
80*77c1e3ccSAndroid Build Coastguard Worker                                 int new_n_cols) {
81*77c1e3ccSAndroid Build Coastguard Worker   sm->n_elem = sm1->n_elem + sm2->n_elem;
82*77c1e3ccSAndroid Build Coastguard Worker   sm->n_cols = new_n_cols;
83*77c1e3ccSAndroid Build Coastguard Worker   sm->n_rows = new_n_rows;
84*77c1e3ccSAndroid Build Coastguard Worker 
85*77c1e3ccSAndroid Build Coastguard Worker   if (sm->n_elem == 0) {
86*77c1e3ccSAndroid Build Coastguard Worker     sm->row_pos = NULL;
87*77c1e3ccSAndroid Build Coastguard Worker     sm->col_pos = NULL;
88*77c1e3ccSAndroid Build Coastguard Worker     sm->value = NULL;
89*77c1e3ccSAndroid Build Coastguard Worker     return 0;
90*77c1e3ccSAndroid Build Coastguard Worker   }
91*77c1e3ccSAndroid Build Coastguard Worker 
92*77c1e3ccSAndroid Build Coastguard Worker   sm->row_pos = aom_calloc(sm->n_elem, sizeof(*sm->row_pos));
93*77c1e3ccSAndroid Build Coastguard Worker   sm->col_pos = aom_calloc(sm->n_elem, sizeof(*sm->col_pos));
94*77c1e3ccSAndroid Build Coastguard Worker   sm->value = aom_calloc(sm->n_elem, sizeof(*sm->value));
95*77c1e3ccSAndroid Build Coastguard Worker 
96*77c1e3ccSAndroid Build Coastguard Worker   if (!sm->row_pos || !sm->col_pos || !sm->value) {
97*77c1e3ccSAndroid Build Coastguard Worker     av1_free_sparse_mtx_elems(sm);
98*77c1e3ccSAndroid Build Coastguard Worker     return -1;
99*77c1e3ccSAndroid Build Coastguard Worker   }
100*77c1e3ccSAndroid Build Coastguard Worker 
101*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < sm1->n_elem; i++) {
102*77c1e3ccSAndroid Build Coastguard Worker     sm->row_pos[i] = sm1->row_pos[i] + row_offset1;
103*77c1e3ccSAndroid Build Coastguard Worker     sm->col_pos[i] = sm1->col_pos[i] + col_offset1;
104*77c1e3ccSAndroid Build Coastguard Worker   }
105*77c1e3ccSAndroid Build Coastguard Worker   memcpy(sm->value, sm1->value, sm1->n_elem * sizeof(*sm1->value));
106*77c1e3ccSAndroid Build Coastguard Worker   int n_elem1 = sm1->n_elem;
107*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < sm2->n_elem; i++) {
108*77c1e3ccSAndroid Build Coastguard Worker     sm->row_pos[n_elem1 + i] = sm2->row_pos[i] + row_offset2;
109*77c1e3ccSAndroid Build Coastguard Worker     sm->col_pos[n_elem1 + i] = sm2->col_pos[i] + col_offset2;
110*77c1e3ccSAndroid Build Coastguard Worker   }
111*77c1e3ccSAndroid Build Coastguard Worker   memcpy(sm->value + n_elem1, sm2->value, sm2->n_elem * sizeof(*sm2->value));
112*77c1e3ccSAndroid Build Coastguard Worker   return 0;
113*77c1e3ccSAndroid Build Coastguard Worker }
114*77c1e3ccSAndroid Build Coastguard Worker 
av1_free_sparse_mtx_elems(SPARSE_MTX * sm)115*77c1e3ccSAndroid Build Coastguard Worker void av1_free_sparse_mtx_elems(SPARSE_MTX *sm) {
116*77c1e3ccSAndroid Build Coastguard Worker   sm->n_cols = 0;
117*77c1e3ccSAndroid Build Coastguard Worker   sm->n_rows = 0;
118*77c1e3ccSAndroid Build Coastguard Worker   if (sm->n_elem != 0) {
119*77c1e3ccSAndroid Build Coastguard Worker     aom_free(sm->row_pos);
120*77c1e3ccSAndroid Build Coastguard Worker     aom_free(sm->col_pos);
121*77c1e3ccSAndroid Build Coastguard Worker     aom_free(sm->value);
122*77c1e3ccSAndroid Build Coastguard Worker   }
123*77c1e3ccSAndroid Build Coastguard Worker   sm->n_elem = 0;
124*77c1e3ccSAndroid Build Coastguard Worker }
125*77c1e3ccSAndroid Build Coastguard Worker 
126*77c1e3ccSAndroid Build Coastguard Worker /*
127*77c1e3ccSAndroid Build Coastguard Worker  * Calculate matrix and vector multiplication: A*b
128*77c1e3ccSAndroid Build Coastguard Worker  *
129*77c1e3ccSAndroid Build Coastguard Worker  * Input:
130*77c1e3ccSAndroid Build Coastguard Worker  * sm: matrix A
131*77c1e3ccSAndroid Build Coastguard Worker  * srcv: the vector b to be multiplied to
132*77c1e3ccSAndroid Build Coastguard Worker  * dstl: the length of vectors
133*77c1e3ccSAndroid Build Coastguard Worker  *
134*77c1e3ccSAndroid Build Coastguard Worker  * Output:
135*77c1e3ccSAndroid Build Coastguard Worker  * dstv: pointer to the resulting vector
136*77c1e3ccSAndroid Build Coastguard Worker  */
av1_mtx_vect_multi_right(const SPARSE_MTX * sm,const double * srcv,double * dstv,int dstl)137*77c1e3ccSAndroid Build Coastguard Worker void av1_mtx_vect_multi_right(const SPARSE_MTX *sm, const double *srcv,
138*77c1e3ccSAndroid Build Coastguard Worker                               double *dstv, int dstl) {
139*77c1e3ccSAndroid Build Coastguard Worker   memset(dstv, 0, sizeof(*dstv) * dstl);
140*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < sm->n_elem; i++) {
141*77c1e3ccSAndroid Build Coastguard Worker     dstv[sm->row_pos[i]] += srcv[sm->col_pos[i]] * sm->value[i];
142*77c1e3ccSAndroid Build Coastguard Worker   }
143*77c1e3ccSAndroid Build Coastguard Worker }
144*77c1e3ccSAndroid Build Coastguard Worker /*
145*77c1e3ccSAndroid Build Coastguard Worker  * Calculate matrix and vector multiplication: b*A
146*77c1e3ccSAndroid Build Coastguard Worker  *
147*77c1e3ccSAndroid Build Coastguard Worker  * Input:
148*77c1e3ccSAndroid Build Coastguard Worker  * sm: matrix A
149*77c1e3ccSAndroid Build Coastguard Worker  * srcv: the vector b to be multiplied to
150*77c1e3ccSAndroid Build Coastguard Worker  * dstl: the length of vectors
151*77c1e3ccSAndroid Build Coastguard Worker  *
152*77c1e3ccSAndroid Build Coastguard Worker  * Output:
153*77c1e3ccSAndroid Build Coastguard Worker  * dstv: pointer to the resulting vector
154*77c1e3ccSAndroid Build Coastguard Worker  */
av1_mtx_vect_multi_left(const SPARSE_MTX * sm,const double * srcv,double * dstv,int dstl)155*77c1e3ccSAndroid Build Coastguard Worker void av1_mtx_vect_multi_left(const SPARSE_MTX *sm, const double *srcv,
156*77c1e3ccSAndroid Build Coastguard Worker                              double *dstv, int dstl) {
157*77c1e3ccSAndroid Build Coastguard Worker   memset(dstv, 0, sizeof(*dstv) * dstl);
158*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < sm->n_elem; i++) {
159*77c1e3ccSAndroid Build Coastguard Worker     dstv[sm->col_pos[i]] += srcv[sm->row_pos[i]] * sm->value[i];
160*77c1e3ccSAndroid Build Coastguard Worker   }
161*77c1e3ccSAndroid Build Coastguard Worker }
162*77c1e3ccSAndroid Build Coastguard Worker 
163*77c1e3ccSAndroid Build Coastguard Worker /*
164*77c1e3ccSAndroid Build Coastguard Worker  * Calculate inner product of two vectors
165*77c1e3ccSAndroid Build Coastguard Worker  *
166*77c1e3ccSAndroid Build Coastguard Worker  * Input:
167*77c1e3ccSAndroid Build Coastguard Worker  * src1, scr2: the vectors to be multiplied
168*77c1e3ccSAndroid Build Coastguard Worker  * src1l: length of the vectors
169*77c1e3ccSAndroid Build Coastguard Worker  *
170*77c1e3ccSAndroid Build Coastguard Worker  * Output:
171*77c1e3ccSAndroid Build Coastguard Worker  * the inner product
172*77c1e3ccSAndroid Build Coastguard Worker  */
av1_vect_vect_multi(const double * src1,int src1l,const double * src2)173*77c1e3ccSAndroid Build Coastguard Worker double av1_vect_vect_multi(const double *src1, int src1l, const double *src2) {
174*77c1e3ccSAndroid Build Coastguard Worker   double result = 0;
175*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < src1l; i++) {
176*77c1e3ccSAndroid Build Coastguard Worker     result += src1[i] * src2[i];
177*77c1e3ccSAndroid Build Coastguard Worker   }
178*77c1e3ccSAndroid Build Coastguard Worker   return result;
179*77c1e3ccSAndroid Build Coastguard Worker }
180*77c1e3ccSAndroid Build Coastguard Worker 
181*77c1e3ccSAndroid Build Coastguard Worker /*
182*77c1e3ccSAndroid Build Coastguard Worker  * Multiply each element in the matrix sm with a constant c
183*77c1e3ccSAndroid Build Coastguard Worker  */
av1_constant_multiply_sparse_matrix(SPARSE_MTX * sm,double c)184*77c1e3ccSAndroid Build Coastguard Worker void av1_constant_multiply_sparse_matrix(SPARSE_MTX *sm, double c) {
185*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < sm->n_elem; i++) {
186*77c1e3ccSAndroid Build Coastguard Worker     sm->value[i] *= c;
187*77c1e3ccSAndroid Build Coastguard Worker   }
188*77c1e3ccSAndroid Build Coastguard Worker }
189*77c1e3ccSAndroid Build Coastguard Worker 
free_solver_local_buf(double * buf1,double * buf2,double * buf3,double * buf4,double * buf5,double * buf6,double * buf7)190*77c1e3ccSAndroid Build Coastguard Worker static inline void free_solver_local_buf(double *buf1, double *buf2,
191*77c1e3ccSAndroid Build Coastguard Worker                                          double *buf3, double *buf4,
192*77c1e3ccSAndroid Build Coastguard Worker                                          double *buf5, double *buf6,
193*77c1e3ccSAndroid Build Coastguard Worker                                          double *buf7) {
194*77c1e3ccSAndroid Build Coastguard Worker   aom_free(buf1);
195*77c1e3ccSAndroid Build Coastguard Worker   aom_free(buf2);
196*77c1e3ccSAndroid Build Coastguard Worker   aom_free(buf3);
197*77c1e3ccSAndroid Build Coastguard Worker   aom_free(buf4);
198*77c1e3ccSAndroid Build Coastguard Worker   aom_free(buf5);
199*77c1e3ccSAndroid Build Coastguard Worker   aom_free(buf6);
200*77c1e3ccSAndroid Build Coastguard Worker   aom_free(buf7);
201*77c1e3ccSAndroid Build Coastguard Worker }
202*77c1e3ccSAndroid Build Coastguard Worker 
203*77c1e3ccSAndroid Build Coastguard Worker /*
204*77c1e3ccSAndroid Build Coastguard Worker  * Solve for Ax = b
205*77c1e3ccSAndroid Build Coastguard Worker  * no requirement on A
206*77c1e3ccSAndroid Build Coastguard Worker  *
207*77c1e3ccSAndroid Build Coastguard Worker  * Input:
208*77c1e3ccSAndroid Build Coastguard Worker  * A: the sparse matrix
209*77c1e3ccSAndroid Build Coastguard Worker  * b: the vector b
210*77c1e3ccSAndroid Build Coastguard Worker  * bl: length of b
211*77c1e3ccSAndroid Build Coastguard Worker  * x: the vector x
212*77c1e3ccSAndroid Build Coastguard Worker  *
213*77c1e3ccSAndroid Build Coastguard Worker  * Output:
214*77c1e3ccSAndroid Build Coastguard Worker  * x: pointer to the solution vector
215*77c1e3ccSAndroid Build Coastguard Worker  *
216*77c1e3ccSAndroid Build Coastguard Worker  * Return: 0  - success
217*77c1e3ccSAndroid Build Coastguard Worker  *         -1 - failed
218*77c1e3ccSAndroid Build Coastguard Worker  */
av1_bi_conjugate_gradient_sparse(const SPARSE_MTX * A,const double * b,int bl,double * x)219*77c1e3ccSAndroid Build Coastguard Worker int av1_bi_conjugate_gradient_sparse(const SPARSE_MTX *A, const double *b,
220*77c1e3ccSAndroid Build Coastguard Worker                                      int bl, double *x) {
221*77c1e3ccSAndroid Build Coastguard Worker   double *r = NULL, *r_hat = NULL, *p = NULL, *p_hat = NULL, *Ap = NULL,
222*77c1e3ccSAndroid Build Coastguard Worker          *p_hatA = NULL, *x_hat = NULL;
223*77c1e3ccSAndroid Build Coastguard Worker   double alpha, beta, rtr, r_norm_2;
224*77c1e3ccSAndroid Build Coastguard Worker   double denormtemp;
225*77c1e3ccSAndroid Build Coastguard Worker 
226*77c1e3ccSAndroid Build Coastguard Worker   // initialize
227*77c1e3ccSAndroid Build Coastguard Worker   r = aom_calloc(bl, sizeof(*r));
228*77c1e3ccSAndroid Build Coastguard Worker   r_hat = aom_calloc(bl, sizeof(*r_hat));
229*77c1e3ccSAndroid Build Coastguard Worker   p = aom_calloc(bl, sizeof(*p));
230*77c1e3ccSAndroid Build Coastguard Worker   p_hat = aom_calloc(bl, sizeof(*p_hat));
231*77c1e3ccSAndroid Build Coastguard Worker   Ap = aom_calloc(bl, sizeof(*Ap));
232*77c1e3ccSAndroid Build Coastguard Worker   p_hatA = aom_calloc(bl, sizeof(*p_hatA));
233*77c1e3ccSAndroid Build Coastguard Worker   x_hat = aom_calloc(bl, sizeof(*x_hat));
234*77c1e3ccSAndroid Build Coastguard Worker   if (!r || !r_hat || !p || !p_hat || !Ap || !p_hatA || !x_hat) {
235*77c1e3ccSAndroid Build Coastguard Worker     free_solver_local_buf(r, r_hat, p, p_hat, Ap, p_hatA, x_hat);
236*77c1e3ccSAndroid Build Coastguard Worker     return -1;
237*77c1e3ccSAndroid Build Coastguard Worker   }
238*77c1e3ccSAndroid Build Coastguard Worker 
239*77c1e3ccSAndroid Build Coastguard Worker   int i;
240*77c1e3ccSAndroid Build Coastguard Worker   for (i = 0; i < bl; i++) {
241*77c1e3ccSAndroid Build Coastguard Worker     r[i] = b[i];
242*77c1e3ccSAndroid Build Coastguard Worker     r_hat[i] = b[i];
243*77c1e3ccSAndroid Build Coastguard Worker     p[i] = r[i];
244*77c1e3ccSAndroid Build Coastguard Worker     p_hat[i] = r_hat[i];
245*77c1e3ccSAndroid Build Coastguard Worker     x[i] = 0;
246*77c1e3ccSAndroid Build Coastguard Worker     x_hat[i] = 0;
247*77c1e3ccSAndroid Build Coastguard Worker   }
248*77c1e3ccSAndroid Build Coastguard Worker   r_norm_2 = av1_vect_vect_multi(r_hat, bl, r);
249*77c1e3ccSAndroid Build Coastguard Worker   for (int k = 0; k < MAX_CG_SP_ITER; k++) {
250*77c1e3ccSAndroid Build Coastguard Worker     rtr = r_norm_2;
251*77c1e3ccSAndroid Build Coastguard Worker     av1_mtx_vect_multi_right(A, p, Ap, bl);
252*77c1e3ccSAndroid Build Coastguard Worker     av1_mtx_vect_multi_left(A, p_hat, p_hatA, bl);
253*77c1e3ccSAndroid Build Coastguard Worker 
254*77c1e3ccSAndroid Build Coastguard Worker     denormtemp = av1_vect_vect_multi(p_hat, bl, Ap);
255*77c1e3ccSAndroid Build Coastguard Worker     if (denormtemp < 1e-10) break;
256*77c1e3ccSAndroid Build Coastguard Worker     alpha = rtr / denormtemp;
257*77c1e3ccSAndroid Build Coastguard Worker     r_norm_2 = 0;
258*77c1e3ccSAndroid Build Coastguard Worker     for (i = 0; i < bl; i++) {
259*77c1e3ccSAndroid Build Coastguard Worker       x[i] += alpha * p[i];
260*77c1e3ccSAndroid Build Coastguard Worker       x_hat[i] += alpha * p_hat[i];
261*77c1e3ccSAndroid Build Coastguard Worker       r[i] -= alpha * Ap[i];
262*77c1e3ccSAndroid Build Coastguard Worker       r_hat[i] -= alpha * p_hatA[i];
263*77c1e3ccSAndroid Build Coastguard Worker       r_norm_2 += r_hat[i] * r[i];
264*77c1e3ccSAndroid Build Coastguard Worker     }
265*77c1e3ccSAndroid Build Coastguard Worker     if (sqrt(r_norm_2) < 1e-2) {
266*77c1e3ccSAndroid Build Coastguard Worker       break;
267*77c1e3ccSAndroid Build Coastguard Worker     }
268*77c1e3ccSAndroid Build Coastguard Worker     if (rtr < 1e-10) break;
269*77c1e3ccSAndroid Build Coastguard Worker     beta = r_norm_2 / rtr;
270*77c1e3ccSAndroid Build Coastguard Worker     for (i = 0; i < bl; i++) {
271*77c1e3ccSAndroid Build Coastguard Worker       p[i] = r[i] + beta * p[i];
272*77c1e3ccSAndroid Build Coastguard Worker       p_hat[i] = r_hat[i] + beta * p_hat[i];
273*77c1e3ccSAndroid Build Coastguard Worker     }
274*77c1e3ccSAndroid Build Coastguard Worker   }
275*77c1e3ccSAndroid Build Coastguard Worker   // free
276*77c1e3ccSAndroid Build Coastguard Worker   free_solver_local_buf(r, r_hat, p, p_hat, Ap, p_hatA, x_hat);
277*77c1e3ccSAndroid Build Coastguard Worker   return 0;
278*77c1e3ccSAndroid Build Coastguard Worker }
279*77c1e3ccSAndroid Build Coastguard Worker 
280*77c1e3ccSAndroid Build Coastguard Worker /*
281*77c1e3ccSAndroid Build Coastguard Worker  * Solve for Ax = b when A is symmetric and positive definite
282*77c1e3ccSAndroid Build Coastguard Worker  *
283*77c1e3ccSAndroid Build Coastguard Worker  * Input:
284*77c1e3ccSAndroid Build Coastguard Worker  * A: the sparse matrix
285*77c1e3ccSAndroid Build Coastguard Worker  * b: the vector b
286*77c1e3ccSAndroid Build Coastguard Worker  * bl: length of b
287*77c1e3ccSAndroid Build Coastguard Worker  * x: the vector x
288*77c1e3ccSAndroid Build Coastguard Worker  *
289*77c1e3ccSAndroid Build Coastguard Worker  * Output:
290*77c1e3ccSAndroid Build Coastguard Worker  * x: pointer to the solution vector
291*77c1e3ccSAndroid Build Coastguard Worker  *
292*77c1e3ccSAndroid Build Coastguard Worker  * Return: 0  - success
293*77c1e3ccSAndroid Build Coastguard Worker  *         -1 - failed
294*77c1e3ccSAndroid Build Coastguard Worker  */
av1_conjugate_gradient_sparse(const SPARSE_MTX * A,const double * b,int bl,double * x)295*77c1e3ccSAndroid Build Coastguard Worker int av1_conjugate_gradient_sparse(const SPARSE_MTX *A, const double *b, int bl,
296*77c1e3ccSAndroid Build Coastguard Worker                                   double *x) {
297*77c1e3ccSAndroid Build Coastguard Worker   double *r = NULL, *p = NULL, *Ap = NULL;
298*77c1e3ccSAndroid Build Coastguard Worker   double alpha, beta, rtr, r_norm_2;
299*77c1e3ccSAndroid Build Coastguard Worker   double denormtemp;
300*77c1e3ccSAndroid Build Coastguard Worker 
301*77c1e3ccSAndroid Build Coastguard Worker   // initialize
302*77c1e3ccSAndroid Build Coastguard Worker   r = aom_calloc(bl, sizeof(*r));
303*77c1e3ccSAndroid Build Coastguard Worker   p = aom_calloc(bl, sizeof(*p));
304*77c1e3ccSAndroid Build Coastguard Worker   Ap = aom_calloc(bl, sizeof(*Ap));
305*77c1e3ccSAndroid Build Coastguard Worker   if (!r || !p || !Ap) {
306*77c1e3ccSAndroid Build Coastguard Worker     free_solver_local_buf(r, p, Ap, NULL, NULL, NULL, NULL);
307*77c1e3ccSAndroid Build Coastguard Worker     return -1;
308*77c1e3ccSAndroid Build Coastguard Worker   }
309*77c1e3ccSAndroid Build Coastguard Worker 
310*77c1e3ccSAndroid Build Coastguard Worker   int i;
311*77c1e3ccSAndroid Build Coastguard Worker   for (i = 0; i < bl; i++) {
312*77c1e3ccSAndroid Build Coastguard Worker     r[i] = b[i];
313*77c1e3ccSAndroid Build Coastguard Worker     p[i] = r[i];
314*77c1e3ccSAndroid Build Coastguard Worker     x[i] = 0;
315*77c1e3ccSAndroid Build Coastguard Worker   }
316*77c1e3ccSAndroid Build Coastguard Worker   r_norm_2 = av1_vect_vect_multi(r, bl, r);
317*77c1e3ccSAndroid Build Coastguard Worker   int k;
318*77c1e3ccSAndroid Build Coastguard Worker   for (k = 0; k < MAX_CG_SP_ITER; k++) {
319*77c1e3ccSAndroid Build Coastguard Worker     rtr = r_norm_2;
320*77c1e3ccSAndroid Build Coastguard Worker     av1_mtx_vect_multi_right(A, p, Ap, bl);
321*77c1e3ccSAndroid Build Coastguard Worker     denormtemp = av1_vect_vect_multi(p, bl, Ap);
322*77c1e3ccSAndroid Build Coastguard Worker     if (denormtemp < 1e-10) break;
323*77c1e3ccSAndroid Build Coastguard Worker     alpha = rtr / denormtemp;
324*77c1e3ccSAndroid Build Coastguard Worker     r_norm_2 = 0;
325*77c1e3ccSAndroid Build Coastguard Worker     for (i = 0; i < bl; i++) {
326*77c1e3ccSAndroid Build Coastguard Worker       x[i] += alpha * p[i];
327*77c1e3ccSAndroid Build Coastguard Worker       r[i] -= alpha * Ap[i];
328*77c1e3ccSAndroid Build Coastguard Worker       r_norm_2 += r[i] * r[i];
329*77c1e3ccSAndroid Build Coastguard Worker     }
330*77c1e3ccSAndroid Build Coastguard Worker     if (r_norm_2 < 1e-8 * bl) break;
331*77c1e3ccSAndroid Build Coastguard Worker     if (rtr < 1e-10) break;
332*77c1e3ccSAndroid Build Coastguard Worker     beta = r_norm_2 / rtr;
333*77c1e3ccSAndroid Build Coastguard Worker     for (i = 0; i < bl; i++) {
334*77c1e3ccSAndroid Build Coastguard Worker       p[i] = r[i] + beta * p[i];
335*77c1e3ccSAndroid Build Coastguard Worker     }
336*77c1e3ccSAndroid Build Coastguard Worker   }
337*77c1e3ccSAndroid Build Coastguard Worker   // free
338*77c1e3ccSAndroid Build Coastguard Worker   free_solver_local_buf(r, p, Ap, NULL, NULL, NULL, NULL);
339*77c1e3ccSAndroid Build Coastguard Worker 
340*77c1e3ccSAndroid Build Coastguard Worker   return 0;
341*77c1e3ccSAndroid Build Coastguard Worker }
342*77c1e3ccSAndroid Build Coastguard Worker 
343*77c1e3ccSAndroid Build Coastguard Worker /*
344*77c1e3ccSAndroid Build Coastguard Worker  * Solve for Ax = b using Jacobi method
345*77c1e3ccSAndroid Build Coastguard Worker  *
346*77c1e3ccSAndroid Build Coastguard Worker  * Input:
347*77c1e3ccSAndroid Build Coastguard Worker  * A: the sparse matrix
348*77c1e3ccSAndroid Build Coastguard Worker  * b: the vector b
349*77c1e3ccSAndroid Build Coastguard Worker  * bl: length of b
350*77c1e3ccSAndroid Build Coastguard Worker  * x: the vector x
351*77c1e3ccSAndroid Build Coastguard Worker  *
352*77c1e3ccSAndroid Build Coastguard Worker  * Output:
353*77c1e3ccSAndroid Build Coastguard Worker  * x: pointer to the solution vector
354*77c1e3ccSAndroid Build Coastguard Worker  *
355*77c1e3ccSAndroid Build Coastguard Worker  * Return: 0  - success
356*77c1e3ccSAndroid Build Coastguard Worker  *         -1 - failed
357*77c1e3ccSAndroid Build Coastguard Worker  */
av1_jacobi_sparse(const SPARSE_MTX * A,const double * b,int bl,double * x)358*77c1e3ccSAndroid Build Coastguard Worker int av1_jacobi_sparse(const SPARSE_MTX *A, const double *b, int bl, double *x) {
359*77c1e3ccSAndroid Build Coastguard Worker   double *diags = NULL, *Rx = NULL, *x_last = NULL, *x_cur = NULL,
360*77c1e3ccSAndroid Build Coastguard Worker          *tempx = NULL;
361*77c1e3ccSAndroid Build Coastguard Worker   double resi2;
362*77c1e3ccSAndroid Build Coastguard Worker 
363*77c1e3ccSAndroid Build Coastguard Worker   diags = aom_calloc(bl, sizeof(*diags));
364*77c1e3ccSAndroid Build Coastguard Worker   Rx = aom_calloc(bl, sizeof(*Rx));
365*77c1e3ccSAndroid Build Coastguard Worker   x_last = aom_calloc(bl, sizeof(*x_last));
366*77c1e3ccSAndroid Build Coastguard Worker   x_cur = aom_calloc(bl, sizeof(*x_cur));
367*77c1e3ccSAndroid Build Coastguard Worker 
368*77c1e3ccSAndroid Build Coastguard Worker   if (!diags || !Rx || !x_last || !x_cur) {
369*77c1e3ccSAndroid Build Coastguard Worker     free_solver_local_buf(diags, Rx, x_last, x_cur, NULL, NULL, NULL);
370*77c1e3ccSAndroid Build Coastguard Worker     return -1;
371*77c1e3ccSAndroid Build Coastguard Worker   }
372*77c1e3ccSAndroid Build Coastguard Worker 
373*77c1e3ccSAndroid Build Coastguard Worker   int i;
374*77c1e3ccSAndroid Build Coastguard Worker   memset(x_last, 0, sizeof(*x_last) * bl);
375*77c1e3ccSAndroid Build Coastguard Worker   // get the diagonals of A
376*77c1e3ccSAndroid Build Coastguard Worker   memset(diags, 0, sizeof(*diags) * bl);
377*77c1e3ccSAndroid Build Coastguard Worker   for (int c = 0; c < A->n_elem; c++) {
378*77c1e3ccSAndroid Build Coastguard Worker     if (A->row_pos[c] != A->col_pos[c]) continue;
379*77c1e3ccSAndroid Build Coastguard Worker     diags[A->row_pos[c]] = A->value[c];
380*77c1e3ccSAndroid Build Coastguard Worker   }
381*77c1e3ccSAndroid Build Coastguard Worker   int k;
382*77c1e3ccSAndroid Build Coastguard Worker   for (k = 0; k < MAX_CG_SP_ITER; k++) {
383*77c1e3ccSAndroid Build Coastguard Worker     // R = A - diag(diags)
384*77c1e3ccSAndroid Build Coastguard Worker     // get R*x_last
385*77c1e3ccSAndroid Build Coastguard Worker     memset(Rx, 0, sizeof(*Rx) * bl);
386*77c1e3ccSAndroid Build Coastguard Worker     for (int c = 0; c < A->n_elem; c++) {
387*77c1e3ccSAndroid Build Coastguard Worker       if (A->row_pos[c] == A->col_pos[c]) continue;
388*77c1e3ccSAndroid Build Coastguard Worker       Rx[A->row_pos[c]] += x_last[A->col_pos[c]] * A->value[c];
389*77c1e3ccSAndroid Build Coastguard Worker     }
390*77c1e3ccSAndroid Build Coastguard Worker     resi2 = 0;
391*77c1e3ccSAndroid Build Coastguard Worker     for (i = 0; i < bl; i++) {
392*77c1e3ccSAndroid Build Coastguard Worker       x_cur[i] = (b[i] - Rx[i]) / diags[i];
393*77c1e3ccSAndroid Build Coastguard Worker       resi2 += (x_last[i] - x_cur[i]) * (x_last[i] - x_cur[i]);
394*77c1e3ccSAndroid Build Coastguard Worker     }
395*77c1e3ccSAndroid Build Coastguard Worker     if (resi2 <= 1e-10 * bl) break;
396*77c1e3ccSAndroid Build Coastguard Worker     // swap last & cur buffer ptrs
397*77c1e3ccSAndroid Build Coastguard Worker     tempx = x_last;
398*77c1e3ccSAndroid Build Coastguard Worker     x_last = x_cur;
399*77c1e3ccSAndroid Build Coastguard Worker     x_cur = tempx;
400*77c1e3ccSAndroid Build Coastguard Worker   }
401*77c1e3ccSAndroid Build Coastguard Worker   printf("\n numiter: %d\n", k);
402*77c1e3ccSAndroid Build Coastguard Worker   for (i = 0; i < bl; i++) {
403*77c1e3ccSAndroid Build Coastguard Worker     x[i] = x_cur[i];
404*77c1e3ccSAndroid Build Coastguard Worker   }
405*77c1e3ccSAndroid Build Coastguard Worker   free_solver_local_buf(diags, Rx, x_last, x_cur, NULL, NULL, NULL);
406*77c1e3ccSAndroid Build Coastguard Worker   return 0;
407*77c1e3ccSAndroid Build Coastguard Worker }
408*77c1e3ccSAndroid Build Coastguard Worker 
409*77c1e3ccSAndroid Build Coastguard Worker /*
410*77c1e3ccSAndroid Build Coastguard Worker  * Solve for Ax = b using Steepest descent method
411*77c1e3ccSAndroid Build Coastguard Worker  *
412*77c1e3ccSAndroid Build Coastguard Worker  * Input:
413*77c1e3ccSAndroid Build Coastguard Worker  * A: the sparse matrix
414*77c1e3ccSAndroid Build Coastguard Worker  * b: the vector b
415*77c1e3ccSAndroid Build Coastguard Worker  * bl: length of b
416*77c1e3ccSAndroid Build Coastguard Worker  * x: the vector x
417*77c1e3ccSAndroid Build Coastguard Worker  *
418*77c1e3ccSAndroid Build Coastguard Worker  * Output:
419*77c1e3ccSAndroid Build Coastguard Worker  * x: pointer to the solution vector
420*77c1e3ccSAndroid Build Coastguard Worker  *
421*77c1e3ccSAndroid Build Coastguard Worker  * Return: 0  - success
422*77c1e3ccSAndroid Build Coastguard Worker  *         -1 - failed
423*77c1e3ccSAndroid Build Coastguard Worker  */
av1_steepest_descent_sparse(const SPARSE_MTX * A,const double * b,int bl,double * x)424*77c1e3ccSAndroid Build Coastguard Worker int av1_steepest_descent_sparse(const SPARSE_MTX *A, const double *b, int bl,
425*77c1e3ccSAndroid Build Coastguard Worker                                 double *x) {
426*77c1e3ccSAndroid Build Coastguard Worker   double *d = NULL, *Ad = NULL, *Ax = NULL;
427*77c1e3ccSAndroid Build Coastguard Worker   double resi2, resi2_last, dAd, temp;
428*77c1e3ccSAndroid Build Coastguard Worker 
429*77c1e3ccSAndroid Build Coastguard Worker   d = aom_calloc(bl, sizeof(*d));
430*77c1e3ccSAndroid Build Coastguard Worker   Ax = aom_calloc(bl, sizeof(*Ax));
431*77c1e3ccSAndroid Build Coastguard Worker   Ad = aom_calloc(bl, sizeof(*Ad));
432*77c1e3ccSAndroid Build Coastguard Worker 
433*77c1e3ccSAndroid Build Coastguard Worker   if (!d || !Ax || !Ad) {
434*77c1e3ccSAndroid Build Coastguard Worker     free_solver_local_buf(d, Ax, Ad, NULL, NULL, NULL, NULL);
435*77c1e3ccSAndroid Build Coastguard Worker     return -1;
436*77c1e3ccSAndroid Build Coastguard Worker   }
437*77c1e3ccSAndroid Build Coastguard Worker 
438*77c1e3ccSAndroid Build Coastguard Worker   int i;
439*77c1e3ccSAndroid Build Coastguard Worker   // initialize with 0s
440*77c1e3ccSAndroid Build Coastguard Worker   resi2 = 0;
441*77c1e3ccSAndroid Build Coastguard Worker   for (i = 0; i < bl; i++) {
442*77c1e3ccSAndroid Build Coastguard Worker     x[i] = 0;
443*77c1e3ccSAndroid Build Coastguard Worker     d[i] = b[i];
444*77c1e3ccSAndroid Build Coastguard Worker     resi2 += d[i] * d[i] / bl;
445*77c1e3ccSAndroid Build Coastguard Worker   }
446*77c1e3ccSAndroid Build Coastguard Worker   int k;
447*77c1e3ccSAndroid Build Coastguard Worker   for (k = 0; k < MAX_CG_SP_ITER; k++) {
448*77c1e3ccSAndroid Build Coastguard Worker     // get A*x_last
449*77c1e3ccSAndroid Build Coastguard Worker     av1_mtx_vect_multi_right(A, d, Ad, bl);
450*77c1e3ccSAndroid Build Coastguard Worker     dAd = resi2 * bl / av1_vect_vect_multi(d, bl, Ad);
451*77c1e3ccSAndroid Build Coastguard Worker     for (i = 0; i < bl; i++) {
452*77c1e3ccSAndroid Build Coastguard Worker       temp = dAd * d[i];
453*77c1e3ccSAndroid Build Coastguard Worker       x[i] = x[i] + temp;
454*77c1e3ccSAndroid Build Coastguard Worker     }
455*77c1e3ccSAndroid Build Coastguard Worker     av1_mtx_vect_multi_right(A, x, Ax, bl);
456*77c1e3ccSAndroid Build Coastguard Worker     resi2_last = resi2;
457*77c1e3ccSAndroid Build Coastguard Worker     resi2 = 0;
458*77c1e3ccSAndroid Build Coastguard Worker     for (i = 0; i < bl; i++) {
459*77c1e3ccSAndroid Build Coastguard Worker       d[i] = b[i] - Ax[i];
460*77c1e3ccSAndroid Build Coastguard Worker       resi2 += d[i] * d[i] / bl;
461*77c1e3ccSAndroid Build Coastguard Worker     }
462*77c1e3ccSAndroid Build Coastguard Worker     if (resi2 <= 1e-8) break;
463*77c1e3ccSAndroid Build Coastguard Worker     if (resi2_last - resi2 < 1e-8) {
464*77c1e3ccSAndroid Build Coastguard Worker       break;
465*77c1e3ccSAndroid Build Coastguard Worker     }
466*77c1e3ccSAndroid Build Coastguard Worker   }
467*77c1e3ccSAndroid Build Coastguard Worker   free_solver_local_buf(d, Ax, Ad, NULL, NULL, NULL, NULL);
468*77c1e3ccSAndroid Build Coastguard Worker 
469*77c1e3ccSAndroid Build Coastguard Worker   return 0;
470*77c1e3ccSAndroid Build Coastguard Worker }
471*77c1e3ccSAndroid Build Coastguard Worker 
472*77c1e3ccSAndroid Build Coastguard Worker #endif  // CONFIG_OPTICAL_FLOW_API
473