xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/service/gpu/precompiled_kernels.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
28 #include "tensorflow/stream_executor/blas.h"
29 #include "tensorflow/stream_executor/device_memory.h"
30 
31 namespace xla {
32 namespace gpu {
33 
TriangularSolveThunk(ThunkInfo thunk_info,const TriangularSolveOptions & options,se::GpuAsmOpts asm_opts,const BufferAllocation::Slice & a_buffer,const BufferAllocation::Slice & b_buffer,const BufferAllocation::Slice & temp_buffer,PrimitiveType type,int64_t batch_size,int64_t m,int64_t n,int64_t a_batch_stride,int64_t b_batch_stride)34 TriangularSolveThunk::TriangularSolveThunk(
35     ThunkInfo thunk_info, const TriangularSolveOptions& options,
36     se::GpuAsmOpts asm_opts,  //
37     const BufferAllocation::Slice& a_buffer,
38     const BufferAllocation::Slice& b_buffer,
39     const BufferAllocation::Slice& temp_buffer,  //
40     PrimitiveType type, int64_t batch_size, int64_t m, int64_t n,
41     int64_t a_batch_stride, int64_t b_batch_stride)
42     : Thunk(Kind::kTriangularSolve, thunk_info),
43       asm_opts_(asm_opts),
44       uplo_(options.lower() ? se::blas::UpperLower::kLower
45                             : se::blas::UpperLower::kUpper),
46       side_(options.left_side() ? se::blas::Side::kLeft
47                                 : se::blas::Side::kRight),
48       unit_diagonal_(options.unit_diagonal() ? se::blas::Diagonal::kUnit
49                                              : se::blas::Diagonal::kNonUnit),
50       a_buffer_(a_buffer),
51       b_buffer_(b_buffer),
52       temp_buffer_(temp_buffer),
53       type_(type),
54       batch_size_(batch_size),
55       m_(m),
56       n_(n),
57       a_batch_stride_(a_batch_stride),
58       b_batch_stride_(b_batch_stride) {
59   transpose_a_ = [&] {
60     switch (options.transpose_a()) {
61       case TriangularSolveOptions::NO_TRANSPOSE:
62         return se::blas::Transpose::kNoTranspose;
63       case TriangularSolveOptions::TRANSPOSE:
64         return se::blas::Transpose::kTranspose;
65       case TriangularSolveOptions::ADJOINT:
66         return se::blas::Transpose::kConjugateTranspose;
67       default:
68         LOG(ERROR) << "Invalid triangular solve transpose value "
69                    << options.transpose_a();
70         return se::blas::Transpose::kNoTranspose;
71     }
72   }();
73 }
74 
ExecuteOnStream(const ExecuteParams & params)75 Status TriangularSolveThunk::ExecuteOnStream(const ExecuteParams& params) {
76   auto& buffer_allocations = *params.buffer_allocations;
77   return RunTriangulatSolve(buffer_allocations.GetDeviceAddress(a_buffer_),
78                             buffer_allocations.GetDeviceAddress(b_buffer_),
79                             buffer_allocations.GetDeviceAddress(temp_buffer_),
80                             asm_opts_, uplo_, side_, unit_diagonal_,
81                             transpose_a_, type_, batch_size_, m_, n_,
82                             a_batch_stride_, b_batch_stride_, params.stream);
83 }
84 
RunTriangulatSolve(se::DeviceMemoryBase a_data,se::DeviceMemoryBase b_data,se::DeviceMemoryBase temp_data,se::GpuAsmOpts asm_opts,se::blas::UpperLower uplo,se::blas::Side side,se::blas::Diagonal unit_diagonal,se::blas::Transpose transpose_a,PrimitiveType type,int64_t batch_size,int64_t m,int64_t n,int64_t a_batch_stride,int64_t b_batch_stride,se::Stream * stream)85 Status RunTriangulatSolve(se::DeviceMemoryBase a_data,
86                           se::DeviceMemoryBase b_data,
87                           se::DeviceMemoryBase temp_data,
88                           se::GpuAsmOpts asm_opts, se::blas::UpperLower uplo,
89                           se::blas::Side side, se::blas::Diagonal unit_diagonal,
90                           se::blas::Transpose transpose_a, PrimitiveType type,
91                           int64_t batch_size, int64_t m, int64_t n,
92                           int64_t a_batch_stride, int64_t b_batch_stride,
93                           se::Stream* stream) {
94   VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo)
95           << " side=" << se::blas::SideString(side)
96           << " diagonal=" << se::blas::DiagonalString(unit_diagonal)
97           << " batch_size=" << batch_size << " m=" << m << " n=" << n
98           << " a_batch_stride=" << a_batch_stride
99           << " b_batch_stride=" << b_batch_stride;
100 
101   const int lda = side == se::blas::Side::kLeft ? m : n;
102   const int ldb = m;
103 
104   bool launch_ok;
105   if (batch_size == 1) {
106     switch (type) {
107       case F32: {
108         se::DeviceMemory<float> b_data_typed(b_data);
109         launch_ok =
110             stream
111                 ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n,
112                                /*alpha=*/1.0f, se::DeviceMemory<float>(a_data),
113                                lda, &b_data_typed, ldb)
114                 .ok();
115         break;
116       }
117       case F64: {
118         se::DeviceMemory<double> b_data_typed(b_data);
119         launch_ok =
120             stream
121                 ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n,
122                                /*alpha=*/1.0, se::DeviceMemory<double>(a_data),
123                                lda, &b_data_typed, ldb)
124                 .ok();
125         break;
126       }
127       case C64: {
128         se::DeviceMemory<std::complex<float>> b_data_typed(b_data);
129         launch_ok =
130             stream
131                 ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n,
132                                /*alpha=*/1.0f,
133                                se::DeviceMemory<std::complex<float>>(a_data),
134                                lda, &b_data_typed, ldb)
135                 .ok();
136         break;
137       }
138       case C128: {
139         se::DeviceMemory<std::complex<double>> b_data_typed(b_data);
140         launch_ok =
141             stream
142                 ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n,
143                                /*alpha=*/1.0,
144                                se::DeviceMemory<std::complex<double>>(a_data),
145                                lda, &b_data_typed, ldb)
146                 .ok();
147         break;
148       }
149       default:
150         return InvalidArgument("Invalid type for triangular solve %d", type);
151     }
152   } else {
153     // cublas trsmBatched requires us to materialize out two arrays of
154     // batch_size_ pointers, pointing to the individual `a` and `b` matrices of
155     // our input.  batch_pointers_bytes is the size in bytes of one of these
156     // arrays.
157     int64_t batch_pointers_bytes = sizeof(void*) * batch_size;
158     TF_RET_CHECK(temp_data.size() >= 2 * batch_pointers_bytes);
159     void** temp_base = reinterpret_cast<void**>(temp_data.opaque());
160     se::DeviceMemoryBase a_pointers(temp_base, batch_pointers_bytes);
161     se::DeviceMemoryBase b_pointers(temp_base + batch_size,
162                                     batch_pointers_bytes);
163 
164     TF_RETURN_IF_ERROR(MakeBatchPointers(
165         stream, asm_opts, a_data, a_batch_stride, batch_size, a_pointers));
166     TF_RETURN_IF_ERROR(MakeBatchPointers(
167         stream, asm_opts, b_data, b_batch_stride, batch_size, b_pointers));
168 
169     switch (type) {
170       case F32: {
171         se::DeviceMemory<float*> typed_b_pointers(b_pointers);
172         launch_ok =
173             stream
174                 ->ThenBlasTrsmBatched(side, uplo, transpose_a, unit_diagonal, m,
175                                       n, /*alpha=*/1.0f,
176                                       se::DeviceMemory<float*>(a_pointers), lda,
177                                       &typed_b_pointers, ldb, batch_size)
178                 .ok();
179         break;
180       }
181       case F64: {
182         se::DeviceMemory<double*> typed_b_pointers(b_pointers);
183         launch_ok =
184             stream
185                 ->ThenBlasTrsmBatched(side, uplo, transpose_a, unit_diagonal, m,
186                                       n, /*alpha=*/1.0f,
187                                       se::DeviceMemory<double*>(a_pointers),
188                                       lda, &typed_b_pointers, ldb, batch_size)
189                 .ok();
190         break;
191       }
192       case C64: {
193         se::DeviceMemory<std::complex<float>*> typed_b_pointers(b_pointers);
194         launch_ok = stream
195                         ->ThenBlasTrsmBatched(
196                             side, uplo, transpose_a, unit_diagonal, m, n,
197                             /*alpha=*/1.0f,
198                             se::DeviceMemory<std::complex<float>*>(a_pointers),
199                             lda, &typed_b_pointers, ldb, batch_size)
200                         .ok();
201         break;
202       }
203       case C128: {
204         se::DeviceMemory<std::complex<double>*> typed_b_pointers(b_pointers);
205         launch_ok = stream
206                         ->ThenBlasTrsmBatched(
207                             side, uplo, transpose_a, unit_diagonal, m, n,
208                             /*alpha=*/1.0f,
209                             se::DeviceMemory<std::complex<double>*>(a_pointers),
210                             lda, &typed_b_pointers, ldb, batch_size)
211                         .ok();
212         break;
213       }
214       default:
215         return InvalidArgument("Invalid type for triangular solve %d", type);
216     }
217   }
218 
219   if (!launch_ok) {
220     return InternalError("Unable to launch triangular solve");
221   }
222   return OkStatus();
223 }
224 
225 }  // namespace gpu
226 }  // namespace xla
227