xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/serialize/input-archive.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Device.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/csrc/jit/api/module.h>
6 #include <torch/types.h>
7 #include <optional>
8 
9 #include <iosfwd>
10 #include <memory>
11 #include <string>
12 #include <utility>
13 
14 namespace at {
15 class Tensor;
16 } // namespace at
17 
18 namespace torch {
19 using at::Tensor;
20 namespace jit {
21 struct Module;
22 } // namespace jit
23 } // namespace torch
24 
25 namespace torch {
26 namespace serialize {
27 
28 /// A recursive representation of tensors that can be deserialized from a file
29 /// or stream. In most cases, users should not have to interact with this class,
30 /// and should instead use `torch::load`.
31 class TORCH_API InputArchive final {
32  public:
33   /// Default-constructs the `InputArchive`.
34   InputArchive();
35 
36   // Move is allowed.
37   InputArchive(InputArchive&&) = default;
38   InputArchive& operator=(InputArchive&&) = default;
39 
40   // Copy is disallowed.
41   InputArchive(InputArchive&) = delete;
42   InputArchive& operator=(InputArchive&) = delete;
43 
44   ~InputArchive() = default;
45 
46   /// Reads an `IValue` associated with a given `key`.
47   void read(const std::string& key, c10::IValue& ivalue);
48 
49   /// Reads an `IValue` associated with a given `key`. If there is no `IValue`
50   /// associated with the `key`, this returns false, otherwise it returns true.
51   bool try_read(const std::string& key, c10::IValue& ivalue);
52 
53   /// Reads a `tensor` associated with a given `key`. If there is no `tensor`
54   /// associated with the `key`, this returns false, otherwise it returns true.
55   /// If the tensor is expected to be a buffer (not differentiable), `is_buffer`
56   /// must be `true`.
57   bool try_read(const std::string& key, Tensor& tensor, bool is_buffer = false);
58 
59   /// Reads a `tensor` associated with a given `key`.
60   /// If the tensor is expected to be a buffer (not differentiable), `is_buffer`
61   /// must be `true`.
62   void read(const std::string& key, Tensor& tensor, bool is_buffer = false);
63 
64   /// Reads a `InputArchive` associated with a given `key`. If there is no
65   /// `InputArchive` associated with the `key`, this returns false, otherwise
66   /// it returns true.
67   bool try_read(const std::string& key, InputArchive& archive);
68 
69   /// Reads an `InputArchive` associated with a given `key`.
70   /// The archive can thereafter be used for further deserialization of the
71   /// nested data.
72   void read(const std::string& key, InputArchive& archive);
73 
74   /// Loads the `InputArchive` from a serialized representation stored in the
75   /// file at `filename`. Storage are remapped using device option. If device
76   /// is not specified, the module is loaded to the original device.
77   void load_from(
78       const std::string& filename,
79       std::optional<torch::Device> device = std::nullopt);
80 
81   /// Loads the `InputArchive` from a serialized representation stored in the
82   /// given `stream`. Storage are remapped using device option. If device
83   /// is not specified, the module is loaded to the original device.
84   void load_from(
85       std::istream& stream,
86       std::optional<torch::Device> device = std::nullopt);
87 
88   // Loads given the specified flat array.
89   void load_from(
90       const char* data,
91       size_t size,
92       std::optional<torch::Device> device = std::nullopt);
93 
94   // Loads given the specified read and size functions.
95   void load_from(
96       const std::function<size_t(uint64_t pos, void* buf, size_t nbytes)>&
97           read_func,
98       const std::function<size_t(void)>& size_func,
99       std::optional<torch::Device> device = std::nullopt);
100 
101   // Returns the vector of keys in the input archive.
102   std::vector<std::string> keys();
103 
104   /// Forwards all arguments to `read()`.
105   /// Useful for generic code that can be re-used for both `InputArchive` and
106   /// `OutputArchive` (where `operator()` forwards to `write()`).
107   template <typename... Ts>
operator()108   void operator()(Ts&&... ts) {
109     read(std::forward<Ts>(ts)...);
110   }
111 
112  private:
113   jit::Module module_;
114   std::string hierarchy_prefix_;
115 };
116 } // namespace serialize
117 } // namespace torch
118