xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/standalone.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/standalone.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/memory/memory.h"
25 #include "tensorflow/core/common_runtime/device_factory.h"
26 #include "tensorflow/core/common_runtime/device_mgr.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/common_runtime/graph_runner.h"
30 #include "tensorflow/core/common_runtime/process_util.h"
31 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
32 #include "tensorflow/core/data/root_dataset.h"
33 #include "tensorflow/core/framework/dataset.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/graph/graph.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/refcount.h"
38 #include "tensorflow/core/public/version.h"
39 #include "tensorflow/core/util/ptr_util.h"
40 
41 namespace tensorflow {
42 namespace data {
43 namespace standalone {
44 
45 namespace {
46 
CreateParams(ProcessFunctionLibraryRuntime * pflr,DeviceMgr * device_mgr,std::function<void (std::function<void ()>)> * runner)47 OpKernelContext::Params CreateParams(
48     ProcessFunctionLibraryRuntime* pflr, DeviceMgr* device_mgr,
49     std::function<void(std::function<void()>)>* runner) {
50   OpKernelContext::Params params;
51   params.function_library = pflr->GetFLR("/device:CPU:0");
52   params.device = device_mgr->ListDevices()[0];
53   params.runner = runner;
54   return params;
55 }
56 
57 }  // namespace
58 
GetNext(std::vector<Tensor> * outputs,bool * end_of_input)59 Status Iterator::GetNext(std::vector<Tensor>* outputs, bool* end_of_input) {
60   return iterator_->GetNext(ctx_.get(), outputs, end_of_input);
61 }
62 
Iterator(IteratorBase * iterator,IteratorContext * ctx)63 Iterator::Iterator(IteratorBase* iterator, IteratorContext* ctx)
64     : iterator_(iterator), ctx_(ctx) {}
65 
FromGraph(Params params,const GraphDef & graph_def,std::unique_ptr<Dataset> * result)66 Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
67                           std::unique_ptr<Dataset>* result) {
68   Graph graph(OpRegistry::Global());
69   TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
70 
71   // Instantiate enough of the TF runtime to run `graph` on a single CPU device.
72   auto device_mgr = std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
73       "CPU", params.session_options, "/job:localhost/replica:0/task:0"));
74   Device* device = device_mgr->ListDevices()[0];
75   // Create a copy of the `FunctionLibraryDefinition` to extend lifetime beyond
76   // the lifetime of `graph`.
77   auto flib_def = std::make_unique<FunctionLibraryDefinition>(
78       OpRegistry::Global(), graph_def.library());
79   auto pflr = std::make_unique<ProcessFunctionLibraryRuntime>(
80       device_mgr.get(), Env::Default(), /*config=*/nullptr,
81       TF_GRAPH_DEF_VERSION, flib_def.get(), OptimizerOptions{},
82       /*thread_pool=*/nullptr, /*parent=*/nullptr,
83       /*session_metadata=*/nullptr,
84       Rendezvous::Factory{
85           [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) {
86             *r = new IntraProcessRendezvous(device_mgr);
87             return OkStatus();
88           }});
89 
90   string fetch_node = "";
91   for (const auto& node : graph_def.node()) {
92     if (node.op() == "_Retval") {
93       fetch_node = node.input(0);
94     }
95   }
96   if (fetch_node.empty()) {
97     return errors::NotFound("Failed to find a _Retval op in the given dataset");
98   }
99 
100   // Run graph up to `output_node` and extract the `DatasetBase` stored in the
101   // DT_VARIANT output tensor.
102   std::vector<Tensor> outputs;
103   GraphRunner graph_runner(device);
104   TF_RETURN_IF_ERROR(graph_runner.Run(&graph, pflr->GetFLR("/device:CPU:0"), {},
105                                       {fetch_node}, &outputs));
106   data::DatasetBase* dataset;
107   TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
108 
109   data::DatasetBase* finalized_dataset;
110   std::unique_ptr<thread::ThreadPool> pool(
111       NewThreadPoolFromSessionOptions(params.session_options));
112   std::function<void(std::function<void()>)> runner =
113       [&pool](std::function<void()> c) { pool->Schedule(std::move(c)); };
114   OpKernelContext::Params op_params =
115       CreateParams(pflr.get(), device_mgr.get(), &runner);
116   OpKernelContext ctx(&op_params, /*num_outputs=*/0);
117   TF_RETURN_IF_ERROR(data::FinalizeDataset(&ctx, dataset, &finalized_dataset));
118   core::ScopedUnref unref(finalized_dataset);
119   *result = WrapUnique(new Dataset(
120       finalized_dataset, dataset, device_mgr.release(), pflr.release(),
121       flib_def.release(), pool.release(), std::move(runner)));
122   return OkStatus();
123 }  // static
124 
MakeIterator(std::vector<std::unique_ptr<SplitProvider>> split_providers,std::unique_ptr<Iterator> * result)125 Status Dataset::MakeIterator(
126     std::vector<std::unique_ptr<SplitProvider>> split_providers,
127     std::unique_ptr<Iterator>* result) {
128   // Create an `IteratorContext`, which bundles together the necessary runtime
129   // support to create and get elements from an iterator.
130   std::unique_ptr<IteratorContext> ctx;
131   // NOTE(mrry): In the current API, an `IteratorContext` is always initially
132   // created from an `OpKernelContext*`, so we need to create `OpKernelContext`
133   // with a valid subset of parameters.
134   OpKernelContext::Params op_params =
135       CreateParams(pflr_.get(), device_mgr_.get(), &runner_);
136   OpKernelContext op_ctx(&op_params, /*num_outputs=*/0);
137   IteratorContext::Params params(&op_ctx);
138   params.cancellation_manager = &cancellation_manager_;
139   params.function_handle_cache = function_handle_cache_.get();
140   params.resource_mgr = &resource_mgr_;
141   std::move(split_providers.begin(), split_providers.end(),
142             std::back_inserter(params.split_providers));
143   params.thread_factory = unbounded_thread_pool_.get_thread_factory();
144   params.thread_pool = &unbounded_thread_pool_;
145   ctx = std::make_unique<IteratorContext>(std::move(params));
146 
147   // Create the iterator from the dataset.
148   std::unique_ptr<IteratorBase> iterator;
149   TF_RETURN_IF_ERROR(finalized_dataset_->MakeIterator(
150       ctx.get(), /*parent=*/nullptr, "Iterator", &iterator));
151   *result = WrapUnique(new Iterator(iterator.release(), ctx.release()));
152 
153   return OkStatus();
154 }
155 
MakeIterator(std::unique_ptr<Iterator> * result)156 Status Dataset::MakeIterator(std::unique_ptr<Iterator>* result) {
157   return MakeIterator(/*split_providers=*/{}, result);
158 }
159 
MakeSplitProviders(std::vector<std::unique_ptr<SplitProvider>> * result)160 Status Dataset::MakeSplitProviders(
161     std::vector<std::unique_ptr<SplitProvider>>* result) {
162   return finalized_dataset_->MakeSplitProviders(result);
163 }
164 
Get() const165 const DatasetBase* Dataset::Get() const { return finalized_dataset_; }
166 
Dataset(DatasetBase * finalized_dataset,DatasetBase * original_dataset,DeviceMgr * device_mgr,ProcessFunctionLibraryRuntime * pflr,FunctionLibraryDefinition * flib_def,thread::ThreadPool * pool,std::function<void (std::function<void ()>)> runner)167 Dataset::Dataset(DatasetBase* finalized_dataset, DatasetBase* original_dataset,
168                  DeviceMgr* device_mgr, ProcessFunctionLibraryRuntime* pflr,
169                  FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool,
170                  std::function<void(std::function<void()>)> runner)
171     : finalized_dataset_(finalized_dataset),
172       original_dataset_(original_dataset),
173       device_mgr_(device_mgr),
174       flib_def_(flib_def),
175       pflr_(pflr),
176       interop_threadpool_(pool),
177       runner_(std::move(runner)),
178       unbounded_thread_pool_(Env::Default(), "tf_data_standalone") {
179   finalized_dataset_->Ref();
180   original_dataset_->Ref();
181   function_handle_cache_ =
182       std::make_unique<FunctionHandleCache>(pflr_->GetFLR("/device:CPU:0"));
183 }
184 
~Dataset()185 Dataset::~Dataset() {
186   finalized_dataset_->Unref();
187   original_dataset_->Unref();
188 }
189 
190 }  // namespace standalone
191 }  // namespace data
192 }  // namespace tensorflow
193