xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/cache_dataset_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/cache_dataset_ops.h"
16 
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/core/data/name_utils.h"
23 #include "tensorflow/core/data/serialization_utils.h"
24 #include "tensorflow/core/framework/dataset.h"
25 #include "tensorflow/core/framework/dataset_options.pb.h"
26 #include "tensorflow/core/framework/partial_tensor_shape.h"
27 #include "tensorflow/core/framework/resource_mgr.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/kernels/data/cache_ops.h"
30 #include "tensorflow/core/kernels/data/iterator_ops.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/strings/stringprintf.h"
33 #include "tensorflow/core/platform/env.h"
34 #include "tensorflow/core/platform/errors.h"
35 #include "tensorflow/core/platform/refcount.h"
36 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
37 
38 namespace tensorflow {
39 namespace data {
40 
41 // See documentation in ../../ops/dataset_ops.cc for a high-level description of
42 // the following op.
43 
44 /* static */ constexpr const char* const CacheDatasetOp::kDatasetType;
45 /* static */ constexpr const char* const CacheDatasetOp::kInputDataset;
46 /* static */ constexpr const char* const CacheDatasetOp::kFileName;
47 /* static */ constexpr const char* const CacheDatasetOp::kOutputTypes;
48 /* static */ constexpr const char* const CacheDatasetOp::kOutputShapes;
49 
50 namespace {
51 
52 constexpr char kKeyStrFormat[] = "%%%zuzu_%%%zuzu";
53 constexpr char kPaddingSizeStrFormat[] = "%zu";
54 constexpr char kFileDatasetPrefix[] = "File";
55 constexpr char kMode[] = "Mode";
56 constexpr char kLockFileSuffix[] = ".lockfile";
57 constexpr char kIterationCompleted[] = "iteration_completed";
58 constexpr char kCurIndex[] = "cur_index";
59 constexpr char kShardId[] = "shard_id";
60 constexpr char kCreatedAt[] = "Created at";
61 constexpr char kMemoryDatasetPrefix[] = "Memory";
62 constexpr char kMemoryCache[] = "MemoryCache";
63 constexpr char kCacheCompleted[] = "cache_completed";
64 constexpr char kIndex[] = "index";
65 constexpr char kImpl[] = "Impl";
66 constexpr char kCacheDataset[] = "CacheDataset";
67 constexpr char kIncompleteCacheErrorMessage[] =
68     "The calling iterator did not fully read the dataset being cached. In "
69     "order to avoid unexpected truncation of the dataset, the partially cached "
70     "contents of the dataset  will be discarded. This can happen if you have "
71     "an input pipeline similar to `dataset.cache().take(k).repeat()`. You "
72     "should use `dataset.take(k).cache().repeat()` instead.";
73 }  // namespace
74 
75 class PartialCache {
76  public:
PartialCache(const DatasetBase * dataset)77   explicit PartialCache(const DatasetBase* dataset) : input_(dataset) {}
78 
79   // Extends the temporary cache up to a given index and then updates
80   // out_tensors with the element at that index.
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors)81   Status Get(OpKernelContext* ctx, int64 index,
82              std::vector<Tensor>* out_tensors) {
83     if (!iter_resource_) {
84       TF_ASSIGN_OR_RETURN(iter_resource_,
85                           GetIteratorResourceFromDataset(ctx, input_));
86       TF_RETURN_IF_ERROR(iter_resource_->SetIteratorFromDataset(ctx, input_));
87     }
88     if (index >= cache_.size()) {
89       TF_RETURN_IF_ERROR(ExtendTempCacheToIndex(index, ctx));
90     }
91     *out_tensors = cache_.at(index);
92     return OkStatus();
93   }
94 
95   // Returns the data which has been cached up to this point.
GetCacheData()96   std::vector<std::vector<Tensor>> GetCacheData() { return cache_; }
97 
98  private:
ExtendTempCacheToIndex(int64 index,OpKernelContext * ctx)99   Status ExtendTempCacheToIndex(int64 index, OpKernelContext* ctx) {
100     bool end_of_sequence;
101     while (cache_.size() <= index) {
102       std::vector<Tensor> out_tensors;
103       TF_RETURN_IF_ERROR(
104           iter_resource_->GetNext(ctx, &out_tensors, &end_of_sequence));
105       if (end_of_sequence) {
106         return tensorflow::errors::OutOfRange("Index out of range [0, ",
107                                               cache_.size(), "):", index);
108       }
109       cache_.push_back(out_tensors);
110     }
111     return OkStatus();
112   }
113 
GetIteratorResourceFromDataset(OpKernelContext * ctx,const DatasetBase * dataset)114   StatusOr<core::RefCountPtr<IteratorResource>> GetIteratorResourceFromDataset(
115       OpKernelContext* ctx, const DatasetBase* dataset) {
116     FunctionLibraryRuntime* flr;
117     std::unique_ptr<DeviceMgr> device_mgr(nullptr);
118     std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
119     std::unique_ptr<ProcessFunctionLibraryRuntime> plfr(nullptr);
120     TF_RETURN_IF_ERROR(
121         ctx->function_library()->Clone(&flib_def, &plfr, &flr, true));
122 
123     core::RefCountPtr<IteratorResource> iter_resource(new IteratorResource(
124         ctx->env(), dataset->output_dtypes(), dataset->output_shapes(),
125         std::move(device_mgr), std::move(flib_def), std::move(plfr), flr));
126     return iter_resource;
127   }
128 
129   const DatasetBase* input_;  // Not owned.
130   core::RefCountPtr<IteratorResource> iter_resource_;
131   std::vector<std::vector<Tensor>> cache_;
132 };
133 
134 class CacheDatasetOp::FileDatasetBase : public DatasetBase {
135  public:
FileDatasetBase(OpKernelContext * ctx,const DatasetBase * input,string filename,Env * env)136   FileDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
137                   string filename, Env* env)
138       : DatasetBase(DatasetContext(ctx)),
139         input_(input),
140         filename_(std::move(filename)),
141         env_(env),
142         num_tensors_(input->output_dtypes().size()),
143         tensor_index_padding_size_(StringPaddingSize(num_tensors_)),
144         item_index_padding_size_(StringPaddingSize(kMaxItems)),
145         tensor_format_string_(strings::Printf(kKeyStrFormat,
146                                               item_index_padding_size_,
147                                               tensor_index_padding_size_)) {
148     input_->Ref();
149     DCHECK_EQ(item_index_padding_size_, 7);
150   }
151 
~FileDatasetBase()152   ~FileDatasetBase() override { input_->Unref(); }
153 
MakeIteratorInternal(const string & prefix) const154   std::unique_ptr<IteratorBase> MakeIteratorInternal(
155       const string& prefix) const override {
156     name_utils::IteratorPrefixParams params;
157     params.dataset_prefix = kFileDatasetPrefix;
158     return std::make_unique<FileIterator>(FileIterator::Params{
159         this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
160   }
161 
output_dtypes() const162   const DataTypeVector& output_dtypes() const override {
163     return input_->output_dtypes();
164   }
165 
output_shapes() const166   const std::vector<PartialTensorShape>& output_shapes() const override {
167     return input_->output_shapes();
168   }
169 
DebugString() const170   string DebugString() const override {
171     name_utils::DatasetDebugStringParams params;
172     params.dataset_prefix = kFileDatasetPrefix;
173     return name_utils::DatasetDebugString(kDatasetType, params);
174   }
175 
CardinalityInternal() const176   int64_t CardinalityInternal() const override { return input_->Cardinality(); }
177 
InputDatasets(std::vector<const DatasetBase * > * inputs) const178   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
179     inputs->push_back(input_);
180     return OkStatus();
181   }
182 
CheckExternalState() const183   Status CheckExternalState() const override {
184     return input_->CheckExternalState();
185   }
186 
187  protected:
188   const DatasetBase* const input_;
189   const tstring filename_;
190 
191  private:
StringPaddingSize(size_t num_tensors)192   static size_t StringPaddingSize(size_t num_tensors) {
193     return strings::Printf(kPaddingSizeStrFormat, num_tensors - 1).size();
194   }
195 
FormatName(size_t item_index,size_t tensor_index) const196   string FormatName(size_t item_index, size_t tensor_index) const {
197     return strings::Printf(tensor_format_string_.c_str(), item_index,
198                            tensor_index);
199   }
200 
201   class FileIterator : public DatasetIterator<FileDatasetBase> {
202    public:
FileIterator(const Params & params)203     explicit FileIterator(const Params& params)
204         : DatasetIterator<FileDatasetBase>(params) {
205       if (params.dataset->env_
206               ->FileExists(MetaFilename(params.dataset->filename_))
207               .ok()) {
208         mode_ = Mode::read;
209       } else {
210         mode_ = Mode::write;
211       }
212     }
213 
Initialize(IteratorContext * ctx)214     Status Initialize(IteratorContext* ctx) override {
215       mutex_lock l(mu_);
216       return InitializeIterator(ctx);
217     }
218 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)219     Status GetNextInternal(IteratorContext* ctx,
220                            std::vector<Tensor>* out_tensors,
221                            bool* end_of_sequence) override {
222       mutex_lock l(mu_);
223       return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
224     }
225 
226    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const227     std::shared_ptr<model::Node> CreateNode(
228         IteratorContext* ctx, model::Node::Args args) const override {
229       return model::MakeKnownRatioNode(std::move(args),
230                                        /*ratio=*/1);
231     }
232 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)233     Status SaveInternal(SerializationContext* ctx,
234                         IteratorStateWriter* writer) override {
235       mutex_lock l(mu_);
236       TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kMode), mode_));
237       return SaveInput(ctx, writer, iterator_);
238     }
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)239     Status RestoreInternal(IteratorContext* ctx,
240                            IteratorStateReader* reader) override {
241       mutex_lock l(mu_);
242       {
243         int64_t temp;
244         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kMode), &temp));
245         mode_ = static_cast<Mode>(temp);
246       }
247       if (mode_ == Mode::write &&
248           dataset()
249               ->env_->FileExists(MetaFilename(dataset()->filename_))
250               .ok()) {
251         // This could happen if the cache was completely written after the
252         // checkpoint was saved.
253         LOG(WARNING)
254             << "It looks like the cache was already completely written("
255             << MetaFilename(dataset()->filename_)
256             << ") after the last checkpoint was saved. Attempting to read "
257             << "the cache instead of continuing to write. If this is a "
258             << "mistake, please remove the above file and try running again.";
259         mode_ = Mode::read;
260       }
261       TF_RETURN_IF_ERROR(InitializeIterator(ctx));
262       return RestoreInput(ctx, reader, iterator_);
263     }
264 
265    private:
266     // FileWriterIterator passes through and caches items from the input
267     // FileDatasetBase.
268     //
269     // This iterator is used when the cache directory is not found on disk. It
270     // creates the cache directory, and passes on the underlying iterator's
271     // elements.
272     //
273     // Caching is performed by writing the input tensors to disk using the
274     // `BundleWriter`. Note that the cache gets fully flushed to disk only
275     // after the input iterator has been fully exhausted. If the program
276     // exits, before completion of an epoch, the cached state would be lost.
277     // To ensure that the partial cache persists across sessions, one should
278     // checkpoint the input pipeline. On each call to `SaveInternal` the
279     // partial cache gets flushed to disk in files with prefix
280     // <filename>_<shard_id> where shard_id is unique for each checkpoint.
281     // When all elements have been produced, these shards get coalesced.
282     class FileWriterIterator : public DatasetIterator<FileDatasetBase> {
283      public:
FileWriterIterator(const Params & params)284       explicit FileWriterIterator(const Params& params)
285           : DatasetIterator<FileDatasetBase>(params),
286             cur_index_(0),
287             shard_id_(0),
288             filename_(
289                 strings::StrCat(params.dataset->filename_, "_", shard_id_)),
290             lockfile_(strings::StrCat(filename_, kLockFileSuffix)),
291             lockfile_created_(false),
292             iteration_completed_(false) {}
293 
~FileWriterIterator()294       ~FileWriterIterator() override {
295         if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
296           LOG(WARNING) << kIncompleteCacheErrorMessage;
297           std::vector<string> cache_files;
298           Status s = dataset()->env_->GetMatchingPaths(
299               strings::StrCat(filename_, "*"), &cache_files);
300           if (!s.ok()) {
301             LOG(WARNING) << "Failed to get matching files on " << filename_
302                          << "* : " << s.ToString();
303           }
304           for (const string& path : cache_files) {
305             s = dataset()->env_->DeleteFile(path);
306             if (!s.ok()) {
307               LOG(WARNING) << "Failed to delete " << path << " : "
308                            << s.ToString();
309             }
310           }
311         }
312       }
313 
Initialize(IteratorContext * ctx)314       Status Initialize(IteratorContext* ctx) override {
315         return dataset()->input_->MakeIterator(ctx, this, prefix(),
316                                                &input_impl_);
317       }
318 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)319       Status GetNextInternal(IteratorContext* ctx,
320                              std::vector<Tensor>* out_tensors,
321                              bool* end_of_sequence) override {
322         mutex_lock l(mu_);
323         *end_of_sequence = false;
324         TF_RETURN_IF_ERROR(EnsureLockFileExists(end_of_sequence));
325         if (*end_of_sequence) {
326           return OkStatus();
327         }
328         TF_RETURN_IF_ERROR(writer_->status());
329         if (cur_index_ >= kMaxItems) {
330           // As a courtesy, close the [truncated] cache file.
331           Status s = Finish();
332           if (!s.ok()) {
333             LOG(ERROR) << s;
334           }
335           return errors::InvalidArgument(
336               "Upstream iterator is producing more than ", kMaxItems,
337               " items, which is more than the cache limit.");
338         }
339 
340         TF_RETURN_IF_ERROR(
341             input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
342         if (*end_of_sequence && out_tensors->empty()) {
343           TF_RETURN_IF_ERROR(Finish());
344           cur_index_++;
345           return OkStatus();
346         }
347         if (out_tensors->size() != dataset()->num_tensors_) {
348           return errors::Internal(
349               "Upstream iterator returned invalid number of tensors. "
350               "Expected ",
351               dataset()->num_tensors_, " got: ", out_tensors->size());
352         }
353         size_t tensor_index = 0;
354         for (const Tensor& t : *out_tensors) {
355           DCHECK_LT(tensor_index, dataset()->num_tensors_);
356           string key = dataset()->FormatName(cur_index_, tensor_index++);
357           TF_RETURN_IF_ERROR(writer_->Add(key, t));
358         }
359         if (*end_of_sequence) {
360           TF_RETURN_IF_ERROR(Finish());
361         }
362         cur_index_++;
363         return OkStatus();
364       }
365 
366      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const367       std::shared_ptr<model::Node> CreateNode(
368           IteratorContext* ctx, model::Node::Args args) const override {
369         return model::MakeKnownRatioNode(std::move(args),
370                                          /*ratio=*/1);
371       }
372 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)373       Status SaveInternal(SerializationContext* ctx,
374                           IteratorStateWriter* writer) override {
375         mutex_lock l(mu_);
376         TF_RETURN_IF_ERROR(
377             writer->WriteScalar(full_name(kCurIndex), cur_index_));
378 
379         if (iteration_completed_) {
380           TF_RETURN_IF_ERROR(
381               writer->WriteScalar(full_name(kIterationCompleted), ""));
382           return OkStatus();
383         }
384 
385         // lockfile is created on the first call to GetNextInternal. The
386         // absence of a lockfile means that GetNextInternal was not called
387         // and hence nothing was written to cache. So we don't need to worry
388         // about flushing the current shard. This ensures that we never write
389         // empty shards.
390         if (lockfile_created_) {
391           // Flush the current bundle.
392           TF_RETURN_IF_ERROR(writer_->Finish());
393 
394           // Note: We do not delete the lockfile here. We keep lockfiles of
395           // all shards around until the entire cache has been written to
396           // prevent concurrent iterators from corrupting any of the shards.
397 
398           // Start caching to a new shard.
399           shard_id_++;
400           filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_);
401           lockfile_ = strings::StrCat(filename_, kLockFileSuffix);
402           lockfile_created_ = false;
403         }
404         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
405         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kShardId), shard_id_));
406         return OkStatus();
407       }
408 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)409       Status RestoreInternal(IteratorContext* ctx,
410                              IteratorStateReader* reader) override {
411         mutex_lock l(mu_);
412         int64_t temp;
413         // TODO(b/78048575): Update this when saving size_t tensors directly
414         // is supported.
415         {
416           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurIndex), &temp));
417           cur_index_ = static_cast<size_t>(temp);
418           if (cur_index_ != temp) {
419             return errors::Internal("Invalid value for cur_index ", temp);
420           }
421         }
422 
423         if (reader->Contains(full_name(kIterationCompleted))) {
424           iteration_completed_ = true;
425           return OkStatus();
426         }
427 
428         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
429 
430         // TODO(b/78048575): Update this when saving size_t tensors directly
431         // is supported.
432         {
433           TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kShardId), &temp));
434           shard_id_ = static_cast<size_t>(temp);
435           if (shard_id_ != temp) {
436             return errors::Internal("Invalid value for shard_id ", temp);
437           }
438         }
439         filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_);
440         lockfile_ = strings::StrCat(filename_, kLockFileSuffix);
441         writer_ = std::make_unique<BundleWriter>(dataset()->env_, filename_);
442         return OkStatus();
443       }
444 
445      private:
EnsureLockFileExists(bool * end_of_sequence)446       Status EnsureLockFileExists(bool* end_of_sequence)
447           TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
448         if (iteration_completed_) {
449           *end_of_sequence = true;
450           return OkStatus();
451         }
452         if (lockfile_created_) {
453           return OkStatus();
454         }
455 
456         // Perform rudimentary locking to help catch concurrent writes to the
457         // same cache files.
458 
459         // 1. Check that a checkpoint for the shard has not already been
460         // written.
461         if (dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
462           return errors::AlreadyExists("Existing cache files found: \n",
463                                        MetaFilename(filename_), "\n",
464                                        DataFilename(filename_, 0, 1), "\n",
465                                        "To continue delete the above files.");
466         }
467 
468         // 2. Check that there isn't a concurrent iterator that is writing
469         // to cache.
470         if (dataset()->env_->FileExists(lockfile_).ok()) {
471           // Attempt to read the contents of the lockfile.
472           char contents_scratch[151] = {0};  // Initialize all to 0.
473           StringPiece contents;
474           std::unique_ptr<RandomAccessFile> file;
475           if (dataset()->env_->NewRandomAccessFile(lockfile_, &file).ok()) {
476             file->Read(0, 150, &contents, contents_scratch).IgnoreError();
477           }
478           return errors::AlreadyExists(
479               "There appears to be a concurrent caching iterator running - "
480               "cache lockfile already exists ('",
481               lockfile_,
482               "'). If you are sure no other running TF computations are "
483               "using this cache prefix, delete the lockfile and "
484               "re-initialize the iterator. Lockfile contents: ",
485               contents);
486         }
487         // Create the file, and write some basic contents.
488         std::unique_ptr<WritableFile> lockfile;
489         TF_RETURN_IF_ERROR(
490             dataset()->env_->NewWritableFile(lockfile_, &lockfile));
491         TF_RETURN_IF_ERROR(lockfile->Append(
492             strings::StrCat(kCreatedAt, ": ", EnvTime::NowSeconds())));
493 
494         // At this point we know that
495         // 1. There is no conflicting checkpoint with prefix `filename_`.
496         // 2. There is no concurrent session that is trying to write a ckpt
497         //    to filename.
498         // So it is safe to create a BundleWriter here. Note that it is
499         // unsafe to initialize the BundleWriter anywhere the above
500         // conditions are not met since BundleWriter's constructor creates
501         // new temp files which can delete the temp files created by a
502         // BundleWriter in another Session.
503         writer_ = std::make_unique<BundleWriter>(dataset()->env_, filename_);
504         lockfile_created_ = true;
505         return OkStatus();
506       }
507 
Finish()508       Status Finish() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
509         iteration_completed_ = true;
510         // Flush the current bundle.
511         TF_RETURN_IF_ERROR(writer_->Finish());
512         // Merge all the bundles.
513         // Currently there are `shard_id_ + 1` bundles, one for each
514         // checkpoint. Each bundle has prefix <filename>_<id> where `id` is an
515         // integer starting at 0 and incremented by 1 for each new checkpoint.
516         // We merge all these bundles into a bundle with prefix <filename> so
517         // that the next call to `MakeIterator` can build a
518         // `FileReaderIterator`.
519         {
520           std::vector<tstring> prefixes;
521           prefixes.reserve(shard_id_ + 1);
522           for (size_t i = 0; i <= shard_id_; ++i) {
523             prefixes.emplace_back(
524                 strings::StrCat(dataset()->filename_, "_", i));
525           }
526           TF_RETURN_IF_ERROR(
527               MergeBundles(dataset()->env_, prefixes, dataset()->filename_));
528         }
529         // Delete all lockfiles.
530         for (size_t i = 0; i <= shard_id_; ++i) {
531           TF_RETURN_IF_ERROR(dataset()->env_->DeleteFile(
532               strings::StrCat(dataset()->filename_, "_", i, kLockFileSuffix)));
533         }
534         return OkStatus();
535       }
536 
537       mutex mu_;
538       size_t cur_index_ TF_GUARDED_BY(mu_);
539       // Index of the current shard. This gets incremented whenever a new
540       // cache shard is saved.
541       size_t shard_id_ TF_GUARDED_BY(mu_);
542       std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
543       // The current prefix for the cache file. This is equal to
544       // `StrCat(dataset()->filename_, "_", shard_id_)`.
545       string filename_;
546       std::unique_ptr<BundleWriter> writer_ TF_GUARDED_BY(mu_);
547       string lockfile_ TF_GUARDED_BY(mu_);
548       bool lockfile_created_ TF_GUARDED_BY(mu_);
549       bool iteration_completed_ TF_GUARDED_BY(mu_);
550     };  // FileWriterIterator
551 
552     class FileReaderIterator : public DatasetIterator<FileDatasetBase> {
553      public:
FileReaderIterator(const Params & params)554       explicit FileReaderIterator(const Params& params)
555           : DatasetIterator<FileDatasetBase>(params),
556             cur_index_(0),
557             reader_(dataset()->env_, dataset()->filename_),
558             iterator_restored_(false) {}
559 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)560       Status GetNextInternal(IteratorContext* ctx,
561                              std::vector<Tensor>* out_tensors,
562                              bool* end_of_sequence) override {
563         mutex_lock l(mu_);
564         *end_of_sequence = false;
565         TF_RETURN_IF_ERROR(reader_.status());
566         if (!reader_.Valid()) {
567           *end_of_sequence = true;
568           return OkStatus();
569         }
570         out_tensors->clear();
571         out_tensors->resize(dataset()->num_tensors_);
572 
573         for (size_t i = 0; i < dataset()->num_tensors_; ++i) {
574           // When the iterator is restored from the checkpoint, `reader_` is
575           // already pointing at `key` so we do not need to skip the header
576           // entry.
577           if (!iterator_restored_) {
578             reader_.Next();  // The first entry in the table is a header.
579           } else {
580             iterator_restored_ = false;
581           }
582           if (!reader_.Valid()) {
583             out_tensors->clear();
584             *end_of_sequence = true;
585             return OkStatus();
586           }
587           StringPiece key = reader_.key();
588           DCHECK_EQ(key, dataset()->FormatName(cur_index_, i));
589           TF_RETURN_IF_ERROR(reader_.ReadCurrent(&(*out_tensors)[i]));
590           TF_RETURN_IF_ERROR(reader_.status());
591         }
592         cur_index_++;
593         return OkStatus();
594       }
595 
596      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const597       std::shared_ptr<model::Node> CreateNode(
598           IteratorContext* ctx, model::Node::Args args) const override {
599         return model::MakeKnownRatioNode(std::move(args),
600                                          /*ratio=*/1);
601       }
602 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)603       Status SaveInternal(SerializationContext* ctx,
604                           IteratorStateWriter* writer) override {
605         mutex_lock l(mu_);
606         TF_RETURN_IF_ERROR(
607             writer->WriteScalar(full_name(kCurIndex), cur_index_));
608         return OkStatus();
609       }
610 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * iterator_state_reader)611       Status RestoreInternal(
612           IteratorContext* ctx,
613           IteratorStateReader* iterator_state_reader) override {
614         mutex_lock l(mu_);
615         {
616           // TODO(b/78048575): Update this when saving size_t tensors directly
617           // is supported.
618           int64_t temp;
619           TF_RETURN_IF_ERROR(
620               iterator_state_reader->ReadScalar(full_name(kCurIndex), &temp));
621           cur_index_ = static_cast<size_t>(temp);
622           if (cur_index_ != temp) {
623             return errors::Internal("Invalid value for cur_index ", temp);
624           }
625         }
626         if (!reader_.Valid()) {
627           return errors::Internal("Error initializing BundleReader.");
628         }
629         reader_.Seek(dataset()->FormatName(cur_index_, 0));
630         iterator_restored_ = true;
631         return OkStatus();
632       }
633 
634      private:
635       mutex mu_;
636       size_t cur_index_ TF_GUARDED_BY(mu_);
637       BundleReader reader_ TF_GUARDED_BY(mu_);
638       bool iterator_restored_ TF_GUARDED_BY(mu_);
639     };  // FileReaderIterator
640 
InitializeIterator(IteratorContext * ctx)641     Status InitializeIterator(IteratorContext* ctx)
642         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
643       // We intentionally use the same prefix for both `FileReaderIterator` and
644       // `FileWriterIterator`. Since at any time there will be at most one of
645       // them alive, there should be no conflicts. This allows both iterators to
646       // use a common key for `cur_index`. We leverage this in the corner case
647       // when this iterator is restored from an old checkpoint in `write` mode
648       // and the cache has been completely flushed to disk since then. In that
649       // case we simply build a `FileReaderIterator` and seek to the
650       // `cur_index`.
651       switch (mode_) {
652         case Mode::read:
653           iterator_ =
654               std::make_unique<FileReaderIterator>(FileReaderIterator::Params{
655                   dataset(), strings::StrCat(prefix(), kImpl)});
656           break;
657         case Mode::write:
658           iterator_ =
659               std::make_unique<FileWriterIterator>(FileWriterIterator::Params{
660                   dataset(), strings::StrCat(prefix(), kImpl)});
661       }
662       TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
663       return iterator_->Initialize(ctx);
664     }
665 
666     mutex mu_;
667     enum Mode { read, write };
668     Mode mode_ TF_GUARDED_BY(mu_);
669     std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
670   };  // FileIterator
671 
672   Env* const env_;
673   const size_t num_tensors_;
674   const size_t tensor_index_padding_size_;
675   static constexpr size_t kMaxItems = 10000000;  // 10 million
676   const size_t item_index_padding_size_;
677   const string tensor_format_string_;
678 };  // FileDatasetBase
679 
680 class CacheDatasetOp::FileDataset : public CacheDatasetOp::FileDatasetBase {
681  public:
682   using FileDatasetBase::FileDatasetBase;
683 
684  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const685   Status AsGraphDefInternal(SerializationContext* ctx,
686                             DatasetGraphDefBuilder* b,
687                             Node** output) const override {
688     Node* input_graph = nullptr;
689     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph));
690     Node* filename = nullptr;
691     TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename));
692     TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output));
693     return OkStatus();
694   }
695 };
696 
697 class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDatasetBase {
698  public:
FileDatasetV2(OpKernelContext * ctx,const DatasetBase * input,string filename,Env * env,const Tensor & resource_handle)699   explicit FileDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
700                          string filename, Env* env,
701                          const Tensor& resource_handle)
702       : FileDatasetBase(ctx, input, filename, env),
703         resource_handle_(resource_handle) {}
704 
705  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const706   Status AsGraphDefInternal(SerializationContext* ctx,
707                             DatasetGraphDefBuilder* b,
708                             Node** output) const override {
709     Node* input_node = nullptr;
710     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
711     Node* filename_node = nullptr;
712     TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename_node));
713     Node* resource_handle_node = nullptr;
714     TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
715     TF_RETURN_IF_ERROR(b->AddDataset(
716         this, {input_node, filename_node, resource_handle_node}, output));
717     return OkStatus();
718   }
719 
720  private:
721   const Tensor resource_handle_;
722 };
723 
724 class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
725  public:
MemoryDatasetBase(OpKernelContext * ctx,const DatasetBase * input,std::shared_ptr<MemoryCache> cache)726   explicit MemoryDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
727                              std::shared_ptr<MemoryCache> cache)
728       : DatasetBase(DatasetContext(ctx)),
729         input_(input),
730         cache_(std::move(cache)) {
731     input_->Ref();
732   }
733 
~MemoryDatasetBase()734   ~MemoryDatasetBase() override { input_->Unref(); }
735 
MakeIteratorInternal(const string & prefix) const736   std::unique_ptr<IteratorBase> MakeIteratorInternal(
737       const string& prefix) const override {
738     name_utils::IteratorPrefixParams params;
739     params.dataset_prefix = kMemoryDatasetPrefix;
740     return std::make_unique<MemoryIterator>(
741         MemoryIterator::Params{
742             this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
743         cache_.get());
744   }
745 
output_dtypes() const746   const DataTypeVector& output_dtypes() const override {
747     return input_->output_dtypes();
748   }
749 
output_shapes() const750   const std::vector<PartialTensorShape>& output_shapes() const override {
751     return input_->output_shapes();
752   }
753 
DebugString() const754   string DebugString() const override {
755     name_utils::DatasetDebugStringParams params;
756     params.dataset_prefix = kMemoryDatasetPrefix;
757     return name_utils::DatasetDebugString(kDatasetType, params);
758   }
759 
CardinalityInternal() const760   int64_t CardinalityInternal() const override {
761     return input_->Cardinality();
762   };
763 
CardinalityInternal(CardinalityOptions options) const764   int64_t CardinalityInternal(CardinalityOptions options) const override {
765     return input_->Cardinality(options);
766   };
767 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const768   Status Get(OpKernelContext* ctx, int64 index,
769              std::vector<Tensor>* out_tensors) const override {
770     mutex_lock l(mu_);
771 
772     CardinalityOptions options;
773     options.set_compute_level(CardinalityOptions::CARDINALITY_COMPUTE_LOW);
774     int64_t cardinality = Cardinality(options);
775 
776     if (cardinality != kUnknownCardinality &&
777         cardinality != kInfiniteCardinality && index >= cardinality) {
778       return errors::OutOfRange("Index out of range [0, ", cardinality,
779                                 "):", index);
780     }
781     if (!partial_cache_) {
782       partial_cache_ = std::make_unique<PartialCache>(input_);
783     }
784     return partial_cache_->Get(ctx, index, out_tensors);
785   }
786 
InputDatasets(std::vector<const DatasetBase * > * inputs) const787   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
788     inputs->push_back(input_);
789     return OkStatus();
790   }
791 
CheckExternalState() const792   Status CheckExternalState() const override {
793     return input_->CheckExternalState();
794   }
795 
796  protected:
797   class MemoryIterator : public DatasetIterator<MemoryDatasetBase> {
798    public:
MemoryIterator(const Params & params,MemoryCache * cache)799     explicit MemoryIterator(const Params& params, MemoryCache* cache)
800         : DatasetIterator<MemoryDatasetBase>(params), cache_(cache) {}
801 
Initialize(IteratorContext * ctx)802     Status Initialize(IteratorContext* ctx) override {
803       mutex_lock l(mu_);
804       return InitializeIterator(ctx);
805     }
806 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)807     Status GetNextInternal(IteratorContext* ctx,
808                            std::vector<Tensor>* out_tensors,
809                            bool* end_of_sequence) override {
810       mutex_lock l(mu_);
811       return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
812     }
813 
814    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const815     std::shared_ptr<model::Node> CreateNode(
816         IteratorContext* ctx, model::Node::Args args) const override {
817       return model::MakeKnownRatioNode(std::move(args),
818                                        /*ratio=*/1);
819     }
820 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)821     Status SaveInternal(SerializationContext* ctx,
822                         IteratorStateWriter* writer) override {
823       mutex_lock l(mu_);
824       if (cache_->IsCompleted()) {
825         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheCompleted), ""));
826         TF_RETURN_IF_ERROR(
827             WriteElementsToCheckpoint(writer, prefix(), cache_->data()));
828       }
829       return SaveInput(ctx, writer, iterator_);
830     }
831 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)832     Status RestoreInternal(IteratorContext* ctx,
833                            IteratorStateReader* reader) override {
834       mutex_lock l(mu_);
835       iterator_.reset();
836       cache_->Reset();
837       if (reader->Contains(full_name(kCacheCompleted))) {
838         std::vector<std::vector<Tensor>> temp_cache;
839         TF_RETURN_IF_ERROR(
840             ReadElementsFromCheckpoint(ctx, reader, prefix(), &temp_cache));
841         cache_->Complete(std::move(temp_cache));
842       }
843       TF_RETURN_IF_ERROR(InitializeIterator(ctx));
844       return RestoreInput(ctx, reader, iterator_);
845     }
846 
847    private:
848     class MemoryWriterIterator : public DatasetIterator<MemoryDatasetBase> {
849      public:
MemoryWriterIterator(const Params & params,MemoryCache * cache)850       explicit MemoryWriterIterator(const Params& params, MemoryCache* cache)
851           : DatasetIterator<MemoryDatasetBase>(params), cache_(cache) {}
852 
~MemoryWriterIterator()853       ~MemoryWriterIterator() override {
854         mutex_lock l(mu_);
855         if (!temp_cache_.empty() && !cache_->IsCompleted()) {
856           LOG(WARNING) << kIncompleteCacheErrorMessage;
857           cache_->Reset();
858         }
859       }
860 
Initialize(IteratorContext * ctx)861       Status Initialize(IteratorContext* ctx) override {
862         return dataset()->input_->MakeIterator(ctx, this, prefix(),
863                                                &input_impl_);
864       }
865 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)866       Status GetNextInternal(IteratorContext* ctx,
867                              std::vector<Tensor>* out_tensors,
868                              bool* end_of_sequence) override {
869         mutex_lock l(mu_);
870         TF_RETURN_IF_ERROR(
871             input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
872         if (*end_of_sequence) {
873           if (!cache_->IsCompleted()) {
874             VLOG(2) << "Finalizing the cache because EOF has been reached.";
875             cache_->Complete(std::move(temp_cache_));
876           }
877           return OkStatus();
878         }
879         RecordBufferEnqueue(ctx, *out_tensors);
880         temp_cache_.emplace_back(*out_tensors);
881         if (temp_cache_.size() == dataset()->input_->Cardinality()) {
882           VLOG(2) << "Finalizing the cache because its size matches the "
883                      "expected input cardinality.";
884           cache_->Complete(std::move(temp_cache_));
885         }
886         return OkStatus();
887       }
888 
889      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const890       std::shared_ptr<model::Node> CreateNode(
891           IteratorContext* ctx, model::Node::Args args) const override {
892         return model::MakeKnownRatioNode(std::move(args),
893                                          /*ratio=*/1);
894       }
895 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)896       Status SaveInternal(SerializationContext* ctx,
897                           IteratorStateWriter* writer) override {
898         mutex_lock l(mu_);
899         if (!cache_->IsCompleted()) {
900           TF_RETURN_IF_ERROR(
901               WriteElementsToCheckpoint(writer, prefix(), temp_cache_));
902         }
903         return SaveInput(ctx, writer, input_impl_);
904       }
905 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)906       Status RestoreInternal(IteratorContext* ctx,
907                              IteratorStateReader* reader) override {
908         mutex_lock l(mu_);
909         if (!reader->Contains(full_name(kCacheCompleted))) {
910           TF_RETURN_IF_ERROR(
911               ReadElementsFromCheckpoint(ctx, reader, prefix(), &temp_cache_));
912         }
913         return RestoreInput(ctx, reader, input_impl_);
914       }
915 
916      private:
917       mutex mu_;
918       std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
919       MemoryCache* const cache_ TF_GUARDED_BY(mu_);  // not owned.
920       std::vector<std::vector<Tensor>> temp_cache_ TF_GUARDED_BY(mu_);
921     };  // MemoryWriterIterator
922 
923     class MemoryReaderIterator : public DatasetIterator<MemoryDatasetBase> {
924      public:
MemoryReaderIterator(const Params & params,MemoryCache * cache)925       explicit MemoryReaderIterator(const Params& params, MemoryCache* cache)
926           : DatasetIterator<MemoryDatasetBase>(params),
927             cache_(cache),
928             index_(0) {}
929 
Initialize(IteratorContext * ctx)930       Status Initialize(IteratorContext* ctx) override {
931         // The memory allocated for the cache is owned by the parent
932         // dataset but performance modeling uses the iterator abstraction and
933         // thus we record the memory allocated for the cache here. The caveat
934         // is that this is incorrect if there are concurrent instances of this
935         // iterator.
936         tf_shared_lock l(mu_);
937         for (size_t i = 0; i < cache_->size(); ++i) {
938           RecordBufferEnqueue(ctx, cache_->at(i));
939         }
940         return OkStatus();
941       }
942 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)943       Status GetNextInternal(IteratorContext* ctx,
944                              std::vector<Tensor>* out_tensors,
945                              bool* end_of_sequence) override {
946         mutex_lock l(mu_);
947         if (index_ < cache_->size()) {
948           const std::vector<Tensor>& cache_tensors = cache_->at(index_);
949           out_tensors->insert(out_tensors->begin(), cache_tensors.begin(),
950                               cache_tensors.end());
951           index_++;
952           *end_of_sequence = false;
953           return OkStatus();
954         } else {
955           *end_of_sequence = true;
956           return OkStatus();
957         }
958       }
959 
960      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const961       std::shared_ptr<model::Node> CreateNode(
962           IteratorContext* ctx, model::Node::Args args) const override {
963         return model::MakeKnownRatioNode(std::move(args),
964                                          /*ratio=*/1);
965       }
966 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)967       Status SaveInternal(SerializationContext* ctx,
968                           IteratorStateWriter* writer) override {
969         mutex_lock l(mu_);
970         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_));
971         return OkStatus();
972       }
973 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)974       Status RestoreInternal(IteratorContext* ctx,
975                              IteratorStateReader* reader) override {
976         mutex_lock l(mu_);
977         {
978           // kIndex will not be set if we are restoring from a checkpoint
979           // written by a MemoryWriterIterator that has completed its cache.
980           int64_t temp = cache_->size();
981           if (reader->Contains(full_name(kIndex))) {
982             TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &temp));
983           }
984           index_ = static_cast<size_t>(temp);
985         }
986         return OkStatus();
987       }
988 
989      private:
990       mutex mu_;
991       MemoryCache* const cache_ TF_GUARDED_BY(mu_);  // not owned.
992       size_t index_ TF_GUARDED_BY(mu_);
993     };  // MemoryReaderIterator
994 
InitializeIterator(IteratorContext * ctx)995     Status InitializeIterator(IteratorContext* ctx)
996         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
997       if (cache_->IsCompleted()) {
998         iterator_ = std::make_unique<MemoryReaderIterator>(
999             MemoryReaderIterator::Params{dataset(),
1000                                          strings::StrCat(prefix(), kImpl)},
1001             cache_);
1002       } else {
1003         iterator_ = std::make_unique<MemoryWriterIterator>(
1004             MemoryWriterIterator::Params{dataset(),
1005                                          strings::StrCat(prefix(), kImpl)},
1006             cache_);
1007       }
1008       TF_RETURN_IF_ERROR(iterator_->InitializeBase(ctx, this));
1009       return iterator_->Initialize(ctx);
1010     }
1011 
1012     mutex mu_;
1013     MemoryCache* cache_ TF_GUARDED_BY(mu_);  // not owned.
1014     std::unique_ptr<IteratorBase> iterator_ TF_GUARDED_BY(mu_);
1015   };  // MemoryIterator
1016 
1017   mutable mutex mu_;
1018   const DatasetBase* const input_;
1019   const std::shared_ptr<MemoryCache> cache_;
1020   mutable std::unique_ptr<PartialCache> partial_cache_ TF_GUARDED_BY(mu_);
1021 };  // MemoryDatasetBase
1022 
1023 // This version of memory dataset has an exclusive ownership of the memory cache
1024 // resource. It supports sharing of the cache across different iterations of the
1025 // `repeat` transformation but not across different iterators.
1026 class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase {
1027  public:
MemoryDataset(OpKernelContext * ctx,const DatasetBase * input,MemoryCacheManager * manager,ResourceHandle && resource_handle)1028   MemoryDataset(OpKernelContext* ctx, const DatasetBase* input,
1029                 MemoryCacheManager* manager, ResourceHandle&& resource_handle)
1030       : MemoryDatasetBase(ctx, input, manager->get()),
1031         manager_(manager),
1032         resource_handle_(std::move(resource_handle)),
1033         resource_mgr_(ctx->resource_manager()) {}
1034 
~MemoryDataset()1035   ~MemoryDataset() override {
1036     manager_->Unref();
1037     Status s = resource_mgr_->Delete<MemoryCacheManager>(
1038         resource_handle_.container(), resource_handle_.name());
1039     if (!s.ok()) {
1040       LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
1041     }
1042   }
1043 
1044  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const1045   Status AsGraphDefInternal(SerializationContext* ctx,
1046                             DatasetGraphDefBuilder* b,
1047                             Node** output) const override {
1048     Node* input_node = nullptr;
1049     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
1050     Node* filename_node = nullptr;
1051     TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
1052     TF_RETURN_IF_ERROR(
1053         b->AddDataset(this, {input_node, filename_node}, output));
1054     return OkStatus();
1055   }
1056 
1057  private:
1058   MemoryCacheManager* const manager_;  // Owned.
1059   const ResourceHandle resource_handle_;
1060   ResourceMgr* const resource_mgr_;  // Not owned.
1061 };
1062 
1063 // This version of memory dataset has a shared ownership of the memory cache
1064 // resource. It supports sharing of the cache across different iterations of
1065 // the `repeat` transformation and also across different iterators.
1066 class CacheDatasetOp::MemoryDatasetV2
1067     : public CacheDatasetOp::MemoryDatasetBase {
1068  public:
MemoryDatasetV2(OpKernelContext * ctx,const DatasetBase * input,MemoryCacheManager * manager,ResourceHandle && resource_handle,bool owns_resource)1069   MemoryDatasetV2(OpKernelContext* ctx, const DatasetBase* input,
1070                   MemoryCacheManager* manager, ResourceHandle&& resource_handle,
1071                   bool owns_resource)
1072       : MemoryDatasetBase(ctx, input, manager->get()),
1073         manager_(manager),
1074         owns_resource_(owns_resource),
1075         resource_handle_(std::move(resource_handle)),
1076         resource_mgr_(ctx->resource_manager()) {}
1077 
~MemoryDatasetV2()1078   ~MemoryDatasetV2() override {
1079     manager_->Unref();
1080     if (owns_resource_) {
1081       Status s = resource_mgr_->Delete<MemoryCacheManager>(
1082           resource_handle_.container(), resource_handle_.name());
1083       if (!s.ok()) {
1084         LOG(WARNING) << "Failed to delete cache resource: " << s.ToString();
1085       }
1086     }
1087   }
1088 
1089  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const1090   Status AsGraphDefInternal(SerializationContext* ctx,
1091                             DatasetGraphDefBuilder* b,
1092                             Node** output) const override {
1093     Node* input_node = nullptr;
1094     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
1095     Node* filename_node = nullptr;
1096     TF_RETURN_IF_ERROR(b->AddScalar(tstring(""), &filename_node));
1097     Node* resource_handle_node = nullptr;
1098     Tensor handle(DT_RESOURCE, TensorShape({}));
1099     handle.scalar<ResourceHandle>()() = resource_handle_;
1100     TF_RETURN_IF_ERROR(b->AddTensor(handle, &resource_handle_node));
1101     TF_RETURN_IF_ERROR(b->AddDataset(
1102         this, {input_node, filename_node, resource_handle_node}, output));
1103     return OkStatus();
1104   }
1105 
1106  private:
1107   MemoryCacheManager* const manager_;  // Owned.
1108   const bool owns_resource_;
1109   const ResourceHandle resource_handle_;
1110   ResourceMgr* const resource_mgr_;  // Not owned.
1111 };
1112 
CacheDatasetOp(OpKernelConstruction * ctx)1113 CacheDatasetOp::CacheDatasetOp(OpKernelConstruction* ctx)
1114     : UnaryDatasetOpKernel(ctx),
1115       op_version_(ctx->def().op() == kCacheDataset ? 1 : 2) {}
1116 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)1117 void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
1118                                  DatasetBase** output) {
1119   // Parse out the filenames tensor.
1120   tstring filename;
1121   OP_REQUIRES_OK(ctx, ParseScalarArgument<tstring>(ctx, kFileName, &filename));
1122   if (filename.empty()) {
1123     static std::atomic<int64_t> resource_id_counter(0);
1124     const string& container = ctx->resource_manager()->default_container();
1125     auto name = strings::StrCat(ctx->op_kernel().name(), "/", kMemoryCache, "_",
1126                                 resource_id_counter.fetch_add(1));
1127     if (op_version_ == 2) {
1128       bool owns_resource = false;
1129       MemoryCacheManager* manager = nullptr;
1130       auto handle = HandleFromInput(ctx, 2);
1131       Status s = ctx->resource_manager()->Lookup<MemoryCacheManager>(
1132           handle.container(), handle.name(), &manager);
1133       if (errors::IsNotFound(s)) {
1134         owns_resource = true;
1135         OP_REQUIRES_OK(
1136             ctx,
1137             ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
1138                 container, name, &manager, [](MemoryCacheManager** manager) {
1139                   *manager = new MemoryCacheManager();
1140                   return OkStatus();
1141                 }));
1142         handle = MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
1143       } else {
1144         OP_REQUIRES_OK(ctx, s);
1145       }
1146       // Ownership of manager is transferred onto `MemoryDatasetV2`.
1147       *output = new MemoryDatasetV2(ctx, input, manager, std::move(handle),
1148                                     owns_resource);
1149     } else {
1150       MemoryCacheManager* manager;
1151       OP_REQUIRES_OK(
1152           ctx, ctx->resource_manager()->LookupOrCreate<MemoryCacheManager>(
1153                    container, name, &manager, [](MemoryCacheManager** manager) {
1154                      *manager = new MemoryCacheManager();
1155                      return OkStatus();
1156                    }));
1157       auto handle =
1158           MakeResourceHandle<MemoryCacheManager>(ctx, container, name);
1159       // Ownership of manager is transferred onto `MemoryDataset`.
1160       *output = new MemoryDataset(ctx, input, manager, std::move(handle));
1161     }
1162   } else {
1163     if (op_version_ == 2) {
1164       *output =
1165           new FileDatasetV2(ctx, input, filename, ctx->env(), ctx->input(2));
1166     } else {
1167       *output = new FileDataset(ctx, input, filename, ctx->env());
1168     }
1169   }
1170 }
1171 
1172 namespace {
1173 REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU),
1174                         CacheDatasetOp);
1175 REGISTER_KERNEL_BUILDER(Name("CacheDatasetV2").Device(DEVICE_CPU),
1176                         CacheDatasetOp);
1177 }  // namespace
1178 }  // namespace data
1179 }  // namespace tensorflow
1180