xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/eval.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cmath>
4 #include <cstring>
5 #include <utility>
6 #include <vector>
7 
8 #include <c10/macros/Macros.h>
9 #include <c10/util/Logging.h>
10 #include <torch/csrc/jit/tensorexpr/codegen.h>
11 #include <torch/csrc/jit/tensorexpr/exceptions.h>
12 #include <torch/csrc/jit/tensorexpr/ir.h>
13 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
14 #include <torch/csrc/jit/tensorexpr/tensor.h>
15 #include <torch/csrc/jit/tensorexpr/types.h>
16 #include <torch/csrc/jit/tensorexpr/var_substitutor.h>
17 
18 namespace torch::jit::tensorexpr {
19 
20 class InterpValue {
21  public:
InterpValue()22   InterpValue() : dtype_(kInt) {
23     Intvalues.push_back(0);
24   }
25 
26   template <typename T>
InterpValue(Dtype dtype,T v)27   InterpValue(Dtype dtype, T v) : dtype_(dtype) {
28 #define TYPE_CASE(Type, Name)  \
29   if (dtype == k##Name) {      \
30     Name##values.push_back(v); \
31     return;                    \
32   }
33     AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
34 #undef TYPE_CASE
35     throw unsupported_dtype();
36   }
37 
38 #define VALUE_CTOR(Type, Name)            \
39   InterpValue(Type v) : dtype_(k##Name) { \
40     Name##values.push_back(v);            \
41   }
42   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_CTOR);
43 #undef VALUE_CTOR
44 
InterpValue(c10::quint8 v)45   explicit InterpValue(c10::quint8 v) : dtype_(kQUInt8) {
46     QUInt8values.emplace_back(v.val_);
47   }
48 
InterpValue(c10::qint8 v)49   explicit InterpValue(c10::qint8 v) : dtype_(kQInt8) {
50     QInt8values.emplace_back(v.val_);
51   }
52 
53 #define VALUE_VEC_CTOR(Type, Name)        \
54   InterpValue(const std::vector<Type>& v) \
55       : dtype_(Dtype(k##Name, v.size())), Name##values(v) {}
56   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_VEC_CTOR);
57   VALUE_VEC_CTOR(c10::quint8, QUInt8)
58   VALUE_VEC_CTOR(c10::qint8, QInt8)
59 #undef VALUE_VEC_CTOR
60 
61   template <typename T>
62   T as() const;
63 
64   template <typename T>
65   const std::vector<T>& as_vec() const;
66 
67   int64_t intValue() const;
68 
dtype()69   Dtype dtype() const {
70     return dtype_;
71   }
72 
73  private:
74   Dtype dtype_;
75 
76 #define VALUE_STORAGE(Type, Name) std::vector<Type> Name##values;
77   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_STORAGE);
78   VALUE_STORAGE(c10::qint8, QInt8);
79   VALUE_STORAGE(c10::quint8, QUInt8);
80 #undef VALUE_STORAGE
81   void* ptr{nullptr};
82 };
83 
84 #define VALUE_AS_DISPATCH(Type, Name)         \
85   template <>                                 \
86   inline Type InterpValue::as<Type>() const { \
87     if (dtype_ != k##Name) {                  \
88       throw unsupported_dtype();              \
89     }                                         \
90     return Name##values[0];                   \
91   }
92 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_DISPATCH);
93 VALUE_AS_DISPATCH(c10::quint8, QUInt8);
94 VALUE_AS_DISPATCH(c10::qint8, QInt8);
95 #undef VALUE_AS_DISPATCH
96 
97 #define VALUE_AS_VEC_DISPATCH(Type, Name)                             \
98   template <>                                                         \
99   inline const std::vector<Type>& InterpValue::as_vec<Type>() const { \
100     if (dtype_.scalar_type() != ScalarType::Name) {                   \
101       throw unsupported_dtype();                                      \
102     }                                                                 \
103     return Name##values;                                              \
104   }
105 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, VALUE_AS_VEC_DISPATCH);
106 VALUE_AS_VEC_DISPATCH(c10::quint8, QUInt8);
107 VALUE_AS_VEC_DISPATCH(c10::qint8, QInt8);
108 #undef VALUE_AS_VEC_DISPATCH
109 
110 template <typename Type>
underlyingValue(Type x)111 auto underlyingValue(Type x) {
112   return x;
113 }
114 
115 template <>
116 inline auto underlyingValue<c10::quint8>(c10::quint8 x) {
117   return x.val_;
118 }
119 
120 template <>
121 inline auto underlyingValue<c10::qint8>(c10::qint8 x) {
122   return x.val_;
123 }
124 
125 template <typename To, typename From>
raw_bitcast(const From & src)126 To raw_bitcast(const From& src) {
127   TORCH_CHECK(sizeof(To) == sizeof(From), "Invalid bitcast invocation");
128   To storage;
129   std::memcpy(&storage, &src, sizeof(To));
130   return reinterpret_cast<To&>(storage);
131 }
132 
133 class SimpleIREvaluatorImpl;
134 class TORCH_API SimpleIREvaluator : public CodeGen {
135  public:
136   SimpleIREvaluator(
137       StmtPtr stmt,
138       const std::vector<BufferArg>& buffer_args,
139       at::Device device = at::kCPU,
140       const std::string& kernel_func_name = "func");
141 
142   ~SimpleIREvaluator() override;
143 
144   void call(const std::vector<CallArg>& args) override;
145   void call_raw(const std::vector<void*>& args) override;
146 
147   template <typename... Ts>
operator()148   void operator()(const Ts&... ts) {
149     std::vector<CallArg> args({CallArg(ts)...});
150     call(args);
151   }
152 
153   void bindVar(const VarPtr& v, const ExprPtr& e);
154   InterpValue value() const;
155 
156  private:
157   void bindArg(const BufferArg& buf, void* data);
expand_intrinsics()158   void expand_intrinsics() {
159     GenericIntrinsicsExpander intrinsics_expander;
160     apply_mutator(&intrinsics_expander);
161   }
162 
163   std::unique_ptr<SimpleIREvaluatorImpl> impl_;
164 };
165 
166 template <class CodeGenType>
167 class ExprEval {
168  public:
169   using BufferArg = CodeGen::BufferArg;
170   using CallArg = CodeGen::CallArg;
171 
172   template <typename... Ts>
ExprEval(const ExprHandle & expr,Ts...ts)173   ExprEval(const ExprHandle& expr, Ts... ts)
174       : ExprEval(expr, {BufferArg(ts)...}) {}
175 
ExprEval(const ExprHandle & expr,const std::vector<BufferArg> & buffer_args)176   ExprEval(const ExprHandle& expr, const std::vector<BufferArg>& buffer_args)
177       : dtype_(expr.dtype()) {
178     std::vector<BufferArg> buffer_args_extended = buffer_args;
179     BufHandle ret_buf("ret_val", {1}, dtype_);
180     std::vector<ExprHandle> indices;
181     ExprHandle zero = IntImm::make(0);
182     for (size_t i = 0; i < ret_buf.ndim(); i++) {
183       indices.push_back(zero);
184     }
185     StmtPtr store_stmt = Store::make(ret_buf, indices, expr);
186     buffer_args_extended.emplace_back(ret_buf);
187     codegen_.reset(new CodeGenType(store_stmt, buffer_args_extended));
188   }
189 
190   template <typename... Ts>
operator()191   void operator()(Ts... ts) {
192     call(ts...);
193   }
194 
operator()195   void operator()(const std::vector<CallArg>& call_args) {
196     call(call_args);
197   }
198 
bindVar(VarPtr v,ExprPtr e)199   void bindVar(VarPtr v, ExprPtr e) {
200     codegen_->bindVar(v, e);
201   }
202 
bindVar(const VarHandle & v,const ExprHandle & e)203   void bindVar(const VarHandle& v, const ExprHandle& e) {
204     codegen_->bindVar(v.node(), e.node());
205   }
206 
207   template <typename... Ts>
call(Ts...ts)208   void call(Ts... ts) {
209     call({CallArg(ts)...});
210   }
211 
call(const std::vector<CallArg> & call_args)212   void call(const std::vector<CallArg>& call_args) {
213     std::vector<CallArg> call_args_extended = call_args;
214     switch (dtype_.scalar_type()) {
215 #define TYPE_CASE(Type, Name)                     \
216   case ScalarType::Name: {                        \
217     std::vector<Type> ret_val_arg(1);             \
218     call_args_extended.emplace_back(ret_val_arg); \
219     codegen_->call(call_args_extended);           \
220     ret_value_ = InterpValue(ret_val_arg[0]);     \
221   } break;
222       AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
223       TYPE_CASE(c10::quint8, QUInt8);
224       TYPE_CASE(c10::qint8, QInt8);
225 #undef TYPE_CASE
226       case ScalarType::Bool: {
227         std::vector<unsigned char> ret_val_arg(1);
228         call_args_extended.emplace_back(ret_val_arg.data());
229         codegen_->call(call_args_extended);
230         ret_value_ = InterpValue((bool)ret_val_arg[0]);
231       } break;
232       default:
233         throw unsupported_dtype();
234     }
235   }
236 
call_raw(const std::vector<void * > & args)237   void call_raw(const std::vector<void*>& args) {
238     std::vector<void*> args_extended = args;
239     switch (dtype_.scalar_type()) {
240 #define TYPE_CASE(Type, Name)                    \
241   case ScalarType::Name: {                       \
242     std::vector<Type> ret_val_arg(1);            \
243     args_extended.push_back(ret_val_arg.data()); \
244     codegen_->call_raw(args_extended);           \
245     ret_value_ = InterpValue(ret_val_arg[0]);    \
246   } break;
247       AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TYPE_CASE);
248       TYPE_CASE(c10::quint8, QUInt8);
249       TYPE_CASE(c10::qint8, QInt8);
250 #undef TYPE_CASE
251       case ScalarType::Bool: {
252         std::vector<unsigned char> ret_val_arg(1);
253         args_extended.push_back(ret_val_arg.data());
254         codegen_->call_raw(args_extended);
255         ret_value_ = InterpValue((bool)ret_val_arg[0]);
256       } break;
257       default:
258         throw unsupported_dtype();
259     }
260   }
261 
262   template <typename T>
value(const std::vector<void * > & args)263   T value(const std::vector<void*>& args) {
264     call_raw(args);
265     return ret_value_.as<T>();
266   }
267 
268   template <typename T, typename... Ts>
value(Ts...ts)269   T value(Ts... ts) {
270     call(std::forward<Ts>(ts)...);
271     return ret_value_.as<T>();
272   }
273 
dtype()274   Dtype dtype() {
275     return dtype_;
276   }
277 
278  private:
279   Dtype dtype_;
280   std::unique_ptr<CodeGenType> codegen_;
281   InterpValue ret_value_;
282 };
283 
284 // Evaluates the given expression and returns an int64_t value if the result of
285 // the given expression is int64_t.
286 std::optional<int64_t> evalInt(ExprPtr e);
287 
288 // Substitutes the given vars with their corresponding expressions in the input
289 // expression.
Substitute(const ExprPtr & expr,const VarMapping & var_mapping)290 inline ExprPtr Substitute(const ExprPtr& expr, const VarMapping& var_mapping) {
291   VarSubMutator var_sub(var_mapping);
292   return expr->accept_mutator(&var_sub);
293 }
294 
295 // Substitutes the given vars with their corresponding expressions in the input
296 // statement.
Substitute(const StmtPtr & stmt,const VarMapping & var_mapping)297 inline StmtPtr Substitute(const StmtPtr& stmt, const VarMapping& var_mapping) {
298   VarSubMutator var_sub(var_mapping);
299   return stmt->accept_mutator(&var_sub);
300 }
301 
302 // Creates a clone of the input expression and substitutes the given vars with
303 // their corresponding expressions in the clone.
304 // NOTE: This works because cloning reuses variables and does not create new
305 // ones, and `VarMapping` input has variables as the key.
SubstituteInClone(const ExprPtr & expr,const VarMapping & var_mapping)306 inline ExprPtr SubstituteInClone(
307     const ExprPtr& expr,
308     const VarMapping& var_mapping) {
309   VarSubMutator var_sub(var_mapping);
310   return Expr::clone(expr)->accept_mutator(&var_sub);
311 }
312 
313 // Creates a clone of the input statement and substitutes the given vars with
314 // their corresponding expressions in the clone.
315 // NOTE: This works because cloning reuses variables and does not create new
316 // ones, and `VarMapping` input has variables as the key.
SubstituteInClone(const StmtPtr & stmt,const VarMapping & var_mapping)317 inline StmtPtr SubstituteInClone(
318     const StmtPtr& stmt,
319     const VarMapping& var_mapping) {
320   VarSubMutator var_sub(var_mapping);
321   return Stmt::clone(stmt)->accept_mutator(&var_sub);
322 }
323 
324 } // namespace torch::jit::tensorexpr
325