xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/ivalue_inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <condition_variable>
4 #include <memory>
5 #include <optional>
6 #include <type_traits>
7 #include <utility>
8 
9 #include <ATen/core/Dict.h>
10 #include <ATen/core/List.h>
11 #include <ATen/core/IListRef.h>
12 #include <ATen/core/functional.h>
13 #include <ATen/core/jit_type.h>
14 #include <ATen/core/qualified_name.h>
15 #include <ATen/core/rref_interface.h>
16 #include <ATen/core/symbol.h>
17 #include <c10/core/DeviceGuard.h>
18 #include <c10/core/Event.h>
19 #include <c10/core/Scalar.h>
20 #include <c10/core/Stream.h>
21 #include <c10/core/StreamGuard.h>
22 #include <c10/core/TensorImpl.h>
23 #include <c10/core/UndefinedTensorImpl.h>
24 #include <c10/core/impl/DeviceGuardImplInterface.h>
25 #include <c10/util/FunctionRef.h>
26 #include <c10/util/Logging.h>
27 #include <c10/util/hash.h>
28 #include <c10/util/intrusive_ptr.h>
29 #include <c10/util/irange.h>
30 
31 namespace torch {
32 namespace jit {
33 struct Function;
34 struct CompilationUnit;
35 } // namespace jit
36 TORCH_API bool isCustomClass(const c10::IValue& v);
37 } // namespace torch
38 namespace c10 {
39 struct IValue;
40 struct ClassType;
41 struct TupleType;
42 struct EnumType;
43 struct InferredType;
44 
45 // For custom class __init__ registration, we need to pass in a function
46 // that looks like this: [](IValue x, args...)
47 
48 // However, make_boxed_from_unboxed_functor.h automatically sets the input types
49 // of the function by introspecting the types of the functor (which is IValue in
50 // this case). However, we need the type it binds to be Foo.
51 
52 // Instead, we pass in a lambda [](ivalue_holder<CurClass> x, args...) from
53 // which getTypePtr can recover the original class pointer.
54 
55 template <typename TaggedCapsuleType>
56 struct tagged_capsule {
57   IValue ivalue;
58 };
59 
60 template <class T, class NullType>
moveToIntrusivePtr()61 c10::intrusive_ptr<T, NullType> IValue::moveToIntrusivePtr() {
62   auto t = c10::intrusive_ptr<T, NullType>::reclaim(
63       payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
64       ? NullType::singleton()
65       : static_cast<T*>(payload.u.as_intrusive_ptr));
66   clearToNone();
67   return t;
68 }
69 template <typename T, class NullType>
toIntrusivePtr()70 c10::intrusive_ptr<T, NullType> IValue::toIntrusivePtr() const {
71   if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) {
72     return c10::intrusive_ptr<T, NullType>();
73   }
74   c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr);
75   return c10::intrusive_ptr<T, NullType>::reclaim(
76       static_cast<T*>(payload.u.as_intrusive_ptr));
77 }
78 
79 template <class T, class U>
static_intrusive_pointer_cast(intrusive_ptr<U> r)80 intrusive_ptr<T> static_intrusive_pointer_cast(intrusive_ptr<U> r) {
81   return intrusive_ptr<T>::reclaim(static_cast<T*>(r.release()));
82 }
83 
84 template <class T, class U>
dynamic_intrusive_pointer_cast(intrusive_ptr<U> r)85 intrusive_ptr<T> dynamic_intrusive_pointer_cast(intrusive_ptr<U> r) {
86   return intrusive_ptr<T>::reclaim(dynamic_cast<T*>(r.release()));
87 }
88 
toFuture()89 inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() && {
90   AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
91   return moveToIntrusivePtr<ivalue::Future>();
92 }
toFuture()93 inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() const& {
94   AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
95   return toIntrusivePtr<ivalue::Future>();
96 }
toAwait()97 inline c10::intrusive_ptr<ivalue::Await> IValue::toAwait() && {
98   AT_ASSERT(isAwait(), "Expected Await but got ", tagKind());
99   return moveToIntrusivePtr<ivalue::Await>();
100 }
toAwait()101 inline c10::intrusive_ptr<ivalue::Await> IValue::toAwait() const& {
102   AT_ASSERT(isAwait(), "Expected Await but got ", tagKind());
103   return toIntrusivePtr<ivalue::Await>();
104 }
toRRef()105 inline c10::intrusive_ptr<c10::RRefInterface> IValue::toRRef() && {
106   AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind());
107   return moveToIntrusivePtr<c10::RRefInterface>();
108 }
toRRef()109 inline c10::intrusive_ptr<c10::RRefInterface> IValue::toRRef() const& {
110   AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind());
111   return toIntrusivePtr<c10::RRefInterface>();
112 }
toQuantizer()113 inline c10::intrusive_ptr<at::Quantizer> IValue::toQuantizer() && {
114   AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind());
115   return moveToIntrusivePtr<at::Quantizer>();
116 }
toQuantizer()117 inline c10::intrusive_ptr<at::Quantizer> IValue::toQuantizer() const& {
118   AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind());
119   return toIntrusivePtr<at::Quantizer>();
120 }
toString()121 inline c10::intrusive_ptr<ivalue::ConstantString> IValue::toString() && {
122   AT_ASSERT(isString(), "Expected String but got ", tagKind());
123   return moveToIntrusivePtr<ivalue::ConstantString>();
124 }
toString()125 inline c10::intrusive_ptr<ivalue::ConstantString> IValue::toString() const& {
126   AT_ASSERT(isString(), "Expected String but got ", tagKind());
127   return toIntrusivePtr<ivalue::ConstantString>();
128 }
toObject()129 inline c10::intrusive_ptr<ivalue::Object> IValue::toObject() && {
130   AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
131   return moveToIntrusivePtr<ivalue::Object>();
132 }
toObject()133 inline c10::intrusive_ptr<ivalue::Object> IValue::toObject() const& {
134   AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
135   return toIntrusivePtr<ivalue::Object>();
136 }
137 inline c10::intrusive_ptr<ivalue::PyObjectHolder> IValue::
toPyObjectHolder()138     toPyObjectHolder() && {
139   TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind());
140   return moveToIntrusivePtr<ivalue::PyObjectHolder>();
141 }
toPyObjectHolder()142 inline c10::intrusive_ptr<ivalue::PyObjectHolder> IValue::toPyObjectHolder()
143     const& {
144   TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind());
145   return toIntrusivePtr<ivalue::PyObjectHolder>();
146 }
toEnumHolder()147 inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() && {
148   TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind());
149   return moveToIntrusivePtr<ivalue::EnumHolder>();
150 }
toEnumHolder()151 inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() const& {
152   TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind());
153   return toIntrusivePtr<ivalue::EnumHolder>();
154 }
toComplexDouble()155 inline c10::complex<double> IValue::toComplexDouble() const {
156   TORCH_INTERNAL_ASSERT(isComplexDouble(), "Expected ComplexDouble but got ", tagKind());
157   auto ptr = toIntrusivePtr<ivalue::ComplexHolder>();
158   return (*ptr).val;
159 }
toTensor()160 inline at::Tensor IValue::toTensor() && {
161   if (C10_UNLIKELY(!isTensor())) {
162     reportToTensorTypeError();
163   }
164   auto result = std::move(payload.as_tensor);
165   // As far as I can tell, omitting the usual explicit destructor call
166   // is not UB in and of itself, and it's a slight perf win. The
167   // destructor is a no-op, because the moved-from Tensor is
168   // effectively an intrusive_ptr in the null state, so we don't need
169   // the behavior for correctness reasons either. Leaving this
170   // explanatory comment, including commented-out destructor call, to
171   // make this abundantly clear.
172   //
173   // payload.as_tensor.~Tensor();
174   clearToNone();
175   return result;
176 }
toTensor()177 inline at::Tensor& IValue::toTensor() & {
178   if (C10_UNLIKELY(!isTensor())) {
179     reportToTensorTypeError();
180   }
181   return payload.as_tensor;
182 }
toTensor()183 inline const at::Tensor& IValue::toTensor() const& {
184   if (C10_UNLIKELY(!isTensor())) {
185     reportToTensorTypeError();
186   }
187   return payload.as_tensor;
188 }
toStorage()189 inline c10::Storage IValue::toStorage() && {
190   AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
191   return c10::Storage(
192       moveToIntrusivePtr<at::StorageImpl>());
193 }
toStorage()194 inline c10::Storage IValue::toStorage() const& {
195   AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
196   return c10::Storage(toIntrusivePtr<at::StorageImpl>());
197 }
toStream()198 inline c10::Stream IValue::toStream() && {
199   AT_ASSERT(isStream(), "Expected Stream but got ", tagKind());
200   auto ptr = toIntrusivePtr<ivalue::StreamData3Holder>();
201   return c10::Stream::unpack3((*ptr).val.stream_id,
202                               (*ptr).val.device_index,
203                               (*ptr).val.device_type);
204 }
toStream()205 inline c10::Stream IValue::toStream() const& {
206   AT_ASSERT(isStream(), "Expected Stream but got ", tagKind());
207   auto ptr = toIntrusivePtr<ivalue::StreamData3Holder>();
208   return c10::Stream::unpack3((*ptr).val.stream_id,
209                               (*ptr).val.device_index,
210                               (*ptr).val.device_type);
211 }
toBlob()212 inline c10::intrusive_ptr<caffe2::Blob> IValue::toBlob() && {
213   AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
214   return moveToIntrusivePtr<caffe2::Blob>();
215 }
toBlob()216 inline c10::intrusive_ptr<caffe2::Blob> IValue::toBlob() const& {
217   AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
218   return toIntrusivePtr<caffe2::Blob>();
219   ;
220 }
toCapsule()221 inline c10::intrusive_ptr<torch::CustomClassHolder> IValue::toCapsule() && {
222   TORCH_INTERNAL_ASSERT(isCapsule());
223   return moveToIntrusivePtr<torch::CustomClassHolder>();
224 }
toCapsule()225 inline c10::intrusive_ptr<torch::CustomClassHolder> IValue::toCapsule() const& {
226   TORCH_INTERNAL_ASSERT(isCapsule());
227   return toIntrusivePtr<torch::CustomClassHolder>();
228 }
toGenerator()229 inline at::Generator IValue::toGenerator() && {
230   AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind());
231   return at::Generator(moveToIntrusivePtr<at::GeneratorImpl>());
232 }
toGenerator()233 inline at::Generator IValue::toGenerator() const& {
234   AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind());
235   return at::Generator(toIntrusivePtr<at::GeneratorImpl>());
236 }
toSymInt()237 inline c10::SymInt IValue::toSymInt() && {
238   AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
239   if (isSymInt()) {
240     return c10::SymInt(moveToIntrusivePtr<c10::SymNodeImpl>());
241   } else {
242     return c10::SymInt(payload.u.as_int);
243   }
244 }
toSymInt()245 inline c10::SymInt IValue::toSymInt() const& {
246   AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
247   if (isSymInt()) {
248     return c10::SymInt(toIntrusivePtr<c10::SymNodeImpl>());
249   } else {
250     return c10::SymInt(payload.u.as_int);
251   }
252 }
toSymFloat()253 inline c10::SymFloat IValue::toSymFloat() && {
254   AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
255   if (isSymFloat()) {
256     return c10::SymFloat(moveToIntrusivePtr<c10::SymNodeImpl>());
257   } else {
258     return c10::SymFloat(payload.u.as_double);
259   }
260 }
toSymFloat()261 inline c10::SymFloat IValue::toSymFloat() const& {
262   AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
263   if (isSymFloat()) {
264     return c10::SymFloat(toIntrusivePtr<c10::SymNodeImpl>());
265   } else {
266     return c10::SymFloat(payload.u.as_double);
267   }
268 }
toSymBool()269 inline c10::SymBool IValue::toSymBool() && {
270   AT_ASSERT(isSymBool() || isBool(), "Expected SymBool or boolean but got ", tagKind());
271   if (isSymBool()) {
272     return c10::SymBool(moveToIntrusivePtr<c10::SymNodeImpl>());
273   } else {
274     return c10::SymBool(payload.u.as_bool);
275   }
276 }
277 
toSymBool()278 inline c10::SymBool IValue::toSymBool() const& {
279   AT_ASSERT(isSymBool() || isBool(), "Expected SymBool or boolean but got ", tagKind());
280   if (isSymBool()) {
281     return c10::SymBool(toIntrusivePtr<c10::SymNodeImpl>());
282   } else {
283     return c10::SymBool(payload.u.as_bool);
284   }
285 }
286 
287 namespace ivalue {
288 
289 void TORCH_API
290 checkCustomClassType(const ClassType* expected_type, const Type* actual_type);
291 
292 template <typename T>
293 using Shared = c10::intrusive_ptr<T>;
294 
295 // string
296 struct TORCH_API ConstantString final : c10::intrusive_ptr_target {
297  private:
298    // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
299   const std::string str_;
300 
301  public:
ConstantStringfinal302   ConstantString(std::string str) : str_(std::move(str)) {}
ConstantStringfinal303   ConstantString(c10::string_view str) : str_(std::string(str)) {}
304   static c10::intrusive_ptr<ConstantString> create(std::string str_);
305   static c10::intrusive_ptr<ConstantString> create(c10::string_view str_);
306   static c10::intrusive_ptr<ConstantString> create(const char* str_);
307 
stringfinal308   const std::string& string() const {
309     return str_;
310   }
string_viewfinal311   c10::string_view string_view() const {
312     return str_;
313   }
314 
315   operator const std::string&() const {
316     return string();
317   }
318   TORCH_API friend std::ostream& operator<<(
319       std::ostream& out,
320       const ConstantString& v);
321 };
322 
323 struct Future;
324 
325 struct TORCH_API TupleElements {
326  private:
327   size_t inlineSize_;
328   // We represent TupleElements this way to save doing a heap
329   // allocation in the common (at least for unpickling) case where we
330   // have only 3 elements. We have our own union instead of
331   // c10::SmallVector<IValue> because c10::SmallVector<IValue> always
332   // stores the begin/end/capacity pointers, which would be a waste of
333   // space in our use case.
334   union {
335     std::vector<IValue> elementsVector_;
336     // Don't want to declare a std::array because the convenient
337     // iteration and size members are a footgun in this case -- the
338     // actual size of the array may be smaller than 3!
339     // NOLINTNEXTLINE(*c-arrays*)
340     IValue elementsInline_[3];
341   };
342 
destroyInlineTupleElements343   void destroyInline() {
344    for (const auto ii : c10::irange(inlineSize_)) {
345      elementsInline_[ii].~IValue();
346    }
347   }
348  public:
349 
350   using iterator = IValue*;
351   using const_iterator = const IValue*;
352 
TupleElementsTupleElements353   TupleElements() : inlineSize_(0) {
354     new (&elementsVector_) std::vector<IValue>();
355   }
356 
TupleElementsTupleElements357   explicit TupleElements(std::vector<IValue> elements)
358   : inlineSize_(0), elementsVector_(std::move(elements)) {}
359 
TupleElementsTupleElements360   explicit TupleElements(c10::ArrayRef<IValue> elements)
361   : inlineSize_(elements.size() <= 3 ? elements.size() : 0) {
362     switch (inlineSize_) {
363       case 3:
364         new (&elementsInline_[2]) IValue(elements[2]);
365         [[fallthrough]];
366       case 2:
367         new (&elementsInline_[1]) IValue(elements[1]);
368         [[fallthrough]];
369       case 1:
370         new (&elementsInline_[0]) IValue(elements[0]);
371         break;
372       case 0:
373         new (&elementsVector_) std::vector<IValue>(elements.begin(), elements.end());
374         break;
375     }
376   }
377 
TupleElementsTupleElements378   explicit TupleElements(IValue&& e1)
379   : inlineSize_(1) {
380     new (&elementsInline_[0]) IValue(std::move(e1));
381   }
382 
TupleElementsTupleElements383   explicit TupleElements(IValue&& e1, IValue&& e2)
384   : inlineSize_(2) {
385     new (&elementsInline_[0]) IValue(std::move(e1));
386     new (&elementsInline_[1]) IValue(std::move(e2));
387   }
388 
TupleElementsTupleElements389   explicit TupleElements(IValue&& e1, IValue&& e2, IValue&& e3)
390   : inlineSize_(3) {
391     new (&elementsInline_[0]) IValue(std::move(e1));
392     new (&elementsInline_[1]) IValue(std::move(e2));
393     new (&elementsInline_[2]) IValue(std::move(e3));
394   }
395 
~TupleElementsTupleElements396   ~TupleElements() {
397     if (inlineSize_) {
398       destroyInline();
399     } else {
400       elementsVector_.~vector();
401     }
402   }
403 
404   // It would be nice to make this noncopyable to prevent people from
405   // writing code like `auto output =
406   // forward(...).toTupleRef().elements()` (which does refcount bumps on
407   // each element, unlike the more efficient but verbose
408   // ```
409   // auto outputIntrusivePtr = forward(...).toTuple();
410   // const auto& output = outputIntrusivePtr->elements();
411   // ```
412   // ), but there is simply an overwhelming amount of code that does
413   // it the inefficient way.
414   // See also operator std::vector below.
TupleElementsTupleElements415   TupleElements(const TupleElements& rhs)
416   : inlineSize_(rhs.inlineSize_) {
417     if (rhs.inlineSize_) {
418       for (const auto  ii : c10::irange(inlineSize_)) {
419         new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
420       }
421     } else {
422       new (&elementsVector_) std::vector<IValue>(rhs.elementsVector_);
423     }
424   }
425 
426   TupleElements& operator=(const TupleElements& rhs) {
427     if (inlineSize_) {
428       if (rhs.inlineSize_) {
429         for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) {
430           elementsInline_[ii] = rhs.elementsInline_[ii];
431         }
432         if (rhs.inlineSize_ > inlineSize_) {
433           for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) {
434             new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
435           }
436         } else {
437           for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) {
438             elementsInline_[ii].~IValue();
439           }
440         }
441       } else {
442         destroyInline();
443         new (&elementsVector_) std::vector<IValue>(rhs.elementsVector_);
444       }
445     } else {
446       if (rhs.inlineSize_) {
447         elementsVector_.~vector();
448         for (const auto ii : c10::irange(rhs.inlineSize_)) {
449           new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
450         }
451       } else {
452         elementsVector_ = rhs.elementsVector_;
453       }
454     }
455     inlineSize_ = rhs.inlineSize_;
456     return *this;
457   }
458 
TupleElementsTupleElements459   TupleElements(TupleElements&& rhs) noexcept
460   : inlineSize_(rhs.inlineSize_) {
461     if (inlineSize_) {
462       for (const auto ii : c10::irange(inlineSize_)) {
463         new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
464       }
465     } else {
466       new (&elementsVector_) std::vector<IValue>(std::move(rhs.elementsVector_));
467     }
468   }
469 
470   TupleElements& operator=(TupleElements&& rhs) noexcept {
471     if (inlineSize_) {
472       if (rhs.inlineSize_) {
473         for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) {
474           elementsInline_[ii] = std::move(rhs.elementsInline_[ii]);
475         }
476         if (rhs.inlineSize_ > inlineSize_) {
477           for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) {
478             new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
479           }
480         } else {
481           for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) {
482             elementsInline_[ii].~IValue();
483           }
484         }
485       } else {
486         destroyInline();
487         new (&elementsVector_) std::vector<IValue>(std::move(rhs.elementsVector_));
488       }
489     } else {
490       if (rhs.inlineSize_) {
491         elementsVector_.~vector();
492         for (const auto ii : c10::irange(rhs.inlineSize_)) {
493           new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
494         }
495       } else {
496         elementsVector_ = std::move(rhs.elementsVector_);
497       }
498     }
499     inlineSize_ = rhs.inlineSize_;
500     return *this;
501   }
502 
asArrayRefTupleElements503   C10_NODISCARD c10::ArrayRef<IValue> asArrayRef() const {
504     if (inlineSize_) {
505       return c10::ArrayRef<IValue>(elementsInline_, inlineSize_);
506     } else {
507       return elementsVector_;
508     }
509   }
510 
511   // Mimic implicit conversion from std::vector to ArrayRef.
512   operator c10::ArrayRef<IValue>() const {
513     return asArrayRef();
514   }
515 
hashTupleElements516   static size_t hash(const TupleElements& v) {
517     return c10::hash<c10::ArrayRef<IValue>>()(v.asArrayRef());
518   }
519 
setContentsTupleElements520   void setContents(std::vector<IValue>&& contents) {
521     if (inlineSize_) {
522       destroyInline();
523       new (&elementsVector_) std::vector<IValue>(std::move(contents));
524       inlineSize_ = 0;
525     } else {
526       elementsVector_ = std::move(contents);
527     }
528   }
529 
emptyTupleElements530   C10_NODISCARD bool empty() const {
531     return inlineSize_ ? false : elementsVector_.empty();
532   }
533 
sizeTupleElements534   C10_NODISCARD size_t size() const {
535     return inlineSize_ ? inlineSize_ : elementsVector_.size();
536   }
537 
538   C10_NODISCARD IValue& operator[](size_t idx) {
539     if (inlineSize_) {
540       return elementsInline_[idx];
541     } else {
542       return elementsVector_[idx];
543     }
544   }
545 
546   C10_NODISCARD const IValue& operator[](size_t idx) const {
547     if (inlineSize_) {
548       return elementsInline_[idx];
549     } else {
550       return elementsVector_[idx];
551     }
552   }
553 
atTupleElements554   C10_NODISCARD IValue& at(size_t idx) {
555     if (inlineSize_) {
556       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3);
557       TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_);
558       return elementsInline_[idx];
559     } else {
560       return elementsVector_.at(idx);
561     }
562   }
563 
atTupleElements564   C10_NODISCARD const IValue& at(size_t idx) const {
565     if (inlineSize_) {
566       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3);
567       TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_);
568       return elementsInline_[idx];
569     } else {
570       TORCH_CHECK(idx < elementsVector_.size(), "TupleElements: invalid index Index = ", idx, "; Length = ", elementsVector_.size());
571       return elementsVector_.at(idx);
572     }
573   }
574 
beginTupleElements575   C10_NODISCARD iterator begin() {
576     if (inlineSize_) {
577       return elementsInline_;
578     } else {
579       return elementsVector_.data();
580     }
581   }
582 
endTupleElements583   C10_NODISCARD iterator end() {
584     if (inlineSize_) {
585       return elementsInline_ + inlineSize_;
586     } else {
587       return elementsVector_.data() + elementsVector_.size();
588     }
589   }
590 
beginTupleElements591   C10_NODISCARD const_iterator begin() const {
592     if (inlineSize_) {
593       return elementsInline_;
594     } else {
595       return elementsVector_.data();
596     }
597   }
598 
endTupleElements599   C10_NODISCARD const_iterator end() const {
600     if (inlineSize_) {
601       return elementsInline_ + inlineSize_;
602     } else {
603       return elementsVector_.data() + elementsVector_.size();
604     }
605   }
606 
cbeginTupleElements607   C10_NODISCARD const_iterator cbegin() const {
608     return begin();
609   }
610 
cendTupleElements611   C10_NODISCARD const_iterator cend() const {
612     return end();
613   }
614 
vecTupleElements615   C10_NODISCARD std::vector<IValue> vec() const & {
616     return asArrayRef().vec();
617   }
618 
backTupleElements619   C10_NODISCARD IValue& back() {
620     return *(end() - 1);
621   }
622 
backTupleElements623   C10_NODISCARD const IValue& back() const {
624     return *(end() - 1);
625   }
626 
vecTupleElements627   C10_NODISCARD std::vector<IValue> vec() && {
628     std::vector<IValue> result;
629     result.reserve(size());
630     for (auto&& iv : *this) {
631       result.push_back(std::move(iv));
632     }
633     return result;
634   }
635 
636   // More compatibility shims for the overwhelming amount of code that
637   // likes to copy tuple elements into a vector; see comment above the
638   // copy constructor.
639   operator std::vector<IValue>() const & {
640     return vec();
641   }
642 
643   operator std::vector<IValue>() && {
644     return vec();
645   }
646 };
647 
648 template <typename T>
649 struct TupleTypeFactory {};
650 
651 template <>
652 struct TORCH_API TupleTypeFactory<TupleType> {
653   static TupleTypePtr create(std::vector<TypePtr> types) {
654     return TupleType::create(std::move(types));
655   }
656   static TupleTypePtr fallback(const Type& type);
657 };
658 
659 template <>
660 struct TORCH_API TupleTypeFactory<c10::DynamicType> {
661   static DynamicTypePtr create(const std::vector<TypePtr>& elemTypes);
662   static DynamicTypePtr fallback(const Type&);
663 };
664 
665 struct TORCH_API Tuple : c10::intrusive_ptr_target {
666  private:
667   TupleElements elements_;
668   mutable c10::TypePtr type_; // lazily computed for unnamed tuples
669 
670  public:
671   // named tuples have additional type information, so we
672   // directly create them tagged
673   static c10::intrusive_ptr<Tuple> createNamed(
674       std::vector<IValue> elements_,
675       c10::TypePtr type_) {
676     return c10::make_intrusive<Tuple>(std::move(elements_), std::move(type_));
677   }
678 
679   static c10::intrusive_ptr<Tuple> createNamed(
680       TupleElements elements_,
681       std::shared_ptr<TupleType> type_) {
682     return c10::make_intrusive<Tuple>(std::move(elements_), std::move(type_));
683   }
684 
685   static c10::intrusive_ptr<Tuple> createNamed(
686       std::initializer_list<IValue> elements_,
687       std::shared_ptr<TupleType> type_) {
688     return createNamed(TupleElements(c10::ArrayRef<IValue>(elements_)), std::move(type_));
689   }
690 
691   // MSVC apparently can't disambiguate the other two overloads of
692   // create when passed an initializer_list without this.
693   static c10::intrusive_ptr<Tuple> create(std::initializer_list<IValue> elements_) {
694     return create(c10::ArrayRef<IValue>(elements_));
695   }
696 
697   static c10::intrusive_ptr<Tuple> create(std::vector<IValue> elements_) {
698     return c10::make_intrusive<Tuple>(std::move(elements_));
699   }
700 
701   static c10::intrusive_ptr<Tuple> create(TupleElements elements_) {
702     return c10::make_intrusive<Tuple>(std::move(elements_));
703   }
704 
705   static c10::intrusive_ptr<Tuple> create(c10::ArrayRef<IValue> elements_) {
706     return create(TupleElements(elements_));
707   }
708 
709   static c10::intrusive_ptr<Tuple> create(IValue e1) {
710     return c10::make_intrusive<Tuple>(std::move(e1));
711   }
712 
713   static c10::intrusive_ptr<Tuple> create(IValue e1, IValue e2) {
714     return c10::make_intrusive<Tuple>(std::move(e1), std::move(e2));
715   }
716 
717   static c10::intrusive_ptr<Tuple> create(IValue e1, IValue e2, IValue e3) {
718     return c10::make_intrusive<Tuple>(std::move(e1), std::move(e2), std::move(e3));
719   }
720 
721  private:
722   // Workaround inability to use `>` operator in template argument list.
723   template <typename... Args>
724   static constexpr bool hasMoreThanThreeArgs() {
725     return sizeof...(Args) > 3;
726   }
727 
728  public:
729   template <typename... Args>
730   static c10::intrusive_ptr<Tuple> create(Args&&... elements_) {
731     switch (sizeof...(Args)) {
732       case 1:
733       case 2:
734       case 3:
735         return create(IValue(std::forward<Args>(elements_))...);
736       default:
737         return create(
738             std::vector<IValue>{IValue(std::forward<Args>(elements_))...});
739     }
740   }
741 
742   // Again, it would be nice to make this noncopyable, but there's a
743   // lot of extant code that copies Tuples.
744   // Tuple(const Tuple& rhs) = delete;
745 
746   const TupleElements& elements() const& {
747     return elements_;
748   }
749 
750   TupleElements elements() && {
751     return std::move(elements_);
752   }
753 
754   void setElements(std::vector<IValue>&& elements) {
755     elements_.setContents(std::move(elements));
756   }
757 
758   void setElements(TupleElements&& elements) {
759     elements_ = std::move(elements);
760   }
761 
762   void unsafeSetElement(size_t idx, const IValue& element) {
763     elements_[idx] = element;
764   }
765 
766   void unsafeSetElement(size_t idx, IValue&& element) {
767     elements_[idx] = std::move(element);
768   }
769 
770   size_t size() const {
771     return elements_.size();
772   }
773 
774   template <typename T = c10::TupleType>
775   std::shared_ptr<T> type() const {
776     if (!type_) {
777       type_ = TupleTypeFactory<T>::create(fmap(elements(), [&](const IValue& v) {
778         return v.type<typename T::ElementType>();
779       }));
780     }
781     if (auto t = type_->cast<T>()) {
782       return t;
783     }
784     return TupleTypeFactory<T>::fallback(*type_);
785   }
786 
787   static size_t hash(const Tuple& t) {
788     return c10::get_hash(t.elements());
789   }
790 
791   TORCH_API friend bool operator==(
792       const ivalue::Tuple& lhs,
793       const ivalue::Tuple& rhs);
794 
795  private:
796   // NOTE: If we try to avoid the overloads without
797   // `std::shared_ptr<TupleType> type` by defaulting it to nullptr, we
798   // end up having to call (part of) the shared_ptr destructor for
799   // `type` even though we should know statically it won't do
800   // anything.
801   explicit Tuple(std::vector<IValue> elements)
802     : elements_(std::move(elements)){}
803 
804   explicit Tuple(std::vector<IValue> elements, c10::TypePtr type)
805     : elements_(std::move(elements)), type_(std::move(type)) {}
806 
807   explicit Tuple(TupleElements&& elements)
808     : elements_(std::move(elements)) {}
809 
810   explicit Tuple(TupleElements&& elements, std::shared_ptr<TupleType> type)
811     : elements_(std::move(elements)), type_(std::move(type)) {}
812 
813   explicit Tuple(IValue&& e1)
814     : elements_(std::move(e1)) {}
815 
816   explicit Tuple(IValue&& e1, std::shared_ptr<TupleType> type)
817     : elements_(std::move(e1)), type_(std::move(type)) {}
818 
819   explicit Tuple(IValue&& e1, IValue&& e2)
820     : elements_(std::move(e1), std::move(e2)) {}
821 
822   explicit Tuple(IValue&& e1, IValue&& e2, std::shared_ptr<TupleType> type)
823     : elements_(std::move(e1), std::move(e2)), type_(std::move(type)) {}
824 
825   explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3)
826     : elements_(std::move(e1), std::move(e2), std::move(e3)) {}
827 
828   explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3, std::shared_ptr<TupleType> type)
829     : elements_(std::move(e1), std::move(e2), std::move(e3)), type_(std::move(type)) {}
830 
831   friend class c10::intrusive_ptr<Tuple>;
832 };
833 
834 struct Object;
835 struct PyObjectHolder;
836 struct EnumHolder;
837 } // namespace ivalue
838 
839 // Future
840 struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
841  private:
842   // Keep this private in order to force users to go through make_intrusive and
843   // thus prevent creating a Future that's not held by an intrusive_ptr.
844   explicit Future(TypePtr type, std::vector<c10::Device> devices={})
845       : type_(std::move(type)),
846         impl_(getTypeOfDevices(devices)),
847         devices_(sortAndDeduplicateDevices(impl_, std::move(devices))) {}
848 
849   friend c10::intrusive_ptr<Future>;
850 
851   struct FutureCallback {
852     std::function<void(Future&)> callback;
853     bool uses_future; // whether the Future& passed in is actually used
854 
855     template <typename T>
856     FutureCallback(T callback, bool uses_future)
857         : callback(std::move(callback)), uses_future(uses_future) {}
858   };
859 
860  public:
861   Future(const Future&) = delete;
862   Future(Future&&) = delete;
863   Future& operator=(const Future&) = delete;
864   Future& operator=(Future&&) = delete;
865 
866   struct TORCH_API FutureError final : public std::exception {
867     explicit FutureError(std::string&& error_msg_)
868         : error_msg(std::move(error_msg_)) {}
869 
870     FutureError() = default;
871 
872     const char* what() const noexcept override {
873       return error_msg.c_str();
874     }
875 
876     std::string error_msg;
877   };
878 
879   /**
880    * Wait on the future until it completes.
881    */
882   void wait() {
883     std::unique_lock<std::mutex> lock(mutex_);
884     finished_cv_.wait(lock, [&]() -> bool { return completed_; });
885     synchronizeWithCurrentStreams();
886   }
887 
888   /**
889    * Wait on the future until it completes and throw an
890    * exception if an error exists.
891    */
892   void waitAndThrow() {
893     wait();
894 
895     if (eptr_) {
896       std::rethrow_exception(eptr_);
897     }
898   }
899 
900   /**
901    * Explicitly mark the future as completed with the output value. Optionally,
902    * the storages for all tensors in IValue can be passed as well. The DataPtrs
903    * of these storages are used to synchronize CUDA streams. If storages isn't
904    * given we will attempt to extract it from the value, if we need to (this
905    * happens if a non-empty set of devices was given to the constructor). Thus
906    * one only needs to provide storages when 1) they cannot be extracted through
907    * IValue::getSubValues() or through pickling in case of Python object; or
908    * when 2) customized storage extraction is more efficient.
909    */
910   using WeakStorage = c10::weak_intrusive_ptr<c10::StorageImpl>;
911   void markCompleted(
912       IValue value,
913       std::optional<std::vector<WeakStorage>> storages = std::nullopt) {
914     // Start by performing all steps that can throw, before setting any field.
915     // Do this before even acquiring the mutex, because extractStorages might
916     // acquire the GIL, which could lead to a lock inversion with our mutex.
917     // See https://github.com/pytorch/pytorch/issues/58239.
918     std::vector<WeakStorage> actualStorages;
919     std::vector<c10::Device> usedDevices;
920     try {
921       // FIXME We should always extract DataPtrs, in order to catch the case of
922       // users using CUDA values but forgetting to set devices, which currently
923       // leads to a silent synchronization/correctness issue. However, as this
924       // might worsen perf in CPU-only cases, we should only do so after careful
925       // benchmarks.
926       if (impl_.type() != c10::kCPU) {
927         actualStorages =
928             storages.has_value() ? std::move(*storages) : extractStorages(value);
929         usedDevices = getDevicesOfStorages(impl_, actualStorages);
930         ensureIsSubsetOfDevices(usedDevices, devices_);
931       }
932     } catch (const std::exception&) {
933       setError(std::current_exception());
934       return;
935     }
936 
937     std::unique_lock<std::mutex> lock(mutex_);
938     TORCH_CHECK(
939         !completed(),
940         "Attempting to mark a completed Future as complete again. Note that "
941         "a Future can only be marked completed once.");
942 
943     // Only set value_ and completed_ flag once all checks and preparation steps
944     // have returned successfully to allow for proper error propagation.
945     value_ = std::move(value);
946     completed_ = true;
947 
948     currentDevice_ = impl_.getDevice();
949     storages_ = std::move(actualStorages);
950     for (const c10::Device& device : usedDevices) {
951       c10::Event event(impl_.type());
952       event.record(impl_.getStream(device));
953       events_.push_back(std::move(event));
954     }
955 
956     std::vector<FutureCallback> cbs;
957     cbs.swap(callbacks_);
958     lock.unlock();
959 
960     finished_cv_.notify_all();
961     for (auto& callback : cbs) {
962       invokeCallback(std::move(callback.callback), callback.uses_future);
963     }
964   }
965 
966   void markCompleted() {
967     markCompleted(IValue{});
968   }
969 
970   void setError(std::exception_ptr eptr) {
971     std::unique_lock<std::mutex> lock(mutex_);
972     setErrorInternal(std::move(eptr), lock);
973   }
974 
975   void setErrorIfNeeded(std::exception_ptr eptr) {
976     std::unique_lock<std::mutex> lock(mutex_);
977     if (completed_) {
978       // This should be rare and shouldn't cause log spew. Its important to
979       // log errors and thats why we have this log here.
980       std::string msg = c10::str(
981           "Skipping setting following error on the Future since "
982           "it is already marked completed (this is not necessarily "
983           "an error):\n",
984           tryRetrieveErrorMessageInternal(std::move(eptr)));
985       if (eptr_) {
986         msg += c10::str(
987             ", \nOriginal exception:\n",
988             tryRetrieveErrorMessageInternal(eptr_));
989       }
990       LOG(INFO) << msg;
991       return;
992     } else {
993       setErrorInternal(std::move(eptr), lock);
994     }
995   }
996 
997   // Get the result of the current future.
998   IValue value() {
999     std::unique_lock<std::mutex> lock(mutex_);
1000     AT_ASSERT(completed());
1001     if (eptr_) {
1002       std::rethrow_exception(eptr_);
1003     }
1004     return value_;
1005   }
1006 
1007   // This accessor should only be used if we know that the future is
1008   // completed() with no error.
1009   const IValue& constValue() const {
1010     std::unique_lock<std::mutex> lock(mutex_);
1011     AT_ASSERT(completed());
1012     TORCH_INTERNAL_ASSERT(
1013       !eptr_,
1014       "value() accessor should only be used when future is not completed with ",
1015       "an error, but future had the following error: ",
1016       tryRetrieveErrorMessageInternal(eptr_)
1017     );
1018     return value_;
1019   }
1020 
1021   // This accessor should only be used if we know that the future is
1022   // completed() with no error.
1023   const std::vector<WeakStorage>& storages() const {
1024     std::unique_lock<std::mutex> lock(mutex_);
1025     AT_ASSERT(completed());
1026     AT_ASSERT(!eptr_);
1027     return storages_;
1028   }
1029 
1030   /**
1031    * Add a callback to the future.
1032    * The callbacks will be executed once the future completes.
1033    * If the future has already completed,
1034    * this function will execute the callback immediately.
1035    */
1036   template <typename T>
1037   void addCallback(T callback, bool uses_future = true) {
1038     static_assert(
1039         std::is_invocable_r<void, T, Future&>::value,
1040         "The callback must have signature void(Future&)");
1041 
1042     std::unique_lock<std::mutex> lock(mutex_);
1043     if (completed()) {
1044       lock.unlock();
1045       invokeCallback(std::move(callback), uses_future);
1046       return;
1047     }
1048     callbacks_.emplace_back(std::move(callback), uses_future);
1049   }
1050 
1051   /**
1052    * Add a callback to the future, and return another Future to hold the return
1053    * value of the callback. This is necessary when the callback provider needs
1054    * to know for sure when the callback has finished.
1055    */
1056   template <typename T>
1057   c10::intrusive_ptr<Future> then(T callback, TypePtr type) {
1058     using IValueWithStorages = std::tuple<IValue, std::vector<WeakStorage>>;
1059     static_assert(
1060         std::disjunction<
1061             std::is_invocable_r<IValue, T, Future&>,
1062             std::is_invocable_r<IValueWithStorages, T, Future&>>::value,
1063         "The callback must have signature IValue(Future&) or "
1064         "std::tuple<IValue, std::vector<Storage>>(Future&)");
1065 
1066     auto childFut = createInstance(::std::move(type));
1067     addCallback([childFut,
1068                  cb = std::move(callback)](Future& parentFut) mutable {
1069       try {
1070         if constexpr (::std::is_convertible_v<typename std::invoke_result_t<T &&, Future&>, IValueWithStorages>) {
1071           auto [ivalue, storages] = cb(parentFut);
1072           childFut->markCompleted(::std::move(ivalue), ::std::move(storages));
1073         } else {
1074           childFut->markCompleted(cb(parentFut));
1075         }
1076       } catch (std::exception&) {
1077         childFut->setError(std::current_exception());
1078       }
1079     });
1080     return childFut;
1081   }
1082 
1083   template <typename T>
1084   c10::intrusive_ptr<Future> thenAsync(T callback, TypePtr type) {
1085     static_assert(
1086         std::is_invocable_r<c10::intrusive_ptr<Future>, T, Future&>::value,
1087         "The callback must have signature c10::intrusive_ptr<Future>(Future&)");
1088 
1089     auto childFut = createInstance(std::move(type));
1090     addCallback(
1091         [childFut, cb = std::move(callback)](Future& parentFut) mutable {
1092           c10::intrusive_ptr<Future> intermediateFut;
1093           try {
1094             intermediateFut = cb(parentFut);
1095           } catch (std::exception&) {
1096             childFut->setError(std::current_exception());
1097             return;
1098           }
1099           intermediateFut->addCallback(
1100               [childFut = std::move(childFut)](Future& intermediateFut) {
1101                 if (intermediateFut.hasError()) {
1102                   childFut->setError(intermediateFut.exception_ptr());
1103                 } else {
1104                   childFut->markCompleted(
1105                       intermediateFut.value(), intermediateFut.storages());
1106                 }
1107               });
1108         });
1109     return childFut;
1110   }
1111 
1112   // Tries to retrieve the error message from std::exception_ptr.
1113   std::string tryRetrieveErrorMessage() const {
1114     TORCH_CHECK(hasError(), "No error present on the future.");
1115     std::unique_lock<std::mutex> lock(mutex_);
1116     return tryRetrieveErrorMessageInternal(eptr_);
1117   }
1118 
1119   // Check if the current future has completed
1120   bool completed() const {
1121     return completed_;
1122   }
1123 
1124   bool hasValue() const {
1125     std::unique_lock<std::mutex> lock(mutex_);
1126     return completed_ && !eptr_;
1127   }
1128 
1129   bool hasError() const {
1130     std::unique_lock<std::mutex> lock(mutex_);
1131     return eptr_ ? true : false;
1132   }
1133 
1134   std::exception_ptr exception_ptr() const {
1135     std::unique_lock<std::mutex> lock(mutex_);
1136     return eptr_;
1137   }
1138 
1139   TORCH_API friend std::ostream& operator<<(
1140       std::ostream& out,
1141       const Future& v);
1142 
1143   const TypePtr& elementType() const {
1144     return type_;
1145   }
1146 
1147   const std::vector<c10::Device>& devices() const {
1148     return devices_;
1149   }
1150 
1151   // This method should be used when one intends to manually create a child
1152   // future, for example when implementing a customized version of then().
1153   c10::intrusive_ptr<Future> createInstance(at::TypePtr type) {
1154     return c10::make_intrusive<Future>(std::move(type), devices_);
1155   }
1156 
1157  private:
1158 
1159   // This method should always be used when invoking a callback (regardless of
1160   // how/when that happens) as it will ensure that the proper "environment" is
1161   // set up before running the callback, as in, it will set up the CUDA streams,
1162   // synchronize them with the value, and so on (if needed).
1163   template<typename T>
1164   void invokeCallback(T callback, bool uses_future) {
1165     static_assert(
1166         std::is_invocable_r<void, T, Future&>::value,
1167         "The callback must have signature void(Future&)");
1168 
1169     // The synchronization performed below shouldn't be needed when the future
1170     // is not used by the callback.
1171     if (uses_future) {
1172       c10::OptionalDeviceGuard deviceGuard(currentDevice_);
1173 
1174       std::vector<c10::Stream> streams;
1175       streams.reserve(devices_.size());
1176       for (const c10::Device& device : devices_) {
1177         streams.push_back(impl_.getStreamFromGlobalPool(device));
1178       }
1179       c10::MultiStreamGuard streamGuard(streams);
1180       synchronizeWithCurrentStreams();
1181       callback(*this);
1182     } else {
1183       callback(*this);
1184     }
1185   }
1186 
1187   // This method should be called before this future's value is used, as it
1188   // ensures that the CUDA streams that are "current" at the callsite properly
1189   // synchronize with the value.
1190   void synchronizeWithCurrentStreams() {
1191     for (c10::Event& event : events_) {
1192       event.block(impl_.getStream(event.device()));
1193     }
1194 
1195     for (const WeakStorage& weak_storage : storages_) {
1196       c10::intrusive_ptr<c10::StorageImpl> storage = weak_storage.lock();
1197       if (!storage) {
1198         continue;
1199       }
1200       if (!storage->device().is_cpu()) {
1201         impl_.recordDataPtrOnStream(
1202             storage->data_ptr(), impl_.getStream(storage->device()));
1203       }
1204     }
1205   }
1206 
1207   void setErrorInternal(
1208       std::exception_ptr eptr,
1209       std::unique_lock<std::mutex>& lock) {
1210     TORCH_CHECK(
1211         !eptr_,
1212         "Error already set on this Future: ",
1213         tryRetrieveErrorMessageInternal(eptr_),
1214         ", trying to set error: ",
1215         tryRetrieveErrorMessageInternal(eptr));
1216     TORCH_INTERNAL_ASSERT(!completed(), "Future is already marked completed");
1217     completed_ = true;
1218     eptr_ = std::move(eptr);
1219 
1220     std::vector<FutureCallback> cbs;
1221     cbs.swap(callbacks_);
1222     lock.unlock();
1223 
1224     finished_cv_.notify_all();
1225     for (auto& callback : cbs) {
1226       invokeCallback(std::move(callback.callback), callback.uses_future);
1227     }
1228   }
1229 
1230   // Tries to retrieve the error message from std::exception_ptr.
1231   std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) const {
1232     try {
1233       std::rethrow_exception(std::move(eptr));
1234     } catch (const std::exception& e) {
1235       return e.what();
1236     } catch (...) {
1237       return "Unknown Exception Type";
1238     }
1239   }
1240 
1241   // Defined in ivalue.cpp.
1242   static std::vector<WeakStorage> extractStorages(
1243       const at::IValue& value);
1244 
1245   static std::vector<c10::Device> getDevicesOfStorages(
1246       const c10::impl::VirtualGuardImpl& impl,
1247       const std::vector<WeakStorage>& storages) {
1248     c10::DeviceIndex deviceCount = impl.deviceCount();
1249     std::vector<bool> isDeviceUsed(deviceCount, false);
1250     for (const WeakStorage& weak_storage : storages) {
1251       c10::intrusive_ptr<c10::StorageImpl> storage = weak_storage.lock();
1252       if (!storage) {
1253         continue;
1254       }
1255       c10::Device device = storage->device();
1256       if (!device.is_cpu()) {
1257         TORCH_CHECK_VALUE(
1258             device.type() == impl.type(),
1259             "Expected all data ptrs to be on a device of type ",
1260             impl.type(),
1261             ", got one on device ",
1262             device);
1263         isDeviceUsed[device.index()] = true;
1264       }
1265     }
1266     std::vector<c10::Device> devices;
1267     for (c10::DeviceIndex idx = 0; idx < deviceCount; idx++) {
1268       if (isDeviceUsed[idx]) {
1269         devices.emplace_back(impl.type(), idx);
1270       }
1271     }
1272     return devices;
1273   }
1274 
1275   static std::string formatSetOfDevices(
1276       const std::vector<c10::Device>& devices) {
1277     if (devices.empty()) {
1278       return "(none)";
1279     }
1280     std::ostringstream oss;
1281     oss << devices[0];
1282     for (const auto idx : c10::irange(1, devices.size())) {
1283       if (idx == devices.size() - 1) {
1284         oss << " and ";
1285       } else {
1286         oss << ", ";
1287       }
1288       oss << devices[idx];
1289     }
1290     return oss.str();
1291   }
1292 
1293   static c10::DeviceType getTypeOfDevices(
1294       const std::vector<c10::Device>& devices) {
1295     if (devices.empty()) {
1296       return c10::kCPU;
1297     }
1298     c10::DeviceType deviceType = devices[0].type();
1299     for (const auto idx : c10::irange(1, devices.size())) {
1300       TORCH_CHECK_VALUE(
1301           devices[idx].type() == deviceType,
1302           "Expected all devices to be of the same type, but got a mismatch between ",
1303           devices[0],
1304           " and ",
1305           devices[idx]);
1306     }
1307     return deviceType;
1308   }
1309 
1310   // We need devices to be sorted in order to use ensureIsSubsetOfDevices.
1311   static std::vector<c10::Device> sortAndDeduplicateDevices(
1312       const c10::impl::VirtualGuardImpl& /*impl*/,
1313       std::vector<c10::Device> devices) {
1314     std::sort(
1315       devices.begin(), devices.end(),
1316       [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); });
1317     // Deduplicate by compacting.
1318     size_t targetIdx = 0;
1319     for (const auto sourceIdx : c10::irange(devices.size())) {
1320       TORCH_CHECK_VALUE(
1321           devices[sourceIdx].has_index(),
1322           "Expected devices to have indices, got ", devices[sourceIdx]);
1323       if (targetIdx > 0 && devices[targetIdx - 1].index() == devices[sourceIdx].index()) {
1324         // It's a duplicate, skip it.
1325         continue;
1326       }
1327       if (sourceIdx != targetIdx) {
1328         devices[targetIdx] = devices[sourceIdx];
1329       }
1330       targetIdx++;
1331     }
1332     // If there were duplicates there's now a gap at the end: trim it. Resizing
1333     // requires the item type to be default-constructible (which c10::Device is
1334     // not) because in principle it could be required to create new items. Since
1335     // we know we'll shrink the vector, we provide a custom dummy value instead.
1336     devices.resize(targetIdx, c10::Device(c10::kCPU));
1337     return devices;
1338   }
1339 
1340   static void ensureIsSubsetOfDevices(
1341       const std::vector<c10::Device>& subset,
1342       const std::vector<c10::Device>& superset) {
1343     // We assume the devices in both vectors have the same consistent type, and
1344     // their indices are unique and sorted.
1345     std::vector<c10::Device> excessDevices;
1346     std::set_difference(
1347         subset.begin(),
1348         subset.end(),
1349         superset.begin(),
1350         superset.end(),
1351         std::back_inserter(excessDevices),
1352         [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); });
1353     TORCH_CHECK_VALUE(
1354         excessDevices.empty(),
1355         "The result contained tensors residing on device(s) ",
1356         formatSetOfDevices(excessDevices),
1357         " which are not among the expected device(s) ",
1358         formatSetOfDevices(superset));
1359   }
1360 
1361   mutable std::mutex mutex_;
1362   std::atomic_bool completed_ = {false}; // is this future complete
1363   std::condition_variable finished_cv_;
1364 
1365   IValue value_; // when finished the value
1366   TypePtr type_;
1367   std::vector<FutureCallback> callbacks_;
1368   std::exception_ptr eptr_;
1369 
1370   // An upcast pointer to a virtual class which allows us to manipulate events,
1371   // streams, ... in a generic way, without an explicit dependency on CUDA.
1372   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
1373   const c10::impl::VirtualGuardImpl impl_;
1374 
1375   // The device that was current when markCompleted was called, which we'll
1376   // restore when invoking callbacks. It's optional because we'll only store it
1377   // if the future completes successfully.
1378   std::optional<c10::Device> currentDevice_;
1379 
1380   // The events that correspond to the completion of the async I/O kernels. They
1381   // are recorded on the appropriate streams when the future is marked completed
1382   // and can then be queried/waited/blocked on. There is one event for each
1383   // distinct device on which the value's tensors reside.
1384   std::vector<c10::Event> events_;
1385 
1386   // A cached version of the storages extracted from the value when the future
1387   // is first marked completed.
1388   std::vector<WeakStorage> storages_;
1389 
1390   // The bounding set of devices that this future, and any of its children, is
1391   // allowed to use. This is a superset of the set of devices used by the events
1392   // above. We need this to know what streams (for which devices) to set as
1393   // current when invoking a callback, thus allowing the callback to use devices
1394   // that the parent future didn't use. This field is set to the value provided
1395   // in the constructor and will be "inherited" by all child futures.
1396   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
1397   const std::vector<c10::Device> devices_;
1398 };
1399 
1400 struct C10_EXPORT ivalue::Await final : c10::intrusive_ptr_target {
1401  private:
1402   explicit Await(TypePtr elType, std::function<IValue()> fn)
1403       : elType_(std::move(elType)), type_(AwaitType::create(elType_)), fn_(std::move(fn)) {}
1404 
1405   explicit Await(TypePtr elType) : elType_(std::move(elType)), type_(AwaitType::create(elType_)) { }
1406 
1407   friend c10::intrusive_ptr<Await>;
1408 
1409  public:
1410   Await(const Await&) = delete;
1411   Await(Await&&) = delete;
1412   Await& operator=(const Await&) = delete;
1413   Await& operator=(Await&&) = delete;
1414 
1415   IValue wait() {
1416     if (!completed_) {
1417       TORCH_CHECK(fn_, "Incompleted Await: fn can't be None");
1418       value_ = fn_();
1419       completed_ = true;
1420       args_ = {};
1421     }
1422     return value_;
1423   }
1424 
1425   IValue value() {
1426     TORCH_CHECK(completed_, "Await must be completed");
1427     return value_;
1428   }
1429 
1430   void setFn(std::function<IValue()> fn) {
1431     fn_ = std::move(fn);
1432   }
1433 
1434   bool completed() {
1435     return completed_;
1436   }
1437 
1438   void markCompleted(IValue value) {
1439     value_ = std::move(value);
1440     completed_ = true;
1441   }
1442 
1443   TORCH_API friend std::ostream& operator<<(
1444       std::ostream& out,
1445       const Await& v);
1446 
1447   const TypePtr& elementType() const {
1448     return elType_;
1449   }
1450 
1451   const TypePtr& type() const {
1452     return type_;
1453   }
1454 
1455   void setArgs(std::vector<IValue> args) {
1456     args_ = std::move(args);
1457   }
1458 
1459   std::vector<IValue>& args() {
1460     return args_;
1461   }
1462 
1463  private:
1464   TypePtr elType_;
1465   TypePtr type_;
1466   std::vector<IValue> args_;
1467   std::function<IValue()> fn_;
1468   IValue value_;
1469   bool completed_{};
1470 };
1471 
1472 // Input is a list of Futures with the same target type.
1473 // Output is a Future to the List of completed Futures.
1474 TORCH_API intrusive_ptr<ivalue::Future> collectAll(
1475     const c10::List<c10::intrusive_ptr<ivalue::Future>>& srcs);
1476 // Input is a List of Futures with the same target type.
1477 // Output is a Future that will be updated with a seen value.
1478 TORCH_API intrusive_ptr<ivalue::Future> collectAny(
1479     const c10::List<c10::intrusive_ptr<ivalue::Future>>& srcs);
1480 
1481 // User-defined object.
1482 struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
1483  public:
1484   // In general, class types hold a shared_ptr to its owning CompilationUnit,
1485   // so that its type and methods do not get deallocated while the class exists.
1486   // However, the CompilationUnit holds ownership of the type's graphs, so
1487   // inserting a constant object into a Graph would create a reference cycle if
1488   // that constant object held a shared_ptr to its CU. For these objects we
1489   // instatiate them with non-owning references to its CU
1490   Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) {
1491     slots_.resize(numSlots);
1492   }
1493 
1494   Object(StrongTypePtr type, size_t numSlots)
1495       : type_(WeakOrStrongTypePtr(std::move(type))) {
1496     slots_.resize(numSlots);
1497   }
1498 
1499   static c10::intrusive_ptr<Object> create(
1500       WeakOrStrongTypePtr type,
1501       size_t numSlots) {
1502     return c10::make_intrusive<Object>(std::move(type), numSlots);
1503   }
1504 
1505   static c10::intrusive_ptr<Object> create(
1506       StrongTypePtr type,
1507       size_t numSlots) {
1508     return c10::make_intrusive<Object>(std::move(type), numSlots);
1509   }
1510 
1511   static c10::intrusive_ptr<Object> create(ClassTypePtr classType, size_t numSlots);
1512 
1513   /**
1514    * Slot API.
1515    *
1516    * Attributes are stored as a simple vector so that lookups are fast at
1517    * runtime. A "slot" is just an index into that vector, which can be computed
1518    * statically if you have access to the class type. Use this API if you are
1519    * writing compiler stuff.
1520    */
1521   void setSlot(size_t slot, IValue v) {
1522     if (slot >= slots_.size()) {
1523       // for module types, it is possible that the members of the class have
1524       // expanded after the object was created. In this case, we expand
1525       // the slots to the right size
1526       resizeObject(slot);
1527     }
1528     slots_[slot] = std::move(v);
1529   }
1530 
1531   const IValue& getSlot(size_t slot) const {
1532     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(slot < slots_.size());
1533     // NOTE: This lookup is fairly hot, so we use unchecked access to the
1534     // vector.  Errors should still be detectable with ASan.
1535     return slots_[slot];
1536   }
1537 
1538   void unsafeRemoveSlot(size_t slot) {
1539     TORCH_CHECK(slot < slots_.size());
1540     slots_.erase(slots_.begin() + static_cast<std::ptrdiff_t>(slot));
1541   }
1542 
1543   /**
1544    * Attribute API.
1545    *
1546    * Wrappers around the slot stuff so that users can access attributes
1547    * directly. Use this API if you are a user.
1548    *
1549    * Note: Unlike in Python, TorchScript must make a distinction between
1550    * attributes (which are IValues) and methods (which are Methods). If you
1551    * want a method, use `obj.type()->getMethod()`
1552    */
1553   IValue getAttr(const std::string& name) const;
1554   void setAttr(const std::string& name, IValue v);
1555   // Remove attribute by name, caller is responsible for
1556   // the safety of this operation
1557   // We didn't remove the attribute in the type because the type
1558   // might be shared by multiple objects.
1559   // Therefore after removing attribute, the object is in an inconsistent
1560   // state where it has more attribute types in its Type than
1561   // the attribute slots it has, user needs to make sure the object
1562   // has consistent by removing the attribute in type as well
1563   void unsafeRemoveAttr(const std::string& name);
1564 
1565   std::string name() const;
1566 
1567   const std::vector<IValue>& slots() const {
1568     return slots_;
1569   }
1570   std::shared_ptr<ClassType> type() const;
1571 
1572   std::shared_ptr<torch::jit::CompilationUnit> compilation_unit() {
1573     if (type_.holds_strong_ref()) {
1574       return type_.cu_.getStrongRefOrThrow();
1575     } else {
1576       auto weak_ptr = type_.cu_.getWeakRefOrThrow();
1577       return std::shared_ptr<torch::jit::CompilationUnit>(weak_ptr);
1578     }
1579   }
1580 
1581   c10::intrusive_ptr<Object> copy_to_weak_compilation_ref() const;
1582 
1583   void unsafe_make_weak_compilation_ref() {
1584     type_ = WeakOrStrongTypePtr(type_.asWeakTypePtr());
1585   }
1586 
1587   c10::intrusive_ptr<Object> copy() const;
1588 
1589   c10::intrusive_ptr<Object> deepcopy(
1590       std::optional<at::Device> device = std::nullopt) const;
1591 
1592   c10::intrusive_ptr<Object> deepcopy(
1593       IValue::HashIdentityIValueMap& memo,
1594       std::optional<at::Device> device = std::nullopt) const;
1595 
1596   bool is_weak_compilation_ref() const {
1597     return !type_.holds_strong_ref();
1598   }
1599 
1600   bool is_empty_strong_compilation_ref() const {
1601     return type_.holds_empty_strong_ref();
1602   }
1603 
1604  private:
1605   void resizeObject(size_t slot);
1606   WeakOrStrongTypePtr type_;
1607   std::vector<IValue> slots_;
1608 };
1609 
1610 // virtual ivalue PyObjectHolder that hold a py::object, we make this virtual
1611 // because the py::object and refcounting logic should happen in libtorch_python
1612 // see concrete implementation in python_ivalue.h
1613 struct ivalue::PyObjectHolder : c10::intrusive_ptr_target {
1614  public:
1615   virtual PyObject* getPyObject() = 0;
1616   virtual c10::InferredType tryToInferType() = 0;
1617   virtual IValue toIValue(const TypePtr& type, std::optional<int32_t> N = std::nullopt) = 0;
1618   virtual std::string toStr() = 0;
1619   virtual std::vector<at::Tensor> extractTensors() = 0;
1620 
1621   ~PyObjectHolder() override = default;
1622 };
1623 
1624 struct ivalue::EnumHolder : c10::intrusive_ptr_target {
1625  public:
1626   EnumHolder(std::shared_ptr<EnumType> type, std::string name, IValue value)
1627       : type_(std::move(type)),
1628         name_(std::move(name)),
1629         value_(std::move(value)) {}
1630 
1631   bool is(const ivalue::EnumHolder& rhs) {
1632     return *this == rhs;
1633   }
1634 
1635   friend bool operator==(
1636       const ivalue::EnumHolder& lhs,
1637       const ivalue::EnumHolder& rhs);
1638 
1639   TORCH_API friend std::ostream& operator<<(
1640       std::ostream& out,
1641       const ivalue::EnumHolder& v);
1642 
1643   TORCH_API const std::string& qualifiedClassName() const;
1644 
1645   const std::string& unqualifiedClassName() const;
1646 
1647   const std::string& name() const {
1648     return name_;
1649   }
1650 
1651   const IValue& value() const {
1652     return value_;
1653   }
1654 
1655   std::shared_ptr<EnumType> type() const {
1656     return type_;
1657   }
1658 
1659  private:
1660   std::shared_ptr<EnumType> type_;
1661   std::string name_;
1662   IValue value_;
1663 };
1664 
1665 #undef TORCH_FORALL_TAGS
1666 
1667 namespace detail {
1668 
1669 struct _guarded_unsigned_long_unique_dummy final {
1670   _guarded_unsigned_long_unique_dummy(int64_t){};
1671 };
1672 using _guarded_unsigned_long = std::conditional_t<
1673     std::is_same_v<unsigned long, uint32_t> ||
1674         std::is_same_v<unsigned long, uint64_t>,
1675     _guarded_unsigned_long_unique_dummy,
1676     unsigned long>;
1677 
1678 } // namespace detail
1679 
1680 inline ivalue::Object& IValue::toObjectRef() const {
1681   AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
1682   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference");
1683   return *static_cast<c10::ivalue::Object*>(payload.u.as_intrusive_ptr);
1684 }
1685 
1686 // note: when adding a DEFINE_TO case here you should also add a
1687 // toX method to IValue. These named methods are much more discoverable
1688 // than the to templated function.
1689 
1690 #define DEFINE_TO(T, method_name)                          \
1691   template <>                                              \
1692   inline T IValue::to<T>()&& {                             \
1693     return static_cast<T>(std::move(*this).method_name()); \
1694   }                                                        \
1695   template <>                                              \
1696   inline c10::detail::ivalue_to_const_ref_overload_return<T>::type IValue::to<T>() const& { \
1697     typedef c10::detail::ivalue_to_const_ref_overload_return<T>::type return_type;          \
1698     return static_cast<return_type>(this->method_name());                                   \
1699   }
1700 
1701 DEFINE_TO(at::Tensor, toTensor)
1702 DEFINE_TO(at::Storage, toStorage)
1703 DEFINE_TO(c10::Stream, toStream)
1704 DEFINE_TO(float, toDouble)
1705 DEFINE_TO(double, toDouble)
1706 DEFINE_TO(c10::complex<double>, toComplexDouble)
1707 DEFINE_TO(unsigned char, toInt)
1708 DEFINE_TO(signed char, toInt)
1709 DEFINE_TO(unsigned short, toInt)
1710 DEFINE_TO(short, toInt)
1711 DEFINE_TO(int, toInt)
1712 DEFINE_TO(uint32_t, toInt)
1713 DEFINE_TO(uint64_t, toInt)
1714 DEFINE_TO(detail::_guarded_unsigned_long, toInt)
1715 DEFINE_TO(int64_t, toInt)
1716 DEFINE_TO(bool, toBool)
1717 DEFINE_TO(c10::intrusive_ptr<caffe2::Blob>, toBlob);
1718 DEFINE_TO(c10::intrusive_ptr<ivalue::ConstantString>, toString)
1719 DEFINE_TO(c10::intrusive_ptr<ivalue::Object>, toObject)
1720 DEFINE_TO(at::Scalar, toScalar)
1721 DEFINE_TO(c10::List<int64_t>, toIntList)
1722 DEFINE_TO(c10::List<double>, toDoubleList)
1723 DEFINE_TO(c10::List<c10::complex<double>>, toComplexDoubleList)
1724 DEFINE_TO(c10::List<bool>, toBoolList)
1725 DEFINE_TO(c10::List<at::Tensor>, toTensorList)
1726 DEFINE_TO(c10::impl::GenericList, toList)
1727 DEFINE_TO(c10::impl::GenericDict, toGenericDict)
1728 DEFINE_TO(c10::intrusive_ptr<ivalue::Tuple>, toTuple)
1729 DEFINE_TO(std::string, toStringRef)
1730 DEFINE_TO(c10::string_view, toStringView)
1731 DEFINE_TO(c10::intrusive_ptr<ivalue::Future>, toFuture)
1732 DEFINE_TO(c10::intrusive_ptr<ivalue::Await>, toAwait)
1733 DEFINE_TO(c10::intrusive_ptr<c10::RRefInterface>, toRRef)
1734 DEFINE_TO(c10::intrusive_ptr<at::Quantizer>, toQuantizer)
1735 DEFINE_TO(IValue, toIValue)
1736 DEFINE_TO(c10::Device, toDevice)
1737 DEFINE_TO(at::ScalarType, toScalarType)
1738 DEFINE_TO(at::Layout, toLayout)
1739 DEFINE_TO(at::MemoryFormat, toMemoryFormat)
1740 DEFINE_TO(at::QScheme, toQScheme)
1741 DEFINE_TO(at::Dimname, toDimname)
1742 DEFINE_TO(at::Generator, toGenerator)
1743 DEFINE_TO(c10::SymInt, toSymInt)
1744 DEFINE_TO(c10::SymFloat, toSymFloat)
1745 DEFINE_TO(c10::SymBool, toSymBool)
1746 
1747 template <class T>
1748 struct _fake_type {};
1749 
1750 // generic_to<T> converts an IValue from a generic list or generic dict
1751 // to a concrete list/dict type likelike List<T>, Dict<...> or std::optional<T>.
1752 // Note that in the case of lists, this only works for IValue-based lists,
1753 // i.e. not for int64_t, double, ...
1754 // generic_to<T> is an implementation detail of IValue::to<T> and not
1755 // supposed to be called directly.
1756 // The _fake_type<T> parameter allows us to overload
1757 // based on the return type.
1758 template <class Elem>
1759 // TODO this is deprecated but we don't throw a warning because a lot of ops in
1760 // native_functions.yaml still return std::vector.
1761 // C10_DEPRECATED_MESSAGE("IValues based on std::vector<T> are potentially slow
1762 // and deprecated. Please use torch::List<T> instead.")
1763 std::vector<Elem> generic_to(IValue ivalue, _fake_type<std::vector<Elem>>) {
1764   // We need to do a deep copy of the vector because there might be other
1765   // references to this same IValue that also use the list. We can't just
1766   // move the elements out.
1767   auto list = std::move(ivalue).to<List<Elem>>();
1768   std::vector<Elem> result;
1769   result.reserve(list.size());
1770   for (Elem v : list) {
1771     result.push_back(std::move(v));
1772   }
1773   return result;
1774 }
1775 
1776 template <typename T>
1777 c10::intrusive_ptr<T> IValue::toCustomClass() && {
1778   static_assert(
1779       std::is_base_of<torch::CustomClassHolder, T>::value == true,
1780       "toCustomClass requires that template parameter T must inherit "
1781       "from torch::CustomClassHolder");
1782   auto obj = toObject();
1783   TORCH_CHECK(
1784       obj->slots().size() == 1,
1785       "Tried to cast IValue to custom class but it did "
1786       "not contain a custom class!");
1787   const auto* expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>().get();
1788   ivalue::checkCustomClassType(expected_type, type().get());
1789   auto userObj =
1790       c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
1791   return userObj;
1792 }
1793 
1794 template <typename T>
1795 c10::intrusive_ptr<T> IValue::toCustomClass() const& {
1796   static_assert(
1797       std::is_base_of<torch::CustomClassHolder, T>::value == true,
1798       "toCustomClass requires that template parameter T must inherit "
1799       "from torch::CustomClassHolder");
1800   auto obj = toObject();
1801   TORCH_CHECK(
1802       obj->slots().size() == 1,
1803       "Tried to cast IValue to custom class but it did "
1804       "not contain a custom class!");
1805   const auto* expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>().get();
1806   ivalue::checkCustomClassType(expected_type, type().get());
1807   auto userObj =
1808       c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
1809   return userObj;
1810 }
1811 
1812 template <typename T>
1813 T generic_to(IValue ivalue, _fake_type<T>) {
1814   using ElemType = typename std::remove_pointer<T>::type::element_type;
1815   return std::move(ivalue).toCustomClass<ElemType>();
1816 }
1817 
1818 template <typename T>
1819 tagged_capsule<T> generic_to(IValue ivalue, _fake_type<tagged_capsule<T>>) {
1820   return tagged_capsule<T>{std::move(ivalue)};
1821 }
1822 
1823 template <typename Elem>
1824 c10::List<Elem> generic_to(IValue ivalue, _fake_type<c10::List<Elem>>) {
1825   return impl::toTypedList<Elem>(std::move(ivalue).toList());
1826 }
1827 
1828 template <typename T>
1829 static T createVectorLikeFromList(const c10::detail::ListImpl* impl) {
1830   T result;
1831   result.reserve(impl->list.size());
1832   for (const auto & i : impl->list) {
1833     result.push_back(i.to<typename T::value_type>());
1834   }
1835   return result;
1836 }
1837 
1838 template <typename T>
1839 static std::vector<T> createVectorFromList(const c10::detail::ListImpl* impl) {
1840   return createVectorLikeFromList<std::vector<T>>(impl);
1841 }
1842 
1843 template <typename T>
1844 std::vector<T> createVectorFromList(const c10::List<T>& impl) {
1845   std::vector<T> result;
1846   result.reserve(impl.size());
1847   for (size_t i = 0, N = impl.size(); i < N; ++i) {
1848     result.push_back(impl[i]);
1849   }
1850   return result;
1851 }
1852 
1853 template <typename T>
1854 OptionalArray<T> generic_to(IValue ivalue, _fake_type<OptionalArray<T>>) {
1855   if (ivalue.isNone()) {
1856     return {};
1857   }
1858   return createVectorFromList<T>(
1859     std::move(ivalue).to<c10::List<T>>()
1860   );
1861 }
1862 
1863 namespace detail {
1864 template <typename Elem, size_t... I>
1865 std::array<Elem, sizeof...(I)> generic_to_array(
1866     IValue ivalue,
1867     _fake_type<std::array<Elem, sizeof...(I)>>,
1868     std::index_sequence<I...>) {
1869   // We need to do a deep copy of the array because there might be other
1870   // references to this same IValue that also use the list. We can't just
1871   // move the elements out.
1872   auto list = std::move(ivalue).to<List<Elem>>();
1873   TORCH_CHECK(
1874       list.size() == sizeof...(I),
1875       "Tried to convert a List with ",
1876       list.size(),
1877       " elements to a fixed-size array of size ",
1878       sizeof...(I));
1879   return {list[I]...};
1880 }
1881 } // namespace detail
1882 
1883 template <typename Elem, size_t N>
1884 std::array<Elem, N> generic_to(
1885     IValue ivalue,
1886     _fake_type<std::array<Elem, N>> ft) {
1887   return detail::generic_to_array(ivalue, ft, std::make_index_sequence<N>());
1888 }
1889 
1890 template <typename Key, typename Value>
1891 c10::Dict<Key, Value> generic_to(
1892     IValue ivalue,
1893     _fake_type<c10::Dict<Key, Value>>) {
1894   return impl::toTypedDict<Key, Value>(std::move(ivalue).toGenericDict());
1895 }
1896 
1897 template <typename K, typename V>
1898 C10_DEPRECATED_MESSAGE(
1899     "IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict<K, V> instead.")
1900 std::unordered_map<K, V> generic_to(
1901     IValue ivalue,
1902     _fake_type<std::unordered_map<K, V>>) {
1903   std::unordered_map<K, V> specialized_dict;
1904 
1905   for (const auto& item : std::move(ivalue).toGenericDict()) {
1906     specialized_dict[item.key().template to<K>()] = item.value().template to<V>();
1907   }
1908 
1909   return specialized_dict;
1910 }
1911 
1912 template <typename T>
1913 std::optional<T> generic_to(IValue ivalue, _fake_type<std::optional<T>>) {
1914   if (ivalue.isNone()) {
1915     return std::nullopt;
1916   }
1917   return std::move(ivalue).to<T>();
1918 }
1919 
1920 namespace detail {
1921 template <typename Tuple, std::size_t... INDEX>
1922 Tuple generic_to_tuple_impl(
1923     const ivalue::TupleElements& t,
1924     std::index_sequence<INDEX...>) {
1925   return std::make_tuple(
1926       t[INDEX].to<typename std::tuple_element<INDEX, Tuple>::type>()...);
1927 }
1928 } // namespace detail
1929 
1930 template <
1931     typename... Args,
1932     typename Indices = std::make_index_sequence<sizeof...(Args)>,
1933     std::enable_if_t<
1934         !std::disjunction_v<
1935             std::is_lvalue_reference<Args>...,
1936             std::negation<std::is_constructible<IValue, Args>>...>,
1937         std::nullptr_t> = nullptr>
1938 std::tuple<Args...> generic_to(const IValue& ivalue, _fake_type<std::tuple<Args...>>) {
1939   const auto& vals = ivalue.toTupleRef().elements();
1940   TORCH_CHECK(vals.size() == sizeof...(Args));
1941   return detail::generic_to_tuple_impl<std::tuple<Args...>>(vals, Indices{});
1942 }
1943 
1944 template <typename T>
1945 inline T IValue::to() && {
1946   return generic_to(std::move(*this), _fake_type<T>{});
1947 }
1948 
1949 template <>
1950 inline std::optional<c10::string_view> IValue::to() && {
1951   // In the default implementation, the IValue is destroyed with std::move.
1952   // But if the unboxed type is std::optional<string_view> we cannot destroy
1953   // the IValue.
1954   return generic_to(*this, _fake_type<std::optional<c10::string_view>>{});
1955 }
1956 
1957 template <typename T>
1958 inline typename c10::detail::ivalue_to_const_ref_overload_return<T>::type IValue::to() const& {
1959   return generic_to(*this, _fake_type<T>{});
1960 }
1961 
1962 inline c10::List<int64_t> IValue::toIntList() && {
1963   AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1964   return c10::List<int64_t>(moveToIntrusivePtr<c10::detail::ListImpl>());
1965 }
1966 inline c10::List<int64_t> IValue::toIntList() const& {
1967   AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1968   return c10::List<int64_t>(toIntrusivePtr<c10::detail::ListImpl>());
1969 }
1970 inline std::vector<int64_t> IValue::toIntVector() const {
1971   AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1972   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1973       payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
1974       "called toIntVector on null intrusive_ptr IValue");
1975   return createVectorFromList<int64_t>(
1976       static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
1977 }
1978 inline std::vector<c10::SymInt> IValue::toSymIntVector() const {
1979   AT_ASSERT(isSymIntList() || isIntList(), "Expected SymIntList or IntList but got ", tagKind());
1980   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1981       payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
1982       "called toSymIntVector on null intrusive_ptr IValue");
1983   return createVectorFromList<c10::SymInt>(
1984       static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
1985 }
1986 inline at::DimVector IValue::toDimVector() const {
1987   AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1988   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1989       payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
1990       "called toDimVector on null intrusive_ptr IValue");
1991   return createVectorLikeFromList<at::DimVector>(
1992       static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
1993 }
1994 inline c10::List<double> IValue::toDoubleList() && {
1995   AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
1996   return c10::List<double>(moveToIntrusivePtr<c10::detail::ListImpl>());
1997 }
1998 inline c10::List<double> IValue::toDoubleList() const& {
1999   AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
2000   return c10::List<double>(toIntrusivePtr<c10::detail::ListImpl>());
2001 }
2002 inline std::vector<double> IValue::toDoubleVector() const {
2003   AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
2004   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2005       payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2006       "called toDoubleVector on null intrusive_ptr IValue");
2007   return createVectorFromList<double>(
2008       static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
2009 }
2010 inline c10::List<c10::complex<double>> IValue::toComplexDoubleList() && {
2011   AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
2012   return c10::List<c10::complex<double>>(moveToIntrusivePtr<c10::detail::ListImpl>());
2013 }
2014 inline c10::List<c10::complex<double>> IValue::toComplexDoubleList() const& {
2015   AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
2016   return c10::List<c10::complex<double>>(toIntrusivePtr<c10::detail::ListImpl>());
2017 }
2018 inline std::vector<c10::complex<double>> IValue::toComplexDoubleVector() const {
2019   AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
2020   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2021       payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2022       "called toComplexDoubleVector on null intrusive_ptr IValue");
2023   return createVectorFromList<c10::complex<double>>(
2024       static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
2025 }
2026 inline c10::List<bool> IValue::toBoolList() && {
2027   AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind());
2028   return c10::List<bool>(moveToIntrusivePtr<c10::detail::ListImpl>());
2029 }
2030 inline c10::List<bool> IValue::toBoolList() const& {
2031   AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind());
2032   return c10::List<bool>(toIntrusivePtr<c10::detail::ListImpl>());
2033 }
2034 inline c10::List<at::Tensor> IValue::toTensorList() && {
2035   AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
2036   return c10::List<at::Tensor>(moveToIntrusivePtr<c10::detail::ListImpl>());
2037 }
2038 inline c10::List<at::Tensor> IValue::toTensorList() const& {
2039   AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
2040   return c10::List<at::Tensor>(toIntrusivePtr<c10::detail::ListImpl>());
2041 }
2042 inline std::vector<at::Tensor> IValue::toTensorVector() const {
2043   AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
2044   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2045       payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2046       "called toTensorVector on null intrusive_ptr IValue");
2047   return createVectorFromList<at::Tensor>(
2048       static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
2049 }
2050 inline c10::List<std::optional<at::Tensor>> IValue::toOptionalTensorList() && {
2051   AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
2052   return c10::List<std::optional<at::Tensor>>(moveToIntrusivePtr<c10::detail::ListImpl>());
2053 }
2054 inline c10::List<std::optional<at::Tensor>> IValue::toOptionalTensorList() const& {
2055   AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
2056   return c10::List<std::optional<at::Tensor>>(toIntrusivePtr<c10::detail::ListImpl>());
2057 }
2058 inline std::vector<std::optional<at::Tensor>> IValue::toOptionalTensorVector() const {
2059   AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
2060   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2061       payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2062       "called toOptionalTensorVector on null intrusive_ptr IValue");
2063   return createVectorFromList<std::optional<at::Tensor>>(
2064       static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
2065 }
2066 inline c10::List<IValue> IValue::toList() && {
2067   AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
2068   return c10::List<IValue>(moveToIntrusivePtr<c10::detail::ListImpl>());
2069 }
2070 inline c10::List<IValue> IValue::toList() const& {
2071   AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
2072   return c10::List<IValue>(toIntrusivePtr<c10::detail::ListImpl>());
2073 }
2074 inline c10::ArrayRef<IValue> IValue::toListRef() const {
2075   AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
2076   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2077       payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2078       "called toListRef on null intrusive_ptr IValue");
2079   return static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr)
2080       ->list;
2081 }
2082 inline c10::Dict<IValue, IValue> IValue::toGenericDict() && {
2083   AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind());
2084   return c10::Dict<IValue, IValue>(moveToIntrusivePtr<c10::detail::DictImpl>());
2085 }
2086 inline c10::Dict<IValue, IValue> IValue::toGenericDict() const& {
2087   AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind());
2088   return c10::Dict<IValue, IValue>(toIntrusivePtr<c10::detail::DictImpl>());
2089 }
2090 inline c10::intrusive_ptr<ivalue::Tuple> IValue::toTuple() && {
2091   AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
2092   return moveToIntrusivePtr<ivalue::Tuple>();
2093 }
2094 inline c10::intrusive_ptr<ivalue::Tuple> IValue::toTuple() const& {
2095   AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
2096   return toIntrusivePtr<ivalue::Tuple>();
2097 }
2098 inline ivalue::Tuple& IValue::toTupleRef() const {
2099   AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
2100   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2101       payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2102       "called toTupleRef on null intrusive_ptr IValue");
2103   return *static_cast<c10::ivalue::Tuple*>(
2104       payload.u.as_intrusive_ptr);
2105 }
2106 
2107 inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v)
2108     : tag(Tag::Tuple) {
2109   payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2110 }
2111 template <
2112     typename... Args,
2113     std::enable_if_t<
2114         !std::disjunction_v<
2115             std::is_lvalue_reference<Args>...,
2116             std::negation<std::is_constructible<IValue, Args>>...>,
2117         std::nullptr_t>>
2118 inline IValue::IValue(const std::tuple<Args...>& t)
2119     : IValue(c10::guts::apply(c10::ivalue::Tuple::create<const Args&...>, t)) {
2120 }
2121 
2122 template <
2123     typename... Args,
2124     std::enable_if_t<
2125         !std::disjunction_v<
2126             std::is_lvalue_reference<Args>...,
2127             std::negation<std::is_constructible<IValue, Args>>...>,
2128         std::nullptr_t>>
2129 inline IValue::IValue(std::tuple<Args...>&& t)
2130     : IValue(c10::guts::apply(c10::ivalue::Tuple::create<Args&&...>, std::move(t))) {
2131 }
2132 
2133 inline IValue::IValue(c10::intrusive_ptr<ivalue::ConstantString> v)
2134     : tag(Tag::String) {
2135   payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2136 }
2137 inline IValue::IValue(std::string v)
2138     : IValue(ivalue::ConstantString::create(std::move(v))) {}
2139 
2140 inline IValue::IValue(c10::impl::GenericList v)
2141     : tag(Tag::GenericList) {
2142   payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release());
2143 }
2144 
2145 template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2146 inline IValue::IValue(c10::List<T>&& v) : IValue(impl::toList<T>(std::move(v))) {}
2147 template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2148 inline IValue::IValue(const c10::List<T>& v) : IValue(impl::toList<T>(v)) {}
2149 template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2150 inline IValue::IValue(at::ArrayRef<T> v) : IValue(c10::List<T>()) {
2151   auto list = to<c10::List<T>>();
2152   list.reserve(v.size());
2153   for (const auto& e : v) {
2154     list.push_back(e);
2155   }
2156 }
2157 template <class T, IValue::enable_if_symint<T>>
2158 inline IValue::IValue(at::ArrayRef<T> v) : IValue() {
2159   auto vi = c10::asIntArrayRefSlowOpt(v);
2160   if (vi.has_value()) {
2161     // This list is entirely integers; ensure it is typed as
2162     // an IntList so toIntList works
2163     *this = IValue(*vi);
2164   } else {
2165     // This list has SymInts; type it as a SymInt
2166     *this = IValue(impl::toList<c10::SymInt>(c10::List<c10::SymInt>()));
2167     auto list = to<c10::List<c10::SymInt>>();
2168     list.reserve(v.size());
2169     for (const auto& e : v) {
2170       list.push_back(e);
2171     }
2172   }
2173 }
2174 template <class T, IValue::enable_if_symint<T>>
2175 inline IValue::IValue(at::OptionalArrayRef<T> mb_v) : IValue() {
2176   if (!mb_v.has_value()) return;
2177   *this = IValue(*mb_v);
2178 }
2179 template <class T, IValue::enable_if_symint<T>>
2180 inline IValue::IValue(const std::vector<T>& v) : IValue() {
2181   *this = IValue(at::ArrayRef<T>(v));
2182 }
2183 template <class T, IValue::enable_if_symint<T>>
2184 inline IValue::IValue(std::vector<T>&& v) : IValue() {
2185   auto vi = c10::asIntArrayRefSlowOpt(v);
2186   if (vi.has_value()) {
2187     // This list is entirely integers; ensure it is typed as
2188     // an IntList so toIntList works
2189     *this = IValue(*vi);
2190   } else {
2191     // This list has SymInts; type it as a SymInt
2192     *this = IValue(impl::toList<c10::SymInt>(c10::List<c10::SymInt>()));
2193     auto list = to<c10::List<c10::SymInt>>();
2194     list.reserve(v.size());
2195     for (auto&& e : std::move(v)) {
2196       list.push_back(std::move(e));
2197     }
2198   }
2199 }
2200 template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2201 inline IValue::IValue(const std::vector<T>& v) : IValue(c10::List<T>()) {
2202   auto list = to<c10::List<T>>();
2203   list.reserve(v.size());
2204   for (const auto& e : v) {
2205     list.push_back(e);
2206   }
2207 }
2208 
2209 template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2210 inline IValue::IValue(std::vector<T>&& v) : IValue(c10::List<T>()) {
2211   auto list = to<c10::List<T>>();
2212   list.reserve(v.size());
2213   if constexpr (std::is_same_v<T, bool>) {
2214     for (auto e : v) {
2215       list.push_back(e);
2216     }
2217   } else {
2218     for (auto&& e : std::move(v)) {
2219       list.push_back(std::move(e));
2220     }
2221   }
2222 }
2223 
2224 template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2225 inline IValue::IValue(c10::OptionalArrayRef<T> v) : IValue() {
2226   if (v.has_value()) {
2227     *this = IValue(std::move(*v));
2228   }
2229 }
2230 
2231 template <class T, size_t N>
2232 inline IValue::IValue(std::array<T, N> v) : IValue(c10::List<T>()) {
2233   auto list = to<c10::List<T>>();
2234   list.reserve(v.size());
2235   for (auto& e : v) {
2236     list.push_back(std::move(e));
2237   }
2238 }
2239 
2240 template <class T, IValue::enable_if_ilist_is_ivalue_constructible<T>>
2241 inline IValue::IValue(c10::IListRef<T> v) : IValue() {
2242   constexpr bool boxed_type_constructs_ivalue =
2243       std::is_constructible<IValue, typename c10::IListRef<T>::boxed_type>::value;
2244   // First, we try to use the boxed value.
2245   // If we fail (either it's not in the boxed state, or its boxed type
2246   // can not construct an IValue), we fallback to copying the list.
2247   if (boxed_type_constructs_ivalue && v.isBoxed()) {
2248     *this = IValue(impl::toList(v.toBoxed()));
2249   } else {
2250     c10::List<T> list;
2251     list.reserve(v.size());
2252     for (const auto& t : v) {
2253       list.push_back(t);
2254     }
2255     *this = IValue(impl::toList(std::move(list)));
2256   }
2257 }
2258 
2259 inline IValue::IValue(c10::impl::GenericDict v)
2260     : tag(Tag::GenericDict) {
2261   payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release());
2262 }
2263 template <class Key, class Value>
2264 inline IValue::IValue(c10::Dict<Key, Value> v)
2265     : IValue(impl::toGenericDict(std::move(v))) {}
2266 
2267 template <class Key, class Value>
2268 inline IValue::IValue(std::unordered_map<Key, Value> v)
2269     : IValue(Dict<Key, Value>()) {
2270   auto dict = to<c10::Dict<Key, Value>>();
2271   dict.reserve(v.size());
2272   for (auto& e : v) {
2273     dict.insert(std::move(e.first), std::move(e.second));
2274   }
2275 }
2276 
2277 template <class T, IValue::enable_if_ivalue_constructible<T>>
2278 inline IValue::IValue(std::optional<T> v) : IValue() {
2279   if (v.has_value()) {
2280     *this = IValue(std::move(*v));
2281   }
2282 }
2283 
2284 inline IValue::IValue(std::nullopt_t) : IValue() {}
2285 
2286 inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v)
2287     : tag(Tag::Object) {
2288   payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2289 }
2290 
2291 inline IValue::IValue(c10::intrusive_ptr<ivalue::PyObjectHolder> v)
2292     : tag(Tag::PyObject) {
2293   payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2294 }
2295 
2296 inline IValue::IValue(c10::intrusive_ptr<ivalue::EnumHolder> v)
2297     : tag(Tag::Enum) {
2298   payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2299 }
2300 
2301 inline IValue IValue::make_capsule(
2302     intrusive_ptr<torch::CustomClassHolder> blob) {
2303   IValue iv;
2304   iv.tag = Tag::Capsule;
2305   iv.payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release());
2306   return iv;
2307 }
2308 
2309 template <
2310     typename T,
2311     std::enable_if_t<std::is_base_of_v<torch::CustomClassHolder, T>, int>>
2312 IValue::IValue(c10::intrusive_ptr<T> custom_class) : tag(Tag::Object) {
2313   auto classType = []() {
2314     try {
2315       return c10::getCustomClassType<c10::intrusive_ptr<T>>();
2316     } catch (const c10::Error&) {
2317       throw c10::Error(
2318           "Trying to instantiate a class that isn't a registered custom class: " +
2319           std::string(c10::util::get_fully_qualified_type_name<T>()));
2320     }
2321   }();
2322   auto ivalue_obj = c10::ivalue::Object::create(std::move(classType), /* numSlots */1);
2323   ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class)));
2324   payload.u.as_intrusive_ptr = null_to_undefined_tensor(ivalue_obj.release());
2325 
2326 }
2327 
2328 inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
2329     : tag(Tag::Future) {
2330   payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2331 }
2332 
2333 inline IValue::IValue(c10::intrusive_ptr<ivalue::Await> v)
2334     : tag(Tag::Await) {
2335   payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2336 }
2337 
2338 inline IValue::IValue(c10::intrusive_ptr<c10::RRefInterface> v)
2339     : tag(Tag::RRef) {
2340   payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2341 }
2342 
2343 inline IValue::IValue(c10::intrusive_ptr<at::Quantizer> v)
2344     : tag(Tag::Quantizer) {
2345   payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2346 }
2347 
2348 template <typename T>
2349 inline IValue::IValue(c10::complex<T> c)
2350     : tag(Tag::ComplexDouble) {
2351   auto v = c10::make_intrusive<ivalue::ComplexHolder>(c);
2352   payload.u.as_intrusive_ptr = v.release();
2353 }
2354 
2355 inline const std::string& IValue::toStringRef() const {
2356   AT_ASSERT(isString(), "Expected String but got ", tagKind());
2357   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2358       payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2359       "called toStringRef on null intrusive_ptr IValue");
2360   return static_cast<const c10::ivalue::ConstantString*>(
2361              payload.u.as_intrusive_ptr)
2362       ->string();
2363 }
2364 inline std::optional<std::reference_wrapper<const std::string>> IValue::
2365     toOptionalStringRef() const {
2366   if (isNone()) {
2367     return std::nullopt;
2368   }
2369   AT_ASSERT(isString(), "Expected std::optional<string> but got ", tagKind());
2370   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2371       payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2372       "called toOptionalStringRef on null intrusive_ptr IValue");
2373   return std::reference_wrapper<const std::string>(
2374       static_cast<const c10::ivalue::ConstantString*>(payload.u.as_intrusive_ptr)
2375           ->string());
2376 }
2377 
2378 inline c10::string_view IValue::toStringView() const {
2379   AT_ASSERT(isString(), "Expected String but got ", tagKind());
2380   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2381       payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2382       "called toStringView on null intrusive_ptr IValue");
2383   return static_cast<const c10::ivalue::ConstantString*>(
2384         payload.u.as_intrusive_ptr)
2385     ->string_view();
2386 }
2387 
2388 inline PyObject* IValue::toPyObject() const {
2389   return toPyObjectHolder()->getPyObject();
2390 }
2391 
2392 template <typename T>
2393 inline std::optional<T> IValue::toOptional() {
2394   if (this->isNone()) {
2395     return std::nullopt;
2396   }
2397   return this->to<T>();
2398 }
2399 
2400 template <typename T>
2401 inline std::optional<T> IValue::toOptional() const {
2402   if (this->isNone()) {
2403     return std::nullopt;
2404   }
2405   return this->to<T>();
2406 }
2407 
2408 inline bool IValue::isCustomClass() const {
2409   return torch::isCustomClass(*this);
2410 }
2411 
2412 inline bool IValue::isSameIdentity(const IValue& rhs) const {
2413   // We choose to not use memcmp for payload check due to potential random
2414   // padding characters on union type
2415 
2416   // Semantics:
2417   // 1. Immutable primitive values of the same type (Int, Double, None, Bool,
2418   // Str) return value equality
2419   // 2. If it is a tensor type, we need to take undefined tensor into account
2420   // 3. Undefined_tensor is None and vice versa should be true
2421   // 4. If it is a reference type (i.e. isIntrusivePtr()), then is True when
2422   // the pointed-to object is the same.
2423   // 5. False for all other comparisons.
2424   if (this->isNone() && rhs.isNone()) {
2425     return true;
2426   } else if (this->isBool() && rhs.isBool()) {
2427     // for bool type, do equality check
2428     return this->toBool() == rhs.toBool();
2429   } else if (this->isTensor() && rhs.isTensor()) {
2430     return this->payload.as_tensor.is_same(rhs.payload.as_tensor);
2431   } else if (this->isTensor() && rhs.isNone()) {
2432     // special case: undefined tensor and None are the same identity
2433     return !this->payload.as_tensor.defined();
2434   } else if (this->isNone() && rhs.isTensor()) {
2435     // special case: undefined tensor and None are the same identity
2436     return !rhs.payload.as_tensor.defined();
2437   } else if (this->isInt() && rhs.isInt()) {
2438     return this->toInt() == rhs.toInt();
2439   } else if (this->isDouble() && rhs.isDouble()) {
2440     return this->toDouble() == rhs.toDouble();
2441   } else if (this->isString() && rhs.isString()) {
2442     return this->toStringRef() == rhs.toStringRef();
2443   } else {
2444     // for objects holding in IValue, do shallow compare on pointer address to
2445     // testify the identity
2446     return this->isIntrusivePtr() && rhs.isIntrusivePtr() &&
2447         this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
2448   }
2449 }
2450 
2451 namespace ivalue {
2452 namespace detail {
2453 
2454 template <typename T>
2455 IValue from_(T&& x, std::true_type) {
2456   return IValue(std::forward<T>(x));
2457 }
2458 template <typename T>
2459 IValue from_(c10::intrusive_ptr<T> x, std::false_type) {
2460   return IValue(std::move(x));
2461 }
2462 template <typename T>
2463 IValue from_(T&& /*x*/, std::false_type) {
2464   static_assert(
2465       guts::false_t<T>::value,
2466       "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)");
2467   return IValue();
2468 }
2469 } // namespace detail
2470 
2471 template <typename T>
2472 IValue from(T&& x) {
2473   return detail::from_(
2474       std::forward<T>(x), typename std::is_constructible<IValue, T>::type{});
2475 }
2476 
2477 } // namespace ivalue
2478 
2479 
2480 template <>
2481 struct MaybeOwnedTraits<IValue> {
2482   using owned_type = IValue;
2483   using borrow_type = IValue;
2484 
2485   static borrow_type createBorrow(const owned_type& from) {
2486     if (!from.isPtrType()) {
2487       return from;
2488     }
2489     if (from.isTensor()) {
2490       return IValue(MaybeOwnedTraits<at::Tensor>::createBorrow(from.toTensor()));
2491     } else {
2492       return IValue(from.payload, from.tag);
2493     }
2494   }
2495 
2496   static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
2497     lhs.clearToNone();
2498     if (!rhs.isPtrType()) {
2499       lhs = rhs;
2500     } else if (rhs.isTensor()) {
2501       lhs = IValue(MaybeOwnedTraits<at::Tensor>::createBorrow(rhs.toTensor()));
2502     } else {
2503       lhs = IValue(rhs.payload, rhs.tag);
2504     }
2505   }
2506 
2507   static void destroyBorrow(borrow_type& toDestroy) {
2508     toDestroy.clearToNone();
2509   }
2510 
2511   static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
2512     return borrow;
2513   }
2514 
2515   static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
2516     return &borrow;
2517   }
2518 
2519   static bool debugBorrowIsValid(const borrow_type&) {
2520     return true;
2521   }
2522 };
2523 
2524 template <>
2525 struct IValue::TagType<c10::Type> {
2526   static TORCH_API c10::TypePtr get(const IValue&);
2527 };
2528 
2529 template <>
2530 struct IValue::TagType<c10::DynamicType> {
2531   static TORCH_API c10::TypePtr get(const IValue&);
2532 };
2533 
2534 template <typename T>
2535 TypePtr IValue::type() const {
2536   return IValue::TagType<T>::get(*this);
2537 }
2538 
2539 } // namespace c10
2540