xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/python_print.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/serialization/python_print.h>
2 
3 #include <algorithm>
4 
5 #include <ATen/core/ivalue.h>
6 #include <ATen/core/qualified_name.h>
7 #include <c10/util/Exception.h>
8 #include <c10/util/StringUtil.h>
9 #include <c10/util/irange.h>
10 #include <caffe2/serialize/versions.h>
11 #include <torch/csrc/jit/api/function_impl.h>
12 #include <torch/csrc/jit/api/module.h>
13 #include <torch/csrc/jit/frontend/error_report.h>
14 #include <torch/csrc/jit/frontend/versioned_symbols.h>
15 #include <torch/csrc/jit/ir/attributes.h>
16 #include <torch/csrc/jit/ir/ir.h>
17 #include <torch/csrc/jit/ir/ir_views.h>
18 #include <torch/csrc/jit/operator_upgraders/version_map.h>
19 #include <torch/csrc/jit/resource_guard.h>
20 #include <torch/csrc/jit/runtime/calculate_necessary_args.h>
21 #include <torch/csrc/jit/serialization/type_name_uniquer.h>
22 
23 using c10::QualifiedName;
24 
25 namespace torch::jit {
26 
isValidIdentifierChar(char c,size_t pos)27 static bool isValidIdentifierChar(char c, size_t pos) {
28   return islower(c) || isupper(c) || c == '_' || (pos > 0 && isdigit(c));
29 }
30 
isValidIdentifier(const std::string & name)31 static bool isValidIdentifier(const std::string& name) {
32   if (name.empty())
33     return false;
34   for (const auto i : c10::irange(name.size())) {
35     if (!isValidIdentifierChar(name[i], i))
36       return false;
37   }
38   return true;
39 }
40 
41 // some names are valid identifiers but off limits because
42 // they are keywords or namespaces used in the output
43 const static std::unordered_set<std::string> reserved_names = {
44     // identifiers in the environment while parsing
45     "_", // avoid the confusing unnamed _
46     "as",
47     "aten",
48     "attribute",
49     "CONSTANTS",
50     "fork",
51     "getattr",
52     "inf",
53     "nan",
54     "infj",
55     "nanj",
56     "ops",
57     "__torch__",
58     // the python keywords
59     "and",
60     "as",
61     "assert",
62     "async",
63     "await",
64     "break",
65     "class",
66     "continue",
67     "def",
68     "del",
69     "elif",
70     "else",
71     "except",
72     "False",
73     "finally",
74     "for",
75     "from",
76     "global",
77     "if",
78     "import",
79     "in",
80     "is",
81     "lambda",
82     "None",
83     "nonlocal",
84     "not",
85     "or",
86     "pass",
87     "raise",
88     "return",
89     "True",
90     "try",
91     "with",
92     "while",
93     "with",
94     "yield",
95     "uninitialized",
96     "unchecked_cast",
97 };
98 
99 // Helper to avoid duplicating class types
add(const c10::NamedTypePtr & type)100 void PrintDepsTable::add(const c10::NamedTypePtr& type) {
101   // Despite doing the linear search below, we don't want to do
102   // wasteful work and only try to insert each instance once.
103   if (!non_unique_.insert(type).second) {
104     return;
105   }
106   // Need to do actual equality comparison, not a pointer equality. This is
107   // because for some types (e.g. FunctionType), we may have multiple
108   // TypePtr's that represent the same underlying thing.
109   // TODO: this should be really swapped for something more efficient
110   auto it = std::find_if(
111       table_.cbegin(), table_.cend(), [&](const c10::NamedTypePtr& dep) {
112         return *dep == *type;
113       });
114 
115   if (it == table_.cend()) {
116     table_.push_back(type);
117   }
118 }
119 
120 struct PythonPrintImpl {
121   using SourceRangeStack = std::vector<SourceRange>;
122   SourceRangeStack source_range_stack_ = {SourceRange()};
123 
124   struct WithSourceRange {
WithSourceRangetorch::jit::PythonPrintImpl::WithSourceRange125     explicit WithSourceRange(SourceRangeStack* stack, Node* n) : stack(stack) {
126       TORCH_INTERNAL_ASSERT(stack);
127       if (auto gen_source = n->sourceRange().findSourceRangeThatGenerated()) {
128         stack->push_back(std::move(gen_source.value()));
129       } else {
130         stack->push_back(n->sourceRange());
131       }
132     }
133 
~WithSourceRangetorch::jit::PythonPrintImpl::WithSourceRange134     ~WithSourceRange() {
135       stack->pop_back();
136     }
137 
138     SourceRangeStack* stack;
139   };
140 
141   class TaggedStringStream {
142    public:
TaggedStringStream(const SourceRangeStack * srs)143     TaggedStringStream(const SourceRangeStack* srs) : srs_(srs) {}
144 
operator <<(const std::string & s)145     TaggedStringStream& operator<<(const std::string& s) {
146       // This prevents having redundant entries at the same offset,
147       // which can happen for example in printValueList when begin
148       // and end are the empty string.
149       if (s.empty()) {
150         return *this;
151       }
152 
153       if (ranges_.empty() || ranges_.back().range != srs_->back()) {
154         ranges_.emplace_back((size_t)oss_.tellp(), srs_->back());
155       }
156       oss_ << s;
157       return *this;
158     }
159 
operator <<(const TaggedStringStream & rhs)160     TaggedStringStream& operator<<(const TaggedStringStream& rhs) {
161       for (const auto& range : rhs.ranges_) {
162         if (ranges_.empty() || ranges_.back().range != range.range) {
163           ranges_.emplace_back((size_t)oss_.tellp() + range.bytes, range.range);
164         }
165       }
166       oss_ << rhs.oss_.str();
167       return *this;
168     }
169 
170     // This overload is here to prevent people from shooting themselves in the
171     // foot. I would be highly surprised if someone actually wanted to write out
172     // the address of a TaggedStringStream in the pretty print.
operator <<(const std::shared_ptr<TaggedStringStream> & rhs)173     TaggedStringStream& operator<<(
174         const std::shared_ptr<TaggedStringStream>& rhs) {
175       (*this) << *rhs;
176       return *this;
177     }
178 
179     template <typename T>
operator <<(const T & t)180     TaggedStringStream& operator<<(const T& t) {
181       if (ranges_.empty() || ranges_.back().range != srs_->back()) {
182         ranges_.emplace_back((size_t)oss_.tellp(), srs_->back());
183       }
184       oss_ << t;
185       return *this;
186     }
187 
str() const188     std::string str() const {
189       return oss_.str();
190     }
191 
ranges() const192     const std::vector<TaggedRange>& ranges() const {
193       return ranges_;
194     }
195 
196    private:
197     std::ostringstream oss_;
198     std::vector<TaggedRange> ranges_;
199     const SourceRangeStack* srs_;
200   };
201 
202   // scanValue, scanNode, scanBlock:
203   // decide if it is safe to omit the output of a temporary variable,
204   // and inline the expression into its use
205   // we only do this if
206   // (1) it is a constant, or
207   // (2) the temporary is unnamed, is single output, is used once,
208   //     and would appear in the same order when the expression tree is
209   //     reparsed.
210   // The last case can be checked
211   // because when we emit a expresion tree in the parser,
212   // we do a left-to-right postorder traversal of the expression tree (emit
213   // children, then emit op). The reverse of this is a right-to-left preorder
214   // traversal of the tree. By doing a right-to-left preorder traversal of the
215   // inputs of a node, while also scanning the list of emitted nodes backward,
216   // we can see if they line up with what would happen when parsed the node as
217   // an expression. While they line up we collapse them into an inline
218   // expression.
219 
220   // The inductive step is that the right-most input should be produced by the
221   // node immediatly before the current node if it is in tree order.
222 
canInlinetorch::jit::PythonPrintImpl223   bool canInline(Value* v) {
224     Node* n = v->node();
225     // there must be only 1 values, otherwise we need an assignment to handle
226     // the multiple outout values
227     if (n->outputs().size() != 1)
228       return false;
229     // if it is used more than once, then we need a variable
230     if (v->uses().size() != 1)
231       return false;
232     auto use = v->uses().at(0);
233     // if it has a name set, then it was written as a variable so preserve that
234     // unless it is being fed directly to the end of the block.
235     // in which case it is not as useful to give it a name just to return it
236     if (v->hasDebugName() && use.user->kind() != prim::Return)
237       return false;
238     // don't try to inline control blocks
239     if (!n->blocks().empty())
240       return false;
241     // if it is a loop-carried input, we need a variable
242     // otherwise the condition or trip count may be emitted in the wrong order
243     // w.r.t. to it
244     if (use.user->kind() == prim::Loop && use.offset >= 2)
245       return false;
246 
247     // subgraph may use this more than once, so disable inlining
248     if (use.user->kind() == prim::fork || use.user->kind() == prim::rpc_async ||
249         use.user->kind() == prim::rpc_sync ||
250         use.user->kind() == prim::rpc_remote)
251       return false;
252 
253     // isinstance appearing in an if expression
254     // causes type refinement to occur, but we have
255     // already handled the refinement and inserted cast
256     // expressions. By not inlining it into the if condition,
257     // we prevent it from happening again.
258     if (v->node()->kind() == prim::isinstance) {
259       return false;
260     }
261 
262     return true;
263   }
264 
265   // block_point is the current node in the reverse linear scan of the emitted
266   // nodes v is the current value in the tree traversal that may match with
267   // block_point's output.
scanValuetorch::jit::PythonPrintImpl268   Node* scanValue(Node* block_point, Value* v) {
269     Node* n = v->node();
270     AT_ASSERT(n->kind() == prim::Constant || output_inline_.count(n) == 0);
271 
272     if (n == block_point &&
273         canInline(v)) { // the node must be at the expected point of the typical
274                         // tree traversal
275       // recursively see if we can inline the inputs to this input
276       block_point = scanNode(block_point);
277       output_inline_.insert(n);
278     } else if (n->kind() == prim::Constant) {
279       // constant nodes can always be inlined, we will de-dup them on parsing
280       // and put them at the top of the function regardless
281       output_inline_.insert(n);
282     }
283     return block_point;
284   }
previousNonConstanttorch::jit::PythonPrintImpl285   Node* previousNonConstant(Node* n) {
286     do {
287       n = n->prev();
288     } while (n->kind() == prim::Constant);
289     return n;
290   }
291 
scanNodetorch::jit::PythonPrintImpl292   Node* scanNode(Node* n) {
293     // don't bother to scan nodes we have already determined to be inline
294     if (output_inline_.count(n)) {
295       return n;
296     }
297     for (auto b : n->blocks()) {
298       scanBlock(b);
299     }
300     Node* block_point = previousNonConstant(n);
301     for (auto it = n->inputs().rbegin(), end = n->inputs().rend(); it != end;
302          ++it) {
303       block_point = scanValue(block_point, *it);
304     }
305     return block_point;
306   }
307 
scanBlocktorch::jit::PythonPrintImpl308   void scanBlock(Block* b) {
309     scanNode(b->return_node());
310     for (auto node : b->nodes().reverse()) {
311       scanNode(node);
312     }
313   }
314 
getOrAddConstanttorch::jit::PythonPrintImpl315   size_t getOrAddConstant(at::IValue val) {
316     // XXX - N^2 warning. This code does the exact same thing as
317     // ConstantPool, which is also N^2 in the size of the constants,
318     // because it doesn't hash any information about the tensors.
319     // We will probably need to optimize this at some point using hashing.
320     if (val.isTensor()) {
321       auto& t = val.toTensor();
322       for (const auto i : c10::irange(constant_table_.size())) {
323         if (!constant_table_[i].isTensor()) {
324           continue;
325         }
326         auto& t2 = constant_table_[i].toTensor();
327         if (t.options().type_equal(t2.options()) && t.equal(t2)) {
328           return i;
329         }
330       }
331     }
332     constant_table_.emplace_back(std::move(val));
333     return constant_table_.size() - 1;
334   }
335 
336   std::unordered_set<Node*> seen_constants;
buildConstantListtorch::jit::PythonPrintImpl337   void buildConstantList(Node* n, std::vector<Node*>& constants) {
338     for (auto input : n->inputs()) {
339       if (input->node()->kind() == prim::Constant &&
340           seen_constants.count(input->node()) == 0) {
341         constants.push_back(input->node());
342         seen_constants.insert(input->node());
343       }
344     }
345     for (auto b : n->blocks()) {
346       buildConstantList(b, constants);
347     }
348   }
buildConstantListtorch::jit::PythonPrintImpl349   void buildConstantList(Block* b, std::vector<Node*>& constants) {
350     for (auto n : b->nodes())
351       buildConstantList(n, constants);
352     buildConstantList(b->return_node(), constants);
353   }
354 
355   // get a new name unique across calls to debugName() and
356   // anything we have used.
357   std::unordered_map<std::string, size_t> next_id;
358 
genNameImpltorch::jit::PythonPrintImpl359   std::string genNameImpl(
360       const std::string& candidate,
361       std::unordered_set<std::string>& used) {
362     std::string name = candidate;
363     while (used.count(name) || reserved_names.count(name)) {
364       // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
365       name = candidate + std::to_string(next_id[name]++);
366     }
367     used.insert(name);
368     return name;
369   }
genNametorch::jit::PythonPrintImpl370   std::string genName(const std::string& candidate) {
371     return genNameImpl(candidate, used_names_);
372   }
373 
374   // unique names might not be valid identifiers,
375   // force them to be by rewriting them
makeValidIdentifiertorch::jit::PythonPrintImpl376   static std::string makeValidIdentifier(const std::string& candidate) {
377     std::stringstream ss;
378     if (candidate.empty() || isdigit(candidate[0]))
379       ss << "_";
380     for (char c : candidate) {
381       if (isupper(c) || islower(c) || isdigit(c) || c == '_')
382         ss << c;
383       else
384         ss << '_';
385     }
386     return ss.str();
387   }
388   // if we have to assign 'v' a name, what should it be?
389   // use the debugName if it was set, otherwise generate a name.
genUniqueNameFortorch::jit::PythonPrintImpl390   std::string genUniqueNameFor(Value* v) {
391     return genName(
392         v->hasDebugName() ? makeValidIdentifier(v->debugNameBase()) : "_");
393   }
394 
395   // map from Value to how it should be printed at each use
396   std::unordered_map<Value*, std::shared_ptr<TaggedStringStream>> expr_table_;
397   std::unordered_map<Value*, std::string> ident_refs_;
398 
399   // NB: we MUST pass around the shared pointers to these streams by value.
400   // There is an interaction in splitLongInlines where the string value for
401   // both the RHS and the LHS of an expression are live at the same time,
402   // however the value for the RHS is overwritten in the table.
useOftorch::jit::PythonPrintImpl403   std::shared_ptr<TaggedStringStream> useOf(Value* v) const {
404     // Ident refs take precedent over expression refs, since presence in
405     // the ident ref table indicates we have already emitted a statement
406     // assigning the given value.
407     if (ident_refs_.count(v)) {
408       auto rv = std::make_shared<TaggedStringStream>(&source_range_stack_);
409       (*rv) << ident_refs_.at(v);
410       return rv;
411     }
412     if (expr_table_.count(v)) {
413       return expr_table_.at(v);
414     }
415     TORCH_INTERNAL_ASSERT(
416         false,
417         "Value (debug name: \"",
418         v->debugName(),
419         "\") was not present in either expressions table or ident refs table");
420   }
assignValuetorch::jit::PythonPrintImpl421   void assignValue(Value* v, const std::string& s) {
422     ident_refs_[v] = s;
423   }
assignValuetorch::jit::PythonPrintImpl424   void assignValue(Value* v, std::shared_ptr<TaggedStringStream> s) {
425     expr_table_[v] = std::move(s);
426   }
assignValuetorch::jit::PythonPrintImpl427   void assignValue(Value* v, Value* w) {
428     assignValue(v, useOf(w));
429   }
assignValuesToTheirUniqueNamestorch::jit::PythonPrintImpl430   void assignValuesToTheirUniqueNames(at::ArrayRef<Value*> values) {
431     for (auto v : values) {
432       assignValue(v, genUniqueNameFor(v));
433     }
434   }
435 
436   size_t level = 0;
437   // indent to the current indent level
indenttorch::jit::PythonPrintImpl438   TaggedStringStream& indent() {
439     for (const auto i : c10::irange(level)) {
440       (void)i; // Suppress unused variable warning
441       body_ << "  ";
442     }
443     return body_;
444   }
445 
WithIndentedtorch::jit::PythonPrintImpl446   ResourceGuard WithIndented() {
447     level++;
448     return ResourceGuard([this] { level--; });
449   }
450 
451   template <class T0, class T1, class F>
zipWithtorch::jit::PythonPrintImpl452   void zipWith(at::ArrayRef<T0> list_a, at::ArrayRef<T1> list_b, F action)
453       const {
454     auto it_a = list_a.begin();
455     auto it_b = list_b.begin();
456 
457     if (list_a.size() != list_b.size()) {
458       AT_ERROR("Python printer expected 2 lists of same size");
459     }
460 
461     for (; it_a != list_a.end(); ++it_a, ++it_b) {
462       action(*it_a, *it_b);
463     }
464   }
465 
printValueListtorch::jit::PythonPrintImpl466   void printValueList(
467       TaggedStringStream& stmt,
468       at::ArrayRef<Value*> list,
469       const char* begin = "",
470       const char* end = "") {
471     stmt << begin;
472     auto delimiter = "";
473     for (auto* value : list) {
474       stmt << delimiter;
475       stmt << useOf(value);
476       delimiter = ", ";
477     }
478     stmt << end;
479   }
480 
printValueIndextorch::jit::PythonPrintImpl481   void printValueIndex(TaggedStringStream& stmt, at::ArrayRef<Value*> inputs) {
482     const std::string val_name = useOf(inputs[0])->str();
483     if (isValidIdentifier(val_name)) {
484       stmt << val_name;
485     } else {
486       stmt << "(" << val_name << ")";
487     }
488     stmt << "[";
489     stmt << useOf(inputs[1]);
490     stmt << "]";
491   }
492 
printDicttorch::jit::PythonPrintImpl493   void printDict(
494       TaggedStringStream& stmt,
495       at::ArrayRef<Value*> key_value_pairs,
496       const char* begin = "{",
497       const char* end = "}") {
498     stmt << begin;
499     auto delimiter = "";
500     for (size_t i = 0; i < key_value_pairs.size(); i += 2) {
501       stmt << delimiter;
502       auto key = key_value_pairs[i];
503       auto value = key_value_pairs[i + 1];
504 
505       stmt << useOf(key) << ": " << useOf(value);
506 
507       delimiter = ", ";
508     }
509     stmt << end;
510   }
511 
printAssignmenttorch::jit::PythonPrintImpl512   void printAssignment(at::ArrayRef<Value*> lhs, at::ArrayRef<Value*> rhs) {
513     if (lhs.empty()) {
514       return;
515     }
516     indent();
517     printValueList(body_, lhs);
518     // We need to preserve Union/Optional type annotations, but only if
519     // we're not assigning values as part of a tuple unpacking statement
520     // (Python doesn't allow type annotations in multiple assignment)
521     if (lhs.size() == 1) {
522       Value* v = lhs.at(0);
523       if (!annotated_unions_.count(v) && !expr_table_.count(v) &&
524           (v->type()->kind() == UnionType::Kind ||
525            v->type()->kind() == OptionalType::Kind)) {
526         body_ << " : " << v->type()->annotation_str();
527         annotated_unions_.insert(v);
528       }
529     }
530     body_ << " = ";
531     // or if value is being assigned to something of a union type
532     printValueList(body_, rhs);
533     body_ << "\n";
534   }
535 
requiresAnnotationtorch::jit::PythonPrintImpl536   bool requiresAnnotation(Value* lhs, Value* rhs) {
537     if (lhs->type()->kind() == UnionType::Kind ||
538         lhs->type()->kind() == OptionalType::Kind) {
539       return annotated_unions_.insert(lhs).second;
540     } else {
541       return *lhs->type() != *rhs->type();
542     }
543   }
544 
printAnnotatedAssignmenttorch::jit::PythonPrintImpl545   void printAnnotatedAssignment(
546       at::ArrayRef<Value*> lhs,
547       at::ArrayRef<Value*> rhs) {
548     for (const auto i : c10::irange(lhs.size())) {
549       indent();
550       body_ << useOf(lhs[i]);
551       if (requiresAnnotation(lhs[i], rhs[i])) {
552         body_ << ": " << lhs[i]->type()->annotation_str(type_printer_);
553       }
554       body_ << " = " << useOf(rhs[i]) << "\n";
555     }
556   }
557 
printIftorch::jit::PythonPrintImpl558   void printIf(IfView stmt) {
559     assignValuesToTheirUniqueNames(stmt.outputs());
560     indent() << "if " << useOf(stmt.cond()) << ":\n";
561     {
562       auto guard = WithIndented();
563       // Print node contents
564       printBlock(stmt.thenBlock(), !stmt.outputs().empty());
565       printAssignment(stmt.outputs(), stmt.thenOutputs());
566     }
567     indent() << "else:\n";
568     {
569       auto guard = WithIndented();
570       printBlock(stmt.elseBlock(), !stmt.outputs().empty());
571       printAssignment(stmt.outputs(), stmt.elseOutputs());
572     }
573   }
574 
printLooptorch::jit::PythonPrintImpl575   void printLoop(LoopView stmt) {
576     // Loop carried dependencies are handled by assigning their initial
577     // values to the node->outputs() before the loop,
578     // and assign node->outputs() to the new values at the end of each trip.
579 
580     auto loop_type = stmt.loopType();
581     if (loop_type == LoopView::ModifiedLoop) {
582       throw(
583           ErrorReport(stmt.node()->sourceRange())
584           << "loop cannot be printed as python "
585           << "because it has gone through an optimization "
586           << "that combined while and for loops. File a bug");
587     }
588 
589     bool emit_as_for_loop = loop_type == LoopView::For;
590 
591     assignValuesToTheirUniqueNames(stmt.carriedOutputs());
592     // Add aliases for loop-carried dependencies
593     zipWith(
594         stmt.bodyCarriedInputs(), // Start at 1 to ignore trip count
595         stmt.carriedOutputs(),
596         [&](Value* block_input, Value* node_output) {
597           assignValue(block_input, node_output);
598         });
599 
600     // Print initial assignments of loop node outputs = loop node inputs
601     printAnnotatedAssignment(stmt.carriedOutputs(), stmt.carriedInputs());
602 
603     assignValuesToTheirUniqueNames(stmt.currentTripCount());
604     // Loop header
605     if (emit_as_for_loop) {
606       indent();
607       body_ << "for " << useOf(stmt.currentTripCount()) << " in range("
608             << useOf(stmt.maxTripCount()) << "):\n";
609     } else {
610       // note: trip_count_in_block is unused because this is a while loop,
611       // so we reuse the Value* as a stand-in for the loop condition
612       printAssignment(stmt.currentTripCount(), stmt.inputCond());
613       indent();
614       body_ << "while " << useOf(stmt.currentTripCount()) << ":\n";
615     }
616     // Loop body
617     {
618       ResourceGuard indent = WithIndented();
619       // Update block outputs to block inputs for next loop iteration
620       // skip the assignment to the new condition in for loops because
621       // the condition is always True
622       size_t offset = emit_as_for_loop ? 1 : 0;
623       auto body_block = stmt.bodyBlock();
624       ArrayRef<Value*> loop_carried_block_inputs =
625           body_block->inputs().slice(offset);
626       printBlock(body_block, !loop_carried_block_inputs.empty());
627       printAssignment(
628           loop_carried_block_inputs, body_block->outputs().slice(offset));
629     }
630   }
631 
isLongLinetorch::jit::PythonPrintImpl632   bool isLongLine(const std::string& str) {
633     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
634     return str.size() + level * 2 >= 40;
635   }
636 
isLongInlinetorch::jit::PythonPrintImpl637   bool isLongInline(Node* node) {
638     return output_inline_.count(node) &&
639         isLongLine(useOf(node->output())->str());
640   }
641 
isNonConstantInlinetorch::jit::PythonPrintImpl642   bool isNonConstantInline(Value* input) {
643     return input->node()->kind() != prim::Constant &&
644         output_inline_.count(input->node());
645   }
646 
647   // [reordering of inlines]
648   // We inline anything that is semantically legal to inline, but sometimes
649   // we find that these lines get too long. In that case we break the lines
650   /// and it  is important that we un-inline all the inputs preceeding the long
651   /// input:
652   //   r = foo(x.add_(b), some_long + expression)
653   //  wrong!
654   //   _0 = some_long + expression
655   //   r = foo(x.add_(b), _0) # wrong! _0 runs before mutating add_
656   // legal!
657   //   _0 = x.add_(b)
658   //   _1 = some_long + expression
659   //   r = foo(_0, _1)
660 
splitLongInlinestorch::jit::PythonPrintImpl661   void splitLongInlines(Value* v) {
662     std::vector<Value*> to_split_reversed;
663     Use u = v->uses().at(0);
664     scanLongInlines(u.user, u.offset, to_split_reversed);
665     for (auto it = to_split_reversed.rbegin(), end = to_split_reversed.rend();
666          it != end;
667          ++it) {
668       printOutputDefinition((*it)->node(), *useOf(*it));
669     }
670   }
671 
scanLongInlinestorch::jit::PythonPrintImpl672   void scanLongInlines(
673       Node* user,
674       size_t offset,
675       std::vector<Value*>& to_split_reversed) {
676     auto it = visited_split_inline_uses_.find(user);
677     bool present = it != visited_split_inline_uses_.end();
678     for (int64_t i = static_cast<int64_t>(offset);
679          i >= (present ? it->second + 1 : 0);
680          --i) {
681       Value* prev_arg = user->input(i);
682       if (isNonConstantInline(prev_arg)) {
683         to_split_reversed.push_back(prev_arg);
684       }
685     }
686     visited_split_inline_uses_[user] = static_cast<int64_t>(offset);
687     if (!present && output_inline_.count(user)) {
688       Use u = user->output()->uses().at(0);
689       scanLongInlines(u.user, u.offset - 1, to_split_reversed);
690       // -1 because the actual use is still being
691       // emitted so it cannot be split
692     }
693   }
694 
695   template <typename T>
printOutputDefinitiontorch::jit::PythonPrintImpl696   void printOutputDefinition(Node* node, const T& expr) {
697     assignValuesToTheirUniqueNames(node->outputs());
698     indent();
699     // Print outputs
700     if (!node->outputs().empty()) {
701       printValueList(body_, node->outputs());
702       body_ << " = ";
703     }
704     body_ << expr << "\n";
705   }
706 
707   // Recursively check contained types for any class dependencies
registerClassDependenciestorch::jit::PythonPrintImpl708   void registerClassDependencies(const TypePtr& type) {
709     if (const auto classType = type->cast<ClassType>()) {
710       deps_table_.add(classType);
711     } else if (const auto tupleType = type->cast<TupleType>()) {
712       if (tupleType->name()) {
713         deps_table_.add(tupleType);
714       }
715     } else if (const auto interfaceType = type->cast<InterfaceType>()) {
716       deps_table_.add(interfaceType);
717     } else if (const auto enumType = type->cast<EnumType>()) {
718       deps_table_.add(enumType);
719     }
720     for (const auto& containedType : type->containedTypes()) {
721       registerClassDependencies(containedType);
722     }
723   }
scanTypeDependenciestorch::jit::PythonPrintImpl724   void scanTypeDependencies(Node* node) {
725     // Check for class dependencies. If this node inputs or outputs a class
726     // type, we need to add it to our table of dependencies.
727     for (const auto input : node->inputs()) {
728       registerClassDependencies(input->type());
729     }
730     for (const auto output : node->outputs()) {
731       registerClassDependencies(output->type());
732     }
733     for (const auto& name : node->attributeNames()) {
734       switch (node->kindOf(name)) {
735         case AttributeKind::ty:
736           registerClassDependencies(node->ty(name));
737           break;
738         case AttributeKind::tys:
739           for (const TypePtr& t : node->tys(name)) {
740             registerClassDependencies(t);
741           }
742           break;
743         default:
744           // noop
745           break;
746       }
747     }
748   }
749 
checkVersiontorch::jit::PythonPrintImpl750   void checkVersion(Node* node) {
751     if (auto schema = node->maybeSchema()) {
752       auto schema_name = getFullSchemaName(*schema);
753       auto version_entry = get_operator_version_map().find(schema_name);
754       if (version_entry != get_operator_version_map().end()) {
755         const auto& entry = version_entry->second;
756         // TODO (tugsuu) move this calculation into a separate step.
757         uint64_t current_version = entry[entry.size() - 1].bumped_at_version;
758         uint64_t legacy_version_map_version =
759             get_min_version_for_kind(node->kind());
760 
761         // True means we solely calculate based on upgrader version
762         if (get_version_calculator_flag()) {
763           min_version_ = std::max(min_version_, current_version);
764         } else {
765           if (legacy_version_map_version != 0) {
766             min_version_ = std::max(min_version_, legacy_version_map_version);
767           } else {
768             min_version_ = std::max(min_version_, current_version);
769           }
770         }
771       }
772     }
773   }
774 
printNodetorch::jit::PythonPrintImpl775   void printNode(Node* node, bool print_const) {
776     WithSourceRange guard(&source_range_stack_, node);
777     scanTypeDependencies(node);
778     checkVersion(node);
779     if (!print_const && node->kind() == prim::Constant)
780       return;
781     switch (node->kind()) {
782       case prim::Return:
783         if (enforce_importable_ && node->inputs().size() != 1) {
784           throw(
785               ErrorReport(node->sourceRange())
786               << "Exportable methods must have a single return value. "
787               << "Normal use of ScriptMethods should enforce this");
788         }
789         if (!node->inputs().empty()) {
790           indent();
791           body_ << "return ";
792           printValueList(body_, node->inputs());
793           body_ << "\n";
794         }
795         break;
796       case prim::Loop:
797         printLoop(LoopView(node));
798         break;
799       case prim::If:
800         printIf(IfView(node));
801         break;
802       case prim::TupleUnpack:
803       case prim::ListUnpack:
804         assignValuesToTheirUniqueNames(node->outputs());
805         indent();
806         // TupleUnpack(unpacked) turns into an assignment op that forces
807         // the unpack to be inserted when parsed back in:
808         // a, b, = unpacked
809         // a, = unpacked # trailing comma forces an unpack to happen
810         if (!node->outputs().empty()) {
811           printValueList(body_, node->outputs(), "", ", = ");
812         }
813         body_ << useOf(node->input()) << "\n";
814         break;
815       case prim::SetAttr: {
816         const auto obj = node->inputs().at(0);
817         const auto newVal = node->inputs().at(1);
818         const auto type = obj->type()->expect<ClassType>();
819         const auto& attrname = node->s(attr::name);
820         indent();
821         body_ << useOf(obj) << "." << attrname << " = " << useOf(newVal)
822               << "\n";
823       } break;
824       case prim::fork: {
825         // the subgraph gets emitted as another function
826         auto name = genName("__forked_function");
827         std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
828         indent();
829         body_ << "def " << name << "():\n";
830         for (size_t i = 0; i < node->inputs().size(); ++i) {
831           assignValue(graph->inputs().at(i), node->inputs().at(i));
832         }
833         printBody(graph->block());
834         std::stringstream ss;
835         ss << "fork(" << name << ")";
836         printOutputDefinition(node, ss.str());
837       } break;
838       case prim::awaitable: {
839         // the subgraph gets emitted as another function
840         auto name = genName("__awaitable_function");
841         auto graph = node->g(attr::Subgraph);
842         indent();
843         body_ << "def " << name << "():\n";
844         for (size_t i = 0; i < node->inputs().size(); ++i) {
845           assignValue(graph->inputs().at(i), node->inputs().at(i));
846         }
847         printBody(graph->block());
848         std::stringstream ss;
849         ss << "awaitable(" << name << ")";
850         printOutputDefinition(node, ss.str());
851       } break;
852       case prim::Enter: {
853         const auto in = node->inputs().at(0);
854         const auto out = node->outputs().at(0);
855         indent();
856         body_ << "with " << useOf(in);
857         if (!out->uses().empty()) {
858           assignValue(out, genUniqueNameFor(out));
859           body_ << " as " << useOf(out);
860         }
861         body_ << ":\n";
862         level++;
863       } break;
864       case prim::Exit: {
865         // If the previous node is a prim::Enter, the with block the generated
866         // this Enter/Exit pair must have been empty.
867         if (node->prev()->kind() == prim::Enter) {
868           indent();
869           body_ << "pass\n";
870         }
871         level--;
872       } break;
873       case prim::Closure: {
874         if (enforce_importable_) {
875           throw(
876               ErrorReport(node->sourceRange())
877               << "closures are not exportable");
878         }
879         assignValuesToTheirUniqueNames(node->outputs());
880         auto name = useOf(node->output())->str();
881         std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
882         indent();
883         body_ << "def " << name << "(";
884         assignValuesToTheirUniqueNames(graph->inputs());
885         for (size_t i = 0; i < graph->inputs().size(); ++i) {
886           Value* v = graph->inputs().at(i);
887           if (i > 0) {
888             body_ << ", ";
889           }
890           body_ << useOf(v) << ": " << v->type()->annotation_str(type_printer_);
891         }
892         body_ << "):\n";
893         printBody(graph->block());
894       } break;
895       case prim::ModuleContainerIndex: {
896         const auto container = node->inputs().at(0);
897         const auto key = node->inputs().at(1);
898         const auto out = node->outputs().at(0);
899         assignValuesToTheirUniqueNames(out);
900         indent();
901         body_ << useOf(out) << " : " << out->type()->annotation_str() << " = "
902               << useOf(container) << "[" << useOf(key) << "]\n";
903       } break;
904       default:
905         auto ss = std::make_shared<TaggedStringStream>(&source_range_stack_);
906         printRHS(*ss, node);
907 
908         // we prevent long constants from inlining here.
909         // it is not safe to do the same thing for non-constants here
910         // because of [reordering of inlines]
911         if (output_inline_.count(node) == 0 ||
912             (node->kind() == prim::Constant && isLongLine(ss->str()))) {
913           printOutputDefinition(node, *ss);
914         } else {
915           // this node is safe to inline, so assign the output value
916           // to that expression directly
917           assignValue(node->output(), ss);
918           if (isLongLine(ss->str())) {
919             splitLongInlines(node->output());
920           }
921         }
922     }
923   }
924 
containsNonASCIIStringtorch::jit::PythonPrintImpl925   static bool containsNonASCIIString(const IValue& val) {
926     bool hasNonASCII = false;
927     auto checkSubvalue = [&hasNonASCII](const IValue& val) {
928       if (val.isString()) {
929         // char's type is implementation designed signedness, likely
930         // signed on x86 and unsigned on ARM. But as of C++11, it is
931         // guaranteed to be twos complement. Therefore, converting to
932         // signed char gives us a range of [-128, 127]. Thus, any
933         // negative number is non-ascii.
934         for (signed char c : val.toStringRef()) {
935           if (c < 0) {
936             hasNonASCII = true;
937             return true;
938           }
939         }
940       }
941       return false;
942     };
943 
944     val.visit(checkSubvalue);
945     return hasNonASCII;
946   }
947 
printConstanttorch::jit::PythonPrintImpl948   void printConstant(TaggedStringStream& stmt, const IValue& v) {
949     const auto customFormatter = [&](std::ostream& ss, const IValue& v) {
950       if (v.isTensor() || containsNonASCIIString(v) || v.isObject()) {
951         TORCH_INTERNAL_ASSERT(!v.type<c10::Type>()->is_module());
952         ss << "CONSTANTS.c" << getOrAddConstant(v);
953         return true;
954       }
955 
956       auto type = v.type();
957       if (auto dyn = type->castRaw<c10::DynamicType>()) {
958         type = dyn->fallback();
959       }
960       if (v.isTuple() && type->expectRef<TupleType>().schema()) {
961         // print the namedtuple constructor and let rest of tuple printing
962         // continue
963         ss << type->expectRef<TupleType>().annotation_str(type_printer_);
964       }
965       return false;
966     };
967 
968     std::stringstream ss;
969     v.repr(ss, customFormatter);
970     stmt << ss.str();
971   }
972 
printOpNametorch::jit::PythonPrintImpl973   void printOpName(TaggedStringStream& stmt, Symbol kind) {
974     // Special overriding ops set that requires serializing differently to
975     // preserve the original code semantics.
976     // This will be more properly handled when we have namespace semantics
977     // for serializing the ops, and it right now hard coded these ops to
978     // ensure consistency and not breaking BC in the future.
979     const static std::unordered_map<Symbol, std::string> override_symbols = {
980         {aten::backward, "torch.autograd.backward"},
981         {aten::grad, "torch.autograd.grad"},
982     };
983     if (override_symbols.find(kind) != override_symbols.end()) {
984       stmt << override_symbols.at(kind);
985     } else if (kind.is_aten()) {
986       // special case aten -> torch because we want to rename
987       // the aten namespace, but this change will take more time
988       // doing it here ensures we do not have fix up archives later
989       stmt << "torch." << kind.toUnqualString();
990     } else {
991       stmt << "ops." << kind.ns().toUnqualString() << "."
992            << kind.toUnqualString();
993     }
994   }
995 
996   // Prints the RHS value of a Node, e.g. `aten.add(x, y)`
printRHStorch::jit::PythonPrintImpl997   void printRHS(TaggedStringStream& stmt, Node* node) {
998     switch (node->kind()) {
999       case prim::PythonOp: {
1000         auto value = static_cast<const PythonOp*>(node);
1001         if (enforce_importable_) {
1002           throw(
1003               ErrorReport(node->sourceRange())
1004               << "Could not export Python function call '" << value->name()
1005               << "'. Remove calls to Python functions before export. "
1006               << "Did you forget to add @script or @script_method annotation? "
1007               << "If this is a nn.ModuleList, add it to __constants__");
1008         }
1009         std::stringstream scalars_stream;
1010         stmt << "^" << value->name();
1011         value->writeScalars(scalars_stream);
1012         stmt << scalars_stream.str();
1013         printValueList(stmt, node->inputs(), "(", ")");
1014       } break;
1015       case prim::Uninitialized: {
1016         stmt << "uninitialized("
1017              << node->output()->type()->annotation_str(type_printer_) << ")";
1018       } break;
1019       case prim::Constant: {
1020         if (node->outputs().size() == 1 &&
1021             node->output()->type()->kind() == TypeKind::FunctionType) {
1022           auto fn = node->output()->type()->expect<FunctionType>();
1023           deps_table_.add(fn);
1024           stmt << fn->annotation_str(type_printer_);
1025         } else if (!node->mustBeNone()) {
1026           IValue v = toIValue(node->output()).value();
1027           printConstant(stmt, v);
1028         } else {
1029           stmt << "None";
1030         }
1031       } break;
1032       case aten::ScalarImplicit:
1033       case aten::FloatImplicit:
1034       case aten::IntImplicit: {
1035         stmt << "annotate("
1036              << node->output()->type()->annotation_str(type_printer_) << ", "
1037              << useOf(node->input()) << ")";
1038       } break;
1039       case aten::Int: {
1040         printValueList(stmt, node->inputs(), "int(", ")");
1041       } break;
1042       case aten::Float: {
1043         printValueList(stmt, node->inputs(), "float(", ")");
1044       } break;
1045       case aten::Bool: {
1046         printValueList(stmt, node->inputs(), "bool(", ")");
1047       } break;
1048       case aten::str: {
1049         printValueList(stmt, node->inputs(), "str(", ")");
1050       } break;
1051       case aten::__getitem__: {
1052         printValueIndex(stmt, node->inputs());
1053       } break;
1054       case prim::Print: {
1055         printValueList(stmt, node->inputs(), "print(", ")");
1056       } break;
1057       case aten::sorted: {
1058         printValueList(stmt, node->inputs(), "sorted(", ")");
1059       } break;
1060       case prim::TupleConstruct: {
1061         if (auto qualname =
1062                 node->output()->type()->expectRef<TupleType>().name()) {
1063           stmt << node->output()->type()->annotation_str(type_printer_);
1064         }
1065         printValueList(
1066             stmt, node->inputs(), "(", node->inputs().size() == 1 ? ",)" : ")");
1067       } break;
1068       case prim::TupleIndex: {
1069         stmt << "(" << useOf(node->inputs().at(0)) << ")["
1070              << useOf(node->inputs().at(1)) << "]";
1071       } break;
1072       case prim::TupleSlice: {
1073         stmt << "(" << useOf(node->input()) << ")[" << node->i(attr::beg) << ":"
1074              << node->i(attr::end) << "]";
1075       } break;
1076       case prim::ListConstruct: {
1077         ListTypePtr list_type = node->output()->type()->expect<ListType>();
1078         TypePtr elem_type = list_type->getElementType();
1079         // Empty lists must be annotated with their type so the compiler knows
1080         // what type is supposed to be inside them
1081         if (node->inputs().empty()) {
1082           stmt << "annotate("
1083                << node->output()->type()->annotation_str(type_printer_)
1084                << ", [])";
1085           // If we can't infer the type based on what's inside, explicitly
1086           // annotate it to disambiguate.
1087           // This happens for List[Tensor] vs. List[Optional[Tensor]]
1088         } else if (!elementTypeCanBeInferredFromMembers(elem_type)) {
1089           stmt << "annotate("
1090                << node->output()->type()->annotation_str(type_printer_) << ", ";
1091           printValueList(stmt, node->inputs(), "[", "]");
1092           stmt << ")";
1093           // Otherwise just print a list
1094         } else {
1095           printValueList(stmt, node->inputs(), "[", "]");
1096         }
1097       } break;
1098       case prim::DictConstruct: {
1099         auto dict_type = node->output()->type()->expect<DictType>();
1100         // There are cases where we must annotate the dict with an explicit type
1101         // to help the compiler out:
1102         //   - the dict is empty
1103         //   - the dict has potentially ambiguous element types
1104         //       (e.g. Tensor vs. Optional[Tensor])
1105         if (node->inputs().empty() ||
1106             !elementTypeCanBeInferredFromMembers(dict_type->getKeyType()) ||
1107             !elementTypeCanBeInferredFromMembers(dict_type->getValueType())) {
1108           stmt << "annotate("
1109                << node->output()->type()->annotation_str(type_printer_) << ", ";
1110           printDict(stmt, node->inputs());
1111           stmt << ")";
1112           // Otherwise just print a dict
1113         } else {
1114           printDict(stmt, node->inputs());
1115         }
1116       } break;
1117       case prim::CreateObject: {
1118         const auto classType = node->output()->type()->expect<ClassType>();
1119         stmt << classType->annotation_str(type_printer_) << ".__new__("
1120              << classType->annotation_str(type_printer_) << ")";
1121       } break;
1122       case prim::GetAttr: {
1123         const auto obj = node->inputs().at(0);
1124         const auto classType = obj->type()->expect<ClassType>();
1125         const auto& field = node->s(attr::name);
1126         if (isValidIdentifier(field)) {
1127           stmt << useOf(obj) << "." << field;
1128         } else {
1129           stmt << "getattr(" << useOf(obj) << ", ";
1130           std::stringstream field_stream;
1131           c10::printQuotedString(field_stream, field);
1132           stmt << field_stream.str() << ")";
1133         }
1134       } break;
1135       case prim::CallFunction: {
1136         stmt << useOf(node->inputs().at(0)) << "(";
1137         for (size_t i = 1; i < node->inputs().size(); i++) {
1138           stmt << useOf(node->inputs()[i]) << ", ";
1139         }
1140         stmt << ")";
1141       } break;
1142       case prim::CallMethod: {
1143         const auto& self = node->inputs().at(0);
1144         const auto& methodName = node->s(attr::name);
1145         stmt << "(" << useOf(self) << ")"
1146              << "." << methodName << "(";
1147         for (size_t i = 1; i < node->inputs().size(); i++) {
1148           stmt << useOf(node->inputs()[i]) << ", ";
1149         }
1150         stmt << ")";
1151 
1152         if (auto selfClass = self->type()->cast<ClassType>()) {
1153           deps_table_.add(selfClass);
1154           const Function& method = selfClass->getMethod(node->s(attr::name));
1155           TORCH_INTERNAL_ASSERT(
1156               method.qualname() ==
1157               QualifiedName(selfClass->name()->qualifiedName(), methodName));
1158         } else if (auto selfInterface = self->type()->cast<InterfaceType>()) {
1159           deps_table_.add(selfInterface);
1160         } else {
1161           TORCH_INTERNAL_ASSERT(
1162               false, "method call to unhandled type in serialization");
1163         }
1164 
1165       } break;
1166       case aten::_unwrap_optional: {
1167         printOpName(stmt, node->kind());
1168         stmt << "(";
1169         // we cannot recover the type of unwrap_optional(None),
1170         // using normal schema matching, so we route around this by rewriting
1171         // the call to unwrap_optional(annotated(Optional[T], None))
1172         if (node->input()->type()->isSubtypeOf(*NoneType::get()) ||
1173             node->input()->mustBeNone()) {
1174           auto input_type = OptionalType::create(node->output()->type());
1175           stmt << "annotate(" << input_type->annotation_str(type_printer_)
1176                << ", " << useOf(node->input()) << ")";
1177         } else {
1178           stmt << useOf(node->input());
1179         }
1180         stmt << ")";
1181       } break;
1182       // unchecked_unwrap_optional is no longer generated by the compiler,
1183       // but may end up here if it was first loaded from a old model and
1184       // re-saved. On re-save we upgrade it to an unchecked_cast, which is an
1185       // equivalent op
1186       case prim::unchecked_unwrap_optional:
1187       case prim::unchecked_cast: {
1188         stmt << "unchecked_cast("
1189              << node->output()->type()->annotation_str(type_printer_) << ", "
1190              << useOf(node->input()) << ")";
1191       } break;
1192       case prim::isinstance: {
1193         stmt << "isinstance(" << useOf(node->input()) << ", ";
1194         const auto& types = node->tys(attr::types);
1195         if (types.size() == 1) {
1196           stmt << types.at(0)->annotation_str(type_printer_);
1197         } else {
1198           // check multiple things, e.g. (str, list, int)
1199           stmt << "(";
1200           bool first = true;
1201           for (const TypePtr& typ : types) {
1202             if (!first) {
1203               stmt << ", ";
1204             }
1205             stmt << typ->annotation_str(type_printer_);
1206             first = false;
1207           }
1208           stmt << ")";
1209         }
1210         stmt << ")";
1211       } break;
1212       case prim::tolist: {
1213         stmt << "annotate("
1214              << node->output()->type()->annotation_str(type_printer_) << ", ";
1215         stmt << useOf(node->input(0)) << ".tolist()"
1216              << ")";
1217       } break;
1218       case prim::EnumValue:
1219         // Note: This CAN NOT be printed as raw operator ops.prim.EnumValue
1220         // because its return type depends on type of enum and must be further
1221         // resolved, but ops.prim.EnumValue construction does not provide such
1222         // functionality.
1223         stmt << "(" << useOf(node->input()) << ").value";
1224         break;
1225       case prim::EnumName:
1226         stmt << "(" << useOf(node->input()) << ").name";
1227         break;
1228       default: {
1229         printOpName(stmt, node->kind());
1230         const FunctionSchema& schema = node->schema();
1231         stmt << "(";
1232         // calculate how many args are specified.
1233         // see (https://github.com/pytorch/pytorch/pull/56079) for more
1234         // details.
1235         size_t num_schema_args = schema.arguments().size();
1236 
1237         // we only want to do this extra logic only when necessary.
1238         if (num_schema_args > 0) {
1239           // calculate how many args are specified.
1240           // see (https://github.com/pytorch/pytorch/pull/56079) for more
1241           // details.
1242           auto specified_args =
1243               CalculateNecessaryArgs(schema.arguments(), node->inputs(), true);
1244 
1245           auto num_necessary = specified_args.first;
1246           auto num_out = specified_args.second;
1247 
1248           for (const auto i : c10::irange(static_cast<size_t>(num_necessary))) {
1249             if (i > 0)
1250               stmt << ", ";
1251             auto v = useOf(node->inputs().at(i));
1252             // print the kwarg name if it is a kwarg only argument.
1253             if (i < num_schema_args) {
1254               auto arg = schema.arguments().at(i);
1255               if (arg.kwarg_only()) {
1256                 stmt << arg.name() << "=";
1257               }
1258             } else {
1259               // vararg functions like format can have extra arguments
1260               AT_ASSERT(schema.is_vararg());
1261             }
1262             stmt << *v;
1263           }
1264 
1265           // print out args
1266           for (size_t i = num_schema_args - num_out; i < num_schema_args; i++) {
1267             stmt << ", ";
1268             auto arg = schema.arguments().at(i);
1269             TORCH_INTERNAL_ASSERT(arg.is_out());
1270             // figure out the corresponding input at this index
1271             auto input_idx = node->inputs().size() - (num_schema_args - i);
1272             if (input_idx < node->inputs().size()) {
1273               stmt << arg.name() << "=" << *useOf(node->inputs().at(input_idx));
1274             }
1275           }
1276         }
1277         stmt << ")";
1278       } break;
1279     }
1280   }
1281 
printBlocktorch::jit::PythonPrintImpl1282   TaggedStringStream& printBlock(Block* root, bool block_has_other_statements) {
1283     // pythons weird 'pass' syntax creates a bunch of places where we have to
1284     // check if this block would be empty. But not everything in a block is a
1285     // node. Sometimes if, loop, and return statements will follow this block
1286     // and block_has_other_statements == true.
1287     if (!block_has_other_statements &&
1288         root->nodes().begin() == root->nodes().end()) {
1289       indent();
1290       body_ << "pass\n";
1291     }
1292     for (auto* node : root->nodes()) {
1293       printNode(node, /*print_const=*/false);
1294     }
1295     return body_;
1296   }
1297 
1298   template <typename dtype>
createBroadListtorch::jit::PythonPrintImpl1299   IValue createBroadList(dtype value, const int64_t& N) {
1300     c10::List<dtype> repeated;
1301     repeated.reserve(N);
1302     for (const auto i : c10::irange(N)) {
1303       (void)i; // Suppress unused variable warning
1304       repeated.push_back(value);
1305     }
1306     return repeated;
1307   }
1308 
printDefaultValuetorch::jit::PythonPrintImpl1309   void printDefaultValue(
1310       const Argument& arg,
1311       TaggedStringStream& stmt,
1312       const IValue& value) {
1313     stmt << "=";
1314     // handle broadcasting lists
1315     if (arg.type()->kind() == ListType::Kind &&
1316         (value.isInt() || value.isDouble() || value.isBool())) {
1317       TORCH_INTERNAL_ASSERT(arg.N(), "expected broadcastinglist");
1318       if (value.isInt()) {
1319         printConstant(stmt, createBroadList<int64_t>(value.toInt(), *arg.N()));
1320       } else if (value.isBool()) {
1321         printConstant(stmt, createBroadList<bool>(value.toBool(), *arg.N()));
1322       } else if (value.isDouble()) {
1323         printConstant(
1324             stmt, createBroadList<double>(value.toDouble(), *arg.N()));
1325       }
1326     } else {
1327       printConstant(stmt, value);
1328     }
1329   }
1330 
printBodytorch::jit::PythonPrintImpl1331   void printBody(Block* body) {
1332     // we always print constants at the top of the function, in the order
1333     // in which they are used.
1334     std::vector<Node*> constants;
1335     buildConstantList(body, constants);
1336 
1337     // current graph is used to de-dup names within a single graph
1338     scanBlock(body);
1339     {
1340       auto guard = WithIndented();
1341       // Print initial constant table (most are just inlined into their use,
1342       // but some like long strings do get emitted)
1343       for (Node* n : constants) {
1344         printNode(n, /*print_const=*/true);
1345       }
1346       // Print body
1347       printBlock(body, !body->return_node()->inputs().empty());
1348       printNode(body->return_node(), /*print_const=*/false);
1349     }
1350   }
1351 
1352  public:
printFunctiontorch::jit::PythonPrintImpl1353   void printFunction(
1354       const Function& func,
1355       bool print_first_argument_type = true) {
1356     const FunctionSchema& schema = func.getSchema();
1357     Graph& graph = *toGraphFunction(func).graph();
1358     used_names_.clear(); // each graph can reuse local names
1359 
1360     WithSourceRange guard(&source_range_stack_, graph.param_node());
1361 
1362     indent();
1363     body_ << "def " << func.name() << "(";
1364     auto param_it = graph.inputs().begin();
1365     for (const Argument& arg : schema.arguments()) {
1366       registerClassDependencies(arg.type());
1367       std::string arg_name = genName(arg.name());
1368       if (param_it == graph.inputs().begin()) {
1369         // the first argument may omit its type when it is implied by context
1370         // the flag print_first_argument_type determines when to do this
1371         body_ << arg_name;
1372         if (print_first_argument_type) {
1373           body_ << ": " << arg.type()->annotation_str(type_printer_);
1374           annotated_unions_.insert(*param_it);
1375         }
1376       } else {
1377         body_ << ",\n    " << arg_name << ": "
1378               << arg.type()->annotation_str(type_printer_);
1379         annotated_unions_.insert(*param_it);
1380       }
1381       if (arg.default_value()) {
1382         printDefaultValue(arg, body_, *arg.default_value());
1383       }
1384       assignValue(*param_it++, arg_name);
1385     }
1386 
1387     const auto& returnType = schema.returns().at(0).type();
1388     body_ << ") -> " << returnType->annotation_str(type_printer_) << ":\n";
1389     registerClassDependencies(returnType);
1390 
1391     printBody(graph.block());
1392   }
1393 
printMethodtorch::jit::PythonPrintImpl1394   void printMethod(const Function& func) {
1395     printFunction(func, /*print_first_argument_type=*/false);
1396   }
1397 
PythonPrintImpltorch::jit::PythonPrintImpl1398   PythonPrintImpl(
1399       std::vector<at::IValue>& constant_table,
1400       PrintDepsTable& deps_table,
1401       c10::TypePrinter type_printer,
1402       bool enforce_importable)
1403       : body_(&source_range_stack_),
1404         constant_table_(constant_table),
1405         deps_table_(deps_table),
1406         type_printer_(std::move(type_printer)),
1407         enforce_importable_(enforce_importable) {}
1408 
printClasstorch::jit::PythonPrintImpl1409   void printClass(const ClassTypePtr& classType) {
1410     // If any of the methods are not Graph funtions, this indicates that
1411     // this class is a custom-bound C++ class. Skip serialization
1412     // of this class, we will depend on the ClassType being defined
1413     // in the target process.
1414     for (auto& method : classType->methods()) {
1415       if (!method->isGraphFunction()) {
1416         return;
1417       }
1418     }
1419 
1420     bool is_module = classType->is_module();
1421     body_ << "class " << classType->name()->name();
1422     if (is_module) {
1423       body_ << "(Module)";
1424     }
1425 
1426     body_ << ":\n";
1427     {
1428       const auto guard = WithIndented();
1429       size_t numAttrs = classType->numAttributes();
1430       // For modules, we need to print special information about the module's
1431       // attributes and parameters.
1432       if (is_module) {
1433         std::vector<std::string> params;
1434         std::vector<std::string> buffers;
1435         // Populate the __parameters__ field. This tells the importer which
1436         // attributes are parameters.
1437         for (const auto i : c10::irange(numAttrs)) {
1438           if (classType->is_parameter(i)) {
1439             params.push_back(classType->getAttributeName(i));
1440           }
1441           if (classType->is_buffer(i)) {
1442             buffers.push_back(classType->getAttributeName(i));
1443           }
1444         }
1445         indent();
1446         body_ << "__parameters__ = [";
1447         for (const auto& param : params) {
1448           body_ << "\"" << param << "\", ";
1449         }
1450         body_ << "]\n";
1451 
1452         indent();
1453         body_ << "__buffers__ = [";
1454         for (const auto& buffer : buffers) {
1455           body_ << "\"" << buffer << "\", ";
1456         }
1457         body_ << "]\n";
1458         auto forwardPreHooks = classType->getForwardPreHooks();
1459         if (!forwardPreHooks.empty()) {
1460           indent();
1461           body_ << "__forward_pre_hooks__ = [";
1462           for (const auto& pre_hook : forwardPreHooks) {
1463             body_ << "\"" << pre_hook->name() << "\", ";
1464           }
1465           body_ << "]\n";
1466         }
1467 
1468         auto forwardHooks = classType->getForwardHooks();
1469         if (!forwardHooks.empty()) {
1470           indent();
1471           body_ << "__forward_hooks__ = [";
1472           for (const auto& hook : forwardHooks) {
1473             body_ << "\"" << hook->name() << "\", ";
1474           }
1475           body_ << "]\n";
1476         }
1477       }
1478 
1479       for (const auto i : c10::irange(numAttrs)) {
1480         const auto& name = classType->getAttributeName(i);
1481         const auto& type = classType->getAttribute(i);
1482         registerClassDependencies(type);
1483 
1484         indent();
1485 
1486         // Handling for when the attribute name is not a valid Python
1487         // identifier. This happens for, e.g. ModuleList.
1488         if (!isValidIdentifier(name)) {
1489           if (i == 0) {
1490             // Initialize the annotations dict if necessary.
1491             body_ << "__annotations__ = []\n";
1492             indent();
1493           }
1494           // Print out a direct manipulation of the annotations dict, like:
1495           //   __annotations__["0"] = SomeType
1496           body_ << "__annotations__["
1497                 << "\"" << name
1498                 << "\"] = " << type->annotation_str(type_printer_) << "\n";
1499         } else {
1500           // Otherwise: just emit a python 3 attribute annotation, like:
1501           //   foo : SomeType
1502           body_ << name << " : " << type->annotation_str(type_printer_) << "\n";
1503         }
1504       }
1505 
1506       size_t numConstants = classType->numConstants();
1507       for (const auto i : c10::irange(numConstants)) {
1508         const auto& name = classType->getConstantName(i);
1509         IValue v = classType->getConstant(i);
1510 
1511         indent();
1512         body_ << name << " : "
1513               << "Final[" << v.type()->annotation_str(type_printer_) << "] = ";
1514         auto ss = std::make_shared<TaggedStringStream>(&source_range_stack_);
1515         printConstant(*ss, v);
1516         body_ << ss->str() << "\n";
1517       }
1518 
1519       // TODO fields
1520       for (auto& method : classType->methods()) {
1521         printFunction(*method);
1522       }
1523       std::set<std::string> already_printed;
1524       for (auto& hook : classType->getForwardHooks()) {
1525         if (already_printed.count(hook->name()) == 0) {
1526           already_printed.insert(hook->name());
1527           printFunction(*hook);
1528         }
1529       }
1530       for (auto& pre_hook : classType->getForwardPreHooks()) {
1531         if (already_printed.count(pre_hook->name()) == 0) {
1532           already_printed.insert(pre_hook->name());
1533           printFunction(*pre_hook);
1534         }
1535       }
1536     }
1537   }
1538 
printNamedTypetorch::jit::PythonPrintImpl1539   void printNamedType(const c10::NamedTypePtr& type) {
1540     if (auto functionType = type->cast<FunctionType>()) {
1541       printFunction(*functionType->function());
1542     } else if (auto classType = type->cast<ClassType>()) {
1543       printClass(classType);
1544     } else if (auto tupleType = type->cast<TupleType>()) {
1545       TORCH_INTERNAL_ASSERT(tupleType->schema());
1546       body_ << "class " << tupleType->name()->name();
1547       body_ << "(NamedTuple):\n";
1548       {
1549         const auto guard = WithIndented();
1550         for (const auto& attr : tupleType->schema()->arguments()) {
1551           TORCH_INTERNAL_ASSERT(attr.type());
1552           indent();
1553           body_ << attr.name() << " : "
1554                 << attr.type()->annotation_str(type_printer_) << "\n";
1555         }
1556       }
1557     } else if (auto interfaceType = type->cast<InterfaceType>()) {
1558       body_ << "class " << interfaceType->name()->name();
1559       if (interfaceType->is_module()) {
1560         body_ << "(ModuleInterface):\n";
1561       } else {
1562         body_ << "(Interface):\n";
1563       }
1564       {
1565         auto guard = WithIndented();
1566         for (const FunctionSchema& method : interfaceType->methods()) {
1567           indent();
1568           body_ << "def " << method.name() << "(self";
1569           TORCH_INTERNAL_ASSERT(
1570               !method.arguments().empty() &&
1571               method.arguments().at(0).name() == "self");
1572           for (const Argument& arg :
1573                at::ArrayRef<Argument>(method.arguments()).slice(1)) {
1574             const auto& arg_type = arg.type();
1575             registerClassDependencies(arg_type);
1576             body_ << ", " << arg.name() << ": "
1577                   << arg_type->annotation_str(type_printer_);
1578           }
1579           auto return_type = method.returns().at(0).type();
1580           registerClassDependencies(return_type);
1581           body_ << ") -> " << return_type->annotation_str(type_printer_)
1582                 << ":\n";
1583           indent();
1584           body_ << "  pass\n";
1585         }
1586       }
1587     } else if (auto enumType = type->cast<EnumType>()) {
1588       body_ << "class " << enumType->qualifiedClassName().name() << "(Enum):\n";
1589 
1590       std::string value_wrapper = "";
1591       if (enumType->getValueType() == StringType::get()) {
1592         value_wrapper = "\"";
1593       }
1594 
1595       {
1596         auto guard = WithIndented();
1597         for (const auto& name_value : enumType->enumNamesValues()) {
1598           indent();
1599           body_ << name_value.first << " = " << value_wrapper
1600                 << name_value.second << value_wrapper << "\n";
1601         }
1602       }
1603     } else {
1604       TORCH_INTERNAL_ASSERT(false, "Unhandled NamedType");
1605     }
1606   }
1607 
1608   ~PythonPrintImpl() = default;
1609 
1610   TaggedStringStream body_;
1611   // When printing this node, is it safe to write it inline (i.e. without
1612   // assigning a temporary variable
1613   std::unordered_set<Node*> output_inline_;
1614 
1615   // see [reordering of inlines]
1616   // used to track parts of an inline statement we already scanned
1617   // for splitting long lines, so that we do not revisit them causing n^2
1618   // behavior. stores the maximum offset into inputs that has already been
1619   // scanned for the node.
1620   std::unordered_map<Node*, int64_t> visited_split_inline_uses_;
1621 
1622   // what valid identifiers are in use for the current function
1623   std::unordered_set<std::string> used_names_;
1624 
1625   // constants are written to this table, and given then named CONSTANTS.cN
1626   // where N is the index into this table.
1627   std::vector<at::IValue>& constant_table_;
1628 
1629   // Any NamedTypes (classes, functions, NamedTuples) used are written to this
1630   // table.
1631   PrintDepsTable& deps_table_;
1632 
1633   // We need to preserve Union/Optional type annotations, but we should
1634   // only print the annotation on variable declaration (not on any
1635   // following uses). This set tracks the Value*s that we've already
1636   // printed with annotations
1637   std::unordered_set<Value*> annotated_unions_;
1638 
1639   // A function that, given a named type, returns us the correct string to print
1640   // for it.
1641   c10::TypePrinter type_printer_;
1642 
1643   // when we print this, should we error if the resulting output would
1644   // not be able to be reparsed?
1645   bool enforce_importable_;
1646 
1647   // The least version that supports all printed ops
1648   uint64_t min_version_ = caffe2::serialize::kMinSupportedFileFormatVersion;
1649 };
1650 
PythonPrint(std::vector<at::IValue> & constant_table,PrintDepsTable & deps_table,c10::TypePrinter type_printer,bool enforce_importable)1651 PythonPrint::PythonPrint(
1652     std::vector<at::IValue>& constant_table,
1653     PrintDepsTable& deps_table,
1654     c10::TypePrinter type_printer,
1655     bool enforce_importable)
1656     : pImpl(std::make_shared<PythonPrintImpl>(
1657           constant_table,
1658           deps_table,
1659           std::move(type_printer),
1660           enforce_importable)) {}
1661 
printNamedType(const c10::NamedTypePtr & type)1662 void PythonPrint::printNamedType(const c10::NamedTypePtr& type) {
1663   pImpl->printNamedType(type);
1664 }
1665 
printFunction(const Function & func)1666 void PythonPrint::printFunction(const Function& func) {
1667   pImpl->printFunction(func);
1668 }
1669 
printMethod(const Function & func)1670 void PythonPrint::printMethod(const Function& func) {
1671   pImpl->printMethod(func);
1672 }
1673 
str() const1674 std::string PythonPrint::str() const {
1675   return pImpl->body_.str();
1676 }
1677 
ranges() const1678 const SourceRangeRecords& PythonPrint::ranges() const {
1679   return pImpl->body_.ranges();
1680 }
1681 
minVersion() const1682 uint64_t PythonPrint::minVersion() const {
1683   return pImpl->min_version_;
1684 }
1685 
traverseIValueAndGetObjects(const IValue & ivalue)1686 static std::vector<IValue> traverseIValueAndGetObjects(const IValue& ivalue) {
1687   std::vector<IValue> result;
1688   std::vector<IValue> stack;
1689   stack.emplace_back(ivalue);
1690   while (!stack.empty()) {
1691     IValue head = stack.back();
1692     stack.pop_back();
1693     if (head.isObject()) {
1694       result.push_back(head);
1695       auto obj = head.toObject();
1696       ClassTypePtr type = obj->type();
1697       if (type->hasMethod("__getstate__")) {
1698         Function& getstate = type->getMethod("__getstate__");
1699         stack.emplace_back(getstate({obj}));
1700       } else {
1701         for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
1702           stack.emplace_back(obj->getSlot(i));
1703         }
1704       }
1705     } else if (ivalue.isGenericDict()) {
1706       for (const auto& kv : ivalue.toGenericDict()) {
1707         // skip key because key cannot be an object
1708         stack.emplace_back(kv.value());
1709       }
1710     } else if (ivalue.isList()) {
1711       for (const auto& v : ivalue.toList()) {
1712         stack.emplace_back(v);
1713       }
1714     } else if (ivalue.isTuple()) {
1715       for (const auto& v : ivalue.toTuple()->elements()) {
1716         stack.emplace_back(v);
1717       }
1718     }
1719   }
1720   return result;
1721 }
1722 
printType(const c10::Type & type,torch::jit::TypeNameUniquer & type_name_uniquer)1723 static std::optional<std::string> printType(
1724     const c10::Type& type,
1725     torch::jit::TypeNameUniquer& type_name_uniquer) {
1726   if (auto dyn = type.castRaw<c10::DynamicType>()) {
1727     return dyn->fallback()->annotation_str(
1728         [&](auto&& t) { return printType(t, type_name_uniquer); });
1729   }
1730   auto namedType = type.cast<c10::NamedType>();
1731   if (namedType && namedType->name()) {
1732     return type_name_uniquer.getUniqueName(namedType).qualifiedName();
1733   }
1734   return std::nullopt;
1735 }
1736 
jitModuleToPythonCodeAndConstants(const Module & module,ExtraFilesMap * jit_sources,std::vector<IValue> * constants)1737 void jitModuleToPythonCodeAndConstants(
1738     const Module& module,
1739     ExtraFilesMap* jit_sources, // output
1740     std::vector<IValue>* constants // output
1741 ) {
1742   std::vector<IValue> objects = traverseIValueAndGetObjects(module._ivalue());
1743   std::unordered_set<c10::QualifiedName> visited;
1744   PrintDepsTable class_deps;
1745   TypeNameUniquer uniquer;
1746   auto type_printer = [&](const c10::Type& t) { return printType(t, uniquer); };
1747 
1748   // Group by prefix; because every prefix is a file.
1749   std::unordered_map<std::string, PythonPrint> grouped_by_prefix;
1750   for (const IValue& obj : objects) {
1751     ObjectPtr obj_ptr = obj.toObject();
1752     ClassTypePtr class_type = obj_ptr->type();
1753     class_deps.add(class_type);
1754   }
1755 
1756   for (size_t i = 0; i < class_deps.size(); ++i) {
1757     // note: PythonPrint may extend class_deps, so re-checking .size() is
1758     // necessary
1759     auto type = class_deps[i];
1760     auto qualname = uniquer.getUniqueName(type);
1761     std::string qualifier = qualname.prefix();
1762     auto pp_iter = grouped_by_prefix.find(qualifier);
1763     if (pp_iter == grouped_by_prefix.end()) {
1764       pp_iter = grouped_by_prefix
1765                     .emplace(
1766                         qualifier,
1767                         PythonPrint(
1768                             *constants,
1769                             class_deps,
1770                             type_printer,
1771                             /*enforce_importable=*/true))
1772                     .first;
1773     }
1774     pp_iter->second.printNamedType(type);
1775   }
1776   for (const auto& kv : grouped_by_prefix) {
1777     (*jit_sources)[kv.first] = kv.second.str();
1778   }
1779 }
1780 
1781 } // namespace torch::jit
1782