1 #pragma once
2
3 #include <condition_variable>
4 #include <memory>
5 #include <optional>
6 #include <type_traits>
7 #include <utility>
8
9 #include <ATen/core/Dict.h>
10 #include <ATen/core/List.h>
11 #include <ATen/core/IListRef.h>
12 #include <ATen/core/functional.h>
13 #include <ATen/core/jit_type.h>
14 #include <ATen/core/qualified_name.h>
15 #include <ATen/core/rref_interface.h>
16 #include <ATen/core/symbol.h>
17 #include <c10/core/DeviceGuard.h>
18 #include <c10/core/Event.h>
19 #include <c10/core/Scalar.h>
20 #include <c10/core/Stream.h>
21 #include <c10/core/StreamGuard.h>
22 #include <c10/core/TensorImpl.h>
23 #include <c10/core/UndefinedTensorImpl.h>
24 #include <c10/core/impl/DeviceGuardImplInterface.h>
25 #include <c10/util/FunctionRef.h>
26 #include <c10/util/Logging.h>
27 #include <c10/util/hash.h>
28 #include <c10/util/intrusive_ptr.h>
29 #include <c10/util/irange.h>
30
31 namespace torch {
32 namespace jit {
33 struct Function;
34 struct CompilationUnit;
35 } // namespace jit
36 TORCH_API bool isCustomClass(const c10::IValue& v);
37 } // namespace torch
38 namespace c10 {
39 struct IValue;
40 struct ClassType;
41 struct TupleType;
42 struct EnumType;
43 struct InferredType;
44
45 // For custom class __init__ registration, we need to pass in a function
46 // that looks like this: [](IValue x, args...)
47
48 // However, make_boxed_from_unboxed_functor.h automatically sets the input types
49 // of the function by introspecting the types of the functor (which is IValue in
50 // this case). However, we need the type it binds to be Foo.
51
52 // Instead, we pass in a lambda [](ivalue_holder<CurClass> x, args...) from
53 // which getTypePtr can recover the original class pointer.
54
55 template <typename TaggedCapsuleType>
56 struct tagged_capsule {
57 IValue ivalue;
58 };
59
60 template <class T, class NullType>
moveToIntrusivePtr()61 c10::intrusive_ptr<T, NullType> IValue::moveToIntrusivePtr() {
62 auto t = c10::intrusive_ptr<T, NullType>::reclaim(
63 payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()
64 ? NullType::singleton()
65 : static_cast<T*>(payload.u.as_intrusive_ptr));
66 clearToNone();
67 return t;
68 }
69 template <typename T, class NullType>
toIntrusivePtr()70 c10::intrusive_ptr<T, NullType> IValue::toIntrusivePtr() const {
71 if (payload.u.as_intrusive_ptr == c10::UndefinedTensorImpl::singleton()) {
72 return c10::intrusive_ptr<T, NullType>();
73 }
74 c10::raw::intrusive_ptr::incref(payload.u.as_intrusive_ptr);
75 return c10::intrusive_ptr<T, NullType>::reclaim(
76 static_cast<T*>(payload.u.as_intrusive_ptr));
77 }
78
79 template <class T, class U>
static_intrusive_pointer_cast(intrusive_ptr<U> r)80 intrusive_ptr<T> static_intrusive_pointer_cast(intrusive_ptr<U> r) {
81 return intrusive_ptr<T>::reclaim(static_cast<T*>(r.release()));
82 }
83
84 template <class T, class U>
dynamic_intrusive_pointer_cast(intrusive_ptr<U> r)85 intrusive_ptr<T> dynamic_intrusive_pointer_cast(intrusive_ptr<U> r) {
86 return intrusive_ptr<T>::reclaim(dynamic_cast<T*>(r.release()));
87 }
88
toFuture()89 inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() && {
90 AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
91 return moveToIntrusivePtr<ivalue::Future>();
92 }
toFuture()93 inline c10::intrusive_ptr<ivalue::Future> IValue::toFuture() const& {
94 AT_ASSERT(isFuture(), "Expected Future but got ", tagKind());
95 return toIntrusivePtr<ivalue::Future>();
96 }
toAwait()97 inline c10::intrusive_ptr<ivalue::Await> IValue::toAwait() && {
98 AT_ASSERT(isAwait(), "Expected Await but got ", tagKind());
99 return moveToIntrusivePtr<ivalue::Await>();
100 }
toAwait()101 inline c10::intrusive_ptr<ivalue::Await> IValue::toAwait() const& {
102 AT_ASSERT(isAwait(), "Expected Await but got ", tagKind());
103 return toIntrusivePtr<ivalue::Await>();
104 }
toRRef()105 inline c10::intrusive_ptr<c10::RRefInterface> IValue::toRRef() && {
106 AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind());
107 return moveToIntrusivePtr<c10::RRefInterface>();
108 }
toRRef()109 inline c10::intrusive_ptr<c10::RRefInterface> IValue::toRRef() const& {
110 AT_ASSERT(isRRef(), "Expected RRef but got ", tagKind());
111 return toIntrusivePtr<c10::RRefInterface>();
112 }
toQuantizer()113 inline c10::intrusive_ptr<at::Quantizer> IValue::toQuantizer() && {
114 AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind());
115 return moveToIntrusivePtr<at::Quantizer>();
116 }
toQuantizer()117 inline c10::intrusive_ptr<at::Quantizer> IValue::toQuantizer() const& {
118 AT_ASSERT(isQuantizer(), "Expected Quantizer but got ", tagKind());
119 return toIntrusivePtr<at::Quantizer>();
120 }
toString()121 inline c10::intrusive_ptr<ivalue::ConstantString> IValue::toString() && {
122 AT_ASSERT(isString(), "Expected String but got ", tagKind());
123 return moveToIntrusivePtr<ivalue::ConstantString>();
124 }
toString()125 inline c10::intrusive_ptr<ivalue::ConstantString> IValue::toString() const& {
126 AT_ASSERT(isString(), "Expected String but got ", tagKind());
127 return toIntrusivePtr<ivalue::ConstantString>();
128 }
toObject()129 inline c10::intrusive_ptr<ivalue::Object> IValue::toObject() && {
130 AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
131 return moveToIntrusivePtr<ivalue::Object>();
132 }
toObject()133 inline c10::intrusive_ptr<ivalue::Object> IValue::toObject() const& {
134 AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
135 return toIntrusivePtr<ivalue::Object>();
136 }
137 inline c10::intrusive_ptr<ivalue::PyObjectHolder> IValue::
toPyObjectHolder()138 toPyObjectHolder() && {
139 TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind());
140 return moveToIntrusivePtr<ivalue::PyObjectHolder>();
141 }
toPyObjectHolder()142 inline c10::intrusive_ptr<ivalue::PyObjectHolder> IValue::toPyObjectHolder()
143 const& {
144 TORCH_INTERNAL_ASSERT(isPyObject(), "Expected PyObject but got ", tagKind());
145 return toIntrusivePtr<ivalue::PyObjectHolder>();
146 }
toEnumHolder()147 inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() && {
148 TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind());
149 return moveToIntrusivePtr<ivalue::EnumHolder>();
150 }
toEnumHolder()151 inline c10::intrusive_ptr<ivalue::EnumHolder> IValue::toEnumHolder() const& {
152 TORCH_INTERNAL_ASSERT(isEnum(), "Expected Enum but got ", tagKind());
153 return toIntrusivePtr<ivalue::EnumHolder>();
154 }
toComplexDouble()155 inline c10::complex<double> IValue::toComplexDouble() const {
156 TORCH_INTERNAL_ASSERT(isComplexDouble(), "Expected ComplexDouble but got ", tagKind());
157 auto ptr = toIntrusivePtr<ivalue::ComplexHolder>();
158 return (*ptr).val;
159 }
toTensor()160 inline at::Tensor IValue::toTensor() && {
161 if (C10_UNLIKELY(!isTensor())) {
162 reportToTensorTypeError();
163 }
164 auto result = std::move(payload.as_tensor);
165 // As far as I can tell, omitting the usual explicit destructor call
166 // is not UB in and of itself, and it's a slight perf win. The
167 // destructor is a no-op, because the moved-from Tensor is
168 // effectively an intrusive_ptr in the null state, so we don't need
169 // the behavior for correctness reasons either. Leaving this
170 // explanatory comment, including commented-out destructor call, to
171 // make this abundantly clear.
172 //
173 // payload.as_tensor.~Tensor();
174 clearToNone();
175 return result;
176 }
toTensor()177 inline at::Tensor& IValue::toTensor() & {
178 if (C10_UNLIKELY(!isTensor())) {
179 reportToTensorTypeError();
180 }
181 return payload.as_tensor;
182 }
toTensor()183 inline const at::Tensor& IValue::toTensor() const& {
184 if (C10_UNLIKELY(!isTensor())) {
185 reportToTensorTypeError();
186 }
187 return payload.as_tensor;
188 }
toStorage()189 inline c10::Storage IValue::toStorage() && {
190 AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
191 return c10::Storage(
192 moveToIntrusivePtr<at::StorageImpl>());
193 }
toStorage()194 inline c10::Storage IValue::toStorage() const& {
195 AT_ASSERT(isStorage(), "Expected Storage but got ", tagKind());
196 return c10::Storage(toIntrusivePtr<at::StorageImpl>());
197 }
toStream()198 inline c10::Stream IValue::toStream() && {
199 AT_ASSERT(isStream(), "Expected Stream but got ", tagKind());
200 auto ptr = toIntrusivePtr<ivalue::StreamData3Holder>();
201 return c10::Stream::unpack3((*ptr).val.stream_id,
202 (*ptr).val.device_index,
203 (*ptr).val.device_type);
204 }
toStream()205 inline c10::Stream IValue::toStream() const& {
206 AT_ASSERT(isStream(), "Expected Stream but got ", tagKind());
207 auto ptr = toIntrusivePtr<ivalue::StreamData3Holder>();
208 return c10::Stream::unpack3((*ptr).val.stream_id,
209 (*ptr).val.device_index,
210 (*ptr).val.device_type);
211 }
toBlob()212 inline c10::intrusive_ptr<caffe2::Blob> IValue::toBlob() && {
213 AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
214 return moveToIntrusivePtr<caffe2::Blob>();
215 }
toBlob()216 inline c10::intrusive_ptr<caffe2::Blob> IValue::toBlob() const& {
217 AT_ASSERT(isBlob(), "Expected Blob but got ", tagKind());
218 return toIntrusivePtr<caffe2::Blob>();
219 ;
220 }
toCapsule()221 inline c10::intrusive_ptr<torch::CustomClassHolder> IValue::toCapsule() && {
222 TORCH_INTERNAL_ASSERT(isCapsule());
223 return moveToIntrusivePtr<torch::CustomClassHolder>();
224 }
toCapsule()225 inline c10::intrusive_ptr<torch::CustomClassHolder> IValue::toCapsule() const& {
226 TORCH_INTERNAL_ASSERT(isCapsule());
227 return toIntrusivePtr<torch::CustomClassHolder>();
228 }
toGenerator()229 inline at::Generator IValue::toGenerator() && {
230 AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind());
231 return at::Generator(moveToIntrusivePtr<at::GeneratorImpl>());
232 }
toGenerator()233 inline at::Generator IValue::toGenerator() const& {
234 AT_ASSERT(isGenerator(), "Expected Generator but got ", tagKind());
235 return at::Generator(toIntrusivePtr<at::GeneratorImpl>());
236 }
toSymInt()237 inline c10::SymInt IValue::toSymInt() && {
238 AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
239 if (isSymInt()) {
240 return c10::SymInt(moveToIntrusivePtr<c10::SymNodeImpl>());
241 } else {
242 return c10::SymInt(payload.u.as_int);
243 }
244 }
toSymInt()245 inline c10::SymInt IValue::toSymInt() const& {
246 AT_ASSERT(isSymInt() || isInt(), "Expected SymInt or int but got ", tagKind());
247 if (isSymInt()) {
248 return c10::SymInt(toIntrusivePtr<c10::SymNodeImpl>());
249 } else {
250 return c10::SymInt(payload.u.as_int);
251 }
252 }
toSymFloat()253 inline c10::SymFloat IValue::toSymFloat() && {
254 AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
255 if (isSymFloat()) {
256 return c10::SymFloat(moveToIntrusivePtr<c10::SymNodeImpl>());
257 } else {
258 return c10::SymFloat(payload.u.as_double);
259 }
260 }
toSymFloat()261 inline c10::SymFloat IValue::toSymFloat() const& {
262 AT_ASSERT(isSymFloat() || isDouble(), "Expected SymFloat or double but got ", tagKind());
263 if (isSymFloat()) {
264 return c10::SymFloat(toIntrusivePtr<c10::SymNodeImpl>());
265 } else {
266 return c10::SymFloat(payload.u.as_double);
267 }
268 }
toSymBool()269 inline c10::SymBool IValue::toSymBool() && {
270 AT_ASSERT(isSymBool() || isBool(), "Expected SymBool or boolean but got ", tagKind());
271 if (isSymBool()) {
272 return c10::SymBool(moveToIntrusivePtr<c10::SymNodeImpl>());
273 } else {
274 return c10::SymBool(payload.u.as_bool);
275 }
276 }
277
toSymBool()278 inline c10::SymBool IValue::toSymBool() const& {
279 AT_ASSERT(isSymBool() || isBool(), "Expected SymBool or boolean but got ", tagKind());
280 if (isSymBool()) {
281 return c10::SymBool(toIntrusivePtr<c10::SymNodeImpl>());
282 } else {
283 return c10::SymBool(payload.u.as_bool);
284 }
285 }
286
287 namespace ivalue {
288
289 void TORCH_API
290 checkCustomClassType(const ClassType* expected_type, const Type* actual_type);
291
292 template <typename T>
293 using Shared = c10::intrusive_ptr<T>;
294
295 // string
296 struct TORCH_API ConstantString final : c10::intrusive_ptr_target {
297 private:
298 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
299 const std::string str_;
300
301 public:
ConstantStringfinal302 ConstantString(std::string str) : str_(std::move(str)) {}
ConstantStringfinal303 ConstantString(c10::string_view str) : str_(std::string(str)) {}
304 static c10::intrusive_ptr<ConstantString> create(std::string str_);
305 static c10::intrusive_ptr<ConstantString> create(c10::string_view str_);
306 static c10::intrusive_ptr<ConstantString> create(const char* str_);
307
stringfinal308 const std::string& string() const {
309 return str_;
310 }
string_viewfinal311 c10::string_view string_view() const {
312 return str_;
313 }
314
315 operator const std::string&() const {
316 return string();
317 }
318 TORCH_API friend std::ostream& operator<<(
319 std::ostream& out,
320 const ConstantString& v);
321 };
322
323 struct Future;
324
325 struct TORCH_API TupleElements {
326 private:
327 size_t inlineSize_;
328 // We represent TupleElements this way to save doing a heap
329 // allocation in the common (at least for unpickling) case where we
330 // have only 3 elements. We have our own union instead of
331 // c10::SmallVector<IValue> because c10::SmallVector<IValue> always
332 // stores the begin/end/capacity pointers, which would be a waste of
333 // space in our use case.
334 union {
335 std::vector<IValue> elementsVector_;
336 // Don't want to declare a std::array because the convenient
337 // iteration and size members are a footgun in this case -- the
338 // actual size of the array may be smaller than 3!
339 // NOLINTNEXTLINE(*c-arrays*)
340 IValue elementsInline_[3];
341 };
342
destroyInlineTupleElements343 void destroyInline() {
344 for (const auto ii : c10::irange(inlineSize_)) {
345 elementsInline_[ii].~IValue();
346 }
347 }
348 public:
349
350 using iterator = IValue*;
351 using const_iterator = const IValue*;
352
TupleElementsTupleElements353 TupleElements() : inlineSize_(0) {
354 new (&elementsVector_) std::vector<IValue>();
355 }
356
TupleElementsTupleElements357 explicit TupleElements(std::vector<IValue> elements)
358 : inlineSize_(0), elementsVector_(std::move(elements)) {}
359
TupleElementsTupleElements360 explicit TupleElements(c10::ArrayRef<IValue> elements)
361 : inlineSize_(elements.size() <= 3 ? elements.size() : 0) {
362 switch (inlineSize_) {
363 case 3:
364 new (&elementsInline_[2]) IValue(elements[2]);
365 [[fallthrough]];
366 case 2:
367 new (&elementsInline_[1]) IValue(elements[1]);
368 [[fallthrough]];
369 case 1:
370 new (&elementsInline_[0]) IValue(elements[0]);
371 break;
372 case 0:
373 new (&elementsVector_) std::vector<IValue>(elements.begin(), elements.end());
374 break;
375 }
376 }
377
TupleElementsTupleElements378 explicit TupleElements(IValue&& e1)
379 : inlineSize_(1) {
380 new (&elementsInline_[0]) IValue(std::move(e1));
381 }
382
TupleElementsTupleElements383 explicit TupleElements(IValue&& e1, IValue&& e2)
384 : inlineSize_(2) {
385 new (&elementsInline_[0]) IValue(std::move(e1));
386 new (&elementsInline_[1]) IValue(std::move(e2));
387 }
388
TupleElementsTupleElements389 explicit TupleElements(IValue&& e1, IValue&& e2, IValue&& e3)
390 : inlineSize_(3) {
391 new (&elementsInline_[0]) IValue(std::move(e1));
392 new (&elementsInline_[1]) IValue(std::move(e2));
393 new (&elementsInline_[2]) IValue(std::move(e3));
394 }
395
~TupleElementsTupleElements396 ~TupleElements() {
397 if (inlineSize_) {
398 destroyInline();
399 } else {
400 elementsVector_.~vector();
401 }
402 }
403
404 // It would be nice to make this noncopyable to prevent people from
405 // writing code like `auto output =
406 // forward(...).toTupleRef().elements()` (which does refcount bumps on
407 // each element, unlike the more efficient but verbose
408 // ```
409 // auto outputIntrusivePtr = forward(...).toTuple();
410 // const auto& output = outputIntrusivePtr->elements();
411 // ```
412 // ), but there is simply an overwhelming amount of code that does
413 // it the inefficient way.
414 // See also operator std::vector below.
TupleElementsTupleElements415 TupleElements(const TupleElements& rhs)
416 : inlineSize_(rhs.inlineSize_) {
417 if (rhs.inlineSize_) {
418 for (const auto ii : c10::irange(inlineSize_)) {
419 new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
420 }
421 } else {
422 new (&elementsVector_) std::vector<IValue>(rhs.elementsVector_);
423 }
424 }
425
426 TupleElements& operator=(const TupleElements& rhs) {
427 if (inlineSize_) {
428 if (rhs.inlineSize_) {
429 for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) {
430 elementsInline_[ii] = rhs.elementsInline_[ii];
431 }
432 if (rhs.inlineSize_ > inlineSize_) {
433 for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) {
434 new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
435 }
436 } else {
437 for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) {
438 elementsInline_[ii].~IValue();
439 }
440 }
441 } else {
442 destroyInline();
443 new (&elementsVector_) std::vector<IValue>(rhs.elementsVector_);
444 }
445 } else {
446 if (rhs.inlineSize_) {
447 elementsVector_.~vector();
448 for (const auto ii : c10::irange(rhs.inlineSize_)) {
449 new (&elementsInline_[ii]) IValue(rhs.elementsInline_[ii]);
450 }
451 } else {
452 elementsVector_ = rhs.elementsVector_;
453 }
454 }
455 inlineSize_ = rhs.inlineSize_;
456 return *this;
457 }
458
TupleElementsTupleElements459 TupleElements(TupleElements&& rhs) noexcept
460 : inlineSize_(rhs.inlineSize_) {
461 if (inlineSize_) {
462 for (const auto ii : c10::irange(inlineSize_)) {
463 new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
464 }
465 } else {
466 new (&elementsVector_) std::vector<IValue>(std::move(rhs.elementsVector_));
467 }
468 }
469
470 TupleElements& operator=(TupleElements&& rhs) noexcept {
471 if (inlineSize_) {
472 if (rhs.inlineSize_) {
473 for (const auto ii : c10::irange(std::min(inlineSize_, rhs.inlineSize_))) {
474 elementsInline_[ii] = std::move(rhs.elementsInline_[ii]);
475 }
476 if (rhs.inlineSize_ > inlineSize_) {
477 for (const auto ii : c10::irange(inlineSize_, rhs.inlineSize_)) {
478 new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
479 }
480 } else {
481 for (const auto ii : c10::irange(rhs.inlineSize_, inlineSize_)) {
482 elementsInline_[ii].~IValue();
483 }
484 }
485 } else {
486 destroyInline();
487 new (&elementsVector_) std::vector<IValue>(std::move(rhs.elementsVector_));
488 }
489 } else {
490 if (rhs.inlineSize_) {
491 elementsVector_.~vector();
492 for (const auto ii : c10::irange(rhs.inlineSize_)) {
493 new (&elementsInline_[ii]) IValue(std::move(rhs.elementsInline_[ii]));
494 }
495 } else {
496 elementsVector_ = std::move(rhs.elementsVector_);
497 }
498 }
499 inlineSize_ = rhs.inlineSize_;
500 return *this;
501 }
502
asArrayRefTupleElements503 C10_NODISCARD c10::ArrayRef<IValue> asArrayRef() const {
504 if (inlineSize_) {
505 return c10::ArrayRef<IValue>(elementsInline_, inlineSize_);
506 } else {
507 return elementsVector_;
508 }
509 }
510
511 // Mimic implicit conversion from std::vector to ArrayRef.
512 operator c10::ArrayRef<IValue>() const {
513 return asArrayRef();
514 }
515
hashTupleElements516 static size_t hash(const TupleElements& v) {
517 return c10::hash<c10::ArrayRef<IValue>>()(v.asArrayRef());
518 }
519
setContentsTupleElements520 void setContents(std::vector<IValue>&& contents) {
521 if (inlineSize_) {
522 destroyInline();
523 new (&elementsVector_) std::vector<IValue>(std::move(contents));
524 inlineSize_ = 0;
525 } else {
526 elementsVector_ = std::move(contents);
527 }
528 }
529
emptyTupleElements530 C10_NODISCARD bool empty() const {
531 return inlineSize_ ? false : elementsVector_.empty();
532 }
533
sizeTupleElements534 C10_NODISCARD size_t size() const {
535 return inlineSize_ ? inlineSize_ : elementsVector_.size();
536 }
537
538 C10_NODISCARD IValue& operator[](size_t idx) {
539 if (inlineSize_) {
540 return elementsInline_[idx];
541 } else {
542 return elementsVector_[idx];
543 }
544 }
545
546 C10_NODISCARD const IValue& operator[](size_t idx) const {
547 if (inlineSize_) {
548 return elementsInline_[idx];
549 } else {
550 return elementsVector_[idx];
551 }
552 }
553
atTupleElements554 C10_NODISCARD IValue& at(size_t idx) {
555 if (inlineSize_) {
556 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3);
557 TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_);
558 return elementsInline_[idx];
559 } else {
560 return elementsVector_.at(idx);
561 }
562 }
563
atTupleElements564 C10_NODISCARD const IValue& at(size_t idx) const {
565 if (inlineSize_) {
566 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3);
567 TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_);
568 return elementsInline_[idx];
569 } else {
570 TORCH_CHECK(idx < elementsVector_.size(), "TupleElements: invalid index Index = ", idx, "; Length = ", elementsVector_.size());
571 return elementsVector_.at(idx);
572 }
573 }
574
beginTupleElements575 C10_NODISCARD iterator begin() {
576 if (inlineSize_) {
577 return elementsInline_;
578 } else {
579 return elementsVector_.data();
580 }
581 }
582
endTupleElements583 C10_NODISCARD iterator end() {
584 if (inlineSize_) {
585 return elementsInline_ + inlineSize_;
586 } else {
587 return elementsVector_.data() + elementsVector_.size();
588 }
589 }
590
beginTupleElements591 C10_NODISCARD const_iterator begin() const {
592 if (inlineSize_) {
593 return elementsInline_;
594 } else {
595 return elementsVector_.data();
596 }
597 }
598
endTupleElements599 C10_NODISCARD const_iterator end() const {
600 if (inlineSize_) {
601 return elementsInline_ + inlineSize_;
602 } else {
603 return elementsVector_.data() + elementsVector_.size();
604 }
605 }
606
cbeginTupleElements607 C10_NODISCARD const_iterator cbegin() const {
608 return begin();
609 }
610
cendTupleElements611 C10_NODISCARD const_iterator cend() const {
612 return end();
613 }
614
vecTupleElements615 C10_NODISCARD std::vector<IValue> vec() const & {
616 return asArrayRef().vec();
617 }
618
backTupleElements619 C10_NODISCARD IValue& back() {
620 return *(end() - 1);
621 }
622
backTupleElements623 C10_NODISCARD const IValue& back() const {
624 return *(end() - 1);
625 }
626
vecTupleElements627 C10_NODISCARD std::vector<IValue> vec() && {
628 std::vector<IValue> result;
629 result.reserve(size());
630 for (auto&& iv : *this) {
631 result.push_back(std::move(iv));
632 }
633 return result;
634 }
635
636 // More compatibility shims for the overwhelming amount of code that
637 // likes to copy tuple elements into a vector; see comment above the
638 // copy constructor.
639 operator std::vector<IValue>() const & {
640 return vec();
641 }
642
643 operator std::vector<IValue>() && {
644 return vec();
645 }
646 };
647
648 template <typename T>
649 struct TupleTypeFactory {};
650
651 template <>
652 struct TORCH_API TupleTypeFactory<TupleType> {
653 static TupleTypePtr create(std::vector<TypePtr> types) {
654 return TupleType::create(std::move(types));
655 }
656 static TupleTypePtr fallback(const Type& type);
657 };
658
659 template <>
660 struct TORCH_API TupleTypeFactory<c10::DynamicType> {
661 static DynamicTypePtr create(const std::vector<TypePtr>& elemTypes);
662 static DynamicTypePtr fallback(const Type&);
663 };
664
665 struct TORCH_API Tuple : c10::intrusive_ptr_target {
666 private:
667 TupleElements elements_;
668 mutable c10::TypePtr type_; // lazily computed for unnamed tuples
669
670 public:
671 // named tuples have additional type information, so we
672 // directly create them tagged
673 static c10::intrusive_ptr<Tuple> createNamed(
674 std::vector<IValue> elements_,
675 c10::TypePtr type_) {
676 return c10::make_intrusive<Tuple>(std::move(elements_), std::move(type_));
677 }
678
679 static c10::intrusive_ptr<Tuple> createNamed(
680 TupleElements elements_,
681 std::shared_ptr<TupleType> type_) {
682 return c10::make_intrusive<Tuple>(std::move(elements_), std::move(type_));
683 }
684
685 static c10::intrusive_ptr<Tuple> createNamed(
686 std::initializer_list<IValue> elements_,
687 std::shared_ptr<TupleType> type_) {
688 return createNamed(TupleElements(c10::ArrayRef<IValue>(elements_)), std::move(type_));
689 }
690
691 // MSVC apparently can't disambiguate the other two overloads of
692 // create when passed an initializer_list without this.
693 static c10::intrusive_ptr<Tuple> create(std::initializer_list<IValue> elements_) {
694 return create(c10::ArrayRef<IValue>(elements_));
695 }
696
697 static c10::intrusive_ptr<Tuple> create(std::vector<IValue> elements_) {
698 return c10::make_intrusive<Tuple>(std::move(elements_));
699 }
700
701 static c10::intrusive_ptr<Tuple> create(TupleElements elements_) {
702 return c10::make_intrusive<Tuple>(std::move(elements_));
703 }
704
705 static c10::intrusive_ptr<Tuple> create(c10::ArrayRef<IValue> elements_) {
706 return create(TupleElements(elements_));
707 }
708
709 static c10::intrusive_ptr<Tuple> create(IValue e1) {
710 return c10::make_intrusive<Tuple>(std::move(e1));
711 }
712
713 static c10::intrusive_ptr<Tuple> create(IValue e1, IValue e2) {
714 return c10::make_intrusive<Tuple>(std::move(e1), std::move(e2));
715 }
716
717 static c10::intrusive_ptr<Tuple> create(IValue e1, IValue e2, IValue e3) {
718 return c10::make_intrusive<Tuple>(std::move(e1), std::move(e2), std::move(e3));
719 }
720
721 private:
722 // Workaround inability to use `>` operator in template argument list.
723 template <typename... Args>
724 static constexpr bool hasMoreThanThreeArgs() {
725 return sizeof...(Args) > 3;
726 }
727
728 public:
729 template <typename... Args>
730 static c10::intrusive_ptr<Tuple> create(Args&&... elements_) {
731 switch (sizeof...(Args)) {
732 case 1:
733 case 2:
734 case 3:
735 return create(IValue(std::forward<Args>(elements_))...);
736 default:
737 return create(
738 std::vector<IValue>{IValue(std::forward<Args>(elements_))...});
739 }
740 }
741
742 // Again, it would be nice to make this noncopyable, but there's a
743 // lot of extant code that copies Tuples.
744 // Tuple(const Tuple& rhs) = delete;
745
746 const TupleElements& elements() const& {
747 return elements_;
748 }
749
750 TupleElements elements() && {
751 return std::move(elements_);
752 }
753
754 void setElements(std::vector<IValue>&& elements) {
755 elements_.setContents(std::move(elements));
756 }
757
758 void setElements(TupleElements&& elements) {
759 elements_ = std::move(elements);
760 }
761
762 void unsafeSetElement(size_t idx, const IValue& element) {
763 elements_[idx] = element;
764 }
765
766 void unsafeSetElement(size_t idx, IValue&& element) {
767 elements_[idx] = std::move(element);
768 }
769
770 size_t size() const {
771 return elements_.size();
772 }
773
774 template <typename T = c10::TupleType>
775 std::shared_ptr<T> type() const {
776 if (!type_) {
777 type_ = TupleTypeFactory<T>::create(fmap(elements(), [&](const IValue& v) {
778 return v.type<typename T::ElementType>();
779 }));
780 }
781 if (auto t = type_->cast<T>()) {
782 return t;
783 }
784 return TupleTypeFactory<T>::fallback(*type_);
785 }
786
787 static size_t hash(const Tuple& t) {
788 return c10::get_hash(t.elements());
789 }
790
791 TORCH_API friend bool operator==(
792 const ivalue::Tuple& lhs,
793 const ivalue::Tuple& rhs);
794
795 private:
796 // NOTE: If we try to avoid the overloads without
797 // `std::shared_ptr<TupleType> type` by defaulting it to nullptr, we
798 // end up having to call (part of) the shared_ptr destructor for
799 // `type` even though we should know statically it won't do
800 // anything.
801 explicit Tuple(std::vector<IValue> elements)
802 : elements_(std::move(elements)){}
803
804 explicit Tuple(std::vector<IValue> elements, c10::TypePtr type)
805 : elements_(std::move(elements)), type_(std::move(type)) {}
806
807 explicit Tuple(TupleElements&& elements)
808 : elements_(std::move(elements)) {}
809
810 explicit Tuple(TupleElements&& elements, std::shared_ptr<TupleType> type)
811 : elements_(std::move(elements)), type_(std::move(type)) {}
812
813 explicit Tuple(IValue&& e1)
814 : elements_(std::move(e1)) {}
815
816 explicit Tuple(IValue&& e1, std::shared_ptr<TupleType> type)
817 : elements_(std::move(e1)), type_(std::move(type)) {}
818
819 explicit Tuple(IValue&& e1, IValue&& e2)
820 : elements_(std::move(e1), std::move(e2)) {}
821
822 explicit Tuple(IValue&& e1, IValue&& e2, std::shared_ptr<TupleType> type)
823 : elements_(std::move(e1), std::move(e2)), type_(std::move(type)) {}
824
825 explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3)
826 : elements_(std::move(e1), std::move(e2), std::move(e3)) {}
827
828 explicit Tuple(IValue&& e1, IValue&& e2, IValue&& e3, std::shared_ptr<TupleType> type)
829 : elements_(std::move(e1), std::move(e2), std::move(e3)), type_(std::move(type)) {}
830
831 friend class c10::intrusive_ptr<Tuple>;
832 };
833
834 struct Object;
835 struct PyObjectHolder;
836 struct EnumHolder;
837 } // namespace ivalue
838
839 // Future
840 struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
841 private:
842 // Keep this private in order to force users to go through make_intrusive and
843 // thus prevent creating a Future that's not held by an intrusive_ptr.
844 explicit Future(TypePtr type, std::vector<c10::Device> devices={})
845 : type_(std::move(type)),
846 impl_(getTypeOfDevices(devices)),
847 devices_(sortAndDeduplicateDevices(impl_, std::move(devices))) {}
848
849 friend c10::intrusive_ptr<Future>;
850
851 struct FutureCallback {
852 std::function<void(Future&)> callback;
853 bool uses_future; // whether the Future& passed in is actually used
854
855 template <typename T>
856 FutureCallback(T callback, bool uses_future)
857 : callback(std::move(callback)), uses_future(uses_future) {}
858 };
859
860 public:
861 Future(const Future&) = delete;
862 Future(Future&&) = delete;
863 Future& operator=(const Future&) = delete;
864 Future& operator=(Future&&) = delete;
865
866 struct TORCH_API FutureError final : public std::exception {
867 explicit FutureError(std::string&& error_msg_)
868 : error_msg(std::move(error_msg_)) {}
869
870 FutureError() = default;
871
872 const char* what() const noexcept override {
873 return error_msg.c_str();
874 }
875
876 std::string error_msg;
877 };
878
879 /**
880 * Wait on the future until it completes.
881 */
882 void wait() {
883 std::unique_lock<std::mutex> lock(mutex_);
884 finished_cv_.wait(lock, [&]() -> bool { return completed_; });
885 synchronizeWithCurrentStreams();
886 }
887
888 /**
889 * Wait on the future until it completes and throw an
890 * exception if an error exists.
891 */
892 void waitAndThrow() {
893 wait();
894
895 if (eptr_) {
896 std::rethrow_exception(eptr_);
897 }
898 }
899
900 /**
901 * Explicitly mark the future as completed with the output value. Optionally,
902 * the storages for all tensors in IValue can be passed as well. The DataPtrs
903 * of these storages are used to synchronize CUDA streams. If storages isn't
904 * given we will attempt to extract it from the value, if we need to (this
905 * happens if a non-empty set of devices was given to the constructor). Thus
906 * one only needs to provide storages when 1) they cannot be extracted through
907 * IValue::getSubValues() or through pickling in case of Python object; or
908 * when 2) customized storage extraction is more efficient.
909 */
910 using WeakStorage = c10::weak_intrusive_ptr<c10::StorageImpl>;
911 void markCompleted(
912 IValue value,
913 std::optional<std::vector<WeakStorage>> storages = std::nullopt) {
914 // Start by performing all steps that can throw, before setting any field.
915 // Do this before even acquiring the mutex, because extractStorages might
916 // acquire the GIL, which could lead to a lock inversion with our mutex.
917 // See https://github.com/pytorch/pytorch/issues/58239.
918 std::vector<WeakStorage> actualStorages;
919 std::vector<c10::Device> usedDevices;
920 try {
921 // FIXME We should always extract DataPtrs, in order to catch the case of
922 // users using CUDA values but forgetting to set devices, which currently
923 // leads to a silent synchronization/correctness issue. However, as this
924 // might worsen perf in CPU-only cases, we should only do so after careful
925 // benchmarks.
926 if (impl_.type() != c10::kCPU) {
927 actualStorages =
928 storages.has_value() ? std::move(*storages) : extractStorages(value);
929 usedDevices = getDevicesOfStorages(impl_, actualStorages);
930 ensureIsSubsetOfDevices(usedDevices, devices_);
931 }
932 } catch (const std::exception&) {
933 setError(std::current_exception());
934 return;
935 }
936
937 std::unique_lock<std::mutex> lock(mutex_);
938 TORCH_CHECK(
939 !completed(),
940 "Attempting to mark a completed Future as complete again. Note that "
941 "a Future can only be marked completed once.");
942
943 // Only set value_ and completed_ flag once all checks and preparation steps
944 // have returned successfully to allow for proper error propagation.
945 value_ = std::move(value);
946 completed_ = true;
947
948 currentDevice_ = impl_.getDevice();
949 storages_ = std::move(actualStorages);
950 for (const c10::Device& device : usedDevices) {
951 c10::Event event(impl_.type());
952 event.record(impl_.getStream(device));
953 events_.push_back(std::move(event));
954 }
955
956 std::vector<FutureCallback> cbs;
957 cbs.swap(callbacks_);
958 lock.unlock();
959
960 finished_cv_.notify_all();
961 for (auto& callback : cbs) {
962 invokeCallback(std::move(callback.callback), callback.uses_future);
963 }
964 }
965
966 void markCompleted() {
967 markCompleted(IValue{});
968 }
969
970 void setError(std::exception_ptr eptr) {
971 std::unique_lock<std::mutex> lock(mutex_);
972 setErrorInternal(std::move(eptr), lock);
973 }
974
975 void setErrorIfNeeded(std::exception_ptr eptr) {
976 std::unique_lock<std::mutex> lock(mutex_);
977 if (completed_) {
978 // This should be rare and shouldn't cause log spew. Its important to
979 // log errors and thats why we have this log here.
980 std::string msg = c10::str(
981 "Skipping setting following error on the Future since "
982 "it is already marked completed (this is not necessarily "
983 "an error):\n",
984 tryRetrieveErrorMessageInternal(std::move(eptr)));
985 if (eptr_) {
986 msg += c10::str(
987 ", \nOriginal exception:\n",
988 tryRetrieveErrorMessageInternal(eptr_));
989 }
990 LOG(INFO) << msg;
991 return;
992 } else {
993 setErrorInternal(std::move(eptr), lock);
994 }
995 }
996
997 // Get the result of the current future.
998 IValue value() {
999 std::unique_lock<std::mutex> lock(mutex_);
1000 AT_ASSERT(completed());
1001 if (eptr_) {
1002 std::rethrow_exception(eptr_);
1003 }
1004 return value_;
1005 }
1006
1007 // This accessor should only be used if we know that the future is
1008 // completed() with no error.
1009 const IValue& constValue() const {
1010 std::unique_lock<std::mutex> lock(mutex_);
1011 AT_ASSERT(completed());
1012 TORCH_INTERNAL_ASSERT(
1013 !eptr_,
1014 "value() accessor should only be used when future is not completed with ",
1015 "an error, but future had the following error: ",
1016 tryRetrieveErrorMessageInternal(eptr_)
1017 );
1018 return value_;
1019 }
1020
1021 // This accessor should only be used if we know that the future is
1022 // completed() with no error.
1023 const std::vector<WeakStorage>& storages() const {
1024 std::unique_lock<std::mutex> lock(mutex_);
1025 AT_ASSERT(completed());
1026 AT_ASSERT(!eptr_);
1027 return storages_;
1028 }
1029
1030 /**
1031 * Add a callback to the future.
1032 * The callbacks will be executed once the future completes.
1033 * If the future has already completed,
1034 * this function will execute the callback immediately.
1035 */
1036 template <typename T>
1037 void addCallback(T callback, bool uses_future = true) {
1038 static_assert(
1039 std::is_invocable_r<void, T, Future&>::value,
1040 "The callback must have signature void(Future&)");
1041
1042 std::unique_lock<std::mutex> lock(mutex_);
1043 if (completed()) {
1044 lock.unlock();
1045 invokeCallback(std::move(callback), uses_future);
1046 return;
1047 }
1048 callbacks_.emplace_back(std::move(callback), uses_future);
1049 }
1050
1051 /**
1052 * Add a callback to the future, and return another Future to hold the return
1053 * value of the callback. This is necessary when the callback provider needs
1054 * to know for sure when the callback has finished.
1055 */
1056 template <typename T>
1057 c10::intrusive_ptr<Future> then(T callback, TypePtr type) {
1058 using IValueWithStorages = std::tuple<IValue, std::vector<WeakStorage>>;
1059 static_assert(
1060 std::disjunction<
1061 std::is_invocable_r<IValue, T, Future&>,
1062 std::is_invocable_r<IValueWithStorages, T, Future&>>::value,
1063 "The callback must have signature IValue(Future&) or "
1064 "std::tuple<IValue, std::vector<Storage>>(Future&)");
1065
1066 auto childFut = createInstance(::std::move(type));
1067 addCallback([childFut,
1068 cb = std::move(callback)](Future& parentFut) mutable {
1069 try {
1070 if constexpr (::std::is_convertible_v<typename std::invoke_result_t<T &&, Future&>, IValueWithStorages>) {
1071 auto [ivalue, storages] = cb(parentFut);
1072 childFut->markCompleted(::std::move(ivalue), ::std::move(storages));
1073 } else {
1074 childFut->markCompleted(cb(parentFut));
1075 }
1076 } catch (std::exception&) {
1077 childFut->setError(std::current_exception());
1078 }
1079 });
1080 return childFut;
1081 }
1082
1083 template <typename T>
1084 c10::intrusive_ptr<Future> thenAsync(T callback, TypePtr type) {
1085 static_assert(
1086 std::is_invocable_r<c10::intrusive_ptr<Future>, T, Future&>::value,
1087 "The callback must have signature c10::intrusive_ptr<Future>(Future&)");
1088
1089 auto childFut = createInstance(std::move(type));
1090 addCallback(
1091 [childFut, cb = std::move(callback)](Future& parentFut) mutable {
1092 c10::intrusive_ptr<Future> intermediateFut;
1093 try {
1094 intermediateFut = cb(parentFut);
1095 } catch (std::exception&) {
1096 childFut->setError(std::current_exception());
1097 return;
1098 }
1099 intermediateFut->addCallback(
1100 [childFut = std::move(childFut)](Future& intermediateFut) {
1101 if (intermediateFut.hasError()) {
1102 childFut->setError(intermediateFut.exception_ptr());
1103 } else {
1104 childFut->markCompleted(
1105 intermediateFut.value(), intermediateFut.storages());
1106 }
1107 });
1108 });
1109 return childFut;
1110 }
1111
1112 // Tries to retrieve the error message from std::exception_ptr.
1113 std::string tryRetrieveErrorMessage() const {
1114 TORCH_CHECK(hasError(), "No error present on the future.");
1115 std::unique_lock<std::mutex> lock(mutex_);
1116 return tryRetrieveErrorMessageInternal(eptr_);
1117 }
1118
1119 // Check if the current future has completed
1120 bool completed() const {
1121 return completed_;
1122 }
1123
1124 bool hasValue() const {
1125 std::unique_lock<std::mutex> lock(mutex_);
1126 return completed_ && !eptr_;
1127 }
1128
1129 bool hasError() const {
1130 std::unique_lock<std::mutex> lock(mutex_);
1131 return eptr_ ? true : false;
1132 }
1133
1134 std::exception_ptr exception_ptr() const {
1135 std::unique_lock<std::mutex> lock(mutex_);
1136 return eptr_;
1137 }
1138
1139 TORCH_API friend std::ostream& operator<<(
1140 std::ostream& out,
1141 const Future& v);
1142
1143 const TypePtr& elementType() const {
1144 return type_;
1145 }
1146
1147 const std::vector<c10::Device>& devices() const {
1148 return devices_;
1149 }
1150
1151 // This method should be used when one intends to manually create a child
1152 // future, for example when implementing a customized version of then().
1153 c10::intrusive_ptr<Future> createInstance(at::TypePtr type) {
1154 return c10::make_intrusive<Future>(std::move(type), devices_);
1155 }
1156
1157 private:
1158
1159 // This method should always be used when invoking a callback (regardless of
1160 // how/when that happens) as it will ensure that the proper "environment" is
1161 // set up before running the callback, as in, it will set up the CUDA streams,
1162 // synchronize them with the value, and so on (if needed).
1163 template<typename T>
1164 void invokeCallback(T callback, bool uses_future) {
1165 static_assert(
1166 std::is_invocable_r<void, T, Future&>::value,
1167 "The callback must have signature void(Future&)");
1168
1169 // The synchronization performed below shouldn't be needed when the future
1170 // is not used by the callback.
1171 if (uses_future) {
1172 c10::OptionalDeviceGuard deviceGuard(currentDevice_);
1173
1174 std::vector<c10::Stream> streams;
1175 streams.reserve(devices_.size());
1176 for (const c10::Device& device : devices_) {
1177 streams.push_back(impl_.getStreamFromGlobalPool(device));
1178 }
1179 c10::MultiStreamGuard streamGuard(streams);
1180 synchronizeWithCurrentStreams();
1181 callback(*this);
1182 } else {
1183 callback(*this);
1184 }
1185 }
1186
1187 // This method should be called before this future's value is used, as it
1188 // ensures that the CUDA streams that are "current" at the callsite properly
1189 // synchronize with the value.
1190 void synchronizeWithCurrentStreams() {
1191 for (c10::Event& event : events_) {
1192 event.block(impl_.getStream(event.device()));
1193 }
1194
1195 for (const WeakStorage& weak_storage : storages_) {
1196 c10::intrusive_ptr<c10::StorageImpl> storage = weak_storage.lock();
1197 if (!storage) {
1198 continue;
1199 }
1200 if (!storage->device().is_cpu()) {
1201 impl_.recordDataPtrOnStream(
1202 storage->data_ptr(), impl_.getStream(storage->device()));
1203 }
1204 }
1205 }
1206
1207 void setErrorInternal(
1208 std::exception_ptr eptr,
1209 std::unique_lock<std::mutex>& lock) {
1210 TORCH_CHECK(
1211 !eptr_,
1212 "Error already set on this Future: ",
1213 tryRetrieveErrorMessageInternal(eptr_),
1214 ", trying to set error: ",
1215 tryRetrieveErrorMessageInternal(eptr));
1216 TORCH_INTERNAL_ASSERT(!completed(), "Future is already marked completed");
1217 completed_ = true;
1218 eptr_ = std::move(eptr);
1219
1220 std::vector<FutureCallback> cbs;
1221 cbs.swap(callbacks_);
1222 lock.unlock();
1223
1224 finished_cv_.notify_all();
1225 for (auto& callback : cbs) {
1226 invokeCallback(std::move(callback.callback), callback.uses_future);
1227 }
1228 }
1229
1230 // Tries to retrieve the error message from std::exception_ptr.
1231 std::string tryRetrieveErrorMessageInternal(std::exception_ptr eptr) const {
1232 try {
1233 std::rethrow_exception(std::move(eptr));
1234 } catch (const std::exception& e) {
1235 return e.what();
1236 } catch (...) {
1237 return "Unknown Exception Type";
1238 }
1239 }
1240
1241 // Defined in ivalue.cpp.
1242 static std::vector<WeakStorage> extractStorages(
1243 const at::IValue& value);
1244
1245 static std::vector<c10::Device> getDevicesOfStorages(
1246 const c10::impl::VirtualGuardImpl& impl,
1247 const std::vector<WeakStorage>& storages) {
1248 c10::DeviceIndex deviceCount = impl.deviceCount();
1249 std::vector<bool> isDeviceUsed(deviceCount, false);
1250 for (const WeakStorage& weak_storage : storages) {
1251 c10::intrusive_ptr<c10::StorageImpl> storage = weak_storage.lock();
1252 if (!storage) {
1253 continue;
1254 }
1255 c10::Device device = storage->device();
1256 if (!device.is_cpu()) {
1257 TORCH_CHECK_VALUE(
1258 device.type() == impl.type(),
1259 "Expected all data ptrs to be on a device of type ",
1260 impl.type(),
1261 ", got one on device ",
1262 device);
1263 isDeviceUsed[device.index()] = true;
1264 }
1265 }
1266 std::vector<c10::Device> devices;
1267 for (c10::DeviceIndex idx = 0; idx < deviceCount; idx++) {
1268 if (isDeviceUsed[idx]) {
1269 devices.emplace_back(impl.type(), idx);
1270 }
1271 }
1272 return devices;
1273 }
1274
1275 static std::string formatSetOfDevices(
1276 const std::vector<c10::Device>& devices) {
1277 if (devices.empty()) {
1278 return "(none)";
1279 }
1280 std::ostringstream oss;
1281 oss << devices[0];
1282 for (const auto idx : c10::irange(1, devices.size())) {
1283 if (idx == devices.size() - 1) {
1284 oss << " and ";
1285 } else {
1286 oss << ", ";
1287 }
1288 oss << devices[idx];
1289 }
1290 return oss.str();
1291 }
1292
1293 static c10::DeviceType getTypeOfDevices(
1294 const std::vector<c10::Device>& devices) {
1295 if (devices.empty()) {
1296 return c10::kCPU;
1297 }
1298 c10::DeviceType deviceType = devices[0].type();
1299 for (const auto idx : c10::irange(1, devices.size())) {
1300 TORCH_CHECK_VALUE(
1301 devices[idx].type() == deviceType,
1302 "Expected all devices to be of the same type, but got a mismatch between ",
1303 devices[0],
1304 " and ",
1305 devices[idx]);
1306 }
1307 return deviceType;
1308 }
1309
1310 // We need devices to be sorted in order to use ensureIsSubsetOfDevices.
1311 static std::vector<c10::Device> sortAndDeduplicateDevices(
1312 const c10::impl::VirtualGuardImpl& /*impl*/,
1313 std::vector<c10::Device> devices) {
1314 std::sort(
1315 devices.begin(), devices.end(),
1316 [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); });
1317 // Deduplicate by compacting.
1318 size_t targetIdx = 0;
1319 for (const auto sourceIdx : c10::irange(devices.size())) {
1320 TORCH_CHECK_VALUE(
1321 devices[sourceIdx].has_index(),
1322 "Expected devices to have indices, got ", devices[sourceIdx]);
1323 if (targetIdx > 0 && devices[targetIdx - 1].index() == devices[sourceIdx].index()) {
1324 // It's a duplicate, skip it.
1325 continue;
1326 }
1327 if (sourceIdx != targetIdx) {
1328 devices[targetIdx] = devices[sourceIdx];
1329 }
1330 targetIdx++;
1331 }
1332 // If there were duplicates there's now a gap at the end: trim it. Resizing
1333 // requires the item type to be default-constructible (which c10::Device is
1334 // not) because in principle it could be required to create new items. Since
1335 // we know we'll shrink the vector, we provide a custom dummy value instead.
1336 devices.resize(targetIdx, c10::Device(c10::kCPU));
1337 return devices;
1338 }
1339
1340 static void ensureIsSubsetOfDevices(
1341 const std::vector<c10::Device>& subset,
1342 const std::vector<c10::Device>& superset) {
1343 // We assume the devices in both vectors have the same consistent type, and
1344 // their indices are unique and sorted.
1345 std::vector<c10::Device> excessDevices;
1346 std::set_difference(
1347 subset.begin(),
1348 subset.end(),
1349 superset.begin(),
1350 superset.end(),
1351 std::back_inserter(excessDevices),
1352 [](const c10::Device& a, const c10::Device& b) { return a.index() < b.index(); });
1353 TORCH_CHECK_VALUE(
1354 excessDevices.empty(),
1355 "The result contained tensors residing on device(s) ",
1356 formatSetOfDevices(excessDevices),
1357 " which are not among the expected device(s) ",
1358 formatSetOfDevices(superset));
1359 }
1360
1361 mutable std::mutex mutex_;
1362 std::atomic_bool completed_ = {false}; // is this future complete
1363 std::condition_variable finished_cv_;
1364
1365 IValue value_; // when finished the value
1366 TypePtr type_;
1367 std::vector<FutureCallback> callbacks_;
1368 std::exception_ptr eptr_;
1369
1370 // An upcast pointer to a virtual class which allows us to manipulate events,
1371 // streams, ... in a generic way, without an explicit dependency on CUDA.
1372 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
1373 const c10::impl::VirtualGuardImpl impl_;
1374
1375 // The device that was current when markCompleted was called, which we'll
1376 // restore when invoking callbacks. It's optional because we'll only store it
1377 // if the future completes successfully.
1378 std::optional<c10::Device> currentDevice_;
1379
1380 // The events that correspond to the completion of the async I/O kernels. They
1381 // are recorded on the appropriate streams when the future is marked completed
1382 // and can then be queried/waited/blocked on. There is one event for each
1383 // distinct device on which the value's tensors reside.
1384 std::vector<c10::Event> events_;
1385
1386 // A cached version of the storages extracted from the value when the future
1387 // is first marked completed.
1388 std::vector<WeakStorage> storages_;
1389
1390 // The bounding set of devices that this future, and any of its children, is
1391 // allowed to use. This is a superset of the set of devices used by the events
1392 // above. We need this to know what streams (for which devices) to set as
1393 // current when invoking a callback, thus allowing the callback to use devices
1394 // that the parent future didn't use. This field is set to the value provided
1395 // in the constructor and will be "inherited" by all child futures.
1396 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
1397 const std::vector<c10::Device> devices_;
1398 };
1399
1400 struct C10_EXPORT ivalue::Await final : c10::intrusive_ptr_target {
1401 private:
1402 explicit Await(TypePtr elType, std::function<IValue()> fn)
1403 : elType_(std::move(elType)), type_(AwaitType::create(elType_)), fn_(std::move(fn)) {}
1404
1405 explicit Await(TypePtr elType) : elType_(std::move(elType)), type_(AwaitType::create(elType_)) { }
1406
1407 friend c10::intrusive_ptr<Await>;
1408
1409 public:
1410 Await(const Await&) = delete;
1411 Await(Await&&) = delete;
1412 Await& operator=(const Await&) = delete;
1413 Await& operator=(Await&&) = delete;
1414
1415 IValue wait() {
1416 if (!completed_) {
1417 TORCH_CHECK(fn_, "Incompleted Await: fn can't be None");
1418 value_ = fn_();
1419 completed_ = true;
1420 args_ = {};
1421 }
1422 return value_;
1423 }
1424
1425 IValue value() {
1426 TORCH_CHECK(completed_, "Await must be completed");
1427 return value_;
1428 }
1429
1430 void setFn(std::function<IValue()> fn) {
1431 fn_ = std::move(fn);
1432 }
1433
1434 bool completed() {
1435 return completed_;
1436 }
1437
1438 void markCompleted(IValue value) {
1439 value_ = std::move(value);
1440 completed_ = true;
1441 }
1442
1443 TORCH_API friend std::ostream& operator<<(
1444 std::ostream& out,
1445 const Await& v);
1446
1447 const TypePtr& elementType() const {
1448 return elType_;
1449 }
1450
1451 const TypePtr& type() const {
1452 return type_;
1453 }
1454
1455 void setArgs(std::vector<IValue> args) {
1456 args_ = std::move(args);
1457 }
1458
1459 std::vector<IValue>& args() {
1460 return args_;
1461 }
1462
1463 private:
1464 TypePtr elType_;
1465 TypePtr type_;
1466 std::vector<IValue> args_;
1467 std::function<IValue()> fn_;
1468 IValue value_;
1469 bool completed_{};
1470 };
1471
1472 // Input is a list of Futures with the same target type.
1473 // Output is a Future to the List of completed Futures.
1474 TORCH_API intrusive_ptr<ivalue::Future> collectAll(
1475 const c10::List<c10::intrusive_ptr<ivalue::Future>>& srcs);
1476 // Input is a List of Futures with the same target type.
1477 // Output is a Future that will be updated with a seen value.
1478 TORCH_API intrusive_ptr<ivalue::Future> collectAny(
1479 const c10::List<c10::intrusive_ptr<ivalue::Future>>& srcs);
1480
1481 // User-defined object.
1482 struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
1483 public:
1484 // In general, class types hold a shared_ptr to its owning CompilationUnit,
1485 // so that its type and methods do not get deallocated while the class exists.
1486 // However, the CompilationUnit holds ownership of the type's graphs, so
1487 // inserting a constant object into a Graph would create a reference cycle if
1488 // that constant object held a shared_ptr to its CU. For these objects we
1489 // instatiate them with non-owning references to its CU
1490 Object(WeakOrStrongTypePtr type, size_t numSlots) : type_(std::move(type)) {
1491 slots_.resize(numSlots);
1492 }
1493
1494 Object(StrongTypePtr type, size_t numSlots)
1495 : type_(WeakOrStrongTypePtr(std::move(type))) {
1496 slots_.resize(numSlots);
1497 }
1498
1499 static c10::intrusive_ptr<Object> create(
1500 WeakOrStrongTypePtr type,
1501 size_t numSlots) {
1502 return c10::make_intrusive<Object>(std::move(type), numSlots);
1503 }
1504
1505 static c10::intrusive_ptr<Object> create(
1506 StrongTypePtr type,
1507 size_t numSlots) {
1508 return c10::make_intrusive<Object>(std::move(type), numSlots);
1509 }
1510
1511 static c10::intrusive_ptr<Object> create(ClassTypePtr classType, size_t numSlots);
1512
1513 /**
1514 * Slot API.
1515 *
1516 * Attributes are stored as a simple vector so that lookups are fast at
1517 * runtime. A "slot" is just an index into that vector, which can be computed
1518 * statically if you have access to the class type. Use this API if you are
1519 * writing compiler stuff.
1520 */
1521 void setSlot(size_t slot, IValue v) {
1522 if (slot >= slots_.size()) {
1523 // for module types, it is possible that the members of the class have
1524 // expanded after the object was created. In this case, we expand
1525 // the slots to the right size
1526 resizeObject(slot);
1527 }
1528 slots_[slot] = std::move(v);
1529 }
1530
1531 const IValue& getSlot(size_t slot) const {
1532 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(slot < slots_.size());
1533 // NOTE: This lookup is fairly hot, so we use unchecked access to the
1534 // vector. Errors should still be detectable with ASan.
1535 return slots_[slot];
1536 }
1537
1538 void unsafeRemoveSlot(size_t slot) {
1539 TORCH_CHECK(slot < slots_.size());
1540 slots_.erase(slots_.begin() + static_cast<std::ptrdiff_t>(slot));
1541 }
1542
1543 /**
1544 * Attribute API.
1545 *
1546 * Wrappers around the slot stuff so that users can access attributes
1547 * directly. Use this API if you are a user.
1548 *
1549 * Note: Unlike in Python, TorchScript must make a distinction between
1550 * attributes (which are IValues) and methods (which are Methods). If you
1551 * want a method, use `obj.type()->getMethod()`
1552 */
1553 IValue getAttr(const std::string& name) const;
1554 void setAttr(const std::string& name, IValue v);
1555 // Remove attribute by name, caller is responsible for
1556 // the safety of this operation
1557 // We didn't remove the attribute in the type because the type
1558 // might be shared by multiple objects.
1559 // Therefore after removing attribute, the object is in an inconsistent
1560 // state where it has more attribute types in its Type than
1561 // the attribute slots it has, user needs to make sure the object
1562 // has consistent by removing the attribute in type as well
1563 void unsafeRemoveAttr(const std::string& name);
1564
1565 std::string name() const;
1566
1567 const std::vector<IValue>& slots() const {
1568 return slots_;
1569 }
1570 std::shared_ptr<ClassType> type() const;
1571
1572 std::shared_ptr<torch::jit::CompilationUnit> compilation_unit() {
1573 if (type_.holds_strong_ref()) {
1574 return type_.cu_.getStrongRefOrThrow();
1575 } else {
1576 auto weak_ptr = type_.cu_.getWeakRefOrThrow();
1577 return std::shared_ptr<torch::jit::CompilationUnit>(weak_ptr);
1578 }
1579 }
1580
1581 c10::intrusive_ptr<Object> copy_to_weak_compilation_ref() const;
1582
1583 void unsafe_make_weak_compilation_ref() {
1584 type_ = WeakOrStrongTypePtr(type_.asWeakTypePtr());
1585 }
1586
1587 c10::intrusive_ptr<Object> copy() const;
1588
1589 c10::intrusive_ptr<Object> deepcopy(
1590 std::optional<at::Device> device = std::nullopt) const;
1591
1592 c10::intrusive_ptr<Object> deepcopy(
1593 IValue::HashIdentityIValueMap& memo,
1594 std::optional<at::Device> device = std::nullopt) const;
1595
1596 bool is_weak_compilation_ref() const {
1597 return !type_.holds_strong_ref();
1598 }
1599
1600 bool is_empty_strong_compilation_ref() const {
1601 return type_.holds_empty_strong_ref();
1602 }
1603
1604 private:
1605 void resizeObject(size_t slot);
1606 WeakOrStrongTypePtr type_;
1607 std::vector<IValue> slots_;
1608 };
1609
1610 // virtual ivalue PyObjectHolder that hold a py::object, we make this virtual
1611 // because the py::object and refcounting logic should happen in libtorch_python
1612 // see concrete implementation in python_ivalue.h
1613 struct ivalue::PyObjectHolder : c10::intrusive_ptr_target {
1614 public:
1615 virtual PyObject* getPyObject() = 0;
1616 virtual c10::InferredType tryToInferType() = 0;
1617 virtual IValue toIValue(const TypePtr& type, std::optional<int32_t> N = std::nullopt) = 0;
1618 virtual std::string toStr() = 0;
1619 virtual std::vector<at::Tensor> extractTensors() = 0;
1620
1621 ~PyObjectHolder() override = default;
1622 };
1623
1624 struct ivalue::EnumHolder : c10::intrusive_ptr_target {
1625 public:
1626 EnumHolder(std::shared_ptr<EnumType> type, std::string name, IValue value)
1627 : type_(std::move(type)),
1628 name_(std::move(name)),
1629 value_(std::move(value)) {}
1630
1631 bool is(const ivalue::EnumHolder& rhs) {
1632 return *this == rhs;
1633 }
1634
1635 friend bool operator==(
1636 const ivalue::EnumHolder& lhs,
1637 const ivalue::EnumHolder& rhs);
1638
1639 TORCH_API friend std::ostream& operator<<(
1640 std::ostream& out,
1641 const ivalue::EnumHolder& v);
1642
1643 TORCH_API const std::string& qualifiedClassName() const;
1644
1645 const std::string& unqualifiedClassName() const;
1646
1647 const std::string& name() const {
1648 return name_;
1649 }
1650
1651 const IValue& value() const {
1652 return value_;
1653 }
1654
1655 std::shared_ptr<EnumType> type() const {
1656 return type_;
1657 }
1658
1659 private:
1660 std::shared_ptr<EnumType> type_;
1661 std::string name_;
1662 IValue value_;
1663 };
1664
1665 #undef TORCH_FORALL_TAGS
1666
1667 namespace detail {
1668
1669 struct _guarded_unsigned_long_unique_dummy final {
1670 _guarded_unsigned_long_unique_dummy(int64_t){};
1671 };
1672 using _guarded_unsigned_long = std::conditional_t<
1673 std::is_same_v<unsigned long, uint32_t> ||
1674 std::is_same_v<unsigned long, uint64_t>,
1675 _guarded_unsigned_long_unique_dummy,
1676 unsigned long>;
1677
1678 } // namespace detail
1679
1680 inline ivalue::Object& IValue::toObjectRef() const {
1681 AT_ASSERT(isObject(), "Expected Object but got ", tagKind());
1682 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(), "Attempted to create null reference");
1683 return *static_cast<c10::ivalue::Object*>(payload.u.as_intrusive_ptr);
1684 }
1685
1686 // note: when adding a DEFINE_TO case here you should also add a
1687 // toX method to IValue. These named methods are much more discoverable
1688 // than the to templated function.
1689
1690 #define DEFINE_TO(T, method_name) \
1691 template <> \
1692 inline T IValue::to<T>()&& { \
1693 return static_cast<T>(std::move(*this).method_name()); \
1694 } \
1695 template <> \
1696 inline c10::detail::ivalue_to_const_ref_overload_return<T>::type IValue::to<T>() const& { \
1697 typedef c10::detail::ivalue_to_const_ref_overload_return<T>::type return_type; \
1698 return static_cast<return_type>(this->method_name()); \
1699 }
1700
1701 DEFINE_TO(at::Tensor, toTensor)
1702 DEFINE_TO(at::Storage, toStorage)
1703 DEFINE_TO(c10::Stream, toStream)
1704 DEFINE_TO(float, toDouble)
1705 DEFINE_TO(double, toDouble)
1706 DEFINE_TO(c10::complex<double>, toComplexDouble)
1707 DEFINE_TO(unsigned char, toInt)
1708 DEFINE_TO(signed char, toInt)
1709 DEFINE_TO(unsigned short, toInt)
1710 DEFINE_TO(short, toInt)
1711 DEFINE_TO(int, toInt)
1712 DEFINE_TO(uint32_t, toInt)
1713 DEFINE_TO(uint64_t, toInt)
1714 DEFINE_TO(detail::_guarded_unsigned_long, toInt)
1715 DEFINE_TO(int64_t, toInt)
1716 DEFINE_TO(bool, toBool)
1717 DEFINE_TO(c10::intrusive_ptr<caffe2::Blob>, toBlob);
1718 DEFINE_TO(c10::intrusive_ptr<ivalue::ConstantString>, toString)
1719 DEFINE_TO(c10::intrusive_ptr<ivalue::Object>, toObject)
1720 DEFINE_TO(at::Scalar, toScalar)
1721 DEFINE_TO(c10::List<int64_t>, toIntList)
1722 DEFINE_TO(c10::List<double>, toDoubleList)
1723 DEFINE_TO(c10::List<c10::complex<double>>, toComplexDoubleList)
1724 DEFINE_TO(c10::List<bool>, toBoolList)
1725 DEFINE_TO(c10::List<at::Tensor>, toTensorList)
1726 DEFINE_TO(c10::impl::GenericList, toList)
1727 DEFINE_TO(c10::impl::GenericDict, toGenericDict)
1728 DEFINE_TO(c10::intrusive_ptr<ivalue::Tuple>, toTuple)
1729 DEFINE_TO(std::string, toStringRef)
1730 DEFINE_TO(c10::string_view, toStringView)
1731 DEFINE_TO(c10::intrusive_ptr<ivalue::Future>, toFuture)
1732 DEFINE_TO(c10::intrusive_ptr<ivalue::Await>, toAwait)
1733 DEFINE_TO(c10::intrusive_ptr<c10::RRefInterface>, toRRef)
1734 DEFINE_TO(c10::intrusive_ptr<at::Quantizer>, toQuantizer)
1735 DEFINE_TO(IValue, toIValue)
1736 DEFINE_TO(c10::Device, toDevice)
1737 DEFINE_TO(at::ScalarType, toScalarType)
1738 DEFINE_TO(at::Layout, toLayout)
1739 DEFINE_TO(at::MemoryFormat, toMemoryFormat)
1740 DEFINE_TO(at::QScheme, toQScheme)
1741 DEFINE_TO(at::Dimname, toDimname)
1742 DEFINE_TO(at::Generator, toGenerator)
1743 DEFINE_TO(c10::SymInt, toSymInt)
1744 DEFINE_TO(c10::SymFloat, toSymFloat)
1745 DEFINE_TO(c10::SymBool, toSymBool)
1746
1747 template <class T>
1748 struct _fake_type {};
1749
1750 // generic_to<T> converts an IValue from a generic list or generic dict
1751 // to a concrete list/dict type likelike List<T>, Dict<...> or std::optional<T>.
1752 // Note that in the case of lists, this only works for IValue-based lists,
1753 // i.e. not for int64_t, double, ...
1754 // generic_to<T> is an implementation detail of IValue::to<T> and not
1755 // supposed to be called directly.
1756 // The _fake_type<T> parameter allows us to overload
1757 // based on the return type.
1758 template <class Elem>
1759 // TODO this is deprecated but we don't throw a warning because a lot of ops in
1760 // native_functions.yaml still return std::vector.
1761 // C10_DEPRECATED_MESSAGE("IValues based on std::vector<T> are potentially slow
1762 // and deprecated. Please use torch::List<T> instead.")
1763 std::vector<Elem> generic_to(IValue ivalue, _fake_type<std::vector<Elem>>) {
1764 // We need to do a deep copy of the vector because there might be other
1765 // references to this same IValue that also use the list. We can't just
1766 // move the elements out.
1767 auto list = std::move(ivalue).to<List<Elem>>();
1768 std::vector<Elem> result;
1769 result.reserve(list.size());
1770 for (Elem v : list) {
1771 result.push_back(std::move(v));
1772 }
1773 return result;
1774 }
1775
1776 template <typename T>
1777 c10::intrusive_ptr<T> IValue::toCustomClass() && {
1778 static_assert(
1779 std::is_base_of<torch::CustomClassHolder, T>::value == true,
1780 "toCustomClass requires that template parameter T must inherit "
1781 "from torch::CustomClassHolder");
1782 auto obj = toObject();
1783 TORCH_CHECK(
1784 obj->slots().size() == 1,
1785 "Tried to cast IValue to custom class but it did "
1786 "not contain a custom class!");
1787 const auto* expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>().get();
1788 ivalue::checkCustomClassType(expected_type, type().get());
1789 auto userObj =
1790 c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
1791 return userObj;
1792 }
1793
1794 template <typename T>
1795 c10::intrusive_ptr<T> IValue::toCustomClass() const& {
1796 static_assert(
1797 std::is_base_of<torch::CustomClassHolder, T>::value == true,
1798 "toCustomClass requires that template parameter T must inherit "
1799 "from torch::CustomClassHolder");
1800 auto obj = toObject();
1801 TORCH_CHECK(
1802 obj->slots().size() == 1,
1803 "Tried to cast IValue to custom class but it did "
1804 "not contain a custom class!");
1805 const auto* expected_type = c10::getCustomClassType<c10::intrusive_ptr<T>>().get();
1806 ivalue::checkCustomClassType(expected_type, type().get());
1807 auto userObj =
1808 c10::static_intrusive_pointer_cast<T>(obj->getSlot(0).toCapsule());
1809 return userObj;
1810 }
1811
1812 template <typename T>
1813 T generic_to(IValue ivalue, _fake_type<T>) {
1814 using ElemType = typename std::remove_pointer<T>::type::element_type;
1815 return std::move(ivalue).toCustomClass<ElemType>();
1816 }
1817
1818 template <typename T>
1819 tagged_capsule<T> generic_to(IValue ivalue, _fake_type<tagged_capsule<T>>) {
1820 return tagged_capsule<T>{std::move(ivalue)};
1821 }
1822
1823 template <typename Elem>
1824 c10::List<Elem> generic_to(IValue ivalue, _fake_type<c10::List<Elem>>) {
1825 return impl::toTypedList<Elem>(std::move(ivalue).toList());
1826 }
1827
1828 template <typename T>
1829 static T createVectorLikeFromList(const c10::detail::ListImpl* impl) {
1830 T result;
1831 result.reserve(impl->list.size());
1832 for (const auto & i : impl->list) {
1833 result.push_back(i.to<typename T::value_type>());
1834 }
1835 return result;
1836 }
1837
1838 template <typename T>
1839 static std::vector<T> createVectorFromList(const c10::detail::ListImpl* impl) {
1840 return createVectorLikeFromList<std::vector<T>>(impl);
1841 }
1842
1843 template <typename T>
1844 std::vector<T> createVectorFromList(const c10::List<T>& impl) {
1845 std::vector<T> result;
1846 result.reserve(impl.size());
1847 for (size_t i = 0, N = impl.size(); i < N; ++i) {
1848 result.push_back(impl[i]);
1849 }
1850 return result;
1851 }
1852
1853 template <typename T>
1854 OptionalArray<T> generic_to(IValue ivalue, _fake_type<OptionalArray<T>>) {
1855 if (ivalue.isNone()) {
1856 return {};
1857 }
1858 return createVectorFromList<T>(
1859 std::move(ivalue).to<c10::List<T>>()
1860 );
1861 }
1862
1863 namespace detail {
1864 template <typename Elem, size_t... I>
1865 std::array<Elem, sizeof...(I)> generic_to_array(
1866 IValue ivalue,
1867 _fake_type<std::array<Elem, sizeof...(I)>>,
1868 std::index_sequence<I...>) {
1869 // We need to do a deep copy of the array because there might be other
1870 // references to this same IValue that also use the list. We can't just
1871 // move the elements out.
1872 auto list = std::move(ivalue).to<List<Elem>>();
1873 TORCH_CHECK(
1874 list.size() == sizeof...(I),
1875 "Tried to convert a List with ",
1876 list.size(),
1877 " elements to a fixed-size array of size ",
1878 sizeof...(I));
1879 return {list[I]...};
1880 }
1881 } // namespace detail
1882
1883 template <typename Elem, size_t N>
1884 std::array<Elem, N> generic_to(
1885 IValue ivalue,
1886 _fake_type<std::array<Elem, N>> ft) {
1887 return detail::generic_to_array(ivalue, ft, std::make_index_sequence<N>());
1888 }
1889
1890 template <typename Key, typename Value>
1891 c10::Dict<Key, Value> generic_to(
1892 IValue ivalue,
1893 _fake_type<c10::Dict<Key, Value>>) {
1894 return impl::toTypedDict<Key, Value>(std::move(ivalue).toGenericDict());
1895 }
1896
1897 template <typename K, typename V>
1898 C10_DEPRECATED_MESSAGE(
1899 "IValues based on std::unordered_map are slow and deprecated. Please use c10::Dict<K, V> instead.")
1900 std::unordered_map<K, V> generic_to(
1901 IValue ivalue,
1902 _fake_type<std::unordered_map<K, V>>) {
1903 std::unordered_map<K, V> specialized_dict;
1904
1905 for (const auto& item : std::move(ivalue).toGenericDict()) {
1906 specialized_dict[item.key().template to<K>()] = item.value().template to<V>();
1907 }
1908
1909 return specialized_dict;
1910 }
1911
1912 template <typename T>
1913 std::optional<T> generic_to(IValue ivalue, _fake_type<std::optional<T>>) {
1914 if (ivalue.isNone()) {
1915 return std::nullopt;
1916 }
1917 return std::move(ivalue).to<T>();
1918 }
1919
1920 namespace detail {
1921 template <typename Tuple, std::size_t... INDEX>
1922 Tuple generic_to_tuple_impl(
1923 const ivalue::TupleElements& t,
1924 std::index_sequence<INDEX...>) {
1925 return std::make_tuple(
1926 t[INDEX].to<typename std::tuple_element<INDEX, Tuple>::type>()...);
1927 }
1928 } // namespace detail
1929
1930 template <
1931 typename... Args,
1932 typename Indices = std::make_index_sequence<sizeof...(Args)>,
1933 std::enable_if_t<
1934 !std::disjunction_v<
1935 std::is_lvalue_reference<Args>...,
1936 std::negation<std::is_constructible<IValue, Args>>...>,
1937 std::nullptr_t> = nullptr>
1938 std::tuple<Args...> generic_to(const IValue& ivalue, _fake_type<std::tuple<Args...>>) {
1939 const auto& vals = ivalue.toTupleRef().elements();
1940 TORCH_CHECK(vals.size() == sizeof...(Args));
1941 return detail::generic_to_tuple_impl<std::tuple<Args...>>(vals, Indices{});
1942 }
1943
1944 template <typename T>
1945 inline T IValue::to() && {
1946 return generic_to(std::move(*this), _fake_type<T>{});
1947 }
1948
1949 template <>
1950 inline std::optional<c10::string_view> IValue::to() && {
1951 // In the default implementation, the IValue is destroyed with std::move.
1952 // But if the unboxed type is std::optional<string_view> we cannot destroy
1953 // the IValue.
1954 return generic_to(*this, _fake_type<std::optional<c10::string_view>>{});
1955 }
1956
1957 template <typename T>
1958 inline typename c10::detail::ivalue_to_const_ref_overload_return<T>::type IValue::to() const& {
1959 return generic_to(*this, _fake_type<T>{});
1960 }
1961
1962 inline c10::List<int64_t> IValue::toIntList() && {
1963 AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1964 return c10::List<int64_t>(moveToIntrusivePtr<c10::detail::ListImpl>());
1965 }
1966 inline c10::List<int64_t> IValue::toIntList() const& {
1967 AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1968 return c10::List<int64_t>(toIntrusivePtr<c10::detail::ListImpl>());
1969 }
1970 inline std::vector<int64_t> IValue::toIntVector() const {
1971 AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1972 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1973 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
1974 "called toIntVector on null intrusive_ptr IValue");
1975 return createVectorFromList<int64_t>(
1976 static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
1977 }
1978 inline std::vector<c10::SymInt> IValue::toSymIntVector() const {
1979 AT_ASSERT(isSymIntList() || isIntList(), "Expected SymIntList or IntList but got ", tagKind());
1980 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1981 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
1982 "called toSymIntVector on null intrusive_ptr IValue");
1983 return createVectorFromList<c10::SymInt>(
1984 static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
1985 }
1986 inline at::DimVector IValue::toDimVector() const {
1987 AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind());
1988 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1989 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
1990 "called toDimVector on null intrusive_ptr IValue");
1991 return createVectorLikeFromList<at::DimVector>(
1992 static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
1993 }
1994 inline c10::List<double> IValue::toDoubleList() && {
1995 AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
1996 return c10::List<double>(moveToIntrusivePtr<c10::detail::ListImpl>());
1997 }
1998 inline c10::List<double> IValue::toDoubleList() const& {
1999 AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
2000 return c10::List<double>(toIntrusivePtr<c10::detail::ListImpl>());
2001 }
2002 inline std::vector<double> IValue::toDoubleVector() const {
2003 AT_ASSERT(isDoubleList(), "Expected DoubleList but got ", tagKind());
2004 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2005 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2006 "called toDoubleVector on null intrusive_ptr IValue");
2007 return createVectorFromList<double>(
2008 static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
2009 }
2010 inline c10::List<c10::complex<double>> IValue::toComplexDoubleList() && {
2011 AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
2012 return c10::List<c10::complex<double>>(moveToIntrusivePtr<c10::detail::ListImpl>());
2013 }
2014 inline c10::List<c10::complex<double>> IValue::toComplexDoubleList() const& {
2015 AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
2016 return c10::List<c10::complex<double>>(toIntrusivePtr<c10::detail::ListImpl>());
2017 }
2018 inline std::vector<c10::complex<double>> IValue::toComplexDoubleVector() const {
2019 AT_ASSERT(isComplexDoubleList(), "Expected ComplexDoubleList but got ", tagKind());
2020 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2021 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2022 "called toComplexDoubleVector on null intrusive_ptr IValue");
2023 return createVectorFromList<c10::complex<double>>(
2024 static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
2025 }
2026 inline c10::List<bool> IValue::toBoolList() && {
2027 AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind());
2028 return c10::List<bool>(moveToIntrusivePtr<c10::detail::ListImpl>());
2029 }
2030 inline c10::List<bool> IValue::toBoolList() const& {
2031 AT_ASSERT(isBoolList(), "Expected BoolList but got ", tagKind());
2032 return c10::List<bool>(toIntrusivePtr<c10::detail::ListImpl>());
2033 }
2034 inline c10::List<at::Tensor> IValue::toTensorList() && {
2035 AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
2036 return c10::List<at::Tensor>(moveToIntrusivePtr<c10::detail::ListImpl>());
2037 }
2038 inline c10::List<at::Tensor> IValue::toTensorList() const& {
2039 AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
2040 return c10::List<at::Tensor>(toIntrusivePtr<c10::detail::ListImpl>());
2041 }
2042 inline std::vector<at::Tensor> IValue::toTensorVector() const {
2043 AT_ASSERT(isTensorList(), "Expected TensorList but got ", tagKind());
2044 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2045 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2046 "called toTensorVector on null intrusive_ptr IValue");
2047 return createVectorFromList<at::Tensor>(
2048 static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
2049 }
2050 inline c10::List<std::optional<at::Tensor>> IValue::toOptionalTensorList() && {
2051 AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
2052 return c10::List<std::optional<at::Tensor>>(moveToIntrusivePtr<c10::detail::ListImpl>());
2053 }
2054 inline c10::List<std::optional<at::Tensor>> IValue::toOptionalTensorList() const& {
2055 AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
2056 return c10::List<std::optional<at::Tensor>>(toIntrusivePtr<c10::detail::ListImpl>());
2057 }
2058 inline std::vector<std::optional<at::Tensor>> IValue::toOptionalTensorVector() const {
2059 AT_ASSERT(isOptionalTensorList(), "Expected OptionalTensorList but got ", tagKind());
2060 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2061 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2062 "called toOptionalTensorVector on null intrusive_ptr IValue");
2063 return createVectorFromList<std::optional<at::Tensor>>(
2064 static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr));
2065 }
2066 inline c10::List<IValue> IValue::toList() && {
2067 AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
2068 return c10::List<IValue>(moveToIntrusivePtr<c10::detail::ListImpl>());
2069 }
2070 inline c10::List<IValue> IValue::toList() const& {
2071 AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
2072 return c10::List<IValue>(toIntrusivePtr<c10::detail::ListImpl>());
2073 }
2074 inline c10::ArrayRef<IValue> IValue::toListRef() const {
2075 AT_ASSERT(isList(), "Expected GenericList but got ", tagKind());
2076 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2077 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2078 "called toListRef on null intrusive_ptr IValue");
2079 return static_cast<const c10::detail::ListImpl*>(payload.u.as_intrusive_ptr)
2080 ->list;
2081 }
2082 inline c10::Dict<IValue, IValue> IValue::toGenericDict() && {
2083 AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind());
2084 return c10::Dict<IValue, IValue>(moveToIntrusivePtr<c10::detail::DictImpl>());
2085 }
2086 inline c10::Dict<IValue, IValue> IValue::toGenericDict() const& {
2087 AT_ASSERT(isGenericDict(), "Expected GenericDict but got ", tagKind());
2088 return c10::Dict<IValue, IValue>(toIntrusivePtr<c10::detail::DictImpl>());
2089 }
2090 inline c10::intrusive_ptr<ivalue::Tuple> IValue::toTuple() && {
2091 AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
2092 return moveToIntrusivePtr<ivalue::Tuple>();
2093 }
2094 inline c10::intrusive_ptr<ivalue::Tuple> IValue::toTuple() const& {
2095 AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
2096 return toIntrusivePtr<ivalue::Tuple>();
2097 }
2098 inline ivalue::Tuple& IValue::toTupleRef() const {
2099 AT_ASSERT(isTuple(), "Expected Tuple but got ", tagKind());
2100 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2101 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2102 "called toTupleRef on null intrusive_ptr IValue");
2103 return *static_cast<c10::ivalue::Tuple*>(
2104 payload.u.as_intrusive_ptr);
2105 }
2106
2107 inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v)
2108 : tag(Tag::Tuple) {
2109 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2110 }
2111 template <
2112 typename... Args,
2113 std::enable_if_t<
2114 !std::disjunction_v<
2115 std::is_lvalue_reference<Args>...,
2116 std::negation<std::is_constructible<IValue, Args>>...>,
2117 std::nullptr_t>>
2118 inline IValue::IValue(const std::tuple<Args...>& t)
2119 : IValue(c10::guts::apply(c10::ivalue::Tuple::create<const Args&...>, t)) {
2120 }
2121
2122 template <
2123 typename... Args,
2124 std::enable_if_t<
2125 !std::disjunction_v<
2126 std::is_lvalue_reference<Args>...,
2127 std::negation<std::is_constructible<IValue, Args>>...>,
2128 std::nullptr_t>>
2129 inline IValue::IValue(std::tuple<Args...>&& t)
2130 : IValue(c10::guts::apply(c10::ivalue::Tuple::create<Args&&...>, std::move(t))) {
2131 }
2132
2133 inline IValue::IValue(c10::intrusive_ptr<ivalue::ConstantString> v)
2134 : tag(Tag::String) {
2135 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2136 }
2137 inline IValue::IValue(std::string v)
2138 : IValue(ivalue::ConstantString::create(std::move(v))) {}
2139
2140 inline IValue::IValue(c10::impl::GenericList v)
2141 : tag(Tag::GenericList) {
2142 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release());
2143 }
2144
2145 template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2146 inline IValue::IValue(c10::List<T>&& v) : IValue(impl::toList<T>(std::move(v))) {}
2147 template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2148 inline IValue::IValue(const c10::List<T>& v) : IValue(impl::toList<T>(v)) {}
2149 template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2150 inline IValue::IValue(at::ArrayRef<T> v) : IValue(c10::List<T>()) {
2151 auto list = to<c10::List<T>>();
2152 list.reserve(v.size());
2153 for (const auto& e : v) {
2154 list.push_back(e);
2155 }
2156 }
2157 template <class T, IValue::enable_if_symint<T>>
2158 inline IValue::IValue(at::ArrayRef<T> v) : IValue() {
2159 auto vi = c10::asIntArrayRefSlowOpt(v);
2160 if (vi.has_value()) {
2161 // This list is entirely integers; ensure it is typed as
2162 // an IntList so toIntList works
2163 *this = IValue(*vi);
2164 } else {
2165 // This list has SymInts; type it as a SymInt
2166 *this = IValue(impl::toList<c10::SymInt>(c10::List<c10::SymInt>()));
2167 auto list = to<c10::List<c10::SymInt>>();
2168 list.reserve(v.size());
2169 for (const auto& e : v) {
2170 list.push_back(e);
2171 }
2172 }
2173 }
2174 template <class T, IValue::enable_if_symint<T>>
2175 inline IValue::IValue(at::OptionalArrayRef<T> mb_v) : IValue() {
2176 if (!mb_v.has_value()) return;
2177 *this = IValue(*mb_v);
2178 }
2179 template <class T, IValue::enable_if_symint<T>>
2180 inline IValue::IValue(const std::vector<T>& v) : IValue() {
2181 *this = IValue(at::ArrayRef<T>(v));
2182 }
2183 template <class T, IValue::enable_if_symint<T>>
2184 inline IValue::IValue(std::vector<T>&& v) : IValue() {
2185 auto vi = c10::asIntArrayRefSlowOpt(v);
2186 if (vi.has_value()) {
2187 // This list is entirely integers; ensure it is typed as
2188 // an IntList so toIntList works
2189 *this = IValue(*vi);
2190 } else {
2191 // This list has SymInts; type it as a SymInt
2192 *this = IValue(impl::toList<c10::SymInt>(c10::List<c10::SymInt>()));
2193 auto list = to<c10::List<c10::SymInt>>();
2194 list.reserve(v.size());
2195 for (auto&& e : std::move(v)) {
2196 list.push_back(std::move(e));
2197 }
2198 }
2199 }
2200 template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2201 inline IValue::IValue(const std::vector<T>& v) : IValue(c10::List<T>()) {
2202 auto list = to<c10::List<T>>();
2203 list.reserve(v.size());
2204 for (const auto& e : v) {
2205 list.push_back(e);
2206 }
2207 }
2208
2209 template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2210 inline IValue::IValue(std::vector<T>&& v) : IValue(c10::List<T>()) {
2211 auto list = to<c10::List<T>>();
2212 list.reserve(v.size());
2213 if constexpr (std::is_same_v<T, bool>) {
2214 for (auto e : v) {
2215 list.push_back(e);
2216 }
2217 } else {
2218 for (auto&& e : std::move(v)) {
2219 list.push_back(std::move(e));
2220 }
2221 }
2222 }
2223
2224 template <class T, IValue::enable_if_list_is_ivalue_constructible<T>>
2225 inline IValue::IValue(c10::OptionalArrayRef<T> v) : IValue() {
2226 if (v.has_value()) {
2227 *this = IValue(std::move(*v));
2228 }
2229 }
2230
2231 template <class T, size_t N>
2232 inline IValue::IValue(std::array<T, N> v) : IValue(c10::List<T>()) {
2233 auto list = to<c10::List<T>>();
2234 list.reserve(v.size());
2235 for (auto& e : v) {
2236 list.push_back(std::move(e));
2237 }
2238 }
2239
2240 template <class T, IValue::enable_if_ilist_is_ivalue_constructible<T>>
2241 inline IValue::IValue(c10::IListRef<T> v) : IValue() {
2242 constexpr bool boxed_type_constructs_ivalue =
2243 std::is_constructible<IValue, typename c10::IListRef<T>::boxed_type>::value;
2244 // First, we try to use the boxed value.
2245 // If we fail (either it's not in the boxed state, or its boxed type
2246 // can not construct an IValue), we fallback to copying the list.
2247 if (boxed_type_constructs_ivalue && v.isBoxed()) {
2248 *this = IValue(impl::toList(v.toBoxed()));
2249 } else {
2250 c10::List<T> list;
2251 list.reserve(v.size());
2252 for (const auto& t : v) {
2253 list.push_back(t);
2254 }
2255 *this = IValue(impl::toList(std::move(list)));
2256 }
2257 }
2258
2259 inline IValue::IValue(c10::impl::GenericDict v)
2260 : tag(Tag::GenericDict) {
2261 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.impl_.release());
2262 }
2263 template <class Key, class Value>
2264 inline IValue::IValue(c10::Dict<Key, Value> v)
2265 : IValue(impl::toGenericDict(std::move(v))) {}
2266
2267 template <class Key, class Value>
2268 inline IValue::IValue(std::unordered_map<Key, Value> v)
2269 : IValue(Dict<Key, Value>()) {
2270 auto dict = to<c10::Dict<Key, Value>>();
2271 dict.reserve(v.size());
2272 for (auto& e : v) {
2273 dict.insert(std::move(e.first), std::move(e.second));
2274 }
2275 }
2276
2277 template <class T, IValue::enable_if_ivalue_constructible<T>>
2278 inline IValue::IValue(std::optional<T> v) : IValue() {
2279 if (v.has_value()) {
2280 *this = IValue(std::move(*v));
2281 }
2282 }
2283
2284 inline IValue::IValue(std::nullopt_t) : IValue() {}
2285
2286 inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v)
2287 : tag(Tag::Object) {
2288 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2289 }
2290
2291 inline IValue::IValue(c10::intrusive_ptr<ivalue::PyObjectHolder> v)
2292 : tag(Tag::PyObject) {
2293 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2294 }
2295
2296 inline IValue::IValue(c10::intrusive_ptr<ivalue::EnumHolder> v)
2297 : tag(Tag::Enum) {
2298 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2299 }
2300
2301 inline IValue IValue::make_capsule(
2302 intrusive_ptr<torch::CustomClassHolder> blob) {
2303 IValue iv;
2304 iv.tag = Tag::Capsule;
2305 iv.payload.u.as_intrusive_ptr = null_to_undefined_tensor(blob.release());
2306 return iv;
2307 }
2308
2309 template <
2310 typename T,
2311 std::enable_if_t<std::is_base_of_v<torch::CustomClassHolder, T>, int>>
2312 IValue::IValue(c10::intrusive_ptr<T> custom_class) : tag(Tag::Object) {
2313 auto classType = []() {
2314 try {
2315 return c10::getCustomClassType<c10::intrusive_ptr<T>>();
2316 } catch (const c10::Error&) {
2317 throw c10::Error(
2318 "Trying to instantiate a class that isn't a registered custom class: " +
2319 std::string(c10::util::get_fully_qualified_type_name<T>()));
2320 }
2321 }();
2322 auto ivalue_obj = c10::ivalue::Object::create(std::move(classType), /* numSlots */1);
2323 ivalue_obj->setSlot(0, IValue::make_capsule(std::move(custom_class)));
2324 payload.u.as_intrusive_ptr = null_to_undefined_tensor(ivalue_obj.release());
2325
2326 }
2327
2328 inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
2329 : tag(Tag::Future) {
2330 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2331 }
2332
2333 inline IValue::IValue(c10::intrusive_ptr<ivalue::Await> v)
2334 : tag(Tag::Await) {
2335 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2336 }
2337
2338 inline IValue::IValue(c10::intrusive_ptr<c10::RRefInterface> v)
2339 : tag(Tag::RRef) {
2340 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2341 }
2342
2343 inline IValue::IValue(c10::intrusive_ptr<at::Quantizer> v)
2344 : tag(Tag::Quantizer) {
2345 payload.u.as_intrusive_ptr = null_to_undefined_tensor(v.release());
2346 }
2347
2348 template <typename T>
2349 inline IValue::IValue(c10::complex<T> c)
2350 : tag(Tag::ComplexDouble) {
2351 auto v = c10::make_intrusive<ivalue::ComplexHolder>(c);
2352 payload.u.as_intrusive_ptr = v.release();
2353 }
2354
2355 inline const std::string& IValue::toStringRef() const {
2356 AT_ASSERT(isString(), "Expected String but got ", tagKind());
2357 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2358 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2359 "called toStringRef on null intrusive_ptr IValue");
2360 return static_cast<const c10::ivalue::ConstantString*>(
2361 payload.u.as_intrusive_ptr)
2362 ->string();
2363 }
2364 inline std::optional<std::reference_wrapper<const std::string>> IValue::
2365 toOptionalStringRef() const {
2366 if (isNone()) {
2367 return std::nullopt;
2368 }
2369 AT_ASSERT(isString(), "Expected std::optional<string> but got ", tagKind());
2370 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2371 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2372 "called toOptionalStringRef on null intrusive_ptr IValue");
2373 return std::reference_wrapper<const std::string>(
2374 static_cast<const c10::ivalue::ConstantString*>(payload.u.as_intrusive_ptr)
2375 ->string());
2376 }
2377
2378 inline c10::string_view IValue::toStringView() const {
2379 AT_ASSERT(isString(), "Expected String but got ", tagKind());
2380 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
2381 payload.u.as_intrusive_ptr != c10::UndefinedTensorImpl::singleton(),
2382 "called toStringView on null intrusive_ptr IValue");
2383 return static_cast<const c10::ivalue::ConstantString*>(
2384 payload.u.as_intrusive_ptr)
2385 ->string_view();
2386 }
2387
2388 inline PyObject* IValue::toPyObject() const {
2389 return toPyObjectHolder()->getPyObject();
2390 }
2391
2392 template <typename T>
2393 inline std::optional<T> IValue::toOptional() {
2394 if (this->isNone()) {
2395 return std::nullopt;
2396 }
2397 return this->to<T>();
2398 }
2399
2400 template <typename T>
2401 inline std::optional<T> IValue::toOptional() const {
2402 if (this->isNone()) {
2403 return std::nullopt;
2404 }
2405 return this->to<T>();
2406 }
2407
2408 inline bool IValue::isCustomClass() const {
2409 return torch::isCustomClass(*this);
2410 }
2411
2412 inline bool IValue::isSameIdentity(const IValue& rhs) const {
2413 // We choose to not use memcmp for payload check due to potential random
2414 // padding characters on union type
2415
2416 // Semantics:
2417 // 1. Immutable primitive values of the same type (Int, Double, None, Bool,
2418 // Str) return value equality
2419 // 2. If it is a tensor type, we need to take undefined tensor into account
2420 // 3. Undefined_tensor is None and vice versa should be true
2421 // 4. If it is a reference type (i.e. isIntrusivePtr()), then is True when
2422 // the pointed-to object is the same.
2423 // 5. False for all other comparisons.
2424 if (this->isNone() && rhs.isNone()) {
2425 return true;
2426 } else if (this->isBool() && rhs.isBool()) {
2427 // for bool type, do equality check
2428 return this->toBool() == rhs.toBool();
2429 } else if (this->isTensor() && rhs.isTensor()) {
2430 return this->payload.as_tensor.is_same(rhs.payload.as_tensor);
2431 } else if (this->isTensor() && rhs.isNone()) {
2432 // special case: undefined tensor and None are the same identity
2433 return !this->payload.as_tensor.defined();
2434 } else if (this->isNone() && rhs.isTensor()) {
2435 // special case: undefined tensor and None are the same identity
2436 return !rhs.payload.as_tensor.defined();
2437 } else if (this->isInt() && rhs.isInt()) {
2438 return this->toInt() == rhs.toInt();
2439 } else if (this->isDouble() && rhs.isDouble()) {
2440 return this->toDouble() == rhs.toDouble();
2441 } else if (this->isString() && rhs.isString()) {
2442 return this->toStringRef() == rhs.toStringRef();
2443 } else {
2444 // for objects holding in IValue, do shallow compare on pointer address to
2445 // testify the identity
2446 return this->isIntrusivePtr() && rhs.isIntrusivePtr() &&
2447 this->payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
2448 }
2449 }
2450
2451 namespace ivalue {
2452 namespace detail {
2453
2454 template <typename T>
2455 IValue from_(T&& x, std::true_type) {
2456 return IValue(std::forward<T>(x));
2457 }
2458 template <typename T>
2459 IValue from_(c10::intrusive_ptr<T> x, std::false_type) {
2460 return IValue(std::move(x));
2461 }
2462 template <typename T>
2463 IValue from_(T&& /*x*/, std::false_type) {
2464 static_assert(
2465 guts::false_t<T>::value,
2466 "You are calling from with a type that it doesn't support, and isn't a potential custom class (ie: is an intrusive_ptr)");
2467 return IValue();
2468 }
2469 } // namespace detail
2470
2471 template <typename T>
2472 IValue from(T&& x) {
2473 return detail::from_(
2474 std::forward<T>(x), typename std::is_constructible<IValue, T>::type{});
2475 }
2476
2477 } // namespace ivalue
2478
2479
2480 template <>
2481 struct MaybeOwnedTraits<IValue> {
2482 using owned_type = IValue;
2483 using borrow_type = IValue;
2484
2485 static borrow_type createBorrow(const owned_type& from) {
2486 if (!from.isPtrType()) {
2487 return from;
2488 }
2489 if (from.isTensor()) {
2490 return IValue(MaybeOwnedTraits<at::Tensor>::createBorrow(from.toTensor()));
2491 } else {
2492 return IValue(from.payload, from.tag);
2493 }
2494 }
2495
2496 static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
2497 lhs.clearToNone();
2498 if (!rhs.isPtrType()) {
2499 lhs = rhs;
2500 } else if (rhs.isTensor()) {
2501 lhs = IValue(MaybeOwnedTraits<at::Tensor>::createBorrow(rhs.toTensor()));
2502 } else {
2503 lhs = IValue(rhs.payload, rhs.tag);
2504 }
2505 }
2506
2507 static void destroyBorrow(borrow_type& toDestroy) {
2508 toDestroy.clearToNone();
2509 }
2510
2511 static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
2512 return borrow;
2513 }
2514
2515 static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
2516 return &borrow;
2517 }
2518
2519 static bool debugBorrowIsValid(const borrow_type&) {
2520 return true;
2521 }
2522 };
2523
2524 template <>
2525 struct IValue::TagType<c10::Type> {
2526 static TORCH_API c10::TypePtr get(const IValue&);
2527 };
2528
2529 template <>
2530 struct IValue::TagType<c10::DynamicType> {
2531 static TORCH_API c10::TypePtr get(const IValue&);
2532 };
2533
2534 template <typename T>
2535 TypePtr IValue::type() const {
2536 return IValue::TagType<T>::get(*this);
2537 }
2538
2539 } // namespace c10
2540