1 #pragma once 2 3 #include <ATen/core/ivalue.h> 4 #include <caffe2/serialize/inline_container.h> 5 #include <torch/csrc/jit/api/module.h> 6 #include <torch/csrc/jit/ir/ir.h> 7 8 #include <istream> 9 10 namespace caffe2::serialize { 11 class ReadAdapterInterface; 12 } // namespace caffe2::serialize 13 14 namespace torch::jit { 15 16 class DeserializationStorageContext; 17 18 TORCH_API Module import_ir_module( 19 std::shared_ptr<CompilationUnit> cu, 20 const std::string& filename, 21 std::optional<c10::Device> device = std::nullopt, 22 bool load_debug_files = true); 23 24 TORCH_API Module import_ir_module( 25 std::shared_ptr<CompilationUnit> cu, 26 std::istream& in, 27 std::optional<c10::Device> device = std::nullopt, 28 bool load_debug_files = true); 29 30 TORCH_API Module import_ir_module( 31 std::shared_ptr<CompilationUnit> cu, 32 std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai, 33 std::optional<c10::Device> device = std::nullopt, 34 bool load_debug_files = true); 35 36 TORCH_API Module import_ir_module( 37 std::shared_ptr<CompilationUnit> cu, 38 const std::string& filename, 39 std::optional<c10::Device> device, 40 ExtraFilesMap& extra_files, 41 bool load_debug_files = true, 42 bool restore_shapes = false); 43 44 // For reading unified serialization format from torch.Package 45 TORCH_API Module import_ir_module( 46 std::shared_ptr<CompilationUnit> cu, 47 std::shared_ptr<caffe2::serialize::PyTorchStreamReader> reader, 48 std::shared_ptr<torch::jit::DeserializationStorageContext> storage_context, 49 std::optional<at::Device> device, 50 const std::string& ts_id /* torchscript identifier inside package */); 51 52 TORCH_API Module import_ir_module( 53 std::shared_ptr<CompilationUnit> cu, 54 std::istream& in, 55 std::optional<c10::Device> device, 56 ExtraFilesMap& extra_files, 57 bool load_debug_files = true, 58 bool restore_shapes = false); 59 60 TORCH_API Module import_ir_module( 61 std::shared_ptr<CompilationUnit> cu, 62 std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai, 63 std::optional<c10::Device> device, 64 ExtraFilesMap& extra_files, 65 bool load_debug_files = true); 66 67 TORCH_API Module import_ir_module( 68 std::shared_ptr<CompilationUnit> cu, 69 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai, 70 std::optional<c10::Device> device, 71 ExtraFilesMap& extra_files, 72 bool load_debug_files = true); 73 74 /// Loads a serialized `Module` from the given `istream`. 75 /// 76 /// The istream must contain a serialized `Module`, exported via 77 /// `torch::jit::ExportModule` in C++. 78 TORCH_API Module load( 79 std::istream& in, 80 std::optional<c10::Device> device = std::nullopt, 81 bool load_debug_files = true); 82 83 TORCH_API Module load( 84 std::istream& in, 85 std::optional<c10::Device> device, 86 ExtraFilesMap& extra_files, 87 bool load_debug_files = true); 88 89 /// Loads a serialized `Module` from the given `filename`. 90 /// 91 /// The file stored at the location given in `filename` must contain a 92 /// serialized `Module`, exported either via `ScriptModule.save()` in 93 /// Python or `torch::jit::ExportModule` in C++. 94 TORCH_API Module load( 95 const std::string& filename, 96 std::optional<c10::Device> device = std::nullopt, 97 bool load_debug_files = true); 98 99 TORCH_API Module load( 100 const std::string& filename, 101 std::optional<c10::Device> device, 102 ExtraFilesMap& extra_files, 103 bool load_debug_files = true); 104 105 /// Loads a serialized `Module` from the given shared_ptr `rai`. 106 /// 107 /// The reader adapter, which is for customized input stream, must contain a 108 /// serialized `Module`, exported either via `ScriptModule.save()` in 109 /// Python or `torch::jit::ExportModule` in C++. 110 TORCH_API Module load( 111 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai, 112 std::optional<c10::Device> device = std::nullopt, 113 bool load_debug_files = true); 114 115 TORCH_API Module load( 116 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai, 117 std::optional<c10::Device> device, 118 ExtraFilesMap& extra_files, 119 bool load_debug_files = true); 120 121 TORCH_API Module jitModuleFromSourceAndConstants( 122 const IValue& ivalue, 123 const ExtraFilesMap& source, 124 const std::vector<IValue>& constants, 125 int32_t version); 126 127 TORCH_API Module parse_and_initialize_jit_module( 128 const std::shared_ptr<char>& data, 129 size_t size, 130 ExtraFilesMap& extra_files, 131 std::optional<at::Device> device = std::nullopt); 132 133 TORCH_API Module load_jit_module_from_file( 134 const std::string& filename, 135 ExtraFilesMap& extra_files, 136 std::optional<at::Device> device = std::nullopt); 137 138 TORCH_API Module load_jit_module_from_stream( 139 std::istream& in, 140 ExtraFilesMap& extra_files, 141 std::optional<at::Device> device = std::nullopt); 142 143 TORCH_API Module parse_and_initialize_jit_module( 144 const std::shared_ptr<char>& data, 145 size_t size, 146 ExtraFilesMap& extra_files, 147 std::optional<at::Device> device); 148 149 TORCH_API c10::intrusive_ptr<c10::ivalue::Object> ObjLoaderFunc( 150 const at::StrongTypePtr& type, 151 IValue input); 152 153 } // namespace torch::jit 154