xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/NestedIntSymNodeImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/ConstantSymNodeImpl.h>
4 #include <c10/core/SymNodeImpl.h>
5 #include <c10/macros/Export.h>
6 #include <c10/util/Exception.h>
7 #include <c10/util/intrusive_ptr.h>
8 #include <cstdint>
9 #include <optional>
10 #include <string>
11 
12 namespace c10 {
13 
14 // The motivating usecase for this is to represent the ragged size structure
15 // of a jagged tensor [B, [s_0, s_1, s_2], D] as a single integer j0. This
16 // allows us to simply return [B, j0, D] if someone queries for the size of our
17 // tensor.
18 //
19 // Morally we define comparison between two nested ints to return true if
20 // that comparison holds for all corresponding elements of the arrays they
21 // represent. Comparison between a nested int and a plain int is defined
22 // similarly.
23 //
24 // To simulate this desired behavior but also avoid the O(N) cost of checking,
25 // we associate each raggedness pattern with an integer "id" that can be used as
26 // a proxy to evaluate equality. We also constrain the range of values for this
27 // as to enable inequality checks.
28 //
29 // We also support a positive integer scalar "coeff" that is used for computing
30 // strides. For example given, a [B, j0, D] tensor, it can be strided in two
31 // different ways: [D * j0, D, 1] and [j0, 1, sum(j0)]. The coeff is used to
32 // differentiate the two cases.
33 //
34 // During tracing the strides of the outputs need to be a function of the size
35 // and strides of the inputs so it is important that NestedIntSymNode itself is
36 // able to express this.
37 class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl {
38  public:
39   // CAUTION: you should probably not be constructing these directly; please
40   // the higher-level API in python instead (TODO: actually introduce that).
NestedIntSymNodeImpl(int64_t val,int64_t coeff)41   explicit NestedIntSymNodeImpl(int64_t val, int64_t coeff)
42       : val_(val), coeff_(coeff) {}
43 
bool_()44   bool bool_() override {
45     return false;
46   }
47 
is_int()48   bool is_int() override {
49     return true;
50   }
51 
is_float()52   bool is_float() override {
53     return false;
54   }
55 
is_bool()56   bool is_bool() override {
57     return false;
58   }
59 
is_nested_int()60   bool is_nested_int() const override {
61     return true;
62   }
63 
has_hint()64   bool has_hint() override {
65     return true;
66   }
67 
wrap_int(int64_t num)68   c10::SymNode wrap_int(int64_t num) override {
69     return SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(num));
70   };
71 
guard_int(const char * file,int64_t line)72   int64_t guard_int(const char* file, int64_t line) override {
73     TORCH_CHECK(false);
74   }
75 
guard_float(const char * file,int64_t line)76   double guard_float(const char* file, int64_t line) override {
77     TORCH_CHECK(false, "not a float");
78   }
79 
guard_bool(const char * file,int64_t line)80   bool guard_bool(const char* file, int64_t line) override {
81     TORCH_CHECK(false, "not a bool");
82   }
83 
int_()84   int64_t int_() override {
85     TORCH_CHECK(false);
86   }
87 
str()88   std::string str() override {
89     if (coeff_ == 1) {
90       return "j" + std::to_string(val_);
91     }
92     return std::to_string(coeff_) + "*j" + std::to_string(val_);
93   }
94 
95   // NOTE [ Inequalities with nested int ]
96   //
97   // The semantics of nested int when it comes to relations is that it is
98   // treated as integer known to be within a certain range,
99   //
100   //     j0 \in [2, int64_t::max]
101   //
102   // allowing us to answer queries like j0 >= 1 (True), and j0 == 0 (False).
103   // This is a useful default range for the raggedness pattern of a jagged
104   // tensor (1) since sizes are non-negative, and (2) we need to get past 0/1
105   // specialization checks.
106   //
107   // [ Indeterminate inequalities error out ]
108   //
109   // Given the semantic defined above, certain relations like j0 < 3 are thus
110   // indeterminable. In our impl today, evaluating such relations error
111   //
112   // It may seem convenient to just define indeterminate relations to return
113   // False, but the implementation we maintain in parallel using sympy does not
114   // allow this.
115   //
116   // Sympy only allows overriding of Ge. The other relations (Lt, Gt, Le) are,
117   // by consequence, all derived from Ge e.g., Lt(a, b) := !Ge(a, b). This
118   // would mean that means that if we define the indeterminate j0 >= 3 to be
119   // False, the also indeterminate j0 < 3 will be evaluated to be True!
120   //
121   // [ Coefficient are assumed positive ]
122   //
123   // For the purpose of computing inequalities, we consider the coefficient of
124   // the nested int to be a positive integer.
125   //
126   // Thus, no modifications are needed to the logic since
127   // j0 >= k implies coeff * j0 >= k
128   //
129   c10::SymNode eq(const c10::SymNode& other) override;
130   c10::SymNode ne(const c10::SymNode& other) override;
131   c10::SymNode ge(const c10::SymNode& other) override;
132   c10::SymNode gt(const c10::SymNode& other) override;
133   c10::SymNode lt(const c10::SymNode& other) override;
134   c10::SymNode le(const c10::SymNode& other) override;
135   c10::SymNode mul(const c10::SymNode& other) override;
136 
nested_int()137   std::optional<int64_t> nested_int() override {
138     return val_;
139   }
140 
nested_int_coeff()141   std::optional<int64_t> nested_int_coeff() override {
142     return coeff_;
143   }
144 
is_symbolic()145   bool is_symbolic() override {
146     return false;
147   }
148 
149   c10::SymNode clone() override;
150 
151 #define DEFINE_BINARY_NOT_SUPPORTED(name)                           \
152   c10::SymNode name(const c10::SymNode& other) override {           \
153     TORCH_CHECK(false, #name " not supported by NestedIntSymNode"); \
154   }
155 
156   DEFINE_BINARY_NOT_SUPPORTED(add)
157   DEFINE_BINARY_NOT_SUPPORTED(sub)
158   DEFINE_BINARY_NOT_SUPPORTED(truediv)
159   DEFINE_BINARY_NOT_SUPPORTED(pow)
160   DEFINE_BINARY_NOT_SUPPORTED(floordiv)
161   DEFINE_BINARY_NOT_SUPPORTED(mod)
162   DEFINE_BINARY_NOT_SUPPORTED(sym_min)
163   DEFINE_BINARY_NOT_SUPPORTED(sym_max)
164   DEFINE_BINARY_NOT_SUPPORTED(sym_and)
165   DEFINE_BINARY_NOT_SUPPORTED(sym_or)
166 
167 #undef DEFINE_BINARY_NOT_SUPPORTED
168 
169 #define DEFINE_NOT_SUPPORTED(name)                                     \
170   c10::SymNode name() override {                                       \
171     TORCH_CHECK(false, #name " is not supported by NestedIntSymNode"); \
172   }
173 
174   DEFINE_NOT_SUPPORTED(sym_not)
175   DEFINE_NOT_SUPPORTED(ceil)
176   DEFINE_NOT_SUPPORTED(floor)
177   DEFINE_NOT_SUPPORTED(neg)
178   DEFINE_NOT_SUPPORTED(sym_float)
179 
180 #undef DEFINE_NOT_SUPPORTED
181 
182  private:
183   int64_t val_;
184   int64_t coeff_;
185 };
186 
187 } // namespace c10
188