1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/qr_expander.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
22 #include "tensorflow/compiler/xla/client/lib/constants.h"
23 #include "tensorflow/compiler/xla/client/lib/loops.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/lib/matrix.h"
26 #include "tensorflow/compiler/xla/client/lib/qr.h"
27 #include "tensorflow/compiler/xla/client/lib/slicing.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/lib/core/errors.h"
36
37 namespace xla {
38
39 namespace {
40
ConcatVectors(absl::Span<const int64_t> xs,absl::Span<const int64_t> ys)41 std::vector<int64_t> ConcatVectors(absl::Span<const int64_t> xs,
42 absl::Span<const int64_t> ys) {
43 std::vector<int64_t> output;
44 output.reserve(xs.size() + ys.size());
45 std::copy(xs.begin(), xs.end(), std::back_inserter(output));
46 std::copy(ys.begin(), ys.end(), std::back_inserter(output));
47 return output;
48 }
49
50 // Computes sqrt(x^2 + y^2 + ...), avoiding overflow/underflow.
51 // e.g. for 3 arguments:
52 // def norm(x, y, z):
53 // xabs = np.abs(x)
54 // yabs = np.abs(y)
55 // zabs = np.abs(z)
56 // w = np.maximum(np.maximum(xabs, yabs), zabs)
57 // if w == 0:
58 // return 0
59 // else:
60 // return w * np.sqrt((xabs / w)**2 + (yabs / w) ** 2 + (zabs / w) ** 2)
Norm(std::vector<XlaOp> xs)61 XlaOp Norm(std::vector<XlaOp> xs) {
62 CHECK(!xs.empty());
63 XlaOp w;
64 for (size_t i = 0; i < xs.size(); ++i) {
65 xs[i] = Abs(xs[i]);
66 w = i == 0 ? xs[i] : xla::Max(w, xs[i]);
67 }
68
69 XlaOp out;
70 for (size_t i = 0; i < xs.size(); ++i) {
71 XlaOp t = Square(xs[i] / w);
72 out = i == 0 ? t : xla::Add(out, t);
73 }
74 return Select(Eq(w, ZerosLike(w)), ZerosLike(w), w * Sqrt(out));
75 }
76
77 // Computes a Householder reflection of the form:
78 // H = I - tau v v.T.
79 // such that
80 // H . ( x1 ) = ( x1 )
81 // ( x2 ) = ( x2 )
82 // ( ... ) = ( ... )
83 // ( xk ) = ( beta )
84 // ( ... ) ( 0 )
85 // ( ... ) ( 0 )
86 // Unlike the usual formulation, we allow the caller to supply 'k' rather than
87 // only providing the relevant part of 'x' to maintain XLA's static shape
88 // invariant. In addition, the implementation supports batching.
89 // Pseudo-code, without batching:
90 // alpha = x[k]
91 // x_copy = np.copy(x)
92 // x_copy[:k+1] = 0
93 // xnorm = norm2(x_copy)
94 // if xnorm == 0 and np.imag(alpha) == 0:
95 // beta = alpha
96 // tau = 0
97 // v = np.zeros_like(x)
98 // else:
99 // beta = -np.sign(np.real(alpha)) * np.sqrt(alpha * np.conj(alpha) + xnorm)
100 // if np.issubdtype(x.dtype, np.complexfloating):
101 // tau = (beta - alpha) / beta
102 // else:
103 // tau = (beta - np.real(alpha) / beta) + (-np.imag(alpha) / beta) * 1j
104 // v = x / (alpha - beta)
105 // v[k] = 1
106 // return (v, tau, beta)
107 // TODO(phawkins): LAPACK's xLARFG implementation has code for handling
108 // overflows in the norm/beta calculations. Perhaps do the same here.
House(XlaOp x,XlaOp k,absl::Span<const int64_t> batch_dims,const int64_t m,XlaOp * v,XlaOp * tau,XlaOp * beta)109 Status House(XlaOp x, XlaOp k, absl::Span<const int64_t> batch_dims,
110 const int64_t m, XlaOp* v, XlaOp* tau, XlaOp* beta) {
111 XlaBuilder* const builder = x.builder();
112 TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
113 const PrimitiveType type = x_shape.element_type();
114
115 std::vector<int64_t> batch_dim_ids(batch_dims.size());
116 std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
117 const int64_t minor_dim = batch_dims.size();
118
119 XlaOp zero = ScalarLike(x, 0.0);
120
121 // alpha = x[k]
122 XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
123
124 // Compute x[k+1:] (padded with zeros in elements 0..k)
125 XlaOp iota = Iota(builder, S32, m);
126 XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
127 /*broadcast_dimensions=*/{minor_dim});
128
129 XlaOp sigma_is_zero;
130 if (primitive_util::IsComplexType(type)) {
131 // sigma = np.dot(x[k+1:], np.conj(x[k+1:]))
132 auto x_squared = Real(x_after_k * Conj(x_after_k));
133 auto sigma =
134 Reduce(x_squared, ScalarLike(x_squared, 0.0),
135 CreateScalarAddComputation(
136 primitive_util::ComplexComponentType(type), builder),
137 {minor_dim});
138 auto mu = Norm({Real(alpha), Imag(alpha), Sqrt(sigma)});
139
140 sigma_is_zero = Eq(sigma, ScalarLike(sigma, 0));
141 sigma_is_zero = And(sigma_is_zero, Eq(Imag(alpha), ScalarLike(sigma, 0)));
142
143 *beta = Select(Lt(Real(alpha), ScalarLike(sigma, 0)), ScalarLike(mu, 1),
144 ScalarLike(mu, -1)) *
145 mu;
146 *beta = Select(sigma_is_zero, Real(alpha), *beta);
147 *tau = Complex((*beta - Real(alpha)) / *beta, -Imag(alpha) / *beta);
148 } else {
149 // sigma = np.dot(x[k+1:], x[k+1:])
150 auto sigma = Reduce(x_after_k * x_after_k, zero,
151 CreateScalarAddComputation(type, builder), {minor_dim});
152 auto mu = Norm({alpha, Sqrt(sigma)});
153 sigma_is_zero = Eq(sigma, zero);
154
155 XlaOp one = ScalarLike(x, 1.0);
156 *beta = Select(Lt(alpha, zero), one, -one) * mu;
157 *beta = Select(sigma_is_zero, alpha, *beta);
158 *tau = (*beta - alpha) / *beta;
159 }
160 *tau = Select(sigma_is_zero, ZerosLike(*tau), *tau);
161
162 auto divisor =
163 Select(sigma_is_zero, Broadcast(ScalarLike(alpha, 1), batch_dims),
164 alpha - ConvertElementType(*beta, type));
165
166 auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
167 std::vector<int64_t>(batch_dims.size(), 1));
168
169 // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
170 // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
171 *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
172 return OkStatus();
173 }
174
175 } // namespace
176
177 // Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
178 // Loan "Matrix Computations", 4th Edition. This is an unblocked implementation
179 // used as an inner routine of the blocked implementation.
180 // Algorithm is adapted slightly so the shapes inside the loop are static, at
181 // the cost of some redundant computation. Since this is used as an inner block
182 // kernel, accumulates the Householder transformations (vs, taus) rather than
183 // the matrix q.
184 // Equivalent Python code, without batching:
185 // def qr(a):
186 // m = a.shape[0]
187 // n = a.shape[1]
188 // taus = np.zeros([n])
189 // for j in xrange(min(m, n)):
190 // v, tau, beta = house(a[:, j], j)
191 // a[:, j+1:] -= np.conj(tau) * np.dot(v[:, np.newaxis],
192 // np.dot(np.conj(v[np.newaxis, :]), a[:, j+1:]))
193 // # Form column j explicitly rather than relying on the precision of the
194 // # Householder update.
195 // a[j, j] = beta
196 // a[j+1:, j] = v[j+1:]
197 // taus[j] = tau
198 // return (a, taus)
QrBlock(XlaOp a,PrecisionConfig::Precision precision)199 StatusOr<QrDecomposition> QrExpander::QrBlock(
200 XlaOp a, PrecisionConfig::Precision precision) {
201 XlaBuilder* builder = a.builder();
202 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
203 const int num_dims = a_shape.rank();
204 if (num_dims < 2) {
205 return InvalidArgument("Argument to QR must have rank >= 2; got shape %s",
206 a_shape.ToString());
207 }
208 PrimitiveType type = a_shape.element_type();
209
210 const int64_t m = ShapeUtil::GetDimension(a_shape, -2);
211 const int64_t n = ShapeUtil::GetDimension(a_shape, -1);
212
213 const int64_t num_batch_dims = num_dims - 2;
214 std::vector<int64_t> batch_dims(num_batch_dims);
215 for (int i = 0; i < num_batch_dims; ++i) {
216 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
217 }
218
219 std::vector<int64_t> batch_dim_indices(num_batch_dims);
220 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
221
222 auto qr_body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
223 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
224 auto a = values[0];
225 auto taus = values[1];
226
227 // v, tau, beta = house(a[:, j], j)
228 auto x = DynamicSliceInMinorDims(a, {j}, {1});
229 XlaOp v, tau, beta;
230 TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j,
231 batch_dims, m, &v, &tau, &beta));
232
233 const int64_t minor_dim = batch_dims.size();
234 auto iota_mn = Iota(
235 builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {m, n})),
236 minor_dim + 1);
237
238 std::vector<int64_t> shape = batch_dims;
239 shape.push_back(1);
240 shape.push_back(m);
241 auto v_broadcast = Reshape(v, shape);
242 // a[:, j+1:] -= np.conj(tau) * (v[:, np.newaxis] @
243 // (np.conj(v[np.newaxis, :]) @ a[:, j+1:]))
244 // We use masking rather than a loop-variant shape to handle the j+1:
245 // indexing.
246 auto vva = BatchDot(MaybeConjugate(v_broadcast, true),
247 Select(Lt(j, iota_mn), a, ZerosLike(a)), precision);
248 vva = BatchDot(v_broadcast, true, vva, false, precision);
249 a = a - Mul(MaybeConjugate(tau, true), vva,
250 /*broadcast_dimensions=*/batch_dim_indices);
251
252 // a[j, j] = beta
253 // a[j+1:,j] = v[j+1:]
254 auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1});
255 auto predecessor_mask = ConvertElementType(Lt(iota, j), type);
256 auto mask = Broadcast(ConvertElementType(Eq(iota, j), type),
257 std::vector<int64_t>(batch_dims.size(), 1));
258 auto successor_mask = Gt(Iota(a.builder(), S32, m), j);
259 auto new_x = Mul(x, predecessor_mask,
260 /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
261 Mul(ConvertElementType(beta, type), mask,
262 /*broadcast_dimensions=*/batch_dim_indices);
263 new_x = Add(
264 new_x, Select(Broadcast(successor_mask, batch_dims), v, ZerosLike(v)),
265 /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {minor_dim}));
266 // Update a[:,j]
267 std::vector<int64_t> dim_ids(num_dims);
268 std::iota(dim_ids.begin(), dim_ids.end(), 0);
269 new_x = BroadcastInDim(new_x, ConcatVectors(batch_dims, {m, n}),
270 /*broadcast_dimensions=*/dim_ids);
271 a = Select(Eq(iota_mn, j), new_x, a);
272
273 // taus[j] = tau
274 std::vector<int64_t> tau_broadcast_dims(batch_dims.size());
275 std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0);
276
277 auto iota_n =
278 Iota(builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {n})),
279 minor_dim);
280 auto taus_zeros = ZerosLike(taus);
281 auto taus_update = Select(
282 Eq(iota_n, j),
283 Add(taus_zeros, tau, /*broadcast_dimensions=*/tau_broadcast_dims),
284 taus_zeros);
285 taus = taus + taus_update;
286 return std::vector<XlaOp>{a, taus};
287 };
288
289 auto taus = Zeros(
290 builder,
291 ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {std::min(m, n)})));
292
293 TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn,
294 {a, taus}, "qr", builder));
295
296 QrDecomposition result;
297 result.q_and_r = values[0];
298 result.taus = values[1];
299 return result;
300 }
301
302 // Computes an upper triangular matrix T such that (I - Y @ T @ Y^t) is a
303 // product of the elementary Householder reflectors given by `vs` and `taus`.
304 //
305 // Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY
306 // representation for products of Householder transformations." SIAM Journal on
307 // Scientific and Statistical Computing 10.1 (1989): 53-57.
308 //
309 // def compact_wy(vs, taus):
310 // m, n = vs.shape[-2:]
311 // t = np.eye(n) * -taus
312 // # We premultiply Y.T @ vs, since we would prefer to compute a single matrix
313 // # multiplication to many matrix-vector products.
314 // vtv = -taus[None, :] * np.triu(np.conj(vs.T) @ vs, 1) + np.eye(n)
315 // for i in range(1, n):
316 // t[:, i] = scipy.linalg.blas.strmm(t, vtv[:, i])
317 // return t
CompactWYRepresentation(PrimitiveType type,absl::Span<const int64_t> batch_dims,XlaOp vs,XlaOp taus,int64_t m,int64_t n,PrecisionConfig::Precision precision)318 StatusOr<XlaOp> QrExpander::CompactWYRepresentation(
319 PrimitiveType type, absl::Span<const int64_t> batch_dims, XlaOp vs,
320 XlaOp taus, int64_t m, int64_t n, PrecisionConfig::Precision precision) {
321 XlaBuilder* builder = vs.builder();
322
323 std::vector<int64_t> batch_dim_indices(batch_dims.size());
324 std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
325 int64_t n_index = batch_dims.size() + 1;
326
327 auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
328 XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
329 // w has shape [..., m, n]
330 auto t = values[0];
331 const auto vtv = values[1];
332
333 // yv has shape [..., n, 1]
334 auto yv = DynamicSliceInMinorDims(vtv, {j}, {1});
335
336 // z has shape [..., n, 1]
337 auto z = BatchDot(t, yv, precision);
338
339 t = DynamicUpdateSliceInMinorDims(t, z, {j});
340
341 return std::vector<XlaOp>{t, vtv};
342 };
343
344 auto tau_scale = BroadcastInDim(-taus, ConcatVectors(batch_dims, {1, n}),
345 ConcatVectors(batch_dim_indices, {n_index}));
346
347 auto eye = Broadcast(IdentityMatrix(builder, type, n, n), batch_dims);
348 auto t = eye;
349
350 auto vtv = BatchDot(MaybeConjugate(vs, true), /*transpose_x=*/true, vs,
351 /*transpose_y=*/false, precision);
352 vtv = Select(TriangleMask(vtv, 0), ZerosLike(vtv), vtv);
353 vtv = (vtv + eye) * tau_scale;
354
355 TF_ASSIGN_OR_RETURN(auto values,
356 ForEachIndex(n, S32, body_fn, {t, vtv}, "wy", builder));
357 return values[0];
358 }
359
360 // Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan.
361 // def qr_blocked(a, block_size):
362 // m = a.shape[0]
363 // n = a.shape[1]
364 // q = np.eye(m)
365 // for i in xrange(0, min(m, n), block_size):
366 // k = min(block_size, min(m, n) - s)
367 // (a, taus) = qr(a[i:, i:i+k])
368 // y = np.eye(m, n) + np.tril(a, -1)
369 // t = CompactWYRepresentation(vs, taus, m-i, k)
370 // a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
371 // q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
372 // return (q, a)
BuildQrDecomposition(XlaOp a,int64_t block_size,PrecisionConfig::Precision precision)373 StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
374 XlaOp a, int64_t block_size, PrecisionConfig::Precision precision) {
375 XlaBuilder* builder = a.builder();
376 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
377 const int num_dims = a_shape.rank();
378 if (num_dims < 2) {
379 return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
380 a_shape.ToString());
381 }
382 PrimitiveType type = a_shape.element_type();
383
384 const int64_t m = ShapeUtil::GetDimension(a_shape, -2);
385 const int64_t n = ShapeUtil::GetDimension(a_shape, -1);
386 const int64_t p = std::min(m, n);
387
388 if (block_size < 1) {
389 return InvalidArgument("block_size argument to QR must be >= 1; got %d",
390 block_size);
391 }
392
393 const int64_t num_batch_dims = num_dims - 2;
394 std::vector<int64_t> batch_dims(num_batch_dims);
395 for (int i = 0; i < num_batch_dims; ++i) {
396 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
397 }
398
399 std::vector<int64_t> taus_dims = batch_dims;
400 taus_dims.push_back(p);
401 auto taus = Zeros(builder, ShapeUtil::MakeShape(type, taus_dims));
402 for (int64_t i = 0; i < p; i += block_size) {
403 int64_t k = std::min(block_size, p - i);
404
405 auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
406 TF_ASSIGN_OR_RETURN(auto qr_block, QrBlock(a_block, precision));
407 auto y = Add(IdentityMatrix(builder, type, m - i, k),
408 Select(TriangleMask(qr_block.q_and_r, -1), qr_block.q_and_r,
409 ZerosLike(qr_block.q_and_r)),
410 /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1});
411
412 a = UpdateSliceInMinorDims(a, qr_block.q_and_r, {i, i});
413 taus = UpdateSliceInMinorDims(taus, qr_block.taus, {i});
414
415 // Compute the I + Y @ T @ Y^t block representation of a product of
416 // Householder matrices.
417 TF_ASSIGN_OR_RETURN(
418 auto t, CompactWYRepresentation(type, batch_dims, y, qr_block.taus,
419 m - i, k, precision));
420
421 // a[i:, i+k:] += (y @ np.conj(t.T)) @ (np.conj(y.T) @ a[i:, i+k:])
422 auto yt = BatchDot(y, /*transpose_x=*/false, MaybeConjugate(t, true),
423 /*transpose_y=*/true, precision);
424 auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
425 auto a_update =
426 BatchDot(MaybeConjugate(y, true), /*transpose_x=*/true, a_panel,
427 /*transpose_y=*/false, precision);
428 a_update = BatchDot(yt, a_update, precision);
429 a_panel = a_panel + a_update;
430 a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
431 }
432
433 return Tuple(builder, {a, taus});
434 }
435
ProductOfElementaryHouseholderReflectors(XlaOp a,XlaOp taus,int64_t block_size,PrecisionConfig::Precision precision)436 StatusOr<XlaOp> QrExpander::ProductOfElementaryHouseholderReflectors(
437 XlaOp a, XlaOp taus, int64_t block_size,
438 PrecisionConfig::Precision precision) {
439 XlaBuilder* builder = a.builder();
440 TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
441 TF_ASSIGN_OR_RETURN(Shape taus_shape, builder->GetShape(taus));
442 const int num_dims = a_shape.rank();
443 if (num_dims < 2) {
444 return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
445 a_shape.ToString());
446 }
447 PrimitiveType type = a_shape.element_type();
448
449 const int64_t m = ShapeUtil::GetDimension(a_shape, -2);
450 int64_t n = ShapeUtil::GetDimension(a_shape, -1);
451 const int64_t p = ShapeUtil::GetDimension(taus_shape, -1);
452 if (m < n) {
453 return InvalidArgument(
454 "Argument to product of elementary Householder "
455 "reflectors must have m >= n, got shape %s",
456 a_shape.ToString());
457 }
458
459 if (block_size < 1) {
460 return InvalidArgument("block_size argument to QR must be >= 1; got %d",
461 block_size);
462 }
463
464 const int64_t num_batch_dims = num_dims - 2;
465 std::vector<int64_t> batch_dims(num_batch_dims);
466 for (int i = 0; i < num_batch_dims; ++i) {
467 batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
468 }
469
470 auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
471 for (int64_t i = 0; i < p; i += block_size) {
472 int64_t k = std::min(block_size, p - i);
473
474 auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
475 auto y = Add(IdentityMatrix(builder, type, m - i, k),
476 Select(TriangleMask(a_block, -1), a_block, ZerosLike(a_block)),
477 /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1});
478
479 // Compute the I + Y @ T @ Y^t block representation of a product of
480 // Householder matrices.
481 auto taus_block = SliceInMinorDims(taus, {i}, {i + k});
482
483 TF_ASSIGN_OR_RETURN(
484 auto t, CompactWYRepresentation(type, batch_dims, y, taus_block, m - i,
485 k, precision));
486 // q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
487 auto yt = BatchDot(y, /*transpose_x=*/false, MaybeConjugate(t, true),
488 /*transpose_y=*/true, precision);
489 auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
490 auto q_update = BatchDot(q_panel, y, precision);
491 q_update =
492 BatchDot(q_update, /*transpose_x=*/false, MaybeConjugate(yt, true),
493 /*transpose_y=*/true, precision);
494 q_panel = q_panel + q_update;
495 q = UpdateSliceInMinorDims(q, q_panel, {0, i});
496 }
497 q = SliceInMinorDims(q, {0, 0}, {m, n});
498 return q;
499 }
500
501 static const char* kQrCustomCallName = "Qr";
502 static const char* kHouseholderProductCustomCallName =
503 "ProductOfElementaryHouseholderReflectors";
504
InstructionMatchesPattern(HloInstruction * instruction)505 bool QrExpander::InstructionMatchesPattern(HloInstruction* instruction) {
506 return instruction->opcode() == HloOpcode::kCustomCall &&
507 (instruction->custom_call_target() == kQrCustomCallName ||
508 instruction->custom_call_target() ==
509 kHouseholderProductCustomCallName);
510 }
511
ExpandInstruction(HloInstruction * instruction)512 StatusOr<HloInstruction*> QrExpander::ExpandInstruction(
513 HloInstruction* instruction) {
514 const std::string name =
515 absl::StrFormat("xla.%s_%s", instruction->custom_call_target(),
516 instruction->operand(0)->shape().ToString());
517
518 HloModule* module = instruction->parent()->parent();
519
520 HloComputation*& computation =
521 computation_cache_.emplace(name, nullptr).first->second;
522 if (!computation) {
523 // Builds a new expansion.
524 //
525 // TODO(b/62327888): We do something unusual here: we build the computation
526 // using the XlaBuilder API, which is nominally an XLA client API. We do
527 // this because the external APIs for building complicated computations
528 // (XlaBuilder) are much more ergonomic than the internal ones. As it turns
529 // out, XlaBuilder isn't really a client API—what it does is build a
530 // HloModuleProto protocol buffer, that we can then deserialize and clone
531 // into our HloModule. Ideally we would avoid the protocol buffer step;
532 // that is left as an exercise for future work.
533 XlaBuilder builder(name);
534 TF_RET_CHECK(instruction->operand_count() >= 1);
535 XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
536 XlaOp result;
537 if (instruction->custom_call_target() == kQrCustomCallName) {
538 TF_RET_CHECK(instruction->operand_count() == 1);
539 TF_ASSIGN_OR_RETURN(
540 result, BuildQrDecomposition(a,
541 /*block_size=*/128,
542 /*precision=*/PrecisionConfig::HIGHEST));
543 } else {
544 TF_RET_CHECK(instruction->operand_count() == 2);
545 XlaOp taus =
546 Parameter(&builder, 1, instruction->operand(1)->shape(), "taus");
547 TF_ASSIGN_OR_RETURN(result, ProductOfElementaryHouseholderReflectors(
548 a, taus, /*block_size=*/128,
549 /*precision=*/PrecisionConfig::HIGHEST));
550 }
551
552 TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result));
553
554 TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
555 xla_computation.GetProgramShape());
556 HloModuleConfig config(program_shape);
557 TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
558 xla_computation.proto(), config));
559 HloCloneContext context(module);
560 computation =
561 module->DeepCloneComputation(new_module->entry_computation(), &context);
562 }
563
564 return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
565 instruction->shape(), instruction->operands(), computation));
566 }
567
568 } // namespace xla
569