1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef XLA_RUNTIME_ARGUMENTS_H_
17 #define XLA_RUNTIME_ARGUMENTS_H_
18
19 #include <cstddef>
20 #include <type_traits>
21
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Error.h"
24 #include "tensorflow/compiler/xla/runtime/types.h"
25 #include "tfrt/dtype/dtype.h" // from @tf_runtime
26
27 namespace xla {
28 namespace runtime {
29
30 //===----------------------------------------------------------------------===//
31 // A base class for XLA executable arguments.
32 //===----------------------------------------------------------------------===//
33
34 class Argument : public llvm::RTTIExtends<Type, llvm::RTTIRoot> {
35 public:
36 static constexpr char ID = 0; // NOLINT
37
38 Argument() = default;
39
40 // Verifies that the argument matches the expected type.
41 virtual llvm::Error Verify(const Type& type) const = 0;
42
43 // Packs argument into the `args` array starting at the given `offset`
44 // according to the expected executable ABI. Return offset incremented by
45 // the number of packed pointers, so that result will point to the offset for
46 // packing the next argument.
47 //
48 // Arguments array is guaranteed to be properly sized to have space for all
49 // arguments according to the arguments memory layout.
50 virtual size_t Pack(llvm::MutableArrayRef<void*> args,
51 size_t offset) const = 0;
52
53 virtual llvm::raw_ostream& print(llvm::raw_ostream& os) const = 0;
54 };
55
56 inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os,
57 const Argument& arg) {
58 return arg.print(os);
59 }
60
61 //===----------------------------------------------------------------------===//
62 // Owning container for storing arguments of different types.
63 //===----------------------------------------------------------------------===//
64
65 // Forward declare class defined below.
66 class ArgumentsRef;
67
68 // An owning container for the variadic arguments, optimized for storing all
69 // arguments of the declared types without dynamic memory allocations.
70 //
71 // Example:
72 //
73 // Arguments<OpaqueArg, MemrefDesc> arguments;
74 // arguments.emplace_back<OpaqueArg>(...);
75 //
76 // Variadic type parameter `Ts` specifies arguments of what types can be added
77 // to the container.
78 template <typename... Ts>
79 class Arguments {
80 public:
Arguments(size_t num_args)81 explicit Arguments(size_t num_args) : num_args_(num_args) {
82 storage_.reserve(num_args);
83 }
84
~Arguments()85 ~Arguments() {
86 for (size_t i = 0; i < storage_.size(); ++i) (*this)[i].~Argument();
87 }
88
89 template <typename T>
push_back(T value)90 T& push_back(T value) {
91 static_assert(std::disjunction_v<std::is_same<T, Ts>...>,
92 "type is not supported by this instance of arguments");
93 assert(storage_.size() < num_args_ && "arguments overflow");
94 storage_.resize_for_overwrite(storage_.size() + 1);
95 return *(new (&storage_.back()) T(std::forward<T>(value)));
96 }
97
98 template <typename T = std::tuple_element_t<0, std::tuple<Ts...>>,
99 typename... Args>
emplace_back(Args...args)100 T& emplace_back(Args... args) {
101 static_assert(std::disjunction_v<std::is_same<T, Ts>...>,
102 "type is not supported by this instance of arguments");
103 assert(storage_.size() < num_args_ && "arguments overflow");
104 storage_.resize_for_overwrite(storage_.size() + 1);
105 return *(new (&storage_.back()) T(std::forward<Args>(args)...));
106 }
107
108 const Argument& operator[](size_t index) const {
109 return *reinterpret_cast<const Argument*>(storage_[index].data);
110 }
111
size()112 size_t size() const { return storage_.size(); }
113
114 private:
115 friend class ArgumentsRef;
116
117 static_assert(std::conjunction_v<std::is_base_of<Argument, Ts>...>,
118 "all types must be arguments");
119
120 // Arguments are not movable or copyable because we do manual memory
121 // management using the `Storage` struct, and moving or copying bytes storing
122 // the argument value is undefined behavior.
123 Arguments(const Arguments&) = delete;
124 Arguments& operator=(const Arguments&) = delete;
125 Arguments(Arguments&&) = delete;
126 Arguments& operator=(Arguments&&) = delete;
127
128 // Avoid dynamic memory allocation for storing arguments of different types
129 // by storing them in the properly aligned byte array.
130 struct Storage {
131 alignas(Ts...) std::byte data[std::max({sizeof(Ts)...})];
132 };
133
134 // To guarantee safe conversion between pointer to `Storage` and pointer to
135 // the first byte (Argument), the storage struct must have standard layout.
136 static_assert(std::is_standard_layout_v<Storage>,
137 "storage must have standard layout");
138
139 size_t num_args_;
140 llvm::SmallVector<Storage> storage_;
141 };
142
143 // A constant reference to an array of arguments, somewhat similar to the
144 // `ArrayRef<Argument>`, however because `ArrayRef` of a virtual base is not
145 // possible, we have our own type that is constructible from the `Arguments`
146 // and array reference or vector of any argument subtype.
147 class ArgumentsRef {
148 template <typename T>
149 static constexpr bool is_argument = std::is_base_of_v<Argument, T>;
150
151 public:
152 template <typename... Ts>
ArgumentsRef(const Arguments<Ts...> & args)153 ArgumentsRef(const Arguments<Ts...>& args) // NOLINT
154 : data_(reinterpret_cast<const Argument*>(args.storage_.data())),
155 size_(args.size()),
156 stride_(sizeof(typename Arguments<Ts...>::Storage)) {}
157
158 template <typename T, std::enable_if_t<is_argument<T>>* = nullptr>
ArgumentsRef(llvm::ArrayRef<T> ref)159 ArgumentsRef(llvm::ArrayRef<T> ref) // NOLINT
160 : data_(ref.data()), size_(ref.size()), stride_(sizeof(T)) {}
161
162 template <typename T, std::enable_if_t<is_argument<T>>* = nullptr>
ArgumentsRef(const llvm::SmallVectorImpl<T> & vec)163 ArgumentsRef(const llvm::SmallVectorImpl<T>& vec) // NOLINT
164 : ArgumentsRef(llvm::ArrayRef<T>(vec)) {}
165
166 template <typename T, std::enable_if_t<is_argument<T>>* = nullptr>
ArgumentsRef(const std::vector<T> & vec)167 ArgumentsRef(const std::vector<T>& vec) // NOLINT
168 : ArgumentsRef(llvm::ArrayRef<T>(vec)) {}
169
170 template <typename T, size_t n, std::enable_if_t<is_argument<T>>* = nullptr>
ArgumentsRef(const std::array<T,n> & arr)171 ArgumentsRef(const std::array<T, n>& arr) // NOLINT
172 : ArgumentsRef(llvm::ArrayRef<T>(arr)) {}
173
174 const Argument& operator[](size_t index) const {
175 assert(index < size_ && "index out of bounds");
176 auto* ptr = reinterpret_cast<const std::byte*>(data_) + index * stride_;
177 return *reinterpret_cast<const Argument*>(ptr);
178 }
179
size()180 size_t size() const { return size_; }
181
182 private:
183 // Arguments stored in the contiguous memory starting at `data_` pointer,
184 // with the given `stride_` in bytes.
185 const Argument* data_;
186 size_t size_;
187 size_t stride_;
188 };
189
190 //===----------------------------------------------------------------------===//
191 // Canonical types for passing compiled kernel arguments.
192 //===----------------------------------------------------------------------===//
193
194 // By default we provide a set of types for passing common arguments to the
195 // compiled kernel. The type hierarchy is open, and users can extend it by
196 // definining new `Type` and `Argument` with the corresponding MLIR types and
197 // MLIR passes to lower types and operations to the LLVM dialect.
198
199 //===----------------------------------------------------------------------===//
200 // OpaqueArg for passing `!llvm.ptr` (opaque pointer) arguments.
201 //===----------------------------------------------------------------------===//
202
203 class OpaqueArg final : public llvm::RTTIExtends<OpaqueArg, Argument> {
204 public:
205 static constexpr char ID = 0; // NOLINT
206
OpaqueArg(void * ptr)207 explicit OpaqueArg(void* ptr) : ptr_(ptr) {}
208
ptr()209 void* ptr() const { return ptr_; }
210
211 llvm::Error Verify(const Type& type) const final;
212 size_t Pack(llvm::MutableArrayRef<void*> args, size_t offset) const final;
213 llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
214
215 private:
216 void* ptr_;
217 };
218
219 //===----------------------------------------------------------------------===//
220 // MemrefDesc for passing `memref` arguments.
221 //===----------------------------------------------------------------------===//
222
223 class MemrefDesc final : public llvm::RTTIExtends<MemrefDesc, Argument> {
224 public:
225 static constexpr char ID = 0; // NOLINT
226
MemrefDesc(tfrt::DType dtype,void * data,int64_t offset,llvm::ArrayRef<int64_t> sizes,llvm::ArrayRef<int64_t> strides)227 MemrefDesc(tfrt::DType dtype, void* data, int64_t offset,
228 llvm::ArrayRef<int64_t> sizes, llvm::ArrayRef<int64_t> strides)
229 : rank_(sizes.size()), dtype_(dtype), data_(data), offset_(offset) {
230 assert(sizes.size() == strides.size() && "invalid sizes and strides pair");
231 sizes_and_strides_.reserve(2 * rank_);
232 sizes_and_strides_.append(sizes.begin(), sizes.end());
233 sizes_and_strides_.append(strides.begin(), strides.end());
234 }
235
236 // Constructs MemrefDesc of the given rank and calls user-provided callback to
237 // initialize sizes and strides.
238 //
239 // Expected `InitializeSizesAndStrides` callback signature:
240 //
241 // void operator()(MutableArrayRef<int64_t> sizes,
242 // MutableArrayRef<int64_t> strides);
243 //
244 // We pass the init callback as a template argument to be able to
245 // inline it at the call site, because MemrefDesc construction is on a hot
246 // path.
247 template <typename InitializeSizesAndStrides>
248 MemrefDesc(unsigned rank, tfrt::DType dtype, void* data, int64_t offset,
249 InitializeSizesAndStrides initialize);
250
251 // Ensure that MemrefDesc is always moved around instead of copying.
252 MemrefDesc(const MemrefDesc&) = delete;
253 MemrefDesc& operator=(const MemrefDesc&) = delete;
254 MemrefDesc(MemrefDesc&&) = default;
255 MemrefDesc& operator=(MemrefDesc&&) = default;
256
rank()257 unsigned rank() const { return rank_; }
dtype()258 tfrt::DType dtype() const { return dtype_; }
259
data()260 void* data() const { return data_; }
offset()261 int64_t offset() const { return offset_; }
262
size(size_t index)263 int64_t size(size_t index) const { return sizes_and_strides_[index]; }
stride(size_t index)264 int64_t stride(size_t index) const {
265 return sizes_and_strides_[rank_ + index];
266 }
267
sizes()268 llvm::ArrayRef<int64_t> sizes() const {
269 return {sizes_and_strides_.data(), rank_};
270 }
strides()271 llvm::ArrayRef<int64_t> strides() const {
272 return {sizes_and_strides_.data() + rank_, rank_};
273 }
274
275 llvm::Error Verify(const Type& type) const final;
276 size_t Pack(llvm::MutableArrayRef<void*> args, size_t offset) const final;
277 llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
278
279 private:
280 unsigned rank_;
281 tfrt::DType dtype_;
282 void* data_;
283 int64_t offset_;
284 // We keep sizes and strides in a single container to save one potential
285 // memory allocation for memrefs of higher ranks, and to save one vector
286 // constructor/destructor call.
287 llvm::SmallVector<int64_t, 8> sizes_and_strides_;
288 };
289
290 template <typename InitializeSizesAndStrides>
MemrefDesc(unsigned rank,tfrt::DType dtype,void * data,int64_t offset,InitializeSizesAndStrides initialize)291 MemrefDesc::MemrefDesc(unsigned rank, tfrt::DType dtype, void* data,
292 int64_t offset, InitializeSizesAndStrides initialize)
293 : rank_(rank), dtype_(dtype), data_(data), offset_(offset) {
294 sizes_and_strides_.resize(2 * rank_);
295 llvm::MutableArrayRef<int64_t> ref = sizes_and_strides_;
296 initialize(ref.drop_back(rank_), ref.drop_front(rank_));
297 }
298
299 //===----------------------------------------------------------------------===//
300 // Verify that argument type is compatible with the run-time memref argument.
301 //===----------------------------------------------------------------------===//
302
303 // Verifies that the type at the given `index` matches the run-time memref
304 // argument: type is a tensor of a memref with compatible element type, and all
305 // statically known dimensions match the run-time sizes. Returns user-friendly
306 // error message in case of an error.
307 llvm::Error VerifyMemrefArgument(unsigned index, const Type& type,
308 const MemrefDesc& arg);
309
310 //===----------------------------------------------------------------------===//
311 // BufferDesc for passing raw `buffer` (i.e. void ptr + size) arguments.
312 //===----------------------------------------------------------------------===//
313
314 class BufferDesc final : public llvm::RTTIExtends<BufferDesc, Argument> {
315 public:
316 static constexpr char ID = 0; // NOLINT
317
BufferDesc(void * data,size_t size)318 BufferDesc(void* data, size_t size) : data_(data), size_(size) {}
319
data()320 void* data() const { return data_; }
size()321 size_t size() const { return size_; }
322
323 llvm::Error Verify(const Type& type) const final;
324 size_t Pack(llvm::MutableArrayRef<void*> args, size_t offset) const final;
325 llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
326
327 private:
328 void* data_;
329 size_t size_;
330 };
331
332 } // namespace runtime
333 } // namespace xla
334
335 #endif // XLA_RUNTIME_ARGUMENTS_H_
336