1 #pragma once 2 3 #include <c10/macros/Export.h> 4 #include <c10/util/ArrayRef.h> 5 #include <c10/util/Exception.h> 6 #include <c10/util/intrusive_ptr.h> 7 #include <cstdint> 8 #include <optional> 9 #include <ostream> 10 #include <string> 11 12 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") 13 14 namespace c10 { 15 16 class SymNodeImpl; 17 using SymNode = c10::intrusive_ptr<SymNodeImpl>; 18 19 // When you add a method, you also need to edit 20 // torch/csrc/jit/python/init.cpp 21 // torch/csrc/utils/python_symnode.h 22 // c10/core/ConstantSymNodeImpl.h 23 class C10_API SymNodeImpl : public c10::intrusive_ptr_target { 24 public: 25 ~SymNodeImpl() override = default; 26 27 template <typename T> dyn_cast()28 c10::intrusive_ptr<T> dyn_cast() const { 29 return c10::intrusive_ptr<T>::reclaim_copy(dynamic_cast<T*>(this)); 30 } 31 32 // these could be pure virtual when we implement LTC versions is_int()33 virtual bool is_int() { 34 TORCH_CHECK(false, "NYI"); 35 } is_bool()36 virtual bool is_bool() { 37 TORCH_CHECK(false, "NYI"); 38 } is_float()39 virtual bool is_float() { 40 TORCH_CHECK(false, "NYI"); 41 } is_nested_int()42 virtual bool is_nested_int() const { 43 return false; 44 } add(const SymNode & other)45 virtual SymNode add(const SymNode& other) { 46 TORCH_CHECK(false, "NYI"); 47 } sub(const SymNode & other)48 virtual SymNode sub(const SymNode& other) { 49 TORCH_CHECK(false, "NYI"); 50 } mul(const SymNode & other)51 virtual SymNode mul(const SymNode& other) { 52 TORCH_CHECK(false, "NYI"); 53 } 54 // NB: legacy, prefer float_truediv or int_truediv truediv(const SymNode & other)55 virtual SymNode truediv(const SymNode& other) { 56 TORCH_CHECK(false, "NYI"); 57 } float_truediv(const SymNode & other)58 virtual SymNode float_truediv(const SymNode& other) { 59 return truediv(other); 60 } int_truediv(const SymNode & other)61 virtual SymNode int_truediv(const SymNode& other) { 62 return truediv(other); 63 } 64 // NB: legacy, prefer float_pow or pow_by_natural pow(const SymNode & other)65 virtual SymNode pow(const SymNode& other) { 66 TORCH_CHECK(false, "NYI"); 67 } float_pow(const SymNode & other)68 virtual SymNode float_pow(const SymNode& other) { 69 return pow(other); 70 } pow_by_natural(const SymNode & other)71 virtual SymNode pow_by_natural(const SymNode& other) { 72 return pow(other); 73 } 74 // NB: legacy, prefer int_floordiv floordiv(const SymNode & other)75 virtual SymNode floordiv(const SymNode& other) { 76 TORCH_CHECK(false, "NYI"); 77 } int_floordiv(const SymNode & other)78 virtual SymNode int_floordiv(const SymNode& other) { 79 return floordiv(other); 80 } mod(const SymNode & other)81 virtual SymNode mod(const SymNode& other) { 82 TORCH_CHECK(false, "NYI"); 83 } eq(const SymNode & other)84 virtual SymNode eq(const SymNode& other) { 85 TORCH_CHECK(false, "NYI"); 86 } ne(const SymNode & other)87 virtual SymNode ne(const SymNode& other) { 88 TORCH_CHECK(false, "NYI"); 89 } gt(const SymNode & other)90 virtual SymNode gt(const SymNode& other) { 91 TORCH_CHECK(false, "NYI"); 92 } lt(const SymNode & other)93 virtual SymNode lt(const SymNode& other) { 94 TORCH_CHECK(false, "NYI"); 95 } le(const SymNode & other)96 virtual SymNode le(const SymNode& other) { 97 TORCH_CHECK(false, "NYI"); 98 } ge(const SymNode & other)99 virtual SymNode ge(const SymNode& other) { 100 TORCH_CHECK(false, "NYI"); 101 } ceil()102 virtual SymNode ceil() { 103 TORCH_CHECK(false, "NYI"); 104 } floor()105 virtual SymNode floor() { 106 TORCH_CHECK(false, "NYI"); 107 } neg()108 virtual SymNode neg() { 109 TORCH_CHECK(false, "NYI"); 110 }; sym_min(const SymNode & other)111 virtual SymNode sym_min(const SymNode& other) { 112 TORCH_CHECK(false, "NYI"); 113 }; sym_max(const SymNode & other)114 virtual SymNode sym_max(const SymNode& other) { 115 TORCH_CHECK(false, "NYI"); 116 }; sym_or(const SymNode & other)117 virtual SymNode sym_or(const SymNode& other) { 118 TORCH_CHECK(false, "NYI"); 119 }; sym_and(const SymNode & other)120 virtual SymNode sym_and(const SymNode& other) { 121 TORCH_CHECK(false, "NYI"); 122 }; sym_not()123 virtual SymNode sym_not() { 124 TORCH_CHECK(false, "NYI"); 125 }; sym_ite(const SymNode & then_val,const SymNode & else_val)126 virtual SymNode sym_ite(const SymNode& then_val, const SymNode& else_val) { 127 TORCH_CHECK(false, "NYI"); 128 }; 129 // NB: self is ignored here, only the arguments are used is_contiguous(ArrayRef<SymNode> sizes,ArrayRef<SymNode> strides)130 virtual SymNode is_contiguous( 131 ArrayRef<SymNode> sizes, 132 ArrayRef<SymNode> strides) { 133 TORCH_CHECK(false, "NYI"); 134 }; is_channels_last_contiguous_2d(ArrayRef<SymNode> sizes,ArrayRef<SymNode> strides)135 virtual SymNode is_channels_last_contiguous_2d( 136 ArrayRef<SymNode> sizes, 137 ArrayRef<SymNode> strides) { 138 TORCH_CHECK(false, "NYI"); 139 }; is_channels_last_contiguous_3d(ArrayRef<SymNode> sizes,ArrayRef<SymNode> strides)140 virtual SymNode is_channels_last_contiguous_3d( 141 ArrayRef<SymNode> sizes, 142 ArrayRef<SymNode> strides) { 143 TORCH_CHECK(false, "NYI"); 144 }; is_channels_last_strides_2d(ArrayRef<SymNode> sizes,ArrayRef<SymNode> strides)145 virtual SymNode is_channels_last_strides_2d( 146 ArrayRef<SymNode> sizes, 147 ArrayRef<SymNode> strides) { 148 TORCH_CHECK(false, "NYI"); 149 }; is_channels_last_strides_3d(ArrayRef<SymNode> sizes,ArrayRef<SymNode> strides)150 virtual SymNode is_channels_last_strides_3d( 151 ArrayRef<SymNode> sizes, 152 ArrayRef<SymNode> strides) { 153 TORCH_CHECK(false, "NYI"); 154 }; is_non_overlapping_and_dense(ArrayRef<SymNode> sizes,ArrayRef<SymNode> strides)155 virtual SymNode is_non_overlapping_and_dense( 156 ArrayRef<SymNode> sizes, 157 ArrayRef<SymNode> strides) { 158 TORCH_CHECK(false, "NYI"); 159 }; clone()160 virtual SymNode clone() { 161 TORCH_CHECK(false, "NYI"); 162 }; sym_float()163 virtual SymNode sym_float() { 164 TORCH_CHECK(false, "NYI"); 165 } wrap_int(int64_t num)166 virtual SymNode wrap_int(int64_t num) { 167 TORCH_CHECK(false, "NYI"); 168 }; wrap_float(double num)169 virtual SymNode wrap_float(double num) { 170 TORCH_CHECK(false, "NYI"); 171 }; wrap_bool(bool num)172 virtual SymNode wrap_bool(bool num) { 173 TORCH_CHECK(false, "NYI"); 174 }; guard_int(const char * file,int64_t line)175 virtual int64_t guard_int(const char* file, int64_t line) { 176 TORCH_CHECK(false, "NYI"); 177 }; guard_bool(const char * file,int64_t line)178 virtual bool guard_bool(const char* file, int64_t line) { 179 TORCH_CHECK(false, "NYI"); 180 }; guard_float(const char * file,int64_t line)181 virtual double guard_float(const char* file, int64_t line) { 182 TORCH_CHECK(false, "NYI"); 183 }; guard_size_oblivious(const char * file,int64_t line)184 virtual bool guard_size_oblivious(const char* file, int64_t line) { 185 // No improvement for unbacked SymBools by default, replace this 186 // with a better implementation! 187 return guard_bool(file, line); 188 } expect_true(const char * file,int64_t line)189 virtual bool expect_true(const char* file, int64_t line) { 190 // No improvement for unbacked SymBools by default, replace this 191 // with a better implementation! 192 return guard_bool(file, line); 193 }; expect_size(const char * file,int64_t line)194 virtual bool expect_size(const char* file, int64_t line) { 195 // No improvement for unbacked SymInts by default, replace this 196 // with a better implementation! 197 return ge(wrap_int(0))->guard_bool(file, line); 198 }; int_()199 virtual int64_t int_() { 200 TORCH_CHECK(false, "NYI"); 201 }; bool_()202 virtual bool bool_() { 203 TORCH_CHECK(false, "NYI"); 204 }; has_hint()205 virtual bool has_hint() { 206 TORCH_CHECK(false, "NYI"); 207 }; str()208 virtual std::string str() { 209 TORCH_CHECK(false, "NYI"); 210 }; _graph_repr()211 virtual std::string _graph_repr() { 212 return str(); 213 }; nested_int()214 virtual std::optional<int64_t> nested_int() { 215 return std::nullopt; 216 } nested_int_coeff()217 virtual std::optional<int64_t> nested_int_coeff() { 218 return std::nullopt; 219 } constant_int()220 virtual std::optional<int64_t> constant_int() { 221 return std::nullopt; 222 } constant_bool()223 virtual std::optional<bool> constant_bool() { 224 return std::nullopt; 225 } maybe_as_int()226 virtual std::optional<int64_t> maybe_as_int() { 227 return std::nullopt; 228 } is_constant()229 virtual bool is_constant() { 230 return false; 231 } is_symbolic()232 virtual bool is_symbolic() { 233 return true; 234 } 235 std::ostream& operator<<(std::ostream& os) { 236 os << str(); 237 return os; 238 } 239 }; 240 241 } // namespace c10 242 C10_DIAGNOSTIC_POP() 243