xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/import.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 #include <caffe2/serialize/inline_container.h>
5 #include <torch/csrc/jit/api/module.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 
8 #include <istream>
9 
10 namespace caffe2::serialize {
11 class ReadAdapterInterface;
12 } // namespace caffe2::serialize
13 
14 namespace torch::jit {
15 
16 class DeserializationStorageContext;
17 
18 TORCH_API Module import_ir_module(
19     std::shared_ptr<CompilationUnit> cu,
20     const std::string& filename,
21     std::optional<c10::Device> device = std::nullopt,
22     bool load_debug_files = true);
23 
24 TORCH_API Module import_ir_module(
25     std::shared_ptr<CompilationUnit> cu,
26     std::istream& in,
27     std::optional<c10::Device> device = std::nullopt,
28     bool load_debug_files = true);
29 
30 TORCH_API Module import_ir_module(
31     std::shared_ptr<CompilationUnit> cu,
32     std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
33     std::optional<c10::Device> device = std::nullopt,
34     bool load_debug_files = true);
35 
36 TORCH_API Module import_ir_module(
37     std::shared_ptr<CompilationUnit> cu,
38     const std::string& filename,
39     std::optional<c10::Device> device,
40     ExtraFilesMap& extra_files,
41     bool load_debug_files = true,
42     bool restore_shapes = false);
43 
44 // For reading unified serialization format from torch.Package
45 TORCH_API Module import_ir_module(
46     std::shared_ptr<CompilationUnit> cu,
47     std::shared_ptr<caffe2::serialize::PyTorchStreamReader> reader,
48     std::shared_ptr<torch::jit::DeserializationStorageContext> storage_context,
49     std::optional<at::Device> device,
50     const std::string& ts_id /* torchscript identifier inside package */);
51 
52 TORCH_API Module import_ir_module(
53     std::shared_ptr<CompilationUnit> cu,
54     std::istream& in,
55     std::optional<c10::Device> device,
56     ExtraFilesMap& extra_files,
57     bool load_debug_files = true,
58     bool restore_shapes = false);
59 
60 TORCH_API Module import_ir_module(
61     std::shared_ptr<CompilationUnit> cu,
62     std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
63     std::optional<c10::Device> device,
64     ExtraFilesMap& extra_files,
65     bool load_debug_files = true);
66 
67 TORCH_API Module import_ir_module(
68     std::shared_ptr<CompilationUnit> cu,
69     std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
70     std::optional<c10::Device> device,
71     ExtraFilesMap& extra_files,
72     bool load_debug_files = true);
73 
74 /// Loads a serialized `Module` from the given `istream`.
75 ///
76 /// The istream must contain a serialized `Module`, exported via
77 /// `torch::jit::ExportModule` in C++.
78 TORCH_API Module load(
79     std::istream& in,
80     std::optional<c10::Device> device = std::nullopt,
81     bool load_debug_files = true);
82 
83 TORCH_API Module load(
84     std::istream& in,
85     std::optional<c10::Device> device,
86     ExtraFilesMap& extra_files,
87     bool load_debug_files = true);
88 
89 /// Loads a serialized `Module` from the given `filename`.
90 ///
91 /// The file stored at the location given in `filename` must contain a
92 /// serialized `Module`, exported either via `ScriptModule.save()` in
93 /// Python or `torch::jit::ExportModule` in C++.
94 TORCH_API Module load(
95     const std::string& filename,
96     std::optional<c10::Device> device = std::nullopt,
97     bool load_debug_files = true);
98 
99 TORCH_API Module load(
100     const std::string& filename,
101     std::optional<c10::Device> device,
102     ExtraFilesMap& extra_files,
103     bool load_debug_files = true);
104 
105 /// Loads a serialized `Module` from the given shared_ptr `rai`.
106 ///
107 /// The reader adapter, which is for customized input stream, must contain a
108 /// serialized `Module`, exported either via `ScriptModule.save()` in
109 /// Python or `torch::jit::ExportModule` in C++.
110 TORCH_API Module load(
111     std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
112     std::optional<c10::Device> device = std::nullopt,
113     bool load_debug_files = true);
114 
115 TORCH_API Module load(
116     std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
117     std::optional<c10::Device> device,
118     ExtraFilesMap& extra_files,
119     bool load_debug_files = true);
120 
121 TORCH_API Module jitModuleFromSourceAndConstants(
122     const IValue& ivalue,
123     const ExtraFilesMap& source,
124     const std::vector<IValue>& constants,
125     int32_t version);
126 
127 TORCH_API Module parse_and_initialize_jit_module(
128     const std::shared_ptr<char>& data,
129     size_t size,
130     ExtraFilesMap& extra_files,
131     std::optional<at::Device> device = std::nullopt);
132 
133 TORCH_API Module load_jit_module_from_file(
134     const std::string& filename,
135     ExtraFilesMap& extra_files,
136     std::optional<at::Device> device = std::nullopt);
137 
138 TORCH_API Module load_jit_module_from_stream(
139     std::istream& in,
140     ExtraFilesMap& extra_files,
141     std::optional<at::Device> device = std::nullopt);
142 
143 TORCH_API Module parse_and_initialize_jit_module(
144     const std::shared_ptr<char>& data,
145     size_t size,
146     ExtraFilesMap& extra_files,
147     std::optional<at::Device> device);
148 
149 TORCH_API c10::intrusive_ptr<c10::ivalue::Object> ObjLoaderFunc(
150     const at::StrongTypePtr& type,
151     IValue input);
152 
153 } // namespace torch::jit
154