xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/python_print.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <torch/csrc/Export.h>
3 #include <torch/csrc/jit/api/module.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <vector>
6 
7 namespace torch::jit {
8 
9 struct Method;
10 struct Module;
11 struct PythonPrintImpl;
12 
13 struct PrintDepsTable {
14   void add(const c10::NamedTypePtr& type);
15 
sizePrintDepsTable16   size_t size() const {
17     return table_.size();
18   }
19 
20   const c10::NamedTypePtr& operator[](size_t index) const {
21     return table_[index];
22   }
23 
24  private:
25   std::vector<c10::NamedTypePtr> table_;
26   std::unordered_set<c10::NamedTypePtr> non_unique_;
27 };
28 
29 struct TORCH_API PythonPrint {
30   PythonPrint(
31       std::vector<IValue>& constant_table,
32       PrintDepsTable& deps_table,
33       c10::TypePrinter type_printer = nullptr,
34       bool enforce_importable = false);
35 
36   void printNamedType(const c10::NamedTypePtr& classType);
37   void printFunction(const Function& callee);
38   void printMethod(const Function& callee);
39 
40   std::string str() const;
41   const SourceRangeRecords& ranges() const;
42   uint64_t minVersion() const;
43 
44  private:
45   std::shared_ptr<PythonPrintImpl> pImpl;
46 };
47 
48 TORCH_API bool printerHasSpecialCaseFor(c10::Symbol sym);
49 
50 TORCH_API void jitModuleToPythonCodeAndConstants(
51     const Module& module,
52     ExtraFilesMap* jit_sources, // output
53     std::vector<IValue>* constants // output
54 );
55 
56 } // namespace torch::jit
57