1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/service/gpu/precompiled_kernels.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
28 #include "tensorflow/stream_executor/blas.h"
29 #include "tensorflow/stream_executor/device_memory.h"
30
31 namespace xla {
32 namespace gpu {
33
TriangularSolveThunk(ThunkInfo thunk_info,const TriangularSolveOptions & options,se::GpuAsmOpts asm_opts,const BufferAllocation::Slice & a_buffer,const BufferAllocation::Slice & b_buffer,const BufferAllocation::Slice & temp_buffer,PrimitiveType type,int64_t batch_size,int64_t m,int64_t n,int64_t a_batch_stride,int64_t b_batch_stride)34 TriangularSolveThunk::TriangularSolveThunk(
35 ThunkInfo thunk_info, const TriangularSolveOptions& options,
36 se::GpuAsmOpts asm_opts, //
37 const BufferAllocation::Slice& a_buffer,
38 const BufferAllocation::Slice& b_buffer,
39 const BufferAllocation::Slice& temp_buffer, //
40 PrimitiveType type, int64_t batch_size, int64_t m, int64_t n,
41 int64_t a_batch_stride, int64_t b_batch_stride)
42 : Thunk(Kind::kTriangularSolve, thunk_info),
43 asm_opts_(asm_opts),
44 uplo_(options.lower() ? se::blas::UpperLower::kLower
45 : se::blas::UpperLower::kUpper),
46 side_(options.left_side() ? se::blas::Side::kLeft
47 : se::blas::Side::kRight),
48 unit_diagonal_(options.unit_diagonal() ? se::blas::Diagonal::kUnit
49 : se::blas::Diagonal::kNonUnit),
50 a_buffer_(a_buffer),
51 b_buffer_(b_buffer),
52 temp_buffer_(temp_buffer),
53 type_(type),
54 batch_size_(batch_size),
55 m_(m),
56 n_(n),
57 a_batch_stride_(a_batch_stride),
58 b_batch_stride_(b_batch_stride) {
59 transpose_a_ = [&] {
60 switch (options.transpose_a()) {
61 case TriangularSolveOptions::NO_TRANSPOSE:
62 return se::blas::Transpose::kNoTranspose;
63 case TriangularSolveOptions::TRANSPOSE:
64 return se::blas::Transpose::kTranspose;
65 case TriangularSolveOptions::ADJOINT:
66 return se::blas::Transpose::kConjugateTranspose;
67 default:
68 LOG(ERROR) << "Invalid triangular solve transpose value "
69 << options.transpose_a();
70 return se::blas::Transpose::kNoTranspose;
71 }
72 }();
73 }
74
ExecuteOnStream(const ExecuteParams & params)75 Status TriangularSolveThunk::ExecuteOnStream(const ExecuteParams& params) {
76 auto& buffer_allocations = *params.buffer_allocations;
77 return RunTriangulatSolve(buffer_allocations.GetDeviceAddress(a_buffer_),
78 buffer_allocations.GetDeviceAddress(b_buffer_),
79 buffer_allocations.GetDeviceAddress(temp_buffer_),
80 asm_opts_, uplo_, side_, unit_diagonal_,
81 transpose_a_, type_, batch_size_, m_, n_,
82 a_batch_stride_, b_batch_stride_, params.stream);
83 }
84
RunTriangulatSolve(se::DeviceMemoryBase a_data,se::DeviceMemoryBase b_data,se::DeviceMemoryBase temp_data,se::GpuAsmOpts asm_opts,se::blas::UpperLower uplo,se::blas::Side side,se::blas::Diagonal unit_diagonal,se::blas::Transpose transpose_a,PrimitiveType type,int64_t batch_size,int64_t m,int64_t n,int64_t a_batch_stride,int64_t b_batch_stride,se::Stream * stream)85 Status RunTriangulatSolve(se::DeviceMemoryBase a_data,
86 se::DeviceMemoryBase b_data,
87 se::DeviceMemoryBase temp_data,
88 se::GpuAsmOpts asm_opts, se::blas::UpperLower uplo,
89 se::blas::Side side, se::blas::Diagonal unit_diagonal,
90 se::blas::Transpose transpose_a, PrimitiveType type,
91 int64_t batch_size, int64_t m, int64_t n,
92 int64_t a_batch_stride, int64_t b_batch_stride,
93 se::Stream* stream) {
94 VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo)
95 << " side=" << se::blas::SideString(side)
96 << " diagonal=" << se::blas::DiagonalString(unit_diagonal)
97 << " batch_size=" << batch_size << " m=" << m << " n=" << n
98 << " a_batch_stride=" << a_batch_stride
99 << " b_batch_stride=" << b_batch_stride;
100
101 const int lda = side == se::blas::Side::kLeft ? m : n;
102 const int ldb = m;
103
104 bool launch_ok;
105 if (batch_size == 1) {
106 switch (type) {
107 case F32: {
108 se::DeviceMemory<float> b_data_typed(b_data);
109 launch_ok =
110 stream
111 ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n,
112 /*alpha=*/1.0f, se::DeviceMemory<float>(a_data),
113 lda, &b_data_typed, ldb)
114 .ok();
115 break;
116 }
117 case F64: {
118 se::DeviceMemory<double> b_data_typed(b_data);
119 launch_ok =
120 stream
121 ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n,
122 /*alpha=*/1.0, se::DeviceMemory<double>(a_data),
123 lda, &b_data_typed, ldb)
124 .ok();
125 break;
126 }
127 case C64: {
128 se::DeviceMemory<std::complex<float>> b_data_typed(b_data);
129 launch_ok =
130 stream
131 ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n,
132 /*alpha=*/1.0f,
133 se::DeviceMemory<std::complex<float>>(a_data),
134 lda, &b_data_typed, ldb)
135 .ok();
136 break;
137 }
138 case C128: {
139 se::DeviceMemory<std::complex<double>> b_data_typed(b_data);
140 launch_ok =
141 stream
142 ->ThenBlasTrsm(side, uplo, transpose_a, unit_diagonal, m, n,
143 /*alpha=*/1.0,
144 se::DeviceMemory<std::complex<double>>(a_data),
145 lda, &b_data_typed, ldb)
146 .ok();
147 break;
148 }
149 default:
150 return InvalidArgument("Invalid type for triangular solve %d", type);
151 }
152 } else {
153 // cublas trsmBatched requires us to materialize out two arrays of
154 // batch_size_ pointers, pointing to the individual `a` and `b` matrices of
155 // our input. batch_pointers_bytes is the size in bytes of one of these
156 // arrays.
157 int64_t batch_pointers_bytes = sizeof(void*) * batch_size;
158 TF_RET_CHECK(temp_data.size() >= 2 * batch_pointers_bytes);
159 void** temp_base = reinterpret_cast<void**>(temp_data.opaque());
160 se::DeviceMemoryBase a_pointers(temp_base, batch_pointers_bytes);
161 se::DeviceMemoryBase b_pointers(temp_base + batch_size,
162 batch_pointers_bytes);
163
164 TF_RETURN_IF_ERROR(MakeBatchPointers(
165 stream, asm_opts, a_data, a_batch_stride, batch_size, a_pointers));
166 TF_RETURN_IF_ERROR(MakeBatchPointers(
167 stream, asm_opts, b_data, b_batch_stride, batch_size, b_pointers));
168
169 switch (type) {
170 case F32: {
171 se::DeviceMemory<float*> typed_b_pointers(b_pointers);
172 launch_ok =
173 stream
174 ->ThenBlasTrsmBatched(side, uplo, transpose_a, unit_diagonal, m,
175 n, /*alpha=*/1.0f,
176 se::DeviceMemory<float*>(a_pointers), lda,
177 &typed_b_pointers, ldb, batch_size)
178 .ok();
179 break;
180 }
181 case F64: {
182 se::DeviceMemory<double*> typed_b_pointers(b_pointers);
183 launch_ok =
184 stream
185 ->ThenBlasTrsmBatched(side, uplo, transpose_a, unit_diagonal, m,
186 n, /*alpha=*/1.0f,
187 se::DeviceMemory<double*>(a_pointers),
188 lda, &typed_b_pointers, ldb, batch_size)
189 .ok();
190 break;
191 }
192 case C64: {
193 se::DeviceMemory<std::complex<float>*> typed_b_pointers(b_pointers);
194 launch_ok = stream
195 ->ThenBlasTrsmBatched(
196 side, uplo, transpose_a, unit_diagonal, m, n,
197 /*alpha=*/1.0f,
198 se::DeviceMemory<std::complex<float>*>(a_pointers),
199 lda, &typed_b_pointers, ldb, batch_size)
200 .ok();
201 break;
202 }
203 case C128: {
204 se::DeviceMemory<std::complex<double>*> typed_b_pointers(b_pointers);
205 launch_ok = stream
206 ->ThenBlasTrsmBatched(
207 side, uplo, transpose_a, unit_diagonal, m, n,
208 /*alpha=*/1.0f,
209 se::DeviceMemory<std::complex<double>*>(a_pointers),
210 lda, &typed_b_pointers, ldb, batch_size)
211 .ok();
212 break;
213 }
214 default:
215 return InvalidArgument("Invalid type for triangular solve %d", type);
216 }
217 }
218
219 if (!launch_ok) {
220 return InternalError("Unable to launch triangular solve");
221 }
222 return OkStatus();
223 }
224
225 } // namespace gpu
226 } // namespace xla
227