1 #pragma once 2 3 #include <c10/core/SymNodeImpl.h> 4 #include <c10/macros/Export.h> 5 #include <c10/util/Exception.h> 6 #include <cstdint> 7 #include <optional> 8 #include <string> 9 #include <variant> 10 11 namespace c10 { 12 13 // Unlike other SymNodeImpl, this cannot be "dispatched" conventionally, 14 // as it typically needs to defer to another SymNodeImpl 15 // 16 // Can either represent a bool, int (don't support float yet) this is useful 17 // for representing otherwise unrepresentable large negative integer constant. 18 template <typename T> 19 class C10_API ConstantSymNodeImpl : public SymNodeImpl { 20 static_assert( 21 ::std::is_same_v<T, int64_t> || ::std::is_same_v<T, bool>, 22 "ConstantSymNodeImpl can only accept int64_t or bool types"); 23 24 public: ConstantSymNodeImpl(T val)25 ConstantSymNodeImpl(T val) : value_(val) {} 26 is_int()27 bool is_int() override { 28 return is_int_(); 29 } is_bool()30 bool is_bool() override { 31 return is_bool_(); 32 } is_float()33 bool is_float() override { 34 return false; 35 } guard_int(const char * file,int64_t line)36 int64_t guard_int( 37 const char* file [[maybe_unused]], 38 int64_t line [[maybe_unused]]) override { 39 TORCH_CHECK(is_int(), "not an int"); 40 return int_(); 41 } guard_bool(const char * file,int64_t line)42 bool guard_bool( 43 const char* file [[maybe_unused]], 44 int64_t line [[maybe_unused]]) override { 45 TORCH_CHECK(is_bool(), "not a bool"); 46 return bool_(); 47 } guard_float(const char * file,int64_t line)48 double guard_float( 49 const char* file [[maybe_unused]], 50 int64_t line [[maybe_unused]]) override { 51 TORCH_CHECK(false, "not a float"); 52 } int_()53 int64_t int_() override { 54 TORCH_CHECK(is_int(), "not an int"); 55 return ::std::get<int64_t>(value_); 56 } bool_()57 bool bool_() override { 58 TORCH_CHECK(is_bool(), "not a bool"); 59 return ::std::get<bool>(value_); 60 } has_hint()61 bool has_hint() override { 62 return true; 63 } 64 c10::SymNode eq(const c10::SymNode& other) override; 65 c10::SymNode ne(const c10::SymNode& other) override; 66 c10::SymNode ge(const c10::SymNode& other) override; 67 c10::SymNode le(const c10::SymNode& other) override; 68 c10::SymNode lt(const c10::SymNode& other) override; 69 c10::SymNode gt(const c10::SymNode& other) override; 70 c10::SymNode mul(const c10::SymNode& other) override; str()71 ::std::string str() override { 72 if constexpr (is_int_()) { 73 return ::std::to_string(::std::get<int64_t>(value_)); 74 } else { 75 return ::std::get<bool>(value_) ? "true" : "false"; 76 } 77 } constant_int()78 std::optional<int64_t> constant_int() override { 79 if constexpr (is_int_()) { 80 return ::std::get<int64_t>(value_); 81 } else { 82 return std::nullopt; 83 } 84 } constant_bool()85 std::optional<bool> constant_bool() override { 86 if constexpr (is_bool_()) { 87 return ::std::get<bool>(value_); 88 } else { 89 return std::nullopt; 90 } 91 } is_constant()92 bool is_constant() override { 93 return true; 94 } is_symbolic()95 bool is_symbolic() override { 96 return false; 97 } 98 99 private: 100 ::std::variant<int64_t, bool> value_; 101 is_int_()102 static constexpr bool is_int_() { 103 return ::std::is_same_v<T, int64_t>; 104 } is_bool_()105 static constexpr bool is_bool_() { 106 return ::std::is_same_v<T, bool>; 107 } 108 }; 109 110 } // namespace c10 111