1 #pragma once
2
3 #include <memory>
4 #include <unordered_map>
5 #include <vector>
6
7 #include <c10/util/irange.h>
8 #include <torch/csrc/jit/api/function_impl.h>
9 #include <torch/csrc/jit/ir/ir.h>
10 #include <torch/csrc/jit/jit_log.h>
11 #include <torch/csrc/jit/passes/bailout_graph.h>
12 #include <torch/csrc/jit/runtime/calculate_necessary_args.h>
13 #include <torch/csrc/jit/runtime/graph_iterator.h>
14 #include <torch/csrc/jit/runtime/instruction.h>
15 #include <torch/csrc/jit/runtime/interpreter/preprocess_graph.h>
16
17 C10_DECLARE_bool(torch_jit_enable_expanded_stacks);
18
19 namespace torch::jit {
20
21 std::ostream& operator<<(std::ostream& out, Instruction inst);
22
23 namespace interpreter {
24
25 template <class Ttarget, class Tsource>
safe_narrow_cast(Tsource v)26 Ttarget safe_narrow_cast(Tsource v) {
27 Ttarget res = static_cast<Ttarget>(v);
28 // Casting it back to check whether it overflew.
29 if (static_cast<Tsource>(res) != v) {
30 TORCH_WARN(
31 "ATTENTION: your model computation is overflowing, safe_narrow_cast<>() failed");
32 return v;
33 }
34 return res;
35 }
36
37 // BailoutBlocks are used to temporarily store
38 // instructions (typically, argument LOADs and TAIL_CALL)
39 // generated for prim::BailOut nodes
40 // before they are merged back into
41 // CodeImpl._instructions_ by insertBailoutBlocks
42 struct BailoutBlock {
43 size_t jf_instruction_index; // this node gets patched to jump here on failure
44 std::vector<Instruction> instructions; // ends in a TAIL_CALL
45
BailoutBlockBailoutBlock46 explicit BailoutBlock(size_t jf_index) : jf_instruction_index(jf_index) {}
47 };
48
49 // for keeping track of the current node
50 struct WithCurrentNode {
WithCurrentNodeWithCurrentNode51 WithCurrentNode(Node** loc, Node* new_value) : loc_(loc), old_value_(*loc_) {
52 *loc = new_value;
53 }
~WithCurrentNodeWithCurrentNode54 ~WithCurrentNode() {
55 *loc_ = old_value_;
56 }
57
58 private:
59 Node** loc_;
60 Node* old_value_;
61 };
62
63 struct NodeSourceInfo {
64 const char* func_name_;
65 const char* file_name_;
66 size_t line_;
NodeSourceInfoNodeSourceInfo67 NodeSourceInfo() : func_name_(nullptr), file_name_(nullptr), line_(0) {}
68 };
69
70 struct CodeImpl {
71 friend struct InterpreterState;
72 std::vector<Instruction> instructions_;
73
74 const c10::unique_t node_stack_attr_symbol_ =
75 static_cast<c10::unique_t>(attr::node_stack_idx);
76 // Expanded inlined stacks as pointers to values in inlined call stack.
77 std::vector<std::vector<NodeSourceInfo>> expanded_node_stacks_;
78
79 // same length as instructions.
80 // what node in the graph cause this
81 // instruction to be emitted?
82 std::vector<Node*> instructions_source_;
83 std::vector<IValue> constant_table_;
84 std::vector<Operation> operator_table_;
85 #ifndef NDEBUG
86 std::vector<Operator> full_operator_table_;
87 #endif
88 // map<(op name, num inputs), index in operator table>, to avoid duplicates,
89 // not including vararg operators
90 std::unordered_map<
91 std::pair<std::string, int>,
92 int,
93 std::function<size_t(const std::pair<std::string, int>& p)>>
94 operator_table_inv_;
95 std::vector<Function*> function_table_;
96 std::vector<std::unique_ptr<GraphFunction>> forked_functions_;
97 std::vector<std::unique_ptr<GraphFunction>> awaited_functions_;
98 std::vector<TypePtr> type_table_;
99 std::vector<std::function<void(std::vector<IValue>&)>>
100 profile_function_table_;
101
102 int register_size_ = 0;
103 size_t n_outputs;
104 size_t n_inputs;
105 TypePtr return_type_;
106 std::string function_name_;
107
108 // We MUST hold onto graph here because some Operators stored in the
109 // instruction lists have dependencies on meta-data stored in the graph
110 // that would be dead otherwise.
111 // It is also very useful for debugging interpreter problems to
112 // keep this around.
113 std::shared_ptr<Graph> graph_;
114 std::optional<std::vector<GraphExecutor*>> grad_executors_;
115 std::optional<std::vector<GraphExecutor*>> forward_executors_;
116 PreprocessGraph preprocess_;
117
118 // map from unique of nodes to register in register table
119 std::unordered_map<Value*, int> value_to_reg_;
120
121 // map from operator name to specified arguments
122 // Example: for a schema of aten::foo.str
123 // aten::foo.str(arg0: str="default", arg1: int=0,
124 // arg2: bool=False, arg3: float=0.0)
125 // If the usages in a graph is:
126 // aten::foo("somestr", arg1=0, arg2=True, arg3=0.0)
127 // aten::foo("somestr", arg1=1, arg2=False, arg3=0.0)
128 // op_to_num_specified_args_["aten::foo.str"] = 3
129 // This is because for all usages, at most 3 args are used.
130 std::unordered_map<std::string, size_t> op_to_num_specified_args_;
131
132 std::unordered_map<std::string, size_t> op_to_num_out_args_;
133
134 // running count of uses as we emit. When we reach use_count_[v] =
135 // v.uses().size() we know it is the final use and we can move rather than
136 // load.
137 std::unordered_map<Value*, size_t> use_count_;
138
139 Node* current_node_; // used in creation of code to keep track
140 // of node being emitted
141 Node* last_inserted_op_ = nullptr;
142
143 // out-of-line jumps for bailouts that are patched in at the end
144 std::vector<BailoutBlock> bailout_blocks_;
145 std::vector<std::unique_ptr<Function>> bailout_functions_;
146 size_t remaining_bailout_depth_;
147
148 CodeImpl(
149 const std::shared_ptr<Graph>& graph,
150 std::string function_name,
151 size_t remaining_bailout_depth,
152 bool emit_instructions = true)
153 : operator_table_inv_(
154 0,
155 [](const std::pair<std::string, int>& p) {
156 return std::hash<std::string>()(p.first) ^
157 std::hash<int>()(p.second);
158 }),
159 function_name_(std::move(function_name)),
160 preprocess_(*graph),
161 current_node_(preprocess_.graph->return_node()),
162 remaining_bailout_depth_(remaining_bailout_depth) {
163 graph_ = preprocess_.graph;
164 n_outputs = graph_->outputs().size();
165 if (n_outputs == 1) {
166 return_type_ = graph->outputs().at(0)->type();
167 } else {
168 return_type_ = TupleType::create(
169 fmap(graph->outputs(), [](const Value* v) { return v->type(); }));
170 }
171 n_inputs = graph_->inputs().size();
172 if (emit_instructions) {
173 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
174 run();
175 }
176 }
177
178 virtual ~CodeImpl() = default;
179
180 // since subclass of CodeImpl needs to populate
181 // op_to_num_specified_args, we separate the calls
182 // that changes internals of CodeImpl into a separate
183 // function.
runCodeImpl184 virtual void run() {
185 emitCodeForBlock(graph_->block());
186 insertInstruction(RET);
187 // we deferred the emission of bailout blocks so they appear at the end
188 // emit them now and patch up the jumps
189 insertBailoutBlocks();
190 }
191
constant_tableCodeImpl192 const std::vector<c10::IValue>& constant_table() const {
193 return constant_table_;
194 }
195
request_bailoutCodeImpl196 void request_bailout(size_t index) {
197 auto count = index;
198 for (const auto instr_index : c10::irange(instructions_.size())) {
199 if (instructions_[instr_index].op == GUARD ||
200 instructions_[instr_index].op == FAIL_GUARD) {
201 if (count-- == 0) {
202 // patching GUARD to FAIL_GUARD
203 instructions_[instr_index].op = FAIL_GUARD;
204 GRAPH_DEBUG(
205 "Added a bailout request for ",
206 index,
207 " at instruction ",
208 instr_index);
209 break;
210 }
211 }
212 }
213 }
214
instructionsCodeImpl215 const std::vector<Instruction>& instructions() const {
216 return instructions_;
217 }
218
op_to_num_specified_argsCodeImpl219 const std::unordered_map<std::string, size_t>& op_to_num_specified_args()
220 const {
221 return op_to_num_specified_args_;
222 }
223
instructions_sourceCodeImpl224 const std::vector<Node*>& instructions_source() const {
225 return instructions_source_;
226 }
227
getSourceInfoFromSourceRangeCodeImpl228 NodeSourceInfo getSourceInfoFromSourceRange(const SourceRange& range) {
229 NodeSourceInfo nodeSource;
230 SourceRange r = range;
231 if (range.source()) {
232 if (auto orig = range.source()->findSourceRangeThatGenerated(r)) {
233 r = *orig;
234 }
235 }
236 if (r.source()) {
237 auto lineno = r.source()->lineno_for_offset(r.start());
238 nodeSource.line_ = r.source()->lineno_to_source_lineno(lineno);
239 if (r.source()->filename()) {
240 nodeSource.file_name_ = r.source()->filename().value().c_str();
241 }
242 }
243 return nodeSource;
244 }
245
expandInlinedNodeStackCodeImpl246 void expandInlinedNodeStack(
247 const InlinedCallStackPtr& cs,
248 std::vector<NodeSourceInfo>* expandedstack) {
249 auto nodeSourceInfo = getSourceInfoFromSourceRange(cs->source_range());
250 nodeSourceInfo.func_name_ = cs->function_name().c_str();
251 expandedstack->emplace_back(nodeSourceInfo);
252
253 if (cs->callee()) {
254 expandInlinedNodeStack(cs->callee().value(), expandedstack);
255 }
256 }
257
getNodeStackCodeImpl258 void getNodeStack(
259 const Node* node,
260 std::vector<NodeSourceInfo>* expandedstack) {
261 if (current_node_->callstack()) {
262 expandInlinedNodeStack(current_node_->callstack().value(), expandedstack);
263 }
264 auto nodeSourceInfo = getSourceInfoFromSourceRange(node->sourceRange());
265 expandedstack->emplace_back(nodeSourceInfo);
266 }
267
268 void insertInstruction(OpCode op, int64_t X = 0, uint64_t N = 0) {
269 instructions_.emplace_back(
270 op,
271 safe_narrow_cast<int32_t, int64_t>(X),
272 safe_narrow_cast<uint16_t, uint64_t>(N));
273 instructions_source_.emplace_back(current_node_);
274
275 if (FLAGS_torch_jit_enable_expanded_stacks &&
276 !current_node_->hasAttribute(attr::node_stack_idx)) {
277 std::vector<NodeSourceInfo> expandedStack;
278 getNodeStack(current_node_, &expandedStack);
279 auto insertIdx = expanded_node_stacks_.size();
280 expanded_node_stacks_.emplace_back(expandedStack);
281 current_node_->i_(attr::node_stack_idx, insertIdx);
282 }
283
284 // check that we didn't accidentally emit nodes out of topological order
285 if (op == OP) {
286 if (last_inserted_op_ != nullptr && current_node_ != last_inserted_op_ &&
287 current_node_->owningBlock() == last_inserted_op_->owningBlock()) {
288 TORCH_INTERNAL_ASSERT(
289 current_node_->isAfter(last_inserted_op_),
290 *current_node_,
291 " is not after ",
292 *last_inserted_op_);
293 }
294 last_inserted_op_ = current_node_;
295 }
296 }
297
truncateInstructionsCodeImpl298 void truncateInstructions(size_t size) {
299 while (instructions_.size() > size) {
300 instructions_.pop_back();
301 instructions_source_.pop_back();
302 }
303 }
304
createBailoutBlockCodeImpl305 void createBailoutBlock(size_t jf_index) {
306 bailout_blocks_.emplace_back(jf_index);
307 auto& bailout_instructions = bailout_blocks_.back().instructions;
308
309 bailout_instructions.insert(
310 bailout_instructions.end(),
311 instructions_.begin() + jf_index + 1,
312 instructions_.end());
313 truncateInstructions(jf_index + 1);
314 }
315
allocRegsCodeImpl316 int allocRegs(at::ArrayRef<Value*> vs) {
317 int result = register_size_ + 1;
318 for (Value* v : vs) {
319 AT_ASSERT(value_to_reg_.count(v) == 0);
320 value_to_reg_[v] = ++register_size_;
321 }
322 return result;
323 }
324
registerForCodeImpl325 int registerFor(Value* v) {
326 return value_to_reg_.at(v);
327 }
328
emitUseCodeImpl329 void emitUse(Value* input, bool drop) {
330 // drop - if true, we are not actually going to use this thing
331 // and we can short circuit doing many instructions here
332 // by either clearing the register (DROPR) or just popping the stack
333 // (DROP)
334 if (preprocess_.can_emit_inline[input->node()]) {
335 emitNode(input->node());
336 if (drop) {
337 insertInstruction(DROP);
338 }
339 } else {
340 int reg = registerFor(input);
341 bool moved = input->uses().size() == ++use_count_[input];
342
343 OpCode op{};
344 if (input->node()->kind() == prim::Constant) {
345 op = LOADC;
346 } else if (moved) {
347 op = MOVE;
348 } else {
349 op = LOAD;
350 }
351
352 if (drop) {
353 op = DROPR;
354 }
355 insertInstruction(op, reg);
356 }
357 }
358
emitLoadInputsCodeImpl359 void emitLoadInputs(at::ArrayRef<Value*> inputs) {
360 for (Value* input : inputs) {
361 emitUse(input, false);
362 }
363 }
364
emitLoadInputsCodeImpl365 void emitLoadInputs(at::ArrayRef<Value*> inputs, int num_include) {
366 int count = 0;
367 for (Value* input : inputs) {
368 if (count < num_include) {
369 emitUse(input, false);
370 count++;
371 }
372 }
373 }
374
emitLoadInputsCodeImpl375 void emitLoadInputs(at::ArrayRef<Value*> inputs, size_t start, size_t end) {
376 for (size_t i = start; i < end; i++) {
377 emitUse(inputs[i], false);
378 }
379 }
380
emitOperatorCodeImpl381 virtual void emitOperator(Node* node) {
382 emitLoadInputs(node->inputs());
383 const Operator& op = node->getOperator();
384 int num_inputs = node->inputs().size();
385 bool is_vararg = op.schema().is_vararg();
386
387 int operation_index = add_to_operator_table(
388 op,
389 node,
390 c10::toString(op.schema().operator_name()),
391 num_inputs,
392 is_vararg);
393
394 if (op.hasOperation() && is_vararg) {
395 insertInstruction(OPN, operation_index, num_inputs);
396 } else {
397 insertInstruction(OP, operation_index);
398 }
399 }
400
emitWaitCodeImpl401 void emitWait(Node* node) {
402 emitLoadInputs(node->inputs());
403 insertInstruction(WAIT);
404 }
405
emitDropCodeImpl406 void emitDrop(at::ArrayRef<Value*> to_drop) {
407 for (Value* input : to_drop) {
408 emitUse(input, true);
409 }
410 }
411
emitStoreOutputsCodeImpl412 void emitStoreOutputs(Node* node) {
413 size_t N = node->outputs().size();
414 if (N == 0) {
415 return;
416 }
417 int regs = allocRegs(node->outputs());
418 if (N == 1) {
419 insertInstruction(STORE, regs);
420 } else {
421 insertInstruction(STOREN, regs, node->outputs().size());
422 }
423 }
424
insertConstantCodeImpl425 int insertConstant(IValue value) {
426 int result = constant_table_.size();
427 constant_table_.emplace_back(std::move(value));
428 return result;
429 }
430
431 virtual void emitOperatorOrInstruction(
432 Node* node,
433 OpCode op,
434 int64_t X = 0,
435 uint64_t N = 0,
436 bool emit_inputs = true) {
437 if (emit_inputs) {
438 emitLoadInputs(node->inputs());
439 }
440 insertInstruction(op, X, N);
441 }
442
emitFormatCodeImpl443 void emitFormat(Node* node) {
444 emitOperatorOrInstruction(node, FORMAT, node->inputs().size(), 0);
445 }
446
checkNodeAndEmitCodeImpl447 void checkNodeAndEmit(Node* node) {
448 // check if the node should be emitted as instruction or operator
449 const Operator& op = node->getOperator();
450 std::string unique_op_name = c10::toString(op.schema().operator_name());
451 if (unique_op_name.find("aten::__getitem__.Dict") == 0) {
452 // __get_item__ overloaded operator for Dict
453 // needs to be emitted an instruction
454 emitOperatorOrInstruction(node, DICT_INDEX);
455 } else {
456 emitOperator(node);
457 }
458 }
459
emitConstantCodeImpl460 void emitConstant(Node* node) {
461 if (node->output()->type()->kind() == FunctionType::Kind) {
462 return;
463 }
464 // constants are just put in the constant table
465 value_to_reg_[node->output()] =
466 insertConstant(toIValue(node->output()).value());
467 }
468
emitIfCodeImpl469 void emitIf(Node* node) {
470 emitLoadInputs(node->inputs());
471 size_t start_if = instructions_.size();
472 insertInstruction(JF, 0); // dummy offset to be filled in
473 emitCodeForBlock(node->blocks().at(0));
474 insertInstruction(JMP, 0); // dummy offset
475 size_t start_else = instructions_.size();
476 instructions_[start_if].X = start_else - start_if;
477 emitCodeForBlock(node->blocks().at(1));
478 instructions_[start_else - 1].X = instructions_.size() - (start_else - 1);
479 }
480
emitLoopCodeImpl481 void emitLoop(Node* loop) {
482 insertInstruction(LOADC, insertConstant(0));
483 emitLoadInputs(loop->inputs());
484 size_t start = instructions_.size();
485 insertInstruction(LOOP, 0, loop->inputs().size()); // dummy offset
486 emitCodeForBlock(loop->blocks().at(0));
487 insertInstruction(JMP, start - instructions_.size());
488 instructions_[start].X = instructions_.size() - start;
489 }
490
emitCallCodeImpl491 void emitCall(Function* func, at::ArrayRef<Value*> inputs) {
492 emitLoadInputs(inputs);
493 insertInstruction(CALL, function_table_.size());
494 function_table_.emplace_back(func);
495 }
496
emitNodeAtBlockLevelCodeImpl497 void emitNodeAtBlockLevel(Node* node) {
498 WithCurrentNode guard(¤t_node_, node);
499 switch (node->kind()) {
500 case prim::Constant:
501 emitConstant(node);
502 break;
503 case prim::Return:
504 emitLoadInputs(node->inputs());
505 break;
506 default:
507 if (!preprocess_.can_emit_inline[node]) {
508 emitNode(node);
509 emitStoreOutputs(node);
510 }
511 break;
512 }
513 }
514
emitTypeCodeImpl515 size_t emitType(TypePtr t) {
516 size_t r = type_table_.size();
517 type_table_.emplace_back(std::move(t));
518 return r;
519 }
520
emitTypeCheckCodeImpl521 void emitTypeCheck(Node* node) {
522 auto num_inputs = node->inputs().size();
523
524 // Check that TypeCheck has at least one input.
525 TORCH_INTERNAL_ASSERT(
526 num_inputs && num_inputs + 1 == node->outputs().size());
527 emitLoadInputs(node->inputs());
528
529 // Emit the expected type.
530 size_t types_start = type_table_.size();
531 auto types = node->tys(attr::types);
532 for (const auto i : c10::irange(num_inputs)) {
533 emitType(types[i]);
534 }
535 insertInstruction(TYPECHECK, types_start, num_inputs);
536 }
537
emitGuardCodeImpl538 size_t emitGuard(Node* node) {
539 // unoptimized graph is at index 0
540 // guarded input is at index 1
541 // the rest of args follow
542 emitLoadInputs(node->inputs().slice(1, 1));
543 insertInstruction(GUARD, emitType(node->outputs().at(0)->type()));
544 insertInstruction(JF, 0 /* to be patched */);
545 return instructions_.size() - 1;
546 }
547
emitBailOutCodeImpl548 void emitBailOut(Node* node) {
549 auto jf_index = emitGuard(node);
550 auto unoptimized_graph = node->inputs().at(0)->node()->g(attr::Subgraph);
551 // note, guaded input is already loaded onto the stack
552 // for GUARD instruction
553 emitLoadInputs(node->inputs().slice(2));
554 insertInstruction(TAIL_CALL, function_table_.size());
555 TORCH_INTERNAL_ASSERT(node->kind() == prim::BailOut);
556 auto bailout_index = node->i(attr::index);
557 TORCH_INTERNAL_ASSERT(bailout_index >= 0);
558
559 auto build_bailout_graph = [bailout_index,
560 unoptimized_graph](GraphFunction& func) {
561 BuildBailOutGraphFrom(bailout_index, unoptimized_graph, func.graph());
562 };
563
564 auto empty_graph = std::make_shared<Graph>();
565 auto func = std::make_unique<GraphFunction>(
566 "bailout", empty_graph, build_bailout_graph);
567 function_table_.emplace_back(func.get());
568 bailout_functions_.emplace_back(std::move(func));
569 createBailoutBlock(jf_index);
570 }
571
emitProfileCodeImpl572 void emitProfile(Node* node) {
573 emitLoadInputs(node->inputs());
574 insertInstruction(PROFILE_OP, profile_function_table_.size());
575 if (node->cast<ProfileOp>()) {
576 profile_function_table_.push_back(node->cast<ProfileOp>()->getCallback());
577 } else if (node->cast<ProfileIValueOp>()) {
578 profile_function_table_.push_back(
579 node->cast<ProfileIValueOp>()->getCallback());
580 } else {
581 TORCH_INTERNAL_ASSERT(false);
582 }
583 }
584
emitGetAttrCodeImpl585 void emitGetAttr(Node* node) {
586 emitLoadInputs(node->inputs());
587 const auto type = node->input()->type()->expect<ClassType>();
588 const auto& field = node->s(attr::name);
589 const auto slot = type->getAttributeSlot(field);
590 insertInstruction(GET_ATTR, slot);
591 }
592
emitSetAttrCodeImpl593 void emitSetAttr(Node* node) {
594 emitLoadInputs(node->inputs());
595 const auto type = node->inputs().at(0)->type()->expect<ClassType>();
596 const auto& field = node->s(attr::name);
597 const auto slot = type->getAttributeSlot(field);
598 insertInstruction(SET_ATTR, slot);
599 }
600
insertBailoutBlocksCodeImpl601 void insertBailoutBlocks() {
602 for (const BailoutBlock& block : bailout_blocks_) {
603 TORCH_INTERNAL_ASSERT(instructions_[block.jf_instruction_index].op == JF)
604 instructions_[block.jf_instruction_index].X =
605 instructions_.size() - block.jf_instruction_index;
606 instructions_.insert(
607 instructions_.end(),
608 block.instructions.begin(),
609 block.instructions.end());
610 instructions_source_.insert(
611 instructions_source_.end(),
612 block.instructions.size(),
613 instructions_source_[block.jf_instruction_index]);
614 }
615 }
emitInterfaceCallCodeImpl616 void emitInterfaceCall(
617 std::string method_name_str,
618 c10::ArrayRef<Value*> inputs) {
619 emitLoadInputs(inputs);
620 auto method_name = insertConstant(std::move(method_name_str));
621 insertInstruction(INTERFACE_CALL, method_name, inputs.size());
622 }
623
emitListUnpackCodeImpl624 void emitListUnpack(Node* node) {
625 emitLoadInputs(node->inputs());
626 insertInstruction(LIST_UNPACK, node->outputs().size());
627 }
628
emitTupleConstructCodeImpl629 void emitTupleConstruct(Node* node) {
630 bool named =
631 node->output()->type()->expectRef<TupleType>().name().has_value();
632 if (named) {
633 emitContainerConstruct(NAMED_TUPLE_CONSTRUCT, node);
634 } else {
635 emitLoadInputs(node->inputs());
636 insertInstruction(TUPLE_CONSTRUCT, node->inputs().size());
637 }
638 }
639
emitContainerConstructCodeImpl640 void emitContainerConstruct(OpCode op, Node* node) {
641 emitLoadInputs(node->inputs());
642 insertInstruction(
643 op, emitType(node->output()->type()), node->inputs().size());
644 }
645
emitCreateObjectCodeImpl646 void emitCreateObject(Node* node) {
647 insertInstruction(CREATE_OBJECT, emitType(node->output()->type()));
648 }
emitIsinstanceCodeImpl649 void emitIsinstance(Node* node) {
650 emitLoadInputs(node->inputs());
651 std::vector<TypePtr> types = node->tys(attr::types);
652 size_t types_start = type_table_.size();
653 for (const auto& typ : types) {
654 emitType(typ);
655 }
656 insertInstruction(ISINSTANCE, types_start, types.size());
657 }
658
emitTupleSliceCodeImpl659 void emitTupleSlice(Node* node) {
660 emitLoadInputs(node->inputs());
661 int64_t beg_ind = node->i(attr::beg);
662 int64_t end_ind = node->i(attr::end);
663 insertInstruction(TUPLE_SLICE, beg_ind, end_ind - beg_ind);
664 }
665
emitForkCodeImpl666 void emitFork(Node* node) {
667 emitLoadInputs(node->inputs());
668 auto forked_fn = std::make_unique<GraphFunction>(
669 "<forked function>", node->g(attr::Subgraph), nullptr);
670 forked_functions_.emplace_back(std::move(forked_fn));
671 function_table_.emplace_back(forked_functions_.back().get());
672 insertInstruction(FORK, function_table_.size() - 1, node->inputs().size());
673 }
674
emitAwaitableCodeImpl675 void emitAwaitable(Node* node) {
676 emitLoadInputs(node->inputs());
677 auto await_fn = std::make_unique<GraphFunction>(
678 "<awaitable function>", node->g(attr::Subgraph), nullptr);
679 awaited_functions_.emplace_back(std::move(await_fn));
680 function_table_.emplace_back(awaited_functions_.back().get());
681 insertInstruction(
682 AWAITABLE, function_table_.size() - 1, node->inputs().size());
683 }
684
emitWarnCodeImpl685 void emitWarn(Node* node) {
686 if (FLAGS_torch_jit_disable_warning_prints) {
687 return;
688 }
689
690 emitLoadInputs(node->inputs());
691 int32_t idx = -1;
692 if (node->hasAttribute(attr::warn_id)) {
693 idx = static_cast<int32_t>(node->i(attr::warn_id));
694 }
695 insertInstruction(WARN, idx);
696 }
697
emitEnterCodeImpl698 void emitEnter(Node* node) {
699 emitLoadInputs(node->inputs());
700 insertInstruction(ENTER);
701 }
702
emitExitCodeImpl703 void emitExit(Node* /* node */) {
704 insertInstruction(EXIT);
705 }
706
emitNodeCodeImpl707 void emitNode(Node* node) {
708 WithCurrentNode guard(¤t_node_, node);
709 switch (node->kind()) {
710 default:
711 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
712 checkNodeAndEmit(node);
713 // emitOperator(node);
714 break;
715 case prim::RaiseException:
716 emitOperatorOrInstruction(node, RAISE_EXCEPTION);
717 break;
718 case prim::TupleIndex:
719 emitOperatorOrInstruction(node, TUPLE_INDEX);
720 break;
721 case prim::Drop:
722 emitDrop(node->inputs());
723 break;
724 case prim::Constant:
725 emitConstant(node);
726 break;
727 case prim::If:
728 emitIf(node);
729 break;
730 case prim::Loop:
731 emitLoop(node);
732 break;
733 case aten::wait:
734 emitWait(node);
735 break;
736 case prim::Param:
737 break;
738 case prim::CallFunction:
739 emitCall(
740 node->inputs().at(0)->type()->expectRef<FunctionType>().function(),
741 node->inputs().slice(1));
742 break;
743 case prim::CallMethod:
744 if (auto class_type = node->inputs().at(0)->type()->cast<ClassType>()) {
745 emitCall(&class_type->getMethod(node->s(attr::name)), node->inputs());
746 } else {
747 emitInterfaceCall(node->s(attr::name), node->inputs());
748 }
749 break;
750 case prim::TypeCheck:
751 emitTypeCheck(node);
752 break;
753 case prim::BailOut:
754 emitBailOut(node);
755 break;
756 case prim::profile_ivalue:
757 case prim::profile:
758 emitProfile(node);
759 break;
760 case prim::GetAttr:
761 emitGetAttr(node);
762 break;
763 case prim::SetAttr:
764 emitSetAttr(node);
765 break;
766 case prim::ListUnpack:
767 emitListUnpack(node);
768 break;
769 case prim::TupleConstruct:
770 emitTupleConstruct(node);
771 break;
772 case prim::ListConstruct:
773 emitContainerConstruct(LIST_CONSTRUCT, node);
774 break;
775 case prim::DictConstruct:
776 emitContainerConstruct(DICT_CONSTRUCT, node);
777 break;
778 case prim::CreateObject:
779 emitCreateObject(node);
780 break;
781 case prim::isinstance:
782 emitIsinstance(node);
783 break;
784 case prim::TupleSlice:
785 emitTupleSlice(node);
786 break;
787 case prim::fork:
788 emitFork(node);
789 break;
790 case prim::awaitable:
791 emitAwaitable(node);
792 break;
793 case aten::warn:
794 emitWarn(node);
795 break;
796 case prim::Enter:
797 emitEnter(node);
798 break;
799 case prim::Exit:
800 emitExit(node);
801 break;
802 case prim::Uninitialized:
803 emitOperatorOrInstruction(node, UN_INITIALIZED, 0, 0, false);
804 break;
805 case prim::dtype:
806 emitOperatorOrInstruction(node, DTYPE);
807 break;
808 case prim::device:
809 emitOperatorOrInstruction(node, DEVICE);
810 break;
811 case aten::dim:
812 emitOperatorOrInstruction(node, DIM);
813 break;
814 case prim::is_cuda:
815 emitOperatorOrInstruction(node, IS_CUDA);
816 break;
817 case aten::__not__:
818 emitOperatorOrInstruction(node, __NOT__);
819 break;
820 case aten::format:
821 emitFormat(node);
822 break;
823 case aten::__is__:
824 emitOperatorOrInstruction(node, __IS__);
825 break;
826 case aten::__isnot__:
827 emitOperatorOrInstruction(node, __ISNOT__);
828 break;
829 case prim::NumToTensor:
830 emitOperatorOrInstruction(node, NUM_TO_TENSOR);
831 break;
832 case prim::tolist:
833 emitOperatorOrInstruction(node, TO_LIST);
834 break;
835 }
836 }
837
emitCodeForBlockCodeImpl838 void emitCodeForBlock(Block* block) {
839 emitNodeAtBlockLevel(block->param_node());
840 for (auto node : block->nodes()) {
841 emitNodeAtBlockLevel(node);
842 }
843 emitNodeAtBlockLevel(block->return_node());
844 }
845
grad_executorsCodeImpl846 const std::vector<GraphExecutor*>& grad_executors() {
847 if (!grad_executors_) {
848 grad_executors_.emplace();
849 for (Operation& op : operator_table_) {
850 if (auto executor = detail::getGradExecutor(op)) {
851 grad_executors_->push_back(executor);
852 }
853 }
854 }
855 return *grad_executors_;
856 }
857
diff_graph_op_executorsCodeImpl858 const std::vector<GraphExecutor*>& diff_graph_op_executors() {
859 if (!forward_executors_) {
860 forward_executors_.emplace();
861 for (Operation& op : operator_table_) {
862 if (auto executor = detail::getDifferentiableGraphOpExecutor(op)) {
863 forward_executors_->push_back(executor);
864 }
865 }
866 }
867 return *forward_executors_;
868 }
869
dumpCodeImpl870 void dump(std::ostream& out, size_t i) const {
871 out << i << " " << instructions_[i];
872 if (instructions_[i].op == OP || instructions_[i].op == CALL ||
873 instructions_[i].op == OPN) {
874 out << " # " << *instructions_source_[i];
875 } else {
876 out << "\n";
877 }
878 }
879
dumpCodeImpl880 void dump(std::ostream& out) const {
881 out << *graph_ << "\n";
882 for (const auto i : c10::irange(instructions_.size())) {
883 dump(out, i);
884 }
885 }
886
887 /**
888 * Add an operation to operator_table_ if not a duplicate and return its index
889 */
add_to_operator_tableCodeImpl890 int add_to_operator_table(
891 const Operator& op,
892 const Node* node,
893 const std::string& op_name,
894 const int num_inputs,
895 const bool is_vararg) {
896 int size = operator_table_.size();
897
898 const Operation& oper = op.getOperation(node);
899
900 if (!is_vararg) {
901 std::pair<std::string, int> key(op_name, num_inputs);
902 auto found = operator_table_inv_.find(key);
903
904 if (found != operator_table_inv_.end()) {
905 return found->second;
906 }
907
908 operator_table_inv_.emplace(key, size);
909 }
910
911 operator_table_.emplace_back(oper);
912 #ifndef NDEBUG
913 full_operator_table_.emplace_back(op);
914 #endif
915 return size;
916 }
917
assert_stack_sizeCodeImpl918 inline void assert_stack_size(
919 int32_t instruction_index,
920 size_t init_size,
921 size_t actual_size) const {
922 #ifndef NDEBUG
923 const auto& schema = full_operator_table_[instruction_index].schema();
924 int64_t expected_size = static_cast<int64_t>(init_size) -
925 static_cast<int64_t>(schema.arguments().size()) +
926 static_cast<int64_t>(schema.returns().size());
927 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
928 static_cast<size_t>(expected_size) == actual_size ||
929 schema.is_varret() || schema.is_vararg(),
930 "Expected to find ",
931 expected_size,
932 " values on the stack, but found ",
933 actual_size,
934 " on the stack after ",
935 toString(full_operator_table_[instruction_index].schema()));
936 #endif
937 }
938 };
939
940 struct MobileCodeImpl : CodeImpl {
MobileCodeImplMobileCodeImpl941 MobileCodeImpl(
942 const std::shared_ptr<Graph>& graph,
943 std::string function_name,
944 bool emit_default_input_instructions,
945 bool support_default_args_before_out,
946 bool emit_promoted_ops,
947 size_t remaining_bailout_depth)
948 : CodeImpl(graph, function_name, remaining_bailout_depth, false),
949 emit_default_input_instructions_(emit_default_input_instructions),
950 support_default_args_before_out_(support_default_args_before_out),
951 emit_promoted_ops_(emit_promoted_ops) {
952 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
953 run();
954 }
955
runMobileCodeImpl956 void run() override {
957 process_ops_for_mobile();
958 emitCodeForBlock(graph_->block());
959 insertInstruction(RET);
960 // we deferred the emission of bailout blocks so they appear at the end
961 // emit them now and patch up the jumps
962 insertBailoutBlocks();
963 }
964
process_ops_for_mobileMobileCodeImpl965 void process_ops_for_mobile() {
966 DepthFirstGraphNodeIterator graph_it(graph_);
967 Node* node = graph_it.next();
968 while (node) {
969 if (node->maybeOperator()) {
970 auto op_schema = node->getOperator().schema();
971 // skip if schema has vararg
972 if (!op_schema.is_vararg()) {
973 auto specifiedArgs = CalculateNecessaryArgs(
974 op_schema.arguments(),
975 node->inputs(),
976 support_default_args_before_out_);
977
978 size_t numInclude = specifiedArgs.first +
979 (support_default_args_before_out_ ? specifiedArgs.second : 0);
980 auto unique_name = !op_schema.overload_name().empty()
981 ? op_schema.name() + "." + op_schema.overload_name()
982 : op_schema.name();
983 auto it = op_to_num_specified_args_.insert(
984 std::pair<std::string, size_t>(unique_name, 0));
985 op_to_num_out_args_.insert(std::pair<std::string, size_t>(
986 unique_name, specifiedArgs.second));
987 auto prev_value = it.first->second;
988 it.first->second = std::max(numInclude, prev_value);
989 }
990 }
991 node = graph_it.next();
992 }
993 }
994
995 private:
emitOperatorMobileCodeImpl996 void emitOperator(Node* node) override {
997 if (emit_default_input_instructions_) {
998 CodeImpl::emitOperator(node);
999 } else {
1000 const Operator& op = node->getOperator();
1001 std::string unique_op_name = c10::toString(op.schema().operator_name());
1002 int num_inputs = node->inputs().size();
1003 bool is_vararg = op.schema().is_vararg();
1004
1005 if (op.hasOperation() && is_vararg) {
1006 emitLoadInputs(node->inputs());
1007 int operation_index = add_to_operator_table(
1008 op,
1009 node,
1010 unique_op_name,
1011 num_inputs,
1012 /* is_vararg */ true);
1013 insertInstruction(OPN, operation_index, num_inputs);
1014 } else {
1015 auto num_include = num_inputs;
1016 auto it = op_to_num_specified_args_.find(unique_op_name);
1017 if (it != op_to_num_specified_args_.end()) {
1018 num_include = it->second;
1019 }
1020 if (support_default_args_before_out_) {
1021 auto num_out = op_to_num_out_args_.find(unique_op_name)->second;
1022 auto num_specified_before_out = num_include - num_out;
1023 emitLoadInputs(node->inputs(), 0, num_specified_before_out);
1024 emitLoadInputs(
1025 node->inputs(),
1026 node->inputs().size() - num_out,
1027 node->inputs().size());
1028 } else {
1029 emitLoadInputs(node->inputs(), num_include);
1030 }
1031 int operation_index = add_to_operator_table(
1032 op, node, unique_op_name, num_inputs, is_vararg);
1033 insertInstruction(OP, operation_index);
1034 }
1035 }
1036 }
1037
1038 void emitOperatorOrInstruction(
1039 Node* node,
1040 OpCode op,
1041 int64_t X = 0,
1042 uint64_t N = 0,
1043 bool emit_inputs = true) override {
1044 if (emit_promoted_ops_) {
1045 CodeImpl::emitOperatorOrInstruction(node, op, X, N, emit_inputs);
1046 } else {
1047 CodeImpl::emitOperator(node);
1048 }
1049 }
1050
1051 // To support forward compatibility for bytecode version bump from v5 to v6
1052 bool emit_default_input_instructions_;
1053 // To support forward compatibility for bytecode version bump from v6 to v7
1054 bool support_default_args_before_out_;
1055 // To support forward compatibility for bytecode version bump from v7 to v8
1056 bool emit_promoted_ops_;
1057 };
1058
1059 } // namespace interpreter
1060 } // namespace torch::jit
1061