1 #pragma once
2
3 #include <ATen/core/qualified_name.h>
4 #include <string>
5 #include <utility>
6 #include <vector>
7
8 #include <ATen/Utils.h>
9 #include <ATen/core/ivalue.h>
10 #include <ATen/core/jit_type.h>
11 #include <c10/util/ArrayRef.h>
12 #include <c10/util/FbcodeMaps.h>
13 #include <c10/util/intrusive_ptr.h>
14 #include <c10/util/string_view.h>
15 #include <torch/csrc/Export.h>
16
17 namespace torch::jit {
18
19 // See Python's pickletools.py for a detailed description of each of these codes
20 enum class PickleOpCode : char {
21 MARK = '(',
22 STOP = '.',
23 POP = '0',
24 POP_MARK = '1',
25 DUP = '2',
26 FLOAT = 'F',
27 INT = 'I',
28 BININT = 'J',
29 BININT1 = 'K',
30 LONG = 'L',
31 BININT2 = 'M',
32 NONE = 'N',
33 PERSID = 'P',
34 BINPERSID = 'Q',
35 REDUCE = 'R',
36 STRING = 'S',
37 BINSTRING = 'T',
38 SHORT_BINSTRING = 'U',
39 // NB: Avoid using UNICODE as it is a macro in the Windows API
40 UNICODE_ = 'V',
41 BINUNICODE = 'X',
42 APPEND = 'a',
43 BUILD = 'b',
44 GLOBAL = 'c',
45 DICT = 'd',
46 EMPTY_DICT = '}',
47 APPENDS = 'e',
48 GET = 'g',
49 BINGET = 'h',
50 INST = 'i',
51 LONG_BINGET = 'j',
52 LIST = 'l',
53 EMPTY_LIST = ']',
54 OBJ = 'o',
55 PUT = 'p',
56 BINPUT = 'q',
57 LONG_BINPUT = 'r',
58 SETITEM = 's',
59 TUPLE = 't',
60 EMPTY_TUPLE = ')',
61 SETITEMS = 'u',
62 BINFLOAT = 'G',
63
64 // Protocol 2
65 PROTO = char('\x80'),
66 NEWOBJ = '\x81',
67 EXT1 = '\x82',
68 EXT2 = '\x83',
69 EXT4 = '\x84',
70 TUPLE1 = '\x85',
71 TUPLE2 = '\x86',
72 TUPLE3 = '\x87',
73 NEWTRUE = '\x88',
74 NEWFALSE = '\x89',
75 LONG1 = '\x8a',
76 LONG4 = '\x8b',
77
78 // Protocol 3 (Python 3.x)
79 BINBYTES = 'B',
80 SHORT_BINBYTES = 'C',
81
82 // Protocol 4
83 SHORT_BINUNICODE = char('\x8c'),
84 BINUNICODE8 = '\x8d',
85 BINBYTES8 = '\x8e',
86 EMPTY_SET = '\x8f',
87 ADDITEMS = '\x90',
88 FROZENSET = '\x91',
89 NEWOBJ_EX = '\x92',
90 STACK_GLOBAL = '\x93',
91 MEMOIZE = '\x94',
92 FRAME = '\x95'
93 };
94
95 using ::c10::IValue;
96
97 struct WriteableTensorData {
dataWriteableTensorData98 const char* data() const {
99 return static_cast<const char*>(tensor_.storage().data());
100 }
sizeInBytesWriteableTensorData101 size_t sizeInBytes() const {
102 return size_;
103 }
nbytesWriteableTensorData104 size_t nbytes() const {
105 return tensor_.storage().nbytes();
106 }
storageHasDeleterWriteableTensorData107 bool storageHasDeleter() const {
108 return tensor_.storage().data_ptr().get_context() != nullptr;
109 }
110
111 private:
112 friend TORCH_API WriteableTensorData
113 getWriteableTensorData(const at::Tensor& tensor, bool to_cpu);
114 at::Tensor tensor_;
115 uint64_t size_;
116 };
117
118 void setTypeTags(bool state);
119 bool getTypeTags();
120
121 class TORCH_API Pickler {
122 AT_DISALLOW_COPY_AND_ASSIGN(Pickler);
123
124 public:
Pickler(std::function<void (const char *,size_t)> writer)125 Pickler(std::function<void(const char*, size_t)> writer)
126 : Pickler(std::move(writer), nullptr, nullptr, nullptr) {}
127
128 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
129 Pickler(
130 std::function<void(const char*, size_t)> writer,
131 std::vector<at::Tensor>* tensor_table,
132 std::function<c10::QualifiedName(const c10::ClassTypePtr&)> type_renamer,
133 std::vector<c10::ClassTypePtr>* memoized_class_types,
134 std::function<std::string(const at::Tensor&)> get_tensor_id = nullptr,
135 bool tag_aggregates = true)
writer_(std::move (writer))136 : writer_(std::move(writer)),
137 tensor_table_(tensor_table),
138 type_renamer_(std::move(type_renamer)),
139 memoized_class_types_(memoized_class_types),
140 get_tensor_id_(std::move(get_tensor_id)),
141 tag_aggregates_(tag_aggregates) {}
142 ~Pickler();
143
144 // Push protocol onto the stack
145 void protocol();
146
147 // Push STOP PickleOpCode onto the stack
148 void stop();
149
150 void pushIValue(const IValue& ivalue);
151
152 void startTuple();
153 void endTuple();
154
tensorData()155 const std::vector<at::Tensor>& tensorData() {
156 return tensor_data_;
157 }
158
159 void pushEmptyDict();
160 void pushDict(const IValue& ivalue);
161 void pushInt(int64_t value);
162 void pushLong(const std::string& data);
163
164 private:
165 void pushIValueImpl(const IValue& ivalue);
166 void startTypeTag();
167 void endTypeTag(const IValue& value);
168 void pushBool(bool value);
169 void pushDouble(double value);
170 void pushComplexDouble(const IValue& value);
171 void pushGenericList(const IValue& ivalue);
172 void pushIntList(const IValue& ivalue);
173 void pushList(const IValue& ivalue);
174 void pushTensor(const IValue& ivalue);
175 void pushTensorReference(const IValue& ivalue);
176 void pushLiteralTensor(const IValue& ivalue);
177 void pushLiteralSparseTensor(const at::Tensor& tensor);
178 void pushTuple(const IValue& ivalue);
179 void pushString(const std::string& string);
180 void pushDevice(const IValue& ivalue);
181 #ifdef USE_DISTRIBUTED
182 void pushRRef(const IValue& ivalue);
183 #endif
184 // unmemoized version
185 void pushStringImpl(const std::string& string);
186 void pushStorageOfTensor(const at::Tensor& tensor);
187
188 void pushBinGet(uint32_t memo_id);
189 void pushSpecializedList(
190 const IValue& ivalue,
191 const char* list_name,
192 const std::function<void(const IValue&)>& item_pusher);
193 void pushGlobal(c10::string_view module_name, c10::string_view class_name);
194 // raw string data is appended directly to the byte stream
195 void pushBytes(const std::string& string);
196 void pushTensorData(const at::Tensor& tensor);
197
198 // Add a BINPUT op and return the memoization id used
199 size_t pushNextBinPut();
200
201 const void* getPointer(const IValue& ivalue);
202
203 // Caller checks that bufferPos_ > 0
flushNonEmpty()204 void flushNonEmpty() {
205 writer_(buffer_.data(), bufferPos_);
206 bufferPos_ = 0;
207 }
208
flush()209 void flush() {
210 if (bufferPos_ != 0) {
211 flushNonEmpty();
212 }
213 }
214
215 // These convert values to bytes and add them to the stack (NB: since T is to
216 // the left of a '::', its type cannot be deduced by the compiler so one must
217 // explicitly instantiate the template, i.e. push<int>(int) works, push(int)
218 // does not)
219 static CONSTEXPR_EXCEPT_WIN_CUDA size_t kBufferSize = 256;
220 template <typename T>
push(std::common_type_t<T> value)221 void push(std::common_type_t<T> value) {
222 const char* begin = reinterpret_cast<const char*>(&value);
223 if (bufferPos_ + sizeof(T) > buffer_.size()) {
224 flushNonEmpty();
225 }
226 static_assert(sizeof(T) <= kBufferSize, "Buffer size assumption");
227 memcpy(buffer_.data() + bufferPos_, begin, sizeof(T));
228 bufferPos_ += sizeof(T);
229 }
230
231 // Stream to write binary data to
232 // Code shouldn't call writer_ directly without first flushing.
233 std::function<void(const char*, size_t)> writer_;
234
235 // Buffer to avoid calling a writer_ on a per-byte basis.
236 std::array<char, kBufferSize> buffer_;
237 size_t bufferPos_{0};
238
239 // Stack of opcodes/data
240 std::vector<char> stack_;
241
242 // External table of tensors to serialize. If this is missing, then tensors
243 // are serialized directly into the pickle
244 std::vector<at::Tensor>* tensor_table_;
245
246 // TODO: only use this if necessary (add a pass to find all shared ivalues,
247 // and only memoize those)
248 uint32_t memo_id_ = 0;
249
250 // Memoization of IValues that have been written (index in table is used for
251 // BINPUT opcodes) to enable shared references
252 c10::FastMap<const void*, uint32_t> memoized_ivalue_map_;
253
254 // because we de-dup ivalues based on their raw pointer address in the above
255 // map we need to keep all the memoized values alive during the pickle.
256 // Otherwise, it is possible that a raw address gets reused for another
257 // object, and we will alias it to the old object at that address.
258 std::vector<IValue> memoized_ivalues_;
259
260 std::function<c10::QualifiedName(const c10::ClassTypePtr&)> type_renamer_;
261
262 // List of all the types that it wrote, inspect from the IValues it wrote.
263 std::vector<c10::ClassTypePtr>* memoized_class_types_;
264
265 // Function to grab next id_name for tensor storage, function is responsible
266 // for returning unique ids
267 std::function<std::string(const at::Tensor&)> get_tensor_id_;
268
269 // List of tensor storages to serialize in the same binary as the pickle data
270 // similar to ivalues, they are memoized using BINPUT
271 std::vector<at::Tensor> tensor_data_;
272 c10::FastMap<const void*, uint32_t> memoized_storage_map_;
273
274 c10::FastMap<std::string, uint32_t> memoized_globals_map_;
275 c10::FastMap<std::string, uint32_t> memoized_strings_map_;
276 c10::FastMap<std::string, uint32_t> memoized_devices_map_;
277 // when true, List and Dict objects will be wrapped in a
278 // torch.jit._pickle.restore_type_tag call to correctly set the dynamic
279 // TorchScript type for the object. When true the thing unpickling must have
280 // torch installed.
281 bool tag_aggregates_;
282 };
283
284 // returns a (tensor, record_size) for a tensor, converting it to a CPU tensor
285 // if it was CUDA and to_cpu is True.
286 TORCH_API WriteableTensorData
287 getWriteableTensorData(const at::Tensor& tensor, bool to_cpu = true);
288
289 // return the value of the tensor's storage pointer
290 uint64_t getStorageKey(const at::Tensor& tensor);
291
292 // if the cls has __getstate__/__setstate__
293 // assert they have the right schema and return true,
294 // otherwise return false
295 bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls);
296
297 // Declare BackendMeta serialization and deserialization function pointer types.
298 using BackendMetaPtr = std::function<
299 void(const at::Tensor&, std::unordered_map<std::string, bool>&)>;
300
301 // A allowlist of device type, currently available is PrivateUse1
GetBackendMetaAllowlist()302 inline std::unordered_set<c10::DeviceType>& GetBackendMetaAllowlist() {
303 static std::unordered_set<c10::DeviceType> DeviceTypeAllowlist{
304 c10::DeviceType::PrivateUse1};
305 return DeviceTypeAllowlist;
306 }
307
308 // Dynamically obtain serialization function pairs
309 // that require the corresponding backend.
310 inline std::array<
311 std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
312 at::COMPILE_TIME_MAX_DEVICE_TYPES>&
GetBackendMetaSerialization()313 GetBackendMetaSerialization() {
314 // The array to save function pointer for BackendMeta serialization.
315 // key is the DeviceType, value is std::pair obj.
316 // value.first represent get function and value.seconde represent set function
317 static std::array<
318 std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>,
319 at::COMPILE_TIME_MAX_DEVICE_TYPES>
320 BackendMetaSerialization;
321 return BackendMetaSerialization;
322 }
323
324 // Register function pointer of Tensor BackendMetadata for serialization.
TensorBackendMetaRegistry(c10::DeviceType t,const BackendMetaPtr & get_fptr,const BackendMetaPtr & set_fptr)325 TORCH_API inline void TensorBackendMetaRegistry(
326 c10::DeviceType t,
327 const BackendMetaPtr& get_fptr,
328 const BackendMetaPtr& set_fptr) {
329 // allowlist verification
330 // Only if the devicetype is in the allowlist,
331 // we allow the serialization extension to be registered for backendmeta data.
332 const auto& DeviceTypeAllowlist = GetBackendMetaAllowlist();
333 TORCH_CHECK(
334 DeviceTypeAllowlist.find(t) != DeviceTypeAllowlist.end(),
335 "It is not allowed to register the serialization method ",
336 "of backendMeta data for PrivateUse1. ",
337 "If you have related serialization requirements, ",
338 "please expand the allowlist");
339 // Register function pointer
340 int device_type = static_cast<int>(t);
341 auto& BackendMetaSerialization = GetBackendMetaSerialization();
342 TORCH_CHECK(
343 !BackendMetaSerialization[device_type].has_value(),
344 "The tensor BackendMeta serialization function pointer for ",
345 t,
346 " has been registered.");
347 BackendMetaSerialization[device_type] =
348 std::optional<std::pair<BackendMetaPtr, BackendMetaPtr>>(
349 std::make_pair(get_fptr, set_fptr));
350 }
351
352 // Return a map of Tensor Metadata which including BackendMetaData for
353 // serialization. For now, it only takes care of `conj` and `neg` bit.
getTensorMetadata(const at::Tensor & t)354 inline std::unordered_map<std::string, bool> getTensorMetadata(
355 const at::Tensor& t) {
356 // We don't support serializing `ZeroTensor` as it is not public
357 // facing yet.
358 TORCH_CHECK(
359 !t._is_zerotensor(),
360 "ZeroTensor is not serializable,",
361 " please file an issue if required.");
362 std::unordered_map<std::string, bool> metadata{};
363
364 // Only add meta-data if the value is not default.
365 if (t.is_conj()) {
366 metadata["conj"] = true;
367 }
368 if (t.is_neg()) {
369 metadata["neg"] = true;
370 }
371 // Only add BackendMetaData for custom backend if the function pointer is
372 // registered.
373 int device_type = static_cast<int>(t.device().type());
374 const auto& BackendMetaSerialization = GetBackendMetaSerialization();
375 if (BackendMetaSerialization[device_type].has_value()) {
376 // Pass the tensor and metadata map references as parameters to the custom
377 // serialization function.
378 BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().first;
379 fptr(t, metadata);
380 }
381 return metadata;
382 }
383
384 // set Tensor Metadata based on the map.
385 // Refer: getTensorMetadata
setTensorMetadata(const at::Tensor & t,std::unordered_map<std::string,bool> metadata)386 inline void setTensorMetadata(
387 const at::Tensor& t,
388 std::unordered_map<std::string, bool> metadata) {
389 auto iter_end = metadata.end();
390 auto iter_temp = metadata.find("conj");
391 if (iter_temp != iter_end) {
392 t._set_conj(true);
393 metadata.erase(iter_temp);
394 }
395 iter_temp = metadata.find("neg");
396 if (iter_temp != iter_end) {
397 t._set_neg(true);
398 metadata.erase(iter_temp);
399 }
400 // Only set BackendMetaData for custom backend if the function pointer is
401 // registered.
402 int device_type = static_cast<int>(t.device().type());
403 const auto& BackendMetaSerialization = GetBackendMetaSerialization();
404 if (BackendMetaSerialization[device_type].has_value()) {
405 // Pass the tensor and metadata map references as parameters to the custom
406 // deserialization function.
407 BackendMetaPtr fptr = BackendMetaSerialization[device_type].value().second;
408 fptr(t, metadata);
409 }
410 }
411
412 // set Tensor metadata based on the map.
413 // NOTE: This overload is required by unpickler.cpp
setTensorMetadata(const at::Tensor & t,const c10::Dict<c10::IValue,c10::IValue> & metadata_idict)414 inline void setTensorMetadata(
415 const at::Tensor& t,
416 const c10::Dict<c10::IValue, c10::IValue>& metadata_idict) {
417 std::unordered_map<std::string, bool> metadata;
418 for (auto& pair : metadata_idict) {
419 auto key = *pair.key().toString();
420 metadata[key] = pair.value().toBool();
421 }
422 setTensorMetadata(t, std::move(metadata));
423 }
424
425 } // namespace torch::jit
426