1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/lib/matrix.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/math/math_util.h"
34
35 namespace xla {
36
37 namespace {
38
39 // Get the diagonal blocks of the coefficient matrix
DiagonalBlocks(XlaOp a,int64_t block_size)40 XlaOp DiagonalBlocks(XlaOp a, int64_t block_size) {
41 XlaBuilder* builder = a.builder();
42 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
43 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a));
44 int ndims = shape.rank();
45 int64_t n = ShapeUtil::GetDimension(shape, -1);
46 int64_t num_blocks = n / block_size;
47 absl::Span<int64_t const> batch_dims = absl::MakeConstSpan(
48 shape.dimensions().begin(), shape.dimensions().begin() + (ndims - 2));
49
50 XlaOp diag_blocks;
51
52 // If the coefficient matrix is exactly the block size, we just add a
53 // singleton dimension i.e. [..., n, n] -> [..., 1, n, n]
54 if (n == block_size) {
55 std::vector<int64_t> permutation(ndims);
56 std::iota(permutation.begin(), permutation.end(), 1);
57 permutation.insert(permutation.end() - 2, 0);
58 return Transpose(Broadcast(a, /*broadcast_sizes=*/{1}), permutation);
59 }
60
61 // We can grab entire blocks using gather
62 if (n > block_size) {
63 // Construct the starting indices of the diagonal blocks
64 auto start_indices =
65 Transpose(Broadcast(Mul(Iota(builder, S32, num_blocks),
66 ConstantR0<int32_t>(builder, block_size)),
67 /*broadcast_sizes=*/{2}),
68 /*permutation=*/{1, 0});
69
70 // Gather the diagonal blocks
71 std::vector<int64_t> slice_sizes(ndims);
72 GatherDimensionNumbers dim_numbers;
73 for (int i = 0; i < ndims - 2; ++i) {
74 dim_numbers.add_offset_dims(i);
75 slice_sizes[i] = ShapeUtil::GetDimension(shape, i);
76 }
77 slice_sizes[ndims - 2] = slice_sizes[ndims - 1] = block_size;
78 dim_numbers.add_offset_dims(ndims - 1);
79 dim_numbers.add_offset_dims(ndims);
80 dim_numbers.add_start_index_map(ndims - 2);
81 dim_numbers.add_start_index_map(ndims - 1);
82 dim_numbers.set_index_vector_dim(1);
83 diag_blocks = Gather(a, start_indices, dim_numbers, slice_sizes);
84 }
85
86 // The last block might be smaller than the block size,
87 // so we will need to pad it
88 if (n % block_size != 0) {
89 // Pad with identity matrix.
90 auto last_blocks =
91 SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n});
92 PaddingConfig config = MakeNoPaddingConfig(ndims);
93 int64_t padding = block_size - n % block_size;
94 config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding);
95 last_blocks =
96 Pad(last_blocks, Zero(builder, shape.element_type()), config);
97
98 auto eye =
99 IdentityMatrix(builder, shape.element_type(), padding, padding);
100 config = MakeNoPaddingConfig(2);
101 config.mutable_dimensions(0)->set_edge_padding_low(n % block_size);
102 eye = Pad(eye, Zero(builder, shape.element_type()), config);
103 eye = Broadcast(eye, batch_dims);
104 last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1);
105
106 // Add a singleton dimension
107 // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size]
108 TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks));
109 auto shape_dims = blocks_shape.dimensions();
110 auto last_blocks_dims = std::vector<int64_t>(ndims);
111 std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin());
112 last_blocks_dims.insert(last_blocks_dims.end() - 2, 1);
113 last_blocks = Reshape(last_blocks, last_blocks_dims);
114
115 // Concatenate with the other blocks if necessary
116 if (n > block_size) {
117 diag_blocks =
118 ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2);
119 } else {
120 diag_blocks = last_blocks;
121 }
122 }
123
124 return diag_blocks;
125 });
126 }
127
SolveWithInvertedDiagonalBlocks(XlaOp a,XlaOp b,XlaOp inv_diag_blocks,bool left_side,bool lower,bool transpose_a,bool conjugate_a,PrecisionConfig::Precision precision)128 XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
129 bool left_side, bool lower,
130 bool transpose_a, bool conjugate_a,
131 PrecisionConfig::Precision precision) {
132 XlaBuilder* builder = a.builder();
133 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
134 TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks));
135 TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
136 int64_t block_size = ShapeUtil::GetDimension(blocks_shape, -1);
137
138 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
139 int64_t ndims = a_shape.rank();
140 int64_t n = ShapeUtil::GetDimension(a_shape, -1);
141 int64_t num_blocks = n / block_size + (n % block_size != 0);
142 int64_t m_dim = (left_side) ? -1 : -2;
143 int64_t m = ShapeUtil::GetDimension(b_shape, m_dim);
144
145 std::vector<XlaOp> update_ops;
146 int bdims = b_shape.rank();
147 int64_t block_dim = (left_side) ? bdims - 2 : bdims - 1;
148
149 // Initialize the solution
150 XlaOp x;
151
152 // This loop is unrolled for performance reasons, but it could be expressed
153 // rolled as well since the matrices are of the same size each iteration
154 for (int i = 0; i < num_blocks; i++) {
155 // High-level intuition: We have B[i] = L[i] @ X. Since L is upper
156 // triangular this means B[i] = L[i, :i + 1] @ X[:i + 1]. We can split
157 // this into two parts: B[i] = L[i, :i] @ X[:i] + L[i, i] @ X[i] which
158 // can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i]
159
160 // Decide whether we go from first block to last or vice versa
161 bool backward = left_side ^ lower ^ transpose_a;
162 auto j = backward ? num_blocks - 1 - i : i;
163
164 // Get the size of the inverse blocks (the last one might be smaller)
165 int64_t block = (n % block_size != 0 && j + 1 == num_blocks)
166 ? n % block_size
167 : block_size;
168 auto inv_block =
169 MaybeConjugate(Collapse(SliceInMinorDims(inv_diag_blocks, {j, 0, 0},
170 {j + 1, block, block}),
171 /*dimensions=*/{ndims - 2, ndims - 1}),
172 conjugate_a);
173
174 // Get the corresponding row of B
175 int64_t k = std::min((j + 1) * block_size, n);
176 std::vector<int64_t> start = {j * block_size, 0};
177 std::vector<int64_t> end = {k, m};
178 if (!left_side) {
179 std::swap(start[0], start[1]);
180 std::swap(end[0], end[1]);
181 }
182 auto b_row = SliceInMinorDims(b, start, end);
183
184 XlaOp remainder;
185 if (i == 0) {
186 remainder = b_row;
187 } else {
188 // This matrix multiply get rid of a lot of multiplying with zero
189 // (namely, X[i * block_size:] = 0), L[i, :i] @ X[:i]
190 if (backward) {
191 start = {j * block_size,
192 std::max(int64_t{0}, (num_blocks - i) * block_size)};
193 end = {k, n};
194 } else {
195 start = {j * block_size, 0};
196 end = {k, std::min(i * block_size, n)};
197 }
198
199 if (!left_side ^ transpose_a) {
200 std::swap(start[0], start[1]);
201 std::swap(end[0], end[1]);
202 }
203 auto a_row =
204 MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a);
205 if (left_side) {
206 remainder = b_row - BatchDot(a_row, transpose_a, x, false, precision);
207 } else {
208 remainder = b_row - BatchDot(x, false, a_row, transpose_a, precision);
209 }
210 }
211
212 XlaOp x_update;
213 if (left_side) {
214 x_update =
215 BatchDot(inv_block, transpose_a, remainder, false, precision);
216 } else {
217 x_update =
218 BatchDot(remainder, false, inv_block, transpose_a, precision);
219 }
220
221 if (i == 0) {
222 x = x_update;
223 } else {
224 if (backward) {
225 x = ConcatInDim(builder, {x_update, x}, block_dim);
226 } else {
227 x = ConcatInDim(builder, {x, x_update}, block_dim);
228 }
229 }
230 }
231
232 return x;
233 });
234 }
235
236 } // namespace
237
InvertDiagonalBlocks(XlaOp diag_blocks,bool lower_triangular,PrecisionConfig::Precision precision)238 XlaOp TriangularSolveExpander::InvertDiagonalBlocks(
239 XlaOp diag_blocks, bool lower_triangular,
240 PrecisionConfig::Precision precision) {
241 XlaBuilder* builder = diag_blocks.builder();
242 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
243 // Input is a batch of square lower triangular square matrices. Its shape is
244 // (..., size, size). We resize this to (num_blocks, size, size).
245 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks));
246 int64_t block_size = ShapeUtil::GetDimension(shape, -1);
247 int64_t num_blocks = ShapeUtil::ElementsIn(shape) / IPow(block_size, 2);
248 diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
249
250 // The input must be triangular because we rely on that when doing
251 // multiplications later on
252 diag_blocks = Triangle(diag_blocks, /*lower=*/lower_triangular);
253
254 // Rescale blocks to be unit triangular, but avoid dividing by
255 // zero (which can happen if the last block was padded) otherwise it will
256 // introduce nans which will propagate
257 auto diags = GetMatrixDiagonal(diag_blocks);
258 auto ones = FullLike(diags, 1);
259 diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
260 auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
261
262 // We can now use the fact that for an upper triangular matrix
263 // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
264 // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
265 // have been rescaled to be unit triangular, so L22 = L22' = 1.
266
267 // Initialize the output matrix with -1s on the diagonal. We use -1 instead
268 // of 1 because we cannot do matrix-vector multiplies with variable shapes
269 // inside of a loop, or do irregularly shaped in-place updates. Hence,
270 // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
271 // entire row i.e. we calculate
272 // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
273 // which means [L21 L22 0] <- [-L21 * L11', L22, 0].
274 auto identity =
275 IdentityMatrix(builder, shape.element_type(), block_size, block_size);
276 auto neg_identity = -identity;
277
278 // The first or last diagonal element should be set to 1 instead of -1
279 // though, since we never update it
280 auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
281 auto start_index =
282 ConstantR0<int>(builder, lower_triangular ? 0 : block_size - 1);
283 auto output_block =
284 DynamicUpdateSlice(neg_identity, pos_one,
285 /*start_indices=*/{start_index, start_index});
286
287 // Broadcast diag([1, -1, -1, ...]) to every block
288 XlaOp output = Broadcast(output_block,
289 /*broadcast_sizes=*/{num_blocks});
290
291 // Now we construct a loop that performs matrix-vector multiplications
292 // inverting the blocks one row at a time
293 std::vector<Shape> tuple_shapes = {
294 // The loop iteration counter is a scalar, incremented each iteration.
295 ShapeUtil::MakeShape(S32, {}),
296 // The output has the shape of A, with one row updated each iteration.
297 ShapeUtil::MakeShape(shape.element_type(),
298 {num_blocks, block_size, block_size}),
299 // The input is a loop invariant.
300 ShapeUtil::MakeShape(shape.element_type(),
301 {num_blocks, block_size, block_size})};
302 Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
303
304 auto init_i = One(builder, S32);
305 auto init = Tuple(builder, {init_i, output, scaled_diag_blocks});
306
307 // Construct the loop condition function.
308 std::unique_ptr<XlaBuilder> condb =
309 builder->CreateSubBuilder("InvertDiagCond");
310 {
311 auto i = GetTupleElement(
312 Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
313 Lt(i, ConstantR0<int32_t>(condb.get(), block_size));
314 }
315 TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
316
317 // Construct the loop body function.
318 std::unique_ptr<XlaBuilder> bodyb =
319 builder->CreateSubBuilder("InvertDiagBody");
320 {
321 auto input_tuple =
322 Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
323
324 auto i = GetTupleElement(input_tuple, 0);
325 auto body_out = GetTupleElement(input_tuple, 1);
326 auto body_input = GetTupleElement(input_tuple, 2);
327
328 auto zero = ConstantR0<int32_t>(bodyb.get(), 0);
329 auto j = lower_triangular ? i : ScalarLike(i, block_size - 1) - i;
330 auto input_row =
331 DynamicSlice(body_input, {zero, j, zero},
332 /*slice_sizes=*/{num_blocks, 1, block_size});
333
334 // We want -L21 L11^{-1}
335 DotDimensionNumbers dnums;
336 dnums.add_lhs_batch_dimensions(0);
337 dnums.add_rhs_batch_dimensions(0);
338 dnums.add_lhs_contracting_dimensions(2);
339 dnums.add_rhs_contracting_dimensions(1);
340 PrecisionConfig precision_proto;
341 precision_proto.add_operand_precision(precision);
342 precision_proto.add_operand_precision(precision);
343 auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
344
345 body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero});
346
347 auto next_i = i + ScalarLike(i, 1);
348 Tuple(bodyb.get(), {next_i, body_out, body_input});
349 }
350 TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
351
352 // Construct the While loop and return the result,
353 // return while_loop(cond_fun, body_fun, init)[1]
354 auto invert_while = While(cond, body, init);
355 auto inv_diag_blocks = GetTupleElement(invert_while, 1);
356 // Undo the scaling
357 inv_diag_blocks = Div(inv_diag_blocks, diags,
358 /*broadcast_dimensions=*/{0, 1});
359
360 // Reshape back to original batch major dimensions
361 return Reshape(inv_diag_blocks, shape.dimensions());
362 });
363 }
364
SolveByInvertingDiagonalBlocks(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,PrecisionConfig::Precision precision)365 XlaOp TriangularSolveExpander::SolveByInvertingDiagonalBlocks(
366 XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
367 bool conjugate_a, bool unit_diagonal,
368 PrecisionConfig::Precision precision) {
369 XlaBuilder* builder = a.builder();
370 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
371 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
372 const int64_t ndims = a_shape.rank();
373 int64_t k = ShapeUtil::GetDimension(a_shape, -1);
374
375 // TODO(phawkins): consider pushing triangle masking into
376 // InvertDiagonalBlocks.
377 if (unit_diagonal) {
378 // Mask everything but the subdiagonal/superdiagonal elements.
379 a = lower ? Select(TriangleMask(a, -1), a, ZerosLike(a))
380 : Select(TriangleMask(a, 0), ZerosLike(a), a);
381 a = xla::Add(a, IdentityMatrix(builder, a_shape.element_type(), k, k),
382 /*broadcast_dimensions=*/{ndims - 2, ndims - 1});
383 } else {
384 // Mask off the ignored elements of the triangular matrix a.
385 a = Triangle(a, lower);
386 }
387
388 // We find the diagonal blocks of the coefficient matrix
389 int64_t block_size = std::min(block_size_, k);
390 auto diag_blocks = DiagonalBlocks(a, block_size);
391
392 // We invert these blocks in parallel using batched matrix-vector products
393 auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, precision);
394
395 // We now find the solution using GEMMs
396 return SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side,
397 lower, transpose_a, conjugate_a,
398 precision);
399 });
400 }
401
402 // def trsm_left_lower_leftlooking(a, b):
403 // n = a.shape[-1]
404 // assert a.shape == (n, n)
405 // b = b.copy()
406 // for j in range(n):
407 // b[j, :] = (b[j, :] - np.dot(a[j, :j], b[:j, :])) / a[j, j]
408 // return b
SolveDirectly(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,PrecisionConfig::Precision precision)409 XlaOp TriangularSolveExpander::SolveDirectly(
410 XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
411 bool conjugate_a, bool unit_diagonal,
412 PrecisionConfig::Precision precision) {
413 XlaBuilder* builder = a.builder();
414 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
415 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
416 TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
417 int64_t m = ShapeUtil::GetDimension(b_shape, -2);
418 int64_t n = ShapeUtil::GetDimension(b_shape, -1);
419 const int64_t a_size = ShapeUtil::GetDimension(a_shape, -1);
420 a = MaybeConjugate(a, conjugate_a);
421 bool backwards = transpose_a ^ lower ^ !left_side;
422 for (int64_t i = 0; i < a_size; ++i) {
423 int64_t j = backwards ? i : (a_size - i - 1);
424 std::vector<int64_t> b_row_start, b_row_end;
425 if (left_side) {
426 b_row_start = {j, 0};
427 b_row_end = {j + 1, n};
428 } else {
429 b_row_start = {0, j};
430 b_row_end = {m, j + 1};
431 }
432 auto b_row = SliceInMinorDims(b, b_row_start, b_row_end);
433
434 std::vector<int64_t> a_start = {j, backwards ? 0 : (j + 1)};
435 std::vector<int64_t> a_end = {j + 1, backwards ? j : a_size};
436 if (transpose_a ^ !left_side) {
437 std::swap(a_start[0], a_start[1]);
438 std::swap(a_end[0], a_end[1]);
439 }
440 auto a_chunk = SliceInMinorDims(a, a_start, a_end);
441 if (left_side) {
442 bool which = transpose_a ^ lower;
443 auto b_chunk =
444 SliceInMinorDims(b, {which ? 0 : (j + 1), 0}, {which ? j : m, n});
445 b_row = b_row - BatchDot(a_chunk, /*transpose_x=*/transpose_a, b_chunk,
446 /*transpose_y=*/false, precision);
447 } else {
448 bool which = transpose_a ^ !lower;
449 auto b_chunk =
450 SliceInMinorDims(b, {0, which ? 0 : (j + 1)}, {m, which ? j : n});
451 b_row = b_row - BatchDot(b_chunk, /*transpose_x=*/false, a_chunk,
452 /*transpose_y=*/transpose_a, precision);
453 }
454 if (!unit_diagonal) {
455 auto a_diag = SliceInMinorDims(a, {j, j}, {j + 1, j + 1});
456 b_row = b_row / a_diag;
457 }
458
459 b = UpdateSliceInMinorDims(b, b_row, b_row_start);
460 }
461
462 return b;
463 });
464 }
465
BuildTriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,int64_t block_size,PrecisionConfig::Precision precision)466 XlaOp TriangularSolveExpander::BuildTriangularSolve(
467 XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
468 bool conjugate_a, bool unit_diagonal, int64_t block_size,
469 PrecisionConfig::Precision precision) {
470 XlaBuilder* builder = a.builder();
471 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
472 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
473 TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
474 if (a_shape.rank() != b_shape.rank()) {
475 return InvalidArgument(
476 "Arguments to TriangularSolve have shapes with different ranks: "
477 "%s vs. %s",
478 ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
479 }
480 const int64_t ndims = a_shape.rank();
481 if (ndims < 2) {
482 return InvalidArgument(
483 "Arguments to TriangularSolve was rank %d but must have rank >= 2.",
484 ndims);
485 }
486 // The batch dimensions must be equal.
487 std::vector<int64_t> batch_dimensions;
488 int64_t batch = 1;
489 for (int i = 0; i < ndims - 2; ++i) {
490 int64_t a_size = a_shape.dimensions(i);
491 int64_t b_size = b_shape.dimensions(i);
492 if (a_size != b_size) {
493 return InvalidArgument(
494 "Batch dimensions of arguments to TriangularSolve must be equal; "
495 "shapes were %s and %s.",
496 ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
497 }
498 batch_dimensions.push_back(a_size);
499 batch *= a_size;
500 }
501
502 if (ShapeUtil::GetDimension(a_shape, -1) !=
503 ShapeUtil::GetDimension(a_shape, -2)) {
504 return InvalidArgument(
505 "The 'a' argument to TriangularSolve must be a batched square matrix;"
506 " shape was: %s",
507 ShapeUtil::HumanString(a_shape));
508 }
509 const int64_t m = ShapeUtil::GetDimension(b_shape, -2);
510 const int64_t n = ShapeUtil::GetDimension(b_shape, -1);
511 if ((left_side ? m : n) != ShapeUtil::GetDimension(a_shape, -1)) {
512 return InvalidArgument(
513 "Arguments to TriangularSolve have incompatible matrix shapes %s and "
514 "%s",
515 ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
516 }
517
518 int64_t a_size = ShapeUtil::GetDimension(a_shape, -1);
519
520 if (ShapeUtil::IsZeroElementArray(b_shape)) {
521 // The output has the same shape as 'b', and since the output has zero
522 // elements, any such array will do.
523 return b;
524 }
525
526 // Degenerate case: 1x1 matrices.
527 if (a_size == 1) {
528 return unit_diagonal ? b : Div(b, MaybeConjugate(a, conjugate_a));
529 }
530
531 // Prefer the direct implementation whenever there is a nontrivial batch
532 // dimension and the matrix is very small.
533 if (UseDirectSolves() && batch > block_size_ / 16 &&
534 a_size < block_size_ / 4) {
535 return SolveDirectly(a, b, left_side, lower, transpose_a, conjugate_a,
536 unit_diagonal, precision);
537 } else {
538 return SolveByInvertingDiagonalBlocks(a, b, left_side, lower, transpose_a,
539 conjugate_a, unit_diagonal,
540 precision);
541 }
542 });
543 }
544
TriangularSolveExpander(int64_t block_size)545 TriangularSolveExpander::TriangularSolveExpander(int64_t block_size)
546 : block_size_(block_size) {
547 CHECK_GE(block_size_, 1);
548 }
549
InstructionMatchesPattern(HloInstruction * instruction)550 bool TriangularSolveExpander::InstructionMatchesPattern(
551 HloInstruction* instruction) {
552 return instruction->opcode() == HloOpcode::kTriangularSolve;
553 }
554
ExpandInstruction(HloInstruction * instruction)555 StatusOr<HloInstruction*> TriangularSolveExpander::ExpandInstruction(
556 HloInstruction* instruction) {
557 const TriangularSolveOptions& options =
558 instruction->triangular_solve_options();
559 const std::string name = absl::StrFormat(
560 "xla.triangular_solve_%s_%s_%s_%s_%s_%s",
561 instruction->operand(0)->shape().ToString(),
562 instruction->operand(1)->shape().ToString(),
563 options.left_side() ? "left" : "right",
564 options.lower() ? "lower" : "upper",
565 TriangularSolveOptions_Transpose_Name(options.transpose_a()),
566 options.unit_diagonal() ? "unit" : "nonunit");
567
568 HloModule* module = instruction->parent()->parent();
569
570 HloComputation*& computation =
571 computation_cache_.emplace(name, nullptr).first->second;
572 if (!computation) {
573 // Builds a new expansion.
574 //
575 // We do something unusual here: we build the computation using the
576 // XlaBuilder API, which is nominally an XLA client API. We do this because
577 // the external APIs for building complicated computations (XlaBuilder)
578 // are much more ergonomic than the internal ones. As it turns out,
579 // XlaBuilder isn't really a client API—what it does is build a
580 // HloModuleProto protocol buffer, that we can then deserialize and clone
581 // into our HloModule. Ideally we would avoid the protocol buffer step;
582 // that is left as an exercise for future work.
583 XlaBuilder builder(name);
584 XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
585 XlaOp b = Parameter(&builder, 1, instruction->operand(1)->shape(), "b");
586 bool transpose_a =
587 options.transpose_a() != TriangularSolveOptions::NO_TRANSPOSE;
588 bool conjugate_a = options.transpose_a() == TriangularSolveOptions::ADJOINT;
589
590 BuildTriangularSolve(a, b, options.left_side(), options.lower(),
591 transpose_a, conjugate_a, options.unit_diagonal(),
592 /*block_size=*/block_size_,
593 /*precision=*/PrecisionConfig::HIGHEST);
594 TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
595
596 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
597 xla_computation.GetProgramShape());
598 HloModuleConfig config(program_shape);
599 TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
600 xla_computation.proto(), config));
601 HloCloneContext context(module);
602 computation =
603 module->DeepCloneComputation(new_module->entry_computation(), &context);
604 }
605
606 return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
607 instruction->shape(), instruction->operands(), computation));
608 }
609
610 } // namespace xla
611