xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/tensor_bundle/tensor_bundle.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // A tensor bundle is a set of immutable persistent files storing a set of named
17 // tensors.  It is designed for checkpointing TensorFlow tensors.
18 //
19 // The paths of the managed files share a common prefix; e.g., with the prefix:
20 //   /fs/model/train/ckpt-step/ckpt
21 //
22 // the bundle may contain a metadata file, and sharded data files:
23 //   /fs/model/train/ckpt-step/
24 //       ckpt.index
25 //       ckpt.data-00000-of-00020
26 //       ckpt.data-00001-of-00020
27 //       ...
28 //       ckpt.data-00019-of-00020
29 //
30 // The ".index" file is a string-string immutable table
31 // (tensorflow::table::Table).  Each key is a name of a tensor and its value is
32 // a serialized BundleEntryProto.  Each BundleEntryProto describes the metadata
33 // of a tensor: which of the "data" files contains the content of a tensor, the
34 // offset into that file, checksum, some auxiliary data, etc.
35 //
36 // A tensor bundle can be accessed randomly using a BundleReader.  Usage:
37 //
38 //   BundleReader reader(env, "/fs/model/train/ckpt-step/ckpt");
39 //   reader.Lookup("name", &tensor);
40 //
41 // A tensor bundle can be built using BundleWriter.  Each BundleWriter builds a
42 // single data file bundle.  Multiple bundles can then be merged by
43 // MergeBundles() without reading and writing large chunk of data: it reads the
44 // metadata files and outputs a single merged metadata.  Typical usage:
45 //
46 //   worker 0:
47 //     BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker0-step");
48 //     writer.Add(...);  // Adds the tensors on this worker.
49 //     writer.Finish();  // Flushes.
50 //   worker 1:
51 //     BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker1-step");
52 //     writer.Add(...);
53 //     writer.Finish();
54 //   worker 2:
55 //     MergeBundles(env,
56 //       {"/fs/model/train/ckpt-step/tmp/worker0-step",
57 //        "/fs/model/train/ckpt-step/tmp/worker1-step"},
58 //       "/fs/model/train/ckpt-step/ckpt" /* merged prefix */);
59 //
60 
61 #ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
62 #define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
63 
64 #include <map>
65 #include <string>
66 #include <unordered_map>
67 
68 #include "absl/algorithm/container.h"
69 #include "absl/container/flat_hash_map.h"
70 #include "absl/functional/function_ref.h"
71 #include "tensorflow/core/framework/tensor.h"
72 #include "tensorflow/core/framework/tensor_shape.h"
73 #include "tensorflow/core/framework/tensor_slice.h"
74 #include "tensorflow/core/lib/core/status.h"
75 #include "tensorflow/core/lib/gtl/array_slice.h"
76 #include "tensorflow/core/lib/io/cache.h"
77 #include "tensorflow/core/lib/io/inputbuffer.h"
78 #include "tensorflow/core/lib/io/table.h"
79 #include "tensorflow/core/platform/cord.h"
80 #include "tensorflow/core/platform/env.h"
81 #include "tensorflow/core/platform/file_system.h"
82 #include "tensorflow/core/platform/macros.h"
83 #include "tensorflow/core/platform/types.h"
84 #include "tensorflow/core/protobuf/tensor_bundle.pb.h"
85 #include "tensorflow/core/util/tensor_bundle/naming.h"
86 #include "tensorflow/core/util/tensor_slice_set.h"
87 
88 namespace tensorflow {
89 
90 class FileOutputBuffer;
91 
92 // Versioning of the tensor bundle format.
93 // Follows the same rules as 3p/tf/core/public/version.h.
94 //
95 // History:
96 // 0. Any tensor bundles produced before this field was added.
97 // 1. Added this field (2016-09-14).
98 extern const int kTensorBundleMinProducer;
99 extern const int kTensorBundleMinConsumer;
100 extern const int kTensorBundleVersion;
101 
102 // The empty string, hence always the first key in the metadata table.  Its
103 // corresponding value is a BundleHeaderProto.
104 extern const char* const kHeaderEntryKey;
105 
106 // Builds a string-string table of tensor names to BundleEntryProto (metadata).
107 //
108 // On construction, attempts to create a directory given by the dirname of
109 // "prefix", so "status()" must be checked before calling any member functions.
110 //
111 // All threads accessing the same BundleWriter must synchronize.
112 class BundleWriter {
113  public:
114   struct Options {
OptionsOptions115     Options() {}
116     // Alignment, in bytes, for tensor data.
117     // Must be >= 1. The default size of 1 densely packs tensors.
118     int data_alignment{1};
119   };
120   BundleWriter(Env* env, StringPiece prefix,
121                const Options& options = Options());
122 
123   // Adds the tensor "val" under key "key".
124   // Across calls "key" must be unique but can be added in any order.
125   Status Add(StringPiece key, const Tensor& val);
126 
127   // Partitioned variables support.
128   // A slice of a full tensor is stored in two entries in the metadata table:
129   //
130   //   full_tensor_key   -> BundleEntryProto, describing all stored slices
131   //                        of this full tensor.  Does not append to the data
132   //                        file.
133   //   encoded slice key -> BundleEntryProto, describing one particular slice.
134   //                        Appends values of this slice to the data file.
135   //
136   // Slices of a full tensor can be added in any order.
137   //
138   // If a full tensor has slices placed on N devices and N BundleWriter's are
139   // concurrently used, the caller must use MergeBundles() to ensure that a
140   // consistent entry for "full_tensor_key" is produced.
141   //
142   // Returns an error if the same slice is added the second time.
143   Status AddSlice(StringPiece full_tensor_key,
144                   const TensorShape& full_tensor_shape,
145                   const TensorSlice& slice_spec, const Tensor& slice_tensor);
146 
147   // Finishes the writer and flushes.
148   Status Finish() TF_MUST_USE_RESULT;
149 
status()150   Status status() const { return status_; }
151 
152  private:
153   Env* const env_;  // Not owned.
154   const Options options_;
155   const string prefix_;
156   string metadata_path_;
157   string data_path_;
158   bool use_temp_file_;
159   std::unique_ptr<FileOutputBuffer> out_;
160   int64_t size_;  // Number of bytes written into out_.
161   std::map<string, BundleEntryProto> entries_;
162   Status status_;
163 
164   TF_DISALLOW_COPY_AND_ASSIGN(BundleWriter);
165 };
166 
167 // Merges a set of bundles (given their prefixes) into a single bundle with the
168 // given "merged_prefix".  The merged metadata is guaranteed to be consistent.
169 //
170 // If there are N bundles in "prefixes", during the merge the data files will be
171 // renamed to contain a proper sharded file spec, with num_shards set to the sum
172 // of num_shards across the N input bundles.
173 //
174 // The caller should only rely on the metadata file of the merged bundle to
175 // query information about a tensor.  In particular, this function does not
176 // guarantee not to re-order the input data files.
177 //
178 // Once merged, makes a best effort to delete the old metadata files.
179 // Returns OK iff all bundles are successfully merged.
180 //
181 // "allow_missing_files": If set to true, merges "prefixes" as long as
182 // at least one file exists. (Defaults to false.)
183 //
184 // Returns an InvalidArgumentError when "allow_missing_files" is set to true
185 // and all data files named in "prefixes" do not exist.
186 //
187 // Returns a NotFoundError when "allow_missing_files" is set to false and
188 // any data file named in "prefixes" does not exist.
189 Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
190                     StringPiece merged_prefix,
191                     bool allow_missing_files = false);
192 
193 // On construction, silently attempts to read the metadata associated with
194 // "prefix".  If caller intends to call any function afterwards, "status()"
195 // must be checked.
196 // All threads accessing the same BundleReader must synchronize.
197 class BundleReader {
198  public:
199   BundleReader(Env* const env, StringPiece prefix);
200   ~BundleReader();
201 
202   // Is ok() iff the reader construction is successful (completed the read of
203   // the metadata).
status()204   Status status() const { return status_; }
205 
206   // Queries whether the bundle contains an entry keyed by "key".  Calls Seek()
207   // internally, so this call invalidates the reader's current position.
208   // REQUIRES: status().ok()
209   bool Contains(StringPiece key);
210 
211   // Sorts a `container` of tensors to read such that when `Seek(key)` is called
212   // on the elements of the sorted container, the underlying file access is
213   // sequential. Sorting can greatly improve overall read speed.
214   //
215   // `get_key` should be a functon that when passed an element in `container`,
216   // returns the `key` of the tensor.
217   //
218   // REQUIRES: status().ok()
219   template <class T>
220   Status SortForSequentialAccess(std::vector<T>& container,
221                                  absl::FunctionRef<string(const T&)> get_key);
222 
223   // Looks up the dtype and the shape of the tensor keyed by "key".
224   // REQUIRES: status().ok()
225   Status LookupDtypeAndShape(StringPiece key, DataType* dtype,
226                              TensorShape* shape) TF_MUST_USE_RESULT;
227 
228   // Looks up the shape of the tensor keyed by "key".
229   // Clears "shape" if not found.
230   // REQUIRES: status().ok()
231   Status LookupTensorShape(StringPiece key,
232                            TensorShape* shape) TF_MUST_USE_RESULT;
233 
234   // Looks up the tensor keyed by "key".  If "key" refers to a partitioned
235   // tensor, attempts to look up the full contents using all stored slices.
236   //
237   // Caller must make sure "val" has the same shape and dtype as the
238   // corresponding contents, so that its buffer can be filled without needing
239   // extra allocation.  These can be queried via "LookupDtypeAndShape()".
240   //
241   // On error, "val" may contain nonsense data.  Returns a NotFound error if
242   // tensor keyed by "key" does not exist in this bundle.
243   //
244   // Validates the stored crc32c checksum against the restored bytes.
245   // REQUIRES: status().ok()
246   Status Lookup(StringPiece key, Tensor* val) TF_MUST_USE_RESULT;
247 
248   // Looks up the tensor pointed to by the internal iterator.
249   //
250   // On error, "val" may contain nonsense data.
251   //
252   // Validates the stored crc32c checksum against the restored bytes.
253   // REQUIRES: status().ok() && Valid()
254   Status ReadCurrent(Tensor* val) TF_MUST_USE_RESULT;
255 
256   // Looks up the slices of the tensor keyed by "key".  On OK, "slices"
257   // is non-empty if and only if the tensor is a partitioned tensor.
258   //
259   // Warning - there is no guaranteed ordering for the returned slices, so
260   // a slice with a larger start index in some dimension could come before
261   // another slice with a smaller start index in the same dimension.
262   // REQUIRES: status().ok()
263   Status LookupTensorSlices(StringPiece key, std::vector<TensorSlice>* slices)
264       TF_MUST_USE_RESULT;
265 
266   // Looks up a specific slice of a partitioned tensor.
267   // It is only required that the stored slices cover the requested slice,
268   // namely "slice_spec" is a subset of the union of the stored slices.
269   // REQUIRES: status().ok()
270   Status LookupSlice(StringPiece full_tensor_key, const TensorSlice& slice_spec,
271                      Tensor* val) TF_MUST_USE_RESULT;
272 
273   // Seeks to the first position in the bundle whose key is no less than "key".
274   // REQUIRES: status().ok()
Seek(StringPiece key)275   void Seek(StringPiece key) { return iter_->Seek(key); }
276   // Moves to the next position in the bundle.
277   // REQUIRES: status().ok()
Next()278   void Next() const { iter_->Next(); }
279   // Returns true iff the reader is positioned to a key/val pair.
280   // REQUIRES: status().ok()
Valid()281   bool Valid() const { return iter_->Valid(); }
282 
283   // Returns the key at the current position.
284   // REQUIRES: status().ok() && Valid()
key()285   StringPiece key() const { return iter_->key(); }
286   // Returns the raw value at the current position.
287   // REQUIRES: status().ok() && Valid()
value()288   StringPiece value() const { return iter_->value(); }
289 
290   string DebugString();
291 
292  private:
293   // Seeks for "key" and reads the metadata proto.
294   // On non-OK return, clears "entry" for the caller.
295   // REQUIRES: status().ok()
296   Status GetBundleEntryProto(StringPiece key,
297                              BundleEntryProto* entry) TF_MUST_USE_RESULT;
298 
299   // Reads the tensor value described by the metadata proto "entry".
300   // Usage for "val" follows the comment of "Lookup()".
301   Status GetValue(const BundleEntryProto& entry,
302                   Tensor* val) TF_MUST_USE_RESULT;
303 
304   // Reads the slice described by "slice_spec".  The corresponding full tensor
305   // has key "ful_tensor_key" and metadata proto "full_tensor_entry".
306   // REQUIRES: full_tensor_entry.slices_size() > 0
307   Status GetSliceValue(StringPiece full_tensor_key,
308                        const BundleEntryProto& full_tensor_entry,
309                        const TensorSlice& slice_spec,
310                        Tensor* val) TF_MUST_USE_RESULT;
311 
312   Env* env_;  // Not owned.
313   const string prefix_;
314 
315   Status status_;
316   RandomAccessFile* metadata_;  // Owned.
317   table::Table* table_;
318   table::Cache* index_cache_;
319   table::Iterator* iter_;
320   // Owned the InputBuffer objects and their underlying RandomAccessFile's.
321   std::unordered_map<int32, io::InputBuffer*> data_;
322 
323   // Maps each partitioned tensor's key to its stored slices (represented in a
324   // TensorSliceSet).  Populated on-demand.
325   std::unordered_map<string, checkpoint::TensorSliceSet*> tensor_slices_;
326 
327   // Expected number of data file shards in the bundle.  Extracted by reading
328   // the header entry in the metadata table.
329   int num_shards_;
330 
331   // Flag that this class sets to true when the endianness of the target bundle
332   // differs from that of the current system's processor architecture.
333   bool need_to_swap_bytes_;
334 
335   friend class TensorBundleAlignmentTest;  // For testing data alignment.
336 
337   TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);
338 };
339 
340 // A buffering wrapper for a WritableFile.  Useful if the caller wishes to issue
341 // small writes to a file (e.g. writing out a list of small varints).
342 // External synchronization must be used in the presence of concurrent callers.
343 class FileOutputBuffer {
344  public:
345   FileOutputBuffer(WritableFile* file, size_t buffer_size);
346   ~FileOutputBuffer();
347 
348   // Buffered append.
349   Status Append(StringPiece data);
350 
351   // Returns the running crc32c checksum of all currently appended bytes.
crc32c()352   uint32 crc32c() { return crc32c_; }
353   // Clears the running crc32c checksum.
clear_crc32c()354   void clear_crc32c() { crc32c_ = 0; }
355 
356   // Appends the buffered data, then closes the underlying file.
357   Status Close();
358 
359  private:
360   // Appends the buffered data to the underlying file. Does NOT flush the file.
361   Status FlushBuffer(bool closing);
362 
363   WritableFile* file_;  // Owned.
364 
365   // buffer_ptr_[0, position_) holds the buffered data not yet appended to the
366   // underlying file.
367   size_t position_;
368   const size_t buffer_size_;
369   char* buffer_ptr_;
370 
371   // Checksum of all appended bytes since construction or last clear_crc32c().
372   uint32 crc32c_ = 0;
373 };
374 
375 template <class T>
SortForSequentialAccess(std::vector<T> & container,absl::FunctionRef<string (const T &)> get_key)376 Status BundleReader::SortForSequentialAccess(
377     std::vector<T>& container, absl::FunctionRef<string(const T&)> get_key) {
378   struct FileOffset {
379     int32_t shard_id;
380     int64_t offset;
381   };
382   absl::flat_hash_map<string, FileOffset> file_offsets;
383   for (const T& element : container) {
384     BundleEntryProto entry;
385     TF_RETURN_IF_ERROR(GetBundleEntryProto(get_key(element), &entry));
386     file_offsets[get_key(element)] = {entry.shard_id(), entry.offset()};
387   }
388   absl::c_sort(container, [&get_key, &file_offsets](const T& a, const T& b) {
389     const FileOffset& file_offset_a = file_offsets[get_key(a)];
390     const FileOffset& file_offset_b = file_offsets[get_key(b)];
391     if (file_offset_a.shard_id == file_offset_b.shard_id) {
392       return file_offset_a.offset < file_offset_b.offset;
393     } else {
394       return file_offset_a.shard_id < file_offset_b.shard_id;
395     }
396   });
397   return OkStatus();
398 }
399 
400 }  // namespace tensorflow
401 
402 #endif  // TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
403