1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ 18 19 #include <type_traits> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 23 #include "tensorflow/compiler/xla/service/hlo_module.h" 24 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 25 #include "tensorflow/core/util/ptr_util.h" 26 27 namespace xla { 28 29 // IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a 30 // gather from another array. It does this by mapping HLO instructions to 31 // instances of IndexedArrayAnalysis::Array, which can be inspected to discover 32 // whether said HLO is equivalent to a gather. 33 class IndexedArrayAnalysis { 34 public: 35 // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array. 36 // Array really just a sum type of the classes that inherit from it. The 37 // meaning of each of the subtypes is documented on the subtype declaration. 38 // 39 // Array instances are immutable once created. 40 class Array { 41 public: 42 enum Kind { 43 kUnknown, 44 kConstant, 45 kReshaped, 46 kScalarIndexedConstant, 47 kScalarIndexed 48 }; 49 50 virtual Kind kind() const = 0; 51 virtual const Shape& shape() const = 0; 52 53 // Does a checked downcast from `Array` to `T` which must be one of its 54 // subtypes. 55 template <typename T> as()56 T* as() { 57 static_assert((std::is_base_of<Array, T>::value), 58 "target type not derived from source type"); 59 // We skip the CHECK and hence the dynamic_cast if RTTI is disabled. 60 #if !defined(__GNUC__) || defined(__GXX_RTTI) 61 CHECK_NE(dynamic_cast<T*>(this), nullptr); 62 #endif // !defined(__GNUC__) || defined(__GXX_RTTI) 63 64 return static_cast<T*>(this); 65 } 66 67 virtual ~Array() = default; 68 69 Array& operator=(const Array& other) = delete; 70 }; 71 72 // Represents an HLO instruction that was not analyzable by this 73 // IndexedArrayAnalysis. Instances of UnknownArray just wrap an existing 74 // HloInstruction. 75 class UnknownArray : public Array { 76 public: kind()77 Kind kind() const override { return kUnknown; } shape()78 const Shape& shape() const override { return instruction().shape(); } instruction()79 const HloInstruction& instruction() const { return instruction_; } 80 81 private: UnknownArray(const HloInstruction * instr)82 explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {} 83 84 const HloInstruction& instruction_; 85 86 friend class IndexedArrayAnalysis; 87 }; 88 89 // Represents a constant value. This constant value may be present in the HLO 90 // module being analyzed, or it could have been created on the fly by the 91 // analysis. 92 class ConstantArray : public Array { 93 public: kind()94 Kind kind() const override { return kConstant; } shape()95 const Shape& shape() const override { return literal()->shape(); } literal()96 const Literal* literal() const { return literal_; } 97 98 private: ConstantArray(const Literal * literal)99 explicit ConstantArray(const Literal* literal) : literal_(literal) {} 100 const Literal* literal_; 101 102 friend class IndexedArrayAnalysis; 103 }; 104 105 // Represents an Array that is a reshape of another Array. 106 class ReshapedArray : public Array { 107 public: kind()108 Kind kind() const override { return kReshaped; } 109 110 // The array to reshape. operand()111 Array* operand() const { return operand_; } 112 113 // The output shape. shape()114 const Shape& shape() const override { return shape_; } 115 116 private: ReshapedArray(Array * operand,Shape shape)117 explicit ReshapedArray(Array* operand, Shape shape) 118 : operand_(operand), shape_(shape) {} 119 120 Array* operand_; 121 const Shape shape_; 122 123 friend class IndexedArrayAnalysis; 124 }; 125 126 // --------------------------------------------------------------------------- 127 // Indexed Array Overview 128 // --------------------------------------------------------------------------- 129 // 130 // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this 131 // analysis. ScalarIndexedConstantArray is just a specialization of 132 // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this 133 // overview. 134 // 135 // A ScalarIndexedArray represents an array that can be computed by indexing 136 // into a "source" array using an "indices" tensor. A simple example is a 137 // gather operation gathering 12 rows out of a [100,100] matrix -- such an 138 // operation will be represented by an instance of a ScalarIndexedArray with 139 // the [100,100] matrix as the "source" array and the [12]-shaped indices 140 // array as the "indices" tensor. The ScalarIndexedArray operation itself 141 // will be of shape [12,100] (assuming we were gathering with axis=0). 142 // 143 // Gather operations are not the only operation that maps to 144 // ScalarIndexedArray instances (if that were true there would be little point 145 // in having a separate analysis). We can often infer ScalarIndexedArrays for 146 // other operations too. For instance, consider: 147 // 148 // %source = f32[100,100] constant 149 // %indices = s32[12] ... 150 // %gather = f32[12,100] ... gather from %source using %indices at axis 0 151 // %dot = dot(%gather, other_constant) [canonical contracting dims] 152 // 153 // The dot operation itself is also a ScalarIndexedArray with source = 154 // dot(constant, other_constant) and indices = %indices. A reshape of %gather 155 // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately 156 // reshaped constant and indices = %indices. 157 158 // Represents the result of a gather operation. This gather operation may 159 // explicitly be present in the HLO module being analyzed, or it could have 160 // been created on the fly by the analysis. 161 // 162 // An instance of ScalarIndexedArray represents a array whose I'th element can 163 // be mapped to the J'th element of the `source` array (where I and J are 164 // multidimensional indices) in this way: 165 // 166 // I' = remove components at positions `output_dims` from I 167 // G' = remove components not at positions `output_dims` from I 168 // T = indices[G'] 169 // J = I' with T inserted at position `source_dim` 170 // 171 // For example, if source is of shape [11,13,17,19], indices is of shape 172 // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of 173 // shape [23,11,29,13,19] and the output index [A,B,C,D,E] is mapped to the 174 // input index [B,D,indices[A,C],E]. 175 class ScalarIndexedArray : public Array { 176 public: kind()177 Kind kind() const override { return kScalarIndexed; } shape()178 const Shape& shape() const override { return shape_; } 179 source()180 Array* source() const { return source_; } indices()181 Array* indices() const { return indices_; } 182 183 // `source_dim` is the dimension in the source array that is being indexed 184 // over using indices from the `indices` array. See the class documentation 185 // and the overview for more details. source_dim()186 int64_t source_dim() const { return source_dim_; } 187 188 // `output_dims` are the dimensions in the output array that are being used 189 // to compute an index into the `indices` array. See the class 190 // documentation and the overview for more details. output_dims()191 absl::Span<const int64_t> output_dims() const { return output_dims_; } 192 193 private: ScalarIndexedArray(Array * source,Array * indices,int64_t source_dim,std::vector<int64_t> output_dims,Shape shape)194 explicit ScalarIndexedArray(Array* source, Array* indices, 195 int64_t source_dim, 196 std::vector<int64_t> output_dims, Shape shape) 197 : source_(source), 198 indices_(indices), 199 source_dim_(source_dim), 200 output_dims_(std::move(output_dims)), 201 shape_(std::move(shape)) {} 202 203 Array* source_; 204 Array* indices_; 205 int64_t source_dim_; 206 std::vector<int64_t> output_dims_; 207 Shape shape_; 208 209 friend class IndexedArrayAnalysis; 210 }; 211 212 // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to 213 // have a ConstantArray instance as the source. This is an ergonomic 214 // concession -- in theory it is possible to just keep ScalarIndexedArray and 215 // check source()->kind(). 216 class ScalarIndexedConstantArray : public ScalarIndexedArray { 217 public: kind()218 Kind kind() const override { return kScalarIndexedConstant; } 219 literal()220 const Literal& literal() const { 221 return *source()->as<ConstantArray>()->literal(); 222 } 223 224 private: ScalarIndexedConstantArray(Array * source,Array * indices,int64_t source_dim,std::vector<int64_t> output_dims,Shape shape)225 explicit ScalarIndexedConstantArray(Array* source, Array* indices, 226 int64_t source_dim, 227 std::vector<int64_t> output_dims, 228 Shape shape) 229 : ScalarIndexedArray(source, indices, source_dim, 230 std::move(output_dims), std::move(shape)) { 231 CHECK(dynamic_cast<ConstantArray*>(source)); 232 } 233 234 friend class IndexedArrayAnalysis; 235 }; 236 237 // Returns an Array instance for `instr`. The IndexedArrayAnalysis instance 238 // keeps ownership of the returned Array instance. 239 // 240 // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO 241 // instructions to IndexedArrayAnalysis::Array instances. This entire cache 242 // becomes stale and may cause the analysis to return incorrect results if any 243 // transitive operand (stopping at the containing computation) is modified for 244 // any HLO instruction on which GetArrayFor has been invoked. 245 // 246 // NB! By inspecting the implementation, you may be able to infer a stronger 247 // caching guarantee than what is mentioned above. Nevertheless, what is 248 // stated above is the contract. 249 StatusOr<Array*> GetArrayFor(const HloInstruction* instr); 250 251 // Pretty-prints the expression rooted at `root`. 252 std::string ToString(Array* root, bool print_constants = false); 253 254 private: 255 // Helper function that ensures that every HLO instruction that is 256 // transitively used by `root` has an entry in `cache_`. 257 Status TraverseAndPopulateCache(const HloInstruction* root); 258 259 // Creates an Array instance for `instr` under the assumption that all 260 // operations of `instr` are present in `cache_`. 261 StatusOr<Array*> ComputeArrayFor(const HloInstruction* instr); 262 263 StatusOr<Array*> ComputeArrayForConstant(const Literal& literal); 264 265 StatusOr<Array*> ComputeArrayForGather( 266 const Shape& shape, const GatherDimensionNumbers& dim_numbers, 267 absl::Span<const int64_t> slice_sizes, Array* source, Array* indices); 268 269 StatusOr<Array*> ComputeArrayForDotWithIndexedLhs( 270 const Shape& shape, const DotDimensionNumbers& dim_numbers, 271 const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, 272 ConstantArray* rhs); 273 274 StatusOr<Array*> ComputeArrayForDotWithIndexedRhs( 275 const Shape& shape, const DotDimensionNumbers& dim_numbers, 276 const PrecisionConfig& precision_config, ConstantArray* lhs, 277 ScalarIndexedConstantArray* rhs); 278 279 StatusOr<Array*> ComputeArrayForDot(const Shape& shape, 280 const DotDimensionNumbers& dim_numbers, 281 const PrecisionConfig& precision_config, 282 Array* lhs, Array* rhs); 283 284 // This tries to fold a ScalarIndexedArray which has another 285 // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a 286 // ScalarIndexedArray as indices. If `source` happened to be a 287 // ScalarIndexedConstantArray this can result in an expression that is more 288 // canonical. 289 // 290 // As an example, consider a gather operation, G0, gathering 7 elements from 291 // an array "Arr" of shape [100] resulting in an array of shape [7], and a 292 // second gather operation, G1, which gathers 3 elements out of the result of 293 // G0 resulting in an array of shape [3]. Let the indices uses by G0 be I0 294 // (of shape [7]) and the indices used by G1 be I1 (of shape [3]). We can 295 // instead rewrite G1 to gather directly from "Arr" with the three indices 296 // from I0 as per I1. In other words, we can rewrite: 297 // 298 // G0 = [Arr[i] for i in I0] 299 // G1 = [G0[i] for i in I1] 300 // 301 // into 302 // 303 // I2 = [I0[i] for i in I1] 304 // G1 = [Arr[i] for i in I2] 305 StatusOr<ScalarIndexedArray*> FoldGatherOfGather( 306 ScalarIndexedArray* source, Array* indices, int64_t source_dim, 307 absl::Span<const int64_t> output_dims, Shape shape); 308 309 // Reshapes a scalar-indexed node to remove the degenerate dimensions in its 310 // output. The result is always a scalar-indexed node. 311 StatusOr<ScalarIndexedArray*> ReshapeToRemoveDegenerateDims( 312 ScalarIndexedArray* operand); 313 314 // Reshapes a scalar-indexed node such that the result has the degenerate 315 // dimensions `degenerate_dims`. The result is always a scalar-indexed node. 316 StatusOr<ScalarIndexedArray*> ReshapeToAddDegenerateDims( 317 ScalarIndexedArray* operand, absl::Span<const int64_t> degenerate_dims); 318 319 StatusOr<ScalarIndexedArray*> FoldReshapeOfGather( 320 const Shape& shape, ScalarIndexedConstantArray* operand); 321 StatusOr<ScalarIndexedArray*> FoldReshapeOfGatherNoDegenerateDims( 322 const Shape& shape, ScalarIndexedConstantArray* scalar_indexed); 323 StatusOr<Array*> ComputeArrayForReshape(const Shape& shape, Array* operand); 324 325 StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, 326 Array* lhs, Array* rhs); 327 StatusOr<Array*> ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, 328 Array* operand); 329 330 template <typename T, typename... Args> Construct(Args &&...args)331 T* Construct(Args&&... args) { 332 T* new_tensor = new T(std::forward<Args>(args)...); 333 owned_tensors_.push_back(std::unique_ptr<T>(new_tensor)); 334 return new_tensor; 335 } 336 ConstructScalarIndexedArray(Array * source,Array * indices,int64_t source_dim,std::vector<int64_t> output_dims,Shape shape)337 ScalarIndexedArray* ConstructScalarIndexedArray( 338 Array* source, Array* indices, int64_t source_dim, 339 std::vector<int64_t> output_dims, Shape shape) { 340 if (source->kind() == Array::kConstant) { 341 return Construct<ScalarIndexedConstantArray>(source, indices, source_dim, 342 std::move(output_dims), 343 std::move(shape)); 344 } else { 345 return Construct<ScalarIndexedArray>(source, indices, source_dim, 346 std::move(output_dims), 347 std::move(shape)); 348 } 349 } 350 TakeOwnership(Literal literal)351 Literal* TakeOwnership(Literal literal) { 352 owned_literals_.push_back(std::move(literal)); 353 return &owned_literals_.back(); 354 } 355 TakeOwnership(StatusOr<Literal> literal_or_error)356 StatusOr<Literal*> TakeOwnership(StatusOr<Literal> literal_or_error) { 357 TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error)); 358 owned_literals_.push_back(std::move(literal)); 359 return &owned_literals_.back(); 360 } 361 362 std::vector<std::unique_ptr<Array>> owned_tensors_; 363 std::vector<Literal> owned_literals_; 364 absl::flat_hash_map<const HloInstruction*, Array*> cache_; 365 }; 366 367 // A pass that prints all non-trivial results returned by IndexedArrayAnalysis. 368 // This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to 369 // unconditionally add to the regular HLO pass pipeline. 370 class IndexedArrayAnalysisPrinterPass : public HloModulePass { 371 public: 372 absl::string_view name() const override; 373 using HloPassInterface::Run; 374 StatusOr<bool> Run( 375 HloModule* module, 376 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 377 }; 378 379 } // namespace xla 380 381 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ 382