xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <pybind11/functional.h>
2 #include <pybind11/operators.h>
3 #include <pybind11/stl.h>
4 #include <torch/csrc/jit/python/pybind_utils.h>
5 #include <torch/csrc/jit/tensorexpr/codegen.h>
6 #include <torch/csrc/utils/pybind.h>
7 #ifdef USE_CUDA
8 #include <torch/csrc/jit/tensorexpr/cuda_codegen.h>
9 #endif
10 #include <torch/csrc/jit/tensorexpr/graph_opt.h>
11 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
12 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
13 #include <torch/csrc/jit/tensorexpr/kernel.h>
14 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
15 #include <torch/csrc/jit/tensorexpr/loopnest.h>
16 #include <torch/csrc/jit/tensorexpr/lowerings.h>
17 #include <torch/csrc/jit/tensorexpr/reduction.h>
18 
19 #include <utility>
20 
21 template <>
22 struct pybind11::detail::type_caster<torch::jit::tensorexpr::ArgValue>
23     : public type_caster_base<torch::jit::tensorexpr::ArgValue> {};
24 
25 namespace torch::jit {
26 using namespace torch::jit::tensorexpr;
27 
convertPyToArgValue(py::handle inp)28 ArgValue convertPyToArgValue(py::handle inp) {
29   if (py::isinstance<BufHandle>(inp)) {
30     return py::cast<BufHandle>(inp);
31   } else if (py::isinstance<VarHandle>(inp)) {
32     return py::cast<VarHandle>(inp);
33   } else if (py::isinstance<py::bool_>(inp)) {
34     return py::cast<bool>(inp);
35   } else if (py::isinstance<py::float_>(inp)) {
36     return py::cast<double>(inp);
37   } else if (py::isinstance<py::int_>(inp)) {
38     return py::cast<int64_t>(inp);
39   } else if (py::isinstance<py::none>(inp)) {
40     return ArgNone();
41   } else if (py::isinstance<py::list>(inp)) {
42     auto l = py::cast<py::list>(inp);
43     if (l.empty()) {
44       return std::vector<BufHandle>();
45     } else if (py::isinstance<py::int_>(l[0])) {
46       return py::cast<IntList>(inp);
47     } else if (py::isinstance<BufHandle>(l[0])) {
48       return py::cast<BufList>(inp);
49     } else {
50       throw std::runtime_error("vector conversion failed");
51     }
52   } else {
53     throw std::runtime_error("conversion not yet implemented");
54   }
55 }
56 
parsePythonDtype(py::handle obj)57 Dtype parsePythonDtype(py::handle obj) {
58   if (THPDtype_Check(obj.ptr())) {
59     return Dtype(reinterpret_cast<THPDtype*>(obj.ptr())->scalar_type);
60   } else {
61     throw std::runtime_error("expected a torch.dtype instance");
62   }
63 }
64 
initTensorExprBindings(PyObject * module)65 void initTensorExprBindings(PyObject* module) {
66   auto m = py::handle(module).cast<py::module>();
67 
68   // Tensor Expr Classes
69   auto te = m.def_submodule("_te");
70 
71   auto dtype_class =
72       py::class_<Dtype>(te, "Dtype").def(py::init(&parsePythonDtype));
73   py::implicitly_convertible<py::object, Dtype>();
74 
75 #define DTYPE_SINGLETON_ACCESSOR(ctype, name) \
76   dtype_class.def_property_readonly_static(   \
77       #name, [](const py::object&) { return k##name; });
78   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DTYPE_SINGLETON_ACCESSOR)
79 #undef DTYPE_SINGLETON_ACCESSOR
80 
81   auto expr_handle_class =
82       py::class_<ExprHandle>(te, "ExprHandle")
83           .def(
84               "__str__",
85               [](const ExprHandle& self) {
86                 std::stringstream ss;
87                 ss << self;
88                 return ss.str();
89               })
90           .def(py::self + py::self)
91           .def(py::self * py::self)
92           .def(py::self - py::self)
93           .def(py::self / py::self)
94           .def(py::self % py::self)
95           .def(py::self == py::self)
96           .def(py::self != py::self)
97           .def(py::self > py::self)
98           .def(py::self >= py::self)
99           .def(py::self < py::self)
100           .def(py::self <= py::self)
101           .def(py::self & py::self)
102           .def(py::self | py::self)
103           .def(py::self ^ py::self)
104           .def(py::self << py::self)
105           .def(py::self >> py::self)
106           .def(
107               "__pow__",
108               [](const ExprHandle& self, const ExprHandle& other) {
109                 return pow(self, other);
110               })
111           .def("sin", [](const ExprHandle& self) { return sin(self); })
112           .def("cos", [](const ExprHandle& self) { return cos(self); })
113           .def("tan", [](const ExprHandle& self) { return tan(self); })
114           .def("asin", [](const ExprHandle& self) { return asin(self); })
115           .def("acos", [](const ExprHandle& self) { return acos(self); })
116           .def("atan", [](const ExprHandle& self) { return atan(self); })
117           .def("sinh", [](const ExprHandle& self) { return sinh(self); })
118           .def("cosh", [](const ExprHandle& self) { return cosh(self); })
119           .def("tanh", [](const ExprHandle& self) { return tanh(self); })
120           .def("sigmoid", [](const ExprHandle& self) { return sigmoid(self); })
121           .def("exp", [](const ExprHandle& self) { return exp(self); })
122           .def("expm1", [](const ExprHandle& self) { return expm1(self); })
123           .def(
124               "abs",
125               [](const ExprHandle& self) { return tensorexpr::abs(self); })
126           .def("log", [](const ExprHandle& self) { return log(self); })
127           .def(
128               "fast_tanh",
129               [](const ExprHandle& self) { return fast_tanh(self); })
130           .def(
131               "fast_sigmoid",
132               [](const ExprHandle& self) { return fast_sigmoid(self); })
133           .def(
134               "fast_log", [](const ExprHandle& self) { return fast_log(self); })
135           .def("log_vml", [](const ExprHandle& self) { return log_vml(self); })
136           .def("log2", [](const ExprHandle& self) { return log2(self); })
137           .def("log10", [](const ExprHandle& self) { return log10(self); })
138           .def("log1p", [](const ExprHandle& self) { return log1p(self); })
139           .def("erf", [](const ExprHandle& self) { return erf(self); })
140           .def("erfc", [](const ExprHandle& self) { return erfc(self); })
141           .def(
142               "sqrt",
143               [](const ExprHandle& self) { return tensorexpr::sqrt(self); })
144           .def("rsqrt", [](const ExprHandle& self) { return rsqrt(self); })
145           .def("ceil", [](const ExprHandle& self) { return ceil(self); })
146           .def("floor", [](const ExprHandle& self) { return floor(self); })
147           .def("round", [](const ExprHandle& self) { return round(self); })
148           .def("trunc", [](const ExprHandle& self) { return trunc(self); })
149           .def("frac", [](const ExprHandle& self) { return frac(self); })
150           .def("lgamma", [](const ExprHandle& self) { return lgamma(self); })
151           .def("isnan", [](const ExprHandle& self) { return isnan(self); })
152           .def(
153               "cast",
154               [](const ExprHandle& self, const Dtype& dt) {
155                 return Cast::make(dt, self);
156               })
157 #define EXPRHANDLE_INIT(ctype, name) \
158   .def(py::init([](ctype val) { return name##Imm::make(val); }))
159               AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_INIT)
160 #undef EXPRHANDLE_INIT
161       ;
162 
163 #define EXPRHANDLE_IMPL_CONV(ctype, name) \
164   py::implicitly_convertible<ctype, ExprHandle>();
165   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_IMPL_CONV)
166 #undef EXPRHANDLE_IMPL_CONV
167 
168   te.def(
169       "ifThenElse",
170       [](const ExprHandle& c, const ExprHandle& t, const ExprHandle& f) {
171         return ifThenElse(c, t, f);
172       });
173 
174   te.def("sin", [](const ExprHandle& v1) { return sin(v1); });
175   te.def("cos", [](const ExprHandle& v1) { return cos(v1); });
176   te.def("tan", [](const ExprHandle& v1) { return tan(v1); });
177   te.def("asin", [](const ExprHandle& v1) { return asin(v1); });
178   te.def("acos", [](const ExprHandle& v1) { return acos(v1); });
179   te.def("atan", [](const ExprHandle& v1) { return atan(v1); });
180   te.def("sinh", [](const ExprHandle& v1) { return sinh(v1); });
181   te.def("cosh", [](const ExprHandle& v1) { return cosh(v1); });
182   te.def("tanh", [](const ExprHandle& v1) { return tanh(v1); });
183   te.def("sigmoid", [](const ExprHandle& v1) { return sigmoid(v1); });
184   te.def("exp", [](const ExprHandle& v1) { return exp(v1); });
185   te.def("expm1", [](const ExprHandle& v1) { return expm1(v1); });
186   te.def("abs", [](const ExprHandle& v1) { return abs(v1); });
187   te.def("log", [](const ExprHandle& v1) { return log(v1); });
188   te.def("log2", [](const ExprHandle& v1) { return log2(v1); });
189   te.def("log10", [](const ExprHandle& v1) { return log10(v1); });
190   te.def("log1p", [](const ExprHandle& v1) { return log1p(v1); });
191   te.def("erf", [](const ExprHandle& v1) { return erf(v1); });
192   te.def("erfc", [](const ExprHandle& v1) { return erfc(v1); });
193   te.def("sqrt", [](const ExprHandle& v1) { return sqrt(v1); });
194   te.def("rsqrt", [](const ExprHandle& v1) { return rsqrt(v1); });
195   te.def("ceil", [](const ExprHandle& v1) { return ceil(v1); });
196   te.def("floor", [](const ExprHandle& v1) { return floor(v1); });
197   te.def("round", [](const ExprHandle& v1) { return round(v1); });
198   te.def("trunc", [](const ExprHandle& v1) { return trunc(v1); });
199   te.def("frac", [](const ExprHandle& v1) { return frac(v1); });
200   te.def("lgamma", [](const ExprHandle& v1) { return lgamma(v1); });
201   te.def("isnan", [](const ExprHandle& v1) { return isnan(v1); });
202 
203   te.def("atan2", [](const ExprHandle& v1, const ExprHandle& v2) {
204     return atan2(v1, v2);
205   });
206   te.def("pow", [](const ExprHandle& v1, const ExprHandle& v2) {
207     return pow(v1, v2);
208   });
209   te.def("fmod", [](const ExprHandle& v1, const ExprHandle& v2) {
210     return fmod(v1, v2);
211   });
212   te.def("remainder", [](const ExprHandle& v1, const ExprHandle& v2) {
213     return remainder(v1, v2);
214   });
215 
216 #define EXPRHANDLE_CTOR(ctype, name) \
217   expr_handle_class.def_static(#ctype, [](ctype v) { return ExprHandle(v); });
218   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_CTOR)
219 #undef EXPRHANDLE_CTOR
220 
221   py::class_<VarHandle, ExprHandle>(te, "VarHandle")
222       .def(
223           "__str__",
224           [](const ExprHandle& self) {
225             std::stringstream ss;
226             ss << self;
227             return ss.str();
228           })
229       .def(py::init<Dtype>())
230       .def(py::init<const std::string&, Dtype>());
231   py::class_<BufHandle, ExprHandle>(te, "BufHandle")
232       .def(
233           py::init<const std::string&, const std::vector<ExprHandle>&, Dtype>())
234       .def(py::init<const std::vector<ExprHandle>&, Dtype>())
235       .def(py::init<Dtype>())
236       .def(
237           "__hash__",
238           [](const BufHandle& self) {
239             return std::hash<BufPtr>()(self.node());
240           })
241       .def(
242           "__eq__",
243           [](const BufHandle& self, const BufHandle& other) {
244             return self.node() == other.node();
245           })
246       .def(
247           "load",
248           [](BufHandle& self, const std::vector<ExprHandle>& v) {
249             return Load::make(self, v);
250           })
251       .def(
252           "load",
253           [](BufHandle& self, const ExprHandle& v) {
254             return Load::make(self, {v});
255           })
256       .def(
257           "store",
258           [](BufHandle& self,
259              const std::vector<ExprHandle>& i,
260              const ExprHandle& v) { return Store::make(self, i, v); })
261       .def(
262           "store",
263           [](BufHandle& self, const ExprHandle& i, const ExprHandle& v) {
264             return Store::make(self, {i}, v);
265           });
266   py::class_<Tensor>(te, "Tensor")
267       .def(py::init([](const BufHandle& b, const StmtPtr& s) {
268         return Tensor(b.node(), s);
269       }))
270       .def(
271           "load",
272           [](Tensor& self, const std::vector<ExprHandle>& v) {
273             return self.load(v);
274           })
275       .def("buf", [](Tensor& self) { return BufHandle(self.buf()); })
276       .def("stmt", &Tensor::stmt);
277   py::class_<Cast, std::shared_ptr<Cast>>(te, "Cast")
278       .def_static("make", &Cast::make)
279       .def(
280           "src_value",
281           [](CastPtr& self) { return ExprHandle(self->src_value()); })
282       .def("set_src_value", [](CastPtr& self, const ExprHandle& value) {
283         self->set_src_value(value.node());
284       });
285 
286   te.def(
287       "Compute",
288       [](const std::string& func_name,
289          const std::vector<ExprHandle>& dim_args,
290          const py::function& func) {
291         if (dim_args.size() == 1) {
292           return Compute(func_name, dim_args, [&func](const VarHandle& a) {
293             return py::cast<ExprHandle>(func(a));
294           });
295         } else if (dim_args.size() == 2) {
296           return Compute(
297               func_name,
298               dim_args,
299               [&func](const VarHandle& a, const VarHandle& b) {
300                 return py::cast<ExprHandle>(func(a, b));
301               });
302         } else if (dim_args.size() == 3) {
303           return Compute(
304               func_name,
305               dim_args,
306               [&func](
307                   const VarHandle& a, const VarHandle& b, const VarHandle& c) {
308                 return py::cast<ExprHandle>(func(a, b, c));
309               });
310         } else if (dim_args.size() == 4) {
311           return Compute(
312               func_name,
313               dim_args,
314               [&func](
315                   const VarHandle& a,
316                   const VarHandle& b,
317                   const VarHandle& c,
318                   const VarHandle& d) {
319                 return py::cast<ExprHandle>(func(a, b, c, d));
320               });
321         } else {
322           throw std::runtime_error("Too many args");
323         }
324       },
325       py::return_value_policy::reference);
326 
327   te.def(
328       "Compute2",
329       [](const std::string& func_name,
330          const std::vector<ExprHandle>& dim_args,
331          const py::function& func) {
332         return Compute(
333             func_name, dim_args, [&func](const std::vector<VarHandle>& dims) {
334               return py::cast<ExprHandle>(func(dims));
335             });
336       },
337       py::return_value_policy::reference);
338 
339   py::class_<Reducer>(te, "Reducer")
340       .def(py::init<
341            ExprHandle,
342            std::function<ExprHandle(ExprHandle, ExprHandle)>>());
343 
344   py::class_<Sum, Reducer>(te, "Sum").def(py::init<>());
345   py::class_<Maximum, Reducer>(te, "Maximum").def(py::init<Dtype>());
346   te.def(
347       "Reduce",
348       [](const std::string& func_name,
349          const std::vector<ExprHandle>& dim_args,
350          const Reducer& reducer,
351          const Tensor& buffer,
352          const std::vector<ExprHandle>& reduce_args) {
353         return Reduce(func_name, dim_args, reducer, buffer, reduce_args);
354       },
355       py::return_value_policy::reference);
356 
357   te.def(
358       "Reduce",
359       [](const std::string& func_name,
360          const std::vector<ExprHandle>& dim_args,
361          const Reducer& reducer,
362          const BufHandle& buffer,
363          const std::vector<ExprHandle>& reduce_args) {
364         return Reduce(func_name, dim_args, reducer, buffer, reduce_args);
365       },
366       py::return_value_policy::reference);
367   te.def(
368       "Reduce",
369       [](const std::string& func_name,
370          const std::vector<ExprHandle>& dim_args,
371          const Reducer& reducer,
372          const std::function<ExprHandle(const std::vector<VarHandle>&)>&
373              body_func,
374          const std::vector<ExprHandle>& reduce_args) {
375         return Reduce(func_name, dim_args, reducer, body_func, reduce_args);
376       },
377       py::return_value_policy::reference);
378   te.def(
379       "Reduce",
380       [](const std::string& func_name,
381          const std::vector<ExprHandle>& dim_args,
382          const Reducer& reducer,
383          const std::function<ExprHandle(const std::vector<VarHandle>&)>&
384              init_func,
385          const std::function<ExprHandle(const std::vector<VarHandle>&)>&
386              body_func,
387          const std::vector<ExprHandle>& reduce_args) {
388         return Reduce(func_name, dim_args, reducer, body_func, reduce_args);
389       },
390       py::return_value_policy::reference);
391 
392   py::class_<Stmt, std::shared_ptr<Stmt>>(te, "Stmt")
393       .def(py::init([](const std::vector<StmtPtr>& stmts) {
394         return tensorexpr::Block::make(stmts);
395       }))
396       .def("__str__", [](Stmt& self) {
397         std::stringstream ss;
398         ss << self;
399         return ss.str();
400       });
401   py::class_<Store, Stmt, std::shared_ptr<Store>>(te, "Store")
402       .def_static(
403           "make",
404           [](const BufHandle& buf,
405              std::vector<ExprHandle>& indices,
406              const ExprHandle& value) {
407             return Store::make(buf, indices, value);
408           });
409 
410   py::class_<For, Stmt, std::shared_ptr<For>>(te, "For")
411       .def("index_var", [](For& self) { return VarHandle(self.var()); })
412       .def("body", &For::body)
413       .def("set_parallel", &For::set_parallel)
414       .def(
415           "set_gpu_block_index",
416           [](For& self, int block_index) {
417             self.set_gpu_block_index(block_index);
418           })
419       .def(
420           "set_gpu_thread_index",
421           [](For& self, int thread_index) {
422             self.set_gpu_thread_index(thread_index);
423           })
424       .def_static(
425           "make",
426           [](const VarHandle& var,
427              const ExprHandle& start,
428              const ExprHandle& stop,
429              const StmtPtr& body) {
430             return For::make(var, start, stop, body);
431           });
432 
433   py::class_<Cond, Stmt, std::shared_ptr<Cond>>(te, "Cond")
434       .def_static(
435           "make",
436           [](const ExprHandle& condition,
437              const StmtPtr& true_stmt,
438              const StmtPtr& false_stmt) {
439             return Cond::make(condition, true_stmt, false_stmt);
440           })
441       .def("true_stmt", &Cond::true_stmt)
442       .def("false_stmt", &Cond::false_stmt);
443 
444   py::class_<tensorexpr::Block, Stmt, std::shared_ptr<tensorexpr::Block>>(
445       te, "Block")
446       .def(py::init([](const std::vector<StmtPtr>& stmts) {
447         return tensorexpr::Block::make(stmts);
448       }))
449       .def("stmts", &tensorexpr::Block::stmts);
450   py::class_<ExternalCall, Stmt, std::shared_ptr<ExternalCall>>(
451       te, "ExternalCall")
452       .def(py::init(&ExternalCall::make));
453 
454   py::class_<LoopNest>(te, "LoopNest")
455       .def(py::init<const std::vector<Tensor>&>())
456       .def(py::init<const std::vector<Tensor>&, const std::vector<Tensor>&>())
457       .def(py::init([](const StmtPtr& s, const std::vector<BufHandle>& bufs) {
458         std::unordered_set<BufPtr> buf_nodes;
459         buf_nodes.reserve(bufs.size());
460         for (auto& buf : bufs) {
461           buf_nodes.insert(buf.node());
462         }
463         return std::make_unique<LoopNest>(s, buf_nodes);
464       }))
465       .def("vectorize_inner_loops", &LoopNest::vectorizeInnerLoops)
466       .def(
467           "prepare_for_codegen",
468           [](LoopNest& self) { return self.prepareForCodegen(); },
469           py::return_value_policy::reference)
470       .def(
471           "get_loop_body_for",
472           [](const LoopNest& self, const Tensor& t) {
473             return self.getLoopBodyFor(t);
474           },
475           py::return_value_policy::reference)
476       .def(
477           "get_loop_body_for",
478           [](const LoopNest& self, BufHandle& b) {
479             return self.getLoopBodyFor(b.node());
480           },
481           py::return_value_policy::reference)
482       .def(
483           "get_loops_for",
484           [](const LoopNest& self, const Tensor& t) {
485             return self.getLoopStmtsFor(t);
486           },
487           py::return_value_policy::reference)
488       .def(
489           "get_all_loopnests_for",
490           [](const LoopNest& self, const BufHandle& b) {
491             return self.getAllLoopNestsWritingToBuf(b.node());
492           },
493           py::return_value_policy::reference)
494       .def(
495           "get_enclosing_loopnest",
496           [](const LoopNest& self, const StmtPtr& s) {
497             return self.getEnclosingLoopNest(s);
498           },
499           py::return_value_policy::reference)
500       .def(
501           "get_innermost_loops_for",
502           [](const LoopNest& self, const BufHandle& b) {
503             return self.getAllInnermostLoopsWritingToBuf(b.node());
504           },
505           py::return_value_policy::reference)
506       .def(
507           "get_writes_for",
508           [](const LoopNest& self, const BufHandle& b) {
509             return self.getAllWritesToBuf(b.node());
510           },
511           py::return_value_policy::reference)
512       .def(
513           "get_loop_at",
514           [](const LoopNest& self,
515              ForPtr root,
516              const std::vector<int>& indices) {
517             return self.getLoopAt(std::move(root), indices);
518           },
519           py::return_value_policy::reference)
520       .def(
521           "get_parent_loop",
522           [](const LoopNest& self, const StmtPtr& s) {
523             return self.getParentLoop(s);
524           },
525           py::return_value_policy::reference)
526       .def_static(
527           "get_loop_stmts_in_loopnest",
528           [](const ForPtr& f, size_t num) {
529             return LoopNest::getLoopStmtsInLoopNest(f, num);
530           },
531           py::return_value_policy::reference)
532       .def(
533           "split_with_tail",
534           [](const ForPtr& f, int factor) {
535             ForPtr inner = nullptr, tail = nullptr;
536             LoopNest::splitWithTail(f, factor, &inner, &tail);
537             return std::make_tuple(std::move(inner), std::move(tail));
538           },
539           py::return_value_policy::reference)
540       .def(
541           "split_with_mask",
542           [](const ForPtr& f, int factor) {
543             ForPtr inner = nullptr;
544             LoopNest::splitWithMask(f, factor, &inner);
545             return inner;
546           },
547           py::return_value_policy::reference)
548       .def(
549           "slice_head",
550           [](const ForPtr& f, int factor) {
551             ForPtr head = nullptr, tail = nullptr;
552             LoopNest::sliceHead(f, factor, &head, &tail);
553             return std::make_tuple(std::move(head), std::move(tail));
554           },
555           py::return_value_policy::reference)
556       .def(
557           "slice_tail",
558           [](const ForPtr& f, int factor) {
559             ForPtr head = nullptr, tail = nullptr;
560             LoopNest::sliceTail(f, factor, &head, &tail);
561             return std::make_tuple(std::move(head), std::move(tail));
562           },
563           py::return_value_policy::reference)
564       .def_static(
565           "normalize",
566           [](const ForPtr& f) {
567             LoopNest::normalize(f);
568             return f;
569           },
570           py::return_value_policy::reference)
571       .def(
572           "tile",
573           [](LoopNest& self,
574              const ForPtr& x,
575              const ForPtr& y,
576              int x_factor,
577              int y_factor) { return self.tile(x, y, x_factor, y_factor); },
578           py::return_value_policy::reference)
579       .def_static(
580           "distribute_loop",
581           [](const ForPtr& f) { return LoopNest::distributeLoop(f); },
582           py::return_value_policy::reference)
583       .def_static(
584           "distribute_loop",
585           [](const ForPtr& f, const std::unordered_set<StmtPtr>& pivots) {
586             return LoopNest::distributeLoop(f, pivots);
587           },
588           py::return_value_policy::reference)
589       .def_static(
590           "distribute_loop_over_inner_loops",
591           [](const ForPtr& f) {
592             return LoopNest::distributeLoopOverInnerLoops(f);
593           },
594           py::return_value_policy::reference)
595       .def_static(
596           "unsafe_fuse_loops",
597           [](const std::vector<ForPtr>& loops) {
598             ForPtr fused_loop = nullptr;
599             LoopNest::unsafeFuseLoops(loops, &fused_loop);
600             return fused_loop;
601           },
602           py::return_value_policy::reference)
603       .def_static(
604           "fuse_loops",
605           [](const std::vector<ForPtr>& loops) {
606             ForPtr fused_loop = nullptr;
607             LoopNest::fuseLoops(loops, &fused_loop);
608             return fused_loop;
609           },
610           py::return_value_policy::reference)
611       .def_static(
612           "reorder",
613           [](const std::vector<ForPtr>& loops,
614              const std::vector<size_t>& permutation) {
615             return LoopNest::reorder(loops, permutation);
616           },
617           py::return_value_policy::reference)
618       .def(
619           "fullUnroll",
620           [](const ForPtr& f) {
621             StmtPtr unrolled = nullptr;
622             LoopNest::fullUnroll(f, &unrolled);
623             return unrolled;
624           },
625           py::return_value_policy::reference)
626       .def(
627           "unroll",
628           [](const ForPtr& f, int factor) {
629             LoopNest::unroll(f, factor);
630             return f;
631           },
632           py::return_value_policy::reference)
633       .def(
634           "vectorize",
635           [](const ForPtr& f) { LoopNest::vectorize(f); },
636           py::return_value_policy::reference)
637       .def_static(
638           "compress_buffer",
639           [](BufHandle& buf, const StmtPtr& stmt) {
640             return LoopNest::compressBuffer(buf.node(), stmt);
641           },
642           py::return_value_policy::reference)
643       .def_static(
644           "cache_accesses",
645           [](const BufHandle& producer,
646              const std::string& name,
647              const StmtPtr& consumer) {
648             std::pair<BufPtr, StmtPtr> ret =
649                 LoopNest::cacheAccesses(producer.node(), name, consumer);
650             return std::make_pair(BufHandle(ret.first), ret.second);
651           },
652           py::return_value_policy::reference)
653       .def_static(
654           "compute_at",
655           [](const StmtPtr& s, const ForPtr& at) {
656             LoopNest::computeAt(s, at);
657           })
658       .def(
659           "compute_inline",
660           [](LoopNest& self, const StmtPtr& s) { self.computeInline(s); },
661           py::return_value_policy::reference)
662       .def(
663           "compute_inline",
664           [](LoopNest& self, const BufHandle& b) {
665             self.computeInline(b.node());
666           },
667           py::return_value_policy::reference)
668       .def(
669           "rfactor",
670           [](const StmtPtr& s, const ForPtr& target_for) {
671             BufPtr rfac_buf = nullptr;
672             LoopNest::rfactor(s, target_for, &rfac_buf);
673             return BufHandle(rfac_buf);
674           },
675           py::return_value_policy::reference)
676       .def(
677           "flatten",
678           [](LoopNest& self, const std::vector<ForPtr>& loops) {
679             ForPtr flattened = nullptr;
680             LoopNest::flatten(loops, &flattened);
681             return flattened;
682           },
683           py::return_value_policy::reference)
684       .def(
685           "reorder_axis",
686           &LoopNest::reorderAxis,
687           py::return_value_policy::reference)
688       .def("simplify", &LoopNest::simplify, py::return_value_policy::reference)
689       .def_static("sanitize_names", &LoopNest::sanitizeNames)
690       .def(
691           "inline_intermediate_bufs",
692           [](LoopNest& self, bool allow_duplicated_work) {
693             self.inlineIntermediateBufs(allow_duplicated_work);
694           })
695       .def(
696           "eliminate_dead_stores",
697           [](LoopNest& self) { self.eliminateDeadStores(); })
698       .def(
699           "__str__",
700           [](const LoopNest& self) {
701             std::stringstream ss;
702             ss << *self.root_stmt();
703             return ss.str();
704           })
705       .def(
706           "root_stmt",
707           &LoopNest::root_stmt,
708           py::return_value_policy::reference);
709 
710   te.def(
711       "simplify",
712       [](const StmtPtr& stmt) { return IRSimplifier::simplify(stmt); },
713       py::return_value_policy::reference);
714 
715   te.def(
716       "lower",
717       [](const std::string& op_str,
718          const py::list& inputs,
719          const std::vector<ExprHandle>& outputShape,
720          Dtype outputType) {
721         auto op = c10::Symbol::fromQualString(op_str);
722         std::vector<ArgValue> argInputs;
723         for (auto inp : inputs) {
724           argInputs.push_back(convertPyToArgValue(inp));
725         }
726         if (NNCLoweringFunction lowering =
727                 getStandardLoweringFor(op.toQualString())) {
728           std::vector<ExprHandle> outputStrides =
729               c10::fmap<ExprHandle>(make_channels_last_strides(outputShape));
730           return lowering(
731               argInputs,
732               outputShape,
733               outputStrides,
734               outputType.scalar_type(),
735               at::kCPU);
736         }
737         std::string msg = std::string("Unhandled node kind (in te.lower): ") +
738             op.toQualString();
739         throw malformed_input(msg);
740       });
741 
742   py::class_<ArgValue>(te, "ArgValue")
743       .def(py::init([](py::handle inp) {
744         return std::make_unique<ArgValue>(convertPyToArgValue(inp));
745       }))
746       .def(
747           "as_buf",
748           [](const ArgValue& self) { return std::get<BufHandle>(self); })
749       .def(
750           "as_var",
751           [](const ArgValue& self) { return std::get<VarHandle>(self); })
752       .def(
753           "as_float",
754           [](const ArgValue& self) { return std::get<double>(self); })
755       .def(
756           "as_int",
757           [](const ArgValue& self) { return std::get<int64_t>(self); })
758       .def("as_bool", [](const ArgValue& self) { return std::get<bool>(self); })
759       .def(
760           "as_none",
761           [](const ArgValue& self) { return std::get<ArgNone>(self); })
762       .def(
763           "as_buflist",
764           [](const ArgValue& self) { return std::get<BufList>(self); })
765       .def("as_intlist", [](const ArgValue& self) {
766         return std::get<IntList>(self);
767       });
768 
769   py::class_<c10::ScalarType> give_me_a_name(te, "ScalarType");
770 
771   using TSGraph = std::shared_ptr<Graph>;
772   py::class_<TensorExprKernel>(te, "TensorExprKernel")
773       .def(py::init<const TSGraph&>())
774       .def(
775           py::init(
776               [](const TSGraph& g,
777                  const std::unordered_map<std::string, NNCLoweringFunction>&
778                      custom_lowerings_str,
779                  std::vector<int64_t> symbolic_shape_inputs,
780                  bool pre_alloc = false) {
781                 std::unordered_map<c10::Symbol, NNCLoweringFunction>
782                     custom_lowerings;
783                 custom_lowerings.reserve(custom_lowerings_str.size());
784                 for (auto& kv : custom_lowerings_str) {
785                   custom_lowerings[c10::Symbol::fromQualString(kv.first)] =
786                       kv.second;
787                 }
788                 return std::make_unique<TensorExprKernel>(
789                     g,
790                     std::move(custom_lowerings),
791                     std::move(symbolic_shape_inputs),
792                     pre_alloc);
793               }),
794           py::arg("g"),
795           py::arg("custom_lowerings_str"),
796           py::arg("symbolic_shape_inputs") = std::vector<int64_t>(),
797           py::arg("pre_alloc") = false)
798       .def(
799           "run",
800           [](TensorExprKernel& self, const py::tuple& inputs) {
801             Stack stack;
802             stack.reserve(inputs.size()); // captures?
803             for (auto& obj : inputs) {
804               stack.push_back(toTypeInferredIValue(obj));
805             }
806             auto g_inputs = self.graph()->inputs();
807             for (size_t i = 0; i < inputs.size(); ++i) {
808               if (stack[i].isTensor()) {
809                 g_inputs[i]->setType(stack[i].type());
810               }
811             }
812             self.run(stack);
813             return createPyObjectForStack(std::move(stack));
814           })
815       .def(
816           "fallback",
817           [](TensorExprKernel& self, const py::tuple& inputs) {
818             Stack stack;
819             stack.reserve(inputs.size()); // captures?
820             for (auto& obj : inputs) {
821               stack.push_back(toTypeInferredIValue(obj));
822             }
823             auto g_inputs = self.graph()->inputs();
824             for (size_t i = 0; i < inputs.size(); ++i) {
825               if (stack[i].isTensor()) {
826                 g_inputs[i]->setType(stack[i].type());
827               }
828             }
829             self.fallback(stack);
830             return createPyObjectForStack(std::move(stack));
831           })
832       .def(
833           "get_codegen_stmt",
834           [](TensorExprKernel& self) { return self.getCodeGenStmt(); },
835           py::return_value_policy::reference)
836       .def(
837           "get_code_text",
838           [](TensorExprKernel& self, const std::string& attr = "") {
839             return self.getCodeText(attr);
840           },
841           py::arg("attr") = "")
842       .def("recompile", [](TensorExprKernel& self) { self.recompile(); });
843 
844   py::class_<CodeGen>(te, "CodeGen")
845       .def(
846           "call",
847           [](CodeGen& self, const py::sequence& values) {
848             std::vector<CodeGen::CallArg> value_ptrs;
849             value_ptrs.reserve(py::len(values));
850 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
851             for (const auto& value : values) {
852               if (py::isinstance<py::int_>(value)) {
853                 value_ptrs.emplace_back(value.cast<int64_t>());
854               } else {
855                 value_ptrs.emplace_back(value.cast<at::Tensor>().data_ptr());
856               }
857             }
858 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
859             if (py::len(values) != self.buffer_args().size()) {
860               throw malformed_input("bad args in CodeGen.call function");
861             }
862             for (size_t i = 0; i < py::len(values); i++) {
863               const auto& value = values[i];
864               const auto& bufArg = self.buffer_args()[i];
865               if (py::isinstance<py::int_>(value)) {
866                 if (!bufArg.isVar()) {
867                   throw malformed_input(
868                       "Integer variable expected in CodeGen.call function");
869                 }
870                 switch (bufArg.dtype().scalar_type()) {
871 #define TYPE_CASE(Type, Name)                    \
872   case ScalarType::Name: {                       \
873     value_ptrs.emplace_back(value.cast<Type>()); \
874     break;                                       \
875   }
876                   AT_FORALL_INT_TYPES(TYPE_CASE);
877                   default:
878                     throw unsupported_dtype();
879                 }
880               } else {
881                 value_ptrs.emplace_back(value.cast<at::Tensor>().data_ptr());
882               }
883             }
884 #else
885 #error Unexpected or undefined __BYTE_ORDER__
886 #endif
887             self.call(value_ptrs);
888           })
889       .def(
890           "call_raw",
891           [](CodeGen& self, const py::sequence& values) {
892             std::vector<void*> value_ptrs;
893             value_ptrs.reserve(py::len(values));
894             for (const auto& value : values) {
895               // Tensor.data_ptr() returns an int in python
896               value_ptrs.emplace_back(
897                   reinterpret_cast<void*>(value.cast<intptr_t>()));
898             }
899             self.call_raw(value_ptrs);
900           })
901       .def(
902           "get_code_text",
903           [](CodeGen& self, const std::string& attr = "") {
904             return self.getCodeText(attr);
905           },
906           py::arg("attr") = "");
907   // NOLINTNEXTLINE(bugprone-unused-raii)
908   py::class_<SimpleIREvaluator, CodeGen>(te, "SimpleIREvaluator");
909 #ifdef TORCH_ENABLE_LLVM
910   py::class_<LLVMCodeGen, CodeGen>(te, "LLVMCodeGen");
911 #endif
912 
913   py::class_<CodeGen::BufferArg>(te, "BufferArg")
914       .def(py::init<Tensor>())
915       .def(py::init<const VarHandle&>())
916       .def(py::init<const BufHandle&>());
917 
918   py::implicitly_convertible<Tensor, CodeGen::BufferArg>();
919   py::implicitly_convertible<VarHandle, CodeGen::BufferArg>();
920   py::implicitly_convertible<BufHandle, CodeGen::BufferArg>();
921 
922   te.def(
923       "construct_codegen",
924       [](const std::string& name,
925          const StmtPtr& stmt,
926          const std::vector<CodeGen::BufferArg>& args) {
927         CodeGen* cg = nullptr;
928         if (name == "llvm") {
929 #ifdef TORCH_ENABLE_LLVM
930           cg = new LLVMCodeGen(stmt, args);
931 #else
932           throw std::runtime_error("PyTorch not compiled with LLVM support!");
933 #endif
934         } else if (name == "cuda") {
935 #ifdef USE_CUDA
936           cg = new CudaCodeGen(stmt, args);
937 #else
938           throw std::runtime_error("PyTorch not compiled with CUDA support!");
939 #endif
940         } else if (name == "ir_eval") {
941           cg = new SimpleIREvaluator(stmt, args);
942         } else {
943           throw std::runtime_error(
944               "construct_codegen() expects 'llvm', 'cuda', or 'ir_eval'");
945         }
946         return cg;
947       });
948   te.def("annotate_input_shapes", &tensorexpr::annotateInputShapes);
949   te.def("remove_unused_self_argument", &tensorexpr::removeUnusedSelfArgument);
950   te.def("make_shapes_symbolic", &tensorexpr::makeShapesSymbolic);
951   te.def("is_graph_compilable", &tensorexpr::isGraphCompilable);
952   te.def("fixup_missing_shape_info", &tensorexpr::fixupMissingShapeInfo);
953   te.def("remove_graph_output", &tensorexpr::removeGraphOutput);
954   te.def(
955       "replace_list_output_with_tuple",
956       &tensorexpr::replaceListOutputWithTuple);
957   te.def("trim_graph", &tensorexpr::trimGraph);
958 #ifdef TORCH_ENABLE_LLVM
959   te.def("set_llvm_target_triple", [](const std::optional<std::string>& val) {
960     tensorexpr::LLVMTargetTriple() = val;
961   });
962   te.def("set_llvm_target_cpu", [](const std::optional<std::string>& val) {
963     tensorexpr::LLVMTargetCPU() = val;
964   });
965   te.def("set_llvm_target_attrs", [](const std::optional<std::string>& val) {
966     tensorexpr::LLVMTargetAttrs() = val;
967   });
968   te.def("set_llvm_aot_workflow", [](bool val) {
969     tensorexpr::LLVMAOTWorkflow() = val;
970   });
971 #endif
972 }
973 
974 } // namespace torch::jit
975