1 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
2
3 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
4 #include <torch/csrc/jit/tensorexpr/reduction.h>
5 #include <torch/csrc/jit/tensorexpr/tensor.h>
6
7 #include <c10/util/irange.h>
8
9 #include <iostream>
10
11 namespace torch::jit::tensorexpr {
12
dtypeToCppString(const Dtype & dtype)13 std::string IRPrinter::dtypeToCppString(const Dtype& dtype) {
14 return dtype.ToCppString();
15 }
16
print(ExprHandle expr)17 void IRPrinter::print(ExprHandle expr) {
18 expr.node()->accept(this);
19 }
20
print(Expr & expr)21 void IRPrinter::print(Expr& expr) {
22 expr.accept(this);
23 }
24
print(Stmt & stmt)25 void IRPrinter::print(Stmt& stmt) {
26 stmt.accept(this);
27 }
to_string(CompareSelectOperation op)28 std::string IRPrinter::to_string(CompareSelectOperation op) {
29 switch (op) {
30 case CompareSelectOperation::kEQ:
31 return "==";
32 case CompareSelectOperation::kNE:
33 return "!=";
34 case CompareSelectOperation::kGT:
35 return ">";
36 case CompareSelectOperation::kGE:
37 return ">=";
38 case CompareSelectOperation::kLT:
39 return "<";
40 case CompareSelectOperation::kLE:
41 return "<=";
42 default:
43 throw std::runtime_error("invalid compare select operator");
44 }
45 }
46
47 // TODO: change whether to include the parenthesis to the parent expression,
48 // we need to look at the operator precedence to make the output simpler.
49 template <
50 typename Op,
51 std::enable_if_t<std::is_same_v<
52 decltype(detail::bin_op_deducer(std::declval<Op>())),
53 void>>* = nullptr>
visitBinaryOp(NodePtr<Op> v,const std::string & op_str,IRPrinter * printer,bool parens=true)54 void visitBinaryOp(
55 NodePtr<Op> v,
56 const std::string& op_str,
57 IRPrinter* printer,
58 bool parens = true) {
59 std::ostream& os = printer->os();
60 int self_prec = getPrecedence(v->expr_type());
61 int lhs_prec = getPrecedence(v->lhs()->expr_type());
62 int rhs_prec = getPrecedence(v->rhs()->expr_type());
63
64 if (lhs_prec >= self_prec) {
65 os << "(";
66 }
67 v->lhs()->accept(printer);
68 if (lhs_prec >= self_prec) {
69 os << ")";
70 }
71
72 os << " " << op_str << " ";
73
74 if (rhs_prec >= self_prec) {
75 os << "(";
76 }
77 v->rhs()->accept(printer);
78 if (rhs_prec >= self_prec) {
79 os << ")";
80 }
81 }
82
visit(const AddPtr & v)83 void IRPrinter::visit(const AddPtr& v) {
84 visitBinaryOp(v, "+", this);
85 }
86
visit(const SubPtr & v)87 void IRPrinter::visit(const SubPtr& v) {
88 visitBinaryOp(v, "-", this);
89 }
90
visit(const MulPtr & v)91 void IRPrinter::visit(const MulPtr& v) {
92 visitBinaryOp(v, "*", this);
93 }
94
visit(const DivPtr & v)95 void IRPrinter::visit(const DivPtr& v) {
96 visitBinaryOp(v, "/", this);
97 }
98
visit(const AndPtr & v)99 void IRPrinter::visit(const AndPtr& v) {
100 visitBinaryOp(v, "&", this);
101 }
102
visit(const OrPtr & v)103 void IRPrinter::visit(const OrPtr& v) {
104 visitBinaryOp(v, "|", this);
105 }
106
visit(const XorPtr & v)107 void IRPrinter::visit(const XorPtr& v) {
108 visitBinaryOp(v, "^", this);
109 }
110
visit(const LshiftPtr & v)111 void IRPrinter::visit(const LshiftPtr& v) {
112 visitBinaryOp(v, "<<", this);
113 }
114
visit(const RshiftPtr & v)115 void IRPrinter::visit(const RshiftPtr& v) {
116 visitBinaryOp(v, ">>", this);
117 }
118
visit(const ModPtr & v)119 void IRPrinter::visit(const ModPtr& v) {
120 if (v->dtype().is_integral()) {
121 visitBinaryOp(v, "%", this);
122 } else if (v->dtype().is_floating_point()) {
123 os() << "mod(" << *v->lhs() << ", " << *v->rhs() << ")";
124 } else {
125 throw std::runtime_error("invalid dtype: " + std::to_string(v->dtype()));
126 }
127 }
128
visit(const MaxPtr & v)129 void IRPrinter::visit(const MaxPtr& v) {
130 os() << "Max(";
131 v->lhs()->accept(this);
132 os() << ", ";
133 v->rhs()->accept(this);
134 os() << ", " << (unsigned int)v->propagate_nans() << ")";
135 }
136
visit(const MinPtr & v)137 void IRPrinter::visit(const MinPtr& v) {
138 os() << "Min(";
139 v->lhs()->accept(this);
140 os() << ", ";
141 v->rhs()->accept(this);
142 os() << ", " << (unsigned int)v->propagate_nans() << ")";
143 }
144
visit(const CompareSelectPtr & v)145 void IRPrinter::visit(const CompareSelectPtr& v) {
146 CompareSelectOperation cmp_op = v->compare_select_op();
147 int self_prec = getPrecedence(v->expr_type());
148 int lhs_prec = getPrecedence(v->lhs()->expr_type());
149 int rhs_prec = getPrecedence(v->rhs()->expr_type());
150
151 if (lhs_prec >= self_prec) {
152 os() << "(";
153 }
154 v->lhs()->accept(this);
155 if (lhs_prec >= self_prec) {
156 os() << ")";
157 }
158
159 os() << to_string(cmp_op);
160
161 if (rhs_prec >= self_prec) {
162 os() << "(";
163 }
164 v->rhs()->accept(this);
165 if (rhs_prec >= self_prec) {
166 os() << ")";
167 }
168 os() << " ? ";
169
170 auto withParens = [&](const ExprPtr& e) {
171 auto prec = getPrecedence(e->expr_type());
172 if (prec >= self_prec) {
173 os() << "(";
174 }
175 e->accept(this);
176 if (prec >= self_prec) {
177 os() << ")";
178 }
179 };
180 withParens(v->ret_val1());
181 os() << " : ";
182 withParens(v->ret_val2());
183 }
184
formatFPSuffix(std::ostream & os,double v)185 static void formatFPSuffix(std::ostream& os, double v) {
186 os << (v == std::ceil(v) ? ".0" : "");
187 }
188
189 template <typename T>
formatFPSuffix(std::ostream & os,T v)190 static void formatFPSuffix(std::ostream& os, T v) {
191 os << (v == std::ceil(v) ? ".f" : "f");
192 }
193
194 template <typename T, std::enable_if_t<std::is_floating_point_v<T>>* = nullptr>
formatImm(std::ostream & os,T v)195 static void formatImm(std::ostream& os, T v) {
196 const int precision = 16;
197 if (std::isnan(v)) {
198 os << "NAN";
199 } else if (std::isinf(v)) {
200 os << (v > 0 ? "POS_INFINITY" : "NEG_INFINITY");
201 } else {
202 os << std::setprecision(precision) << v;
203 formatFPSuffix(os, v);
204 }
205 }
206
formatIntSuffix(std::ostream & os,int64_t v)207 static void formatIntSuffix(std::ostream& os, int64_t v) {
208 os << "ll";
209 }
210
211 template <typename T>
formatIntSuffix(std::ostream & os,T v)212 static void formatIntSuffix(std::ostream& os, T v) {}
213
214 template <typename T, std::enable_if_t<!std::is_floating_point_v<T>>* = nullptr>
formatImm(std::ostream & os,T v)215 static void formatImm(std::ostream& os, T v) {
216 os << +v;
217 formatIntSuffix(os, v);
218 }
219
220 // NOLINTNEXTLINE
221 #define IMM_PRINT_VISIT(Type, Name) \
222 void IRPrinter::visit(const Name##ImmPtr& v) { \
223 formatImm(os(), v->value()); \
224 }
225 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT);
226 #undef IMM_PRINT_VISIT
227
visit(const CastPtr & v)228 void IRPrinter::visit(const CastPtr& v) {
229 auto dtype = v->dtype();
230 os() << dtypeToCppString(dtype) << "(";
231 v->src_value()->accept(this);
232 os() << ")";
233 }
234
visit(const BitCastPtr & v)235 void IRPrinter::visit(const BitCastPtr& v) {
236 auto dtype = v->dtype();
237 os() << "BitCast<" << dtype.ToCppString() << ">(";
238 v->src_value()->accept(this);
239 os() << ")";
240 }
241
visit(const VarPtr & v)242 void IRPrinter::visit(const VarPtr& v) {
243 os() << name_manager_.get_unique_name(v);
244 }
245
visit(const BufPtr & v)246 void IRPrinter::visit(const BufPtr& v) {
247 auto dtype = v->dtype();
248 os() << *v->base_handle();
249 os() << "(dtype=" << dtypeToCppString(dtype);
250 if (v->qscale()) {
251 os() << ", qscale=";
252 v->qscale()->accept(this);
253 }
254 if (v->qscale()) {
255 os() << ", qzero=";
256 v->qzero()->accept(this);
257 }
258 os() << ", sizes=[";
259 size_t i = 0;
260 for (const ExprPtr& s : v->dims()) {
261 if (i++) {
262 os() << ", ";
263 }
264 s->accept(this);
265 }
266 os() << "]";
267 os() << ", strides=[";
268 i = 0;
269 for (const ExprPtr& s : v->strides()) {
270 if (i++) {
271 os() << ", ";
272 }
273 s->accept(this);
274 }
275 os() << "]";
276
277 os() << ")";
278 }
279
visit(const RampPtr & v)280 void IRPrinter::visit(const RampPtr& v) {
281 os() << "Ramp(" << *v->base() << ", " << *v->stride() << ", " << v->lanes()
282 << ")";
283 }
284
visit(const LoadPtr & v)285 void IRPrinter::visit(const LoadPtr& v) {
286 // TODO: support the mask case
287 if (v->indices().empty()) {
288 os() << *v->base_handle();
289 } else {
290 os() << *v->base_handle() << "[";
291 size_t i = 0;
292 for (const ExprPtr& ind : v->indices()) {
293 if (i++) {
294 os() << ", ";
295 }
296 ind->accept(this);
297 }
298 if (v->indices().empty()) {
299 os() << "0";
300 }
301 os() << "]";
302 }
303 }
304
visit(const BroadcastPtr & v)305 void IRPrinter::visit(const BroadcastPtr& v) {
306 os() << "Broadcast(" << *v->value() << ", " << v->lanes() << ")";
307 }
308
visit(const IfThenElsePtr & v)309 void IRPrinter::visit(const IfThenElsePtr& v) {
310 os() << "IfThenElse(" << *v->condition() << ", " << *v->true_value() << ", "
311 << *v->false_value() << ")";
312 }
313
visit(const IntrinsicsPtr & v)314 void IRPrinter::visit(const IntrinsicsPtr& v) {
315 os() << v->func_name() << "(";
316 for (const auto i : c10::irange(v->nparams())) {
317 if (i > 0) {
318 os() << ", ";
319 }
320 os() << *v->param(i);
321 }
322 os() << ")";
323 }
324
visit(const TermPtr & v)325 void IRPrinter::visit(const TermPtr& v) {
326 os() << "Term(";
327 v->scalar()->accept(this);
328 for (const auto& t : v->variables()) {
329 os() << ",";
330 t->accept(this);
331 }
332 os() << ")";
333 }
334
visit(const PolynomialPtr & v)335 void IRPrinter::visit(const PolynomialPtr& v) {
336 bool first = true;
337 os() << "Polynomial(";
338 for (const auto& t : v->variables()) {
339 if (!first) {
340 os() << " + ";
341 }
342 first = false;
343 t->accept(this);
344 }
345
346 if (!first) {
347 os() << " + ";
348 }
349 v->scalar()->accept(this);
350 os() << ")";
351 }
352
visit(const RoundOffPtr & v)353 void IRPrinter::visit(const RoundOffPtr& v) {
354 os() << "RoundOff(";
355 v->lhs()->accept(this);
356 os() << ", ";
357 v->rhs()->accept(this);
358 os() << ")";
359 }
360
visit(const MaxTermPtr & v)361 void IRPrinter::visit(const MaxTermPtr& v) {
362 os() << "MaxTerm(";
363 if (v->scalar()) {
364 v->scalar()->accept(this);
365 os() << ", ";
366 }
367 for (size_t i = 0; i < v->variables().size(); ++i) {
368 v->variables()[i]->accept(this);
369 if (i < v->variables().size() - 1) {
370 os() << ", ";
371 }
372 }
373 os() << ")";
374 }
375
visit(const MinTermPtr & v)376 void IRPrinter::visit(const MinTermPtr& v) {
377 os() << "MinTerm(";
378 if (v->scalar()) {
379 v->scalar()->accept(this);
380 os() << ", ";
381 }
382 for (size_t i = 0; i < v->variables().size(); ++i) {
383 v->variables()[i]->accept(this);
384 if (i < v->variables().size() - 1) {
385 os() << ", ";
386 }
387 }
388 os() << ")";
389 }
390
visit(const ReduceOpPtr & v)391 void IRPrinter::visit(const ReduceOpPtr& v) {
392 os() << "ReduceOp(";
393 os() << *v->body() << ", ";
394
395 bool first = true;
396 os() << "reduce_args={";
397 for (const auto& d : v->reduce_args()) {
398 if (!first) {
399 os() << ", ";
400 }
401 os() << *d;
402 first = false;
403 }
404 os() << "})";
405 }
406
407 // === Stmt visitors below ===
408
409 // Newlines and indentation are handled solely by the `Block` printer. For
410 // each statement in a `Block` the printer will insert indentation before
411 // the statement and a newline after the statement.
412
visit(const StorePtr & v)413 void IRPrinter::visit(const StorePtr& v) {
414 // TODO: handle the mask
415 if (v->indices().empty()) {
416 os() << *v->base_handle() << " = " << *v->value() << ";";
417 return;
418 }
419
420 os() << *v->base_handle() << "[";
421 size_t i = 0;
422 for (const ExprPtr& ind : v->indices()) {
423 if (i++) {
424 os() << ", ";
425 }
426 ind->accept(this);
427 }
428 if (v->indices().empty()) {
429 os() << "0";
430 }
431 os() << "] = " << *v->value() << ";";
432 }
433
visit(const ForPtr & v)434 void IRPrinter::visit(const ForPtr& v) {
435 VarPtr var = v->var();
436 VarHandle vv(var);
437 os() << "for (" << dtypeToCppString(var->dtype()) << " " << vv << " = "
438 << ExprHandle(v->start()) << "; " << vv << " < " << ExprHandle(v->stop())
439 << "; " << vv << "++) ";
440 std::string loop_options_str = v->loop_options().ToString();
441 if (!loop_options_str.empty()) {
442 os() << " /* " << loop_options_str << " */";
443 }
444 if (v->body()) {
445 os() << *v->body();
446 } else {
447 os() << "{}";
448 }
449 }
450
visit(const BlockPtr & v)451 void IRPrinter::visit(const BlockPtr& v) {
452 os() << "{\n";
453 indent_++;
454
455 for (const StmtPtr& s : *v) {
456 emitIndent();
457 os() << *s << "\n";
458 }
459 indent_--;
460 emitIndent();
461 os() << "}";
462 }
463
visit(const AllocatePtr & v)464 void IRPrinter::visit(const AllocatePtr& v) {
465 os() << "Allocate(" << *v->buffer_var()
466 << "); // dtype=" << dtypeToCppString(v->dtype());
467 os() << ", dims=[";
468 const std::vector<ExprPtr>& dims = v->dims();
469 for (const auto i : c10::irange(dims.size())) {
470 if (i != 0) {
471 os() << ", ";
472 }
473 os() << *dims[i];
474 }
475 os() << "]";
476 }
477
visit(const FreePtr & v)478 void IRPrinter::visit(const FreePtr& v) {
479 os() << "Free(" << *v->buffer_var() << ");";
480 }
481
visit(const FreeExtPtr & v)482 void IRPrinter::visit(const FreeExtPtr& v) {
483 os() << "FreeExt(bufs={";
484 int i = 0;
485 for (const auto& buf : v->bufs()) {
486 if (i++ > 0) {
487 os() << ", ";
488 }
489 os() << *buf;
490 }
491
492 os() << "});";
493 }
494
visit(const PlacementAllocatePtr & v)495 void IRPrinter::visit(const PlacementAllocatePtr& v) {
496 os() << "Alias(" << *v->buf()->base_handle() << ","
497 << *v->buf_to_reuse()->base_handle() << ");";
498 }
499
visit(const LetPtr & v)500 void IRPrinter::visit(const LetPtr& v) {
501 os() << dtypeToCppString(v->var()->dtype()) << " " << *v->var();
502 os() << " = " << *v->value() << ";";
503 }
504
visit(const CondPtr & v)505 void IRPrinter::visit(const CondPtr& v) {
506 ExprPtr cond = v->condition();
507 StmtPtr true_stmt = v->true_stmt();
508 StmtPtr false_stmt = v->false_stmt();
509 if (!true_stmt) {
510 os() << "if (!" << *cond << ") ";
511 os() << *false_stmt;
512 } else {
513 os() << "if (" << *cond << ") ";
514 os() << *true_stmt;
515 if (false_stmt) {
516 os() << " else ";
517 os() << *false_stmt;
518 }
519 }
520 }
521
visit(const AtomicAddPtr & v)522 void IRPrinter::visit(const AtomicAddPtr& v) {
523 os() << "atomicAdd(&" << *v->base_handle() << "[";
524 size_t i = 0;
525 for (const ExprPtr& ind : v->indices()) {
526 if (i++) {
527 os() << ", ";
528 }
529 ind->accept(this);
530 }
531 if (v->indices().empty()) {
532 os() << "0";
533 }
534 os() << "], " << *v->value() << ");";
535 }
536
visit(const SyncThreadsPtr & v)537 void IRPrinter::visit(const SyncThreadsPtr& v) {
538 os() << "__syncthreads();";
539 }
540
visit(const ExternalCallPtr & v)541 void IRPrinter::visit(const ExternalCallPtr& v) {
542 os() << *v->buf() << " = " << v->func_name() << "(";
543
544 os() << "buf_args={";
545 int i = 0;
546 for (const BufPtr& buf_arg : v->buf_args()) {
547 if (i++ > 0) {
548 os() << ", ";
549 }
550 os() << *buf_arg;
551 }
552
553 os() << "}, args={";
554 i = 0;
555 for (const ExprPtr& arg : v->args()) {
556 if (i++ > 0) {
557 os() << ", ";
558 }
559 os() << *arg;
560 }
561 os() << "})";
562 }
563
visit(const ExternalCallWithAllocPtr & v)564 void IRPrinter::visit(const ExternalCallWithAllocPtr& v) {
565 int i = 0;
566 for (const auto& buf_out_arg : v->buf_out_args()) {
567 if (i++ > 0) {
568 os() << ", ";
569 }
570 os() << *buf_out_arg;
571 }
572
573 os() << " := " << v->func_name() << "(";
574
575 os() << "buf_args={";
576 i = 0;
577 for (const auto& buf_arg : v->buf_args()) {
578 if (i++ > 0) {
579 os() << ", ";
580 }
581 os() << *buf_arg;
582 }
583
584 os() << "}, args={";
585 i = 0;
586 for (const auto& arg : v->args()) {
587 if (i++ > 0) {
588 os() << ", ";
589 }
590 os() << *arg;
591 }
592 os() << "})";
593 }
594
emitIndent()595 void IRPrinter::emitIndent() {
596 os() << std::setw(2 * indent_) << "";
597 }
598
operator <<(std::ostream & stream,const ExprHandle & expr)599 std::ostream& operator<<(std::ostream& stream, const ExprHandle& expr) {
600 IRPrinter::PrinterStream* printer_stream =
601 dynamic_cast<IRPrinter::PrinterStream*>(&stream);
602 ExprHandle& mutable_expr = const_cast<ExprHandle&>(expr);
603 if (printer_stream != nullptr) {
604 mutable_expr.node()->accept(printer_stream->printer());
605 } else {
606 IRPrinter p(stream);
607 p.print(mutable_expr);
608 }
609 return stream;
610 }
611
operator <<(std::ostream & stream,const Expr & expr)612 std::ostream& operator<<(std::ostream& stream, const Expr& expr) {
613 IRPrinter::PrinterStream* printer_stream =
614 dynamic_cast<IRPrinter::PrinterStream*>(&stream);
615 Expr& mutable_expr = const_cast<Expr&>(expr);
616 if (printer_stream != nullptr) {
617 mutable_expr.accept(printer_stream->printer());
618 } else {
619 IRPrinter p(stream);
620 p.print(mutable_expr);
621 }
622 return stream;
623 }
624
operator <<(std::ostream & stream,const Stmt & stmt)625 std::ostream& operator<<(std::ostream& stream, const Stmt& stmt) {
626 IRPrinter::PrinterStream* printer_stream =
627 dynamic_cast<IRPrinter::PrinterStream*>(&stream);
628 Stmt& mutable_stmt = const_cast<Stmt&>(stmt);
629 if (printer_stream != nullptr) {
630 mutable_stmt.accept(printer_stream->printer());
631 } else {
632 IRPrinter p(stream);
633 p.print(mutable_stmt);
634 }
635 return stream;
636 }
637
operator <<(std::ostream & stream,const Tensor & t)638 std::ostream& operator<<(std::ostream& stream, const Tensor& t) {
639 stream << std::to_string(t);
640 return stream;
641 }
642
print(const ExprPtr & expr)643 void print(const ExprPtr& expr) {
644 if (expr) {
645 IRPrinter p(std::cout);
646 p.print(*expr);
647 } else {
648 std::cout << "(null expr)";
649 }
650 std::cout << "\n";
651 }
652
print(const StmtPtr & stmt)653 void print(const StmtPtr& stmt) {
654 if (stmt) {
655 IRPrinter p(std::cout);
656 p.print(*stmt);
657 } else {
658 std::cout << "(null stmt)\n";
659 }
660 }
661
print(const Tensor & t)662 void print(const Tensor& t) {
663 std::cout << std::to_string(t);
664 }
665
666 } // namespace torch::jit::tensorexpr
667
668 namespace std {
to_string(const ExprPtr & expr)669 std::string to_string(const ExprPtr& expr) {
670 std::ostringstream oss;
671 oss << *expr;
672 return oss.str();
673 }
674
to_string(const StmtPtr & stmt)675 std::string to_string(const StmtPtr& stmt) {
676 std::ostringstream oss;
677 oss << *stmt;
678 return oss.str();
679 }
680
to_string(const Tensor & t)681 std::string to_string(const Tensor& t) {
682 std::ostringstream oss;
683 // TODO: move this to Buf printer
684 oss << "Tensor " << t.buf()->name_hint() << "[";
685 for (const auto i : c10::irange(t.buf()->ndim())) {
686 if (i != 0) {
687 oss << ", ";
688 }
689 oss << *t.buf()->dim(i);
690 }
691 oss << "]:\n" << *t.stmt() << "\n";
692 return oss.str();
693 }
694 } // namespace std
695