xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/pickler.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/qualified_name.h>
4 #include <string>
5 #include <utility>
6 #include <vector>
7 
8 #include <ATen/Utils.h>
9 #include <ATen/core/ivalue.h>
10 #include <ATen/core/jit_type.h>
11 #include <c10/util/ArrayRef.h>
12 #include <c10/util/FbcodeMaps.h>
13 #include <c10/util/intrusive_ptr.h>
14 #include <c10/util/string_view.h>
15 #include <torch/csrc/Export.h>
16 
17 namespace torch::jit {
18 
19 // See Python's pickletools.py for a detailed description of each of these codes
20 enum class PickleOpCode : char {
21   MARK = '(',
22   STOP = '.',
23   POP = '0',
24   POP_MARK = '1',
25   DUP = '2',
26   FLOAT = 'F',
27   INT = 'I',
28   BININT = 'J',
29   BININT1 = 'K',
30   LONG = 'L',
31   BININT2 = 'M',
32   NONE = 'N',
33   PERSID = 'P',
34   BINPERSID = 'Q',
35   REDUCE = 'R',
36   STRING = 'S',
37   BINSTRING = 'T',
38   SHORT_BINSTRING = 'U',
39   // NB: Avoid using UNICODE as it is a macro in the Windows API
40   UNICODE_ = 'V',
41   BINUNICODE = 'X',
42   APPEND = 'a',
43   BUILD = 'b',
44   GLOBAL = 'c',
45   DICT = 'd',
46   EMPTY_DICT = '}',
47   APPENDS = 'e',
48   GET = 'g',
49   BINGET = 'h',
50   INST = 'i',
51   LONG_BINGET = 'j',
52   LIST = 'l',
53   EMPTY_LIST = ']',
54   OBJ = 'o',
55   PUT = 'p',
56   BINPUT = 'q',
57   LONG_BINPUT = 'r',
58   SETITEM = 's',
59   TUPLE = 't',
60   EMPTY_TUPLE = ')',
61   SETITEMS = 'u',
62   BINFLOAT = 'G',
63 
64   // Protocol 2
65   PROTO = char('\x80'),
66   NEWOBJ = '\x81',
67   EXT1 = '\x82',
68   EXT2 = '\x83',
69   EXT4 = '\x84',
70   TUPLE1 = '\x85',
71   TUPLE2 = '\x86',
72   TUPLE3 = '\x87',
73   NEWTRUE = '\x88',
74   NEWFALSE = '\x89',
75   LONG1 = '\x8a',
76   LONG4 = '\x8b',
77 
78   // Protocol 3 (Python 3.x)
79   BINBYTES = 'B',
80   SHORT_BINBYTES = 'C',
81 
82   // Protocol 4
83   SHORT_BINUNICODE = char('\x8c'),
84   BINUNICODE8 = '\x8d',
85   BINBYTES8 = '\x8e',
86   EMPTY_SET = '\x8f',
87   ADDITEMS = '\x90',
88   FROZENSET = '\x91',
89   NEWOBJ_EX = '\x92',
90   STACK_GLOBAL = '\x93',
91   MEMOIZE = '\x94',
92   FRAME = '\x95'
93 };
94 
95 using ::c10::IValue;
96 
97 struct WriteableTensorData {
dataWriteableTensorData98   const char* data() const {
99     return static_cast<const char*>(tensor_.storage().data());
100   }
sizeInBytesWriteableTensorData101   size_t sizeInBytes() const {
102     return size_;
103   }
nbytesWriteableTensorData104   size_t nbytes() const {
105     return tensor_.storage().nbytes();
106   }
storageHasDeleterWriteableTensorData107   bool storageHasDeleter() const {
108     return tensor_.storage().data_ptr().get_context() != nullptr;
109   }
110 
111  private:
112   friend TORCH_API WriteableTensorData
113   getWriteableTensorData(const at::Tensor& tensor, bool to_cpu);
114   at::Tensor tensor_;
115   uint64_t size_;
116 };
117 
118 void setTypeTags(bool state);
119 bool getTypeTags();
120 
121 class TORCH_API Pickler {
122   AT_DISALLOW_COPY_AND_ASSIGN(Pickler);
123 
124  public:
Pickler(std::function<void (const char *,size_t)> writer)125   Pickler(std::function<void(const char*, size_t)> writer)
126       : Pickler(std::move(writer), nullptr, nullptr, nullptr) {}
127 
128   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
129   Pickler(
130       std::function<void(const char*, size_t)> writer,
131       std::vector<at::Tensor>* tensor_table,
132       std::function<c10::QualifiedName(const c10::ClassTypePtr&)> type_renamer,
133       std::vector<c10::ClassTypePtr>* memoized_class_types,
134       std::function<std::string(const at::Tensor&)> get_tensor_id = nullptr,
135       bool tag_aggregates = true)
writer_(std::move (writer))136       : writer_(std::move(writer)),
137         tensor_table_(tensor_table),
138         type_renamer_(std::move(type_renamer)),
139         memoized_class_types_(memoized_class_types),
140         get_tensor_id_(std::move(get_tensor_id)),
141         tag_aggregates_(tag_aggregates) {}
142   ~Pickler();
143 
144   // Push protocol onto the stack
145   void protocol();
146 
147   // Push STOP PickleOpCode onto the stack
148   void stop();
149 
150   void pushIValue(const IValue& ivalue);
151 
152   void startTuple();
153   void endTuple();
154 
tensorData()155   const std::vector<at::Tensor>& tensorData() {
156     return tensor_data_;
157   }
158 
159   void pushEmptyDict();
160   void pushDict(const IValue& ivalue);
161   void pushInt(int64_t value);
162   void pushLong(const std::string& data);
163 
164  private:
165   void pushIValueImpl(const IValue& ivalue);
166   void startTypeTag();
167   void endTypeTag(const IValue& value);
168   void pushBool(bool value);
169   void pushDouble(double value);
170   void pushComplexDouble(const IValue& value);
171   void pushGenericList(const IValue& ivalue);
172   void pushIntList(const IValue& ivalue);
173   void pushList(const IValue& ivalue);
174   void pushTensor(const IValue& ivalue);
175   void pushTensorReference(const IValue& ivalue);
176   void pushLiteralTensor(const IValue& ivalue);
177   void pushLiteralSparseTensor(const at::Tensor& tensor);
178   void pushTuple(const IValue& ivalue);
179   void pushString(const std::string& string);
180   void pushDevice(const IValue& ivalue);
181 #ifdef USE_DISTRIBUTED
182   void pushRRef(const IValue& ivalue);
183 #endif
184   // unmemoized version
185   void pushStringImpl(const std::string& string);
186   void pushStorageOfTensor(const at::Tensor& tensor);
187 
188   void pushBinGet(uint32_t memo_id);
189   void pushSpecializedList(
190       const IValue& ivalue,
191       const char* list_name,
192       const std::function<void(const IValue&)>& item_pusher);
193   void pushGlobal(c10::string_view module_name, c10::string_view class_name);
194   // raw string data is appended directly to the byte stream
195   void pushBytes(const std::string& string);
196   void pushTensorData(const at::Tensor& tensor);
197 
198   // Add a BINPUT op and return the memoization id used
199   size_t pushNextBinPut();
200 
201   const void* getPointer(const IValue& ivalue);
202 
203   // Caller checks that bufferPos_ > 0
flushNonEmpty()204   void flushNonEmpty() {
205     writer_(buffer_.data(), bufferPos_);
206     bufferPos_ = 0;
207   }
208 
flush()209   void flush() {
210     if (bufferPos_ != 0) {
211       flushNonEmpty();
212     }
213   }
214 
215   // These convert values to bytes and add them to the stack (NB: since T is to
216   // the left of a '::', its type cannot be deduced by the compiler so one must
217   // explicitly instantiate the template, i.e. push<int>(int) works, push(int)
218   // does not)
219   static CONSTEXPR_EXCEPT_WIN_CUDA size_t kBufferSize = 256;
220   template <typename T>
push(std::common_type_t<T> value)221   void push(std::common_type_t<T> value) {
222     const char* begin = reinterpret_cast<const char*>(&value);
223     if (bufferPos_ + sizeof(T) > buffer_.size()) {
224       flushNonEmpty();
225     }
226     static_assert(sizeof(T) <= kBufferSize, "Buffer size assumption");
227     memcpy(buffer_.data() + bufferPos_, begin, sizeof(T));
228     bufferPos_ += sizeof(T);
229   }
230 
231   // Stream to write binary data to
232   // Code shouldn't call writer_ directly without first flushing.
233   std::function<void(const char*, size_t)> writer_;
234 
235   // Buffer to avoid calling a writer_ on a per-byte basis.
236   std::array<char, kBufferSize> buffer_;
237   size_t bufferPos_{0};
238 
239   // Stack of opcodes/data
240   std::vector<char> stack_;
241 
242   // External table of tensors to serialize. If this is missing, then tensors
243   // are serialized directly into the pickle
244   std::vector<at::Tensor>* tensor_table_;
245 
246   // TODO: only use this if necessary (add a pass to find all shared ivalues,
247   // and only memoize those)
248   uint32_t memo_id_ = 0;
249 
250   // Memoization of IValues that have been written (index in table is used for
251   // BINPUT opcodes) to enable shared references
252   c10::FastMap<const void*, uint32_t> memoized_ivalue_map_;
253 
254   // because we de-dup ivalues based on their raw pointer address in the above
255   // map we need to keep all the memoized values alive during the pickle.
256   // Otherwise, it is possible that a raw address gets reused for another
257   // object, and we will alias it to the old object at that address.
258   std::vector<IValue> memoized_ivalues_;
259 
260   std::function<c10::QualifiedName(const c10::ClassTypePtr&)> type_renamer_;
261 
262   // List of all the types that it wrote, inspect from the IValues it wrote.
263   std::vector<c10::ClassTypePtr>* memoized_class_types_;
264 
265   // Function to grab next id_name for tensor storage, function is responsible
266   // for returning unique ids
267   std::function<std::string(const at::Tensor&)> get_tensor_id_;
268 
269   // List of tensor storages to serialize in the same binary as the pickle data
270   // similar to ivalues, they are memoized using BINPUT
271   std::vector<at::Tensor> tensor_data_;
272   c10::FastMap<const void*, uint32_t> memoized_storage_map_;
273 
274   c10::FastMap<std::string, uint32_t> memoized_globals_map_;
275   c10::FastMap<std::string, uint32_t> memoized_strings_map_;
276   c10::FastMap<std::string, uint32_t> memoized_devices_map_;
277   // when true, List and Dict objects will be wrapped in a
278   // torch.jit._pickle.restore_type_tag call to correctly set the dynamic
279   // TorchScript type for the object. When true the thing unpickling must have
280   // torch installed.
281   bool tag_aggregates_;
282 };
283 
284 // returns a (tensor, record_size) for a tensor, converting it to a CPU tensor
285 // if it was CUDA and to_cpu is True.
286 TORCH_API WriteableTensorData
287 getWriteableTensorData(const at::Tensor& tensor, bool to_cpu = true);
288 
289 // return the value of the tensor's storage pointer
290 uint64_t getStorageKey(const at::Tensor& tensor);
291 
292 // if the cls has __getstate__/__setstate__
293 // assert they have the right schema and return true,
294 // otherwise return false
295 bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls);
296 
297 // Declare BackendMeta serialization and deserialization function pointer types.
298 using BackendMetaPtr = std::function<
299     void(const at::Tensor&, std::unordered_map<std::string, bool>&)>;
300 
301 // A allowlist of device type, currently available is PrivateUse1
GetBackendMetaAllowlist()302 inline std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
303   static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
304       c10::DeviceType::PrivateUse1};
305   return DeviceTypeAllowlist;
306 }
307 
308 // Dynamically obtain serialization function pairs
309 // that require the corresponding backend.
310 inline std::array<
311     std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
312     at::COMPILE_TIME_MAX_DEVICE_TYPES>&
GetBackendMetaSerialization()313 GetBackendMetaSerialization() {
314   // The array to save function pointer for BackendMeta serialization.
315   // key is the DeviceType, value is std::pair obj.
316   // value.first represent get function and value.seconde represent set function
317   static std::array<
318       std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
319       at::COMPILE_TIME_MAX_DEVICE_TYPES>
320       BackendMetaSerialization;
321   return BackendMetaSerialization;
322 }
323 
324 // Register function pointer of Tensor BackendMetadata for serialization.
TensorBackendMetaRegistry(c10::DeviceType t,const BackendMetaPtr & get_fptr,const BackendMetaPtr & set_fptr)325 TORCH_API inline void TensorBackendMetaRegistry(
326     c10::DeviceType t,
327     const BackendMetaPtr& get_fptr,
328     const BackendMetaPtr& set_fptr) {
329   // allowlist verification
330   // Only if the devicetype is in the allowlist,
331   // we allow the serialization extension to be registered for backendmeta data.
332   const auto& DeviceTypeAllowlist = GetBackendMetaAllowlist();
333   TORCH_CHECK(
334       DeviceTypeAllowlist.find(t) != DeviceTypeAllowlist.end(),
335       "It is not allowed to register the serialization method ",
336       "of backendMeta data for PrivateUse1. ",
337       "If you have related serialization requirements, ",
338       "please expand the allowlist");
339   // Register function pointer
340   int device_type = static_cast<int>(t);
341   auto& BackendMetaSerialization = GetBackendMetaSerialization();
342   TORCH_CHECK(
343       !BackendMetaSerialization[device_type].has_value(),
344       "The tensor BackendMeta serialization function pointer for ",
345       t,
346       " has been registered.");
347   BackendMetaSerialization[device_type] =
348       std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>(
349           std::make_pair(get_fptr, set_fptr));
350 }
351 
352 // Return a map of Tensor Metadata which including BackendMetaData for
353 // serialization. For now, it only takes care of `conj` and `neg` bit.
getTensorMetadata(const at::Tensor & t)354 inline std::unordered_map<std::string, bool> getTensorMetadata(
355     const at::Tensor& t) {
356   // We don't support serializing `ZeroTensor` as it is not public
357   // facing yet.
358   TORCH_CHECK(
359       !t._is_zerotensor(),
360       "ZeroTensor is not serializable,",
361       " please file an issue if required.");
362   std::unordered_map<std::string, bool> metadata{};
363 
364   // Only add meta-data if the value is not default.
365   if (t.is_conj()) {
366     metadata["conj"] = true;
367   }
368   if (t.is_neg()) {
369     metadata["neg"] = true;
370   }
371   // Only add BackendMetaData for custom backend if the function pointer is
372   // registered.
373   int device_type = static_cast<int>(t.device().type());
374   const auto& BackendMetaSerialization = GetBackendMetaSerialization();
375   if (BackendMetaSerialization[device_type].has_value()) {
376     // Pass the tensor and metadata map references as parameters to the custom
377     // serialization function.
378     BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().first;
379     fptr(t, metadata);
380   }
381   return metadata;
382 }
383 
384 // set Tensor Metadata based on the map.
385 // Refer: getTensorMetadata
setTensorMetadata(const at::Tensor & t,std::unordered_map<std::string,bool> metadata)386 inline void setTensorMetadata(
387     const at::Tensor& t,
388     std::unordered_map<std::string, bool> metadata) {
389   auto iter_end = metadata.end();
390   auto iter_temp = metadata.find("conj");
391   if (iter_temp != iter_end) {
392     t._set_conj(true);
393     metadata.erase(iter_temp);
394   }
395   iter_temp = metadata.find("neg");
396   if (iter_temp != iter_end) {
397     t._set_neg(true);
398     metadata.erase(iter_temp);
399   }
400   // Only set BackendMetaData for custom backend if the function pointer is
401   // registered.
402   int device_type = static_cast<int>(t.device().type());
403   const auto& BackendMetaSerialization = GetBackendMetaSerialization();
404   if (BackendMetaSerialization[device_type].has_value()) {
405     // Pass the tensor and metadata map references as parameters to the custom
406     // deserialization function.
407     BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().second;
408     fptr(t, metadata);
409   }
410 }
411 
412 // set Tensor metadata based on the map.
413 // NOTE: This overload is required by unpickler.cpp
setTensorMetadata(const at::Tensor & t,const c10::Dict<c10::IValue,c10::IValue> & metadata_idict)414 inline void setTensorMetadata(
415     const at::Tensor& t,
416     const c10::Dict<c10::IValue, c10::IValue>& metadata_idict) {
417   std::unordered_map<std::string, bool> metadata;
418   for (auto& pair : metadata_idict) {
419     auto key = *pair.key().toString();
420     metadata[key] = pair.value().toBool();
421   }
422   setTensorMetadata(t, std::move(metadata));
423 }
424 
425 } // namespace torch::jit
426