xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/append_slices_op.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <functional>
19 #include <memory>
20 #include <numeric>
21 #include <queue>
22 #include <string>
23 #include <utility>
24 
25 #include "absl/base/attributes.h"
26 #include "absl/base/const_init.h"
27 #include "absl/synchronization/mutex.h"
28 #include "tensorflow/core/framework/bounds_check.h"
29 #include "tensorflow/core/framework/common_shape_fns.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/op_requires.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/shape_inference.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/tensor_slice.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/framework/versions.pb.h"
38 #include "tensorflow/core/lib/io/table.h"
39 #include "tensorflow/core/lib/io/table_builder.h"
40 #include "tensorflow/core/lib/io/table_options.h"
41 #include "tensorflow/core/platform/errors.h"
42 #include "tensorflow/core/platform/file_system.h"
43 #include "tensorflow/core/platform/stringpiece.h"
44 #include "tensorflow/core/protobuf/error_codes.pb.h"
45 #include "tensorflow/core/public/version.h"
46 #include "tensorflow/core/util/saved_tensor_slice.pb.h"
47 #include "tensorflow/core/util/saved_tensor_slice_util.h"
48 #include "tensorflow/core/util/tensor_slice_reader.h"
49 #include "tensorflow/core/util/tensor_slice_writer.h"
50 
51 namespace fcp {
52 namespace {
53 
54 using ::tensorflow::OpKernel;
55 using ::tensorflow::OpKernelConstruction;
56 using ::tensorflow::OpKernelContext;
57 
58 constexpr absl::string_view kSavedTensorSlicesKey = "";
59 
60 // Returns the host-endian byte representation of `value`.
61 //
62 // The `value` must be non-null and must continue to be valid as long as the
63 // return value is used.
Int64ToHostEndianBytes(int64_t * value)64 absl::string_view Int64ToHostEndianBytes(int64_t* value) {
65   return absl::string_view(reinterpret_cast<const char*>(value),
66                            sizeof(int64_t));
67 }
68 
69 // Returns `value` intepreted as the host-endian bytes of an `int64_t`.
Int64FromHostEndianBytes(const char value[sizeof (int64_t)])70 int64_t Int64FromHostEndianBytes(const char value[sizeof(int64_t)]) {
71   return *reinterpret_cast<const int64_t*>(value);
72 }
73 
74 // Implementation of the save ops.
75 //
76 // This is copied without change from save_restore_tensor.cc because that target
77 // cannot be  included in `tf_custom_op_library` targets due to its dependency
78 // on `//third_party/tensorflow/core:framework`.
SaveTensors(OpKernelContext * context,tensorflow::checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,bool save_slices)79 void SaveTensors(
80     OpKernelContext* context,
81     tensorflow::checkpoint::TensorSliceWriter::CreateBuilderFunction
82         builder_func,
83     bool save_slices) {
84   const tensorflow::Tensor& filename_t = context->input(0);
85   {
86     const int64_t size = filename_t.NumElements();
87     OP_REQUIRES(
88         context, size == 1,
89         tensorflow::errors::InvalidArgument(
90             "Input 0 (filename) must be a string scalar; got a tensor of ",
91             size, "elements"));
92   }
93   const std::string& filename = filename_t.scalar<tensorflow::tstring>()();
94 
95   // Path, names, and slices if save_slices is true.
96   const int kFixedInputs = save_slices ? 3 : 2;
97   const tensorflow::Tensor& tensor_names_t = context->input(1);
98   OP_REQUIRES(
99       context,
100       tensorflow::FastBoundsCheck(tensor_names_t.NumElements() + kFixedInputs,
101                                   std::numeric_limits<int>::max()),
102       tensorflow::errors::InvalidArgument("Too many inputs to SaveTensors"));
103   const int N = static_cast<int>(tensor_names_t.NumElements());
104   const tensorflow::tstring* tensor_shapes_and_slices_ptr = nullptr;
105   if (save_slices) {
106     const tensorflow::Tensor& tensor_shapes_and_slices_t = context->input(2);
107     OP_REQUIRES(
108         context,
109         tensor_shapes_and_slices_t.NumElements() == static_cast<int64_t>(N),
110         tensorflow::errors::InvalidArgument(
111             "Expected ", N,
112             " elements for the tensor "
113             "shapes and slices but got ",
114             tensor_shapes_and_slices_t.NumElements()));
115     tensor_shapes_and_slices_ptr =
116         tensor_shapes_and_slices_t.flat<tensorflow::tstring>().data();
117   }
118   OP_REQUIRES(
119       context, context->num_inputs() == N + kFixedInputs,
120       tensorflow::errors::InvalidArgument(
121           "Expected totally ", N + kFixedInputs,
122           " inputs as input #1 (which is a string "
123           "tensor of saved names) contains ",
124           N, " names, but received ", context->num_inputs(), " inputs"));
125 
126   VLOG(1) << "About to save tensors to file " << filename << "...";
127   tensorflow::checkpoint::TensorSliceWriter writer(filename,
128                                                    std::move(builder_func));
129 
130   tensorflow::Status s;
131   auto tensor_names_flat = tensor_names_t.flat<tensorflow::tstring>();
132 
133   // Process tensors in sorted name order.  This allows us to avoid seeking
134   // during restoration in the common case where we are restoring a full
135   // checkpoint.
136   // RestoreTensorsV2 was changed to sort by file offset, so this sorting isn't
137   // strictly necessary anymore. However, restores with TF version <= 2.7 will
138   // still benefit.
139   std::vector<int> sorted_name_idx(tensor_names_flat.size());
140   std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
141   std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
142             [&tensor_names_flat](size_t a, size_t b) {
143               return tensor_names_flat(a) < tensor_names_flat(b);
144             });
145 
146   for (const int i : sorted_name_idx) {
147     const std::string& name = tensor_names_flat(i);
148     const tensorflow::Tensor& input = context->input(i + kFixedInputs);
149     tensorflow::TensorShape shape(input.shape());
150     tensorflow::TensorSlice slice(input.dims());
151     if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) {
152       const tensorflow::tstring& shape_spec = tensor_shapes_and_slices_ptr[i];
153       tensorflow::TensorShape slice_shape;
154       OP_REQUIRES_OK(context, tensorflow::checkpoint::ParseShapeAndSlice(
155                                   shape_spec, &shape, &slice, &slice_shape));
156       OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()),
157                   tensorflow::errors::InvalidArgument(
158                       "Slice in shape_and_slice "
159                       "specification does not match the "
160                       "shape of the tensor to  save: ",
161                       shape_spec, ", tensor: ", input.shape().DebugString()));
162     }
163 
164 #define WRITER_ADD(T)                                           \
165   case tensorflow::DataTypeToEnum<T>::value:                    \
166     s = writer.Add(name, shape, slice, input.flat<T>().data()); \
167     break;
168 
169     switch (input.dtype()) {
170       TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD)
171       default:
172         context->SetStatus(tensorflow::errors::Unimplemented(
173             "Saving data type ", DataTypeString(input.dtype()),
174             " not yet supported"));
175         return;
176     }
177 #undef WRITER_ADD
178     if (!s.ok()) {
179       context->SetStatus(s);
180       return;
181     }
182   }
183 
184   s = writer.Finish();
185   if (!s.ok()) {
186     context->SetStatus(s);
187   }
188 }
189 
190 // A `WritableFile` that wraps an existing file, appending a chunk with a length
191 // footer to the end of it.
192 //
193 // File start position is stored as a footer since `WritableFile` does not allow
194 // `Seek`ing to modify an earlier position in the file.
195 class AppendedFileWithStartPosFooter : public tensorflow::WritableFile {
196  public:
FromFile(std::unique_ptr<tensorflow::WritableFile> file,std::unique_ptr<tensorflow::WritableFile> & wrapped_file_out)197   static tensorflow::Status FromFile(
198       std::unique_ptr<tensorflow::WritableFile> file,
199       std::unique_ptr<tensorflow::WritableFile>& wrapped_file_out) {
200     int64_t body_start;
201     TF_RETURN_IF_ERROR(file->Tell(&body_start));
202     VLOG(1) << "Appending to checkpoint with starting position " << body_start;
203     // Note: cannot use `make_unique` due to private constructor.
204     wrapped_file_out = std::unique_ptr<tensorflow::WritableFile>(
205         new AppendedFileWithStartPosFooter(std::move(file), body_start));
206     return tensorflow::OkStatus();
207   }
Append(tensorflow::StringPiece data)208   tensorflow::Status Append(tensorflow::StringPiece data) override {
209     return file_->Append(data);
210   }
Close()211   tensorflow::Status Close() override {
212     TF_RETURN_IF_ERROR(file_->Append(Int64ToHostEndianBytes(&body_start_)));
213     return file_->Close();
214   }
Flush()215   tensorflow::Status Flush() override { return file_->Flush(); }
Sync()216   tensorflow::Status Sync() override { return file_->Sync(); }
Tell(int64_t * position)217   tensorflow::Status Tell(int64_t* position) override {
218     int64_t internal_position;
219     TF_RETURN_IF_ERROR(file_->Tell(&internal_position));
220     *position = internal_position - body_start_;
221     return tensorflow::OkStatus();
222   }
223 
224  private:
AppendedFileWithStartPosFooter(std::unique_ptr<tensorflow::WritableFile> file,int64_t body_start)225   AppendedFileWithStartPosFooter(std::unique_ptr<tensorflow::WritableFile> file,
226                                  int64_t body_start)
227       : file_(std::move(file)), body_start_(body_start) {}
228 
229   std::unique_ptr<tensorflow::WritableFile> file_;
230   int64_t body_start_;
231 };
232 
233 // An implementation of the `TensorSliceWriter::Builder` interface which
234 // delegates to `tensorflow::table::TableBuilder`.
235 class TableBuilder : public tensorflow::checkpoint::TensorSliceWriter::Builder {
236  public:
TableBuilder(std::string name,std::unique_ptr<tensorflow::WritableFile> file)237   TableBuilder(std::string name, std::unique_ptr<tensorflow::WritableFile> file)
238       : name_(std::move(name)), file_(std::move(file)) {
239     tensorflow::table::Options option;
240     option.compression = tensorflow::table::kNoCompression;
241     builder_ =
242         std::make_unique<tensorflow::table::TableBuilder>(option, file_.get());
243   }
Add(tensorflow::StringPiece key,tensorflow::StringPiece val)244   void Add(tensorflow::StringPiece key, tensorflow::StringPiece val) override {
245     builder_->Add(key, val);
246   }
Finish(int64_t * file_size)247   tensorflow::Status Finish(int64_t* file_size) override {
248     *file_size = -1;
249     tensorflow::Status s = builder_->Finish();
250     if (s.ok()) {
251       s = file_->Close();
252       if (s.ok()) {
253         *file_size = builder_->FileSize();
254       }
255     }
256     if (!s.ok()) {
257       s = tensorflow::errors::Internal(
258 #if TF_GRAPH_DEF_VERSION < 1467
259           "Error writing (tmp) checkpoint file: ", name_, ": ",
260           s.error_message());
261 #else
262           "Error writing (tmp) checkpoint file: ", name_, ": ", s.message());
263 #endif
264     }
265     return s;
266   }
267 
268  private:
269   std::string name_;
270   std::unique_ptr<tensorflow::WritableFile> file_;
271   std::unique_ptr<tensorflow::table::TableBuilder> builder_;
272 };
273 
274 // Creates a new `TensorSliceWriter::Builder` which will append the tensor
275 // slices to `filename` along with a footer indicating the start position of
276 // this particular chunk of slices.
277 //
278 // If this method returns `OK`, `builder` will contain a new owned pointer to
279 // a `TensorSliceWriter::Builder`.
CreateAppendingTensorSliceBuilder(const std::string & filename,tensorflow::checkpoint::TensorSliceWriter::Builder ** builder)280 tensorflow::Status CreateAppendingTensorSliceBuilder(
281     const std::string& filename,
282     tensorflow::checkpoint::TensorSliceWriter::Builder** builder) {
283   *builder = nullptr;
284   if (VLOG_IS_ON(1)) {
285     uint64_t file_size = 0;
286     if (tensorflow::Env::Default()->GetFileSize(filename, &file_size).ok()) {
287       VLOG(1) << "Appending checkpoint to file " << filename << " with size "
288               << file_size;
289     } else {
290       VLOG(1) << "Appending checkpoint to new file " << filename;
291     }
292   }
293   std::unique_ptr<tensorflow::WritableFile> file;
294   TF_RETURN_IF_ERROR(
295       tensorflow::Env::Default()->NewAppendableFile(filename, &file));
296   std::unique_ptr<tensorflow::WritableFile> wrapped_file;
297   TF_RETURN_IF_ERROR(
298       AppendedFileWithStartPosFooter::FromFile(std::move(file), wrapped_file));
299   *builder = new TableBuilder(filename, std::move(wrapped_file));
300   return tensorflow::OkStatus();
301 }
302 
303 // A `RandomAccessFile` which wraps another `RandomAccessFile`, providing access
304 // to only a portion of the file.
305 class PartialRandomAccessFile : public tensorflow::RandomAccessFile {
306  public:
307   // Constructs a `PartialRandomAccessFile` pointing to a segment of `file`.
308   //
309   // `file` must be non-null and must continue to be valid as long as the
310   // return value is used.
PartialRandomAccessFile(tensorflow::RandomAccessFile * file,int64_t start,int64_t end)311   PartialRandomAccessFile(tensorflow::RandomAccessFile* file, int64_t start,
312                           int64_t end)
313       : file_(file), start_(start), end_(end) {}
314   ~PartialRandomAccessFile() override = default;
Read(uint64_t offset,size_t n,tensorflow::StringPiece * result,char * scratch) const315   tensorflow::Status Read(uint64_t offset, size_t n,
316                           tensorflow::StringPiece* result,
317                           char* scratch) const override {
318     const size_t max_allowable_n = end_ - (start_ + offset);
319     bool read_too_long = n > max_allowable_n;
320     if (read_too_long) {
321       n = max_allowable_n;
322     }
323     TF_RETURN_WITH_CONTEXT_IF_ERROR(
324         file_->Read(offset + start_, n, result, scratch),
325         absl::StrCat("Reading from PartialRandomAccessFile at offset ", offset,
326                      " from start position ", start_));
327     if (read_too_long) {
328       return tensorflow::Status(
329           static_cast<tensorflow::errors::Code>(absl::StatusCode::kOutOfRange),
330           "Attempted to read past end of file chunk.");
331     }
332     return tensorflow::OkStatus();
333   }
334 
335  private:
336   tensorflow::RandomAccessFile* file_;
337   int64_t start_;
338   int64_t end_;
339 };
340 
341 struct TableIteratorComparator {
342   // Returns whether `i1` should come after `i2` in the priority queue.
343   // That is, whether `i1` has *lower* priority than `i2`.
operator ()fcp::__anoncdcc8cda0111::TableIteratorComparator344   bool operator()(const std::unique_ptr<tensorflow::table::Iterator>& i1,
345                   const std::unique_ptr<tensorflow::table::Iterator>& i2) {
346     // Ensure that iterators which have no remaining elements go last in the
347     // list.
348     if (!i2->Valid()) {
349       return false;
350     }
351     if (!i1->Valid()) {
352       return true;
353     }
354     if ((i2->key() == kSavedTensorSlicesKey) &&
355         (i1->key() != kSavedTensorSlicesKey)) {
356       return true;
357     }
358     return i1->key() > i2->key();
359   }
360 };
361 
362 // Pops and returns the top element of a `std::priority_queue`.
363 template <class Element, class Container, class Comparator>
PopWithElement(std::priority_queue<Element,Container,Comparator> & queue)364 Element PopWithElement(
365     std::priority_queue<Element, Container, Comparator>& queue) {
366   Element e = std::move(const_cast<Element&>(queue.top()));
367   queue.pop();
368   return e;
369 }
370 
371 // Parses a `serialized` into a `SavedTensorSlices` stored in `meta_out`.
MetadataFromString(absl::string_view serialized,tensorflow::SavedTensorSlices & meta_out)372 tensorflow::Status MetadataFromString(absl::string_view serialized,
373                                       tensorflow::SavedTensorSlices& meta_out) {
374   // NOTE: The conversion to `std::string` is unfortunately necessary here
375   // because the OSS version of `ParseFromString` takes a `const std::string&`
376   // rather than a `absl::string_view`.
377   if (!meta_out.ParseFromString(std::string(serialized))) {
378     return tensorflow::Status(
379         static_cast<tensorflow::errors::Code>(absl::StatusCode::kInternal),
380         absl::StrCat("Failed to parse table entry as `SavedTensorSlices`: ",
381                      serialized));
382   }
383   return tensorflow::OkStatus();
384 }
385 
386 // Merges appended checkpoints in `filename` into a single checkpoint.
387 //
388 // Note: this function accepts `filename` as a `const std::string&` rather than
389 // `string_view` because that is the type accepted by the functions it calls
390 // (`GetFileSize` and `NewRandomAccessFile`). This avoids unnecessary
391 // allocation.
LoadAndMergeAppendedSlices(const std::string & filename)392 tensorflow::Status LoadAndMergeAppendedSlices(const std::string& filename) {
393   tensorflow::Env* env = tensorflow::Env::Default();
394   uint64_t file_size;
395   TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
396   // Short-circuit on empty files so that we can assume at least a single entry
397   // below.
398   if (file_size == 0) {
399     return tensorflow::OkStatus();
400   }
401   std::unique_ptr<tensorflow::RandomAccessFile> file;
402   TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
403 
404   // Overwrite the underlying file, relying on `file` above to provide a handle
405   // into the old file contents even after it is overwritten.
406   TF_RETURN_IF_ERROR(tensorflow::Env::Default()->DeleteFile(filename));
407 
408   // `chunk_files` and `chunk_tables` must be kept around since they are
409   // referenced internally by `chunk_iterators`.
410   std::vector<std::unique_ptr<tensorflow::RandomAccessFile>> chunk_files;
411   std::vector<std::unique_ptr<tensorflow::table::Table>> chunk_tables;
412   std::priority_queue<std::unique_ptr<tensorflow::table::Iterator>,
413                       std::vector<std::unique_ptr<tensorflow::table::Iterator>>,
414                       TableIteratorComparator>
415       chunk_iterators;
416 
417   tensorflow::SavedTensorSlices merged_sts;
418   tensorflow::SavedTensorSliceMeta* merged_meta = merged_sts.mutable_meta();
419   std::set<std::string> slices_added;
420 
421   // Read all of the chunks into tables.
422   int64_t chunk_footer_end = file_size;
423   bool version_was_set = false;
424   while (chunk_footer_end > 0) {
425     // Read in the footer telling us where the chunk started.
426     char footer_scratch[sizeof(int64_t)];
427     tensorflow::StringPiece chunk_footer;
428     TF_RETURN_IF_ERROR(file->Read(chunk_footer_end - sizeof(int64_t),
429                                   sizeof(int64_t), &chunk_footer,
430                                   footer_scratch));
431     int64_t chunk_start = Int64FromHostEndianBytes(chunk_footer.data());
432     int64_t chunk_end = chunk_footer_end - sizeof(int64_t);
433     int64_t chunk_len = chunk_end - chunk_start;
434     std::unique_ptr<tensorflow::RandomAccessFile> chunk_file =
435         std::make_unique<PartialRandomAccessFile>(file.get(), chunk_start,
436                                                   chunk_end);
437     tensorflow::table::Options options;
438     tensorflow::table::Table* raw_table;
439     TF_RETURN_WITH_CONTEXT_IF_ERROR(
440         tensorflow::table::Table::Open(options, chunk_file.get(), chunk_len,
441                                        &raw_table),
442         absl::StrCat("Error opening sub-table of file ", filename,
443                      " starting at ", chunk_start, " and ending at ", chunk_end,
444                      ". Total file size: ", file_size));
445     std::unique_ptr<tensorflow::table::Table> table(raw_table);
446     tensorflow::table::Iterator* raw_iterator = table->NewIterator();
447     std::unique_ptr<tensorflow::table::Iterator> iterator(raw_iterator);
448     iterator->SeekToFirst();
449     if (!iterator->Valid()) {
450       return tensorflow::Status(
451           static_cast<tensorflow::errors::Code>(absl::StatusCode::kInternal),
452           "Unexpected immediately-invalid iterator. "
453           "Expected table to iterator to have at least a "
454           "single entry (metadata)");
455     }
456     if (iterator->key() != kSavedTensorSlicesKey) {
457       return tensorflow::Status(
458           static_cast<tensorflow::errors::Code>(absl::StatusCode::kInternal),
459           absl::StrCat("Expected table iterator to have an initial metadata "
460                        "entry with key `",
461                        kSavedTensorSlicesKey, "`, found key `", iterator->key(),
462                        "`"));
463     }
464     tensorflow::SavedTensorSlices sts;
465     TF_RETURN_IF_ERROR(MetadataFromString(iterator->value(), sts));
466     iterator->Next();
467     if (!version_was_set) {
468       version_was_set = true;
469       *merged_meta->mutable_versions() = sts.meta().versions();
470     }
471     for (const tensorflow::SavedSliceMeta& slice_meta : sts.meta().tensor()) {
472       if (slices_added.find(slice_meta.name()) != slices_added.end()) {
473         return tensorflow::Status(
474             // Remove the cast after TF 2.12 is released and used in FCP.
475             static_cast<tensorflow::errors::Code>(
476                 absl::StatusCode::kInvalidArgument),
477             absl::StrCat(
478                 "Attempted to merge two checkpoint entries for slice name: `",
479                 slice_meta.name(), "`. Only one entry per name is permitted."));
480       }
481       slices_added.insert(slice_meta.name());
482     }
483     merged_meta->mutable_tensor()->MergeFrom(sts.meta().tensor());
484     chunk_iterators.push(std::move(iterator));
485     chunk_files.push_back(std::move(chunk_file));
486     chunk_tables.push_back(std::move(table));
487     chunk_footer_end = chunk_start;
488   }
489   VLOG(1) << "Merging " << chunk_files.size() << " checkpoint chunks from file "
490           << filename;
491 
492   tensorflow::checkpoint::TensorSliceWriter::Builder* raw_builder;
493   TF_RETURN_IF_ERROR(tensorflow::checkpoint::CreateTableTensorSliceBuilder(
494       filename, &raw_builder));
495   std::unique_ptr<tensorflow::checkpoint::TensorSliceWriter::Builder> builder(
496       raw_builder);
497 
498   // First, we add the merged entry which holds a `SavedTensorSlices` proto.
499   builder->Add(kSavedTensorSlicesKey, merged_sts.SerializeAsString());
500 
501   // Then the remaining entries are concatenated alphabetically.
502   while (chunk_iterators.top()->Valid()) {
503     std::unique_ptr<tensorflow::table::Iterator> iter =
504         PopWithElement(chunk_iterators);
505     VLOG(2) << "Merging table entry for key " << iter->key();
506     builder->Add(iter->key(), iter->value());
507     iter->Next();
508     chunk_iterators.push(std::move(iter));
509   }
510   int64_t resulting_file_size;
511   TF_RETURN_WITH_CONTEXT_IF_ERROR(builder->Finish(&resulting_file_size),
512                                   "Finishing TensorSliceWriter::Builder");
513   return tensorflow::OkStatus();
514 }
515 
516 ABSL_CONST_INIT absl::Mutex append_mutex(absl::kConstInit);
517 
518 }  // namespace
519 
520 class AppendSlicesOp : public OpKernel {
521  public:
AppendSlicesOp(OpKernelConstruction * context)522   explicit AppendSlicesOp(OpKernelConstruction* context) : OpKernel(context) {}
523 
Compute(OpKernelContext * context)524   void Compute(OpKernelContext* context) override {
525     absl::MutexLock lock(&append_mutex);
526     const tensorflow::Tensor& filename_t = context->input(0);
527     tensorflow::tstring filename = filename_t.flat<tensorflow::tstring>()(0);
528     SaveTensors(
529         context,
530         [context](
531             const std::string& target_filename,
532             tensorflow::checkpoint::TensorSliceWriter::Builder** builder) {
533           // `TensorSliceWriter` targets writing to a new temporary file which
534           // it then moves into the location of the final file once complete.
535           // In order to comply with this behavior while still retaining
536           // "append" semantics, the original file (if it exists) is first moved
537           // into the temporary target location.
538           tensorflow::tstring original_filename =
539               context->input(0).scalar<tensorflow::tstring>()();
540           tensorflow::Status status = tensorflow::Env::Default()->RenameFile(
541               original_filename, target_filename);
542           if (status.ok()) {
543             VLOG(1) << "Appending to existing file " << original_filename
544                     << " via move to temporary location " << target_filename;
545           } else if (status.code() == tensorflow::error::NOT_FOUND) {
546             VLOG(1) << "Appending to new file " << original_filename
547                     << " in temporary location " << target_filename;
548           } else {
549             return status;
550           }
551           return CreateAppendingTensorSliceBuilder(target_filename, builder);
552         },
553         /*save_slices=*/true);
554   }
555 };
556 
557 class MergeAppendedSlicesOp : public OpKernel {
558  public:
MergeAppendedSlicesOp(OpKernelConstruction * context)559   explicit MergeAppendedSlicesOp(OpKernelConstruction* context)
560       : OpKernel(context) {}
561 
Compute(OpKernelContext * context)562   void Compute(OpKernelContext* context) override {
563     absl::MutexLock lock(&append_mutex);
564     const tensorflow::Tensor* filename_tensor;
565     OP_REQUIRES_OK(context, context->input("filename", &filename_tensor));
566     const tensorflow::tstring filename =
567         filename_tensor->scalar<tensorflow::tstring>()();
568     OP_REQUIRES_OK(context, LoadAndMergeAppendedSlices(filename));
569   }
570 };
571 
572 // Note: `key` *must* come last so that the indices of the other arguments are
573 // as expected by `SaveTensors`.
574 REGISTER_OP("AppendSlices")
575     .Input("filename: string")
576     .Input("tensor_names: string")
577     .Input("shapes_and_slices: string")
578     .Input("data: T")
579     .Attr("T: list(type)")
580     .SetIsStateful();
581 
582 REGISTER_KERNEL_BUILDER(Name("AppendSlices").Device(tensorflow::DEVICE_CPU),
583                         AppendSlicesOp);
584 
585 REGISTER_OP("MergeAppendedSlices").Input("filename: string").SetIsStateful();
586 
587 REGISTER_KERNEL_BUILDER(
588     Name("MergeAppendedSlices").Device(tensorflow::DEVICE_CPU),
589     MergeAppendedSlicesOp);
590 
591 }  // namespace fcp
592