xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_tracer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/python_headers.h>
2 
3 #include <torch/csrc/jit/frontend/tracer.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 #include <torch/csrc/jit/passes/inliner.h>
6 #include <torch/csrc/jit/passes/lower_tuples.h>
7 #include <torch/csrc/jit/python/pybind.h>
8 #include <torch/csrc/jit/python/python_tracer.h>
9 #include <torch/csrc/jit/serialization/export.h>
10 #include <torch/csrc/utils/python_strings.h>
11 
12 #include <c10/util/Exception.h>
13 #include <c10/util/irange.h>
14 
15 #include <sstream>
16 
17 using namespace torch::autograd;
18 using namespace torch::jit;
19 using namespace torch::jit::tracer;
20 
21 namespace torch::jit::tracer {
22 
23 // Python interpreter retrieval routine adapted from
24 // https://stackoverflow.com/a/8706144
_pythonCallstack()25 std::vector<StackEntry> _pythonCallstack() {
26   pybind11::gil_scoped_acquire gil;
27   PyFrameObject* frame = PyEval_GetFrame();
28   Py_XINCREF(frame);
29   std::vector<StackEntry> entries;
30 
31   while (nullptr != frame) {
32     auto code = THPCodeObjectPtr(PyFrame_GetCode(frame));
33     size_t line = PyCode_Addr2Line(code.get(), PyFrame_GetLasti(frame));
34     std::string filename = THPUtils_unpackString(code->co_filename);
35     std::string funcname = THPUtils_unpackString(code->co_name);
36     auto source = std::make_shared<Source>(funcname, filename, line);
37     entries.emplace_back(
38         StackEntry{funcname, SourceRange(source, 0, funcname.size())});
39     auto new_frame = PyFrame_GetBack(frame);
40     Py_DECREF(frame);
41     frame = new_frame;
42   }
43   return entries;
44 }
45 
getPythonInterpreterSourceRange()46 SourceRange getPythonInterpreterSourceRange() {
47   auto cs = pythonCallstack();
48   std::optional<std::string> source_filename;
49   size_t source_line = 0;
50   std::stringstream stack_trace;
51   for (const auto& entry : cs) {
52     auto& range = entry.range;
53     if (range.source()) {
54       auto& src = range.source();
55       if (src && src->filename()) {
56         auto line =
57             src->starting_line_no() + src->lineno_for_offset(range.start());
58         stack_trace << *(src->filename()) << "(" << line
59                     << "): " << entry.filename << "\n";
60         if (!source_filename) {
61           source_filename = *(src->filename());
62           source_line = line;
63         }
64       }
65     }
66   }
67 
68   auto stack_trace_text = stack_trace.str();
69   auto source =
70       std::make_shared<Source>(stack_trace_text, source_filename, source_line);
71   return SourceRange(source, 0, stack_trace_text.size());
72 }
73 
createGraphByTracingWithDict(const py::function & func,const py::dict & inputs_dict,const Stack & trace_inputs,const py::function & var_name_lookup_fn,bool strict,bool force_outplace,Module * self,const std::vector<std::string> & argument_names)74 std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracingWithDict(
75     const py::function& func,
76     const py::dict& inputs_dict,
77     const Stack& trace_inputs,
78     const py::function& var_name_lookup_fn,
79     bool strict,
80     bool force_outplace,
81     Module* self,
82     const std::vector<std::string>& argument_names) {
83   C10_LOG_API_USAGE_ONCE("torch.tracer");
84 
85   auto lookup_fn_adapter =
86       [var_name_lookup_fn](const Variable& var) -> std::string {
87     pybind11::gil_scoped_acquire ag;
88     return py::cast<std::string>(var_name_lookup_fn(var));
89   };
90 
91   // The argument_names parameter is parsed in python and its order
92   // is the same as the arguments' decalaration order in forward() method.
93   // These name shall be added to the graph as debug name and the order
94   // should align with the traceable stack we generated by the python dict.
95   std::vector<std::string> compact_argument_names;
96   Stack compact_trace_inputs;
97   for (const auto& argument_name : argument_names) {
98     if (inputs_dict.contains(argument_name)) {
99       compact_argument_names.push_back(argument_name);
100     }
101   }
102   for (const auto& compact_argument_name : compact_argument_names) {
103     for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) {
104       if (py::cast<std::string>(it->first) == compact_argument_name) {
105         compact_trace_inputs.push_back(
106             toIValue(it->second, tryToInferType(it->second).type()));
107       }
108     }
109   }
110 
111   auto outs = tracer::trace(
112       std::move(compact_trace_inputs),
113       [&](const Stack& inputs) -> Stack {
114         // We just leave the inputs_dict as it was and pass it to forward
115         // method.
116         auto out = func(**inputs_dict);
117         if (out.ptr() == Py_None) {
118           AT_ERROR(
119               "The traced function didn't return any values! Side-effects are not "
120               "captured in traces, so it would be a no-op.");
121         }
122         return {toTypeInferredIValue(out)};
123       },
124       lookup_fn_adapter,
125       strict,
126       force_outplace,
127       self,
128       compact_argument_names);
129   return std::make_pair(std::get<0>(outs)->graph, std::get<1>(outs));
130 }
131 
createGraphByTracing(const py::function & func,Stack trace_inputs,const py::function & var_name_lookup_fn,bool strict,bool force_outplace,Module * self,const std::vector<std::string> & argument_names)132 std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracing(
133     const py::function& func,
134     Stack trace_inputs,
135     const py::function& var_name_lookup_fn,
136     bool strict,
137     bool force_outplace,
138     Module* self,
139     const std::vector<std::string>& argument_names) {
140   C10_LOG_API_USAGE_ONCE("torch.tracer");
141 
142   auto lookup_fn_adapter =
143       [var_name_lookup_fn](const Variable& var) -> std::string {
144     pybind11::gil_scoped_acquire ag;
145     return py::cast<std::string>(var_name_lookup_fn(var));
146   };
147 
148   auto outs = tracer::trace(
149       std::move(trace_inputs),
150       [&func](Stack inputs) -> Stack {
151         size_t num_func_inputs = inputs.size();
152         py::tuple py_inputs(num_func_inputs);
153         for (const auto i : c10::irange(num_func_inputs)) {
154           py_inputs[i] = py::cast(inputs[i]);
155         }
156         auto out = func(*py_inputs);
157         if (out.ptr() == Py_None) {
158           AT_ERROR(
159               "The traced function didn't return any values! Side-effects are not "
160               "captured in traces, so it would be a no-op.");
161         }
162         return {toTypeInferredIValue(out)};
163       },
164       lookup_fn_adapter,
165       strict,
166       force_outplace,
167       self,
168       argument_names);
169   return std::make_pair(std::get<0>(outs)->graph, std::get<1>(outs));
170 }
171 
preRecordPythonTrace(THPObjectPtr pyobj,const std::string & arg_types,at::ArrayRef<Variable> inputs,pyobj_list scalar_args)172 Node* preRecordPythonTrace(
173     THPObjectPtr pyobj,
174     const std::string& arg_types,
175     at::ArrayRef<Variable> inputs,
176     pyobj_list scalar_args) {
177   THPObjectPtr apply(PyObject_GetAttrString(pyobj.get(), "apply"));
178   if (!apply) {
179     throw python_error();
180   }
181 
182   auto& graph = getTracingState()->graph;
183 
184   Node* n = graph->createPythonOp(
185       std::move(apply), arg_types, std::move(scalar_args));
186   recordSourceLocation(n);
187 
188   for (const Variable& input : inputs) {
189     n->addInput(getValueTrace(input));
190   }
191 
192   graph->insertNode(n);
193 
194   return n;
195 }
196 
pythonRecordSourceLocation(Node * n)197 void pythonRecordSourceLocation(Node* n) {
198   n->setSourceRange(getPythonInterpreterSourceRange());
199 }
200 
pythonWarn(const std::string & reason)201 void pythonWarn(const std::string& reason) {
202   pybind11::gil_scoped_acquire gil;
203   auto warn_class = py::module::import("torch.jit").attr("TracerWarning");
204   PyErr_WarnEx(warn_class.ptr(), reason.c_str(), 1);
205 }
206 
initPythonTracerBindings(PyObject * module)207 void initPythonTracerBindings(PyObject* module) {
208   setPythonCallstack(_pythonCallstack);
209   setRecordSourceLocation(pythonRecordSourceLocation);
210 
211   auto m = py::handle(module).cast<py::module>();
212   py::class_<TracingState, std::shared_ptr<TracingState>>(
213       m, "TracingState", py::dynamic_attr())
214       // NB: no constructor; you have to get it from C++ code
215       .def(
216           "__repr__",
217           [](const TracingState& s) {
218             std::ostringstream ss;
219             ss << "<TracingState " << (const void*)&s << ">";
220             return ss.str();
221           })
222       .def(
223           "__str__",
224           [](const TracingState& s) -> std::string {
225             std::ostringstream ss;
226             ss << *s.graph;
227             return ss.str();
228           })
229       .def(
230           "push_scope",
231           [](TracingState& s, const std::string& scope_name) {
232             s.graph->push_scope(scope_name);
233           })
234       .def("pop_scope", [](TracingState& s) { s.graph->pop_scope(); })
235       .def(
236           "current_scope",
237           [](TracingState& s) {
238             return s.graph->current_scope()->name().toUnqualString();
239           })
240       .def(
241           "set_graph",
242           [](TracingState& s, std::shared_ptr<Graph> g) {
243             s.graph = std::move(g);
244           })
245       .def("graph", [](TracingState& s) { return s.graph; });
246 
247   m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); });
248   m.def(
249       "_create_graph_by_tracing",
250       createGraphByTracing,
251       py::arg("func"),
252       py::arg("inputs"),
253       py::arg("var_name_lookup_fn"),
254       py::arg("strict"),
255       py::arg("force_outplace"),
256       py::arg("self") = nullptr,
257       py::arg("argument_names") = std::vector<std::string>());
258   m.def("_get_tracing_state", []() { return getTracingState(); });
259   m.def("_set_tracing_state", [](std::shared_ptr<TracingState> state) {
260     return setTracingState(std::move(state));
261   });
262   m.def("_get_value_trace", [](const Variable& var) {
263     return getValueTrace(var);
264   });
265   m.def("_set_value_trace", [](const Variable& var, Value* value) {
266     return setValueTrace(var, value);
267   });
268   m.def("_tracer_set_get_unique_name_fn", [](const py::function& func) {
269     const auto& tracing_state = getTracingState();
270     AT_ASSERT(tracing_state);
271     tracing_state->lookup_var_name_fn =
272         [func](const Variable& var) -> std::string {
273       pybind11::gil_scoped_acquire ag;
274       return py::cast<std::string>(func(var));
275     };
276   });
277   m.def("_tracer_set_force_outplace", [](bool force_outplace) {
278     const auto& tracing_state = getTracingState();
279     AT_ASSERT(tracing_state);
280     tracing_state->force_outplace = force_outplace;
281   });
282 }
283 
284 } // namespace torch::jit::tracer
285