xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/standalone.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DATA_STANDALONE_H_
17 #define TENSORFLOW_CORE_DATA_STANDALONE_H_
18 
19 #include <functional>
20 #include <memory>
21 
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/data/unbounded_thread_pool.h"
24 #include "tensorflow/core/framework/dataset.h"
25 #include "tensorflow/core/framework/function_handle_cache.h"
26 #include "tensorflow/core/lib/core/threadpool.h"
27 #include "tensorflow/core/public/session_options.h"
28 
29 namespace tensorflow {
30 namespace data {
31 namespace standalone {
32 
33 // The purpose of the API in this file is to facilitate standalone execution of
34 // a tf.data input pipeline graph.
35 //
36 // The API exposes two abstractions -- a `Dataset` and an `Iterator` -- which
37 // encapsulate TensorFlow runtime.
38 //
39 // The `Dataset` abstraction represents an input pipeline as a collection
40 // of data sources and a logical plan of transformations that operate over the
41 // data.
42 //
43 // The `Iterator` abstraction represents an execution of an input pipeline that
44 // can be used to enumerate its elements.
45 //
46 // Example usage:
47 //
48 //   // Create a `Dataset` by running the `graph_def` graph.
49 //   tensorflow::data:standalone::Dataset::Params params;
50 //   std::unique_ptr<tensorflow::data::standalone::Dataset> dataset;
51 //   Status s = tensorflow::data::standalone::Dataset::FromGraph(
52 //      params, graph_def, &dataset);
53 //   if (!s.ok()) { /* error handling */ }
54 //
55 //   std::unique_ptr<tensorflow::data::standalone::Iterator> iterator;
56 //   s = dataset->MakeIterator(&iterator);
57 //   if (!s.ok()) { /* error handling */ }
58 //
59 //   bool end_of_input = false;
60 //   while (!end_of_input) {
61 //     std::vector<tensorflow::Tensor> outputs;
62 //     s = iterator->GetNext(&outputs, &end_of_input);
63 //     if (!s.ok()) { /* error handling */ }
64 //     if (!end_of_input) { /* output handling */ }
65 //   }
66 
67 class Dataset;
68 
69 // Represents an execution of an input pipeline that can be used to enumerate
70 // its elements.
71 class Iterator {
72  public:
73   // Returns the next element of the input pipeline (if there is one) and an
74   // indication of whether the end of the input pipeline has been reached.
75   Status GetNext(std::vector<Tensor>* outputs, bool* end_of_input);
76 
77  private:
78   friend class Dataset;
79 
80   Iterator(IteratorBase* iterator, IteratorContext* ctx);
81 
82   std::unique_ptr<IteratorBase> iterator_;
83   std::unique_ptr<IteratorContext> ctx_;
84 };
85 
86 // Represents an input pipeline as a collection of data sources and a logical
87 // plan of transformations that operate over the data.
88 class Dataset {
89  public:
90   // Parameters for `Dataset` creation (e.g. TensorFlow runtime configuration).
91   struct Params {
92     SessionOptions session_options;
93   };
94 
95   // Creates a new `Dataset` instance by running the given dataset graph.
96   static Status FromGraph(Params params, const GraphDef& graph_def,
97                           std::unique_ptr<Dataset>* result);
98 
99   ~Dataset();
100 
101   // Creates an iterator for this dataset.
102   Status MakeIterator(std::unique_ptr<Iterator>* result);
103   // Creates an iterator, optionally with a split provider.
104   Status MakeIterator(
105       std::vector<std::unique_ptr<SplitProvider>> split_providers,
106       std::unique_ptr<Iterator>* result);
107 
108   // Creates split providers for this dataset.
109   Status MakeSplitProviders(
110       std::vector<std::unique_ptr<SplitProvider>>* result);
111   // Returns a pointer to the underlying dataset.
112   const DatasetBase* Get() const;
113 
114  private:
115   Dataset(DatasetBase* finalized_dataset, DatasetBase* original_dataset,
116           DeviceMgr* device_mgr, ProcessFunctionLibraryRuntime* pflr,
117           FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool,
118           std::function<void(std::function<void()>)> runner);
119 
120   DatasetBase* finalized_dataset_;  // owned
121   DatasetBase* original_dataset_;   // owned
122   std::unique_ptr<DeviceMgr> device_mgr_;
123   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
124   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
125   std::unique_ptr<thread::ThreadPool> interop_threadpool_;
126   std::unique_ptr<FunctionHandleCache> function_handle_cache_;
127   std::function<void(std::function<void()>)> runner_;
128   ResourceMgr resource_mgr_;
129   CancellationManager cancellation_manager_;
130   UnboundedThreadPool unbounded_thread_pool_;
131 };
132 
133 }  // namespace standalone
134 }  // namespace data
135 }  // namespace tensorflow
136 
137 #endif  // TENSORFLOW_CORE_DATA_STANDALONE_H_
138