1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
17
18 #include <algorithm>
19 #include <numeric>
20 #include <optional>
21 #include <string>
22 #include <utility>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_join.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
32 #include "tensorflow/compiler/xla/util.h"
33
34 namespace xla {
35
36 namespace {
37 using Analysis = IndexedArrayAnalysis;
38 using UnknownArray = Analysis::UnknownArray;
39 using ConstantArray = Analysis::ConstantArray;
40 using ReshapedArray = Analysis::ReshapedArray;
41 using ScalarIndexedArray = Analysis::ScalarIndexedArray;
42 using absl::StrJoin;
43 } // namespace
44
ToString(Array * root,bool print_constants)45 std::string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
46 switch (root->kind()) {
47 case Array::kUnknown: {
48 auto* unknown_tensor = root->as<UnknownArray>();
49 return absl::StrCat("%", unknown_tensor->instruction().name());
50 }
51
52 case Array::kConstant: {
53 if (print_constants) {
54 std::string contents = root->as<ConstantArray>()->literal()->ToString();
55 return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
56 " ", contents, ")");
57 }
58 return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
59 ")");
60 }
61
62 case Array::kReshaped: {
63 ReshapedArray* reshaped_array = root->as<ReshapedArray>();
64 return absl::StrCat(
65 "(reshape ", ToString(reshaped_array->operand(), print_constants),
66 " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")");
67 }
68
69 case Array::kScalarIndexedConstant:
70 case Array::kScalarIndexed: {
71 auto* indexed_array = root->as<ScalarIndexedArray>();
72 std::string name = root->kind() == Array::kScalarIndexedConstant
73 ? "scalar-indexed-const"
74 : "scalar-indexed";
75 return absl::StrCat(
76 "(", name, " ", ToString(indexed_array->source(), print_constants),
77 " ", ToString(indexed_array->indices(), print_constants), " ",
78 indexed_array->source_dim(), "->[",
79 StrJoin(indexed_array->output_dims(), ","), "])");
80 }
81 }
82 }
83
GetArrayFor(const HloInstruction * instr)84 StatusOr<Analysis::Array*> IndexedArrayAnalysis::GetArrayFor(
85 const HloInstruction* instr) {
86 auto it = cache_.find(instr);
87 if (it != cache_.end()) {
88 return it->second;
89 }
90
91 TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr));
92 return FindOrDie(cache_, instr);
93 }
94
TraverseAndPopulateCache(const HloInstruction * root)95 Status IndexedArrayAnalysis::TraverseAndPopulateCache(
96 const HloInstruction* root) {
97 // Depth first search over the DAG, invoking ComputeArrayFor in post order.
98 // The HLO instructions already in the cache are considered leaves.
99
100 absl::InlinedVector<const HloInstruction*, 4> stack;
101
102 enum DfsState { kDiscovered, kVisited };
103 absl::flat_hash_map<const HloInstruction*, DfsState> dfs_state_map;
104
105 stack.push_back(root);
106 InsertOrDie(&dfs_state_map, root, kDiscovered);
107
108 do {
109 const HloInstruction* instr = stack.back();
110 if (cache_.contains(instr)) {
111 stack.pop_back();
112 continue;
113 }
114
115 switch (FindOrDie(dfs_state_map, instr)) {
116 case kDiscovered: {
117 for (const HloInstruction* operand : instr->operands()) {
118 if (!cache_.contains(operand)) {
119 stack.push_back(operand);
120 CHECK(!dfs_state_map.contains(operand) ||
121 dfs_state_map[operand] == kDiscovered);
122 dfs_state_map[operand] = kDiscovered;
123 }
124 }
125 dfs_state_map[instr] = kVisited;
126 break;
127 }
128
129 case kVisited:
130 stack.pop_back();
131 TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr));
132 InsertOrDie(&cache_, instr, array);
133 break;
134 }
135 } while (!stack.empty());
136
137 return OkStatus();
138 }
139
ComputeArrayFor(const HloInstruction * instr)140 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
141 const HloInstruction* instr) {
142 Array* computed_array;
143 if (instr->IsElementwise() && instr->operand_count() == 1) {
144 TF_ASSIGN_OR_RETURN(
145 computed_array,
146 ComputeArrayForElementwiseUnaryOp(
147 instr->opcode(), FindOrDie(cache_, instr->operand(0))));
148 } else if (instr->IsElementwise() && instr->operand_count() == 2) {
149 TF_ASSIGN_OR_RETURN(
150 computed_array,
151 ComputeArrayForElementwiseBinaryOp(
152 instr->opcode(), FindOrDie(cache_, instr->operand(0)),
153 FindOrDie(cache_, instr->operand(1))));
154 } else if (instr->opcode() == HloOpcode::kConstant) {
155 TF_ASSIGN_OR_RETURN(computed_array,
156 ComputeArrayForConstant(instr->literal()));
157 } else if (instr->opcode() == HloOpcode::kGather) {
158 TF_ASSIGN_OR_RETURN(
159 computed_array,
160 ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(),
161 instr->gather_slice_sizes(),
162 FindOrDie(cache_, instr->operand(0)),
163 FindOrDie(cache_, instr->operand(1))));
164 } else if (instr->opcode() == HloOpcode::kReshape) {
165 TF_ASSIGN_OR_RETURN(
166 computed_array,
167 ComputeArrayForReshape(instr->shape(),
168 FindOrDie(cache_, instr->operand(0))));
169 } else if (instr->opcode() == HloOpcode::kDot) {
170 TF_ASSIGN_OR_RETURN(
171 computed_array,
172 ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(),
173 instr->precision_config(),
174 FindOrDie(cache_, instr->operand(0)),
175 FindOrDie(cache_, instr->operand(1))));
176 } else {
177 computed_array = nullptr;
178 }
179
180 if (!computed_array) {
181 computed_array = Construct<UnknownArray>(instr);
182 }
183
184 return computed_array;
185 }
186
ComputeArrayForConstant(const Literal & literal)187 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForConstant(
188 const Literal& literal) {
189 return Construct<ConstantArray>(&literal);
190 }
191
FoldGatherOfGather(ScalarIndexedArray * source,Array * indices,int64_t source_dim,absl::Span<const int64_t> output_dims,Shape shape)192 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
193 ScalarIndexedArray* source, Array* indices, int64_t source_dim,
194 absl::Span<const int64_t> output_dims, Shape shape) {
195 // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)).
196 // `source` is the inner Gather(A, X).
197
198 Array* a = source->source();
199 Array* x = source->indices();
200 Array* y = indices;
201
202 // This bit is slightly tricky, so we do a naive "simulation" of the two
203 // consecutive gather operations to infer what the composed gather should look
204 // like.
205
206 enum class IndexComponent { Ungathered, GatheredFirst, GatheredSecond };
207
208 std::vector<IndexComponent> simulated_index(a->shape().dimensions_size(),
209 IndexComponent::Ungathered);
210
211 // Simulate the first gather.
212 EraseAt(&simulated_index, source->source_dim());
213 for (int64_t gather_dim : source->output_dims()) {
214 simulated_index.insert(simulated_index.begin() + gather_dim,
215 IndexComponent::GatheredFirst);
216 }
217
218 // Simulate the second gather.
219 EraseAt(&simulated_index, source_dim);
220 for (int64_t output_dim : output_dims) {
221 simulated_index.insert(simulated_index.begin() + output_dim,
222 IndexComponent::GatheredSecond);
223 }
224
225 int64_t source_dim_for_index_array =
226 FindIndex(source->output_dims(), source_dim);
227 CHECK_NE(source_dim_for_index_array, source->output_dims().size());
228
229 std::vector<int64_t> output_dims_for_index_array;
230 int64_t gathered_index_components_seen = 0;
231 for (IndexComponent simulation_dim : simulated_index) {
232 if (simulation_dim == IndexComponent::GatheredSecond) {
233 output_dims_for_index_array.push_back(gathered_index_components_seen);
234 }
235 if (simulation_dim != IndexComponent::Ungathered) {
236 gathered_index_components_seen++;
237 }
238 }
239
240 std::vector<int64_t> dim_sizes_for_composed_index;
241 std::vector<int64_t> output_dims_for_new_gather;
242 for (int64_t i = 0, e = simulated_index.size(); i < e; i++) {
243 if (simulated_index[i] != IndexComponent::Ungathered) {
244 dim_sizes_for_composed_index.push_back(shape.dimensions(i));
245 output_dims_for_new_gather.push_back(i);
246 }
247 }
248
249 Array* inner_indices = ConstructScalarIndexedArray(
250 x, y, source_dim_for_index_array, output_dims_for_index_array,
251 ShapeUtil::MakeShape(x->shape().element_type(),
252 dim_sizes_for_composed_index));
253 return ConstructScalarIndexedArray(a, inner_indices, source->source_dim(),
254 output_dims_for_new_gather,
255 std::move(shape));
256 }
257
ComputeArrayForGather(const Shape & shape,const GatherDimensionNumbers & dim_numbers,absl::Span<const int64_t> slice_sizes,Array * source,Array * indices)258 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
259 const Shape& shape, const GatherDimensionNumbers& dim_numbers,
260 absl::Span<const int64_t> slice_sizes, Array* source, Array* indices) {
261 if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) {
262 VLOG(3) << "ComputeArrayForGather: indices are not scalar";
263 return nullptr;
264 }
265
266 CHECK_EQ(dim_numbers.start_index_map_size(), 1);
267
268 // We can also handle dim_numbers.collapsed_slice_dims_size() == 0 here,
269 // should it become relevant.
270
271 if (dim_numbers.collapsed_slice_dims_size() != 1 ||
272 dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) {
273 VLOG(3) << "ComputeArrayForGather: gather operations must elide "
274 "start_index_map[0] and "
275 "start_index_map[0] only";
276 return nullptr;
277 }
278
279 // ScalarIndexedArray cannot represent gathers that "slice" along some
280 // dimensions -- for instance it cannot represent a gather that picks 5 [2,3]
281 // arrays from an array of size [7,4,6]. We check that condition down below:
282
283 for (int64_t i = 0, e = source->shape().dimensions_size(); i < e; i++) {
284 if (i != dim_numbers.collapsed_slice_dims(0) &&
285 source->shape().dimensions(i) != slice_sizes[i]) {
286 VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i
287 << "] != source->shape().dimensions(" << i << ") -- "
288 << source->shape().dimensions(i) << " vs. " << slice_sizes[i]
289 << " with dim_numbers.collapsed_slice_dims(0) = "
290 << dim_numbers.collapsed_slice_dims(0);
291 return nullptr;
292 }
293 }
294
295 int64_t source_dim = dim_numbers.start_index_map(0);
296 std::vector<int64_t> output_dims;
297 for (int64_t i = 0, e = shape.dimensions_size(); i < e; i++) {
298 if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
299 output_dims.push_back(i);
300 }
301 }
302
303 if (auto* indexed = dynamic_cast<ScalarIndexedArray*>(source)) {
304 if (absl::c_linear_search(indexed->output_dims(), source_dim)) {
305 return FoldGatherOfGather(indexed, indices, source_dim, output_dims,
306 shape);
307 }
308 } else if (auto* constant = dynamic_cast<ConstantArray*>(source)) {
309 return Construct<ScalarIndexedConstantArray>(constant, indices, source_dim,
310 output_dims, shape);
311 }
312
313 return Construct<ScalarIndexedArray>(source, indices, source_dim, output_dims,
314 shape);
315 }
316
317 namespace {
318 // Returns an index into `values` such that the product of the range
319 // [values.begin()+index, values.end()) is equal to `product`. If there is no
320 // such index, return -1. All integers in `values` must be positive.
FindSuffixWithProduct(absl::Span<const int64_t> values,int64_t product)321 int64_t FindSuffixWithProduct(absl::Span<const int64_t> values,
322 int64_t product) {
323 DCHECK(absl::c_all_of(values, [](int64_t value) { return value > 0; }));
324
325 int64_t current_product = 1;
326 int64_t i;
327 for (i = values.size() - 1; i >= 0 && product > current_product; --i) {
328 current_product *= values[i];
329 }
330
331 if (product == current_product) {
332 return i + 1;
333 }
334
335 return -1;
336 }
337
338 struct ReshapePassthroughDimPair {
339 int64_t result_dim;
340 int64_t operand_dim;
341 };
342
343 // Returns a set of dimension pairs such for all (result_dim, operand_dim) in
344 // the set:
345 //
346 // output_index[result_dim] = SourceIndexOfReshape(output_index)[operand_dim]
347 //
348 // The returned vector of pairs is sorted in both the result_dim and the
349 // operand_dim components.
ComputeReshapePassthroughDimPairs(absl::Span<const int64_t> operand_shape,absl::Span<const int64_t> result_shape)350 std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
351 absl::Span<const int64_t> operand_shape,
352 absl::Span<const int64_t> result_shape) {
353 // A reshape can be seen as an index mapping from output index to input index:
354 //
355 // (i_0, ..., i_n) = f(o_0, ..., o_m)
356 //
357 // This function returns the pairs (j, k) for which the following invariant
358 // holds for all indices in the shape:
359 //
360 // o_j == i_k
361 //
362 // And this occurs when:
363 //
364 // O_{j+1} * ... * O_n == I_{k+1} * ... * I_m
365 //
366 // (where O_x are the sizes of the output shape and I_x are the sizes of the
367 // input shape) and the size of the dimension j of the result is the same as
368 // the size of dimension k in the operand.
369 //
370 // These conditions are sufficient because the Reshape HLO is spec'ed such
371 // that the rightmost dimensions are always minor in the flattening and refine
372 // operation.
373
374 std::vector<ReshapePassthroughDimPair> result;
375 int64_t result_subarray_size = 1;
376 for (int64_t result_dim = result_shape.size() - 1; result_dim >= 0;
377 --result_dim) {
378 int64_t candidate_operand_dim =
379 FindSuffixWithProduct(operand_shape, result_subarray_size);
380
381 // result_subarray_size does not include the elements in the current
382 // `result_dim` dimension (we multiply in result_shape[result_dim] at the
383 // end of loop body) so candidate_operand_dim can never be zero.
384 CHECK_NE(candidate_operand_dim, 0)
385 << "result_dim = " << result_dim
386 << ", result_subarray_size = " << result_subarray_size
387 << ", result_shape = [" << StrJoin(result_shape, ",") << "]"
388 << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]";
389
390 if (candidate_operand_dim != -1 &&
391 result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) {
392 result.push_back({/*result_dim=*/result_dim,
393 /*operand_dim=*/candidate_operand_dim - 1});
394 }
395 result_subarray_size *= result_shape[result_dim];
396 }
397
398 absl::c_reverse(result);
399
400 if (VLOG_IS_ON(3)) {
401 std::vector<std::string> result_strings;
402 absl::c_transform(result, std::back_inserter(result_strings),
403 [](ReshapePassthroughDimPair value) {
404 return absl::StrCat(value.result_dim, "->",
405 value.operand_dim);
406 });
407 VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to ["
408 << StrJoin(result_shape, ",") << "] passthrough indices are ["
409 << StrJoin(result_strings, ",")
410 << "] (legend: `result`->`operand`)";
411 }
412
413 DCHECK(absl::c_is_sorted(
414 result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
415 return lhs.result_dim < rhs.result_dim;
416 }));
417
418 DCHECK(absl::c_is_sorted(
419 result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
420 return lhs.operand_dim < rhs.operand_dim;
421 }));
422
423 return result;
424 }
425
426 // Return true if `dim` is stated as an passthrough operand dim in
427 // `passthrough_dims`.
IsReshapePassthroughOperandDim(absl::Span<const ReshapePassthroughDimPair> passthrough_dims,int64_t dim)428 bool IsReshapePassthroughOperandDim(
429 absl::Span<const ReshapePassthroughDimPair> passthrough_dims, int64_t dim) {
430 return absl::c_any_of(passthrough_dims,
431 [&](ReshapePassthroughDimPair passthrough_dim_pair) {
432 return passthrough_dim_pair.operand_dim == dim;
433 });
434 }
435
436 // Maps `operand_dim` which must be an passthrough operand dimension to its
437 // corresponding passthrough result dimension based on `passthrough_dims`.
MapPassthroughOperandDimToResultDim(absl::Span<const ReshapePassthroughDimPair> passthrough_dims,int64_t operand_dim)438 int64_t MapPassthroughOperandDimToResultDim(
439 absl::Span<const ReshapePassthroughDimPair> passthrough_dims,
440 int64_t operand_dim) {
441 auto it = absl::c_find_if(
442 passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) {
443 return passthrough_dim_pair.operand_dim == operand_dim;
444 });
445 CHECK(it != passthrough_dims.end());
446 return it->result_dim;
447 }
448
FindSourcePositionForPassthroughResultDim(absl::Span<const int64_t> operand_shape,absl::Span<const int64_t> result_shape,int64_t source_passthrough_dim)449 int64_t FindSourcePositionForPassthroughResultDim(
450 absl::Span<const int64_t> operand_shape,
451 absl::Span<const int64_t> result_shape, int64_t source_passthrough_dim) {
452 VLOG(3) << "FindSourcePositionForPassthroughResultDim(["
453 << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",")
454 << "], " << source_passthrough_dim << ")";
455
456 int64_t indexed_source_subarray_size =
457 std::accumulate(operand_shape.begin() + source_passthrough_dim + 1,
458 operand_shape.end(), 1LL, std::multiplies<int64_t>());
459
460 return FindSuffixWithProduct(result_shape, indexed_source_subarray_size);
461 }
462
StripDegenerateDimensions(const Shape & shape)463 Shape StripDegenerateDimensions(const Shape& shape) {
464 DimensionVector new_dims;
465 absl::c_copy_if(shape.dimensions(), std::back_inserter(new_dims),
466 [](int64_t dim) { return dim != 1; });
467 return ShapeUtil::MakeShape(shape.element_type(), new_dims);
468 }
469 }; // namespace
470
471 StatusOr<ScalarIndexedArray*>
ReshapeToRemoveDegenerateDims(ScalarIndexedArray * operand)472 IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims(
473 ScalarIndexedArray* operand) {
474 const Shape& shape = operand->shape();
475 if (!ShapeUtil::HasDegenerateDimensions(shape)) {
476 return operand;
477 }
478
479 // We only need to reshape out the degenerate dims from the indices and the
480 // source (except the source dim).
481
482 const Shape& source_shape = operand->source()->shape();
483 DimensionVector new_source_shape_dims;
484 for (int64_t i = 0, e = source_shape.dimensions_size(); i < e; i++) {
485 if (i == operand->source_dim() || source_shape.dimensions(i) != 1) {
486 new_source_shape_dims.push_back(source_shape.dimensions(i));
487 }
488 }
489
490 Shape new_source_shape =
491 ShapeUtil::MakeShape(shape.element_type(), new_source_shape_dims);
492 Shape new_indices_shape =
493 StripDegenerateDimensions(operand->indices()->shape());
494
495 TF_ASSIGN_OR_RETURN(
496 Array* const new_source,
497 ComputeArrayForReshape(new_source_shape, operand->source()));
498 TF_ASSIGN_OR_RETURN(
499 Array* const new_indices,
500 ComputeArrayForReshape(new_indices_shape, operand->indices()));
501
502 // Build the new output dims while keeping track of the degenerate dims that
503 // will no longer be present.
504 DimensionVector new_output_dims;
505 int64_t degenerate_dims_seen = 0;
506 for (int64_t i = 0, e = shape.dimensions_size(); i < e; i++) {
507 if (shape.dimensions(i) == 1) {
508 degenerate_dims_seen++;
509 } else if (absl::c_linear_search(operand->output_dims(), i)) {
510 new_output_dims.push_back(i - degenerate_dims_seen);
511 }
512 }
513
514 // Similarly, build the new source dim while keeping track of the degenerate
515 // dims that will no longer be present.
516 int64_t degenerate_dims_before_source_dim =
517 std::count(source_shape.dimensions().begin(),
518 source_shape.dimensions().begin() + operand->source_dim(), 1);
519 int64_t new_source_dim =
520 operand->source_dim() - degenerate_dims_before_source_dim;
521
522 return ConstructScalarIndexedArray(
523 new_source, new_indices, new_source_dim,
524 InlinedVectorToVector(new_output_dims),
525 StripDegenerateDimensions(operand->shape()));
526 }
527
ReshapeToAddDegenerateDims(ScalarIndexedArray * operand,absl::Span<const int64_t> degenerate_dims)528 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims(
529 ScalarIndexedArray* operand, absl::Span<const int64_t> degenerate_dims) {
530 if (degenerate_dims.empty()) {
531 return operand;
532 }
533
534 CHECK(!ShapeUtil::HasDegenerateDimensions(operand->shape()));
535
536 DimensionVector new_output_dims = [&]() {
537 // To make things easy we use a "scratch" buffer of bools where the i'th
538 // element is true iff the i'th component of the result index is an output
539 // index.
540
541 absl::InlinedVector<bool, 6> output_dims_bitvector(
542 operand->shape().dimensions_size());
543 for (int64_t output_dim : operand->output_dims()) {
544 output_dims_bitvector[output_dim] = true;
545 }
546
547 for (int64_t degenerate_dim : degenerate_dims) {
548 InsertAt(&output_dims_bitvector, degenerate_dim, false);
549 }
550
551 DimensionVector result;
552 result.reserve(operand->output_dims().size());
553 for (int64_t i = 0, e = output_dims_bitvector.size(); i < e; i++) {
554 if (output_dims_bitvector[i]) {
555 result.push_back(i);
556 }
557 }
558
559 return result;
560 }();
561
562 DimensionVector new_result_shape_dims;
563 absl::c_copy(operand->shape().dimensions(),
564 std::back_inserter(new_result_shape_dims));
565 for (int64_t degenerate_dim : degenerate_dims) {
566 InsertAt(&new_result_shape_dims, degenerate_dim, 1);
567 }
568
569 DimensionVector new_source_shape_dims = new_result_shape_dims;
570 for (int64_t output_dim : new_output_dims) {
571 EraseAt(&new_source_shape_dims, output_dim);
572 }
573
574 int64_t new_source_dim = [&]() {
575 for (int i = 0, e = new_source_shape_dims.size(); i < e; i++) {
576 int64_t non_degenerate_dims_seen = 0;
577 if (non_degenerate_dims_seen == operand->source_dim()) {
578 return i;
579 }
580 if (new_source_shape_dims[new_source_dim] != 1) {
581 non_degenerate_dims_seen++;
582 }
583 }
584 LOG(FATAL) << "Did not find source dim in " << ToString(operand);
585 }();
586
587 int64_t source_dim_size =
588 operand->source()->shape().dimensions(operand->source_dim());
589 InsertAt(&new_source_shape_dims, /*index=*/new_source_dim,
590 /*value=*/source_dim_size);
591
592 Shape new_source_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
593 new_source_shape_dims);
594 Shape new_result_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
595 new_result_shape_dims);
596
597 TF_ASSIGN_OR_RETURN(
598 Array* const new_source,
599 ComputeArrayForReshape(new_source_shape, operand->source()));
600 return ConstructScalarIndexedArray(
601 new_source, operand->indices(), new_source_dim,
602 InlinedVectorToVector(new_output_dims), new_result_shape);
603 }
604
FoldReshapeOfGather(const Shape & shape,ScalarIndexedConstantArray * operand)605 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldReshapeOfGather(
606 const Shape& shape, ScalarIndexedConstantArray* operand) {
607 VLOG(3) << "FoldReshapeOfGather(" << ToString(operand) << ")";
608
609 // To make things easier on ourselves, instead of directly trying to fold the
610 // reshape of `operand` to `shape`, we call
611 // `FoldReshapeOfGatherNoDegenerateDims` on shapes without degenerate dims and
612 // handle the degenerate dimensions here by inserting reshapes.
613
614 TF_ASSIGN_OR_RETURN(ScalarIndexedArray* const operand_without_degenerate_dims,
615 ReshapeToRemoveDegenerateDims(operand));
616
617 Shape output_shape_without_degenerate_dims = StripDegenerateDimensions(shape);
618 TF_ASSIGN_OR_RETURN(
619 ScalarIndexedArray* const folded_reshape_without_degenerate_dims,
620 FoldReshapeOfGatherNoDegenerateDims(
621 output_shape_without_degenerate_dims,
622 operand_without_degenerate_dims->as<ScalarIndexedConstantArray>()));
623
624 if (folded_reshape_without_degenerate_dims == nullptr) {
625 return nullptr;
626 }
627
628 DimensionVector degenerate_result_dims;
629 for (int64_t i = 0, e = shape.dimensions_size(); i < e; i++) {
630 if (shape.dimensions(i) == 1) {
631 degenerate_result_dims.push_back(i);
632 }
633 }
634
635 return ReshapeToAddDegenerateDims(folded_reshape_without_degenerate_dims,
636 degenerate_result_dims);
637 }
638
639 StatusOr<ScalarIndexedArray*>
FoldReshapeOfGatherNoDegenerateDims(const Shape & shape,ScalarIndexedConstantArray * scalar_indexed)640 IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
641 const Shape& shape, ScalarIndexedConstantArray* scalar_indexed) {
642 VLOG(3) << "FoldReshapeOfGatherNoDegenerateDims(" << ToString(scalar_indexed)
643 << ")";
644 CHECK(!ShapeUtil::HasDegenerateDimensions(shape));
645 CHECK(!ShapeUtil::HasDegenerateDimensions(scalar_indexed->shape()));
646
647 // Try to fold Reshape(ScalarIndexed(Const, Indices))
648 // => ScalarIndexed(Const', Indices)
649 //
650 // We can view the reshape and the scalar-indexed operations as functions that
651 // map an output index (i.e. an index into the result) to an input index
652 // (i.e. an index into the operand). The key idea used here is that the
653 // output-to-input mapping for some reshape operations may "pass through" some
654 // output dimensions into the input space unchanged -- i.e. there may exist
655 // output dimension "O" and input dimension "I" such that OutputIndex[O] is
656 // always == InputIndexForReshape(OutputIndex)[I]. If these pass-through
657 // dimensions in the input space of the reshape happen to be include all the
658 // output dimensions for the scalar-indexed node then, roughly, the following
659 // holds:
660 //
661 // SourceIndexOfScalarIndexed(SourceIndexOfReshape(Idx))
662 // == SourceIndexOfScalarIndexed(SourceIndexOfReshape(Ps ++ Qs))
663 //
664 // Where Ps are the set of the pass-through components of Idx that are
665 // also the output dims of the scalar-indexed node, and Qs are the rest.
666 // For brevity, we're playing fast and loose with the notation here -- we
667 // don't literally require Idx to be a concatenation of Ps and Qs, as
668 // suggested by the "++".
669 //
670 // == SourceIndexOfScalarIndexed(Ps ++ SourceIndexOfReshape(Qs))
671 //
672 // Again, we're playing fast and loose with the notation around "++".
673 // Generally this ++ will be a different function that the ++ in the
674 // previous step.
675 //
676 // If the scalar-indexed node has a constant as the source then the
677 // SourceIndexOfReshape function can be "folded into" the constant itself by
678 // reshaping it, leaving us with:
679 //
680 // == SourceIndexOfScalarIndexed(Ps ++ Qs)
681 // == SourceIndexOfScalarIndexed(Idx)
682 //
683 // which is just a scalar-indexed node (with parameters different from the
684 // scalar-indexed node we started with) with a reshaped constant as the
685 // source.
686 //
687 // We can't fold SourceIndexOfReshape into the constant without introducing
688 // another precondition: since the new scalar-indexed node will have a
689 // reshaped (constant) array as its source it will, in general, have a
690 // different source dimension than the original scalar-indexed node. This
691 // source dimension will have to be a passthrough dimension of the
692 // SourceIndexOfReshape indexing function that is folded into the source. And
693 // such a dimension need not exist so this is a non-trivial precondition.
694
695 std::vector<ReshapePassthroughDimPair> reshape_passthrough_dims =
696 ComputeReshapePassthroughDimPairs(
697 /*operand_shape=*/scalar_indexed->shape().dimensions(),
698 /*result_shape=*/shape.dimensions());
699
700 auto is_reshape_passthrough_operand_dim = [&](int64_t operand_dim) {
701 return IsReshapePassthroughOperandDim(reshape_passthrough_dims,
702 operand_dim);
703 };
704
705 if (!absl::c_all_of(scalar_indexed->output_dims(),
706 is_reshape_passthrough_operand_dim)) {
707 VLOG(3) << "Not all output dims are passthrough dims "
708 << ToString(scalar_indexed);
709 return nullptr;
710 }
711
712 // To compute the shape of the source for the new scalar-indexed node we're
713 // going to create, we first "undo" the scalar-indexed operation.
714 std::vector<int64_t> new_scalar_indexed_source_shape(
715 shape.dimensions().begin(), shape.dimensions().end());
716 for (int64_t i = scalar_indexed->output_dims().size() - 1; i >= 0; i--) {
717 int64_t output_dim = scalar_indexed->output_dims()[i];
718 int64_t output_dim_after_reshape = MapPassthroughOperandDimToResultDim(
719 reshape_passthrough_dims, output_dim);
720 EraseAt(&new_scalar_indexed_source_shape, output_dim_after_reshape);
721 }
722
723 // After this, we need to add in the dimension that will be the source
724 // dimension for the new scalar-indexed node. A scalar-indexed node "removes"
725 // the source dimensions and "adds" the output dimensions, so to get back to
726 // the shape for the *source* of the scalar-indexed node we need to remove the
727 // output dims (which we did above) and then add back the source dim (which we
728 // are about to do below):
729
730 const Shape& scalar_indexed_source_shape = scalar_indexed->source()->shape();
731
732 int64_t source_dim_for_new_scalar_indexed_node =
733 FindSourcePositionForPassthroughResultDim(
734 /*operand_shape=*/scalar_indexed_source_shape.dimensions(),
735 /*result_shape=*/new_scalar_indexed_source_shape,
736 scalar_indexed->source_dim());
737
738 // We may not be able to find a source dim for the new scalar-indexed node.
739 // For instance consider:
740 //
741 // operand = s32[3,5,2] constant({...})
742 // indices = s32[7] parameter(0)
743 // gather = s32[3,2,7] gather(operand, indices),
744 // offset_dims={0,1},
745 // collapsed_slice_dims={1},
746 // start_index_map={1},
747 // index_vector_dim=1,
748 // slice_sizes={3,1,2}
749 // reshape = s32[6,7] reshape(gather)
750 //
751 // In this case the gather maps to:
752 // (scalar-indexed-const (constant s32[3,5,2]) %indices 1->[2])
753 //
754 // and the reshape passes through dimension 2 from its input into dimension 1
755 // in its output. However, we can't rewrite the reshape as a scalar-indexed
756 // node because then we'd have to reshape the [3,5,2] `operand` array to
757 // [6,5], but then dimension 1 of the reshaped [6,5] array indexes differently
758 // (a.k.a. isn't pass-through) than the [3,5,2] array.
759
760 if (source_dim_for_new_scalar_indexed_node == -1) {
761 VLOG(3) << "Could not compute the source dim for the new scalar indexed "
762 "node: scalar_indexed_source_shape = ["
763 << StrJoin(scalar_indexed_source_shape.dimensions(), ",")
764 << "] and new_scalar_indexed_source_shape = ["
765 << StrJoin(new_scalar_indexed_source_shape, ",") << "]";
766 return nullptr;
767 }
768
769 InsertAt(
770 &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node,
771 scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim()));
772
773 CHECK_EQ(absl::c_accumulate(new_scalar_indexed_source_shape, 1LL,
774 std::multiplies<int64_t>()),
775 ShapeUtil::ElementsIn(scalar_indexed_source_shape));
776
777 CHECK(IsReshapePassthroughOperandDim(
778 ComputeReshapePassthroughDimPairs(
779 /*operand_shape=*/scalar_indexed_source_shape.dimensions(),
780 /*result_shape=*/new_scalar_indexed_source_shape),
781 scalar_indexed->source_dim()));
782
783 auto map_passthrough_operand_dim_to_result_dim = [&](int64_t result_dim) {
784 return MapPassthroughOperandDimToResultDim(reshape_passthrough_dims,
785 result_dim);
786 };
787
788 std::vector<int64_t> output_dims_for_new_scalar_indexed_node;
789 absl::c_transform(scalar_indexed->output_dims(),
790 std::back_inserter(output_dims_for_new_scalar_indexed_node),
791 map_passthrough_operand_dim_to_result_dim);
792
793 TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal,
794 TakeOwnership(scalar_indexed->literal().Reshape(
795 new_scalar_indexed_source_shape)));
796 TF_ASSIGN_OR_RETURN(
797 Array * new_scalar_indexed_source,
798 ComputeArrayForConstant(*new_scalar_indexed_source_literal));
799
800 return ConstructScalarIndexedArray(
801 new_scalar_indexed_source, scalar_indexed->indices(),
802 source_dim_for_new_scalar_indexed_node,
803 output_dims_for_new_scalar_indexed_node, shape);
804 }
805
ComputeArrayForReshape(const Shape & shape,Array * operand)806 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
807 const Shape& shape, Array* operand) {
808 if (ShapeUtil::Compatible(operand->shape(), shape)) {
809 return operand;
810 }
811
812 if (auto* scalar_indexed =
813 dynamic_cast<ScalarIndexedConstantArray*>(operand)) {
814 TF_ASSIGN_OR_RETURN(Analysis::Array * reshape_folded_into_gather,
815 FoldReshapeOfGather(shape, scalar_indexed));
816 if (reshape_folded_into_gather) {
817 return reshape_folded_into_gather;
818 }
819 }
820
821 if (auto* constant_array = dynamic_cast<ConstantArray*>(operand)) {
822 TF_ASSIGN_OR_RETURN(
823 Literal* const new_literal,
824 TakeOwnership(constant_array->literal()->Reshape(shape.dimensions())));
825 return Construct<ConstantArray>(new_literal);
826 }
827
828 return Construct<ReshapedArray>(operand, shape);
829 }
830
831 StatusOr<Analysis::Array*>
ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,Array * lhs,Array * rhs)832 IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
833 Array* lhs,
834 Array* rhs) {
835 // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices))
836 // => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices)
837 //
838 // We can do this if every output dimension from the scalar-indexed node is a
839 // broadcasted dimension for the broadcast node. Informally, the precondition
840 // means Broadcast(Const0)[IDX] is solely a function of the components of IDX
841 // that are not output-dims for the scalar-indexed node. In other words, for
842 // every assignment to the non-output dims in IDX we have a "constant" LHS to
843 // the BinaryOp. This transform propagates this "constant" to the source for
844 // the scalar-indexed node.
845
846 ScalarIndexedConstantArray* lhs_scalar_indexed_const =
847 dynamic_cast<ScalarIndexedConstantArray*>(lhs);
848 ScalarIndexedConstantArray* rhs_scalar_indexed_const =
849 dynamic_cast<ScalarIndexedConstantArray*>(rhs);
850
851 bool lhs_is_indexed;
852
853 // One of the operands must be scalar-indexed and the other must be a
854 // broadcast of a constant.
855 if (lhs_scalar_indexed_const && !rhs_scalar_indexed_const) {
856 lhs_is_indexed = true;
857 } else if (rhs_scalar_indexed_const && !lhs_scalar_indexed_const) {
858 lhs_is_indexed = false;
859 } else {
860 return nullptr;
861 }
862
863 ScalarIndexedConstantArray* scalar_indexed_const =
864 lhs_is_indexed ? lhs_scalar_indexed_const : rhs_scalar_indexed_const;
865 UnknownArray* candidate_broadcast_array =
866 dynamic_cast<UnknownArray*>(lhs_is_indexed ? rhs : lhs);
867 if (!candidate_broadcast_array ||
868 candidate_broadcast_array->instruction().opcode() !=
869 HloOpcode::kBroadcast) {
870 return nullptr;
871 }
872
873 const HloInstruction* broadcast_instr =
874 &candidate_broadcast_array->instruction();
875 const HloInstruction* broadcast_const_operand = broadcast_instr->operand(0);
876 if (broadcast_const_operand->opcode() != HloOpcode::kConstant) {
877 return nullptr;
878 }
879
880 absl::Span<const int64_t> broadcast_dims = broadcast_instr->dimensions();
881 auto is_broadcasted_dim = [&](int64_t output_dim) {
882 return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end();
883 };
884
885 // All of the output dims must be "broadcasted" dims for the other operand.
886 if (!absl::c_all_of(scalar_indexed_const->output_dims(),
887 is_broadcasted_dim)) {
888 return nullptr;
889 }
890
891 // To figure out the broadcast dimensions for the (constant) source for the
892 // scalar-indexed node, we "simulate" the index transformation done by the
893 // existing broadcast:
894 enum class IndexComponent { Broadcasted, NotBroadcasted };
895 std::vector<IndexComponent> simulated_index(
896 broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted);
897 for (int64_t broadcast_dim : broadcast_dims) {
898 simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted;
899 }
900
901 // The scalar-indexed node "removes" the source dim and "inserts" the output
902 // dims. We do the opposite here to undo the scalar-indexed operation.
903 absl::Span<const int64_t> output_dims = scalar_indexed_const->output_dims();
904 for (int64_t i = output_dims.size() - 1; i >= 0; --i) {
905 CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted);
906 EraseAt(&simulated_index, output_dims[i]);
907 }
908
909 InsertAt(&simulated_index, scalar_indexed_const->source_dim(),
910 IndexComponent::Broadcasted);
911
912 // new_inner_broadcast_dims holds the broadcast dimensions for the inner
913 // BinaryOp(Broadcast'(Const0), Const1). We now translate simulated_index to
914 // new_inner_broadcast_dims.
915 std::vector<int64_t> new_inner_broadcast_dims;
916 for (int64_t i = 0; i < simulated_index.size(); i++) {
917 if (simulated_index[i] == IndexComponent::NotBroadcasted) {
918 new_inner_broadcast_dims.push_back(i);
919 }
920 }
921
922 // inner_broadcast_result is the Broadcast'(Const0) bit in
923 // BinaryOp(Broadcast'(Const0), Const1)
924 TF_ASSIGN_OR_RETURN(
925 Literal inner_broadcast_result,
926 broadcast_const_operand->literal().Broadcast(
927 scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
928
929 // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1)
930 const Literal* literal_for_new_source;
931 if (lhs_is_indexed) {
932 TF_ASSIGN_OR_RETURN(
933 literal_for_new_source,
934 TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
935 opcode, scalar_indexed_const->literal(), inner_broadcast_result)));
936 } else {
937 TF_ASSIGN_OR_RETURN(
938 literal_for_new_source,
939 TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
940 opcode, inner_broadcast_result, scalar_indexed_const->literal())));
941 }
942
943 ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
944 return Construct<ScalarIndexedConstantArray>(
945 new_source, scalar_indexed_const->indices(),
946 scalar_indexed_const->source_dim(),
947 std::vector<int64_t>(scalar_indexed_const->output_dims().begin(),
948 scalar_indexed_const->output_dims().end()),
949 scalar_indexed_const->shape());
950 }
951
952 StatusOr<Analysis::Array*>
ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,Array * operand)953 IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
954 Array* operand) {
955 auto* scalar_indexed_const =
956 dynamic_cast<ScalarIndexedConstantArray*>(operand);
957 if (scalar_indexed_const == nullptr) {
958 return nullptr;
959 }
960
961 // Fold UnaryOp(ScalarIndexed(Const, Indices))
962 // => ScalarIndexed(UnaryOp(Const), Indices)
963
964 TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
965 TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp(
966 opcode, scalar_indexed_const->literal())));
967 ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
968 return Construct<ScalarIndexedConstantArray>(
969 new_source, scalar_indexed_const->indices(),
970 scalar_indexed_const->source_dim(),
971 SpanToVector(scalar_indexed_const->output_dims()),
972 scalar_indexed_const->shape());
973 }
974
975 namespace {
976
977 // Returns the non-contracting non-batch dimension (as per `contracting_dims`
978 // and `batch_dims`) if there is exactly one, otherwise returns nullopt.
GetOnlyNonContractingNonBatchDim(int64_t rank,absl::Span<const int64_t> contracting_dims,absl::Span<const int64_t> batch_dims)979 std::optional<int64_t> GetOnlyNonContractingNonBatchDim(
980 int64_t rank, absl::Span<const int64_t> contracting_dims,
981 absl::Span<const int64_t> batch_dims) {
982 std::optional<int64_t> result;
983 for (int64_t dim = 0; dim < rank; dim++) {
984 if (!absl::c_linear_search(contracting_dims, dim) &&
985 !absl::c_linear_search(batch_dims, dim)) {
986 if (result.has_value()) {
987 return std::nullopt;
988 }
989 result = dim;
990 }
991 }
992 return result;
993 }
994
995 // Returns true if `indexed_array`, which is either the LHS or the RHS of a Dot
996 // HLO, can be folded into the dot operation. For now these conditions are both
997 // necessary and sufficient.
998 //
999 // `tag` describes the caller. Used only for logging.
1000 //
1001 // `contracting_dims` and `batch_dims` are the contracting and batch dimensions
1002 // of whatever operand `indexed_array` is to the dot (LHS or RHS).
CanFoldDotIntoIndexedArray(absl::string_view tag,Analysis::ScalarIndexedConstantArray * indexed_array,absl::Span<const int64_t> contracting_dims,absl::Span<const int64_t> batch_dims)1003 bool CanFoldDotIntoIndexedArray(
1004 absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array,
1005 absl::Span<const int64_t> contracting_dims,
1006 absl::Span<const int64_t> batch_dims) {
1007 std::optional<int64_t> non_contracting_non_batch_dim =
1008 GetOnlyNonContractingNonBatchDim(indexed_array->shape().rank(),
1009 contracting_dims, batch_dims);
1010 if (!non_contracting_non_batch_dim.has_value()) {
1011 VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions";
1012 return false;
1013 }
1014
1015 if (indexed_array->output_dims().size() != 1 ||
1016 indexed_array->output_dims()[0] != *non_contracting_non_batch_dim) {
1017 VLOG(3) << tag << ": output dims != the lhs non-contracting non-batch dim";
1018 return false;
1019 }
1020
1021 int64_t indexed_array_rank = indexed_array->shape().rank();
1022 if (indexed_array->source_dim() < (indexed_array_rank - 2)) {
1023 // This restriction can be lifted by inserting reshape nodes.
1024 VLOG(3) << tag
1025 << ": source dim is not in the low two dims, won't be able to form "
1026 "a matmul";
1027 return false;
1028 }
1029
1030 return true;
1031 }
1032
1033 } // namespace
1034
1035 StatusOr<Analysis::Array*>
ComputeArrayForDotWithIndexedLhs(const Shape & shape,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,ScalarIndexedConstantArray * lhs,ConstantArray * rhs)1036 IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
1037 const Shape& shape, const DotDimensionNumbers& dim_numbers,
1038 const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs,
1039 ConstantArray* rhs) {
1040 VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " "
1041 << ToString(rhs);
1042 if (!CanFoldDotIntoIndexedArray(
1043 "ComputeArrayForDotWithIndexedLhs", lhs, /*contracting_dims=*/
1044 dim_numbers.lhs_contracting_dimensions(),
1045 /*batch_dims=*/dim_numbers.lhs_batch_dimensions())) {
1046 return nullptr;
1047 }
1048
1049 int64_t lhs_rank = lhs->shape().rank();
1050 DotDimensionNumbers new_dim_numbers = dim_numbers;
1051 new_dim_numbers.set_lhs_contracting_dimensions(
1052 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1));
1053
1054 TF_ASSIGN_OR_RETURN(
1055 Literal * literal_for_new_source,
1056 TakeOwnership(HloEvaluator{}.EvaluateDotOp(
1057 new_dim_numbers, precision_config, lhs->literal(), *rhs->literal())));
1058
1059 // The new source dimension is wherever the non-batch non-contracting LHS
1060 // dimension "went".
1061 int64_t new_source_dim = dim_numbers.lhs_batch_dimensions_size() +
1062 dim_numbers.rhs_batch_dimensions_size();
1063
1064 ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
1065 return Construct<ScalarIndexedConstantArray>(
1066 new_source, lhs->indices(), new_source_dim,
1067 SpanToVector(lhs->output_dims()), shape);
1068 }
1069
1070 StatusOr<Analysis::Array*>
ComputeArrayForDotWithIndexedRhs(const Shape & shape,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,ConstantArray * lhs,ScalarIndexedConstantArray * rhs)1071 IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
1072 const Shape& shape, const DotDimensionNumbers& dim_numbers,
1073 const PrecisionConfig& precision_config, ConstantArray* lhs,
1074 ScalarIndexedConstantArray* rhs) {
1075 VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " "
1076 << ToString(rhs);
1077 if (!CanFoldDotIntoIndexedArray(
1078 "ComputeArrayForDotWithIndexedRhs", rhs, /*contracting_dims=*/
1079 dim_numbers.rhs_contracting_dimensions(),
1080 /*batch_dims=*/dim_numbers.rhs_batch_dimensions())) {
1081 return nullptr;
1082 }
1083
1084 int64_t rhs_rank = rhs->shape().rank();
1085
1086 DotDimensionNumbers new_dim_numbers = dim_numbers;
1087 new_dim_numbers.set_rhs_contracting_dimensions(
1088 0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1));
1089
1090 TF_ASSIGN_OR_RETURN(
1091 Literal * literal_for_new_source,
1092 TakeOwnership(HloEvaluator{}.EvaluateDotOp(
1093 new_dim_numbers, precision_config, *lhs->literal(), rhs->literal())));
1094
1095 // The new source dimension is wherever the non-batch non-contracting RHS
1096 // dimension "went".
1097 int64_t new_source_dim = dim_numbers.lhs_batch_dimensions_size() +
1098 dim_numbers.rhs_batch_dimensions_size() + 1;
1099
1100 ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
1101 return Construct<ScalarIndexedConstantArray>(
1102 new_source, rhs->indices(), new_source_dim,
1103 SpanToVector(rhs->output_dims()), shape);
1104 }
1105
ComputeArrayForDot(const Shape & shape,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,Array * lhs,Array * rhs)1106 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
1107 const Shape& shape, const DotDimensionNumbers& dim_numbers,
1108 const PrecisionConfig& precision_config, Array* lhs, Array* rhs) {
1109 // Intuitively, if
1110 //
1111 // - The LHS of a dot product is a gathered sequence of rows from a constant
1112 // array (i.e. LHS[I,J] = Const[Indices[I],J]) and the RHS is a constant
1113 //
1114 // OR
1115 //
1116 // - If the RHS of a dot product is a gathered sequence of columns from a
1117 // constant array (i.e. RHS[I,J] = Const[I, Indices[J]]) and the LHS is a
1118 // constant
1119 //
1120 // then the result of the dot product itself is a gather from a constant
1121 // array. E.g. Dot(LHS, ConstRhs) where LHS[I,J] = Const[Indices[I],J] can be
1122 // rewritten as Result where Result[I,J] = Dot(Const, ConstRhs)[Indices[I],
1123 // J].
1124 //
1125 // We do a general version of this rewrite here.
1126 VLOG(3) << "ComputeArrayForDot(" << ToString(lhs) << " " << ToString(rhs);
1127 if (auto* lhs_indexed_array =
1128 dynamic_cast<ScalarIndexedConstantArray*>(lhs)) {
1129 if (auto* rhs_constant = dynamic_cast<ConstantArray*>(rhs)) {
1130 return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers,
1131 precision_config,
1132 lhs_indexed_array, rhs_constant);
1133 }
1134 }
1135
1136 if (auto* rhs_indexed_array =
1137 dynamic_cast<ScalarIndexedConstantArray*>(rhs)) {
1138 if (auto* lhs_constant = dynamic_cast<ConstantArray*>(lhs)) {
1139 return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers,
1140 precision_config, lhs_constant,
1141 rhs_indexed_array);
1142 }
1143 }
1144
1145 return nullptr;
1146 }
1147
name() const1148 absl::string_view IndexedArrayAnalysisPrinterPass::name() const {
1149 return "indexed-array-analysis-printer-pass";
1150 }
1151
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)1152 StatusOr<bool> IndexedArrayAnalysisPrinterPass::Run(
1153 HloModule* module,
1154 const absl::flat_hash_set<absl::string_view>& execution_threads) {
1155 if (!VLOG_IS_ON(2)) {
1156 return false;
1157 }
1158
1159 IndexedArrayAnalysis analysis;
1160 for (auto* computation :
1161 module->MakeNonfusionComputations(execution_threads)) {
1162 for (auto* instr : computation->instructions()) {
1163 TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr));
1164 if (!dynamic_cast<UnknownArray*>(t) && !dynamic_cast<ConstantArray*>(t)) {
1165 VLOG(2) << instr->ToString() << " -> " << analysis.ToString(t);
1166 }
1167 }
1168 }
1169
1170 return false;
1171 }
1172
1173 } // namespace xla
1174