xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/triangular_solve_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/math.h"
24 #include "tensorflow/compiler/xla/client/lib/matrix.h"
25 #include "tensorflow/compiler/xla/client/lib/slicing.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/lib/math/math_util.h"
34 
35 namespace xla {
36 
37 namespace {
38 
39 // Get the diagonal blocks of the coefficient matrix
DiagonalBlocks(XlaOp a,int64_t block_size)40 XlaOp DiagonalBlocks(XlaOp a, int64_t block_size) {
41   XlaBuilder* builder = a.builder();
42   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
43     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a));
44     int ndims = shape.rank();
45     int64_t n = ShapeUtil::GetDimension(shape, -1);
46     int64_t num_blocks = n / block_size;
47     absl::Span<int64_t const> batch_dims = absl::MakeConstSpan(
48         shape.dimensions().begin(), shape.dimensions().begin() + (ndims - 2));
49 
50     XlaOp diag_blocks;
51 
52     // If the coefficient matrix is exactly the block size, we just add a
53     // singleton dimension i.e. [..., n, n] -> [..., 1, n, n]
54     if (n == block_size) {
55       std::vector<int64_t> permutation(ndims);
56       std::iota(permutation.begin(), permutation.end(), 1);
57       permutation.insert(permutation.end() - 2, 0);
58       return Transpose(Broadcast(a, /*broadcast_sizes=*/{1}), permutation);
59     }
60 
61     // We can grab entire blocks using gather
62     if (n > block_size) {
63       // Construct the starting indices of the diagonal blocks
64       auto start_indices =
65           Transpose(Broadcast(Mul(Iota(builder, S32, num_blocks),
66                                   ConstantR0<int32_t>(builder, block_size)),
67                               /*broadcast_sizes=*/{2}),
68                     /*permutation=*/{1, 0});
69 
70       // Gather the diagonal blocks
71       std::vector<int64_t> slice_sizes(ndims);
72       GatherDimensionNumbers dim_numbers;
73       for (int i = 0; i < ndims - 2; ++i) {
74         dim_numbers.add_offset_dims(i);
75         slice_sizes[i] = ShapeUtil::GetDimension(shape, i);
76       }
77       slice_sizes[ndims - 2] = slice_sizes[ndims - 1] = block_size;
78       dim_numbers.add_offset_dims(ndims - 1);
79       dim_numbers.add_offset_dims(ndims);
80       dim_numbers.add_start_index_map(ndims - 2);
81       dim_numbers.add_start_index_map(ndims - 1);
82       dim_numbers.set_index_vector_dim(1);
83       diag_blocks = Gather(a, start_indices, dim_numbers, slice_sizes);
84     }
85 
86     // The last block might be smaller than the block size,
87     // so we will need to pad it
88     if (n % block_size != 0) {
89       // Pad with identity matrix.
90       auto last_blocks =
91           SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n});
92       PaddingConfig config = MakeNoPaddingConfig(ndims);
93       int64_t padding = block_size - n % block_size;
94       config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding);
95       last_blocks =
96           Pad(last_blocks, Zero(builder, shape.element_type()), config);
97 
98       auto eye =
99           IdentityMatrix(builder, shape.element_type(), padding, padding);
100       config = MakeNoPaddingConfig(2);
101       config.mutable_dimensions(0)->set_edge_padding_low(n % block_size);
102       eye = Pad(eye, Zero(builder, shape.element_type()), config);
103       eye = Broadcast(eye, batch_dims);
104       last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1);
105 
106       // Add a singleton dimension
107       // i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size]
108       TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks));
109       auto shape_dims = blocks_shape.dimensions();
110       auto last_blocks_dims = std::vector<int64_t>(ndims);
111       std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin());
112       last_blocks_dims.insert(last_blocks_dims.end() - 2, 1);
113       last_blocks = Reshape(last_blocks, last_blocks_dims);
114 
115       // Concatenate with the other blocks if necessary
116       if (n > block_size) {
117         diag_blocks =
118             ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2);
119       } else {
120         diag_blocks = last_blocks;
121       }
122     }
123 
124     return diag_blocks;
125   });
126 }
127 
SolveWithInvertedDiagonalBlocks(XlaOp a,XlaOp b,XlaOp inv_diag_blocks,bool left_side,bool lower,bool transpose_a,bool conjugate_a,PrecisionConfig::Precision precision)128 XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
129                                       bool left_side, bool lower,
130                                       bool transpose_a, bool conjugate_a,
131                                       PrecisionConfig::Precision precision) {
132   XlaBuilder* builder = a.builder();
133   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
134     TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks));
135     TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
136     int64_t block_size = ShapeUtil::GetDimension(blocks_shape, -1);
137 
138     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
139     int64_t ndims = a_shape.rank();
140     int64_t n = ShapeUtil::GetDimension(a_shape, -1);
141     int64_t num_blocks = n / block_size + (n % block_size != 0);
142     int64_t m_dim = (left_side) ? -1 : -2;
143     int64_t m = ShapeUtil::GetDimension(b_shape, m_dim);
144 
145     std::vector<XlaOp> update_ops;
146     int bdims = b_shape.rank();
147     int64_t block_dim = (left_side) ? bdims - 2 : bdims - 1;
148 
149     // Initialize the solution
150     XlaOp x;
151 
152     // This loop is unrolled for performance reasons, but it could be expressed
153     // rolled as well since the matrices are of the same size each iteration
154     for (int i = 0; i < num_blocks; i++) {
155       // High-level intuition: We have B[i] = L[i] @ X. Since L is upper
156       // triangular this means B[i] = L[i, :i + 1] @ X[:i + 1]. We can split
157       // this into two parts: B[i] = L[i, :i] @ X[:i] + L[i, i] @ X[i] which
158       // can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i]
159 
160       // Decide whether we go from first block to last or vice versa
161       bool backward = left_side ^ lower ^ transpose_a;
162       auto j = backward ? num_blocks - 1 - i : i;
163 
164       // Get the size of the inverse blocks (the last one might be smaller)
165       int64_t block = (n % block_size != 0 && j + 1 == num_blocks)
166                           ? n % block_size
167                           : block_size;
168       auto inv_block =
169           MaybeConjugate(Collapse(SliceInMinorDims(inv_diag_blocks, {j, 0, 0},
170                                                    {j + 1, block, block}),
171                                   /*dimensions=*/{ndims - 2, ndims - 1}),
172                          conjugate_a);
173 
174       // Get the corresponding row of B
175       int64_t k = std::min((j + 1) * block_size, n);
176       std::vector<int64_t> start = {j * block_size, 0};
177       std::vector<int64_t> end = {k, m};
178       if (!left_side) {
179         std::swap(start[0], start[1]);
180         std::swap(end[0], end[1]);
181       }
182       auto b_row = SliceInMinorDims(b, start, end);
183 
184       XlaOp remainder;
185       if (i == 0) {
186         remainder = b_row;
187       } else {
188         // This matrix multiply get rid of a lot of multiplying with zero
189         // (namely, X[i * block_size:] = 0), L[i, :i] @ X[:i]
190         if (backward) {
191           start = {j * block_size,
192                    std::max(int64_t{0}, (num_blocks - i) * block_size)};
193           end = {k, n};
194         } else {
195           start = {j * block_size, 0};
196           end = {k, std::min(i * block_size, n)};
197         }
198 
199         if (!left_side ^ transpose_a) {
200           std::swap(start[0], start[1]);
201           std::swap(end[0], end[1]);
202         }
203         auto a_row =
204             MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a);
205         if (left_side) {
206           remainder = b_row - BatchDot(a_row, transpose_a, x, false, precision);
207         } else {
208           remainder = b_row - BatchDot(x, false, a_row, transpose_a, precision);
209         }
210       }
211 
212       XlaOp x_update;
213       if (left_side) {
214         x_update =
215             BatchDot(inv_block, transpose_a, remainder, false, precision);
216       } else {
217         x_update =
218             BatchDot(remainder, false, inv_block, transpose_a, precision);
219       }
220 
221       if (i == 0) {
222         x = x_update;
223       } else {
224         if (backward) {
225           x = ConcatInDim(builder, {x_update, x}, block_dim);
226         } else {
227           x = ConcatInDim(builder, {x, x_update}, block_dim);
228         }
229       }
230     }
231 
232     return x;
233   });
234 }
235 
236 }  // namespace
237 
InvertDiagonalBlocks(XlaOp diag_blocks,bool lower_triangular,PrecisionConfig::Precision precision)238 XlaOp TriangularSolveExpander::InvertDiagonalBlocks(
239     XlaOp diag_blocks, bool lower_triangular,
240     PrecisionConfig::Precision precision) {
241   XlaBuilder* builder = diag_blocks.builder();
242   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
243     // Input is a batch of square lower triangular square matrices. Its shape is
244     // (..., size, size). We resize this to (num_blocks, size, size).
245     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks));
246     int64_t block_size = ShapeUtil::GetDimension(shape, -1);
247     int64_t num_blocks = ShapeUtil::ElementsIn(shape) / IPow(block_size, 2);
248     diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
249 
250     // The input must be triangular because we rely on that when doing
251     // multiplications later on
252     diag_blocks = Triangle(diag_blocks, /*lower=*/lower_triangular);
253 
254     // Rescale blocks to be unit triangular, but avoid dividing by
255     // zero (which can happen if the last block was padded) otherwise it will
256     // introduce nans which will propagate
257     auto diags = GetMatrixDiagonal(diag_blocks);
258     auto ones = FullLike(diags, 1);
259     diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
260     auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
261 
262     // We can now use the fact that for an upper triangular matrix
263     // [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
264     // L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
265     // have been rescaled to be unit triangular, so L22 = L22' = 1.
266 
267     // Initialize the output matrix with -1s on the diagonal. We use -1 instead
268     // of 1 because we cannot do matrix-vector multiplies with variable shapes
269     // inside of a loop, or do irregularly shaped in-place updates. Hence,
270     // L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
271     // entire row i.e. we calculate
272     // [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
273     // which means [L21 L22 0] <- [-L21 * L11', L22, 0].
274     auto identity =
275         IdentityMatrix(builder, shape.element_type(), block_size, block_size);
276     auto neg_identity = -identity;
277 
278     // The first or last  diagonal element should be set to 1 instead of -1
279     // though, since we never update it
280     auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
281     auto start_index =
282         ConstantR0<int>(builder, lower_triangular ? 0 : block_size - 1);
283     auto output_block =
284         DynamicUpdateSlice(neg_identity, pos_one,
285                            /*start_indices=*/{start_index, start_index});
286 
287     // Broadcast diag([1, -1, -1, ...]) to every block
288     XlaOp output = Broadcast(output_block,
289                              /*broadcast_sizes=*/{num_blocks});
290 
291     // Now we construct a loop that performs matrix-vector multiplications
292     // inverting the blocks one row at a time
293     std::vector<Shape> tuple_shapes = {
294         // The loop iteration counter is a scalar, incremented each iteration.
295         ShapeUtil::MakeShape(S32, {}),
296         // The output has the shape of A, with one row updated each iteration.
297         ShapeUtil::MakeShape(shape.element_type(),
298                              {num_blocks, block_size, block_size}),
299         // The input is a loop invariant.
300         ShapeUtil::MakeShape(shape.element_type(),
301                              {num_blocks, block_size, block_size})};
302     Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
303 
304     auto init_i = One(builder, S32);
305     auto init = Tuple(builder, {init_i, output, scaled_diag_blocks});
306 
307     // Construct the loop condition function.
308     std::unique_ptr<XlaBuilder> condb =
309         builder->CreateSubBuilder("InvertDiagCond");
310     {
311       auto i = GetTupleElement(
312           Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
313       Lt(i, ConstantR0<int32_t>(condb.get(), block_size));
314     }
315     TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
316 
317     // Construct the loop body function.
318     std::unique_ptr<XlaBuilder> bodyb =
319         builder->CreateSubBuilder("InvertDiagBody");
320     {
321       auto input_tuple =
322           Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
323 
324       auto i = GetTupleElement(input_tuple, 0);
325       auto body_out = GetTupleElement(input_tuple, 1);
326       auto body_input = GetTupleElement(input_tuple, 2);
327 
328       auto zero = ConstantR0<int32_t>(bodyb.get(), 0);
329       auto j = lower_triangular ? i : ScalarLike(i, block_size - 1) - i;
330       auto input_row =
331           DynamicSlice(body_input, {zero, j, zero},
332                        /*slice_sizes=*/{num_blocks, 1, block_size});
333 
334       // We want -L21 L11^{-1}
335       DotDimensionNumbers dnums;
336       dnums.add_lhs_batch_dimensions(0);
337       dnums.add_rhs_batch_dimensions(0);
338       dnums.add_lhs_contracting_dimensions(2);
339       dnums.add_rhs_contracting_dimensions(1);
340       PrecisionConfig precision_proto;
341       precision_proto.add_operand_precision(precision);
342       precision_proto.add_operand_precision(precision);
343       auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
344 
345       body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero});
346 
347       auto next_i = i + ScalarLike(i, 1);
348       Tuple(bodyb.get(), {next_i, body_out, body_input});
349     }
350     TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
351 
352     // Construct the While loop and return the result,
353     // return while_loop(cond_fun, body_fun, init)[1]
354     auto invert_while = While(cond, body, init);
355     auto inv_diag_blocks = GetTupleElement(invert_while, 1);
356     // Undo the scaling
357     inv_diag_blocks = Div(inv_diag_blocks, diags,
358                           /*broadcast_dimensions=*/{0, 1});
359 
360     // Reshape back to original batch major dimensions
361     return Reshape(inv_diag_blocks, shape.dimensions());
362   });
363 }
364 
SolveByInvertingDiagonalBlocks(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,PrecisionConfig::Precision precision)365 XlaOp TriangularSolveExpander::SolveByInvertingDiagonalBlocks(
366     XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
367     bool conjugate_a, bool unit_diagonal,
368     PrecisionConfig::Precision precision) {
369   XlaBuilder* builder = a.builder();
370   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
371     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
372     const int64_t ndims = a_shape.rank();
373     int64_t k = ShapeUtil::GetDimension(a_shape, -1);
374 
375     // TODO(phawkins): consider pushing triangle masking into
376     // InvertDiagonalBlocks.
377     if (unit_diagonal) {
378       // Mask everything but the subdiagonal/superdiagonal elements.
379       a = lower ? Select(TriangleMask(a, -1), a, ZerosLike(a))
380                 : Select(TriangleMask(a, 0), ZerosLike(a), a);
381       a = xla::Add(a, IdentityMatrix(builder, a_shape.element_type(), k, k),
382                    /*broadcast_dimensions=*/{ndims - 2, ndims - 1});
383     } else {
384       // Mask off the ignored elements of the triangular matrix a.
385       a = Triangle(a, lower);
386     }
387 
388     // We find the diagonal blocks of the coefficient matrix
389     int64_t block_size = std::min(block_size_, k);
390     auto diag_blocks = DiagonalBlocks(a, block_size);
391 
392     // We invert these blocks in parallel using batched matrix-vector products
393     auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, precision);
394 
395     // We now find the solution using GEMMs
396     return SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side,
397                                            lower, transpose_a, conjugate_a,
398                                            precision);
399   });
400 }
401 
402 // def trsm_left_lower_leftlooking(a, b):
403 //   n = a.shape[-1]
404 //   assert a.shape == (n, n)
405 //   b = b.copy()
406 //   for j in range(n):
407 //     b[j, :] = (b[j, :] - np.dot(a[j, :j], b[:j, :])) / a[j, j]
408 //   return b
SolveDirectly(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,PrecisionConfig::Precision precision)409 XlaOp TriangularSolveExpander::SolveDirectly(
410     XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
411     bool conjugate_a, bool unit_diagonal,
412     PrecisionConfig::Precision precision) {
413   XlaBuilder* builder = a.builder();
414   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
415     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
416     TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
417     int64_t m = ShapeUtil::GetDimension(b_shape, -2);
418     int64_t n = ShapeUtil::GetDimension(b_shape, -1);
419     const int64_t a_size = ShapeUtil::GetDimension(a_shape, -1);
420     a = MaybeConjugate(a, conjugate_a);
421     bool backwards = transpose_a ^ lower ^ !left_side;
422     for (int64_t i = 0; i < a_size; ++i) {
423       int64_t j = backwards ? i : (a_size - i - 1);
424       std::vector<int64_t> b_row_start, b_row_end;
425       if (left_side) {
426         b_row_start = {j, 0};
427         b_row_end = {j + 1, n};
428       } else {
429         b_row_start = {0, j};
430         b_row_end = {m, j + 1};
431       }
432       auto b_row = SliceInMinorDims(b, b_row_start, b_row_end);
433 
434       std::vector<int64_t> a_start = {j, backwards ? 0 : (j + 1)};
435       std::vector<int64_t> a_end = {j + 1, backwards ? j : a_size};
436       if (transpose_a ^ !left_side) {
437         std::swap(a_start[0], a_start[1]);
438         std::swap(a_end[0], a_end[1]);
439       }
440       auto a_chunk = SliceInMinorDims(a, a_start, a_end);
441       if (left_side) {
442         bool which = transpose_a ^ lower;
443         auto b_chunk =
444             SliceInMinorDims(b, {which ? 0 : (j + 1), 0}, {which ? j : m, n});
445         b_row = b_row - BatchDot(a_chunk, /*transpose_x=*/transpose_a, b_chunk,
446                                  /*transpose_y=*/false, precision);
447       } else {
448         bool which = transpose_a ^ !lower;
449         auto b_chunk =
450             SliceInMinorDims(b, {0, which ? 0 : (j + 1)}, {m, which ? j : n});
451         b_row = b_row - BatchDot(b_chunk, /*transpose_x=*/false, a_chunk,
452                                  /*transpose_y=*/transpose_a, precision);
453       }
454       if (!unit_diagonal) {
455         auto a_diag = SliceInMinorDims(a, {j, j}, {j + 1, j + 1});
456         b_row = b_row / a_diag;
457       }
458 
459       b = UpdateSliceInMinorDims(b, b_row, b_row_start);
460     }
461 
462     return b;
463   });
464 }
465 
BuildTriangularSolve(XlaOp a,XlaOp b,bool left_side,bool lower,bool transpose_a,bool conjugate_a,bool unit_diagonal,int64_t block_size,PrecisionConfig::Precision precision)466 XlaOp TriangularSolveExpander::BuildTriangularSolve(
467     XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
468     bool conjugate_a, bool unit_diagonal, int64_t block_size,
469     PrecisionConfig::Precision precision) {
470   XlaBuilder* builder = a.builder();
471   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
472     TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
473     TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
474     if (a_shape.rank() != b_shape.rank()) {
475       return InvalidArgument(
476           "Arguments to TriangularSolve have shapes with different ranks: "
477           "%s vs. %s",
478           ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
479     }
480     const int64_t ndims = a_shape.rank();
481     if (ndims < 2) {
482       return InvalidArgument(
483           "Arguments to TriangularSolve was rank %d but must have rank >= 2.",
484           ndims);
485     }
486     // The batch dimensions must be equal.
487     std::vector<int64_t> batch_dimensions;
488     int64_t batch = 1;
489     for (int i = 0; i < ndims - 2; ++i) {
490       int64_t a_size = a_shape.dimensions(i);
491       int64_t b_size = b_shape.dimensions(i);
492       if (a_size != b_size) {
493         return InvalidArgument(
494             "Batch dimensions of arguments to TriangularSolve must be equal; "
495             "shapes were %s and %s.",
496             ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
497       }
498       batch_dimensions.push_back(a_size);
499       batch *= a_size;
500     }
501 
502     if (ShapeUtil::GetDimension(a_shape, -1) !=
503         ShapeUtil::GetDimension(a_shape, -2)) {
504       return InvalidArgument(
505           "The 'a' argument to TriangularSolve must be a batched square matrix;"
506           " shape was: %s",
507           ShapeUtil::HumanString(a_shape));
508     }
509     const int64_t m = ShapeUtil::GetDimension(b_shape, -2);
510     const int64_t n = ShapeUtil::GetDimension(b_shape, -1);
511     if ((left_side ? m : n) != ShapeUtil::GetDimension(a_shape, -1)) {
512       return InvalidArgument(
513           "Arguments to TriangularSolve have incompatible matrix shapes %s and "
514           "%s",
515           ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
516     }
517 
518     int64_t a_size = ShapeUtil::GetDimension(a_shape, -1);
519 
520     if (ShapeUtil::IsZeroElementArray(b_shape)) {
521       // The output has the same shape as 'b', and since the output has zero
522       // elements, any such array will do.
523       return b;
524     }
525 
526     // Degenerate case: 1x1 matrices.
527     if (a_size == 1) {
528       return unit_diagonal ? b : Div(b, MaybeConjugate(a, conjugate_a));
529     }
530 
531     // Prefer the direct implementation whenever there is a nontrivial batch
532     // dimension and the matrix is very small.
533     if (UseDirectSolves() && batch > block_size_ / 16 &&
534         a_size < block_size_ / 4) {
535       return SolveDirectly(a, b, left_side, lower, transpose_a, conjugate_a,
536                            unit_diagonal, precision);
537     } else {
538       return SolveByInvertingDiagonalBlocks(a, b, left_side, lower, transpose_a,
539                                             conjugate_a, unit_diagonal,
540                                             precision);
541     }
542   });
543 }
544 
TriangularSolveExpander(int64_t block_size)545 TriangularSolveExpander::TriangularSolveExpander(int64_t block_size)
546     : block_size_(block_size) {
547   CHECK_GE(block_size_, 1);
548 }
549 
InstructionMatchesPattern(HloInstruction * instruction)550 bool TriangularSolveExpander::InstructionMatchesPattern(
551     HloInstruction* instruction) {
552   return instruction->opcode() == HloOpcode::kTriangularSolve;
553 }
554 
ExpandInstruction(HloInstruction * instruction)555 StatusOr<HloInstruction*> TriangularSolveExpander::ExpandInstruction(
556     HloInstruction* instruction) {
557   const TriangularSolveOptions& options =
558       instruction->triangular_solve_options();
559   const std::string name = absl::StrFormat(
560       "xla.triangular_solve_%s_%s_%s_%s_%s_%s",
561       instruction->operand(0)->shape().ToString(),
562       instruction->operand(1)->shape().ToString(),
563       options.left_side() ? "left" : "right",
564       options.lower() ? "lower" : "upper",
565       TriangularSolveOptions_Transpose_Name(options.transpose_a()),
566       options.unit_diagonal() ? "unit" : "nonunit");
567 
568   HloModule* module = instruction->parent()->parent();
569 
570   HloComputation*& computation =
571       computation_cache_.emplace(name, nullptr).first->second;
572   if (!computation) {
573     // Builds a new expansion.
574     //
575     // We do something unusual here: we build the computation using the
576     // XlaBuilder API, which is nominally an XLA client API. We do this because
577     // the external APIs for building complicated computations (XlaBuilder)
578     // are much more ergonomic than the internal ones. As it turns out,
579     // XlaBuilder isn't really a client API—what it does is build a
580     // HloModuleProto protocol buffer, that we can then deserialize and clone
581     // into our HloModule. Ideally we would avoid the protocol buffer step;
582     // that is left as an exercise for future work.
583     XlaBuilder builder(name);
584     XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
585     XlaOp b = Parameter(&builder, 1, instruction->operand(1)->shape(), "b");
586     bool transpose_a =
587         options.transpose_a() != TriangularSolveOptions::NO_TRANSPOSE;
588     bool conjugate_a = options.transpose_a() == TriangularSolveOptions::ADJOINT;
589 
590     BuildTriangularSolve(a, b, options.left_side(), options.lower(),
591                          transpose_a, conjugate_a, options.unit_diagonal(),
592                          /*block_size=*/block_size_,
593                          /*precision=*/PrecisionConfig::HIGHEST);
594     TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
595 
596     TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
597                         xla_computation.GetProgramShape());
598     HloModuleConfig config(program_shape);
599     TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
600                                              xla_computation.proto(), config));
601     HloCloneContext context(module);
602     computation =
603         module->DeepCloneComputation(new_module->entry_computation(), &context);
604   }
605 
606   return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
607       instruction->shape(), instruction->operands(), computation));
608 }
609 
610 }  // namespace xla
611