xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/dynamic_type.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdint>
4 #include <memory>
5 #include <type_traits>
6 
7 #include <ATen/core/jit_type_base.h>
8 #include <optional>
9 
10 namespace c10 {
11 
12 using DynamicTypeBits = std::uint32_t;
13 #define DYNAMIC_TYPE_BIT(x) (1u << x)
14 
15 constexpr DynamicTypeBits kDynamicCovariantTypeBit = DYNAMIC_TYPE_BIT(31);
16 constexpr DynamicTypeBits kDynamicAnyTypeBit = DYNAMIC_TYPE_BIT(30);
17 
18 constexpr DynamicTypeBits kDynamicNoneTypeBit = DYNAMIC_TYPE_BIT(1);
19 constexpr DynamicTypeBits kDynamicIntTypeBit = DYNAMIC_TYPE_BIT(3);
20 constexpr DynamicTypeBits kDynamicFloatTypeBit = DYNAMIC_TYPE_BIT(4);
21 constexpr DynamicTypeBits kDynamicComplexTypeBit = DYNAMIC_TYPE_BIT(5);
22 constexpr DynamicTypeBits kDynamicListTypeBit = DYNAMIC_TYPE_BIT(7);
23 constexpr DynamicTypeBits kDynamicTupleTypeBit = DYNAMIC_TYPE_BIT(8);
24 constexpr DynamicTypeBits kDynamicClassTypeBit = DYNAMIC_TYPE_BIT(10);
25 
26 #define FORALL_DYNAMIC_TYPES(_)                                              \
27   _(Tensor, DYNAMIC_TYPE_BIT(0), 1)                                          \
28   _(None, kDynamicNoneTypeBit, 1)                                            \
29   _(Bool, DYNAMIC_TYPE_BIT(2), 1)                                            \
30   _(Int, kDynamicIntTypeBit, 1)                                              \
31   _(Float, kDynamicFloatTypeBit, 1)                                          \
32   _(Complex, kDynamicComplexTypeBit, 1)                                      \
33   _(Number,                                                                  \
34     (kDynamicIntTypeBit | kDynamicFloatTypeBit | kDynamicComplexTypeBit),    \
35     1)                                                                       \
36   _(String, DYNAMIC_TYPE_BIT(6), 1)                                          \
37   _(List, kDynamicListTypeBit, 0)                                            \
38   _(Tuple, (kDynamicTupleTypeBit | kDynamicCovariantTypeBit), 0)             \
39   _(Dict, DYNAMIC_TYPE_BIT(9), 0)                                            \
40   _(Class, kDynamicClassTypeBit, 0)                                          \
41   _(Optional,                                                                \
42     (DYNAMIC_TYPE_BIT(11) | kDynamicNoneTypeBit | kDynamicCovariantTypeBit), \
43     0)                                                                       \
44   _(AnyList, (kDynamicListTypeBit | kDynamicAnyTypeBit), 1)                  \
45   _(AnyTuple,                                                                \
46     (kDynamicTupleTypeBit | kDynamicCovariantTypeBit | kDynamicAnyTypeBit),  \
47     1)                                                                       \
48   _(DeviceObj, DYNAMIC_TYPE_BIT(12), 1)                                      \
49   _(StreamObj, DYNAMIC_TYPE_BIT(13), 1)                                      \
50   _(Capsule, DYNAMIC_TYPE_BIT(14), 1)                                        \
51   _(Generator, DYNAMIC_TYPE_BIT(15), 1)                                      \
52   _(Storage, DYNAMIC_TYPE_BIT(16), 1)                                        \
53   _(Var, DYNAMIC_TYPE_BIT(17), 0)                                            \
54   _(AnyClass, (kDynamicClassTypeBit | kDynamicAnyTypeBit), 1)                \
55   _(QScheme, DYNAMIC_TYPE_BIT(18), 1)                                        \
56   _(Quantizer, DYNAMIC_TYPE_BIT(19), 1)                                      \
57   _(AnyEnum, DYNAMIC_TYPE_BIT(20), 1)                                        \
58   _(RRef, DYNAMIC_TYPE_BIT(21), 0)                                           \
59   _(Future, DYNAMIC_TYPE_BIT(22), 0)                                         \
60   _(Await, DYNAMIC_TYPE_BIT(23), 0)                                          \
61   _(Any, 0xffffffff, 1)
62 
63 #define FORALL_DYNAMIC_TYPES_FAKE(_) \
64   _(ScalarType, kDynamicIntTypeBit, 1)                                \
65   _(Layout, kDynamicIntTypeBit, 1)                                        \
66   _(SymInt, kDynamicIntTypeBit, 1)                                        \
67   _(MemoryFormat, kDynamicIntTypeBit, 1)
68 
69 #define FORWARD_DECL_TYPE(NAME, _, __) struct NAME ## Type;
70   FORALL_DYNAMIC_TYPES(FORWARD_DECL_TYPE)
71   FORALL_DYNAMIC_TYPES_FAKE(FORWARD_DECL_TYPE)
72 #undef FORWARD_DECL_TYPE
73 
74 class DynamicType;
75 using DynamicTypePtr = std::shared_ptr<DynamicType>;
76 
77 /**
78  * DynamicType is designed as a low dependency type system for TorchScript. The
79  * existing JIT types are used for both compilation and runtime, which makes
80  * sense for server contexts because we often compile and run the model in
81  * the same process, however this doesn't hold for mobile devices where we
82  * always compiles a model ahead of time, therefore there will be dependencies
83  * which are not needed, but built with mobile runtime causing binary size
84  * bloat, by design. Every basic type like Int, Bool or String will bring their
85  * vtable, typeinfo, constructor, destructor and even more data from their
86  * specializations for STL types to the binary causing a long tail bloat.
87  *
88  * The core problem is about the complexity to implement and maintain a single
89  * type system for both analysis and execution purposes. Although they should
90  * have the exactly same semantics, in practice implement a unified abstraction
91  * adds conceptual and representational overhead for both sides of the world.
92  *
93  * To address the issues, DynamicType implements a minimal subset of JIT types
94  * and uses a generic algorithm to test all subtyping relations. To achieve
95  * this, we assign each dynamic type a single integer tag to represent its
96  * semantics. More specifically, a dynamic type is defined as a set of "control
97  * bits" and "data bits", where control bits describe the special behavior when
98  * testing a type and data bits map to identity of each nominal type. We use bit
99  * operations to perform all the tests.
100  *
101  * For example, a "covariant bit" is a control bit used to describe if a type
102  * is covariant, right now the most used one is tuple type, and in addition to
103  * the control bit, tuple type's data bit is the 8th bit from the LSB. Control
104  * bits start from MSB and data bits start from LSB.
105  *
106  * If two types are equal, then they are subtype of each other, also if the bits
107  * from one type tag is subset of the other tag, it automatically becomes a
108  * subtype of the other. This simplifies the subtyping logic a lot, and over the
109  * long term it is possible to adopt this scheme on the server side as well.
110  * Special cases can be added but they generally should not take too much code
111  * size.
112  *
113  * DynamicType may or may not inherit from c10::Type because it's not the core
114  * requirement of DynamicType to interface with existing JIT types, but we might
115  * want to inherit from c10::Type to reduce the migration cost.
116  */
117 class DynamicType : public SharedType {
118   using ClassTypePtr = std::shared_ptr<const c10::ClassType>;
119 
120   /**
121    * A implementation detail to support NamedTuple.
122    */
123   struct LabeledDynamicType {
124     std::optional<std::string> label;
125     DynamicTypePtr ty;
LabeledDynamicTypeLabeledDynamicType126     explicit LabeledDynamicType(DynamicTypePtr t) : ty(std::move(t)) {}
127 
128     bool equals(const LabeledDynamicType& other) const;
129     bool isSubtypeOf(const LabeledDynamicType& other) const;
130   };
131 
132  public:
133   // TODO Change Ptr to DynamicTypePtr when all migrations are done.
134   using Ptr = TypePtr;
135   using ElementType = DynamicType;
136   ~DynamicType() override;
137 
138   struct Arguments {
139     Arguments() = default;
140     Arguments(c10::ArrayRef<TypePtr>);
141     Arguments(const std::vector<c10::string_view>&, c10::ArrayRef<TypePtr>);
142     std::vector<LabeledDynamicType> elems;
143   };
144 
145   enum class Tag : DynamicTypeBits {
146 #define DYNAMIC_TYPE_ITEM(NAME, VAL, _) NAME = VAL,
147     FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_ITEM)
148     FORALL_DYNAMIC_TYPES_FAKE(DYNAMIC_TYPE_ITEM)
149 #undef DYNAMIC_TYPE_ITEM
150   };
151 
152   bool equals(const Type& rhs) const override;
153   bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override;
154   std::string str() const override;
155   static const TypeKind Kind = TypeKind::DynamicType;
156   static TORCH_API DynamicTypePtr create(Type& ty);
157 
158   explicit DynamicType(Tag, Arguments);
159   explicit DynamicType(Tag, c10::string_view, Arguments);
160 
161   TypePtr containedType(size_t) const override;
162   size_t containedTypeSize() const override;
tag()163   Tag tag() const {
164     return tag_;
165   }
name()166   const std::optional<std::string>& name() const {
167     return name_;
168   }
arguments()169   const Arguments& arguments() const {
170     return arguments_;
171   }
172   TORCH_API TypeKind dynamicKind() const;
173 
174   // Should be used only on the server side to restore static type information.
175 #ifndef C10_MOBILE
176   TORCH_API
177 #endif
178   TypePtr fallback() const;
179 
180  private:
symmetric()181   bool symmetric() const override {
182     return false;
183   }
184   friend struct Type;
185   static std::shared_ptr<const DynamicType> create(const Type& ty);
186   DynamicType(const Type& other);
187   bool equals(const DynamicType& other) const;
188 
189   template <typename F>
compareArguments(const DynamicType & other,const F & f)190   bool compareArguments(const DynamicType& other, const F& f) const {
191     if (arguments_.elems.size() != other.arguments_.elems.size()) {
192       return false;
193     }
194     for (size_t i = 0; i < arguments_.elems.size(); i++) {
195       if (!f(arguments_.elems[i], other.arguments_.elems[i])) {
196         return false;
197       }
198     }
199     return true;
200   }
201 
202   Tag tag_;
203   std::optional<std::string> name_;
204   union {
205     Arguments arguments_;
206     ClassTypePtr class_;
207   };
208 };
209 
210 template <typename T>
211 struct DynamicTypeTrait {
tagValueDynamicTypeTrait212   C10_NOINLINE static auto tagValue() {
213     TORCH_CHECK(false);
214     return DynamicType::Tag::Any;
215   }
216 };
217 
218 namespace detail {
219 C10_NOINLINE DynamicTypePtr makeBaseType(DynamicType::Tag tag);
220 }
221 
222 #define DYNAMIC_TYPE_TAG_VALUE(NAME, _, IS_BASE_TYPE)      \
223   template <>                                              \
224   struct TORCH_API DynamicTypeTrait<NAME##Type> {          \
225     C10_ERASE static auto tagValue() {                     \
226       return DynamicType::Tag::NAME;                       \
227     }                                                      \
228     static constexpr bool isBaseType = IS_BASE_TYPE;       \
229     template <typename T = const DynamicTypePtr&>          \
230     static std::enable_if_t<isBaseType, T> getBaseType() { \
231       static auto type = detail::makeBaseType(tagValue()); \
232       return type;                                         \
233     }                                                      \
234   }; // namespace c10
235 FORALL_DYNAMIC_TYPES(DYNAMIC_TYPE_TAG_VALUE)
236 FORALL_DYNAMIC_TYPES_FAKE(DYNAMIC_TYPE_TAG_VALUE)
237 #undef DYNAMIC_TYPE_TAG_VALUE
238 
239 } // namespace c10
240