xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/interpreter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/interpreter.h>
2 
3 #include <ATen/Parallel.h>
4 #include <ATen/core/ivalue.h>
5 #include <ATen/record_function.h>
6 #include <c10/core/thread_pool.h>
7 #include <c10/macros/Macros.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/irange.h>
10 #include <torch/csrc/autograd/edge.h>
11 #include <torch/csrc/autograd/grad_mode.h>
12 #include <torch/csrc/autograd/profiler.h>
13 #include <torch/csrc/autograd/variable.h>
14 #include <torch/csrc/jit/api/compilation_unit.h>
15 #include <torch/csrc/jit/api/function_impl.h>
16 #include <torch/csrc/jit/ir/constants.h>
17 #include <torch/csrc/jit/ir/ir.h>
18 #include <torch/csrc/jit/jit_log.h>
19 #include <torch/csrc/jit/mobile/promoted_prim_ops.h>
20 #include <torch/csrc/jit/runtime/exception_message.h>
21 #include <torch/csrc/jit/runtime/graph_executor.h>
22 #include <torch/csrc/jit/runtime/instruction.h>
23 #include <torch/csrc/jit/runtime/interpreter/code_impl.h>
24 #include <torch/csrc/jit/runtime/interpreter/frame.h>
25 #include <torch/csrc/jit/runtime/jit_exception.h>
26 #include <torch/csrc/jit/runtime/operator.h>
27 #include <torch/csrc/jit/runtime/profiling_record.h>
28 #include <torch/csrc/jit/runtime/script_profile.h>
29 #include <torch/csrc/jit/runtime/vararg_functions.h>
30 #include <torch/csrc/utils/cpp_stacktraces.h>
31 #include <string>
32 
33 #ifdef USE_RPC
34 #include <torch/csrc/distributed/autograd/context/container.h>
35 using torch::distributed::autograd::DistAutogradContainer;
36 #endif
37 
38 #include <exception>
39 #include <memory>
40 #include <mutex>
41 #include <ostream>
42 #include <stdexcept>
43 #include <typeinfo>
44 #include <unordered_map>
45 #include <unordered_set>
46 #include <utility>
47 #include <vector>
48 
49 C10_DEFINE_bool(
50     torch_jit_enable_rethrow_caught_exception,
51     false,
52     "enable rethrowing caught exception");
53 
54 C10_DEFINE_bool(
55     torch_jit_enable_expanded_stacks,
56     false,
57     "When true we will attemps to pre-expand node stacks and cache expanded stacks.");
58 
59 namespace torch::jit {
60 
61 using CodeImpl = interpreter::CodeImpl;
62 
63 // Before we translate to interpreter instructions, we do
64 // some preprocessing of the graph to turn it into a form that is closer
65 // to what the instructions will look like.
66 // In particular we:
67 // *  Computes whether a input to a node is the last use, so we can issue MOVE
68 //    rather than LOAD instructions.
69 // *  Drop nodes are inserted for any node that is unused to create a dummy use
70 //    that will cause the interpreter to free the node.
71 //    A drop node just pops its input off the stack to  ensure the interpreter
72 //    releases references to nodes that are never used. Drop nodes are also
73 //    inserted when the last use of a node is in some conditionally run control
74 //    flow (e.g. one side of an If) and the interpreter must free the node only
75 //    after the control flow has reconverged
76 // Outputs are:
77 // * graph - the post processed copy of g
78 // * move_flags[n] - a list of booleans, one for each input,
79 //   indicating whether this is the last use of the value. The interpreter
80 //   should generate a move rather than a copy in this case.
81 
tensorTypeInCurrentExecutionContext(const at::Tensor & t)82 TensorTypePtr tensorTypeInCurrentExecutionContext(const at::Tensor& t) {
83   if (!t.defined()) {
84     return TensorType::get()->withUndefined();
85   }
86   auto r = TensorType::create(t);
87   if (!at::GradMode::is_enabled()) {
88     return r->withRequiresGrad(false);
89   }
90   return r;
91 }
92 
93 namespace {
getDistAutogradContextId()94 inline int64_t getDistAutogradContextId() {
95 #ifdef USE_RPC
96   return DistAutogradContainer::currentContextId();
97 #else
98   return 0;
99 #endif
100 }
101 } // namespace
102 
103 thread_local InterpreterStateImpl* tls_int_state_ptr_ = nullptr;
104 struct TLSCurrentInterpreterGuard {
TLSCurrentInterpreterGuardtorch::jit::TLSCurrentInterpreterGuard105   TLSCurrentInterpreterGuard(InterpreterStateImpl* state)
106       : prev_state_(tls_int_state_ptr_) {
107     tls_int_state_ptr_ = state;
108   }
109 
~TLSCurrentInterpreterGuardtorch::jit::TLSCurrentInterpreterGuard110   ~TLSCurrentInterpreterGuard() {
111     tls_int_state_ptr_ = prev_state_;
112   }
113 
114  private:
115   InterpreterStateImpl* prev_state_;
116 };
117 
118 // InterpreterState state that and used to compute a Code
119 struct InterpreterStateImpl : c10::intrusive_ptr_target {
InterpreterStateImpltorch::jit::InterpreterStateImpl120   InterpreterStateImpl(const Code& code, TaskLauncher taskLauncher)
121       : taskLauncher_(std::move(taskLauncher)) {
122     enterFrame(code, 0);
123   }
124 
125  private:
126   using Frame = torch::jit::interpreter::Frame;
127   struct WarnedNodes {
128    public:
129     // Inserts idx into warned_nodes_, returns a boolean indicates whether
130     // insertion actually happened (idx wasn't originally in the set).
inserttorch::jit::InterpreterStateImpl::WarnedNodes131     bool insert(int32_t idx) {
132       std::unique_lock<std::mutex> lock(mutex_);
133       return warned_nodes_.insert(idx).second;
134     }
135 
136    private:
137     std::mutex mutex_;
138     std::unordered_set<int32_t> warned_nodes_;
139   };
140 
141   WarnedNodes warned_nodes_;
142 
143   // if we need to suspend, where do we reset the stack?
144   // answer: to where it was when we were called, not
145   // including any inputs to this function
146   int64_t stack_start_ = -1;
147   c10::intrusive_ptr<Future> future_;
148   TaskLauncher taskLauncher_;
149 
150   // this holds all the tensors for this interpreter run
151   // we don't bother minimizing the size of this vector, since the extra
152   // memory used by the pointers in this will be small
153   // instead we are very aggressive about releasing tensors when they become
154   // dead to make sure memory management happens efficiently. We optimize for
155   // the case where derivatives are run with retain_graph=False in the case
156   // where it is true, then the interpreter and this array get copied if this
157   // every becomes a bottleneck then we _should_ consider minimizing the total
158   // number or register
159   std::vector<IValue> registers;
160 
161   // A stack of objects that have been __enter__'d.
162   std::vector<IValue> entered_objects;
163 
164   std::vector<Frame> frames;
165 
intrusive_from_thistorch::jit::InterpreterStateImpl166   c10::intrusive_ptr<InterpreterStateImpl> intrusive_from_this() {
167     c10::raw::intrusive_ptr::incref(this);
168     return c10::intrusive_ptr<InterpreterStateImpl>::reclaim(this);
169   }
170 
enterFrametorch::jit::InterpreterStateImpl171   void enterFrame(const Code& code, size_t base_pointer) {
172     frames.emplace_back(Frame{code.pImpl, 0, base_pointer, std::nullopt});
173     registers.resize(registers.size() + code.pImpl->register_size_);
174   }
175 
leaveFrametorch::jit::InterpreterStateImpl176   void leaveFrame() {
177     registers.resize(registers.size() - frames.back().function->register_size_);
178     frames.pop_back();
179   }
180 
callFunctiontorch::jit::InterpreterStateImpl181   void callFunction(
182       Function& f,
183       Stack& stack,
184       std::optional<size_t> bailOut = std::nullopt,
185       bool next = true) {
186     bool newFrame = f.call(stack, bailOut, [&](const Code& code) {
187       enterFrame(code, stack.size() - code.num_inputs());
188       checkAndStartRecordFunction(frames.back(), stack);
189     });
190     if (next) {
191       (frames.rbegin() + (newFrame ? 1 : 0))->pc++;
192     }
193   }
194 
195   // relative to the end of the register list so that when we call
196   // functions we are referring to the registers of the currently executing
197   // function.
regtorch::jit::InterpreterStateImpl198   IValue& reg(size_t reg) {
199     return *(registers.end() - reg);
200   }
201 
dumptorch::jit::InterpreterStateImpl202   void dump(std::ostream& out, const Stack& stack) const {
203     out << "Stack:\n";
204     for (const auto& val : stack) {
205       out << val;
206       out << "\n";
207     }
208   }
209 
210   class StackSizeDidntChangeGuard {
211    public:
212     StackSizeDidntChangeGuard(const StackSizeDidntChangeGuard&) = delete;
213     StackSizeDidntChangeGuard(StackSizeDidntChangeGuard&&) = delete;
214     StackSizeDidntChangeGuard& operator=(const StackSizeDidntChangeGuard&) =
215         delete;
216     StackSizeDidntChangeGuard& operator=(StackSizeDidntChangeGuard&&) = delete;
217 
StackSizeDidntChangeGuard(const Frame & frame,const torch::jit::Stack & stack,const Instruction & inst)218     StackSizeDidntChangeGuard(
219         const Frame& frame,
220         const torch::jit::Stack& stack,
221         const Instruction& inst)
222         : frame_(frame), stack_(stack), instX_(inst.X) {
223       // portable maybe_unused attribute.
224       (void)frame_;
225       (void)stack_;
226       (void)instX_;
227       (void)initialSize_;
228     }
229 
callAssert() const230     void callAssert() const {
231 #ifndef NDEBUG
232       frame_.function->assert_stack_size(instX_, initialSize_, stack_.size());
233 #endif
234     }
235 
236    private:
237     const Frame& frame_;
238     const torch::jit::Stack& stack_;
239     std::uint32_t instX_;
240     std::size_t initialSize_{stack_.size()};
241   };
242 
243   struct C10_UNUSED DoNothing {};
244 
245 #if defined(__GNUC__) || defined(__clang__)
246 #define JIT_USE_COMPUTED_GOTO
247 #endif
248 // Primitives for making interpreter internal state transitions.
249 // We maintain two local variables as the internal interpreter state:
250 // `frame` will be the current frame that the interpreter operators on.
251 // `inst` will the current instruction pointed to by program counter.
252 //
253 // Instruction blocks should be always declared through `INST` macro and
254 // the instruction body should always start with a `instGuard()` declaration.
255 // Also blocks should be ended properly with either `INST_NEXT` (for going
256 // to the next instruction), or `INST_DISPATCH` (for jumping to a computed
257 // position using `instFetch`).
258 #if defined(JIT_USE_COMPUTED_GOTO)
259 #define INST(NAME) \
260   NAME:            \
261   label_##NAME
262 #define INST_DISPATCH goto* dispatch_table[inst.op]
263 #else
264 #define INST(NAME) NAME
265 #define INST_DISPATCH break
266 #endif
267 #define INST_NEXT      \
268   inst = instFetch(1); \
269   INST_DISPATCH
270 
271   template <bool EnableProfiling>
runTemplatetorch::jit::InterpreterStateImpl272   bool runTemplate(Stack& stack) {
273     // if we have never run before, then we might have to return the
274     // stack when we suspend, record where it starts so we return the right
275     // stack
276     if (stack_start_ == -1) {
277       TORCH_INTERNAL_ASSERT(stack.size() >= frames.back().function->n_inputs);
278       stack_start_ = stack.size() - frames.back().function->n_inputs;
279     } else {
280       // during restarts, all of the stack is always our own, so we leave
281       // nothing
282       stack_start_ = 0;
283     }
284 
285     TLSCurrentInterpreterGuard g(this);
286     if (frames.back().pc == 0 && stack_start_ == 0) {
287       checkAndStartRecordFunction(frames.back(), stack);
288     }
289 
290 #if defined(JIT_USE_COMPUTED_GOTO)
291     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
292     static void* dispatch_table[] = {
293 #define DISPATCH_TABLE_ENTRY(op, _) &&label_##op,
294         FORALL_OPCODES(DISPATCH_TABLE_ENTRY)
295 #undef DISPATCH_TABLE_ENTRY
296     };
297 #endif
298 
299     try {
300       while (true) {
301         Frame& frame = frames.back();
302 
303         auto instFetch = [&](auto x) {
304           return frame.function->instructions_[frame.pc += x];
305         };
306 
307         auto instGuard = [&] {
308           if constexpr (!EnableProfiling) {
309             return DoNothing{};
310           } else {
311             return profiling::InstructionSpan{
312                 *frame.function->instructions_source()[frame.pc]};
313           }
314         };
315 
316         Instruction inst = instFetch(0);
317 
318         auto stackSizeAssertGuard = [&] {
319           return StackSizeDidntChangeGuard{frame, stack, inst};
320         };
321 
322         switch (inst.op) {
323           case INST(ENTER): {
324             auto _ = instGuard();
325             const auto& obj = peek(stack, 0, 1);
326             TORCH_INTERNAL_ASSERT(obj.isObject());
327             entered_objects.push_back(obj);
328           }
329             INST_NEXT;
330           case INST(EXIT): {
331             auto _ = instGuard();
332             auto obj = entered_objects.back().toObject();
333             auto& f = obj->type()->getMethod("__exit__");
334             push(stack, std::move(obj));
335             entered_objects.pop_back();
336             push(stack, IValue());
337             push(stack, IValue());
338             push(stack, IValue());
339             callFunction(f, stack);
340             continue;
341           }
342           case INST(OP): {
343             auto _ = instGuard();
344             auto stackSizeGuard = stackSizeAssertGuard();
345             frame.function->operator_table_[inst.X](stack);
346             stackSizeGuard.callAssert();
347           }
348             INST_NEXT;
349           case INST(OPN): {
350             auto _ = instGuard();
351             stack.emplace_back(inst.N);
352             auto stackSizeGuard = stackSizeAssertGuard();
353             frame.function->operator_table_[inst.X](stack);
354             stackSizeGuard.callAssert();
355           }
356             INST_NEXT;
357           case INST(LOAD): {
358             auto _ = instGuard();
359             stack.emplace_back(reg(inst.X));
360           }
361             INST_NEXT;
362           case INST(MOVE): {
363             auto _ = instGuard();
364             stack.emplace_back(std::move(reg(inst.X)));
365           }
366             INST_NEXT;
367           case INST(STORE): {
368             auto _ = instGuard();
369             reg(inst.X) = pop(stack);
370           }
371             INST_NEXT;
372           case INST(STOREN): {
373             auto _ = instGuard();
374             TORCH_INTERNAL_ASSERT(stack.size() >= inst.N);
375             for (size_t i = inst.N; i > 0; --i) {
376               reg(inst.X + i - 1) = pop(stack);
377             }
378           }
379             INST_NEXT;
380           case INST(DROP): {
381             auto _ = instGuard();
382             stack.pop_back();
383           }
384             INST_NEXT;
385           case INST(DROPR): {
386             auto _ = instGuard();
387             reg(inst.X) = IValue();
388           }
389             INST_NEXT;
390           case INST(LOADC): {
391             auto _ = instGuard();
392             stack.emplace_back(frame.function->constant_table_[inst.X]);
393           }
394             INST_NEXT;
395           case INST(GET_ATTR): {
396             auto _ = instGuard();
397             const auto& userObj = stack.back().toObjectRef();
398             stack.back() = userObj.getSlot(inst.X);
399           }
400             INST_NEXT;
401           case INST(SET_ATTR): {
402             auto _ = instGuard();
403             auto v = pop(stack);
404             auto& userObj = stack.back().toObjectRef();
405             userObj.setSlot(inst.X, std::move(v));
406             stack.pop_back();
407           }
408             INST_NEXT;
409           case INST(JF): {
410             auto _ = instGuard();
411             if (pop(stack).toBool()) {
412               inst = instFetch(1);
413             } else {
414               inst = instFetch(inst.X);
415             }
416           }
417             INST_DISPATCH;
418           case INST(JMP): {
419             auto _ = instGuard();
420             inst = instFetch(inst.X);
421           }
422             INST_DISPATCH;
423           case INST(LOOP): {
424             auto _ = instGuard();
425             // stack: iteration_count, max_iter, cond, loop_carried_deps...
426             auto fr = stack.end() - (inst.N + 1);
427             int64_t trip_count = fr[0].toInt();
428             int64_t max_trip_count = fr[1].toInt();
429             bool cond = fr[2].toBool();
430             if (trip_count < max_trip_count && cond) {
431               fr[2] = trip_count;
432               fr[0] = trip_count + 1;
433               inst = instFetch(1);
434             } else {
435               size_t n_loop_carried = inst.N - 2;
436               for (const auto i : c10::irange(n_loop_carried)) {
437                 fr[i] = std::move(fr[i + 3]);
438               }
439               drop(stack, 3); // iteration_count, max_iter, cond
440               inst = instFetch(inst.X);
441             }
442           }
443             INST_DISPATCH;
444           case INST(CALL): {
445             auto _ = instGuard();
446             Function* fn = frame.function->function_table_[inst.X];
447             callFunction(*fn, stack);
448             continue;
449           }
450           case INST(INTERFACE_CALL): {
451             auto _ = instGuard();
452             // note the hash table lookup to find the function
453             // this can be more optimized if necessary, caching parts
454             // of the hashing computation or storing the offset when
455             // the object is turned into an interface
456 
457             // consider passing
458             // `frames.back().function->remaining_bailout_depth_` into
459             // `get_executor().getPlanFor()` to propagate caller's depth
460             // restrictions onto children while this strategy has a potential to
461             // reduce the number of compilations for too dynamic callers we
462             // might miss opportunities where a caller is dynamic but a callee
463             // gets stable arguments
464             Function& function =
465                 peek(stack, 0, inst.N)
466                     .toObject()
467                     ->type()
468                     ->getMethod(
469                         frame.function->constant_table_[inst.X].toStringRef());
470             callFunction(function, stack);
471             continue;
472           }
473           case INST(RET): {
474             if (frames.size() > 1) {
475               leaveFrame();
476               continue;
477             }
478             if (future_) {
479               auto num_outputs = frames.back().function->n_outputs;
480               if (num_outputs == 1) {
481                 future_->markCompleted(stack.back());
482               } else {
483                 future_->markCompleted(
484                     c10::ivalue::Tuple::create(jit::last(stack, num_outputs)));
485               }
486             }
487             // destroy the last frame and call RecordFunction's end callbacks
488             leaveFrame();
489             return false;
490           }
491           case INST(WAIT): {
492             auto _ = instGuard();
493             auto future = stack.back().toFuture();
494             if (!future->completed()) {
495               getOrCreateFuture();
496 
497               // callback needs to be a struct rather than a lambda so that
498               // we can move the stack to the other thread
499               struct Callback {
500                 Callback(
501                     c10::intrusive_ptr<InterpreterStateImpl> state,
502                     Stack stack)
503                     : stateImpl_(std::move(state)),
504                       state_(stateImpl_),
505                       stack_(std::move(stack)),
506                       dist_autograd_context_id_(getDistAutogradContextId()) {
507                   state_ = InterpreterState(stateImpl_);
508                 }
509                 void operator()(c10::ivalue::Future& /* unused */) {
510                   stateImpl_->taskLauncher_(InterpreterContinuation(
511                       state_,
512                       std::move(stack_),
513                       dist_autograd_context_id_,
514                       std::move(tls_state_)));
515                 }
516 
517                private:
518                 c10::intrusive_ptr<InterpreterStateImpl> stateImpl_;
519                 InterpreterState state_;
520                 Stack stack_;
521                 int64_t dist_autograd_context_id_;
522                 // preserve the original ThreadLocalState
523                 at::ThreadLocalState tls_state_;
524               };
525 
526               // we are suspending, so we need to reset the stack to where we
527               // started if it started empty, except for the inputs we can avoid
528               // a true copy by swapping, which leaves the original stack empty.
529               Stack copied;
530               if (stack_start_ == 0) {
531                 copied.swap(stack);
532               } else {
533                 copied.insert(
534                     copied.begin(),
535                     std::make_move_iterator(stack.begin() + stack_start_),
536                     std::make_move_iterator(stack.end()));
537                 stack.resize(stack_start_);
538               }
539               // save pc into the frame so we continue here when restored
540               future->addCallback(
541                   Callback(intrusive_from_this(), std::move(copied)));
542 
543               return true;
544             }
545             stack.pop_back();
546             stack.emplace_back(future->value());
547           }
548             INST_NEXT;
549           case INST(PROFILE_OP): {
550             auto _ = instGuard();
551             auto& frame_id_ref = frame.id;
552             if (!frame_id_ref.has_value()) {
553               frame_id_ref = Frame::genId();
554             }
555             const auto& callback =
556                 frame.function->profile_function_table_[inst.X];
557             push(stack, c10::IValue{static_cast<int64_t>(*frame_id_ref)});
558             callback(stack);
559           }
560             INST_NEXT;
561           case INST(FAIL_GUARD): {
562             auto _ = instGuard();
563             // patch FAIL_GUARD back to GUARD
564             GRAPH_DEBUG(
565                 "Bailout ", inst.X, " triggered via bailout_requests_!");
566             frame.function->instructions_[frame.pc].op = GUARD;
567             push(stack, false);
568           }
569             INST_NEXT;
570           case INST(TYPECHECK): {
571             auto _ = instGuard();
572             unsigned num_inputs = inst.N, i = 0;
573             TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0);
574             // Check every input's shape against profiled (expected) shape.
575             for (i = 0; i < num_inputs; i++) {
576               auto& input = peek(stack, i, num_inputs);
577               auto& t = input.toTensor();
578               const TypePtr& expected = frame.function->type_table_[inst.X + i];
579               auto* expected_type = expected->castRaw<TensorType>();
580               if (t.defined() && !expected_type->matchTensor(t)) {
581                 push(stack, false);
582                 break;
583               }
584             }
585             if (i == num_inputs) {
586               push(stack, true);
587             }
588           }
589             INST_NEXT;
590           case INST(GUARD): {
591             auto _ = instGuard();
592             if (!stack.back().isTensor()) {
593               // stack.back() is an Uninitialized IValue and this is a guard
594               // on a block output. Uninitialized IValues are never used
595               // so it's safe to pass this guard check
596               push(stack, true);
597             } else {
598               auto& t = stack.back().toTensor();
599               const TypePtr& expected = frame.function->type_table_[inst.X];
600               auto* expected_type = expected->castRaw<TensorType>();
601               if (t.defined() &&
602                   !frames.back().symbols2dims.bindSymbolicShapes(
603                       t.sizes(), expected_type->symbolic_sizes())) {
604                 push(stack, false);
605               } else {
606                 push(stack, expected_type->matchTensor(t));
607               }
608             }
609           }
610             INST_NEXT;
611           case INST(TAIL_CALL): {
612             auto _ = instGuard();
613             GRAPH_DEBUG("running TAIL_CALL for ", inst.X);
614             frame.function->function_table_[inst.X]->ensure_defined();
615             size_t remaining_bailout_depth =
616                 frame.function->remaining_bailout_depth_ > 0
617                 ? frame.function->remaining_bailout_depth_ - 1
618                 : 0;
619             auto& f = *frame.function->function_table_[inst.X];
620             size_t num_inputs = f.num_inputs();
621             size_t base_pointer = frame.base_pointer;
622             TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs);
623             size_t inputs_start = stack.size() - num_inputs;
624             for (const auto i : c10::irange(num_inputs)) {
625               stack.at(base_pointer + i) =
626                   std::move(stack.at(inputs_start + i));
627             }
628             stack.resize(base_pointer + num_inputs);
629             leaveFrame();
630 
631             callFunction(f, stack, remaining_bailout_depth, false);
632             continue;
633           }
634           case INST(LIST_UNPACK): {
635             auto _ = instGuard();
636             listUnpack(stack, inst.X);
637           }
638             INST_NEXT;
639           case INST(TUPLE_CONSTRUCT): {
640             auto _ = instGuard();
641             tupleConstruct(stack, inst.X);
642           }
643             INST_NEXT;
644           case INST(TUPLE_SLICE): {
645             auto _ = instGuard();
646             tupleSlice(stack, inst.X, inst.X + inst.N);
647           }
648             INST_NEXT;
649           case INST(NAMED_TUPLE_CONSTRUCT): {
650             auto _ = instGuard();
651             namedTupleConstruct(
652                 stack,
653                 frame.function->type_table_[inst.X]->expect<TupleType>(),
654                 inst.N);
655           }
656             INST_NEXT;
657           case INST(LIST_CONSTRUCT): {
658             auto _ = instGuard();
659             const auto& type =
660                 frame.function->type_table_[inst.X]->expectRef<ListType>();
661             listConstruct(stack, type, inst.N);
662           }
663             INST_NEXT;
664           case INST(DICT_CONSTRUCT): {
665             auto _ = instGuard();
666             const auto& type =
667                 frame.function->type_table_[inst.X]->expectRef<DictType>();
668             dictConstruct(stack, type, inst.N);
669           }
670             INST_NEXT;
671           case INST(CREATE_OBJECT): {
672             auto _ = instGuard();
673             auto type =
674                 frame.function->type_table_[inst.X]->expect<ClassType>();
675             createObject(stack, type);
676           }
677             INST_NEXT;
678           case INST(ISINSTANCE): {
679             auto _ = instGuard();
680             at::ArrayRef<TypePtr> types(
681                 &frame.function->type_table_[inst.X],
682                 &frame.function->type_table_[inst.X] + inst.N);
683             isinstance(stack, types);
684           }
685             INST_NEXT;
686           case INST(TUPLE_INDEX): {
687             auto _ = instGuard();
688             tupleIndex(stack);
689           }
690             INST_NEXT;
691           case INST(RAISE_EXCEPTION): {
692             auto _ = instGuard();
693             raiseExceptionWithMessage(stack);
694           }
695             INST_NEXT;
696           case INST(UNCHECKED_CAST): {
697             auto _ = instGuard();
698             noop(stack);
699           }
700             INST_NEXT;
701           case INST(__IS__): {
702             auto _ = instGuard();
703             is(stack);
704           }
705             INST_NEXT;
706           case INST(UN_INITIALIZED): {
707             auto _ = instGuard();
708             unInitialized(stack);
709           }
710             INST_NEXT;
711           case INST(__ISNOT__): {
712             auto _ = instGuard();
713             isNot(stack);
714           }
715             INST_NEXT;
716           case INST(FORMAT): {
717             auto _ = instGuard();
718             format(stack, inst.X);
719           }
720             INST_NEXT;
721           case INST(DEVICE): {
722             auto _ = instGuard();
723             device(stack);
724           }
725             INST_NEXT;
726           case INST(DTYPE): {
727             auto _ = instGuard();
728             TORCH_INTERNAL_ASSERT(!stack.empty());
729             dtype(stack);
730           }
731             INST_NEXT;
732           case INST(DIM): {
733             auto _ = instGuard();
734             TORCH_INTERNAL_ASSERT(!stack.empty());
735             dim(stack);
736           }
737             INST_NEXT;
738           case INST(__NOT__): {
739             auto _ = instGuard();
740             _not(stack);
741           }
742             INST_NEXT;
743           case INST(DICT_INDEX): {
744             auto _ = instGuard();
745             dictIndex(stack);
746           }
747             INST_NEXT;
748           case INST(TO_LIST): {
749             auto _ = instGuard();
750             toList(stack);
751           }
752             INST_NEXT;
753           case INST(NUM_TO_TENSOR): {
754             auto _ = instGuard();
755             numToTensorScalar(stack);
756           }
757             INST_NEXT;
758           case INST(IS_CUDA): {
759             auto _ = instGuard();
760             isCuda(stack);
761           }
762             INST_NEXT;
763           case INST(FORK): {
764             auto _ = instGuard();
765             // Move inputs to a separate stack
766             auto& forked_fn =
767                 toGraphFunction(*frame.function->function_table_[inst.X]);
768             InterpreterState forked_interpreter(
769                 forked_fn.get_executor().getPlanFor(stack).code, taskLauncher_);
770             InterpreterContinuation continuation(
771                 forked_interpreter,
772                 Stack(stack.end() - inst.N, stack.end()),
773                 getDistAutogradContextId());
774             drop(stack, inst.N);
775             push(stack, forked_interpreter.getFuture());
776             taskLauncher_(std::move(continuation));
777           }
778             INST_NEXT;
779           case INST(AWAITABLE): {
780             auto _ = instGuard();
781             auto fn_ptr = frame.function->function_table_[inst.X];
782             auto& fn = toGraphFunction(*fn_ptr);
783             auto num_outputs = fn.graph()->outputs().size();
784             TypePtr out_type;
785             if (num_outputs == 1) {
786               out_type = fn.graph()->outputs()[0]->type();
787             } else {
788               std::vector<TypePtr> out_types;
789               for (const auto& o : fn.graph()->outputs()) {
790                 out_types.push_back(o->type());
791               }
792               out_type = TupleType::create(out_types);
793             }
794             auto args = std::vector<IValue>(stack.end() - inst.N, stack.end());
795             auto aw = c10::make_intrusive<c10::ivalue::Await>(out_type);
796             aw->setArgs(std::move(args));
797             aw->setFn(
798                 [&args = aw->args(),
799                  fn_ptr,
800                  taskLauncher = taskLauncher_]() -> IValue {
801                   auto& fn = toGraphFunction(*fn_ptr);
802                   auto n_out = fn.graph()->outputs().size();
803                   torch::jit::Stack s;
804                   for (const auto& arg : args) {
805                     s.push_back(arg);
806                   }
807                   InterpreterState await_interpreter(
808                       fn.get_executor().getPlanFor(s).code, taskLauncher);
809                   await_interpreter.run(s);
810                   if (n_out == 1) {
811                     return s.back();
812                   }
813                   return c10::ivalue::Tuple::create(jit::last(s, n_out));
814                 });
815             drop(stack, inst.N);
816             push(stack, std::move(aw));
817           }
818             INST_NEXT;
819           case INST(WARN): {
820             auto _ = instGuard();
821             // Keeps track of which WARN instruction has been executed before,
822             // we only want to execute each WARN once to match default Python
823             // warning behavior.
824             bool need_warn = true;
825             if (inst.X != -1) {
826               need_warn = warned_nodes_.insert(inst.X);
827             }
828 
829             Node* node =
830                 frames.back().function->instructions_source_.at(frame.pc);
831             auto range = node->sourceRange().source();
832             if (range->filename()) {
833               drop(stack, 1);
834               const auto& msg = stack.back().toStringRef();
835               if (need_warn) {
836                 auto line = range->starting_line_no() +
837                     range->lineno_for_offset(node->sourceRange().start());
838                 c10::SourceLocation location{
839                     "", range->filename()->c_str(), uint32_t(line)};
840                 // Sends the warning to the warning handler with the
841                 // "verbatim" flag. This flag ensures the warning handler
842                 // will print the exception as configured.
843                 c10::warn(c10::Warning(
844                     c10::UserWarning(), location, msg, /*verbatim=*/true));
845               }
846               stack.pop_back();
847             } else {
848               if (need_warn) {
849                 TORCH_WARN(stack.back().toStringRef());
850               }
851               stack.pop_back();
852             }
853           }
854             INST_NEXT;
855         }
856       }
857     } catch (std::exception& e) {
858       for (auto it = entered_objects.rbegin(), end = entered_objects.rend();
859            it != end;
860            ++it) {
861         auto& f = it->toObject()->type()->getMethod("__exit__");
862         Stack stack;
863         push(stack, *it);
864         push(stack, IValue());
865         push(stack, IValue());
866         push(stack, IValue());
867         try {
868           f.run(stack);
869         } catch (std::exception& _) {
870           // TODO(T98048876): Handle `_` correctly.
871         }
872       }
873       if (FLAGS_torch_jit_enable_rethrow_caught_exception) {
874         if (future_) {
875           future_->setError(std::current_exception());
876           return false;
877         }
878         throw;
879       }
880       auto* jit_exception = dynamic_cast<JITException*>(&e);
881       // Janky af.  See https://github.com/pytorch/pytorch/issues/54612
882       auto* not_implemented_error = dynamic_cast<c10::NotImplementedError*>(&e);
883 
884       std::optional<std::string> python_class_name;
885       if (jit_exception) {
886         python_class_name = jit_exception->getPythonClassName();
887       }
888       handleError(
889           e, (bool)jit_exception, not_implemented_error, python_class_name);
890       return false;
891     }
892   }
893 
894 #undef INST_NEXT
895 #undef INST_DISPATCH
896 #undef INST
897 #undef JIT_USE_COMPUTED_GOTO
898 
runImpltorch::jit::InterpreterStateImpl899   bool runImpl(Stack& stack) {
900     if (!profiling::isProfilingOngoing()) {
901       return runTemplate</*EnableProfiling*/ false>(stack);
902     } else {
903       return runTemplate</*EnableProfiling*/ true>(stack);
904     }
905   }
906 
formatStackTracetorch::jit::InterpreterStateImpl907   void formatStackTrace(std::ostream& out) {
908     format_stack_trace(out, callstack());
909   }
910 
handleErrortorch::jit::InterpreterStateImpl911   void handleError(
912       const std::exception& e,
913       bool is_jit_exception,
914       c10::NotImplementedError* not_implemented_error,
915       std::optional<std::string> python_class_name) {
916     ExceptionMessage msg(e);
917     std::ostringstream ss;
918     std::string class_name =
919         python_class_name ? *python_class_name : "RuntimeError";
920     ss << "The following operation failed in the TorchScript interpreter.\n";
921     formatStackTrace(ss);
922     ss << class_name << ": " << msg << "\n";
923     if (future_) {
924       future_->setError(std::make_exception_ptr(Future::FutureError(ss.str())));
925     } else if (is_jit_exception) {
926       // save the original exception's message when creating a new JITException
927       throw JITException(ss.str(), python_class_name, e.what());
928     } else if (not_implemented_error) {
929       throw c10::NotImplementedError(
930           ss.str(),
931           not_implemented_error->backtrace(),
932           not_implemented_error->caller());
933     } else {
934       if (get_cpp_stacktraces_enabled()) {
935         ss << e.what() << "\n";
936       }
937       throw std::runtime_error(ss.str());
938     }
939   }
940 
checkAndStartRecordFunctiontorch::jit::InterpreterStateImpl941   static void checkAndStartRecordFunction(Frame& frame, Stack& stack) {
942     if (!frame.record_function) {
943       auto step_callbacks = at::getStepCallbacksUnlessEmpty(
944           at::RecordScope::TORCHSCRIPT_FUNCTION);
945       if (C10_UNLIKELY(step_callbacks.has_value())) {
946         auto rec_fn =
947             std::make_unique<at::RecordFunction>(std::move(*step_callbacks));
948         TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rec_fn->isActive());
949         if (rec_fn->needsInputs()) {
950           rec_fn->before(
951               frame.function->function_name_,
952               last(stack, frame.function->n_inputs));
953         } else {
954           rec_fn->before(frame.function->function_name_);
955         }
956         frame.record_function = std::move(rec_fn);
957       }
958     }
959   }
960 
961  public:
962   // One way to avoid overhead of forming string would be to return
963   // a vector of frame.function, i.e. CodeImpl*
964   // This is not exactly clean as it will expose, internal details of
965   // interpreter. But this way we hold onto graph/node and Function and
966   // we can create module hierarchy string for each event in autograd
967   // profiler at the end, when consolidating events.
968   // At the moment overhead does not seem exhorbitantly large.
969   // Another option would be return vector of (string, InlinedCallstackPtrs)
970   // string would contain function name and typename of self
971   // Format of the returned vector of strings:
972   // For each frame, the corresponding module name, type and function name
973   // are in following format:
974   // <module-instance-name>(module type)::<function-name>
975   // Special keys for module-instance-name:
976   //   - TOP: for top level module
977   //   - SELF: When method/function of the frame is associated with
978   //           previous frame's module instance
979   //   - INSTANCE_NAME_UNKNOWN: instance name cannot be figured out
980   //   - CALL_FUNCTION: call to free function
moduleHierarchytorch::jit::InterpreterStateImpl981   std::vector<std::string> moduleHierarchy() const {
982     std::vector<std::string> module_function_list;
983     std::string module_hierarchy("TOP");
984     for (size_t i = 0; i < frames.size(); ++i) {
985       const Frame& frame = frames[i];
986       std::string fn_name = frame.function->function_name_;
987       // For each frame, type of the class with which the function is
988       // associated, is queried here. And the type name is added to
989       // module hierarchy.
990       const auto& g = frame.function->graph_;
991       std::string g_self_type;
992       if (g && !g->inputs().empty()) {
993         const auto& g_self_type_ptr =
994             g->inputs()[0]->type()->cast<c10::ClassType>();
995         if (g_self_type_ptr) {
996           g_self_type = g_self_type_ptr->name()->qualifiedName();
997           g_self_type = g_self_type.substr(g_self_type.find_last_of('.') + 1);
998         }
999       }
1000       module_hierarchy.append("(")
1001           .append(g_self_type)
1002           .append(")::")
1003           .append(fn_name);
1004       module_function_list.emplace_back(std::move(module_hierarchy));
1005 
1006       size_t pc = frame.pc;
1007       // CALL nodes have already advanced the pc, so
1008       // undo that to report the call node
1009       if (i + 1 < frames.size()) {
1010         --pc;
1011       }
1012 
1013       Node* node = frame.function->instructions_source_[pc];
1014       if (node->callstack()) {
1015         for (const auto& p : (*node->callstack())->vec()) {
1016           fn_name = std::get<0>(p)->name();
1017           const auto& opt_module_info = std::get<2>(p);
1018           if (opt_module_info.has_value()) {
1019             const auto& module_instance_info = opt_module_info.value();
1020             module_hierarchy = utils::get_module_info(module_instance_info);
1021             module_hierarchy.append("::").append(fn_name);
1022           } else {
1023             // This is likely a call to free function, not associated with
1024             // any class
1025             module_hierarchy = "::";
1026             module_hierarchy.append(fn_name);
1027           }
1028           module_function_list.emplace_back(std::move(module_hierarchy));
1029         }
1030       }
1031 
1032       module_hierarchy = std::string();
1033       // If this node is of type callMethod then the following frame
1034       // will contain the op being executed.
1035       // For such callMethod node, we add the object instance name
1036       // associated with it, since the following frame will not have it.
1037       if (node->kind() == prim::CallMethod) {
1038         std::string class_instance_name;
1039         if (node->input(0)->node()->kind() == prim::GetAttr) {
1040           class_instance_name = node->input(0)->node()->s(attr::name);
1041         } else if (
1042             !node->owningGraph()->inputs().empty() &&
1043             node->input(0) == node->owningGraph()->inputs()[0]) {
1044           class_instance_name = "SELF";
1045         } else {
1046           class_instance_name = "INSTANCE_NAME_UNKNOWN";
1047         }
1048         module_hierarchy = std::move(class_instance_name);
1049       } else if (node->kind() == prim::CallFunction) {
1050         auto function_constant = node->input(0)->node();
1051         auto fun_type =
1052             function_constant->output()->type()->expect<FunctionType>();
1053         auto fun_name = fun_type->function()->name();
1054         module_hierarchy = "CALL_FUNCTION::";
1055         module_hierarchy.append(fun_name);
1056       }
1057     }
1058     return module_function_list;
1059   }
1060 
callstacktorch::jit::InterpreterStateImpl1061   std::vector<StackEntry> callstack() const {
1062     std::vector<StackEntry> entries;
1063     for (const auto i : c10::irange(frames.size())) {
1064       const Frame& frame = frames[i];
1065       std::string previous_fn_name = frame.function->function_name_;
1066       size_t pc = frame.pc;
1067       // CALL nodes have already advanced the pc, so
1068       // undo that to report the call node
1069       if (i + 1 < frames.size()) {
1070         --pc;
1071       }
1072 
1073       Node* node = frame.function->instructions_source_[pc];
1074       if (node->callstack()) {
1075         for (const auto& p : (*node->callstack())->vec()) {
1076           entries.emplace_back(StackEntry{previous_fn_name, std::get<1>(p)});
1077           previous_fn_name = std::get<0>(p)->name();
1078         }
1079       }
1080       entries.emplace_back(StackEntry{previous_fn_name, node->sourceRange()});
1081     }
1082     return entries;
1083   }
1084 
getOrCreateFuturetorch::jit::InterpreterStateImpl1085   c10::intrusive_ptr<Future> getOrCreateFuture() {
1086     if (!future_) {
1087       future_ =
1088           c10::make_intrusive<Future>(frames.front().function->return_type_);
1089     }
1090     return future_;
1091   }
1092 
runAsynctorch::jit::InterpreterStateImpl1093   c10::intrusive_ptr<Future> runAsync(Stack& stack) {
1094     getOrCreateFuture();
1095     runImpl(stack);
1096     return future_;
1097   }
1098 
runtorch::jit::InterpreterStateImpl1099   void run(Stack& stack) {
1100     // By the time the continuation completes the frame will be gone, so this
1101     // must be done before calling runImpl().
1102     TORCH_INTERNAL_ASSERT(!frames.empty());
1103     const auto num_outputs = frames.front().function->n_outputs;
1104     if (runImpl(stack)) {
1105       future_->wait();
1106 
1107       if (num_outputs == 1) {
1108         push(stack, future_->value());
1109       } else {
1110         auto tuple = future_->value().toTuple();
1111         for (const IValue& value : tuple->elements()) {
1112           push(stack, value);
1113         }
1114       }
1115     }
1116   }
1117 };
1118 
currentCallstack()1119 std::vector<StackEntry> currentCallstack() {
1120   if (tls_int_state_ptr_) {
1121     auto cs = tls_int_state_ptr_->callstack();
1122     std::reverse(cs.begin(), cs.end());
1123     return cs;
1124   }
1125   return std::vector<StackEntry>();
1126 }
1127 
currentModuleHierarchy()1128 std::vector<std::string> currentModuleHierarchy() {
1129   if (tls_int_state_ptr_) {
1130     return tls_int_state_ptr_->moduleHierarchy();
1131   }
1132   return std::vector<std::string>();
1133 }
1134 
operator <<(std::ostream & out,const Code & code)1135 std::ostream& operator<<(std::ostream& out, const Code& code) {
1136   out << *code.pImpl->graph_ << "\n";
1137   code.pImpl->dump(out);
1138   return out;
1139 }
1140 
Code(const std::shared_ptr<Graph> & graph,std::string function_name,size_t remaining_bailout_depth)1141 Code::Code(
1142     const std::shared_ptr<Graph>& graph,
1143     std::string function_name,
1144     size_t remaining_bailout_depth)
1145     : pImpl(new CodeImpl(
1146           graph,
1147           std::move(function_name),
1148           remaining_bailout_depth)) {}
1149 
Code(CodeImpl * codeImpl)1150 Code::Code(CodeImpl* codeImpl) : pImpl(codeImpl) {}
1151 
MobileCode(const std::shared_ptr<Graph> & graph,std::string function_name,bool emit_default_input_instructions,bool support_default_args_before_out,bool emit_promoted_ops,size_t remaining_bailout_depth)1152 MobileCode::MobileCode(
1153     const std::shared_ptr<Graph>& graph,
1154     std::string function_name,
1155     bool emit_default_input_instructions,
1156     bool support_default_args_before_out,
1157     bool emit_promoted_ops,
1158     size_t remaining_bailout_depth)
1159     : Code(new interpreter::MobileCodeImpl(
1160           graph,
1161           std::move(function_name),
1162           emit_default_input_instructions,
1163           support_default_args_before_out,
1164           emit_promoted_ops,
1165           remaining_bailout_depth)) {}
1166 
grad_executors()1167 const std::vector<GraphExecutor*>& Code::grad_executors() {
1168   return pImpl->grad_executors();
1169 }
1170 
diff_graph_op_executors()1171 const std::vector<GraphExecutor*>& Code::diff_graph_op_executors() {
1172   return pImpl->diff_graph_op_executors();
1173 }
1174 
num_bailouts() const1175 size_t Code::num_bailouts() const {
1176   return pImpl->type_table_.size();
1177 }
1178 
request_bailout(size_t index)1179 void Code::request_bailout(size_t index) {
1180   pImpl->request_bailout(index);
1181 }
1182 
num_inputs() const1183 size_t Code::num_inputs() const {
1184   return pImpl->n_inputs;
1185 }
1186 
num_outputs() const1187 size_t Code::num_outputs() const {
1188   return pImpl->n_outputs;
1189 }
1190 
constant_table() const1191 const std::vector<c10::IValue>& Code::constant_table() const {
1192   return pImpl->constant_table();
1193 }
1194 
instructions() const1195 const std::vector<Instruction>& Code::instructions() const {
1196   return pImpl->instructions();
1197 }
1198 
op_to_num_specified_args() const1199 const std::unordered_map<std::string, size_t>& Code::op_to_num_specified_args()
1200     const {
1201   return pImpl->op_to_num_specified_args();
1202 }
1203 
instructions_source() const1204 const std::vector<Node*>& Code::instructions_source() const {
1205   return pImpl->instructions_source();
1206 }
1207 
type_table() const1208 const std::vector<TypePtr>& Code::type_table() const {
1209   return pImpl->type_table_;
1210 }
1211 
register_size() const1212 size_t Code::register_size() const {
1213   return pImpl->register_size_;
1214 }
1215 
graph() const1216 std::shared_ptr<Graph> Code::graph() const {
1217   return pImpl->preprocess_.graph;
1218 }
1219 
InterpreterState(const Code & code,TaskLauncher taskLauncher)1220 InterpreterState::InterpreterState(const Code& code, TaskLauncher taskLauncher)
1221     : pImpl(c10::make_intrusive<InterpreterStateImpl>(
1222           code,
1223           std::move(taskLauncher))) {}
1224 
run(Stack & stack)1225 void InterpreterState::run(Stack& stack) {
1226   static_cast<InterpreterStateImpl*>(pImpl.get())->run(stack);
1227 }
1228 
runAsync(Stack & stack)1229 c10::intrusive_ptr<Future> InterpreterState::runAsync(Stack& stack) {
1230   return static_cast<InterpreterStateImpl*>(pImpl.get())->runAsync(stack);
1231 }
1232 
getFuture()1233 c10::intrusive_ptr<Future> InterpreterState::getFuture() {
1234   return static_cast<InterpreterStateImpl*>(pImpl.get())->getOrCreateFuture();
1235 }
1236 
InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl_)1237 InterpreterState::InterpreterState(
1238     c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl_)
1239     : pImpl(std::move(pImpl_)) {}
1240 
operator ()()1241 void InterpreterContinuation::operator()() {
1242 #ifdef USE_RPC
1243   auto prev_dist_id = DistAutogradContainer::currentContextId();
1244   DistAutogradContainer::forceCurrentContextId(dist_autograd_context_id_);
1245 #endif
1246   if (tls_state_ != std::nullopt) {
1247     at::ThreadLocalStateGuard g(*tls_state_);
1248     state.runAsync(stack);
1249   } else {
1250     state.runAsync(stack);
1251   }
1252 #ifdef USE_RPC
1253   DistAutogradContainer::forceCurrentContextId(prev_dist_id);
1254 #endif
1255 }
1256 
1257 } // namespace torch::jit
1258