xref: /aosp_15_r20/external/pytorch/c10/core/ConstantSymNodeImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/SymNodeImpl.h>
4 #include <c10/macros/Export.h>
5 #include <c10/util/Exception.h>
6 #include <cstdint>
7 #include <optional>
8 #include <string>
9 #include <variant>
10 
11 namespace c10 {
12 
13 // Unlike other SymNodeImpl, this cannot be "dispatched" conventionally,
14 // as it typically needs to defer to another SymNodeImpl
15 //
16 // Can either represent a bool, int (don't support float yet) this is useful
17 // for representing otherwise unrepresentable large negative integer constant.
18 template <typename T>
19 class C10_API ConstantSymNodeImpl : public SymNodeImpl {
20   static_assert(
21       ::std::is_same_v<T, int64_t> || ::std::is_same_v<T, bool>,
22       "ConstantSymNodeImpl can only accept int64_t or bool types");
23 
24  public:
ConstantSymNodeImpl(T val)25   ConstantSymNodeImpl(T val) : value_(val) {}
26 
is_int()27   bool is_int() override {
28     return is_int_();
29   }
is_bool()30   bool is_bool() override {
31     return is_bool_();
32   }
is_float()33   bool is_float() override {
34     return false;
35   }
guard_int(const char * file,int64_t line)36   int64_t guard_int(
37       const char* file [[maybe_unused]],
38       int64_t line [[maybe_unused]]) override {
39     TORCH_CHECK(is_int(), "not an int");
40     return int_();
41   }
guard_bool(const char * file,int64_t line)42   bool guard_bool(
43       const char* file [[maybe_unused]],
44       int64_t line [[maybe_unused]]) override {
45     TORCH_CHECK(is_bool(), "not a bool");
46     return bool_();
47   }
guard_float(const char * file,int64_t line)48   double guard_float(
49       const char* file [[maybe_unused]],
50       int64_t line [[maybe_unused]]) override {
51     TORCH_CHECK(false, "not a float");
52   }
int_()53   int64_t int_() override {
54     TORCH_CHECK(is_int(), "not an int");
55     return ::std::get<int64_t>(value_);
56   }
bool_()57   bool bool_() override {
58     TORCH_CHECK(is_bool(), "not a bool");
59     return ::std::get<bool>(value_);
60   }
has_hint()61   bool has_hint() override {
62     return true;
63   }
64   c10::SymNode eq(const c10::SymNode& other) override;
65   c10::SymNode ne(const c10::SymNode& other) override;
66   c10::SymNode ge(const c10::SymNode& other) override;
67   c10::SymNode le(const c10::SymNode& other) override;
68   c10::SymNode lt(const c10::SymNode& other) override;
69   c10::SymNode gt(const c10::SymNode& other) override;
70   c10::SymNode mul(const c10::SymNode& other) override;
str()71   ::std::string str() override {
72     if constexpr (is_int_()) {
73       return ::std::to_string(::std::get<int64_t>(value_));
74     } else {
75       return ::std::get<bool>(value_) ? "true" : "false";
76     }
77   }
constant_int()78   std::optional<int64_t> constant_int() override {
79     if constexpr (is_int_()) {
80       return ::std::get<int64_t>(value_);
81     } else {
82       return std::nullopt;
83     }
84   }
constant_bool()85   std::optional<bool> constant_bool() override {
86     if constexpr (is_bool_()) {
87       return ::std::get<bool>(value_);
88     } else {
89       return std::nullopt;
90     }
91   }
is_constant()92   bool is_constant() override {
93     return true;
94   }
is_symbolic()95   bool is_symbolic() override {
96     return false;
97   }
98 
99  private:
100   ::std::variant<int64_t, bool> value_;
101 
is_int_()102   static constexpr bool is_int_() {
103     return ::std::is_same_v<T, int64_t>;
104   }
is_bool_()105   static constexpr bool is_bool_() {
106     return ::std::is_same_v<T, bool>;
107   }
108 };
109 
110 } // namespace c10
111