xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/import_source.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue_inl.h>
4 #include <ATen/core/qualified_name.h>
5 #include <torch/csrc/jit/api/module.h>
6 #include <torch/csrc/jit/frontend/parser.h>
7 #include <torch/csrc/jit/frontend/resolver.h>
8 #include <torch/csrc/jit/frontend/script_type_parser.h>
9 #include <torch/csrc/jit/frontend/source_range.h>
10 #include <torch/csrc/jit/ir/ir.h>
11 #include <torch/csrc/jit/serialization/export.h>
12 #include <torch/custom_class.h>
13 #include <functional>
14 #include <memory>
15 #include <optional>
16 #include <string>
17 #include <vector>
18 
19 namespace torch::jit {
20 
21 using SourceLoader = std::function<std::shared_ptr<Source>(const std::string&)>;
22 
23 struct SourceImporterImpl : public Resolver,
24                             std::enable_shared_from_this<SourceImporterImpl> {
25   SourceImporterImpl(
26       std::shared_ptr<CompilationUnit> cu,
27       const std::vector<at::IValue>* constant_table,
28       SourceLoader source_loader,
29       size_t version);
30   TypePtr findNamedType(const QualifiedName& name);
31   Function* findFunction(const QualifiedName& name);
32   void parseSourceIfNeeded(const std::string& qualifier);
33   void LEGACY_import_methods(
34       const Module& mod,
35       const std::shared_ptr<Source>& src);
36 
37   std::shared_ptr<SugaredValue> resolveValue(
38       const std::string& name,
39       GraphFunction& m,
40       const SourceRange& loc) override;
41   TypePtr resolveType(const std::string& name, const SourceRange& loc) override;
42 
43  private:
44   void importFunction(const std::string& qualifier, const Def& def);
45   void importNamedType(const std::string& qualifier, const ClassDef& class_def);
46   std::optional<Assign> attributeAssignmentSpecialHandlingHack(
47       const QualifiedName& qualified_classname,
48       const Assign& assign);
49   void importClass(
50       const QualifiedName& qualified_classname,
51       const ClassDef& class_def,
52       bool is_module);
53   void importEnum(
54       const QualifiedName& qualified_name,
55       const ClassDef& enum_def);
56   void importNamedTuple(
57       const QualifiedName& qualified_name,
58       const ClassDef& named_tuple_def);
59 
60   void parsePossibleVersionNumber(Lexer& L);
61 
62   void parseImports(Lexer& L);
63 
64   std::shared_ptr<CompilationUnit> cu_;
65   std::unordered_map<std::string, std::shared_ptr<SugaredValue>> env_;
66   SourceLoader source_loader_;
67   std::optional<size_t> version_ = std::nullopt;
68   std::unordered_set<std::string> loaded_sources_;
69   // named types and functions loaded from a file but not yet defined because
70   // their type has not been requested yet.
71   std::unordered_map<QualifiedName, TreeRef> to_be_defined_;
72 };
73 
74 // Given a directory of serialized TorchScript sources,
75 // This class allows the loading of individual named types in source.
76 // Resolves the dependencies between source files and parses
77 // the source files as necessary.
78 
79 struct TORCH_API SourceImporter {
80   SourceImporter(
81       // The compilation unit that will own the imported source
82       std::shared_ptr<CompilationUnit> cu,
83       const std::vector<at::IValue>* constant_table,
84       SourceLoader loader,
85       size_t version);
86 
87   TypePtr loadType(const QualifiedName& name) const;
88 
89   // Add the methods defined in `src` to the module `mod`, using SourceImporter
90   // to resolve any classes via loadType
91   void LEGACY_import_methods(
92       const Module& mod,
93       const std::shared_ptr<Source>& src);
94   ~SourceImporter();
95 
96  private:
97   std::shared_ptr<SourceImporterImpl> pImpl;
98 };
99 
100 } // namespace torch::jit
101