1 #pragma once 2 3 #include <c10/core/ConstantSymNodeImpl.h> 4 #include <c10/core/SymNodeImpl.h> 5 #include <c10/macros/Export.h> 6 #include <c10/util/Exception.h> 7 #include <c10/util/intrusive_ptr.h> 8 #include <cstdint> 9 #include <optional> 10 #include <string> 11 12 namespace c10 { 13 14 // The motivating usecase for this is to represent the ragged size structure 15 // of a jagged tensor [B, [s_0, s_1, s_2], D] as a single integer j0. This 16 // allows us to simply return [B, j0, D] if someone queries for the size of our 17 // tensor. 18 // 19 // Morally we define comparison between two nested ints to return true if 20 // that comparison holds for all corresponding elements of the arrays they 21 // represent. Comparison between a nested int and a plain int is defined 22 // similarly. 23 // 24 // To simulate this desired behavior but also avoid the O(N) cost of checking, 25 // we associate each raggedness pattern with an integer "id" that can be used as 26 // a proxy to evaluate equality. We also constrain the range of values for this 27 // as to enable inequality checks. 28 // 29 // We also support a positive integer scalar "coeff" that is used for computing 30 // strides. For example given, a [B, j0, D] tensor, it can be strided in two 31 // different ways: [D * j0, D, 1] and [j0, 1, sum(j0)]. The coeff is used to 32 // differentiate the two cases. 33 // 34 // During tracing the strides of the outputs need to be a function of the size 35 // and strides of the inputs so it is important that NestedIntSymNode itself is 36 // able to express this. 37 class TORCH_API NestedIntSymNodeImpl : public SymNodeImpl { 38 public: 39 // CAUTION: you should probably not be constructing these directly; please 40 // the higher-level API in python instead (TODO: actually introduce that). NestedIntSymNodeImpl(int64_t val,int64_t coeff)41 explicit NestedIntSymNodeImpl(int64_t val, int64_t coeff) 42 : val_(val), coeff_(coeff) {} 43 bool_()44 bool bool_() override { 45 return false; 46 } 47 is_int()48 bool is_int() override { 49 return true; 50 } 51 is_float()52 bool is_float() override { 53 return false; 54 } 55 is_bool()56 bool is_bool() override { 57 return false; 58 } 59 is_nested_int()60 bool is_nested_int() const override { 61 return true; 62 } 63 has_hint()64 bool has_hint() override { 65 return true; 66 } 67 wrap_int(int64_t num)68 c10::SymNode wrap_int(int64_t num) override { 69 return SymNode(c10::make_intrusive<ConstantSymNodeImpl<int64_t>>(num)); 70 }; 71 guard_int(const char * file,int64_t line)72 int64_t guard_int(const char* file, int64_t line) override { 73 TORCH_CHECK(false); 74 } 75 guard_float(const char * file,int64_t line)76 double guard_float(const char* file, int64_t line) override { 77 TORCH_CHECK(false, "not a float"); 78 } 79 guard_bool(const char * file,int64_t line)80 bool guard_bool(const char* file, int64_t line) override { 81 TORCH_CHECK(false, "not a bool"); 82 } 83 int_()84 int64_t int_() override { 85 TORCH_CHECK(false); 86 } 87 str()88 std::string str() override { 89 if (coeff_ == 1) { 90 return "j" + std::to_string(val_); 91 } 92 return std::to_string(coeff_) + "*j" + std::to_string(val_); 93 } 94 95 // NOTE [ Inequalities with nested int ] 96 // 97 // The semantics of nested int when it comes to relations is that it is 98 // treated as integer known to be within a certain range, 99 // 100 // j0 \in [2, int64_t::max] 101 // 102 // allowing us to answer queries like j0 >= 1 (True), and j0 == 0 (False). 103 // This is a useful default range for the raggedness pattern of a jagged 104 // tensor (1) since sizes are non-negative, and (2) we need to get past 0/1 105 // specialization checks. 106 // 107 // [ Indeterminate inequalities error out ] 108 // 109 // Given the semantic defined above, certain relations like j0 < 3 are thus 110 // indeterminable. In our impl today, evaluating such relations error 111 // 112 // It may seem convenient to just define indeterminate relations to return 113 // False, but the implementation we maintain in parallel using sympy does not 114 // allow this. 115 // 116 // Sympy only allows overriding of Ge. The other relations (Lt, Gt, Le) are, 117 // by consequence, all derived from Ge e.g., Lt(a, b) := !Ge(a, b). This 118 // would mean that means that if we define the indeterminate j0 >= 3 to be 119 // False, the also indeterminate j0 < 3 will be evaluated to be True! 120 // 121 // [ Coefficient are assumed positive ] 122 // 123 // For the purpose of computing inequalities, we consider the coefficient of 124 // the nested int to be a positive integer. 125 // 126 // Thus, no modifications are needed to the logic since 127 // j0 >= k implies coeff * j0 >= k 128 // 129 c10::SymNode eq(const c10::SymNode& other) override; 130 c10::SymNode ne(const c10::SymNode& other) override; 131 c10::SymNode ge(const c10::SymNode& other) override; 132 c10::SymNode gt(const c10::SymNode& other) override; 133 c10::SymNode lt(const c10::SymNode& other) override; 134 c10::SymNode le(const c10::SymNode& other) override; 135 c10::SymNode mul(const c10::SymNode& other) override; 136 nested_int()137 std::optional<int64_t> nested_int() override { 138 return val_; 139 } 140 nested_int_coeff()141 std::optional<int64_t> nested_int_coeff() override { 142 return coeff_; 143 } 144 is_symbolic()145 bool is_symbolic() override { 146 return false; 147 } 148 149 c10::SymNode clone() override; 150 151 #define DEFINE_BINARY_NOT_SUPPORTED(name) \ 152 c10::SymNode name(const c10::SymNode& other) override { \ 153 TORCH_CHECK(false, #name " not supported by NestedIntSymNode"); \ 154 } 155 156 DEFINE_BINARY_NOT_SUPPORTED(add) 157 DEFINE_BINARY_NOT_SUPPORTED(sub) 158 DEFINE_BINARY_NOT_SUPPORTED(truediv) 159 DEFINE_BINARY_NOT_SUPPORTED(pow) 160 DEFINE_BINARY_NOT_SUPPORTED(floordiv) 161 DEFINE_BINARY_NOT_SUPPORTED(mod) 162 DEFINE_BINARY_NOT_SUPPORTED(sym_min) 163 DEFINE_BINARY_NOT_SUPPORTED(sym_max) 164 DEFINE_BINARY_NOT_SUPPORTED(sym_and) 165 DEFINE_BINARY_NOT_SUPPORTED(sym_or) 166 167 #undef DEFINE_BINARY_NOT_SUPPORTED 168 169 #define DEFINE_NOT_SUPPORTED(name) \ 170 c10::SymNode name() override { \ 171 TORCH_CHECK(false, #name " is not supported by NestedIntSymNode"); \ 172 } 173 174 DEFINE_NOT_SUPPORTED(sym_not) 175 DEFINE_NOT_SUPPORTED(ceil) 176 DEFINE_NOT_SUPPORTED(floor) 177 DEFINE_NOT_SUPPORTED(neg) 178 DEFINE_NOT_SUPPORTED(sym_float) 179 180 #undef DEFINE_NOT_SUPPORTED 181 182 private: 183 int64_t val_; 184 int64_t coeff_; 185 }; 186 187 } // namespace c10 188