1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include <functional>
19 #include <memory>
20 #include <numeric>
21 #include <queue>
22 #include <string>
23 #include <utility>
24
25 #include "absl/base/attributes.h"
26 #include "absl/base/const_init.h"
27 #include "absl/synchronization/mutex.h"
28 #include "tensorflow/core/framework/bounds_check.h"
29 #include "tensorflow/core/framework/common_shape_fns.h"
30 #include "tensorflow/core/framework/op_kernel.h"
31 #include "tensorflow/core/framework/op_requires.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/shape_inference.h"
34 #include "tensorflow/core/framework/tensor_shape.h"
35 #include "tensorflow/core/framework/tensor_slice.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/framework/versions.pb.h"
38 #include "tensorflow/core/lib/io/table.h"
39 #include "tensorflow/core/lib/io/table_builder.h"
40 #include "tensorflow/core/lib/io/table_options.h"
41 #include "tensorflow/core/platform/errors.h"
42 #include "tensorflow/core/platform/file_system.h"
43 #include "tensorflow/core/platform/stringpiece.h"
44 #include "tensorflow/core/protobuf/error_codes.pb.h"
45 #include "tensorflow/core/public/version.h"
46 #include "tensorflow/core/util/saved_tensor_slice.pb.h"
47 #include "tensorflow/core/util/saved_tensor_slice_util.h"
48 #include "tensorflow/core/util/tensor_slice_reader.h"
49 #include "tensorflow/core/util/tensor_slice_writer.h"
50
51 namespace fcp {
52 namespace {
53
54 using ::tensorflow::OpKernel;
55 using ::tensorflow::OpKernelConstruction;
56 using ::tensorflow::OpKernelContext;
57
58 constexpr absl::string_view kSavedTensorSlicesKey = "";
59
60 // Returns the host-endian byte representation of `value`.
61 //
62 // The `value` must be non-null and must continue to be valid as long as the
63 // return value is used.
Int64ToHostEndianBytes(int64_t * value)64 absl::string_view Int64ToHostEndianBytes(int64_t* value) {
65 return absl::string_view(reinterpret_cast<const char*>(value),
66 sizeof(int64_t));
67 }
68
69 // Returns `value` intepreted as the host-endian bytes of an `int64_t`.
Int64FromHostEndianBytes(const char value[sizeof (int64_t)])70 int64_t Int64FromHostEndianBytes(const char value[sizeof(int64_t)]) {
71 return *reinterpret_cast<const int64_t*>(value);
72 }
73
74 // Implementation of the save ops.
75 //
76 // This is copied without change from save_restore_tensor.cc because that target
77 // cannot be included in `tf_custom_op_library` targets due to its dependency
78 // on `//third_party/tensorflow/core:framework`.
SaveTensors(OpKernelContext * context,tensorflow::checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,bool save_slices)79 void SaveTensors(
80 OpKernelContext* context,
81 tensorflow::checkpoint::TensorSliceWriter::CreateBuilderFunction
82 builder_func,
83 bool save_slices) {
84 const tensorflow::Tensor& filename_t = context->input(0);
85 {
86 const int64_t size = filename_t.NumElements();
87 OP_REQUIRES(
88 context, size == 1,
89 tensorflow::errors::InvalidArgument(
90 "Input 0 (filename) must be a string scalar; got a tensor of ",
91 size, "elements"));
92 }
93 const std::string& filename = filename_t.scalar<tensorflow::tstring>()();
94
95 // Path, names, and slices if save_slices is true.
96 const int kFixedInputs = save_slices ? 3 : 2;
97 const tensorflow::Tensor& tensor_names_t = context->input(1);
98 OP_REQUIRES(
99 context,
100 tensorflow::FastBoundsCheck(tensor_names_t.NumElements() + kFixedInputs,
101 std::numeric_limits<int>::max()),
102 tensorflow::errors::InvalidArgument("Too many inputs to SaveTensors"));
103 const int N = static_cast<int>(tensor_names_t.NumElements());
104 const tensorflow::tstring* tensor_shapes_and_slices_ptr = nullptr;
105 if (save_slices) {
106 const tensorflow::Tensor& tensor_shapes_and_slices_t = context->input(2);
107 OP_REQUIRES(
108 context,
109 tensor_shapes_and_slices_t.NumElements() == static_cast<int64_t>(N),
110 tensorflow::errors::InvalidArgument(
111 "Expected ", N,
112 " elements for the tensor "
113 "shapes and slices but got ",
114 tensor_shapes_and_slices_t.NumElements()));
115 tensor_shapes_and_slices_ptr =
116 tensor_shapes_and_slices_t.flat<tensorflow::tstring>().data();
117 }
118 OP_REQUIRES(
119 context, context->num_inputs() == N + kFixedInputs,
120 tensorflow::errors::InvalidArgument(
121 "Expected totally ", N + kFixedInputs,
122 " inputs as input #1 (which is a string "
123 "tensor of saved names) contains ",
124 N, " names, but received ", context->num_inputs(), " inputs"));
125
126 VLOG(1) << "About to save tensors to file " << filename << "...";
127 tensorflow::checkpoint::TensorSliceWriter writer(filename,
128 std::move(builder_func));
129
130 tensorflow::Status s;
131 auto tensor_names_flat = tensor_names_t.flat<tensorflow::tstring>();
132
133 // Process tensors in sorted name order. This allows us to avoid seeking
134 // during restoration in the common case where we are restoring a full
135 // checkpoint.
136 // RestoreTensorsV2 was changed to sort by file offset, so this sorting isn't
137 // strictly necessary anymore. However, restores with TF version <= 2.7 will
138 // still benefit.
139 std::vector<int> sorted_name_idx(tensor_names_flat.size());
140 std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
141 std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
142 [&tensor_names_flat](size_t a, size_t b) {
143 return tensor_names_flat(a) < tensor_names_flat(b);
144 });
145
146 for (const int i : sorted_name_idx) {
147 const std::string& name = tensor_names_flat(i);
148 const tensorflow::Tensor& input = context->input(i + kFixedInputs);
149 tensorflow::TensorShape shape(input.shape());
150 tensorflow::TensorSlice slice(input.dims());
151 if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) {
152 const tensorflow::tstring& shape_spec = tensor_shapes_and_slices_ptr[i];
153 tensorflow::TensorShape slice_shape;
154 OP_REQUIRES_OK(context, tensorflow::checkpoint::ParseShapeAndSlice(
155 shape_spec, &shape, &slice, &slice_shape));
156 OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()),
157 tensorflow::errors::InvalidArgument(
158 "Slice in shape_and_slice "
159 "specification does not match the "
160 "shape of the tensor to save: ",
161 shape_spec, ", tensor: ", input.shape().DebugString()));
162 }
163
164 #define WRITER_ADD(T) \
165 case tensorflow::DataTypeToEnum<T>::value: \
166 s = writer.Add(name, shape, slice, input.flat<T>().data()); \
167 break;
168
169 switch (input.dtype()) {
170 TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD)
171 default:
172 context->SetStatus(tensorflow::errors::Unimplemented(
173 "Saving data type ", DataTypeString(input.dtype()),
174 " not yet supported"));
175 return;
176 }
177 #undef WRITER_ADD
178 if (!s.ok()) {
179 context->SetStatus(s);
180 return;
181 }
182 }
183
184 s = writer.Finish();
185 if (!s.ok()) {
186 context->SetStatus(s);
187 }
188 }
189
190 // A `WritableFile` that wraps an existing file, appending a chunk with a length
191 // footer to the end of it.
192 //
193 // File start position is stored as a footer since `WritableFile` does not allow
194 // `Seek`ing to modify an earlier position in the file.
195 class AppendedFileWithStartPosFooter : public tensorflow::WritableFile {
196 public:
FromFile(std::unique_ptr<tensorflow::WritableFile> file,std::unique_ptr<tensorflow::WritableFile> & wrapped_file_out)197 static tensorflow::Status FromFile(
198 std::unique_ptr<tensorflow::WritableFile> file,
199 std::unique_ptr<tensorflow::WritableFile>& wrapped_file_out) {
200 int64_t body_start;
201 TF_RETURN_IF_ERROR(file->Tell(&body_start));
202 VLOG(1) << "Appending to checkpoint with starting position " << body_start;
203 // Note: cannot use `make_unique` due to private constructor.
204 wrapped_file_out = std::unique_ptr<tensorflow::WritableFile>(
205 new AppendedFileWithStartPosFooter(std::move(file), body_start));
206 return tensorflow::OkStatus();
207 }
Append(tensorflow::StringPiece data)208 tensorflow::Status Append(tensorflow::StringPiece data) override {
209 return file_->Append(data);
210 }
Close()211 tensorflow::Status Close() override {
212 TF_RETURN_IF_ERROR(file_->Append(Int64ToHostEndianBytes(&body_start_)));
213 return file_->Close();
214 }
Flush()215 tensorflow::Status Flush() override { return file_->Flush(); }
Sync()216 tensorflow::Status Sync() override { return file_->Sync(); }
Tell(int64_t * position)217 tensorflow::Status Tell(int64_t* position) override {
218 int64_t internal_position;
219 TF_RETURN_IF_ERROR(file_->Tell(&internal_position));
220 *position = internal_position - body_start_;
221 return tensorflow::OkStatus();
222 }
223
224 private:
AppendedFileWithStartPosFooter(std::unique_ptr<tensorflow::WritableFile> file,int64_t body_start)225 AppendedFileWithStartPosFooter(std::unique_ptr<tensorflow::WritableFile> file,
226 int64_t body_start)
227 : file_(std::move(file)), body_start_(body_start) {}
228
229 std::unique_ptr<tensorflow::WritableFile> file_;
230 int64_t body_start_;
231 };
232
233 // An implementation of the `TensorSliceWriter::Builder` interface which
234 // delegates to `tensorflow::table::TableBuilder`.
235 class TableBuilder : public tensorflow::checkpoint::TensorSliceWriter::Builder {
236 public:
TableBuilder(std::string name,std::unique_ptr<tensorflow::WritableFile> file)237 TableBuilder(std::string name, std::unique_ptr<tensorflow::WritableFile> file)
238 : name_(std::move(name)), file_(std::move(file)) {
239 tensorflow::table::Options option;
240 option.compression = tensorflow::table::kNoCompression;
241 builder_ =
242 std::make_unique<tensorflow::table::TableBuilder>(option, file_.get());
243 }
Add(tensorflow::StringPiece key,tensorflow::StringPiece val)244 void Add(tensorflow::StringPiece key, tensorflow::StringPiece val) override {
245 builder_->Add(key, val);
246 }
Finish(int64_t * file_size)247 tensorflow::Status Finish(int64_t* file_size) override {
248 *file_size = -1;
249 tensorflow::Status s = builder_->Finish();
250 if (s.ok()) {
251 s = file_->Close();
252 if (s.ok()) {
253 *file_size = builder_->FileSize();
254 }
255 }
256 if (!s.ok()) {
257 s = tensorflow::errors::Internal(
258 #if TF_GRAPH_DEF_VERSION < 1467
259 "Error writing (tmp) checkpoint file: ", name_, ": ",
260 s.error_message());
261 #else
262 "Error writing (tmp) checkpoint file: ", name_, ": ", s.message());
263 #endif
264 }
265 return s;
266 }
267
268 private:
269 std::string name_;
270 std::unique_ptr<tensorflow::WritableFile> file_;
271 std::unique_ptr<tensorflow::table::TableBuilder> builder_;
272 };
273
274 // Creates a new `TensorSliceWriter::Builder` which will append the tensor
275 // slices to `filename` along with a footer indicating the start position of
276 // this particular chunk of slices.
277 //
278 // If this method returns `OK`, `builder` will contain a new owned pointer to
279 // a `TensorSliceWriter::Builder`.
CreateAppendingTensorSliceBuilder(const std::string & filename,tensorflow::checkpoint::TensorSliceWriter::Builder ** builder)280 tensorflow::Status CreateAppendingTensorSliceBuilder(
281 const std::string& filename,
282 tensorflow::checkpoint::TensorSliceWriter::Builder** builder) {
283 *builder = nullptr;
284 if (VLOG_IS_ON(1)) {
285 uint64_t file_size = 0;
286 if (tensorflow::Env::Default()->GetFileSize(filename, &file_size).ok()) {
287 VLOG(1) << "Appending checkpoint to file " << filename << " with size "
288 << file_size;
289 } else {
290 VLOG(1) << "Appending checkpoint to new file " << filename;
291 }
292 }
293 std::unique_ptr<tensorflow::WritableFile> file;
294 TF_RETURN_IF_ERROR(
295 tensorflow::Env::Default()->NewAppendableFile(filename, &file));
296 std::unique_ptr<tensorflow::WritableFile> wrapped_file;
297 TF_RETURN_IF_ERROR(
298 AppendedFileWithStartPosFooter::FromFile(std::move(file), wrapped_file));
299 *builder = new TableBuilder(filename, std::move(wrapped_file));
300 return tensorflow::OkStatus();
301 }
302
303 // A `RandomAccessFile` which wraps another `RandomAccessFile`, providing access
304 // to only a portion of the file.
305 class PartialRandomAccessFile : public tensorflow::RandomAccessFile {
306 public:
307 // Constructs a `PartialRandomAccessFile` pointing to a segment of `file`.
308 //
309 // `file` must be non-null and must continue to be valid as long as the
310 // return value is used.
PartialRandomAccessFile(tensorflow::RandomAccessFile * file,int64_t start,int64_t end)311 PartialRandomAccessFile(tensorflow::RandomAccessFile* file, int64_t start,
312 int64_t end)
313 : file_(file), start_(start), end_(end) {}
314 ~PartialRandomAccessFile() override = default;
Read(uint64_t offset,size_t n,tensorflow::StringPiece * result,char * scratch) const315 tensorflow::Status Read(uint64_t offset, size_t n,
316 tensorflow::StringPiece* result,
317 char* scratch) const override {
318 const size_t max_allowable_n = end_ - (start_ + offset);
319 bool read_too_long = n > max_allowable_n;
320 if (read_too_long) {
321 n = max_allowable_n;
322 }
323 TF_RETURN_WITH_CONTEXT_IF_ERROR(
324 file_->Read(offset + start_, n, result, scratch),
325 absl::StrCat("Reading from PartialRandomAccessFile at offset ", offset,
326 " from start position ", start_));
327 if (read_too_long) {
328 return tensorflow::Status(
329 static_cast<tensorflow::errors::Code>(absl::StatusCode::kOutOfRange),
330 "Attempted to read past end of file chunk.");
331 }
332 return tensorflow::OkStatus();
333 }
334
335 private:
336 tensorflow::RandomAccessFile* file_;
337 int64_t start_;
338 int64_t end_;
339 };
340
341 struct TableIteratorComparator {
342 // Returns whether `i1` should come after `i2` in the priority queue.
343 // That is, whether `i1` has *lower* priority than `i2`.
operator ()fcp::__anoncdcc8cda0111::TableIteratorComparator344 bool operator()(const std::unique_ptr<tensorflow::table::Iterator>& i1,
345 const std::unique_ptr<tensorflow::table::Iterator>& i2) {
346 // Ensure that iterators which have no remaining elements go last in the
347 // list.
348 if (!i2->Valid()) {
349 return false;
350 }
351 if (!i1->Valid()) {
352 return true;
353 }
354 if ((i2->key() == kSavedTensorSlicesKey) &&
355 (i1->key() != kSavedTensorSlicesKey)) {
356 return true;
357 }
358 return i1->key() > i2->key();
359 }
360 };
361
362 // Pops and returns the top element of a `std::priority_queue`.
363 template <class Element, class Container, class Comparator>
PopWithElement(std::priority_queue<Element,Container,Comparator> & queue)364 Element PopWithElement(
365 std::priority_queue<Element, Container, Comparator>& queue) {
366 Element e = std::move(const_cast<Element&>(queue.top()));
367 queue.pop();
368 return e;
369 }
370
371 // Parses a `serialized` into a `SavedTensorSlices` stored in `meta_out`.
MetadataFromString(absl::string_view serialized,tensorflow::SavedTensorSlices & meta_out)372 tensorflow::Status MetadataFromString(absl::string_view serialized,
373 tensorflow::SavedTensorSlices& meta_out) {
374 // NOTE: The conversion to `std::string` is unfortunately necessary here
375 // because the OSS version of `ParseFromString` takes a `const std::string&`
376 // rather than a `absl::string_view`.
377 if (!meta_out.ParseFromString(std::string(serialized))) {
378 return tensorflow::Status(
379 static_cast<tensorflow::errors::Code>(absl::StatusCode::kInternal),
380 absl::StrCat("Failed to parse table entry as `SavedTensorSlices`: ",
381 serialized));
382 }
383 return tensorflow::OkStatus();
384 }
385
386 // Merges appended checkpoints in `filename` into a single checkpoint.
387 //
388 // Note: this function accepts `filename` as a `const std::string&` rather than
389 // `string_view` because that is the type accepted by the functions it calls
390 // (`GetFileSize` and `NewRandomAccessFile`). This avoids unnecessary
391 // allocation.
LoadAndMergeAppendedSlices(const std::string & filename)392 tensorflow::Status LoadAndMergeAppendedSlices(const std::string& filename) {
393 tensorflow::Env* env = tensorflow::Env::Default();
394 uint64_t file_size;
395 TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
396 // Short-circuit on empty files so that we can assume at least a single entry
397 // below.
398 if (file_size == 0) {
399 return tensorflow::OkStatus();
400 }
401 std::unique_ptr<tensorflow::RandomAccessFile> file;
402 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
403
404 // Overwrite the underlying file, relying on `file` above to provide a handle
405 // into the old file contents even after it is overwritten.
406 TF_RETURN_IF_ERROR(tensorflow::Env::Default()->DeleteFile(filename));
407
408 // `chunk_files` and `chunk_tables` must be kept around since they are
409 // referenced internally by `chunk_iterators`.
410 std::vector<std::unique_ptr<tensorflow::RandomAccessFile>> chunk_files;
411 std::vector<std::unique_ptr<tensorflow::table::Table>> chunk_tables;
412 std::priority_queue<std::unique_ptr<tensorflow::table::Iterator>,
413 std::vector<std::unique_ptr<tensorflow::table::Iterator>>,
414 TableIteratorComparator>
415 chunk_iterators;
416
417 tensorflow::SavedTensorSlices merged_sts;
418 tensorflow::SavedTensorSliceMeta* merged_meta = merged_sts.mutable_meta();
419 std::set<std::string> slices_added;
420
421 // Read all of the chunks into tables.
422 int64_t chunk_footer_end = file_size;
423 bool version_was_set = false;
424 while (chunk_footer_end > 0) {
425 // Read in the footer telling us where the chunk started.
426 char footer_scratch[sizeof(int64_t)];
427 tensorflow::StringPiece chunk_footer;
428 TF_RETURN_IF_ERROR(file->Read(chunk_footer_end - sizeof(int64_t),
429 sizeof(int64_t), &chunk_footer,
430 footer_scratch));
431 int64_t chunk_start = Int64FromHostEndianBytes(chunk_footer.data());
432 int64_t chunk_end = chunk_footer_end - sizeof(int64_t);
433 int64_t chunk_len = chunk_end - chunk_start;
434 std::unique_ptr<tensorflow::RandomAccessFile> chunk_file =
435 std::make_unique<PartialRandomAccessFile>(file.get(), chunk_start,
436 chunk_end);
437 tensorflow::table::Options options;
438 tensorflow::table::Table* raw_table;
439 TF_RETURN_WITH_CONTEXT_IF_ERROR(
440 tensorflow::table::Table::Open(options, chunk_file.get(), chunk_len,
441 &raw_table),
442 absl::StrCat("Error opening sub-table of file ", filename,
443 " starting at ", chunk_start, " and ending at ", chunk_end,
444 ". Total file size: ", file_size));
445 std::unique_ptr<tensorflow::table::Table> table(raw_table);
446 tensorflow::table::Iterator* raw_iterator = table->NewIterator();
447 std::unique_ptr<tensorflow::table::Iterator> iterator(raw_iterator);
448 iterator->SeekToFirst();
449 if (!iterator->Valid()) {
450 return tensorflow::Status(
451 static_cast<tensorflow::errors::Code>(absl::StatusCode::kInternal),
452 "Unexpected immediately-invalid iterator. "
453 "Expected table to iterator to have at least a "
454 "single entry (metadata)");
455 }
456 if (iterator->key() != kSavedTensorSlicesKey) {
457 return tensorflow::Status(
458 static_cast<tensorflow::errors::Code>(absl::StatusCode::kInternal),
459 absl::StrCat("Expected table iterator to have an initial metadata "
460 "entry with key `",
461 kSavedTensorSlicesKey, "`, found key `", iterator->key(),
462 "`"));
463 }
464 tensorflow::SavedTensorSlices sts;
465 TF_RETURN_IF_ERROR(MetadataFromString(iterator->value(), sts));
466 iterator->Next();
467 if (!version_was_set) {
468 version_was_set = true;
469 *merged_meta->mutable_versions() = sts.meta().versions();
470 }
471 for (const tensorflow::SavedSliceMeta& slice_meta : sts.meta().tensor()) {
472 if (slices_added.find(slice_meta.name()) != slices_added.end()) {
473 return tensorflow::Status(
474 // Remove the cast after TF 2.12 is released and used in FCP.
475 static_cast<tensorflow::errors::Code>(
476 absl::StatusCode::kInvalidArgument),
477 absl::StrCat(
478 "Attempted to merge two checkpoint entries for slice name: `",
479 slice_meta.name(), "`. Only one entry per name is permitted."));
480 }
481 slices_added.insert(slice_meta.name());
482 }
483 merged_meta->mutable_tensor()->MergeFrom(sts.meta().tensor());
484 chunk_iterators.push(std::move(iterator));
485 chunk_files.push_back(std::move(chunk_file));
486 chunk_tables.push_back(std::move(table));
487 chunk_footer_end = chunk_start;
488 }
489 VLOG(1) << "Merging " << chunk_files.size() << " checkpoint chunks from file "
490 << filename;
491
492 tensorflow::checkpoint::TensorSliceWriter::Builder* raw_builder;
493 TF_RETURN_IF_ERROR(tensorflow::checkpoint::CreateTableTensorSliceBuilder(
494 filename, &raw_builder));
495 std::unique_ptr<tensorflow::checkpoint::TensorSliceWriter::Builder> builder(
496 raw_builder);
497
498 // First, we add the merged entry which holds a `SavedTensorSlices` proto.
499 builder->Add(kSavedTensorSlicesKey, merged_sts.SerializeAsString());
500
501 // Then the remaining entries are concatenated alphabetically.
502 while (chunk_iterators.top()->Valid()) {
503 std::unique_ptr<tensorflow::table::Iterator> iter =
504 PopWithElement(chunk_iterators);
505 VLOG(2) << "Merging table entry for key " << iter->key();
506 builder->Add(iter->key(), iter->value());
507 iter->Next();
508 chunk_iterators.push(std::move(iter));
509 }
510 int64_t resulting_file_size;
511 TF_RETURN_WITH_CONTEXT_IF_ERROR(builder->Finish(&resulting_file_size),
512 "Finishing TensorSliceWriter::Builder");
513 return tensorflow::OkStatus();
514 }
515
516 ABSL_CONST_INIT absl::Mutex append_mutex(absl::kConstInit);
517
518 } // namespace
519
520 class AppendSlicesOp : public OpKernel {
521 public:
AppendSlicesOp(OpKernelConstruction * context)522 explicit AppendSlicesOp(OpKernelConstruction* context) : OpKernel(context) {}
523
Compute(OpKernelContext * context)524 void Compute(OpKernelContext* context) override {
525 absl::MutexLock lock(&append_mutex);
526 const tensorflow::Tensor& filename_t = context->input(0);
527 tensorflow::tstring filename = filename_t.flat<tensorflow::tstring>()(0);
528 SaveTensors(
529 context,
530 [context](
531 const std::string& target_filename,
532 tensorflow::checkpoint::TensorSliceWriter::Builder** builder) {
533 // `TensorSliceWriter` targets writing to a new temporary file which
534 // it then moves into the location of the final file once complete.
535 // In order to comply with this behavior while still retaining
536 // "append" semantics, the original file (if it exists) is first moved
537 // into the temporary target location.
538 tensorflow::tstring original_filename =
539 context->input(0).scalar<tensorflow::tstring>()();
540 tensorflow::Status status = tensorflow::Env::Default()->RenameFile(
541 original_filename, target_filename);
542 if (status.ok()) {
543 VLOG(1) << "Appending to existing file " << original_filename
544 << " via move to temporary location " << target_filename;
545 } else if (status.code() == tensorflow::error::NOT_FOUND) {
546 VLOG(1) << "Appending to new file " << original_filename
547 << " in temporary location " << target_filename;
548 } else {
549 return status;
550 }
551 return CreateAppendingTensorSliceBuilder(target_filename, builder);
552 },
553 /*save_slices=*/true);
554 }
555 };
556
557 class MergeAppendedSlicesOp : public OpKernel {
558 public:
MergeAppendedSlicesOp(OpKernelConstruction * context)559 explicit MergeAppendedSlicesOp(OpKernelConstruction* context)
560 : OpKernel(context) {}
561
Compute(OpKernelContext * context)562 void Compute(OpKernelContext* context) override {
563 absl::MutexLock lock(&append_mutex);
564 const tensorflow::Tensor* filename_tensor;
565 OP_REQUIRES_OK(context, context->input("filename", &filename_tensor));
566 const tensorflow::tstring filename =
567 filename_tensor->scalar<tensorflow::tstring>()();
568 OP_REQUIRES_OK(context, LoadAndMergeAppendedSlices(filename));
569 }
570 };
571
572 // Note: `key` *must* come last so that the indices of the other arguments are
573 // as expected by `SaveTensors`.
574 REGISTER_OP("AppendSlices")
575 .Input("filename: string")
576 .Input("tensor_names: string")
577 .Input("shapes_and_slices: string")
578 .Input("data: T")
579 .Attr("T: list(type)")
580 .SetIsStateful();
581
582 REGISTER_KERNEL_BUILDER(Name("AppendSlices").Device(tensorflow::DEVICE_CPU),
583 AppendSlicesOp);
584
585 REGISTER_OP("MergeAppendedSlices").Input("filename: string").SetIsStateful();
586
587 REGISTER_KERNEL_BUILDER(
588 Name("MergeAppendedSlices").Device(tensorflow::DEVICE_CPU),
589 MergeAppendedSlicesOp);
590
591 } // namespace fcp
592