xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/qr_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/qr_expander.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/loops.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/lib/matrix.h"
26 #include "tensorflow/compiler/xla/client/lib/qr.h"
27 #include "tensorflow/compiler/xla/client/lib/slicing.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 
37 namespace xla {
38 
39 namespace {
40 
ConcatVectors(absl::Span<const int64_t> xs,absl::Span<const int64_t> ys)41 std::vector<int64_t> ConcatVectors(absl::Span<const int64_t> xs,
42                                    absl::Span<const int64_t> ys) {
43   std::vector<int64_t> output;
44   output.reserve(xs.size() + ys.size());
45   std::copy(xs.begin(), xs.end(), std::back_inserter(output));
46   std::copy(ys.begin(), ys.end(), std::back_inserter(output));
47   return output;
48 }
49 
50 // Computes sqrt(x^2 + y^2 + ...), avoiding overflow/underflow.
51 // e.g. for 3 arguments:
52 // def norm(x, y, z):
53 //   xabs = np.abs(x)
54 //   yabs = np.abs(y)
55 //   zabs = np.abs(z)
56 //   w = np.maximum(np.maximum(xabs, yabs), zabs)
57 //   if w == 0:
58 //     return 0
59 //   else:
60 //     return w * np.sqrt((xabs / w)**2 + (yabs / w) ** 2 + (zabs / w) ** 2)
Norm(std::vector<XlaOp> xs)61 XlaOp Norm(std::vector<XlaOp> xs) {
62   CHECK(!xs.empty());
63   XlaOp w;
64   for (size_t i = 0; i < xs.size(); ++i) {
65     xs[i] = Abs(xs[i]);
66     w = i == 0 ? xs[i] : xla::Max(w, xs[i]);
67   }
68 
69   XlaOp out;
70   for (size_t i = 0; i < xs.size(); ++i) {
71     XlaOp t = Square(xs[i] / w);
72     out = i == 0 ? t : xla::Add(out, t);
73   }
74   return Select(Eq(w, ZerosLike(w)), ZerosLike(w), w * Sqrt(out));
75 }
76 
77 // Computes a Householder reflection of the form:
78 // H = I - tau v v.T.
79 // such that
80 // H . ( x1  ) = ( x1   )
81 //     ( x2  ) = ( x2   )
82 //     ( ... ) = ( ...  )
83 //     ( xk  ) = ( beta )
84 //     ( ... )   ( 0    )
85 //     ( ... )   ( 0    )
86 // Unlike the usual formulation, we allow the caller to supply 'k' rather than
87 // only providing the relevant part of 'x' to maintain XLA's static shape
88 // invariant. In addition, the implementation supports batching.
89 // Pseudo-code, without batching:
90 //   alpha = x[k]
91 //   x_copy = np.copy(x)
92 //   x_copy[:k+1] = 0
93 //   xnorm = norm2(x_copy)
94 //   if xnorm == 0 and np.imag(alpha) == 0:
95 //     beta = alpha
96 //     tau = 0
97 //     v = np.zeros_like(x)
98 //   else:
99 //     beta = -np.sign(np.real(alpha)) * np.sqrt(alpha * np.conj(alpha) + xnorm)
100 //     if np.issubdtype(x.dtype, np.complexfloating):
101 //       tau = (beta - alpha) / beta
102 //     else:
103 //       tau = (beta - np.real(alpha) / beta) + (-np.imag(alpha) / beta) * 1j
104 //     v = x / (alpha - beta)
105 //   v[k] = 1
106 //   return (v, tau, beta)
107 // TODO(phawkins): LAPACK's xLARFG implementation has code for handling
108 // overflows in the norm/beta calculations. Perhaps do the same here.
House(XlaOp x,XlaOp k,absl::Span<const int64_t> batch_dims,const int64_t m,XlaOp * v,XlaOp * tau,XlaOp * beta)109 Status House(XlaOp x, XlaOp k, absl::Span<const int64_t> batch_dims,
110              const int64_t m, XlaOp* v, XlaOp* tau, XlaOp* beta) {
111   XlaBuilder* const builder = x.builder();
112   TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
113   const PrimitiveType type = x_shape.element_type();
114 
115   std::vector<int64_t> batch_dim_ids(batch_dims.size());
116   std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
117   const int64_t minor_dim = batch_dims.size();
118 
119   XlaOp zero = ScalarLike(x, 0.0);
120 
121   // alpha = x[k]
122   XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
123 
124   // Compute x[k+1:] (padded with zeros in elements 0..k)
125   XlaOp iota = Iota(builder, S32, m);
126   XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
127                         /*broadcast_dimensions=*/{minor_dim});
128 
129   XlaOp sigma_is_zero;
130   if (primitive_util::IsComplexType(type)) {
131     // sigma = np.dot(x[k+1:], np.conj(x[k+1:]))
132     auto x_squared = Real(x_after_k * Conj(x_after_k));
133     auto sigma =
134         Reduce(x_squared, ScalarLike(x_squared, 0.0),
135                CreateScalarAddComputation(
136                    primitive_util::ComplexComponentType(type), builder),
137                {minor_dim});
138     auto mu = Norm({Real(alpha), Imag(alpha), Sqrt(sigma)});
139 
140     sigma_is_zero = Eq(sigma, ScalarLike(sigma, 0));
141     sigma_is_zero = And(sigma_is_zero, Eq(Imag(alpha), ScalarLike(sigma, 0)));
142 
143     *beta = Select(Lt(Real(alpha), ScalarLike(sigma, 0)), ScalarLike(mu, 1),
144                    ScalarLike(mu, -1)) *
145             mu;
146     *beta = Select(sigma_is_zero, Real(alpha), *beta);
147     *tau = Complex((*beta - Real(alpha)) / *beta, -Imag(alpha) / *beta);
148   } else {
149     // sigma = np.dot(x[k+1:], x[k+1:])
150     auto sigma = Reduce(x_after_k * x_after_k, zero,
151                         CreateScalarAddComputation(type, builder), {minor_dim});
152     auto mu = Norm({alpha, Sqrt(sigma)});
153     sigma_is_zero = Eq(sigma, zero);
154 
155     XlaOp one = ScalarLike(x, 1.0);
156     *beta = Select(Lt(alpha, zero), one, -one) * mu;
157     *beta = Select(sigma_is_zero, alpha, *beta);
158     *tau = (*beta - alpha) / *beta;
159   }
160   *tau = Select(sigma_is_zero, ZerosLike(*tau), *tau);
161 
162   auto divisor =
163       Select(sigma_is_zero, Broadcast(ScalarLike(alpha, 1), batch_dims),
164              alpha - ConvertElementType(*beta, type));
165 
166   auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
167                        std::vector<int64_t>(batch_dims.size(), 1));
168 
169   // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
170   // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
171   *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
172   return OkStatus();
173 }
174 
175 }  // namespace
176 
177 // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
178 // Loan "Matrix Computations", 4th Edition. This is an unblocked implementation
179 // used as an inner routine of the blocked implementation.
180 // Algorithm is adapted slightly so the shapes inside the loop are static, at
181 // the cost of some redundant computation. Since this is used as an inner block
182 // kernel, accumulates the Householder transformations (vs, taus) rather than
183 // the matrix q.
184 // Equivalent Python code, without batching:
185 // def qr(a):
186 //   m = a.shape[0]
187 //   n = a.shape[1]
188 //   taus = np.zeros([n])
189 //   for j in xrange(min(m, n)):
190 //     v, tau, beta = house(a[:, j], j)
191 //     a[:, j+1:] -= np.conj(tau) * np.dot(v[:, np.newaxis],
192 //                                np.dot(np.conj(v[np.newaxis, :]), a[:, j+1:]))
193 //     # Form column j explicitly rather than relying on the precision of the
194 //     # Householder update.
195 //     a[j, j] = beta
196 //     a[j+1:, j] = v[j+1:]
197 //     taus[j] = tau
198 //   return (a, taus)
QrBlock(XlaOp a,PrecisionConfig::Precision precision)199 StatusOr<QrDecomposition> QrExpander::QrBlock(
200     XlaOp a, PrecisionConfig::Precision precision) {
201   XlaBuilder* builder = a.builder();
202   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
203   const int num_dims = a_shape.rank();
204   if (num_dims < 2) {
205     return InvalidArgument("Argument to QR must have rank >= 2; got shape %s",
206                            a_shape.ToString());
207   }
208   PrimitiveType type = a_shape.element_type();
209 
210   const int64_t m = ShapeUtil::GetDimension(a_shape, -2);
211   const int64_t n = ShapeUtil::GetDimension(a_shape, -1);
212 
213   const int64_t num_batch_dims = num_dims - 2;
214   std::vector<int64_t> batch_dims(num_batch_dims);
215   for (int i = 0; i < num_batch_dims; ++i) {
216     batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
217   }
218 
219   std::vector<int64_t> batch_dim_indices(num_batch_dims);
220   std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
221 
222   auto qr_body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
223                         XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
224     auto a = values[0];
225     auto taus = values[1];
226 
227     // v, tau, beta = house(a[:, j], j)
228     auto x = DynamicSliceInMinorDims(a, {j}, {1});
229     XlaOp v, tau, beta;
230     TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j,
231                              batch_dims, m, &v, &tau, &beta));
232 
233     const int64_t minor_dim = batch_dims.size();
234     auto iota_mn = Iota(
235         builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {m, n})),
236         minor_dim + 1);
237 
238     std::vector<int64_t> shape = batch_dims;
239     shape.push_back(1);
240     shape.push_back(m);
241     auto v_broadcast = Reshape(v, shape);
242     // a[:, j+1:] -= np.conj(tau) * (v[:, np.newaxis] @
243     //     (np.conj(v[np.newaxis, :]) @ a[:, j+1:]))
244     // We use masking rather than a loop-variant shape to handle the j+1:
245     // indexing.
246     auto vva = BatchDot(MaybeConjugate(v_broadcast, true),
247                         Select(Lt(j, iota_mn), a, ZerosLike(a)), precision);
248     vva = BatchDot(v_broadcast, true, vva, false, precision);
249     a = a - Mul(MaybeConjugate(tau, true), vva,
250                 /*broadcast_dimensions=*/batch_dim_indices);
251 
252     // a[j, j] = beta
253     // a[j+1:,j] = v[j+1:]
254     auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1});
255     auto predecessor_mask = ConvertElementType(Lt(iota, j), type);
256     auto mask = Broadcast(ConvertElementType(Eq(iota, j), type),
257                           std::vector<int64_t>(batch_dims.size(), 1));
258     auto successor_mask = Gt(Iota(a.builder(), S32, m), j);
259     auto new_x = Mul(x, predecessor_mask,
260                      /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
261                  Mul(ConvertElementType(beta, type), mask,
262                      /*broadcast_dimensions=*/batch_dim_indices);
263     new_x = Add(
264         new_x, Select(Broadcast(successor_mask, batch_dims), v, ZerosLike(v)),
265         /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {minor_dim}));
266     // Update a[:,j]
267     std::vector<int64_t> dim_ids(num_dims);
268     std::iota(dim_ids.begin(), dim_ids.end(), 0);
269     new_x = BroadcastInDim(new_x, ConcatVectors(batch_dims, {m, n}),
270                            /*broadcast_dimensions=*/dim_ids);
271     a = Select(Eq(iota_mn, j), new_x, a);
272 
273     // taus[j] = tau
274     std::vector<int64_t> tau_broadcast_dims(batch_dims.size());
275     std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0);
276 
277     auto iota_n =
278         Iota(builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {n})),
279              minor_dim);
280     auto taus_zeros = ZerosLike(taus);
281     auto taus_update = Select(
282         Eq(iota_n, j),
283         Add(taus_zeros, tau, /*broadcast_dimensions=*/tau_broadcast_dims),
284         taus_zeros);
285     taus = taus + taus_update;
286     return std::vector<XlaOp>{a, taus};
287   };
288 
289   auto taus = Zeros(
290       builder,
291       ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {std::min(m, n)})));
292 
293   TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn,
294                                                 {a, taus}, "qr", builder));
295 
296   QrDecomposition result;
297   result.q_and_r = values[0];
298   result.taus = values[1];
299   return result;
300 }
301 
302 // Computes an upper triangular matrix T such that (I - Y @ T @ Y^t) is a
303 // product of the elementary Householder reflectors given by `vs` and `taus`.
304 //
305 // Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY
306 // representation for products of Householder transformations." SIAM Journal on
307 // Scientific and Statistical Computing 10.1 (1989): 53-57.
308 //
309 // def compact_wy(vs, taus):
310 //   m, n = vs.shape[-2:]
311 //   t = np.eye(n) * -taus
312 //   # We premultiply Y.T @ vs, since we would prefer to compute a single matrix
313 //   # multiplication to many matrix-vector products.
314 //   vtv = -taus[None, :] * np.triu(np.conj(vs.T) @ vs, 1) + np.eye(n)
315 //   for i in range(1, n):
316 //     t[:, i] = scipy.linalg.blas.strmm(t, vtv[:, i])
317 //   return t
CompactWYRepresentation(PrimitiveType type,absl::Span<const int64_t> batch_dims,XlaOp vs,XlaOp taus,int64_t m,int64_t n,PrecisionConfig::Precision precision)318 StatusOr<XlaOp> QrExpander::CompactWYRepresentation(
319     PrimitiveType type, absl::Span<const int64_t> batch_dims, XlaOp vs,
320     XlaOp taus, int64_t m, int64_t n, PrecisionConfig::Precision precision) {
321   XlaBuilder* builder = vs.builder();
322 
323   std::vector<int64_t> batch_dim_indices(batch_dims.size());
324   std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
325   int64_t n_index = batch_dims.size() + 1;
326 
327   auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
328                      XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
329     // w has shape [..., m, n]
330     auto t = values[0];
331     const auto vtv = values[1];
332 
333     // yv has shape [..., n, 1]
334     auto yv = DynamicSliceInMinorDims(vtv, {j}, {1});
335 
336     // z has shape [..., n, 1]
337     auto z = BatchDot(t, yv, precision);
338 
339     t = DynamicUpdateSliceInMinorDims(t, z, {j});
340 
341     return std::vector<XlaOp>{t, vtv};
342   };
343 
344   auto tau_scale = BroadcastInDim(-taus, ConcatVectors(batch_dims, {1, n}),
345                                   ConcatVectors(batch_dim_indices, {n_index}));
346 
347   auto eye = Broadcast(IdentityMatrix(builder, type, n, n), batch_dims);
348   auto t = eye;
349 
350   auto vtv = BatchDot(MaybeConjugate(vs, true), /*transpose_x=*/true, vs,
351                       /*transpose_y=*/false, precision);
352   vtv = Select(TriangleMask(vtv, 0), ZerosLike(vtv), vtv);
353   vtv = (vtv + eye) * tau_scale;
354 
355   TF_ASSIGN_OR_RETURN(auto values,
356                       ForEachIndex(n, S32, body_fn, {t, vtv}, "wy", builder));
357   return values[0];
358 }
359 
360 // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan.
361 // def qr_blocked(a, block_size):
362 //   m = a.shape[0]
363 //   n = a.shape[1]
364 //   q = np.eye(m)
365 //   for i in xrange(0, min(m, n), block_size):
366 //     k = min(block_size, min(m, n) - s)
367 //     (a, taus) = qr(a[i:, i:i+k])
368 //     y = np.eye(m, n) + np.tril(a, -1)
369 //     t = CompactWYRepresentation(vs, taus, m-i, k)
370 //     a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
371 //     q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
372 //   return (q, a)
BuildQrDecomposition(XlaOp a,int64_t block_size,PrecisionConfig::Precision precision)373 StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
374     XlaOp a, int64_t block_size, PrecisionConfig::Precision precision) {
375   XlaBuilder* builder = a.builder();
376   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
377   const int num_dims = a_shape.rank();
378   if (num_dims < 2) {
379     return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
380                            a_shape.ToString());
381   }
382   PrimitiveType type = a_shape.element_type();
383 
384   const int64_t m = ShapeUtil::GetDimension(a_shape, -2);
385   const int64_t n = ShapeUtil::GetDimension(a_shape, -1);
386   const int64_t p = std::min(m, n);
387 
388   if (block_size < 1) {
389     return InvalidArgument("block_size argument to QR must be >= 1; got %d",
390                            block_size);
391   }
392 
393   const int64_t num_batch_dims = num_dims - 2;
394   std::vector<int64_t> batch_dims(num_batch_dims);
395   for (int i = 0; i < num_batch_dims; ++i) {
396     batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
397   }
398 
399   std::vector<int64_t> taus_dims = batch_dims;
400   taus_dims.push_back(p);
401   auto taus = Zeros(builder, ShapeUtil::MakeShape(type, taus_dims));
402   for (int64_t i = 0; i < p; i += block_size) {
403     int64_t k = std::min(block_size, p - i);
404 
405     auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
406     TF_ASSIGN_OR_RETURN(auto qr_block, QrBlock(a_block, precision));
407     auto y = Add(IdentityMatrix(builder, type, m - i, k),
408                  Select(TriangleMask(qr_block.q_and_r, -1), qr_block.q_and_r,
409                         ZerosLike(qr_block.q_and_r)),
410                  /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1});
411 
412     a = UpdateSliceInMinorDims(a, qr_block.q_and_r, {i, i});
413     taus = UpdateSliceInMinorDims(taus, qr_block.taus, {i});
414 
415     // Compute the I + Y @ T @ Y^t block representation of a product of
416     // Householder matrices.
417     TF_ASSIGN_OR_RETURN(
418         auto t, CompactWYRepresentation(type, batch_dims, y, qr_block.taus,
419                                         m - i, k, precision));
420 
421     // a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
422     auto yt = BatchDot(y, /*transpose_x=*/false, MaybeConjugate(t, true),
423                        /*transpose_y=*/true, precision);
424     auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
425     auto a_update =
426         BatchDot(MaybeConjugate(y, true), /*transpose_x=*/true, a_panel,
427                  /*transpose_y=*/false, precision);
428     a_update = BatchDot(yt, a_update, precision);
429     a_panel = a_panel + a_update;
430     a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
431   }
432 
433   return Tuple(builder, {a, taus});
434 }
435 
ProductOfElementaryHouseholderReflectors(XlaOp a,XlaOp taus,int64_t block_size,PrecisionConfig::Precision precision)436 StatusOr<XlaOp> QrExpander::ProductOfElementaryHouseholderReflectors(
437     XlaOp a, XlaOp taus, int64_t block_size,
438     PrecisionConfig::Precision precision) {
439   XlaBuilder* builder = a.builder();
440   TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
441   TF_ASSIGN_OR_RETURN(Shape taus_shape, builder->GetShape(taus));
442   const int num_dims = a_shape.rank();
443   if (num_dims < 2) {
444     return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
445                            a_shape.ToString());
446   }
447   PrimitiveType type = a_shape.element_type();
448 
449   const int64_t m = ShapeUtil::GetDimension(a_shape, -2);
450   int64_t n = ShapeUtil::GetDimension(a_shape, -1);
451   const int64_t p = ShapeUtil::GetDimension(taus_shape, -1);
452   if (m < n) {
453     return InvalidArgument(
454         "Argument to product of elementary Householder "
455         "reflectors must have m >= n, got shape %s",
456         a_shape.ToString());
457   }
458 
459   if (block_size < 1) {
460     return InvalidArgument("block_size argument to QR must be >= 1; got %d",
461                            block_size);
462   }
463 
464   const int64_t num_batch_dims = num_dims - 2;
465   std::vector<int64_t> batch_dims(num_batch_dims);
466   for (int i = 0; i < num_batch_dims; ++i) {
467     batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
468   }
469 
470   auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
471   for (int64_t i = 0; i < p; i += block_size) {
472     int64_t k = std::min(block_size, p - i);
473 
474     auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
475     auto y = Add(IdentityMatrix(builder, type, m - i, k),
476                  Select(TriangleMask(a_block, -1), a_block, ZerosLike(a_block)),
477                  /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1});
478 
479     // Compute the I + Y @ T @ Y^t block representation of a product of
480     // Householder matrices.
481     auto taus_block = SliceInMinorDims(taus, {i}, {i + k});
482 
483     TF_ASSIGN_OR_RETURN(
484         auto t, CompactWYRepresentation(type, batch_dims, y, taus_block, m - i,
485                                         k, precision));
486     // q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
487     auto yt = BatchDot(y, /*transpose_x=*/false, MaybeConjugate(t, true),
488                        /*transpose_y=*/true, precision);
489     auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
490     auto q_update = BatchDot(q_panel, y, precision);
491     q_update =
492         BatchDot(q_update, /*transpose_x=*/false, MaybeConjugate(yt, true),
493                  /*transpose_y=*/true, precision);
494     q_panel = q_panel + q_update;
495     q = UpdateSliceInMinorDims(q, q_panel, {0, i});
496   }
497   q = SliceInMinorDims(q, {0, 0}, {m, n});
498   return q;
499 }
500 
501 static const char* kQrCustomCallName = "Qr";
502 static const char* kHouseholderProductCustomCallName =
503     "ProductOfElementaryHouseholderReflectors";
504 
InstructionMatchesPattern(HloInstruction * instruction)505 bool QrExpander::InstructionMatchesPattern(HloInstruction* instruction) {
506   return instruction->opcode() == HloOpcode::kCustomCall &&
507          (instruction->custom_call_target() == kQrCustomCallName ||
508           instruction->custom_call_target() ==
509               kHouseholderProductCustomCallName);
510 }
511 
ExpandInstruction(HloInstruction * instruction)512 StatusOr<HloInstruction*> QrExpander::ExpandInstruction(
513     HloInstruction* instruction) {
514   const std::string name =
515       absl::StrFormat("xla.%s_%s", instruction->custom_call_target(),
516                       instruction->operand(0)->shape().ToString());
517 
518   HloModule* module = instruction->parent()->parent();
519 
520   HloComputation*& computation =
521       computation_cache_.emplace(name, nullptr).first->second;
522   if (!computation) {
523     // Builds a new expansion.
524     //
525     // TODO(b/62327888): We do something unusual here: we build the computation
526     // using the XlaBuilder API, which is nominally an XLA client API. We do
527     // this because the external APIs for building complicated computations
528     // (XlaBuilder) are much more ergonomic than the internal ones. As it turns
529     // out, XlaBuilder isn't really a client API—what it does is build a
530     // HloModuleProto protocol buffer, that we can then deserialize and clone
531     // into our HloModule. Ideally we would avoid the protocol buffer step;
532     // that is left as an exercise for future work.
533     XlaBuilder builder(name);
534     TF_RET_CHECK(instruction->operand_count() >= 1);
535     XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
536     XlaOp result;
537     if (instruction->custom_call_target() == kQrCustomCallName) {
538       TF_RET_CHECK(instruction->operand_count() == 1);
539       TF_ASSIGN_OR_RETURN(
540           result, BuildQrDecomposition(a,
541                                        /*block_size=*/128,
542                                        /*precision=*/PrecisionConfig::HIGHEST));
543     } else {
544       TF_RET_CHECK(instruction->operand_count() == 2);
545       XlaOp taus =
546           Parameter(&builder, 1, instruction->operand(1)->shape(), "taus");
547       TF_ASSIGN_OR_RETURN(result, ProductOfElementaryHouseholderReflectors(
548                                       a, taus, /*block_size=*/128,
549                                       /*precision=*/PrecisionConfig::HIGHEST));
550     }
551 
552     TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result));
553 
554     TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
555                         xla_computation.GetProgramShape());
556     HloModuleConfig config(program_shape);
557     TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
558                                              xla_computation.proto(), config));
559     HloCloneContext context(module);
560     computation =
561         module->DeepCloneComputation(new_module->entry_computation(), &context);
562   }
563 
564   return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
565       instruction->shape(), instruction->operands(), computation));
566 }
567 
568 }  // namespace xla
569