xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/interpreter/code_impl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <memory>
4 #include <unordered_map>
5 #include <vector>
6 
7 #include <c10/util/irange.h>
8 #include <torch/csrc/jit/api/function_impl.h>
9 #include <torch/csrc/jit/ir/ir.h>
10 #include <torch/csrc/jit/jit_log.h>
11 #include <torch/csrc/jit/passes/bailout_graph.h>
12 #include <torch/csrc/jit/runtime/calculate_necessary_args.h>
13 #include <torch/csrc/jit/runtime/graph_iterator.h>
14 #include <torch/csrc/jit/runtime/instruction.h>
15 #include <torch/csrc/jit/runtime/interpreter/preprocess_graph.h>
16 
17 C10_DECLARE_bool(torch_jit_enable_expanded_stacks);
18 
19 namespace torch::jit {
20 
21 std::ostream& operator<<(std::ostream& out, Instruction inst);
22 
23 namespace interpreter {
24 
25 template <class Ttarget, class Tsource>
safe_narrow_cast(Tsource v)26 Ttarget safe_narrow_cast(Tsource v) {
27   Ttarget res = static_cast<Ttarget>(v);
28   // Casting it back to check whether it overflew.
29   if (static_cast<Tsource>(res) != v) {
30     TORCH_WARN(
31         "ATTENTION: your model computation is overflowing, safe_narrow_cast<>() failed");
32     return v;
33   }
34   return res;
35 }
36 
37 // BailoutBlocks are used to temporarily store
38 // instructions (typically, argument LOADs and TAIL_CALL)
39 // generated for prim::BailOut nodes
40 // before they are merged back into
41 // CodeImpl._instructions_ by insertBailoutBlocks
42 struct BailoutBlock {
43   size_t jf_instruction_index; // this node gets patched to jump here on failure
44   std::vector<Instruction> instructions; // ends in a TAIL_CALL
45 
BailoutBlockBailoutBlock46   explicit BailoutBlock(size_t jf_index) : jf_instruction_index(jf_index) {}
47 };
48 
49 // for keeping track of the current node
50 struct WithCurrentNode {
WithCurrentNodeWithCurrentNode51   WithCurrentNode(Node** loc, Node* new_value) : loc_(loc), old_value_(*loc_) {
52     *loc = new_value;
53   }
~WithCurrentNodeWithCurrentNode54   ~WithCurrentNode() {
55     *loc_ = old_value_;
56   }
57 
58  private:
59   Node** loc_;
60   Node* old_value_;
61 };
62 
63 struct NodeSourceInfo {
64   const char* func_name_;
65   const char* file_name_;
66   size_t line_;
NodeSourceInfoNodeSourceInfo67   NodeSourceInfo() : func_name_(nullptr), file_name_(nullptr), line_(0) {}
68 };
69 
70 struct CodeImpl {
71   friend struct InterpreterState;
72   std::vector<Instruction> instructions_;
73 
74   const c10::unique_t node_stack_attr_symbol_ =
75       static_cast<c10::unique_t>(attr::node_stack_idx);
76   // Expanded inlined stacks as pointers to values in inlined call stack.
77   std::vector<std::vector<NodeSourceInfo>> expanded_node_stacks_;
78 
79   // same length as instructions.
80   // what node in the graph cause this
81   // instruction to be emitted?
82   std::vector<Node*> instructions_source_;
83   std::vector<IValue> constant_table_;
84   std::vector<Operation> operator_table_;
85 #ifndef NDEBUG
86   std::vector<Operator> full_operator_table_;
87 #endif
88   // map<(op name, num inputs), index in operator table>, to avoid duplicates,
89   // not including vararg operators
90   std::unordered_map<
91       std::pair<std::string, int>,
92       int,
93       std::function<size_t(const std::pair<std::string, int>& p)>>
94       operator_table_inv_;
95   std::vector<Function*> function_table_;
96   std::vector<std::unique_ptr<GraphFunction>> forked_functions_;
97   std::vector<std::unique_ptr<GraphFunction>> awaited_functions_;
98   std::vector<TypePtr> type_table_;
99   std::vector<std::function<void(std::vector<IValue>&)>>
100       profile_function_table_;
101 
102   int register_size_ = 0;
103   size_t n_outputs;
104   size_t n_inputs;
105   TypePtr return_type_;
106   std::string function_name_;
107 
108   // We MUST hold onto graph here because some Operators stored in the
109   // instruction lists have dependencies on meta-data stored in the graph
110   // that would be dead otherwise.
111   // It is also very useful for debugging interpreter problems to
112   // keep this around.
113   std::shared_ptr<Graph> graph_;
114   std::optional<std::vector<GraphExecutor*>> grad_executors_;
115   std::optional<std::vector<GraphExecutor*>> forward_executors_;
116   PreprocessGraph preprocess_;
117 
118   // map from unique of nodes to register in register table
119   std::unordered_map<Value*, int> value_to_reg_;
120 
121   // map from operator name to specified arguments
122   // Example: for a schema of aten::foo.str
123   // aten::foo.str(arg0: str="default", arg1: int=0,
124   //               arg2: bool=False, arg3: float=0.0)
125   // If the usages in a graph is:
126   //    aten::foo("somestr", arg1=0, arg2=True, arg3=0.0)
127   //    aten::foo("somestr", arg1=1, arg2=False, arg3=0.0)
128   // op_to_num_specified_args_["aten::foo.str"] = 3
129   // This is because for all usages, at most 3 args are used.
130   std::unordered_map<std::string, size_t> op_to_num_specified_args_;
131 
132   std::unordered_map<std::string, size_t> op_to_num_out_args_;
133 
134   // running count of uses as we emit. When we reach use_count_[v] =
135   // v.uses().size() we know it is the final use and we can move rather than
136   // load.
137   std::unordered_map<Value*, size_t> use_count_;
138 
139   Node* current_node_; // used in creation of code to keep track
140                        // of node being emitted
141   Node* last_inserted_op_ = nullptr;
142 
143   // out-of-line jumps for bailouts that are patched in at the end
144   std::vector<BailoutBlock> bailout_blocks_;
145   std::vector<std::unique_ptr<Function>> bailout_functions_;
146   size_t remaining_bailout_depth_;
147 
148   CodeImpl(
149       const std::shared_ptr<Graph>& graph,
150       std::string function_name,
151       size_t remaining_bailout_depth,
152       bool emit_instructions = true)
153       : operator_table_inv_(
154             0,
155             [](const std::pair<std::string, int>& p) {
156               return std::hash<std::string>()(p.first) ^
157                   std::hash<int>()(p.second);
158             }),
159         function_name_(std::move(function_name)),
160         preprocess_(*graph),
161         current_node_(preprocess_.graph->return_node()),
162         remaining_bailout_depth_(remaining_bailout_depth) {
163     graph_ = preprocess_.graph;
164     n_outputs = graph_->outputs().size();
165     if (n_outputs == 1) {
166       return_type_ = graph->outputs().at(0)->type();
167     } else {
168       return_type_ = TupleType::create(
169           fmap(graph->outputs(), [](const Value* v) { return v->type(); }));
170     }
171     n_inputs = graph_->inputs().size();
172     if (emit_instructions) {
173       // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
174       run();
175     }
176   }
177 
178   virtual ~CodeImpl() = default;
179 
180   // since subclass of CodeImpl needs to populate
181   // op_to_num_specified_args, we separate the calls
182   // that changes internals of CodeImpl into a separate
183   // function.
runCodeImpl184   virtual void run() {
185     emitCodeForBlock(graph_->block());
186     insertInstruction(RET);
187     // we deferred the emission of bailout blocks so they appear at the end
188     // emit them now and patch up the jumps
189     insertBailoutBlocks();
190   }
191 
constant_tableCodeImpl192   const std::vector<c10::IValue>& constant_table() const {
193     return constant_table_;
194   }
195 
request_bailoutCodeImpl196   void request_bailout(size_t index) {
197     auto count = index;
198     for (const auto instr_index : c10::irange(instructions_.size())) {
199       if (instructions_[instr_index].op == GUARD ||
200           instructions_[instr_index].op == FAIL_GUARD) {
201         if (count-- == 0) {
202           // patching GUARD to FAIL_GUARD
203           instructions_[instr_index].op = FAIL_GUARD;
204           GRAPH_DEBUG(
205               "Added a bailout request for ",
206               index,
207               " at instruction ",
208               instr_index);
209           break;
210         }
211       }
212     }
213   }
214 
instructionsCodeImpl215   const std::vector<Instruction>& instructions() const {
216     return instructions_;
217   }
218 
op_to_num_specified_argsCodeImpl219   const std::unordered_map<std::string, size_t>& op_to_num_specified_args()
220       const {
221     return op_to_num_specified_args_;
222   }
223 
instructions_sourceCodeImpl224   const std::vector<Node*>& instructions_source() const {
225     return instructions_source_;
226   }
227 
getSourceInfoFromSourceRangeCodeImpl228   NodeSourceInfo getSourceInfoFromSourceRange(const SourceRange& range) {
229     NodeSourceInfo nodeSource;
230     SourceRange r = range;
231     if (range.source()) {
232       if (auto orig = range.source()->findSourceRangeThatGenerated(r)) {
233         r = *orig;
234       }
235     }
236     if (r.source()) {
237       auto lineno = r.source()->lineno_for_offset(r.start());
238       nodeSource.line_ = r.source()->lineno_to_source_lineno(lineno);
239       if (r.source()->filename()) {
240         nodeSource.file_name_ = r.source()->filename().value().c_str();
241       }
242     }
243     return nodeSource;
244   }
245 
expandInlinedNodeStackCodeImpl246   void expandInlinedNodeStack(
247       const InlinedCallStackPtr& cs,
248       std::vector<NodeSourceInfo>* expandedstack) {
249     auto nodeSourceInfo = getSourceInfoFromSourceRange(cs->source_range());
250     nodeSourceInfo.func_name_ = cs->function_name().c_str();
251     expandedstack->emplace_back(nodeSourceInfo);
252 
253     if (cs->callee()) {
254       expandInlinedNodeStack(cs->callee().value(), expandedstack);
255     }
256   }
257 
getNodeStackCodeImpl258   void getNodeStack(
259       const Node* node,
260       std::vector<NodeSourceInfo>* expandedstack) {
261     if (current_node_->callstack()) {
262       expandInlinedNodeStack(current_node_->callstack().value(), expandedstack);
263     }
264     auto nodeSourceInfo = getSourceInfoFromSourceRange(node->sourceRange());
265     expandedstack->emplace_back(nodeSourceInfo);
266   }
267 
268   void insertInstruction(OpCode op, int64_t X = 0, uint64_t N = 0) {
269     instructions_.emplace_back(
270         op,
271         safe_narrow_cast<int32_t, int64_t>(X),
272         safe_narrow_cast<uint16_t, uint64_t>(N));
273     instructions_source_.emplace_back(current_node_);
274 
275     if (FLAGS_torch_jit_enable_expanded_stacks &&
276         !current_node_->hasAttribute(attr::node_stack_idx)) {
277       std::vector<NodeSourceInfo> expandedStack;
278       getNodeStack(current_node_, &expandedStack);
279       auto insertIdx = expanded_node_stacks_.size();
280       expanded_node_stacks_.emplace_back(expandedStack);
281       current_node_->i_(attr::node_stack_idx, insertIdx);
282     }
283 
284     // check that we didn't accidentally emit nodes out of topological order
285     if (op == OP) {
286       if (last_inserted_op_ != nullptr && current_node_ != last_inserted_op_ &&
287           current_node_->owningBlock() == last_inserted_op_->owningBlock()) {
288         TORCH_INTERNAL_ASSERT(
289             current_node_->isAfter(last_inserted_op_),
290             *current_node_,
291             " is not after ",
292             *last_inserted_op_);
293       }
294       last_inserted_op_ = current_node_;
295     }
296   }
297 
truncateInstructionsCodeImpl298   void truncateInstructions(size_t size) {
299     while (instructions_.size() > size) {
300       instructions_.pop_back();
301       instructions_source_.pop_back();
302     }
303   }
304 
createBailoutBlockCodeImpl305   void createBailoutBlock(size_t jf_index) {
306     bailout_blocks_.emplace_back(jf_index);
307     auto& bailout_instructions = bailout_blocks_.back().instructions;
308 
309     bailout_instructions.insert(
310         bailout_instructions.end(),
311         instructions_.begin() + jf_index + 1,
312         instructions_.end());
313     truncateInstructions(jf_index + 1);
314   }
315 
allocRegsCodeImpl316   int allocRegs(at::ArrayRef<Value*> vs) {
317     int result = register_size_ + 1;
318     for (Value* v : vs) {
319       AT_ASSERT(value_to_reg_.count(v) == 0);
320       value_to_reg_[v] = ++register_size_;
321     }
322     return result;
323   }
324 
registerForCodeImpl325   int registerFor(Value* v) {
326     return value_to_reg_.at(v);
327   }
328 
emitUseCodeImpl329   void emitUse(Value* input, bool drop) {
330     // drop - if true, we are not actually going to use this thing
331     // and we can short circuit doing many instructions here
332     // by either clearing the register (DROPR) or just popping the stack
333     // (DROP)
334     if (preprocess_.can_emit_inline[input->node()]) {
335       emitNode(input->node());
336       if (drop) {
337         insertInstruction(DROP);
338       }
339     } else {
340       int reg = registerFor(input);
341       bool moved = input->uses().size() == ++use_count_[input];
342 
343       OpCode op{};
344       if (input->node()->kind() == prim::Constant) {
345         op = LOADC;
346       } else if (moved) {
347         op = MOVE;
348       } else {
349         op = LOAD;
350       }
351 
352       if (drop) {
353         op = DROPR;
354       }
355       insertInstruction(op, reg);
356     }
357   }
358 
emitLoadInputsCodeImpl359   void emitLoadInputs(at::ArrayRef<Value*> inputs) {
360     for (Value* input : inputs) {
361       emitUse(input, false);
362     }
363   }
364 
emitLoadInputsCodeImpl365   void emitLoadInputs(at::ArrayRef<Value*> inputs, int num_include) {
366     int count = 0;
367     for (Value* input : inputs) {
368       if (count < num_include) {
369         emitUse(input, false);
370         count++;
371       }
372     }
373   }
374 
emitLoadInputsCodeImpl375   void emitLoadInputs(at::ArrayRef<Value*> inputs, size_t start, size_t end) {
376     for (size_t i = start; i < end; i++) {
377       emitUse(inputs[i], false);
378     }
379   }
380 
emitOperatorCodeImpl381   virtual void emitOperator(Node* node) {
382     emitLoadInputs(node->inputs());
383     const Operator& op = node->getOperator();
384     int num_inputs = node->inputs().size();
385     bool is_vararg = op.schema().is_vararg();
386 
387     int operation_index = add_to_operator_table(
388         op,
389         node,
390         c10::toString(op.schema().operator_name()),
391         num_inputs,
392         is_vararg);
393 
394     if (op.hasOperation() && is_vararg) {
395       insertInstruction(OPN, operation_index, num_inputs);
396     } else {
397       insertInstruction(OP, operation_index);
398     }
399   }
400 
emitWaitCodeImpl401   void emitWait(Node* node) {
402     emitLoadInputs(node->inputs());
403     insertInstruction(WAIT);
404   }
405 
emitDropCodeImpl406   void emitDrop(at::ArrayRef<Value*> to_drop) {
407     for (Value* input : to_drop) {
408       emitUse(input, true);
409     }
410   }
411 
emitStoreOutputsCodeImpl412   void emitStoreOutputs(Node* node) {
413     size_t N = node->outputs().size();
414     if (N == 0) {
415       return;
416     }
417     int regs = allocRegs(node->outputs());
418     if (N == 1) {
419       insertInstruction(STORE, regs);
420     } else {
421       insertInstruction(STOREN, regs, node->outputs().size());
422     }
423   }
424 
insertConstantCodeImpl425   int insertConstant(IValue value) {
426     int result = constant_table_.size();
427     constant_table_.emplace_back(std::move(value));
428     return result;
429   }
430 
431   virtual void emitOperatorOrInstruction(
432       Node* node,
433       OpCode op,
434       int64_t X = 0,
435       uint64_t N = 0,
436       bool emit_inputs = true) {
437     if (emit_inputs) {
438       emitLoadInputs(node->inputs());
439     }
440     insertInstruction(op, X, N);
441   }
442 
emitFormatCodeImpl443   void emitFormat(Node* node) {
444     emitOperatorOrInstruction(node, FORMAT, node->inputs().size(), 0);
445   }
446 
checkNodeAndEmitCodeImpl447   void checkNodeAndEmit(Node* node) {
448     // check if the node should be emitted as instruction or operator
449     const Operator& op = node->getOperator();
450     std::string unique_op_name = c10::toString(op.schema().operator_name());
451     if (unique_op_name.find("aten::__getitem__.Dict") == 0) {
452       // __get_item__ overloaded operator for Dict
453       // needs to be emitted an instruction
454       emitOperatorOrInstruction(node, DICT_INDEX);
455     } else {
456       emitOperator(node);
457     }
458   }
459 
emitConstantCodeImpl460   void emitConstant(Node* node) {
461     if (node->output()->type()->kind() == FunctionType::Kind) {
462       return;
463     }
464     // constants are just put in the constant table
465     value_to_reg_[node->output()] =
466         insertConstant(toIValue(node->output()).value());
467   }
468 
emitIfCodeImpl469   void emitIf(Node* node) {
470     emitLoadInputs(node->inputs());
471     size_t start_if = instructions_.size();
472     insertInstruction(JF, 0); // dummy offset to be filled in
473     emitCodeForBlock(node->blocks().at(0));
474     insertInstruction(JMP, 0); // dummy offset
475     size_t start_else = instructions_.size();
476     instructions_[start_if].X = start_else - start_if;
477     emitCodeForBlock(node->blocks().at(1));
478     instructions_[start_else - 1].X = instructions_.size() - (start_else - 1);
479   }
480 
emitLoopCodeImpl481   void emitLoop(Node* loop) {
482     insertInstruction(LOADC, insertConstant(0));
483     emitLoadInputs(loop->inputs());
484     size_t start = instructions_.size();
485     insertInstruction(LOOP, 0, loop->inputs().size()); // dummy offset
486     emitCodeForBlock(loop->blocks().at(0));
487     insertInstruction(JMP, start - instructions_.size());
488     instructions_[start].X = instructions_.size() - start;
489   }
490 
emitCallCodeImpl491   void emitCall(Function* func, at::ArrayRef<Value*> inputs) {
492     emitLoadInputs(inputs);
493     insertInstruction(CALL, function_table_.size());
494     function_table_.emplace_back(func);
495   }
496 
emitNodeAtBlockLevelCodeImpl497   void emitNodeAtBlockLevel(Node* node) {
498     WithCurrentNode guard(&current_node_, node);
499     switch (node->kind()) {
500       case prim::Constant:
501         emitConstant(node);
502         break;
503       case prim::Return:
504         emitLoadInputs(node->inputs());
505         break;
506       default:
507         if (!preprocess_.can_emit_inline[node]) {
508           emitNode(node);
509           emitStoreOutputs(node);
510         }
511         break;
512     }
513   }
514 
emitTypeCodeImpl515   size_t emitType(TypePtr t) {
516     size_t r = type_table_.size();
517     type_table_.emplace_back(std::move(t));
518     return r;
519   }
520 
emitTypeCheckCodeImpl521   void emitTypeCheck(Node* node) {
522     auto num_inputs = node->inputs().size();
523 
524     // Check that TypeCheck has at least one input.
525     TORCH_INTERNAL_ASSERT(
526         num_inputs && num_inputs + 1 == node->outputs().size());
527     emitLoadInputs(node->inputs());
528 
529     // Emit the expected type.
530     size_t types_start = type_table_.size();
531     auto types = node->tys(attr::types);
532     for (const auto i : c10::irange(num_inputs)) {
533       emitType(types[i]);
534     }
535     insertInstruction(TYPECHECK, types_start, num_inputs);
536   }
537 
emitGuardCodeImpl538   size_t emitGuard(Node* node) {
539     // unoptimized graph is at index 0
540     // guarded input is at index 1
541     // the rest of args follow
542     emitLoadInputs(node->inputs().slice(1, 1));
543     insertInstruction(GUARD, emitType(node->outputs().at(0)->type()));
544     insertInstruction(JF, 0 /* to be patched */);
545     return instructions_.size() - 1;
546   }
547 
emitBailOutCodeImpl548   void emitBailOut(Node* node) {
549     auto jf_index = emitGuard(node);
550     auto unoptimized_graph = node->inputs().at(0)->node()->g(attr::Subgraph);
551     // note, guaded input is already loaded onto the stack
552     // for GUARD instruction
553     emitLoadInputs(node->inputs().slice(2));
554     insertInstruction(TAIL_CALL, function_table_.size());
555     TORCH_INTERNAL_ASSERT(node->kind() == prim::BailOut);
556     auto bailout_index = node->i(attr::index);
557     TORCH_INTERNAL_ASSERT(bailout_index >= 0);
558 
559     auto build_bailout_graph = [bailout_index,
560                                 unoptimized_graph](GraphFunction& func) {
561       BuildBailOutGraphFrom(bailout_index, unoptimized_graph, func.graph());
562     };
563 
564     auto empty_graph = std::make_shared<Graph>();
565     auto func = std::make_unique<GraphFunction>(
566         "bailout", empty_graph, build_bailout_graph);
567     function_table_.emplace_back(func.get());
568     bailout_functions_.emplace_back(std::move(func));
569     createBailoutBlock(jf_index);
570   }
571 
emitProfileCodeImpl572   void emitProfile(Node* node) {
573     emitLoadInputs(node->inputs());
574     insertInstruction(PROFILE_OP, profile_function_table_.size());
575     if (node->cast<ProfileOp>()) {
576       profile_function_table_.push_back(node->cast<ProfileOp>()->getCallback());
577     } else if (node->cast<ProfileIValueOp>()) {
578       profile_function_table_.push_back(
579           node->cast<ProfileIValueOp>()->getCallback());
580     } else {
581       TORCH_INTERNAL_ASSERT(false);
582     }
583   }
584 
emitGetAttrCodeImpl585   void emitGetAttr(Node* node) {
586     emitLoadInputs(node->inputs());
587     const auto type = node->input()->type()->expect<ClassType>();
588     const auto& field = node->s(attr::name);
589     const auto slot = type->getAttributeSlot(field);
590     insertInstruction(GET_ATTR, slot);
591   }
592 
emitSetAttrCodeImpl593   void emitSetAttr(Node* node) {
594     emitLoadInputs(node->inputs());
595     const auto type = node->inputs().at(0)->type()->expect<ClassType>();
596     const auto& field = node->s(attr::name);
597     const auto slot = type->getAttributeSlot(field);
598     insertInstruction(SET_ATTR, slot);
599   }
600 
insertBailoutBlocksCodeImpl601   void insertBailoutBlocks() {
602     for (const BailoutBlock& block : bailout_blocks_) {
603       TORCH_INTERNAL_ASSERT(instructions_[block.jf_instruction_index].op == JF)
604       instructions_[block.jf_instruction_index].X =
605           instructions_.size() - block.jf_instruction_index;
606       instructions_.insert(
607           instructions_.end(),
608           block.instructions.begin(),
609           block.instructions.end());
610       instructions_source_.insert(
611           instructions_source_.end(),
612           block.instructions.size(),
613           instructions_source_[block.jf_instruction_index]);
614     }
615   }
emitInterfaceCallCodeImpl616   void emitInterfaceCall(
617       std::string method_name_str,
618       c10::ArrayRef<Value*> inputs) {
619     emitLoadInputs(inputs);
620     auto method_name = insertConstant(std::move(method_name_str));
621     insertInstruction(INTERFACE_CALL, method_name, inputs.size());
622   }
623 
emitListUnpackCodeImpl624   void emitListUnpack(Node* node) {
625     emitLoadInputs(node->inputs());
626     insertInstruction(LIST_UNPACK, node->outputs().size());
627   }
628 
emitTupleConstructCodeImpl629   void emitTupleConstruct(Node* node) {
630     bool named =
631         node->output()->type()->expectRef<TupleType>().name().has_value();
632     if (named) {
633       emitContainerConstruct(NAMED_TUPLE_CONSTRUCT, node);
634     } else {
635       emitLoadInputs(node->inputs());
636       insertInstruction(TUPLE_CONSTRUCT, node->inputs().size());
637     }
638   }
639 
emitContainerConstructCodeImpl640   void emitContainerConstruct(OpCode op, Node* node) {
641     emitLoadInputs(node->inputs());
642     insertInstruction(
643         op, emitType(node->output()->type()), node->inputs().size());
644   }
645 
emitCreateObjectCodeImpl646   void emitCreateObject(Node* node) {
647     insertInstruction(CREATE_OBJECT, emitType(node->output()->type()));
648   }
emitIsinstanceCodeImpl649   void emitIsinstance(Node* node) {
650     emitLoadInputs(node->inputs());
651     std::vector<TypePtr> types = node->tys(attr::types);
652     size_t types_start = type_table_.size();
653     for (const auto& typ : types) {
654       emitType(typ);
655     }
656     insertInstruction(ISINSTANCE, types_start, types.size());
657   }
658 
emitTupleSliceCodeImpl659   void emitTupleSlice(Node* node) {
660     emitLoadInputs(node->inputs());
661     int64_t beg_ind = node->i(attr::beg);
662     int64_t end_ind = node->i(attr::end);
663     insertInstruction(TUPLE_SLICE, beg_ind, end_ind - beg_ind);
664   }
665 
emitForkCodeImpl666   void emitFork(Node* node) {
667     emitLoadInputs(node->inputs());
668     auto forked_fn = std::make_unique<GraphFunction>(
669         "<forked function>", node->g(attr::Subgraph), nullptr);
670     forked_functions_.emplace_back(std::move(forked_fn));
671     function_table_.emplace_back(forked_functions_.back().get());
672     insertInstruction(FORK, function_table_.size() - 1, node->inputs().size());
673   }
674 
emitAwaitableCodeImpl675   void emitAwaitable(Node* node) {
676     emitLoadInputs(node->inputs());
677     auto await_fn = std::make_unique<GraphFunction>(
678         "<awaitable function>", node->g(attr::Subgraph), nullptr);
679     awaited_functions_.emplace_back(std::move(await_fn));
680     function_table_.emplace_back(awaited_functions_.back().get());
681     insertInstruction(
682         AWAITABLE, function_table_.size() - 1, node->inputs().size());
683   }
684 
emitWarnCodeImpl685   void emitWarn(Node* node) {
686     if (FLAGS_torch_jit_disable_warning_prints) {
687       return;
688     }
689 
690     emitLoadInputs(node->inputs());
691     int32_t idx = -1;
692     if (node->hasAttribute(attr::warn_id)) {
693       idx = static_cast<int32_t>(node->i(attr::warn_id));
694     }
695     insertInstruction(WARN, idx);
696   }
697 
emitEnterCodeImpl698   void emitEnter(Node* node) {
699     emitLoadInputs(node->inputs());
700     insertInstruction(ENTER);
701   }
702 
emitExitCodeImpl703   void emitExit(Node* /* node */) {
704     insertInstruction(EXIT);
705   }
706 
emitNodeCodeImpl707   void emitNode(Node* node) {
708     WithCurrentNode guard(&current_node_, node);
709     switch (node->kind()) {
710       default:
711         // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
712         checkNodeAndEmit(node);
713         // emitOperator(node);
714         break;
715       case prim::RaiseException:
716         emitOperatorOrInstruction(node, RAISE_EXCEPTION);
717         break;
718       case prim::TupleIndex:
719         emitOperatorOrInstruction(node, TUPLE_INDEX);
720         break;
721       case prim::Drop:
722         emitDrop(node->inputs());
723         break;
724       case prim::Constant:
725         emitConstant(node);
726         break;
727       case prim::If:
728         emitIf(node);
729         break;
730       case prim::Loop:
731         emitLoop(node);
732         break;
733       case aten::wait:
734         emitWait(node);
735         break;
736       case prim::Param:
737         break;
738       case prim::CallFunction:
739         emitCall(
740             node->inputs().at(0)->type()->expectRef<FunctionType>().function(),
741             node->inputs().slice(1));
742         break;
743       case prim::CallMethod:
744         if (auto class_type = node->inputs().at(0)->type()->cast<ClassType>()) {
745           emitCall(&class_type->getMethod(node->s(attr::name)), node->inputs());
746         } else {
747           emitInterfaceCall(node->s(attr::name), node->inputs());
748         }
749         break;
750       case prim::TypeCheck:
751         emitTypeCheck(node);
752         break;
753       case prim::BailOut:
754         emitBailOut(node);
755         break;
756       case prim::profile_ivalue:
757       case prim::profile:
758         emitProfile(node);
759         break;
760       case prim::GetAttr:
761         emitGetAttr(node);
762         break;
763       case prim::SetAttr:
764         emitSetAttr(node);
765         break;
766       case prim::ListUnpack:
767         emitListUnpack(node);
768         break;
769       case prim::TupleConstruct:
770         emitTupleConstruct(node);
771         break;
772       case prim::ListConstruct:
773         emitContainerConstruct(LIST_CONSTRUCT, node);
774         break;
775       case prim::DictConstruct:
776         emitContainerConstruct(DICT_CONSTRUCT, node);
777         break;
778       case prim::CreateObject:
779         emitCreateObject(node);
780         break;
781       case prim::isinstance:
782         emitIsinstance(node);
783         break;
784       case prim::TupleSlice:
785         emitTupleSlice(node);
786         break;
787       case prim::fork:
788         emitFork(node);
789         break;
790       case prim::awaitable:
791         emitAwaitable(node);
792         break;
793       case aten::warn:
794         emitWarn(node);
795         break;
796       case prim::Enter:
797         emitEnter(node);
798         break;
799       case prim::Exit:
800         emitExit(node);
801         break;
802       case prim::Uninitialized:
803         emitOperatorOrInstruction(node, UN_INITIALIZED, 0, 0, false);
804         break;
805       case prim::dtype:
806         emitOperatorOrInstruction(node, DTYPE);
807         break;
808       case prim::device:
809         emitOperatorOrInstruction(node, DEVICE);
810         break;
811       case aten::dim:
812         emitOperatorOrInstruction(node, DIM);
813         break;
814       case prim::is_cuda:
815         emitOperatorOrInstruction(node, IS_CUDA);
816         break;
817       case aten::__not__:
818         emitOperatorOrInstruction(node, __NOT__);
819         break;
820       case aten::format:
821         emitFormat(node);
822         break;
823       case aten::__is__:
824         emitOperatorOrInstruction(node, __IS__);
825         break;
826       case aten::__isnot__:
827         emitOperatorOrInstruction(node, __ISNOT__);
828         break;
829       case prim::NumToTensor:
830         emitOperatorOrInstruction(node, NUM_TO_TENSOR);
831         break;
832       case prim::tolist:
833         emitOperatorOrInstruction(node, TO_LIST);
834         break;
835     }
836   }
837 
emitCodeForBlockCodeImpl838   void emitCodeForBlock(Block* block) {
839     emitNodeAtBlockLevel(block->param_node());
840     for (auto node : block->nodes()) {
841       emitNodeAtBlockLevel(node);
842     }
843     emitNodeAtBlockLevel(block->return_node());
844   }
845 
grad_executorsCodeImpl846   const std::vector<GraphExecutor*>& grad_executors() {
847     if (!grad_executors_) {
848       grad_executors_.emplace();
849       for (Operation& op : operator_table_) {
850         if (auto executor = detail::getGradExecutor(op)) {
851           grad_executors_->push_back(executor);
852         }
853       }
854     }
855     return *grad_executors_;
856   }
857 
diff_graph_op_executorsCodeImpl858   const std::vector<GraphExecutor*>& diff_graph_op_executors() {
859     if (!forward_executors_) {
860       forward_executors_.emplace();
861       for (Operation& op : operator_table_) {
862         if (auto executor = detail::getDifferentiableGraphOpExecutor(op)) {
863           forward_executors_->push_back(executor);
864         }
865       }
866     }
867     return *forward_executors_;
868   }
869 
dumpCodeImpl870   void dump(std::ostream& out, size_t i) const {
871     out << i << " " << instructions_[i];
872     if (instructions_[i].op == OP || instructions_[i].op == CALL ||
873         instructions_[i].op == OPN) {
874       out << " # " << *instructions_source_[i];
875     } else {
876       out << "\n";
877     }
878   }
879 
dumpCodeImpl880   void dump(std::ostream& out) const {
881     out << *graph_ << "\n";
882     for (const auto i : c10::irange(instructions_.size())) {
883       dump(out, i);
884     }
885   }
886 
887   /**
888    * Add an operation to operator_table_ if not a duplicate and return its index
889    */
add_to_operator_tableCodeImpl890   int add_to_operator_table(
891       const Operator& op,
892       const Node* node,
893       const std::string& op_name,
894       const int num_inputs,
895       const bool is_vararg) {
896     int size = operator_table_.size();
897 
898     const Operation& oper = op.getOperation(node);
899 
900     if (!is_vararg) {
901       std::pair<std::string, int> key(op_name, num_inputs);
902       auto found = operator_table_inv_.find(key);
903 
904       if (found != operator_table_inv_.end()) {
905         return found->second;
906       }
907 
908       operator_table_inv_.emplace(key, size);
909     }
910 
911     operator_table_.emplace_back(oper);
912 #ifndef NDEBUG
913     full_operator_table_.emplace_back(op);
914 #endif
915     return size;
916   }
917 
assert_stack_sizeCodeImpl918   inline void assert_stack_size(
919       int32_t instruction_index,
920       size_t init_size,
921       size_t actual_size) const {
922 #ifndef NDEBUG
923     const auto& schema = full_operator_table_[instruction_index].schema();
924     int64_t expected_size = static_cast<int64_t>(init_size) -
925         static_cast<int64_t>(schema.arguments().size()) +
926         static_cast<int64_t>(schema.returns().size());
927     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
928         static_cast<size_t>(expected_size) == actual_size ||
929             schema.is_varret() || schema.is_vararg(),
930         "Expected to find ",
931         expected_size,
932         " values on the stack, but found ",
933         actual_size,
934         " on the stack after ",
935         toString(full_operator_table_[instruction_index].schema()));
936 #endif
937   }
938 };
939 
940 struct MobileCodeImpl : CodeImpl {
MobileCodeImplMobileCodeImpl941   MobileCodeImpl(
942       const std::shared_ptr<Graph>& graph,
943       std::string function_name,
944       bool emit_default_input_instructions,
945       bool support_default_args_before_out,
946       bool emit_promoted_ops,
947       size_t remaining_bailout_depth)
948       : CodeImpl(graph, function_name, remaining_bailout_depth, false),
949         emit_default_input_instructions_(emit_default_input_instructions),
950         support_default_args_before_out_(support_default_args_before_out),
951         emit_promoted_ops_(emit_promoted_ops) {
952     // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
953     run();
954   }
955 
runMobileCodeImpl956   void run() override {
957     process_ops_for_mobile();
958     emitCodeForBlock(graph_->block());
959     insertInstruction(RET);
960     // we deferred the emission of bailout blocks so they appear at the end
961     // emit them now and patch up the jumps
962     insertBailoutBlocks();
963   }
964 
process_ops_for_mobileMobileCodeImpl965   void process_ops_for_mobile() {
966     DepthFirstGraphNodeIterator graph_it(graph_);
967     Node* node = graph_it.next();
968     while (node) {
969       if (node->maybeOperator()) {
970         auto op_schema = node->getOperator().schema();
971         // skip if schema has vararg
972         if (!op_schema.is_vararg()) {
973           auto specifiedArgs = CalculateNecessaryArgs(
974               op_schema.arguments(),
975               node->inputs(),
976               support_default_args_before_out_);
977 
978           size_t numInclude = specifiedArgs.first +
979               (support_default_args_before_out_ ? specifiedArgs.second : 0);
980           auto unique_name = !op_schema.overload_name().empty()
981               ? op_schema.name() + "." + op_schema.overload_name()
982               : op_schema.name();
983           auto it = op_to_num_specified_args_.insert(
984               std::pair<std::string, size_t>(unique_name, 0));
985           op_to_num_out_args_.insert(std::pair<std::string, size_t>(
986               unique_name, specifiedArgs.second));
987           auto prev_value = it.first->second;
988           it.first->second = std::max(numInclude, prev_value);
989         }
990       }
991       node = graph_it.next();
992     }
993   }
994 
995  private:
emitOperatorMobileCodeImpl996   void emitOperator(Node* node) override {
997     if (emit_default_input_instructions_) {
998       CodeImpl::emitOperator(node);
999     } else {
1000       const Operator& op = node->getOperator();
1001       std::string unique_op_name = c10::toString(op.schema().operator_name());
1002       int num_inputs = node->inputs().size();
1003       bool is_vararg = op.schema().is_vararg();
1004 
1005       if (op.hasOperation() && is_vararg) {
1006         emitLoadInputs(node->inputs());
1007         int operation_index = add_to_operator_table(
1008             op,
1009             node,
1010             unique_op_name,
1011             num_inputs,
1012             /* is_vararg */ true);
1013         insertInstruction(OPN, operation_index, num_inputs);
1014       } else {
1015         auto num_include = num_inputs;
1016         auto it = op_to_num_specified_args_.find(unique_op_name);
1017         if (it != op_to_num_specified_args_.end()) {
1018           num_include = it->second;
1019         }
1020         if (support_default_args_before_out_) {
1021           auto num_out = op_to_num_out_args_.find(unique_op_name)->second;
1022           auto num_specified_before_out = num_include - num_out;
1023           emitLoadInputs(node->inputs(), 0, num_specified_before_out);
1024           emitLoadInputs(
1025               node->inputs(),
1026               node->inputs().size() - num_out,
1027               node->inputs().size());
1028         } else {
1029           emitLoadInputs(node->inputs(), num_include);
1030         }
1031         int operation_index = add_to_operator_table(
1032             op, node, unique_op_name, num_inputs, is_vararg);
1033         insertInstruction(OP, operation_index);
1034       }
1035     }
1036   }
1037 
1038   void emitOperatorOrInstruction(
1039       Node* node,
1040       OpCode op,
1041       int64_t X = 0,
1042       uint64_t N = 0,
1043       bool emit_inputs = true) override {
1044     if (emit_promoted_ops_) {
1045       CodeImpl::emitOperatorOrInstruction(node, op, X, N, emit_inputs);
1046     } else {
1047       CodeImpl::emitOperator(node);
1048     }
1049   }
1050 
1051   // To support forward compatibility for bytecode version bump from v5 to v6
1052   bool emit_default_input_instructions_;
1053   // To support forward compatibility for bytecode version bump from v6 to v7
1054   bool support_default_args_before_out_;
1055   // To support forward compatibility for bytecode version bump from v7 to v8
1056   bool emit_promoted_ops_;
1057 };
1058 
1059 } // namespace interpreter
1060 } // namespace torch::jit
1061