1 #pragma once
2
3 #include <cmath>
4 #include <cstring>
5 #include <utility>
6 #include <vector>
7
8 #include <c10/macros/Macros.h>
9 #include <c10/util/Logging.h>
10 #include <torch/csrc/jit/tensorexpr/codegen.h>
11 #include <torch/csrc/jit/tensorexpr/exceptions.h>
12 #include <torch/csrc/jit/tensorexpr/ir.h>
13 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
14 #include <torch/csrc/jit/tensorexpr/tensor.h>
15 #include <torch/csrc/jit/tensorexpr/types.h>
16 #include <torch/csrc/jit/tensorexpr/var_substitutor.h>
17
18 namespace torch::jit::tensorexpr {
19
20 class InterpValue {
21 public:
InterpValue()22 InterpValue() : dtype_(kInt) {
23 Intvalues.push_back(0);
24 }
25
26 template <typename T>
InterpValue(Dtype dtype,T v)27 InterpValue(Dtype dtype, T v) : dtype_(dtype) {
28 #define TYPE_CASE(Type, Name) \
29 if (dtype == k##Name) { \
30 Name##values.push_back(v); \
31 return; \
32 }
33 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
34 #undef TYPE_CASE
35 throw unsupported_dtype();
36 }
37
38 #define VALUE_CTOR(Type, Name) \
39 InterpValue(Type v) : dtype_(k##Name) { \
40 Name##values.push_back(v); \
41 }
42 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_CTOR);
43 #undef VALUE_CTOR
44
InterpValue(c10::quint8 v)45 explicit InterpValue(c10::quint8 v) : dtype_(kQUInt8) {
46 QUInt8values.emplace_back(v.val_);
47 }
48
InterpValue(c10::qint8 v)49 explicit InterpValue(c10::qint8 v) : dtype_(kQInt8) {
50 QInt8values.emplace_back(v.val_);
51 }
52
53 #define VALUE_VEC_CTOR(Type, Name) \
54 InterpValue(const std::vector<Type>& v) \
55 : dtype_(Dtype(k##Name, v.size())), Name##values(v) {}
56 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_VEC_CTOR);
57 VALUE_VEC_CTOR(c10::quint8, QUInt8)
58 VALUE_VEC_CTOR(c10::qint8, QInt8)
59 #undef VALUE_VEC_CTOR
60
61 template <typename T>
62 T as() const;
63
64 template <typename T>
65 const std::vector<T>& as_vec() const;
66
67 int64_t intValue() const;
68
dtype()69 Dtype dtype() const {
70 return dtype_;
71 }
72
73 private:
74 Dtype dtype_;
75
76 #define VALUE_STORAGE(Type, Name) std::vector<Type> Name##values;
77 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_STORAGE);
78 VALUE_STORAGE(c10::qint8, QInt8);
79 VALUE_STORAGE(c10::quint8, QUInt8);
80 #undef VALUE_STORAGE
81 void* ptr{nullptr};
82 };
83
84 #define VALUE_AS_DISPATCH(Type, Name) \
85 template <> \
86 inline Type InterpValue::as<Type>() const { \
87 if (dtype_ != k##Name) { \
88 throw unsupported_dtype(); \
89 } \
90 return Name##values[0]; \
91 }
92 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_DISPATCH);
93 VALUE_AS_DISPATCH(c10::quint8, QUInt8);
94 VALUE_AS_DISPATCH(c10::qint8, QInt8);
95 #undef VALUE_AS_DISPATCH
96
97 #define VALUE_AS_VEC_DISPATCH(Type, Name) \
98 template <> \
99 inline const std::vector<Type>& InterpValue::as_vec<Type>() const { \
100 if (dtype_.scalar_type() != ScalarType::Name) { \
101 throw unsupported_dtype(); \
102 } \
103 return Name##values; \
104 }
105 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_VEC_DISPATCH);
106 VALUE_AS_VEC_DISPATCH(c10::quint8, QUInt8);
107 VALUE_AS_VEC_DISPATCH(c10::qint8, QInt8);
108 #undef VALUE_AS_VEC_DISPATCH
109
110 template <typename Type>
underlyingValue(Type x)111 auto underlyingValue(Type x) {
112 return x;
113 }
114
115 template <>
116 inline auto underlyingValue<c10::quint8>(c10::quint8 x) {
117 return x.val_;
118 }
119
120 template <>
121 inline auto underlyingValue<c10::qint8>(c10::qint8 x) {
122 return x.val_;
123 }
124
125 template <typename To, typename From>
raw_bitcast(const From & src)126 To raw_bitcast(const From& src) {
127 TORCH_CHECK(sizeof(To) == sizeof(From), "Invalid bitcast invocation");
128 To storage;
129 std::memcpy(&storage, &src, sizeof(To));
130 return reinterpret_cast<To&>(storage);
131 }
132
133 class SimpleIREvaluatorImpl;
134 class TORCH_API SimpleIREvaluator : public CodeGen {
135 public:
136 SimpleIREvaluator(
137 StmtPtr stmt,
138 const std::vector<BufferArg>& buffer_args,
139 at::Device device = at::kCPU,
140 const std::string& kernel_func_name = "func");
141
142 ~SimpleIREvaluator() override;
143
144 void call(const std::vector<CallArg>& args) override;
145 void call_raw(const std::vector<void*>& args) override;
146
147 template <typename... Ts>
operator()148 void operator()(const Ts&... ts) {
149 std::vector<CallArg> args({CallArg(ts)...});
150 call(args);
151 }
152
153 void bindVar(const VarPtr& v, const ExprPtr& e);
154 InterpValue value() const;
155
156 private:
157 void bindArg(const BufferArg& buf, void* data);
expand_intrinsics()158 void expand_intrinsics() {
159 GenericIntrinsicsExpander intrinsics_expander;
160 apply_mutator(&intrinsics_expander);
161 }
162
163 std::unique_ptr<SimpleIREvaluatorImpl> impl_;
164 };
165
166 template <class CodeGenType>
167 class ExprEval {
168 public:
169 using BufferArg = CodeGen::BufferArg;
170 using CallArg = CodeGen::CallArg;
171
172 template <typename... Ts>
ExprEval(const ExprHandle & expr,Ts...ts)173 ExprEval(const ExprHandle& expr, Ts... ts)
174 : ExprEval(expr, {BufferArg(ts)...}) {}
175
ExprEval(const ExprHandle & expr,const std::vector<BufferArg> & buffer_args)176 ExprEval(const ExprHandle& expr, const std::vector<BufferArg>& buffer_args)
177 : dtype_(expr.dtype()) {
178 std::vector<BufferArg> buffer_args_extended = buffer_args;
179 BufHandle ret_buf("ret_val", {1}, dtype_);
180 std::vector<ExprHandle> indices;
181 ExprHandle zero = IntImm::make(0);
182 for (size_t i = 0; i < ret_buf.ndim(); i++) {
183 indices.push_back(zero);
184 }
185 StmtPtr store_stmt = Store::make(ret_buf, indices, expr);
186 buffer_args_extended.emplace_back(ret_buf);
187 codegen_.reset(new CodeGenType(store_stmt, buffer_args_extended));
188 }
189
190 template <typename... Ts>
operator()191 void operator()(Ts... ts) {
192 call(ts...);
193 }
194
operator()195 void operator()(const std::vector<CallArg>& call_args) {
196 call(call_args);
197 }
198
bindVar(VarPtr v,ExprPtr e)199 void bindVar(VarPtr v, ExprPtr e) {
200 codegen_->bindVar(v, e);
201 }
202
bindVar(const VarHandle & v,const ExprHandle & e)203 void bindVar(const VarHandle& v, const ExprHandle& e) {
204 codegen_->bindVar(v.node(), e.node());
205 }
206
207 template <typename... Ts>
call(Ts...ts)208 void call(Ts... ts) {
209 call({CallArg(ts)...});
210 }
211
call(const std::vector<CallArg> & call_args)212 void call(const std::vector<CallArg>& call_args) {
213 std::vector<CallArg> call_args_extended = call_args;
214 switch (dtype_.scalar_type()) {
215 #define TYPE_CASE(Type, Name) \
216 case ScalarType::Name: { \
217 std::vector<Type> ret_val_arg(1); \
218 call_args_extended.emplace_back(ret_val_arg); \
219 codegen_->call(call_args_extended); \
220 ret_value_ = InterpValue(ret_val_arg[0]); \
221 } break;
222 AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
223 TYPE_CASE(c10::quint8, QUInt8);
224 TYPE_CASE(c10::qint8, QInt8);
225 #undef TYPE_CASE
226 case ScalarType::Bool: {
227 std::vector<unsigned char> ret_val_arg(1);
228 call_args_extended.emplace_back(ret_val_arg.data());
229 codegen_->call(call_args_extended);
230 ret_value_ = InterpValue((bool)ret_val_arg[0]);
231 } break;
232 default:
233 throw unsupported_dtype();
234 }
235 }
236
call_raw(const std::vector<void * > & args)237 void call_raw(const std::vector<void*>& args) {
238 std::vector<void*> args_extended = args;
239 switch (dtype_.scalar_type()) {
240 #define TYPE_CASE(Type, Name) \
241 case ScalarType::Name: { \
242 std::vector<Type> ret_val_arg(1); \
243 args_extended.push_back(ret_val_arg.data()); \
244 codegen_->call_raw(args_extended); \
245 ret_value_ = InterpValue(ret_val_arg[0]); \
246 } break;
247 AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
248 TYPE_CASE(c10::quint8, QUInt8);
249 TYPE_CASE(c10::qint8, QInt8);
250 #undef TYPE_CASE
251 case ScalarType::Bool: {
252 std::vector<unsigned char> ret_val_arg(1);
253 args_extended.push_back(ret_val_arg.data());
254 codegen_->call_raw(args_extended);
255 ret_value_ = InterpValue((bool)ret_val_arg[0]);
256 } break;
257 default:
258 throw unsupported_dtype();
259 }
260 }
261
262 template <typename T>
value(const std::vector<void * > & args)263 T value(const std::vector<void*>& args) {
264 call_raw(args);
265 return ret_value_.as<T>();
266 }
267
268 template <typename T, typename... Ts>
value(Ts...ts)269 T value(Ts... ts) {
270 call(std::forward<Ts>(ts)...);
271 return ret_value_.as<T>();
272 }
273
dtype()274 Dtype dtype() {
275 return dtype_;
276 }
277
278 private:
279 Dtype dtype_;
280 std::unique_ptr<CodeGenType> codegen_;
281 InterpValue ret_value_;
282 };
283
284 // Evaluates the given expression and returns an int64_t value if the result of
285 // the given expression is int64_t.
286 std::optional<int64_t> evalInt(ExprPtr e);
287
288 // Substitutes the given vars with their corresponding expressions in the input
289 // expression.
Substitute(const ExprPtr & expr,const VarMapping & var_mapping)290 inline ExprPtr Substitute(const ExprPtr& expr, const VarMapping& var_mapping) {
291 VarSubMutator var_sub(var_mapping);
292 return expr->accept_mutator(&var_sub);
293 }
294
295 // Substitutes the given vars with their corresponding expressions in the input
296 // statement.
Substitute(const StmtPtr & stmt,const VarMapping & var_mapping)297 inline StmtPtr Substitute(const StmtPtr& stmt, const VarMapping& var_mapping) {
298 VarSubMutator var_sub(var_mapping);
299 return stmt->accept_mutator(&var_sub);
300 }
301
302 // Creates a clone of the input expression and substitutes the given vars with
303 // their corresponding expressions in the clone.
304 // NOTE: This works because cloning reuses variables and does not create new
305 // ones, and `VarMapping` input has variables as the key.
SubstituteInClone(const ExprPtr & expr,const VarMapping & var_mapping)306 inline ExprPtr SubstituteInClone(
307 const ExprPtr& expr,
308 const VarMapping& var_mapping) {
309 VarSubMutator var_sub(var_mapping);
310 return Expr::clone(expr)->accept_mutator(&var_sub);
311 }
312
313 // Creates a clone of the input statement and substitutes the given vars with
314 // their corresponding expressions in the clone.
315 // NOTE: This works because cloning reuses variables and does not create new
316 // ones, and `VarMapping` input has variables as the key.
SubstituteInClone(const StmtPtr & stmt,const VarMapping & var_mapping)317 inline StmtPtr SubstituteInClone(
318 const StmtPtr& stmt,
319 const VarMapping& var_mapping) {
320 VarSubMutator var_sub(var_mapping);
321 return Stmt::clone(stmt)->accept_mutator(&var_sub);
322 }
323
324 } // namespace torch::jit::tensorexpr
325