1 #pragma once 2 3 #include <c10/core/Device.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/csrc/jit/api/module.h> 6 #include <torch/types.h> 7 #include <optional> 8 9 #include <iosfwd> 10 #include <memory> 11 #include <string> 12 #include <utility> 13 14 namespace at { 15 class Tensor; 16 } // namespace at 17 18 namespace torch { 19 using at::Tensor; 20 namespace jit { 21 struct Module; 22 } // namespace jit 23 } // namespace torch 24 25 namespace torch { 26 namespace serialize { 27 28 /// A recursive representation of tensors that can be deserialized from a file 29 /// or stream. In most cases, users should not have to interact with this class, 30 /// and should instead use `torch::load`. 31 class TORCH_API InputArchive final { 32 public: 33 /// Default-constructs the `InputArchive`. 34 InputArchive(); 35 36 // Move is allowed. 37 InputArchive(InputArchive&&) = default; 38 InputArchive& operator=(InputArchive&&) = default; 39 40 // Copy is disallowed. 41 InputArchive(InputArchive&) = delete; 42 InputArchive& operator=(InputArchive&) = delete; 43 44 ~InputArchive() = default; 45 46 /// Reads an `IValue` associated with a given `key`. 47 void read(const std::string& key, c10::IValue& ivalue); 48 49 /// Reads an `IValue` associated with a given `key`. If there is no `IValue` 50 /// associated with the `key`, this returns false, otherwise it returns true. 51 bool try_read(const std::string& key, c10::IValue& ivalue); 52 53 /// Reads a `tensor` associated with a given `key`. If there is no `tensor` 54 /// associated with the `key`, this returns false, otherwise it returns true. 55 /// If the tensor is expected to be a buffer (not differentiable), `is_buffer` 56 /// must be `true`. 57 bool try_read(const std::string& key, Tensor& tensor, bool is_buffer = false); 58 59 /// Reads a `tensor` associated with a given `key`. 60 /// If the tensor is expected to be a buffer (not differentiable), `is_buffer` 61 /// must be `true`. 62 void read(const std::string& key, Tensor& tensor, bool is_buffer = false); 63 64 /// Reads a `InputArchive` associated with a given `key`. If there is no 65 /// `InputArchive` associated with the `key`, this returns false, otherwise 66 /// it returns true. 67 bool try_read(const std::string& key, InputArchive& archive); 68 69 /// Reads an `InputArchive` associated with a given `key`. 70 /// The archive can thereafter be used for further deserialization of the 71 /// nested data. 72 void read(const std::string& key, InputArchive& archive); 73 74 /// Loads the `InputArchive` from a serialized representation stored in the 75 /// file at `filename`. Storage are remapped using device option. If device 76 /// is not specified, the module is loaded to the original device. 77 void load_from( 78 const std::string& filename, 79 std::optional<torch::Device> device = std::nullopt); 80 81 /// Loads the `InputArchive` from a serialized representation stored in the 82 /// given `stream`. Storage are remapped using device option. If device 83 /// is not specified, the module is loaded to the original device. 84 void load_from( 85 std::istream& stream, 86 std::optional<torch::Device> device = std::nullopt); 87 88 // Loads given the specified flat array. 89 void load_from( 90 const char* data, 91 size_t size, 92 std::optional<torch::Device> device = std::nullopt); 93 94 // Loads given the specified read and size functions. 95 void load_from( 96 const std::function<size_t(uint64_t pos, void* buf, size_t nbytes)>& 97 read_func, 98 const std::function<size_t(void)>& size_func, 99 std::optional<torch::Device> device = std::nullopt); 100 101 // Returns the vector of keys in the input archive. 102 std::vector<std::string> keys(); 103 104 /// Forwards all arguments to `read()`. 105 /// Useful for generic code that can be re-used for both `InputArchive` and 106 /// `OutputArchive` (where `operator()` forwards to `write()`). 107 template <typename... Ts> operator()108 void operator()(Ts&&... ts) { 109 read(std::forward<Ts>(ts)...); 110 } 111 112 private: 113 jit::Module module_; 114 std::string hierarchy_prefix_; 115 }; 116 } // namespace serialize 117 } // namespace torch 118