xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/interpreter/frame.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <atomic>
4 #include <memory>
5 
6 #include <torch/csrc/jit/runtime/interpreter/code_impl.h>
7 #include <torch/csrc/jit/runtime/profiling_record.h>
8 
9 namespace torch::jit::interpreter {
10 
11 // A Frame captures function's state
12 // (e.g. `pc` and `base_pointer`)
13 // Each Frame corresponds to a call to a `Frame::function`
14 // which has not yet returned
15 // The arguments for `Frame::function`
16 // are located at [base_pointer + arg_number]
17 struct Frame {
18   std::shared_ptr<CodeImpl> function;
19   // program counter corresponds to the index
20   // of the currently executed instruction
21   size_t pc;
22   // marks the start index of the frame
23   // base_pointer is used by TAIL_CALL
24   // to replace the current frame
25   // with a frame of a bailout graph
26   size_t base_pointer;
27 
28   // unique to every frame with prim::profile across all threads
29   std::optional<size_t> id;
30 
31   // RecordFunction object associated with this frame
32   std::unique_ptr<at::RecordFunction> record_function;
33 
34   // symbol table for a frame
35   ShapeSymbolTable symbols2dims;
36 
37   static size_t genId();
38 };
39 
40 } // namespace torch::jit::interpreter
41