1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/cache_dataset_ops.h"
16
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "tensorflow/core/data/name_utils.h"
23 #include "tensorflow/core/data/serialization_utils.h"
24 #include "tensorflow/core/framework/dataset.h"
25 #include "tensorflow/core/framework/dataset_options.pb.h"
26 #include "tensorflow/core/framework/partial_tensor_shape.h"
27 #include "tensorflow/core/framework/resource_mgr.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/kernels/data/cache_ops.h"
30 #include "tensorflow/core/kernels/data/iterator_ops.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/strings/stringprintf.h"
33 #include "tensorflow/core/platform/env.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/refcount.h"
36 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
37
38 namespace tensorflow {
39 namespace data {
40
41 // See documentation in ../../ops/dataset_ops.cc for a high-level description of
42 // the following op.
43
44 /* static */ constexpr const char* const CacheDatasetOp::kDatasetType;
45 /* static */ constexpr const char* const CacheDatasetOp::kInputDataset;
46 /* static */ constexpr const char* const CacheDatasetOp::kFileName;
47 /* static */ constexpr const char* const CacheDatasetOp::kOutputTypes;
48 /* static */ constexpr const char* const CacheDatasetOp::kOutputShapes;
49
50 namespace {
51
52 constexpr char kKeyStrFormat[] = "%%%zuzu_%%%zuzu";
53 constexpr char kPaddingSizeStrFormat[] = "%zu";
54 constexpr char kFileDatasetPrefix[] = "File";
55 constexpr char kMode[] = "Mode";
56 constexpr char kLockFileSuffix[] = ".lockfile";
57 constexpr char kIterationCompleted[] = "iteration_completed";
58 constexpr char kCurIndex[] = "cur_index";
59 constexpr char kShardId[] = "shard_id";
60 constexpr char kCreatedAt[] = "Created at";
61 constexpr char kMemoryDatasetPrefix[] = "Memory";
62 constexpr char kMemoryCache[] = "MemoryCache";
63 constexpr char kCacheCompleted[] = "cache_completed";
64 constexpr char kIndex[] = "index";
65 constexpr char kImpl[] = "Impl";
66 constexpr char kCacheDataset[] = "CacheDataset";
67 constexpr char kIncompleteCacheErrorMessage[] =
68 "The calling iterator did not fully read the dataset being cached. In "
69 "order to avoid unexpected truncation of the dataset, the partially cached "
70 "contents of the dataset will be discarded. This can happen if you have "
71 "an input pipeline similar to `dataset.cache().take(k).repeat()`. You "
72 "should use `dataset.take(k).cache().repeat()` instead.";
73 } // namespace
74
75 class PartialCache {
76 public:
PartialCache(const DatasetBase * dataset)77 explicit PartialCache(const DatasetBase* dataset) : input_(dataset) {}
78
79 // Extends the temporary cache up to a given index and then updates
80 // out_tensors with the element at that index.
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors)81 Status Get(OpKernelContext* ctx, int64 index,
82 std::vector<Tensor>* out_tensors) {
83 if (!iter_resource_) {
84 TF_ASSIGN_OR_RETURN(iter_resource_,
85 GetIteratorResourceFromDataset(ctx, input_));
86 TF_RETURN_IF_ERROR(iter_resource_->SetIteratorFromDataset(ctx, input_));
87 }
88 if (index >= cache_.size()) {
89 TF_RETURN_IF_ERROR(ExtendTempCacheToIndex(index, ctx));
90 }
91 *out_tensors = cache_.at(index);
92 return OkStatus();
93 }
94
95 // Returns the data which has been cached up to this point.
GetCacheData()96 std::vector<std::vector<Tensor>> GetCacheData() { return cache_; }
97
98 private:
ExtendTempCacheToIndex(int64 index,OpKernelContext * ctx)99 Status ExtendTempCacheToIndex(int64 index, OpKernelContext* ctx) {
100 bool end_of_sequence;
101 while (cache_.size() <= index) {
102 std::vector<Tensor> out_tensors;
103 TF_RETURN_IF_ERROR(
104 iter_resource_->GetNext(ctx, &out_tensors, &end_of_sequence));
105 if (end_of_sequence) {
106 return tensorflow::errors::OutOfRange("Index out of range [0, ",
107 cache_.size(), "):", index);
108 }
109 cache_.push_back(out_tensors);
110 }
111 return OkStatus();
112 }
113
GetIteratorResourceFromDataset(OpKernelContext * ctx,const DatasetBase * dataset)114 StatusOr<core::RefCountPtr<IteratorResource>> GetIteratorResourceFromDataset(
115 OpKernelContext* ctx, const DatasetBase* dataset) {
116 FunctionLibraryRuntime* flr;
117 std::unique_ptr<DeviceMgr> device_mgr(nullptr);
118 std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
119 std::unique_ptr<ProcessFunctionLibraryRuntime> plfr(nullptr);
120 TF_RETURN_IF_ERROR(
121 ctx->function_library()->Clone(&flib_def, &plfr, &flr, true));
122
123 core::RefCountPtr<IteratorResource> iter_resource(new IteratorResource(
124 ctx->env(), dataset->output_dtypes(), dataset->output_shapes(),
125 std::move(device_mgr), std::move(flib_def), std::move(plfr), flr));
126 return iter_resource;
127 }
128
129 const DatasetBase* input_; // Not owned.
130 core::RefCountPtr<IteratorResource> iter_resource_;
131 std::vector<std::vector<Tensor>> cache_;
132 };
133
134 class CacheDatasetOp::FileDatasetBase : public DatasetBase {
135 public:
FileDatasetBase(OpKernelContext * ctx,const DatasetBase * input,string filename,Env * env)136 FileDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
137 string filename, Env* env)
138 : DatasetBase(DatasetContext(ctx)),
139 input_(input),
140 filename_(std::move(filename)),
141 env_(env),
142 num_tensors_(input->output_dtypes().size()),
143 tensor_index_padding_size_(StringPaddingSize(num_tensors_)),
144 item_index_padding_size_(StringPaddingSize(kMaxItems)),
145 tensor_format_string_(strings::Printf(kKeyStrFormat,
146 item_index_padding_size_,
147 tensor_index_padding_size_)) {
148 input_->Ref();
149 DCHECK_EQ(item_index_padding_size_, 7);
150 }
151
~FileDatasetBase()152 ~FileDatasetBase() override { input_->Unref(); }
153
MakeIteratorInternal(const string & prefix) const154 std::unique_ptr<IteratorBase> MakeIteratorInternal(
155 const string& prefix) const override {
156 name_utils::IteratorPrefixParams params;
157 params.dataset_prefix = kFileDatasetPrefix;
158 return std::make_unique<FileIterator>(FileIterator::Params{
159 this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
160 }
161
output_dtypes() const162 const DataTypeVector& output_dtypes() const override {
163 return input_->output_dtypes();
164 }
165
output_shapes() const166 const std::vector<PartialTensorShape>& output_shapes() const override {
167 return input_->output_shapes();
168 }
169
DebugString() const170 string DebugString() const override {
171 name_utils::DatasetDebugStringParams params;
172 params.dataset_prefix = kFileDatasetPrefix;
173 return name_utils::DatasetDebugString(kDatasetType, params);
174 }
175
CardinalityInternal() const176 int64_t CardinalityInternal() const override { return input_->Cardinality(); }
177
InputDatasets(std::vector<const DatasetBase * > * inputs) const178 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
179 inputs->push_back(input_);
180 return OkStatus();
181 }
182
CheckExternalState() const183 Status CheckExternalState() const override {
184 return input_->CheckExternalState();
185 }
186
187 protected:
188 const DatasetBase* const input_;
189 const tstring filename_;
190
191 private:
StringPaddingSize(size_t num_tensors)192 static size_t StringPaddingSize(size_t num_tensors) {
193 return strings::Printf(kPaddingSizeStrFormat, num_tensors - 1).size();
194 }
195
FormatName(size_t item_index,size_t tensor_index) const196 string FormatName(size_t item_index, size_t tensor_index) const {
197 return strings::Printf(tensor_format_string_.c_str(), item_index,
198 tensor_index);
199 }
200
201 class FileIterator : public DatasetIterator<FileDatasetBase> {
202 public:
FileIterator(const Params & params)203 explicit FileIterator(const Params& params)
204 : DatasetIterator<FileDatasetBase>(params) {
205 if (params.dataset->env_
206 ->FileExists(MetaFilename(params.dataset->filename_))
207 .ok()) {
208 mode_ = Mode::read;
209 } else {
210 mode_ = Mode::write;
211 }
212 }
213
Initialize(IteratorContext * ctx)214 Status Initialize(IteratorContext* ctx) override {
215 mutex_lock l(mu_);
216 return InitializeIterator(ctx);
217 }
218
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)219 Status GetNextInternal(IteratorContext* ctx,
220 std::vector<Tensor>* out_tensors,
221 bool* end_of_sequence) override {
222 mutex_lock l(mu_);
223 return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
224 }
225
226 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const227 std::shared_ptr<model::Node> CreateNode(
228 IteratorContext* ctx, model::Node::Args args) const override {
229 return model::MakeKnownRatioNode(std::move(args),
230 /*ratio=*/1);
231 }
232
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)233 Status SaveInternal(SerializationContext* ctx,
234 IteratorStateWriter* writer) override {
235 mutex_lock l(mu_);
236 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kMode), mode_));
237 return SaveInput(ctx, writer, iterator_);
238 }
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)239 Status RestoreInternal(IteratorContext* ctx,
240 IteratorStateReader* reader) override {
241 mutex_lock l(mu_);
242 {
243 int64_t temp;
244 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kMode), &temp));
245 mode_ = static_cast<Mode>(temp);
246 }
247 if (mode_ == Mode::write &&
248 dataset()
249 ->env_->FileExists(MetaFilename(dataset()->filename_))
250 .ok()) {
251 // This could happen if the cache was completely written after the
252 // checkpoint was saved.
253 LOG(WARNING)
254 << "It looks like the cache was already completely written("
255 << MetaFilename(dataset()->filename_)
256 << ") after the last checkpoint was saved. Attempting to read "
257 << "the cache instead of continuing to write. If this is a "
258 << "mistake, please remove the above file and try running again.";
259 mode_ = Mode::read;
260 }
261 TF_RETURN_IF_ERROR(InitializeIterator(ctx));
262 return RestoreInput(ctx, reader, iterator_);
263 }
264
265 private:
266 // FileWriterIterator passes through and caches items from the input
267 // FileDatasetBase.
268 //
269 // This iterator is used when the cache directory is not found on disk. It
270 // creates the cache directory, and passes on the underlying iterator's
271 // elements.
272 //
273 // Caching is performed by writing the input tensors to disk using the
274 // `BundleWriter`. Note that the cache gets fully flushed to disk only
275 // after the input iterator has been fully exhausted. If the program
276 // exits, before completion of an epoch, the cached state would be lost.
277 // To ensure that the partial cache persists across sessions, one should
278 // checkpoint the input pipeline. On each call to `SaveInternal` the
279 // partial cache gets flushed to disk in files with prefix
280 // <filename>_<shard_id> where shard_id is unique for each checkpoint.
281 // When all elements have been produced, these shards get coalesced.
282 class FileWriterIterator : public DatasetIterator<FileDatasetBase> {
283 public:
FileWriterIterator(const Params & params)284 explicit FileWriterIterator(const Params& params)
285 : DatasetIterator<FileDatasetBase>(params),
286 cur_index_(0),
287 shard_id_(0),
288 filename_(
289 strings::StrCat(params.dataset->filename_, "_", shard_id_)),
290 lockfile_(strings::StrCat(filename_, kLockFileSuffix)),
291 lockfile_created_(false),
292 iteration_completed_(false) {}
293
~FileWriterIterator()294 ~FileWriterIterator() override {
295 if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
296 LOG(WARNING) << kIncompleteCacheErrorMessage;
297 std::vector<string> cache_files;
298 Status s = dataset()->env_->GetMatchingPaths(
299 strings::StrCat(filename_, "*"), &cache_files);
300 if (!s.ok()) {
301 LOG(WARNING) << "Failed to get matching files on " << filename_
302 << "* : " << s.ToString();
303 }
304 for (const string& path : cache_files) {
305 s = dataset()->env_->DeleteFile(path);
306 if (!s.ok()) {
307 LOG(WARNING) << "Failed to delete " << path << " : "
308 << s.ToString();
309 }
310 }
311 }
312 }
313
Initialize(IteratorContext * ctx)314 Status Initialize(IteratorContext* ctx) override {
315 return dataset()->input_->MakeIterator(ctx, this, prefix(),
316 &input_impl_);
317 }
318
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)319 Status GetNextInternal(IteratorContext* ctx,
320 std::vector<Tensor>* out_tensors,
321 bool* end_of_sequence) override {
322 mutex_lock l(mu_);
323 *end_of_sequence = false;
324 TF_RETURN_IF_ERROR(EnsureLockFileExists(end_of_sequence));
325 if (*end_of_sequence) {
326 return OkStatus();
327 }
328 TF_RETURN_IF_ERROR(writer_->status());
329 if (cur_index_ >= kMaxItems) {
330 // As a courtesy, close the [truncated] cache file.
331 Status s = Finish();
332 if (!s.ok()) {
333 LOG(ERROR) << s;
334 }
335 return errors::InvalidArgument(
336 "Upstream iterator is producing more than ", kMaxItems,
337 " items, which is more than the cache limit.");
338 }
339
340 TF_RETURN_IF_ERROR(
341 input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
342 if (*end_of_sequence && out_tensors->empty()) {
343 TF_RETURN_IF_ERROR(Finish());
344 cur_index_++;
345 return OkStatus();
346 }
347 if (out_tensors->size() != dataset()->num_tensors_) {
348 return errors::Internal(
349 "Upstream iterator returned invalid number of tensors. "
350 "Expected ",
351 dataset()->num_tensors_, " got: ", out_tensors->size());
352 }
353 size_t tensor_index = 0;
354 for (const Tensor& t : *out_tensors) {
355 DCHECK_LT(tensor_index, dataset()->num_tensors_);
356 string key = dataset()->FormatName(cur_index_, tensor_index++);
357 TF_RETURN_IF_ERROR(writer_->Add(key, t));
358 }
359 if (*end_of_sequence) {
360 TF_RETURN_IF_ERROR(Finish());
361 }
362 cur_index_++;
363 return OkStatus();
364 }
365
366 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const367 std::shared_ptr<model::Node> CreateNode(
368 IteratorContext* ctx, model::Node::Args args) const override {
369 return model::MakeKnownRatioNode(std::move(args),
370 /*ratio=*/1);
371 }
372
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)373 Status SaveInternal(SerializationContext* ctx,
374 IteratorStateWriter* writer) override {
375 mutex_lock l(mu_);
376 TF_RETURN_IF_ERROR(
377 writer->WriteScalar(full_name(kCurIndex), cur_index_));
378
379 if (iteration_completed_) {
380 TF_RETURN_IF_ERROR(
381 writer->WriteScalar(full_name(kIterationCompleted), ""));
382 return OkStatus();
383 }
384
385 // lockfile is created on the first call to GetNextInternal. The
386 // absence of a lockfile means that GetNextInternal was not called
387 // and hence nothing was written to cache. So we don't need to worry
388 // about flushing the current shard. This ensures that we never write
389 // empty shards.
390 if (lockfile_created_) {
391 // Flush the current bundle.
392 TF_RETURN_IF_ERROR(writer_->Finish());
393
394 // Note: We do not delete the lockfile here. We keep lockfiles of
395 // all shards around until the entire cache has been written to
396 // prevent concurrent iterators from corrupting any of the shards.
397
398 // Start caching to a new shard.
399 shard_id_++;
400 filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_);
401 lockfile_ = strings::StrCat(filename_, kLockFileSuffix);
402 lockfile_created_ = false;
403 }
404 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
405 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kShardId), shard_id_));
406 return OkStatus();
407 }
408
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)409 Status RestoreInternal(IteratorContext* ctx,
410 IteratorStateReader* reader) override {
411 mutex_lock l(mu_);
412 int64_t temp;
413 // TODO(b/78048575): Update this when saving size_t tensors directly
414 // is supported.
415 {
416 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &temp));
417 cur_index_ = static_cast<size_t>(temp);
418 if (cur_index_ != temp) {
419 return errors::Internal("Invalid value for cur_index ", temp);
420 }
421 }
422
423 if (reader->Contains(full_name(kIterationCompleted))) {
424 iteration_completed_ = true;
425 return OkStatus();
426 }
427
428 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
429
430 // TODO(b/78048575): Update this when saving size_t tensors directly
431 // is supported.
432 {
433 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kShardId), &temp));
434 shard_id_ = static_cast<size_t>(temp);
435 if (shard_id_ != temp) {
436 return errors::Internal("Invalid value for shard_id ", temp);
437 }
438 }
439 filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_);
440 lockfile_ = strings::StrCat(filename_, kLockFileSuffix);
441 writer_ = std::make_unique<BundleWriter>(dataset()->env_, filename_);
442 return OkStatus();
443 }
444
445 private:
EnsureLockFileExists(bool * end_of_sequence)446 Status EnsureLockFileExists(bool* end_of_sequence)
447 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
448 if (iteration_completed_) {
449 *end_of_sequence = true;
450 return OkStatus();
451 }
452 if (lockfile_created_) {
453 return OkStatus();
454 }
455
456 // Perform rudimentary locking to help catch concurrent writes to the
457 // same cache files.
458
459 // 1. Check that a checkpoint for the shard has not already been
460 // written.
461 if (dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
462 return errors::AlreadyExists("Existing cache files found: \n",
463 MetaFilename(filename_), "\n",
464 DataFilename(filename_, 0, 1), "\n",
465 "To continue delete the above files.");
466 }
467
468 // 2. Check that there isn't a concurrent iterator that is writing
469 // to cache.
470 if (dataset()->env_->FileExists(lockfile_).ok()) {
471 // Attempt to read the contents of the lockfile.
472 char contents_scratch[151] = {0}; // Initialize all to 0.
473 StringPiece contents;
474 std::unique_ptr<RandomAccessFile> file;
475 if (dataset()->env_->NewRandomAccessFile(lockfile_, &file).ok()) {
476 file->Read(0, 150, &contents, contents_scratch).IgnoreError();
477 }
478 return errors::AlreadyExists(
479 "There appears to be a concurrent caching iterator running - "
480 "cache lockfile already exists ('",
481 lockfile_,
482 "'). If you are sure no other running TF computations are "
483 "using this cache prefix, delete the lockfile and "
484 "re-initialize the iterator. Lockfile contents: ",
485 contents);
486 }
487 // Create the file, and write some basic contents.
488 std::unique_ptr<WritableFile> lockfile;
489 TF_RETURN_IF_ERROR(
490 dataset()->env_->NewWritableFile(lockfile_, &lockfile));
491 TF_RETURN_IF_ERROR(lockfile->Append(
492 strings::StrCat(kCreatedAt, ": ", EnvTime::NowSeconds())));
493
494 // At this point we know that
495 // 1. There is no conflicting checkpoint with prefix `filename_`.
496 // 2. There is no concurrent session that is trying to write a ckpt
497 // to filename.
498 // So it is safe to create a BundleWriter here. Note that it is
499 // unsafe to initialize the BundleWriter anywhere the above
500 // conditions are not met since BundleWriter's constructor creates
501 // new temp files which can delete the temp files created by a
502 // BundleWriter in another Session.
503 writer_ = std::make_unique<BundleWriter>(dataset()->env_, filename_);
504 lockfile_created_ = true;
505 return OkStatus();
506 }
507
Finish()508 Status Finish() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
509 iteration_completed_ = true;
510 // Flush the current bundle.
511 TF_RETURN_IF_ERROR(writer_->Finish());
512 // Merge all the bundles.
513 // Currently there are `shard_id_ + 1` bundles, one for each
514 // checkpoint. Each bundle has prefix <filename>_<id> where `id` is an
515 // integer starting at 0 and incremented by 1 for each new checkpoint.
516 // We merge all these bundles into a bundle with prefix <filename> so
517 // that the next call to `MakeIterator` can build a
518 // `FileReaderIterator`.
519 {
520 std::vector<tstring> prefixes;
521 prefixes.reserve(shard_id_ + 1);
522 for (size_t i = 0; i <= shard_id_; ++i) {
523 prefixes.emplace_back(
524 strings::StrCat(dataset()->filename_, "_", i));
525 }
526 TF_RETURN_IF_ERROR(
527 MergeBundles(dataset()->env_, prefixes, dataset()->filename_));
528 }
529 // Delete all lockfiles.
530 for (size_t i = 0; i <= shard_id_; ++i) {
531 TF_RETURN_IF_ERROR(dataset()->env_->DeleteFile(
532 strings::StrCat(dataset()->filename_, "_", i, kLockFileSuffix)));
533 }
534 return OkStatus();
535 }
536
537 mutex mu_;
538 size_t cur_index_ TF_GUARDED_BY(mu_);
539 // Index of the current shard. This gets incremented whenever a new
540 // cache shard is saved.
541 size_t shard_id_ TF_GUARDED_BY(mu_);
542 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
543 // The current prefix for the cache file. This is equal to
544 // `StrCat(dataset()->filename_, "_", shard_id_)`.
545 string filename_;
546 std::unique_ptr<BundleWriter> writer_ TF_GUARDED_BY(mu_);
547 string lockfile_ TF_GUARDED_BY(mu_);
548 bool lockfile_created_ TF_GUARDED_BY(mu_);
549 bool iteration_completed_ TF_GUARDED_BY(mu_);
550 }; // FileWriterIterator
551
552 class FileReaderIterator : public DatasetIterator<FileDatasetBase> {
553 public:
FileReaderIterator(const Params & params)554 explicit FileReaderIterator(const Params& params)
555 : DatasetIterator<FileDatasetBase>(params),
556 cur_index_(0),
557 reader_(dataset()->env_, dataset()->filename_),
558 iterator_restored_(false) {}
559
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)560 Status GetNextInternal(IteratorContext* ctx,
561 std::vector<Tensor>* out_tensors,
562 bool* end_of_sequence) override {
563 mutex_lock l(mu_);
564 *end_of_sequence = false;
565 TF_RETURN_IF_ERROR(reader_.status());
566 if (!reader_.Valid()) {
567 *end_of_sequence = true;
568 return OkStatus();
569 }
570 out_tensors->clear();
571 out_tensors->resize(dataset()->num_tensors_);
572
573 for (size_t i = 0; i < dataset()->num_tensors_; ++i) {
574 // When the iterator is restored from the checkpoint, `reader_` is
575 // already pointing at `key` so we do not need to skip the header
576 // entry.
577 if (!iterator_restored_) {
578 reader_.Next(); // The first entry in the table is a header.
579 } else {
580 iterator_restored_ = false;
581 }
582 if (!reader_.Valid()) {
583 out_tensors->clear();
584 *end_of_sequence = true;
585 return OkStatus();
586 }
587 StringPiece key = reader_.key();
588 DCHECK_EQ(key, dataset()->FormatName(cur_index_, i));
589 TF_RETURN_IF_ERROR(reader_.ReadCurrent(&(*out_tensors)[i]));
590 TF_RETURN_IF_ERROR(reader_.status());
591 }
592 cur_index_++;
593 return OkStatus();
594 }
595
596 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const597 std::shared_ptr<model::Node> CreateNode(
598 IteratorContext* ctx, model::Node::Args args) const override {
599 return model::MakeKnownRatioNode(std::move(args),
600 /*ratio=*/1);
601 }
602
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)603 Status SaveInternal(SerializationContext* ctx,
604 IteratorStateWriter* writer) override {
605 mutex_lock l(mu_);
606 TF_RETURN_IF_ERROR(
607 writer->WriteScalar(full_name(kCurIndex), cur_index_));
608 return OkStatus();
609 }
610
RestoreInternal(IteratorContext * ctx,IteratorStateReader * iterator_state_reader)611 Status RestoreInternal(
612 IteratorContext* ctx,
613 IteratorStateReader* iterator_state_reader) override {
614 mutex_lock l(mu_);
615 {
616 // TODO(b/78048575): Update this when saving size_t tensors directly
617 // is supported.
618 int64_t temp;
619 TF_RETURN_IF_ERROR(
620 iterator_state_reader->ReadScalar(full_name(kCurIndex), &temp));
621 cur_index_ = static_cast<size_t>(temp);
622 if (cur_index_ != temp) {
623 return errors::Internal("Invalid value for cur_index ", temp);
624 }
625 }
626 if (!reader_.Valid()) {
627 return errors::Internal("Error initializing BundleReader.");
628 }
629 reader_.Seek(dataset()->FormatName(cur_index_, 0));
630 iterator_restored_ = true;
631 return OkStatus();
632 }
633
634 private:
635 mutex mu_;
636 size_t cur_index_ TF_GUARDED_BY(mu_);
637 BundleReader reader_ TF_GUARDED_BY(mu_);
638 bool iterator_restored_ TF_GUARDED_BY(mu_);
639 }; // FileReaderIterator
640
InitializeIterator(IteratorContext * ctx)641 Status InitializeIterator(IteratorContext* ctx)
642 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
643 // We intentionally use the same prefix for both `FileReaderIterator` and
644 // `FileWriterIterator`. Since at any time there will be at most one of
645 // them alive, there should be no conflicts. This allows both iterators to
646 // use a common key for `cur_index`. We leverage this in the corner case
647 // when this iterator is restored from an old checkpoint in `write` mode
648 // and the cache has been completely flushed to disk since then. In that
649 // case we simply build a `FileReaderIterator` and seek to the
650 // `cur_index`.
651 switch (mode_) {
652 case Mode::read:
653 iterator_ =
654 std::make_unique<FileReaderIterator>(FileReaderIterator::Params{
655 dataset(), strings::StrCat(prefix(), kImpl)});
656 break;
657 case Mode::write:
658 iterator_ =
659 std::make_unique<FileWriterIterator>(FileWriterIterator::Params{
660 dataset(), strings::StrCat(prefix(), kImpl)});
661 }
662 TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
663 return iterator_->Initialize(ctx);
664 }
665
666 mutex mu_;
667 enum Mode { read, write };
668 Mode mode_ TF_GUARDED_BY(mu_);
669 std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
670 }; // FileIterator
671
672 Env* const env_;
673 const size_t num_tensors_;
674 const size_t tensor_index_padding_size_;
675 static constexpr size_t kMaxItems = 10000000; // 10 million
676 const size_t item_index_padding_size_;
677 const string tensor_format_string_;
678 }; // FileDatasetBase
679
680 class CacheDatasetOp::FileDataset : public CacheDatasetOp::FileDatasetBase {
681 public:
682 using FileDatasetBase::FileDatasetBase;
683
684 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const685 Status AsGraphDefInternal(SerializationContext* ctx,
686 DatasetGraphDefBuilder* b,
687 Node** output) const override {
688 Node* input_graph = nullptr;
689 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
690 Node* filename = nullptr;
691 TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename));
692 TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output));
693 return OkStatus();
694 }
695 };
696
697 class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDatasetBase {
698 public:
FileDatasetV2(OpKernelContext * ctx,const DatasetBase * input,string filename,Env * env,const Tensor & resource_handle)699 explicit FileDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
700 string filename, Env* env,
701 const Tensor& resource_handle)
702 : FileDatasetBase(ctx, input, filename, env),
703 resource_handle_(resource_handle) {}
704
705 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const706 Status AsGraphDefInternal(SerializationContext* ctx,
707 DatasetGraphDefBuilder* b,
708 Node** output) const override {
709 Node* input_node = nullptr;
710 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
711 Node* filename_node = nullptr;
712 TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename_node));
713 Node* resource_handle_node = nullptr;
714 TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
715 TF_RETURN_IF_ERROR(b->AddDataset(
716 this, {input_node, filename_node, resource_handle_node}, output));
717 return OkStatus();
718 }
719
720 private:
721 const Tensor resource_handle_;
722 };
723
724 class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
725 public:
MemoryDatasetBase(OpKernelContext * ctx,const DatasetBase * input,std::shared_ptr<MemoryCache> cache)726 explicit MemoryDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
727 std::shared_ptr<MemoryCache> cache)
728 : DatasetBase(DatasetContext(ctx)),
729 input_(input),
730 cache_(std::move(cache)) {
731 input_->Ref();
732 }
733
~MemoryDatasetBase()734 ~MemoryDatasetBase() override { input_->Unref(); }
735
MakeIteratorInternal(const string & prefix) const736 std::unique_ptr<IteratorBase> MakeIteratorInternal(
737 const string& prefix) const override {
738 name_utils::IteratorPrefixParams params;
739 params.dataset_prefix = kMemoryDatasetPrefix;
740 return std::make_unique<MemoryIterator>(
741 MemoryIterator::Params{
742 this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
743 cache_.get());
744 }
745
output_dtypes() const746 const DataTypeVector& output_dtypes() const override {
747 return input_->output_dtypes();
748 }
749
output_shapes() const750 const std::vector<PartialTensorShape>& output_shapes() const override {
751 return input_->output_shapes();
752 }
753
DebugString() const754 string DebugString() const override {
755 name_utils::DatasetDebugStringParams params;
756 params.dataset_prefix = kMemoryDatasetPrefix;
757 return name_utils::DatasetDebugString(kDatasetType, params);
758 }
759
CardinalityInternal() const760 int64_t CardinalityInternal() const override {
761 return input_->Cardinality();
762 };
763
CardinalityInternal(CardinalityOptions options) const764 int64_t CardinalityInternal(CardinalityOptions options) const override {
765 return input_->Cardinality(options);
766 };
767
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const768 Status Get(OpKernelContext* ctx, int64 index,
769 std::vector<Tensor>* out_tensors) const override {
770 mutex_lock l(mu_);
771
772 CardinalityOptions options;
773 options.set_compute_level(CardinalityOptions::CARDINALITY_COMPUTE_LOW);
774 int64_t cardinality = Cardinality(options);
775
776 if (cardinality != kUnknownCardinality &&
777 cardinality != kInfiniteCardinality && index >= cardinality) {
778 return errors::OutOfRange("Index out of range [0, ", cardinality,
779 "):", index);
780 }
781 if (!partial_cache_) {
782 partial_cache_ = std::make_unique<PartialCache>(input_);
783 }
784 return partial_cache_->Get(ctx, index, out_tensors);
785 }
786
InputDatasets(std::vector<const DatasetBase * > * inputs) const787 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
788 inputs->push_back(input_);
789 return OkStatus();
790 }
791
CheckExternalState() const792 Status CheckExternalState() const override {
793 return input_->CheckExternalState();
794 }
795
796 protected:
797 class MemoryIterator : public DatasetIterator<MemoryDatasetBase> {
798 public:
MemoryIterator(const Params & params,MemoryCache * cache)799 explicit MemoryIterator(const Params& params, MemoryCache* cache)
800 : DatasetIterator<MemoryDatasetBase>(params), cache_(cache) {}
801
Initialize(IteratorContext * ctx)802 Status Initialize(IteratorContext* ctx) override {
803 mutex_lock l(mu_);
804 return InitializeIterator(ctx);
805 }
806
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)807 Status GetNextInternal(IteratorContext* ctx,
808 std::vector<Tensor>* out_tensors,
809 bool* end_of_sequence) override {
810 mutex_lock l(mu_);
811 return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
812 }
813
814 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const815 std::shared_ptr<model::Node> CreateNode(
816 IteratorContext* ctx, model::Node::Args args) const override {
817 return model::MakeKnownRatioNode(std::move(args),
818 /*ratio=*/1);
819 }
820
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)821 Status SaveInternal(SerializationContext* ctx,
822 IteratorStateWriter* writer) override {
823 mutex_lock l(mu_);
824 if (cache_->IsCompleted()) {
825 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheCompleted), ""));
826 TF_RETURN_IF_ERROR(
827 WriteElementsToCheckpoint(writer, prefix(), cache_->data()));
828 }
829 return SaveInput(ctx, writer, iterator_);
830 }
831
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)832 Status RestoreInternal(IteratorContext* ctx,
833 IteratorStateReader* reader) override {
834 mutex_lock l(mu_);
835 iterator_.reset();
836 cache_->Reset();
837 if (reader->Contains(full_name(kCacheCompleted))) {
838 std::vector<std::vector<Tensor>> temp_cache;
839 TF_RETURN_IF_ERROR(
840 ReadElementsFromCheckpoint(ctx, reader, prefix(), &temp_cache));
841 cache_->Complete(std::move(temp_cache));
842 }
843 TF_RETURN_IF_ERROR(InitializeIterator(ctx));
844 return RestoreInput(ctx, reader, iterator_);
845 }
846
847 private:
848 class MemoryWriterIterator : public DatasetIterator<MemoryDatasetBase> {
849 public:
MemoryWriterIterator(const Params & params,MemoryCache * cache)850 explicit MemoryWriterIterator(const Params& params, MemoryCache* cache)
851 : DatasetIterator<MemoryDatasetBase>(params), cache_(cache) {}
852
~MemoryWriterIterator()853 ~MemoryWriterIterator() override {
854 mutex_lock l(mu_);
855 if (!temp_cache_.empty() && !cache_->IsCompleted()) {
856 LOG(WARNING) << kIncompleteCacheErrorMessage;
857 cache_->Reset();
858 }
859 }
860
Initialize(IteratorContext * ctx)861 Status Initialize(IteratorContext* ctx) override {
862 return dataset()->input_->MakeIterator(ctx, this, prefix(),
863 &input_impl_);
864 }
865
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)866 Status GetNextInternal(IteratorContext* ctx,
867 std::vector<Tensor>* out_tensors,
868 bool* end_of_sequence) override {
869 mutex_lock l(mu_);
870 TF_RETURN_IF_ERROR(
871 input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
872 if (*end_of_sequence) {
873 if (!cache_->IsCompleted()) {
874 VLOG(2) << "Finalizing the cache because EOF has been reached.";
875 cache_->Complete(std::move(temp_cache_));
876 }
877 return OkStatus();
878 }
879 RecordBufferEnqueue(ctx, *out_tensors);
880 temp_cache_.emplace_back(*out_tensors);
881 if (temp_cache_.size() == dataset()->input_->Cardinality()) {
882 VLOG(2) << "Finalizing the cache because its size matches the "
883 "expected input cardinality.";
884 cache_->Complete(std::move(temp_cache_));
885 }
886 return OkStatus();
887 }
888
889 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const890 std::shared_ptr<model::Node> CreateNode(
891 IteratorContext* ctx, model::Node::Args args) const override {
892 return model::MakeKnownRatioNode(std::move(args),
893 /*ratio=*/1);
894 }
895
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)896 Status SaveInternal(SerializationContext* ctx,
897 IteratorStateWriter* writer) override {
898 mutex_lock l(mu_);
899 if (!cache_->IsCompleted()) {
900 TF_RETURN_IF_ERROR(
901 WriteElementsToCheckpoint(writer, prefix(), temp_cache_));
902 }
903 return SaveInput(ctx, writer, input_impl_);
904 }
905
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)906 Status RestoreInternal(IteratorContext* ctx,
907 IteratorStateReader* reader) override {
908 mutex_lock l(mu_);
909 if (!reader->Contains(full_name(kCacheCompleted))) {
910 TF_RETURN_IF_ERROR(
911 ReadElementsFromCheckpoint(ctx, reader, prefix(), &temp_cache_));
912 }
913 return RestoreInput(ctx, reader, input_impl_);
914 }
915
916 private:
917 mutex mu_;
918 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
919 MemoryCache* const cache_ TF_GUARDED_BY(mu_); // not owned.
920 std::vector<std::vector<Tensor>> temp_cache_ TF_GUARDED_BY(mu_);
921 }; // MemoryWriterIterator
922
923 class MemoryReaderIterator : public DatasetIterator<MemoryDatasetBase> {
924 public:
MemoryReaderIterator(const Params & params,MemoryCache * cache)925 explicit MemoryReaderIterator(const Params& params, MemoryCache* cache)
926 : DatasetIterator<MemoryDatasetBase>(params),
927 cache_(cache),
928 index_(0) {}
929
Initialize(IteratorContext * ctx)930 Status Initialize(IteratorContext* ctx) override {
931 // The memory allocated for the cache is owned by the parent
932 // dataset but performance modeling uses the iterator abstraction and
933 // thus we record the memory allocated for the cache here. The caveat
934 // is that this is incorrect if there are concurrent instances of this
935 // iterator.
936 tf_shared_lock l(mu_);
937 for (size_t i = 0; i < cache_->size(); ++i) {
938 RecordBufferEnqueue(ctx, cache_->at(i));
939 }
940 return OkStatus();
941 }
942
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)943 Status GetNextInternal(IteratorContext* ctx,
944 std::vector<Tensor>* out_tensors,
945 bool* end_of_sequence) override {
946 mutex_lock l(mu_);
947 if (index_ < cache_->size()) {
948 const std::vector<Tensor>& cache_tensors = cache_->at(index_);
949 out_tensors->insert(out_tensors->begin(), cache_tensors.begin(),
950 cache_tensors.end());
951 index_++;
952 *end_of_sequence = false;
953 return OkStatus();
954 } else {
955 *end_of_sequence = true;
956 return OkStatus();
957 }
958 }
959
960 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const961 std::shared_ptr<model::Node> CreateNode(
962 IteratorContext* ctx, model::Node::Args args) const override {
963 return model::MakeKnownRatioNode(std::move(args),
964 /*ratio=*/1);
965 }
966
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)967 Status SaveInternal(SerializationContext* ctx,
968 IteratorStateWriter* writer) override {
969 mutex_lock l(mu_);
970 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_));
971 return OkStatus();
972 }
973
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)974 Status RestoreInternal(IteratorContext* ctx,
975 IteratorStateReader* reader) override {
976 mutex_lock l(mu_);
977 {
978 // kIndex will not be set if we are restoring from a checkpoint
979 // written by a MemoryWriterIterator that has completed its cache.
980 int64_t temp = cache_->size();
981 if (reader->Contains(full_name(kIndex))) {
982 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &temp));
983 }
984 index_ = static_cast<size_t>(temp);
985 }
986 return OkStatus();
987 }
988
989 private:
990 mutex mu_;
991 MemoryCache* const cache_ TF_GUARDED_BY(mu_); // not owned.
992 size_t index_ TF_GUARDED_BY(mu_);
993 }; // MemoryReaderIterator
994
InitializeIterator(IteratorContext * ctx)995 Status InitializeIterator(IteratorContext* ctx)
996 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
997 if (cache_->IsCompleted()) {
998 iterator_ = std::make_unique<MemoryReaderIterator>(
999 MemoryReaderIterator::Params{dataset(),
1000 strings::StrCat(prefix(), kImpl)},
1001 cache_);
1002 } else {
1003 iterator_ = std::make_unique<MemoryWriterIterator>(
1004 MemoryWriterIterator::Params{dataset(),
1005 strings::StrCat(prefix(), kImpl)},
1006 cache_);
1007 }
1008 TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
1009 return iterator_->Initialize(ctx);
1010 }
1011
1012 mutex mu_;
1013 MemoryCache* cache_ TF_GUARDED_BY(mu_); // not owned.
1014 std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
1015 }; // MemoryIterator
1016
1017 mutable mutex mu_;
1018 const DatasetBase* const input_;
1019 const std::shared_ptr<MemoryCache> cache_;
1020 mutable std::unique_ptr<PartialCache> partial_cache_ TF_GUARDED_BY(mu_);
1021 }; // MemoryDatasetBase
1022
1023 // This version of memory dataset has an exclusive ownership of the memory cache
1024 // resource. It supports sharing of the cache across different iterations of the
1025 // `repeat` transformation but not across different iterators.
1026 class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
1027 public:
MemoryDataset(OpKernelContext * ctx,const DatasetBase * input,MemoryCacheManager * manager,ResourceHandle && resource_handle)1028 MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
1029 MemoryCacheManager* manager, ResourceHandle&& resource_handle)
1030 : MemoryDatasetBase(ctx, input, manager->get()),
1031 manager_(manager),
1032 resource_handle_(std::move(resource_handle)),
1033 resource_mgr_(ctx->resource_manager()) {}
1034
~MemoryDataset()1035 ~MemoryDataset() override {
1036 manager_->Unref();
1037 Status s = resource_mgr_->Delete<MemoryCacheManager>(
1038 resource_handle_.container(), resource_handle_.name());
1039 if (!s.ok()) {
1040 LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
1041 }
1042 }
1043
1044 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const1045 Status AsGraphDefInternal(SerializationContext* ctx,
1046 DatasetGraphDefBuilder* b,
1047 Node** output) const override {
1048 Node* input_node = nullptr;
1049 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
1050 Node* filename_node = nullptr;
1051 TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
1052 TF_RETURN_IF_ERROR(
1053 b->AddDataset(this, {input_node, filename_node}, output));
1054 return OkStatus();
1055 }
1056
1057 private:
1058 MemoryCacheManager* const manager_; // Owned.
1059 const ResourceHandle resource_handle_;
1060 ResourceMgr* const resource_mgr_; // Not owned.
1061 };
1062
1063 // This version of memory dataset has a shared ownership of the memory cache
1064 // resource. It supports sharing of the cache across different iterations of
1065 // the `repeat` transformation and also across different iterators.
1066 class CacheDatasetOp::MemoryDatasetV2
1067 : public CacheDatasetOp::MemoryDatasetBase {
1068 public:
MemoryDatasetV2(OpKernelContext * ctx,const DatasetBase * input,MemoryCacheManager * manager,ResourceHandle && resource_handle,bool owns_resource)1069 MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
1070 MemoryCacheManager* manager, ResourceHandle&& resource_handle,
1071 bool owns_resource)
1072 : MemoryDatasetBase(ctx, input, manager->get()),
1073 manager_(manager),
1074 owns_resource_(owns_resource),
1075 resource_handle_(std::move(resource_handle)),
1076 resource_mgr_(ctx->resource_manager()) {}
1077
~MemoryDatasetV2()1078 ~MemoryDatasetV2() override {
1079 manager_->Unref();
1080 if (owns_resource_) {
1081 Status s = resource_mgr_->Delete<MemoryCacheManager>(
1082 resource_handle_.container(), resource_handle_.name());
1083 if (!s.ok()) {
1084 LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
1085 }
1086 }
1087 }
1088
1089 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const1090 Status AsGraphDefInternal(SerializationContext* ctx,
1091 DatasetGraphDefBuilder* b,
1092 Node** output) const override {
1093 Node* input_node = nullptr;
1094 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
1095 Node* filename_node = nullptr;
1096 TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
1097 Node* resource_handle_node = nullptr;
1098 Tensor handle(DT_RESOURCE, TensorShape({}));
1099 handle.scalar<ResourceHandle>()() = resource_handle_;
1100 TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
1101 TF_RETURN_IF_ERROR(b->AddDataset(
1102 this, {input_node, filename_node, resource_handle_node}, output));
1103 return OkStatus();
1104 }
1105
1106 private:
1107 MemoryCacheManager* const manager_; // Owned.
1108 const bool owns_resource_;
1109 const ResourceHandle resource_handle_;
1110 ResourceMgr* const resource_mgr_; // Not owned.
1111 };
1112
CacheDatasetOp(OpKernelConstruction * ctx)1113 CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx)
1114 : UnaryDatasetOpKernel(ctx),
1115 op_version_(ctx->def().op() == kCacheDataset ? 1 : 2) {}
1116
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)1117 void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1118 DatasetBase** output) {
1119 // Parse out the filenames tensor.
1120 tstring filename;
1121 OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, kFileName, &filename));
1122 if (filename.empty()) {
1123 static std::atomic<int64_t> resource_id_counter(0);
1124 const string& container = ctx->resource_manager()->default_container();
1125 auto name = strings::StrCat(ctx->op_kernel().name(), "/", kMemoryCache, "_",
1126 resource_id_counter.fetch_add(1));
1127 if (op_version_ == 2) {
1128 bool owns_resource = false;
1129 MemoryCacheManager* manager = nullptr;
1130 auto handle = HandleFromInput(ctx, 2);
1131 Status s = ctx->resource_manager()->Lookup<MemoryCacheManager>(
1132 handle.container(), handle.name(), &manager);
1133 if (errors::IsNotFound(s)) {
1134 owns_resource = true;
1135 OP_REQUIRES_OK(
1136 ctx,
1137 ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
1138 container, name, &manager, [](MemoryCacheManager** manager) {
1139 *manager = new MemoryCacheManager();
1140 return OkStatus();
1141 }));
1142 handle = MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
1143 } else {
1144 OP_REQUIRES_OK(ctx, s);
1145 }
1146 // Ownership of manager is transferred onto `MemoryDatasetV2`.
1147 *output = new MemoryDatasetV2(ctx, input, manager, std::move(handle),
1148 owns_resource);
1149 } else {
1150 MemoryCacheManager* manager;
1151 OP_REQUIRES_OK(
1152 ctx, ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
1153 container, name, &manager, [](MemoryCacheManager** manager) {
1154 *manager = new MemoryCacheManager();
1155 return OkStatus();
1156 }));
1157 auto handle =
1158 MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
1159 // Ownership of manager is transferred onto `MemoryDataset`.
1160 *output = new MemoryDataset(ctx, input, manager, std::move(handle));
1161 }
1162 } else {
1163 if (op_version_ == 2) {
1164 *output =
1165 new FileDatasetV2(ctx, input, filename, ctx->env(), ctx->input(2));
1166 } else {
1167 *output = new FileDataset(ctx, input, filename, ctx->env());
1168 }
1169 }
1170 }
1171
1172 namespace {
1173 REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU),
1174 CacheDatasetOp);
1175 REGISTER_KERNEL_BUILDER(Name("CacheDatasetV2").Device(DEVICE_CPU),
1176 CacheDatasetOp);
1177 } // namespace
1178 } // namespace data
1179 } // namespace tensorflow
1180