1 #pragma once 2 3 #include <ATen/core/ivalue.h> 4 #include <c10/util/ArrayRef.h> 5 #include <caffe2/serialize/inline_container.h> 6 #include <torch/csrc/Export.h> 7 #include <torch/csrc/jit/serialization/pickler.h> 8 #include <torch/csrc/jit/serialization/unpickler.h> 9 10 namespace torch::jit { 11 12 /// Pickle an IValue by calling a function to handle writing the data. 13 /// 14 /// `writer` is a function that takes in a pointer to a chunk of memory and its 15 /// size and consumes it. 16 /// 17 /// See `jit::pickle` for more details. 18 TORCH_API void pickle( 19 std::function<void(const char* data_start, size_t data_len)> writer, 20 const IValue& ivalue, 21 std::vector<at::Tensor>* tensor_table = nullptr); 22 23 /// Save a `torch::IValue` in a format compatible with Python's `pickle` module 24 /// 25 /// If present, `tensor_table` is a pointer to a table in which tensors that 26 /// are contained within `ivalue` are stored, and the bytes returned by the 27 /// pickler will only include references to these tensors in the table. This can 28 /// be used to keep the binary blob size small. 29 /// If not provided, tensors are stored in the same byte stream as the pickle 30 /// data, similar to `torch.save()` in eager Python. 31 /// 32 /// Pickled values can be loaded in Python and C++: 33 /// \rst 34 /// .. code-block:: cpp 35 /// 36 /// torch::IValue float_value(2.3); 37 /// 38 /// // TODO: when tensors are stored in the pickle, delete this 39 /// std::vector<at::Tensor> tensor_table; 40 /// auto data = torch::jit::pickle(float_value, &tensor_table); 41 /// 42 /// std::vector<torch::IValue> ivalues = 43 /// torch::jit::unpickle(data.data(), data.size()); 44 /// 45 /// .. code-block:: python 46 /// 47 /// values = torch.load('data.pkl') 48 /// print(values) 49 /// 50 /// \endrst 51 TORCH_API std::vector<char> pickle( 52 const IValue& ivalue, 53 std::vector<at::Tensor>* tensor_table = nullptr); 54 55 /// Save a `torch::IValue` in a format that can be loaded by both 56 /// `torch::pickle_load` in C++ and `torch.load` in Python. 57 TORCH_API std::vector<char> pickle_save(const IValue& ivalue); 58 59 /// Deserialize a `torch::IValue` from bytes produced by either 60 /// `torch::pickle_save` in C++ or `torch.save` in Python 61 TORCH_API IValue pickle_load(const std::vector<char>& data); 62 63 /// Deserialize a `torch::IValue` from bytes produced by either 64 /// `torch::pickle_save` in C++ or `torch.save` in Python with custom object. 65 TORCH_API IValue pickle_load_obj(std::string_view data); 66 67 /// `reader` is a function that takes in a size to read from some pickled 68 /// binary. `reader` should remember where it last read, and return 69 /// the number of bytes read. 70 /// See `torch::pickle` for details. 71 /// type_resolver is used to resolve any JIT type based on type str 72 TORCH_API IValue unpickle( 73 std::function<size_t(char*, size_t)> reader, 74 TypeResolver type_resolver, 75 c10::ArrayRef<at::Tensor> tensor_table, 76 c10::TypePtr (*type_parser)(const std::string&) = 77 Unpickler::defaultTypeParser, 78 ObjLoader obj_loader = nullptr); 79 80 /// Decode a chunk of memory containing pickled data into its `torch::IValue`s. 81 /// 82 /// If any `torch::IValue`s in the pickled data are `Object`s, then a 83 /// `class_resolver` function must be provided. 84 /// 85 /// See `torch::pickle` for details. 86 TORCH_API IValue unpickle( 87 const char* data, 88 size_t size, 89 TypeResolver type_resolver = nullptr, 90 c10::ArrayRef<at::Tensor> tensor_table = {}, 91 c10::TypePtr (*type_parser)(const std::string&) = 92 Unpickler::defaultTypeParser); 93 94 /// Decode a chunk of memory containing pickled data into its `torch::IValue`s. 95 /// 96 /// If any `torch::IValue`s in the pickled data are `Object`s, then a 97 /// `class_resolver` function must be provided. 98 /// 99 /// See `torch::pickle` for details. 100 TORCH_API IValue unpickle( 101 const char* data, 102 size_t size, 103 ObjLoader obj_loader, 104 TypeResolver type_resolver = nullptr, 105 c10::ArrayRef<at::Tensor> tensor_table = {}, 106 c10::TypePtr (*type_parser)(const std::string&) = 107 Unpickler::defaultTypeParser); 108 109 #ifndef C10_MOBILE 110 class VectorReader : public caffe2::serialize::ReadAdapterInterface { 111 public: VectorReader(std::vector<char> data)112 VectorReader(std::vector<char> data) : data_(std::move(data)) {} 113 size()114 size_t size() const override { 115 return data_.size(); 116 } 117 118 size_t read(uint64_t pos, void* buf, size_t n, const char* what) 119 const override; 120 121 private: 122 std::vector<char> data_; 123 }; 124 125 class StringViewReader : public caffe2::serialize::ReadAdapterInterface { 126 public: StringViewReader(std::string_view data)127 StringViewReader(std::string_view data) : data_(data) {} 128 size()129 size_t size() const override { 130 return data_.size(); 131 } 132 133 size_t read(uint64_t pos, void* buf, size_t n, const char* what) 134 const override; 135 136 private: 137 std::string_view data_; 138 }; 139 #endif 140 } // namespace torch::jit 141