1 #pragma once 2 #include <torch/csrc/Export.h> 3 #include <torch/csrc/jit/api/module.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 #include <vector> 6 7 namespace torch::jit { 8 9 struct Method; 10 struct Module; 11 struct PythonPrintImpl; 12 13 struct PrintDepsTable { 14 void add(const c10::NamedTypePtr& type); 15 sizePrintDepsTable16 size_t size() const { 17 return table_.size(); 18 } 19 20 const c10::NamedTypePtr& operator[](size_t index) const { 21 return table_[index]; 22 } 23 24 private: 25 std::vector<c10::NamedTypePtr> table_; 26 std::unordered_set<c10::NamedTypePtr> non_unique_; 27 }; 28 29 struct TORCH_API PythonPrint { 30 PythonPrint( 31 std::vector<IValue>& constant_table, 32 PrintDepsTable& deps_table, 33 c10::TypePrinter type_printer = nullptr, 34 bool enforce_importable = false); 35 36 void printNamedType(const c10::NamedTypePtr& classType); 37 void printFunction(const Function& callee); 38 void printMethod(const Function& callee); 39 40 std::string str() const; 41 const SourceRangeRecords& ranges() const; 42 uint64_t minVersion() const; 43 44 private: 45 std::shared_ptr<PythonPrintImpl> pImpl; 46 }; 47 48 TORCH_API bool printerHasSpecialCaseFor(c10::Symbol sym); 49 50 TORCH_API void jitModuleToPythonCodeAndConstants( 51 const Module& module, 52 ExtraFilesMap* jit_sources, // output 53 std::vector<IValue>* constants // output 54 ); 55 56 } // namespace torch::jit 57