1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/runtime/symbolic_shape.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <iterator>
21 #include <numeric>
22 #include <system_error> // NOLINT TODO(ezhulenev): Migrate to absl::Status.
23
24 #include "llvm/ADT/Hashing.h"
25 #include "llvm/Support/Compiler.h"
26 #include "tensorflow/compiler/xla/runtime/arguments.h"
27 #include "tensorflow/compiler/xla/runtime/constraints.h"
28 #include "tensorflow/compiler/xla/runtime/types.h"
29
30 namespace xla {
31 namespace runtime {
32
33 using llvm::cast;
34 using llvm::dyn_cast;
35 using llvm::isa;
36
37 using llvm::ArrayRef;
38 using llvm::MutableArrayRef;
39
40 using SymbolicShape = SymbolicShapesResolver::SymbolicShape;
41 using StaticShape = SymbolicShapesResolver::StaticShape;
42
SymbolicShapesResolver(const FunctionType & signature,ArrayRef<ArgumentConstraint> constraints)43 SymbolicShapesResolver::SymbolicShapesResolver(
44 const FunctionType& signature, ArrayRef<ArgumentConstraint> constraints)
45 : constraints_(constraints.begin(), constraints.end()) {
46 for (unsigned i = 0; i < signature.num_operands(); ++i) {
47 auto* type = signature.operand(i);
48
49 // For unranked arguments we do not know any static shape information.
50 if (isa<UnrankedTensorType, UnrankedMemrefType>(type)) {
51 arguments_sizes_.emplace_back();
52 continue;
53 }
54
55 auto emplace_sizes = [&](ArrayRef<int64_t> sizes) {
56 arguments_sizes_.emplace_back(llvm::to_vector(sizes));
57
58 // Keep track of all statically known dimension sizes.
59 for (int64_t size : sizes) {
60 if (size != MemrefType::kDynamicSize) seen_static_sizes_.insert(size);
61 }
62 };
63
64 // Copy memref dimensions sizes from the signature type.
65 if (auto* memref = dyn_cast<MemrefType>(type)) {
66 emplace_sizes(memref->sizes());
67 continue;
68 }
69
70 // Copy tensor dimensions sizes from the signature type.
71 if (auto* tensor = dyn_cast<RankedTensorType>(type)) {
72 emplace_sizes(tensor->sizes());
73 continue;
74 }
75
76 // TODO(ezhulenev): Add support for `ShapedType` to allow users to enable
77 // symbolic shape resolution for user-defined types.
78
79 // All non-shaped types have statically known empty shape.
80 emplace_sizes({});
81 }
82
83 // When resolving symbolic shapes we should visit arguments starting from the
84 // more constrained ones, because they can change the static signature of the
85 // function, and this information should be propagated to arguments with
86 // dynamic shapes (e.g. all seen static sizes should be materialized in the
87 // function signature).
88 iteration_order_.resize(signature.num_operands());
89 std::iota(iteration_order_.begin(), iteration_order_.end(), 0);
90
91 // Make the sort stable so that dynamic shapes are computed deterministically.
92 llvm::sort(iteration_order_, [&](size_t a, size_t b) {
93 unsigned ca = static_cast<unsigned>(constraints[a]);
94 unsigned cb = static_cast<unsigned>(constraints[b]);
95 if (ca > cb) return true;
96 return ca < cb ? false : a < b;
97 });
98
99 // We can safely skip arguments with a known empty symbolic shape, because
100 // that's the default value we return when resolving symbolic shapes for
101 // the arguments, and such shapes do not participate in the hash computation.
102 llvm::erase_if(iteration_order_, [&](size_t i) {
103 return arguments_sizes_[i].has_value() && arguments_sizes_[i]->empty();
104 });
105
106 // When computing a symbolic shapes hash we don't need to visit arguments with
107 // a statically known shape.
108 auto is_dynamic_shape_argument = [&](size_t idx) {
109 return !arguments_sizes_[idx].has_value() ||
110 llvm::any_of(*arguments_sizes_[idx],
111 [](int64_t d) { return d < 0; });
112 };
113 llvm::copy_if(iteration_order_, std::back_inserter(hash_iteration_order_),
114 is_dynamic_shape_argument);
115 }
116
constraint(size_t index) const117 ArgumentConstraint SymbolicShapesResolver::constraint(size_t index) const {
118 return constraints_[index];
119 }
120
num_arguments() const121 size_t SymbolicShapesResolver::num_arguments() const {
122 return arguments_sizes_.size();
123 }
124
has_argument_sizes(size_t index) const125 bool SymbolicShapesResolver::has_argument_sizes(size_t index) const {
126 return arguments_sizes_[index].has_value();
127 }
128
argument_sizes(size_t index) const129 const StaticShape& SymbolicShapesResolver::argument_sizes(size_t index) const {
130 return *arguments_sizes_[index];
131 }
132
seen_static_size(size_t dim) const133 bool SymbolicShapesResolver::seen_static_size(size_t dim) const {
134 return seen_static_sizes_.contains(dim);
135 }
136
137 template <typename SymbolicShapes>
ResolveImpl(const SymbolicShapesResolver & resolver,ArgumentsRef arguments,ArrayRef<size_t> iteration_order,SymbolicShapes & symbolic_shapes)138 LLVM_ATTRIBUTE_ALWAYS_INLINE static std::error_code ResolveImpl(
139 const SymbolicShapesResolver& resolver, ArgumentsRef arguments,
140 ArrayRef<size_t> iteration_order, SymbolicShapes& symbolic_shapes) {
141 // The number of arguments must match the function signature.
142 assert(arguments.size() == resolver.num_arguments());
143
144 // Mapping from the runtime dimension size to the symbolic dimension.
145 llvm::SmallDenseMap<int64_t, int64_t, 16> size_to_symbolic_dim;
146
147 int64_t sym_dim = -2; // the next symbolic dimension id
148
149 for (size_t i : iteration_order) {
150 bool has_static_sizes = resolver.has_argument_sizes(i);
151
152 // TODO(ezhulenev): Add support for `ShapedArgument` to allow users to
153 // enable symbolic shape resolution for user-defined arguments.
154 //
155 // At this point it's guaranteed that the argument at `i` is a shaped one,
156 // because non-shaped argument are not in the `iteration_order`.
157 const MemrefDesc* shaped = cast<MemrefDesc>(&arguments[i]);
158 ArrayRef<int64_t> runtime_sizes = shaped->sizes();
159
160 // Check that statically known rank matches the runtime rank.
161 if (LLVM_UNLIKELY(has_static_sizes && resolver.argument_sizes(i).size() !=
162 runtime_sizes.size()))
163 return llvm::errc::invalid_argument;
164
165 // For shape constrained argument use runtime shape.
166 if (resolver.constraint(i) == ArgumentConstraint::kShape) {
167 symbolic_shapes[i].assign(runtime_sizes.begin(), runtime_sizes.end());
168
169 // Add all runtime dimensions to the `size_to_symbolic_dim` to materialize
170 // all dynamic dimensions of the same size as static dimensions.
171 for (int64_t d : runtime_sizes) size_to_symbolic_dim.try_emplace(d, d);
172
173 continue;
174 }
175
176 // Initialize symbolic shape with a statically known shape of the argument
177 // if it is available, otherwise initialize it with a fully dynamic shape
178 // with rank matching the runtime rank.
179 if (has_static_sizes) {
180 ArrayRef<int64_t> static_sizes = resolver.argument_sizes(i);
181 assert(runtime_sizes.size() == static_sizes.size());
182 symbolic_shapes[i].assign(static_sizes.begin(), static_sizes.end());
183 } else {
184 size_t rank = runtime_sizes.size();
185 symbolic_shapes[i].resize(rank, MemrefType::kDynamicSize);
186 }
187
188 MutableArrayRef<int64_t> symbolic_sizes = symbolic_shapes[i];
189
190 for (unsigned d = 0; d < runtime_sizes.size(); ++d) {
191 int64_t symbolic_dim = symbolic_sizes[d];
192 int64_t runtime_dim = runtime_sizes[d];
193
194 // Skip statically known dimensions.
195 if (symbolic_dim >= 0) {
196 // Check that statically known dimension agrees with runtime dimension.
197 if (LLVM_UNLIKELY(symbolic_dim != runtime_dim))
198 return llvm::errc::invalid_argument;
199 continue;
200 }
201
202 // Update unknown dimension to a static dimension.
203 if (runtime_dim == 1 || resolver.seen_static_size(runtime_dim)) {
204 symbolic_sizes[d] = runtime_dim;
205 continue;
206 }
207
208 // Try to assign a symbolic dimension to the runtime dimension.
209 auto emplaced = size_to_symbolic_dim.try_emplace(runtime_dim, sym_dim);
210 symbolic_sizes[d] = emplaced.first->second;
211
212 // Update the symbolic dimension if we assigned the previous value to the
213 // runtime dimension size.
214 if (emplaced.second) --sym_dim;
215 }
216 }
217
218 return {};
219 }
220
Resolve(ArgumentsRef arguments) const221 llvm::ErrorOr<llvm::SmallVector<SymbolicShape>> SymbolicShapesResolver::Resolve(
222 ArgumentsRef arguments) const {
223 // Prepare storage for resolving symbolic shapes.
224 llvm::SmallVector<SymbolicShape> symbolic_shapes;
225 symbolic_shapes.resize(arguments.size());
226
227 if (LLVM_UNLIKELY(
228 ResolveImpl(*this, arguments, iteration_order_, symbolic_shapes)))
229 return llvm::errc::invalid_argument;
230
231 return symbolic_shapes;
232 }
233
234 namespace {
235 // A struct to accumulate all resolved symbolic dimensions in a single vector.
236 // Resolved symbolic dimensions stored according to the iteration order, and not
237 // the argument order, however for computing the hash value it doesn't matter.
238 struct SymbolicShapesFingerprint {
SymbolicShapesFingerprintxla::runtime::__anonbef848620611::SymbolicShapesFingerprint239 SymbolicShapesFingerprint() : offset(0) {}
240
241 // Make sure that we do not copy the fingerprint.
242 SymbolicShapesFingerprint(const SymbolicShapesFingerprint&) = delete;
243
operator []xla::runtime::__anonbef848620611::SymbolicShapesFingerprint244 SymbolicShapesFingerprint& operator[](size_t i) { return *this; }
245
246 template <typename InputIt>
assignxla::runtime::__anonbef848620611::SymbolicShapesFingerprint247 LLVM_ATTRIBUTE_ALWAYS_INLINE void assign(InputIt first, InputIt last) {
248 auto rank = std::distance(first, last);
249 offset = values.size();
250 values.resize_for_overwrite(offset + rank);
251 llvm::copy(llvm::make_range(first, last), values.begin() + offset);
252 }
253
resizexla::runtime::__anonbef848620611::SymbolicShapesFingerprint254 LLVM_ATTRIBUTE_ALWAYS_INLINE void resize(int64_t rank, int64_t dim) {
255 values.push_back(rank);
256 offset = values.size();
257 values.resize(offset + rank, dim);
258 }
259
operator MutableArrayRef<int64_t>xla::runtime::__anonbef848620611::SymbolicShapesFingerprint260 operator MutableArrayRef<int64_t>() { // NOLINT
261 return {values.begin() + offset, values.end()};
262 }
263
264 size_t offset;
265 llvm::SmallVector<int64_t, 32> values;
266 };
267 } // namespace
268
ResolveHash(ArgumentsRef arguments) const269 llvm::ErrorOr<llvm::hash_code> SymbolicShapesResolver::ResolveHash(
270 ArgumentsRef arguments) const {
271 // Accumulate symbolic shapes into the shapes fingerprint.
272 SymbolicShapesFingerprint fingerprint;
273
274 if (LLVM_UNLIKELY(
275 ResolveImpl(*this, arguments, hash_iteration_order_, fingerprint)))
276 return llvm::errc::invalid_argument;
277
278 return llvm::hash_combine_range(fingerprint.values.begin(),
279 fingerprint.values.end());
280 }
281
Normalize(const SymbolicShape & shape)282 /*static*/ llvm::SmallVector<int64_t> SymbolicShapesResolver::Normalize(
283 const SymbolicShape& shape) {
284 auto normalize = llvm::map_range(shape, [](int64_t dim) {
285 return std::max(dim, MemrefType::kDynamicSize);
286 });
287 return {normalize.begin(), normalize.end()};
288 }
289
SymbolicShapeHash(const SymbolicShape & shape)290 static llvm::hash_code SymbolicShapeHash(const SymbolicShape& shape) {
291 return llvm::hash_combine(
292 shape.size(), llvm::hash_combine_range(shape.begin(), shape.end()));
293 }
294
Hash(ArrayRef<SymbolicShape> symbolic_shapes)295 /*static*/ llvm::hash_code SymbolicShapesResolver::Hash(
296 ArrayRef<SymbolicShape> symbolic_shapes) {
297 if (LLVM_UNLIKELY(symbolic_shapes.empty())) return llvm::hash_code(0);
298
299 llvm::hash_code hash = SymbolicShapeHash(symbolic_shapes[0]);
300 for (unsigned i = 1; i < symbolic_shapes.size(); ++i)
301 hash = llvm::hash_combine(hash, SymbolicShapeHash(symbolic_shapes[i]));
302
303 return hash;
304 }
305
306 } // namespace runtime
307 } // namespace xla
308