1 #pragma once 2 3 #include <cstdint> 4 #include <memory> 5 #include <type_traits> 6 7 #include <ATen/core/jit_type_base.h> 8 #include <optional> 9 10 namespace c10 { 11 12 using DynamicTypeBits = std::uint32_t; 13 #define DYNAMIC_TYPE_BIT(x) (1u << x) 14 15 constexpr DynamicTypeBits kDynamicCovariantTypeBit = DYNAMIC_TYPE_BIT(31); 16 constexpr DynamicTypeBits kDynamicAnyTypeBit = DYNAMIC_TYPE_BIT(30); 17 18 constexpr DynamicTypeBits kDynamicNoneTypeBit = DYNAMIC_TYPE_BIT(1); 19 constexpr DynamicTypeBits kDynamicIntTypeBit = DYNAMIC_TYPE_BIT(3); 20 constexpr DynamicTypeBits kDynamicFloatTypeBit = DYNAMIC_TYPE_BIT(4); 21 constexpr DynamicTypeBits kDynamicComplexTypeBit = DYNAMIC_TYPE_BIT(5); 22 constexpr DynamicTypeBits kDynamicListTypeBit = DYNAMIC_TYPE_BIT(7); 23 constexpr DynamicTypeBits kDynamicTupleTypeBit = DYNAMIC_TYPE_BIT(8); 24 constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10); 25 26 #define FORALL_DYNAMIC_TYPES(_) \ 27 _(Tensor, DYNAMIC_TYPE_BIT(0), 1) \ 28 _(None, kDynamicNoneTypeBit, 1) \ 29 _(Bool, DYNAMIC_TYPE_BIT(2), 1) \ 30 _(Int, kDynamicIntTypeBit, 1) \ 31 _(Float, kDynamicFloatTypeBit, 1) \ 32 _(Complex, kDynamicComplexTypeBit, 1) \ 33 _(Number, \ 34 (kDynamicIntTypeBit | kDynamicFloatTypeBit | kDynamicComplexTypeBit), \ 35 1) \ 36 _(String, DYNAMIC_TYPE_BIT(6), 1) \ 37 _(List, kDynamicListTypeBit, 0) \ 38 _(Tuple, (kDynamicTupleTypeBit | kDynamicCovariantTypeBit), 0) \ 39 _(Dict, DYNAMIC_TYPE_BIT(9), 0) \ 40 _(Class, kDynamicClassTypeBit, 0) \ 41 _(Optional, \ 42 (DYNAMIC_TYPE_BIT(11) | kDynamicNoneTypeBit | kDynamicCovariantTypeBit), \ 43 0) \ 44 _(AnyList, (kDynamicListTypeBit | kDynamicAnyTypeBit), 1) \ 45 _(AnyTuple, \ 46 (kDynamicTupleTypeBit | kDynamicCovariantTypeBit | kDynamicAnyTypeBit), \ 47 1) \ 48 _(DeviceObj, DYNAMIC_TYPE_BIT(12), 1) \ 49 _(StreamObj, DYNAMIC_TYPE_BIT(13), 1) \ 50 _(Capsule, DYNAMIC_TYPE_BIT(14), 1) \ 51 _(Generator, DYNAMIC_TYPE_BIT(15), 1) \ 52 _(Storage, DYNAMIC_TYPE_BIT(16), 1) \ 53 _(Var, DYNAMIC_TYPE_BIT(17), 0) \ 54 _(AnyClass, (kDynamicClassTypeBit | kDynamicAnyTypeBit), 1) \ 55 _(QScheme, DYNAMIC_TYPE_BIT(18), 1) \ 56 _(Quantizer, DYNAMIC_TYPE_BIT(19), 1) \ 57 _(AnyEnum, DYNAMIC_TYPE_BIT(20), 1) \ 58 _(RRef, DYNAMIC_TYPE_BIT(21), 0) \ 59 _(Future, DYNAMIC_TYPE_BIT(22), 0) \ 60 _(Await, DYNAMIC_TYPE_BIT(23), 0) \ 61 _(Any, 0xffffffff, 1) 62 63 #define FORALL_DYNAMIC_TYPES_FAKE(_) \ 64 _(ScalarType, kDynamicIntTypeBit, 1) \ 65 _(Layout, kDynamicIntTypeBit, 1) \ 66 _(SymInt, kDynamicIntTypeBit, 1) \ 67 _(MemoryFormat, kDynamicIntTypeBit, 1) 68 69 #define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type; 70 FORALL_DYNAMIC_TYPES(FORWARD_DECL_TYPE) 71 FORALL_DYNAMIC_TYPES_FAKE(FORWARD_DECL_TYPE) 72 #undef FORWARD_DECL_TYPE 73 74 class DynamicType; 75 using DynamicTypePtr = std::shared_ptr<DynamicType>; 76 77 /** 78 * DynamicType is designed as a low dependency type system for TorchScript. The 79 * existing JIT types are used for both compilation and runtime, which makes 80 * sense for server contexts because we often compile and run the model in 81 * the same process, however this doesn't hold for mobile devices where we 82 * always compiles a model ahead of time, therefore there will be dependencies 83 * which are not needed, but built with mobile runtime causing binary size 84 * bloat, by design. Every basic type like Int, Bool or String will bring their 85 * vtable, typeinfo, constructor, destructor and even more data from their 86 * specializations for STL types to the binary causing a long tail bloat. 87 * 88 * The core problem is about the complexity to implement and maintain a single 89 * type system for both analysis and execution purposes. Although they should 90 * have the exactly same semantics, in practice implement a unified abstraction 91 * adds conceptual and representational overhead for both sides of the world. 92 * 93 * To address the issues, DynamicType implements a minimal subset of JIT types 94 * and uses a generic algorithm to test all subtyping relations. To achieve 95 * this, we assign each dynamic type a single integer tag to represent its 96 * semantics. More specifically, a dynamic type is defined as a set of "control 97 * bits" and "data bits", where control bits describe the special behavior when 98 * testing a type and data bits map to identity of each nominal type. We use bit 99 * operations to perform all the tests. 100 * 101 * For example, a "covariant bit" is a control bit used to describe if a type 102 * is covariant, right now the most used one is tuple type, and in addition to 103 * the control bit, tuple type's data bit is the 8th bit from the LSB. Control 104 * bits start from MSB and data bits start from LSB. 105 * 106 * If two types are equal, then they are subtype of each other, also if the bits 107 * from one type tag is subset of the other tag, it automatically becomes a 108 * subtype of the other. This simplifies the subtyping logic a lot, and over the 109 * long term it is possible to adopt this scheme on the server side as well. 110 * Special cases can be added but they generally should not take too much code 111 * size. 112 * 113 * DynamicType may or may not inherit from c10::Type because it's not the core 114 * requirement of DynamicType to interface with existing JIT types, but we might 115 * want to inherit from c10::Type to reduce the migration cost. 116 */ 117 class DynamicType : public SharedType { 118 using ClassTypePtr = std::shared_ptr<const c10::ClassType>; 119 120 /** 121 * A implementation detail to support NamedTuple. 122 */ 123 struct LabeledDynamicType { 124 std::optional<std::string> label; 125 DynamicTypePtr ty; LabeledDynamicTypeLabeledDynamicType126 explicit LabeledDynamicType(DynamicTypePtr t) : ty(std::move(t)) {} 127 128 bool equals(const LabeledDynamicType& other) const; 129 bool isSubtypeOf(const LabeledDynamicType& other) const; 130 }; 131 132 public: 133 // TODO Change Ptr to DynamicTypePtr when all migrations are done. 134 using Ptr = TypePtr; 135 using ElementType = DynamicType; 136 ~DynamicType() override; 137 138 struct Arguments { 139 Arguments() = default; 140 Arguments(c10::ArrayRef<TypePtr>); 141 Arguments(const std::vector<c10::string_view>&, c10::ArrayRef<TypePtr>); 142 std::vector<LabeledDynamicType> elems; 143 }; 144 145 enum class Tag : DynamicTypeBits { 146 #define DYNAMIC_TYPE_ITEM(NAME, VAL, _) NAME = VAL, 147 FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_ITEM) 148 FORALL_DYNAMIC_TYPES_FAKE(DYNAMIC_TYPE_ITEM) 149 #undef DYNAMIC_TYPE_ITEM 150 }; 151 152 bool equals(const Type& rhs) const override; 153 bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; 154 std::string str() const override; 155 static const TypeKind Kind = TypeKind::DynamicType; 156 static TORCH_API DynamicTypePtr create(Type& ty); 157 158 explicit DynamicType(Tag, Arguments); 159 explicit DynamicType(Tag, c10::string_view, Arguments); 160 161 TypePtr containedType(size_t) const override; 162 size_t containedTypeSize() const override; tag()163 Tag tag() const { 164 return tag_; 165 } name()166 const std::optional<std::string>& name() const { 167 return name_; 168 } arguments()169 const Arguments& arguments() const { 170 return arguments_; 171 } 172 TORCH_API TypeKind dynamicKind() const; 173 174 // Should be used only on the server side to restore static type information. 175 #ifndef C10_MOBILE 176 TORCH_API 177 #endif 178 TypePtr fallback() const; 179 180 private: symmetric()181 bool symmetric() const override { 182 return false; 183 } 184 friend struct Type; 185 static std::shared_ptr<const DynamicType> create(const Type& ty); 186 DynamicType(const Type& other); 187 bool equals(const DynamicType& other) const; 188 189 template <typename F> compareArguments(const DynamicType & other,const F & f)190 bool compareArguments(const DynamicType& other, const F& f) const { 191 if (arguments_.elems.size() != other.arguments_.elems.size()) { 192 return false; 193 } 194 for (size_t i = 0; i < arguments_.elems.size(); i++) { 195 if (!f(arguments_.elems[i], other.arguments_.elems[i])) { 196 return false; 197 } 198 } 199 return true; 200 } 201 202 Tag tag_; 203 std::optional<std::string> name_; 204 union { 205 Arguments arguments_; 206 ClassTypePtr class_; 207 }; 208 }; 209 210 template <typename T> 211 struct DynamicTypeTrait { tagValueDynamicTypeTrait212 C10_NOINLINE static auto tagValue() { 213 TORCH_CHECK(false); 214 return DynamicType::Tag::Any; 215 } 216 }; 217 218 namespace detail { 219 C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag); 220 } 221 222 #define DYNAMIC_TYPE_TAG_VALUE(NAME, _, IS_BASE_TYPE) \ 223 template <> \ 224 struct TORCH_API DynamicTypeTrait<NAME##Type> { \ 225 C10_ERASE static auto tagValue() { \ 226 return DynamicType::Tag::NAME; \ 227 } \ 228 static constexpr bool isBaseType = IS_BASE_TYPE; \ 229 template <typename T = const DynamicTypePtr&> \ 230 static std::enable_if_t<isBaseType, T> getBaseType() { \ 231 static auto type = detail::makeBaseType(tagValue()); \ 232 return type; \ 233 } \ 234 }; // namespace c10 235 FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_TAG_VALUE) 236 FORALL_DYNAMIC_TYPES_FAKE(DYNAMIC_TYPE_TAG_VALUE) 237 #undef DYNAMIC_TYPE_TAG_VALUE 238 239 } // namespace c10 240