1 #include <torch/csrc/jit/serialization/python_print.h>
2
3 #include <algorithm>
4
5 #include <ATen/core/ivalue.h>
6 #include <ATen/core/qualified_name.h>
7 #include <c10/util/Exception.h>
8 #include <c10/util/StringUtil.h>
9 #include <c10/util/irange.h>
10 #include <caffe2/serialize/versions.h>
11 #include <torch/csrc/jit/api/function_impl.h>
12 #include <torch/csrc/jit/api/module.h>
13 #include <torch/csrc/jit/frontend/error_report.h>
14 #include <torch/csrc/jit/frontend/versioned_symbols.h>
15 #include <torch/csrc/jit/ir/attributes.h>
16 #include <torch/csrc/jit/ir/ir.h>
17 #include <torch/csrc/jit/ir/ir_views.h>
18 #include <torch/csrc/jit/operator_upgraders/version_map.h>
19 #include <torch/csrc/jit/resource_guard.h>
20 #include <torch/csrc/jit/runtime/calculate_necessary_args.h>
21 #include <torch/csrc/jit/serialization/type_name_uniquer.h>
22
23 using c10::QualifiedName;
24
25 namespace torch::jit {
26
isValidIdentifierChar(char c,size_t pos)27 static bool isValidIdentifierChar(char c, size_t pos) {
28 return islower(c) || isupper(c) || c == '_' || (pos > 0 && isdigit(c));
29 }
30
isValidIdentifier(const std::string & name)31 static bool isValidIdentifier(const std::string& name) {
32 if (name.empty())
33 return false;
34 for (const auto i : c10::irange(name.size())) {
35 if (!isValidIdentifierChar(name[i], i))
36 return false;
37 }
38 return true;
39 }
40
41 // some names are valid identifiers but off limits because
42 // they are keywords or namespaces used in the output
43 const static std::unordered_set<std::string> reserved_names = {
44 // identifiers in the environment while parsing
45 "_", // avoid the confusing unnamed _
46 "as",
47 "aten",
48 "attribute",
49 "CONSTANTS",
50 "fork",
51 "getattr",
52 "inf",
53 "nan",
54 "infj",
55 "nanj",
56 "ops",
57 "__torch__",
58 // the python keywords
59 "and",
60 "as",
61 "assert",
62 "async",
63 "await",
64 "break",
65 "class",
66 "continue",
67 "def",
68 "del",
69 "elif",
70 "else",
71 "except",
72 "False",
73 "finally",
74 "for",
75 "from",
76 "global",
77 "if",
78 "import",
79 "in",
80 "is",
81 "lambda",
82 "None",
83 "nonlocal",
84 "not",
85 "or",
86 "pass",
87 "raise",
88 "return",
89 "True",
90 "try",
91 "with",
92 "while",
93 "with",
94 "yield",
95 "uninitialized",
96 "unchecked_cast",
97 };
98
99 // Helper to avoid duplicating class types
add(const c10::NamedTypePtr & type)100 void PrintDepsTable::add(const c10::NamedTypePtr& type) {
101 // Despite doing the linear search below, we don't want to do
102 // wasteful work and only try to insert each instance once.
103 if (!non_unique_.insert(type).second) {
104 return;
105 }
106 // Need to do actual equality comparison, not a pointer equality. This is
107 // because for some types (e.g. FunctionType), we may have multiple
108 // TypePtr's that represent the same underlying thing.
109 // TODO: this should be really swapped for something more efficient
110 auto it = std::find_if(
111 table_.cbegin(), table_.cend(), [&](const c10::NamedTypePtr& dep) {
112 return *dep == *type;
113 });
114
115 if (it == table_.cend()) {
116 table_.push_back(type);
117 }
118 }
119
120 struct PythonPrintImpl {
121 using SourceRangeStack = std::vector<SourceRange>;
122 SourceRangeStack source_range_stack_ = {SourceRange()};
123
124 struct WithSourceRange {
WithSourceRangetorch::jit::PythonPrintImpl::WithSourceRange125 explicit WithSourceRange(SourceRangeStack* stack, Node* n) : stack(stack) {
126 TORCH_INTERNAL_ASSERT(stack);
127 if (auto gen_source = n->sourceRange().findSourceRangeThatGenerated()) {
128 stack->push_back(std::move(gen_source.value()));
129 } else {
130 stack->push_back(n->sourceRange());
131 }
132 }
133
~WithSourceRangetorch::jit::PythonPrintImpl::WithSourceRange134 ~WithSourceRange() {
135 stack->pop_back();
136 }
137
138 SourceRangeStack* stack;
139 };
140
141 class TaggedStringStream {
142 public:
TaggedStringStream(const SourceRangeStack * srs)143 TaggedStringStream(const SourceRangeStack* srs) : srs_(srs) {}
144
operator <<(const std::string & s)145 TaggedStringStream& operator<<(const std::string& s) {
146 // This prevents having redundant entries at the same offset,
147 // which can happen for example in printValueList when begin
148 // and end are the empty string.
149 if (s.empty()) {
150 return *this;
151 }
152
153 if (ranges_.empty() || ranges_.back().range != srs_->back()) {
154 ranges_.emplace_back((size_t)oss_.tellp(), srs_->back());
155 }
156 oss_ << s;
157 return *this;
158 }
159
operator <<(const TaggedStringStream & rhs)160 TaggedStringStream& operator<<(const TaggedStringStream& rhs) {
161 for (const auto& range : rhs.ranges_) {
162 if (ranges_.empty() || ranges_.back().range != range.range) {
163 ranges_.emplace_back((size_t)oss_.tellp() + range.bytes, range.range);
164 }
165 }
166 oss_ << rhs.oss_.str();
167 return *this;
168 }
169
170 // This overload is here to prevent people from shooting themselves in the
171 // foot. I would be highly surprised if someone actually wanted to write out
172 // the address of a TaggedStringStream in the pretty print.
operator <<(const std::shared_ptr<TaggedStringStream> & rhs)173 TaggedStringStream& operator<<(
174 const std::shared_ptr<TaggedStringStream>& rhs) {
175 (*this) << *rhs;
176 return *this;
177 }
178
179 template <typename T>
operator <<(const T & t)180 TaggedStringStream& operator<<(const T& t) {
181 if (ranges_.empty() || ranges_.back().range != srs_->back()) {
182 ranges_.emplace_back((size_t)oss_.tellp(), srs_->back());
183 }
184 oss_ << t;
185 return *this;
186 }
187
str() const188 std::string str() const {
189 return oss_.str();
190 }
191
ranges() const192 const std::vector<TaggedRange>& ranges() const {
193 return ranges_;
194 }
195
196 private:
197 std::ostringstream oss_;
198 std::vector<TaggedRange> ranges_;
199 const SourceRangeStack* srs_;
200 };
201
202 // scanValue, scanNode, scanBlock:
203 // decide if it is safe to omit the output of a temporary variable,
204 // and inline the expression into its use
205 // we only do this if
206 // (1) it is a constant, or
207 // (2) the temporary is unnamed, is single output, is used once,
208 // and would appear in the same order when the expression tree is
209 // reparsed.
210 // The last case can be checked
211 // because when we emit a expresion tree in the parser,
212 // we do a left-to-right postorder traversal of the expression tree (emit
213 // children, then emit op). The reverse of this is a right-to-left preorder
214 // traversal of the tree. By doing a right-to-left preorder traversal of the
215 // inputs of a node, while also scanning the list of emitted nodes backward,
216 // we can see if they line up with what would happen when parsed the node as
217 // an expression. While they line up we collapse them into an inline
218 // expression.
219
220 // The inductive step is that the right-most input should be produced by the
221 // node immediatly before the current node if it is in tree order.
222
canInlinetorch::jit::PythonPrintImpl223 bool canInline(Value* v) {
224 Node* n = v->node();
225 // there must be only 1 values, otherwise we need an assignment to handle
226 // the multiple outout values
227 if (n->outputs().size() != 1)
228 return false;
229 // if it is used more than once, then we need a variable
230 if (v->uses().size() != 1)
231 return false;
232 auto use = v->uses().at(0);
233 // if it has a name set, then it was written as a variable so preserve that
234 // unless it is being fed directly to the end of the block.
235 // in which case it is not as useful to give it a name just to return it
236 if (v->hasDebugName() && use.user->kind() != prim::Return)
237 return false;
238 // don't try to inline control blocks
239 if (!n->blocks().empty())
240 return false;
241 // if it is a loop-carried input, we need a variable
242 // otherwise the condition or trip count may be emitted in the wrong order
243 // w.r.t. to it
244 if (use.user->kind() == prim::Loop && use.offset >= 2)
245 return false;
246
247 // subgraph may use this more than once, so disable inlining
248 if (use.user->kind() == prim::fork || use.user->kind() == prim::rpc_async ||
249 use.user->kind() == prim::rpc_sync ||
250 use.user->kind() == prim::rpc_remote)
251 return false;
252
253 // isinstance appearing in an if expression
254 // causes type refinement to occur, but we have
255 // already handled the refinement and inserted cast
256 // expressions. By not inlining it into the if condition,
257 // we prevent it from happening again.
258 if (v->node()->kind() == prim::isinstance) {
259 return false;
260 }
261
262 return true;
263 }
264
265 // block_point is the current node in the reverse linear scan of the emitted
266 // nodes v is the current value in the tree traversal that may match with
267 // block_point's output.
scanValuetorch::jit::PythonPrintImpl268 Node* scanValue(Node* block_point, Value* v) {
269 Node* n = v->node();
270 AT_ASSERT(n->kind() == prim::Constant || output_inline_.count(n) == 0);
271
272 if (n == block_point &&
273 canInline(v)) { // the node must be at the expected point of the typical
274 // tree traversal
275 // recursively see if we can inline the inputs to this input
276 block_point = scanNode(block_point);
277 output_inline_.insert(n);
278 } else if (n->kind() == prim::Constant) {
279 // constant nodes can always be inlined, we will de-dup them on parsing
280 // and put them at the top of the function regardless
281 output_inline_.insert(n);
282 }
283 return block_point;
284 }
previousNonConstanttorch::jit::PythonPrintImpl285 Node* previousNonConstant(Node* n) {
286 do {
287 n = n->prev();
288 } while (n->kind() == prim::Constant);
289 return n;
290 }
291
scanNodetorch::jit::PythonPrintImpl292 Node* scanNode(Node* n) {
293 // don't bother to scan nodes we have already determined to be inline
294 if (output_inline_.count(n)) {
295 return n;
296 }
297 for (auto b : n->blocks()) {
298 scanBlock(b);
299 }
300 Node* block_point = previousNonConstant(n);
301 for (auto it = n->inputs().rbegin(), end = n->inputs().rend(); it != end;
302 ++it) {
303 block_point = scanValue(block_point, *it);
304 }
305 return block_point;
306 }
307
scanBlocktorch::jit::PythonPrintImpl308 void scanBlock(Block* b) {
309 scanNode(b->return_node());
310 for (auto node : b->nodes().reverse()) {
311 scanNode(node);
312 }
313 }
314
getOrAddConstanttorch::jit::PythonPrintImpl315 size_t getOrAddConstant(at::IValue val) {
316 // XXX - N^2 warning. This code does the exact same thing as
317 // ConstantPool, which is also N^2 in the size of the constants,
318 // because it doesn't hash any information about the tensors.
319 // We will probably need to optimize this at some point using hashing.
320 if (val.isTensor()) {
321 auto& t = val.toTensor();
322 for (const auto i : c10::irange(constant_table_.size())) {
323 if (!constant_table_[i].isTensor()) {
324 continue;
325 }
326 auto& t2 = constant_table_[i].toTensor();
327 if (t.options().type_equal(t2.options()) && t.equal(t2)) {
328 return i;
329 }
330 }
331 }
332 constant_table_.emplace_back(std::move(val));
333 return constant_table_.size() - 1;
334 }
335
336 std::unordered_set<Node*> seen_constants;
buildConstantListtorch::jit::PythonPrintImpl337 void buildConstantList(Node* n, std::vector<Node*>& constants) {
338 for (auto input : n->inputs()) {
339 if (input->node()->kind() == prim::Constant &&
340 seen_constants.count(input->node()) == 0) {
341 constants.push_back(input->node());
342 seen_constants.insert(input->node());
343 }
344 }
345 for (auto b : n->blocks()) {
346 buildConstantList(b, constants);
347 }
348 }
buildConstantListtorch::jit::PythonPrintImpl349 void buildConstantList(Block* b, std::vector<Node*>& constants) {
350 for (auto n : b->nodes())
351 buildConstantList(n, constants);
352 buildConstantList(b->return_node(), constants);
353 }
354
355 // get a new name unique across calls to debugName() and
356 // anything we have used.
357 std::unordered_map<std::string, size_t> next_id;
358
genNameImpltorch::jit::PythonPrintImpl359 std::string genNameImpl(
360 const std::string& candidate,
361 std::unordered_set<std::string>& used) {
362 std::string name = candidate;
363 while (used.count(name) || reserved_names.count(name)) {
364 // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
365 name = candidate + std::to_string(next_id[name]++);
366 }
367 used.insert(name);
368 return name;
369 }
genNametorch::jit::PythonPrintImpl370 std::string genName(const std::string& candidate) {
371 return genNameImpl(candidate, used_names_);
372 }
373
374 // unique names might not be valid identifiers,
375 // force them to be by rewriting them
makeValidIdentifiertorch::jit::PythonPrintImpl376 static std::string makeValidIdentifier(const std::string& candidate) {
377 std::stringstream ss;
378 if (candidate.empty() || isdigit(candidate[0]))
379 ss << "_";
380 for (char c : candidate) {
381 if (isupper(c) || islower(c) || isdigit(c) || c == '_')
382 ss << c;
383 else
384 ss << '_';
385 }
386 return ss.str();
387 }
388 // if we have to assign 'v' a name, what should it be?
389 // use the debugName if it was set, otherwise generate a name.
genUniqueNameFortorch::jit::PythonPrintImpl390 std::string genUniqueNameFor(Value* v) {
391 return genName(
392 v->hasDebugName() ? makeValidIdentifier(v->debugNameBase()) : "_");
393 }
394
395 // map from Value to how it should be printed at each use
396 std::unordered_map<Value*, std::shared_ptr<TaggedStringStream>> expr_table_;
397 std::unordered_map<Value*, std::string> ident_refs_;
398
399 // NB: we MUST pass around the shared pointers to these streams by value.
400 // There is an interaction in splitLongInlines where the string value for
401 // both the RHS and the LHS of an expression are live at the same time,
402 // however the value for the RHS is overwritten in the table.
useOftorch::jit::PythonPrintImpl403 std::shared_ptr<TaggedStringStream> useOf(Value* v) const {
404 // Ident refs take precedent over expression refs, since presence in
405 // the ident ref table indicates we have already emitted a statement
406 // assigning the given value.
407 if (ident_refs_.count(v)) {
408 auto rv = std::make_shared<TaggedStringStream>(&source_range_stack_);
409 (*rv) << ident_refs_.at(v);
410 return rv;
411 }
412 if (expr_table_.count(v)) {
413 return expr_table_.at(v);
414 }
415 TORCH_INTERNAL_ASSERT(
416 false,
417 "Value (debug name: \"",
418 v->debugName(),
419 "\") was not present in either expressions table or ident refs table");
420 }
assignValuetorch::jit::PythonPrintImpl421 void assignValue(Value* v, const std::string& s) {
422 ident_refs_[v] = s;
423 }
assignValuetorch::jit::PythonPrintImpl424 void assignValue(Value* v, std::shared_ptr<TaggedStringStream> s) {
425 expr_table_[v] = std::move(s);
426 }
assignValuetorch::jit::PythonPrintImpl427 void assignValue(Value* v, Value* w) {
428 assignValue(v, useOf(w));
429 }
assignValuesToTheirUniqueNamestorch::jit::PythonPrintImpl430 void assignValuesToTheirUniqueNames(at::ArrayRef<Value*> values) {
431 for (auto v : values) {
432 assignValue(v, genUniqueNameFor(v));
433 }
434 }
435
436 size_t level = 0;
437 // indent to the current indent level
indenttorch::jit::PythonPrintImpl438 TaggedStringStream& indent() {
439 for (const auto i : c10::irange(level)) {
440 (void)i; // Suppress unused variable warning
441 body_ << " ";
442 }
443 return body_;
444 }
445
WithIndentedtorch::jit::PythonPrintImpl446 ResourceGuard WithIndented() {
447 level++;
448 return ResourceGuard([this] { level--; });
449 }
450
451 template <class T0, class T1, class F>
zipWithtorch::jit::PythonPrintImpl452 void zipWith(at::ArrayRef<T0> list_a, at::ArrayRef<T1> list_b, F action)
453 const {
454 auto it_a = list_a.begin();
455 auto it_b = list_b.begin();
456
457 if (list_a.size() != list_b.size()) {
458 AT_ERROR("Python printer expected 2 lists of same size");
459 }
460
461 for (; it_a != list_a.end(); ++it_a, ++it_b) {
462 action(*it_a, *it_b);
463 }
464 }
465
printValueListtorch::jit::PythonPrintImpl466 void printValueList(
467 TaggedStringStream& stmt,
468 at::ArrayRef<Value*> list,
469 const char* begin = "",
470 const char* end = "") {
471 stmt << begin;
472 auto delimiter = "";
473 for (auto* value : list) {
474 stmt << delimiter;
475 stmt << useOf(value);
476 delimiter = ", ";
477 }
478 stmt << end;
479 }
480
printValueIndextorch::jit::PythonPrintImpl481 void printValueIndex(TaggedStringStream& stmt, at::ArrayRef<Value*> inputs) {
482 const std::string val_name = useOf(inputs[0])->str();
483 if (isValidIdentifier(val_name)) {
484 stmt << val_name;
485 } else {
486 stmt << "(" << val_name << ")";
487 }
488 stmt << "[";
489 stmt << useOf(inputs[1]);
490 stmt << "]";
491 }
492
printDicttorch::jit::PythonPrintImpl493 void printDict(
494 TaggedStringStream& stmt,
495 at::ArrayRef<Value*> key_value_pairs,
496 const char* begin = "{",
497 const char* end = "}") {
498 stmt << begin;
499 auto delimiter = "";
500 for (size_t i = 0; i < key_value_pairs.size(); i += 2) {
501 stmt << delimiter;
502 auto key = key_value_pairs[i];
503 auto value = key_value_pairs[i + 1];
504
505 stmt << useOf(key) << ": " << useOf(value);
506
507 delimiter = ", ";
508 }
509 stmt << end;
510 }
511
printAssignmenttorch::jit::PythonPrintImpl512 void printAssignment(at::ArrayRef<Value*> lhs, at::ArrayRef<Value*> rhs) {
513 if (lhs.empty()) {
514 return;
515 }
516 indent();
517 printValueList(body_, lhs);
518 // We need to preserve Union/Optional type annotations, but only if
519 // we're not assigning values as part of a tuple unpacking statement
520 // (Python doesn't allow type annotations in multiple assignment)
521 if (lhs.size() == 1) {
522 Value* v = lhs.at(0);
523 if (!annotated_unions_.count(v) && !expr_table_.count(v) &&
524 (v->type()->kind() == UnionType::Kind ||
525 v->type()->kind() == OptionalType::Kind)) {
526 body_ << " : " << v->type()->annotation_str();
527 annotated_unions_.insert(v);
528 }
529 }
530 body_ << " = ";
531 // or if value is being assigned to something of a union type
532 printValueList(body_, rhs);
533 body_ << "\n";
534 }
535
requiresAnnotationtorch::jit::PythonPrintImpl536 bool requiresAnnotation(Value* lhs, Value* rhs) {
537 if (lhs->type()->kind() == UnionType::Kind ||
538 lhs->type()->kind() == OptionalType::Kind) {
539 return annotated_unions_.insert(lhs).second;
540 } else {
541 return *lhs->type() != *rhs->type();
542 }
543 }
544
printAnnotatedAssignmenttorch::jit::PythonPrintImpl545 void printAnnotatedAssignment(
546 at::ArrayRef<Value*> lhs,
547 at::ArrayRef<Value*> rhs) {
548 for (const auto i : c10::irange(lhs.size())) {
549 indent();
550 body_ << useOf(lhs[i]);
551 if (requiresAnnotation(lhs[i], rhs[i])) {
552 body_ << ": " << lhs[i]->type()->annotation_str(type_printer_);
553 }
554 body_ << " = " << useOf(rhs[i]) << "\n";
555 }
556 }
557
printIftorch::jit::PythonPrintImpl558 void printIf(IfView stmt) {
559 assignValuesToTheirUniqueNames(stmt.outputs());
560 indent() << "if " << useOf(stmt.cond()) << ":\n";
561 {
562 auto guard = WithIndented();
563 // Print node contents
564 printBlock(stmt.thenBlock(), !stmt.outputs().empty());
565 printAssignment(stmt.outputs(), stmt.thenOutputs());
566 }
567 indent() << "else:\n";
568 {
569 auto guard = WithIndented();
570 printBlock(stmt.elseBlock(), !stmt.outputs().empty());
571 printAssignment(stmt.outputs(), stmt.elseOutputs());
572 }
573 }
574
printLooptorch::jit::PythonPrintImpl575 void printLoop(LoopView stmt) {
576 // Loop carried dependencies are handled by assigning their initial
577 // values to the node->outputs() before the loop,
578 // and assign node->outputs() to the new values at the end of each trip.
579
580 auto loop_type = stmt.loopType();
581 if (loop_type == LoopView::ModifiedLoop) {
582 throw(
583 ErrorReport(stmt.node()->sourceRange())
584 << "loop cannot be printed as python "
585 << "because it has gone through an optimization "
586 << "that combined while and for loops. File a bug");
587 }
588
589 bool emit_as_for_loop = loop_type == LoopView::For;
590
591 assignValuesToTheirUniqueNames(stmt.carriedOutputs());
592 // Add aliases for loop-carried dependencies
593 zipWith(
594 stmt.bodyCarriedInputs(), // Start at 1 to ignore trip count
595 stmt.carriedOutputs(),
596 [&](Value* block_input, Value* node_output) {
597 assignValue(block_input, node_output);
598 });
599
600 // Print initial assignments of loop node outputs = loop node inputs
601 printAnnotatedAssignment(stmt.carriedOutputs(), stmt.carriedInputs());
602
603 assignValuesToTheirUniqueNames(stmt.currentTripCount());
604 // Loop header
605 if (emit_as_for_loop) {
606 indent();
607 body_ << "for " << useOf(stmt.currentTripCount()) << " in range("
608 << useOf(stmt.maxTripCount()) << "):\n";
609 } else {
610 // note: trip_count_in_block is unused because this is a while loop,
611 // so we reuse the Value* as a stand-in for the loop condition
612 printAssignment(stmt.currentTripCount(), stmt.inputCond());
613 indent();
614 body_ << "while " << useOf(stmt.currentTripCount()) << ":\n";
615 }
616 // Loop body
617 {
618 ResourceGuard indent = WithIndented();
619 // Update block outputs to block inputs for next loop iteration
620 // skip the assignment to the new condition in for loops because
621 // the condition is always True
622 size_t offset = emit_as_for_loop ? 1 : 0;
623 auto body_block = stmt.bodyBlock();
624 ArrayRef<Value*> loop_carried_block_inputs =
625 body_block->inputs().slice(offset);
626 printBlock(body_block, !loop_carried_block_inputs.empty());
627 printAssignment(
628 loop_carried_block_inputs, body_block->outputs().slice(offset));
629 }
630 }
631
isLongLinetorch::jit::PythonPrintImpl632 bool isLongLine(const std::string& str) {
633 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
634 return str.size() + level * 2 >= 40;
635 }
636
isLongInlinetorch::jit::PythonPrintImpl637 bool isLongInline(Node* node) {
638 return output_inline_.count(node) &&
639 isLongLine(useOf(node->output())->str());
640 }
641
isNonConstantInlinetorch::jit::PythonPrintImpl642 bool isNonConstantInline(Value* input) {
643 return input->node()->kind() != prim::Constant &&
644 output_inline_.count(input->node());
645 }
646
647 // [reordering of inlines]
648 // We inline anything that is semantically legal to inline, but sometimes
649 // we find that these lines get too long. In that case we break the lines
650 /// and it is important that we un-inline all the inputs preceeding the long
651 /// input:
652 // r = foo(x.add_(b), some_long + expression)
653 // wrong!
654 // _0 = some_long + expression
655 // r = foo(x.add_(b), _0) # wrong! _0 runs before mutating add_
656 // legal!
657 // _0 = x.add_(b)
658 // _1 = some_long + expression
659 // r = foo(_0, _1)
660
splitLongInlinestorch::jit::PythonPrintImpl661 void splitLongInlines(Value* v) {
662 std::vector<Value*> to_split_reversed;
663 Use u = v->uses().at(0);
664 scanLongInlines(u.user, u.offset, to_split_reversed);
665 for (auto it = to_split_reversed.rbegin(), end = to_split_reversed.rend();
666 it != end;
667 ++it) {
668 printOutputDefinition((*it)->node(), *useOf(*it));
669 }
670 }
671
scanLongInlinestorch::jit::PythonPrintImpl672 void scanLongInlines(
673 Node* user,
674 size_t offset,
675 std::vector<Value*>& to_split_reversed) {
676 auto it = visited_split_inline_uses_.find(user);
677 bool present = it != visited_split_inline_uses_.end();
678 for (int64_t i = static_cast<int64_t>(offset);
679 i >= (present ? it->second + 1 : 0);
680 --i) {
681 Value* prev_arg = user->input(i);
682 if (isNonConstantInline(prev_arg)) {
683 to_split_reversed.push_back(prev_arg);
684 }
685 }
686 visited_split_inline_uses_[user] = static_cast<int64_t>(offset);
687 if (!present && output_inline_.count(user)) {
688 Use u = user->output()->uses().at(0);
689 scanLongInlines(u.user, u.offset - 1, to_split_reversed);
690 // -1 because the actual use is still being
691 // emitted so it cannot be split
692 }
693 }
694
695 template <typename T>
printOutputDefinitiontorch::jit::PythonPrintImpl696 void printOutputDefinition(Node* node, const T& expr) {
697 assignValuesToTheirUniqueNames(node->outputs());
698 indent();
699 // Print outputs
700 if (!node->outputs().empty()) {
701 printValueList(body_, node->outputs());
702 body_ << " = ";
703 }
704 body_ << expr << "\n";
705 }
706
707 // Recursively check contained types for any class dependencies
registerClassDependenciestorch::jit::PythonPrintImpl708 void registerClassDependencies(const TypePtr& type) {
709 if (const auto classType = type->cast<ClassType>()) {
710 deps_table_.add(classType);
711 } else if (const auto tupleType = type->cast<TupleType>()) {
712 if (tupleType->name()) {
713 deps_table_.add(tupleType);
714 }
715 } else if (const auto interfaceType = type->cast<InterfaceType>()) {
716 deps_table_.add(interfaceType);
717 } else if (const auto enumType = type->cast<EnumType>()) {
718 deps_table_.add(enumType);
719 }
720 for (const auto& containedType : type->containedTypes()) {
721 registerClassDependencies(containedType);
722 }
723 }
scanTypeDependenciestorch::jit::PythonPrintImpl724 void scanTypeDependencies(Node* node) {
725 // Check for class dependencies. If this node inputs or outputs a class
726 // type, we need to add it to our table of dependencies.
727 for (const auto input : node->inputs()) {
728 registerClassDependencies(input->type());
729 }
730 for (const auto output : node->outputs()) {
731 registerClassDependencies(output->type());
732 }
733 for (const auto& name : node->attributeNames()) {
734 switch (node->kindOf(name)) {
735 case AttributeKind::ty:
736 registerClassDependencies(node->ty(name));
737 break;
738 case AttributeKind::tys:
739 for (const TypePtr& t : node->tys(name)) {
740 registerClassDependencies(t);
741 }
742 break;
743 default:
744 // noop
745 break;
746 }
747 }
748 }
749
checkVersiontorch::jit::PythonPrintImpl750 void checkVersion(Node* node) {
751 if (auto schema = node->maybeSchema()) {
752 auto schema_name = getFullSchemaName(*schema);
753 auto version_entry = get_operator_version_map().find(schema_name);
754 if (version_entry != get_operator_version_map().end()) {
755 const auto& entry = version_entry->second;
756 // TODO (tugsuu) move this calculation into a separate step.
757 uint64_t current_version = entry[entry.size() - 1].bumped_at_version;
758 uint64_t legacy_version_map_version =
759 get_min_version_for_kind(node->kind());
760
761 // True means we solely calculate based on upgrader version
762 if (get_version_calculator_flag()) {
763 min_version_ = std::max(min_version_, current_version);
764 } else {
765 if (legacy_version_map_version != 0) {
766 min_version_ = std::max(min_version_, legacy_version_map_version);
767 } else {
768 min_version_ = std::max(min_version_, current_version);
769 }
770 }
771 }
772 }
773 }
774
printNodetorch::jit::PythonPrintImpl775 void printNode(Node* node, bool print_const) {
776 WithSourceRange guard(&source_range_stack_, node);
777 scanTypeDependencies(node);
778 checkVersion(node);
779 if (!print_const && node->kind() == prim::Constant)
780 return;
781 switch (node->kind()) {
782 case prim::Return:
783 if (enforce_importable_ && node->inputs().size() != 1) {
784 throw(
785 ErrorReport(node->sourceRange())
786 << "Exportable methods must have a single return value. "
787 << "Normal use of ScriptMethods should enforce this");
788 }
789 if (!node->inputs().empty()) {
790 indent();
791 body_ << "return ";
792 printValueList(body_, node->inputs());
793 body_ << "\n";
794 }
795 break;
796 case prim::Loop:
797 printLoop(LoopView(node));
798 break;
799 case prim::If:
800 printIf(IfView(node));
801 break;
802 case prim::TupleUnpack:
803 case prim::ListUnpack:
804 assignValuesToTheirUniqueNames(node->outputs());
805 indent();
806 // TupleUnpack(unpacked) turns into an assignment op that forces
807 // the unpack to be inserted when parsed back in:
808 // a, b, = unpacked
809 // a, = unpacked # trailing comma forces an unpack to happen
810 if (!node->outputs().empty()) {
811 printValueList(body_, node->outputs(), "", ", = ");
812 }
813 body_ << useOf(node->input()) << "\n";
814 break;
815 case prim::SetAttr: {
816 const auto obj = node->inputs().at(0);
817 const auto newVal = node->inputs().at(1);
818 const auto type = obj->type()->expect<ClassType>();
819 const auto& attrname = node->s(attr::name);
820 indent();
821 body_ << useOf(obj) << "." << attrname << " = " << useOf(newVal)
822 << "\n";
823 } break;
824 case prim::fork: {
825 // the subgraph gets emitted as another function
826 auto name = genName("__forked_function");
827 std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
828 indent();
829 body_ << "def " << name << "():\n";
830 for (size_t i = 0; i < node->inputs().size(); ++i) {
831 assignValue(graph->inputs().at(i), node->inputs().at(i));
832 }
833 printBody(graph->block());
834 std::stringstream ss;
835 ss << "fork(" << name << ")";
836 printOutputDefinition(node, ss.str());
837 } break;
838 case prim::awaitable: {
839 // the subgraph gets emitted as another function
840 auto name = genName("__awaitable_function");
841 auto graph = node->g(attr::Subgraph);
842 indent();
843 body_ << "def " << name << "():\n";
844 for (size_t i = 0; i < node->inputs().size(); ++i) {
845 assignValue(graph->inputs().at(i), node->inputs().at(i));
846 }
847 printBody(graph->block());
848 std::stringstream ss;
849 ss << "awaitable(" << name << ")";
850 printOutputDefinition(node, ss.str());
851 } break;
852 case prim::Enter: {
853 const auto in = node->inputs().at(0);
854 const auto out = node->outputs().at(0);
855 indent();
856 body_ << "with " << useOf(in);
857 if (!out->uses().empty()) {
858 assignValue(out, genUniqueNameFor(out));
859 body_ << " as " << useOf(out);
860 }
861 body_ << ":\n";
862 level++;
863 } break;
864 case prim::Exit: {
865 // If the previous node is a prim::Enter, the with block the generated
866 // this Enter/Exit pair must have been empty.
867 if (node->prev()->kind() == prim::Enter) {
868 indent();
869 body_ << "pass\n";
870 }
871 level--;
872 } break;
873 case prim::Closure: {
874 if (enforce_importable_) {
875 throw(
876 ErrorReport(node->sourceRange())
877 << "closures are not exportable");
878 }
879 assignValuesToTheirUniqueNames(node->outputs());
880 auto name = useOf(node->output())->str();
881 std::shared_ptr<Graph> graph = node->g(attr::Subgraph);
882 indent();
883 body_ << "def " << name << "(";
884 assignValuesToTheirUniqueNames(graph->inputs());
885 for (size_t i = 0; i < graph->inputs().size(); ++i) {
886 Value* v = graph->inputs().at(i);
887 if (i > 0) {
888 body_ << ", ";
889 }
890 body_ << useOf(v) << ": " << v->type()->annotation_str(type_printer_);
891 }
892 body_ << "):\n";
893 printBody(graph->block());
894 } break;
895 case prim::ModuleContainerIndex: {
896 const auto container = node->inputs().at(0);
897 const auto key = node->inputs().at(1);
898 const auto out = node->outputs().at(0);
899 assignValuesToTheirUniqueNames(out);
900 indent();
901 body_ << useOf(out) << " : " << out->type()->annotation_str() << " = "
902 << useOf(container) << "[" << useOf(key) << "]\n";
903 } break;
904 default:
905 auto ss = std::make_shared<TaggedStringStream>(&source_range_stack_);
906 printRHS(*ss, node);
907
908 // we prevent long constants from inlining here.
909 // it is not safe to do the same thing for non-constants here
910 // because of [reordering of inlines]
911 if (output_inline_.count(node) == 0 ||
912 (node->kind() == prim::Constant && isLongLine(ss->str()))) {
913 printOutputDefinition(node, *ss);
914 } else {
915 // this node is safe to inline, so assign the output value
916 // to that expression directly
917 assignValue(node->output(), ss);
918 if (isLongLine(ss->str())) {
919 splitLongInlines(node->output());
920 }
921 }
922 }
923 }
924
containsNonASCIIStringtorch::jit::PythonPrintImpl925 static bool containsNonASCIIString(const IValue& val) {
926 bool hasNonASCII = false;
927 auto checkSubvalue = [&hasNonASCII](const IValue& val) {
928 if (val.isString()) {
929 // char's type is implementation designed signedness, likely
930 // signed on x86 and unsigned on ARM. But as of C++11, it is
931 // guaranteed to be twos complement. Therefore, converting to
932 // signed char gives us a range of [-128, 127]. Thus, any
933 // negative number is non-ascii.
934 for (signed char c : val.toStringRef()) {
935 if (c < 0) {
936 hasNonASCII = true;
937 return true;
938 }
939 }
940 }
941 return false;
942 };
943
944 val.visit(checkSubvalue);
945 return hasNonASCII;
946 }
947
printConstanttorch::jit::PythonPrintImpl948 void printConstant(TaggedStringStream& stmt, const IValue& v) {
949 const auto customFormatter = [&](std::ostream& ss, const IValue& v) {
950 if (v.isTensor() || containsNonASCIIString(v) || v.isObject()) {
951 TORCH_INTERNAL_ASSERT(!v.type<c10::Type>()->is_module());
952 ss << "CONSTANTS.c" << getOrAddConstant(v);
953 return true;
954 }
955
956 auto type = v.type();
957 if (auto dyn = type->castRaw<c10::DynamicType>()) {
958 type = dyn->fallback();
959 }
960 if (v.isTuple() && type->expectRef<TupleType>().schema()) {
961 // print the namedtuple constructor and let rest of tuple printing
962 // continue
963 ss << type->expectRef<TupleType>().annotation_str(type_printer_);
964 }
965 return false;
966 };
967
968 std::stringstream ss;
969 v.repr(ss, customFormatter);
970 stmt << ss.str();
971 }
972
printOpNametorch::jit::PythonPrintImpl973 void printOpName(TaggedStringStream& stmt, Symbol kind) {
974 // Special overriding ops set that requires serializing differently to
975 // preserve the original code semantics.
976 // This will be more properly handled when we have namespace semantics
977 // for serializing the ops, and it right now hard coded these ops to
978 // ensure consistency and not breaking BC in the future.
979 const static std::unordered_map<Symbol, std::string> override_symbols = {
980 {aten::backward, "torch.autograd.backward"},
981 {aten::grad, "torch.autograd.grad"},
982 };
983 if (override_symbols.find(kind) != override_symbols.end()) {
984 stmt << override_symbols.at(kind);
985 } else if (kind.is_aten()) {
986 // special case aten -> torch because we want to rename
987 // the aten namespace, but this change will take more time
988 // doing it here ensures we do not have fix up archives later
989 stmt << "torch." << kind.toUnqualString();
990 } else {
991 stmt << "ops." << kind.ns().toUnqualString() << "."
992 << kind.toUnqualString();
993 }
994 }
995
996 // Prints the RHS value of a Node, e.g. `aten.add(x, y)`
printRHStorch::jit::PythonPrintImpl997 void printRHS(TaggedStringStream& stmt, Node* node) {
998 switch (node->kind()) {
999 case prim::PythonOp: {
1000 auto value = static_cast<const PythonOp*>(node);
1001 if (enforce_importable_) {
1002 throw(
1003 ErrorReport(node->sourceRange())
1004 << "Could not export Python function call '" << value->name()
1005 << "'. Remove calls to Python functions before export. "
1006 << "Did you forget to add @script or @script_method annotation? "
1007 << "If this is a nn.ModuleList, add it to __constants__");
1008 }
1009 std::stringstream scalars_stream;
1010 stmt << "^" << value->name();
1011 value->writeScalars(scalars_stream);
1012 stmt << scalars_stream.str();
1013 printValueList(stmt, node->inputs(), "(", ")");
1014 } break;
1015 case prim::Uninitialized: {
1016 stmt << "uninitialized("
1017 << node->output()->type()->annotation_str(type_printer_) << ")";
1018 } break;
1019 case prim::Constant: {
1020 if (node->outputs().size() == 1 &&
1021 node->output()->type()->kind() == TypeKind::FunctionType) {
1022 auto fn = node->output()->type()->expect<FunctionType>();
1023 deps_table_.add(fn);
1024 stmt << fn->annotation_str(type_printer_);
1025 } else if (!node->mustBeNone()) {
1026 IValue v = toIValue(node->output()).value();
1027 printConstant(stmt, v);
1028 } else {
1029 stmt << "None";
1030 }
1031 } break;
1032 case aten::ScalarImplicit:
1033 case aten::FloatImplicit:
1034 case aten::IntImplicit: {
1035 stmt << "annotate("
1036 << node->output()->type()->annotation_str(type_printer_) << ", "
1037 << useOf(node->input()) << ")";
1038 } break;
1039 case aten::Int: {
1040 printValueList(stmt, node->inputs(), "int(", ")");
1041 } break;
1042 case aten::Float: {
1043 printValueList(stmt, node->inputs(), "float(", ")");
1044 } break;
1045 case aten::Bool: {
1046 printValueList(stmt, node->inputs(), "bool(", ")");
1047 } break;
1048 case aten::str: {
1049 printValueList(stmt, node->inputs(), "str(", ")");
1050 } break;
1051 case aten::__getitem__: {
1052 printValueIndex(stmt, node->inputs());
1053 } break;
1054 case prim::Print: {
1055 printValueList(stmt, node->inputs(), "print(", ")");
1056 } break;
1057 case aten::sorted: {
1058 printValueList(stmt, node->inputs(), "sorted(", ")");
1059 } break;
1060 case prim::TupleConstruct: {
1061 if (auto qualname =
1062 node->output()->type()->expectRef<TupleType>().name()) {
1063 stmt << node->output()->type()->annotation_str(type_printer_);
1064 }
1065 printValueList(
1066 stmt, node->inputs(), "(", node->inputs().size() == 1 ? ",)" : ")");
1067 } break;
1068 case prim::TupleIndex: {
1069 stmt << "(" << useOf(node->inputs().at(0)) << ")["
1070 << useOf(node->inputs().at(1)) << "]";
1071 } break;
1072 case prim::TupleSlice: {
1073 stmt << "(" << useOf(node->input()) << ")[" << node->i(attr::beg) << ":"
1074 << node->i(attr::end) << "]";
1075 } break;
1076 case prim::ListConstruct: {
1077 ListTypePtr list_type = node->output()->type()->expect<ListType>();
1078 TypePtr elem_type = list_type->getElementType();
1079 // Empty lists must be annotated with their type so the compiler knows
1080 // what type is supposed to be inside them
1081 if (node->inputs().empty()) {
1082 stmt << "annotate("
1083 << node->output()->type()->annotation_str(type_printer_)
1084 << ", [])";
1085 // If we can't infer the type based on what's inside, explicitly
1086 // annotate it to disambiguate.
1087 // This happens for List[Tensor] vs. List[Optional[Tensor]]
1088 } else if (!elementTypeCanBeInferredFromMembers(elem_type)) {
1089 stmt << "annotate("
1090 << node->output()->type()->annotation_str(type_printer_) << ", ";
1091 printValueList(stmt, node->inputs(), "[", "]");
1092 stmt << ")";
1093 // Otherwise just print a list
1094 } else {
1095 printValueList(stmt, node->inputs(), "[", "]");
1096 }
1097 } break;
1098 case prim::DictConstruct: {
1099 auto dict_type = node->output()->type()->expect<DictType>();
1100 // There are cases where we must annotate the dict with an explicit type
1101 // to help the compiler out:
1102 // - the dict is empty
1103 // - the dict has potentially ambiguous element types
1104 // (e.g. Tensor vs. Optional[Tensor])
1105 if (node->inputs().empty() ||
1106 !elementTypeCanBeInferredFromMembers(dict_type->getKeyType()) ||
1107 !elementTypeCanBeInferredFromMembers(dict_type->getValueType())) {
1108 stmt << "annotate("
1109 << node->output()->type()->annotation_str(type_printer_) << ", ";
1110 printDict(stmt, node->inputs());
1111 stmt << ")";
1112 // Otherwise just print a dict
1113 } else {
1114 printDict(stmt, node->inputs());
1115 }
1116 } break;
1117 case prim::CreateObject: {
1118 const auto classType = node->output()->type()->expect<ClassType>();
1119 stmt << classType->annotation_str(type_printer_) << ".__new__("
1120 << classType->annotation_str(type_printer_) << ")";
1121 } break;
1122 case prim::GetAttr: {
1123 const auto obj = node->inputs().at(0);
1124 const auto classType = obj->type()->expect<ClassType>();
1125 const auto& field = node->s(attr::name);
1126 if (isValidIdentifier(field)) {
1127 stmt << useOf(obj) << "." << field;
1128 } else {
1129 stmt << "getattr(" << useOf(obj) << ", ";
1130 std::stringstream field_stream;
1131 c10::printQuotedString(field_stream, field);
1132 stmt << field_stream.str() << ")";
1133 }
1134 } break;
1135 case prim::CallFunction: {
1136 stmt << useOf(node->inputs().at(0)) << "(";
1137 for (size_t i = 1; i < node->inputs().size(); i++) {
1138 stmt << useOf(node->inputs()[i]) << ", ";
1139 }
1140 stmt << ")";
1141 } break;
1142 case prim::CallMethod: {
1143 const auto& self = node->inputs().at(0);
1144 const auto& methodName = node->s(attr::name);
1145 stmt << "(" << useOf(self) << ")"
1146 << "." << methodName << "(";
1147 for (size_t i = 1; i < node->inputs().size(); i++) {
1148 stmt << useOf(node->inputs()[i]) << ", ";
1149 }
1150 stmt << ")";
1151
1152 if (auto selfClass = self->type()->cast<ClassType>()) {
1153 deps_table_.add(selfClass);
1154 const Function& method = selfClass->getMethod(node->s(attr::name));
1155 TORCH_INTERNAL_ASSERT(
1156 method.qualname() ==
1157 QualifiedName(selfClass->name()->qualifiedName(), methodName));
1158 } else if (auto selfInterface = self->type()->cast<InterfaceType>()) {
1159 deps_table_.add(selfInterface);
1160 } else {
1161 TORCH_INTERNAL_ASSERT(
1162 false, "method call to unhandled type in serialization");
1163 }
1164
1165 } break;
1166 case aten::_unwrap_optional: {
1167 printOpName(stmt, node->kind());
1168 stmt << "(";
1169 // we cannot recover the type of unwrap_optional(None),
1170 // using normal schema matching, so we route around this by rewriting
1171 // the call to unwrap_optional(annotated(Optional[T], None))
1172 if (node->input()->type()->isSubtypeOf(*NoneType::get()) ||
1173 node->input()->mustBeNone()) {
1174 auto input_type = OptionalType::create(node->output()->type());
1175 stmt << "annotate(" << input_type->annotation_str(type_printer_)
1176 << ", " << useOf(node->input()) << ")";
1177 } else {
1178 stmt << useOf(node->input());
1179 }
1180 stmt << ")";
1181 } break;
1182 // unchecked_unwrap_optional is no longer generated by the compiler,
1183 // but may end up here if it was first loaded from a old model and
1184 // re-saved. On re-save we upgrade it to an unchecked_cast, which is an
1185 // equivalent op
1186 case prim::unchecked_unwrap_optional:
1187 case prim::unchecked_cast: {
1188 stmt << "unchecked_cast("
1189 << node->output()->type()->annotation_str(type_printer_) << ", "
1190 << useOf(node->input()) << ")";
1191 } break;
1192 case prim::isinstance: {
1193 stmt << "isinstance(" << useOf(node->input()) << ", ";
1194 const auto& types = node->tys(attr::types);
1195 if (types.size() == 1) {
1196 stmt << types.at(0)->annotation_str(type_printer_);
1197 } else {
1198 // check multiple things, e.g. (str, list, int)
1199 stmt << "(";
1200 bool first = true;
1201 for (const TypePtr& typ : types) {
1202 if (!first) {
1203 stmt << ", ";
1204 }
1205 stmt << typ->annotation_str(type_printer_);
1206 first = false;
1207 }
1208 stmt << ")";
1209 }
1210 stmt << ")";
1211 } break;
1212 case prim::tolist: {
1213 stmt << "annotate("
1214 << node->output()->type()->annotation_str(type_printer_) << ", ";
1215 stmt << useOf(node->input(0)) << ".tolist()"
1216 << ")";
1217 } break;
1218 case prim::EnumValue:
1219 // Note: This CAN NOT be printed as raw operator ops.prim.EnumValue
1220 // because its return type depends on type of enum and must be further
1221 // resolved, but ops.prim.EnumValue construction does not provide such
1222 // functionality.
1223 stmt << "(" << useOf(node->input()) << ").value";
1224 break;
1225 case prim::EnumName:
1226 stmt << "(" << useOf(node->input()) << ").name";
1227 break;
1228 default: {
1229 printOpName(stmt, node->kind());
1230 const FunctionSchema& schema = node->schema();
1231 stmt << "(";
1232 // calculate how many args are specified.
1233 // see (https://github.com/pytorch/pytorch/pull/56079) for more
1234 // details.
1235 size_t num_schema_args = schema.arguments().size();
1236
1237 // we only want to do this extra logic only when necessary.
1238 if (num_schema_args > 0) {
1239 // calculate how many args are specified.
1240 // see (https://github.com/pytorch/pytorch/pull/56079) for more
1241 // details.
1242 auto specified_args =
1243 CalculateNecessaryArgs(schema.arguments(), node->inputs(), true);
1244
1245 auto num_necessary = specified_args.first;
1246 auto num_out = specified_args.second;
1247
1248 for (const auto i : c10::irange(static_cast<size_t>(num_necessary))) {
1249 if (i > 0)
1250 stmt << ", ";
1251 auto v = useOf(node->inputs().at(i));
1252 // print the kwarg name if it is a kwarg only argument.
1253 if (i < num_schema_args) {
1254 auto arg = schema.arguments().at(i);
1255 if (arg.kwarg_only()) {
1256 stmt << arg.name() << "=";
1257 }
1258 } else {
1259 // vararg functions like format can have extra arguments
1260 AT_ASSERT(schema.is_vararg());
1261 }
1262 stmt << *v;
1263 }
1264
1265 // print out args
1266 for (size_t i = num_schema_args - num_out; i < num_schema_args; i++) {
1267 stmt << ", ";
1268 auto arg = schema.arguments().at(i);
1269 TORCH_INTERNAL_ASSERT(arg.is_out());
1270 // figure out the corresponding input at this index
1271 auto input_idx = node->inputs().size() - (num_schema_args - i);
1272 if (input_idx < node->inputs().size()) {
1273 stmt << arg.name() << "=" << *useOf(node->inputs().at(input_idx));
1274 }
1275 }
1276 }
1277 stmt << ")";
1278 } break;
1279 }
1280 }
1281
printBlocktorch::jit::PythonPrintImpl1282 TaggedStringStream& printBlock(Block* root, bool block_has_other_statements) {
1283 // pythons weird 'pass' syntax creates a bunch of places where we have to
1284 // check if this block would be empty. But not everything in a block is a
1285 // node. Sometimes if, loop, and return statements will follow this block
1286 // and block_has_other_statements == true.
1287 if (!block_has_other_statements &&
1288 root->nodes().begin() == root->nodes().end()) {
1289 indent();
1290 body_ << "pass\n";
1291 }
1292 for (auto* node : root->nodes()) {
1293 printNode(node, /*print_const=*/false);
1294 }
1295 return body_;
1296 }
1297
1298 template <typename dtype>
createBroadListtorch::jit::PythonPrintImpl1299 IValue createBroadList(dtype value, const int64_t& N) {
1300 c10::List<dtype> repeated;
1301 repeated.reserve(N);
1302 for (const auto i : c10::irange(N)) {
1303 (void)i; // Suppress unused variable warning
1304 repeated.push_back(value);
1305 }
1306 return repeated;
1307 }
1308
printDefaultValuetorch::jit::PythonPrintImpl1309 void printDefaultValue(
1310 const Argument& arg,
1311 TaggedStringStream& stmt,
1312 const IValue& value) {
1313 stmt << "=";
1314 // handle broadcasting lists
1315 if (arg.type()->kind() == ListType::Kind &&
1316 (value.isInt() || value.isDouble() || value.isBool())) {
1317 TORCH_INTERNAL_ASSERT(arg.N(), "expected broadcastinglist");
1318 if (value.isInt()) {
1319 printConstant(stmt, createBroadList<int64_t>(value.toInt(), *arg.N()));
1320 } else if (value.isBool()) {
1321 printConstant(stmt, createBroadList<bool>(value.toBool(), *arg.N()));
1322 } else if (value.isDouble()) {
1323 printConstant(
1324 stmt, createBroadList<double>(value.toDouble(), *arg.N()));
1325 }
1326 } else {
1327 printConstant(stmt, value);
1328 }
1329 }
1330
printBodytorch::jit::PythonPrintImpl1331 void printBody(Block* body) {
1332 // we always print constants at the top of the function, in the order
1333 // in which they are used.
1334 std::vector<Node*> constants;
1335 buildConstantList(body, constants);
1336
1337 // current graph is used to de-dup names within a single graph
1338 scanBlock(body);
1339 {
1340 auto guard = WithIndented();
1341 // Print initial constant table (most are just inlined into their use,
1342 // but some like long strings do get emitted)
1343 for (Node* n : constants) {
1344 printNode(n, /*print_const=*/true);
1345 }
1346 // Print body
1347 printBlock(body, !body->return_node()->inputs().empty());
1348 printNode(body->return_node(), /*print_const=*/false);
1349 }
1350 }
1351
1352 public:
printFunctiontorch::jit::PythonPrintImpl1353 void printFunction(
1354 const Function& func,
1355 bool print_first_argument_type = true) {
1356 const FunctionSchema& schema = func.getSchema();
1357 Graph& graph = *toGraphFunction(func).graph();
1358 used_names_.clear(); // each graph can reuse local names
1359
1360 WithSourceRange guard(&source_range_stack_, graph.param_node());
1361
1362 indent();
1363 body_ << "def " << func.name() << "(";
1364 auto param_it = graph.inputs().begin();
1365 for (const Argument& arg : schema.arguments()) {
1366 registerClassDependencies(arg.type());
1367 std::string arg_name = genName(arg.name());
1368 if (param_it == graph.inputs().begin()) {
1369 // the first argument may omit its type when it is implied by context
1370 // the flag print_first_argument_type determines when to do this
1371 body_ << arg_name;
1372 if (print_first_argument_type) {
1373 body_ << ": " << arg.type()->annotation_str(type_printer_);
1374 annotated_unions_.insert(*param_it);
1375 }
1376 } else {
1377 body_ << ",\n " << arg_name << ": "
1378 << arg.type()->annotation_str(type_printer_);
1379 annotated_unions_.insert(*param_it);
1380 }
1381 if (arg.default_value()) {
1382 printDefaultValue(arg, body_, *arg.default_value());
1383 }
1384 assignValue(*param_it++, arg_name);
1385 }
1386
1387 const auto& returnType = schema.returns().at(0).type();
1388 body_ << ") -> " << returnType->annotation_str(type_printer_) << ":\n";
1389 registerClassDependencies(returnType);
1390
1391 printBody(graph.block());
1392 }
1393
printMethodtorch::jit::PythonPrintImpl1394 void printMethod(const Function& func) {
1395 printFunction(func, /*print_first_argument_type=*/false);
1396 }
1397
PythonPrintImpltorch::jit::PythonPrintImpl1398 PythonPrintImpl(
1399 std::vector<at::IValue>& constant_table,
1400 PrintDepsTable& deps_table,
1401 c10::TypePrinter type_printer,
1402 bool enforce_importable)
1403 : body_(&source_range_stack_),
1404 constant_table_(constant_table),
1405 deps_table_(deps_table),
1406 type_printer_(std::move(type_printer)),
1407 enforce_importable_(enforce_importable) {}
1408
printClasstorch::jit::PythonPrintImpl1409 void printClass(const ClassTypePtr& classType) {
1410 // If any of the methods are not Graph funtions, this indicates that
1411 // this class is a custom-bound C++ class. Skip serialization
1412 // of this class, we will depend on the ClassType being defined
1413 // in the target process.
1414 for (auto& method : classType->methods()) {
1415 if (!method->isGraphFunction()) {
1416 return;
1417 }
1418 }
1419
1420 bool is_module = classType->is_module();
1421 body_ << "class " << classType->name()->name();
1422 if (is_module) {
1423 body_ << "(Module)";
1424 }
1425
1426 body_ << ":\n";
1427 {
1428 const auto guard = WithIndented();
1429 size_t numAttrs = classType->numAttributes();
1430 // For modules, we need to print special information about the module's
1431 // attributes and parameters.
1432 if (is_module) {
1433 std::vector<std::string> params;
1434 std::vector<std::string> buffers;
1435 // Populate the __parameters__ field. This tells the importer which
1436 // attributes are parameters.
1437 for (const auto i : c10::irange(numAttrs)) {
1438 if (classType->is_parameter(i)) {
1439 params.push_back(classType->getAttributeName(i));
1440 }
1441 if (classType->is_buffer(i)) {
1442 buffers.push_back(classType->getAttributeName(i));
1443 }
1444 }
1445 indent();
1446 body_ << "__parameters__ = [";
1447 for (const auto& param : params) {
1448 body_ << "\"" << param << "\", ";
1449 }
1450 body_ << "]\n";
1451
1452 indent();
1453 body_ << "__buffers__ = [";
1454 for (const auto& buffer : buffers) {
1455 body_ << "\"" << buffer << "\", ";
1456 }
1457 body_ << "]\n";
1458 auto forwardPreHooks = classType->getForwardPreHooks();
1459 if (!forwardPreHooks.empty()) {
1460 indent();
1461 body_ << "__forward_pre_hooks__ = [";
1462 for (const auto& pre_hook : forwardPreHooks) {
1463 body_ << "\"" << pre_hook->name() << "\", ";
1464 }
1465 body_ << "]\n";
1466 }
1467
1468 auto forwardHooks = classType->getForwardHooks();
1469 if (!forwardHooks.empty()) {
1470 indent();
1471 body_ << "__forward_hooks__ = [";
1472 for (const auto& hook : forwardHooks) {
1473 body_ << "\"" << hook->name() << "\", ";
1474 }
1475 body_ << "]\n";
1476 }
1477 }
1478
1479 for (const auto i : c10::irange(numAttrs)) {
1480 const auto& name = classType->getAttributeName(i);
1481 const auto& type = classType->getAttribute(i);
1482 registerClassDependencies(type);
1483
1484 indent();
1485
1486 // Handling for when the attribute name is not a valid Python
1487 // identifier. This happens for, e.g. ModuleList.
1488 if (!isValidIdentifier(name)) {
1489 if (i == 0) {
1490 // Initialize the annotations dict if necessary.
1491 body_ << "__annotations__ = []\n";
1492 indent();
1493 }
1494 // Print out a direct manipulation of the annotations dict, like:
1495 // __annotations__["0"] = SomeType
1496 body_ << "__annotations__["
1497 << "\"" << name
1498 << "\"] = " << type->annotation_str(type_printer_) << "\n";
1499 } else {
1500 // Otherwise: just emit a python 3 attribute annotation, like:
1501 // foo : SomeType
1502 body_ << name << " : " << type->annotation_str(type_printer_) << "\n";
1503 }
1504 }
1505
1506 size_t numConstants = classType->numConstants();
1507 for (const auto i : c10::irange(numConstants)) {
1508 const auto& name = classType->getConstantName(i);
1509 IValue v = classType->getConstant(i);
1510
1511 indent();
1512 body_ << name << " : "
1513 << "Final[" << v.type()->annotation_str(type_printer_) << "] = ";
1514 auto ss = std::make_shared<TaggedStringStream>(&source_range_stack_);
1515 printConstant(*ss, v);
1516 body_ << ss->str() << "\n";
1517 }
1518
1519 // TODO fields
1520 for (auto& method : classType->methods()) {
1521 printFunction(*method);
1522 }
1523 std::set<std::string> already_printed;
1524 for (auto& hook : classType->getForwardHooks()) {
1525 if (already_printed.count(hook->name()) == 0) {
1526 already_printed.insert(hook->name());
1527 printFunction(*hook);
1528 }
1529 }
1530 for (auto& pre_hook : classType->getForwardPreHooks()) {
1531 if (already_printed.count(pre_hook->name()) == 0) {
1532 already_printed.insert(pre_hook->name());
1533 printFunction(*pre_hook);
1534 }
1535 }
1536 }
1537 }
1538
printNamedTypetorch::jit::PythonPrintImpl1539 void printNamedType(const c10::NamedTypePtr& type) {
1540 if (auto functionType = type->cast<FunctionType>()) {
1541 printFunction(*functionType->function());
1542 } else if (auto classType = type->cast<ClassType>()) {
1543 printClass(classType);
1544 } else if (auto tupleType = type->cast<TupleType>()) {
1545 TORCH_INTERNAL_ASSERT(tupleType->schema());
1546 body_ << "class " << tupleType->name()->name();
1547 body_ << "(NamedTuple):\n";
1548 {
1549 const auto guard = WithIndented();
1550 for (const auto& attr : tupleType->schema()->arguments()) {
1551 TORCH_INTERNAL_ASSERT(attr.type());
1552 indent();
1553 body_ << attr.name() << " : "
1554 << attr.type()->annotation_str(type_printer_) << "\n";
1555 }
1556 }
1557 } else if (auto interfaceType = type->cast<InterfaceType>()) {
1558 body_ << "class " << interfaceType->name()->name();
1559 if (interfaceType->is_module()) {
1560 body_ << "(ModuleInterface):\n";
1561 } else {
1562 body_ << "(Interface):\n";
1563 }
1564 {
1565 auto guard = WithIndented();
1566 for (const FunctionSchema& method : interfaceType->methods()) {
1567 indent();
1568 body_ << "def " << method.name() << "(self";
1569 TORCH_INTERNAL_ASSERT(
1570 !method.arguments().empty() &&
1571 method.arguments().at(0).name() == "self");
1572 for (const Argument& arg :
1573 at::ArrayRef<Argument>(method.arguments()).slice(1)) {
1574 const auto& arg_type = arg.type();
1575 registerClassDependencies(arg_type);
1576 body_ << ", " << arg.name() << ": "
1577 << arg_type->annotation_str(type_printer_);
1578 }
1579 auto return_type = method.returns().at(0).type();
1580 registerClassDependencies(return_type);
1581 body_ << ") -> " << return_type->annotation_str(type_printer_)
1582 << ":\n";
1583 indent();
1584 body_ << " pass\n";
1585 }
1586 }
1587 } else if (auto enumType = type->cast<EnumType>()) {
1588 body_ << "class " << enumType->qualifiedClassName().name() << "(Enum):\n";
1589
1590 std::string value_wrapper = "";
1591 if (enumType->getValueType() == StringType::get()) {
1592 value_wrapper = "\"";
1593 }
1594
1595 {
1596 auto guard = WithIndented();
1597 for (const auto& name_value : enumType->enumNamesValues()) {
1598 indent();
1599 body_ << name_value.first << " = " << value_wrapper
1600 << name_value.second << value_wrapper << "\n";
1601 }
1602 }
1603 } else {
1604 TORCH_INTERNAL_ASSERT(false, "Unhandled NamedType");
1605 }
1606 }
1607
1608 ~PythonPrintImpl() = default;
1609
1610 TaggedStringStream body_;
1611 // When printing this node, is it safe to write it inline (i.e. without
1612 // assigning a temporary variable
1613 std::unordered_set<Node*> output_inline_;
1614
1615 // see [reordering of inlines]
1616 // used to track parts of an inline statement we already scanned
1617 // for splitting long lines, so that we do not revisit them causing n^2
1618 // behavior. stores the maximum offset into inputs that has already been
1619 // scanned for the node.
1620 std::unordered_map<Node*, int64_t> visited_split_inline_uses_;
1621
1622 // what valid identifiers are in use for the current function
1623 std::unordered_set<std::string> used_names_;
1624
1625 // constants are written to this table, and given then named CONSTANTS.cN
1626 // where N is the index into this table.
1627 std::vector<at::IValue>& constant_table_;
1628
1629 // Any NamedTypes (classes, functions, NamedTuples) used are written to this
1630 // table.
1631 PrintDepsTable& deps_table_;
1632
1633 // We need to preserve Union/Optional type annotations, but we should
1634 // only print the annotation on variable declaration (not on any
1635 // following uses). This set tracks the Value*s that we've already
1636 // printed with annotations
1637 std::unordered_set<Value*> annotated_unions_;
1638
1639 // A function that, given a named type, returns us the correct string to print
1640 // for it.
1641 c10::TypePrinter type_printer_;
1642
1643 // when we print this, should we error if the resulting output would
1644 // not be able to be reparsed?
1645 bool enforce_importable_;
1646
1647 // The least version that supports all printed ops
1648 uint64_t min_version_ = caffe2::serialize::kMinSupportedFileFormatVersion;
1649 };
1650
PythonPrint(std::vector<at::IValue> & constant_table,PrintDepsTable & deps_table,c10::TypePrinter type_printer,bool enforce_importable)1651 PythonPrint::PythonPrint(
1652 std::vector<at::IValue>& constant_table,
1653 PrintDepsTable& deps_table,
1654 c10::TypePrinter type_printer,
1655 bool enforce_importable)
1656 : pImpl(std::make_shared<PythonPrintImpl>(
1657 constant_table,
1658 deps_table,
1659 std::move(type_printer),
1660 enforce_importable)) {}
1661
printNamedType(const c10::NamedTypePtr & type)1662 void PythonPrint::printNamedType(const c10::NamedTypePtr& type) {
1663 pImpl->printNamedType(type);
1664 }
1665
printFunction(const Function & func)1666 void PythonPrint::printFunction(const Function& func) {
1667 pImpl->printFunction(func);
1668 }
1669
printMethod(const Function & func)1670 void PythonPrint::printMethod(const Function& func) {
1671 pImpl->printMethod(func);
1672 }
1673
str() const1674 std::string PythonPrint::str() const {
1675 return pImpl->body_.str();
1676 }
1677
ranges() const1678 const SourceRangeRecords& PythonPrint::ranges() const {
1679 return pImpl->body_.ranges();
1680 }
1681
minVersion() const1682 uint64_t PythonPrint::minVersion() const {
1683 return pImpl->min_version_;
1684 }
1685
traverseIValueAndGetObjects(const IValue & ivalue)1686 static std::vector<IValue> traverseIValueAndGetObjects(const IValue& ivalue) {
1687 std::vector<IValue> result;
1688 std::vector<IValue> stack;
1689 stack.emplace_back(ivalue);
1690 while (!stack.empty()) {
1691 IValue head = stack.back();
1692 stack.pop_back();
1693 if (head.isObject()) {
1694 result.push_back(head);
1695 auto obj = head.toObject();
1696 ClassTypePtr type = obj->type();
1697 if (type->hasMethod("__getstate__")) {
1698 Function& getstate = type->getMethod("__getstate__");
1699 stack.emplace_back(getstate({obj}));
1700 } else {
1701 for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
1702 stack.emplace_back(obj->getSlot(i));
1703 }
1704 }
1705 } else if (ivalue.isGenericDict()) {
1706 for (const auto& kv : ivalue.toGenericDict()) {
1707 // skip key because key cannot be an object
1708 stack.emplace_back(kv.value());
1709 }
1710 } else if (ivalue.isList()) {
1711 for (const auto& v : ivalue.toList()) {
1712 stack.emplace_back(v);
1713 }
1714 } else if (ivalue.isTuple()) {
1715 for (const auto& v : ivalue.toTuple()->elements()) {
1716 stack.emplace_back(v);
1717 }
1718 }
1719 }
1720 return result;
1721 }
1722
printType(const c10::Type & type,torch::jit::TypeNameUniquer & type_name_uniquer)1723 static std::optional<std::string> printType(
1724 const c10::Type& type,
1725 torch::jit::TypeNameUniquer& type_name_uniquer) {
1726 if (auto dyn = type.castRaw<c10::DynamicType>()) {
1727 return dyn->fallback()->annotation_str(
1728 [&](auto&& t) { return printType(t, type_name_uniquer); });
1729 }
1730 auto namedType = type.cast<c10::NamedType>();
1731 if (namedType && namedType->name()) {
1732 return type_name_uniquer.getUniqueName(namedType).qualifiedName();
1733 }
1734 return std::nullopt;
1735 }
1736
jitModuleToPythonCodeAndConstants(const Module & module,ExtraFilesMap * jit_sources,std::vector<IValue> * constants)1737 void jitModuleToPythonCodeAndConstants(
1738 const Module& module,
1739 ExtraFilesMap* jit_sources, // output
1740 std::vector<IValue>* constants // output
1741 ) {
1742 std::vector<IValue> objects = traverseIValueAndGetObjects(module._ivalue());
1743 std::unordered_set<c10::QualifiedName> visited;
1744 PrintDepsTable class_deps;
1745 TypeNameUniquer uniquer;
1746 auto type_printer = [&](const c10::Type& t) { return printType(t, uniquer); };
1747
1748 // Group by prefix; because every prefix is a file.
1749 std::unordered_map<std::string, PythonPrint> grouped_by_prefix;
1750 for (const IValue& obj : objects) {
1751 ObjectPtr obj_ptr = obj.toObject();
1752 ClassTypePtr class_type = obj_ptr->type();
1753 class_deps.add(class_type);
1754 }
1755
1756 for (size_t i = 0; i < class_deps.size(); ++i) {
1757 // note: PythonPrint may extend class_deps, so re-checking .size() is
1758 // necessary
1759 auto type = class_deps[i];
1760 auto qualname = uniquer.getUniqueName(type);
1761 std::string qualifier = qualname.prefix();
1762 auto pp_iter = grouped_by_prefix.find(qualifier);
1763 if (pp_iter == grouped_by_prefix.end()) {
1764 pp_iter = grouped_by_prefix
1765 .emplace(
1766 qualifier,
1767 PythonPrint(
1768 *constants,
1769 class_deps,
1770 type_printer,
1771 /*enforce_importable=*/true))
1772 .first;
1773 }
1774 pp_iter->second.printNamedType(type);
1775 }
1776 for (const auto& kv : grouped_by_prefix) {
1777 (*jit_sources)[kv.first] = kv.second.str();
1778 }
1779 }
1780
1781 } // namespace torch::jit
1782