xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_tracer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/frontend/source_range.h>
4 #include <torch/csrc/jit/frontend/tracer.h>
5 #include <torch/csrc/python_headers.h>
6 #include <torch/csrc/utils/pybind.h>
7 
8 #include <memory>
9 #include <string>
10 
11 namespace torch::jit {
12 
13 struct Module;
14 
15 namespace tracer {
16 void initPythonTracerBindings(PyObject* module);
17 
18 SourceRange getPythonInterpreterSourceRange();
19 
20 Node* preRecordPythonTrace(
21     THPObjectPtr pyobj,
22     const std::string& arg_types,
23     at::ArrayRef<autograd::Variable> inputs,
24     std::vector<THPObjectPtr> scalar_args);
25 
26 std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracingWithDict(
27     const py::function& func,
28     const py::dict& inputs_dict,
29     const Stack& inputs,
30     const py::function& var_name_lookup_fn,
31     bool strict,
32     bool force_outplace,
33     Module* self = nullptr,
34     const std::vector<std::string>& argument_names = {});
35 
36 std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracing(
37     const py::function& func,
38     Stack inputs,
39     const py::function& var_name_lookup_fn,
40     bool strict,
41     bool force_outplace,
42     Module* self = nullptr,
43     const std::vector<std::string>& argument_names = {});
44 } // namespace tracer
45 } // namespace torch::jit
46