xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/import.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <torch/csrc/jit/mobile/module.h>
3 #include <torch/csrc/jit/mobile/parse_operators.h>
4 
5 #include <istream>
6 #include <memory>
7 
8 #include <caffe2/serialize/file_adapter.h>
9 
10 namespace torch::jit {
11 using caffe2::serialize::ReadAdapterInterface;
12 using ExtraFilesMap = std::unordered_map<std::string, std::string>;
13 
14 constexpr const char* kArchiveNameBytecode = "bytecode";
15 constexpr const char* kArchiveNameConstants = "constants";
16 constexpr const char* kArchiveNameVersion = "version";
17 
18 // The family of methods below load a serialized Mobile Module
19 // into a mobile::Module object.
20 TORCH_API mobile::Module _load_for_mobile(
21     std::istream& in,
22     std::optional<at::Device> device,
23     ExtraFilesMap& extra_file,
24     uint64_t module_load_options = kDefaultMobileLoadOptions);
25 
26 TORCH_API mobile::Module _load_for_mobile(
27     const std::string& filename,
28     std::optional<at::Device> device,
29     ExtraFilesMap& extra_files);
30 
31 TORCH_API mobile::Module _load_for_mobile(
32     std::unique_ptr<ReadAdapterInterface> rai,
33     std::optional<c10::Device> device,
34     ExtraFilesMap& extra_files,
35     uint64_t module_load_options = kDefaultMobileLoadOptions);
36 
37 TORCH_API mobile::Module _load_for_mobile(
38     const std::string& filename,
39     std::optional<at::Device> device,
40     ExtraFilesMap& extra_files,
41     uint64_t module_load_options);
42 
43 TORCH_API mobile::Module _load_for_mobile(
44     std::istream& in,
45     std::optional<at::Device> device = std::nullopt);
46 
47 TORCH_API mobile::Module _load_for_mobile(
48     const std::string& filename,
49     std::optional<at::Device> device = std::nullopt);
50 
51 TORCH_API mobile::Module _load_for_mobile(
52     std::unique_ptr<ReadAdapterInterface> rai,
53     std::optional<c10::Device> device = std::nullopt);
54 
55 /**
56  * Load only the contents of the "extra/" files whose names are
57  * passed in the map (extra_files). Populate the corresponding values
58  * with the contents of those files. Do not attempt to load the entire
59  * model, and stop once the extra files have been extracted.
60  *
61  * This API is needed to be able to load GPU models on linux CPU
62  * machines and extract only the extra files so that we can inspect
63  * the metadata that was added to the .ptl archive when it was
64  * generated.
65  *
66  */
67 void _load_extra_only_for_mobile(
68     const std::string& filename,
69     std::optional<at::Device> device,
70     ExtraFilesMap& extra_files);
71 
72 // Currently used by both mobile/import.cpp and model_compatibility.cpp.
73 // Should be removed after model_compatibility.cpp start using simplified
74 // version type_resolver and obj_loader.
75 at::TypePtr resolveTypeNameMobile(
76     const c10::QualifiedName& qn,
77     const std::shared_ptr<CompilationUnit>& compilation_unit);
78 c10::StrongTypePtr typeResolverMobile(
79     const c10::QualifiedName& qn,
80     const std::shared_ptr<CompilationUnit>& compilation_unit);
81 c10::intrusive_ptr<c10::ivalue::Object> objLoaderMobile(
82     const at::StrongTypePtr& type,
83     const at::IValue& input,
84     mobile::CompilationUnit& mobile_compilation_unit);
85 
86 // Given a reader, which has access to a model file,
87 // return true if there exists tensors in `bytecode` archive
88 bool isTensorInBytecodeArchive(
89     caffe2::serialize::PyTorchStreamReader& stream_reader);
90 
91 namespace mobile {
92 
93 /**
94  * Given a torch::jit::mobile::Module, return a set of operator names
95  * (with overload name) that are used by any method in this mobile
96  * Mobile. This method runs through the bytecode for all methods
97  * in the specified model (module), and extracts all the root
98  * operator names. Root operators are operators that are called
99  * directly by the model (as opposed to non-root operators, which
100  * may be called transitively by the root operators).
101  *
102  */
103 TORCH_API std::set<std::string> _export_operator_list(
104     torch::jit::mobile::Module& module);
105 
106 } // namespace mobile
107 
108 } // namespace torch::jit
109