xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/ir_printer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
2 
3 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
4 #include <torch/csrc/jit/tensorexpr/reduction.h>
5 #include <torch/csrc/jit/tensorexpr/tensor.h>
6 
7 #include <c10/util/irange.h>
8 
9 #include <iostream>
10 
11 namespace torch::jit::tensorexpr {
12 
dtypeToCppString(const Dtype & dtype)13 std::string IRPrinter::dtypeToCppString(const Dtype& dtype) {
14   return dtype.ToCppString();
15 }
16 
print(ExprHandle expr)17 void IRPrinter::print(ExprHandle expr) {
18   expr.node()->accept(this);
19 }
20 
print(Expr & expr)21 void IRPrinter::print(Expr& expr) {
22   expr.accept(this);
23 }
24 
print(Stmt & stmt)25 void IRPrinter::print(Stmt& stmt) {
26   stmt.accept(this);
27 }
to_string(CompareSelectOperation op)28 std::string IRPrinter::to_string(CompareSelectOperation op) {
29   switch (op) {
30     case CompareSelectOperation::kEQ:
31       return "==";
32     case CompareSelectOperation::kNE:
33       return "!=";
34     case CompareSelectOperation::kGT:
35       return ">";
36     case CompareSelectOperation::kGE:
37       return ">=";
38     case CompareSelectOperation::kLT:
39       return "<";
40     case CompareSelectOperation::kLE:
41       return "<=";
42     default:
43       throw std::runtime_error("invalid compare select operator");
44   }
45 }
46 
47 // TODO: change whether to include the parenthesis to the parent expression,
48 // we need to look at the operator precedence to make the output simpler.
49 template <
50     typename Op,
51     std::enable_if_t<std::is_same_v<
52         decltype(detail::bin_op_deducer(std::declval<Op>())),
53         void>>* = nullptr>
visitBinaryOp(NodePtr<Op> v,const std::string & op_str,IRPrinter * printer,bool parens=true)54 void visitBinaryOp(
55     NodePtr<Op> v,
56     const std::string& op_str,
57     IRPrinter* printer,
58     bool parens = true) {
59   std::ostream& os = printer->os();
60   int self_prec = getPrecedence(v->expr_type());
61   int lhs_prec = getPrecedence(v->lhs()->expr_type());
62   int rhs_prec = getPrecedence(v->rhs()->expr_type());
63 
64   if (lhs_prec >= self_prec) {
65     os << "(";
66   }
67   v->lhs()->accept(printer);
68   if (lhs_prec >= self_prec) {
69     os << ")";
70   }
71 
72   os << " " << op_str << " ";
73 
74   if (rhs_prec >= self_prec) {
75     os << "(";
76   }
77   v->rhs()->accept(printer);
78   if (rhs_prec >= self_prec) {
79     os << ")";
80   }
81 }
82 
visit(const AddPtr & v)83 void IRPrinter::visit(const AddPtr& v) {
84   visitBinaryOp(v, "+", this);
85 }
86 
visit(const SubPtr & v)87 void IRPrinter::visit(const SubPtr& v) {
88   visitBinaryOp(v, "-", this);
89 }
90 
visit(const MulPtr & v)91 void IRPrinter::visit(const MulPtr& v) {
92   visitBinaryOp(v, "*", this);
93 }
94 
visit(const DivPtr & v)95 void IRPrinter::visit(const DivPtr& v) {
96   visitBinaryOp(v, "/", this);
97 }
98 
visit(const AndPtr & v)99 void IRPrinter::visit(const AndPtr& v) {
100   visitBinaryOp(v, "&", this);
101 }
102 
visit(const OrPtr & v)103 void IRPrinter::visit(const OrPtr& v) {
104   visitBinaryOp(v, "|", this);
105 }
106 
visit(const XorPtr & v)107 void IRPrinter::visit(const XorPtr& v) {
108   visitBinaryOp(v, "^", this);
109 }
110 
visit(const LshiftPtr & v)111 void IRPrinter::visit(const LshiftPtr& v) {
112   visitBinaryOp(v, "<<", this);
113 }
114 
visit(const RshiftPtr & v)115 void IRPrinter::visit(const RshiftPtr& v) {
116   visitBinaryOp(v, ">>", this);
117 }
118 
visit(const ModPtr & v)119 void IRPrinter::visit(const ModPtr& v) {
120   if (v->dtype().is_integral()) {
121     visitBinaryOp(v, "%", this);
122   } else if (v->dtype().is_floating_point()) {
123     os() << "mod(" << *v->lhs() << ", " << *v->rhs() << ")";
124   } else {
125     throw std::runtime_error("invalid dtype: " + std::to_string(v->dtype()));
126   }
127 }
128 
visit(const MaxPtr & v)129 void IRPrinter::visit(const MaxPtr& v) {
130   os() << "Max(";
131   v->lhs()->accept(this);
132   os() << ", ";
133   v->rhs()->accept(this);
134   os() << ", " << (unsigned int)v->propagate_nans() << ")";
135 }
136 
visit(const MinPtr & v)137 void IRPrinter::visit(const MinPtr& v) {
138   os() << "Min(";
139   v->lhs()->accept(this);
140   os() << ", ";
141   v->rhs()->accept(this);
142   os() << ", " << (unsigned int)v->propagate_nans() << ")";
143 }
144 
visit(const CompareSelectPtr & v)145 void IRPrinter::visit(const CompareSelectPtr& v) {
146   CompareSelectOperation cmp_op = v->compare_select_op();
147   int self_prec = getPrecedence(v->expr_type());
148   int lhs_prec = getPrecedence(v->lhs()->expr_type());
149   int rhs_prec = getPrecedence(v->rhs()->expr_type());
150 
151   if (lhs_prec >= self_prec) {
152     os() << "(";
153   }
154   v->lhs()->accept(this);
155   if (lhs_prec >= self_prec) {
156     os() << ")";
157   }
158 
159   os() << to_string(cmp_op);
160 
161   if (rhs_prec >= self_prec) {
162     os() << "(";
163   }
164   v->rhs()->accept(this);
165   if (rhs_prec >= self_prec) {
166     os() << ")";
167   }
168   os() << " ? ";
169 
170   auto withParens = [&](const ExprPtr& e) {
171     auto prec = getPrecedence(e->expr_type());
172     if (prec >= self_prec) {
173       os() << "(";
174     }
175     e->accept(this);
176     if (prec >= self_prec) {
177       os() << ")";
178     }
179   };
180   withParens(v->ret_val1());
181   os() << " : ";
182   withParens(v->ret_val2());
183 }
184 
formatFPSuffix(std::ostream & os,double v)185 static void formatFPSuffix(std::ostream& os, double v) {
186   os << (v == std::ceil(v) ? ".0" : "");
187 }
188 
189 template <typename T>
formatFPSuffix(std::ostream & os,T v)190 static void formatFPSuffix(std::ostream& os, T v) {
191   os << (v == std::ceil(v) ? ".f" : "f");
192 }
193 
194 template <typename T, std::enable_if_t<std::is_floating_point_v<T>>* = nullptr>
formatImm(std::ostream & os,T v)195 static void formatImm(std::ostream& os, T v) {
196   const int precision = 16;
197   if (std::isnan(v)) {
198     os << "NAN";
199   } else if (std::isinf(v)) {
200     os << (v > 0 ? "POS_INFINITY" : "NEG_INFINITY");
201   } else {
202     os << std::setprecision(precision) << v;
203     formatFPSuffix(os, v);
204   }
205 }
206 
formatIntSuffix(std::ostream & os,int64_t v)207 static void formatIntSuffix(std::ostream& os, int64_t v) {
208   os << "ll";
209 }
210 
211 template <typename T>
formatIntSuffix(std::ostream & os,T v)212 static void formatIntSuffix(std::ostream& os, T v) {}
213 
214 template <typename T, std::enable_if_t<!std::is_floating_point_v<T>>* = nullptr>
formatImm(std::ostream & os,T v)215 static void formatImm(std::ostream& os, T v) {
216   os << +v;
217   formatIntSuffix(os, v);
218 }
219 
220 // NOLINTNEXTLINE
221 #define IMM_PRINT_VISIT(Type, Name)              \
222   void IRPrinter::visit(const Name##ImmPtr& v) { \
223     formatImm(os(), v->value());                 \
224   }
225 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT);
226 #undef IMM_PRINT_VISIT
227 
visit(const CastPtr & v)228 void IRPrinter::visit(const CastPtr& v) {
229   auto dtype = v->dtype();
230   os() << dtypeToCppString(dtype) << "(";
231   v->src_value()->accept(this);
232   os() << ")";
233 }
234 
visit(const BitCastPtr & v)235 void IRPrinter::visit(const BitCastPtr& v) {
236   auto dtype = v->dtype();
237   os() << "BitCast<" << dtype.ToCppString() << ">(";
238   v->src_value()->accept(this);
239   os() << ")";
240 }
241 
visit(const VarPtr & v)242 void IRPrinter::visit(const VarPtr& v) {
243   os() << name_manager_.get_unique_name(v);
244 }
245 
visit(const BufPtr & v)246 void IRPrinter::visit(const BufPtr& v) {
247   auto dtype = v->dtype();
248   os() << *v->base_handle();
249   os() << "(dtype=" << dtypeToCppString(dtype);
250   if (v->qscale()) {
251     os() << ", qscale=";
252     v->qscale()->accept(this);
253   }
254   if (v->qscale()) {
255     os() << ", qzero=";
256     v->qzero()->accept(this);
257   }
258   os() << ", sizes=[";
259   size_t i = 0;
260   for (const ExprPtr& s : v->dims()) {
261     if (i++) {
262       os() << ", ";
263     }
264     s->accept(this);
265   }
266   os() << "]";
267   os() << ", strides=[";
268   i = 0;
269   for (const ExprPtr& s : v->strides()) {
270     if (i++) {
271       os() << ", ";
272     }
273     s->accept(this);
274   }
275   os() << "]";
276 
277   os() << ")";
278 }
279 
visit(const RampPtr & v)280 void IRPrinter::visit(const RampPtr& v) {
281   os() << "Ramp(" << *v->base() << ", " << *v->stride() << ", " << v->lanes()
282        << ")";
283 }
284 
visit(const LoadPtr & v)285 void IRPrinter::visit(const LoadPtr& v) {
286   // TODO: support the mask case
287   if (v->indices().empty()) {
288     os() << *v->base_handle();
289   } else {
290     os() << *v->base_handle() << "[";
291     size_t i = 0;
292     for (const ExprPtr& ind : v->indices()) {
293       if (i++) {
294         os() << ", ";
295       }
296       ind->accept(this);
297     }
298     if (v->indices().empty()) {
299       os() << "0";
300     }
301     os() << "]";
302   }
303 }
304 
visit(const BroadcastPtr & v)305 void IRPrinter::visit(const BroadcastPtr& v) {
306   os() << "Broadcast(" << *v->value() << ", " << v->lanes() << ")";
307 }
308 
visit(const IfThenElsePtr & v)309 void IRPrinter::visit(const IfThenElsePtr& v) {
310   os() << "IfThenElse(" << *v->condition() << ", " << *v->true_value() << ", "
311        << *v->false_value() << ")";
312 }
313 
visit(const IntrinsicsPtr & v)314 void IRPrinter::visit(const IntrinsicsPtr& v) {
315   os() << v->func_name() << "(";
316   for (const auto i : c10::irange(v->nparams())) {
317     if (i > 0) {
318       os() << ", ";
319     }
320     os() << *v->param(i);
321   }
322   os() << ")";
323 }
324 
visit(const TermPtr & v)325 void IRPrinter::visit(const TermPtr& v) {
326   os() << "Term(";
327   v->scalar()->accept(this);
328   for (const auto& t : v->variables()) {
329     os() << ",";
330     t->accept(this);
331   }
332   os() << ")";
333 }
334 
visit(const PolynomialPtr & v)335 void IRPrinter::visit(const PolynomialPtr& v) {
336   bool first = true;
337   os() << "Polynomial(";
338   for (const auto& t : v->variables()) {
339     if (!first) {
340       os() << " + ";
341     }
342     first = false;
343     t->accept(this);
344   }
345 
346   if (!first) {
347     os() << " + ";
348   }
349   v->scalar()->accept(this);
350   os() << ")";
351 }
352 
visit(const RoundOffPtr & v)353 void IRPrinter::visit(const RoundOffPtr& v) {
354   os() << "RoundOff(";
355   v->lhs()->accept(this);
356   os() << ", ";
357   v->rhs()->accept(this);
358   os() << ")";
359 }
360 
visit(const MaxTermPtr & v)361 void IRPrinter::visit(const MaxTermPtr& v) {
362   os() << "MaxTerm(";
363   if (v->scalar()) {
364     v->scalar()->accept(this);
365     os() << ", ";
366   }
367   for (size_t i = 0; i < v->variables().size(); ++i) {
368     v->variables()[i]->accept(this);
369     if (i < v->variables().size() - 1) {
370       os() << ", ";
371     }
372   }
373   os() << ")";
374 }
375 
visit(const MinTermPtr & v)376 void IRPrinter::visit(const MinTermPtr& v) {
377   os() << "MinTerm(";
378   if (v->scalar()) {
379     v->scalar()->accept(this);
380     os() << ", ";
381   }
382   for (size_t i = 0; i < v->variables().size(); ++i) {
383     v->variables()[i]->accept(this);
384     if (i < v->variables().size() - 1) {
385       os() << ", ";
386     }
387   }
388   os() << ")";
389 }
390 
visit(const ReduceOpPtr & v)391 void IRPrinter::visit(const ReduceOpPtr& v) {
392   os() << "ReduceOp(";
393   os() << *v->body() << ", ";
394 
395   bool first = true;
396   os() << "reduce_args={";
397   for (const auto& d : v->reduce_args()) {
398     if (!first) {
399       os() << ", ";
400     }
401     os() << *d;
402     first = false;
403   }
404   os() << "})";
405 }
406 
407 // === Stmt visitors below ===
408 
409 // Newlines and indentation are handled solely by the `Block` printer.  For
410 // each statement in a `Block` the printer will insert indentation before
411 // the statement and a newline after the statement.
412 
visit(const StorePtr & v)413 void IRPrinter::visit(const StorePtr& v) {
414   // TODO: handle the mask
415   if (v->indices().empty()) {
416     os() << *v->base_handle() << " = " << *v->value() << ";";
417     return;
418   }
419 
420   os() << *v->base_handle() << "[";
421   size_t i = 0;
422   for (const ExprPtr& ind : v->indices()) {
423     if (i++) {
424       os() << ", ";
425     }
426     ind->accept(this);
427   }
428   if (v->indices().empty()) {
429     os() << "0";
430   }
431   os() << "] = " << *v->value() << ";";
432 }
433 
visit(const ForPtr & v)434 void IRPrinter::visit(const ForPtr& v) {
435   VarPtr var = v->var();
436   VarHandle vv(var);
437   os() << "for (" << dtypeToCppString(var->dtype()) << " " << vv << " = "
438        << ExprHandle(v->start()) << "; " << vv << " < " << ExprHandle(v->stop())
439        << "; " << vv << "++) ";
440   std::string loop_options_str = v->loop_options().ToString();
441   if (!loop_options_str.empty()) {
442     os() << " /* " << loop_options_str << " */";
443   }
444   if (v->body()) {
445     os() << *v->body();
446   } else {
447     os() << "{}";
448   }
449 }
450 
visit(const BlockPtr & v)451 void IRPrinter::visit(const BlockPtr& v) {
452   os() << "{\n";
453   indent_++;
454 
455   for (const StmtPtr& s : *v) {
456     emitIndent();
457     os() << *s << "\n";
458   }
459   indent_--;
460   emitIndent();
461   os() << "}";
462 }
463 
visit(const AllocatePtr & v)464 void IRPrinter::visit(const AllocatePtr& v) {
465   os() << "Allocate(" << *v->buffer_var()
466        << "); // dtype=" << dtypeToCppString(v->dtype());
467   os() << ", dims=[";
468   const std::vector<ExprPtr>& dims = v->dims();
469   for (const auto i : c10::irange(dims.size())) {
470     if (i != 0) {
471       os() << ", ";
472     }
473     os() << *dims[i];
474   }
475   os() << "]";
476 }
477 
visit(const FreePtr & v)478 void IRPrinter::visit(const FreePtr& v) {
479   os() << "Free(" << *v->buffer_var() << ");";
480 }
481 
visit(const FreeExtPtr & v)482 void IRPrinter::visit(const FreeExtPtr& v) {
483   os() << "FreeExt(bufs={";
484   int i = 0;
485   for (const auto& buf : v->bufs()) {
486     if (i++ > 0) {
487       os() << ", ";
488     }
489     os() << *buf;
490   }
491 
492   os() << "});";
493 }
494 
visit(const PlacementAllocatePtr & v)495 void IRPrinter::visit(const PlacementAllocatePtr& v) {
496   os() << "Alias(" << *v->buf()->base_handle() << ","
497        << *v->buf_to_reuse()->base_handle() << ");";
498 }
499 
visit(const LetPtr & v)500 void IRPrinter::visit(const LetPtr& v) {
501   os() << dtypeToCppString(v->var()->dtype()) << " " << *v->var();
502   os() << " = " << *v->value() << ";";
503 }
504 
visit(const CondPtr & v)505 void IRPrinter::visit(const CondPtr& v) {
506   ExprPtr cond = v->condition();
507   StmtPtr true_stmt = v->true_stmt();
508   StmtPtr false_stmt = v->false_stmt();
509   if (!true_stmt) {
510     os() << "if (!" << *cond << ") ";
511     os() << *false_stmt;
512   } else {
513     os() << "if (" << *cond << ") ";
514     os() << *true_stmt;
515     if (false_stmt) {
516       os() << " else ";
517       os() << *false_stmt;
518     }
519   }
520 }
521 
visit(const AtomicAddPtr & v)522 void IRPrinter::visit(const AtomicAddPtr& v) {
523   os() << "atomicAdd(&" << *v->base_handle() << "[";
524   size_t i = 0;
525   for (const ExprPtr& ind : v->indices()) {
526     if (i++) {
527       os() << ", ";
528     }
529     ind->accept(this);
530   }
531   if (v->indices().empty()) {
532     os() << "0";
533   }
534   os() << "], " << *v->value() << ");";
535 }
536 
visit(const SyncThreadsPtr & v)537 void IRPrinter::visit(const SyncThreadsPtr& v) {
538   os() << "__syncthreads();";
539 }
540 
visit(const ExternalCallPtr & v)541 void IRPrinter::visit(const ExternalCallPtr& v) {
542   os() << *v->buf() << " = " << v->func_name() << "(";
543 
544   os() << "buf_args={";
545   int i = 0;
546   for (const BufPtr& buf_arg : v->buf_args()) {
547     if (i++ > 0) {
548       os() << ", ";
549     }
550     os() << *buf_arg;
551   }
552 
553   os() << "}, args={";
554   i = 0;
555   for (const ExprPtr& arg : v->args()) {
556     if (i++ > 0) {
557       os() << ", ";
558     }
559     os() << *arg;
560   }
561   os() << "})";
562 }
563 
visit(const ExternalCallWithAllocPtr & v)564 void IRPrinter::visit(const ExternalCallWithAllocPtr& v) {
565   int i = 0;
566   for (const auto& buf_out_arg : v->buf_out_args()) {
567     if (i++ > 0) {
568       os() << ", ";
569     }
570     os() << *buf_out_arg;
571   }
572 
573   os() << " := " << v->func_name() << "(";
574 
575   os() << "buf_args={";
576   i = 0;
577   for (const auto& buf_arg : v->buf_args()) {
578     if (i++ > 0) {
579       os() << ", ";
580     }
581     os() << *buf_arg;
582   }
583 
584   os() << "}, args={";
585   i = 0;
586   for (const auto& arg : v->args()) {
587     if (i++ > 0) {
588       os() << ", ";
589     }
590     os() << *arg;
591   }
592   os() << "})";
593 }
594 
emitIndent()595 void IRPrinter::emitIndent() {
596   os() << std::setw(2 * indent_) << "";
597 }
598 
operator <<(std::ostream & stream,const ExprHandle & expr)599 std::ostream& operator<<(std::ostream& stream, const ExprHandle& expr) {
600   IRPrinter::PrinterStream* printer_stream =
601       dynamic_cast<IRPrinter::PrinterStream*>(&stream);
602   ExprHandle& mutable_expr = const_cast<ExprHandle&>(expr);
603   if (printer_stream != nullptr) {
604     mutable_expr.node()->accept(printer_stream->printer());
605   } else {
606     IRPrinter p(stream);
607     p.print(mutable_expr);
608   }
609   return stream;
610 }
611 
operator <<(std::ostream & stream,const Expr & expr)612 std::ostream& operator<<(std::ostream& stream, const Expr& expr) {
613   IRPrinter::PrinterStream* printer_stream =
614       dynamic_cast<IRPrinter::PrinterStream*>(&stream);
615   Expr& mutable_expr = const_cast<Expr&>(expr);
616   if (printer_stream != nullptr) {
617     mutable_expr.accept(printer_stream->printer());
618   } else {
619     IRPrinter p(stream);
620     p.print(mutable_expr);
621   }
622   return stream;
623 }
624 
operator <<(std::ostream & stream,const Stmt & stmt)625 std::ostream& operator<<(std::ostream& stream, const Stmt& stmt) {
626   IRPrinter::PrinterStream* printer_stream =
627       dynamic_cast<IRPrinter::PrinterStream*>(&stream);
628   Stmt& mutable_stmt = const_cast<Stmt&>(stmt);
629   if (printer_stream != nullptr) {
630     mutable_stmt.accept(printer_stream->printer());
631   } else {
632     IRPrinter p(stream);
633     p.print(mutable_stmt);
634   }
635   return stream;
636 }
637 
operator <<(std::ostream & stream,const Tensor & t)638 std::ostream& operator<<(std::ostream& stream, const Tensor& t) {
639   stream << std::to_string(t);
640   return stream;
641 }
642 
print(const ExprPtr & expr)643 void print(const ExprPtr& expr) {
644   if (expr) {
645     IRPrinter p(std::cout);
646     p.print(*expr);
647   } else {
648     std::cout << "(null expr)";
649   }
650   std::cout << "\n";
651 }
652 
print(const StmtPtr & stmt)653 void print(const StmtPtr& stmt) {
654   if (stmt) {
655     IRPrinter p(std::cout);
656     p.print(*stmt);
657   } else {
658     std::cout << "(null stmt)\n";
659   }
660 }
661 
print(const Tensor & t)662 void print(const Tensor& t) {
663   std::cout << std::to_string(t);
664 }
665 
666 } // namespace torch::jit::tensorexpr
667 
668 namespace std {
to_string(const ExprPtr & expr)669 std::string to_string(const ExprPtr& expr) {
670   std::ostringstream oss;
671   oss << *expr;
672   return oss.str();
673 }
674 
to_string(const StmtPtr & stmt)675 std::string to_string(const StmtPtr& stmt) {
676   std::ostringstream oss;
677   oss << *stmt;
678   return oss.str();
679 }
680 
to_string(const Tensor & t)681 std::string to_string(const Tensor& t) {
682   std::ostringstream oss;
683   // TODO: move this to Buf printer
684   oss << "Tensor " << t.buf()->name_hint() << "[";
685   for (const auto i : c10::irange(t.buf()->ndim())) {
686     if (i != 0) {
687       oss << ", ";
688     }
689     oss << *t.buf()->dim(i);
690   }
691   oss << "]:\n" << *t.stmt() << "\n";
692   return oss.str();
693 }
694 } // namespace std
695