xref: /aosp_15_r20/external/pytorch/c10/core/SymNodeImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Export.h>
4 #include <c10/util/ArrayRef.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/intrusive_ptr.h>
7 #include <cstdint>
8 #include <optional>
9 #include <ostream>
10 #include <string>
11 
12 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
13 
14 namespace c10 {
15 
16 class SymNodeImpl;
17 using SymNode = c10::intrusive_ptr<SymNodeImpl>;
18 
19 // When you add a method, you also need to edit
20 // torch/csrc/jit/python/init.cpp
21 // torch/csrc/utils/python_symnode.h
22 // c10/core/ConstantSymNodeImpl.h
23 class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
24  public:
25   ~SymNodeImpl() override = default;
26 
27   template <typename T>
dyn_cast()28   c10::intrusive_ptr<T> dyn_cast() const {
29     return c10::intrusive_ptr<T>::reclaim_copy(dynamic_cast<T*>(this));
30   }
31 
32   // these could be pure virtual when we implement LTC versions
is_int()33   virtual bool is_int() {
34     TORCH_CHECK(false, "NYI");
35   }
is_bool()36   virtual bool is_bool() {
37     TORCH_CHECK(false, "NYI");
38   }
is_float()39   virtual bool is_float() {
40     TORCH_CHECK(false, "NYI");
41   }
is_nested_int()42   virtual bool is_nested_int() const {
43     return false;
44   }
add(const SymNode & other)45   virtual SymNode add(const SymNode& other) {
46     TORCH_CHECK(false, "NYI");
47   }
sub(const SymNode & other)48   virtual SymNode sub(const SymNode& other) {
49     TORCH_CHECK(false, "NYI");
50   }
mul(const SymNode & other)51   virtual SymNode mul(const SymNode& other) {
52     TORCH_CHECK(false, "NYI");
53   }
54   // NB: legacy, prefer float_truediv or int_truediv
truediv(const SymNode & other)55   virtual SymNode truediv(const SymNode& other) {
56     TORCH_CHECK(false, "NYI");
57   }
float_truediv(const SymNode & other)58   virtual SymNode float_truediv(const SymNode& other) {
59     return truediv(other);
60   }
int_truediv(const SymNode & other)61   virtual SymNode int_truediv(const SymNode& other) {
62     return truediv(other);
63   }
64   // NB: legacy, prefer float_pow or pow_by_natural
pow(const SymNode & other)65   virtual SymNode pow(const SymNode& other) {
66     TORCH_CHECK(false, "NYI");
67   }
float_pow(const SymNode & other)68   virtual SymNode float_pow(const SymNode& other) {
69     return pow(other);
70   }
pow_by_natural(const SymNode & other)71   virtual SymNode pow_by_natural(const SymNode& other) {
72     return pow(other);
73   }
74   // NB: legacy, prefer int_floordiv
floordiv(const SymNode & other)75   virtual SymNode floordiv(const SymNode& other) {
76     TORCH_CHECK(false, "NYI");
77   }
int_floordiv(const SymNode & other)78   virtual SymNode int_floordiv(const SymNode& other) {
79     return floordiv(other);
80   }
mod(const SymNode & other)81   virtual SymNode mod(const SymNode& other) {
82     TORCH_CHECK(false, "NYI");
83   }
eq(const SymNode & other)84   virtual SymNode eq(const SymNode& other) {
85     TORCH_CHECK(false, "NYI");
86   }
ne(const SymNode & other)87   virtual SymNode ne(const SymNode& other) {
88     TORCH_CHECK(false, "NYI");
89   }
gt(const SymNode & other)90   virtual SymNode gt(const SymNode& other) {
91     TORCH_CHECK(false, "NYI");
92   }
lt(const SymNode & other)93   virtual SymNode lt(const SymNode& other) {
94     TORCH_CHECK(false, "NYI");
95   }
le(const SymNode & other)96   virtual SymNode le(const SymNode& other) {
97     TORCH_CHECK(false, "NYI");
98   }
ge(const SymNode & other)99   virtual SymNode ge(const SymNode& other) {
100     TORCH_CHECK(false, "NYI");
101   }
ceil()102   virtual SymNode ceil() {
103     TORCH_CHECK(false, "NYI");
104   }
floor()105   virtual SymNode floor() {
106     TORCH_CHECK(false, "NYI");
107   }
neg()108   virtual SymNode neg() {
109     TORCH_CHECK(false, "NYI");
110   };
sym_min(const SymNode & other)111   virtual SymNode sym_min(const SymNode& other) {
112     TORCH_CHECK(false, "NYI");
113   };
sym_max(const SymNode & other)114   virtual SymNode sym_max(const SymNode& other) {
115     TORCH_CHECK(false, "NYI");
116   };
sym_or(const SymNode & other)117   virtual SymNode sym_or(const SymNode& other) {
118     TORCH_CHECK(false, "NYI");
119   };
sym_and(const SymNode & other)120   virtual SymNode sym_and(const SymNode& other) {
121     TORCH_CHECK(false, "NYI");
122   };
sym_not()123   virtual SymNode sym_not() {
124     TORCH_CHECK(false, "NYI");
125   };
sym_ite(const SymNode & then_val,const SymNode & else_val)126   virtual SymNode sym_ite(const SymNode& then_val, const SymNode& else_val) {
127     TORCH_CHECK(false, "NYI");
128   };
129   // NB: self is ignored here, only the arguments are used
is_contiguous(ArrayRef<SymNode> sizes,ArrayRef<SymNode> strides)130   virtual SymNode is_contiguous(
131       ArrayRef<SymNode> sizes,
132       ArrayRef<SymNode> strides) {
133     TORCH_CHECK(false, "NYI");
134   };
is_channels_last_contiguous_2d(ArrayRef<SymNode> sizes,ArrayRef<SymNode> strides)135   virtual SymNode is_channels_last_contiguous_2d(
136       ArrayRef<SymNode> sizes,
137       ArrayRef<SymNode> strides) {
138     TORCH_CHECK(false, "NYI");
139   };
is_channels_last_contiguous_3d(ArrayRef<SymNode> sizes,ArrayRef<SymNode> strides)140   virtual SymNode is_channels_last_contiguous_3d(
141       ArrayRef<SymNode> sizes,
142       ArrayRef<SymNode> strides) {
143     TORCH_CHECK(false, "NYI");
144   };
is_channels_last_strides_2d(ArrayRef<SymNode> sizes,ArrayRef<SymNode> strides)145   virtual SymNode is_channels_last_strides_2d(
146       ArrayRef<SymNode> sizes,
147       ArrayRef<SymNode> strides) {
148     TORCH_CHECK(false, "NYI");
149   };
is_channels_last_strides_3d(ArrayRef<SymNode> sizes,ArrayRef<SymNode> strides)150   virtual SymNode is_channels_last_strides_3d(
151       ArrayRef<SymNode> sizes,
152       ArrayRef<SymNode> strides) {
153     TORCH_CHECK(false, "NYI");
154   };
is_non_overlapping_and_dense(ArrayRef<SymNode> sizes,ArrayRef<SymNode> strides)155   virtual SymNode is_non_overlapping_and_dense(
156       ArrayRef<SymNode> sizes,
157       ArrayRef<SymNode> strides) {
158     TORCH_CHECK(false, "NYI");
159   };
clone()160   virtual SymNode clone() {
161     TORCH_CHECK(false, "NYI");
162   };
sym_float()163   virtual SymNode sym_float() {
164     TORCH_CHECK(false, "NYI");
165   }
wrap_int(int64_t num)166   virtual SymNode wrap_int(int64_t num) {
167     TORCH_CHECK(false, "NYI");
168   };
wrap_float(double num)169   virtual SymNode wrap_float(double num) {
170     TORCH_CHECK(false, "NYI");
171   };
wrap_bool(bool num)172   virtual SymNode wrap_bool(bool num) {
173     TORCH_CHECK(false, "NYI");
174   };
guard_int(const char * file,int64_t line)175   virtual int64_t guard_int(const char* file, int64_t line) {
176     TORCH_CHECK(false, "NYI");
177   };
guard_bool(const char * file,int64_t line)178   virtual bool guard_bool(const char* file, int64_t line) {
179     TORCH_CHECK(false, "NYI");
180   };
guard_float(const char * file,int64_t line)181   virtual double guard_float(const char* file, int64_t line) {
182     TORCH_CHECK(false, "NYI");
183   };
guard_size_oblivious(const char * file,int64_t line)184   virtual bool guard_size_oblivious(const char* file, int64_t line) {
185     // No improvement for unbacked SymBools by default, replace this
186     // with a better implementation!
187     return guard_bool(file, line);
188   }
expect_true(const char * file,int64_t line)189   virtual bool expect_true(const char* file, int64_t line) {
190     // No improvement for unbacked SymBools by default, replace this
191     // with a better implementation!
192     return guard_bool(file, line);
193   };
expect_size(const char * file,int64_t line)194   virtual bool expect_size(const char* file, int64_t line) {
195     // No improvement for unbacked SymInts by default, replace this
196     // with a better implementation!
197     return ge(wrap_int(0))->guard_bool(file, line);
198   };
int_()199   virtual int64_t int_() {
200     TORCH_CHECK(false, "NYI");
201   };
bool_()202   virtual bool bool_() {
203     TORCH_CHECK(false, "NYI");
204   };
has_hint()205   virtual bool has_hint() {
206     TORCH_CHECK(false, "NYI");
207   };
str()208   virtual std::string str() {
209     TORCH_CHECK(false, "NYI");
210   };
_graph_repr()211   virtual std::string _graph_repr() {
212     return str();
213   };
nested_int()214   virtual std::optional<int64_t> nested_int() {
215     return std::nullopt;
216   }
nested_int_coeff()217   virtual std::optional<int64_t> nested_int_coeff() {
218     return std::nullopt;
219   }
constant_int()220   virtual std::optional<int64_t> constant_int() {
221     return std::nullopt;
222   }
constant_bool()223   virtual std::optional<bool> constant_bool() {
224     return std::nullopt;
225   }
maybe_as_int()226   virtual std::optional<int64_t> maybe_as_int() {
227     return std::nullopt;
228   }
is_constant()229   virtual bool is_constant() {
230     return false;
231   }
is_symbolic()232   virtual bool is_symbolic() {
233     return true;
234   }
235   std::ostream& operator<<(std::ostream& os) {
236     os << str();
237     return os;
238   }
239 };
240 
241 } // namespace c10
242 C10_DIAGNOSTIC_POP()
243