xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/model.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/model.h"
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <memory>
21 #include <queue>
22 
23 #include "absl/time/clock.h"
24 #include "tensorflow/core/framework/cancellation.h"
25 #include "tensorflow/core/framework/model.pb.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/host_info.h"
29 #include "tensorflow/core/platform/mem.h"
30 #include "tensorflow/core/platform/statusor.h"
31 
32 namespace tensorflow {
33 namespace data {
34 namespace model {
35 
36 constexpr int64_t Model::kOptimizationPeriodMinMs;
37 constexpr int64_t Model::kOptimizationPeriodMaxMs;
38 
39 namespace {
40 
41 // This is the number of the latest gap times used to compute the target time
42 // for stage based optimization.
43 constexpr int32_t kGapTimeWindow = 100;
44 // Gap time threshold: any gap time over the this duration will be dropped.
45 constexpr uint64_t kGapDurationThresholdUsec = 10000000;  // 10 seconds
46 // In outlier computation, points that are larger than `kOutlierSigmas` standard
47 // deviations are considered outliers.
48 constexpr double kOutlierSigmas = 2.0;
49 
50 // A class to prune outliers given a set of points. To use it, instantiate an
51 // object and call the `GetCleanPoints()` method.
52 class OutlierPruner {
53  public:
OutlierPruner(const std::vector<uint64_t> & points)54   explicit OutlierPruner(const std::vector<uint64_t>& points)
55       : points_(points.begin(), points.end()) {}
56 
57   // Returns the remaining points after removing outliers from the original set
58   // of points.
GetCleanPoints()59   std::vector<uint64_t> GetCleanPoints() {
60     if (points_.empty()) {
61       return points_;
62     }
63     // Compute the outlier threshold
64     double mean;
65     double standard_deviation;
66     ComputeMeanAndStandardDeviation(&mean, &standard_deviation);
67     double threshold = mean + standard_deviation * kOutlierSigmas;
68     std::vector<uint64_t> clean_points;
69     for (auto point : points_) {
70       if (static_cast<double>(point) > threshold) {
71         continue;
72       }
73       clean_points.push_back(point);
74     }
75     return clean_points;
76   }
77 
78  private:
ComputeMeanAndStandardDeviation(double * mean,double * standard_deviation)79   void ComputeMeanAndStandardDeviation(double* mean,
80                                        double* standard_deviation) {
81     uint64_t sum = std::accumulate(points_.begin(), points_.end(), 0);
82     *mean = static_cast<double>(sum) / static_cast<double>(points_.size());
83     double accum = 0.0;
84     for (auto point : points_) {
85       accum += (static_cast<double>(point) - *mean) *
86                (static_cast<double>(point) - *mean);
87     }
88     *standard_deviation = std::sqrt(accum / (points_.size() - 1));
89   }
90 
91   // Points to cluster.
92   std::vector<uint64_t> points_;
93 };
94 
95 // A priority queue that holds stage roots where the top of the priority queue
96 // is the node with the largest total time.
97 class ModelTimingPriorityQueue {
98  public:
ModelTimingPriorityQueue(ModelTiming & model_timing)99   explicit ModelTimingPriorityQueue(ModelTiming& model_timing) {
100     std::vector<std::shared_ptr<Node>> stage_roots =
101         model_timing.GetStageRoots();
102     if (stage_roots.empty()) {
103       return;
104     }
105     for (auto& root : stage_roots) {
106       DCHECK(model_timing.GetTiming(root.get()) != nullptr);
107       const ModelTiming::NodeTiming* root_timing =
108           model_timing.GetTiming(root.get());
109       stage_roots_queue_.emplace(
110           root_timing->total_time_nsec * root_timing->pipeline_ratio,
111           root.get());
112     }
113   }
114 
115   // Pops the top item from the queue, i.e. node with the largest total time.
PopSlowestStageRoot()116   StatusOr<std::pair<double, Node*>> PopSlowestStageRoot() {
117     if (stage_roots_queue_.empty()) {
118       return errors::Internal(
119           "Model timing priority queue is empty during stage-based "
120           "optimization");
121     }
122     std::pair<double, Node*> top_item = stage_roots_queue_.top();
123     stage_roots_queue_.pop();
124     return top_item;
125   }
126 
127   // Push a node together with its total time onto the queue.
Push(Node * node,const ModelTiming::NodeTiming & node_timing)128   void Push(Node* node, const ModelTiming::NodeTiming& node_timing) {
129     stage_roots_queue_.emplace(
130         node_timing.total_time_nsec * node_timing.pipeline_ratio, node);
131   }
132 
133  private:
134   std::priority_queue<std::pair<double, Node*>> stage_roots_queue_;
135 };
136 
137 // A cache that looks up the `parallelism` parameters of nodes the first time
138 // they are requested and saves them for subsequent requests.
139 class NodeParallelismParameters {
140  public:
NodeParallelismParameters()141   NodeParallelismParameters() {}
142 
143   // Returns the `parallelism` parameter given a node.
Get(const Node * node)144   Parameter* Get(const Node* node) {
145     if (node_parallelism_.contains(node)) {
146       // Look for the `parallelism` parameter of this node in the cache.
147       return node_parallelism_.at(node);
148     }
149     // Find the `parallelism` parameter of this node and cache it.
150     Node::ModelParameters parameters = node->CollectNodeTunableParameters();
151     Node::ModelParameters::iterator parameter_pair = std::find_if(
152         parameters.begin(), parameters.end(),
153         [](const std::pair<std::string, std::shared_ptr<Parameter>>&
154                parameter) { return parameter.second->name == kParallelism; });
155     if (parameter_pair == parameters.end()) {
156       return nullptr;
157     }
158     node_parallelism_[node] = parameter_pair->second.get();
159     return parameter_pair->second.get();
160   }
161 
162  private:
163   absl::flat_hash_map<const Node*, Parameter*> node_parallelism_;
164 };
165 
166 // Returns true if all parameters have reached their max values.
AreAllParametersMax(const Model::ModelParameters & parameters)167 bool AreAllParametersMax(const Model::ModelParameters& parameters) {
168   for (const auto& pair : parameters) {
169     if (pair.second->value < pair.second->max) {
170       return false;
171     }
172   }
173   return true;
174 }
175 
176 // Records the ram usage of hill climbing algorithm.
RecordAutotuneRamUsage(int64 ram_budget,double max_buffered_bytes)177 void RecordAutotuneRamUsage(int64 ram_budget, double max_buffered_bytes) {
178   if (ram_budget == 0) {
179     return;
180   }
181   const auto memory_info = port::GetMemoryInfo();
182   // Records ratio of memory used since RootDataset was created over the ram
183   // budget.
184   const auto original_free_memory = ram_budget / kRamBudgetShare;
185   const auto current_free_memory = memory_info.free;
186   metrics::RecordTFDataAutotuneUsedRamBudgetRatio(
187       (original_free_memory - current_free_memory) / ram_budget);
188   // Records ratio of maximum buffer bytes tf.data could use over the ram
189   // budget.
190   metrics::RecordTFDataAutotuneMaxBufferBudgetRatio(
191       max_buffered_bytes / static_cast<double>(ram_budget));
192 }
193 
194 // Helper function for node traversal that doesn't skip any nodes.
IsAnyNode(const std::shared_ptr<Node> node)195 inline bool IsAnyNode(const std::shared_ptr<Node> node) { return true; }
196 
197 // Helper function for node traversal that filters out nodes for which
198 // autotuning is disabled.
IsAutotuneNode(const std::shared_ptr<Node> node)199 inline bool IsAutotuneNode(const std::shared_ptr<Node> node) {
200   return node->autotune();
201 }
202 
203 // Helper function for node traversal that returns only synchronous nodes.
IsSyncNode(const std::shared_ptr<Node> node)204 inline bool IsSyncNode(const std::shared_ptr<Node> node) {
205   return !node->IsAsync();
206 }
207 
208 // Helper function for node traversal that returns only asynchronous nodes.
IsAsyncNode(const std::shared_ptr<Node> node)209 inline bool IsAsyncNode(const std::shared_ptr<Node> node) {
210   return node->IsAsync();
211 }
212 
213 // Wrapper for the square function to reduce verbosity.
Square(double x)214 inline double Square(double x) { return x * x; }
215 
216 // Collects "essential" parallelism parameters and buffer size parameters in the
217 // tree rooted in the given node. Which parallelism parameters are essential is
218 // determined by the relative processing time spent in the corresponding
219 // transformation. The collected parameters are returned via maps that map node
220 // names to their respective parameters.
CollectParameters(std::shared_ptr<Node> node,const Node::ModelParameters & parameters,Node::ModelParameters * parallelism_parameters,Node::ModelParameters * buffer_size_parameters)221 inline void CollectParameters(std::shared_ptr<Node> node,
222                               const Node::ModelParameters& parameters,
223                               Node::ModelParameters* parallelism_parameters,
224                               Node::ModelParameters* buffer_size_parameters) {
225   // Parallelism parameter is considered to be essential if the corresponding
226   // transformations's processing time is greater than essential rate times the
227   // average transformation self processing time.
228   constexpr double kEssentialRate = 0.3L;
229 
230   Node::NodeValues processing_times;
231   double processing_time = node->TotalProcessingTime(&processing_times);
232   double uniform_share =
233       processing_time / static_cast<double>(processing_times.size());
234   for (auto& pair : parameters) {
235     if (pair.second->name == kParallelism &&
236         processing_times[pair.first] > kEssentialRate * uniform_share) {
237       parallelism_parameters->push_back(pair);
238     } else if (pair.second->name == kBufferSize) {
239       buffer_size_parameters->push_back(pair);
240     }
241   }
242 }
243 
244 // Applies the gradient descent method once and updates the parameter values. If
245 // the new value is out of the range, bound it within the range between the
246 // minimal and maximum values.
UpdateParameterValues(const Node::ParameterGradients & gradients,Node::ModelParameters * parameters)247 inline void UpdateParameterValues(const Node::ParameterGradients& gradients,
248                                   Node::ModelParameters* parameters) {
249   // Gradient descent step size.
250   constexpr double kDescentStep = 0.1L;
251   double new_value;
252 
253   double max_abs_derivative = 1.0;
254   for (auto& pair : *parameters) {
255     if (std::round(pair.second->value) != pair.second->max) {
256       auto* gradient = gtl::FindOrNull(
257           gradients, std::make_pair(pair.first, pair.second->name));
258       if (gradient) {
259         max_abs_derivative = std::max(max_abs_derivative, std::abs(*gradient));
260       }
261     }
262   }
263   for (auto& pair : *parameters) {
264     auto* gradient = gtl::FindOrNull(
265         gradients, std::make_pair(pair.first, pair.second->name));
266     if (gradient) {
267       new_value =
268           pair.second->value - kDescentStep * (*gradient) / max_abs_derivative;
269       // Projection on a feasible interval.
270       if (new_value > pair.second->max) {
271         pair.second->value = pair.second->max;
272       } else if (new_value < pair.second->min) {
273         pair.second->value = pair.second->min;
274       } else {
275         pair.second->value = new_value;
276       }
277     }
278   }
279 }
280 
281 // Copies the parameter values (which are for optimization tuning) and updates
282 // the state values (which are for the input pipeline to follow).
UpdateStateValues(Node::ModelParameters * parameters)283 inline void UpdateStateValues(Node::ModelParameters* parameters) {
284   for (auto& pair : *parameters) {
285     auto& parameter = pair.second;
286     VLOG(2) << "Setting tunable parameter " << pair.first
287             << ":: " << parameter->name << " to " << parameter->value;
288     mutex_lock l(*parameter->state->mu);
289     parameter->state->value = parameter->value;
290     parameter->state->cond_var->notify_all();
291   }
292 }
293 
294 // Recursively produces protos for nodes in a subtree of `output` node and
295 // appends them to nodes of the given model.
ModelToProtoHelper(std::shared_ptr<Node> output,ModelProto * model)296 Status ModelToProtoHelper(std::shared_ptr<Node> output, ModelProto* model) {
297   model->set_output(output->id());
298   std::list<std::shared_ptr<Node>> to_serialize = {output};
299   auto& nodes = *model->mutable_nodes();
300   while (!to_serialize.empty()) {
301     const std::shared_ptr<Node> node = to_serialize.front();
302     to_serialize.pop_front();
303     TF_RETURN_IF_ERROR(node->ToProto(&(nodes[node->id()])));
304     for (auto input : node->inputs()) {
305       to_serialize.push_back(input);
306     }
307   }
308   return OkStatus();
309 }
310 
311 // Recursively produces node tree rooted in `output` from the given model proto.
ModelFromProtoHelper(ModelProto model,std::shared_ptr<Node> * output)312 Status ModelFromProtoHelper(ModelProto model, std::shared_ptr<Node>* output) {
313   if (model.nodes().empty()) {
314     return errors::Internal(
315         "Cannot restore model from proto because it has no nodes.");
316   }
317   TF_RETURN_IF_ERROR(Node::FromProto(model.nodes().at(model.output()),
318                                      /*output=*/nullptr, output));
319   std::list<std::shared_ptr<Node>> to_restore_inputs = {*output};
320   while (!to_restore_inputs.empty()) {
321     std::shared_ptr<Node> node = to_restore_inputs.front();
322     to_restore_inputs.pop_front();
323     for (int64_t input_id : model.nodes().at(node->id()).inputs()) {
324       std::shared_ptr<Node> input;
325       TF_RETURN_IF_ERROR(
326           Node::FromProto(model.nodes().at(input_id), node, &input));
327       node->add_input(input);
328       to_restore_inputs.push_back(input);
329     }
330   }
331   return OkStatus();
332 }
333 
334 // The first input of InterleaveMany corresponds to the input dataset whose
335 // elements are used to create the (derived) input datasets whose elements are
336 // interleaved as output.
337 //
338 // TODO(jsimsa): model the first input
339 class InterleaveMany : public Node {
340  public:
341   using Node::Node;
342 
InterleaveMany(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)343   InterleaveMany(Node::Args args,
344                  std::vector<std::shared_ptr<Parameter>> parameters)
345       : Node(args) {
346     for (auto& parameter : parameters) {
347       parameters_[parameter->name] = std::move(parameter);
348     }
349   }
350 
~InterleaveMany()351   virtual ~InterleaveMany() {}
352 
353   // The ratio of an InterleaveMany node is `1/cycle_length`. If cycle length is
354   // not available, we approximate it by `1/input_size`. The input size does not
355   // include the original input dataset that generates other input datasets of
356   // interleave nodes.
Ratio() const357   double Ratio() const override {
358     auto* cycle_length = gtl::FindOrNull(parameters_, kCycleLength);
359     if (cycle_length != nullptr) {
360       return 1.0 / (*cycle_length)->value;
361     }
362     // After cl/436244658, `cycle_length` can not be `nullptr`. The remaining
363     // part of this function is used to approximate `Ratio()` of this node for
364     // model proto that was created before the CL.
365 
366     // Cycle length is not available, use 1/input_size as the ratio.
367     std::size_t input_size = 1;
368     {
369       mutex_lock l(mu_);
370       if (inputs_.size() >= 2) {
371         auto first_input = inputs_.begin();
372         auto second_input = std::next(first_input);
373         // Some interleave datasets have 2 different inputs: the original input
374         // dataset and the generated input datasets when interleave is iterated,
375         // and some do not.
376         if ((*first_input)->name() == (*second_input)->name()) {
377           input_size = std::max(inputs_.size(), input_size);
378         } else {
379           input_size = std::max(inputs_.size() - 1, input_size);
380         }
381       }
382     }
383     if (input_size == 0) {
384       return 1.0;
385     }
386     return 1.0 / static_cast<double>(input_size);
387   }
388 
389  protected:
Clone(std::shared_ptr<Node> output) const390   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
391       TF_SHARED_LOCKS_REQUIRED(mu_) {
392     return std::make_shared<InterleaveMany>(
393         Args{id_, name_, std::move(output)});
394   }
395 
InputTimeLocked(NodeValues * input_times) const396   void InputTimeLocked(NodeValues* input_times) const override
397       TF_SHARED_LOCKS_REQUIRED(mu_) {
398     double inherited_input_time;
399     if (output_) {
400       inherited_input_time = (*input_times)[output_->long_name()];
401     } else {
402       inherited_input_time = (*input_times)[kModelInputTimeKey];
403     }
404 
405     if (num_inputs() <= 1) {
406       (*input_times)[long_name()] = inherited_input_time;
407       return;
408     }
409     // Here `inherited_input_time + SelfProcessingTimeLocked()` is the average
410     // input time for InterleaveMany node to call one of the `(num_inputs() -
411     // 1)` input nodes (except first input) to return an element. Regardless of
412     // the `block_length` parameter of InterleaveMany node, the average input
413     // time for any of the `(num_inputs() - 1)` input nodes to be called is
414     // computed as:
415     double input_time = (inherited_input_time + SelfProcessingTimeLocked()) *
416                         static_cast<double>(num_inputs() - 1);
417     (*input_times)[long_name()] = input_time;
418   }
419 
420   // The output time is the sum of the self processing time and the average
421   // output time of inputs comprising the interleave "cycle".
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const422   void OutputTimeLocked(const NodeValues& input_times,
423                         ParameterGradients* gradients, NodeValues* output_times,
424                         NodeValues* output_time_gradients) const override
425       TF_SHARED_LOCKS_REQUIRED(mu_) {
426     double self_processing_time = SelfProcessingTimeLocked();
427     if (num_inputs() <= 1) {
428       (*output_times)[long_name()] = self_processing_time;
429       if (gradients) {
430         for (const auto& pair : CollectTunableParametersLocked()) {
431           gradients->erase(std::make_pair(pair.first, pair.second->name));
432         }
433       }
434       return;
435     }
436 
437     double inputs_output_time =
438         (OutputTimeForInputs(*output_times) -
439          (*output_times)[inputs_.front()->long_name()]) /
440         static_cast<double>(num_inputs() - 1);
441     if (gradients) {
442       for (const auto& pair : CollectTunableParametersLocked()) {
443         auto* gradient = gtl::FindOrNull(
444             *gradients, std::make_pair(pair.first, pair.second->name));
445         if (gradient) {
446           *gradient /= static_cast<double>(num_inputs() - 1);
447         }
448       }
449 
450       (*output_time_gradients)[long_name()] =
451           OutputTimeGradientsForInputs(*output_time_gradients) -
452           (*output_time_gradients)[inputs_.front()->long_name()];
453 
454       // Set derivatives w.r.t. tunable parameters of the subtree rooted in the
455       // first input equal to 0 since its output time is excluded from
456       // computations.
457       for (auto& pair : inputs_.front()->CollectTunableParameters()) {
458         (*gradients)[std::make_pair(pair.first, pair.second->name)] = 0.0L;
459       }
460     }
461     (*output_times)[long_name()] = self_processing_time + inputs_output_time;
462   }
463 
464   // The processing time is the sum of the self processing time and the average
465   // processing time of inputs comprising the interleave "cycle".
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)466   void TotalProcessingTimeLocked(NodeValues* processing_times,
467                                  NodeValues* total_processing_times) override
468       TF_SHARED_LOCKS_REQUIRED(mu_) {
469     double self_processing_time = SelfProcessingTimeLocked();
470     if (processing_times) {
471       (*processing_times)[long_name()] = self_processing_time;
472     }
473     if (num_inputs() <= 1) {
474       (*total_processing_times)[long_name()] = self_processing_time;
475       return;
476     }
477     double inputs_processing_time =
478         (TotalProcessingTimeForInputs(*total_processing_times) -
479          (*total_processing_times)[inputs_.front()->long_name()]) /
480         static_cast<double>(num_inputs() - 1);
481     (*total_processing_times)[long_name()] =
482         self_processing_time + inputs_processing_time;
483   }
484 
ToProto(ModelProto::Node * node_proto) const485   Status ToProto(ModelProto::Node* node_proto) const {
486     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
487     node_proto->set_node_class(NodeClass::INTERLEAVE_MANY);
488     return OkStatus();
489   }
490 };
491 
492 // The first input of AsyncInterleaveMany corresponds to the input dataset whose
493 // elements are used to create the (derived) input datasets whose elements are
494 // interleaved as output.
495 //
496 // TODO(jsimsa): model the first input
497 class AsyncInterleaveMany : public Node {
498  public:
AsyncInterleaveMany(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)499   AsyncInterleaveMany(Node::Args args,
500                       std::vector<std::shared_ptr<Parameter>> parameters)
501       : Node(args) {
502     for (auto& parameter : parameters) {
503       parameters_[parameter->name] = std::move(parameter);
504     }
505   }
506 
~AsyncInterleaveMany()507   virtual ~AsyncInterleaveMany() {}
508 
IsAsync() const509   bool IsAsync() const override { return true; }
510 
511   // The ratio of an AsyncInterleaveMany node is 1/`cycle_length`. If cycle
512   // length is not available, we use 1/parallelism.
Ratio() const513   double Ratio() const override {
514     auto* cycle_length = gtl::FindOrNull(parameters_, kCycleLength);
515     if (cycle_length != nullptr) {
516       return 1.0 / (*cycle_length)->value;
517     }
518     // After cl/436244658, `cycle_length` can not be `nullptr`. The remaining
519     // part of this function is used to approximate `Ratio()` of this node for
520     // model proto that was created before the CL.
521 
522     // Cycle length is not available, use 1/min(input_size, parallelism) as the
523     // ratio.
524     double parallelism = 1.0;
525     {
526       mutex_lock l(mu_);
527       if (inputs_.size() >= 2) {
528         auto first_input = inputs_.begin();
529         auto second_input = std::next(first_input);
530         // Some interleave datasets have 2 different inputs: the original input
531         // dataset and the generated input datasets when interleave is iterated,
532         // and some do not.
533         if ((*first_input)->name() == (*second_input)->name()) {
534           parallelism = std::max(inputs_.size(), size_t{1});
535         } else {
536           parallelism = std::max(inputs_.size() - 1, size_t{1});
537         }
538       }
539     }
540     auto* parameter = gtl::FindOrNull(parameters_, kParallelism);
541     if (parameter) {
542       parallelism = std::min(parallelism, (*parameter)->value);
543     }
544     return 1.0 / parallelism;
545   }
546 
ComputeSelfTime() const547   double ComputeSelfTime() const override {
548     double parallelism = 1.0;
549     auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
550     if (parallelism_parameter) {
551       parallelism = (*parallelism_parameter)->value;
552     }
553     if (num_elements_ == 0) {
554       return 0;
555     }
556     {
557       tf_shared_lock l(mu_);
558       return processing_time_ema_ / parallelism;
559     }
560   }
561 
562  protected:
Clone(std::shared_ptr<Node> output) const563   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
564       TF_SHARED_LOCKS_REQUIRED(mu_) {
565     std::vector<std::shared_ptr<Parameter>> parameters;
566     for (auto& pair : parameters_) {
567       parameters.push_back(pair.second);
568     }
569     return std::make_shared<AsyncInterleaveMany>(
570         Args{id_, name_, std::move(output)}, parameters);
571   }
572 
InputTimeLocked(NodeValues * input_times) const573   void InputTimeLocked(NodeValues* input_times) const override
574       TF_SHARED_LOCKS_REQUIRED(mu_) {
575     double inherited_input_time;
576     if (output_) {
577       inherited_input_time = (*input_times)[output_->long_name()];
578     } else {
579       inherited_input_time = (*input_times)[kModelInputTimeKey];
580     }
581 
582     if (num_inputs() <= 1) {
583       (*input_times)[long_name()] = inherited_input_time;
584       return;
585     }
586     // Here `inherited_input_time + SelfProcessingTimeLocked()` is the average
587     // input time for AsyncInterleaveMany node to call one of the `(num_inputs()
588     // - 1)` input nodes (except first input) to return an element. Regardless
589     // of the `block_length` parameter of AsyncInterleaveMany node, the average
590     // input time for any of the `(num_inputs() - 1)` input nodes to be called
591     // is computed as:
592     double input_time = (inherited_input_time + SelfProcessingTimeLocked()) *
593                         static_cast<double>(num_inputs() - 1);
594     (*input_times)[long_name()] = input_time;
595   }
596 
597   // The output time is the sum of self processing time and expected wait time
598   // from the buffer model estimated using `ComputeWaitTime(producer_time,
599   // consumer_time, parallelism, ...)`, where `producer_time` is the average
600   // output time of inputs comprising the interleave "cycle" divided by
601   // `parallelism`, `consumer_time` is the `input_time` specified through
602   // `input_times` divided by `num_inputs() - 1`, and if the node has
603   // parallelism parameter, then `buffer_size` is derived from `parallelism`.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const604   void OutputTimeLocked(const NodeValues& input_times,
605                         ParameterGradients* gradients, NodeValues* output_times,
606                         NodeValues* output_time_gradients) const override
607       TF_SHARED_LOCKS_REQUIRED(mu_) {
608     double self_processing_time = SelfProcessingTimeLocked();
609     if (num_inputs() <= 1) {
610       (*output_times)[long_name()] = self_processing_time;
611       if (gradients) {
612         for (const auto& pair : CollectTunableParametersLocked()) {
613           gradients->erase(std::make_pair(pair.first, pair.second->name));
614         }
615       }
616       return;
617     }
618 
619     double output_time, wait_time, consumer_time, producer_time;
620     double input_time = input_times.at(long_name());
621     consumer_time = input_time / static_cast<double>(num_inputs() - 1);
622     double parallelism = num_inputs() - 1;  // default to cycle length
623     auto* parameter = gtl::FindOrNull(parameters_, kParallelism);
624     if (parameter) {
625       parallelism = std::min(parallelism, (*parameter)->value);
626     }
627     double output_time_for_inputs =
628         OutputTimeForInputs(*output_times) -
629         (*output_times)[inputs_.front()->long_name()];
630     producer_time = output_time_for_inputs /
631                     static_cast<double>(num_inputs() - 1) / parallelism;
632 
633     if (gradients) {
634       double producer_time_der = 0.0L;
635       double consumer_time_der = 0.0L;
636       double buffer_size_der = 0.0L;
637       wait_time = ComputeWaitTime(producer_time, consumer_time, parallelism,
638                                   &producer_time_der, &consumer_time_der,
639                                   &buffer_size_der);
640       double inputs_time_der_sum =
641           OutputTimeGradientsForInputs(*output_time_gradients);
642       (*output_time_gradients)[long_name()] =
643           consumer_time_der +
644           producer_time_der * inputs_time_der_sum / parallelism;
645 
646       for (const auto& pair : CollectTunableParametersLocked()) {
647         auto* gradient = gtl::FindOrNull(
648             *gradients, std::make_pair(pair.first, pair.second->name));
649         if (gradient) {
650           *gradient *= (producer_time_der /
651                         static_cast<double>(num_inputs() - 1) / parallelism);
652         }
653       }
654 
655       // Set derivatives w.r.t. tunable parameters of the subtree rooted in the
656       // first input equal to 0 since its output time is excluded from
657       // computations.
658       for (auto& pair : inputs_.front()->CollectTunableParameters()) {
659         (*gradients)[std::make_pair(pair.first, pair.second->name)] = 0.0L;
660       }
661       // Add derivative w.r.t. own parallelism parameter.
662       if (parameter && (*parameter)->state->tunable) {
663         (*gradients)[std::make_pair(long_name(), (*parameter)->name)] =
664             buffer_size_der - producer_time_der * producer_time / parallelism;
665       }
666     } else {
667       wait_time = ComputeWaitTime(producer_time, consumer_time, parallelism,
668                                   /*producer_time_derivative=*/nullptr,
669                                   /*consumer_time_derivative=*/nullptr,
670                                   /*buffer_size_derivative=*/nullptr);
671     }
672     output_time = self_processing_time + wait_time;
673     (*output_times)[long_name()] = output_time;
674   }
675 
676   // The processing time is the sum of the self processing time and the average
677   // processing time of inputs comprising the interleave "cycle".
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)678   void TotalProcessingTimeLocked(NodeValues* processing_times,
679                                  NodeValues* total_processing_times) override
680       TF_SHARED_LOCKS_REQUIRED(mu_) {
681     double self_processing_time = SelfProcessingTimeLocked();
682     if (processing_times) {
683       (*processing_times)[long_name()] = self_processing_time;
684     }
685     if (num_inputs() <= 1) {
686       (*total_processing_times)[long_name()] = self_processing_time;
687       return;
688     }
689     double inputs_processing_time =
690         (TotalProcessingTimeForInputs(*total_processing_times) -
691          (*total_processing_times)[inputs_.front()->long_name()]) /
692         static_cast<double>(num_inputs() - 1);
693     (*total_processing_times)[long_name()] =
694         self_processing_time + inputs_processing_time;
695   }
696 
MaximumBufferedBytes() const697   double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
698     auto* parameter = gtl::FindOrNull(parameters_, kMaxBufferedElements);
699     if (parameter == nullptr) {
700       parameter = gtl::FindOrNull(parameters_, kParallelism);
701       if (parameter == nullptr) {
702         return 0.0;
703       }
704     }
705     return (*parameter)->value * AverageBufferedElementSize();
706   }
707 
ToProto(ModelProto::Node * node_proto) const708   Status ToProto(ModelProto::Node* node_proto) const {
709     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
710     node_proto->set_node_class(NodeClass::ASYNC_INTERLEAVE_MANY);
711     return OkStatus();
712   }
713 };
714 
715 class KnownRatio : public Node {
716  public:
KnownRatio(Node::Args args,double ratio)717   KnownRatio(Node::Args args, double ratio) : Node(args), ratio_(ratio) {}
718 
~KnownRatio()719   virtual ~KnownRatio() {}
720 
Ratio() const721   double Ratio() const override { return ratio_; }
722 
723  protected:
Clone(std::shared_ptr<Node> output) const724   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
725       TF_SHARED_LOCKS_REQUIRED(mu_) {
726     return std::make_shared<KnownRatio>(Args{id_, name_, std::move(output)},
727                                         ratio_);
728   }
729 
730   // The input time is the sum of inherited input time and self processing time,
731   // divided by `ratio_`.
InputTimeLocked(NodeValues * input_times) const732   void InputTimeLocked(NodeValues* input_times) const override
733       TF_SHARED_LOCKS_REQUIRED(mu_) {
734     double inherited_input_time;
735     if (output_) {
736       inherited_input_time = (*input_times)[output_->long_name()];
737     } else {
738       inherited_input_time = (*input_times)[kModelInputTimeKey];
739     }
740 
741     if (ratio_ == 0) {
742       (*input_times)[long_name()] = inherited_input_time;
743       return;
744     }
745     double input_time =
746         (inherited_input_time + SelfProcessingTimeLocked()) / ratio_;
747     (*input_times)[long_name()] = input_time;
748   }
749 
750   // The output time is the sum of the self processing time and the product of
751   // `ratio_` and the sum of output times of inputs.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const752   void OutputTimeLocked(const NodeValues& input_times,
753                         ParameterGradients* gradients, NodeValues* output_times,
754                         NodeValues* output_time_gradients) const override
755       TF_SHARED_LOCKS_REQUIRED(mu_) {
756     double self_processing_time = SelfProcessingTimeLocked();
757     if (ratio_ == 0) {
758       (*output_times)[long_name()] = self_processing_time;
759       if (gradients) {
760         for (const auto& pair : CollectTunableParametersLocked()) {
761           gradients->erase(std::make_pair(pair.first, pair.second->name));
762         }
763       }
764       return;
765     }
766     if (gradients) {
767       for (const auto& pair : CollectTunableParametersLocked()) {
768         auto* gradient = gtl::FindOrNull(
769             *gradients, std::make_pair(pair.first, pair.second->name));
770         if (gradient) {
771           *gradient *= ratio_;
772         }
773       }
774       (*output_time_gradients)[long_name()] =
775           OutputTimeGradientsForInputs(*output_time_gradients);
776     }
777     double inputs_output_time = ratio_ * OutputTimeForInputs(*output_times);
778     (*output_times)[long_name()] = self_processing_time + inputs_output_time;
779   }
780 
781   // The processing time is the sum of the self processing time and the product
782   // of `ratio_` and the sum of processing times of inputs.
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)783   void TotalProcessingTimeLocked(NodeValues* processing_times,
784                                  NodeValues* total_processing_times) override
785       TF_SHARED_LOCKS_REQUIRED(mu_) {
786     double self_processing_time = SelfProcessingTimeLocked();
787     if (processing_times) {
788       (*processing_times)[long_name()] = self_processing_time;
789     }
790     if (ratio_ == 0) {
791       (*total_processing_times)[long_name()] = self_processing_time;
792       return;
793     }
794     double inputs_processing_time =
795         ratio_ * TotalProcessingTimeForInputs(*total_processing_times);
796     (*total_processing_times)[long_name()] =
797         self_processing_time + inputs_processing_time;
798   }
799 
ToProto(ModelProto::Node * node_proto) const800   Status ToProto(ModelProto::Node* node_proto) const {
801     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
802     node_proto->set_node_class(NodeClass::KNOWN_RATIO);
803     node_proto->set_ratio(ratio_);
804     return OkStatus();
805   }
806 
807  private:
808   const double ratio_;
809 };
810 
811 class AsyncRatio : public Node {
812  public:
AsyncRatio(Node::Args args,double ratio,double memory_ratio,std::vector<std::shared_ptr<Parameter>> parameters)813   AsyncRatio(Node::Args args, double ratio, double memory_ratio,
814              std::vector<std::shared_ptr<Parameter>> parameters)
815       : Node(args), ratio_(ratio), memory_ratio_(memory_ratio) {
816     for (auto& parameter : parameters) {
817       parameters_[parameter->name] = std::move(parameter);
818     }
819   }
820 
~AsyncRatio()821   virtual ~AsyncRatio() {}
822 
IsAsync() const823   bool IsAsync() const override { return true; }
824 
Ratio() const825   double Ratio() const override { return ratio_; }
826 
ComputeSelfTime() const827   double ComputeSelfTime() const override {
828     double parallelism = 1.0;
829     auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
830     if (parallelism_parameter) {
831       parallelism = (*parallelism_parameter)->value;
832     }
833     if (num_elements_ == 0) {
834       return 0;
835     }
836     {
837       tf_shared_lock l(mu_);
838       return processing_time_ema_ / parallelism;
839     }
840   }
841 
842  protected:
RatioLocked() const843   virtual double RatioLocked() const TF_SHARED_LOCKS_REQUIRED(mu_) {
844     return ratio_;
845   }
846 
MemoryRatio() const847   double MemoryRatio() const { return memory_ratio_; }
848 
849   // The input time is the sum of inherited input time and parallelism adjusted
850   // self processing time, divided by `Ratio()`.
InputTimeLocked(NodeValues * input_times) const851   void InputTimeLocked(NodeValues* input_times) const override
852       TF_SHARED_LOCKS_REQUIRED(mu_) {
853     double inherited_input_time;
854     if (output_) {
855       inherited_input_time = (*input_times)[output_->long_name()];
856     } else {
857       inherited_input_time = (*input_times)[kModelInputTimeKey];
858     }
859     double parallelism = 1.0;
860     auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
861     if (parallelism_parameter) {
862       parallelism = (*parallelism_parameter)->value;
863     }
864 
865     auto ratio = RatioLocked();
866     if (ratio == 0.0) {
867       (*input_times)[long_name()] =
868           inherited_input_time + SelfProcessingTimeLocked() / parallelism;
869       return;
870     }
871     double input_time =
872         (inherited_input_time + SelfProcessingTimeLocked() / parallelism) /
873         ratio;
874     (*input_times)[long_name()] = input_time;
875   }
876 
877   // The output time is the sum of parallelism adjusted self processing time and
878   // expected wait time from the buffer model estimated using
879   // `ComputeWaitTime(producer_time, consumer_time, parallelism, ...)`, where
880   // `producer_time` is the product of `Ratio()` and the sum of output times of
881   // inputs, `consumer_time` is the product of `Ratio()` and the `input_time`
882   // specified through `input_times` (since for each element stored in the
883   // buffer, the inputs need to be called `Ratio()` times), and if the node has
884   // parallelism parameter, then `buffer_size` is derived from `parallelism`.
885   //
886   // Current implementation assumes that there is at most 1 parameter per node.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const887   void OutputTimeLocked(const NodeValues& input_times,
888                         ParameterGradients* gradients, NodeValues* output_times,
889                         NodeValues* output_time_gradients) const override
890       TF_SHARED_LOCKS_REQUIRED(mu_) {
891     auto ratio = RatioLocked();
892     double parallelism = 1.0;
893     double buffer_size = 0.0;
894     auto* parallelism_parameter = gtl::FindOrNull(parameters_, kParallelism);
895     auto* buffer_size_parameter = gtl::FindOrNull(parameters_, kBufferSize);
896     if (parallelism_parameter) {
897       parallelism = (*parallelism_parameter)->value;
898       if (ratio == 0.0) {
899         buffer_size = parallelism;
900       } else {
901         // Currently, MapAndBatch is the only transformation creates
902         // AsyncKnownRatio nodes with ratio >= 1. For MapAndBatch, we create
903         // `parallelism` threads to apply the function on elements from input
904         // dataset, while one element in the buffer actually corresponds to
905         // `Ratio()` elements from input dataset. So we adjust the `buffer_size`
906         // by dividing `Ratio()`.
907         buffer_size = parallelism / ratio;
908       }
909     } else if (buffer_size_parameter) {
910       buffer_size = (*buffer_size_parameter)->value;
911     }
912     double self_processing_time = SelfProcessingTimeLocked();
913     double output_time, wait_time, consumer_time, producer_time;
914     double input_time = input_times.at(long_name());
915 
916     if (ratio == 0.0) {
917       consumer_time = input_time;
918       producer_time = 0.0L;
919       if (gradients) {
920         for (const auto& pair : CollectTunableParametersLocked()) {
921           gradients->erase(std::make_pair(pair.first, pair.second->name));
922         }
923 
924         double producer_time_der = 0.0L;
925         double consumer_time_der = 0.0L;
926         double buffer_size_der = 0.0L;
927         wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
928                                     &producer_time_der, &consumer_time_der,
929                                     &buffer_size_der);
930         (*output_time_gradients)[long_name()] = consumer_time_der;
931         if (parallelism_parameter && (*parallelism_parameter)->state->tunable) {
932           (*gradients)[std::make_pair(long_name(),
933                                       (*parallelism_parameter)->name)] =
934               -(1.0L + consumer_time_der) * self_processing_time /
935                   Square(parallelism) +
936               buffer_size_der;
937         } else if (buffer_size_parameter &&
938                    (*buffer_size_parameter)->state->tunable) {
939           (*gradients)[std::make_pair(
940               long_name(), (*buffer_size_parameter)->name)] = buffer_size_der;
941         }
942       } else {
943         wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
944                                     /*producer_time_derivative=*/nullptr,
945                                     /*consumer_time_derivative=*/nullptr,
946                                     /*buffer_size_derivative=*/nullptr);
947       }
948       output_time = self_processing_time / parallelism + wait_time;
949       (*output_times)[long_name()] = output_time;
950       return;
951     }
952 
953     consumer_time = input_time * ratio;
954     producer_time = ratio * OutputTimeForInputs(*output_times);
955     if (gradients) {
956       double producer_time_der = 0.0L;
957       double consumer_time_der = 0.0L;
958       double buffer_size_der = 0.0L;
959       wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
960                                   &producer_time_der, &consumer_time_der,
961                                   &buffer_size_der);
962       double inputs_time_der_sum =
963           OutputTimeGradientsForInputs(*output_time_gradients);
964       (*output_time_gradients)[long_name()] =
965           consumer_time_der + producer_time_der * inputs_time_der_sum;
966 
967       for (const auto& pair : CollectTunableParametersLocked()) {
968         auto* gradient = gtl::FindOrNull(
969             *gradients, std::make_pair(pair.first, pair.second->name));
970         if (gradient) {
971           *gradient *= (ratio * producer_time_der);
972         }
973       }
974 
975       // Add derivative w.r.t. own parameter if it's tunable.
976       if (parallelism_parameter && (*parallelism_parameter)->state->tunable) {
977         (*gradients)[std::make_pair(long_name(),
978                                     (*parallelism_parameter)->name)] =
979             buffer_size_der / ratio -
980             (1.0L + consumer_time_der +
981              producer_time_der * inputs_time_der_sum) *
982                 self_processing_time / Square(parallelism);
983       } else if (buffer_size_parameter &&
984                  (*buffer_size_parameter)->state->tunable) {
985         (*gradients)[std::make_pair(
986             long_name(), (*buffer_size_parameter)->name)] = buffer_size_der;
987       }
988     } else {
989       wait_time = ComputeWaitTime(producer_time, consumer_time, buffer_size,
990                                   /*producer_time_derivative=*/nullptr,
991                                   /*consumer_time_derivative=*/nullptr,
992                                   /*buffer_size_derivative=*/nullptr);
993     }
994     output_time = self_processing_time / parallelism + wait_time;
995     (*output_times)[long_name()] = output_time;
996   }
997 
998   // The processing time is the sum of the self processing time and the product
999   // of `Ratio()` and the sum of processing times of inputs.
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)1000   void TotalProcessingTimeLocked(NodeValues* processing_times,
1001                                  NodeValues* total_processing_times) override
1002       TF_SHARED_LOCKS_REQUIRED(mu_) {
1003     double self_processing_time = SelfProcessingTimeLocked();
1004     if (processing_times) {
1005       (*processing_times)[long_name()] = self_processing_time;
1006     }
1007     auto ratio = RatioLocked();
1008     if (ratio == 0) {
1009       (*total_processing_times)[long_name()] = self_processing_time;
1010       return;
1011     }
1012     double inputs_processing_time =
1013         ratio * TotalProcessingTimeForInputs(*total_processing_times);
1014     (*total_processing_times)[long_name()] =
1015         self_processing_time + inputs_processing_time;
1016   }
1017 
MaximumBufferedBytes() const1018   double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
1019     double result = 0;
1020     auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
1021     if (!parameter) {
1022       parameter = gtl::FindOrNull(parameters_, kParallelism);
1023     }
1024 
1025     if (parameter) {
1026       if (memory_ratio_ == 0) {
1027         result += (*parameter)->value * AverageBufferedElementSize();
1028       } else {
1029         // The estimation is currently not accurate for MapAndBatchDataset for
1030         // the maximum buffer size does not match `num_parallel_calls`
1031         // parameter.
1032         result +=
1033             (*parameter)->value * AverageBufferedElementSize() / memory_ratio_;
1034       }
1035     }
1036     return result;
1037   }
1038 
1039  private:
1040   // Identifies how many input elements need to be created to construct an
1041   // element for the dataset.
1042   //
1043   // Currently the value is 1 for PrefetchDataset and ParallelMapDataset,
1044   // batch_size for MapAndBatchDataset and ParallelBatchDataset.
1045   const double ratio_;
1046   // For parallelism nodes, identifies how many parallelism calls are introduced
1047   // by one buffered element. The value is defined to correctly estimate RAM
1048   // budget bound with given num_parallel_calls (or buffer_size) combined with
1049   // the estimated average size of buffered elements.
1050   const double memory_ratio_;
1051 };
1052 
1053 class UnknownRatio : public Node {
1054  public:
1055   using Node::Node;
1056 
~UnknownRatio()1057   virtual ~UnknownRatio() {}
1058 
Ratio() const1059   double Ratio() const override {
1060     tf_shared_lock l(mu_);
1061     return RatioLocked();
1062   }
1063 
1064  protected:
RatioLocked() const1065   double RatioLocked() const TF_SHARED_LOCKS_REQUIRED(mu_) {
1066     // TODO(wilsin): Consistent with UnknownRatio, current implementation
1067     // assumes that the number of input elements consumed per output is the same
1068     // across all inputs.
1069     if (num_elements_ == 0 || inputs_.empty() ||
1070         inputs_.front()->num_elements() == 0) {
1071       return 0.0;
1072     }
1073     return static_cast<double>(inputs_.front()->num_elements()) /
1074            static_cast<double>(num_elements_);
1075   }
1076 
Clone(std::shared_ptr<Node> output) const1077   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
1078       TF_SHARED_LOCKS_REQUIRED(mu_) {
1079     return std::make_shared<UnknownRatio>(Args{id_, name_, std::move(output)});
1080   }
1081 
1082   // The input time is the sum of inherited input time and self processing time,
1083   // divided by the ratio estimate.
InputTimeLocked(NodeValues * input_times) const1084   void InputTimeLocked(NodeValues* input_times) const override
1085       TF_SHARED_LOCKS_REQUIRED(mu_) {
1086     double inherited_input_time;
1087     if (output_) {
1088       inherited_input_time = (*input_times)[output_->long_name()];
1089     } else {
1090       inherited_input_time = (*input_times)[kModelInputTimeKey];
1091     }
1092 
1093     if (num_elements_ == 0 || inputs_.empty() ||
1094         inputs_.front()->num_elements() == 0) {
1095       (*input_times)[long_name()] = inherited_input_time;
1096       return;
1097     }
1098     std::shared_ptr<Node> input = inputs_.front();
1099     double ratio = static_cast<double>(input->num_elements()) /
1100                    static_cast<double>(num_elements_);
1101     double input_time =
1102         (inherited_input_time + SelfProcessingTimeLocked()) / ratio;
1103     (*input_times)[long_name()] = input_time;
1104   }
1105 
1106   // The output time is the sum of the self processing time and the product of
1107   // the ratio estimate and the sum of output times of inputs.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const1108   void OutputTimeLocked(const NodeValues& input_times,
1109                         ParameterGradients* gradients, NodeValues* output_times,
1110                         NodeValues* output_time_gradients) const override
1111       TF_SHARED_LOCKS_REQUIRED(mu_) {
1112     double self_processing_time = SelfProcessingTimeLocked();
1113     if (num_elements_ == 0 || inputs_.empty() ||
1114         inputs_.front()->num_elements() == 0) {
1115       (*output_times)[long_name()] = self_processing_time;
1116       if (gradients) {
1117         for (const auto& pair : CollectTunableParametersLocked()) {
1118           gradients->erase(std::make_pair(pair.first, pair.second->name));
1119         }
1120       }
1121       return;
1122     }
1123     // TODO(jsimsa): The current implementation assumes that the number of input
1124     // elements consumed per output is the same across all inputs.
1125     double ratio = static_cast<double>(inputs_.front()->num_elements()) /
1126                    static_cast<double>(num_elements_);
1127     if (gradients) {
1128       for (const auto& pair : CollectTunableParametersLocked()) {
1129         auto* gradient = gtl::FindOrNull(
1130             *gradients, std::make_pair(pair.first, pair.second->name));
1131         if (gradient) {
1132           *gradient *= ratio;
1133         }
1134       }
1135       (*output_time_gradients)[long_name()] =
1136           OutputTimeGradientsForInputs(*output_time_gradients);
1137     }
1138     double inputs_output_time = ratio * OutputTimeForInputs(*output_times);
1139     (*output_times)[long_name()] = self_processing_time + inputs_output_time;
1140   }
1141 
1142   // The processing time is the sum of the self processing time and the product
1143   // of the ratio estimate and the sum of processing times of inputs.
TotalProcessingTimeLocked(absl::flat_hash_map<string,double> * processing_times,absl::flat_hash_map<string,double> * total_processing_times)1144   void TotalProcessingTimeLocked(
1145       absl::flat_hash_map<string, double>* processing_times,
1146       absl::flat_hash_map<string, double>* total_processing_times) override
1147       TF_SHARED_LOCKS_REQUIRED(mu_) {
1148     double self_processing_time = SelfProcessingTimeLocked();
1149     if (processing_times) {
1150       (*processing_times)[long_name()] = self_processing_time;
1151     }
1152     if (inputs_.empty() || num_elements_ == 0) {
1153       (*total_processing_times)[long_name()] = self_processing_time;
1154       return;
1155     }
1156     // TODO(jsimsa): The current implementation assumes that the number of input
1157     // elements consumed per output is the same across all inputs.
1158     std::shared_ptr<Node> input = inputs_.front();
1159     double ratio = static_cast<double>(input->num_elements()) /
1160                    static_cast<double>(num_elements_);
1161     double inputs_processing_time =
1162         ratio * TotalProcessingTimeForInputs(*total_processing_times);
1163     (*total_processing_times)[long_name()] =
1164         self_processing_time + inputs_processing_time;
1165   }
1166 
ToProto(ModelProto::Node * node_proto) const1167   Status ToProto(ModelProto::Node* node_proto) const {
1168     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
1169     node_proto->set_node_class(NodeClass::UNKNOWN_RATIO);
1170     return OkStatus();
1171   }
1172 };
1173 
1174 class Unknown : public Node {
1175  public:
1176   using Node::Node;
1177 
~Unknown()1178   virtual ~Unknown() {}
1179 
1180  protected:
Clone(std::shared_ptr<Node> output) const1181   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
1182       TF_SHARED_LOCKS_REQUIRED(mu_) {
1183     return std::make_shared<Unknown>(Args{id_, name_, std::move(output)});
1184   }
1185 
1186   // The input time is the inherited input time.
InputTimeLocked(NodeValues * input_times) const1187   void InputTimeLocked(NodeValues* input_times) const override
1188       TF_SHARED_LOCKS_REQUIRED(mu_) {
1189     double inherited_input_time;
1190     if (output_) {
1191       inherited_input_time = (*input_times)[output_->long_name()];
1192     } else {
1193       inherited_input_time = (*input_times)[kModelInputTimeKey];
1194     }
1195     (*input_times)[long_name()] = inherited_input_time;
1196   }
1197 
1198   // The output time is the sum of output times of inputs.
OutputTimeLocked(const NodeValues & input_times,ParameterGradients * gradients,NodeValues * output_times,NodeValues * output_time_gradients) const1199   void OutputTimeLocked(const NodeValues& input_times,
1200                         ParameterGradients* gradients, NodeValues* output_times,
1201                         NodeValues* output_time_gradients) const override
1202       TF_SHARED_LOCKS_REQUIRED(mu_) {
1203     (*output_times)[long_name()] = OutputTimeForInputs(*output_times);
1204     if (gradients) {
1205       (*output_time_gradients)[long_name()] =
1206           OutputTimeGradientsForInputs(*output_time_gradients);
1207     }
1208   }
1209 
1210   // The processing time is the sum of processing times of inputs.
TotalProcessingTimeLocked(NodeValues * processing_times,NodeValues * total_processing_times)1211   void TotalProcessingTimeLocked(NodeValues* processing_times,
1212                                  NodeValues* total_processing_times) override
1213       TF_SHARED_LOCKS_REQUIRED(mu_) {
1214     if (processing_times) {
1215       (*processing_times)[long_name()] = SelfProcessingTimeLocked();
1216     }
1217     (*total_processing_times)[long_name()] =
1218         TotalProcessingTimeForInputs(*total_processing_times);
1219   }
1220 
ToProto(ModelProto::Node * node_proto) const1221   Status ToProto(ModelProto::Node* node_proto) const {
1222     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
1223     node_proto->set_node_class(NodeClass::UNKNOWN);
1224     return OkStatus();
1225   }
1226 };
1227 
1228 class AsyncKnownRatio : public AsyncRatio {
1229  public:
AsyncKnownRatio(Node::Args args,double ratio,double memory_ratio,std::vector<std::shared_ptr<Parameter>> parameters)1230   AsyncKnownRatio(Node::Args args, double ratio, double memory_ratio,
1231                   std::vector<std::shared_ptr<Parameter>> parameters)
1232       : AsyncRatio(args, ratio, memory_ratio, parameters) {}
1233 
~AsyncKnownRatio()1234   virtual ~AsyncKnownRatio() {}
1235 
1236  protected:
Clone(std::shared_ptr<Node> output) const1237   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
1238       TF_SHARED_LOCKS_REQUIRED(mu_) {
1239     std::vector<std::shared_ptr<Parameter>> parameters;
1240     for (auto& pair : parameters_) {
1241       parameters.push_back(pair.second);
1242     }
1243     return std::make_shared<AsyncKnownRatio>(
1244         Args{id_, name_, std::move(output)}, Ratio(), MemoryRatio(),
1245         parameters);
1246   }
1247 
ToProto(ModelProto::Node * node_proto) const1248   Status ToProto(ModelProto::Node* node_proto) const {
1249     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
1250     node_proto->set_node_class(NodeClass::ASYNC_KNOWN_RATIO);
1251     node_proto->set_ratio(Ratio());
1252     node_proto->set_memory_ratio(MemoryRatio());
1253     return OkStatus();
1254   }
1255 };
1256 
1257 class AsyncUnknownRatio : public AsyncRatio {
1258  public:
AsyncUnknownRatio(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)1259   AsyncUnknownRatio(Node::Args args,
1260                     std::vector<std::shared_ptr<Parameter>> parameters)
1261       : AsyncRatio(args, /*ratio=*/0.0, /*memory_ratio=*/0.0, parameters) {}
1262 
~AsyncUnknownRatio()1263   virtual ~AsyncUnknownRatio() {}
1264 
Ratio() const1265   double Ratio() const override {
1266     tf_shared_lock l(mu_);
1267     return RatioLocked();
1268   }
1269 
1270  protected:
RatioLocked() const1271   double RatioLocked() const TF_SHARED_LOCKS_REQUIRED(mu_) override {
1272     // TODO(wilsin): Consistent with UnknownRatio, current implementation
1273     // assumes that the number of input elements consumed per output is the same
1274     // across all inputs.
1275     if (num_elements_ == 0 || inputs_.empty() ||
1276         inputs_.front()->num_elements() == 0) {
1277       return 0.0;
1278     }
1279     return static_cast<double>(inputs_.front()->num_elements()) /
1280            static_cast<double>(num_elements_);
1281   }
1282 
Clone(std::shared_ptr<Node> output) const1283   std::shared_ptr<Node> Clone(std::shared_ptr<Node> output) const override
1284       TF_SHARED_LOCKS_REQUIRED(mu_) {
1285     std::vector<std::shared_ptr<Parameter>> parameters;
1286     for (auto& pair : parameters_) {
1287       parameters.push_back(pair.second);
1288     }
1289     return std::make_shared<AsyncUnknownRatio>(
1290         Args{id_, name_, std::move(output)}, parameters);
1291   }
1292 
ToProto(ModelProto::Node * node_proto) const1293   Status ToProto(ModelProto::Node* node_proto) const {
1294     TF_RETURN_IF_ERROR(Node::ToProto(node_proto));
1295     node_proto->set_node_class(NodeClass::ASYNC_UNKNOWN_RATIO);
1296     return OkStatus();
1297   }
1298 };
1299 
1300 }  // namespace
1301 
1302 thread_local int64_t Node::work_start_;
1303 
MakeParameter(const string & name,std::shared_ptr<SharedState> state,double min,double max)1304 std::shared_ptr<Parameter> MakeParameter(const string& name,
1305                                          std::shared_ptr<SharedState> state,
1306                                          double min, double max) {
1307   return std::make_shared<Parameter>(name, state, min, max);
1308 }
1309 
MakeNonTunableParameter(const string & name,double value)1310 std::shared_ptr<Parameter> MakeNonTunableParameter(const string& name,
1311                                                    double value) {
1312   return std::make_shared<Parameter>(name, nullptr, /*min=*/value,
1313                                      /*max=*/value);
1314 }
1315 
MakeInterleaveManyNode(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)1316 std::shared_ptr<Node> MakeInterleaveManyNode(
1317     Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters) {
1318   DCHECK(absl::c_any_of(parameters,
1319                         [](const std::shared_ptr<Parameter>& parameter) {
1320                           return parameter->name == kCycleLength;
1321                         }));
1322   return std::make_shared<InterleaveMany>(std::move(args),
1323                                           std::move(parameters));
1324 }
1325 
MakeAsyncInterleaveManyNode(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)1326 std::shared_ptr<Node> MakeAsyncInterleaveManyNode(
1327     Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters) {
1328   DCHECK(absl::c_any_of(parameters,
1329                         [](const std::shared_ptr<Parameter>& parameter) {
1330                           return parameter->name == kCycleLength;
1331                         }));
1332   return std::make_shared<AsyncInterleaveMany>(std::move(args),
1333                                                std::move(parameters));
1334 }
1335 
MakeKnownRatioNode(Node::Args args,double ratio)1336 std::shared_ptr<Node> MakeKnownRatioNode(Node::Args args, double ratio) {
1337   return std::make_shared<KnownRatio>(std::move(args), ratio);
1338 }
1339 
MakeAsyncKnownRatioNode(Node::Args args,double ratio,double memory_ratio,std::vector<std::shared_ptr<Parameter>> parameters)1340 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
1341     Node::Args args, double ratio, double memory_ratio,
1342     std::vector<std::shared_ptr<Parameter>> parameters) {
1343   return std::make_shared<AsyncKnownRatio>(std::move(args), ratio, memory_ratio,
1344                                            std::move(parameters));
1345 }
1346 
MakeAsyncKnownRatioNode(Node::Args args,double ratio,std::vector<std::shared_ptr<Parameter>> parameters)1347 std::shared_ptr<Node> MakeAsyncKnownRatioNode(
1348     Node::Args args, double ratio,
1349     std::vector<std::shared_ptr<Parameter>> parameters) {
1350   return MakeAsyncKnownRatioNode(std::move(args), /*ratio=*/ratio,
1351                                  /*memory_ratio=*/ratio, std::move(parameters));
1352 }
1353 
MakeSourceNode(Node::Args args)1354 std::shared_ptr<Node> MakeSourceNode(Node::Args args) {
1355   return MakeKnownRatioNode(std::move(args), 0);
1356 }
1357 
MakeUnknownRatioNode(Node::Args args)1358 std::shared_ptr<Node> MakeUnknownRatioNode(Node::Args args) {
1359   return std::make_shared<UnknownRatio>(std::move(args));
1360 }
1361 
MakeAsyncUnknownRatioNode(Node::Args args,std::vector<std::shared_ptr<Parameter>> parameters)1362 std::shared_ptr<Node> MakeAsyncUnknownRatioNode(
1363     Node::Args args, std::vector<std::shared_ptr<Parameter>> parameters) {
1364   return std::make_shared<AsyncUnknownRatio>(std::move(args),
1365                                              std::move(parameters));
1366 }
1367 
MakeUnknownNode(Node::Args args)1368 std::shared_ptr<Node> MakeUnknownNode(Node::Args args) {
1369   return std::make_shared<Unknown>(std::move(args));
1370 }
1371 
ComputeWaitTime(const double & producer_time,const double & consumer_time,const double & buffer_size,double * producer_time_derivative,double * consumer_time_derivative,double * buffer_size_derivative)1372 double Node::ComputeWaitTime(const double& producer_time,
1373                              const double& consumer_time,
1374                              const double& buffer_size,
1375                              double* producer_time_derivative,
1376                              double* consumer_time_derivative,
1377                              double* buffer_size_derivative) {
1378   // If we set x=`consumer_time`, y=`producer_time`, n=`buffer_size`,
1379   // p=`p_buffer_empty`, T=`wait_time`, then we have:
1380   // if y = 0, then p = 0;
1381   // elif x = 0, then p = 1;
1382   // elif x = y, then p = 1 / (n+1);
1383   // else p = [1 - x/y] / [1 - power(x/y, n+1)].
1384   //
1385   // We also have T = p * y, and derivatives of T w.r.t. x, y, n are computed:
1386   // dT/dx = dp/dx * y,
1387   // dT/dy = p + dp/dy * y,
1388   // dT/dn = dp/dn * y.
1389   // Then the remaining work is to compute dp/dx, dp/dy, dp/dn by considering
1390   // different cases and substitute the values into above formulas.
1391 
1392   // Case 1: if producer is infinitely fast. The buffer will always be full.
1393   // Wait time will always be 0.
1394   if (producer_time == 0) {
1395     if (producer_time_derivative) {
1396       // Note a common error is `*producer_time_derivative = 0` since p=0 on the
1397       // line y=0 doesn't imply dp/dy = 0 there. Actually to compute dp/dy at
1398       // (x,0), we need to consider lim_{dy->0+} [p(x,dy)-p(x,0)] / dy, where
1399       // p(x,0)=0 and p(x,dy) = [1 - x/dy] / [1 - power(x/dy, n+1)].
1400       if (buffer_size == 0 || consumer_time == 0) {
1401         *producer_time_derivative = 1.0L;
1402       } else {
1403         *producer_time_derivative = 0.0L;
1404       }
1405     }
1406     if (consumer_time_derivative) {
1407       *consumer_time_derivative = 0.0L;
1408     }
1409     if (buffer_size_derivative) {
1410       *buffer_size_derivative = 0.0L;
1411     }
1412     return 0.0L;
1413   }
1414 
1415   // Case 2: if consumer is infinitely fast. Wait time is always the time to
1416   // produce an output.
1417   if (consumer_time == 0) {
1418     if (producer_time_derivative) {
1419       *producer_time_derivative = 1.0L;
1420     }
1421     if (consumer_time_derivative) {
1422       // Note a common error is `*consumer_time_derivative = 0` since p=1 on the
1423       // line x=0 doesn't imply dp/dx = 0 there. Actually to compute dp/dx at
1424       // (0,y), we need to consider lim_{dx->0+} [p(dx,y)-p(0,y)] / dx, where
1425       // p(0,y)=1, p(dx,y) = [1 - dx/y] / [1 - power(dx/y, n+1)] if y!=0.
1426       if (buffer_size == 0) {
1427         *consumer_time_derivative = 0.0L;
1428       } else {
1429         *consumer_time_derivative = -1.0L;
1430       }
1431     }
1432     if (buffer_size_derivative) {
1433       *buffer_size_derivative = 0.0L;
1434     }
1435     return producer_time;
1436   }
1437 
1438   // Case 3: the consumer and the producer are equally fast. Expected wait time
1439   // decreases linearly with the size of the buffer.
1440   if (consumer_time == producer_time) {
1441     const double p_buffer_empty = 1.0L / (buffer_size + 1.0L);
1442     const double p_buffer_empty_der =
1443         -buffer_size / (2.0L * buffer_size + 2.0L);
1444     if (producer_time_derivative) {
1445       // Note a common error is `*producer_time_derivative = p_buffer_empty`
1446       // since p=1/(n+1) on the line x=y doesn't imply dp/dy = 0 there. Actually
1447       // to compute dp/dy at (y,y), we need to consider lim_{dy->0}
1448       // [p(y,y+dy)-p(y,y)] / dy, where p(y,y)=1/(n+1), p(y,y+dy) = [1 -
1449       // y/(y+dy)] / [1 - power(y/(y+dy), n+1)].
1450       *producer_time_derivative = p_buffer_empty - p_buffer_empty_der;
1451     }
1452     if (consumer_time_derivative) {
1453       // Note a common error is `*consumer_time_derivative = 0` since p=1/(n+1)
1454       // on the line x=y doesn't imply dp/dx = 0 there. Actually to compute
1455       // dp/dx at (x,x), we need to consider lim_{dx->0} [p(x+dx,x)-p(x,x)] /
1456       // dx, where p(x,x)=1/(n+1), p(x+dx,x) = [1 - (x+dx)/x] / [1 -
1457       // power((x+dx)/x, n+1)].
1458       *consumer_time_derivative = p_buffer_empty_der;
1459     }
1460     if (buffer_size_derivative) {
1461       *buffer_size_derivative = -producer_time / Square(buffer_size + 1.0L);
1462     }
1463     return p_buffer_empty * producer_time;
1464   }
1465 
1466   // Case 4: the consumer is slower than the producer and neither is infinitely
1467   // fast. Case 4 and Case 5 actually follow same formula. Separate them for
1468   // numerical computation reasons.
1469   if (consumer_time > producer_time) {
1470     const double ratio = producer_time / consumer_time;
1471     const double ratio_pow = std::pow(ratio, buffer_size);
1472     const double p_buffer_empty =
1473         ratio_pow * (1.0L - ratio) / (1.0L - ratio * ratio_pow);
1474     const double p_buffer_empty_der =
1475         (buffer_size - (buffer_size + 1.0L) * ratio + ratio_pow * ratio) *
1476         ratio_pow / ratio / Square(1.0L - ratio_pow * ratio);
1477     if (producer_time_derivative) {
1478       *producer_time_derivative = p_buffer_empty + p_buffer_empty_der * ratio;
1479     }
1480     if (consumer_time_derivative) {
1481       *consumer_time_derivative = -p_buffer_empty_der * Square(ratio);
1482     }
1483     if (buffer_size_derivative) {
1484       *buffer_size_derivative = p_buffer_empty / (1.0L - ratio_pow * ratio) *
1485                                 std::log(ratio) * producer_time;
1486     }
1487     return p_buffer_empty * producer_time;
1488   }
1489 
1490   // Case 5: the producer is slower than the consumer and neither is infinitely
1491   // fast.
1492   const double ratio = consumer_time / producer_time;
1493   const double ratio_pow = std::pow(ratio, buffer_size);
1494   const double p_buffer_empty = (1.0L - ratio) / (1.0L - ratio_pow * ratio);
1495   const double p_buffer_empty_der =
1496       ((buffer_size + 1.0L - buffer_size * ratio) * ratio_pow - 1.0L) /
1497       Square(1.0L - ratio_pow * ratio);
1498   if (producer_time_derivative) {
1499     *producer_time_derivative = p_buffer_empty - p_buffer_empty_der * ratio;
1500   }
1501   if (consumer_time_derivative) {
1502     *consumer_time_derivative = p_buffer_empty_der;
1503   }
1504   if (buffer_size_derivative) {
1505     *buffer_size_derivative = p_buffer_empty / (1.0L - ratio_pow * ratio) *
1506                               ratio_pow * ratio * std::log(ratio) *
1507                               producer_time;
1508   }
1509   return p_buffer_empty * producer_time;
1510 }
1511 
CollectTunableParametersLocked() const1512 Node::ModelParameters Node::CollectTunableParametersLocked() const {
1513   Node::ModelParameters parameters;
1514   // Collect tunable parameters from the leaves of the nodes tree to the root.
1515   for (const auto& node :
1516        CollectNodesLocked(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
1517     tf_shared_lock l(node->mu_);
1518     node->CollectTunableParametersHelper(&parameters);
1519   }
1520   CollectTunableParametersHelper(&parameters);
1521   return parameters;
1522 }
1523 
CollectTunableParameters() const1524 Node::ModelParameters Node::CollectTunableParameters() const {
1525   tf_shared_lock l(mu_);
1526   return CollectTunableParametersLocked();
1527 }
1528 
CollectNodeTunableParameters() const1529 Node::ModelParameters Node::CollectNodeTunableParameters() const {
1530   tf_shared_lock l(mu_);
1531   Node::ModelParameters parameters;
1532   CollectTunableParametersHelper(&parameters);
1533   return parameters;
1534 }
1535 
DebugString() const1536 string Node::DebugString() const {
1537   absl::flat_hash_map<string, string> debug_strings;
1538   tf_shared_lock l(mu_);
1539   // Build up the debug string from the leaves of the nodes tree to the root.
1540   for (const auto& node :
1541        CollectNodesLocked(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1542     tf_shared_lock l(node->mu_);
1543     node->DebugStringHelper(&debug_strings);
1544   }
1545   DebugStringHelper(&debug_strings);
1546 
1547   return debug_strings[long_name()];
1548 }
1549 
FlushMetrics()1550 void Node::FlushMetrics() {
1551   if (!record_metrics_) {
1552     return;
1553   }
1554   metrics_.record_bytes_consumed(bytes_consumed_);
1555   metrics_.record_bytes_produced(bytes_produced_);
1556   metrics_.record_num_elements(num_elements_);
1557 }
1558 
OutputTime(Node::NodeValues * input_times,Node::ParameterGradients * gradients) const1559 double Node::OutputTime(Node::NodeValues* input_times,
1560                         Node::ParameterGradients* gradients) const {
1561   // To store the output time gradient w.r.t. input time (if `gradients` is not
1562   // `nullptr`) and the output time for each node.
1563   Node::NodeValues output_time_gradients, output_times;
1564   tf_shared_lock l(mu_);
1565   auto nodes = CollectNodesLocked(TraversalOrder::BFS, IsAutotuneNode);
1566 
1567   // Computes and stores input time for each node from the root to leaves of the
1568   // nodes tree.
1569   InputTimeLocked(input_times);
1570   for (const auto& node : nodes) {
1571     tf_shared_lock l(node->mu_);
1572     node->InputTimeLocked(input_times);
1573   }
1574 
1575   std::reverse(nodes.begin(), nodes.end());
1576   // Computes and stores the output time and output time gradient w.r.t. input
1577   // time (if `gradients` is not `nullptr`) for each node from leaves of the
1578   // nodes tree to the root.
1579   for (const auto& node : nodes) {
1580     tf_shared_lock l(node->mu_);
1581     node->OutputTimeLocked(*input_times, gradients, &output_times,
1582                            &output_time_gradients);
1583   }
1584   OutputTimeLocked(*input_times, gradients, &output_times,
1585                    &output_time_gradients);
1586 
1587   return output_times[long_name()];
1588 }
1589 
ComputeSelfTime() const1590 double Node::ComputeSelfTime() const {
1591   if (num_elements_ == 0) {
1592     return 0;
1593   }
1594   tf_shared_lock l(mu_);
1595   return processing_time_ema_;
1596 }
1597 
Snapshot() const1598 std::shared_ptr<Node> Node::Snapshot() const {
1599   NodePairList node_pairs;
1600   auto result = SnapshotHelper(nullptr, &node_pairs);
1601 
1602   while (!node_pairs.empty()) {
1603     auto node_pair = node_pairs.front();
1604     node_pairs.pop_front();
1605     std::shared_ptr<Node> current = node_pair.first,
1606                           cloned_output = node_pair.second;
1607     cloned_output->add_input(
1608         current->SnapshotHelper(cloned_output, &node_pairs));
1609   }
1610   return result;
1611 }
1612 
SelfProcessingTime() const1613 double Node::SelfProcessingTime() const {
1614   tf_shared_lock l(mu_);
1615   return SelfProcessingTimeLocked();
1616 }
1617 
TotalBufferedBytes() const1618 double Node::TotalBufferedBytes() const {
1619   Node::NodeValues total_bytes;
1620   tf_shared_lock l(mu_);
1621   // Compute total buffered bytes from the leaves of the nodes tree to the root.
1622   for (const auto& node :
1623        CollectNodesLocked(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1624     tf_shared_lock l(node->mu_);
1625     node->TotalBufferedBytesHelper(&total_bytes);
1626   }
1627   TotalBufferedBytesHelper(&total_bytes);
1628 
1629   return total_bytes[long_name()];
1630 }
1631 
TotalMaximumBufferedBytes() const1632 double Node::TotalMaximumBufferedBytes() const {
1633   Node::NodeValues total_bytes;
1634   tf_shared_lock l(mu_);
1635   // Compute total maximum buffered bytes from the leaves of the nodes tree to
1636   // the root.
1637   for (const auto& node :
1638        CollectNodesLocked(TraversalOrder::REVERSE_BFS, IsAnyNode)) {
1639     tf_shared_lock l(node->mu_);
1640     node->TotalMaximumBufferedBytesHelper(&total_bytes);
1641   }
1642   TotalMaximumBufferedBytesHelper(&total_bytes);
1643 
1644   return total_bytes[long_name()];
1645 }
1646 
TotalProcessingTime(Node::NodeValues * processing_times)1647 double Node::TotalProcessingTime(Node::NodeValues* processing_times) {
1648   // Create a hash map to store the per-element CPU time spent in the subtree
1649   // rooted in each node.
1650   Node::NodeValues total_processing_times;
1651   tf_shared_lock l(mu_);
1652 
1653   // Computes per-element CPU time spent in the subtree rooted in the node from
1654   // the leaves of the nodes tree to the root.
1655   for (const auto& node :
1656        CollectNodesLocked(TraversalOrder::REVERSE_BFS, IsAutotuneNode)) {
1657     tf_shared_lock l(node->mu_);
1658     node->TotalProcessingTimeLocked(processing_times, &total_processing_times);
1659   }
1660   TotalProcessingTimeLocked(processing_times, &total_processing_times);
1661 
1662   return total_processing_times[long_name()];
1663 }
1664 
AverageBufferedElementSize() const1665 double Node::AverageBufferedElementSize() const {
1666   DCHECK_GE(num_elements_, 0);
1667   DCHECK_GE(buffered_elements_, 0);
1668   if (num_elements_ <= 0) {
1669     if (buffered_elements_ <= 0) {
1670       // If there are no produced elements or buffered elements recorded, return
1671       // 0.
1672       return 0;
1673     }
1674     // If there are no produced elements but some buffered elements, return the
1675     // average size of all buffered elements.
1676     return static_cast<double>(buffered_bytes_) /
1677            static_cast<double>(buffered_elements_);
1678   }
1679 
1680   if (buffered_elements_ <= 0) {
1681     // If there are no buffered elements but some produced elements, return the
1682     // average size of all produced elements.
1683     return static_cast<double>(bytes_produced_) /
1684            static_cast<double>(num_elements_);
1685   }
1686 
1687   // Otherwise, return the mean value of average size of all produced elements
1688   // and average size of all buffered elements.
1689   return (static_cast<double>(bytes_produced_) /
1690               static_cast<double>(num_elements_) +
1691           static_cast<double>(buffered_bytes_) /
1692               static_cast<double>(buffered_elements_)) /
1693          2.0;
1694 }
1695 
OutputTimeForInputs(const Node::NodeValues & output_times) const1696 double Node::OutputTimeForInputs(const Node::NodeValues& output_times) const {
1697   double sum = 0;
1698   for (auto& input : inputs_) {
1699     // Inputs for which autotuning is disabled are excluded.
1700     if (input->autotune()) {
1701       sum += output_times.at(input->long_name());
1702     }
1703   }
1704   return sum;
1705 }
1706 
OutputTimeGradientsForInputs(const Node::NodeValues & output_time_gradients) const1707 double Node::OutputTimeGradientsForInputs(
1708     const Node::NodeValues& output_time_gradients) const {
1709   double sum = 0;
1710   for (auto& input : inputs_) {
1711     // Inputs for which autotuning is disabled are excluded.
1712     if (input->autotune()) {
1713       sum +=
1714           gtl::FindWithDefault(output_time_gradients, input->long_name(), 0.0L);
1715     }
1716   }
1717   return sum;
1718 }
1719 
TotalProcessingTimeForInputs(const Node::NodeValues & total_processing_times)1720 double Node::TotalProcessingTimeForInputs(
1721     const Node::NodeValues& total_processing_times) {
1722   // If the number of elements produced by an input is smaller than this
1723   // constant, then its processing time is estimated using a weighted average of
1724   // the empirical processing time and processing time history.
1725   constexpr int kNumElementsThreshold = 30;
1726 
1727   // Identifies the minimum number of input processing times to collect before
1728   // the processing time history is used as a prior.
1729   constexpr int kCountThreshold = 30;
1730 
1731   double sum = 0;
1732   for (auto& input : inputs_) {
1733     // Inputs for which autotuning is disabled are excluded.
1734     if (input->autotune()) {
1735       double input_processing_time =
1736           total_processing_times.at(input->long_name());
1737       int64_t num_elements = input->num_elements();
1738       if (num_elements < kNumElementsThreshold) {
1739         if (input_processing_time_count_ < kCountThreshold) {
1740           sum += input_processing_time;
1741         } else {
1742           // The fewer elements the input has produced so far, the more weight
1743           // is assigned to the prior to reduce volatility.
1744           double prior_weight = 1.0L / static_cast<double>(2 << num_elements);
1745           double prior =
1746               input_processing_time_sum_ / input_processing_time_count_;
1747           sum += (1.0L - prior_weight) * input_processing_time +
1748                  prior_weight * prior;
1749         }
1750       } else {
1751         sum += input_processing_time;
1752         input_processing_time_count_++;
1753         input_processing_time_sum_ += input_processing_time;
1754       }
1755     }
1756   }
1757   return sum;
1758 }
1759 
SelfProcessingTimeLocked() const1760 double Node::SelfProcessingTimeLocked() const {
1761   if (num_elements_ == 0) {
1762     return 0;
1763   }
1764   return static_cast<double>(processing_time_) /
1765          static_cast<double>(num_elements_);
1766 }
1767 
CollectNodes(TraversalOrder order,bool collect_node (const std::shared_ptr<Node>)) const1768 Node::NodeVector Node::CollectNodes(
1769     TraversalOrder order,
1770     bool collect_node(const std::shared_ptr<Node>)) const {
1771   tf_shared_lock l(mu_);
1772   return CollectNodesLocked(order, collect_node);
1773 }
1774 
TryDownsizeBuffer()1775 bool Node::TryDownsizeBuffer() {
1776   if (!IsAsync()) {
1777     return false;
1778   }
1779   Node::ModelParameters tunable_parameters;
1780   {
1781     tf_shared_lock l(mu_);
1782     if (buffered_elements_low_ > buffered_elements_high_) {
1783       // No element is stored in the buffer yet. Do nothing.
1784       return false;
1785     }
1786     CollectTunableParametersHelper(&tunable_parameters);
1787   }
1788   Node::ModelParameters buffer_size_parameters;
1789   for (auto& parameter : tunable_parameters) {
1790     if (parameter.second->name != kBufferSize) {
1791       continue;
1792     }
1793     buffer_size_parameters.push_back(std::move(parameter));
1794   }
1795   bool downsized = false;
1796   // Sync buffer state values to parameter values
1797   for (auto& [node_name, parameter] : buffer_size_parameters) {
1798     tf_shared_lock l(*parameter->state->mu);
1799     parameter->value = parameter->state->value;
1800   }
1801   {
1802     // Downsize buffers
1803     tf_shared_lock l(mu_);
1804     for (auto& [node_name, parameter] : buffer_size_parameters) {
1805       if (buffered_elements_low_ > 0 &&
1806           (buffered_elements_high_ - buffered_elements_low_ + 1) <
1807               parameter->value) {
1808         double old_value = parameter->value;
1809         // By default, we double buffer sizes if there is enough RAM in
1810         // upsize. We cap the downsize by 1/4 of the current size to avoid
1811         // undoing the previous upsize.
1812         parameter->value =
1813             std::max(buffered_elements_high_ - buffered_elements_low_ + 1,
1814                      static_cast<int64_t>(old_value * 0.75));
1815         if (old_value != parameter->value) {
1816           VLOG(2) << "Downsize buffer " << long_name()
1817                   << "::" << parameter->name << " from " << old_value << " to "
1818                   << parameter->value;
1819           downsized = true;
1820         }
1821       }
1822     }
1823   }
1824   // Since SharedState locks are the same as the Ops iterator locks, locking of
1825   // the SharedState locks should be minimized in the optimization thread.
1826   if (downsized) {
1827     UpdateStateValues(&buffer_size_parameters);
1828   }
1829   return downsized;
1830 }
1831 
CollectBufferParametersToUpsize(absl::flat_hash_map<Node *,Parameter * > & node_parameters)1832 void Node::CollectBufferParametersToUpsize(
1833     absl::flat_hash_map<Node*, Parameter*>& node_parameters) {
1834   {
1835     tf_shared_lock l(mu_);
1836     for (auto& [node_name, parameter] : parameters_) {
1837       if ((parameter->name != kBufferSize) ||
1838           (parameter->state == nullptr || !parameter->state->tunable)) {
1839         continue;
1840       }
1841       if (buffered_elements_low_ <= 0 &&
1842           buffered_elements_high_ >= parameter->value) {
1843         parameter->value = parameter->state->value;
1844         node_parameters[this] = parameter.get();
1845       }
1846     }
1847   }
1848   for (auto& [node, parameter] : node_parameters) {
1849     tf_shared_lock l(*parameter->state->mu);
1850     parameter->value = parameter->state->value;
1851   }
1852 }
1853 
CollectNodesLocked(TraversalOrder order,bool collect_node (const std::shared_ptr<Node>)) const1854 Node::NodeVector Node::CollectNodesLocked(
1855     TraversalOrder order, bool collect_node(const std::shared_ptr<Node>)) const
1856     TF_SHARED_LOCKS_REQUIRED(mu_) {
1857   NodeVector node_vector;
1858   std::list<std::shared_ptr<Node>> temp_list;
1859 
1860   for (auto& input : inputs_) {
1861     if (collect_node(input)) {
1862       node_vector.push_back(input);
1863       temp_list.push_back(input);
1864     }
1865   }
1866 
1867   while (!temp_list.empty()) {
1868     auto cur_node = temp_list.front();
1869     temp_list.pop_front();
1870     tf_shared_lock l(cur_node->mu_);
1871     for (auto& input : cur_node->inputs_) {
1872       if (collect_node(input)) {
1873         node_vector.push_back(input);
1874         temp_list.push_back(input);
1875       }
1876     }
1877   }
1878 
1879   if (order == TraversalOrder::REVERSE_BFS) {
1880     std::reverse(node_vector.begin(), node_vector.end());
1881   }
1882   return node_vector;
1883 }
1884 
CollectTunableParametersHelper(Node::ModelParameters * parameters) const1885 void Node::CollectTunableParametersHelper(
1886     Node::ModelParameters* parameters) const TF_SHARED_LOCKS_REQUIRED(mu_) {
1887   // If autotune is turned off or there are no elements recorded, we don't
1888   // collect the parameters on the node.
1889   if (!autotune_ || num_elements_ <= 0) {
1890     return;
1891   }
1892   for (auto& pair : parameters_) {
1893     if (pair.second->state != nullptr && pair.second->state->tunable) {
1894       parameters->push_back(std::make_pair(long_name(), pair.second));
1895     }
1896   }
1897 }
1898 
DebugStringHelper(absl::flat_hash_map<string,string> * debug_strings) const1899 void Node::DebugStringHelper(absl::flat_hash_map<string, string>* debug_strings)
1900     const TF_SHARED_LOCKS_REQUIRED(mu_) {
1901   string result;
1902   strings::StrAppend(&result, long_name(), ":\n");
1903   strings::StrAppend(&result, "  autotune=", autotune_.load(), "\n");
1904   strings::StrAppend(&result, "  buffered_bytes=", buffered_bytes_.load(),
1905                      "\n");
1906   strings::StrAppend(&result, "  buffered_elements=", buffered_elements_.load(),
1907                      "\n");
1908   strings::StrAppend(&result, "  bytes_consumed=", bytes_consumed_.load(),
1909                      "\n");
1910   strings::StrAppend(&result, "  bytes_produced=", bytes_produced_.load(),
1911                      "\n");
1912   strings::StrAppend(&result, "  processing_time=", processing_time_.load(),
1913                      "\n");
1914   strings::StrAppend(&result, "  num_elements=", num_elements_.load(), "\n");
1915   string inputs;
1916   for (auto& input : inputs_) {
1917     strings::StrAppend(&inputs, input->long_name(), ",");
1918   }
1919   strings::StrAppend(&result, "  inputs={", inputs, "}\n");
1920   for (auto& input : inputs_) {
1921     strings::StrAppend(&result, debug_strings->at(input->long_name()));
1922   }
1923   debug_strings->insert(std::make_pair(long_name(), result));
1924 }
1925 
SnapshotHelper(std::shared_ptr<Node> cloned_output,Node::NodePairList * node_pairs) const1926 std::shared_ptr<Node> Node::SnapshotHelper(
1927     std::shared_ptr<Node> cloned_output, Node::NodePairList* node_pairs) const {
1928   tf_shared_lock l(mu_);
1929 
1930   // Clone current node(`this`), also set clone of its output node
1931   // (`cloned_output`) to be the output node of the cloned node
1932   // (`cloned_current`).
1933   std::shared_ptr<Node> cloned_current = Clone(cloned_output);
1934   {
1935     cloned_current->autotune_.store(autotune_);
1936     cloned_current->buffered_bytes_.store(buffered_bytes_);
1937     cloned_current->buffered_elements_.store(buffered_elements_);
1938     cloned_current->buffered_elements_low_.store(buffered_elements_low_);
1939     cloned_current->buffered_elements_high_.store(buffered_elements_high_);
1940     cloned_current->bytes_consumed_.store(bytes_consumed_);
1941     cloned_current->bytes_produced_.store(bytes_produced_);
1942     cloned_current->num_elements_.store(num_elements_);
1943     cloned_current->record_metrics_.store(false);
1944     cloned_current->processing_time_.store(processing_time_);
1945     {
1946       mutex_lock l2(cloned_current->mu_);
1947       cloned_current->parameters_ = parameters_;
1948       cloned_current->previous_processing_time_ = previous_processing_time_;
1949       cloned_current->processing_time_ema_ = processing_time_ema_;
1950     }
1951   }
1952 
1953   for (auto& input : inputs_) {
1954     node_pairs->push_back(std::make_pair(input, cloned_current));
1955   }
1956   return cloned_current;
1957 }
1958 
TotalBufferedBytesHelper(Node::NodeValues * total_bytes) const1959 void Node::TotalBufferedBytesHelper(Node::NodeValues* total_bytes) const
1960     TF_SHARED_LOCKS_REQUIRED(mu_) {
1961   if (!autotune_) {
1962     total_bytes->insert(std::make_pair(long_name(), 0));
1963     return;
1964   }
1965 
1966   double result = 0;
1967   auto* parameter = gtl::FindOrNull(parameters_, kBufferSize);
1968   if (!parameter) {
1969     parameter = gtl::FindOrNull(parameters_, kParallelism);
1970   }
1971   if (parameter) {
1972     result = buffered_bytes_;
1973   }
1974   for (auto& input : inputs_) {
1975     result += total_bytes->at(input->long_name());
1976   }
1977   total_bytes->insert(std::make_pair(long_name(), result));
1978 }
1979 
TotalMaximumBufferedBytesHelper(Node::NodeValues * total_bytes) const1980 void Node::TotalMaximumBufferedBytesHelper(Node::NodeValues* total_bytes) const
1981     TF_SHARED_LOCKS_REQUIRED(mu_) {
1982   if (!autotune_) {
1983     total_bytes->insert(std::make_pair(long_name(), 0));
1984     return;
1985   }
1986 
1987   double result = MaximumBufferedBytes();
1988   for (auto& input : inputs_) {
1989     result += total_bytes->at(input->long_name());
1990   }
1991   total_bytes->insert(std::make_pair(long_name(), result));
1992 }
1993 
MaximumBufferedBytes() const1994 double Node::MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_) {
1995   return 0;
1996 }
1997 
ToProto(ModelProto::Node * node_proto) const1998 Status Node::ToProto(ModelProto::Node* node_proto) const {
1999   tf_shared_lock l(mu_);
2000   node_proto->set_id(id_);
2001   node_proto->set_name(name_);
2002   node_proto->set_autotune(autotune_);
2003   node_proto->set_buffered_bytes(buffered_bytes_);
2004   node_proto->set_buffered_elements(buffered_elements_);
2005   node_proto->set_bytes_consumed(bytes_consumed_);
2006   node_proto->set_bytes_produced(bytes_produced_);
2007   node_proto->set_num_elements(num_elements_);
2008   node_proto->set_processing_time(processing_time_);
2009   node_proto->set_record_metrics(record_metrics_);
2010 
2011   // Produce protos for all parameters.
2012   for (auto const& parameter : parameters_) {
2013     ModelProto::Node::Parameter* parameter_proto = node_proto->add_parameters();
2014     parameter_proto->set_name(parameter.first);
2015     parameter_proto->set_value(parameter.second->value);
2016     parameter_proto->set_min(parameter.second->min);
2017     parameter_proto->set_max(parameter.second->max);
2018     if (parameter.second->state != nullptr) {
2019       parameter_proto->set_state_value(parameter.second->state->value);
2020       parameter_proto->set_tunable(parameter.second->state->tunable);
2021     }
2022   }
2023 
2024   // Add input node ids.
2025   for (auto const& input : inputs_) {
2026     node_proto->add_inputs(input->id());
2027   }
2028   return OkStatus();
2029 }
2030 
FromProtoHelper(ModelProto::Node node_proto,std::shared_ptr<Node> node)2031 Status Node::FromProtoHelper(ModelProto::Node node_proto,
2032                              std::shared_ptr<Node> node) {
2033   {
2034     tf_shared_lock l(node->mu_);
2035     node->autotune_.store(node_proto.autotune());
2036     node->buffered_bytes_.store(node_proto.buffered_bytes());
2037     node->buffered_elements_.store(node_proto.buffered_elements());
2038     if (node_proto.buffered_elements() == 0) {
2039       node->buffered_elements_low_.store(std::numeric_limits<int64_t>::max());
2040       node->buffered_elements_high_.store(std::numeric_limits<int64_t>::min());
2041     } else {
2042       node->buffered_elements_low_.store(node_proto.buffered_elements());
2043       node->buffered_elements_high_.store(node_proto.buffered_elements());
2044     }
2045     node->bytes_consumed_.store(node_proto.bytes_consumed());
2046     node->bytes_produced_.store(node_proto.bytes_produced());
2047     node->num_elements_.store(node_proto.num_elements());
2048     node->processing_time_.store(node_proto.processing_time());
2049     node->record_metrics_.store(node_proto.record_metrics());
2050 
2051     // Restore parameters.
2052     int64_t num_parameters = node_proto.parameters_size();
2053     for (int i = 0; i < num_parameters; i++) {
2054       const ModelProto::Node::Parameter& parameter_proto =
2055           node_proto.parameters(i);
2056       std::shared_ptr<SharedState> state;
2057       if (parameter_proto.tunable()) {
2058         state = std::make_shared<SharedState>(
2059             kAutotune, std::make_shared<mutex>(),
2060             std::make_shared<condition_variable>());
2061         state->value = parameter_proto.state_value();
2062       } else {
2063         state = std::make_shared<SharedState>(
2064             parameter_proto.state_value(), std::make_shared<mutex>(),
2065             std::make_shared<condition_variable>());
2066       }
2067       node->parameters_[parameter_proto.name()] =
2068           MakeParameter(parameter_proto.name(), state, parameter_proto.min(),
2069                         parameter_proto.max());
2070       node->parameters_[parameter_proto.name()]->value =
2071           std::max(parameter_proto.min(), parameter_proto.value());
2072     }
2073   }
2074   {
2075     mutex_lock l(node->mu_);
2076     node->UpdateProcessingTimeEma();
2077   }
2078   return OkStatus();
2079 }
2080 
FromProto(ModelProto::Node node_proto,std::shared_ptr<Node> output,std::shared_ptr<Node> * node)2081 Status Node::FromProto(ModelProto::Node node_proto,
2082                        std::shared_ptr<Node> output,
2083                        std::shared_ptr<Node>* node) {
2084   // Note that parameters are restored in `FromProtoHelper`.
2085   Args args = {node_proto.id(), node_proto.name(), std::move(output)};
2086   switch (node_proto.node_class()) {
2087     case NodeClass::INTERLEAVE_MANY:
2088       *node = std::make_shared<InterleaveMany>(args);
2089       break;
2090     case NodeClass::ASYNC_INTERLEAVE_MANY:
2091       *node = std::make_shared<AsyncInterleaveMany>(
2092           args, /*parameters=*/std::vector<std::shared_ptr<Parameter>>());
2093       break;
2094     case NodeClass::KNOWN_RATIO:
2095       *node = std::make_shared<KnownRatio>(args, node_proto.ratio());
2096       break;
2097     case NodeClass::ASYNC_KNOWN_RATIO:
2098       *node = std::make_shared<AsyncKnownRatio>(
2099           args, node_proto.ratio(), node_proto.memory_ratio(),
2100           /*parameters=*/std::vector<std::shared_ptr<Parameter>>());
2101       break;
2102     case NodeClass::UNKNOWN_RATIO:
2103       *node = std::make_shared<UnknownRatio>(args);
2104       break;
2105     case NodeClass::ASYNC_UNKNOWN_RATIO:
2106       *node = std::make_shared<AsyncUnknownRatio>(
2107           args, /*parameters=*/std::vector<std::shared_ptr<Parameter>>());
2108       break;
2109     default:
2110       *node = std::make_shared<Unknown>(args);
2111   }
2112   return FromProtoHelper(node_proto, *node);
2113 }
2114 
Model()2115 Model::Model() : optimization_period_ms_(kOptimizationPeriodMinMs) {
2116   model_gauge_cell_ = metrics::GetTFDataModelGauge(
2117       strings::StrCat(reinterpret_cast<uint64>(this)));
2118   model_gauge_cell_->Set([&]() { return DebugString(); });
2119 }
2120 
~Model()2121 Model::~Model() {
2122   // Before the model is destroyed, we record an empty string in the gauge to
2123   // prevent race condition where the gauge callback is called after the Model
2124   // is destroyed.
2125   model_gauge_cell_->Set([]() { return std::string(); });
2126 }
2127 
AddNode(Node::Factory factory,const string & name,std::shared_ptr<Node> parent,std::shared_ptr<Node> * out_node)2128 void Model::AddNode(Node::Factory factory, const string& name,
2129                     std::shared_ptr<Node> parent,
2130                     std::shared_ptr<Node>* out_node) {
2131   // The name captures the sequence of iterators joined by `::`. We only use the
2132   // last element of the sequence as the name node.
2133   auto node_name = str_util::Split(name, ':', str_util::SkipEmpty()).back();
2134   mutex_lock l(mu_);
2135   std::shared_ptr<Node> node = factory({id_counter_++, node_name, parent});
2136   if (!output_) {
2137     output_ = node;
2138   }
2139   if (parent) {
2140     VLOG(3) << "Adding " << node->long_name() << " as input for "
2141             << parent->long_name();
2142     parent->add_input(node);
2143   } else {
2144     VLOG(3) << "Adding " << node->long_name();
2145   }
2146   *out_node = std::move(node);
2147   // TODO(jsimsa): Reset the optimization period when a node is added so that
2148   // autotuning adapts to changes to the input pipeline faster. Initial attempt
2149   // to enable this functionality caused a regression (see b/179812091).
2150 }
2151 
FlushMetrics()2152 void Model::FlushMetrics() {
2153   std::deque<std::shared_ptr<Node>> queue;
2154   {
2155     tf_shared_lock l(mu_);
2156     if (output_) queue.push_back(output_);
2157   }
2158   while (!queue.empty()) {
2159     auto node = queue.front();
2160     queue.pop_front();
2161     node->FlushMetrics();
2162     for (auto input : node->inputs()) {
2163       queue.push_back(input);
2164     }
2165   }
2166 }
2167 
Optimize(AutotuneAlgorithm algorithm,int64_t cpu_budget,int64_t ram_budget,double model_input_time,CancellationManager * cancellation_manager)2168 void Model::Optimize(AutotuneAlgorithm algorithm, int64_t cpu_budget,
2169                      int64_t ram_budget, double model_input_time,
2170                      CancellationManager* cancellation_manager) {
2171   std::shared_ptr<Node> snapshot;
2172   {
2173     tf_shared_lock l(mu_);
2174     snapshot = output_->Snapshot();
2175   }
2176   if (!port::JobName().empty()) {
2177     RecordAutotuneRamUsage(ram_budget, TotalMaximumBufferedBytes(snapshot));
2178   }
2179   OptimizationParams optimization_params;
2180   optimization_params.set_algorithm(algorithm);
2181   optimization_params.set_cpu_budget(cpu_budget);
2182   optimization_params.set_ram_budget(ram_budget);
2183   optimization_params.set_model_input_time(model_input_time);
2184   switch (algorithm) {
2185     case AutotuneAlgorithm::DEFAULT:
2186     case AutotuneAlgorithm::MAX_PARALLELISM:
2187       OptimizeMaxParallelism(snapshot, optimization_params,
2188                              cancellation_manager);
2189       break;
2190     case AutotuneAlgorithm::HILL_CLIMB:
2191       OptimizeHillClimb(snapshot, optimization_params, cancellation_manager);
2192       break;
2193     case AutotuneAlgorithm::GRADIENT_DESCENT:
2194       OptimizeGradientDescent(snapshot, optimization_params,
2195                               cancellation_manager);
2196       break;
2197     case AutotuneAlgorithm::STAGE_BASED:
2198       OptimizeStageBased(snapshot, optimization_params, cancellation_manager);
2199       break;
2200     default:
2201       VLOG(2) << "Autotuning algorithm was not recognized. Aborting "
2202                  "optimization.";
2203       return;
2204   }
2205   if (experiment_ == "autotune_buffer_optimization") {
2206     OptimizeBuffers(snapshot, optimization_params.ram_budget());
2207   }
2208 }
2209 
RemoveNode(std::shared_ptr<Node> node)2210 void Model::RemoveNode(std::shared_ptr<Node> node) {
2211   mutex_lock l(mu_);
2212   if (node) {
2213     if (node->output()) {
2214       node->output()->remove_input(node);
2215     }
2216     VLOG(3) << "Removing " << node->long_name();
2217   }
2218 }
2219 
CollectTunableParameters(std::shared_ptr<Node> node)2220 Model::ModelParameters Model::CollectTunableParameters(
2221     std::shared_ptr<Node> node) {
2222   return node->CollectTunableParameters();
2223 }
2224 
DownsizeBuffers(std::shared_ptr<Node> snapshot)2225 bool Model::DownsizeBuffers(std::shared_ptr<Node> snapshot) {
2226   Node::NodeVector nodes =
2227       snapshot->CollectNodes(TraversalOrder::BFS, IsAsyncNode);
2228   nodes.push_back(snapshot);
2229   bool downsized = false;
2230   for (auto& node : nodes) {
2231     if (node->TryDownsizeBuffer()) {
2232       downsized = true;
2233     }
2234   }
2235   return downsized;
2236 }
2237 
CollectBufferParametersToUpsize(std::shared_ptr<Node> snapshot)2238 absl::flat_hash_map<Node*, Parameter*> Model::CollectBufferParametersToUpsize(
2239     std::shared_ptr<Node> snapshot) {
2240   Node::NodeVector nodes =
2241       snapshot->CollectNodes(TraversalOrder::BFS, IsAsyncNode);
2242   absl::flat_hash_map<Node*, Parameter*> node_parameters;
2243   if (snapshot->IsAsync()) {
2244     snapshot->CollectBufferParametersToUpsize(node_parameters);
2245   }
2246   for (auto& node : nodes) {
2247     node->CollectBufferParametersToUpsize(node_parameters);
2248   }
2249   return node_parameters;
2250 }
2251 
ShouldStop(int64_t cpu_budget,int64_t ram_budget,const Model::ModelParameters & parameters,const Model::ModelParameters & parallelism_parameters,const Model::ModelParameters & buffer_size_parameters,std::shared_ptr<Node> snapshot,bool * cpu_budget_reached)2252 bool Model::ShouldStop(int64_t cpu_budget, int64_t ram_budget,
2253                        const Model::ModelParameters& parameters,
2254                        const Model::ModelParameters& parallelism_parameters,
2255                        const Model::ModelParameters& buffer_size_parameters,
2256                        std::shared_ptr<Node> snapshot,
2257                        bool* cpu_budget_reached) {
2258   if (!(*cpu_budget_reached)) {
2259     // If those essential transformations' parallelism reaches the CPU budget,
2260     // we will only tune the buffer size parameters in future iterations.
2261     int64_t model_parallelism = 0;
2262     for (auto& pair : parallelism_parameters) {
2263       model_parallelism += std::round(pair.second->value);
2264     }
2265     *cpu_budget_reached = (model_parallelism > cpu_budget);
2266   }
2267 
2268   bool all_max = AreAllParametersMax(
2269       *cpu_budget_reached ? buffer_size_parameters : parameters);
2270 
2271   // If all parameters have reached their maximum values or RAM budget is
2272   // reached, we stop the iterations.
2273   return all_max || TotalMaximumBufferedBytes(snapshot) > ram_budget;
2274 }
2275 
2276 // TODO(jsimsa): Add support for tracking and using the model input time.
OptimizeLoop(AutotuneAlgorithm algorithm,int64_t cpu_budget,int64_t ram_budget,CancellationManager * cancellation_manager)2277 Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, int64_t cpu_budget,
2278                            int64_t ram_budget,
2279                            CancellationManager* cancellation_manager) {
2280   std::function<void()> unused;
2281   TF_RETURN_IF_ERROR(RegisterCancellationCallback(
2282       cancellation_manager,
2283       [this]() {
2284         mutex_lock l(mu_);
2285         optimize_cond_var_.notify_all();
2286       },
2287       /*deregister_fn=*/&unused));
2288 
2289   int64_t last_optimization_ms = 0;
2290   int64_t current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
2291   while (true) {
2292     {
2293       mutex_lock l(mu_);
2294       while (!cancellation_manager->IsCancelled() &&
2295              last_optimization_ms + optimization_period_ms_ > current_time_ms) {
2296         auto wait_ms =
2297             last_optimization_ms + optimization_period_ms_ - current_time_ms;
2298         VLOG(2) << "Waiting for " << wait_ms << " ms.";
2299         optimize_cond_var_.wait_for(l, std::chrono::milliseconds(wait_ms));
2300         current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
2301       }
2302       if (cancellation_manager->IsCancelled()) {
2303         return OkStatus();
2304       }
2305     }
2306 
2307     int64_t start_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
2308     double model_input_time = 0.0;
2309     // Model input time is set to 0 for all optimization algorithms except for
2310     // stage-based optimization algorithm for historical reason. In stage-based
2311     // optimization algorithm, the model input time is used as a target
2312     // optimization time of all stages in the pipeline.
2313     if (algorithm == AutotuneAlgorithm::STAGE_BASED) {
2314       model_input_time = ComputeTargetTimeNsec();
2315     }
2316     Optimize(algorithm, cpu_budget, ram_budget, model_input_time,
2317              cancellation_manager);
2318     int64_t end_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
2319     VLOG(2) << "Optimized for " << end_ms - start_ms << " ms.";
2320 
2321     // Exponentially increase the period of running the optimization until a
2322     // threshold is reached.
2323     {
2324       mutex_lock l(mu_);
2325       optimization_period_ms_ =
2326           std::min(optimization_period_ms_ << 1, kOptimizationPeriodMaxMs);
2327     }
2328     current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros;
2329     last_optimization_ms = current_time_ms;
2330     FlushMetrics();
2331   }
2332 }
2333 
OptimizeGradientDescent(std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params,CancellationManager * cancellation_manager)2334 void Model::OptimizeGradientDescent(
2335     std::shared_ptr<Node> snapshot,
2336     const OptimizationParams& optimization_params,
2337     CancellationManager* cancellation_manager) {
2338   VLOG(2) << "Starting optimization of tunable parameters with Gradient "
2339              "Descent.";
2340   auto parameters = CollectTunableParameters(snapshot);
2341   if (parameters.empty()) {
2342     VLOG(2) << "The Gradient Descent optimization is terminated since no node "
2343                "with tunable parameters has recorded elements.";
2344     return;
2345   }
2346   VLOG(2) << "Number of tunable parameters: " << parameters.size();
2347 
2348   // The vectors of "essential" parallelism parameters and buffer size
2349   // parameters.
2350   Model::ModelParameters parallelism_parameters, buffer_size_parameters;
2351   CollectParameters(snapshot, parameters, &parallelism_parameters,
2352                     &buffer_size_parameters);
2353 
2354   // Initialize the parameter values to minimal before tuning.
2355   for (auto& pair : parameters) {
2356     pair.second->value = pair.second->min;
2357   }
2358 
2359   // Optimization is stopped once the `OutputTime` improvement is smaller than
2360   // this value.
2361   constexpr double kOptimizationPrecision = 100.0L;
2362 
2363   // Maximum number of iterations for optimization.
2364   constexpr int64_t kMaxIterations = 1000;
2365 
2366   double output_time = 0;
2367   double new_output_time;
2368 
2369   // When the CPU budget is reached, the parallelism parameter values are fixed
2370   // and we only increase the buffer size parameters.
2371   bool cpu_budget_reached = false;
2372 
2373   for (int i = 0; i < kMaxIterations; ++i) {
2374     if (cancellation_manager->IsCancelled() ||
2375         ShouldStop(optimization_params.cpu_budget(),
2376                    optimization_params.ram_budget(), parameters,
2377                    parallelism_parameters, buffer_size_parameters, snapshot,
2378                    &cpu_budget_reached)) {
2379       break;
2380     }
2381     Model::ParameterGradients gradients;
2382     new_output_time = OutputTime(
2383         snapshot, optimization_params.model_input_time(), &gradients);
2384     // We also terminate once the improvement of the output latency is too
2385     // small.
2386     if (std::abs(output_time - new_output_time) < kOptimizationPrecision) {
2387       break;
2388     }
2389 
2390     UpdateParameterValues(
2391         gradients, &(cpu_budget_reached ? buffer_size_parameters : parameters));
2392     output_time = new_output_time;
2393   }
2394 
2395   for (auto& pair : parameters) {
2396     pair.second->value = std::round(pair.second->value);
2397   }
2398   UpdateStateValues(&parameters);
2399 }
2400 
OptimizeHillClimbHelper(std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params,CancellationManager * cancellation_manager,StopPredicate should_stop)2401 void Model::OptimizeHillClimbHelper(
2402     std::shared_ptr<Node> snapshot,
2403     const OptimizationParams& optimization_params,
2404     CancellationManager* cancellation_manager, StopPredicate should_stop) {
2405   VLOG(2) << "Starting optimization of tunable parameters with Hill Climb.";
2406   const double processing_time = TotalProcessingTime(snapshot);
2407   auto parameters = CollectTunableParameters(snapshot);
2408   if (parameters.empty()) {
2409     VLOG(2) << "There are no tunable parameters.";
2410     return;
2411   }
2412   VLOG(2) << "Number of tunable parameters: " << parameters.size();
2413 
2414   // Buffer size parameter will only be incremented if the output latency
2415   // improvement is greater than this constant.
2416   constexpr double kBufferSizeMinDelta = 1.0L;
2417 
2418   // Skip buffer size optimization if we are running the new buffering
2419   // algorithm.
2420   bool skip_buffer_sizes = (experiment_ == "autotune_buffer_optimization");
2421   if (skip_buffer_sizes) {
2422     constexpr float TEN_MINUTES = 60.0 * 10.0;
2423     LOG_EVERY_N_SEC(INFO, TEN_MINUTES)
2424         << "Skipping buffer_size parameters in HillClimb (message logged "
2425            "every "
2426            "10 minutes).";
2427   }
2428   // Initialize the parameter values to minimal before tuning.
2429   for (auto& pair : parameters) {
2430     if (skip_buffer_sizes && (pair.second->name == kBufferSize)) {
2431       continue;
2432     }
2433     pair.second->value = pair.second->min;
2434   }
2435   while (!cancellation_manager->IsCancelled()) {
2436     const double output_time =
2437         OutputTime(snapshot, optimization_params.model_input_time(),
2438                    /*gradients=*/nullptr);
2439     if (should_stop(parameters, processing_time, output_time,
2440                     TotalMaximumBufferedBytes(snapshot))) {
2441       break;
2442     }
2443 
2444     double best_delta = -1.0L;
2445     Parameter* best_parameter = nullptr;
2446     for (auto& pair : parameters) {
2447       if (pair.second->value >= pair.second->max ||
2448           (skip_buffer_sizes && (pair.second->name == kBufferSize))) {
2449         continue;
2450       }
2451       pair.second->value++;
2452       double new_output_time =
2453           OutputTime(snapshot, optimization_params.model_input_time(),
2454                      /*gradients=*/nullptr);
2455       double delta = output_time - new_output_time;
2456       if (delta > best_delta &&
2457           (delta > kBufferSizeMinDelta || pair.second->name != kBufferSize)) {
2458         best_delta = delta;
2459         best_parameter = pair.second.get();
2460       }
2461       pair.second->value--;
2462     }
2463     if (!best_parameter) {
2464       VLOG(2) << "Failed to find a tunable parameter that would further "
2465                  "decrease the output time. This suggests that the hill-climb "
2466                  "optimization got stuck in a local maximum. The optimization "
2467                  "attempt will stop now.";
2468       break;
2469     }
2470     best_parameter->value++;
2471   }
2472   UpdateStateValues(&parameters);
2473 }
RecordIteratorGapTime(uint64_t duration_usec)2474 void Model::RecordIteratorGapTime(uint64_t duration_usec) {
2475   mutex_lock l(gap_mu_);
2476   // Drop duration if it is too large.
2477   if (duration_usec >= kGapDurationThresholdUsec) {
2478     return;
2479   }
2480   gap_times_usec_.push_back(duration_usec);
2481   // Keep only the latest `window` gap times. Drop the oldest one.
2482   while (gap_times_usec_.size() > kGapTimeWindow) {
2483     gap_times_usec_.pop_front();
2484   }
2485 }
2486 
ComputeTargetTimeNsec()2487 double Model::ComputeTargetTimeNsec() {
2488   tf_shared_lock l(gap_mu_);
2489   if (gap_times_usec_.empty()) {
2490     return 0.0;
2491   }
2492   // Remove outliers.
2493   std::vector<uint64_t> clean_gap_times_usec =
2494       OutlierPruner({gap_times_usec_.begin(), gap_times_usec_.end()})
2495           .GetCleanPoints();
2496   if (clean_gap_times_usec.empty()) {
2497     return 0.0;
2498   }
2499   // Compute mean after outliers are removed.
2500   double sum_gap_time_usec = std::accumulate(clean_gap_times_usec.begin(),
2501                                              clean_gap_times_usec.end(), 0);
2502   return sum_gap_time_usec / static_cast<double>(clean_gap_times_usec.size()) *
2503          1.0e3;
2504 }
2505 
OptimizeStageBased(std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params,CancellationManager * cancellation_manager)2506 void Model::OptimizeStageBased(std::shared_ptr<Node> snapshot,
2507                                const OptimizationParams& optimization_params,
2508                                CancellationManager* cancellation_manager) {
2509   return OptimizeStageBasedParallelism(
2510       snapshot, optimization_params.model_input_time(), optimization_params,
2511       cancellation_manager);
2512 }
2513 
OptimizeStageBasedParallelism(std::shared_ptr<Node> snapshot,double target_time_nsec,const OptimizationParams & optimization_params,CancellationManager * cancellation_manager)2514 void Model::OptimizeStageBasedParallelism(
2515     std::shared_ptr<Node> snapshot, double target_time_nsec,
2516     const OptimizationParams& optimization_params,
2517     CancellationManager* cancellation_manager) {
2518   VLOG(2) << "Starting optimization of tunable parameters with Stage-Based "
2519              "optimization with a target time of "
2520           << optimization_params.model_input_time() << " nanoseconds.";
2521   Node::ModelParameters tunable_parameters = CollectTunableParameters(snapshot);
2522   // Initialize the parallelism parameter values to minimal before tuning.
2523   for (std::pair<string, std::shared_ptr<Parameter>>& pair :
2524        tunable_parameters) {
2525     if (pair.second->name != kParallelism) {
2526       continue;
2527     }
2528     pair.second->value = pair.second->min;
2529   }
2530   ModelTiming model_timing(snapshot);
2531   ModelTimingPriorityQueue priority_queue(model_timing);
2532   StatusOr<std::pair<double, Node*>> critical_root_status =
2533       priority_queue.PopSlowestStageRoot();
2534   if (!critical_root_status.ok()) {
2535     return;
2536   }
2537   NodeParallelismParameters node_parallelism;
2538   std::pair<double, Node*> critical_root = critical_root_status.ValueOrDie();
2539   while (critical_root.first > target_time_nsec) {
2540     Parameter* parallelism_parameter =
2541         node_parallelism.Get(critical_root.second);
2542     // Stop optimization if the critical stage has no `parallelism` parameter or
2543     // it has reached the max parallelism value.
2544     if (parallelism_parameter == nullptr ||
2545         parallelism_parameter->value >= parallelism_parameter->max) {
2546       break;
2547     }
2548     parallelism_parameter->value += 1.0;
2549     if (TotalMaximumBufferedBytes(snapshot) >
2550         optimization_params.ram_budget()) {
2551       // Increasing the parallelism by 1 exceeded ram budget. Reduce it back and
2552       // stop optimization because we cannot improve the most critical stage.
2553       // There is also a decent chance that the current optimization iteration
2554       // is under-optimized. For that reason, return immediately without
2555       // updating the parameter state values.
2556       parallelism_parameter->value -= 1.0;
2557       return;
2558     }
2559     // Compute the new total time and put the node back in the queue after its
2560     // parallelism value has been increased by 1.
2561     model_timing.ComputeNodeTotalTime(*critical_root.second);
2562     const ModelTiming::NodeTiming* root_timing =
2563         model_timing.GetTiming(critical_root.second);
2564     // If timing has not improved, stop optimizing.
2565     if (critical_root.first <= root_timing->total_time_nsec) {
2566       parallelism_parameter->value -= 1.0;
2567       break;
2568     }
2569     // Push it back to the priority queue.
2570     priority_queue.Push(critical_root.second, *root_timing);
2571     // Get the next critical stage root.
2572     critical_root_status = priority_queue.PopSlowestStageRoot();
2573     if (!critical_root_status.ok()) {
2574       break;
2575     }
2576     critical_root = critical_root_status.ValueOrDie();
2577   }
2578   UpdateStateValues(&tunable_parameters);
2579 }
2580 
OptimizeBuffers(std::shared_ptr<Node> snapshot,int64_t ram_budget)2581 void Model::OptimizeBuffers(std::shared_ptr<Node> snapshot,
2582                             int64_t ram_budget) {
2583   VLOG(2) << "Starting optimization of buffer_size parameters.";
2584   constexpr float TEN_MINUTES = 60.0 * 10.0;
2585   LOG_EVERY_N_SEC(INFO, TEN_MINUTES)
2586       << "Starting optimization of buffer_size parameters (message logged "
2587          "every 10 minutes).";
2588   // Reset node watermarks if any node's buffer is upsized or downsized. We
2589   // reset the watermarks of not only those nodes whose sizes change but all
2590   // nodes. The reason is that the optimization algorithm works on a snapshot of
2591   // nodes. There is no back references from snapshot of nodes to nodes. We
2592   // could add these back references but it is probably not necessary.
2593   bool downsized = DownsizeBuffers(snapshot);
2594   bool upsized = UpsizeBuffers(snapshot, ram_budget);
2595   if (downsized || upsized) {
2596     ResetBufferWatermarks();
2597   }
2598 }
2599 
UpsizeBuffers(std::shared_ptr<Node> snapshot,int64_t ram_budget)2600 bool Model::UpsizeBuffers(std::shared_ptr<Node> snapshot, int64_t ram_budget) {
2601   // Find buffers that should be up-sized.
2602   absl::flat_hash_map<Node*, Parameter*> node_parameters =
2603       CollectBufferParametersToUpsize(snapshot);
2604 
2605   // Compute available memory.
2606   double available_ram_bytes =
2607       static_cast<double>(ram_budget) - TotalMaximumBufferedBytes(snapshot);
2608 
2609   // Compute the max memory used by all buffers that should be upsized.
2610   double max_buffered_bytes = 0;
2611   for (auto& [node, parameter] : node_parameters) {
2612     if (node->buffered_elements() == 0) {
2613       continue;
2614     }
2615     max_buffered_bytes += static_cast<double>(node->buffered_bytes()) /
2616                           static_cast<double>(node->buffered_elements()) *
2617                           parameter->value;
2618   }
2619 
2620   // Compute a uniform scaling factor for all buffers. Cap the factor at 2.
2621   double scaling_factor = 2.0;
2622   if (max_buffered_bytes > 0) {
2623     scaling_factor =
2624         1.0 + std::min(1.0, available_ram_bytes / max_buffered_bytes);
2625   }
2626 
2627   bool upsized = false;
2628   // Up-size all buffers by the scaling factor.
2629   for (auto& [node, parameter] : node_parameters) {
2630     double old_value = parameter->value;
2631     // Scale the new buffer_size value. Use 1 if it is less than 1.
2632     double new_value = std::max(1.0, static_cast<double>(static_cast<int64_t>(
2633                                          parameter->value * scaling_factor)));
2634     // Cap the new buffer_size value at its max value.
2635     parameter->value = std::min(parameter->max, new_value);
2636     VLOG(2) << "Upsize buffer " << node->long_name() << "::" << parameter->name
2637             << " from " << old_value << " to " << parameter->value;
2638     if (parameter->value != parameter->state->value) {
2639       {
2640         mutex_lock l(*parameter->state->mu);
2641         parameter->state->value = parameter->value;
2642         parameter->state->cond_var->notify_all();
2643       }
2644       upsized = true;
2645     }
2646   }
2647   return upsized;
2648 }
2649 
ResetBufferWatermarks()2650 void Model::ResetBufferWatermarks() {
2651   Node::NodeVector nodes =
2652       output()->CollectNodes(TraversalOrder::BFS, IsAsyncNode);
2653   nodes.push_back(output());
2654   for (auto& node : nodes) {
2655     node->ResetBufferWatermarks();
2656   }
2657 }
2658 
OptimizeHillClimb(std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params,CancellationManager * cancellation_manager)2659 void Model::OptimizeHillClimb(std::shared_ptr<Node> snapshot,
2660                               const OptimizationParams& optimization_params,
2661                               CancellationManager* cancellation_manager) {
2662   auto should_stop = [&optimization_params](const ModelParameters& parameters,
2663                                             double processing_time,
2664                                             double output_time,
2665                                             double buffered_bytes) {
2666     const bool all_max = AreAllParametersMax(parameters);
2667     const bool output_time_budget_exceeded =
2668         output_time < processing_time / optimization_params.cpu_budget();
2669     const bool ram_budget_exceeded =
2670         buffered_bytes > optimization_params.ram_budget();
2671     if (all_max) {
2672       metrics::RecordTFDataAutotuneStoppingCriteria("all_max");
2673     }
2674     if (output_time_budget_exceeded) {
2675       metrics::RecordTFDataAutotuneStoppingCriteria("output_time");
2676     }
2677     if (ram_budget_exceeded) {
2678       metrics::RecordTFDataAutotuneStoppingCriteria("max_buffered_bytes");
2679     }
2680     return all_max || output_time_budget_exceeded || ram_budget_exceeded;
2681   };
2682   OptimizeHillClimbHelper(snapshot, optimization_params, cancellation_manager,
2683                           should_stop);
2684 }
2685 
OptimizeMaxParallelism(std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params,CancellationManager * cancellation_manager)2686 void Model::OptimizeMaxParallelism(
2687     std::shared_ptr<Node> snapshot,
2688     const OptimizationParams& optimization_params,
2689     CancellationManager* cancellation_manager) {
2690   auto should_stop = [&optimization_params](const ModelParameters& parameters,
2691                                             double processing_time,
2692                                             double output_time,
2693                                             double buffered_bytes) {
2694     const bool all_max = AreAllParametersMax(parameters);
2695     const bool ram_budget_exceeded =
2696         buffered_bytes > optimization_params.ram_budget();
2697     if (all_max) {
2698       metrics::RecordTFDataAutotuneStoppingCriteria("all_max");
2699     }
2700     if (ram_budget_exceeded) {
2701       metrics::RecordTFDataAutotuneStoppingCriteria("max_buffered_bytes");
2702     }
2703     return all_max || ram_budget_exceeded;
2704   };
2705   OptimizeHillClimbHelper(snapshot, optimization_params, cancellation_manager,
2706                           should_stop);
2707 }
2708 
OutputTime(std::shared_ptr<Node> node,double model_input_time,Model::ParameterGradients * gradients)2709 double Model::OutputTime(std::shared_ptr<Node> node, double model_input_time,
2710                          Model::ParameterGradients* gradients) {
2711   // To store the input time for each node.
2712   Model::NodeValues input_times = {{kModelInputTimeKey, model_input_time}};
2713 
2714   // TODO(jsimsa): Now that we are accounting for buffer size in wait time
2715   // computation, assuming that the input is infinitely fast will result in
2716   // inaccurate estimates of the output latency.
2717   //
2718   // We should compute the output latency as a fix-point of the following
2719   // equation: `output_time = node(OutputTime(input_times(1, output_time))`.
2720 
2721   return node->OutputTime(&input_times, gradients);
2722 }
2723 
TotalBufferedBytes(std::shared_ptr<Node> node)2724 double Model::TotalBufferedBytes(std::shared_ptr<Node> node) {
2725   return node->TotalBufferedBytes();
2726 }
2727 
TotalMaximumBufferedBytes(std::shared_ptr<Node> node)2728 double Model::TotalMaximumBufferedBytes(std::shared_ptr<Node> node) {
2729   return node->TotalMaximumBufferedBytes();
2730 }
2731 
TotalProcessingTime(std::shared_ptr<Node> node)2732 double Model::TotalProcessingTime(std::shared_ptr<Node> node) {
2733   return node->TotalProcessingTime(/*processing_times=*/nullptr);
2734 }
2735 
ToProto(ModelProto * model_proto)2736 Status Model::ToProto(ModelProto* model_proto) {
2737   tf_shared_lock l(mu_);
2738   model_proto->set_id_counter(id_counter_);
2739   return ModelToProtoHelper(output_, model_proto);
2740 }
2741 
FromProto(ModelProto model_proto,std::unique_ptr<Model> * model)2742 Status Model::FromProto(ModelProto model_proto, std::unique_ptr<Model>* model) {
2743   std::unique_ptr<Model> restored_model = std::make_unique<Model>();
2744   mutex_lock l(restored_model->mu_);
2745   TF_RETURN_IF_ERROR(
2746       ModelFromProtoHelper(model_proto, &restored_model->output_));
2747   restored_model->id_counter_ = model_proto.id_counter();
2748   *model = std::move(restored_model);
2749   return OkStatus();
2750 }
2751 
Save(const string & fname,std::shared_ptr<Node> snapshot,const OptimizationParams & optimization_params)2752 Status Model::Save(const string& fname, std::shared_ptr<Node> snapshot,
2753                    const OptimizationParams& optimization_params) {
2754   ModelProto model_proto;
2755   std::unique_ptr<Model> model_snapshot = std::make_unique<Model>();
2756   {
2757     mutex_lock l(model_snapshot->mu_);
2758     model_snapshot->output_ = std::move(snapshot);
2759     model_snapshot->id_counter_ = id_counter_;
2760   }
2761   TF_RETURN_IF_ERROR(model_snapshot->ToProto(&model_proto));
2762   OptimizationParams* saved_optimization_params =
2763       model_proto.mutable_optimization_params();
2764   *saved_optimization_params = optimization_params;
2765   return WriteBinaryProto(Env::Default(), fname, model_proto);
2766 }
2767 
Load(const string & fname,std::unique_ptr<Model> * model,OptimizationParams * optimization_params)2768 Status Model::Load(const string& fname, std::unique_ptr<Model>* model,
2769                    OptimizationParams* optimization_params) {
2770   ModelProto model_proto;
2771   TF_RETURN_IF_ERROR(
2772       ReadTextOrBinaryProto(Env::Default(), fname, &model_proto));
2773   TF_RETURN_IF_ERROR(FromProto(model_proto, model));
2774   const OptimizationParams restored_optimization_params =
2775       model_proto.optimization_params();
2776   *optimization_params = restored_optimization_params;
2777   return OkStatus();
2778 }
2779 
DebugString()2780 std::string Model::DebugString() {
2781   constexpr int64_t kMinSecondsBetweenCalls = 30;
2782   if (absl::Now() < cache_until_) return cached_debug_string_;
2783   std::shared_ptr<Node> snapshot;
2784   {
2785     tf_shared_lock l(mu_);
2786     if (!output_) return cached_debug_string_;
2787     snapshot = output_->Snapshot();
2788   }
2789   // TODO(jsimsa): Populate OptimizationParams.
2790   ModelProto model_proto;
2791   Status s = ModelToProtoHelper(snapshot, &model_proto);
2792   if (s.ok()) {
2793     cached_debug_string_ = model_proto.DebugString();
2794   } else {
2795     LOG(WARNING) << s.error_message();
2796   }
2797   cache_until_ = absl::Now() + absl::Seconds(kMinSecondsBetweenCalls);
2798   return cached_debug_string_;
2799 }
2800 
ModelTiming(std::shared_ptr<Node> root)2801 ModelTiming::ModelTiming(std::shared_ptr<Node> root) : root_(root) {
2802   DCHECK(root_.get() != nullptr);
2803   auto bfs_nodes = CollectNodes(root_, TraversalOrder::BFS, IsAnyNode);
2804   auto reverse_bfs_nodes = bfs_nodes;
2805   std::reverse(reverse_bfs_nodes.begin(), reverse_bfs_nodes.end());
2806   ComputePipelineRatios(bfs_nodes);
2807   ComputeTotalTimes(reverse_bfs_nodes);
2808 }
2809 
CollectNodes(std::shared_ptr<Node> root,TraversalOrder order,bool collect_node (const std::shared_ptr<Node>)) const2810 Node::NodeVector ModelTiming::CollectNodes(
2811     std::shared_ptr<Node> root, TraversalOrder order,
2812     bool collect_node(const std::shared_ptr<Node>)) const {
2813   if (root == nullptr) {
2814     return Node::NodeVector({});
2815   }
2816   auto subtree_nodes = root->CollectNodes(order, collect_node);
2817   Node::NodeVector nodes;
2818   if (order == TraversalOrder::BFS) {
2819     nodes.push_back(root);
2820     nodes.insert(nodes.end(), subtree_nodes.begin(), subtree_nodes.end());
2821   } else {
2822     nodes.insert(nodes.end(), subtree_nodes.begin(), subtree_nodes.end());
2823     nodes.push_back(root);
2824   }
2825   return nodes;
2826 }
2827 
GetTiming(const Node * node) const2828 const ModelTiming::NodeTiming* ModelTiming::GetTiming(const Node* node) const {
2829   if (timing_nodes_.find(node) == timing_nodes_.end()) {
2830     return nullptr;
2831   }
2832   return &(timing_nodes_.at(node));
2833 }
2834 
ComputePipelineRatios(const Node::NodeVector & bfs_nodes)2835 void ModelTiming::ComputePipelineRatios(const Node::NodeVector& bfs_nodes) {
2836   for (const auto& node : bfs_nodes) {
2837     auto& node_timing = timing_nodes_[node.get()];
2838     if (!node->autotune()) {
2839       // These are inactive nodes marked by parallel interleave
2840       // transformations.
2841       node_timing.pipeline_ratio = 0.0;
2842       continue;
2843     }
2844     double parent_pipeline_ratio = 1.0;
2845     double parent_ratio = 1.0;
2846     if (node->output() != nullptr || timing_nodes_.contains(node->output())) {
2847       const auto& output_timing = timing_nodes_[node->output()];
2848       parent_pipeline_ratio = output_timing.pipeline_ratio;
2849       parent_ratio = node->output()->Ratio();
2850       if (parent_ratio <= 0.0) {
2851         // Parent ratio is unknown, we use 1.0 as a guess.
2852         parent_ratio = 1.0;
2853       }
2854     }
2855     node_timing.pipeline_ratio = parent_pipeline_ratio * parent_ratio;
2856   }
2857 }
2858 
ComputeNonAsyncInterleaveManyTotalTime(const Node & node)2859 void ModelTiming::ComputeNonAsyncInterleaveManyTotalTime(const Node& node) {
2860   DCHECK(timing_nodes_.contains(&node));
2861   auto& node_timing = timing_nodes_[&node];
2862   double input_total_time_nsec = 0.0;
2863   for (auto input : node.inputs()) {
2864     if (input->IsAsync()) {
2865       continue;
2866     }
2867     if (!input->autotune() || input->num_elements() <= 0) {
2868       continue;
2869     }
2870     DCHECK(timing_nodes_.contains(input.get()))
2871         << "Input " << input->long_name() << " of node " << node.long_name()
2872         << " has no timing node.";
2873 
2874     input_total_time_nsec += timing_nodes_[input.get()].total_time_nsec;
2875   }
2876   node_timing.total_time_nsec =
2877       node_timing.self_time_nsec + input_total_time_nsec * node.Ratio();
2878 }
2879 
ComputeAsyncInterleaveManyTotalTime(const Node & node)2880 void ModelTiming::ComputeAsyncInterleaveManyTotalTime(const Node& node) {
2881   DCHECK(timing_nodes_.contains(&node));
2882   auto& node_timing = timing_nodes_[&node];
2883   double max_input_total_time_nsec = 0.0;
2884   double sum_input_throughput = 0.0;
2885   auto inputs = node.inputs();
2886   // `ParallelInterleave` is often used to interleave processing of datasets
2887   // generated from the first input, e.g. reading from IO where the first input
2888   // has the list of all filenames. The first input is typically not the
2889   // bottleneck. We exclude the timing of the first input in the throughput
2890   // computation of the remaining input. It also excluded from the total time
2891   // computation of the async interleave node.
2892   auto input = std::next(inputs.begin());
2893   // `num_active_inputs` holds the number of inputs that the
2894   // `ParallelInterleave` is reading from, not including those that are warm
2895   // starting, which can be detected by checking the value of `autotune()`. It
2896   // also does not count async inputs because they would be in their own
2897   // stages. This number is typically the same as `cycle_length`. It will be
2898   // used below to scale the throughput of inputs if `cycle_length` is smaller
2899   // than `num_active_inputs`.
2900   int num_active_inputs = 0;
2901   for (; input != inputs.end(); ++input) {
2902     if ((*input)->IsAsync()) {
2903       continue;
2904     }
2905     if (!(*input)->autotune() || (*input)->num_elements() <= 0) {
2906       continue;
2907     }
2908     DCHECK(timing_nodes_.contains((*input).get()))
2909         << "Input " << (*input)->long_name() << " of node " << node.long_name()
2910         << " has no timing node.";
2911     auto input_total_time_nsec = timing_nodes_[(*input).get()].total_time_nsec;
2912     max_input_total_time_nsec =
2913         std::max(input_total_time_nsec, max_input_total_time_nsec);
2914     if (input_total_time_nsec > 0.0) {
2915       sum_input_throughput += 1.0 / input_total_time_nsec;
2916     }
2917     ++num_active_inputs;
2918   }
2919   auto parallelism_param = node.ParameterValue(kParallelism);
2920   double parallelism = num_active_inputs;
2921   if (parallelism_param.ok()) {
2922     parallelism = parallelism_param.ValueOrDie();
2923   }
2924   // After cl/445005635, there should always be `deterministic` parameter for an
2925   // ASYNC_INTERLEAVE_MANY node. The "not-ok" check is to allow the code to work
2926   // with protos saved and restored before that CL. Similarly for `cycle_length`
2927   // with cl/436244658.
2928   auto deterministic_param = node.ParameterValue(kDeterministic);
2929   bool deterministic = false;
2930   if (deterministic_param.ok()) {
2931     deterministic = deterministic_param.ValueOrDie() == 1.0;
2932   }
2933   auto cycle_length_param = node.ParameterValue(kCycleLength);
2934   double cycle_length = num_active_inputs;
2935   if (cycle_length_param.ok()) {
2936     cycle_length = cycle_length_param.ValueOrDie();
2937   }
2938   double input_total_time_nsec = 0.0;
2939   if (deterministic) {
2940     // If deterministic = true, then the total time is `max input total time /
2941     // min(parallelism, cycle_length)`.
2942     input_total_time_nsec =
2943         max_input_total_time_nsec / std::min(parallelism, cycle_length);
2944   } else if (sum_input_throughput > 0.0) {
2945     // If deterministic = false, then the total time is
2946     // `1/sum_input_throughput`. Scale the throughput according to `parallelism`
2947     // and `cycle_length` if `cycle_length` or `parallelism` is smaller than
2948     // active inputs. `cycle_length` and `parallelism` could theoretically be
2949     // larger than active inputs when some inputs are async and some are sync.
2950     if (std::min(cycle_length, parallelism) < num_active_inputs) {
2951       sum_input_throughput *=
2952           std::min(parallelism, cycle_length) / num_active_inputs;
2953     }
2954     input_total_time_nsec = 1.0 / sum_input_throughput;
2955   }
2956   node_timing.total_time_nsec =
2957       node_timing.self_time_nsec + input_total_time_nsec;
2958 }
2959 
ComputeTotalTimes(const Node::NodeVector & reverse_bfs_nodes)2960 void ModelTiming::ComputeTotalTimes(const Node::NodeVector& reverse_bfs_nodes) {
2961   for (const auto& node : reverse_bfs_nodes) {
2962     ComputeNodeTotalTime(*(node.get()));
2963   }
2964 }
2965 
ComputeNodeTotalTime(const Node & node)2966 void ModelTiming::ComputeNodeTotalTime(const Node& node) {
2967   NodeTiming& node_timing = timing_nodes_[&node];
2968   node_timing.self_time_nsec = node.ComputeSelfTime();
2969   if (!node.autotune() || node.num_elements() <= 0) {
2970     return;
2971   }
2972 #if !defined(IS_MOBILE_PLATFORM)
2973   // This block of code is defined only for non-mobile platform because mobile
2974   // platform lacks RTTI, i.e. the use of `dynamic_cast`.
2975   if (dynamic_cast<const AsyncInterleaveMany*>(&node) != nullptr) {
2976     ComputeAsyncInterleaveManyTotalTime(node);
2977   } else {
2978     ComputeNonAsyncInterleaveManyTotalTime(node);
2979   }
2980 #else   // !IS_MOBILE_PLATFORM
2981   ComputeNonAsyncInterleaveManyTotalTime(node);
2982 #endif  // !IS_MOBILE_PLATFORM
2983 }
2984 
GetStageRoots() const2985 std::vector<std::shared_ptr<Node>> ModelTiming::GetStageRoots() const {
2986   auto bfs_nodes = CollectNodes(root_, TraversalOrder::BFS, IsAnyNode);
2987   std::vector<std::shared_ptr<Node>> roots;
2988   if (!bfs_nodes.empty() && !bfs_nodes[0]->IsAsync()) {
2989     roots.push_back(bfs_nodes[0]);
2990   }
2991   for (auto& node : bfs_nodes) {
2992     if (node->IsAsync()) {
2993       roots.push_back(node);
2994     }
2995   }
2996   return roots;
2997 }
2998 
GetStageNodes(std::shared_ptr<Node> stage_root) const2999 std::vector<std::shared_ptr<Node>> ModelTiming::GetStageNodes(
3000     std::shared_ptr<Node> stage_root) const {
3001   return CollectNodes(stage_root, TraversalOrder::BFS, IsSyncNode);
3002 }
3003 
3004 }  // namespace model
3005 }  // namespace data
3006 }  // namespace tensorflow
3007