1 #pragma once 2 #include <torch/csrc/jit/mobile/module.h> 3 #include <torch/csrc/jit/mobile/parse_operators.h> 4 5 #include <istream> 6 #include <memory> 7 8 #include <caffe2/serialize/file_adapter.h> 9 10 namespace torch::jit { 11 using caffe2::serialize::ReadAdapterInterface; 12 using ExtraFilesMap = std::unordered_map<std::string, std::string>; 13 14 constexpr const char* kArchiveNameBytecode = "bytecode"; 15 constexpr const char* kArchiveNameConstants = "constants"; 16 constexpr const char* kArchiveNameVersion = "version"; 17 18 // The family of methods below load a serialized Mobile Module 19 // into a mobile::Module object. 20 TORCH_API mobile::Module _load_for_mobile( 21 std::istream& in, 22 std::optional<at::Device> device, 23 ExtraFilesMap& extra_file, 24 uint64_t module_load_options = kDefaultMobileLoadOptions); 25 26 TORCH_API mobile::Module _load_for_mobile( 27 const std::string& filename, 28 std::optional<at::Device> device, 29 ExtraFilesMap& extra_files); 30 31 TORCH_API mobile::Module _load_for_mobile( 32 std::unique_ptr<ReadAdapterInterface> rai, 33 std::optional<c10::Device> device, 34 ExtraFilesMap& extra_files, 35 uint64_t module_load_options = kDefaultMobileLoadOptions); 36 37 TORCH_API mobile::Module _load_for_mobile( 38 const std::string& filename, 39 std::optional<at::Device> device, 40 ExtraFilesMap& extra_files, 41 uint64_t module_load_options); 42 43 TORCH_API mobile::Module _load_for_mobile( 44 std::istream& in, 45 std::optional<at::Device> device = std::nullopt); 46 47 TORCH_API mobile::Module _load_for_mobile( 48 const std::string& filename, 49 std::optional<at::Device> device = std::nullopt); 50 51 TORCH_API mobile::Module _load_for_mobile( 52 std::unique_ptr<ReadAdapterInterface> rai, 53 std::optional<c10::Device> device = std::nullopt); 54 55 /** 56 * Load only the contents of the "extra/" files whose names are 57 * passed in the map (extra_files). Populate the corresponding values 58 * with the contents of those files. Do not attempt to load the entire 59 * model, and stop once the extra files have been extracted. 60 * 61 * This API is needed to be able to load GPU models on linux CPU 62 * machines and extract only the extra files so that we can inspect 63 * the metadata that was added to the .ptl archive when it was 64 * generated. 65 * 66 */ 67 void _load_extra_only_for_mobile( 68 const std::string& filename, 69 std::optional<at::Device> device, 70 ExtraFilesMap& extra_files); 71 72 // Currently used by both mobile/import.cpp and model_compatibility.cpp. 73 // Should be removed after model_compatibility.cpp start using simplified 74 // version type_resolver and obj_loader. 75 at::TypePtr resolveTypeNameMobile( 76 const c10::QualifiedName& qn, 77 const std::shared_ptr<CompilationUnit>& compilation_unit); 78 c10::StrongTypePtr typeResolverMobile( 79 const c10::QualifiedName& qn, 80 const std::shared_ptr<CompilationUnit>& compilation_unit); 81 c10::intrusive_ptr<c10::ivalue::Object> objLoaderMobile( 82 const at::StrongTypePtr& type, 83 const at::IValue& input, 84 mobile::CompilationUnit& mobile_compilation_unit); 85 86 // Given a reader, which has access to a model file, 87 // return true if there exists tensors in `bytecode` archive 88 bool isTensorInBytecodeArchive( 89 caffe2::serialize::PyTorchStreamReader& stream_reader); 90 91 namespace mobile { 92 93 /** 94 * Given a torch::jit::mobile::Module, return a set of operator names 95 * (with overload name) that are used by any method in this mobile 96 * Mobile. This method runs through the bytecode for all methods 97 * in the specified model (module), and extracts all the root 98 * operator names. Root operators are operators that are called 99 * directly by the model (as opposed to non-root operators, which 100 * may be called transitively by the root operators). 101 * 102 */ 103 TORCH_API std::set<std::string> _export_operator_list( 104 torch::jit::mobile::Module& module); 105 106 } // namespace mobile 107 108 } // namespace torch::jit 109