xref: /aosp_15_r20/external/eigen/Eigen/src/Core/products/TriangularMatrixMatrix_BLAS.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 /*
2  Copyright (c) 2011, Intel Corporation. All rights reserved.
3 
4  Redistribution and use in source and binary forms, with or without modification,
5  are permitted provided that the following conditions are met:
6 
7  * Redistributions of source code must retain the above copyright notice, this
8    list of conditions and the following disclaimer.
9  * Redistributions in binary form must reproduce the above copyright notice,
10    this list of conditions and the following disclaimer in the documentation
11    and/or other materials provided with the distribution.
12  * Neither the name of Intel Corporation nor the names of its contributors may
13    be used to endorse or promote products derived from this software without
14    specific prior written permission.
15 
16  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20  ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
23  ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 
27  ********************************************************************************
28  *   Content : Eigen bindings to BLAS F77
29  *   Triangular matrix * matrix product functionality based on ?TRMM.
30  ********************************************************************************
31 */
32 
33 #ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
34 #define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
35 
36 namespace Eigen {
37 
38 namespace internal {
39 
40 
41 template <typename Scalar, typename Index,
42           int Mode, bool LhsIsTriangular,
43           int LhsStorageOrder, bool ConjugateLhs,
44           int RhsStorageOrder, bool ConjugateRhs,
45           int ResStorageOrder>
46 struct product_triangular_matrix_matrix_trmm :
47        product_triangular_matrix_matrix<Scalar,Index,Mode,
48           LhsIsTriangular,LhsStorageOrder,ConjugateLhs,
49           RhsStorageOrder, ConjugateRhs, ResStorageOrder, 1, BuiltIn> {};
50 
51 
52 // try to go to BLAS specialization
53 #define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
54 template <typename Index, int Mode, \
55           int LhsStorageOrder, bool ConjugateLhs, \
56           int RhsStorageOrder, bool ConjugateRhs> \
57 struct product_triangular_matrix_matrix<Scalar,Index, Mode, LhsIsTriangular, \
58            LhsStorageOrder,ConjugateLhs, RhsStorageOrder,ConjugateRhs,ColMajor,1,Specialized> { \
59   static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride,\
60     const Scalar* _rhs, Index rhsStride, Scalar* res, Index resIncr, Index resStride, Scalar alpha, level3_blocking<Scalar,Scalar>& blocking) { \
61       EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
62       eigen_assert(resIncr == 1); \
63       product_triangular_matrix_matrix_trmm<Scalar,Index,Mode, \
64         LhsIsTriangular,LhsStorageOrder,ConjugateLhs, \
65         RhsStorageOrder, ConjugateRhs, ColMajor>::run( \
66           _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, resStride, alpha, blocking); \
67   } \
68 };
69 
70 EIGEN_BLAS_TRMM_SPECIALIZE(double, true)
71 EIGEN_BLAS_TRMM_SPECIALIZE(double, false)
72 EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, true)
73 EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, false)
74 EIGEN_BLAS_TRMM_SPECIALIZE(float, true)
75 EIGEN_BLAS_TRMM_SPECIALIZE(float, false)
76 EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, true)
77 EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, false)
78 
79 // implements col-major += alpha * op(triangular) * op(general)
80 #define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
81 template <typename Index, int Mode, \
82           int LhsStorageOrder, bool ConjugateLhs, \
83           int RhsStorageOrder, bool ConjugateRhs> \
84 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,true, \
85          LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
86 { \
87   enum { \
88     IsLower = (Mode&Lower) == Lower, \
89     SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
90     IsUnitDiag  = (Mode&UnitDiag) ? 1 : 0, \
91     IsZeroDiag  = (Mode&ZeroDiag) ? 1 : 0, \
92     LowUp = IsLower ? Lower : Upper, \
93     conjA = ((LhsStorageOrder==ColMajor) && ConjugateLhs) ? 1 : 0 \
94   }; \
95 \
96   static void run( \
97     Index _rows, Index _cols, Index _depth, \
98     const EIGTYPE* _lhs, Index lhsStride, \
99     const EIGTYPE* _rhs, Index rhsStride, \
100     EIGTYPE* res,        Index resStride, \
101     EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
102   { \
103    Index diagSize  = (std::min)(_rows,_depth); \
104    Index rows      = IsLower ? _rows : diagSize; \
105    Index depth     = IsLower ? diagSize : _depth; \
106    Index cols      = _cols; \
107 \
108    typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
109    typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
110 \
111 /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
112    if (rows != depth) { \
113 \
114      /* FIXME handle mkl_domain_get_max_threads */ \
115      /*int nthr = mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS);*/ int nthr = 1;\
116 \
117      if (((nthr==1) && (((std::max)(rows,depth)-diagSize)/(double)diagSize < 0.5))) { \
118      /* Most likely no benefit to call TRMM or GEMM from BLAS */ \
119        product_triangular_matrix_matrix<EIGTYPE,Index,Mode,true, \
120        LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \
121            _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \
122      /*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \
123      } else { \
124      /* Make sense to call GEMM */ \
125        Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs,rows,depth,OuterStride<>(lhsStride)); \
126        MatrixLhs aa_tmp=lhsMap.template triangularView<Mode>(); \
127        BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
128        gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
129        general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
130        rows, cols, depth, aa_tmp.data(), aStride, _rhs, rhsStride, res, 1, resStride, alpha, gemm_blocking, 0); \
131 \
132      /*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
133      } \
134      return; \
135    } \
136    char side = 'L', transa, uplo, diag = 'N'; \
137    EIGTYPE *b; \
138    const EIGTYPE *a; \
139    BlasIndex m, n, lda, ldb; \
140 \
141 /* Set m, n */ \
142    m = convert_index<BlasIndex>(diagSize); \
143    n = convert_index<BlasIndex>(cols); \
144 \
145 /* Set trans */ \
146    transa = (LhsStorageOrder==RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
147 \
148 /* Set b, ldb */ \
149    Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols,OuterStride<>(rhsStride)); \
150    MatrixX##EIGPREFIX b_tmp; \
151 \
152    if (ConjugateRhs) b_tmp = rhs.conjugate(); else b_tmp = rhs; \
153    b = b_tmp.data(); \
154    ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
155 \
156 /* Set uplo */ \
157    uplo = IsLower ? 'L' : 'U'; \
158    if (LhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
159 /* Set a, lda */ \
160    Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
161    MatrixLhs a_tmp; \
162 \
163    if ((conjA!=0) || (SetDiag==0)) { \
164      if (conjA) a_tmp = lhs.conjugate(); else a_tmp = lhs; \
165      if (IsZeroDiag) \
166        a_tmp.diagonal().setZero(); \
167      else if (IsUnitDiag) \
168        a_tmp.diagonal().setOnes();\
169      a = a_tmp.data(); \
170      lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
171    } else { \
172      a = _lhs; \
173      lda = convert_index<BlasIndex>(lhsStride); \
174    } \
175    /*std::cout << "TRMM_L: A is square! Go to BLAS TRMM implementation! \n";*/ \
176 /* call ?trmm*/ \
177    BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
178 \
179 /* Add op(a_triangular)*b into res*/ \
180    Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
181    res_tmp=res_tmp+b_tmp; \
182   } \
183 };
184 
185 #ifdef EIGEN_USE_MKL
186 EIGEN_BLAS_TRMM_L(double, double, d, dtrmm)
187 EIGEN_BLAS_TRMM_L(dcomplex, MKL_Complex16, cd, ztrmm)
188 EIGEN_BLAS_TRMM_L(float, float, f, strmm)
189 EIGEN_BLAS_TRMM_L(scomplex, MKL_Complex8, cf, ctrmm)
190 #else
191 EIGEN_BLAS_TRMM_L(double, double, d, dtrmm_)
192 EIGEN_BLAS_TRMM_L(dcomplex, double, cd, ztrmm_)
193 EIGEN_BLAS_TRMM_L(float, float, f, strmm_)
194 EIGEN_BLAS_TRMM_L(scomplex, float, cf, ctrmm_)
195 #endif
196 
197 // implements col-major += alpha * op(general) * op(triangular)
198 #define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
199 template <typename Index, int Mode, \
200           int LhsStorageOrder, bool ConjugateLhs, \
201           int RhsStorageOrder, bool ConjugateRhs> \
202 struct product_triangular_matrix_matrix_trmm<EIGTYPE,Index,Mode,false, \
203          LhsStorageOrder,ConjugateLhs,RhsStorageOrder,ConjugateRhs,ColMajor> \
204 { \
205   enum { \
206     IsLower = (Mode&Lower) == Lower, \
207     SetDiag = (Mode&(ZeroDiag|UnitDiag)) ? 0 : 1, \
208     IsUnitDiag  = (Mode&UnitDiag) ? 1 : 0, \
209     IsZeroDiag  = (Mode&ZeroDiag) ? 1 : 0, \
210     LowUp = IsLower ? Lower : Upper, \
211     conjA = ((RhsStorageOrder==ColMajor) && ConjugateRhs) ? 1 : 0 \
212   }; \
213 \
214   static void run( \
215     Index _rows, Index _cols, Index _depth, \
216     const EIGTYPE* _lhs, Index lhsStride, \
217     const EIGTYPE* _rhs, Index rhsStride, \
218     EIGTYPE* res,        Index resStride, \
219     EIGTYPE alpha, level3_blocking<EIGTYPE,EIGTYPE>& blocking) \
220   { \
221    Index diagSize  = (std::min)(_cols,_depth); \
222    Index rows      = _rows; \
223    Index depth     = IsLower ? _depth : diagSize; \
224    Index cols      = IsLower ? diagSize : _cols; \
225 \
226    typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
227    typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
228 \
229 /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
230    if (cols != depth) { \
231 \
232      int nthr = 1 /*mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS)*/; \
233 \
234      if ((nthr==1) && (((std::max)(cols,depth)-diagSize)/(double)diagSize < 0.5)) { \
235      /* Most likely no benefit to call TRMM or GEMM from BLAS*/ \
236        product_triangular_matrix_matrix<EIGTYPE,Index,Mode,false, \
237        LhsStorageOrder,ConjugateLhs, RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run( \
238            _rows, _cols, _depth, _lhs, lhsStride, _rhs, rhsStride, res, 1, resStride, alpha, blocking); \
239        /*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \
240      } else { \
241      /* Make sense to call GEMM */ \
242        Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs,depth,cols, OuterStride<>(rhsStride)); \
243        MatrixRhs aa_tmp=rhsMap.template triangularView<Mode>(); \
244        BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
245        gemm_blocking_space<ColMajor,EIGTYPE,EIGTYPE,Dynamic,Dynamic,Dynamic> gemm_blocking(_rows,_cols,_depth, 1, true); \
246        general_matrix_matrix_product<Index,EIGTYPE,LhsStorageOrder,ConjugateLhs,EIGTYPE,RhsStorageOrder,ConjugateRhs,ColMajor,1>::run( \
247        rows, cols, depth, _lhs, lhsStride, aa_tmp.data(), aStride, res, 1, resStride, alpha, gemm_blocking, 0); \
248 \
249      /*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
250      } \
251      return; \
252    } \
253    char side = 'R', transa, uplo, diag = 'N'; \
254    EIGTYPE *b; \
255    const EIGTYPE *a; \
256    BlasIndex m, n, lda, ldb; \
257 \
258 /* Set m, n */ \
259    m = convert_index<BlasIndex>(rows); \
260    n = convert_index<BlasIndex>(diagSize); \
261 \
262 /* Set trans */ \
263    transa = (RhsStorageOrder==RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
264 \
265 /* Set b, ldb */ \
266    Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs,rows,depth,OuterStride<>(lhsStride)); \
267    MatrixX##EIGPREFIX b_tmp; \
268 \
269    if (ConjugateLhs) b_tmp = lhs.conjugate(); else b_tmp = lhs; \
270    b = b_tmp.data(); \
271    ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
272 \
273 /* Set uplo */ \
274    uplo = IsLower ? 'L' : 'U'; \
275    if (RhsStorageOrder==RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
276 /* Set a, lda */ \
277    Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs,depth,cols, OuterStride<>(rhsStride)); \
278    MatrixRhs a_tmp; \
279 \
280    if ((conjA!=0) || (SetDiag==0)) { \
281      if (conjA) a_tmp = rhs.conjugate(); else a_tmp = rhs; \
282      if (IsZeroDiag) \
283        a_tmp.diagonal().setZero(); \
284      else if (IsUnitDiag) \
285        a_tmp.diagonal().setOnes();\
286      a = a_tmp.data(); \
287      lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
288    } else { \
289      a = _rhs; \
290      lda = convert_index<BlasIndex>(rhsStride); \
291    } \
292    /*std::cout << "TRMM_R: A is square! Go to BLAS TRMM implementation! \n";*/ \
293 /* call ?trmm*/ \
294    BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, (BLASTYPE*)b, &ldb); \
295 \
296 /* Add op(a_triangular)*b into res*/ \
297    Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res,rows,cols,OuterStride<>(resStride)); \
298    res_tmp=res_tmp+b_tmp; \
299   } \
300 };
301 
302 #ifdef EIGEN_USE_MKL
303 EIGEN_BLAS_TRMM_R(double, double, d, dtrmm)
304 EIGEN_BLAS_TRMM_R(dcomplex, MKL_Complex16, cd, ztrmm)
305 EIGEN_BLAS_TRMM_R(float, float, f, strmm)
306 EIGEN_BLAS_TRMM_R(scomplex, MKL_Complex8, cf, ctrmm)
307 #else
308 EIGEN_BLAS_TRMM_R(double, double, d, dtrmm_)
309 EIGEN_BLAS_TRMM_R(dcomplex, double, cd, ztrmm_)
310 EIGEN_BLAS_TRMM_R(float, float, f, strmm_)
311 EIGEN_BLAS_TRMM_R(scomplex, float, cf, ctrmm_)
312 #endif
313 } // end namespace internal
314 
315 } // end namespace Eigen
316 
317 #endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
318