xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/indexed_array_analysis.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
17 
18 #include <algorithm>
19 #include <numeric>
20 #include <optional>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_join.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
32 #include "tensorflow/compiler/xla/util.h"
33 
34 namespace xla {
35 
36 namespace {
37 using Analysis = IndexedArrayAnalysis;
38 using UnknownArray = Analysis::UnknownArray;
39 using ConstantArray = Analysis::ConstantArray;
40 using ReshapedArray = Analysis::ReshapedArray;
41 using ScalarIndexedArray = Analysis::ScalarIndexedArray;
42 using absl::StrJoin;
43 }  // namespace
44 
ToString(Array * root,bool print_constants)45 std::string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
46   switch (root->kind()) {
47     case Array::kUnknown: {
48       auto* unknown_tensor = root->as<UnknownArray>();
49       return absl::StrCat("%", unknown_tensor->instruction().name());
50     }
51 
52     case Array::kConstant: {
53       if (print_constants) {
54         std::string contents = root->as<ConstantArray>()->literal()->ToString();
55         return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
56                             " ", contents, ")");
57       }
58       return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()),
59                           ")");
60     }
61 
62     case Array::kReshaped: {
63       ReshapedArray* reshaped_array = root->as<ReshapedArray>();
64       return absl::StrCat(
65           "(reshape ", ToString(reshaped_array->operand(), print_constants),
66           " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")");
67     }
68 
69     case Array::kScalarIndexedConstant:
70     case Array::kScalarIndexed: {
71       auto* indexed_array = root->as<ScalarIndexedArray>();
72       std::string name = root->kind() == Array::kScalarIndexedConstant
73                              ? "scalar-indexed-const"
74                              : "scalar-indexed";
75       return absl::StrCat(
76           "(", name, " ", ToString(indexed_array->source(), print_constants),
77           " ", ToString(indexed_array->indices(), print_constants), " ",
78           indexed_array->source_dim(), "->[",
79           StrJoin(indexed_array->output_dims(), ","), "])");
80     }
81   }
82 }
83 
GetArrayFor(const HloInstruction * instr)84 StatusOr<Analysis::Array*> IndexedArrayAnalysis::GetArrayFor(
85     const HloInstruction* instr) {
86   auto it = cache_.find(instr);
87   if (it != cache_.end()) {
88     return it->second;
89   }
90 
91   TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr));
92   return FindOrDie(cache_, instr);
93 }
94 
TraverseAndPopulateCache(const HloInstruction * root)95 Status IndexedArrayAnalysis::TraverseAndPopulateCache(
96     const HloInstruction* root) {
97   // Depth first search over the DAG, invoking ComputeArrayFor in post order.
98   // The HLO instructions already in the cache are considered leaves.
99 
100   absl::InlinedVector<const HloInstruction*, 4> stack;
101 
102   enum DfsState { kDiscovered, kVisited };
103   absl::flat_hash_map<const HloInstruction*, DfsState> dfs_state_map;
104 
105   stack.push_back(root);
106   InsertOrDie(&dfs_state_map, root, kDiscovered);
107 
108   do {
109     const HloInstruction* instr = stack.back();
110     if (cache_.contains(instr)) {
111       stack.pop_back();
112       continue;
113     }
114 
115     switch (FindOrDie(dfs_state_map, instr)) {
116       case kDiscovered: {
117         for (const HloInstruction* operand : instr->operands()) {
118           if (!cache_.contains(operand)) {
119             stack.push_back(operand);
120             CHECK(!dfs_state_map.contains(operand) ||
121                   dfs_state_map[operand] == kDiscovered);
122             dfs_state_map[operand] = kDiscovered;
123           }
124         }
125         dfs_state_map[instr] = kVisited;
126         break;
127       }
128 
129       case kVisited:
130         stack.pop_back();
131         TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr));
132         InsertOrDie(&cache_, instr, array);
133         break;
134     }
135   } while (!stack.empty());
136 
137   return OkStatus();
138 }
139 
ComputeArrayFor(const HloInstruction * instr)140 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
141     const HloInstruction* instr) {
142   Array* computed_array;
143   if (instr->IsElementwise() && instr->operand_count() == 1) {
144     TF_ASSIGN_OR_RETURN(
145         computed_array,
146         ComputeArrayForElementwiseUnaryOp(
147             instr->opcode(), FindOrDie(cache_, instr->operand(0))));
148   } else if (instr->IsElementwise() && instr->operand_count() == 2) {
149     TF_ASSIGN_OR_RETURN(
150         computed_array,
151         ComputeArrayForElementwiseBinaryOp(
152             instr->opcode(), FindOrDie(cache_, instr->operand(0)),
153             FindOrDie(cache_, instr->operand(1))));
154   } else if (instr->opcode() == HloOpcode::kConstant) {
155     TF_ASSIGN_OR_RETURN(computed_array,
156                         ComputeArrayForConstant(instr->literal()));
157   } else if (instr->opcode() == HloOpcode::kGather) {
158     TF_ASSIGN_OR_RETURN(
159         computed_array,
160         ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(),
161                               instr->gather_slice_sizes(),
162                               FindOrDie(cache_, instr->operand(0)),
163                               FindOrDie(cache_, instr->operand(1))));
164   } else if (instr->opcode() == HloOpcode::kReshape) {
165     TF_ASSIGN_OR_RETURN(
166         computed_array,
167         ComputeArrayForReshape(instr->shape(),
168                                FindOrDie(cache_, instr->operand(0))));
169   } else if (instr->opcode() == HloOpcode::kDot) {
170     TF_ASSIGN_OR_RETURN(
171         computed_array,
172         ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(),
173                            instr->precision_config(),
174                            FindOrDie(cache_, instr->operand(0)),
175                            FindOrDie(cache_, instr->operand(1))));
176   } else {
177     computed_array = nullptr;
178   }
179 
180   if (!computed_array) {
181     computed_array = Construct<UnknownArray>(instr);
182   }
183 
184   return computed_array;
185 }
186 
ComputeArrayForConstant(const Literal & literal)187 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForConstant(
188     const Literal& literal) {
189   return Construct<ConstantArray>(&literal);
190 }
191 
FoldGatherOfGather(ScalarIndexedArray * source,Array * indices,int64_t source_dim,absl::Span<const int64_t> output_dims,Shape shape)192 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
193     ScalarIndexedArray* source, Array* indices, int64_t source_dim,
194     absl::Span<const int64_t> output_dims, Shape shape) {
195   // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)).
196   // `source` is the inner Gather(A, X).
197 
198   Array* a = source->source();
199   Array* x = source->indices();
200   Array* y = indices;
201 
202   // This bit is slightly tricky, so we do a naive "simulation" of the two
203   // consecutive gather operations to infer what the composed gather should look
204   // like.
205 
206   enum class IndexComponent { Ungathered, GatheredFirst, GatheredSecond };
207 
208   std::vector<IndexComponent> simulated_index(a->shape().dimensions_size(),
209                                               IndexComponent::Ungathered);
210 
211   // Simulate the first gather.
212   EraseAt(&simulated_index, source->source_dim());
213   for (int64_t gather_dim : source->output_dims()) {
214     simulated_index.insert(simulated_index.begin() + gather_dim,
215                            IndexComponent::GatheredFirst);
216   }
217 
218   // Simulate the second gather.
219   EraseAt(&simulated_index, source_dim);
220   for (int64_t output_dim : output_dims) {
221     simulated_index.insert(simulated_index.begin() + output_dim,
222                            IndexComponent::GatheredSecond);
223   }
224 
225   int64_t source_dim_for_index_array =
226       FindIndex(source->output_dims(), source_dim);
227   CHECK_NE(source_dim_for_index_array, source->output_dims().size());
228 
229   std::vector<int64_t> output_dims_for_index_array;
230   int64_t gathered_index_components_seen = 0;
231   for (IndexComponent simulation_dim : simulated_index) {
232     if (simulation_dim == IndexComponent::GatheredSecond) {
233       output_dims_for_index_array.push_back(gathered_index_components_seen);
234     }
235     if (simulation_dim != IndexComponent::Ungathered) {
236       gathered_index_components_seen++;
237     }
238   }
239 
240   std::vector<int64_t> dim_sizes_for_composed_index;
241   std::vector<int64_t> output_dims_for_new_gather;
242   for (int64_t i = 0, e = simulated_index.size(); i < e; i++) {
243     if (simulated_index[i] != IndexComponent::Ungathered) {
244       dim_sizes_for_composed_index.push_back(shape.dimensions(i));
245       output_dims_for_new_gather.push_back(i);
246     }
247   }
248 
249   Array* inner_indices = ConstructScalarIndexedArray(
250       x, y, source_dim_for_index_array, output_dims_for_index_array,
251       ShapeUtil::MakeShape(x->shape().element_type(),
252                            dim_sizes_for_composed_index));
253   return ConstructScalarIndexedArray(a, inner_indices, source->source_dim(),
254                                      output_dims_for_new_gather,
255                                      std::move(shape));
256 }
257 
ComputeArrayForGather(const Shape & shape,const GatherDimensionNumbers & dim_numbers,absl::Span<const int64_t> slice_sizes,Array * source,Array * indices)258 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
259     const Shape& shape, const GatherDimensionNumbers& dim_numbers,
260     absl::Span<const int64_t> slice_sizes, Array* source, Array* indices) {
261   if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) {
262     VLOG(3) << "ComputeArrayForGather: indices are not scalar";
263     return nullptr;
264   }
265 
266   CHECK_EQ(dim_numbers.start_index_map_size(), 1);
267 
268   // We can also handle dim_numbers.collapsed_slice_dims_size() == 0 here,
269   // should it become relevant.
270 
271   if (dim_numbers.collapsed_slice_dims_size() != 1 ||
272       dim_numbers.collapsed_slice_dims(0) != dim_numbers.start_index_map(0)) {
273     VLOG(3) << "ComputeArrayForGather: gather operations must elide "
274                "start_index_map[0] and "
275                "start_index_map[0] only";
276     return nullptr;
277   }
278 
279   // ScalarIndexedArray cannot represent gathers that "slice" along some
280   // dimensions -- for instance it cannot represent a gather that picks 5 [2,3]
281   // arrays from an array of size [7,4,6].  We check that condition down below:
282 
283   for (int64_t i = 0, e = source->shape().dimensions_size(); i < e; i++) {
284     if (i != dim_numbers.collapsed_slice_dims(0) &&
285         source->shape().dimensions(i) != slice_sizes[i]) {
286       VLOG(3) << "ComputeArrayForGather: slice_sizes[" << i
287               << "] != source->shape().dimensions(" << i << ") -- "
288               << source->shape().dimensions(i) << " vs. " << slice_sizes[i]
289               << " with dim_numbers.collapsed_slice_dims(0) = "
290               << dim_numbers.collapsed_slice_dims(0);
291       return nullptr;
292     }
293   }
294 
295   int64_t source_dim = dim_numbers.start_index_map(0);
296   std::vector<int64_t> output_dims;
297   for (int64_t i = 0, e = shape.dimensions_size(); i < e; i++) {
298     if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) {
299       output_dims.push_back(i);
300     }
301   }
302 
303   if (auto* indexed = dynamic_cast<ScalarIndexedArray*>(source)) {
304     if (absl::c_linear_search(indexed->output_dims(), source_dim)) {
305       return FoldGatherOfGather(indexed, indices, source_dim, output_dims,
306                                 shape);
307     }
308   } else if (auto* constant = dynamic_cast<ConstantArray*>(source)) {
309     return Construct<ScalarIndexedConstantArray>(constant, indices, source_dim,
310                                                  output_dims, shape);
311   }
312 
313   return Construct<ScalarIndexedArray>(source, indices, source_dim, output_dims,
314                                        shape);
315 }
316 
317 namespace {
318 // Returns an index into `values` such that the product of the range
319 // [values.begin()+index, values.end()) is equal to `product`.  If there is no
320 // such index, return -1.  All integers in `values` must be positive.
FindSuffixWithProduct(absl::Span<const int64_t> values,int64_t product)321 int64_t FindSuffixWithProduct(absl::Span<const int64_t> values,
322                               int64_t product) {
323   DCHECK(absl::c_all_of(values, [](int64_t value) { return value > 0; }));
324 
325   int64_t current_product = 1;
326   int64_t i;
327   for (i = values.size() - 1; i >= 0 && product > current_product; --i) {
328     current_product *= values[i];
329   }
330 
331   if (product == current_product) {
332     return i + 1;
333   }
334 
335   return -1;
336 }
337 
338 struct ReshapePassthroughDimPair {
339   int64_t result_dim;
340   int64_t operand_dim;
341 };
342 
343 // Returns a set of dimension pairs such for all (result_dim, operand_dim) in
344 // the set:
345 //
346 // output_index[result_dim] = SourceIndexOfReshape(output_index)[operand_dim]
347 //
348 // The returned vector of pairs is sorted in both the result_dim and the
349 // operand_dim components.
ComputeReshapePassthroughDimPairs(absl::Span<const int64_t> operand_shape,absl::Span<const int64_t> result_shape)350 std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
351     absl::Span<const int64_t> operand_shape,
352     absl::Span<const int64_t> result_shape) {
353   // A reshape can be seen as an index mapping from output index to input index:
354   //
355   // (i_0, ..., i_n) = f(o_0, ..., o_m)
356   //
357   // This function returns the pairs (j, k) for which the following invariant
358   // holds for all indices in the shape:
359   //
360   //   o_j == i_k
361   //
362   // And this occurs when:
363   //
364   //    O_{j+1} * ... * O_n == I_{k+1} * ...  * I_m
365   //
366   // (where O_x are the sizes of the output shape and I_x are the sizes of the
367   // input shape) and the size of the dimension j of the result is the same as
368   // the size of dimension k in the operand.
369   //
370   // These conditions are sufficient because the Reshape HLO is spec'ed such
371   // that the rightmost dimensions are always minor in the flattening and refine
372   // operation.
373 
374   std::vector<ReshapePassthroughDimPair> result;
375   int64_t result_subarray_size = 1;
376   for (int64_t result_dim = result_shape.size() - 1; result_dim >= 0;
377        --result_dim) {
378     int64_t candidate_operand_dim =
379         FindSuffixWithProduct(operand_shape, result_subarray_size);
380 
381     // result_subarray_size does not include the elements in the current
382     // `result_dim` dimension (we multiply in result_shape[result_dim] at the
383     // end of loop body) so candidate_operand_dim can never be zero.
384     CHECK_NE(candidate_operand_dim, 0)
385         << "result_dim = " << result_dim
386         << ", result_subarray_size = " << result_subarray_size
387         << ", result_shape = [" << StrJoin(result_shape, ",") << "]"
388         << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]";
389 
390     if (candidate_operand_dim != -1 &&
391         result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) {
392       result.push_back({/*result_dim=*/result_dim,
393                         /*operand_dim=*/candidate_operand_dim - 1});
394     }
395     result_subarray_size *= result_shape[result_dim];
396   }
397 
398   absl::c_reverse(result);
399 
400   if (VLOG_IS_ON(3)) {
401     std::vector<std::string> result_strings;
402     absl::c_transform(result, std::back_inserter(result_strings),
403                       [](ReshapePassthroughDimPair value) {
404                         return absl::StrCat(value.result_dim, "->",
405                                             value.operand_dim);
406                       });
407     VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to ["
408             << StrJoin(result_shape, ",") << "] passthrough indices are ["
409             << StrJoin(result_strings, ",")
410             << "] (legend: `result`->`operand`)";
411   }
412 
413   DCHECK(absl::c_is_sorted(
414       result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
415         return lhs.result_dim < rhs.result_dim;
416       }));
417 
418   DCHECK(absl::c_is_sorted(
419       result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) {
420         return lhs.operand_dim < rhs.operand_dim;
421       }));
422 
423   return result;
424 }
425 
426 // Return true if `dim` is stated as an passthrough operand dim in
427 // `passthrough_dims`.
IsReshapePassthroughOperandDim(absl::Span<const ReshapePassthroughDimPair> passthrough_dims,int64_t dim)428 bool IsReshapePassthroughOperandDim(
429     absl::Span<const ReshapePassthroughDimPair> passthrough_dims, int64_t dim) {
430   return absl::c_any_of(passthrough_dims,
431                         [&](ReshapePassthroughDimPair passthrough_dim_pair) {
432                           return passthrough_dim_pair.operand_dim == dim;
433                         });
434 }
435 
436 // Maps `operand_dim` which must be an passthrough operand dimension to its
437 // corresponding passthrough result dimension based on `passthrough_dims`.
MapPassthroughOperandDimToResultDim(absl::Span<const ReshapePassthroughDimPair> passthrough_dims,int64_t operand_dim)438 int64_t MapPassthroughOperandDimToResultDim(
439     absl::Span<const ReshapePassthroughDimPair> passthrough_dims,
440     int64_t operand_dim) {
441   auto it = absl::c_find_if(
442       passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) {
443         return passthrough_dim_pair.operand_dim == operand_dim;
444       });
445   CHECK(it != passthrough_dims.end());
446   return it->result_dim;
447 }
448 
FindSourcePositionForPassthroughResultDim(absl::Span<const int64_t> operand_shape,absl::Span<const int64_t> result_shape,int64_t source_passthrough_dim)449 int64_t FindSourcePositionForPassthroughResultDim(
450     absl::Span<const int64_t> operand_shape,
451     absl::Span<const int64_t> result_shape, int64_t source_passthrough_dim) {
452   VLOG(3) << "FindSourcePositionForPassthroughResultDim(["
453           << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",")
454           << "], " << source_passthrough_dim << ")";
455 
456   int64_t indexed_source_subarray_size =
457       std::accumulate(operand_shape.begin() + source_passthrough_dim + 1,
458                       operand_shape.end(), 1LL, std::multiplies<int64_t>());
459 
460   return FindSuffixWithProduct(result_shape, indexed_source_subarray_size);
461 }
462 
StripDegenerateDimensions(const Shape & shape)463 Shape StripDegenerateDimensions(const Shape& shape) {
464   DimensionVector new_dims;
465   absl::c_copy_if(shape.dimensions(), std::back_inserter(new_dims),
466                   [](int64_t dim) { return dim != 1; });
467   return ShapeUtil::MakeShape(shape.element_type(), new_dims);
468 }
469 };  // namespace
470 
471 StatusOr<ScalarIndexedArray*>
ReshapeToRemoveDegenerateDims(ScalarIndexedArray * operand)472 IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims(
473     ScalarIndexedArray* operand) {
474   const Shape& shape = operand->shape();
475   if (!ShapeUtil::HasDegenerateDimensions(shape)) {
476     return operand;
477   }
478 
479   // We only need to reshape out the degenerate dims from the indices and the
480   // source (except the source dim).
481 
482   const Shape& source_shape = operand->source()->shape();
483   DimensionVector new_source_shape_dims;
484   for (int64_t i = 0, e = source_shape.dimensions_size(); i < e; i++) {
485     if (i == operand->source_dim() || source_shape.dimensions(i) != 1) {
486       new_source_shape_dims.push_back(source_shape.dimensions(i));
487     }
488   }
489 
490   Shape new_source_shape =
491       ShapeUtil::MakeShape(shape.element_type(), new_source_shape_dims);
492   Shape new_indices_shape =
493       StripDegenerateDimensions(operand->indices()->shape());
494 
495   TF_ASSIGN_OR_RETURN(
496       Array* const new_source,
497       ComputeArrayForReshape(new_source_shape, operand->source()));
498   TF_ASSIGN_OR_RETURN(
499       Array* const new_indices,
500       ComputeArrayForReshape(new_indices_shape, operand->indices()));
501 
502   // Build the new output dims while keeping track of the degenerate dims that
503   // will no longer be present.
504   DimensionVector new_output_dims;
505   int64_t degenerate_dims_seen = 0;
506   for (int64_t i = 0, e = shape.dimensions_size(); i < e; i++) {
507     if (shape.dimensions(i) == 1) {
508       degenerate_dims_seen++;
509     } else if (absl::c_linear_search(operand->output_dims(), i)) {
510       new_output_dims.push_back(i - degenerate_dims_seen);
511     }
512   }
513 
514   // Similarly, build the new source dim while keeping track of the degenerate
515   // dims that will no longer be present.
516   int64_t degenerate_dims_before_source_dim =
517       std::count(source_shape.dimensions().begin(),
518                  source_shape.dimensions().begin() + operand->source_dim(), 1);
519   int64_t new_source_dim =
520       operand->source_dim() - degenerate_dims_before_source_dim;
521 
522   return ConstructScalarIndexedArray(
523       new_source, new_indices, new_source_dim,
524       InlinedVectorToVector(new_output_dims),
525       StripDegenerateDimensions(operand->shape()));
526 }
527 
ReshapeToAddDegenerateDims(ScalarIndexedArray * operand,absl::Span<const int64_t> degenerate_dims)528 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims(
529     ScalarIndexedArray* operand, absl::Span<const int64_t> degenerate_dims) {
530   if (degenerate_dims.empty()) {
531     return operand;
532   }
533 
534   CHECK(!ShapeUtil::HasDegenerateDimensions(operand->shape()));
535 
536   DimensionVector new_output_dims = [&]() {
537     // To make things easy we use a "scratch" buffer of bools where the i'th
538     // element is true iff the i'th component of the result index is an output
539     // index.
540 
541     absl::InlinedVector<bool, 6> output_dims_bitvector(
542         operand->shape().dimensions_size());
543     for (int64_t output_dim : operand->output_dims()) {
544       output_dims_bitvector[output_dim] = true;
545     }
546 
547     for (int64_t degenerate_dim : degenerate_dims) {
548       InsertAt(&output_dims_bitvector, degenerate_dim, false);
549     }
550 
551     DimensionVector result;
552     result.reserve(operand->output_dims().size());
553     for (int64_t i = 0, e = output_dims_bitvector.size(); i < e; i++) {
554       if (output_dims_bitvector[i]) {
555         result.push_back(i);
556       }
557     }
558 
559     return result;
560   }();
561 
562   DimensionVector new_result_shape_dims;
563   absl::c_copy(operand->shape().dimensions(),
564                std::back_inserter(new_result_shape_dims));
565   for (int64_t degenerate_dim : degenerate_dims) {
566     InsertAt(&new_result_shape_dims, degenerate_dim, 1);
567   }
568 
569   DimensionVector new_source_shape_dims = new_result_shape_dims;
570   for (int64_t output_dim : new_output_dims) {
571     EraseAt(&new_source_shape_dims, output_dim);
572   }
573 
574   int64_t new_source_dim = [&]() {
575     for (int i = 0, e = new_source_shape_dims.size(); i < e; i++) {
576       int64_t non_degenerate_dims_seen = 0;
577       if (non_degenerate_dims_seen == operand->source_dim()) {
578         return i;
579       }
580       if (new_source_shape_dims[new_source_dim] != 1) {
581         non_degenerate_dims_seen++;
582       }
583     }
584     LOG(FATAL) << "Did not find source dim in " << ToString(operand);
585   }();
586 
587   int64_t source_dim_size =
588       operand->source()->shape().dimensions(operand->source_dim());
589   InsertAt(&new_source_shape_dims, /*index=*/new_source_dim,
590            /*value=*/source_dim_size);
591 
592   Shape new_source_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
593                                                 new_source_shape_dims);
594   Shape new_result_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
595                                                 new_result_shape_dims);
596 
597   TF_ASSIGN_OR_RETURN(
598       Array* const new_source,
599       ComputeArrayForReshape(new_source_shape, operand->source()));
600   return ConstructScalarIndexedArray(
601       new_source, operand->indices(), new_source_dim,
602       InlinedVectorToVector(new_output_dims), new_result_shape);
603 }
604 
FoldReshapeOfGather(const Shape & shape,ScalarIndexedConstantArray * operand)605 StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldReshapeOfGather(
606     const Shape& shape, ScalarIndexedConstantArray* operand) {
607   VLOG(3) << "FoldReshapeOfGather(" << ToString(operand) << ")";
608 
609   // To make things easier on ourselves, instead of directly trying to fold the
610   // reshape of `operand` to `shape`, we call
611   // `FoldReshapeOfGatherNoDegenerateDims` on shapes without degenerate dims and
612   // handle the degenerate dimensions here by inserting reshapes.
613 
614   TF_ASSIGN_OR_RETURN(ScalarIndexedArray* const operand_without_degenerate_dims,
615                       ReshapeToRemoveDegenerateDims(operand));
616 
617   Shape output_shape_without_degenerate_dims = StripDegenerateDimensions(shape);
618   TF_ASSIGN_OR_RETURN(
619       ScalarIndexedArray* const folded_reshape_without_degenerate_dims,
620       FoldReshapeOfGatherNoDegenerateDims(
621           output_shape_without_degenerate_dims,
622           operand_without_degenerate_dims->as<ScalarIndexedConstantArray>()));
623 
624   if (folded_reshape_without_degenerate_dims == nullptr) {
625     return nullptr;
626   }
627 
628   DimensionVector degenerate_result_dims;
629   for (int64_t i = 0, e = shape.dimensions_size(); i < e; i++) {
630     if (shape.dimensions(i) == 1) {
631       degenerate_result_dims.push_back(i);
632     }
633   }
634 
635   return ReshapeToAddDegenerateDims(folded_reshape_without_degenerate_dims,
636                                     degenerate_result_dims);
637 }
638 
639 StatusOr<ScalarIndexedArray*>
FoldReshapeOfGatherNoDegenerateDims(const Shape & shape,ScalarIndexedConstantArray * scalar_indexed)640 IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
641     const Shape& shape, ScalarIndexedConstantArray* scalar_indexed) {
642   VLOG(3) << "FoldReshapeOfGatherNoDegenerateDims(" << ToString(scalar_indexed)
643           << ")";
644   CHECK(!ShapeUtil::HasDegenerateDimensions(shape));
645   CHECK(!ShapeUtil::HasDegenerateDimensions(scalar_indexed->shape()));
646 
647   // Try to fold Reshape(ScalarIndexed(Const, Indices))
648   //          => ScalarIndexed(Const', Indices)
649   //
650   // We can view the reshape and the scalar-indexed operations as functions that
651   // map an output index (i.e. an index into the result) to an input index
652   // (i.e. an index into the operand).  The key idea used here is that the
653   // output-to-input mapping for some reshape operations may "pass through" some
654   // output dimensions into the input space unchanged -- i.e. there may exist
655   // output dimension "O" and input dimension "I" such that OutputIndex[O] is
656   // always == InputIndexForReshape(OutputIndex)[I].  If these pass-through
657   // dimensions in the input space of the reshape happen to be include all the
658   // output dimensions for the scalar-indexed node then, roughly, the following
659   // holds:
660   //
661   //    SourceIndexOfScalarIndexed(SourceIndexOfReshape(Idx))
662   // == SourceIndexOfScalarIndexed(SourceIndexOfReshape(Ps ++ Qs))
663   //
664   //      Where Ps are the set of the pass-through components of Idx that are
665   //      also the output dims of the scalar-indexed node, and Qs are the rest.
666   //      For brevity, we're playing fast and loose with the notation here -- we
667   //      don't literally require Idx to be a concatenation of Ps and Qs, as
668   //      suggested by the "++".
669   //
670   // == SourceIndexOfScalarIndexed(Ps ++ SourceIndexOfReshape(Qs))
671   //
672   //      Again, we're playing fast and loose with the notation around "++".
673   //      Generally this ++ will be a different function that the ++ in the
674   //      previous step.
675   //
676   // If the scalar-indexed node has a constant as the source then the
677   // SourceIndexOfReshape function can be "folded into" the constant itself by
678   // reshaping it, leaving us with:
679   //
680   // == SourceIndexOfScalarIndexed(Ps ++ Qs)
681   // == SourceIndexOfScalarIndexed(Idx)
682   //
683   // which is just a scalar-indexed node (with parameters different from the
684   // scalar-indexed node we started with) with a reshaped constant as the
685   // source.
686   //
687   // We can't fold SourceIndexOfReshape into the constant without introducing
688   // another precondition: since the new scalar-indexed node will have a
689   // reshaped (constant) array as its source it will, in general, have a
690   // different source dimension than the original scalar-indexed node.  This
691   // source dimension will have to be a passthrough dimension of the
692   // SourceIndexOfReshape indexing function that is folded into the source. And
693   // such a dimension need not exist so this is a non-trivial precondition.
694 
695   std::vector<ReshapePassthroughDimPair> reshape_passthrough_dims =
696       ComputeReshapePassthroughDimPairs(
697           /*operand_shape=*/scalar_indexed->shape().dimensions(),
698           /*result_shape=*/shape.dimensions());
699 
700   auto is_reshape_passthrough_operand_dim = [&](int64_t operand_dim) {
701     return IsReshapePassthroughOperandDim(reshape_passthrough_dims,
702                                           operand_dim);
703   };
704 
705   if (!absl::c_all_of(scalar_indexed->output_dims(),
706                       is_reshape_passthrough_operand_dim)) {
707     VLOG(3) << "Not all output dims are passthrough dims "
708             << ToString(scalar_indexed);
709     return nullptr;
710   }
711 
712   // To compute the shape of the source for the new scalar-indexed node we're
713   // going to create, we first "undo" the scalar-indexed operation.
714   std::vector<int64_t> new_scalar_indexed_source_shape(
715       shape.dimensions().begin(), shape.dimensions().end());
716   for (int64_t i = scalar_indexed->output_dims().size() - 1; i >= 0; i--) {
717     int64_t output_dim = scalar_indexed->output_dims()[i];
718     int64_t output_dim_after_reshape = MapPassthroughOperandDimToResultDim(
719         reshape_passthrough_dims, output_dim);
720     EraseAt(&new_scalar_indexed_source_shape, output_dim_after_reshape);
721   }
722 
723   // After this, we need to add in the dimension that will be the source
724   // dimension for the new scalar-indexed node.  A scalar-indexed node "removes"
725   // the source dimensions and "adds" the output dimensions, so to get back to
726   // the shape for the *source* of the scalar-indexed node we need to remove the
727   // output dims (which we did above) and then add back the source dim (which we
728   // are about to do below):
729 
730   const Shape& scalar_indexed_source_shape = scalar_indexed->source()->shape();
731 
732   int64_t source_dim_for_new_scalar_indexed_node =
733       FindSourcePositionForPassthroughResultDim(
734           /*operand_shape=*/scalar_indexed_source_shape.dimensions(),
735           /*result_shape=*/new_scalar_indexed_source_shape,
736           scalar_indexed->source_dim());
737 
738   // We may not be able to find a source dim for the new scalar-indexed node.
739   // For instance consider:
740   //
741   //   operand = s32[3,5,2] constant({...})
742   //   indices = s32[7] parameter(0)
743   //   gather = s32[3,2,7] gather(operand, indices),
744   //       offset_dims={0,1},
745   //       collapsed_slice_dims={1},
746   //       start_index_map={1},
747   //       index_vector_dim=1,
748   //       slice_sizes={3,1,2}
749   //   reshape = s32[6,7] reshape(gather)
750   //
751   // In this case the gather maps to:
752   //    (scalar-indexed-const (constant s32[3,5,2]) %indices 1->[2])
753   //
754   // and the reshape passes through dimension 2 from its input into dimension 1
755   // in its output.  However, we can't rewrite the reshape as a scalar-indexed
756   // node because then we'd have to reshape the [3,5,2] `operand` array to
757   // [6,5], but then dimension 1 of the reshaped [6,5] array indexes differently
758   // (a.k.a. isn't pass-through) than the [3,5,2] array.
759 
760   if (source_dim_for_new_scalar_indexed_node == -1) {
761     VLOG(3) << "Could not compute the source dim for the new scalar indexed "
762                "node: scalar_indexed_source_shape = ["
763             << StrJoin(scalar_indexed_source_shape.dimensions(), ",")
764             << "] and new_scalar_indexed_source_shape = ["
765             << StrJoin(new_scalar_indexed_source_shape, ",") << "]";
766     return nullptr;
767   }
768 
769   InsertAt(
770       &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node,
771       scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim()));
772 
773   CHECK_EQ(absl::c_accumulate(new_scalar_indexed_source_shape, 1LL,
774                               std::multiplies<int64_t>()),
775            ShapeUtil::ElementsIn(scalar_indexed_source_shape));
776 
777   CHECK(IsReshapePassthroughOperandDim(
778       ComputeReshapePassthroughDimPairs(
779           /*operand_shape=*/scalar_indexed_source_shape.dimensions(),
780           /*result_shape=*/new_scalar_indexed_source_shape),
781       scalar_indexed->source_dim()));
782 
783   auto map_passthrough_operand_dim_to_result_dim = [&](int64_t result_dim) {
784     return MapPassthroughOperandDimToResultDim(reshape_passthrough_dims,
785                                                result_dim);
786   };
787 
788   std::vector<int64_t> output_dims_for_new_scalar_indexed_node;
789   absl::c_transform(scalar_indexed->output_dims(),
790                     std::back_inserter(output_dims_for_new_scalar_indexed_node),
791                     map_passthrough_operand_dim_to_result_dim);
792 
793   TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal,
794                       TakeOwnership(scalar_indexed->literal().Reshape(
795                           new_scalar_indexed_source_shape)));
796   TF_ASSIGN_OR_RETURN(
797       Array * new_scalar_indexed_source,
798       ComputeArrayForConstant(*new_scalar_indexed_source_literal));
799 
800   return ConstructScalarIndexedArray(
801       new_scalar_indexed_source, scalar_indexed->indices(),
802       source_dim_for_new_scalar_indexed_node,
803       output_dims_for_new_scalar_indexed_node, shape);
804 }
805 
ComputeArrayForReshape(const Shape & shape,Array * operand)806 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
807     const Shape& shape, Array* operand) {
808   if (ShapeUtil::Compatible(operand->shape(), shape)) {
809     return operand;
810   }
811 
812   if (auto* scalar_indexed =
813           dynamic_cast<ScalarIndexedConstantArray*>(operand)) {
814     TF_ASSIGN_OR_RETURN(Analysis::Array * reshape_folded_into_gather,
815                         FoldReshapeOfGather(shape, scalar_indexed));
816     if (reshape_folded_into_gather) {
817       return reshape_folded_into_gather;
818     }
819   }
820 
821   if (auto* constant_array = dynamic_cast<ConstantArray*>(operand)) {
822     TF_ASSIGN_OR_RETURN(
823         Literal* const new_literal,
824         TakeOwnership(constant_array->literal()->Reshape(shape.dimensions())));
825     return Construct<ConstantArray>(new_literal);
826   }
827 
828   return Construct<ReshapedArray>(operand, shape);
829 }
830 
831 StatusOr<Analysis::Array*>
ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,Array * lhs,Array * rhs)832 IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
833                                                          Array* lhs,
834                                                          Array* rhs) {
835   // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices))
836   //          => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices)
837   //
838   // We can do this if every output dimension from the scalar-indexed node is a
839   // broadcasted dimension for the broadcast node.  Informally, the precondition
840   // means Broadcast(Const0)[IDX] is solely a function of the components of IDX
841   // that are not output-dims for the scalar-indexed node. In other words, for
842   // every assignment to the non-output dims in IDX we have a "constant" LHS to
843   // the BinaryOp.  This transform propagates this "constant" to the source for
844   // the scalar-indexed node.
845 
846   ScalarIndexedConstantArray* lhs_scalar_indexed_const =
847       dynamic_cast<ScalarIndexedConstantArray*>(lhs);
848   ScalarIndexedConstantArray* rhs_scalar_indexed_const =
849       dynamic_cast<ScalarIndexedConstantArray*>(rhs);
850 
851   bool lhs_is_indexed;
852 
853   // One of the operands must be scalar-indexed and the other must be a
854   // broadcast of a constant.
855   if (lhs_scalar_indexed_const && !rhs_scalar_indexed_const) {
856     lhs_is_indexed = true;
857   } else if (rhs_scalar_indexed_const && !lhs_scalar_indexed_const) {
858     lhs_is_indexed = false;
859   } else {
860     return nullptr;
861   }
862 
863   ScalarIndexedConstantArray* scalar_indexed_const =
864       lhs_is_indexed ? lhs_scalar_indexed_const : rhs_scalar_indexed_const;
865   UnknownArray* candidate_broadcast_array =
866       dynamic_cast<UnknownArray*>(lhs_is_indexed ? rhs : lhs);
867   if (!candidate_broadcast_array ||
868       candidate_broadcast_array->instruction().opcode() !=
869           HloOpcode::kBroadcast) {
870     return nullptr;
871   }
872 
873   const HloInstruction* broadcast_instr =
874       &candidate_broadcast_array->instruction();
875   const HloInstruction* broadcast_const_operand = broadcast_instr->operand(0);
876   if (broadcast_const_operand->opcode() != HloOpcode::kConstant) {
877     return nullptr;
878   }
879 
880   absl::Span<const int64_t> broadcast_dims = broadcast_instr->dimensions();
881   auto is_broadcasted_dim = [&](int64_t output_dim) {
882     return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end();
883   };
884 
885   // All of the output dims must be "broadcasted" dims for the other operand.
886   if (!absl::c_all_of(scalar_indexed_const->output_dims(),
887                       is_broadcasted_dim)) {
888     return nullptr;
889   }
890 
891   // To figure out the broadcast dimensions for the (constant) source for the
892   // scalar-indexed node, we "simulate" the index transformation done by the
893   // existing broadcast:
894   enum class IndexComponent { Broadcasted, NotBroadcasted };
895   std::vector<IndexComponent> simulated_index(
896       broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted);
897   for (int64_t broadcast_dim : broadcast_dims) {
898     simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted;
899   }
900 
901   // The scalar-indexed node "removes" the source dim and "inserts" the output
902   // dims.  We do the opposite here to undo the scalar-indexed operation.
903   absl::Span<const int64_t> output_dims = scalar_indexed_const->output_dims();
904   for (int64_t i = output_dims.size() - 1; i >= 0; --i) {
905     CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted);
906     EraseAt(&simulated_index, output_dims[i]);
907   }
908 
909   InsertAt(&simulated_index, scalar_indexed_const->source_dim(),
910            IndexComponent::Broadcasted);
911 
912   // new_inner_broadcast_dims holds the broadcast dimensions for the inner
913   // BinaryOp(Broadcast'(Const0), Const1).  We now translate simulated_index to
914   // new_inner_broadcast_dims.
915   std::vector<int64_t> new_inner_broadcast_dims;
916   for (int64_t i = 0; i < simulated_index.size(); i++) {
917     if (simulated_index[i] == IndexComponent::NotBroadcasted) {
918       new_inner_broadcast_dims.push_back(i);
919     }
920   }
921 
922   // inner_broadcast_result is the Broadcast'(Const0) bit in
923   // BinaryOp(Broadcast'(Const0), Const1)
924   TF_ASSIGN_OR_RETURN(
925       Literal inner_broadcast_result,
926       broadcast_const_operand->literal().Broadcast(
927           scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
928 
929   // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1)
930   const Literal* literal_for_new_source;
931   if (lhs_is_indexed) {
932     TF_ASSIGN_OR_RETURN(
933         literal_for_new_source,
934         TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
935             opcode, scalar_indexed_const->literal(), inner_broadcast_result)));
936   } else {
937     TF_ASSIGN_OR_RETURN(
938         literal_for_new_source,
939         TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
940             opcode, inner_broadcast_result, scalar_indexed_const->literal())));
941   }
942 
943   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
944   return Construct<ScalarIndexedConstantArray>(
945       new_source, scalar_indexed_const->indices(),
946       scalar_indexed_const->source_dim(),
947       std::vector<int64_t>(scalar_indexed_const->output_dims().begin(),
948                            scalar_indexed_const->output_dims().end()),
949       scalar_indexed_const->shape());
950 }
951 
952 StatusOr<Analysis::Array*>
ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,Array * operand)953 IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
954                                                         Array* operand) {
955   auto* scalar_indexed_const =
956       dynamic_cast<ScalarIndexedConstantArray*>(operand);
957   if (scalar_indexed_const == nullptr) {
958     return nullptr;
959   }
960 
961   // Fold UnaryOp(ScalarIndexed(Const, Indices))
962   //   => ScalarIndexed(UnaryOp(Const), Indices)
963 
964   TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
965                       TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp(
966                           opcode, scalar_indexed_const->literal())));
967   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
968   return Construct<ScalarIndexedConstantArray>(
969       new_source, scalar_indexed_const->indices(),
970       scalar_indexed_const->source_dim(),
971       SpanToVector(scalar_indexed_const->output_dims()),
972       scalar_indexed_const->shape());
973 }
974 
975 namespace {
976 
977 // Returns the non-contracting non-batch dimension (as per `contracting_dims`
978 // and `batch_dims`) if there is exactly one, otherwise returns nullopt.
GetOnlyNonContractingNonBatchDim(int64_t rank,absl::Span<const int64_t> contracting_dims,absl::Span<const int64_t> batch_dims)979 std::optional<int64_t> GetOnlyNonContractingNonBatchDim(
980     int64_t rank, absl::Span<const int64_t> contracting_dims,
981     absl::Span<const int64_t> batch_dims) {
982   std::optional<int64_t> result;
983   for (int64_t dim = 0; dim < rank; dim++) {
984     if (!absl::c_linear_search(contracting_dims, dim) &&
985         !absl::c_linear_search(batch_dims, dim)) {
986       if (result.has_value()) {
987         return std::nullopt;
988       }
989       result = dim;
990     }
991   }
992   return result;
993 }
994 
995 // Returns true if `indexed_array`, which is either the LHS or the RHS of a Dot
996 // HLO, can be folded into the dot operation.  For now these conditions are both
997 // necessary and sufficient.
998 //
999 // `tag` describes the caller.  Used only for logging.
1000 //
1001 // `contracting_dims` and `batch_dims` are the contracting and batch dimensions
1002 // of whatever operand `indexed_array` is to the dot (LHS or RHS).
CanFoldDotIntoIndexedArray(absl::string_view tag,Analysis::ScalarIndexedConstantArray * indexed_array,absl::Span<const int64_t> contracting_dims,absl::Span<const int64_t> batch_dims)1003 bool CanFoldDotIntoIndexedArray(
1004     absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array,
1005     absl::Span<const int64_t> contracting_dims,
1006     absl::Span<const int64_t> batch_dims) {
1007   std::optional<int64_t> non_contracting_non_batch_dim =
1008       GetOnlyNonContractingNonBatchDim(indexed_array->shape().rank(),
1009                                        contracting_dims, batch_dims);
1010   if (!non_contracting_non_batch_dim.has_value()) {
1011     VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions";
1012     return false;
1013   }
1014 
1015   if (indexed_array->output_dims().size() != 1 ||
1016       indexed_array->output_dims()[0] != *non_contracting_non_batch_dim) {
1017     VLOG(3) << tag << ": output dims != the lhs non-contracting non-batch dim";
1018     return false;
1019   }
1020 
1021   int64_t indexed_array_rank = indexed_array->shape().rank();
1022   if (indexed_array->source_dim() < (indexed_array_rank - 2)) {
1023     // This restriction can be lifted by inserting reshape nodes.
1024     VLOG(3) << tag
1025             << ": source dim is not in the low two dims, won't be able to form "
1026                "a matmul";
1027     return false;
1028   }
1029 
1030   return true;
1031 }
1032 
1033 }  // namespace
1034 
1035 StatusOr<Analysis::Array*>
ComputeArrayForDotWithIndexedLhs(const Shape & shape,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,ScalarIndexedConstantArray * lhs,ConstantArray * rhs)1036 IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
1037     const Shape& shape, const DotDimensionNumbers& dim_numbers,
1038     const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs,
1039     ConstantArray* rhs) {
1040   VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " "
1041           << ToString(rhs);
1042   if (!CanFoldDotIntoIndexedArray(
1043           "ComputeArrayForDotWithIndexedLhs", lhs, /*contracting_dims=*/
1044           dim_numbers.lhs_contracting_dimensions(),
1045           /*batch_dims=*/dim_numbers.lhs_batch_dimensions())) {
1046     return nullptr;
1047   }
1048 
1049   int64_t lhs_rank = lhs->shape().rank();
1050   DotDimensionNumbers new_dim_numbers = dim_numbers;
1051   new_dim_numbers.set_lhs_contracting_dimensions(
1052       0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1));
1053 
1054   TF_ASSIGN_OR_RETURN(
1055       Literal * literal_for_new_source,
1056       TakeOwnership(HloEvaluator{}.EvaluateDotOp(
1057           new_dim_numbers, precision_config, lhs->literal(), *rhs->literal())));
1058 
1059   // The new source dimension is wherever the non-batch non-contracting LHS
1060   // dimension "went".
1061   int64_t new_source_dim = dim_numbers.lhs_batch_dimensions_size() +
1062                            dim_numbers.rhs_batch_dimensions_size();
1063 
1064   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
1065   return Construct<ScalarIndexedConstantArray>(
1066       new_source, lhs->indices(), new_source_dim,
1067       SpanToVector(lhs->output_dims()), shape);
1068 }
1069 
1070 StatusOr<Analysis::Array*>
ComputeArrayForDotWithIndexedRhs(const Shape & shape,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,ConstantArray * lhs,ScalarIndexedConstantArray * rhs)1071 IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
1072     const Shape& shape, const DotDimensionNumbers& dim_numbers,
1073     const PrecisionConfig& precision_config, ConstantArray* lhs,
1074     ScalarIndexedConstantArray* rhs) {
1075   VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " "
1076           << ToString(rhs);
1077   if (!CanFoldDotIntoIndexedArray(
1078           "ComputeArrayForDotWithIndexedRhs", rhs, /*contracting_dims=*/
1079           dim_numbers.rhs_contracting_dimensions(),
1080           /*batch_dims=*/dim_numbers.rhs_batch_dimensions())) {
1081     return nullptr;
1082   }
1083 
1084   int64_t rhs_rank = rhs->shape().rank();
1085 
1086   DotDimensionNumbers new_dim_numbers = dim_numbers;
1087   new_dim_numbers.set_rhs_contracting_dimensions(
1088       0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1));
1089 
1090   TF_ASSIGN_OR_RETURN(
1091       Literal * literal_for_new_source,
1092       TakeOwnership(HloEvaluator{}.EvaluateDotOp(
1093           new_dim_numbers, precision_config, *lhs->literal(), rhs->literal())));
1094 
1095   // The new source dimension is wherever the non-batch non-contracting RHS
1096   // dimension "went".
1097   int64_t new_source_dim = dim_numbers.lhs_batch_dimensions_size() +
1098                            dim_numbers.rhs_batch_dimensions_size() + 1;
1099 
1100   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
1101   return Construct<ScalarIndexedConstantArray>(
1102       new_source, rhs->indices(), new_source_dim,
1103       SpanToVector(rhs->output_dims()), shape);
1104 }
1105 
ComputeArrayForDot(const Shape & shape,const DotDimensionNumbers & dim_numbers,const PrecisionConfig & precision_config,Array * lhs,Array * rhs)1106 StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
1107     const Shape& shape, const DotDimensionNumbers& dim_numbers,
1108     const PrecisionConfig& precision_config, Array* lhs, Array* rhs) {
1109   // Intuitively, if
1110   //
1111   //  - The LHS of a dot product is a gathered sequence of rows from a constant
1112   //    array (i.e. LHS[I,J] = Const[Indices[I],J]) and the RHS is a constant
1113   //
1114   //  OR
1115   //
1116   //  - If the RHS of a dot product is a gathered sequence of columns from a
1117   //    constant array (i.e. RHS[I,J] = Const[I, Indices[J]]) and the LHS is a
1118   //    constant
1119   //
1120   // then the result of the dot product itself is a gather from a constant
1121   // array.  E.g. Dot(LHS, ConstRhs) where LHS[I,J] = Const[Indices[I],J] can be
1122   // rewritten as Result where Result[I,J] = Dot(Const, ConstRhs)[Indices[I],
1123   // J].
1124   //
1125   // We do a general version of this rewrite here.
1126   VLOG(3) << "ComputeArrayForDot(" << ToString(lhs) << " " << ToString(rhs);
1127   if (auto* lhs_indexed_array =
1128           dynamic_cast<ScalarIndexedConstantArray*>(lhs)) {
1129     if (auto* rhs_constant = dynamic_cast<ConstantArray*>(rhs)) {
1130       return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers,
1131                                               precision_config,
1132                                               lhs_indexed_array, rhs_constant);
1133     }
1134   }
1135 
1136   if (auto* rhs_indexed_array =
1137           dynamic_cast<ScalarIndexedConstantArray*>(rhs)) {
1138     if (auto* lhs_constant = dynamic_cast<ConstantArray*>(lhs)) {
1139       return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers,
1140                                               precision_config, lhs_constant,
1141                                               rhs_indexed_array);
1142     }
1143   }
1144 
1145   return nullptr;
1146 }
1147 
name() const1148 absl::string_view IndexedArrayAnalysisPrinterPass::name() const {
1149   return "indexed-array-analysis-printer-pass";
1150 }
1151 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)1152 StatusOr<bool> IndexedArrayAnalysisPrinterPass::Run(
1153     HloModule* module,
1154     const absl::flat_hash_set<absl::string_view>& execution_threads) {
1155   if (!VLOG_IS_ON(2)) {
1156     return false;
1157   }
1158 
1159   IndexedArrayAnalysis analysis;
1160   for (auto* computation :
1161        module->MakeNonfusionComputations(execution_threads)) {
1162     for (auto* instr : computation->instructions()) {
1163       TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr));
1164       if (!dynamic_cast<UnknownArray*>(t) && !dynamic_cast<ConstantArray*>(t)) {
1165         VLOG(2) << instr->ToString() << "   ->   " << analysis.ToString(t);
1166       }
1167     }
1168   }
1169 
1170   return false;
1171 }
1172 
1173 }  // namespace xla
1174