xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/runtime/symbolic_shape.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_RUNTIME_SYMBOLIC_SHAPE_H_
17 #define XLA_RUNTIME_SYMBOLIC_SHAPE_H_
18 
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/ADT/Hashing.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "tensorflow/compiler/xla/runtime/arguments.h"
23 #include "tensorflow/compiler/xla/runtime/constraints.h"
24 
25 namespace xla {
26 namespace runtime {
27 
28 // Symbolic shapes resolver computes the symbolic shapes of the arguments based
29 // on the function signature, and concrete shapes of the arguments at runtime.
30 //
31 // Example: dimensions that have the same symbolic shape at runtime.
32 //
33 //   signature: func @compute(%arg0: tensor<?xf32>, %arg1: tensor<?xf32)
34 //                            ^                     ^
35 //   arguments:               memref<123xf32>       memref<123xf32>
36 //                            ^                     ^
37 //   symbolic shapes:         [-2xf32]              [-2xf32]
38 //
39 // Each unknown dimension in the function signature will be assigned a symbolic
40 // dimension. If multiple shaped arguments have unknown dimensions that are the
41 // same at runtime, they will be assigned the same symbolic dimensions value
42 // (e.g. `-2` in the example above).
43 //
44 // If an unknown dimension at runtime is equal to some statically known
45 // dimension in the function signature (of any shaped argument), it will be
46 // resolved to that statically known constant value:
47 //
48 // Example: in this example unknown dimension of `arg0` replaced with a `32`.
49 //
50 //  signature:  func @compute(%arg0: tensor<?xf32>, %arg1: tensor<32xf32>)
51 //                            ^                     ^
52 //  arguments:                memref<32xf32>        memref<32xf32>
53 //                            ^                     ^
54 //  symbolic shapes:          [32xf32]              [32xf32]
55 //
56 // Unknown dimensions that are `1` at runtime are always materialized as a
57 // statically known `1` in the symbolic shape.
58 class SymbolicShapesResolver {
59  public:
60   // Dimension size can be symbolic (<= -2) or static.
61   using SymbolicShape = llvm::SmallVector<int64_t>;
62   // Dimension size can be dynamic (ShapedType::kDynamicSize) or static.
63   using StaticShape = llvm::SmallVector<int64_t>;
64 
65   SymbolicShapesResolver(const FunctionType& signature,
66                          llvm::ArrayRef<ArgumentConstraint> constraints);
67 
68   // Resolves symbolic shapes from the runtime arguments. Returns failure if
69   // runtime dimensions do not match the statically known dimensions.
70   llvm::ErrorOr<llvm::SmallVector<SymbolicShape>> Resolve(
71       ArgumentsRef arguments) const;
72 
73   // Resolves symbolic shapes and computes the hash value from the runtime
74   // arguments. Returns failure if runtime dimensions do not match the
75   // statically known dimensions.
76   //
77   // This function might not return the same hash value as calling `Resolve` and
78   // then `Hash`, because it might use more efficient hashing algorithm.
79   llvm::ErrorOr<llvm::hash_code> ResolveHash(ArgumentsRef arguments) const;
80 
81   // Replaces all symbolic dimensions with dynamic dimension.
82   static llvm::SmallVector<int64_t> Normalize(const SymbolicShape& shape);
83 
84   // Computes a hash value of the symbolic shapes.
85   static llvm::hash_code Hash(llvm::ArrayRef<SymbolicShape> symbolic_shapes);
86 
87   ArgumentConstraint constraint(size_t index) const;
88   size_t num_arguments() const;
89   bool has_argument_sizes(size_t index) const;
90   const StaticShape& argument_sizes(size_t index) const;
91   bool seen_static_size(size_t dim) const;
92 
93  private:
94   // Constraints on the function arguments.
95   llvm::SmallVector<ArgumentConstraint> constraints_;
96 
97   // Statically known sizes of shaped arguments from the function signature. For
98   // non-shaped arguments (e.g. opaque pointers) we keep empty shape value.
99   llvm::SmallVector<llvm::Optional<StaticShape>> arguments_sizes_;
100 
101   // Values of statically known dimensions sizes in the function signature.
102   llvm::DenseSet<int64_t> seen_static_sizes_;
103 
104   // The iteration order for the arguments when resolving symbolic shapes.
105   llvm::SmallVector<size_t> iteration_order_;
106 
107   // The iteration order for the arguments when resolving symbolic shapes hash.
108   llvm::SmallVector<size_t> hash_iteration_order_;
109 };
110 
111 }  // namespace runtime
112 }  // namespace xla
113 
114 #endif  // XLA_RUNTIME_SYMBOLIC_SHAPE_H_
115