1 #include <pybind11/functional.h>
2 #include <pybind11/operators.h>
3 #include <pybind11/stl.h>
4 #include <torch/csrc/jit/python/pybind_utils.h>
5 #include <torch/csrc/jit/tensorexpr/codegen.h>
6 #include <torch/csrc/utils/pybind.h>
7 #ifdef USE_CUDA
8 #include <torch/csrc/jit/tensorexpr/cuda_codegen.h>
9 #endif
10 #include <torch/csrc/jit/tensorexpr/graph_opt.h>
11 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
12 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
13 #include <torch/csrc/jit/tensorexpr/kernel.h>
14 #include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
15 #include <torch/csrc/jit/tensorexpr/loopnest.h>
16 #include <torch/csrc/jit/tensorexpr/lowerings.h>
17 #include <torch/csrc/jit/tensorexpr/reduction.h>
18
19 #include <utility>
20
21 template <>
22 struct pybind11::detail::type_caster<torch::jit::tensorexpr::ArgValue>
23 : public type_caster_base<torch::jit::tensorexpr::ArgValue> {};
24
25 namespace torch::jit {
26 using namespace torch::jit::tensorexpr;
27
convertPyToArgValue(py::handle inp)28 ArgValue convertPyToArgValue(py::handle inp) {
29 if (py::isinstance<BufHandle>(inp)) {
30 return py::cast<BufHandle>(inp);
31 } else if (py::isinstance<VarHandle>(inp)) {
32 return py::cast<VarHandle>(inp);
33 } else if (py::isinstance<py::bool_>(inp)) {
34 return py::cast<bool>(inp);
35 } else if (py::isinstance<py::float_>(inp)) {
36 return py::cast<double>(inp);
37 } else if (py::isinstance<py::int_>(inp)) {
38 return py::cast<int64_t>(inp);
39 } else if (py::isinstance<py::none>(inp)) {
40 return ArgNone();
41 } else if (py::isinstance<py::list>(inp)) {
42 auto l = py::cast<py::list>(inp);
43 if (l.empty()) {
44 return std::vector<BufHandle>();
45 } else if (py::isinstance<py::int_>(l[0])) {
46 return py::cast<IntList>(inp);
47 } else if (py::isinstance<BufHandle>(l[0])) {
48 return py::cast<BufList>(inp);
49 } else {
50 throw std::runtime_error("vector conversion failed");
51 }
52 } else {
53 throw std::runtime_error("conversion not yet implemented");
54 }
55 }
56
parsePythonDtype(py::handle obj)57 Dtype parsePythonDtype(py::handle obj) {
58 if (THPDtype_Check(obj.ptr())) {
59 return Dtype(reinterpret_cast<THPDtype*>(obj.ptr())->scalar_type);
60 } else {
61 throw std::runtime_error("expected a torch.dtype instance");
62 }
63 }
64
initTensorExprBindings(PyObject * module)65 void initTensorExprBindings(PyObject* module) {
66 auto m = py::handle(module).cast<py::module>();
67
68 // Tensor Expr Classes
69 auto te = m.def_submodule("_te");
70
71 auto dtype_class =
72 py::class_<Dtype>(te, "Dtype").def(py::init(&parsePythonDtype));
73 py::implicitly_convertible<py::object, Dtype>();
74
75 #define DTYPE_SINGLETON_ACCESSOR(ctype, name) \
76 dtype_class.def_property_readonly_static( \
77 #name, [](const py::object&) { return k##name; });
78 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DTYPE_SINGLETON_ACCESSOR)
79 #undef DTYPE_SINGLETON_ACCESSOR
80
81 auto expr_handle_class =
82 py::class_<ExprHandle>(te, "ExprHandle")
83 .def(
84 "__str__",
85 [](const ExprHandle& self) {
86 std::stringstream ss;
87 ss << self;
88 return ss.str();
89 })
90 .def(py::self + py::self)
91 .def(py::self * py::self)
92 .def(py::self - py::self)
93 .def(py::self / py::self)
94 .def(py::self % py::self)
95 .def(py::self == py::self)
96 .def(py::self != py::self)
97 .def(py::self > py::self)
98 .def(py::self >= py::self)
99 .def(py::self < py::self)
100 .def(py::self <= py::self)
101 .def(py::self & py::self)
102 .def(py::self | py::self)
103 .def(py::self ^ py::self)
104 .def(py::self << py::self)
105 .def(py::self >> py::self)
106 .def(
107 "__pow__",
108 [](const ExprHandle& self, const ExprHandle& other) {
109 return pow(self, other);
110 })
111 .def("sin", [](const ExprHandle& self) { return sin(self); })
112 .def("cos", [](const ExprHandle& self) { return cos(self); })
113 .def("tan", [](const ExprHandle& self) { return tan(self); })
114 .def("asin", [](const ExprHandle& self) { return asin(self); })
115 .def("acos", [](const ExprHandle& self) { return acos(self); })
116 .def("atan", [](const ExprHandle& self) { return atan(self); })
117 .def("sinh", [](const ExprHandle& self) { return sinh(self); })
118 .def("cosh", [](const ExprHandle& self) { return cosh(self); })
119 .def("tanh", [](const ExprHandle& self) { return tanh(self); })
120 .def("sigmoid", [](const ExprHandle& self) { return sigmoid(self); })
121 .def("exp", [](const ExprHandle& self) { return exp(self); })
122 .def("expm1", [](const ExprHandle& self) { return expm1(self); })
123 .def(
124 "abs",
125 [](const ExprHandle& self) { return tensorexpr::abs(self); })
126 .def("log", [](const ExprHandle& self) { return log(self); })
127 .def(
128 "fast_tanh",
129 [](const ExprHandle& self) { return fast_tanh(self); })
130 .def(
131 "fast_sigmoid",
132 [](const ExprHandle& self) { return fast_sigmoid(self); })
133 .def(
134 "fast_log", [](const ExprHandle& self) { return fast_log(self); })
135 .def("log_vml", [](const ExprHandle& self) { return log_vml(self); })
136 .def("log2", [](const ExprHandle& self) { return log2(self); })
137 .def("log10", [](const ExprHandle& self) { return log10(self); })
138 .def("log1p", [](const ExprHandle& self) { return log1p(self); })
139 .def("erf", [](const ExprHandle& self) { return erf(self); })
140 .def("erfc", [](const ExprHandle& self) { return erfc(self); })
141 .def(
142 "sqrt",
143 [](const ExprHandle& self) { return tensorexpr::sqrt(self); })
144 .def("rsqrt", [](const ExprHandle& self) { return rsqrt(self); })
145 .def("ceil", [](const ExprHandle& self) { return ceil(self); })
146 .def("floor", [](const ExprHandle& self) { return floor(self); })
147 .def("round", [](const ExprHandle& self) { return round(self); })
148 .def("trunc", [](const ExprHandle& self) { return trunc(self); })
149 .def("frac", [](const ExprHandle& self) { return frac(self); })
150 .def("lgamma", [](const ExprHandle& self) { return lgamma(self); })
151 .def("isnan", [](const ExprHandle& self) { return isnan(self); })
152 .def(
153 "cast",
154 [](const ExprHandle& self, const Dtype& dt) {
155 return Cast::make(dt, self);
156 })
157 #define EXPRHANDLE_INIT(ctype, name) \
158 .def(py::init([](ctype val) { return name##Imm::make(val); }))
159 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_INIT)
160 #undef EXPRHANDLE_INIT
161 ;
162
163 #define EXPRHANDLE_IMPL_CONV(ctype, name) \
164 py::implicitly_convertible<ctype, ExprHandle>();
165 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_IMPL_CONV)
166 #undef EXPRHANDLE_IMPL_CONV
167
168 te.def(
169 "ifThenElse",
170 [](const ExprHandle& c, const ExprHandle& t, const ExprHandle& f) {
171 return ifThenElse(c, t, f);
172 });
173
174 te.def("sin", [](const ExprHandle& v1) { return sin(v1); });
175 te.def("cos", [](const ExprHandle& v1) { return cos(v1); });
176 te.def("tan", [](const ExprHandle& v1) { return tan(v1); });
177 te.def("asin", [](const ExprHandle& v1) { return asin(v1); });
178 te.def("acos", [](const ExprHandle& v1) { return acos(v1); });
179 te.def("atan", [](const ExprHandle& v1) { return atan(v1); });
180 te.def("sinh", [](const ExprHandle& v1) { return sinh(v1); });
181 te.def("cosh", [](const ExprHandle& v1) { return cosh(v1); });
182 te.def("tanh", [](const ExprHandle& v1) { return tanh(v1); });
183 te.def("sigmoid", [](const ExprHandle& v1) { return sigmoid(v1); });
184 te.def("exp", [](const ExprHandle& v1) { return exp(v1); });
185 te.def("expm1", [](const ExprHandle& v1) { return expm1(v1); });
186 te.def("abs", [](const ExprHandle& v1) { return abs(v1); });
187 te.def("log", [](const ExprHandle& v1) { return log(v1); });
188 te.def("log2", [](const ExprHandle& v1) { return log2(v1); });
189 te.def("log10", [](const ExprHandle& v1) { return log10(v1); });
190 te.def("log1p", [](const ExprHandle& v1) { return log1p(v1); });
191 te.def("erf", [](const ExprHandle& v1) { return erf(v1); });
192 te.def("erfc", [](const ExprHandle& v1) { return erfc(v1); });
193 te.def("sqrt", [](const ExprHandle& v1) { return sqrt(v1); });
194 te.def("rsqrt", [](const ExprHandle& v1) { return rsqrt(v1); });
195 te.def("ceil", [](const ExprHandle& v1) { return ceil(v1); });
196 te.def("floor", [](const ExprHandle& v1) { return floor(v1); });
197 te.def("round", [](const ExprHandle& v1) { return round(v1); });
198 te.def("trunc", [](const ExprHandle& v1) { return trunc(v1); });
199 te.def("frac", [](const ExprHandle& v1) { return frac(v1); });
200 te.def("lgamma", [](const ExprHandle& v1) { return lgamma(v1); });
201 te.def("isnan", [](const ExprHandle& v1) { return isnan(v1); });
202
203 te.def("atan2", [](const ExprHandle& v1, const ExprHandle& v2) {
204 return atan2(v1, v2);
205 });
206 te.def("pow", [](const ExprHandle& v1, const ExprHandle& v2) {
207 return pow(v1, v2);
208 });
209 te.def("fmod", [](const ExprHandle& v1, const ExprHandle& v2) {
210 return fmod(v1, v2);
211 });
212 te.def("remainder", [](const ExprHandle& v1, const ExprHandle& v2) {
213 return remainder(v1, v2);
214 });
215
216 #define EXPRHANDLE_CTOR(ctype, name) \
217 expr_handle_class.def_static(#ctype, [](ctype v) { return ExprHandle(v); });
218 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, EXPRHANDLE_CTOR)
219 #undef EXPRHANDLE_CTOR
220
221 py::class_<VarHandle, ExprHandle>(te, "VarHandle")
222 .def(
223 "__str__",
224 [](const ExprHandle& self) {
225 std::stringstream ss;
226 ss << self;
227 return ss.str();
228 })
229 .def(py::init<Dtype>())
230 .def(py::init<const std::string&, Dtype>());
231 py::class_<BufHandle, ExprHandle>(te, "BufHandle")
232 .def(
233 py::init<const std::string&, const std::vector<ExprHandle>&, Dtype>())
234 .def(py::init<const std::vector<ExprHandle>&, Dtype>())
235 .def(py::init<Dtype>())
236 .def(
237 "__hash__",
238 [](const BufHandle& self) {
239 return std::hash<BufPtr>()(self.node());
240 })
241 .def(
242 "__eq__",
243 [](const BufHandle& self, const BufHandle& other) {
244 return self.node() == other.node();
245 })
246 .def(
247 "load",
248 [](BufHandle& self, const std::vector<ExprHandle>& v) {
249 return Load::make(self, v);
250 })
251 .def(
252 "load",
253 [](BufHandle& self, const ExprHandle& v) {
254 return Load::make(self, {v});
255 })
256 .def(
257 "store",
258 [](BufHandle& self,
259 const std::vector<ExprHandle>& i,
260 const ExprHandle& v) { return Store::make(self, i, v); })
261 .def(
262 "store",
263 [](BufHandle& self, const ExprHandle& i, const ExprHandle& v) {
264 return Store::make(self, {i}, v);
265 });
266 py::class_<Tensor>(te, "Tensor")
267 .def(py::init([](const BufHandle& b, const StmtPtr& s) {
268 return Tensor(b.node(), s);
269 }))
270 .def(
271 "load",
272 [](Tensor& self, const std::vector<ExprHandle>& v) {
273 return self.load(v);
274 })
275 .def("buf", [](Tensor& self) { return BufHandle(self.buf()); })
276 .def("stmt", &Tensor::stmt);
277 py::class_<Cast, std::shared_ptr<Cast>>(te, "Cast")
278 .def_static("make", &Cast::make)
279 .def(
280 "src_value",
281 [](CastPtr& self) { return ExprHandle(self->src_value()); })
282 .def("set_src_value", [](CastPtr& self, const ExprHandle& value) {
283 self->set_src_value(value.node());
284 });
285
286 te.def(
287 "Compute",
288 [](const std::string& func_name,
289 const std::vector<ExprHandle>& dim_args,
290 const py::function& func) {
291 if (dim_args.size() == 1) {
292 return Compute(func_name, dim_args, [&func](const VarHandle& a) {
293 return py::cast<ExprHandle>(func(a));
294 });
295 } else if (dim_args.size() == 2) {
296 return Compute(
297 func_name,
298 dim_args,
299 [&func](const VarHandle& a, const VarHandle& b) {
300 return py::cast<ExprHandle>(func(a, b));
301 });
302 } else if (dim_args.size() == 3) {
303 return Compute(
304 func_name,
305 dim_args,
306 [&func](
307 const VarHandle& a, const VarHandle& b, const VarHandle& c) {
308 return py::cast<ExprHandle>(func(a, b, c));
309 });
310 } else if (dim_args.size() == 4) {
311 return Compute(
312 func_name,
313 dim_args,
314 [&func](
315 const VarHandle& a,
316 const VarHandle& b,
317 const VarHandle& c,
318 const VarHandle& d) {
319 return py::cast<ExprHandle>(func(a, b, c, d));
320 });
321 } else {
322 throw std::runtime_error("Too many args");
323 }
324 },
325 py::return_value_policy::reference);
326
327 te.def(
328 "Compute2",
329 [](const std::string& func_name,
330 const std::vector<ExprHandle>& dim_args,
331 const py::function& func) {
332 return Compute(
333 func_name, dim_args, [&func](const std::vector<VarHandle>& dims) {
334 return py::cast<ExprHandle>(func(dims));
335 });
336 },
337 py::return_value_policy::reference);
338
339 py::class_<Reducer>(te, "Reducer")
340 .def(py::init<
341 ExprHandle,
342 std::function<ExprHandle(ExprHandle, ExprHandle)>>());
343
344 py::class_<Sum, Reducer>(te, "Sum").def(py::init<>());
345 py::class_<Maximum, Reducer>(te, "Maximum").def(py::init<Dtype>());
346 te.def(
347 "Reduce",
348 [](const std::string& func_name,
349 const std::vector<ExprHandle>& dim_args,
350 const Reducer& reducer,
351 const Tensor& buffer,
352 const std::vector<ExprHandle>& reduce_args) {
353 return Reduce(func_name, dim_args, reducer, buffer, reduce_args);
354 },
355 py::return_value_policy::reference);
356
357 te.def(
358 "Reduce",
359 [](const std::string& func_name,
360 const std::vector<ExprHandle>& dim_args,
361 const Reducer& reducer,
362 const BufHandle& buffer,
363 const std::vector<ExprHandle>& reduce_args) {
364 return Reduce(func_name, dim_args, reducer, buffer, reduce_args);
365 },
366 py::return_value_policy::reference);
367 te.def(
368 "Reduce",
369 [](const std::string& func_name,
370 const std::vector<ExprHandle>& dim_args,
371 const Reducer& reducer,
372 const std::function<ExprHandle(const std::vector<VarHandle>&)>&
373 body_func,
374 const std::vector<ExprHandle>& reduce_args) {
375 return Reduce(func_name, dim_args, reducer, body_func, reduce_args);
376 },
377 py::return_value_policy::reference);
378 te.def(
379 "Reduce",
380 [](const std::string& func_name,
381 const std::vector<ExprHandle>& dim_args,
382 const Reducer& reducer,
383 const std::function<ExprHandle(const std::vector<VarHandle>&)>&
384 init_func,
385 const std::function<ExprHandle(const std::vector<VarHandle>&)>&
386 body_func,
387 const std::vector<ExprHandle>& reduce_args) {
388 return Reduce(func_name, dim_args, reducer, body_func, reduce_args);
389 },
390 py::return_value_policy::reference);
391
392 py::class_<Stmt, std::shared_ptr<Stmt>>(te, "Stmt")
393 .def(py::init([](const std::vector<StmtPtr>& stmts) {
394 return tensorexpr::Block::make(stmts);
395 }))
396 .def("__str__", [](Stmt& self) {
397 std::stringstream ss;
398 ss << self;
399 return ss.str();
400 });
401 py::class_<Store, Stmt, std::shared_ptr<Store>>(te, "Store")
402 .def_static(
403 "make",
404 [](const BufHandle& buf,
405 std::vector<ExprHandle>& indices,
406 const ExprHandle& value) {
407 return Store::make(buf, indices, value);
408 });
409
410 py::class_<For, Stmt, std::shared_ptr<For>>(te, "For")
411 .def("index_var", [](For& self) { return VarHandle(self.var()); })
412 .def("body", &For::body)
413 .def("set_parallel", &For::set_parallel)
414 .def(
415 "set_gpu_block_index",
416 [](For& self, int block_index) {
417 self.set_gpu_block_index(block_index);
418 })
419 .def(
420 "set_gpu_thread_index",
421 [](For& self, int thread_index) {
422 self.set_gpu_thread_index(thread_index);
423 })
424 .def_static(
425 "make",
426 [](const VarHandle& var,
427 const ExprHandle& start,
428 const ExprHandle& stop,
429 const StmtPtr& body) {
430 return For::make(var, start, stop, body);
431 });
432
433 py::class_<Cond, Stmt, std::shared_ptr<Cond>>(te, "Cond")
434 .def_static(
435 "make",
436 [](const ExprHandle& condition,
437 const StmtPtr& true_stmt,
438 const StmtPtr& false_stmt) {
439 return Cond::make(condition, true_stmt, false_stmt);
440 })
441 .def("true_stmt", &Cond::true_stmt)
442 .def("false_stmt", &Cond::false_stmt);
443
444 py::class_<tensorexpr::Block, Stmt, std::shared_ptr<tensorexpr::Block>>(
445 te, "Block")
446 .def(py::init([](const std::vector<StmtPtr>& stmts) {
447 return tensorexpr::Block::make(stmts);
448 }))
449 .def("stmts", &tensorexpr::Block::stmts);
450 py::class_<ExternalCall, Stmt, std::shared_ptr<ExternalCall>>(
451 te, "ExternalCall")
452 .def(py::init(&ExternalCall::make));
453
454 py::class_<LoopNest>(te, "LoopNest")
455 .def(py::init<const std::vector<Tensor>&>())
456 .def(py::init<const std::vector<Tensor>&, const std::vector<Tensor>&>())
457 .def(py::init([](const StmtPtr& s, const std::vector<BufHandle>& bufs) {
458 std::unordered_set<BufPtr> buf_nodes;
459 buf_nodes.reserve(bufs.size());
460 for (auto& buf : bufs) {
461 buf_nodes.insert(buf.node());
462 }
463 return std::make_unique<LoopNest>(s, buf_nodes);
464 }))
465 .def("vectorize_inner_loops", &LoopNest::vectorizeInnerLoops)
466 .def(
467 "prepare_for_codegen",
468 [](LoopNest& self) { return self.prepareForCodegen(); },
469 py::return_value_policy::reference)
470 .def(
471 "get_loop_body_for",
472 [](const LoopNest& self, const Tensor& t) {
473 return self.getLoopBodyFor(t);
474 },
475 py::return_value_policy::reference)
476 .def(
477 "get_loop_body_for",
478 [](const LoopNest& self, BufHandle& b) {
479 return self.getLoopBodyFor(b.node());
480 },
481 py::return_value_policy::reference)
482 .def(
483 "get_loops_for",
484 [](const LoopNest& self, const Tensor& t) {
485 return self.getLoopStmtsFor(t);
486 },
487 py::return_value_policy::reference)
488 .def(
489 "get_all_loopnests_for",
490 [](const LoopNest& self, const BufHandle& b) {
491 return self.getAllLoopNestsWritingToBuf(b.node());
492 },
493 py::return_value_policy::reference)
494 .def(
495 "get_enclosing_loopnest",
496 [](const LoopNest& self, const StmtPtr& s) {
497 return self.getEnclosingLoopNest(s);
498 },
499 py::return_value_policy::reference)
500 .def(
501 "get_innermost_loops_for",
502 [](const LoopNest& self, const BufHandle& b) {
503 return self.getAllInnermostLoopsWritingToBuf(b.node());
504 },
505 py::return_value_policy::reference)
506 .def(
507 "get_writes_for",
508 [](const LoopNest& self, const BufHandle& b) {
509 return self.getAllWritesToBuf(b.node());
510 },
511 py::return_value_policy::reference)
512 .def(
513 "get_loop_at",
514 [](const LoopNest& self,
515 ForPtr root,
516 const std::vector<int>& indices) {
517 return self.getLoopAt(std::move(root), indices);
518 },
519 py::return_value_policy::reference)
520 .def(
521 "get_parent_loop",
522 [](const LoopNest& self, const StmtPtr& s) {
523 return self.getParentLoop(s);
524 },
525 py::return_value_policy::reference)
526 .def_static(
527 "get_loop_stmts_in_loopnest",
528 [](const ForPtr& f, size_t num) {
529 return LoopNest::getLoopStmtsInLoopNest(f, num);
530 },
531 py::return_value_policy::reference)
532 .def(
533 "split_with_tail",
534 [](const ForPtr& f, int factor) {
535 ForPtr inner = nullptr, tail = nullptr;
536 LoopNest::splitWithTail(f, factor, &inner, &tail);
537 return std::make_tuple(std::move(inner), std::move(tail));
538 },
539 py::return_value_policy::reference)
540 .def(
541 "split_with_mask",
542 [](const ForPtr& f, int factor) {
543 ForPtr inner = nullptr;
544 LoopNest::splitWithMask(f, factor, &inner);
545 return inner;
546 },
547 py::return_value_policy::reference)
548 .def(
549 "slice_head",
550 [](const ForPtr& f, int factor) {
551 ForPtr head = nullptr, tail = nullptr;
552 LoopNest::sliceHead(f, factor, &head, &tail);
553 return std::make_tuple(std::move(head), std::move(tail));
554 },
555 py::return_value_policy::reference)
556 .def(
557 "slice_tail",
558 [](const ForPtr& f, int factor) {
559 ForPtr head = nullptr, tail = nullptr;
560 LoopNest::sliceTail(f, factor, &head, &tail);
561 return std::make_tuple(std::move(head), std::move(tail));
562 },
563 py::return_value_policy::reference)
564 .def_static(
565 "normalize",
566 [](const ForPtr& f) {
567 LoopNest::normalize(f);
568 return f;
569 },
570 py::return_value_policy::reference)
571 .def(
572 "tile",
573 [](LoopNest& self,
574 const ForPtr& x,
575 const ForPtr& y,
576 int x_factor,
577 int y_factor) { return self.tile(x, y, x_factor, y_factor); },
578 py::return_value_policy::reference)
579 .def_static(
580 "distribute_loop",
581 [](const ForPtr& f) { return LoopNest::distributeLoop(f); },
582 py::return_value_policy::reference)
583 .def_static(
584 "distribute_loop",
585 [](const ForPtr& f, const std::unordered_set<StmtPtr>& pivots) {
586 return LoopNest::distributeLoop(f, pivots);
587 },
588 py::return_value_policy::reference)
589 .def_static(
590 "distribute_loop_over_inner_loops",
591 [](const ForPtr& f) {
592 return LoopNest::distributeLoopOverInnerLoops(f);
593 },
594 py::return_value_policy::reference)
595 .def_static(
596 "unsafe_fuse_loops",
597 [](const std::vector<ForPtr>& loops) {
598 ForPtr fused_loop = nullptr;
599 LoopNest::unsafeFuseLoops(loops, &fused_loop);
600 return fused_loop;
601 },
602 py::return_value_policy::reference)
603 .def_static(
604 "fuse_loops",
605 [](const std::vector<ForPtr>& loops) {
606 ForPtr fused_loop = nullptr;
607 LoopNest::fuseLoops(loops, &fused_loop);
608 return fused_loop;
609 },
610 py::return_value_policy::reference)
611 .def_static(
612 "reorder",
613 [](const std::vector<ForPtr>& loops,
614 const std::vector<size_t>& permutation) {
615 return LoopNest::reorder(loops, permutation);
616 },
617 py::return_value_policy::reference)
618 .def(
619 "fullUnroll",
620 [](const ForPtr& f) {
621 StmtPtr unrolled = nullptr;
622 LoopNest::fullUnroll(f, &unrolled);
623 return unrolled;
624 },
625 py::return_value_policy::reference)
626 .def(
627 "unroll",
628 [](const ForPtr& f, int factor) {
629 LoopNest::unroll(f, factor);
630 return f;
631 },
632 py::return_value_policy::reference)
633 .def(
634 "vectorize",
635 [](const ForPtr& f) { LoopNest::vectorize(f); },
636 py::return_value_policy::reference)
637 .def_static(
638 "compress_buffer",
639 [](BufHandle& buf, const StmtPtr& stmt) {
640 return LoopNest::compressBuffer(buf.node(), stmt);
641 },
642 py::return_value_policy::reference)
643 .def_static(
644 "cache_accesses",
645 [](const BufHandle& producer,
646 const std::string& name,
647 const StmtPtr& consumer) {
648 std::pair<BufPtr, StmtPtr> ret =
649 LoopNest::cacheAccesses(producer.node(), name, consumer);
650 return std::make_pair(BufHandle(ret.first), ret.second);
651 },
652 py::return_value_policy::reference)
653 .def_static(
654 "compute_at",
655 [](const StmtPtr& s, const ForPtr& at) {
656 LoopNest::computeAt(s, at);
657 })
658 .def(
659 "compute_inline",
660 [](LoopNest& self, const StmtPtr& s) { self.computeInline(s); },
661 py::return_value_policy::reference)
662 .def(
663 "compute_inline",
664 [](LoopNest& self, const BufHandle& b) {
665 self.computeInline(b.node());
666 },
667 py::return_value_policy::reference)
668 .def(
669 "rfactor",
670 [](const StmtPtr& s, const ForPtr& target_for) {
671 BufPtr rfac_buf = nullptr;
672 LoopNest::rfactor(s, target_for, &rfac_buf);
673 return BufHandle(rfac_buf);
674 },
675 py::return_value_policy::reference)
676 .def(
677 "flatten",
678 [](LoopNest& self, const std::vector<ForPtr>& loops) {
679 ForPtr flattened = nullptr;
680 LoopNest::flatten(loops, &flattened);
681 return flattened;
682 },
683 py::return_value_policy::reference)
684 .def(
685 "reorder_axis",
686 &LoopNest::reorderAxis,
687 py::return_value_policy::reference)
688 .def("simplify", &LoopNest::simplify, py::return_value_policy::reference)
689 .def_static("sanitize_names", &LoopNest::sanitizeNames)
690 .def(
691 "inline_intermediate_bufs",
692 [](LoopNest& self, bool allow_duplicated_work) {
693 self.inlineIntermediateBufs(allow_duplicated_work);
694 })
695 .def(
696 "eliminate_dead_stores",
697 [](LoopNest& self) { self.eliminateDeadStores(); })
698 .def(
699 "__str__",
700 [](const LoopNest& self) {
701 std::stringstream ss;
702 ss << *self.root_stmt();
703 return ss.str();
704 })
705 .def(
706 "root_stmt",
707 &LoopNest::root_stmt,
708 py::return_value_policy::reference);
709
710 te.def(
711 "simplify",
712 [](const StmtPtr& stmt) { return IRSimplifier::simplify(stmt); },
713 py::return_value_policy::reference);
714
715 te.def(
716 "lower",
717 [](const std::string& op_str,
718 const py::list& inputs,
719 const std::vector<ExprHandle>& outputShape,
720 Dtype outputType) {
721 auto op = c10::Symbol::fromQualString(op_str);
722 std::vector<ArgValue> argInputs;
723 for (auto inp : inputs) {
724 argInputs.push_back(convertPyToArgValue(inp));
725 }
726 if (NNCLoweringFunction lowering =
727 getStandardLoweringFor(op.toQualString())) {
728 std::vector<ExprHandle> outputStrides =
729 c10::fmap<ExprHandle>(make_channels_last_strides(outputShape));
730 return lowering(
731 argInputs,
732 outputShape,
733 outputStrides,
734 outputType.scalar_type(),
735 at::kCPU);
736 }
737 std::string msg = std::string("Unhandled node kind (in te.lower): ") +
738 op.toQualString();
739 throw malformed_input(msg);
740 });
741
742 py::class_<ArgValue>(te, "ArgValue")
743 .def(py::init([](py::handle inp) {
744 return std::make_unique<ArgValue>(convertPyToArgValue(inp));
745 }))
746 .def(
747 "as_buf",
748 [](const ArgValue& self) { return std::get<BufHandle>(self); })
749 .def(
750 "as_var",
751 [](const ArgValue& self) { return std::get<VarHandle>(self); })
752 .def(
753 "as_float",
754 [](const ArgValue& self) { return std::get<double>(self); })
755 .def(
756 "as_int",
757 [](const ArgValue& self) { return std::get<int64_t>(self); })
758 .def("as_bool", [](const ArgValue& self) { return std::get<bool>(self); })
759 .def(
760 "as_none",
761 [](const ArgValue& self) { return std::get<ArgNone>(self); })
762 .def(
763 "as_buflist",
764 [](const ArgValue& self) { return std::get<BufList>(self); })
765 .def("as_intlist", [](const ArgValue& self) {
766 return std::get<IntList>(self);
767 });
768
769 py::class_<c10::ScalarType> give_me_a_name(te, "ScalarType");
770
771 using TSGraph = std::shared_ptr<Graph>;
772 py::class_<TensorExprKernel>(te, "TensorExprKernel")
773 .def(py::init<const TSGraph&>())
774 .def(
775 py::init(
776 [](const TSGraph& g,
777 const std::unordered_map<std::string, NNCLoweringFunction>&
778 custom_lowerings_str,
779 std::vector<int64_t> symbolic_shape_inputs,
780 bool pre_alloc = false) {
781 std::unordered_map<c10::Symbol, NNCLoweringFunction>
782 custom_lowerings;
783 custom_lowerings.reserve(custom_lowerings_str.size());
784 for (auto& kv : custom_lowerings_str) {
785 custom_lowerings[c10::Symbol::fromQualString(kv.first)] =
786 kv.second;
787 }
788 return std::make_unique<TensorExprKernel>(
789 g,
790 std::move(custom_lowerings),
791 std::move(symbolic_shape_inputs),
792 pre_alloc);
793 }),
794 py::arg("g"),
795 py::arg("custom_lowerings_str"),
796 py::arg("symbolic_shape_inputs") = std::vector<int64_t>(),
797 py::arg("pre_alloc") = false)
798 .def(
799 "run",
800 [](TensorExprKernel& self, const py::tuple& inputs) {
801 Stack stack;
802 stack.reserve(inputs.size()); // captures?
803 for (auto& obj : inputs) {
804 stack.push_back(toTypeInferredIValue(obj));
805 }
806 auto g_inputs = self.graph()->inputs();
807 for (size_t i = 0; i < inputs.size(); ++i) {
808 if (stack[i].isTensor()) {
809 g_inputs[i]->setType(stack[i].type());
810 }
811 }
812 self.run(stack);
813 return createPyObjectForStack(std::move(stack));
814 })
815 .def(
816 "fallback",
817 [](TensorExprKernel& self, const py::tuple& inputs) {
818 Stack stack;
819 stack.reserve(inputs.size()); // captures?
820 for (auto& obj : inputs) {
821 stack.push_back(toTypeInferredIValue(obj));
822 }
823 auto g_inputs = self.graph()->inputs();
824 for (size_t i = 0; i < inputs.size(); ++i) {
825 if (stack[i].isTensor()) {
826 g_inputs[i]->setType(stack[i].type());
827 }
828 }
829 self.fallback(stack);
830 return createPyObjectForStack(std::move(stack));
831 })
832 .def(
833 "get_codegen_stmt",
834 [](TensorExprKernel& self) { return self.getCodeGenStmt(); },
835 py::return_value_policy::reference)
836 .def(
837 "get_code_text",
838 [](TensorExprKernel& self, const std::string& attr = "") {
839 return self.getCodeText(attr);
840 },
841 py::arg("attr") = "")
842 .def("recompile", [](TensorExprKernel& self) { self.recompile(); });
843
844 py::class_<CodeGen>(te, "CodeGen")
845 .def(
846 "call",
847 [](CodeGen& self, const py::sequence& values) {
848 std::vector<CodeGen::CallArg> value_ptrs;
849 value_ptrs.reserve(py::len(values));
850 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
851 for (const auto& value : values) {
852 if (py::isinstance<py::int_>(value)) {
853 value_ptrs.emplace_back(value.cast<int64_t>());
854 } else {
855 value_ptrs.emplace_back(value.cast<at::Tensor>().data_ptr());
856 }
857 }
858 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
859 if (py::len(values) != self.buffer_args().size()) {
860 throw malformed_input("bad args in CodeGen.call function");
861 }
862 for (size_t i = 0; i < py::len(values); i++) {
863 const auto& value = values[i];
864 const auto& bufArg = self.buffer_args()[i];
865 if (py::isinstance<py::int_>(value)) {
866 if (!bufArg.isVar()) {
867 throw malformed_input(
868 "Integer variable expected in CodeGen.call function");
869 }
870 switch (bufArg.dtype().scalar_type()) {
871 #define TYPE_CASE(Type, Name) \
872 case ScalarType::Name: { \
873 value_ptrs.emplace_back(value.cast<Type>()); \
874 break; \
875 }
876 AT_FORALL_INT_TYPES(TYPE_CASE);
877 default:
878 throw unsupported_dtype();
879 }
880 } else {
881 value_ptrs.emplace_back(value.cast<at::Tensor>().data_ptr());
882 }
883 }
884 #else
885 #error Unexpected or undefined __BYTE_ORDER__
886 #endif
887 self.call(value_ptrs);
888 })
889 .def(
890 "call_raw",
891 [](CodeGen& self, const py::sequence& values) {
892 std::vector<void*> value_ptrs;
893 value_ptrs.reserve(py::len(values));
894 for (const auto& value : values) {
895 // Tensor.data_ptr() returns an int in python
896 value_ptrs.emplace_back(
897 reinterpret_cast<void*>(value.cast<intptr_t>()));
898 }
899 self.call_raw(value_ptrs);
900 })
901 .def(
902 "get_code_text",
903 [](CodeGen& self, const std::string& attr = "") {
904 return self.getCodeText(attr);
905 },
906 py::arg("attr") = "");
907 // NOLINTNEXTLINE(bugprone-unused-raii)
908 py::class_<SimpleIREvaluator, CodeGen>(te, "SimpleIREvaluator");
909 #ifdef TORCH_ENABLE_LLVM
910 py::class_<LLVMCodeGen, CodeGen>(te, "LLVMCodeGen");
911 #endif
912
913 py::class_<CodeGen::BufferArg>(te, "BufferArg")
914 .def(py::init<Tensor>())
915 .def(py::init<const VarHandle&>())
916 .def(py::init<const BufHandle&>());
917
918 py::implicitly_convertible<Tensor, CodeGen::BufferArg>();
919 py::implicitly_convertible<VarHandle, CodeGen::BufferArg>();
920 py::implicitly_convertible<BufHandle, CodeGen::BufferArg>();
921
922 te.def(
923 "construct_codegen",
924 [](const std::string& name,
925 const StmtPtr& stmt,
926 const std::vector<CodeGen::BufferArg>& args) {
927 CodeGen* cg = nullptr;
928 if (name == "llvm") {
929 #ifdef TORCH_ENABLE_LLVM
930 cg = new LLVMCodeGen(stmt, args);
931 #else
932 throw std::runtime_error("PyTorch not compiled with LLVM support!");
933 #endif
934 } else if (name == "cuda") {
935 #ifdef USE_CUDA
936 cg = new CudaCodeGen(stmt, args);
937 #else
938 throw std::runtime_error("PyTorch not compiled with CUDA support!");
939 #endif
940 } else if (name == "ir_eval") {
941 cg = new SimpleIREvaluator(stmt, args);
942 } else {
943 throw std::runtime_error(
944 "construct_codegen() expects 'llvm', 'cuda', or 'ir_eval'");
945 }
946 return cg;
947 });
948 te.def("annotate_input_shapes", &tensorexpr::annotateInputShapes);
949 te.def("remove_unused_self_argument", &tensorexpr::removeUnusedSelfArgument);
950 te.def("make_shapes_symbolic", &tensorexpr::makeShapesSymbolic);
951 te.def("is_graph_compilable", &tensorexpr::isGraphCompilable);
952 te.def("fixup_missing_shape_info", &tensorexpr::fixupMissingShapeInfo);
953 te.def("remove_graph_output", &tensorexpr::removeGraphOutput);
954 te.def(
955 "replace_list_output_with_tuple",
956 &tensorexpr::replaceListOutputWithTuple);
957 te.def("trim_graph", &tensorexpr::trimGraph);
958 #ifdef TORCH_ENABLE_LLVM
959 te.def("set_llvm_target_triple", [](const std::optional<std::string>& val) {
960 tensorexpr::LLVMTargetTriple() = val;
961 });
962 te.def("set_llvm_target_cpu", [](const std::optional<std::string>& val) {
963 tensorexpr::LLVMTargetCPU() = val;
964 });
965 te.def("set_llvm_target_attrs", [](const std::optional<std::string>& val) {
966 tensorexpr::LLVMTargetAttrs() = val;
967 });
968 te.def("set_llvm_aot_workflow", [](bool val) {
969 tensorexpr::LLVMAOTWorkflow() = val;
970 });
971 #endif
972 }
973
974 } // namespace torch::jit
975