xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/dynamic_type.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/dynamic_type.h>
2 
3 #include <string>
4 
5 #include <ATen/core/class_type.h>
6 #include <ATen/core/ivalue.h>
7 #include <ATen/core/jit_type.h>
8 #include <ATen/core/type_factory.h>
9 #include <c10/util/Exception.h>
10 
11 namespace c10 {
12 
13 namespace {
14 
contains(DynamicType::Tag lhs,DynamicTypeBits rhs)15 bool contains(DynamicType::Tag lhs, DynamicTypeBits rhs) {
16   return (static_cast<DynamicTypeBits>(lhs) | rhs) ==
17       static_cast<DynamicTypeBits>(lhs);
18 }
19 
contains(DynamicType::Tag lhs,DynamicType::Tag rhs)20 bool contains(DynamicType::Tag lhs, DynamicType::Tag rhs) {
21   return contains(lhs, static_cast<DynamicTypeBits>(rhs));
22 }
23 
24 } // namespace
25 
26 namespace detail {
27 
makeBaseType(DynamicType::Tag tag)28 DynamicTypePtr makeBaseType(DynamicType::Tag tag) {
29   return std::make_shared<DynamicType>(tag, DynamicType::Arguments{});
30 }
31 
32 } // namespace detail
33 
str() const34 std::string DynamicType::str() const {
35   if (name_) {
36     return *name_;
37   }
38   std::string ret = "Dynamic<";
39   ret += std::to_string(static_cast<DynamicTypeBits>(tag_));
40   ret += ">";
41   if (tag_ != Tag::Class && !arguments_.elems.empty()) {
42     ret += "[";
43     for (const auto& arg : arguments_.elems) {
44       if (arg.label) {
45         ret += *arg.label + ":";
46       }
47       ret += arg.ty->str();
48       ret += ",";
49     }
50     ret += "]";
51   }
52   return ret;
53 }
54 
Arguments(c10::ArrayRef<TypePtr> args)55 DynamicType::Arguments::Arguments(c10::ArrayRef<TypePtr> args) {
56   elems.reserve(args.size());
57   for (const auto& arg : args) {
58     elems.emplace_back(create(*arg));
59   }
60 }
61 
Arguments(const std::vector<c10::string_view> & names,c10::ArrayRef<TypePtr> args)62 DynamicType::Arguments::Arguments(
63     const std::vector<c10::string_view>& names,
64     c10::ArrayRef<TypePtr> args)
65     : Arguments(args) {
66   TORCH_INTERNAL_ASSERT(names.size() == args.size());
67   for (size_t i = 0; i < args.size(); i++) {
68     elems[i].label = std::string{names[i]};
69   }
70 }
71 
~DynamicType()72 DynamicType::~DynamicType() {
73   if (tag_ == Tag::Class) {
74     class_.~ClassTypePtr();
75     return;
76   }
77 
78   arguments_.~Arguments();
79 }
80 
create(const Type & other)81 std::shared_ptr<const DynamicType> DynamicType::create(const Type& other) {
82   if (auto dynRaw = other.castRaw<DynamicType>()) {
83     TORCH_INTERNAL_ASSERT(!dynRaw->weak_from_this().expired(),
84         "Error creating dynamic type instance not managed by shared_ptr: ",
85         other.str());
86   }
87   if (auto dyn = other.cast<DynamicType>()) {
88     return dyn;
89   }
90   return std::shared_ptr<const DynamicType>(new DynamicType{other});
91 }
92 
create(Type & other)93 DynamicTypePtr DynamicType::create(Type& other) {
94   if (auto dynRaw = other.castRaw<DynamicType>()) {
95     TORCH_INTERNAL_ASSERT(!dynRaw->weak_from_this().expired(),
96         "Error creating dynamic type instance not managed by shared_ptr: ",
97         other.str());
98   }
99   if (auto dyn = other.cast<DynamicType>()) {
100     return dyn;
101   }
102   return std::shared_ptr<DynamicType>(new DynamicType{other});
103 }
104 
DynamicType(Tag tag,Arguments arguments)105 DynamicType::DynamicType(Tag tag, Arguments arguments)
106     : SharedType(Kind), tag_(tag), arguments_(std::move(arguments)) {}
107 
DynamicType(Tag tag,c10::string_view name,Arguments arguments)108 DynamicType::DynamicType(Tag tag, c10::string_view name, Arguments arguments)
109     : SharedType(Kind),
110       tag_(tag),
111       name_(std::string{name}),
112       arguments_(std::move(arguments)) {}
113 
DynamicType(const Type & other)114 DynamicType::DynamicType(const Type& other) : SharedType(DynamicType::Kind) {
115   auto kind = other.kind();
116   TORCH_INTERNAL_ASSERT(kind != Kind);
117   if (auto n = other.castRaw<NamedType>()) {
118     if (const auto& qn = n->name()) {
119       name_ = qn->qualifiedName();
120     }
121   } else if (auto v = other.castRaw<VarType>()) {
122     name_ = v->name();
123   }
124 
125   if (auto cls = other.cast<ClassType>()) {
126     new (&class_) ClassTypePtr(std::move(cls));
127     tag_ = Tag::Class;
128     return;
129   }
130   switch (kind) {
131 #define CASE_TYPE(T, _, __) \
132   case T##Type::Kind:       \
133     tag_ = Tag::T;          \
134     break;
135     FORALL_DYNAMIC_TYPES(CASE_TYPE)
136     FORALL_DYNAMIC_TYPES_FAKE(CASE_TYPE)
137 #undef CASE_TYPE
138     default:
139       TORCH_INTERNAL_ASSERT(false, "Unsupported dynamic type: ", other.str());
140   }
141 
142   auto args = other.containedTypes();
143   if (args.empty()) {
144     new (&arguments_) Arguments();
145     return;
146   }
147 
148   if (auto tup = other.castRaw<TupleType>()) {
149     if (auto names = tup->names()) {
150       new (&arguments_) Arguments(*names, args);
151       return;
152     }
153   }
154 
155   new (&arguments_) Arguments(args);
156 }
157 
equals(const DynamicType & other) const158 bool DynamicType::equals(const DynamicType& other) const {
159   if (this == &other) {
160     return true;
161   }
162   if (tag_ != other.tag_) {
163     return false;
164   }
165   switch (tag_) {
166     case Tag::Class:
167       return *class_ == *other.class_;
168     default:
169       return compareArguments(
170           other, [](const LabeledDynamicType& a, const LabeledDynamicType& b) {
171             return a.equals(b);
172           });
173   }
174 }
175 
equals(const Type & rhs) const176 bool DynamicType::equals(const Type& rhs) const {
177   return equals(*create(rhs));
178 }
179 
isSubtypeOfExt(const Type & rhs,std::ostream *) const180 bool DynamicType::isSubtypeOfExt(const Type& rhs, std::ostream*) const {
181   auto other = create(rhs);
182   if (tag_ == other->tag_) {
183     if (equals(*other)) {
184       return true;
185     }
186     if (contains(tag_, kDynamicCovariantTypeBit)) {
187       if (compareArguments(
188               *other,
189               [](const LabeledDynamicType& a, const LabeledDynamicType& b) {
190                 return a.isSubtypeOf(b);
191               })) {
192         return true;
193       };
194     }
195   } else if (contains(other->tag_, tag_)) {
196     return true;
197   }
198 
199   if (other->tag_ == Tag::Optional) {
200     if (isSubtypeOf(other->arguments_.elems[0].ty)) {
201       return true;
202     }
203   }
204 
205   return false;
206 }
207 
containedType(size_t i) const208 TypePtr DynamicType::containedType(size_t i) const {
209   TORCH_INTERNAL_ASSERT(tag_ != Tag::Class);
210   return arguments_.elems.at(i).ty;
211 }
212 
containedTypeSize() const213 size_t DynamicType::containedTypeSize() const {
214   TORCH_INTERNAL_ASSERT(tag_ != Tag::Class);
215   return arguments_.elems.size();
216 }
217 
dynamicKind() const218 TypeKind DynamicType::dynamicKind() const {
219   switch (tag_) {
220 #define CASE_TYPE(T, _, __) \
221   case Tag::T:              \
222     return TypeKind::T##Type;
223     FORALL_DYNAMIC_TYPES(CASE_TYPE)
224     // FORALL_DYNAMIC_TYPES_FAKE is intentionally omitted here
225     // as these dynamic types map to the same tag, so they always
226     // resolve to integers
227 #undef CASE_TYPE
228     default:
229       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
230       return TypeKind::AnyType;
231   }
232 }
233 
fallback() const234 TypePtr DynamicType::fallback() const {
235   switch (tag_) {
236     case Tag::Tensor:
237       return TensorType::get();
238     case Tag::None:
239       return NoneType::get();
240     case Tag::Bool:
241       return BoolType::get();
242     case Tag::Int:
243       return IntType::get();
244     case Tag::Float:
245       return FloatType::get();
246     case Tag::Complex:
247       return ComplexType::get();
248     case Tag::Number:
249       return NumberType::get();
250     case Tag::String:
251       return StringType::get();
252     case Tag::List:
253       return ListType::create(arguments_.elems[0].ty->fallback());
254     case Tag::Tuple: {
255       std::vector<TypePtr> fallbacks;
256       fallbacks.reserve(arguments_.elems.size());
257       for (const auto& elem : arguments_.elems) {
258         fallbacks.push_back(elem.ty->fallback());
259       }
260       if (name_) {
261         std::vector<c10::string_view> fields;
262         fields.reserve(arguments_.elems.size());
263         for (const auto& elem : arguments_.elems) {
264           // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
265           fields.emplace_back(*elem.label);
266         }
267         return TupleType::createNamed(*name_, fields, fallbacks);
268       }
269       return TupleType::create(std::move(fallbacks));
270     }
271     case Tag::Dict:
272       return DictType::create(
273           arguments_.elems[0].ty->fallback(),
274           arguments_.elems[1].ty->fallback());
275     case Tag::Class:
276       return std::make_shared<ClassType>(*class_);
277     case Tag::Optional:
278       return OptionalType::create(arguments_.elems[0].ty->fallback());
279     case Tag::AnyList:
280       return AnyListType::get();
281     case Tag::AnyTuple:
282       return AnyTupleType::get();
283     case Tag::DeviceObj:
284       return DeviceObjType::get();
285     case Tag::StreamObj:
286       return StreamObjType::get();
287     case Tag::Capsule:
288       return CapsuleType::get();
289     case Tag::Generator:
290       return GeneratorType::get();
291     case Tag::Storage:
292       return StorageType::get();
293     case Tag::Var:
294       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
295       return VarType::create(*name_);
296     case Tag::AnyClass:
297       return AnyClassType::get();
298     case Tag::QScheme:
299       return QSchemeType::get();
300     case Tag::Quantizer:
301       return QuantizerType::get();
302     case Tag::AnyEnum:
303       return AnyEnumType::get();
304     case Tag::RRef:
305       return RRefType::create(arguments_.elems[0].ty->fallback());
306     case Tag::Future:
307       return FutureType::create(arguments_.elems[0].ty->fallback());
308     case Tag::Await:
309       return AwaitType::create(arguments_.elems[0].ty->fallback());
310     case Tag::Any:
311       return AnyType::get();
312   }
313   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
314   return nullptr;
315 }
316 
isSubtypeOf(const LabeledDynamicType & other) const317 bool DynamicType::LabeledDynamicType::isSubtypeOf(
318     const LabeledDynamicType& other) const {
319   if (!other.label || (label == other.label)) {
320     return ty->isSubtypeOf(other.ty);
321   }
322 
323   return false;
324 }
325 
equals(const LabeledDynamicType & other) const326 bool DynamicType::LabeledDynamicType::equals(
327     const LabeledDynamicType& other) const {
328   return (label == other.label) && (*ty == *other.ty);
329 }
330 
get(const c10::IValue & v)331 DynamicType::Ptr IValue::TagType<c10::DynamicType>::get(const c10::IValue& v) {
332   switch (v.tag) {
333     case Tag::None:
334       return DynamicTypeTrait<NoneType>::getBaseType();
335     case Tag::Tensor:
336       return DynamicTypeTrait<TensorType>::getBaseType();
337     case Tag::Double:
338       return DynamicTypeTrait<FloatType>::getBaseType();
339     case Tag::ComplexDouble:
340       return DynamicTypeTrait<ComplexType>::getBaseType();
341     case Tag::Int:
342       return DynamicTypeTrait<IntType>::getBaseType();
343     case Tag::Bool:
344       return DynamicTypeTrait<BoolType>::getBaseType();
345     case Tag::String:
346       return DynamicTypeTrait<StringType>::getBaseType();
347     case Tag::GenericDict: {
348       auto d = v.toGenericDict();
349       return DynamicTypeFactory::create<DictType>(d.keyType(), d.valueType());
350     }
351     case Tag::GenericList:
352       return DynamicTypeFactory::create<ListType>(v.toList().elementType());
353     case Tag::Device:
354       return DynamicTypeTrait<DeviceObjType>::getBaseType();
355     case Tag::Stream:
356       return DynamicTypeTrait<StreamObjType>::getBaseType();
357     case Tag::Object:
358       return v.toObjectRef().type();
359     case Tag::Capsule:
360       return DynamicTypeTrait<CapsuleType>::getBaseType();
361     case Tag::Tuple:
362       return v.toTupleRef().type<c10::DynamicType>();
363     default:
364       return DynamicTypeTrait<AnyType>::getBaseType();
365   }
366 }
367 
create(const std::vector<TypePtr> & elemTypes)368 DynamicTypePtr ivalue::TupleTypeFactory<c10::DynamicType>::create(
369     const std::vector<TypePtr>& elemTypes) {
370   return DynamicTypeFactory::create<TupleType>(elemTypes);
371 }
372 
fallback(const Type &)373 DynamicTypePtr ivalue::TupleTypeFactory<c10::DynamicType>::fallback(
374     const Type&) {
375   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
376   return nullptr;
377 }
378 
379 TORCH_API TupleTypePtr
fallback(C10_UNUSED const Type & type)380 ivalue::TupleTypeFactory<TupleType>::fallback(C10_UNUSED const Type& type) {
381 #ifdef C10_MOBILE
382   return nullptr;
383 #else
384   const auto& dyn = type.expectRef<DynamicType>();
385   std::vector<c10::string_view> fields;
386   std::vector<TypePtr> types;
387 
388   for (const auto& elem : dyn.arguments().elems) {
389     types.emplace_back(elem.ty);
390     if (const auto& name = elem.label) {
391       fields.emplace_back(*name);
392     }
393   }
394   if (const auto& name = dyn.name()) {
395     return TupleType::createNamed(*name, fields, types);
396   }
397   return TupleType::create(std::move(types));
398 #endif
399 }
400 
401 
402 } // namespace c10
403