xref: /aosp_15_r20/external/eigen/blas/BandTriangularSolver.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1*bf2c3715SXin Li // This file is part of Eigen, a lightweight C++ template library
2*bf2c3715SXin Li // for linear algebra.
3*bf2c3715SXin Li //
4*bf2c3715SXin Li // Copyright (C) 2011 Gael Guennebaud <[email protected]>
5*bf2c3715SXin Li //
6*bf2c3715SXin Li // This Source Code Form is subject to the terms of the Mozilla
7*bf2c3715SXin Li // Public License v. 2.0. If a copy of the MPL was not distributed
8*bf2c3715SXin Li // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9*bf2c3715SXin Li 
10*bf2c3715SXin Li #ifndef EIGEN_BAND_TRIANGULARSOLVER_H
11*bf2c3715SXin Li #define EIGEN_BAND_TRIANGULARSOLVER_H
12*bf2c3715SXin Li 
13*bf2c3715SXin Li namespace internal {
14*bf2c3715SXin Li 
15*bf2c3715SXin Li  /* \internal
16*bf2c3715SXin Li   * Solve Ax=b with A a band triangular matrix
17*bf2c3715SXin Li   * TODO: extend it to matrices for x abd b */
18*bf2c3715SXin Li template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, int StorageOrder>
19*bf2c3715SXin Li struct band_solve_triangular_selector;
20*bf2c3715SXin Li 
21*bf2c3715SXin Li 
22*bf2c3715SXin Li template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar>
23*bf2c3715SXin Li struct band_solve_triangular_selector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,RowMajor>
24*bf2c3715SXin Li {
25*bf2c3715SXin Li   typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
26*bf2c3715SXin Li   typedef Map<Matrix<RhsScalar,Dynamic,1> > RhsMap;
27*bf2c3715SXin Li   enum { IsLower = (Mode&Lower) ? 1 : 0 };
28*bf2c3715SXin Li   static void run(Index size, Index k, const LhsScalar* _lhs, Index lhsStride, RhsScalar* _other)
29*bf2c3715SXin Li   {
30*bf2c3715SXin Li     const LhsMap lhs(_lhs,size,k+1,OuterStride<>(lhsStride));
31*bf2c3715SXin Li     RhsMap other(_other,size,1);
32*bf2c3715SXin Li     typename internal::conditional<
33*bf2c3715SXin Li                           ConjLhs,
34*bf2c3715SXin Li                           const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
35*bf2c3715SXin Li                           const LhsMap&>
36*bf2c3715SXin Li                         ::type cjLhs(lhs);
37*bf2c3715SXin Li 
38*bf2c3715SXin Li     for(int col=0 ; col<other.cols() ; ++col)
39*bf2c3715SXin Li     {
40*bf2c3715SXin Li       for(int ii=0; ii<size; ++ii)
41*bf2c3715SXin Li       {
42*bf2c3715SXin Li         int i = IsLower ? ii : size-ii-1;
43*bf2c3715SXin Li         int actual_k = (std::min)(k,ii);
44*bf2c3715SXin Li         int actual_start = IsLower ? k-actual_k : 1;
45*bf2c3715SXin Li 
46*bf2c3715SXin Li         if(actual_k>0)
47*bf2c3715SXin Li           other.coeffRef(i,col) -= cjLhs.row(i).segment(actual_start,actual_k).transpose()
48*bf2c3715SXin Li                                   .cwiseProduct(other.col(col).segment(IsLower ? i-actual_k : i+1,actual_k)).sum();
49*bf2c3715SXin Li 
50*bf2c3715SXin Li         if((Mode&UnitDiag)==0)
51*bf2c3715SXin Li           other.coeffRef(i,col) /= cjLhs(i,IsLower ? k : 0);
52*bf2c3715SXin Li       }
53*bf2c3715SXin Li     }
54*bf2c3715SXin Li   }
55*bf2c3715SXin Li 
56*bf2c3715SXin Li };
57*bf2c3715SXin Li 
58*bf2c3715SXin Li template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar>
59*bf2c3715SXin Li struct band_solve_triangular_selector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ColMajor>
60*bf2c3715SXin Li {
61*bf2c3715SXin Li   typedef Map<const Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
62*bf2c3715SXin Li   typedef Map<Matrix<RhsScalar,Dynamic,1> > RhsMap;
63*bf2c3715SXin Li   enum { IsLower = (Mode&Lower) ? 1 : 0 };
64*bf2c3715SXin Li   static void run(Index size, Index k, const LhsScalar* _lhs, Index lhsStride, RhsScalar* _other)
65*bf2c3715SXin Li   {
66*bf2c3715SXin Li     const LhsMap lhs(_lhs,k+1,size,OuterStride<>(lhsStride));
67*bf2c3715SXin Li     RhsMap other(_other,size,1);
68*bf2c3715SXin Li     typename internal::conditional<
69*bf2c3715SXin Li                           ConjLhs,
70*bf2c3715SXin Li                           const CwiseUnaryOp<typename internal::scalar_conjugate_op<LhsScalar>,LhsMap>,
71*bf2c3715SXin Li                           const LhsMap&>
72*bf2c3715SXin Li                         ::type cjLhs(lhs);
73*bf2c3715SXin Li 
74*bf2c3715SXin Li     for(int col=0 ; col<other.cols() ; ++col)
75*bf2c3715SXin Li     {
76*bf2c3715SXin Li       for(int ii=0; ii<size; ++ii)
77*bf2c3715SXin Li       {
78*bf2c3715SXin Li         int i = IsLower ? ii : size-ii-1;
79*bf2c3715SXin Li         int actual_k = (std::min)(k,size-ii-1);
80*bf2c3715SXin Li         int actual_start = IsLower ? 1 : k-actual_k;
81*bf2c3715SXin Li 
82*bf2c3715SXin Li         if((Mode&UnitDiag)==0)
83*bf2c3715SXin Li           other.coeffRef(i,col) /= cjLhs(IsLower ? 0 : k, i);
84*bf2c3715SXin Li 
85*bf2c3715SXin Li         if(actual_k>0)
86*bf2c3715SXin Li           other.col(col).segment(IsLower ? i+1 : i-actual_k, actual_k)
87*bf2c3715SXin Li               -= other.coeff(i,col) * cjLhs.col(i).segment(actual_start,actual_k);
88*bf2c3715SXin Li 
89*bf2c3715SXin Li       }
90*bf2c3715SXin Li     }
91*bf2c3715SXin Li   }
92*bf2c3715SXin Li };
93*bf2c3715SXin Li 
94*bf2c3715SXin Li 
95*bf2c3715SXin Li } // end namespace internal
96*bf2c3715SXin Li 
97*bf2c3715SXin Li #endif // EIGEN_BAND_TRIANGULARSOLVER_H
98