xref: /aosp_15_r20/external/angle/src/common/matrix_utils.h (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2015 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Matrix:
7 //   Utility class implementing various matrix operations.
8 //   Supports matrices with minimum 2 and maximum 4 number of rows/columns.
9 //
10 // TODO: Check if we can merge Matrix.h in sample_util with this and replace it with this
11 // implementation.
12 // TODO: Rename this file to Matrix.h once we remove Matrix.h in sample_util.
13 
14 #ifndef COMMON_MATRIX_UTILS_H_
15 #define COMMON_MATRIX_UTILS_H_
16 
17 #include <array>
18 #include <vector>
19 
20 #include "common/debug.h"
21 #include "common/mathutil.h"
22 #include "common/vector_utils.h"
23 
24 namespace
25 {
26 template <typename T4x4>
CofactorTransposed(const T4x4 & mat,T4x4 & coft)27 void CofactorTransposed(const T4x4 &mat, T4x4 &coft)
28 {
29     coft(0, 0) = mat(1, 1) * mat(2, 2) * mat(3, 3) + mat(2, 1) * mat(3, 2) * mat(1, 3) +
30                  mat(3, 1) * mat(1, 2) * mat(2, 3) - mat(1, 1) * mat(3, 2) * mat(2, 3) -
31                  mat(2, 1) * mat(1, 2) * mat(3, 3) - mat(3, 1) * mat(2, 2) * mat(1, 3);
32     coft(1, 0) = -(mat(1, 0) * mat(2, 2) * mat(3, 3) + mat(2, 0) * mat(3, 2) * mat(1, 3) +
33                    mat(3, 0) * mat(1, 2) * mat(2, 3) - mat(1, 0) * mat(3, 2) * mat(2, 3) -
34                    mat(2, 0) * mat(1, 2) * mat(3, 3) - mat(3, 0) * mat(2, 2) * mat(1, 3));
35     coft(2, 0) = mat(1, 0) * mat(2, 1) * mat(3, 3) + mat(2, 0) * mat(3, 1) * mat(1, 3) +
36                  mat(3, 0) * mat(1, 1) * mat(2, 3) - mat(1, 0) * mat(3, 1) * mat(2, 3) -
37                  mat(2, 0) * mat(1, 1) * mat(3, 3) - mat(3, 0) * mat(2, 1) * mat(1, 3);
38     coft(3, 0) = -(mat(1, 0) * mat(2, 1) * mat(3, 2) + mat(2, 0) * mat(3, 1) * mat(1, 2) +
39                    mat(3, 0) * mat(1, 1) * mat(2, 2) - mat(1, 0) * mat(3, 1) * mat(2, 2) -
40                    mat(2, 0) * mat(1, 1) * mat(3, 2) - mat(3, 0) * mat(2, 1) * mat(1, 2));
41     coft(0, 1) = -(mat(0, 1) * mat(2, 2) * mat(3, 3) + mat(2, 1) * mat(3, 2) * mat(0, 3) +
42                    mat(3, 1) * mat(0, 2) * mat(2, 3) - mat(0, 1) * mat(3, 2) * mat(2, 3) -
43                    mat(2, 1) * mat(0, 2) * mat(3, 3) - mat(3, 1) * mat(2, 2) * mat(0, 3));
44     coft(1, 1) = mat(0, 0) * mat(2, 2) * mat(3, 3) + mat(2, 0) * mat(3, 2) * mat(0, 3) +
45                  mat(3, 0) * mat(0, 2) * mat(2, 3) - mat(0, 0) * mat(3, 2) * mat(2, 3) -
46                  mat(2, 0) * mat(0, 2) * mat(3, 3) - mat(3, 0) * mat(2, 2) * mat(0, 3);
47     coft(2, 1) = -(mat(0, 0) * mat(2, 1) * mat(3, 3) + mat(2, 0) * mat(3, 1) * mat(0, 3) +
48                    mat(3, 0) * mat(0, 1) * mat(2, 3) - mat(0, 0) * mat(3, 1) * mat(2, 3) -
49                    mat(2, 0) * mat(0, 1) * mat(3, 3) - mat(3, 0) * mat(2, 1) * mat(0, 3));
50     coft(3, 1) = mat(0, 0) * mat(2, 1) * mat(3, 2) + mat(2, 0) * mat(3, 1) * mat(0, 2) +
51                  mat(3, 0) * mat(0, 1) * mat(2, 2) - mat(0, 0) * mat(3, 1) * mat(2, 2) -
52                  mat(2, 0) * mat(0, 1) * mat(3, 2) - mat(3, 0) * mat(2, 1) * mat(0, 2);
53     coft(0, 2) = mat(0, 1) * mat(1, 2) * mat(3, 3) + mat(1, 1) * mat(3, 2) * mat(0, 3) +
54                  mat(3, 1) * mat(0, 2) * mat(1, 3) - mat(0, 1) * mat(3, 2) * mat(1, 3) -
55                  mat(1, 1) * mat(0, 2) * mat(3, 3) - mat(3, 1) * mat(1, 2) * mat(0, 3);
56     coft(1, 2) = -(mat(0, 0) * mat(1, 2) * mat(3, 3) + mat(1, 0) * mat(3, 2) * mat(0, 3) +
57                    mat(3, 0) * mat(0, 2) * mat(1, 3) - mat(0, 0) * mat(3, 2) * mat(1, 3) -
58                    mat(1, 0) * mat(0, 2) * mat(3, 3) - mat(3, 0) * mat(1, 2) * mat(0, 3));
59     coft(2, 2) = mat(0, 0) * mat(1, 1) * mat(3, 3) + mat(1, 0) * mat(3, 1) * mat(0, 3) +
60                  mat(3, 0) * mat(0, 1) * mat(1, 3) - mat(0, 0) * mat(3, 1) * mat(1, 3) -
61                  mat(1, 0) * mat(0, 1) * mat(3, 3) - mat(3, 0) * mat(1, 1) * mat(0, 3);
62     coft(3, 2) = -(mat(0, 0) * mat(1, 1) * mat(3, 2) + mat(1, 0) * mat(3, 1) * mat(0, 2) +
63                    mat(3, 0) * mat(0, 1) * mat(1, 2) - mat(0, 0) * mat(3, 1) * mat(1, 2) -
64                    mat(1, 0) * mat(0, 1) * mat(3, 2) - mat(3, 0) * mat(1, 1) * mat(0, 2));
65     coft(0, 3) = -(mat(0, 1) * mat(1, 2) * mat(2, 3) + mat(1, 1) * mat(2, 2) * mat(0, 3) +
66                    mat(2, 1) * mat(0, 2) * mat(1, 3) - mat(0, 1) * mat(2, 2) * mat(1, 3) -
67                    mat(1, 1) * mat(0, 2) * mat(2, 3) - mat(2, 1) * mat(1, 2) * mat(0, 3));
68     coft(1, 3) = mat(0, 0) * mat(1, 2) * mat(2, 3) + mat(1, 0) * mat(2, 2) * mat(0, 3) +
69                  mat(2, 0) * mat(0, 2) * mat(1, 3) - mat(0, 0) * mat(2, 2) * mat(1, 3) -
70                  mat(1, 0) * mat(0, 2) * mat(2, 3) - mat(2, 0) * mat(1, 2) * mat(0, 3);
71     coft(2, 3) = -(mat(0, 0) * mat(1, 1) * mat(2, 3) + mat(1, 0) * mat(2, 1) * mat(0, 3) +
72                    mat(2, 0) * mat(0, 1) * mat(1, 3) - mat(0, 0) * mat(2, 1) * mat(1, 3) -
73                    mat(1, 0) * mat(0, 1) * mat(2, 3) - mat(2, 0) * mat(1, 1) * mat(0, 3));
74     coft(3, 3) = mat(0, 0) * mat(1, 1) * mat(2, 2) + mat(1, 0) * mat(2, 1) * mat(0, 2) +
75                  mat(2, 0) * mat(0, 1) * mat(1, 2) - mat(0, 0) * mat(2, 1) * mat(1, 2) -
76                  mat(1, 0) * mat(0, 1) * mat(2, 2) - mat(2, 0) * mat(1, 1) * mat(0, 2);
77 }
78 }  // namespace
79 
80 namespace angle
81 {
82 
83 template <typename T>
84 class Matrix
85 {
86   public:
Matrix(const std::vector<T> & elements,const unsigned int numRows,const unsigned int numCols)87     Matrix(const std::vector<T> &elements, const unsigned int numRows, const unsigned int numCols)
88         : mElements(elements), mRows(numRows), mCols(numCols)
89     {
90         ASSERT(rows() >= 1 && rows() <= 4);
91         ASSERT(columns() >= 1 && columns() <= 4);
92     }
93 
Matrix(const std::vector<T> & elements,const unsigned int size)94     Matrix(const std::vector<T> &elements, const unsigned int size)
95         : mElements(elements), mRows(size), mCols(size)
96     {
97         ASSERT(rows() >= 1 && rows() <= 4);
98         ASSERT(columns() >= 1 && columns() <= 4);
99     }
100 
Matrix(const T * elements,const unsigned int size)101     Matrix(const T *elements, const unsigned int size) : mRows(size), mCols(size)
102     {
103         ASSERT(rows() >= 1 && rows() <= 4);
104         ASSERT(columns() >= 1 && columns() <= 4);
105         for (size_t i = 0; i < size * size; i++)
106             mElements.push_back(elements[i]);
107     }
108 
operator()109     const T &operator()(const unsigned int rowIndex, const unsigned int columnIndex) const
110     {
111         ASSERT(rowIndex < mRows);
112         ASSERT(columnIndex < mCols);
113         return mElements[rowIndex * columns() + columnIndex];
114     }
115 
operator()116     T &operator()(const unsigned int rowIndex, const unsigned int columnIndex)
117     {
118         ASSERT(rowIndex < mRows);
119         ASSERT(columnIndex < mCols);
120         return mElements[rowIndex * columns() + columnIndex];
121     }
122 
at(const unsigned int rowIndex,const unsigned int columnIndex)123     const T &at(const unsigned int rowIndex, const unsigned int columnIndex) const
124     {
125         ASSERT(rowIndex < mRows);
126         ASSERT(columnIndex < mCols);
127         return operator()(rowIndex, columnIndex);
128     }
129 
130     Matrix<T> operator*(const Matrix<T> &m)
131     {
132         ASSERT(columns() == m.rows());
133 
134         unsigned int resultRows = rows();
135         unsigned int resultCols = m.columns();
136         Matrix<T> result(std::vector<T>(resultRows * resultCols), resultRows, resultCols);
137         for (unsigned int i = 0; i < resultRows; i++)
138         {
139             for (unsigned int j = 0; j < resultCols; j++)
140             {
141                 T tmp = 0.0f;
142                 for (unsigned int k = 0; k < columns(); k++)
143                     tmp += at(i, k) * m(k, j);
144                 result(i, j) = tmp;
145             }
146         }
147 
148         return result;
149     }
150 
151     void operator*=(const Matrix<T> &m)
152     {
153         ASSERT(columns() == m.rows());
154         Matrix<T> res  = (*this) * m;
155         size_t numElts = res.elements().size();
156         mElements.resize(numElts);
157         memcpy(mElements.data(), res.data(), numElts * sizeof(float));
158     }
159 
160     bool operator==(const Matrix<T> &m) const
161     {
162         ASSERT(columns() == m.columns());
163         ASSERT(rows() == m.rows());
164         return mElements == m.elements();
165     }
166 
167     bool operator!=(const Matrix<T> &m) const { return !(mElements == m.elements()); }
168 
nearlyEqual(T epsilon,const Matrix<T> & m)169     bool nearlyEqual(T epsilon, const Matrix<T> &m) const
170     {
171         ASSERT(columns() == m.columns());
172         ASSERT(rows() == m.rows());
173         const auto &otherElts = m.elements();
174         for (size_t i = 0; i < otherElts.size(); i++)
175         {
176             if ((mElements[i] - otherElts[i] > epsilon) && (otherElts[i] - mElements[i] > epsilon))
177                 return false;
178         }
179         return true;
180     }
181 
size()182     unsigned int size() const
183     {
184         ASSERT(rows() == columns());
185         return rows();
186     }
187 
rows()188     unsigned int rows() const { return mRows; }
189 
columns()190     unsigned int columns() const { return mCols; }
191 
elements()192     std::vector<T> elements() const { return mElements; }
data()193     T *data() { return mElements.data(); }
constData()194     const T *constData() const { return mElements.data(); }
195 
compMult(const Matrix<T> & mat1)196     Matrix<T> compMult(const Matrix<T> &mat1) const
197     {
198         Matrix result(std::vector<T>(mElements.size()), rows(), columns());
199         for (unsigned int i = 0; i < rows(); i++)
200         {
201             for (unsigned int j = 0; j < columns(); j++)
202             {
203                 T lhs        = at(i, j);
204                 T rhs        = mat1(i, j);
205                 result(i, j) = rhs * lhs;
206             }
207         }
208 
209         return result;
210     }
211 
outerProduct(const Matrix<T> & mat1)212     Matrix<T> outerProduct(const Matrix<T> &mat1) const
213     {
214         unsigned int cols = mat1.columns();
215         Matrix result(std::vector<T>(rows() * cols), rows(), cols);
216         for (unsigned int i = 0; i < rows(); i++)
217             for (unsigned int j = 0; j < cols; j++)
218                 result(i, j) = at(i, 0) * mat1(0, j);
219 
220         return result;
221     }
222 
transpose()223     Matrix<T> transpose() const
224     {
225         Matrix result(std::vector<T>(mElements.size()), columns(), rows());
226         for (unsigned int i = 0; i < columns(); i++)
227             for (unsigned int j = 0; j < rows(); j++)
228                 result(i, j) = at(j, i);
229 
230         return result;
231     }
232 
determinant()233     T determinant() const
234     {
235         ASSERT(rows() == columns());
236 
237         switch (size())
238         {
239             case 2:
240                 return at(0, 0) * at(1, 1) - at(0, 1) * at(1, 0);
241 
242             case 3:
243                 return at(0, 0) * at(1, 1) * at(2, 2) + at(0, 1) * at(1, 2) * at(2, 0) +
244                        at(0, 2) * at(1, 0) * at(2, 1) - at(0, 2) * at(1, 1) * at(2, 0) -
245                        at(0, 1) * at(1, 0) * at(2, 2) - at(0, 0) * at(1, 2) * at(2, 1);
246 
247             case 4:
248             {
249                 const float minorMatrices[4][3 * 3] = {{
250                                                            at(1, 1),
251                                                            at(2, 1),
252                                                            at(3, 1),
253                                                            at(1, 2),
254                                                            at(2, 2),
255                                                            at(3, 2),
256                                                            at(1, 3),
257                                                            at(2, 3),
258                                                            at(3, 3),
259                                                        },
260                                                        {
261                                                            at(1, 0),
262                                                            at(2, 0),
263                                                            at(3, 0),
264                                                            at(1, 2),
265                                                            at(2, 2),
266                                                            at(3, 2),
267                                                            at(1, 3),
268                                                            at(2, 3),
269                                                            at(3, 3),
270                                                        },
271                                                        {
272                                                            at(1, 0),
273                                                            at(2, 0),
274                                                            at(3, 0),
275                                                            at(1, 1),
276                                                            at(2, 1),
277                                                            at(3, 1),
278                                                            at(1, 3),
279                                                            at(2, 3),
280                                                            at(3, 3),
281                                                        },
282                                                        {
283                                                            at(1, 0),
284                                                            at(2, 0),
285                                                            at(3, 0),
286                                                            at(1, 1),
287                                                            at(2, 1),
288                                                            at(3, 1),
289                                                            at(1, 2),
290                                                            at(2, 2),
291                                                            at(3, 2),
292                                                        }};
293                 return at(0, 0) * Matrix<T>(minorMatrices[0], 3).determinant() -
294                        at(0, 1) * Matrix<T>(minorMatrices[1], 3).determinant() +
295                        at(0, 2) * Matrix<T>(minorMatrices[2], 3).determinant() -
296                        at(0, 3) * Matrix<T>(minorMatrices[3], 3).determinant();
297             }
298 
299             default:
300                 UNREACHABLE();
301                 break;
302         }
303 
304         return T();
305     }
306 
inverse()307     Matrix<T> inverse() const
308     {
309         ASSERT(rows() == columns());
310 
311         Matrix<T> coft(std::vector<T>(mElements.size()), rows(), columns());
312         switch (size())
313         {
314             case 2:
315                 coft(0, 0) = at(1, 1);
316                 coft(1, 0) = -at(1, 0);
317                 coft(0, 1) = -at(0, 1);
318                 coft(1, 1) = at(0, 0);
319                 break;
320 
321             case 3:
322                 coft(0, 0) = at(1, 1) * at(2, 2) - at(2, 1) * at(1, 2);
323                 coft(1, 0) = -(at(1, 0) * at(2, 2) - at(2, 0) * at(1, 2));
324                 coft(2, 0) = at(1, 0) * at(2, 1) - at(2, 0) * at(1, 1);
325                 coft(0, 1) = -(at(0, 1) * at(2, 2) - at(2, 1) * at(0, 2));
326                 coft(1, 1) = at(0, 0) * at(2, 2) - at(2, 0) * at(0, 2);
327                 coft(2, 1) = -(at(0, 0) * at(2, 1) - at(2, 0) * at(0, 1));
328                 coft(0, 2) = at(0, 1) * at(1, 2) - at(1, 1) * at(0, 2);
329                 coft(1, 2) = -(at(0, 0) * at(1, 2) - at(1, 0) * at(0, 2));
330                 coft(2, 2) = at(0, 0) * at(1, 1) - at(1, 0) * at(0, 1);
331                 break;
332 
333             case 4:
334                 CofactorTransposed(*this, coft);
335                 break;
336 
337             default:
338                 UNREACHABLE();
339                 break;
340         }
341 
342         // The inverse of A is the transpose of the cofactor matrix times the reciprocal of the
343         // determinant of A.
344         T det = determinant();
345         Matrix<T> result(std::vector<T>(mElements.size()), rows(), columns());
346         for (unsigned int i = 0; i < rows(); i++)
347             for (unsigned int j = 0; j < columns(); j++)
348                 result(i, j) = (det != static_cast<T>(0)) ? coft(i, j) / det : T();
349 
350         return result;
351     }
352 
setToIdentity()353     void setToIdentity()
354     {
355         ASSERT(rows() == columns());
356 
357         const auto one  = T(1);
358         const auto zero = T(0);
359 
360         for (auto &e : mElements)
361             e = zero;
362 
363         for (unsigned int i = 0; i < rows(); ++i)
364         {
365             const auto pos = i * columns() + (i % columns());
366             mElements[pos] = one;
367         }
368     }
369 
370     template <unsigned int Size>
setToIdentity(T (& matrix)[Size])371     static void setToIdentity(T (&matrix)[Size])
372     {
373         static_assert(gl::iSquareRoot<Size>() != 0, "Matrix is not square.");
374 
375         const auto cols = gl::iSquareRoot<Size>();
376         const auto one  = T(1);
377         const auto zero = T(0);
378 
379         for (auto &e : matrix)
380             e = zero;
381 
382         for (unsigned int i = 0; i < cols; ++i)
383         {
384             const auto pos = i * cols + (i % cols);
385             matrix[pos]    = one;
386         }
387     }
388 
389   protected:
390     std::vector<T> mElements;
391     unsigned int mRows;
392     unsigned int mCols;
393 };
394 
395 // Not derived from Matrix<float>: fixed-size std::array instead, to avoid malloc
396 class Mat4
397 {
398   public:
399     Mat4();
400     Mat4(const Matrix<float> generalMatrix);
401     Mat4(const std::vector<float> &elements);
402     Mat4(const float *elements);
403     Mat4(float m00,
404          float m01,
405          float m02,
406          float m03,
407          float m10,
408          float m11,
409          float m12,
410          float m13,
411          float m20,
412          float m21,
413          float m22,
414          float m23,
415          float m30,
416          float m31,
417          float m32,
418          float m33);
419 
420     static Mat4 Rotate(float angle, const Vector3 &axis);
421     static Mat4 Translate(const Vector3 &t);
422     static Mat4 Scale(const Vector3 &s);
423     static Mat4 Frustum(float l, float r, float b, float t, float n, float f);
424     static Mat4 Perspective(float fov, float aspectRatio, float n, float f);
425     static Mat4 Ortho(float l, float r, float b, float t, float n, float f);
426 
427     Mat4 product(const Mat4 &m);
428     Vector4 product(const Vector4 &b);
429     void dump();
430 
data()431     float *data() { return mElements.data(); }
constData()432     const float *constData() const { return mElements.data(); }
433 
operator()434     float operator()(const unsigned int rowIndex, const unsigned int columnIndex) const
435     {
436         ASSERT(rowIndex < 4);
437         ASSERT(columnIndex < 4);
438         return mElements[rowIndex * 4 + columnIndex];
439     }
440 
operator()441     float &operator()(const unsigned int rowIndex, const unsigned int columnIndex)
442     {
443         ASSERT(rowIndex < 4);
444         ASSERT(columnIndex < 4);
445         return mElements[rowIndex * 4 + columnIndex];
446     }
447 
at(const unsigned int rowIndex,const unsigned int columnIndex)448     float at(const unsigned int rowIndex, const unsigned int columnIndex) const
449     {
450         ASSERT(rowIndex < 4);
451         ASSERT(columnIndex < 4);
452         return operator()(rowIndex, columnIndex);
453     }
454 
455     bool operator==(const Mat4 &m) const { return mElements == m.elements(); }
456 
nearlyEqual(float epsilon,const Mat4 & m)457     bool nearlyEqual(float epsilon, const Mat4 &m) const
458     {
459         const auto &otherElts = m.elements();
460         for (size_t i = 0; i < otherElts.size(); i++)
461         {
462             if ((mElements[i] - otherElts[i] > epsilon) && (otherElts[i] - mElements[i] > epsilon))
463                 return false;
464         }
465         return true;
466     }
467 
elements()468     const std::array<float, 4 * 4> &elements() const { return mElements; }
469 
transpose()470     Mat4 transpose() const
471     {
472         Mat4 result;
473         for (unsigned int i = 0; i < 4; i++)
474             for (unsigned int j = 0; j < 4; j++)
475                 result(i, j) = at(j, i);
476 
477         return result;
478     }
479 
inverse()480     Mat4 inverse() const
481     {
482         Mat4 coft;
483         CofactorTransposed(*this, coft);
484 
485         float det = at(0, 0) * coft(0, 0) + at(0, 1) * coft(1, 0) + at(0, 2) * coft(2, 0) +
486                     at(0, 3) * coft(3, 0);
487 
488         Mat4 result = coft;
489         for (int i = 0; i < 16; i++)
490         {
491             result.data()[i] /= det;
492         }
493 
494         return result;
495     }
496 
497   private:
498     std::array<float, 4 * 4> mElements;
499 };
500 
501 }  // namespace angle
502 
503 #endif  // COMMON_MATRIX_UTILS_H_
504