1 #include <ATen/core/dynamic_type.h>
2
3 #include <string>
4
5 #include <ATen/core/class_type.h>
6 #include <ATen/core/ivalue.h>
7 #include <ATen/core/jit_type.h>
8 #include <ATen/core/type_factory.h>
9 #include <c10/util/Exception.h>
10
11 namespace c10 {
12
13 namespace {
14
contains(DynamicType::Tag lhs,DynamicTypeBits rhs)15 bool contains(DynamicType::Tag lhs, DynamicTypeBits rhs) {
16 return (static_cast<DynamicTypeBits>(lhs) | rhs) ==
17 static_cast<DynamicTypeBits>(lhs);
18 }
19
contains(DynamicType::Tag lhs,DynamicType::Tag rhs)20 bool contains(DynamicType::Tag lhs, DynamicType::Tag rhs) {
21 return contains(lhs, static_cast<DynamicTypeBits>(rhs));
22 }
23
24 } // namespace
25
26 namespace detail {
27
makeBaseType(DynamicType::Tag tag)28 DynamicTypePtr makeBaseType(DynamicType::Tag tag) {
29 return std::make_shared<DynamicType>(tag, DynamicType::Arguments{});
30 }
31
32 } // namespace detail
33
str() const34 std::string DynamicType::str() const {
35 if (name_) {
36 return *name_;
37 }
38 std::string ret = "Dynamic<";
39 ret += std::to_string(static_cast<DynamicTypeBits>(tag_));
40 ret += ">";
41 if (tag_ != Tag::Class && !arguments_.elems.empty()) {
42 ret += "[";
43 for (const auto& arg : arguments_.elems) {
44 if (arg.label) {
45 ret += *arg.label + ":";
46 }
47 ret += arg.ty->str();
48 ret += ",";
49 }
50 ret += "]";
51 }
52 return ret;
53 }
54
Arguments(c10::ArrayRef<TypePtr> args)55 DynamicType::Arguments::Arguments(c10::ArrayRef<TypePtr> args) {
56 elems.reserve(args.size());
57 for (const auto& arg : args) {
58 elems.emplace_back(create(*arg));
59 }
60 }
61
Arguments(const std::vector<c10::string_view> & names,c10::ArrayRef<TypePtr> args)62 DynamicType::Arguments::Arguments(
63 const std::vector<c10::string_view>& names,
64 c10::ArrayRef<TypePtr> args)
65 : Arguments(args) {
66 TORCH_INTERNAL_ASSERT(names.size() == args.size());
67 for (size_t i = 0; i < args.size(); i++) {
68 elems[i].label = std::string{names[i]};
69 }
70 }
71
~DynamicType()72 DynamicType::~DynamicType() {
73 if (tag_ == Tag::Class) {
74 class_.~ClassTypePtr();
75 return;
76 }
77
78 arguments_.~Arguments();
79 }
80
create(const Type & other)81 std::shared_ptr<const DynamicType> DynamicType::create(const Type& other) {
82 if (auto dynRaw = other.castRaw<DynamicType>()) {
83 TORCH_INTERNAL_ASSERT(!dynRaw->weak_from_this().expired(),
84 "Error creating dynamic type instance not managed by shared_ptr: ",
85 other.str());
86 }
87 if (auto dyn = other.cast<DynamicType>()) {
88 return dyn;
89 }
90 return std::shared_ptr<const DynamicType>(new DynamicType{other});
91 }
92
create(Type & other)93 DynamicTypePtr DynamicType::create(Type& other) {
94 if (auto dynRaw = other.castRaw<DynamicType>()) {
95 TORCH_INTERNAL_ASSERT(!dynRaw->weak_from_this().expired(),
96 "Error creating dynamic type instance not managed by shared_ptr: ",
97 other.str());
98 }
99 if (auto dyn = other.cast<DynamicType>()) {
100 return dyn;
101 }
102 return std::shared_ptr<DynamicType>(new DynamicType{other});
103 }
104
DynamicType(Tag tag,Arguments arguments)105 DynamicType::DynamicType(Tag tag, Arguments arguments)
106 : SharedType(Kind), tag_(tag), arguments_(std::move(arguments)) {}
107
DynamicType(Tag tag,c10::string_view name,Arguments arguments)108 DynamicType::DynamicType(Tag tag, c10::string_view name, Arguments arguments)
109 : SharedType(Kind),
110 tag_(tag),
111 name_(std::string{name}),
112 arguments_(std::move(arguments)) {}
113
DynamicType(const Type & other)114 DynamicType::DynamicType(const Type& other) : SharedType(DynamicType::Kind) {
115 auto kind = other.kind();
116 TORCH_INTERNAL_ASSERT(kind != Kind);
117 if (auto n = other.castRaw<NamedType>()) {
118 if (const auto& qn = n->name()) {
119 name_ = qn->qualifiedName();
120 }
121 } else if (auto v = other.castRaw<VarType>()) {
122 name_ = v->name();
123 }
124
125 if (auto cls = other.cast<ClassType>()) {
126 new (&class_) ClassTypePtr(std::move(cls));
127 tag_ = Tag::Class;
128 return;
129 }
130 switch (kind) {
131 #define CASE_TYPE(T, _, __) \
132 case T##Type::Kind: \
133 tag_ = Tag::T; \
134 break;
135 FORALL_DYNAMIC_TYPES(CASE_TYPE)
136 FORALL_DYNAMIC_TYPES_FAKE(CASE_TYPE)
137 #undef CASE_TYPE
138 default:
139 TORCH_INTERNAL_ASSERT(false, "Unsupported dynamic type: ", other.str());
140 }
141
142 auto args = other.containedTypes();
143 if (args.empty()) {
144 new (&arguments_) Arguments();
145 return;
146 }
147
148 if (auto tup = other.castRaw<TupleType>()) {
149 if (auto names = tup->names()) {
150 new (&arguments_) Arguments(*names, args);
151 return;
152 }
153 }
154
155 new (&arguments_) Arguments(args);
156 }
157
equals(const DynamicType & other) const158 bool DynamicType::equals(const DynamicType& other) const {
159 if (this == &other) {
160 return true;
161 }
162 if (tag_ != other.tag_) {
163 return false;
164 }
165 switch (tag_) {
166 case Tag::Class:
167 return *class_ == *other.class_;
168 default:
169 return compareArguments(
170 other, [](const LabeledDynamicType& a, const LabeledDynamicType& b) {
171 return a.equals(b);
172 });
173 }
174 }
175
equals(const Type & rhs) const176 bool DynamicType::equals(const Type& rhs) const {
177 return equals(*create(rhs));
178 }
179
isSubtypeOfExt(const Type & rhs,std::ostream *) const180 bool DynamicType::isSubtypeOfExt(const Type& rhs, std::ostream*) const {
181 auto other = create(rhs);
182 if (tag_ == other->tag_) {
183 if (equals(*other)) {
184 return true;
185 }
186 if (contains(tag_, kDynamicCovariantTypeBit)) {
187 if (compareArguments(
188 *other,
189 [](const LabeledDynamicType& a, const LabeledDynamicType& b) {
190 return a.isSubtypeOf(b);
191 })) {
192 return true;
193 };
194 }
195 } else if (contains(other->tag_, tag_)) {
196 return true;
197 }
198
199 if (other->tag_ == Tag::Optional) {
200 if (isSubtypeOf(other->arguments_.elems[0].ty)) {
201 return true;
202 }
203 }
204
205 return false;
206 }
207
containedType(size_t i) const208 TypePtr DynamicType::containedType(size_t i) const {
209 TORCH_INTERNAL_ASSERT(tag_ != Tag::Class);
210 return arguments_.elems.at(i).ty;
211 }
212
containedTypeSize() const213 size_t DynamicType::containedTypeSize() const {
214 TORCH_INTERNAL_ASSERT(tag_ != Tag::Class);
215 return arguments_.elems.size();
216 }
217
dynamicKind() const218 TypeKind DynamicType::dynamicKind() const {
219 switch (tag_) {
220 #define CASE_TYPE(T, _, __) \
221 case Tag::T: \
222 return TypeKind::T##Type;
223 FORALL_DYNAMIC_TYPES(CASE_TYPE)
224 // FORALL_DYNAMIC_TYPES_FAKE is intentionally omitted here
225 // as these dynamic types map to the same tag, so they always
226 // resolve to integers
227 #undef CASE_TYPE
228 default:
229 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
230 return TypeKind::AnyType;
231 }
232 }
233
fallback() const234 TypePtr DynamicType::fallback() const {
235 switch (tag_) {
236 case Tag::Tensor:
237 return TensorType::get();
238 case Tag::None:
239 return NoneType::get();
240 case Tag::Bool:
241 return BoolType::get();
242 case Tag::Int:
243 return IntType::get();
244 case Tag::Float:
245 return FloatType::get();
246 case Tag::Complex:
247 return ComplexType::get();
248 case Tag::Number:
249 return NumberType::get();
250 case Tag::String:
251 return StringType::get();
252 case Tag::List:
253 return ListType::create(arguments_.elems[0].ty->fallback());
254 case Tag::Tuple: {
255 std::vector<TypePtr> fallbacks;
256 fallbacks.reserve(arguments_.elems.size());
257 for (const auto& elem : arguments_.elems) {
258 fallbacks.push_back(elem.ty->fallback());
259 }
260 if (name_) {
261 std::vector<c10::string_view> fields;
262 fields.reserve(arguments_.elems.size());
263 for (const auto& elem : arguments_.elems) {
264 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
265 fields.emplace_back(*elem.label);
266 }
267 return TupleType::createNamed(*name_, fields, fallbacks);
268 }
269 return TupleType::create(std::move(fallbacks));
270 }
271 case Tag::Dict:
272 return DictType::create(
273 arguments_.elems[0].ty->fallback(),
274 arguments_.elems[1].ty->fallback());
275 case Tag::Class:
276 return std::make_shared<ClassType>(*class_);
277 case Tag::Optional:
278 return OptionalType::create(arguments_.elems[0].ty->fallback());
279 case Tag::AnyList:
280 return AnyListType::get();
281 case Tag::AnyTuple:
282 return AnyTupleType::get();
283 case Tag::DeviceObj:
284 return DeviceObjType::get();
285 case Tag::StreamObj:
286 return StreamObjType::get();
287 case Tag::Capsule:
288 return CapsuleType::get();
289 case Tag::Generator:
290 return GeneratorType::get();
291 case Tag::Storage:
292 return StorageType::get();
293 case Tag::Var:
294 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
295 return VarType::create(*name_);
296 case Tag::AnyClass:
297 return AnyClassType::get();
298 case Tag::QScheme:
299 return QSchemeType::get();
300 case Tag::Quantizer:
301 return QuantizerType::get();
302 case Tag::AnyEnum:
303 return AnyEnumType::get();
304 case Tag::RRef:
305 return RRefType::create(arguments_.elems[0].ty->fallback());
306 case Tag::Future:
307 return FutureType::create(arguments_.elems[0].ty->fallback());
308 case Tag::Await:
309 return AwaitType::create(arguments_.elems[0].ty->fallback());
310 case Tag::Any:
311 return AnyType::get();
312 }
313 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
314 return nullptr;
315 }
316
isSubtypeOf(const LabeledDynamicType & other) const317 bool DynamicType::LabeledDynamicType::isSubtypeOf(
318 const LabeledDynamicType& other) const {
319 if (!other.label || (label == other.label)) {
320 return ty->isSubtypeOf(other.ty);
321 }
322
323 return false;
324 }
325
equals(const LabeledDynamicType & other) const326 bool DynamicType::LabeledDynamicType::equals(
327 const LabeledDynamicType& other) const {
328 return (label == other.label) && (*ty == *other.ty);
329 }
330
get(const c10::IValue & v)331 DynamicType::Ptr IValue::TagType<c10::DynamicType>::get(const c10::IValue& v) {
332 switch (v.tag) {
333 case Tag::None:
334 return DynamicTypeTrait<NoneType>::getBaseType();
335 case Tag::Tensor:
336 return DynamicTypeTrait<TensorType>::getBaseType();
337 case Tag::Double:
338 return DynamicTypeTrait<FloatType>::getBaseType();
339 case Tag::ComplexDouble:
340 return DynamicTypeTrait<ComplexType>::getBaseType();
341 case Tag::Int:
342 return DynamicTypeTrait<IntType>::getBaseType();
343 case Tag::Bool:
344 return DynamicTypeTrait<BoolType>::getBaseType();
345 case Tag::String:
346 return DynamicTypeTrait<StringType>::getBaseType();
347 case Tag::GenericDict: {
348 auto d = v.toGenericDict();
349 return DynamicTypeFactory::create<DictType>(d.keyType(), d.valueType());
350 }
351 case Tag::GenericList:
352 return DynamicTypeFactory::create<ListType>(v.toList().elementType());
353 case Tag::Device:
354 return DynamicTypeTrait<DeviceObjType>::getBaseType();
355 case Tag::Stream:
356 return DynamicTypeTrait<StreamObjType>::getBaseType();
357 case Tag::Object:
358 return v.toObjectRef().type();
359 case Tag::Capsule:
360 return DynamicTypeTrait<CapsuleType>::getBaseType();
361 case Tag::Tuple:
362 return v.toTupleRef().type<c10::DynamicType>();
363 default:
364 return DynamicTypeTrait<AnyType>::getBaseType();
365 }
366 }
367
create(const std::vector<TypePtr> & elemTypes)368 DynamicTypePtr ivalue::TupleTypeFactory<c10::DynamicType>::create(
369 const std::vector<TypePtr>& elemTypes) {
370 return DynamicTypeFactory::create<TupleType>(elemTypes);
371 }
372
fallback(const Type &)373 DynamicTypePtr ivalue::TupleTypeFactory<c10::DynamicType>::fallback(
374 const Type&) {
375 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false);
376 return nullptr;
377 }
378
379 TORCH_API TupleTypePtr
fallback(C10_UNUSED const Type & type)380 ivalue::TupleTypeFactory<TupleType>::fallback(C10_UNUSED const Type& type) {
381 #ifdef C10_MOBILE
382 return nullptr;
383 #else
384 const auto& dyn = type.expectRef<DynamicType>();
385 std::vector<c10::string_view> fields;
386 std::vector<TypePtr> types;
387
388 for (const auto& elem : dyn.arguments().elems) {
389 types.emplace_back(elem.ty);
390 if (const auto& name = elem.label) {
391 fields.emplace_back(*name);
392 }
393 }
394 if (const auto& name = dyn.name()) {
395 return TupleType::createNamed(*name, fields, types);
396 }
397 return TupleType::create(std::move(types));
398 #endif
399 }
400
401
402 } // namespace c10
403