1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_RUNTIME_SYMBOLIC_SHAPE_H_ 17 #define XLA_RUNTIME_SYMBOLIC_SHAPE_H_ 18 19 #include "llvm/ADT/DenseSet.h" 20 #include "llvm/ADT/Hashing.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "tensorflow/compiler/xla/runtime/arguments.h" 23 #include "tensorflow/compiler/xla/runtime/constraints.h" 24 25 namespace xla { 26 namespace runtime { 27 28 // Symbolic shapes resolver computes the symbolic shapes of the arguments based 29 // on the function signature, and concrete shapes of the arguments at runtime. 30 // 31 // Example: dimensions that have the same symbolic shape at runtime. 32 // 33 // signature: func @compute(%arg0: tensor<?xf32>, %arg1: tensor<?xf32) 34 // ^ ^ 35 // arguments: memref<123xf32> memref<123xf32> 36 // ^ ^ 37 // symbolic shapes: [-2xf32] [-2xf32] 38 // 39 // Each unknown dimension in the function signature will be assigned a symbolic 40 // dimension. If multiple shaped arguments have unknown dimensions that are the 41 // same at runtime, they will be assigned the same symbolic dimensions value 42 // (e.g. `-2` in the example above). 43 // 44 // If an unknown dimension at runtime is equal to some statically known 45 // dimension in the function signature (of any shaped argument), it will be 46 // resolved to that statically known constant value: 47 // 48 // Example: in this example unknown dimension of `arg0` replaced with a `32`. 49 // 50 // signature: func @compute(%arg0: tensor<?xf32>, %arg1: tensor<32xf32>) 51 // ^ ^ 52 // arguments: memref<32xf32> memref<32xf32> 53 // ^ ^ 54 // symbolic shapes: [32xf32] [32xf32] 55 // 56 // Unknown dimensions that are `1` at runtime are always materialized as a 57 // statically known `1` in the symbolic shape. 58 class SymbolicShapesResolver { 59 public: 60 // Dimension size can be symbolic (<= -2) or static. 61 using SymbolicShape = llvm::SmallVector<int64_t>; 62 // Dimension size can be dynamic (ShapedType::kDynamicSize) or static. 63 using StaticShape = llvm::SmallVector<int64_t>; 64 65 SymbolicShapesResolver(const FunctionType& signature, 66 llvm::ArrayRef<ArgumentConstraint> constraints); 67 68 // Resolves symbolic shapes from the runtime arguments. Returns failure if 69 // runtime dimensions do not match the statically known dimensions. 70 llvm::ErrorOr<llvm::SmallVector<SymbolicShape>> Resolve( 71 ArgumentsRef arguments) const; 72 73 // Resolves symbolic shapes and computes the hash value from the runtime 74 // arguments. Returns failure if runtime dimensions do not match the 75 // statically known dimensions. 76 // 77 // This function might not return the same hash value as calling `Resolve` and 78 // then `Hash`, because it might use more efficient hashing algorithm. 79 llvm::ErrorOr<llvm::hash_code> ResolveHash(ArgumentsRef arguments) const; 80 81 // Replaces all symbolic dimensions with dynamic dimension. 82 static llvm::SmallVector<int64_t> Normalize(const SymbolicShape& shape); 83 84 // Computes a hash value of the symbolic shapes. 85 static llvm::hash_code Hash(llvm::ArrayRef<SymbolicShape> symbolic_shapes); 86 87 ArgumentConstraint constraint(size_t index) const; 88 size_t num_arguments() const; 89 bool has_argument_sizes(size_t index) const; 90 const StaticShape& argument_sizes(size_t index) const; 91 bool seen_static_size(size_t dim) const; 92 93 private: 94 // Constraints on the function arguments. 95 llvm::SmallVector<ArgumentConstraint> constraints_; 96 97 // Statically known sizes of shaped arguments from the function signature. For 98 // non-shaped arguments (e.g. opaque pointers) we keep empty shape value. 99 llvm::SmallVector<llvm::Optional<StaticShape>> arguments_sizes_; 100 101 // Values of statically known dimensions sizes in the function signature. 102 llvm::DenseSet<int64_t> seen_static_sizes_; 103 104 // The iteration order for the arguments when resolving symbolic shapes. 105 llvm::SmallVector<size_t> iteration_order_; 106 107 // The iteration order for the arguments when resolving symbolic shapes hash. 108 llvm::SmallVector<size_t> hash_iteration_order_; 109 }; 110 111 } // namespace runtime 112 } // namespace xla 113 114 #endif // XLA_RUNTIME_SYMBOLIC_SHAPE_H_ 115