xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/indexed_array_analysis.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_
18 
19 #include <type_traits>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
25 #include "tensorflow/core/util/ptr_util.h"
26 
27 namespace xla {
28 
29 // IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a
30 // gather from another array.  It does this by mapping HLO instructions to
31 // instances of IndexedArrayAnalysis::Array, which can be inspected to discover
32 // whether said HLO is equivalent to a gather.
33 class IndexedArrayAnalysis {
34  public:
35   // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array.
36   // Array really just a sum type of the classes that inherit from it.  The
37   // meaning of each of the subtypes is documented on the subtype declaration.
38   //
39   // Array instances are immutable once created.
40   class Array {
41    public:
42     enum Kind {
43       kUnknown,
44       kConstant,
45       kReshaped,
46       kScalarIndexedConstant,
47       kScalarIndexed
48     };
49 
50     virtual Kind kind() const = 0;
51     virtual const Shape& shape() const = 0;
52 
53     // Does a checked downcast from `Array` to `T` which must be one of its
54     // subtypes.
55     template <typename T>
as()56     T* as() {
57       static_assert((std::is_base_of<Array, T>::value),
58                     "target type not derived from source type");
59       // We skip the CHECK and hence the dynamic_cast if RTTI is disabled.
60 #if !defined(__GNUC__) || defined(__GXX_RTTI)
61       CHECK_NE(dynamic_cast<T*>(this), nullptr);
62 #endif  // !defined(__GNUC__) || defined(__GXX_RTTI)
63 
64       return static_cast<T*>(this);
65     }
66 
67     virtual ~Array() = default;
68 
69     Array& operator=(const Array& other) = delete;
70   };
71 
72   // Represents an HLO instruction that was not analyzable by this
73   // IndexedArrayAnalysis.  Instances of UnknownArray just wrap an existing
74   // HloInstruction.
75   class UnknownArray : public Array {
76    public:
kind()77     Kind kind() const override { return kUnknown; }
shape()78     const Shape& shape() const override { return instruction().shape(); }
instruction()79     const HloInstruction& instruction() const { return instruction_; }
80 
81    private:
UnknownArray(const HloInstruction * instr)82     explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {}
83 
84     const HloInstruction& instruction_;
85 
86     friend class IndexedArrayAnalysis;
87   };
88 
89   // Represents a constant value.  This constant value may be present in the HLO
90   // module being analyzed, or it could have been created on the fly by the
91   // analysis.
92   class ConstantArray : public Array {
93    public:
kind()94     Kind kind() const override { return kConstant; }
shape()95     const Shape& shape() const override { return literal()->shape(); }
literal()96     const Literal* literal() const { return literal_; }
97 
98    private:
ConstantArray(const Literal * literal)99     explicit ConstantArray(const Literal* literal) : literal_(literal) {}
100     const Literal* literal_;
101 
102     friend class IndexedArrayAnalysis;
103   };
104 
105   // Represents an Array that is a reshape of another Array.
106   class ReshapedArray : public Array {
107    public:
kind()108     Kind kind() const override { return kReshaped; }
109 
110     // The array to reshape.
operand()111     Array* operand() const { return operand_; }
112 
113     // The output shape.
shape()114     const Shape& shape() const override { return shape_; }
115 
116    private:
ReshapedArray(Array * operand,Shape shape)117     explicit ReshapedArray(Array* operand, Shape shape)
118         : operand_(operand), shape_(shape) {}
119 
120     Array* operand_;
121     const Shape shape_;
122 
123     friend class IndexedArrayAnalysis;
124   };
125 
126   // ---------------------------------------------------------------------------
127   // Indexed Array Overview
128   // ---------------------------------------------------------------------------
129   //
130   // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this
131   // analysis.  ScalarIndexedConstantArray is just a specialization of
132   // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this
133   // overview.
134   //
135   // A ScalarIndexedArray represents an array that can be computed by indexing
136   // into a "source" array using an "indices" tensor.  A simple example is a
137   // gather operation gathering 12 rows out of a [100,100] matrix -- such an
138   // operation will be represented by an instance of a ScalarIndexedArray with
139   // the [100,100] matrix as the "source" array and the [12]-shaped indices
140   // array as the "indices" tensor.  The ScalarIndexedArray operation itself
141   // will be of shape [12,100] (assuming we were gathering with axis=0).
142   //
143   // Gather operations are not the only operation that maps to
144   // ScalarIndexedArray instances (if that were true there would be little point
145   // in having a separate analysis).  We can often infer ScalarIndexedArrays for
146   // other operations too.  For instance, consider:
147   //
148   //   %source = f32[100,100] constant
149   //   %indices = s32[12] ...
150   //   %gather = f32[12,100] ... gather from %source using %indices at axis 0
151   //   %dot = dot(%gather, other_constant) [canonical contracting dims]
152   //
153   // The dot operation itself is also a ScalarIndexedArray with source =
154   // dot(constant, other_constant) and indices = %indices.  A reshape of %gather
155   // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately
156   // reshaped constant and indices = %indices.
157 
158   // Represents the result of a gather operation.  This gather operation may
159   // explicitly be present in the HLO module being analyzed, or it could have
160   // been created on the fly by the analysis.
161   //
162   // An instance of ScalarIndexedArray represents a array whose I'th element can
163   // be mapped to the J'th element of the `source` array (where I and J are
164   // multidimensional indices) in this way:
165   //
166   //   I' = remove components at positions `output_dims` from I
167   //   G' = remove components not at positions `output_dims` from I
168   //   T  = indices[G']
169   //   J  = I' with T inserted at position `source_dim`
170   //
171   // For example, if source is of shape [11,13,17,19], indices is of shape
172   // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of
173   // shape [23,11,29,13,19] and the output index [A,B,C,D,E] is mapped to the
174   // input index [B,D,indices[A,C],E].
175   class ScalarIndexedArray : public Array {
176    public:
kind()177     Kind kind() const override { return kScalarIndexed; }
shape()178     const Shape& shape() const override { return shape_; }
179 
source()180     Array* source() const { return source_; }
indices()181     Array* indices() const { return indices_; }
182 
183     // `source_dim` is the dimension in the source array that is being indexed
184     // over using indices from the `indices` array.  See the class documentation
185     // and the overview for more details.
source_dim()186     int64_t source_dim() const { return source_dim_; }
187 
188     // `output_dims` are the dimensions in the output array that are being used
189     // to compute an index into the `indices` array.  See the class
190     // documentation and the overview for more details.
output_dims()191     absl::Span<const int64_t> output_dims() const { return output_dims_; }
192 
193    private:
ScalarIndexedArray(Array * source,Array * indices,int64_t source_dim,std::vector<int64_t> output_dims,Shape shape)194     explicit ScalarIndexedArray(Array* source, Array* indices,
195                                 int64_t source_dim,
196                                 std::vector<int64_t> output_dims, Shape shape)
197         : source_(source),
198           indices_(indices),
199           source_dim_(source_dim),
200           output_dims_(std::move(output_dims)),
201           shape_(std::move(shape)) {}
202 
203     Array* source_;
204     Array* indices_;
205     int64_t source_dim_;
206     std::vector<int64_t> output_dims_;
207     Shape shape_;
208 
209     friend class IndexedArrayAnalysis;
210   };
211 
212   // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to
213   // have a ConstantArray instance as the source.  This is an ergonomic
214   // concession -- in theory it is possible to just keep ScalarIndexedArray and
215   // check source()->kind().
216   class ScalarIndexedConstantArray : public ScalarIndexedArray {
217    public:
kind()218     Kind kind() const override { return kScalarIndexedConstant; }
219 
literal()220     const Literal& literal() const {
221       return *source()->as<ConstantArray>()->literal();
222     }
223 
224    private:
ScalarIndexedConstantArray(Array * source,Array * indices,int64_t source_dim,std::vector<int64_t> output_dims,Shape shape)225     explicit ScalarIndexedConstantArray(Array* source, Array* indices,
226                                         int64_t source_dim,
227                                         std::vector<int64_t> output_dims,
228                                         Shape shape)
229         : ScalarIndexedArray(source, indices, source_dim,
230                              std::move(output_dims), std::move(shape)) {
231       CHECK(dynamic_cast<ConstantArray*>(source));
232     }
233 
234     friend class IndexedArrayAnalysis;
235   };
236 
237   // Returns an Array instance for `instr`.  The IndexedArrayAnalysis instance
238   // keeps ownership of the returned Array instance.
239   //
240   // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO
241   // instructions to IndexedArrayAnalysis::Array instances.  This entire cache
242   // becomes stale and may cause the analysis to return incorrect results if any
243   // transitive operand (stopping at the containing computation) is modified for
244   // any HLO instruction on which GetArrayFor has been invoked.
245   //
246   // NB!  By inspecting the implementation, you may be able to infer a stronger
247   // caching guarantee than what is mentioned above.  Nevertheless, what is
248   // stated above is the contract.
249   StatusOr<Array*> GetArrayFor(const HloInstruction* instr);
250 
251   // Pretty-prints the expression rooted at `root`.
252   std::string ToString(Array* root, bool print_constants = false);
253 
254  private:
255   // Helper function that ensures that every HLO instruction that is
256   // transitively used by `root` has an entry in `cache_`.
257   Status TraverseAndPopulateCache(const HloInstruction* root);
258 
259   // Creates an Array instance for `instr` under the assumption that all
260   // operations of `instr` are present in `cache_`.
261   StatusOr<Array*> ComputeArrayFor(const HloInstruction* instr);
262 
263   StatusOr<Array*> ComputeArrayForConstant(const Literal& literal);
264 
265   StatusOr<Array*> ComputeArrayForGather(
266       const Shape& shape, const GatherDimensionNumbers& dim_numbers,
267       absl::Span<const int64_t> slice_sizes, Array* source, Array* indices);
268 
269   StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
270       const Shape& shape, const DotDimensionNumbers& dim_numbers,
271       const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs,
272       ConstantArray* rhs);
273 
274   StatusOr<Array*> ComputeArrayForDotWithIndexedRhs(
275       const Shape& shape, const DotDimensionNumbers& dim_numbers,
276       const PrecisionConfig& precision_config, ConstantArray* lhs,
277       ScalarIndexedConstantArray* rhs);
278 
279   StatusOr<Array*> ComputeArrayForDot(const Shape& shape,
280                                       const DotDimensionNumbers& dim_numbers,
281                                       const PrecisionConfig& precision_config,
282                                       Array* lhs, Array* rhs);
283 
284   // This tries to fold a ScalarIndexedArray which has another
285   // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a
286   // ScalarIndexedArray as indices.  If `source` happened to be a
287   // ScalarIndexedConstantArray this can result in an expression that is more
288   // canonical.
289   //
290   // As an example, consider a gather operation, G0, gathering 7 elements from
291   // an array "Arr" of shape [100] resulting in an array of shape [7], and a
292   // second gather operation, G1, which gathers 3 elements out of the result of
293   // G0 resulting in an array of shape [3].  Let the indices uses by G0 be I0
294   // (of shape [7]) and the indices used by G1 be I1 (of shape [3]).  We can
295   // instead rewrite G1 to gather directly from "Arr" with the three indices
296   // from I0 as per I1.  In other words, we can rewrite:
297   //
298   //    G0 = [Arr[i] for i in I0]
299   //    G1 = [G0[i]  for i in I1]
300   //
301   // into
302   //
303   //    I2 = [I0[i]  for i in I1]
304   //    G1 = [Arr[i] for i in I2]
305   StatusOr<ScalarIndexedArray*> FoldGatherOfGather(
306       ScalarIndexedArray* source, Array* indices, int64_t source_dim,
307       absl::Span<const int64_t> output_dims, Shape shape);
308 
309   // Reshapes a scalar-indexed node to remove the degenerate dimensions in its
310   // output.  The result is always a scalar-indexed node.
311   StatusOr<ScalarIndexedArray*> ReshapeToRemoveDegenerateDims(
312       ScalarIndexedArray* operand);
313 
314   // Reshapes a scalar-indexed node such that the result has the degenerate
315   // dimensions `degenerate_dims`.  The result is always a scalar-indexed node.
316   StatusOr<ScalarIndexedArray*> ReshapeToAddDegenerateDims(
317       ScalarIndexedArray* operand, absl::Span<const int64_t> degenerate_dims);
318 
319   StatusOr<ScalarIndexedArray*> FoldReshapeOfGather(
320       const Shape& shape, ScalarIndexedConstantArray* operand);
321   StatusOr<ScalarIndexedArray*> FoldReshapeOfGatherNoDegenerateDims(
322       const Shape& shape, ScalarIndexedConstantArray* scalar_indexed);
323   StatusOr<Array*> ComputeArrayForReshape(const Shape& shape, Array* operand);
324 
325   StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
326                                                       Array* lhs, Array* rhs);
327   StatusOr<Array*> ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
328                                                      Array* operand);
329 
330   template <typename T, typename... Args>
Construct(Args &&...args)331   T* Construct(Args&&... args) {
332     T* new_tensor = new T(std::forward<Args>(args)...);
333     owned_tensors_.push_back(std::unique_ptr<T>(new_tensor));
334     return new_tensor;
335   }
336 
ConstructScalarIndexedArray(Array * source,Array * indices,int64_t source_dim,std::vector<int64_t> output_dims,Shape shape)337   ScalarIndexedArray* ConstructScalarIndexedArray(
338       Array* source, Array* indices, int64_t source_dim,
339       std::vector<int64_t> output_dims, Shape shape) {
340     if (source->kind() == Array::kConstant) {
341       return Construct<ScalarIndexedConstantArray>(source, indices, source_dim,
342                                                    std::move(output_dims),
343                                                    std::move(shape));
344     } else {
345       return Construct<ScalarIndexedArray>(source, indices, source_dim,
346                                            std::move(output_dims),
347                                            std::move(shape));
348     }
349   }
350 
TakeOwnership(Literal literal)351   Literal* TakeOwnership(Literal literal) {
352     owned_literals_.push_back(std::move(literal));
353     return &owned_literals_.back();
354   }
355 
TakeOwnership(StatusOr<Literal> literal_or_error)356   StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) {
357     TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error));
358     owned_literals_.push_back(std::move(literal));
359     return &owned_literals_.back();
360   }
361 
362   std::vector<std::unique_ptr<Array>> owned_tensors_;
363   std::vector<Literal> owned_literals_;
364   absl::flat_hash_map<const HloInstruction*, Array*> cache_;
365 };
366 
367 // A pass that prints all non-trivial results returned by IndexedArrayAnalysis.
368 // This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to
369 // unconditionally add to the regular HLO pass pipeline.
370 class IndexedArrayAnalysisPrinterPass : public HloModulePass {
371  public:
372   absl::string_view name() const override;
373   using HloPassInterface::Run;
374   StatusOr<bool> Run(
375       HloModule* module,
376       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
377 };
378 
379 }  // namespace xla
380 
381 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_
382