xref: /aosp_15_r20/external/pytorch/caffe2/serialize/inline_container.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <cstdio>
2*da0073e9SAndroid Build Coastguard Worker #include <cstring>
3*da0073e9SAndroid Build Coastguard Worker #include <cerrno>
4*da0073e9SAndroid Build Coastguard Worker #include <istream>
5*da0073e9SAndroid Build Coastguard Worker #include <ostream>
6*da0073e9SAndroid Build Coastguard Worker #include <fstream>
7*da0073e9SAndroid Build Coastguard Worker #include <algorithm>
8*da0073e9SAndroid Build Coastguard Worker #include <sstream>
9*da0073e9SAndroid Build Coastguard Worker #include <sys/stat.h>
10*da0073e9SAndroid Build Coastguard Worker #include <sys/types.h>
11*da0073e9SAndroid Build Coastguard Worker #include <thread>
12*da0073e9SAndroid Build Coastguard Worker 
13*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Allocator.h>
14*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Backend.h>
15*da0073e9SAndroid Build Coastguard Worker #include <c10/core/CPUAllocator.h>
16*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Backend.h>
17*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
18*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Logging.h>
19*da0073e9SAndroid Build Coastguard Worker #include <c10/util/hash.h>
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker #include "caffe2/core/common.h"
22*da0073e9SAndroid Build Coastguard Worker #include "caffe2/serialize/file_adapter.h"
23*da0073e9SAndroid Build Coastguard Worker #include "caffe2/serialize/inline_container.h"
24*da0073e9SAndroid Build Coastguard Worker #include "caffe2/serialize/istream_adapter.h"
25*da0073e9SAndroid Build Coastguard Worker #include "caffe2/serialize/read_adapter_interface.h"
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker #include "caffe2/serialize/versions.h"
28*da0073e9SAndroid Build Coastguard Worker #include "miniz.h"
29*da0073e9SAndroid Build Coastguard Worker 
30*da0073e9SAndroid Build Coastguard Worker namespace caffe2 {
31*da0073e9SAndroid Build Coastguard Worker namespace serialize {
32*da0073e9SAndroid Build Coastguard Worker constexpr c10::string_view kDebugPklSuffix(".debug_pkl");
33*da0073e9SAndroid Build Coastguard Worker 
34*da0073e9SAndroid Build Coastguard Worker struct MzZipReaderIterWrapper {
MzZipReaderIterWrappercaffe2::serialize::MzZipReaderIterWrapper35*da0073e9SAndroid Build Coastguard Worker   MzZipReaderIterWrapper(mz_zip_reader_extract_iter_state* iter) : impl(iter) {}
36*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_extract_iter_state* impl;
37*da0073e9SAndroid Build Coastguard Worker };
38*da0073e9SAndroid Build Coastguard Worker 
ChunkRecordIterator(size_t recordSize,size_t chunkSize,std::unique_ptr<MzZipReaderIterWrapper> iter)39*da0073e9SAndroid Build Coastguard Worker ChunkRecordIterator::ChunkRecordIterator(
40*da0073e9SAndroid Build Coastguard Worker     size_t recordSize,
41*da0073e9SAndroid Build Coastguard Worker     size_t chunkSize,
42*da0073e9SAndroid Build Coastguard Worker     std::unique_ptr<MzZipReaderIterWrapper> iter)
43*da0073e9SAndroid Build Coastguard Worker     : recordSize_(recordSize),
44*da0073e9SAndroid Build Coastguard Worker       chunkSize_(chunkSize),
45*da0073e9SAndroid Build Coastguard Worker       offset_(0),
46*da0073e9SAndroid Build Coastguard Worker       iter_(std::move(iter)) {}
47*da0073e9SAndroid Build Coastguard Worker 
~ChunkRecordIterator()48*da0073e9SAndroid Build Coastguard Worker ChunkRecordIterator::~ChunkRecordIterator() {
49*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_extract_iter_free(iter_->impl);
50*da0073e9SAndroid Build Coastguard Worker }
51*da0073e9SAndroid Build Coastguard Worker 
next(void * buf)52*da0073e9SAndroid Build Coastguard Worker size_t ChunkRecordIterator::next(void* buf){
53*da0073e9SAndroid Build Coastguard Worker   size_t want_size = std::min(chunkSize_, recordSize_ - offset_);
54*da0073e9SAndroid Build Coastguard Worker   if (want_size == 0) {
55*da0073e9SAndroid Build Coastguard Worker     return 0;
56*da0073e9SAndroid Build Coastguard Worker   }
57*da0073e9SAndroid Build Coastguard Worker   size_t read_size = mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size);
58*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(read_size > 0, "Read bytes should be larger than 0");
59*da0073e9SAndroid Build Coastguard Worker   offset_ += read_size;
60*da0073e9SAndroid Build Coastguard Worker   return read_size;
61*da0073e9SAndroid Build Coastguard Worker }
62*da0073e9SAndroid Build Coastguard Worker 
istream_read_func(void * pOpaque,mz_uint64 file_ofs,void * pBuf,size_t n)63*da0073e9SAndroid Build Coastguard Worker size_t istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) {
64*da0073e9SAndroid Build Coastguard Worker   auto self = static_cast<PyTorchStreamReader*>(pOpaque);
65*da0073e9SAndroid Build Coastguard Worker   return self->read(file_ofs, static_cast<char*>(pBuf), n);
66*da0073e9SAndroid Build Coastguard Worker }
67*da0073e9SAndroid Build Coastguard Worker 
basename(const std::string & name)68*da0073e9SAndroid Build Coastguard Worker static std::string basename(const std::string& name) {
69*da0073e9SAndroid Build Coastguard Worker   size_t start = 0;
70*da0073e9SAndroid Build Coastguard Worker   for(size_t i = 0; i < name.size(); ++i) {
71*da0073e9SAndroid Build Coastguard Worker     if (name[i] == '\\' || name[i] == '/') {
72*da0073e9SAndroid Build Coastguard Worker       start = i + 1;
73*da0073e9SAndroid Build Coastguard Worker     }
74*da0073e9SAndroid Build Coastguard Worker   }
75*da0073e9SAndroid Build Coastguard Worker 
76*da0073e9SAndroid Build Coastguard Worker   if (start >= name.size()) {
77*da0073e9SAndroid Build Coastguard Worker     return "";
78*da0073e9SAndroid Build Coastguard Worker   }
79*da0073e9SAndroid Build Coastguard Worker 
80*da0073e9SAndroid Build Coastguard Worker   size_t end = name.size();
81*da0073e9SAndroid Build Coastguard Worker   for(size_t i = end; i > start; --i) {
82*da0073e9SAndroid Build Coastguard Worker     if (name[i - 1] == '.') {
83*da0073e9SAndroid Build Coastguard Worker       end = i - 1;
84*da0073e9SAndroid Build Coastguard Worker       break;
85*da0073e9SAndroid Build Coastguard Worker     }
86*da0073e9SAndroid Build Coastguard Worker   }
87*da0073e9SAndroid Build Coastguard Worker   return name.substr(start, end - start);
88*da0073e9SAndroid Build Coastguard Worker }
89*da0073e9SAndroid Build Coastguard Worker 
parentdir(const std::string & name)90*da0073e9SAndroid Build Coastguard Worker static std::string parentdir(const std::string& name) {
91*da0073e9SAndroid Build Coastguard Worker   size_t end = name.find_last_of('/');
92*da0073e9SAndroid Build Coastguard Worker   if (end == std::string::npos) {
93*da0073e9SAndroid Build Coastguard Worker     end = name.find_last_of('\\');
94*da0073e9SAndroid Build Coastguard Worker   }
95*da0073e9SAndroid Build Coastguard Worker 
96*da0073e9SAndroid Build Coastguard Worker   #ifdef WIN32
97*da0073e9SAndroid Build Coastguard Worker   if (end != std::string::npos && end > 1 && name[end - 1] == ':') {
98*da0073e9SAndroid Build Coastguard Worker     // This is a Windows root directory, so include the slash in
99*da0073e9SAndroid Build Coastguard Worker     // the parent directory
100*da0073e9SAndroid Build Coastguard Worker     end++;
101*da0073e9SAndroid Build Coastguard Worker   }
102*da0073e9SAndroid Build Coastguard Worker   #endif
103*da0073e9SAndroid Build Coastguard Worker 
104*da0073e9SAndroid Build Coastguard Worker   if (end == std::string::npos) {
105*da0073e9SAndroid Build Coastguard Worker     return "";
106*da0073e9SAndroid Build Coastguard Worker   }
107*da0073e9SAndroid Build Coastguard Worker 
108*da0073e9SAndroid Build Coastguard Worker   return name.substr(0, end);
109*da0073e9SAndroid Build Coastguard Worker }
110*da0073e9SAndroid Build Coastguard Worker 
read(uint64_t pos,char * buf,size_t n)111*da0073e9SAndroid Build Coastguard Worker size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) {
112*da0073e9SAndroid Build Coastguard Worker   return in_->read(pos, buf, n, "reading file");
113*da0073e9SAndroid Build Coastguard Worker }
114*da0073e9SAndroid Build Coastguard Worker 
115*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
PyTorchStreamReader(const std::string & file_name)116*da0073e9SAndroid Build Coastguard Worker PyTorchStreamReader::PyTorchStreamReader(const std::string& file_name)
117*da0073e9SAndroid Build Coastguard Worker     : ar_(std::make_unique<mz_zip_archive>()),
118*da0073e9SAndroid Build Coastguard Worker       in_(std::make_unique<FileAdapter>(file_name)) {
119*da0073e9SAndroid Build Coastguard Worker   init();
120*da0073e9SAndroid Build Coastguard Worker }
121*da0073e9SAndroid Build Coastguard Worker 
122*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
PyTorchStreamReader(std::istream * in)123*da0073e9SAndroid Build Coastguard Worker PyTorchStreamReader::PyTorchStreamReader(std::istream* in)
124*da0073e9SAndroid Build Coastguard Worker     : ar_(std::make_unique<mz_zip_archive>()),
125*da0073e9SAndroid Build Coastguard Worker       in_(std::make_unique<IStreamAdapter>(in)) {
126*da0073e9SAndroid Build Coastguard Worker   init();
127*da0073e9SAndroid Build Coastguard Worker }
128*da0073e9SAndroid Build Coastguard Worker 
129*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
PyTorchStreamReader(std::shared_ptr<ReadAdapterInterface> in)130*da0073e9SAndroid Build Coastguard Worker PyTorchStreamReader::PyTorchStreamReader(
131*da0073e9SAndroid Build Coastguard Worker     std::shared_ptr<ReadAdapterInterface> in)
132*da0073e9SAndroid Build Coastguard Worker     : ar_(std::make_unique<mz_zip_archive>()), in_(std::move(in)) {
133*da0073e9SAndroid Build Coastguard Worker   init();
134*da0073e9SAndroid Build Coastguard Worker }
135*da0073e9SAndroid Build Coastguard Worker 
init()136*da0073e9SAndroid Build Coastguard Worker void PyTorchStreamReader::init() {
137*da0073e9SAndroid Build Coastguard Worker   AT_ASSERT(in_ != nullptr);
138*da0073e9SAndroid Build Coastguard Worker   AT_ASSERT(ar_ != nullptr);
139*da0073e9SAndroid Build Coastguard Worker   memset(ar_.get(), 0, sizeof(mz_zip_archive));
140*da0073e9SAndroid Build Coastguard Worker 
141*da0073e9SAndroid Build Coastguard Worker   size_t size = in_->size();
142*da0073e9SAndroid Build Coastguard Worker 
143*da0073e9SAndroid Build Coastguard Worker   // check for the old magic number,
144*da0073e9SAndroid Build Coastguard Worker   constexpr size_t kMagicValueLength = 8;
145*da0073e9SAndroid Build Coastguard Worker   if (size > kMagicValueLength) {
146*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
147*da0073e9SAndroid Build Coastguard Worker     char buf[kMagicValueLength];
148*da0073e9SAndroid Build Coastguard Worker     read(0, buf, kMagicValueLength);
149*da0073e9SAndroid Build Coastguard Worker     valid("checking magic number");
150*da0073e9SAndroid Build Coastguard Worker     AT_ASSERTM(
151*da0073e9SAndroid Build Coastguard Worker         memcmp("PYTORCH1", buf, kMagicValueLength) != 0,
152*da0073e9SAndroid Build Coastguard Worker         "File is an unsupported archive format from the preview release.");
153*da0073e9SAndroid Build Coastguard Worker   }
154*da0073e9SAndroid Build Coastguard Worker 
155*da0073e9SAndroid Build Coastguard Worker   ar_->m_pIO_opaque = this;
156*da0073e9SAndroid Build Coastguard Worker   ar_->m_pRead = istream_read_func;
157*da0073e9SAndroid Build Coastguard Worker 
158*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_init(ar_.get(), size, 0);
159*da0073e9SAndroid Build Coastguard Worker   valid("reading zip archive");
160*da0073e9SAndroid Build Coastguard Worker 
161*da0073e9SAndroid Build Coastguard Worker   // figure out the archive_name (i.e. the zip folder all the other files are in)
162*da0073e9SAndroid Build Coastguard Worker   // all lookups to getRecord will be prefixed by this folder
163*da0073e9SAndroid Build Coastguard Worker   mz_uint n = mz_zip_reader_get_num_files(ar_.get());
164*da0073e9SAndroid Build Coastguard Worker   if (n == 0) {
165*da0073e9SAndroid Build Coastguard Worker     CAFFE_THROW("archive does not contain any files");
166*da0073e9SAndroid Build Coastguard Worker   }
167*da0073e9SAndroid Build Coastguard Worker   size_t name_size = mz_zip_reader_get_filename(ar_.get(), 0, nullptr, 0);
168*da0073e9SAndroid Build Coastguard Worker   valid("getting filename");
169*da0073e9SAndroid Build Coastguard Worker   std::string buf(name_size, '\0');
170*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_get_filename(ar_.get(), 0, &buf[0], name_size);
171*da0073e9SAndroid Build Coastguard Worker   valid("getting filename");
172*da0073e9SAndroid Build Coastguard Worker   auto pos = buf.find_first_of('/');
173*da0073e9SAndroid Build Coastguard Worker   if (pos == std::string::npos) {
174*da0073e9SAndroid Build Coastguard Worker     CAFFE_THROW("file in archive is not in a subdirectory: ", buf);
175*da0073e9SAndroid Build Coastguard Worker   }
176*da0073e9SAndroid Build Coastguard Worker   archive_name_ = buf.substr(0, pos);
177*da0073e9SAndroid Build Coastguard Worker   archive_name_plus_slash_ = archive_name_ + "/";
178*da0073e9SAndroid Build Coastguard Worker 
179*da0073e9SAndroid Build Coastguard Worker   // read serialization id
180*da0073e9SAndroid Build Coastguard Worker   if (hasRecord(kSerializationIdRecordName)) {
181*da0073e9SAndroid Build Coastguard Worker     at::DataPtr serialization_id_ptr;
182*da0073e9SAndroid Build Coastguard Worker     size_t serialization_id_size = 0;
183*da0073e9SAndroid Build Coastguard Worker     std::tie(serialization_id_ptr, serialization_id_size) =
184*da0073e9SAndroid Build Coastguard Worker         getRecord(kSerializationIdRecordName);
185*da0073e9SAndroid Build Coastguard Worker     serialization_id_.assign(
186*da0073e9SAndroid Build Coastguard Worker         static_cast<const char*>(serialization_id_ptr.get()),
187*da0073e9SAndroid Build Coastguard Worker         serialization_id_size);
188*da0073e9SAndroid Build Coastguard Worker   }
189*da0073e9SAndroid Build Coastguard Worker   c10::LogAPIUsageMetadata(
190*da0073e9SAndroid Build Coastguard Worker       "pytorch.stream.reader.metadata",
191*da0073e9SAndroid Build Coastguard Worker       {{"serialization_id", serialization_id_},
192*da0073e9SAndroid Build Coastguard Worker        {"file_name", archive_name_},
193*da0073e9SAndroid Build Coastguard Worker        {"file_size", str(mz_zip_get_archive_size(ar_.get()))}});
194*da0073e9SAndroid Build Coastguard Worker 
195*da0073e9SAndroid Build Coastguard Worker   // version check
196*da0073e9SAndroid Build Coastguard Worker   at::DataPtr version_ptr;
197*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
198*da0073e9SAndroid Build Coastguard Worker   size_t version_size;
199*da0073e9SAndroid Build Coastguard Worker   if (hasRecord(".data/version")) {
200*da0073e9SAndroid Build Coastguard Worker     std::tie(version_ptr, version_size) = getRecord(".data/version");
201*da0073e9SAndroid Build Coastguard Worker   } else {
202*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(hasRecord("version"))
203*da0073e9SAndroid Build Coastguard Worker     std::tie(version_ptr, version_size) = getRecord("version");
204*da0073e9SAndroid Build Coastguard Worker   }
205*da0073e9SAndroid Build Coastguard Worker   std::string version(static_cast<const char*>(version_ptr.get()), version_size);
206*da0073e9SAndroid Build Coastguard Worker   try {
207*da0073e9SAndroid Build Coastguard Worker     version_ = std::stoull(version);
208*da0073e9SAndroid Build Coastguard Worker   } catch (const std::invalid_argument& e) {
209*da0073e9SAndroid Build Coastguard Worker     CAFFE_THROW("Couldn't parse the version ",
210*da0073e9SAndroid Build Coastguard Worker                  version,
211*da0073e9SAndroid Build Coastguard Worker                  " as Long Long.");
212*da0073e9SAndroid Build Coastguard Worker   }
213*da0073e9SAndroid Build Coastguard Worker   if (version_ < static_cast<decltype(version_)>(kMinSupportedFileFormatVersion)) {
214*da0073e9SAndroid Build Coastguard Worker     CAFFE_THROW(
215*da0073e9SAndroid Build Coastguard Worker         "Attempted to read a PyTorch file with version ",
216*da0073e9SAndroid Build Coastguard Worker         std::to_string(version_),
217*da0073e9SAndroid Build Coastguard Worker         ", but the minimum supported version for reading is ",
218*da0073e9SAndroid Build Coastguard Worker         std::to_string(kMinSupportedFileFormatVersion),
219*da0073e9SAndroid Build Coastguard Worker         ". Your PyTorch script module file is too old. Please regenerate it",
220*da0073e9SAndroid Build Coastguard Worker         " with latest version of PyTorch to mitigate this issue.");
221*da0073e9SAndroid Build Coastguard Worker   }
222*da0073e9SAndroid Build Coastguard Worker 
223*da0073e9SAndroid Build Coastguard Worker   if (version_ > static_cast<decltype(version_)>(kMaxSupportedFileFormatVersion)) {
224*da0073e9SAndroid Build Coastguard Worker     CAFFE_THROW(
225*da0073e9SAndroid Build Coastguard Worker         "Attempted to read a PyTorch file with version ",
226*da0073e9SAndroid Build Coastguard Worker         version_,
227*da0073e9SAndroid Build Coastguard Worker         ", but the maximum supported version for reading is ",
228*da0073e9SAndroid Build Coastguard Worker         kMaxSupportedFileFormatVersion,
229*da0073e9SAndroid Build Coastguard Worker         ". The version of your PyTorch installation may be too old, ",
230*da0073e9SAndroid Build Coastguard Worker         "please upgrade PyTorch to latest version to mitigate this issue.");
231*da0073e9SAndroid Build Coastguard Worker   }
232*da0073e9SAndroid Build Coastguard Worker }
233*da0073e9SAndroid Build Coastguard Worker 
valid(const char * what,const char * info)234*da0073e9SAndroid Build Coastguard Worker void PyTorchStreamReader::valid(const char* what, const char* info) {
235*da0073e9SAndroid Build Coastguard Worker   const auto err = mz_zip_get_last_error(ar_.get());
236*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
237*da0073e9SAndroid Build Coastguard Worker       err == MZ_ZIP_NO_ERROR,
238*da0073e9SAndroid Build Coastguard Worker       "PytorchStreamReader failed ",
239*da0073e9SAndroid Build Coastguard Worker       what,
240*da0073e9SAndroid Build Coastguard Worker       info,
241*da0073e9SAndroid Build Coastguard Worker       ": ",
242*da0073e9SAndroid Build Coastguard Worker       mz_zip_get_error_string(err));
243*da0073e9SAndroid Build Coastguard Worker }
244*da0073e9SAndroid Build Coastguard Worker 
245*da0073e9SAndroid Build Coastguard Worker constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
246*da0073e9SAndroid Build Coastguard Worker constexpr int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26;
247*da0073e9SAndroid Build Coastguard Worker constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
248*da0073e9SAndroid Build Coastguard Worker constexpr int MZ_ZIP_DATA_DESCRIPTOR_ID = 0x08074b50;
249*da0073e9SAndroid Build Coastguard Worker 
250*da0073e9SAndroid Build Coastguard Worker namespace detail {
getPadding(size_t cursor,size_t filename_size,size_t size,std::string & padding_buf)251*da0073e9SAndroid Build Coastguard Worker size_t getPadding(
252*da0073e9SAndroid Build Coastguard Worker     size_t cursor,
253*da0073e9SAndroid Build Coastguard Worker     size_t filename_size,
254*da0073e9SAndroid Build Coastguard Worker     size_t size,
255*da0073e9SAndroid Build Coastguard Worker     std::string& padding_buf) {
256*da0073e9SAndroid Build Coastguard Worker   size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size +
257*da0073e9SAndroid Build Coastguard Worker       sizeof(mz_uint16) * 2;
258*da0073e9SAndroid Build Coastguard Worker   if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
259*da0073e9SAndroid Build Coastguard Worker     start += sizeof(mz_uint16) * 2;
260*da0073e9SAndroid Build Coastguard Worker     if (size >= MZ_UINT32_MAX) {
261*da0073e9SAndroid Build Coastguard Worker       start += 2 * sizeof(mz_uint64);
262*da0073e9SAndroid Build Coastguard Worker     }
263*da0073e9SAndroid Build Coastguard Worker     if (cursor >= MZ_UINT32_MAX) {
264*da0073e9SAndroid Build Coastguard Worker       start += sizeof(mz_uint64);
265*da0073e9SAndroid Build Coastguard Worker     }
266*da0073e9SAndroid Build Coastguard Worker   }
267*da0073e9SAndroid Build Coastguard Worker   size_t mod = start % kFieldAlignment;
268*da0073e9SAndroid Build Coastguard Worker   size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
269*da0073e9SAndroid Build Coastguard Worker   size_t padding_size = next_offset - start;
270*da0073e9SAndroid Build Coastguard Worker   size_t padding_size_plus_fbxx = padding_size + 4;
271*da0073e9SAndroid Build Coastguard Worker   if (padding_buf.size() < padding_size_plus_fbxx) {
272*da0073e9SAndroid Build Coastguard Worker     padding_buf.append(padding_size_plus_fbxx - padding_buf.size(), 'Z');
273*da0073e9SAndroid Build Coastguard Worker   }
274*da0073e9SAndroid Build Coastguard Worker   // zip extra encoding (key, size_of_extra_bytes)
275*da0073e9SAndroid Build Coastguard Worker   padding_buf[0] = 'F';
276*da0073e9SAndroid Build Coastguard Worker   padding_buf[1] = 'B';
277*da0073e9SAndroid Build Coastguard Worker   padding_buf[2] = (uint8_t)padding_size;
278*da0073e9SAndroid Build Coastguard Worker   padding_buf[3] = (uint8_t)(padding_size >> 8);
279*da0073e9SAndroid Build Coastguard Worker   return padding_size_plus_fbxx;
280*da0073e9SAndroid Build Coastguard Worker }
281*da0073e9SAndroid Build Coastguard Worker }
282*da0073e9SAndroid Build Coastguard Worker 
hasRecord(const std::string & name)283*da0073e9SAndroid Build Coastguard Worker bool PyTorchStreamReader::hasRecord(const std::string& name) {
284*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> guard(reader_lock_);
285*da0073e9SAndroid Build Coastguard Worker 
286*da0073e9SAndroid Build Coastguard Worker   if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
287*da0073e9SAndroid Build Coastguard Worker     return false;
288*da0073e9SAndroid Build Coastguard Worker   }
289*da0073e9SAndroid Build Coastguard Worker   std::string ss = archive_name_plus_slash_ + name;
290*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
291*da0073e9SAndroid Build Coastguard Worker   const mz_zip_error err = mz_zip_get_last_error(ar_.get());
292*da0073e9SAndroid Build Coastguard Worker 
293*da0073e9SAndroid Build Coastguard Worker   if (err == MZ_ZIP_NO_ERROR) {
294*da0073e9SAndroid Build Coastguard Worker     return true;
295*da0073e9SAndroid Build Coastguard Worker   } else if (err == MZ_ZIP_FILE_NOT_FOUND) {
296*da0073e9SAndroid Build Coastguard Worker     return false;
297*da0073e9SAndroid Build Coastguard Worker   } else {
298*da0073e9SAndroid Build Coastguard Worker     // A different error happened, raise it.
299*da0073e9SAndroid Build Coastguard Worker     valid("attempting to locate file ", name.c_str());
300*da0073e9SAndroid Build Coastguard Worker   }
301*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(false, "should not reach here");
302*da0073e9SAndroid Build Coastguard Worker }
303*da0073e9SAndroid Build Coastguard Worker 
getAllRecords()304*da0073e9SAndroid Build Coastguard Worker std::vector<std::string> PyTorchStreamReader::getAllRecords() {
305*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> guard(reader_lock_);
306*da0073e9SAndroid Build Coastguard Worker   mz_uint num_files = mz_zip_reader_get_num_files(ar_.get());
307*da0073e9SAndroid Build Coastguard Worker   std::vector<std::string> out;
308*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
309*da0073e9SAndroid Build Coastguard Worker   char buf[MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE];
310*da0073e9SAndroid Build Coastguard Worker   for (size_t i = 0; i < num_files; i++) {
311*da0073e9SAndroid Build Coastguard Worker     mz_zip_reader_get_filename(ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE);
312*da0073e9SAndroid Build Coastguard Worker     if (strncmp(
313*da0073e9SAndroid Build Coastguard Worker             buf,
314*da0073e9SAndroid Build Coastguard Worker             archive_name_plus_slash_.data(),
315*da0073e9SAndroid Build Coastguard Worker             archive_name_plus_slash_.size()) != 0) {
316*da0073e9SAndroid Build Coastguard Worker       CAFFE_THROW(
317*da0073e9SAndroid Build Coastguard Worker           "file in archive is not in a subdirectory ",
318*da0073e9SAndroid Build Coastguard Worker           archive_name_plus_slash_,
319*da0073e9SAndroid Build Coastguard Worker           ": ",
320*da0073e9SAndroid Build Coastguard Worker           buf);
321*da0073e9SAndroid Build Coastguard Worker     }
322*da0073e9SAndroid Build Coastguard Worker     if ((load_debug_symbol_) ||
323*da0073e9SAndroid Build Coastguard Worker         (!c10::string_view(buf + archive_name_plus_slash_.size()).ends_with(kDebugPklSuffix))) {
324*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(modernize-use-emplace)
325*da0073e9SAndroid Build Coastguard Worker       out.push_back(buf + archive_name_plus_slash_.size());
326*da0073e9SAndroid Build Coastguard Worker     }
327*da0073e9SAndroid Build Coastguard Worker   }
328*da0073e9SAndroid Build Coastguard Worker   return out;
329*da0073e9SAndroid Build Coastguard Worker }
330*da0073e9SAndroid Build Coastguard Worker 
331*da0073e9SAndroid Build Coastguard Worker const std::unordered_set<std::string>&
getAllWrittenRecords()332*da0073e9SAndroid Build Coastguard Worker PyTorchStreamWriter::getAllWrittenRecords() {
333*da0073e9SAndroid Build Coastguard Worker   return files_written_;
334*da0073e9SAndroid Build Coastguard Worker }
335*da0073e9SAndroid Build Coastguard Worker 
getRecordID(const std::string & name)336*da0073e9SAndroid Build Coastguard Worker size_t PyTorchStreamReader::getRecordID(const std::string& name) {
337*da0073e9SAndroid Build Coastguard Worker   std::string ss = archive_name_plus_slash_ + name;
338*da0073e9SAndroid Build Coastguard Worker   size_t result = mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
339*da0073e9SAndroid Build Coastguard Worker   valid("locating file ", name.c_str());
340*da0073e9SAndroid Build Coastguard Worker   return result;
341*da0073e9SAndroid Build Coastguard Worker }
342*da0073e9SAndroid Build Coastguard Worker 
343*da0073e9SAndroid Build Coastguard Worker // return dataptr, size
getRecord(const std::string & name)344*da0073e9SAndroid Build Coastguard Worker std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
345*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> guard(reader_lock_);
346*da0073e9SAndroid Build Coastguard Worker   if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
347*da0073e9SAndroid Build Coastguard Worker     at::DataPtr retval;
348*da0073e9SAndroid Build Coastguard Worker     return std::make_tuple(std::move(retval), 0);
349*da0073e9SAndroid Build Coastguard Worker   }
350*da0073e9SAndroid Build Coastguard Worker   size_t key = getRecordID(name);
351*da0073e9SAndroid Build Coastguard Worker   mz_zip_archive_file_stat stat;
352*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_file_stat(ar_.get(), key, &stat);
353*da0073e9SAndroid Build Coastguard Worker   valid("retrieving file meta-data for ", name.c_str());
354*da0073e9SAndroid Build Coastguard Worker   at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
355*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_extract_to_mem(ar_.get(), key, retval.get(), stat.m_uncomp_size, 0);
356*da0073e9SAndroid Build Coastguard Worker   valid("reading file ", name.c_str());
357*da0073e9SAndroid Build Coastguard Worker 
358*da0073e9SAndroid Build Coastguard Worker   return std::make_tuple(std::move(retval), stat.m_uncomp_size);
359*da0073e9SAndroid Build Coastguard Worker }
360*da0073e9SAndroid Build Coastguard Worker 
361*da0073e9SAndroid Build Coastguard Worker size_t
getRecordMultiReaders(const std::string & name,std::vector<std::shared_ptr<ReadAdapterInterface>> & additionalReaders,void * dst,size_t n)362*da0073e9SAndroid Build Coastguard Worker PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
363*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
364*da0073e9SAndroid Build Coastguard Worker   void *dst, size_t n){
365*da0073e9SAndroid Build Coastguard Worker 
366*da0073e9SAndroid Build Coastguard Worker   size_t nthread = additionalReaders.size()+1;
367*da0073e9SAndroid Build Coastguard Worker   size_t recordOff = getRecordOffset(name);
368*da0073e9SAndroid Build Coastguard Worker   std::vector<std::thread> loaderThreads;
369*da0073e9SAndroid Build Coastguard Worker   size_t perThreadSize = (n+nthread-1)/nthread;
370*da0073e9SAndroid Build Coastguard Worker   std::vector<size_t> readSizes(nthread, 0);
371*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> guard(reader_lock_);
372*da0073e9SAndroid Build Coastguard Worker   for(size_t i = 0; i < nthread ; i++){
373*da0073e9SAndroid Build Coastguard Worker     loaderThreads.emplace_back([this, name, i, n, recordOff, perThreadSize, dst, &additionalReaders, &readSizes]{
374*da0073e9SAndroid Build Coastguard Worker       size_t startPos = i*perThreadSize;
375*da0073e9SAndroid Build Coastguard Worker       size_t endPos = std::min((i+1)*perThreadSize,n);
376*da0073e9SAndroid Build Coastguard Worker       if (startPos < endPos){
377*da0073e9SAndroid Build Coastguard Worker         size_t threadReadSize = endPos - startPos;
378*da0073e9SAndroid Build Coastguard Worker         size_t size = 0;
379*da0073e9SAndroid Build Coastguard Worker         if (i==0){
380*da0073e9SAndroid Build Coastguard Worker           size = read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
381*da0073e9SAndroid Build Coastguard Worker         }else{
382*da0073e9SAndroid Build Coastguard Worker           auto reader = additionalReaders[i-1];
383*da0073e9SAndroid Build Coastguard Worker           size = reader->read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
384*da0073e9SAndroid Build Coastguard Worker         }
385*da0073e9SAndroid Build Coastguard Worker         readSizes[i] = size;
386*da0073e9SAndroid Build Coastguard Worker         LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos << "] "
387*da0073e9SAndroid Build Coastguard Worker             << "from " << name << " of size " << n;
388*da0073e9SAndroid Build Coastguard Worker         TORCH_CHECK(
389*da0073e9SAndroid Build Coastguard Worker               threadReadSize == size,
390*da0073e9SAndroid Build Coastguard Worker               "record size ",
391*da0073e9SAndroid Build Coastguard Worker               threadReadSize,
392*da0073e9SAndroid Build Coastguard Worker               " mismatch with read size ",
393*da0073e9SAndroid Build Coastguard Worker               size);
394*da0073e9SAndroid Build Coastguard Worker       }
395*da0073e9SAndroid Build Coastguard Worker     });
396*da0073e9SAndroid Build Coastguard Worker   }
397*da0073e9SAndroid Build Coastguard Worker 
398*da0073e9SAndroid Build Coastguard Worker   for (auto& thread : loaderThreads) {
399*da0073e9SAndroid Build Coastguard Worker     thread.join();
400*da0073e9SAndroid Build Coastguard Worker   }
401*da0073e9SAndroid Build Coastguard Worker   loaderThreads.clear();
402*da0073e9SAndroid Build Coastguard Worker 
403*da0073e9SAndroid Build Coastguard Worker   size_t total_read_n = 0;
404*da0073e9SAndroid Build Coastguard Worker   for (auto& r : readSizes){
405*da0073e9SAndroid Build Coastguard Worker     total_read_n += r;
406*da0073e9SAndroid Build Coastguard Worker   }
407*da0073e9SAndroid Build Coastguard Worker 
408*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
409*da0073e9SAndroid Build Coastguard Worker       n == total_read_n,
410*da0073e9SAndroid Build Coastguard Worker       "Multi reader total read size ",
411*da0073e9SAndroid Build Coastguard Worker       total_read_n,
412*da0073e9SAndroid Build Coastguard Worker       " mismatch with dst size ",
413*da0073e9SAndroid Build Coastguard Worker       n);
414*da0073e9SAndroid Build Coastguard Worker 
415*da0073e9SAndroid Build Coastguard Worker   return total_read_n;
416*da0073e9SAndroid Build Coastguard Worker }
417*da0073e9SAndroid Build Coastguard Worker 
418*da0073e9SAndroid Build Coastguard Worker // read record with multi clients
419*da0073e9SAndroid Build Coastguard Worker std::tuple<at::DataPtr, size_t>
getRecord(const std::string & name,std::vector<std::shared_ptr<ReadAdapterInterface>> & additionalReaders)420*da0073e9SAndroid Build Coastguard Worker PyTorchStreamReader::getRecord(const std::string& name,
421*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
422*da0073e9SAndroid Build Coastguard Worker   if(additionalReaders.empty()){
423*da0073e9SAndroid Build Coastguard Worker     // No additional readers or record too small, use single threaded version
424*da0073e9SAndroid Build Coastguard Worker     return getRecord(name);
425*da0073e9SAndroid Build Coastguard Worker   }
426*da0073e9SAndroid Build Coastguard Worker 
427*da0073e9SAndroid Build Coastguard Worker   if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
428*da0073e9SAndroid Build Coastguard Worker     at::DataPtr retval;
429*da0073e9SAndroid Build Coastguard Worker     return std::make_tuple(std::move(retval), 0);
430*da0073e9SAndroid Build Coastguard Worker   }
431*da0073e9SAndroid Build Coastguard Worker   size_t key = getRecordID(name);
432*da0073e9SAndroid Build Coastguard Worker   mz_zip_archive_file_stat stat;
433*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_file_stat(ar_.get(), key, &stat);
434*da0073e9SAndroid Build Coastguard Worker   auto n = stat.m_uncomp_size;
435*da0073e9SAndroid Build Coastguard Worker   valid("retrieving file meta-data for ", name.c_str());
436*da0073e9SAndroid Build Coastguard Worker   if(n < additional_reader_size_threshold_){
437*da0073e9SAndroid Build Coastguard Worker     // Reader size too small, use single threaded version
438*da0073e9SAndroid Build Coastguard Worker     return getRecord(name);
439*da0073e9SAndroid Build Coastguard Worker   }
440*da0073e9SAndroid Build Coastguard Worker 
441*da0073e9SAndroid Build Coastguard Worker   at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
442*da0073e9SAndroid Build Coastguard Worker   void* dst = retval.get();
443*da0073e9SAndroid Build Coastguard Worker   PyTorchStreamReader::getRecordMultiReaders(name, additionalReaders, dst, n);
444*da0073e9SAndroid Build Coastguard Worker   return std::make_tuple(std::move(retval), stat.m_uncomp_size);
445*da0073e9SAndroid Build Coastguard Worker }
446*da0073e9SAndroid Build Coastguard Worker 
447*da0073e9SAndroid Build Coastguard Worker // inplace memory writing
448*da0073e9SAndroid Build Coastguard Worker size_t
getRecord(const std::string & name,void * dst,size_t n)449*da0073e9SAndroid Build Coastguard Worker PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) {
450*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> guard(reader_lock_);
451*da0073e9SAndroid Build Coastguard Worker   if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
452*da0073e9SAndroid Build Coastguard Worker     return 0;
453*da0073e9SAndroid Build Coastguard Worker   }
454*da0073e9SAndroid Build Coastguard Worker   size_t key = getRecordID(name);
455*da0073e9SAndroid Build Coastguard Worker   mz_zip_archive_file_stat stat;
456*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_file_stat(ar_.get(), key, &stat);
457*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
458*da0073e9SAndroid Build Coastguard Worker       n == stat.m_uncomp_size,
459*da0073e9SAndroid Build Coastguard Worker       "record size ",
460*da0073e9SAndroid Build Coastguard Worker       stat.m_uncomp_size,
461*da0073e9SAndroid Build Coastguard Worker       " mismatch with dst size ",
462*da0073e9SAndroid Build Coastguard Worker       n);
463*da0073e9SAndroid Build Coastguard Worker   valid("retrieving file meta-data for ", name.c_str());
464*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_extract_to_mem(ar_.get(), key, dst, stat.m_uncomp_size, 0);
465*da0073e9SAndroid Build Coastguard Worker   valid("reading file ", name.c_str());
466*da0073e9SAndroid Build Coastguard Worker 
467*da0073e9SAndroid Build Coastguard Worker   return stat.m_uncomp_size;
468*da0073e9SAndroid Build Coastguard Worker }
469*da0073e9SAndroid Build Coastguard Worker 
470*da0073e9SAndroid Build Coastguard Worker 
471*da0073e9SAndroid Build Coastguard Worker // inplace memory writing, in-tensor multi-threads, can be used for large tensor.
472*da0073e9SAndroid Build Coastguard Worker size_t
getRecord(const std::string & name,void * dst,size_t n,std::vector<std::shared_ptr<ReadAdapterInterface>> & additionalReaders)473*da0073e9SAndroid Build Coastguard Worker PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n,
474*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
475*da0073e9SAndroid Build Coastguard Worker   if(additionalReaders.empty()){
476*da0073e9SAndroid Build Coastguard Worker     // No additional readers, use single threaded version
477*da0073e9SAndroid Build Coastguard Worker     return getRecord(name, dst, n);
478*da0073e9SAndroid Build Coastguard Worker   }
479*da0073e9SAndroid Build Coastguard Worker 
480*da0073e9SAndroid Build Coastguard Worker   if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
481*da0073e9SAndroid Build Coastguard Worker     return 0;
482*da0073e9SAndroid Build Coastguard Worker   }
483*da0073e9SAndroid Build Coastguard Worker   size_t key = getRecordID(name);
484*da0073e9SAndroid Build Coastguard Worker   mz_zip_archive_file_stat stat;
485*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_file_stat(ar_.get(), key, &stat);
486*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
487*da0073e9SAndroid Build Coastguard Worker       n == stat.m_uncomp_size,
488*da0073e9SAndroid Build Coastguard Worker       "record size ",
489*da0073e9SAndroid Build Coastguard Worker       stat.m_uncomp_size,
490*da0073e9SAndroid Build Coastguard Worker       " mismatch with dst size ",
491*da0073e9SAndroid Build Coastguard Worker       n);
492*da0073e9SAndroid Build Coastguard Worker   valid("retrieving file meta-data for ", name.c_str());
493*da0073e9SAndroid Build Coastguard Worker 
494*da0073e9SAndroid Build Coastguard Worker   if(n < additional_reader_size_threshold_){
495*da0073e9SAndroid Build Coastguard Worker     // Reader size too small, use single threaded version
496*da0073e9SAndroid Build Coastguard Worker     return getRecord(name, dst, n);
497*da0073e9SAndroid Build Coastguard Worker   }
498*da0073e9SAndroid Build Coastguard Worker 
499*da0073e9SAndroid Build Coastguard Worker   PyTorchStreamReader::getRecordMultiReaders(name, additionalReaders, dst, n);
500*da0073e9SAndroid Build Coastguard Worker   return stat.m_uncomp_size;
501*da0073e9SAndroid Build Coastguard Worker }
502*da0073e9SAndroid Build Coastguard Worker 
getRecord(const std::string & name,void * dst,size_t n,size_t chunk_size,void * buf,const std::function<void (void *,const void *,size_t)> & memcpy_func)503*da0073e9SAndroid Build Coastguard Worker size_t PyTorchStreamReader::getRecord(
504*da0073e9SAndroid Build Coastguard Worker     const std::string& name,
505*da0073e9SAndroid Build Coastguard Worker     void* dst,
506*da0073e9SAndroid Build Coastguard Worker     size_t n,
507*da0073e9SAndroid Build Coastguard Worker     size_t chunk_size,
508*da0073e9SAndroid Build Coastguard Worker     void* buf,
509*da0073e9SAndroid Build Coastguard Worker     const std::function<void(void*, const void*, size_t)>& memcpy_func) {
510*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> guard(reader_lock_);
511*da0073e9SAndroid Build Coastguard Worker   if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
512*da0073e9SAndroid Build Coastguard Worker     return 0;
513*da0073e9SAndroid Build Coastguard Worker   }
514*da0073e9SAndroid Build Coastguard Worker   if (chunk_size <= 0) {
515*da0073e9SAndroid Build Coastguard Worker     chunk_size = n;
516*da0073e9SAndroid Build Coastguard Worker   }
517*da0073e9SAndroid Build Coastguard Worker   size_t key = getRecordID(name);
518*da0073e9SAndroid Build Coastguard Worker   mz_zip_archive_file_stat stat;
519*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_file_stat(ar_.get(), key, &stat);
520*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
521*da0073e9SAndroid Build Coastguard Worker       n == stat.m_uncomp_size,
522*da0073e9SAndroid Build Coastguard Worker       "record size ",
523*da0073e9SAndroid Build Coastguard Worker       stat.m_uncomp_size,
524*da0073e9SAndroid Build Coastguard Worker       " mismatch with dst size ",
525*da0073e9SAndroid Build Coastguard Worker       n);
526*da0073e9SAndroid Build Coastguard Worker   valid("retrieving file meta-data for ", name.c_str());
527*da0073e9SAndroid Build Coastguard Worker 
528*da0073e9SAndroid Build Coastguard Worker   std::vector<uint8_t> buffer;
529*da0073e9SAndroid Build Coastguard Worker   if (buf == nullptr) {
530*da0073e9SAndroid Build Coastguard Worker     buffer.resize(chunk_size);
531*da0073e9SAndroid Build Coastguard Worker     buf = buffer.data();
532*da0073e9SAndroid Build Coastguard Worker   }
533*da0073e9SAndroid Build Coastguard Worker 
534*da0073e9SAndroid Build Coastguard Worker   auto chunkIterator =
535*da0073e9SAndroid Build Coastguard Worker       createChunkReaderIter(name, (size_t)stat.m_uncomp_size, chunk_size);
536*da0073e9SAndroid Build Coastguard Worker   while (auto readSize = chunkIterator.next(buf)) {
537*da0073e9SAndroid Build Coastguard Worker     memcpy_func((char*)dst + chunkIterator.offset_ - readSize, buf, readSize);
538*da0073e9SAndroid Build Coastguard Worker   }
539*da0073e9SAndroid Build Coastguard Worker   valid("reading file ", name.c_str());
540*da0073e9SAndroid Build Coastguard Worker 
541*da0073e9SAndroid Build Coastguard Worker   return stat.m_uncomp_size;
542*da0073e9SAndroid Build Coastguard Worker }
543*da0073e9SAndroid Build Coastguard Worker 
createChunkReaderIter(const std::string & name,const size_t recordSize,const size_t chunkSize)544*da0073e9SAndroid Build Coastguard Worker ChunkRecordIterator PyTorchStreamReader::createChunkReaderIter(
545*da0073e9SAndroid Build Coastguard Worker     const std::string& name,
546*da0073e9SAndroid Build Coastguard Worker     const size_t recordSize,
547*da0073e9SAndroid Build Coastguard Worker     const size_t chunkSize) {
548*da0073e9SAndroid Build Coastguard Worker   // Create zip reader iterator
549*da0073e9SAndroid Build Coastguard Worker   size_t key = getRecordID(name);
550*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_extract_iter_state* zipReaderIter =
551*da0073e9SAndroid Build Coastguard Worker       mz_zip_reader_extract_iter_new(ar_.get(), key, 0);
552*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(
553*da0073e9SAndroid Build Coastguard Worker       zipReaderIter != nullptr,
554*da0073e9SAndroid Build Coastguard Worker       "Failed to create zip reader iter: ",
555*da0073e9SAndroid Build Coastguard Worker       mz_zip_get_error_string(mz_zip_get_last_error(ar_.get())));
556*da0073e9SAndroid Build Coastguard Worker 
557*da0073e9SAndroid Build Coastguard Worker   return ChunkRecordIterator(
558*da0073e9SAndroid Build Coastguard Worker       recordSize,
559*da0073e9SAndroid Build Coastguard Worker       chunkSize,
560*da0073e9SAndroid Build Coastguard Worker       std::make_unique<MzZipReaderIterWrapper>(zipReaderIter));
561*da0073e9SAndroid Build Coastguard Worker }
562*da0073e9SAndroid Build Coastguard Worker 
read_le_16(uint8_t * buf)563*da0073e9SAndroid Build Coastguard Worker static int64_t read_le_16(uint8_t* buf) {
564*da0073e9SAndroid Build Coastguard Worker   return buf[0] + (buf[1] << 8);
565*da0073e9SAndroid Build Coastguard Worker }
566*da0073e9SAndroid Build Coastguard Worker 
getRecordOffset(const std::string & name)567*da0073e9SAndroid Build Coastguard Worker size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
568*da0073e9SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> guard(reader_lock_);
569*da0073e9SAndroid Build Coastguard Worker   mz_zip_archive_file_stat stat;
570*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat);
571*da0073e9SAndroid Build Coastguard Worker   valid("retrieving file meta-data for ", name.c_str());
572*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
573*da0073e9SAndroid Build Coastguard Worker   uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
574*da0073e9SAndroid Build Coastguard Worker   in_->read(
575*da0073e9SAndroid Build Coastguard Worker       stat.m_local_header_ofs,
576*da0073e9SAndroid Build Coastguard Worker       local_header,
577*da0073e9SAndroid Build Coastguard Worker       MZ_ZIP_LOCAL_DIR_HEADER_SIZE,
578*da0073e9SAndroid Build Coastguard Worker       "reading file header");
579*da0073e9SAndroid Build Coastguard Worker   size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
580*da0073e9SAndroid Build Coastguard Worker   size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
581*da0073e9SAndroid Build Coastguard Worker   return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
582*da0073e9SAndroid Build Coastguard Worker }
583*da0073e9SAndroid Build Coastguard Worker 
getRecordSize(const std::string & name)584*da0073e9SAndroid Build Coastguard Worker size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
585*da0073e9SAndroid Build Coastguard Worker   mz_zip_archive_file_stat stat;
586*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat);
587*da0073e9SAndroid Build Coastguard Worker   return stat.m_uncomp_size;
588*da0073e9SAndroid Build Coastguard Worker }
589*da0073e9SAndroid Build Coastguard Worker 
~PyTorchStreamReader()590*da0073e9SAndroid Build Coastguard Worker PyTorchStreamReader::~PyTorchStreamReader() {
591*da0073e9SAndroid Build Coastguard Worker   mz_zip_clear_last_error(ar_.get());
592*da0073e9SAndroid Build Coastguard Worker   mz_zip_reader_end(ar_.get());
593*da0073e9SAndroid Build Coastguard Worker   valid("closing reader for archive ", archive_name_.c_str());
594*da0073e9SAndroid Build Coastguard Worker }
595*da0073e9SAndroid Build Coastguard Worker 
ostream_write_func(void * pOpaque,mz_uint64 file_ofs,const void * pBuf,size_t n)596*da0073e9SAndroid Build Coastguard Worker size_t ostream_write_func(
597*da0073e9SAndroid Build Coastguard Worker     void* pOpaque,
598*da0073e9SAndroid Build Coastguard Worker     mz_uint64 file_ofs,
599*da0073e9SAndroid Build Coastguard Worker     const void* pBuf,
600*da0073e9SAndroid Build Coastguard Worker     size_t n) {
601*da0073e9SAndroid Build Coastguard Worker   auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
602*da0073e9SAndroid Build Coastguard Worker   if (self->current_pos_ != file_ofs) {
603*da0073e9SAndroid Build Coastguard Worker     CAFFE_THROW("unexpected pos ", self->current_pos_, " vs ", file_ofs);
604*da0073e9SAndroid Build Coastguard Worker   }
605*da0073e9SAndroid Build Coastguard Worker   size_t ret = self->writer_func_(pBuf, n);
606*da0073e9SAndroid Build Coastguard Worker   if (n != ret) {
607*da0073e9SAndroid Build Coastguard Worker     self->err_seen_ = true;
608*da0073e9SAndroid Build Coastguard Worker   }
609*da0073e9SAndroid Build Coastguard Worker   self->current_pos_ += ret;
610*da0073e9SAndroid Build Coastguard Worker 
611*da0073e9SAndroid Build Coastguard Worker   // Get the CRC32 of uncompressed data from the data descriptor, if the written
612*da0073e9SAndroid Build Coastguard Worker   // data is identified as the data descriptor block.
613*da0073e9SAndroid Build Coastguard Worker   // See [Note: write_record_metadata] for why we check for non-null pBuf here
614*da0073e9SAndroid Build Coastguard Worker   if (pBuf && n >= 8 && MZ_READ_LE32(pBuf) == MZ_ZIP_DATA_DESCRIPTOR_ID) {
615*da0073e9SAndroid Build Coastguard Worker     const int8_t* pInt8Buf = (const int8_t*)pBuf;
616*da0073e9SAndroid Build Coastguard Worker     const uint32_t uncomp_crc32 = MZ_READ_LE32(pInt8Buf + 4);
617*da0073e9SAndroid Build Coastguard Worker     self->combined_uncomp_crc32_ =
618*da0073e9SAndroid Build Coastguard Worker         c10::hash_combine(self->combined_uncomp_crc32_, uncomp_crc32);
619*da0073e9SAndroid Build Coastguard Worker   }
620*da0073e9SAndroid Build Coastguard Worker 
621*da0073e9SAndroid Build Coastguard Worker   return ret;
622*da0073e9SAndroid Build Coastguard Worker }
623*da0073e9SAndroid Build Coastguard Worker 
PyTorchStreamWriter(const std::string & file_name)624*da0073e9SAndroid Build Coastguard Worker PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name)
625*da0073e9SAndroid Build Coastguard Worker     : archive_name_(basename(file_name)) {
626*da0073e9SAndroid Build Coastguard Worker   setup(file_name);
627*da0073e9SAndroid Build Coastguard Worker }
628*da0073e9SAndroid Build Coastguard Worker 
PyTorchStreamWriter(const std::function<size_t (const void *,size_t)> writer_func)629*da0073e9SAndroid Build Coastguard Worker PyTorchStreamWriter::PyTorchStreamWriter(
630*da0073e9SAndroid Build Coastguard Worker     const std::function<size_t(const void*, size_t)> writer_func)
631*da0073e9SAndroid Build Coastguard Worker     : archive_name_("archive"),
632*da0073e9SAndroid Build Coastguard Worker       writer_func_(writer_func) {
633*da0073e9SAndroid Build Coastguard Worker   setup(archive_name_);
634*da0073e9SAndroid Build Coastguard Worker }
635*da0073e9SAndroid Build Coastguard Worker 
setup(const string & file_name)636*da0073e9SAndroid Build Coastguard Worker void PyTorchStreamWriter::setup(const string& file_name) {
637*da0073e9SAndroid Build Coastguard Worker   ar_ = std::make_unique<mz_zip_archive>();
638*da0073e9SAndroid Build Coastguard Worker   memset(ar_.get(), 0, sizeof(mz_zip_archive));
639*da0073e9SAndroid Build Coastguard Worker   archive_name_plus_slash_ = archive_name_ + "/"; // for writeRecord().
640*da0073e9SAndroid Build Coastguard Worker 
641*da0073e9SAndroid Build Coastguard Worker   if (archive_name_.size() == 0) {
642*da0073e9SAndroid Build Coastguard Worker     CAFFE_THROW("invalid file name: ", file_name);
643*da0073e9SAndroid Build Coastguard Worker   }
644*da0073e9SAndroid Build Coastguard Worker   if (!writer_func_) {
645*da0073e9SAndroid Build Coastguard Worker     file_stream_.open(
646*da0073e9SAndroid Build Coastguard Worker         file_name,
647*da0073e9SAndroid Build Coastguard Worker         std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
648*da0073e9SAndroid Build Coastguard Worker     valid("opening archive ", file_name.c_str());
649*da0073e9SAndroid Build Coastguard Worker 
650*da0073e9SAndroid Build Coastguard Worker     const std::string dir_name = parentdir(file_name);
651*da0073e9SAndroid Build Coastguard Worker     if(!dir_name.empty()) {
652*da0073e9SAndroid Build Coastguard Worker       struct stat st;
653*da0073e9SAndroid Build Coastguard Worker       bool dir_exists = (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR));
654*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(dir_exists, "Parent directory ", dir_name, " does not exist.");
655*da0073e9SAndroid Build Coastguard Worker     }
656*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened.");
657*da0073e9SAndroid Build Coastguard Worker     writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
658*da0073e9SAndroid Build Coastguard Worker       if (!buf) {
659*da0073e9SAndroid Build Coastguard Worker         // See [Note: write_record_metadata]
660*da0073e9SAndroid Build Coastguard Worker         file_stream_.seekp(nbytes, std::ios_base::cur);
661*da0073e9SAndroid Build Coastguard Worker       } else {
662*da0073e9SAndroid Build Coastguard Worker         file_stream_.write(static_cast<const char*>(buf), nbytes);
663*da0073e9SAndroid Build Coastguard Worker       }
664*da0073e9SAndroid Build Coastguard Worker       return !file_stream_ ? 0 : nbytes;
665*da0073e9SAndroid Build Coastguard Worker     };
666*da0073e9SAndroid Build Coastguard Worker   }
667*da0073e9SAndroid Build Coastguard Worker 
668*da0073e9SAndroid Build Coastguard Worker   ar_->m_pIO_opaque = this;
669*da0073e9SAndroid Build Coastguard Worker   ar_->m_pWrite = ostream_write_func;
670*da0073e9SAndroid Build Coastguard Worker 
671*da0073e9SAndroid Build Coastguard Worker   mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
672*da0073e9SAndroid Build Coastguard Worker   valid("initializing archive ", file_name.c_str());
673*da0073e9SAndroid Build Coastguard Worker }
674*da0073e9SAndroid Build Coastguard Worker 
setMinVersion(const uint64_t version)675*da0073e9SAndroid Build Coastguard Worker void PyTorchStreamWriter::setMinVersion(const uint64_t version) {
676*da0073e9SAndroid Build Coastguard Worker   version_ = std::max(version, version_);
677*da0073e9SAndroid Build Coastguard Worker }
678*da0073e9SAndroid Build Coastguard Worker 
writeRecord(const std::string & name,const void * data,size_t size,bool compress)679*da0073e9SAndroid Build Coastguard Worker void PyTorchStreamWriter::writeRecord(
680*da0073e9SAndroid Build Coastguard Worker     const std::string& name,
681*da0073e9SAndroid Build Coastguard Worker     const void* data,
682*da0073e9SAndroid Build Coastguard Worker     size_t size,
683*da0073e9SAndroid Build Coastguard Worker     bool compress) {
684*da0073e9SAndroid Build Coastguard Worker   AT_ASSERT(!finalized_);
685*da0073e9SAndroid Build Coastguard Worker   AT_ASSERT(!archive_name_plus_slash_.empty());
686*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(
687*da0073e9SAndroid Build Coastguard Worker       files_written_.count(name) == 0, "Tried to serialize file twice: ", name);
688*da0073e9SAndroid Build Coastguard Worker   if (name == kSerializationIdRecordName && serialization_id_.empty()) {
689*da0073e9SAndroid Build Coastguard Worker     // In case of copying records from another file, skip writing a different
690*da0073e9SAndroid Build Coastguard Worker     // serialization_id than the one computed in this writer.
691*da0073e9SAndroid Build Coastguard Worker     // This is to ensure serialization_id is unique per serialization output.
692*da0073e9SAndroid Build Coastguard Worker     return;
693*da0073e9SAndroid Build Coastguard Worker   }
694*da0073e9SAndroid Build Coastguard Worker   std::string full_name = archive_name_plus_slash_ + name;
695*da0073e9SAndroid Build Coastguard Worker   size_t padding_size =
696*da0073e9SAndroid Build Coastguard Worker       detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
697*da0073e9SAndroid Build Coastguard Worker   uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
698*da0073e9SAndroid Build Coastguard Worker   mz_zip_writer_add_mem_ex_v2(
699*da0073e9SAndroid Build Coastguard Worker       /*pZip=*/ar_.get(),
700*da0073e9SAndroid Build Coastguard Worker       /*pArchive_name=*/full_name.c_str(),
701*da0073e9SAndroid Build Coastguard Worker       /*pBuf=*/data,
702*da0073e9SAndroid Build Coastguard Worker       /*buf_size=*/size,
703*da0073e9SAndroid Build Coastguard Worker       /*pComment=*/nullptr,
704*da0073e9SAndroid Build Coastguard Worker       /*comment_size=*/0,
705*da0073e9SAndroid Build Coastguard Worker       /*level_and_flags=*/flags,
706*da0073e9SAndroid Build Coastguard Worker       /*uncomp_size=*/0,
707*da0073e9SAndroid Build Coastguard Worker       /*uncomp_crc32=*/0,
708*da0073e9SAndroid Build Coastguard Worker       /*last_modified=*/nullptr,
709*da0073e9SAndroid Build Coastguard Worker       /*user_extra_data=*/padding_.c_str(),
710*da0073e9SAndroid Build Coastguard Worker       /*user_extra_data_len=*/padding_size,
711*da0073e9SAndroid Build Coastguard Worker       /*user_extra_data_central=*/nullptr,
712*da0073e9SAndroid Build Coastguard Worker       /*user_extra_data_central_len=*/0);
713*da0073e9SAndroid Build Coastguard Worker   valid("writing file ", name.c_str());
714*da0073e9SAndroid Build Coastguard Worker   files_written_.insert(name);
715*da0073e9SAndroid Build Coastguard Worker }
716*da0073e9SAndroid Build Coastguard Worker 
writeEndOfFile()717*da0073e9SAndroid Build Coastguard Worker void PyTorchStreamWriter::writeEndOfFile() {
718*da0073e9SAndroid Build Coastguard Worker   // Ensurers that finalized is set to true even
719*da0073e9SAndroid Build Coastguard Worker   // exception is raised during the method call.
720*da0073e9SAndroid Build Coastguard Worker   // I.e. even partial call to writeEndOfFile() should mark
721*da0073e9SAndroid Build Coastguard Worker   // file as finalized, otherwise double exception raised from
722*da0073e9SAndroid Build Coastguard Worker   // destructor would would result in `std::terminate()`
723*da0073e9SAndroid Build Coastguard Worker   // See https://github.com/pytorch/pytorch/issues/87997/
724*da0073e9SAndroid Build Coastguard Worker   struct Finalizer {
725*da0073e9SAndroid Build Coastguard Worker     Finalizer(bool& var): var_(var) {}
726*da0073e9SAndroid Build Coastguard Worker     ~Finalizer() {
727*da0073e9SAndroid Build Coastguard Worker       var_ = true;
728*da0073e9SAndroid Build Coastguard Worker     }
729*da0073e9SAndroid Build Coastguard Worker    private:
730*da0073e9SAndroid Build Coastguard Worker     bool& var_;
731*da0073e9SAndroid Build Coastguard Worker   } f(finalized_);
732*da0073e9SAndroid Build Coastguard Worker 
733*da0073e9SAndroid Build Coastguard Worker   auto allRecords = getAllWrittenRecords();
734*da0073e9SAndroid Build Coastguard Worker   // If no ".data/version" or "version" record in the output model, rewrites version info
735*da0073e9SAndroid Build Coastguard Worker   if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) {
736*da0073e9SAndroid Build Coastguard Worker     std::string version = std::to_string(version_);
737*da0073e9SAndroid Build Coastguard Worker     version.push_back('\n');
738*da0073e9SAndroid Build Coastguard Worker     if (version_ >= 0x6L) {
739*da0073e9SAndroid Build Coastguard Worker       writeRecord(".data/version", version.c_str(), version.size());
740*da0073e9SAndroid Build Coastguard Worker     } else {
741*da0073e9SAndroid Build Coastguard Worker       writeRecord("version", version.c_str(), version.size());
742*da0073e9SAndroid Build Coastguard Worker     }
743*da0073e9SAndroid Build Coastguard Worker   }
744*da0073e9SAndroid Build Coastguard Worker 
745*da0073e9SAndroid Build Coastguard Worker   // If no "byteorder" record in the output model, rewrites byteorder info
746*da0073e9SAndroid Build Coastguard Worker   if(allRecords.find("byteorder") == allRecords.end()) {
747*da0073e9SAndroid Build Coastguard Worker #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
748*da0073e9SAndroid Build Coastguard Worker     std::string byteorder = "little";
749*da0073e9SAndroid Build Coastguard Worker #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
750*da0073e9SAndroid Build Coastguard Worker     std::string byteorder = "big";
751*da0073e9SAndroid Build Coastguard Worker #else
752*da0073e9SAndroid Build Coastguard Worker #error Unexpected or undefined __BYTE_ORDER__
753*da0073e9SAndroid Build Coastguard Worker #endif
754*da0073e9SAndroid Build Coastguard Worker     writeRecord("byteorder", byteorder.c_str(), byteorder.size());
755*da0073e9SAndroid Build Coastguard Worker   }
756*da0073e9SAndroid Build Coastguard Worker 
757*da0073e9SAndroid Build Coastguard Worker   writeSerializationId();
758*da0073e9SAndroid Build Coastguard Worker 
759*da0073e9SAndroid Build Coastguard Worker   AT_ASSERT(!finalized_);
760*da0073e9SAndroid Build Coastguard Worker   finalized_ = true;
761*da0073e9SAndroid Build Coastguard Worker 
762*da0073e9SAndroid Build Coastguard Worker   mz_zip_writer_finalize_archive(ar_.get());
763*da0073e9SAndroid Build Coastguard Worker   mz_zip_writer_end(ar_.get());
764*da0073e9SAndroid Build Coastguard Worker   valid("writing central directory for archive ", archive_name_.c_str());
765*da0073e9SAndroid Build Coastguard Worker   c10::LogAPIUsageMetadata(
766*da0073e9SAndroid Build Coastguard Worker       "pytorch.stream.writer.metadata",
767*da0073e9SAndroid Build Coastguard Worker       {{"serialization_id", serialization_id_},
768*da0073e9SAndroid Build Coastguard Worker        {"file_name", archive_name_},
769*da0073e9SAndroid Build Coastguard Worker        {"file_size", str(mz_zip_get_archive_size(ar_.get()))}});
770*da0073e9SAndroid Build Coastguard Worker   if (file_stream_.is_open()) {
771*da0073e9SAndroid Build Coastguard Worker     file_stream_.close();
772*da0073e9SAndroid Build Coastguard Worker   }
773*da0073e9SAndroid Build Coastguard Worker }
774*da0073e9SAndroid Build Coastguard Worker 
valid(const char * what,const char * info)775*da0073e9SAndroid Build Coastguard Worker void PyTorchStreamWriter::valid(const char* what, const char* info) {
776*da0073e9SAndroid Build Coastguard Worker   auto err = mz_zip_get_last_error(ar_.get());
777*da0073e9SAndroid Build Coastguard Worker   if (err != MZ_ZIP_NO_ERROR) {
778*da0073e9SAndroid Build Coastguard Worker     CAFFE_THROW(
779*da0073e9SAndroid Build Coastguard Worker         "PytorchStreamWriter failed ",
780*da0073e9SAndroid Build Coastguard Worker         what,
781*da0073e9SAndroid Build Coastguard Worker         info,
782*da0073e9SAndroid Build Coastguard Worker         ": ",
783*da0073e9SAndroid Build Coastguard Worker         mz_zip_get_error_string(err));
784*da0073e9SAndroid Build Coastguard Worker   }
785*da0073e9SAndroid Build Coastguard Worker   if (err_seen_) {
786*da0073e9SAndroid Build Coastguard Worker     CAFFE_THROW("PytorchStreamWriter failed ", what, info, ".");
787*da0073e9SAndroid Build Coastguard Worker   }
788*da0073e9SAndroid Build Coastguard Worker }
789*da0073e9SAndroid Build Coastguard Worker 
writeSerializationId()790*da0073e9SAndroid Build Coastguard Worker void PyTorchStreamWriter::writeSerializationId() {
791*da0073e9SAndroid Build Coastguard Worker   // Serialization id is computed based on all files written, and is composed of
792*da0073e9SAndroid Build Coastguard Worker   // 1) a combined hash of record name hashes
793*da0073e9SAndroid Build Coastguard Worker   // 2) a combined crc32 of the record uncompressed data
794*da0073e9SAndroid Build Coastguard Worker   // This is best effort to create a fixed-length, unique and deterministic id
795*da0073e9SAndroid Build Coastguard Worker   // for the serialized files without incurring additional computation overhead.
796*da0073e9SAndroid Build Coastguard Worker   if (files_written_.find(kSerializationIdRecordName) == files_written_.end()) {
797*da0073e9SAndroid Build Coastguard Worker     uint64_t combined_record_name_hash = 0;
798*da0073e9SAndroid Build Coastguard Worker     for (const std::string& record_name : files_written_) {
799*da0073e9SAndroid Build Coastguard Worker       size_t record_name_hash = c10::hash<std::string>{}(record_name);
800*da0073e9SAndroid Build Coastguard Worker       combined_record_name_hash =
801*da0073e9SAndroid Build Coastguard Worker           c10::hash_combine(combined_record_name_hash, record_name_hash);
802*da0073e9SAndroid Build Coastguard Worker     }
803*da0073e9SAndroid Build Coastguard Worker     std::ostringstream serialization_id_oss;
804*da0073e9SAndroid Build Coastguard Worker     serialization_id_oss << std::setfill('0') << std::setw(20)
805*da0073e9SAndroid Build Coastguard Worker                          << combined_record_name_hash
806*da0073e9SAndroid Build Coastguard Worker                          << std::setfill('0') << std::setw(20)
807*da0073e9SAndroid Build Coastguard Worker                          << combined_uncomp_crc32_;
808*da0073e9SAndroid Build Coastguard Worker     serialization_id_ = serialization_id_oss.str();
809*da0073e9SAndroid Build Coastguard Worker     writeRecord(
810*da0073e9SAndroid Build Coastguard Worker         kSerializationIdRecordName,
811*da0073e9SAndroid Build Coastguard Worker         serialization_id_.c_str(),
812*da0073e9SAndroid Build Coastguard Worker         serialization_id_.size());
813*da0073e9SAndroid Build Coastguard Worker   }
814*da0073e9SAndroid Build Coastguard Worker }
815*da0073e9SAndroid Build Coastguard Worker 
816*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(bugprone-exception-escape)
~PyTorchStreamWriter()817*da0073e9SAndroid Build Coastguard Worker PyTorchStreamWriter::~PyTorchStreamWriter() {
818*da0073e9SAndroid Build Coastguard Worker   if (!finalized_) {
819*da0073e9SAndroid Build Coastguard Worker     writeEndOfFile();
820*da0073e9SAndroid Build Coastguard Worker   }
821*da0073e9SAndroid Build Coastguard Worker }
822*da0073e9SAndroid Build Coastguard Worker 
823*da0073e9SAndroid Build Coastguard Worker } // namespace serialize
824*da0073e9SAndroid Build Coastguard Worker } // namespace caffe2
825