1 #include <torch/csrc/jit/runtime/interpreter.h>
2
3 #include <ATen/Parallel.h>
4 #include <ATen/core/ivalue.h>
5 #include <ATen/record_function.h>
6 #include <c10/core/thread_pool.h>
7 #include <c10/macros/Macros.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/irange.h>
10 #include <torch/csrc/autograd/edge.h>
11 #include <torch/csrc/autograd/grad_mode.h>
12 #include <torch/csrc/autograd/profiler.h>
13 #include <torch/csrc/autograd/variable.h>
14 #include <torch/csrc/jit/api/compilation_unit.h>
15 #include <torch/csrc/jit/api/function_impl.h>
16 #include <torch/csrc/jit/ir/constants.h>
17 #include <torch/csrc/jit/ir/ir.h>
18 #include <torch/csrc/jit/jit_log.h>
19 #include <torch/csrc/jit/mobile/promoted_prim_ops.h>
20 #include <torch/csrc/jit/runtime/exception_message.h>
21 #include <torch/csrc/jit/runtime/graph_executor.h>
22 #include <torch/csrc/jit/runtime/instruction.h>
23 #include <torch/csrc/jit/runtime/interpreter/code_impl.h>
24 #include <torch/csrc/jit/runtime/interpreter/frame.h>
25 #include <torch/csrc/jit/runtime/jit_exception.h>
26 #include <torch/csrc/jit/runtime/operator.h>
27 #include <torch/csrc/jit/runtime/profiling_record.h>
28 #include <torch/csrc/jit/runtime/script_profile.h>
29 #include <torch/csrc/jit/runtime/vararg_functions.h>
30 #include <torch/csrc/utils/cpp_stacktraces.h>
31 #include <string>
32
33 #ifdef USE_RPC
34 #include <torch/csrc/distributed/autograd/context/container.h>
35 using torch::distributed::autograd::DistAutogradContainer;
36 #endif
37
38 #include <exception>
39 #include <memory>
40 #include <mutex>
41 #include <ostream>
42 #include <stdexcept>
43 #include <typeinfo>
44 #include <unordered_map>
45 #include <unordered_set>
46 #include <utility>
47 #include <vector>
48
49 C10_DEFINE_bool(
50 torch_jit_enable_rethrow_caught_exception,
51 false,
52 "enable rethrowing caught exception");
53
54 C10_DEFINE_bool(
55 torch_jit_enable_expanded_stacks,
56 false,
57 "When true we will attemps to pre-expand node stacks and cache expanded stacks.");
58
59 namespace torch::jit {
60
61 using CodeImpl = interpreter::CodeImpl;
62
63 // Before we translate to interpreter instructions, we do
64 // some preprocessing of the graph to turn it into a form that is closer
65 // to what the instructions will look like.
66 // In particular we:
67 // * Computes whether a input to a node is the last use, so we can issue MOVE
68 // rather than LOAD instructions.
69 // * Drop nodes are inserted for any node that is unused to create a dummy use
70 // that will cause the interpreter to free the node.
71 // A drop node just pops its input off the stack to ensure the interpreter
72 // releases references to nodes that are never used. Drop nodes are also
73 // inserted when the last use of a node is in some conditionally run control
74 // flow (e.g. one side of an If) and the interpreter must free the node only
75 // after the control flow has reconverged
76 // Outputs are:
77 // * graph - the post processed copy of g
78 // * move_flags[n] - a list of booleans, one for each input,
79 // indicating whether this is the last use of the value. The interpreter
80 // should generate a move rather than a copy in this case.
81
tensorTypeInCurrentExecutionContext(const at::Tensor & t)82 TensorTypePtr tensorTypeInCurrentExecutionContext(const at::Tensor& t) {
83 if (!t.defined()) {
84 return TensorType::get()->withUndefined();
85 }
86 auto r = TensorType::create(t);
87 if (!at::GradMode::is_enabled()) {
88 return r->withRequiresGrad(false);
89 }
90 return r;
91 }
92
93 namespace {
getDistAutogradContextId()94 inline int64_t getDistAutogradContextId() {
95 #ifdef USE_RPC
96 return DistAutogradContainer::currentContextId();
97 #else
98 return 0;
99 #endif
100 }
101 } // namespace
102
103 thread_local InterpreterStateImpl* tls_int_state_ptr_ = nullptr;
104 struct TLSCurrentInterpreterGuard {
TLSCurrentInterpreterGuardtorch::jit::TLSCurrentInterpreterGuard105 TLSCurrentInterpreterGuard(InterpreterStateImpl* state)
106 : prev_state_(tls_int_state_ptr_) {
107 tls_int_state_ptr_ = state;
108 }
109
~TLSCurrentInterpreterGuardtorch::jit::TLSCurrentInterpreterGuard110 ~TLSCurrentInterpreterGuard() {
111 tls_int_state_ptr_ = prev_state_;
112 }
113
114 private:
115 InterpreterStateImpl* prev_state_;
116 };
117
118 // InterpreterState state that and used to compute a Code
119 struct InterpreterStateImpl : c10::intrusive_ptr_target {
InterpreterStateImpltorch::jit::InterpreterStateImpl120 InterpreterStateImpl(const Code& code, TaskLauncher taskLauncher)
121 : taskLauncher_(std::move(taskLauncher)) {
122 enterFrame(code, 0);
123 }
124
125 private:
126 using Frame = torch::jit::interpreter::Frame;
127 struct WarnedNodes {
128 public:
129 // Inserts idx into warned_nodes_, returns a boolean indicates whether
130 // insertion actually happened (idx wasn't originally in the set).
inserttorch::jit::InterpreterStateImpl::WarnedNodes131 bool insert(int32_t idx) {
132 std::unique_lock<std::mutex> lock(mutex_);
133 return warned_nodes_.insert(idx).second;
134 }
135
136 private:
137 std::mutex mutex_;
138 std::unordered_set<int32_t> warned_nodes_;
139 };
140
141 WarnedNodes warned_nodes_;
142
143 // if we need to suspend, where do we reset the stack?
144 // answer: to where it was when we were called, not
145 // including any inputs to this function
146 int64_t stack_start_ = -1;
147 c10::intrusive_ptr<Future> future_;
148 TaskLauncher taskLauncher_;
149
150 // this holds all the tensors for this interpreter run
151 // we don't bother minimizing the size of this vector, since the extra
152 // memory used by the pointers in this will be small
153 // instead we are very aggressive about releasing tensors when they become
154 // dead to make sure memory management happens efficiently. We optimize for
155 // the case where derivatives are run with retain_graph=False in the case
156 // where it is true, then the interpreter and this array get copied if this
157 // every becomes a bottleneck then we _should_ consider minimizing the total
158 // number or register
159 std::vector<IValue> registers;
160
161 // A stack of objects that have been __enter__'d.
162 std::vector<IValue> entered_objects;
163
164 std::vector<Frame> frames;
165
intrusive_from_thistorch::jit::InterpreterStateImpl166 c10::intrusive_ptr<InterpreterStateImpl> intrusive_from_this() {
167 c10::raw::intrusive_ptr::incref(this);
168 return c10::intrusive_ptr<InterpreterStateImpl>::reclaim(this);
169 }
170
enterFrametorch::jit::InterpreterStateImpl171 void enterFrame(const Code& code, size_t base_pointer) {
172 frames.emplace_back(Frame{code.pImpl, 0, base_pointer, std::nullopt});
173 registers.resize(registers.size() + code.pImpl->register_size_);
174 }
175
leaveFrametorch::jit::InterpreterStateImpl176 void leaveFrame() {
177 registers.resize(registers.size() - frames.back().function->register_size_);
178 frames.pop_back();
179 }
180
callFunctiontorch::jit::InterpreterStateImpl181 void callFunction(
182 Function& f,
183 Stack& stack,
184 std::optional<size_t> bailOut = std::nullopt,
185 bool next = true) {
186 bool newFrame = f.call(stack, bailOut, [&](const Code& code) {
187 enterFrame(code, stack.size() - code.num_inputs());
188 checkAndStartRecordFunction(frames.back(), stack);
189 });
190 if (next) {
191 (frames.rbegin() + (newFrame ? 1 : 0))->pc++;
192 }
193 }
194
195 // relative to the end of the register list so that when we call
196 // functions we are referring to the registers of the currently executing
197 // function.
regtorch::jit::InterpreterStateImpl198 IValue& reg(size_t reg) {
199 return *(registers.end() - reg);
200 }
201
dumptorch::jit::InterpreterStateImpl202 void dump(std::ostream& out, const Stack& stack) const {
203 out << "Stack:\n";
204 for (const auto& val : stack) {
205 out << val;
206 out << "\n";
207 }
208 }
209
210 class StackSizeDidntChangeGuard {
211 public:
212 StackSizeDidntChangeGuard(const StackSizeDidntChangeGuard&) = delete;
213 StackSizeDidntChangeGuard(StackSizeDidntChangeGuard&&) = delete;
214 StackSizeDidntChangeGuard& operator=(const StackSizeDidntChangeGuard&) =
215 delete;
216 StackSizeDidntChangeGuard& operator=(StackSizeDidntChangeGuard&&) = delete;
217
StackSizeDidntChangeGuard(const Frame & frame,const torch::jit::Stack & stack,const Instruction & inst)218 StackSizeDidntChangeGuard(
219 const Frame& frame,
220 const torch::jit::Stack& stack,
221 const Instruction& inst)
222 : frame_(frame), stack_(stack), instX_(inst.X) {
223 // portable maybe_unused attribute.
224 (void)frame_;
225 (void)stack_;
226 (void)instX_;
227 (void)initialSize_;
228 }
229
callAssert() const230 void callAssert() const {
231 #ifndef NDEBUG
232 frame_.function->assert_stack_size(instX_, initialSize_, stack_.size());
233 #endif
234 }
235
236 private:
237 const Frame& frame_;
238 const torch::jit::Stack& stack_;
239 std::uint32_t instX_;
240 std::size_t initialSize_{stack_.size()};
241 };
242
243 struct C10_UNUSED DoNothing {};
244
245 #if defined(__GNUC__) || defined(__clang__)
246 #define JIT_USE_COMPUTED_GOTO
247 #endif
248 // Primitives for making interpreter internal state transitions.
249 // We maintain two local variables as the internal interpreter state:
250 // `frame` will be the current frame that the interpreter operators on.
251 // `inst` will the current instruction pointed to by program counter.
252 //
253 // Instruction blocks should be always declared through `INST` macro and
254 // the instruction body should always start with a `instGuard()` declaration.
255 // Also blocks should be ended properly with either `INST_NEXT` (for going
256 // to the next instruction), or `INST_DISPATCH` (for jumping to a computed
257 // position using `instFetch`).
258 #if defined(JIT_USE_COMPUTED_GOTO)
259 #define INST(NAME) \
260 NAME: \
261 label_##NAME
262 #define INST_DISPATCH goto* dispatch_table[inst.op]
263 #else
264 #define INST(NAME) NAME
265 #define INST_DISPATCH break
266 #endif
267 #define INST_NEXT \
268 inst = instFetch(1); \
269 INST_DISPATCH
270
271 template <bool EnableProfiling>
runTemplatetorch::jit::InterpreterStateImpl272 bool runTemplate(Stack& stack) {
273 // if we have never run before, then we might have to return the
274 // stack when we suspend, record where it starts so we return the right
275 // stack
276 if (stack_start_ == -1) {
277 TORCH_INTERNAL_ASSERT(stack.size() >= frames.back().function->n_inputs);
278 stack_start_ = stack.size() - frames.back().function->n_inputs;
279 } else {
280 // during restarts, all of the stack is always our own, so we leave
281 // nothing
282 stack_start_ = 0;
283 }
284
285 TLSCurrentInterpreterGuard g(this);
286 if (frames.back().pc == 0 && stack_start_ == 0) {
287 checkAndStartRecordFunction(frames.back(), stack);
288 }
289
290 #if defined(JIT_USE_COMPUTED_GOTO)
291 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
292 static void* dispatch_table[] = {
293 #define DISPATCH_TABLE_ENTRY(op, _) &&label_##op,
294 FORALL_OPCODES(DISPATCH_TABLE_ENTRY)
295 #undef DISPATCH_TABLE_ENTRY
296 };
297 #endif
298
299 try {
300 while (true) {
301 Frame& frame = frames.back();
302
303 auto instFetch = [&](auto x) {
304 return frame.function->instructions_[frame.pc += x];
305 };
306
307 auto instGuard = [&] {
308 if constexpr (!EnableProfiling) {
309 return DoNothing{};
310 } else {
311 return profiling::InstructionSpan{
312 *frame.function->instructions_source()[frame.pc]};
313 }
314 };
315
316 Instruction inst = instFetch(0);
317
318 auto stackSizeAssertGuard = [&] {
319 return StackSizeDidntChangeGuard{frame, stack, inst};
320 };
321
322 switch (inst.op) {
323 case INST(ENTER): {
324 auto _ = instGuard();
325 const auto& obj = peek(stack, 0, 1);
326 TORCH_INTERNAL_ASSERT(obj.isObject());
327 entered_objects.push_back(obj);
328 }
329 INST_NEXT;
330 case INST(EXIT): {
331 auto _ = instGuard();
332 auto obj = entered_objects.back().toObject();
333 auto& f = obj->type()->getMethod("__exit__");
334 push(stack, std::move(obj));
335 entered_objects.pop_back();
336 push(stack, IValue());
337 push(stack, IValue());
338 push(stack, IValue());
339 callFunction(f, stack);
340 continue;
341 }
342 case INST(OP): {
343 auto _ = instGuard();
344 auto stackSizeGuard = stackSizeAssertGuard();
345 frame.function->operator_table_[inst.X](stack);
346 stackSizeGuard.callAssert();
347 }
348 INST_NEXT;
349 case INST(OPN): {
350 auto _ = instGuard();
351 stack.emplace_back(inst.N);
352 auto stackSizeGuard = stackSizeAssertGuard();
353 frame.function->operator_table_[inst.X](stack);
354 stackSizeGuard.callAssert();
355 }
356 INST_NEXT;
357 case INST(LOAD): {
358 auto _ = instGuard();
359 stack.emplace_back(reg(inst.X));
360 }
361 INST_NEXT;
362 case INST(MOVE): {
363 auto _ = instGuard();
364 stack.emplace_back(std::move(reg(inst.X)));
365 }
366 INST_NEXT;
367 case INST(STORE): {
368 auto _ = instGuard();
369 reg(inst.X) = pop(stack);
370 }
371 INST_NEXT;
372 case INST(STOREN): {
373 auto _ = instGuard();
374 TORCH_INTERNAL_ASSERT(stack.size() >= inst.N);
375 for (size_t i = inst.N; i > 0; --i) {
376 reg(inst.X + i - 1) = pop(stack);
377 }
378 }
379 INST_NEXT;
380 case INST(DROP): {
381 auto _ = instGuard();
382 stack.pop_back();
383 }
384 INST_NEXT;
385 case INST(DROPR): {
386 auto _ = instGuard();
387 reg(inst.X) = IValue();
388 }
389 INST_NEXT;
390 case INST(LOADC): {
391 auto _ = instGuard();
392 stack.emplace_back(frame.function->constant_table_[inst.X]);
393 }
394 INST_NEXT;
395 case INST(GET_ATTR): {
396 auto _ = instGuard();
397 const auto& userObj = stack.back().toObjectRef();
398 stack.back() = userObj.getSlot(inst.X);
399 }
400 INST_NEXT;
401 case INST(SET_ATTR): {
402 auto _ = instGuard();
403 auto v = pop(stack);
404 auto& userObj = stack.back().toObjectRef();
405 userObj.setSlot(inst.X, std::move(v));
406 stack.pop_back();
407 }
408 INST_NEXT;
409 case INST(JF): {
410 auto _ = instGuard();
411 if (pop(stack).toBool()) {
412 inst = instFetch(1);
413 } else {
414 inst = instFetch(inst.X);
415 }
416 }
417 INST_DISPATCH;
418 case INST(JMP): {
419 auto _ = instGuard();
420 inst = instFetch(inst.X);
421 }
422 INST_DISPATCH;
423 case INST(LOOP): {
424 auto _ = instGuard();
425 // stack: iteration_count, max_iter, cond, loop_carried_deps...
426 auto fr = stack.end() - (inst.N + 1);
427 int64_t trip_count = fr[0].toInt();
428 int64_t max_trip_count = fr[1].toInt();
429 bool cond = fr[2].toBool();
430 if (trip_count < max_trip_count && cond) {
431 fr[2] = trip_count;
432 fr[0] = trip_count + 1;
433 inst = instFetch(1);
434 } else {
435 size_t n_loop_carried = inst.N - 2;
436 for (const auto i : c10::irange(n_loop_carried)) {
437 fr[i] = std::move(fr[i + 3]);
438 }
439 drop(stack, 3); // iteration_count, max_iter, cond
440 inst = instFetch(inst.X);
441 }
442 }
443 INST_DISPATCH;
444 case INST(CALL): {
445 auto _ = instGuard();
446 Function* fn = frame.function->function_table_[inst.X];
447 callFunction(*fn, stack);
448 continue;
449 }
450 case INST(INTERFACE_CALL): {
451 auto _ = instGuard();
452 // note the hash table lookup to find the function
453 // this can be more optimized if necessary, caching parts
454 // of the hashing computation or storing the offset when
455 // the object is turned into an interface
456
457 // consider passing
458 // `frames.back().function->remaining_bailout_depth_` into
459 // `get_executor().getPlanFor()` to propagate caller's depth
460 // restrictions onto children while this strategy has a potential to
461 // reduce the number of compilations for too dynamic callers we
462 // might miss opportunities where a caller is dynamic but a callee
463 // gets stable arguments
464 Function& function =
465 peek(stack, 0, inst.N)
466 .toObject()
467 ->type()
468 ->getMethod(
469 frame.function->constant_table_[inst.X].toStringRef());
470 callFunction(function, stack);
471 continue;
472 }
473 case INST(RET): {
474 if (frames.size() > 1) {
475 leaveFrame();
476 continue;
477 }
478 if (future_) {
479 auto num_outputs = frames.back().function->n_outputs;
480 if (num_outputs == 1) {
481 future_->markCompleted(stack.back());
482 } else {
483 future_->markCompleted(
484 c10::ivalue::Tuple::create(jit::last(stack, num_outputs)));
485 }
486 }
487 // destroy the last frame and call RecordFunction's end callbacks
488 leaveFrame();
489 return false;
490 }
491 case INST(WAIT): {
492 auto _ = instGuard();
493 auto future = stack.back().toFuture();
494 if (!future->completed()) {
495 getOrCreateFuture();
496
497 // callback needs to be a struct rather than a lambda so that
498 // we can move the stack to the other thread
499 struct Callback {
500 Callback(
501 c10::intrusive_ptr<InterpreterStateImpl> state,
502 Stack stack)
503 : stateImpl_(std::move(state)),
504 state_(stateImpl_),
505 stack_(std::move(stack)),
506 dist_autograd_context_id_(getDistAutogradContextId()) {
507 state_ = InterpreterState(stateImpl_);
508 }
509 void operator()(c10::ivalue::Future& /* unused */) {
510 stateImpl_->taskLauncher_(InterpreterContinuation(
511 state_,
512 std::move(stack_),
513 dist_autograd_context_id_,
514 std::move(tls_state_)));
515 }
516
517 private:
518 c10::intrusive_ptr<InterpreterStateImpl> stateImpl_;
519 InterpreterState state_;
520 Stack stack_;
521 int64_t dist_autograd_context_id_;
522 // preserve the original ThreadLocalState
523 at::ThreadLocalState tls_state_;
524 };
525
526 // we are suspending, so we need to reset the stack to where we
527 // started if it started empty, except for the inputs we can avoid
528 // a true copy by swapping, which leaves the original stack empty.
529 Stack copied;
530 if (stack_start_ == 0) {
531 copied.swap(stack);
532 } else {
533 copied.insert(
534 copied.begin(),
535 std::make_move_iterator(stack.begin() + stack_start_),
536 std::make_move_iterator(stack.end()));
537 stack.resize(stack_start_);
538 }
539 // save pc into the frame so we continue here when restored
540 future->addCallback(
541 Callback(intrusive_from_this(), std::move(copied)));
542
543 return true;
544 }
545 stack.pop_back();
546 stack.emplace_back(future->value());
547 }
548 INST_NEXT;
549 case INST(PROFILE_OP): {
550 auto _ = instGuard();
551 auto& frame_id_ref = frame.id;
552 if (!frame_id_ref.has_value()) {
553 frame_id_ref = Frame::genId();
554 }
555 const auto& callback =
556 frame.function->profile_function_table_[inst.X];
557 push(stack, c10::IValue{static_cast<int64_t>(*frame_id_ref)});
558 callback(stack);
559 }
560 INST_NEXT;
561 case INST(FAIL_GUARD): {
562 auto _ = instGuard();
563 // patch FAIL_GUARD back to GUARD
564 GRAPH_DEBUG(
565 "Bailout ", inst.X, " triggered via bailout_requests_!");
566 frame.function->instructions_[frame.pc].op = GUARD;
567 push(stack, false);
568 }
569 INST_NEXT;
570 case INST(TYPECHECK): {
571 auto _ = instGuard();
572 unsigned num_inputs = inst.N, i = 0;
573 TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0);
574 // Check every input's shape against profiled (expected) shape.
575 for (i = 0; i < num_inputs; i++) {
576 auto& input = peek(stack, i, num_inputs);
577 auto& t = input.toTensor();
578 const TypePtr& expected = frame.function->type_table_[inst.X + i];
579 auto* expected_type = expected->castRaw<TensorType>();
580 if (t.defined() && !expected_type->matchTensor(t)) {
581 push(stack, false);
582 break;
583 }
584 }
585 if (i == num_inputs) {
586 push(stack, true);
587 }
588 }
589 INST_NEXT;
590 case INST(GUARD): {
591 auto _ = instGuard();
592 if (!stack.back().isTensor()) {
593 // stack.back() is an Uninitialized IValue and this is a guard
594 // on a block output. Uninitialized IValues are never used
595 // so it's safe to pass this guard check
596 push(stack, true);
597 } else {
598 auto& t = stack.back().toTensor();
599 const TypePtr& expected = frame.function->type_table_[inst.X];
600 auto* expected_type = expected->castRaw<TensorType>();
601 if (t.defined() &&
602 !frames.back().symbols2dims.bindSymbolicShapes(
603 t.sizes(), expected_type->symbolic_sizes())) {
604 push(stack, false);
605 } else {
606 push(stack, expected_type->matchTensor(t));
607 }
608 }
609 }
610 INST_NEXT;
611 case INST(TAIL_CALL): {
612 auto _ = instGuard();
613 GRAPH_DEBUG("running TAIL_CALL for ", inst.X);
614 frame.function->function_table_[inst.X]->ensure_defined();
615 size_t remaining_bailout_depth =
616 frame.function->remaining_bailout_depth_ > 0
617 ? frame.function->remaining_bailout_depth_ - 1
618 : 0;
619 auto& f = *frame.function->function_table_[inst.X];
620 size_t num_inputs = f.num_inputs();
621 size_t base_pointer = frame.base_pointer;
622 TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs);
623 size_t inputs_start = stack.size() - num_inputs;
624 for (const auto i : c10::irange(num_inputs)) {
625 stack.at(base_pointer + i) =
626 std::move(stack.at(inputs_start + i));
627 }
628 stack.resize(base_pointer + num_inputs);
629 leaveFrame();
630
631 callFunction(f, stack, remaining_bailout_depth, false);
632 continue;
633 }
634 case INST(LIST_UNPACK): {
635 auto _ = instGuard();
636 listUnpack(stack, inst.X);
637 }
638 INST_NEXT;
639 case INST(TUPLE_CONSTRUCT): {
640 auto _ = instGuard();
641 tupleConstruct(stack, inst.X);
642 }
643 INST_NEXT;
644 case INST(TUPLE_SLICE): {
645 auto _ = instGuard();
646 tupleSlice(stack, inst.X, inst.X + inst.N);
647 }
648 INST_NEXT;
649 case INST(NAMED_TUPLE_CONSTRUCT): {
650 auto _ = instGuard();
651 namedTupleConstruct(
652 stack,
653 frame.function->type_table_[inst.X]->expect<TupleType>(),
654 inst.N);
655 }
656 INST_NEXT;
657 case INST(LIST_CONSTRUCT): {
658 auto _ = instGuard();
659 const auto& type =
660 frame.function->type_table_[inst.X]->expectRef<ListType>();
661 listConstruct(stack, type, inst.N);
662 }
663 INST_NEXT;
664 case INST(DICT_CONSTRUCT): {
665 auto _ = instGuard();
666 const auto& type =
667 frame.function->type_table_[inst.X]->expectRef<DictType>();
668 dictConstruct(stack, type, inst.N);
669 }
670 INST_NEXT;
671 case INST(CREATE_OBJECT): {
672 auto _ = instGuard();
673 auto type =
674 frame.function->type_table_[inst.X]->expect<ClassType>();
675 createObject(stack, type);
676 }
677 INST_NEXT;
678 case INST(ISINSTANCE): {
679 auto _ = instGuard();
680 at::ArrayRef<TypePtr> types(
681 &frame.function->type_table_[inst.X],
682 &frame.function->type_table_[inst.X] + inst.N);
683 isinstance(stack, types);
684 }
685 INST_NEXT;
686 case INST(TUPLE_INDEX): {
687 auto _ = instGuard();
688 tupleIndex(stack);
689 }
690 INST_NEXT;
691 case INST(RAISE_EXCEPTION): {
692 auto _ = instGuard();
693 raiseExceptionWithMessage(stack);
694 }
695 INST_NEXT;
696 case INST(UNCHECKED_CAST): {
697 auto _ = instGuard();
698 noop(stack);
699 }
700 INST_NEXT;
701 case INST(__IS__): {
702 auto _ = instGuard();
703 is(stack);
704 }
705 INST_NEXT;
706 case INST(UN_INITIALIZED): {
707 auto _ = instGuard();
708 unInitialized(stack);
709 }
710 INST_NEXT;
711 case INST(__ISNOT__): {
712 auto _ = instGuard();
713 isNot(stack);
714 }
715 INST_NEXT;
716 case INST(FORMAT): {
717 auto _ = instGuard();
718 format(stack, inst.X);
719 }
720 INST_NEXT;
721 case INST(DEVICE): {
722 auto _ = instGuard();
723 device(stack);
724 }
725 INST_NEXT;
726 case INST(DTYPE): {
727 auto _ = instGuard();
728 TORCH_INTERNAL_ASSERT(!stack.empty());
729 dtype(stack);
730 }
731 INST_NEXT;
732 case INST(DIM): {
733 auto _ = instGuard();
734 TORCH_INTERNAL_ASSERT(!stack.empty());
735 dim(stack);
736 }
737 INST_NEXT;
738 case INST(__NOT__): {
739 auto _ = instGuard();
740 _not(stack);
741 }
742 INST_NEXT;
743 case INST(DICT_INDEX): {
744 auto _ = instGuard();
745 dictIndex(stack);
746 }
747 INST_NEXT;
748 case INST(TO_LIST): {
749 auto _ = instGuard();
750 toList(stack);
751 }
752 INST_NEXT;
753 case INST(NUM_TO_TENSOR): {
754 auto _ = instGuard();
755 numToTensorScalar(stack);
756 }
757 INST_NEXT;
758 case INST(IS_CUDA): {
759 auto _ = instGuard();
760 isCuda(stack);
761 }
762 INST_NEXT;
763 case INST(FORK): {
764 auto _ = instGuard();
765 // Move inputs to a separate stack
766 auto& forked_fn =
767 toGraphFunction(*frame.function->function_table_[inst.X]);
768 InterpreterState forked_interpreter(
769 forked_fn.get_executor().getPlanFor(stack).code, taskLauncher_);
770 InterpreterContinuation continuation(
771 forked_interpreter,
772 Stack(stack.end() - inst.N, stack.end()),
773 getDistAutogradContextId());
774 drop(stack, inst.N);
775 push(stack, forked_interpreter.getFuture());
776 taskLauncher_(std::move(continuation));
777 }
778 INST_NEXT;
779 case INST(AWAITABLE): {
780 auto _ = instGuard();
781 auto fn_ptr = frame.function->function_table_[inst.X];
782 auto& fn = toGraphFunction(*fn_ptr);
783 auto num_outputs = fn.graph()->outputs().size();
784 TypePtr out_type;
785 if (num_outputs == 1) {
786 out_type = fn.graph()->outputs()[0]->type();
787 } else {
788 std::vector<TypePtr> out_types;
789 for (const auto& o : fn.graph()->outputs()) {
790 out_types.push_back(o->type());
791 }
792 out_type = TupleType::create(out_types);
793 }
794 auto args = std::vector<IValue>(stack.end() - inst.N, stack.end());
795 auto aw = c10::make_intrusive<c10::ivalue::Await>(out_type);
796 aw->setArgs(std::move(args));
797 aw->setFn(
798 [&args = aw->args(),
799 fn_ptr,
800 taskLauncher = taskLauncher_]() -> IValue {
801 auto& fn = toGraphFunction(*fn_ptr);
802 auto n_out = fn.graph()->outputs().size();
803 torch::jit::Stack s;
804 for (const auto& arg : args) {
805 s.push_back(arg);
806 }
807 InterpreterState await_interpreter(
808 fn.get_executor().getPlanFor(s).code, taskLauncher);
809 await_interpreter.run(s);
810 if (n_out == 1) {
811 return s.back();
812 }
813 return c10::ivalue::Tuple::create(jit::last(s, n_out));
814 });
815 drop(stack, inst.N);
816 push(stack, std::move(aw));
817 }
818 INST_NEXT;
819 case INST(WARN): {
820 auto _ = instGuard();
821 // Keeps track of which WARN instruction has been executed before,
822 // we only want to execute each WARN once to match default Python
823 // warning behavior.
824 bool need_warn = true;
825 if (inst.X != -1) {
826 need_warn = warned_nodes_.insert(inst.X);
827 }
828
829 Node* node =
830 frames.back().function->instructions_source_.at(frame.pc);
831 auto range = node->sourceRange().source();
832 if (range->filename()) {
833 drop(stack, 1);
834 const auto& msg = stack.back().toStringRef();
835 if (need_warn) {
836 auto line = range->starting_line_no() +
837 range->lineno_for_offset(node->sourceRange().start());
838 c10::SourceLocation location{
839 "", range->filename()->c_str(), uint32_t(line)};
840 // Sends the warning to the warning handler with the
841 // "verbatim" flag. This flag ensures the warning handler
842 // will print the exception as configured.
843 c10::warn(c10::Warning(
844 c10::UserWarning(), location, msg, /*verbatim=*/true));
845 }
846 stack.pop_back();
847 } else {
848 if (need_warn) {
849 TORCH_WARN(stack.back().toStringRef());
850 }
851 stack.pop_back();
852 }
853 }
854 INST_NEXT;
855 }
856 }
857 } catch (std::exception& e) {
858 for (auto it = entered_objects.rbegin(), end = entered_objects.rend();
859 it != end;
860 ++it) {
861 auto& f = it->toObject()->type()->getMethod("__exit__");
862 Stack stack;
863 push(stack, *it);
864 push(stack, IValue());
865 push(stack, IValue());
866 push(stack, IValue());
867 try {
868 f.run(stack);
869 } catch (std::exception& _) {
870 // TODO(T98048876): Handle `_` correctly.
871 }
872 }
873 if (FLAGS_torch_jit_enable_rethrow_caught_exception) {
874 if (future_) {
875 future_->setError(std::current_exception());
876 return false;
877 }
878 throw;
879 }
880 auto* jit_exception = dynamic_cast<JITException*>(&e);
881 // Janky af. See https://github.com/pytorch/pytorch/issues/54612
882 auto* not_implemented_error = dynamic_cast<c10::NotImplementedError*>(&e);
883
884 std::optional<std::string> python_class_name;
885 if (jit_exception) {
886 python_class_name = jit_exception->getPythonClassName();
887 }
888 handleError(
889 e, (bool)jit_exception, not_implemented_error, python_class_name);
890 return false;
891 }
892 }
893
894 #undef INST_NEXT
895 #undef INST_DISPATCH
896 #undef INST
897 #undef JIT_USE_COMPUTED_GOTO
898
runImpltorch::jit::InterpreterStateImpl899 bool runImpl(Stack& stack) {
900 if (!profiling::isProfilingOngoing()) {
901 return runTemplate</*EnableProfiling*/ false>(stack);
902 } else {
903 return runTemplate</*EnableProfiling*/ true>(stack);
904 }
905 }
906
formatStackTracetorch::jit::InterpreterStateImpl907 void formatStackTrace(std::ostream& out) {
908 format_stack_trace(out, callstack());
909 }
910
handleErrortorch::jit::InterpreterStateImpl911 void handleError(
912 const std::exception& e,
913 bool is_jit_exception,
914 c10::NotImplementedError* not_implemented_error,
915 std::optional<std::string> python_class_name) {
916 ExceptionMessage msg(e);
917 std::ostringstream ss;
918 std::string class_name =
919 python_class_name ? *python_class_name : "RuntimeError";
920 ss << "The following operation failed in the TorchScript interpreter.\n";
921 formatStackTrace(ss);
922 ss << class_name << ": " << msg << "\n";
923 if (future_) {
924 future_->setError(std::make_exception_ptr(Future::FutureError(ss.str())));
925 } else if (is_jit_exception) {
926 // save the original exception's message when creating a new JITException
927 throw JITException(ss.str(), python_class_name, e.what());
928 } else if (not_implemented_error) {
929 throw c10::NotImplementedError(
930 ss.str(),
931 not_implemented_error->backtrace(),
932 not_implemented_error->caller());
933 } else {
934 if (get_cpp_stacktraces_enabled()) {
935 ss << e.what() << "\n";
936 }
937 throw std::runtime_error(ss.str());
938 }
939 }
940
checkAndStartRecordFunctiontorch::jit::InterpreterStateImpl941 static void checkAndStartRecordFunction(Frame& frame, Stack& stack) {
942 if (!frame.record_function) {
943 auto step_callbacks = at::getStepCallbacksUnlessEmpty(
944 at::RecordScope::TORCHSCRIPT_FUNCTION);
945 if (C10_UNLIKELY(step_callbacks.has_value())) {
946 auto rec_fn =
947 std::make_unique<at::RecordFunction>(std::move(*step_callbacks));
948 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rec_fn->isActive());
949 if (rec_fn->needsInputs()) {
950 rec_fn->before(
951 frame.function->function_name_,
952 last(stack, frame.function->n_inputs));
953 } else {
954 rec_fn->before(frame.function->function_name_);
955 }
956 frame.record_function = std::move(rec_fn);
957 }
958 }
959 }
960
961 public:
962 // One way to avoid overhead of forming string would be to return
963 // a vector of frame.function, i.e. CodeImpl*
964 // This is not exactly clean as it will expose, internal details of
965 // interpreter. But this way we hold onto graph/node and Function and
966 // we can create module hierarchy string for each event in autograd
967 // profiler at the end, when consolidating events.
968 // At the moment overhead does not seem exhorbitantly large.
969 // Another option would be return vector of (string, InlinedCallstackPtrs)
970 // string would contain function name and typename of self
971 // Format of the returned vector of strings:
972 // For each frame, the corresponding module name, type and function name
973 // are in following format:
974 // <module-instance-name>(module type)::<function-name>
975 // Special keys for module-instance-name:
976 // - TOP: for top level module
977 // - SELF: When method/function of the frame is associated with
978 // previous frame's module instance
979 // - INSTANCE_NAME_UNKNOWN: instance name cannot be figured out
980 // - CALL_FUNCTION: call to free function
moduleHierarchytorch::jit::InterpreterStateImpl981 std::vector<std::string> moduleHierarchy() const {
982 std::vector<std::string> module_function_list;
983 std::string module_hierarchy("TOP");
984 for (size_t i = 0; i < frames.size(); ++i) {
985 const Frame& frame = frames[i];
986 std::string fn_name = frame.function->function_name_;
987 // For each frame, type of the class with which the function is
988 // associated, is queried here. And the type name is added to
989 // module hierarchy.
990 const auto& g = frame.function->graph_;
991 std::string g_self_type;
992 if (g && !g->inputs().empty()) {
993 const auto& g_self_type_ptr =
994 g->inputs()[0]->type()->cast<c10::ClassType>();
995 if (g_self_type_ptr) {
996 g_self_type = g_self_type_ptr->name()->qualifiedName();
997 g_self_type = g_self_type.substr(g_self_type.find_last_of('.') + 1);
998 }
999 }
1000 module_hierarchy.append("(")
1001 .append(g_self_type)
1002 .append(")::")
1003 .append(fn_name);
1004 module_function_list.emplace_back(std::move(module_hierarchy));
1005
1006 size_t pc = frame.pc;
1007 // CALL nodes have already advanced the pc, so
1008 // undo that to report the call node
1009 if (i + 1 < frames.size()) {
1010 --pc;
1011 }
1012
1013 Node* node = frame.function->instructions_source_[pc];
1014 if (node->callstack()) {
1015 for (const auto& p : (*node->callstack())->vec()) {
1016 fn_name = std::get<0>(p)->name();
1017 const auto& opt_module_info = std::get<2>(p);
1018 if (opt_module_info.has_value()) {
1019 const auto& module_instance_info = opt_module_info.value();
1020 module_hierarchy = utils::get_module_info(module_instance_info);
1021 module_hierarchy.append("::").append(fn_name);
1022 } else {
1023 // This is likely a call to free function, not associated with
1024 // any class
1025 module_hierarchy = "::";
1026 module_hierarchy.append(fn_name);
1027 }
1028 module_function_list.emplace_back(std::move(module_hierarchy));
1029 }
1030 }
1031
1032 module_hierarchy = std::string();
1033 // If this node is of type callMethod then the following frame
1034 // will contain the op being executed.
1035 // For such callMethod node, we add the object instance name
1036 // associated with it, since the following frame will not have it.
1037 if (node->kind() == prim::CallMethod) {
1038 std::string class_instance_name;
1039 if (node->input(0)->node()->kind() == prim::GetAttr) {
1040 class_instance_name = node->input(0)->node()->s(attr::name);
1041 } else if (
1042 !node->owningGraph()->inputs().empty() &&
1043 node->input(0) == node->owningGraph()->inputs()[0]) {
1044 class_instance_name = "SELF";
1045 } else {
1046 class_instance_name = "INSTANCE_NAME_UNKNOWN";
1047 }
1048 module_hierarchy = std::move(class_instance_name);
1049 } else if (node->kind() == prim::CallFunction) {
1050 auto function_constant = node->input(0)->node();
1051 auto fun_type =
1052 function_constant->output()->type()->expect<FunctionType>();
1053 auto fun_name = fun_type->function()->name();
1054 module_hierarchy = "CALL_FUNCTION::";
1055 module_hierarchy.append(fun_name);
1056 }
1057 }
1058 return module_function_list;
1059 }
1060
callstacktorch::jit::InterpreterStateImpl1061 std::vector<StackEntry> callstack() const {
1062 std::vector<StackEntry> entries;
1063 for (const auto i : c10::irange(frames.size())) {
1064 const Frame& frame = frames[i];
1065 std::string previous_fn_name = frame.function->function_name_;
1066 size_t pc = frame.pc;
1067 // CALL nodes have already advanced the pc, so
1068 // undo that to report the call node
1069 if (i + 1 < frames.size()) {
1070 --pc;
1071 }
1072
1073 Node* node = frame.function->instructions_source_[pc];
1074 if (node->callstack()) {
1075 for (const auto& p : (*node->callstack())->vec()) {
1076 entries.emplace_back(StackEntry{previous_fn_name, std::get<1>(p)});
1077 previous_fn_name = std::get<0>(p)->name();
1078 }
1079 }
1080 entries.emplace_back(StackEntry{previous_fn_name, node->sourceRange()});
1081 }
1082 return entries;
1083 }
1084
getOrCreateFuturetorch::jit::InterpreterStateImpl1085 c10::intrusive_ptr<Future> getOrCreateFuture() {
1086 if (!future_) {
1087 future_ =
1088 c10::make_intrusive<Future>(frames.front().function->return_type_);
1089 }
1090 return future_;
1091 }
1092
runAsynctorch::jit::InterpreterStateImpl1093 c10::intrusive_ptr<Future> runAsync(Stack& stack) {
1094 getOrCreateFuture();
1095 runImpl(stack);
1096 return future_;
1097 }
1098
runtorch::jit::InterpreterStateImpl1099 void run(Stack& stack) {
1100 // By the time the continuation completes the frame will be gone, so this
1101 // must be done before calling runImpl().
1102 TORCH_INTERNAL_ASSERT(!frames.empty());
1103 const auto num_outputs = frames.front().function->n_outputs;
1104 if (runImpl(stack)) {
1105 future_->wait();
1106
1107 if (num_outputs == 1) {
1108 push(stack, future_->value());
1109 } else {
1110 auto tuple = future_->value().toTuple();
1111 for (const IValue& value : tuple->elements()) {
1112 push(stack, value);
1113 }
1114 }
1115 }
1116 }
1117 };
1118
currentCallstack()1119 std::vector<StackEntry> currentCallstack() {
1120 if (tls_int_state_ptr_) {
1121 auto cs = tls_int_state_ptr_->callstack();
1122 std::reverse(cs.begin(), cs.end());
1123 return cs;
1124 }
1125 return std::vector<StackEntry>();
1126 }
1127
currentModuleHierarchy()1128 std::vector<std::string> currentModuleHierarchy() {
1129 if (tls_int_state_ptr_) {
1130 return tls_int_state_ptr_->moduleHierarchy();
1131 }
1132 return std::vector<std::string>();
1133 }
1134
operator <<(std::ostream & out,const Code & code)1135 std::ostream& operator<<(std::ostream& out, const Code& code) {
1136 out << *code.pImpl->graph_ << "\n";
1137 code.pImpl->dump(out);
1138 return out;
1139 }
1140
Code(const std::shared_ptr<Graph> & graph,std::string function_name,size_t remaining_bailout_depth)1141 Code::Code(
1142 const std::shared_ptr<Graph>& graph,
1143 std::string function_name,
1144 size_t remaining_bailout_depth)
1145 : pImpl(new CodeImpl(
1146 graph,
1147 std::move(function_name),
1148 remaining_bailout_depth)) {}
1149
Code(CodeImpl * codeImpl)1150 Code::Code(CodeImpl* codeImpl) : pImpl(codeImpl) {}
1151
MobileCode(const std::shared_ptr<Graph> & graph,std::string function_name,bool emit_default_input_instructions,bool support_default_args_before_out,bool emit_promoted_ops,size_t remaining_bailout_depth)1152 MobileCode::MobileCode(
1153 const std::shared_ptr<Graph>& graph,
1154 std::string function_name,
1155 bool emit_default_input_instructions,
1156 bool support_default_args_before_out,
1157 bool emit_promoted_ops,
1158 size_t remaining_bailout_depth)
1159 : Code(new interpreter::MobileCodeImpl(
1160 graph,
1161 std::move(function_name),
1162 emit_default_input_instructions,
1163 support_default_args_before_out,
1164 emit_promoted_ops,
1165 remaining_bailout_depth)) {}
1166
grad_executors()1167 const std::vector<GraphExecutor*>& Code::grad_executors() {
1168 return pImpl->grad_executors();
1169 }
1170
diff_graph_op_executors()1171 const std::vector<GraphExecutor*>& Code::diff_graph_op_executors() {
1172 return pImpl->diff_graph_op_executors();
1173 }
1174
num_bailouts() const1175 size_t Code::num_bailouts() const {
1176 return pImpl->type_table_.size();
1177 }
1178
request_bailout(size_t index)1179 void Code::request_bailout(size_t index) {
1180 pImpl->request_bailout(index);
1181 }
1182
num_inputs() const1183 size_t Code::num_inputs() const {
1184 return pImpl->n_inputs;
1185 }
1186
num_outputs() const1187 size_t Code::num_outputs() const {
1188 return pImpl->n_outputs;
1189 }
1190
constant_table() const1191 const std::vector<c10::IValue>& Code::constant_table() const {
1192 return pImpl->constant_table();
1193 }
1194
instructions() const1195 const std::vector<Instruction>& Code::instructions() const {
1196 return pImpl->instructions();
1197 }
1198
op_to_num_specified_args() const1199 const std::unordered_map<std::string, size_t>& Code::op_to_num_specified_args()
1200 const {
1201 return pImpl->op_to_num_specified_args();
1202 }
1203
instructions_source() const1204 const std::vector<Node*>& Code::instructions_source() const {
1205 return pImpl->instructions_source();
1206 }
1207
type_table() const1208 const std::vector<TypePtr>& Code::type_table() const {
1209 return pImpl->type_table_;
1210 }
1211
register_size() const1212 size_t Code::register_size() const {
1213 return pImpl->register_size_;
1214 }
1215
graph() const1216 std::shared_ptr<Graph> Code::graph() const {
1217 return pImpl->preprocess_.graph;
1218 }
1219
InterpreterState(const Code & code,TaskLauncher taskLauncher)1220 InterpreterState::InterpreterState(const Code& code, TaskLauncher taskLauncher)
1221 : pImpl(c10::make_intrusive<InterpreterStateImpl>(
1222 code,
1223 std::move(taskLauncher))) {}
1224
run(Stack & stack)1225 void InterpreterState::run(Stack& stack) {
1226 static_cast<InterpreterStateImpl*>(pImpl.get())->run(stack);
1227 }
1228
runAsync(Stack & stack)1229 c10::intrusive_ptr<Future> InterpreterState::runAsync(Stack& stack) {
1230 return static_cast<InterpreterStateImpl*>(pImpl.get())->runAsync(stack);
1231 }
1232
getFuture()1233 c10::intrusive_ptr<Future> InterpreterState::getFuture() {
1234 return static_cast<InterpreterStateImpl*>(pImpl.get())->getOrCreateFuture();
1235 }
1236
InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl_)1237 InterpreterState::InterpreterState(
1238 c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl_)
1239 : pImpl(std::move(pImpl_)) {}
1240
operator ()()1241 void InterpreterContinuation::operator()() {
1242 #ifdef USE_RPC
1243 auto prev_dist_id = DistAutogradContainer::currentContextId();
1244 DistAutogradContainer::forceCurrentContextId(dist_autograd_context_id_);
1245 #endif
1246 if (tls_state_ != std::nullopt) {
1247 at::ThreadLocalStateGuard g(*tls_state_);
1248 state.runAsync(stack);
1249 } else {
1250 state.runAsync(stack);
1251 }
1252 #ifdef USE_RPC
1253 DistAutogradContainer::forceCurrentContextId(prev_dist_id);
1254 #endif
1255 }
1256
1257 } // namespace torch::jit
1258