1 #pragma once 2 3 #include <ATen/core/ivalue_inl.h> 4 #include <ATen/core/qualified_name.h> 5 #include <torch/csrc/jit/api/module.h> 6 #include <torch/csrc/jit/frontend/parser.h> 7 #include <torch/csrc/jit/frontend/resolver.h> 8 #include <torch/csrc/jit/frontend/script_type_parser.h> 9 #include <torch/csrc/jit/frontend/source_range.h> 10 #include <torch/csrc/jit/ir/ir.h> 11 #include <torch/csrc/jit/serialization/export.h> 12 #include <torch/custom_class.h> 13 #include <functional> 14 #include <memory> 15 #include <optional> 16 #include <string> 17 #include <vector> 18 19 namespace torch::jit { 20 21 using SourceLoader = std::function<std::shared_ptr<Source>(const std::string&)>; 22 23 struct SourceImporterImpl : public Resolver, 24 std::enable_shared_from_this<SourceImporterImpl> { 25 SourceImporterImpl( 26 std::shared_ptr<CompilationUnit> cu, 27 const std::vector<at::IValue>* constant_table, 28 SourceLoader source_loader, 29 size_t version); 30 TypePtr findNamedType(const QualifiedName& name); 31 Function* findFunction(const QualifiedName& name); 32 void parseSourceIfNeeded(const std::string& qualifier); 33 void LEGACY_import_methods( 34 const Module& mod, 35 const std::shared_ptr<Source>& src); 36 37 std::shared_ptr<SugaredValue> resolveValue( 38 const std::string& name, 39 GraphFunction& m, 40 const SourceRange& loc) override; 41 TypePtr resolveType(const std::string& name, const SourceRange& loc) override; 42 43 private: 44 void importFunction(const std::string& qualifier, const Def& def); 45 void importNamedType(const std::string& qualifier, const ClassDef& class_def); 46 std::optional<Assign> attributeAssignmentSpecialHandlingHack( 47 const QualifiedName& qualified_classname, 48 const Assign& assign); 49 void importClass( 50 const QualifiedName& qualified_classname, 51 const ClassDef& class_def, 52 bool is_module); 53 void importEnum( 54 const QualifiedName& qualified_name, 55 const ClassDef& enum_def); 56 void importNamedTuple( 57 const QualifiedName& qualified_name, 58 const ClassDef& named_tuple_def); 59 60 void parsePossibleVersionNumber(Lexer& L); 61 62 void parseImports(Lexer& L); 63 64 std::shared_ptr<CompilationUnit> cu_; 65 std::unordered_map<std::string, std::shared_ptr<SugaredValue>> env_; 66 SourceLoader source_loader_; 67 std::optional<size_t> version_ = std::nullopt; 68 std::unordered_set<std::string> loaded_sources_; 69 // named types and functions loaded from a file but not yet defined because 70 // their type has not been requested yet. 71 std::unordered_map<QualifiedName, TreeRef> to_be_defined_; 72 }; 73 74 // Given a directory of serialized TorchScript sources, 75 // This class allows the loading of individual named types in source. 76 // Resolves the dependencies between source files and parses 77 // the source files as necessary. 78 79 struct TORCH_API SourceImporter { 80 SourceImporter( 81 // The compilation unit that will own the imported source 82 std::shared_ptr<CompilationUnit> cu, 83 const std::vector<at::IValue>* constant_table, 84 SourceLoader loader, 85 size_t version); 86 87 TypePtr loadType(const QualifiedName& name) const; 88 89 // Add the methods defined in `src` to the module `mod`, using SourceImporter 90 // to resolve any classes via loadType 91 void LEGACY_import_methods( 92 const Module& mod, 93 const std::shared_ptr<Source>& src); 94 ~SourceImporter(); 95 96 private: 97 std::shared_ptr<SourceImporterImpl> pImpl; 98 }; 99 100 } // namespace torch::jit 101