xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/root_dataset.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/root_dataset.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <string>
21 #include <utility>
22 
23 #include "tensorflow/core/data/dataset_utils.h"
24 #include "tensorflow/core/data/name_utils.h"
25 #include "tensorflow/core/data/rewrite_utils.h"
26 #include "tensorflow/core/framework/model.pb.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/host_info.h"
29 #include "tensorflow/core/platform/refcount.h"
30 #include "tensorflow/core/platform/stringprintf.h"
31 
32 namespace tensorflow {
33 namespace data {
34 namespace {
35 
36 constexpr char kDatasetType[] = "Root";
37 
38 constexpr char kAlgorithm[] = "algorithm";
39 constexpr char kCpuBudget[] = "cpu_budget";
40 constexpr char kExperiments[] = "experiments";
41 constexpr char kInjectPrefetchEligibleOpt[] = "inject_prefetch_eligible";
42 constexpr char kIntraOpParallelism[] = "intra_op_parallelism";
43 constexpr char kMemBandwidth[] = "mem_bw_used_megabytes_per_sec";
44 constexpr char kPrivateThreadpoolSize[] = "threadpool_size";
45 constexpr char kRamBudget[] = "ram_budget_megabytes";
46 constexpr char kRamUsage[] = "ram_usage_megabytes";
47 constexpr char kMaxBufferBytes[] = "max_buffered_megabytes";
48 
49 // If value `x` matches `y`, returns default value `z`. Otherwise, return `x`.
value_or_default(int64_t x,int64_t y,int64_t z)50 inline int64_t value_or_default(int64_t x, int64_t y, int64_t z) {
51   return x == y ? z : x;
52 }
53 
SetRootDatasetParams(const Options & options,RootDataset::Params * params)54 void SetRootDatasetParams(const Options& options, RootDataset::Params* params) {
55   if (ShouldConfigureMaxIntraOpParallelism(options)) {
56     params->max_intra_op_parallelism =
57         options.threading_options().max_intra_op_parallelism();
58   }
59   if (ShouldUsePrivateThreadPool(options)) {
60     params->private_threadpool_size =
61         options.threading_options().private_threadpool_size();
62   }
63   params->autotune = ShouldUseAutotuning(options);
64   if (params->autotune) {
65     params->autotune_algorithm = model::AutotuneAlgorithm::DEFAULT;
66     if (GetExperiments().contains("stage_based_autotune")) {
67       params->autotune_algorithm = model::AutotuneAlgorithm::STAGE_BASED;
68     }
69     if (options.autotune_options().optional_autotune_algorithm_case() ==
70         AutotuneOptions::kAutotuneAlgorithm) {
71       params->autotune_algorithm =
72           options.autotune_options().autotune_algorithm();
73     }
74     params->autotune_cpu_budget = value_or_default(
75         options.autotune_options().cpu_budget(), 0, GetCpuBudget());
76     params->autotune_ram_budget =
77         value_or_default(options.autotune_options().ram_budget(), 0,
78                          model::kRamBudgetShare * port::AvailableRam());
79   }
80 }
81 
AddTraceMetadata(const RootDataset::Params & params,TraceMeMetadata * trace_metadata)82 void AddTraceMetadata(const RootDataset::Params& params,
83                       TraceMeMetadata* trace_metadata) {
84   if (params.autotune) {
85     trace_metadata->push_back(std::make_pair(
86         kAlgorithm, model::AutotuneAlgorithm_Name(params.autotune_algorithm)));
87     trace_metadata->push_back(std::make_pair(
88         kCpuBudget, strings::Printf("%lld", static_cast<long long>(
89                                                 params.autotune_cpu_budget))));
90     trace_metadata->push_back(std::make_pair(
91         kRamBudget,
92         strings::Printf("%lld", static_cast<long long>(
93                                     params.autotune_ram_budget / 1.0e6))));
94   }
95   if (params.max_intra_op_parallelism >= 0) {
96     trace_metadata->push_back(std::make_pair(
97         kIntraOpParallelism,
98         strings::Printf("%lld", static_cast<long long>(value_or_default(
99                                     params.max_intra_op_parallelism, 0,
100                                     port::MaxParallelism())))));
101   }
102   if (params.private_threadpool_size >= 0) {
103     trace_metadata->push_back(std::make_pair(
104         kPrivateThreadpoolSize,
105         strings::Printf("%lld", static_cast<long long>(value_or_default(
106                                     params.private_threadpool_size, 0,
107                                     port::MaxParallelism())))));
108   }
109   auto experiments = GetExperiments();
110   if (!experiments.empty()) {
111     trace_metadata->push_back(
112         std::make_pair(kExperiments, absl::StrJoin(experiments, " ")));
113   }
114 }
115 }  // namespace
116 
117 // static
FromOptions(const DatasetBase * input,DatasetBase ** output)118 Status RootDataset::FromOptions(const DatasetBase* input,
119                                 DatasetBase** output) {
120   Params params;
121   SetRootDatasetParams(input->options(), &params);
122   *output = new RootDataset(input, params);
123   (*output)->Initialize(/*metadata=*/{});
124   return OkStatus();
125 }
126 
FromOptions(core::RefCountPtr<DatasetBase> input,DatasetBase ** output)127 Status RootDataset::FromOptions(core::RefCountPtr<DatasetBase> input,
128                                 DatasetBase** output) {
129   Params params;
130   SetRootDatasetParams(input->options(), &params);
131   *output = new RootDataset(std::move(input), params);
132   (*output)->Initialize(/*metadata=*/{});
133   return OkStatus();
134 }
135 
136 class RootDataset::Iterator : public DatasetIterator<RootDataset> {
137  public:
Iterator(const Params & params)138   explicit Iterator(const Params& params)
139       : DatasetIterator<RootDataset>(params) {
140     if (dataset()->params_.autotune) {
141       model_ = std::make_shared<model::Model>();
142       if (GetExperiments().contains("autotune_buffer_optimization")) {
143         model_->SetExperiment("autotune_buffer_optimization");
144       }
145     }
146     if (dataset()->params_.max_intra_op_parallelism >= 0) {
147       max_intra_op_parallelism_ =
148           value_or_default(dataset()->params_.max_intra_op_parallelism, 0,
149                            port::MaxParallelism());
150     }
151     if (dataset()->params_.private_threadpool_size >= 0) {
152       threadpool_size_ =
153           value_or_default(dataset()->params_.private_threadpool_size, 0,
154                            port::MaxParallelism());
155       thread_pool_ = std::make_unique<thread::ThreadPool>(
156           Env::Default(), ThreadOptions{}, "data_private_threadpool",
157           threadpool_size_);
158     }
159     cancellation_manager_ = std::make_unique<CancellationManager>();
160   }
161 
~Iterator()162   ~Iterator() override { cancellation_manager_->StartCancel(); }
163 
Initialize(IteratorContext * ctx)164   Status Initialize(IteratorContext* ctx) override {
165     return dataset()->input_->MakeIterator(IteratorContext(CreateParams(ctx)),
166                                            this, prefix(), &input_impl_);
167   }
168 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)169   Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
170                          bool* end_of_sequence) override {
171     {
172       tf_shared_lock l(mu_);
173       if (model_ != nullptr && end_time_usec_ > 0) {
174         model_->RecordIteratorGapTime(ctx->env()->NowMicros() - end_time_usec_);
175       }
176     }
177     if (dataset()->params_.autotune) {
178       TF_RETURN_IF_ERROR(EnsureModelThreadStarted(ctx));
179     }
180     TF_RETURN_IF_ERROR(input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
181                                             out_tensors, end_of_sequence));
182     {
183       mutex_lock l(mu_);
184       end_time_usec_ = std::max(ctx->env()->NowMicros(), end_time_usec_);
185     }
186     return OkStatus();
187   }
188 
189  protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const190   std::shared_ptr<model::Node> CreateNode(
191       IteratorContext* ctx, model::Node::Args args) const override {
192     return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
193   }
194 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)195   Status SaveInternal(SerializationContext* ctx,
196                       IteratorStateWriter* writer) override {
197     TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
198     return OkStatus();
199   }
200 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)201   Status RestoreInternal(IteratorContext* ctx,
202                          IteratorStateReader* reader) override {
203     TF_RETURN_IF_ERROR(
204         RestoreInput(IteratorContext(CreateParams(ctx)), reader, input_impl_));
205     return OkStatus();
206   }
207 
GetTraceMeMetadata() const208   TraceMeMetadata GetTraceMeMetadata() const override {
209     tensorflow::data::TraceMeMetadata traceme_metadata =
210         dataset()->traceme_metadata_;
211     const int64_t mem_bw = port::GetMemoryBandwidthInfo().bw_used;
212     if (mem_bw != INT64_MAX) {
213       traceme_metadata.push_back(std::make_pair(
214           kMemBandwidth,
215           strings::Printf("%lld", static_cast<long long>(mem_bw))));
216     }
217     const auto memory_info = port::GetMemoryInfo();
218     const auto memory_usage = memory_info.total - memory_info.free;
219     traceme_metadata.push_back(std::make_pair(
220         kRamUsage,
221         strings::Printf("%lld out of %lld (%.2f%%)",
222                         static_cast<long long>(memory_usage / 1.0e6),
223                         static_cast<long long>(memory_info.total / 1.0e6),
224                         static_cast<double>(100 * memory_usage) /
225                             static_cast<double>(memory_info.total))));
226     if (model_node() != nullptr) {
227       traceme_metadata.push_back(std::make_pair(
228           kMaxBufferBytes,
229           strings::Printf(
230               "%lld", static_cast<long long>(
231                           model_node()->TotalMaximumBufferedBytes() / 1.0e6))));
232     }
233     return traceme_metadata;
234   }
235 
236  private:
CreateParams(IteratorContext * ctx)237   IteratorContext::Params CreateParams(IteratorContext* ctx) {
238     IteratorContext::Params params(ctx);
239     if (dataset()->params_.autotune) {
240       params.model = model_;
241     }
242     if (dataset()->params_.private_threadpool_size >= 0) {
243       params.runner = [pool = thread_pool_.get()](std::function<void()> c) {
244         pool->Schedule(std::move(c));
245       };
246       params.runner_threadpool_size = threadpool_size_;
247     }
248     if (dataset()->params_.max_intra_op_parallelism >= 0) {
249       params.runner =
250           RunnerWithMaxParallelism(params.runner, max_intra_op_parallelism_);
251     }
252     params.options = &dataset()->options();
253     return params;
254   }
255 
EnsureModelThreadStarted(IteratorContext * ctx)256   Status EnsureModelThreadStarted(IteratorContext* ctx) {
257     mutex_lock l(mu_);
258     if (!model_thread_) {
259       model_thread_ = ctx->StartThread("tf_data_model", [this]() {
260         Status status =
261             model_->OptimizeLoop(dataset()->params_.autotune_algorithm,
262                                  dataset()->params_.autotune_cpu_budget,
263                                  dataset()->params_.autotune_ram_budget,
264                                  cancellation_manager_.get());
265         if (!status.ok()) {
266           LOG(WARNING) << "Optimization loop failed: " << status.ToString();
267         }
268       });
269     }
270     return OkStatus();
271   }
272 
273   std::shared_ptr<model::Model> model_ = nullptr;
274   // Controls cancellation of `model_thread_`. Must be ordered before
275   // `model_thread_` so that `model_thread_` is destroyed first.
276   std::unique_ptr<CancellationManager> cancellation_manager_;
277   mutex mu_;
278   std::unique_ptr<Thread> model_thread_ TF_GUARDED_BY(mu_);
279   int64_t max_intra_op_parallelism_;
280   int64_t threadpool_size_;
281   std::unique_ptr<thread::ThreadPool> thread_pool_;
282 
283   // The end time of the previous `GetNextInternal` call.
284   uint64_t end_time_usec_ TF_GUARDED_BY(mu_) = 0;
285 
286   // Must be ordered last as its execution may depend on other members.
287   std::unique_ptr<IteratorBase> input_impl_;
288 };
289 
RootDataset(const DatasetBase * input,const Params & params)290 RootDataset::RootDataset(const DatasetBase* input, const Params& params)
291     : DatasetBase(DatasetContext({name_utils::OpName(kDatasetType),
292                                   name_utils::OpName(kDatasetType)})),
293       input_(input),
294       params_(std::move(params)) {
295   AddTraceMetadata(params_, &traceme_metadata_);
296 }
297 
RootDataset(core::RefCountPtr<DatasetBase> input,const Params & params)298 RootDataset::RootDataset(core::RefCountPtr<DatasetBase> input,
299                          const Params& params)
300     : DatasetBase(DatasetContext({name_utils::OpName(kDatasetType),
301                                   name_utils::OpName(kDatasetType)})),
302       params_(std::move(params)) {
303   owned_input_ = std::move(input);
304   input_ = owned_input_.get();
305   AddTraceMetadata(params_, &traceme_metadata_);
306 }
307 
~RootDataset()308 RootDataset::~RootDataset() {}
309 
MakeIteratorInternal(const string & prefix) const310 std::unique_ptr<IteratorBase> RootDataset::MakeIteratorInternal(
311     const string& prefix) const {
312   return std::make_unique<Iterator>(
313       Iterator::Params{this, name_utils::IteratorPrefix(kDatasetType, prefix)});
314 }
315 
output_dtypes() const316 const DataTypeVector& RootDataset::output_dtypes() const {
317   return input_->output_dtypes();
318 }
319 
output_shapes() const320 const std::vector<PartialTensorShape>& RootDataset::output_shapes() const {
321   return input_->output_shapes();
322 }
323 
DebugString() const324 string RootDataset::DebugString() const {
325   return name_utils::DatasetDebugString(kDatasetType);
326 }
327 
CardinalityInternal() const328 int64_t RootDataset::CardinalityInternal() const {
329   return input_->Cardinality();
330 }
331 
CardinalityInternal(CardinalityOptions options) const332 int64_t RootDataset::CardinalityInternal(CardinalityOptions options) const {
333   return input_->Cardinality(options);
334 }
335 
Get(OpKernelContext * ctx,int64 index,std::vector<Tensor> * out_tensors) const336 Status RootDataset::Get(OpKernelContext* ctx, int64 index,
337                         std::vector<Tensor>* out_tensors) const {
338   std::vector<const DatasetBase*> inputs;
339   TF_RETURN_IF_ERROR(this->InputDatasets(&inputs));
340   return inputs[0]->Get(ctx, index, out_tensors);
341 }
342 
InputDatasets(std::vector<const DatasetBase * > * inputs) const343 Status RootDataset::InputDatasets(
344     std::vector<const DatasetBase*>* inputs) const {
345   inputs->push_back(input_);
346   return OkStatus();
347 }
348 
CheckExternalState() const349 Status RootDataset::CheckExternalState() const {
350   return input_->CheckExternalState();
351 }
352 
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const353 Status RootDataset::AsGraphDefInternal(SerializationContext* ctx,
354                                        DatasetGraphDefBuilder* b,
355                                        Node** output) const {
356   return errors::Unimplemented("RootDataset does not support serialization.");
357 }
358 
359 #if !defined(IS_MOBILE_PLATFORM)
FinalizeDataset(OpKernelContext * ctx,const DatasetBase * input,DatasetBase ** output)360 Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input,
361                        DatasetBase** output) {
362   const Options& options = input->options();
363   absl::flat_hash_set<tstring> optimizations_enabled;
364   absl::flat_hash_set<tstring> optimizations_disabled;
365   absl::flat_hash_set<tstring> optimizations_default;
366   GetOptimizations(options, &optimizations_enabled, &optimizations_disabled,
367                    &optimizations_default);
368   // Disable `enable_gradient_descent` as it assumes presence of ModelDatasetOp.
369   optimizations_disabled.insert("enable_gradient_descent");
370   if (!port::JobName().empty()) {
371     // Enable kInjectPrefetchEligibleOpt that does not modify the graph and is
372     // used to check whether the `inject_prefetch` optimization would modify the
373     // graph.
374     optimizations_enabled.insert(kInjectPrefetchEligibleOpt);
375   }
376 
377   auto experiments = GetExperiments();
378   LogAndRecordExperiments(experiments);
379   auto optimizations =
380       SelectOptimizations(experiments, optimizations_enabled,
381                           optimizations_disabled, optimizations_default);
382   if (optimizations.empty()) {
383     return RootDataset::FromOptions(input, output);
384   }
385 
386   auto optimization_configs = CreateGraphRewriteConfigs(options);
387   auto config_factory = [&optimizations, &optimization_configs]() {
388     return CreateRewriterConfig(optimizations, optimization_configs);
389   };
390   core::RefCountPtr<DatasetBase> rewritten_output;
391   Status s = RewriteDataset(ctx, input, std::move(config_factory),
392                             /*record_fingerprint=*/true, &rewritten_output);
393 
394   *output = rewritten_output.get();
395   bool rewritten = (*output != input);
396   if (errors::IsDeadlineExceeded(s)) {
397     // Ignore DeadlineExceeded as it implies that the attempted rewrite took too
398     // long which should not prevent further computation.
399     LOG(WARNING) << s.ToString();
400   } else if (!s.ok()) {
401     return s;
402   }
403   if (!rewritten) {
404     return RootDataset::FromOptions(input, output);
405   } else {
406     return RootDataset::FromOptions(std::move(rewritten_output), output);
407   }
408   return OkStatus();
409 }
410 
411 #else   // !IS_MOBILE_PLATFORM
FinalizeDataset(OpKernelContext * ctx,const DatasetBase * input,DatasetBase ** output)412 Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input,
413                        DatasetBase** output) {
414   return RootDataset::FromOptions(input, output);
415 }
416 #endif  // !IS_MOBILE_PLATFORM
417 
418 }  // namespace data
419 }  // namespace tensorflow
420