xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/model.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
17 
18 #include <algorithm>
19 #include <list>
20 #include <memory>
21 #include <string>
22 // TODO(b/114492873): Move this include into core/platform.
23 #include <thread>  // NOLINT
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/container/flat_hash_map.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/metrics.h"
30 #include "tensorflow/core/framework/model.pb.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/lib/gtl/cleanup.h"
33 #include "tensorflow/core/lib/gtl/map_util.h"
34 #include "tensorflow/core/lib/histogram/histogram.h"
35 #include "tensorflow/core/lib/random/random.h"
36 #include "tensorflow/core/platform/cpu_info.h"
37 #include "tensorflow/core/platform/env.h"
38 #include "tensorflow/core/platform/mutex.h"
39 #include "tensorflow/core/platform/path.h"
40 #include "tensorflow/core/platform/statusor.h"
41 #include "tensorflow/core/platform/strcat.h"
42 #include "tensorflow/core/platform/stringprintf.h"
43 
44 namespace tensorflow {
45 namespace data {
46 namespace model {
47 
48 // A constant that can be used to enable auto-tuning.
49 constexpr int64_t kAutotune = -1;
50 constexpr char kParallelism[] = "parallelism";
51 constexpr char kBufferSize[] = "buffer_size";
52 constexpr char kCycleLength[] = "cycle_length";
53 constexpr char kDeterministic[] = "deterministic";
54 constexpr char kMaxBufferedElements[] = "max_buffered_elements";
55 
56 // A key used to identify the input time of the model.
57 constexpr char kModelInputTimeKey[] = "model_input_time";
58 
59 // Default share of available RAM that can be used by model's internal buffers.
60 constexpr double kRamBudgetShare = 0.5;
61 
62 // Weight of the latest processing time used in computing the exponential moving
63 // average of processing time per element.
64 constexpr double kProcessingTimeEmaWeight = 0.1;
65 
66 enum class TraversalOrder {
67   BFS = 0,
68   REVERSE_BFS = 1,
69 };
70 
71 // Represents thread-safe state that can be shared between an input pipeline and
72 // the performance model.
73 struct SharedState {
74  public:
SharedStateSharedState75   SharedState(int64_t value, std::shared_ptr<mutex> mu,
76               std::shared_ptr<condition_variable> cond_var)
77       : value(value),
78         mu(std::move(mu)),
79         cond_var(std::move(cond_var)),
80         tunable(value == kAutotune) {}
81 
82   double value;
83   const std::shared_ptr<mutex> mu;
84   const std::shared_ptr<condition_variable> cond_var;
85   const bool tunable;
86 };
87 
88 // Represents a parameter.
89 struct Parameter {
ParameterParameter90   Parameter(const string& name, std::shared_ptr<SharedState> state, double min,
91             double max)
92       : name(name),
93         // Sometimes non-autotune nodes (with `autotune_=false`) may contain
94         // parameters (for example inputs of parallel interleave dataset which
95         // are not in the current cycle). To avoid unrealistic situation
96         // (say `buffer_size=-1` or `parallelism=-1`) in the optimization
97         // computation, if the state value is `kAutotune=-1` (just to indicate
98         // the `SharedState` is tunable), we initialize the parameter value to
99         // be the minimal value of the state.
100         value(state == nullptr || state->value == kAutotune ? min
101                                                             : state->value),
102         min(min),
103         max(max),
104         state(std::move(state)) {}
105 
106   // Human-readable name of the parameter.
107   const string name;
108 
109   // Identifies the model value of the parameter. This can be different from
110   // the actual value (e.g. during optimization search).
111   double value;
112 
113   // Identifies the minimum value of the parameter.
114   const double min;
115 
116   // Identifies the maximum value of the parameter.
117   const double max;
118 
119   // Shared state of the parameter.
120   std::shared_ptr<SharedState> state;
121 };
122 
123 // Returns a new tunable parameter.
124 std::shared_ptr<Parameter> MakeParameter(const string& name,
125                                          std::shared_ptr<SharedState> state,
126                                          double min, double max);
127 
128 // Returns a new non-tunable parameter.
129 std::shared_ptr<Parameter> MakeNonTunableParameter(const string& name,
130                                                    double value);
131 
132 // Abstract representation of a TensorFlow input pipeline node. It collects
133 // information about inputs to this node, processing time spent executing the
134 // node logic, number of elements produced by the node, various other
135 // information (e.g. batch size or execution parallelism).
136 //
137 // Developers of tf.data transformations are not expected to interact with
138 // this class directly. Boiler plate code for creating the abstract
139 // representation of the input pipeline and collecting common information has
140 // been added to the implementation of `DatasetBase` and `DatasetBaseIterator`
141 // respectively.
142 //
143 // In addition, `DatasetBaseIterator` provides wrappers that can be used for
144 // transformation-specific information collection. The `SetMetadata` wrapper
145 // can be used to pass arbitrary metadata to the modeling framework, while the
146 // `StartWork` and `StopWork` wrappers should be used to correctly account for
147 // processing time of multi-threaded transformation that yield the CPU; such
148 // transformations should invoke `StartWork()` when a transformation thread
149 // starts executing (e.g. when created or woken up) and `StopWork()` when a
150 // transformation thread stops executing (e.g. when returning or waiting).
151 class Node {
152  public:
153   // Arguments for `Node` constructor.
154   struct Args {
155     int64_t id;
156     string name;
157     std::shared_ptr<Node> output;
158   };
159 
160   using Factory = std::function<std::shared_ptr<Node>(Args)>;
161   using NodeVector = std::vector<std::shared_ptr<Node>>;
162   using NodePairList =
163       std::list<std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>>>;
164   using ModelParameters =
165       std::vector<std::pair<string, std::shared_ptr<Parameter>>>;
166   using NodeValues = absl::flat_hash_map<string, double>;
167   using ParameterGradients =
168       absl::flat_hash_map<std::pair<string, string>, double>;
169 
Node(Args args)170   explicit Node(Args args)
171       : id_(args.id),
172         name_(std::move(args.name)),
173         autotune_(true),
174         buffered_bytes_(0),
175         buffered_elements_(0),
176         buffered_elements_low_(std::numeric_limits<int64_t>::max()),
177         buffered_elements_high_(std::numeric_limits<int64_t>::min()),
178         bytes_consumed_(0),
179         bytes_produced_(0),
180         num_elements_(0),
181         processing_time_(0),
182         record_metrics_(true),
183         metrics_(name_),
184         output_(args.output.get()) {}
185 
~Node()186   virtual ~Node() {
187     // Clear the sub-nodes instead of relying on implicit shared pointer
188     // destructor to avoid potential stack overflow when the tree is deep.
189     std::deque<std::shared_ptr<Node>> queue;
190     {
191       mutex_lock l(mu_);
192       while (inputs_.size() > 0) {
193         queue.push_back(inputs_.front());
194         inputs_.pop_front();
195       }
196     }
197     while (!queue.empty()) {
198       auto node = queue.back();
199       queue.pop_back();
200       {
201         mutex_lock l(node->mu_);
202         while (node->inputs_.size() > 0) {
203           queue.push_back(node->inputs_.front());
204           node->inputs_.pop_front();
205         }
206       }
207     }
208 
209     FlushMetrics();
210   }
211 
212   // Adds an input.
add_input(std::shared_ptr<Node> node)213   void add_input(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_) {
214     mutex_lock l(mu_);
215     inputs_.push_back(node);
216   }
217 
218   // Increments the aggregate processing time by the given delta.
add_processing_time(int64_t delta)219   void add_processing_time(int64_t delta) TF_LOCKS_EXCLUDED(mu_) {
220     processing_time_ += delta;
221   }
222 
223   // Returns an indication whether autotuning is enabled for this node.
autotune()224   bool autotune() const TF_LOCKS_EXCLUDED(mu_) { return autotune_; }
225 
226   // Returns the number of bytes stored in this node's buffer.
buffered_bytes()227   int64_t buffered_bytes() const TF_LOCKS_EXCLUDED(mu_) {
228     return buffered_bytes_;
229   }
230 
231   // Returns the number of elements stored in this node's buffer.
buffered_elements()232   int64_t buffered_elements() const TF_LOCKS_EXCLUDED(mu_) {
233     return buffered_elements_;
234   }
235 
236   // Returns the low watermark of the number of elements stored in this node's
237   // buffer. The watermarks are reset at the beginning of the execution time and
238   // each time the buffer is upsized or downsized.
buffered_elements_low()239   int64_t buffered_elements_low() const TF_LOCKS_EXCLUDED(mu_) {
240     return buffered_elements_low_;
241   }
242 
243   // Returns the high watermark of the number of elements stored in this node's
244   // buffer. The watermarks are reset at the beginning of the execution time and
245   // each time the buffer is upsized or downsized.
buffered_elements_high()246   int64_t buffered_elements_high() const TF_LOCKS_EXCLUDED(mu_) {
247     return buffered_elements_high_;
248   }
249 
250   // Returns the number of bytes consumed by the node.
bytes_consumed()251   int64_t bytes_consumed() const TF_LOCKS_EXCLUDED(mu_) {
252     return bytes_consumed_;
253   }
254 
255   // Returns the number of bytes produced by the node.
bytes_produced()256   int64_t bytes_produced() const TF_LOCKS_EXCLUDED(mu_) {
257     return bytes_produced_;
258   }
259 
260   // Indicates whether the node has tunable parameters.
has_tunable_parameters()261   bool has_tunable_parameters() const TF_LOCKS_EXCLUDED(mu_) {
262     tf_shared_lock l(mu_);
263     for (const auto& pair : parameters_) {
264       if (pair.second->state->tunable) return true;
265     }
266     return false;
267   }
268 
269   // Returns the unique node ID.
id()270   int64_t id() const TF_LOCKS_EXCLUDED(mu_) { return id_; }
271 
272   // Returns the node inputs.
inputs()273   std::list<std::shared_ptr<Node>> inputs() const TF_LOCKS_EXCLUDED(mu_) {
274     tf_shared_lock l(mu_);
275     return inputs_;
276   }
277 
278   // Returns a longer node name that is guaranteed to be unique.
long_name()279   string long_name() const { return strings::StrCat(name_, "(id:", id_, ")"); }
280 
281   // Returns the node name.
name()282   const string& name() const { return name_; }
283 
284   // Returns the number of elements produced by the node.
num_elements()285   int64_t num_elements() const TF_LOCKS_EXCLUDED(mu_) { return num_elements_; }
286 
287   // Returns the node output.
output()288   Node* output() const { return output_; }
289 
290   // Returns the parameter value.
parameter_value(const string & name)291   double parameter_value(const string& name) const TF_LOCKS_EXCLUDED(mu_) {
292     tf_shared_lock l(mu_);
293     return parameters_.at(name)->state->value;
294   }
295 
296   // Returns the aggregate processing time.
processing_time()297   int64_t processing_time() const TF_LOCKS_EXCLUDED(mu_) {
298     return processing_time_;
299   }
300 
301   // Records that the node consumed the given number of bytes.
record_bytes_consumed(int64_t num_bytes)302   void record_bytes_consumed(int64_t num_bytes) {
303     bytes_consumed_ += num_bytes;
304   }
305 
306   // Records that the node produced the given number of bytes.
record_bytes_produced(int64_t num_bytes)307   void record_bytes_produced(int64_t num_bytes) {
308     bytes_produced_ += num_bytes;
309   }
310 
311   // Records the change in this node's buffer.
record_buffer_event(int64_t bytes_delta,int64_t elements_delta)312   void record_buffer_event(int64_t bytes_delta, int64_t elements_delta) {
313     buffered_bytes_ += bytes_delta;
314     buffered_elements_ += elements_delta;
315     // There is no need to maintain watermarks for synchronous ops because we
316     // will not upsize or downsize the buffers of synchronous ops.
317     if (IsAsync()) {
318       int64_t low_watermark =
319           std::min(buffered_elements_low_, buffered_elements_);
320       buffered_elements_low_ = low_watermark;
321       int64_t high_watermark =
322           std::max(buffered_elements_high_, buffered_elements_);
323       buffered_elements_high_ = high_watermark;
324     }
325   }
326 
327   // Records that the node produced an element.
record_element()328   void record_element() TF_LOCKS_EXCLUDED(mu_) {
329     num_elements_++;
330     {
331       mutex_lock l(mu_);
332       UpdateProcessingTimeEma();
333     }
334   }
335 
336   // Records that a node thread has started executing.
record_start(int64_t time_nanos)337   void record_start(int64_t time_nanos) TF_LOCKS_EXCLUDED(mu_) {
338     DCHECK_EQ(work_start_, 0);
339     work_start_ = time_nanos;
340   }
341 
342   // Records that a node thread has stopped executing.
record_stop(int64_t time_nanos)343   void record_stop(int64_t time_nanos) TF_LOCKS_EXCLUDED(mu_) {
344     // TODO(jsimsa): Use DCHECK_NE(work_start_, 0) here.
345     if (work_start_ != 0) {
346       processing_time_ += time_nanos - work_start_;
347       work_start_ = 0;
348     } else {
349       VLOG(1) << "Encountered a stop event without a matching start event.";
350     }
351   }
352 
353   // Returns whether work is currently being recorded, i.e. whether we are
354   // currently between a `record_start` and a `record_stop`.
is_recording()355   bool is_recording() TF_LOCKS_EXCLUDED(mu_) { return work_start_ > 0; }
356 
357   // Removes an input.
remove_input(std::shared_ptr<Node> input)358   void remove_input(std::shared_ptr<Node> input) TF_LOCKS_EXCLUDED(mu_) {
359     mutex_lock l(mu_);
360     inputs_.remove(input);
361   }
362 
363   // Sets the value that determines whether autotuning is enabled for this node.
set_autotune(bool autotune)364   void set_autotune(bool autotune) TF_LOCKS_EXCLUDED(mu_) {
365     autotune_.store(autotune);
366   }
367 
368   // Resets buffer watermarks to the current buffered elements.
ResetBufferWatermarks()369   void ResetBufferWatermarks() {
370     if (!IsAsync()) {
371       return;
372     }
373     int64_t current_buffer_size = buffered_elements_;
374     buffered_elements_low_ = current_buffer_size;
375     buffered_elements_high_ = current_buffer_size;
376   }
377 
378   // Returns true for asynchronous nodes; false otherwise.
IsAsync()379   virtual bool IsAsync() const { return false; }
380 
381   // Returns the ratio of the node, which is defined as the number of elements
382   // per input needed by the node to produce an element, e.g. batch size of a
383   // `Batch`. It can be 0 if the ratio is unknown.
Ratio()384   virtual double Ratio() const { return 1.0; }
385 
386   // Computes the self time in nanoseconds of the node to produce one element.
387   virtual double ComputeSelfTime() const;
388 
389   // Returns the parameter value if it exists, not ok status otherwise.
ParameterValue(const std::string & parameter_name)390   StatusOr<double> ParameterValue(const std::string& parameter_name) const
391       TF_LOCKS_EXCLUDED(mu_) {
392     tf_shared_lock l(mu_);
393     if (parameters_.contains(parameter_name)) {
394       return parameters_.at(parameter_name)->value;
395     }
396     return errors::NotFound("Parameter ", parameter_name,
397                             " was not found in model node ", long_name());
398   }
399 
400   // Given the average time between events when the elements in the buffer are
401   // produced (`producer_time`), the average time between events when elements
402   // in the buffer are consumed (`consumer_time`) and the buffer size, the
403   // method computes the expected time an consumer event will have to wait.
404   //
405   // The wait time is approximated as the product of the probability the buffer
406   // will be empty and the time it takes to produce an element into the buffer.
407   //
408   // The formula used for computing the probability is derived by modeling the
409   // problem as an M/M/1/K queue
410   // (https://en.wikipedia.org/wiki/Birth%E2%80%93death_process#M/M/1/K_queue).
411   //
412   // Collects derivatives of `ComputeWaitTime` w.r.t `producer_time`,
413   // `consumer_time' and `buffer_size` if the corresponding pointers are not
414   // `nullptr`.
415   static double ComputeWaitTime(const double& producer_time,
416                                 const double& consumer_time,
417                                 const double& buffer_size,
418                                 double* producer_time_derivative,
419                                 double* consumer_time_derivative,
420                                 double* buffer_size_derivative);
421 
422   // Collects tunable parameters in the subtree rooted in this node.
423   ModelParameters CollectTunableParameters() const TF_LOCKS_EXCLUDED(mu_);
424 
425   // Collects tunable parameters in this node.
426   ModelParameters CollectNodeTunableParameters() const TF_LOCKS_EXCLUDED(mu_);
427 
428   // Returns a human-readable representation of this node.
429   string DebugString() const TF_LOCKS_EXCLUDED(mu_);
430 
431   // Flushes the metrics recorded by this node.
432   void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
433 
434   // Returns the per-element output time for this node and if `gradients` is not
435   // `nullptr`, collects the output time gradient w.r.t. tunable parameters of
436   // the subtree rooted in this node.
437   double OutputTime(NodeValues* input_times,
438                     ParameterGradients* gradients) const TF_LOCKS_EXCLUDED(mu_);
439 
440   // Returns a copy of this node, making a deep copy of its inputs and a
441   // shallow copy of its tunable parameters.
442   //
443   // The purpose for this method is to allow the model optimization logic to
444   // operate over immutable state while allowing concurrent model updates.
445   std::shared_ptr<Node> Snapshot() const TF_LOCKS_EXCLUDED(mu_);
446 
447   // Returns the per-element processing time in nanoseconds spent in this node.
448   double SelfProcessingTime() const TF_LOCKS_EXCLUDED(mu_);
449 
450   // Returns the total number of bytes buffered in all nodes in the subtree for
451   // which autotuning is enabled.
452   double TotalBufferedBytes() const TF_LOCKS_EXCLUDED(mu_);
453 
454   // Collects the total buffer limit of all nodes in the subtree for which
455   // autotuning is enabled. This number represents the amount of memory that
456   // would be used by the subtree nodes if all of their buffers were full.
457   double TotalMaximumBufferedBytes() const TF_LOCKS_EXCLUDED(mu_);
458 
459   // Returns the per-element CPU time in nanoseconds spent in the subtree rooted
460   // in this node. If `processing_times` is not `nullptr`, collects the
461   // per-element CPU time spent in each node of the subtree.
462   double TotalProcessingTime(NodeValues* processing_times)
463       TF_LOCKS_EXCLUDED(mu_);
464 
465   // Produces a proto for this node. Does not produce a proto for input nodes.
466   virtual Status ToProto(ModelProto::Node* node_proto) const;
467 
468   // Restores a node from the proto. Does not restore input nodes.
469   static Status FromProto(ModelProto::Node node_proto,
470                           std::shared_ptr<Node> output,
471                           std::shared_ptr<Node>* node);
472 
473   // Returns a vector of nodes of the subtree rooted in this node. The nodes are
474   // either in breadth-first search or reverse breadth-first search order
475   // depending on the `order` argument. The nodes are collected based on the
476   // results of the `collect_node` predicate: if the predicate returns `false`
477   // for a given node, then the subtree rooted in this node is excluded. The
478   // root node itself is not collected.
479   NodeVector CollectNodes(TraversalOrder order,
480                           bool collect_node(const std::shared_ptr<Node>)) const
481       TF_LOCKS_EXCLUDED(mu_);
482 
483   // Downsizes buffer parameters of this node. Returns true if any buffer is
484   // downsized.
485   bool TryDownsizeBuffer();
486 
487   // Collects buffer parameters of this node that should be upsized.
488   void CollectBufferParametersToUpsize(
489       absl::flat_hash_map<Node*, Parameter*>& node_parameters);
490 
491  protected:
492   // Used for (incrementally) recording metrics. The class is thread-safe.
493   class Metrics {
494    public:
Metrics(const string & name)495     explicit Metrics(const string& name)
496         : bytes_consumed_counter_(metrics::GetTFDataBytesConsumedCounter(name)),
497           bytes_produced_counter_(metrics::GetTFDataBytesProducedCounter(name)),
498           num_elements_counter_(metrics::GetTFDataElementsCounter(name)),
499           recorded_bytes_consumed_(0),
500           recorded_bytes_produced_(0),
501           recorded_num_elements_(0) {}
502 
503     // Expects the total number of bytes consumed and records the delta since
504     // last invocation.
record_bytes_consumed(int64_t total_bytes)505     void record_bytes_consumed(int64_t total_bytes) {
506       int64_t delta =
507           total_bytes - recorded_bytes_consumed_.exchange(total_bytes);
508       bytes_consumed_counter_->IncrementBy(delta);
509     }
510 
511     // Expects the total number of bytes produced and records the delta since
512     // last invocation.
record_bytes_produced(int64_t total_bytes)513     void record_bytes_produced(int64_t total_bytes) {
514       int64_t delta =
515           total_bytes - recorded_bytes_produced_.exchange(total_bytes);
516       bytes_produced_counter_->IncrementBy(delta);
517     }
518 
519     // Expects the total number of elements produced and records the delta since
520     // last invocation.
record_num_elements(int64_t total_elements)521     void record_num_elements(int64_t total_elements) {
522       int64_t delta =
523           total_elements - recorded_num_elements_.exchange(total_elements);
524       num_elements_counter_->IncrementBy(delta);
525     }
526 
527    private:
528     monitoring::CounterCell* const bytes_consumed_counter_;
529     monitoring::CounterCell* const bytes_produced_counter_;
530     monitoring::CounterCell* const num_elements_counter_;
531     std::atomic<int64_t> recorded_bytes_consumed_;
532     std::atomic<int64_t> recorded_bytes_produced_;
533     std::atomic<int64_t> recorded_num_elements_;
534   };
535 
536   // Computes the exponential moving average of processing time per element.
UpdateProcessingTimeEma()537   void UpdateProcessingTimeEma() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
538     if (previous_processing_time_ == 0) {
539       if (num_elements_ > 0) {
540         processing_time_ema_ = static_cast<double>(processing_time_) /
541                                static_cast<double>(num_elements_);
542       } else {
543         processing_time_ema_ = static_cast<double>(processing_time_);
544       }
545     } else {
546       processing_time_ema_ =
547           (1.0 - kProcessingTimeEmaWeight) * processing_time_ema_ +
548           kProcessingTimeEmaWeight *
549               static_cast<double>(processing_time_ - previous_processing_time_);
550     }
551     previous_processing_time_ = processing_time_;
552   }
553 
554   // Returns the number of inputs.
num_inputs()555   int64_t num_inputs() const TF_SHARED_LOCKS_REQUIRED(mu_) {
556     int64_t num_inputs = 0;
557     for (auto& input : inputs_) {
558       // Inputs for which autotuning is disabled are excluded.
559       if (input->autotune()) {
560         ++num_inputs;
561       }
562     }
563     return num_inputs;
564   }
565 
566   // Creates a clone of this node.
567   virtual std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const
568       TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
569 
570   // Returns the average size of an element buffered in this node.
571   double AverageBufferedElementSize() const TF_SHARED_LOCKS_REQUIRED(mu_);
572 
573   // Returns the sum of per-element output time for the tunable inputs of this
574   // node.
575   double OutputTimeForInputs(const NodeValues& output_times) const
576       TF_SHARED_LOCKS_REQUIRED(mu_);
577 
578   // Returns the sum of output time gradient w.r.t. input time for the tunable
579   // inputs of this node.
580   double OutputTimeGradientsForInputs(const NodeValues& output_time_gradients)
581       const TF_SHARED_LOCKS_REQUIRED(mu_);
582 
583   // Computes the input time for this node and stores it in `input_times`.
584   virtual void InputTimeLocked(NodeValues* input_times) const
585       TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
586 
587   // Computes the per-element output time for this node and stores it in
588   // `output_times`. If `gradients` is not `nullptr`, computes the output time
589   // gradient w.r.t. tunable parameters of the subtree rooted in this node and
590   // stores it in `gradients`, also computes the output time gradient w.r.t.
591   // input time and stores it in `output_time_gradients`.
592   virtual void OutputTimeLocked(const NodeValues& input_times,
593                                 ParameterGradients* gradients,
594                                 NodeValues* output_times,
595                                 NodeValues* output_time_gradients) const
596       TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
597 
598   // Returns the sum of per-element processing time for the inputs of this node
599   // by adding values for input nodes in `total_processing_times`. Processing
600   // time for a given input is a weighted combination of a statistic based on
601   // history of input processing time and the actual time. This is done to
602   // improve accuracy of processing time estimation for newly created inputs.
603   //
604   // Uniform distribution of per-element processing times across different
605   // inputs is assumed.
606   double TotalProcessingTimeForInputs(const NodeValues& total_processing_times)
607       TF_SHARED_LOCKS_REQUIRED(mu_);
608 
609   // Returns the per-element processing time spent in this node.
610   double SelfProcessingTimeLocked() const TF_SHARED_LOCKS_REQUIRED(mu_);
611 
612   // Computes the per-element CPU time spent in the subtree rooted in this node
613   // and stores it in `total_processing_times`. If `processing_times` is not
614   // `nullptr`, collects the per-element CPU time spent in each node of the
615   // subtree.
616   virtual void TotalProcessingTimeLocked(NodeValues* processing_times,
617                                          NodeValues* total_processing_times)
618       TF_SHARED_LOCKS_REQUIRED(mu_) = 0;
619 
620   // This is the locked version of the public `CollectNodes`.
621   NodeVector CollectNodesLocked(TraversalOrder order,
622                                 bool collect_node(const std::shared_ptr<Node>))
623       const TF_SHARED_LOCKS_REQUIRED(mu_);
624 
625   // Collects tunable parameters in the subtree rooted in this node assuming
626   // mutex locked.
627   ModelParameters CollectTunableParametersLocked() const
628       TF_SHARED_LOCKS_REQUIRED(mu_);
629 
630   // Collect tunable parameters on the nodes which have recorded
631   // elements.
632   void CollectTunableParametersHelper(ModelParameters* parameters) const
633       TF_SHARED_LOCKS_REQUIRED(mu_);
634 
635   // Build up debug string for the node and store in the debug strings map.
636   void DebugStringHelper(absl::flat_hash_map<string, string>* debug_strings)
637       const TF_SHARED_LOCKS_REQUIRED(mu_);
638 
639   // Copy the node and add the (input, copy) pairs to the NodePairList.
640   std::shared_ptr<Node> SnapshotHelper(std::shared_ptr<Node> cloned_output,
641                                        NodePairList* node_pairs) const;
642 
643   // Compute total buffered bytes for the node and store in the total bytes map.
644   void TotalBufferedBytesHelper(NodeValues* total_bytes) const
645       TF_SHARED_LOCKS_REQUIRED(mu_);
646 
647   // Compute total maximum buffered bytes for the node and store in the total
648   // bytes map.
649   void TotalMaximumBufferedBytesHelper(NodeValues* total_bytes) const
650       TF_SHARED_LOCKS_REQUIRED(mu_);
651 
652   // Compute and return the maximum buffered bytes on the node itself. By
653   // default non-tunable nodes are assumed not to buffer any bytes, so the
654   // tunable nodes as subclasses are expected to override this method to ensure
655   // that the optimization algorithm respects the memory budget.
656   virtual double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_);
657 
658   // Restores node from the proto. Note that this is not done recursively, i.e.
659   // input nodes are not restored.
660   static Status FromProtoHelper(ModelProto::Node node_proto,
661                                 std::shared_ptr<Node> node);
662 
663   // Stores the time passed to the last call to `Node::record_start()` on the
664   // current thread.
665   //
666   // NOTE: This thread-local variable is shared between all instances of `Node`
667   // on which the same thread calls `record_start()` or `record_stop()`. It
668   // relies on the invariant that at most one `Node` can be "active" on a
669   // particular thread at any time. Therefore if `n->record_start()` is called
670   // on thread `t`, then `n->record_stop()` must be called before another call
671   // to `Node::record_start()` (for any node).
672   static thread_local int64_t work_start_;  // Will be initialized to zero.
673 
674   mutable mutex mu_;
675   const int64_t id_;
676   const string name_;
677 
678   // Indicates whether the subtree rooted in this node should be included in
679   // autotuning. In particular, if this is `false`, then the subtree is excluded
680   // from computation of output time and processing time.
681   std::atomic<bool> autotune_;
682   std::atomic<int64_t> buffered_bytes_;
683   std::atomic<int64_t> buffered_elements_;
684   std::atomic<int64_t> buffered_elements_low_;
685   std::atomic<int64_t> buffered_elements_high_;
686   std::atomic<int64_t> bytes_consumed_;
687   std::atomic<int64_t> bytes_produced_;
688   std::atomic<int64_t> num_elements_;
689   std::atomic<int64_t> processing_time_;
690   std::atomic<bool> record_metrics_;
691   Metrics metrics_;
692   absl::flat_hash_map<string, std::shared_ptr<Parameter>> parameters_
693       TF_GUARDED_BY(mu_);
694 
695   // Statistic of inputs processing time history.
696   double input_processing_time_sum_ = 0.0L;
697   int64_t input_processing_time_count_ = 0;
698 
699   // Holds the previous processing time and the per element processing time
700   // exponential moving average.
701   int64_t previous_processing_time_ TF_GUARDED_BY(mu_) = 0;
702   double processing_time_ema_ TF_GUARDED_BY(mu_) = 0.0;
703 
704   // Inputs of this node. These can represent an iterator created from the input
705   // dataset but also other input iterators (e.g. created by the user-defined
706   // functions of `flat_map` or `interleave`).
707   std::list<std::shared_ptr<Node>> inputs_ TF_GUARDED_BY(mu_);
708 
709   // The reference to the output node is not owned so that deletion of a
710   // node results in recursive deletion of the subtree rooted in the node.
711   Node* const output_;
712 };
713 
714 // InterleaveMany is used to model datasets whose inputs are used to create
715 // datasets whose elements are then interleaved.
716 std::shared_ptr<Node> MakeInterleaveManyNode(
717     Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters);
718 
719 // AsyncInterleaveMany nodes are the asynchronous version of InterleaveMany
720 // nodes.
721 std::shared_ptr<Node> MakeAsyncInterleaveManyNode(
722     Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters);
723 
724 // KnownMany nodes model datasets that synchronously consume known number of
725 // input element per output element.
726 std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio);
727 
728 // AsyncKnownRatio nodes are the asynchronous version of KnownRate nodes.
729 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
730     Node::Args args, double ratio, double memory_ratio,
731     std::vector<std::shared_ptr<Parameter>> parameters);
732 
733 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
734     Node::Args args, double ratio,
735     std::vector<std::shared_ptr<Parameter>> parameters);
736 
737 // Source nodes represent data sources.
738 std::shared_ptr<Node> MakeSourceNode(Node::Args args);
739 
740 // UnknownMany nodes represent datasets that synchronously consume an
741 // unknown number of input elements per output.
742 //
743 // Unlike KnownRatio nodes which expect the ratio between inputs and outputs is
744 // specified as a parameter, UnknownRatio estimates the ratio empirically.
745 std::shared_ptr<Node> MakeUnknownRatioNode(Node::Args args);
746 
747 // AsyncUnknownRatio nodes are the asynchronous version of unknown ratio nodes.
748 std::shared_ptr<Node> MakeAsyncUnknownRatioNode(
749     Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters);
750 
751 // Unknown nodes represent datasets for which we do not have a model. It acts
752 // as pass-through between inputs and output.
753 std::shared_ptr<Node> MakeUnknownNode(Node::Args args);
754 
755 // Abstract representation of a TensorFlow input pipeline that can be used
756 // for collecting runtime information and optimizing performance. It collects
757 // runtime information about execution of the input pipeline that is used to
758 // create a performance model, which is in turn used to identify optimal values
759 // of tunable parameters.
760 //
761 // Developers of tf.data transformations are not expected to interact with this
762 // class directly. Boiler plate code for creating the abstract representation of
763 // the input pipeline and collecting runtime information has been added to the
764 // implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
765 //
766 // The order of locks acquired is SharedState lock, Model lock, Node lock.
767 // SharedState lock is acquired first because it shares the same lock as the
768 // dataset iterator that contains it.
769 class Model {
770  public:
771   using OptimizationParams = ModelProto::OptimizationParams;
772   using ModelParameters = Node::ModelParameters;
773   using NodeValues = Node::NodeValues;
774   using ParameterGradients = Node::ParameterGradients;
775 
776   Model();
777   ~Model();
778 
779   // Returns a pointer to the model's output node.
output()780   const std::shared_ptr<Node> output() const {
781     mutex_lock l(mu_);
782     return output_;
783   }
784 
785   // Set the experiment that this job is part of.
SetExperiment(const string & experiment)786   void SetExperiment(const string& experiment) { experiment_ = experiment; }
787 
788   // Adds a node with the given name and given parent.
789   void AddNode(Node::Factory factory, const string& name,
790                std::shared_ptr<Node> parent, std::shared_ptr<Node>* out_node)
791       TF_LOCKS_EXCLUDED(mu_);
792 
793   // Returns a human-readable string representation of the model. This method
794   // can be invoked automatically by monitoring gauges and to avoid frequent
795   // recomputation, the implementation caches the result.
796   std::string DebugString();
797 
798   // Uses the given algorithm and resource budgets to periodically perform the
799   // autotuning optimization.
800   //
801   // To terminate the execution of the optimization loop, the caller needs to
802   // invoke `cancellation_mgr->StartCancel()`.
803   Status OptimizeLoop(AutotuneAlgorithm algorithm, int64_t cpu_budget,
804                       int64_t ram_budget,
805                       CancellationManager* cancellation_manager);
806 
807   // Uses the given algorithm and resource budgets to perform the autotuning
808   // optimization.
809   void Optimize(AutotuneAlgorithm algorithm, int64_t cpu_budget,
810                 int64_t ram_budget, double model_input_time,
811                 CancellationManager* cancellation_manager);
812 
813   // Optimizes buffers in the pipeline rooted at `snapshot`. It downsizes
814   // buffers that are too large and upsizes buffers that are too small while
815   // respecting the ram budget. If any node is downsized or upsized, the
816   // watermarks of all nodes are reset to the buffered elements.
817   void OptimizeBuffers(std::shared_ptr<Node> snapshot, int64_t ram_budget);
818 
819   // Collects the output time and if `gradients` is not `nullptr`, the output
820   // time gradient w.r.t. tunable parameters of the subtree rooted in the given
821   // node.
822   double OutputTime(std::shared_ptr<Node> node, double model_input_time,
823                     ParameterGradients* gradients);
824 
825   // Removes the given node.
826   void RemoveNode(std::shared_ptr<Node> node) TF_LOCKS_EXCLUDED(mu_);
827 
828   // Produces a proto for this model.
829   Status ToProto(ModelProto* model_proto);
830 
831   // Restores a model from the proto.
832   static Status FromProto(ModelProto model_proto,
833                           std::unique_ptr<Model>* model);
834 
835   // Saves this model with a given snapshot and its optimization parameters to a
836   // file. Note that the file directory must already exist.
837   Status Save(const string& fname, std::shared_ptr<Node> snapshot,
838               const OptimizationParams& optimization_params);
839 
840   // Loads a model and its optimization parameters from a file with the given
841   // name.
842   static Status Load(const string& fname, std::unique_ptr<Model>* model,
843                      OptimizationParams* optimization_params);
844 
845   // Records gap time between consecutive `GetNext()` calls.
846   void RecordIteratorGapTime(uint64_t duration_usec);
847 
848   // Computes the target time in nsecs to use for `STAGE_BASED` autotune
849   // algorithm.
850   double ComputeTargetTimeNsec();
851 
852  private:
853   // Determines whether optimization should stop given total processing time,
854   // estimated output time, and estimated number of buffers bytes.
855   using StopPredicate =
856       std::function<bool(const ModelParameters&, double, double, double)>;
857 
858   static constexpr int64_t kOptimizationPeriodMinMs = 10;
859   static constexpr int64_t kOptimizationPeriodMaxMs =
860       60 * EnvTime::kSecondsToMillis;
861 
862   // Collects tunable parameters in the tree rooted in the given node, returning
863   // a vector which contains pairs of node names and tunable parameters.
864   ModelParameters CollectTunableParameters(std::shared_ptr<Node> node);
865 
866   // Downsizes buffers that are too large for all nodes rooted at `snapshot`.
867   // Returns true if any buffer is downsized.
868   bool DownsizeBuffers(std::shared_ptr<Node> snapshot);
869 
870   // Upsizes buffers that are too small for all nodes rooted at `snapshot` while
871   // respecting the ram budget. Returns true if any buffer is upsized.
872   bool UpsizeBuffers(std::shared_ptr<Node> snapshot, int64_t ram_budget);
873 
874   // Reset buffer watermarks of all asynchronous nodes to their buffered
875   // elements.
876   void ResetBufferWatermarks();
877 
878   // Collects buffer parameters of all nodes in the model that should be
879   // upsized.
880   absl::flat_hash_map<Node*, Parameter*> CollectBufferParametersToUpsize(
881       std::shared_ptr<Node> snapshot);
882 
883   // Flushes metrics recorded by the model.
884   void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
885 
886   // This optimization algorithm starts by setting all tunable parallelism
887   // parameters to the minimum value. It then improves current parameters by
888   // making a step in the direction opposite to the gradient of `OutputTime` and
889   // projecting resulting values on the feasible intervals. Improvement step is
890   // repeated until either the output time improvement is smaller than threshold
891   // value or the output time is less than the processing time needed to produce
892   // an element divided by CPU budget.
893   void OptimizeGradientDescent(std::shared_ptr<Node> snapshot,
894                                const OptimizationParams& optimization_params,
895                                CancellationManager* cancellation_manager);
896 
897   // Helper method for implementing hill-climb optimization that can be
898   // parametrized by a predicate to use for stopping the optimization.
899   void OptimizeHillClimbHelper(std::shared_ptr<Node> snapshot,
900                                const OptimizationParams& optimization_params,
901                                CancellationManager* cancellation_manager,
902                                StopPredicate should_stop);
903 
904   // This optimization algorithm starts by setting all tunable parallelism
905   // parameters to the minimum value. It then repeatedly identifies the
906   // parameter whose increase in parallelism decreases the output time the most.
907   // This process is repeated until all parameters reach their maximum values or
908   // the projected output time is less than or equal to the processing time
909   // needed to produce an element divided by CPU budget.
910   void OptimizeHillClimb(std::shared_ptr<Node> snapshot,
911                          const OptimizationParams& optimization_params,
912                          CancellationManager* cancellation_manager);
913 
914   // This optimization behaves similarly to the hill climb optimization but uses
915   // a relaxed stoping condition, allowing the optimization to oversubscribe
916   // CPU.
917   void OptimizeMaxParallelism(std::shared_ptr<Node> snapshot,
918                               const OptimizationParams& optimization_params,
919                               CancellationManager* cancellation_manager);
920 
921   // This optimization starts by setting all tunable parallelism parameters to
922   // their minimum values. It then repeatedly increases the parallelism
923   // parameter of the longest stage by 1 until either the longest stage is
924   // faster than the target time or the memory or CPU budget is fully utilized.
925   // TODO(b/226910071): The second part of this algorithm optimizes the buffer
926   // sizes of parallel ops.
927   void OptimizeStageBased(std::shared_ptr<Node> snapshot,
928                           const OptimizationParams& optimization_params,
929                           CancellationManager* cancellation_manager);
930 
931   // This is the first part of the stage-based optimization that optimizes
932   // tunable parallelism parameters.
933   void OptimizeStageBasedParallelism(
934       std::shared_ptr<Node> snapshot, double target_time_nsec,
935       const OptimizationParams& optimization_params,
936       CancellationManager* cancellation_manager);
937 
938   // Determines if we should stop the gradient descent optimization iterations
939   // based on number of increasable parameters, CPU budget, RAM budget and
940   // current resource usage.
941   bool ShouldStop(int64_t cpu_budget, int64_t ram_budget,
942                   const ModelParameters& parameters,
943                   const ModelParameters& parallelism_parameters,
944                   const ModelParameters& buffer_size_parameters,
945                   std::shared_ptr<Node> snapshot, bool* cpu_budget_reached);
946 
947   // Collects the processing time for the given node.
948   double TotalProcessingTime(std::shared_ptr<Node> node);
949 
950   // Collects the total number of bytes buffered in all nodes in the subtree
951   // rooted in the given node for which autotuning is enabled.
952   double TotalBufferedBytes(std::shared_ptr<Node> node);
953 
954   // Collects the total buffer limit of all nodes in the subtree rooted in the
955   // given node for which autotuning is enabled. This number represents the
956   // amount of memory that would be used by the subtree nodes if all of their
957   // buffers were full.
958   double TotalMaximumBufferedBytes(std::shared_ptr<Node> node);
959 
960   // Used for coordination between different input pipeline threads. Exclusive
961   // access is required only when adding or removing nodes. Concurrent access to
962   // existing nodes is protected by a node mutex.
963   mutable mutex mu_;
964   // Used for coordinating the optimization loop and model modifications.
965   condition_variable optimize_cond_var_;
966   int64_t id_counter_ TF_GUARDED_BY(mu_) = 1;
967   std::shared_ptr<Node> output_ TF_GUARDED_BY(mu_) = nullptr;
968 
969   // Determines the time the optimization loop should wait between
970   // running optimizations.
971   int64_t optimization_period_ms_ TF_GUARDED_BY(mu_);
972 
973   // Gauge cell that can be used to collect the state of the model.
974   monitoring::GaugeCell<std::function<std::string()>>* model_gauge_cell_ =
975       nullptr;
976   // Time use for rate limitting the recomputation of human-readable string
977   // represention of the model.
978   absl::Time cache_until_ = absl::InfinitePast();
979   // Cached result of the `DebugString()` invocation used to implement rate
980   // limitting of the computation.
981   std::string cached_debug_string_ = "";
982   // Used to coordinate gap time updates between different threads. Gap time is
983   // the time between the completion of the previous `GetNext()` and the start
984   // of the next `GetNext()`.
985   mutable mutex gap_mu_;
986   // Stores the latest gap times between consecutive `GetNext()`.
987   std::deque<uint64_t> gap_times_usec_ TF_GUARDED_BY(gap_mu_);
988   // The experiment that this job is part of.
989   std::string experiment_ = "";
990 };
991 
992 // Class to compute timing information for a model.
993 class ModelTiming {
994  public:
995   struct NodeTiming {
996     // Pipeline ratio is the number of elements this node needs to produce in
997     // order to produce an element at the root of the pipeline.
998     double pipeline_ratio = 0.0;
999     // The self time it takes this node to produce the elements needed to
1000     // produce one element of the root of the pipeline.
1001     double self_time_nsec = 0.0;
1002     // The total time it takes this node and the subtree rooted at this node to
1003     // produce the elements needed to produce one element at the root of the
1004     // pipeline.
1005     double total_time_nsec = 0.0;
1006   };
1007 
1008   explicit ModelTiming(std::shared_ptr<Node> root);
1009 
1010   // Returns the timing data for `node`.
1011   const NodeTiming* GetTiming(const Node* node) const;
1012 
1013   // Returns the root nodes of all stages.
1014   std::vector<std::shared_ptr<Node>> GetStageRoots() const;
1015 
1016   // Returns all the nodes of a stage given the stage root.
1017   std::vector<std::shared_ptr<Node>> GetStageNodes(
1018       std::shared_ptr<Node> stage_root) const;
1019 
1020   // Computes the total time for a node.
1021   void ComputeNodeTotalTime(const Node& node);
1022 
1023  private:
1024   // Computes the pipeline ratios of all nodes.
1025   void ComputePipelineRatios(const Node::NodeVector& bfs_nodes);
1026 
1027   // Computes the total time for all nodes. The `reverse_bfs_nodes` are assumed
1028   // to be a vector of model nodes in reversed BFS manner.
1029   void ComputeTotalTimes(const Node::NodeVector& reverse_bfs_nodes);
1030 
1031   // Computes the total time of a node that is not an async interleave node.
1032   void ComputeNonAsyncInterleaveManyTotalTime(const Node& node);
1033 
1034   // Computes the total time of an async interleave node.
1035   void ComputeAsyncInterleaveManyTotalTime(const Node& node);
1036 
1037   // Returns a vector of all nodes in the model. The nodes are either in
1038   // breadth-first search or reverse breadth-first search order depending on the
1039   // `order` argument. The nodes are collected based on the results of the
1040   // `collect_node` predicate: if the predicate returns `false` for a given
1041   // node, then the subtree rooted in this node is excluded. The root node
1042   // itself is not collected.
1043   Node::NodeVector CollectNodes(
1044       std::shared_ptr<Node> root, TraversalOrder order,
1045       bool collect_node(const std::shared_ptr<Node>)) const;
1046 
1047   // Stores a pointer to the root of a model.
1048   std::shared_ptr<Node> root_;
1049 
1050   // Holds a mapping from node to its timing node.
1051   absl::flat_hash_map<const Node*, NodeTiming> timing_nodes_;
1052 };
1053 
1054 }  // namespace model
1055 }  // namespace data
1056 }  // namespace tensorflow
1057 
1058 #endif  // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
1059