xref: /aosp_15_r20/external/executorch/runtime/core/evalue.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 #include <executorch/runtime/core/exec_aten/exec_aten.h>
11 #include <executorch/runtime/core/tag.h>
12 #include <executorch/runtime/platform/assert.h>
13 
14 namespace executorch {
15 namespace runtime {
16 
17 struct EValue;
18 
19 namespace internal {
20 
21 // Tensor gets proper reference treatment because its expensive to copy in aten
22 // mode, all other types are just copied.
23 template <typename T>
24 struct evalue_to_const_ref_overload_return {
25   using type = T;
26 };
27 
28 template <>
29 struct evalue_to_const_ref_overload_return<executorch::aten::Tensor> {
30   using type = const executorch::aten::Tensor&;
31 };
32 
33 template <typename T>
34 struct evalue_to_ref_overload_return {
35   using type = T;
36 };
37 
38 template <>
39 struct evalue_to_ref_overload_return<executorch::aten::Tensor> {
40   using type = executorch::aten::Tensor&;
41 };
42 
43 } // namespace internal
44 
45 /*
46  * Helper class used to correlate EValues in the executor table, with the
47  * unwrapped list of the proper type. Because values in the runtime's values
48  * table can change during execution, we cannot statically allocate list of
49  * objects at deserialization. Imagine the serialized list says index 0 in the
50  * value table is element 2 in the list, but during execution the value in
51  * element 2 changes (in the case of tensor this means the TensorImpl* stored in
52  * the tensor changes). To solve this instead they must be created dynamically
53  * whenever they are used.
54  */
55 template <typename T>
56 class BoxedEvalueList {
57  public:
58   BoxedEvalueList() = default;
59   /*
60    * Wrapped_vals is a list of pointers into the values table of the runtime
61    * whose destinations correlate with the elements of the list, unwrapped_vals
62    * is a container of the same size whose serves as memory to construct the
63    * unwrapped vals.
64    */
65   BoxedEvalueList(EValue** wrapped_vals, T* unwrapped_vals, int size)
66       : wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {}
67   /*
68    * Constructs and returns the list of T specified by the EValue pointers
69    */
70   executorch::aten::ArrayRef<T> get() const;
71 
72  private:
73   // Source of truth for the list
74   executorch::aten::ArrayRef<EValue*> wrapped_vals_;
75   // Same size as wrapped_vals
76   mutable T* unwrapped_vals_;
77 };
78 
79 template <>
80 executorch::aten::ArrayRef<executorch::aten::optional<executorch::aten::Tensor>>
81 BoxedEvalueList<executorch::aten::optional<executorch::aten::Tensor>>::get()
82     const;
83 
84 // Aggregate typing system similar to IValue only slimmed down with less
85 // functionality, no dependencies on atomic, and fewer supported types to better
86 // suit embedded systems (ie no intrusive ptr)
87 struct EValue {
88   union Payload {
89     // When in ATen mode at::Tensor is not trivially copyable, this nested union
90     // lets us handle tensor as a special case while leaving the rest of the
91     // fields in a simple state instead of requiring a switch on tag everywhere.
92     union TriviallyCopyablePayload {
93       TriviallyCopyablePayload() : as_int(0) {}
94       // Scalar supported through these 3 types
95       int64_t as_int;
96       double as_double;
97       bool as_bool;
98       // TODO(jakeszwe): convert back to pointers to optimize size of this
99       // struct
100       executorch::aten::ArrayRef<char> as_string;
101       executorch::aten::ArrayRef<double> as_double_list;
102       executorch::aten::ArrayRef<bool> as_bool_list;
103       BoxedEvalueList<int64_t> as_int_list;
104       BoxedEvalueList<executorch::aten::Tensor> as_tensor_list;
105       BoxedEvalueList<executorch::aten::optional<executorch::aten::Tensor>>
106           as_list_optional_tensor;
107     } copyable_union;
108 
109     // Since a Tensor just holds a TensorImpl*, there's no value to use Tensor*
110     // here.
111     executorch::aten::Tensor as_tensor;
112 
113     Payload() {}
114     ~Payload() {}
115   };
116 
117   // Data storage and type tag
118   Payload payload;
119   Tag tag;
120 
121   // Basic ctors and assignments
122   EValue(const EValue& rhs) : EValue(rhs.payload, rhs.tag) {}
123 
124   EValue(EValue&& rhs) noexcept : tag(rhs.tag) {
125     moveFrom(std::move(rhs));
126   }
127 
128   EValue& operator=(EValue&& rhs) & noexcept {
129     if (&rhs == this) {
130       return *this;
131     }
132 
133     destroy();
134     moveFrom(std::move(rhs));
135     return *this;
136   }
137 
138   EValue& operator=(EValue const& rhs) & {
139     // Define copy assignment through copy ctor and move assignment
140     *this = EValue(rhs);
141     return *this;
142   }
143 
144   ~EValue() {
145     destroy();
146   }
147 
148   /****** None Type ******/
149   EValue() : tag(Tag::None) {
150     payload.copyable_union.as_int = 0;
151   }
152 
153   bool isNone() const {
154     return tag == Tag::None;
155   }
156 
157   /****** Int Type ******/
158   /*implicit*/ EValue(int64_t i) : tag(Tag::Int) {
159     payload.copyable_union.as_int = i;
160   }
161 
162   bool isInt() const {
163     return tag == Tag::Int;
164   }
165 
166   int64_t toInt() const {
167     ET_CHECK_MSG(isInt(), "EValue is not an int.");
168     return payload.copyable_union.as_int;
169   }
170 
171   /****** Double Type ******/
172   /*implicit*/ EValue(double d) : tag(Tag::Double) {
173     payload.copyable_union.as_double = d;
174   }
175 
176   bool isDouble() const {
177     return tag == Tag::Double;
178   }
179 
180   double toDouble() const {
181     ET_CHECK_MSG(isDouble(), "EValue is not a Double.");
182     return payload.copyable_union.as_double;
183   }
184 
185   /****** Bool Type ******/
186   /*implicit*/ EValue(bool b) : tag(Tag::Bool) {
187     payload.copyable_union.as_bool = b;
188   }
189 
190   bool isBool() const {
191     return tag == Tag::Bool;
192   }
193 
194   bool toBool() const {
195     ET_CHECK_MSG(isBool(), "EValue is not a Bool.");
196     return payload.copyable_union.as_bool;
197   }
198 
199   /****** Scalar Type ******/
200   /// Construct an EValue using the implicit value of a Scalar.
201   /*implicit*/ EValue(executorch::aten::Scalar s) {
202     if (s.isIntegral(false)) {
203       tag = Tag::Int;
204       payload.copyable_union.as_int = s.to<int64_t>();
205     } else if (s.isFloatingPoint()) {
206       tag = Tag::Double;
207       payload.copyable_union.as_double = s.to<double>();
208     } else if (s.isBoolean()) {
209       tag = Tag::Bool;
210       payload.copyable_union.as_bool = s.to<bool>();
211     } else {
212       ET_CHECK_MSG(false, "Scalar passed to EValue is not initialized.");
213     }
214   }
215 
216   bool isScalar() const {
217     return tag == Tag::Int || tag == Tag::Double || tag == Tag::Bool;
218   }
219 
220   executorch::aten::Scalar toScalar() const {
221     // Convert from implicit value to Scalar using implicit constructors.
222 
223     if (isDouble()) {
224       return toDouble();
225     } else if (isInt()) {
226       return toInt();
227     } else if (isBool()) {
228       return toBool();
229     } else {
230       ET_CHECK_MSG(false, "EValue is not a Scalar.");
231     }
232   }
233 
234   /****** Tensor Type ******/
235   /*implicit*/ EValue(executorch::aten::Tensor t) : tag(Tag::Tensor) {
236     // When built in aten mode, at::Tensor has a non trivial constructor
237     // destructor, so regular assignment to a union field is UB. Instead we must
238     // go through placement new (which causes a refcount bump).
239     new (&payload.as_tensor) executorch::aten::Tensor(t);
240   }
241 
242   // Template constructor that allows construction from types that can be
243   // dereferenced to produce a type that EValue can be implicitly constructed
244   // from.
245   template <
246       typename T,
247       typename = typename std::enable_if<std::is_convertible<
248           decltype(*std::forward<T>(std::declval<T>())), // declval to simulate
249                                                          // forwarding
250           EValue>::value>::type>
251   /*implicit*/ EValue(T&& value) {
252     ET_CHECK_MSG(value != nullptr, "Pointer is null.");
253     // Note that this ctor does not initialize this->tag directly; it is set by
254     // moving in the new value.
255     moveFrom(*std::forward<T>(value));
256   }
257 
258   // Delete constructor for raw pointers to ensure they cannot be used.
259   template <typename T>
260   explicit EValue(T* value) = delete;
261 
262   bool isTensor() const {
263     return tag == Tag::Tensor;
264   }
265 
266   executorch::aten::Tensor toTensor() && {
267     ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
268     auto res = std::move(payload.as_tensor);
269     clearToNone();
270     return res;
271   }
272 
273   executorch::aten::Tensor& toTensor() & {
274     ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
275     return payload.as_tensor;
276   }
277 
278   const executorch::aten::Tensor& toTensor() const& {
279     ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
280     return payload.as_tensor;
281   }
282 
283   /****** String Type ******/
284   /*implicit*/ EValue(const char* s, size_t size) : tag(Tag::String) {
285     payload.copyable_union.as_string =
286         executorch::aten::ArrayRef<char>(s, size);
287   }
288 
289   bool isString() const {
290     return tag == Tag::String;
291   }
292 
293   executorch::aten::string_view toString() const {
294     ET_CHECK_MSG(isString(), "EValue is not a String.");
295     return executorch::aten::string_view(
296         payload.copyable_union.as_string.data(),
297         payload.copyable_union.as_string.size());
298   }
299 
300   /****** Int List Type ******/
301   /*implicit*/ EValue(BoxedEvalueList<int64_t> i) : tag(Tag::ListInt) {
302     payload.copyable_union.as_int_list = i;
303   }
304 
305   bool isIntList() const {
306     return tag == Tag::ListInt;
307   }
308 
309   executorch::aten::ArrayRef<int64_t> toIntList() const {
310     ET_CHECK_MSG(isIntList(), "EValue is not an Int List.");
311     return payload.copyable_union.as_int_list.get();
312   }
313 
314   /****** Bool List Type ******/
315   /*implicit*/ EValue(executorch::aten::ArrayRef<bool> b) : tag(Tag::ListBool) {
316     payload.copyable_union.as_bool_list = b;
317   }
318 
319   bool isBoolList() const {
320     return tag == Tag::ListBool;
321   }
322 
323   executorch::aten::ArrayRef<bool> toBoolList() const {
324     ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List.");
325     return payload.copyable_union.as_bool_list;
326   }
327 
328   /****** Double List Type ******/
329   /*implicit*/ EValue(executorch::aten::ArrayRef<double> d)
330       : tag(Tag::ListDouble) {
331     payload.copyable_union.as_double_list = d;
332   }
333 
334   bool isDoubleList() const {
335     return tag == Tag::ListDouble;
336   }
337 
338   executorch::aten::ArrayRef<double> toDoubleList() const {
339     ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List.");
340     return payload.copyable_union.as_double_list;
341   }
342 
343   /****** Tensor List Type ******/
344   /*implicit*/ EValue(BoxedEvalueList<executorch::aten::Tensor> t)
345       : tag(Tag::ListTensor) {
346     payload.copyable_union.as_tensor_list = t;
347   }
348 
349   bool isTensorList() const {
350     return tag == Tag::ListTensor;
351   }
352 
353   executorch::aten::ArrayRef<executorch::aten::Tensor> toTensorList() const {
354     ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List.");
355     return payload.copyable_union.as_tensor_list.get();
356   }
357 
358   /****** List Optional Tensor Type ******/
359   /*implicit*/ EValue(
360       BoxedEvalueList<executorch::aten::optional<executorch::aten::Tensor>> t)
361       : tag(Tag::ListOptionalTensor) {
362     payload.copyable_union.as_list_optional_tensor = t;
363   }
364 
365   bool isListOptionalTensor() const {
366     return tag == Tag::ListOptionalTensor;
367   }
368 
369   executorch::aten::ArrayRef<
370       executorch::aten::optional<executorch::aten::Tensor>>
371   toListOptionalTensor() const {
372     return payload.copyable_union.as_list_optional_tensor.get();
373   }
374 
375   /****** ScalarType Type ******/
376   executorch::aten::ScalarType toScalarType() const {
377     ET_CHECK_MSG(isInt(), "EValue is not a ScalarType.");
378     return static_cast<executorch::aten::ScalarType>(
379         payload.copyable_union.as_int);
380   }
381 
382   /****** MemoryFormat Type ******/
383   executorch::aten::MemoryFormat toMemoryFormat() const {
384     ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat.");
385     return static_cast<executorch::aten::MemoryFormat>(
386         payload.copyable_union.as_int);
387   }
388 
389   /****** Layout Type ******/
390   executorch::aten::Layout toLayout() const {
391     ET_CHECK_MSG(isInt(), "EValue is not a Layout.");
392     return static_cast<executorch::aten::Layout>(payload.copyable_union.as_int);
393   }
394 
395   /****** Device Type ******/
396   executorch::aten::Device toDevice() const {
397     ET_CHECK_MSG(isInt(), "EValue is not a Device.");
398     return executorch::aten::Device(
399         static_cast<executorch::aten::DeviceType>(
400             payload.copyable_union.as_int),
401         -1);
402   }
403 
404   template <typename T>
405   T to() &&;
406   template <typename T>
407   typename internal::evalue_to_const_ref_overload_return<T>::type to() const&;
408   template <typename T>
409   typename internal::evalue_to_ref_overload_return<T>::type to() &;
410 
411   /**
412    * Converts the EValue to an optional object that can represent both T and
413    * an uninitialized state.
414    */
415   template <typename T>
416   inline executorch::aten::optional<T> toOptional() const {
417     if (this->isNone()) {
418       return executorch::aten::nullopt;
419     }
420     return this->to<T>();
421   }
422 
423  private:
424   // Pre cond: the payload value has had its destructor called
425   void clearToNone() noexcept {
426     payload.copyable_union.as_int = 0;
427     tag = Tag::None;
428   }
429 
430   // Shared move logic
431   void moveFrom(EValue&& rhs) noexcept {
432     if (rhs.isTensor()) {
433       new (&payload.as_tensor)
434           executorch::aten::Tensor(std::move(rhs.payload.as_tensor));
435       rhs.payload.as_tensor.~Tensor();
436     } else {
437       payload.copyable_union = rhs.payload.copyable_union;
438     }
439     tag = rhs.tag;
440     rhs.clearToNone();
441   }
442 
443   // Destructs stored tensor if there is one
444   void destroy() {
445     // Necessary for ATen tensor to refcount decrement the intrusive_ptr to
446     // tensorimpl that got a refcount increment when we placed it in the evalue,
447     // no-op if executorch tensor #ifdef could have a
448     // minor performance bump for a code maintainability hit
449     if (isTensor()) {
450       payload.as_tensor.~Tensor();
451     } else if (isTensorList()) {
452       for (auto& tensor : toTensorList()) {
453         tensor.~Tensor();
454       }
455     } else if (isListOptionalTensor()) {
456       for (auto& optional_tensor : toListOptionalTensor()) {
457         optional_tensor.~optional();
458       }
459     }
460   }
461 
462   EValue(const Payload& p, Tag t) : tag(t) {
463     if (isTensor()) {
464       new (&payload.as_tensor) executorch::aten::Tensor(p.as_tensor);
465     } else {
466       payload.copyable_union = p.copyable_union;
467     }
468   }
469 };
470 
471 #define EVALUE_DEFINE_TO(T, method_name)                                       \
472   template <>                                                                  \
473   inline T EValue::to<T>()&& {                                                 \
474     return static_cast<T>(std::move(*this).method_name());                     \
475   }                                                                            \
476   template <>                                                                  \
477   inline ::executorch::runtime::internal::evalue_to_const_ref_overload_return< \
478       T>::type                                                                 \
479   EValue::to<T>() const& {                                                     \
480     typedef ::executorch::runtime::internal::                                  \
481         evalue_to_const_ref_overload_return<T>::type return_type;              \
482     return static_cast<return_type>(this->method_name());                      \
483   }                                                                            \
484   template <>                                                                  \
485   inline ::executorch::runtime::internal::evalue_to_ref_overload_return<       \
486       T>::type                                                                 \
487   EValue::to<T>()& {                                                           \
488     typedef ::executorch::runtime::internal::evalue_to_ref_overload_return<    \
489         T>::type return_type;                                                  \
490     return static_cast<return_type>(this->method_name());                      \
491   }
492 
493 EVALUE_DEFINE_TO(executorch::aten::Scalar, toScalar)
494 EVALUE_DEFINE_TO(int64_t, toInt)
495 EVALUE_DEFINE_TO(bool, toBool)
496 EVALUE_DEFINE_TO(double, toDouble)
497 EVALUE_DEFINE_TO(executorch::aten::string_view, toString)
498 EVALUE_DEFINE_TO(executorch::aten::ScalarType, toScalarType)
499 EVALUE_DEFINE_TO(executorch::aten::MemoryFormat, toMemoryFormat)
500 EVALUE_DEFINE_TO(executorch::aten::Layout, toLayout)
501 EVALUE_DEFINE_TO(executorch::aten::Device, toDevice)
502 // Tensor and Optional Tensor
503 EVALUE_DEFINE_TO(
504     executorch::aten::optional<executorch::aten::Tensor>,
505     toOptional<executorch::aten::Tensor>)
506 EVALUE_DEFINE_TO(executorch::aten::Tensor, toTensor)
507 
508 // IntList and Optional IntList
509 EVALUE_DEFINE_TO(executorch::aten::ArrayRef<int64_t>, toIntList)
510 EVALUE_DEFINE_TO(
511     executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>,
512     toOptional<executorch::aten::ArrayRef<int64_t>>)
513 
514 // DoubleList and Optional DoubleList
515 EVALUE_DEFINE_TO(executorch::aten::ArrayRef<double>, toDoubleList)
516 EVALUE_DEFINE_TO(
517     executorch::aten::optional<executorch::aten::ArrayRef<double>>,
518     toOptional<executorch::aten::ArrayRef<double>>)
519 
520 // BoolList and Optional BoolList
521 EVALUE_DEFINE_TO(executorch::aten::ArrayRef<bool>, toBoolList)
522 EVALUE_DEFINE_TO(
523     executorch::aten::optional<executorch::aten::ArrayRef<bool>>,
524     toOptional<executorch::aten::ArrayRef<bool>>)
525 
526 // TensorList and Optional TensorList
527 EVALUE_DEFINE_TO(
528     executorch::aten::ArrayRef<executorch::aten::Tensor>,
529     toTensorList)
530 EVALUE_DEFINE_TO(
531     executorch::aten::optional<
532         executorch::aten::ArrayRef<executorch::aten::Tensor>>,
533     toOptional<executorch::aten::ArrayRef<executorch::aten::Tensor>>)
534 
535 // List of Optional Tensor
536 EVALUE_DEFINE_TO(
537     executorch::aten::ArrayRef<
538         executorch::aten::optional<executorch::aten::Tensor>>,
539     toListOptionalTensor)
540 #undef EVALUE_DEFINE_TO
541 
542 template <typename T>
543 executorch::aten::ArrayRef<T> BoxedEvalueList<T>::get() const {
544   for (typename executorch::aten::ArrayRef<T>::size_type i = 0;
545        i < wrapped_vals_.size();
546        i++) {
547     ET_CHECK(wrapped_vals_[i] != nullptr);
548     unwrapped_vals_[i] = wrapped_vals_[i]->template to<T>();
549   }
550   return executorch::aten::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size()};
551 }
552 
553 } // namespace runtime
554 } // namespace executorch
555 
556 namespace torch {
557 namespace executor {
558 // TODO(T197294990): Remove these deprecated aliases once all users have moved
559 // to the new `::executorch` namespaces.
560 using ::executorch::runtime::BoxedEvalueList;
561 using ::executorch::runtime::EValue;
562 } // namespace executor
563 } // namespace torch
564