xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/runtime/symbolic_shape.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/runtime/symbolic_shape.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <iterator>
21 #include <numeric>
22 #include <system_error>  // NOLINT TODO(ezhulenev): Migrate to absl::Status.
23 
24 #include "llvm/ADT/Hashing.h"
25 #include "llvm/Support/Compiler.h"
26 #include "tensorflow/compiler/xla/runtime/arguments.h"
27 #include "tensorflow/compiler/xla/runtime/constraints.h"
28 #include "tensorflow/compiler/xla/runtime/types.h"
29 
30 namespace xla {
31 namespace runtime {
32 
33 using llvm::cast;
34 using llvm::dyn_cast;
35 using llvm::isa;
36 
37 using llvm::ArrayRef;
38 using llvm::MutableArrayRef;
39 
40 using SymbolicShape = SymbolicShapesResolver::SymbolicShape;
41 using StaticShape = SymbolicShapesResolver::StaticShape;
42 
SymbolicShapesResolver(const FunctionType & signature,ArrayRef<ArgumentConstraint> constraints)43 SymbolicShapesResolver::SymbolicShapesResolver(
44     const FunctionType& signature, ArrayRef<ArgumentConstraint> constraints)
45     : constraints_(constraints.begin(), constraints.end()) {
46   for (unsigned i = 0; i < signature.num_operands(); ++i) {
47     auto* type = signature.operand(i);
48 
49     // For unranked arguments we do not know any static shape information.
50     if (isa<UnrankedTensorType, UnrankedMemrefType>(type)) {
51       arguments_sizes_.emplace_back();
52       continue;
53     }
54 
55     auto emplace_sizes = [&](ArrayRef<int64_t> sizes) {
56       arguments_sizes_.emplace_back(llvm::to_vector(sizes));
57 
58       // Keep track of all statically known dimension sizes.
59       for (int64_t size : sizes) {
60         if (size != MemrefType::kDynamicSize) seen_static_sizes_.insert(size);
61       }
62     };
63 
64     // Copy memref dimensions sizes from the signature type.
65     if (auto* memref = dyn_cast<MemrefType>(type)) {
66       emplace_sizes(memref->sizes());
67       continue;
68     }
69 
70     // Copy tensor dimensions sizes from the signature type.
71     if (auto* tensor = dyn_cast<RankedTensorType>(type)) {
72       emplace_sizes(tensor->sizes());
73       continue;
74     }
75 
76     // TODO(ezhulenev): Add support for `ShapedType` to allow users to enable
77     // symbolic shape resolution for user-defined types.
78 
79     // All non-shaped types have statically known empty shape.
80     emplace_sizes({});
81   }
82 
83   // When resolving symbolic shapes we should visit arguments starting from the
84   // more constrained ones, because they can change the static signature of the
85   // function, and this information should be propagated to arguments with
86   // dynamic shapes (e.g. all seen static sizes should be materialized in the
87   // function signature).
88   iteration_order_.resize(signature.num_operands());
89   std::iota(iteration_order_.begin(), iteration_order_.end(), 0);
90 
91   // Make the sort stable so that dynamic shapes are computed deterministically.
92   llvm::sort(iteration_order_, [&](size_t a, size_t b) {
93     unsigned ca = static_cast<unsigned>(constraints[a]);
94     unsigned cb = static_cast<unsigned>(constraints[b]);
95     if (ca > cb) return true;
96     return ca < cb ? false : a < b;
97   });
98 
99   // We can safely skip arguments with a known empty symbolic shape, because
100   // that's the default value we return when resolving symbolic shapes for
101   // the arguments, and such shapes do not participate in the hash computation.
102   llvm::erase_if(iteration_order_, [&](size_t i) {
103     return arguments_sizes_[i].has_value() && arguments_sizes_[i]->empty();
104   });
105 
106   // When computing a symbolic shapes hash we don't need to visit arguments with
107   // a statically known shape.
108   auto is_dynamic_shape_argument = [&](size_t idx) {
109     return !arguments_sizes_[idx].has_value() ||
110            llvm::any_of(*arguments_sizes_[idx],
111                         [](int64_t d) { return d < 0; });
112   };
113   llvm::copy_if(iteration_order_, std::back_inserter(hash_iteration_order_),
114                 is_dynamic_shape_argument);
115 }
116 
constraint(size_t index) const117 ArgumentConstraint SymbolicShapesResolver::constraint(size_t index) const {
118   return constraints_[index];
119 }
120 
num_arguments() const121 size_t SymbolicShapesResolver::num_arguments() const {
122   return arguments_sizes_.size();
123 }
124 
has_argument_sizes(size_t index) const125 bool SymbolicShapesResolver::has_argument_sizes(size_t index) const {
126   return arguments_sizes_[index].has_value();
127 }
128 
argument_sizes(size_t index) const129 const StaticShape& SymbolicShapesResolver::argument_sizes(size_t index) const {
130   return *arguments_sizes_[index];
131 }
132 
seen_static_size(size_t dim) const133 bool SymbolicShapesResolver::seen_static_size(size_t dim) const {
134   return seen_static_sizes_.contains(dim);
135 }
136 
137 template <typename SymbolicShapes>
ResolveImpl(const SymbolicShapesResolver & resolver,ArgumentsRef arguments,ArrayRef<size_t> iteration_order,SymbolicShapes & symbolic_shapes)138 LLVM_ATTRIBUTE_ALWAYS_INLINE static std::error_code ResolveImpl(
139     const SymbolicShapesResolver& resolver, ArgumentsRef arguments,
140     ArrayRef<size_t> iteration_order, SymbolicShapes& symbolic_shapes) {
141   // The number of arguments must match the function signature.
142   assert(arguments.size() == resolver.num_arguments());
143 
144   // Mapping from the runtime dimension size to the symbolic dimension.
145   llvm::SmallDenseMap<int64_t, int64_t, 16> size_to_symbolic_dim;
146 
147   int64_t sym_dim = -2;  // the next symbolic dimension id
148 
149   for (size_t i : iteration_order) {
150     bool has_static_sizes = resolver.has_argument_sizes(i);
151 
152     // TODO(ezhulenev): Add support for `ShapedArgument` to allow users to
153     // enable symbolic shape resolution for user-defined arguments.
154     //
155     // At this point it's guaranteed that the argument at `i` is a shaped one,
156     // because non-shaped argument are not in the `iteration_order`.
157     const MemrefDesc* shaped = cast<MemrefDesc>(&arguments[i]);
158     ArrayRef<int64_t> runtime_sizes = shaped->sizes();
159 
160     // Check that statically known rank matches the runtime rank.
161     if (LLVM_UNLIKELY(has_static_sizes && resolver.argument_sizes(i).size() !=
162                                               runtime_sizes.size()))
163       return llvm::errc::invalid_argument;
164 
165     // For shape constrained argument use runtime shape.
166     if (resolver.constraint(i) == ArgumentConstraint::kShape) {
167       symbolic_shapes[i].assign(runtime_sizes.begin(), runtime_sizes.end());
168 
169       // Add all runtime dimensions to the `size_to_symbolic_dim` to materialize
170       // all dynamic dimensions of the same size as static dimensions.
171       for (int64_t d : runtime_sizes) size_to_symbolic_dim.try_emplace(d, d);
172 
173       continue;
174     }
175 
176     // Initialize symbolic shape with a statically known shape of the argument
177     // if it is available, otherwise initialize it with a fully dynamic shape
178     // with rank matching the runtime rank.
179     if (has_static_sizes) {
180       ArrayRef<int64_t> static_sizes = resolver.argument_sizes(i);
181       assert(runtime_sizes.size() == static_sizes.size());
182       symbolic_shapes[i].assign(static_sizes.begin(), static_sizes.end());
183     } else {
184       size_t rank = runtime_sizes.size();
185       symbolic_shapes[i].resize(rank, MemrefType::kDynamicSize);
186     }
187 
188     MutableArrayRef<int64_t> symbolic_sizes = symbolic_shapes[i];
189 
190     for (unsigned d = 0; d < runtime_sizes.size(); ++d) {
191       int64_t symbolic_dim = symbolic_sizes[d];
192       int64_t runtime_dim = runtime_sizes[d];
193 
194       // Skip statically known dimensions.
195       if (symbolic_dim >= 0) {
196         // Check that statically known dimension agrees with runtime dimension.
197         if (LLVM_UNLIKELY(symbolic_dim != runtime_dim))
198           return llvm::errc::invalid_argument;
199         continue;
200       }
201 
202       // Update unknown dimension to a static dimension.
203       if (runtime_dim == 1 || resolver.seen_static_size(runtime_dim)) {
204         symbolic_sizes[d] = runtime_dim;
205         continue;
206       }
207 
208       // Try to assign a symbolic dimension to the runtime dimension.
209       auto emplaced = size_to_symbolic_dim.try_emplace(runtime_dim, sym_dim);
210       symbolic_sizes[d] = emplaced.first->second;
211 
212       // Update the symbolic dimension if we assigned the previous value to the
213       // runtime dimension size.
214       if (emplaced.second) --sym_dim;
215     }
216   }
217 
218   return {};
219 }
220 
Resolve(ArgumentsRef arguments) const221 llvm::ErrorOr<llvm::SmallVector<SymbolicShape>> SymbolicShapesResolver::Resolve(
222     ArgumentsRef arguments) const {
223   // Prepare storage for resolving symbolic shapes.
224   llvm::SmallVector<SymbolicShape> symbolic_shapes;
225   symbolic_shapes.resize(arguments.size());
226 
227   if (LLVM_UNLIKELY(
228           ResolveImpl(*this, arguments, iteration_order_, symbolic_shapes)))
229     return llvm::errc::invalid_argument;
230 
231   return symbolic_shapes;
232 }
233 
234 namespace {
235 // A struct to accumulate all resolved symbolic dimensions in a single vector.
236 // Resolved symbolic dimensions stored according to the iteration order, and not
237 // the argument order, however for computing the hash value it doesn't matter.
238 struct SymbolicShapesFingerprint {
SymbolicShapesFingerprintxla::runtime::__anonbef848620611::SymbolicShapesFingerprint239   SymbolicShapesFingerprint() : offset(0) {}
240 
241   // Make sure that we do not copy the fingerprint.
242   SymbolicShapesFingerprint(const SymbolicShapesFingerprint&) = delete;
243 
operator []xla::runtime::__anonbef848620611::SymbolicShapesFingerprint244   SymbolicShapesFingerprint& operator[](size_t i) { return *this; }
245 
246   template <typename InputIt>
assignxla::runtime::__anonbef848620611::SymbolicShapesFingerprint247   LLVM_ATTRIBUTE_ALWAYS_INLINE void assign(InputIt first, InputIt last) {
248     auto rank = std::distance(first, last);
249     offset = values.size();
250     values.resize_for_overwrite(offset + rank);
251     llvm::copy(llvm::make_range(first, last), values.begin() + offset);
252   }
253 
resizexla::runtime::__anonbef848620611::SymbolicShapesFingerprint254   LLVM_ATTRIBUTE_ALWAYS_INLINE void resize(int64_t rank, int64_t dim) {
255     values.push_back(rank);
256     offset = values.size();
257     values.resize(offset + rank, dim);
258   }
259 
operator MutableArrayRef<int64_t>xla::runtime::__anonbef848620611::SymbolicShapesFingerprint260   operator MutableArrayRef<int64_t>() {  // NOLINT
261     return {values.begin() + offset, values.end()};
262   }
263 
264   size_t offset;
265   llvm::SmallVector<int64_t, 32> values;
266 };
267 }  // namespace
268 
ResolveHash(ArgumentsRef arguments) const269 llvm::ErrorOr<llvm::hash_code> SymbolicShapesResolver::ResolveHash(
270     ArgumentsRef arguments) const {
271   // Accumulate symbolic shapes into the shapes fingerprint.
272   SymbolicShapesFingerprint fingerprint;
273 
274   if (LLVM_UNLIKELY(
275           ResolveImpl(*this, arguments, hash_iteration_order_, fingerprint)))
276     return llvm::errc::invalid_argument;
277 
278   return llvm::hash_combine_range(fingerprint.values.begin(),
279                                   fingerprint.values.end());
280 }
281 
Normalize(const SymbolicShape & shape)282 /*static*/ llvm::SmallVector<int64_t> SymbolicShapesResolver::Normalize(
283     const SymbolicShape& shape) {
284   auto normalize = llvm::map_range(shape, [](int64_t dim) {
285     return std::max(dim, MemrefType::kDynamicSize);
286   });
287   return {normalize.begin(), normalize.end()};
288 }
289 
SymbolicShapeHash(const SymbolicShape & shape)290 static llvm::hash_code SymbolicShapeHash(const SymbolicShape& shape) {
291   return llvm::hash_combine(
292       shape.size(), llvm::hash_combine_range(shape.begin(), shape.end()));
293 }
294 
Hash(ArrayRef<SymbolicShape> symbolic_shapes)295 /*static*/ llvm::hash_code SymbolicShapesResolver::Hash(
296     ArrayRef<SymbolicShape> symbolic_shapes) {
297   if (LLVM_UNLIKELY(symbolic_shapes.empty())) return llvm::hash_code(0);
298 
299   llvm::hash_code hash = SymbolicShapeHash(symbolic_shapes[0]);
300   for (unsigned i = 1; i < symbolic_shapes.size(); ++i)
301     hash = llvm::hash_combine(hash, SymbolicShapeHash(symbolic_shapes[i]));
302 
303   return hash;
304 }
305 
306 }  // namespace runtime
307 }  // namespace xla
308