xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/runtime/arguments.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_RUNTIME_ARGUMENTS_H_
17 #define XLA_RUNTIME_ARGUMENTS_H_
18 
19 #include <cstddef>
20 #include <type_traits>
21 
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Error.h"
24 #include "tensorflow/compiler/xla/runtime/types.h"
25 #include "tfrt/dtype/dtype.h"  // from @tf_runtime
26 
27 namespace xla {
28 namespace runtime {
29 
30 //===----------------------------------------------------------------------===//
31 // A base class for XLA executable arguments.
32 //===----------------------------------------------------------------------===//
33 
34 class Argument : public llvm::RTTIExtends<Type, llvm::RTTIRoot> {
35  public:
36   static constexpr char ID = 0;  // NOLINT
37 
38   Argument() = default;
39 
40   // Verifies that the argument matches the expected type.
41   virtual llvm::Error Verify(const Type& type) const = 0;
42 
43   // Packs argument into the `args` array starting at the given `offset`
44   // according to the expected executable ABI. Return offset incremented by
45   // the number of packed pointers, so that result will point to the offset for
46   // packing the next argument.
47   //
48   // Arguments array is guaranteed to be properly sized to have space for all
49   // arguments according to the arguments memory layout.
50   virtual size_t Pack(llvm::MutableArrayRef<void*> args,
51                       size_t offset) const = 0;
52 
53   virtual llvm::raw_ostream& print(llvm::raw_ostream& os) const = 0;
54 };
55 
56 inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os,
57                                      const Argument& arg) {
58   return arg.print(os);
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // Owning container for storing arguments of different types.
63 //===----------------------------------------------------------------------===//
64 
65 // Forward declare class defined below.
66 class ArgumentsRef;
67 
68 // An owning container for the variadic arguments, optimized for storing all
69 // arguments of the declared types without dynamic memory allocations.
70 //
71 // Example:
72 //
73 //   Arguments<OpaqueArg, MemrefDesc> arguments;
74 //   arguments.emplace_back<OpaqueArg>(...);
75 //
76 // Variadic type parameter `Ts` specifies arguments of what types can be added
77 // to the container.
78 template <typename... Ts>
79 class Arguments {
80  public:
Arguments(size_t num_args)81   explicit Arguments(size_t num_args) : num_args_(num_args) {
82     storage_.reserve(num_args);
83   }
84 
~Arguments()85   ~Arguments() {
86     for (size_t i = 0; i < storage_.size(); ++i) (*this)[i].~Argument();
87   }
88 
89   template <typename T>
push_back(T value)90   T& push_back(T value) {
91     static_assert(std::disjunction_v<std::is_same<T, Ts>...>,
92                   "type is not supported by this instance of arguments");
93     assert(storage_.size() < num_args_ && "arguments overflow");
94     storage_.resize_for_overwrite(storage_.size() + 1);
95     return *(new (&storage_.back()) T(std::forward<T>(value)));
96   }
97 
98   template <typename T = std::tuple_element_t<0, std::tuple<Ts...>>,
99             typename... Args>
emplace_back(Args...args)100   T& emplace_back(Args... args) {
101     static_assert(std::disjunction_v<std::is_same<T, Ts>...>,
102                   "type is not supported by this instance of arguments");
103     assert(storage_.size() < num_args_ && "arguments overflow");
104     storage_.resize_for_overwrite(storage_.size() + 1);
105     return *(new (&storage_.back()) T(std::forward<Args>(args)...));
106   }
107 
108   const Argument& operator[](size_t index) const {
109     return *reinterpret_cast<const Argument*>(storage_[index].data);
110   }
111 
size()112   size_t size() const { return storage_.size(); }
113 
114  private:
115   friend class ArgumentsRef;
116 
117   static_assert(std::conjunction_v<std::is_base_of<Argument, Ts>...>,
118                 "all types must be arguments");
119 
120   // Arguments are not movable or copyable because we do manual memory
121   // management using the `Storage` struct, and moving or copying bytes storing
122   // the argument value is undefined behavior.
123   Arguments(const Arguments&) = delete;
124   Arguments& operator=(const Arguments&) = delete;
125   Arguments(Arguments&&) = delete;
126   Arguments& operator=(Arguments&&) = delete;
127 
128   // Avoid dynamic memory allocation for storing arguments of different types
129   // by storing them in the properly aligned byte array.
130   struct Storage {
131     alignas(Ts...) std::byte data[std::max({sizeof(Ts)...})];
132   };
133 
134   // To guarantee safe conversion between pointer to `Storage` and pointer to
135   // the first byte (Argument), the storage struct must have standard layout.
136   static_assert(std::is_standard_layout_v<Storage>,
137                 "storage must have standard layout");
138 
139   size_t num_args_;
140   llvm::SmallVector<Storage> storage_;
141 };
142 
143 // A constant reference to an array of arguments, somewhat similar to the
144 // `ArrayRef<Argument>`, however because `ArrayRef` of a virtual base is not
145 // possible, we have our own type that is constructible from the `Arguments`
146 // and array reference or vector of any argument subtype.
147 class ArgumentsRef {
148   template <typename T>
149   static constexpr bool is_argument = std::is_base_of_v<Argument, T>;
150 
151  public:
152   template <typename... Ts>
ArgumentsRef(const Arguments<Ts...> & args)153   ArgumentsRef(const Arguments<Ts...>& args)  // NOLINT
154       : data_(reinterpret_cast<const Argument*>(args.storage_.data())),
155         size_(args.size()),
156         stride_(sizeof(typename Arguments<Ts...>::Storage)) {}
157 
158   template <typename T, std::enable_if_t<is_argument<T>>* = nullptr>
ArgumentsRef(llvm::ArrayRef<T> ref)159   ArgumentsRef(llvm::ArrayRef<T> ref)  // NOLINT
160       : data_(ref.data()), size_(ref.size()), stride_(sizeof(T)) {}
161 
162   template <typename T, std::enable_if_t<is_argument<T>>* = nullptr>
ArgumentsRef(const llvm::SmallVectorImpl<T> & vec)163   ArgumentsRef(const llvm::SmallVectorImpl<T>& vec)  // NOLINT
164       : ArgumentsRef(llvm::ArrayRef<T>(vec)) {}
165 
166   template <typename T, std::enable_if_t<is_argument<T>>* = nullptr>
ArgumentsRef(const std::vector<T> & vec)167   ArgumentsRef(const std::vector<T>& vec)  // NOLINT
168       : ArgumentsRef(llvm::ArrayRef<T>(vec)) {}
169 
170   template <typename T, size_t n, std::enable_if_t<is_argument<T>>* = nullptr>
ArgumentsRef(const std::array<T,n> & arr)171   ArgumentsRef(const std::array<T, n>& arr)  // NOLINT
172       : ArgumentsRef(llvm::ArrayRef<T>(arr)) {}
173 
174   const Argument& operator[](size_t index) const {
175     assert(index < size_ && "index out of bounds");
176     auto* ptr = reinterpret_cast<const std::byte*>(data_) + index * stride_;
177     return *reinterpret_cast<const Argument*>(ptr);
178   }
179 
size()180   size_t size() const { return size_; }
181 
182  private:
183   // Arguments stored in the contiguous memory starting at `data_` pointer,
184   // with the given `stride_` in bytes.
185   const Argument* data_;
186   size_t size_;
187   size_t stride_;
188 };
189 
190 //===----------------------------------------------------------------------===//
191 // Canonical types for passing compiled kernel arguments.
192 //===----------------------------------------------------------------------===//
193 
194 // By default we provide a set of types for passing common arguments to the
195 // compiled kernel. The type hierarchy is open, and users can extend it by
196 // definining new `Type` and `Argument` with the corresponding MLIR types and
197 // MLIR passes to lower types and operations to the LLVM dialect.
198 
199 //===----------------------------------------------------------------------===//
200 // OpaqueArg for passing `!llvm.ptr` (opaque pointer) arguments.
201 //===----------------------------------------------------------------------===//
202 
203 class OpaqueArg final : public llvm::RTTIExtends<OpaqueArg, Argument> {
204  public:
205   static constexpr char ID = 0;  // NOLINT
206 
OpaqueArg(void * ptr)207   explicit OpaqueArg(void* ptr) : ptr_(ptr) {}
208 
ptr()209   void* ptr() const { return ptr_; }
210 
211   llvm::Error Verify(const Type& type) const final;
212   size_t Pack(llvm::MutableArrayRef<void*> args, size_t offset) const final;
213   llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
214 
215  private:
216   void* ptr_;
217 };
218 
219 //===----------------------------------------------------------------------===//
220 // MemrefDesc for passing `memref` arguments.
221 //===----------------------------------------------------------------------===//
222 
223 class MemrefDesc final : public llvm::RTTIExtends<MemrefDesc, Argument> {
224  public:
225   static constexpr char ID = 0;  // NOLINT
226 
MemrefDesc(tfrt::DType dtype,void * data,int64_t offset,llvm::ArrayRef<int64_t> sizes,llvm::ArrayRef<int64_t> strides)227   MemrefDesc(tfrt::DType dtype, void* data, int64_t offset,
228              llvm::ArrayRef<int64_t> sizes, llvm::ArrayRef<int64_t> strides)
229       : rank_(sizes.size()), dtype_(dtype), data_(data), offset_(offset) {
230     assert(sizes.size() == strides.size() && "invalid sizes and strides pair");
231     sizes_and_strides_.reserve(2 * rank_);
232     sizes_and_strides_.append(sizes.begin(), sizes.end());
233     sizes_and_strides_.append(strides.begin(), strides.end());
234   }
235 
236   // Constructs MemrefDesc of the given rank and calls user-provided callback to
237   // initialize sizes and strides.
238   //
239   // Expected `InitializeSizesAndStrides` callback signature:
240   //
241   //   void operator()(MutableArrayRef<int64_t> sizes,
242   //                   MutableArrayRef<int64_t> strides);
243   //
244   // We pass the init callback as a template argument to be able to
245   // inline it at the call site, because MemrefDesc construction is on a hot
246   // path.
247   template <typename InitializeSizesAndStrides>
248   MemrefDesc(unsigned rank, tfrt::DType dtype, void* data, int64_t offset,
249              InitializeSizesAndStrides initialize);
250 
251   // Ensure that MemrefDesc is always moved around instead of copying.
252   MemrefDesc(const MemrefDesc&) = delete;
253   MemrefDesc& operator=(const MemrefDesc&) = delete;
254   MemrefDesc(MemrefDesc&&) = default;
255   MemrefDesc& operator=(MemrefDesc&&) = default;
256 
rank()257   unsigned rank() const { return rank_; }
dtype()258   tfrt::DType dtype() const { return dtype_; }
259 
data()260   void* data() const { return data_; }
offset()261   int64_t offset() const { return offset_; }
262 
size(size_t index)263   int64_t size(size_t index) const { return sizes_and_strides_[index]; }
stride(size_t index)264   int64_t stride(size_t index) const {
265     return sizes_and_strides_[rank_ + index];
266   }
267 
sizes()268   llvm::ArrayRef<int64_t> sizes() const {
269     return {sizes_and_strides_.data(), rank_};
270   }
strides()271   llvm::ArrayRef<int64_t> strides() const {
272     return {sizes_and_strides_.data() + rank_, rank_};
273   }
274 
275   llvm::Error Verify(const Type& type) const final;
276   size_t Pack(llvm::MutableArrayRef<void*> args, size_t offset) const final;
277   llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
278 
279  private:
280   unsigned rank_;
281   tfrt::DType dtype_;
282   void* data_;
283   int64_t offset_;
284   // We keep sizes and strides in a single container to save one potential
285   // memory allocation for memrefs of higher ranks, and to save one vector
286   // constructor/destructor call.
287   llvm::SmallVector<int64_t, 8> sizes_and_strides_;
288 };
289 
290 template <typename InitializeSizesAndStrides>
MemrefDesc(unsigned rank,tfrt::DType dtype,void * data,int64_t offset,InitializeSizesAndStrides initialize)291 MemrefDesc::MemrefDesc(unsigned rank, tfrt::DType dtype, void* data,
292                        int64_t offset, InitializeSizesAndStrides initialize)
293     : rank_(rank), dtype_(dtype), data_(data), offset_(offset) {
294   sizes_and_strides_.resize(2 * rank_);
295   llvm::MutableArrayRef<int64_t> ref = sizes_and_strides_;
296   initialize(ref.drop_back(rank_), ref.drop_front(rank_));
297 }
298 
299 //===----------------------------------------------------------------------===//
300 // Verify that argument type is compatible with the run-time memref argument.
301 //===----------------------------------------------------------------------===//
302 
303 // Verifies that the type at the given `index` matches the run-time memref
304 // argument: type is a tensor of a memref with compatible element type, and all
305 // statically known dimensions match the run-time sizes. Returns user-friendly
306 // error message in case of an error.
307 llvm::Error VerifyMemrefArgument(unsigned index, const Type& type,
308                                  const MemrefDesc& arg);
309 
310 //===----------------------------------------------------------------------===//
311 // BufferDesc for passing raw `buffer` (i.e. void ptr + size) arguments.
312 //===----------------------------------------------------------------------===//
313 
314 class BufferDesc final : public llvm::RTTIExtends<BufferDesc, Argument> {
315  public:
316   static constexpr char ID = 0;  // NOLINT
317 
BufferDesc(void * data,size_t size)318   BufferDesc(void* data, size_t size) : data_(data), size_(size) {}
319 
data()320   void* data() const { return data_; }
size()321   size_t size() const { return size_; }
322 
323   llvm::Error Verify(const Type& type) const final;
324   size_t Pack(llvm::MutableArrayRef<void*> args, size_t offset) const final;
325   llvm::raw_ostream& print(llvm::raw_ostream& os) const final;
326 
327  private:
328   void* data_;
329   size_t size_;
330 };
331 
332 }  // namespace runtime
333 }  // namespace xla
334 
335 #endif  // XLA_RUNTIME_ARGUMENTS_H_
336