xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/costs/virtual_scheduler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_replace.h"
25 #include "tensorflow/core/framework/allocation_description.pb.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/framework/tensor_description.pb.h"
30 #include "tensorflow/core/framework/tensor_shape.pb.h"
31 #include "tensorflow/core/grappler/clusters/utils.h"
32 #include "tensorflow/core/grappler/costs/utils.h"
33 #include "tensorflow/core/grappler/op_types.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/strings/numbers.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/util/device_name_utils.h"
40 
41 namespace tensorflow {
42 namespace grappler {
43 
44 const char kAttrInputSrc[] = "input_source_";
45 const char kAttrSrcDevice[] = "send_device";
46 const char kAttrDstDevice[] = "recv_device";
47 const char kAttrTensorName[] = "tensor_name";
48 const char kChannelDevice[] = "Channel";
49 const char kStreaming[] = "_streaming";
50 
51 namespace {
52 
53 using ::tensorflow::strings::HumanReadableNumBytes;
54 
Round2(const float x)55 float Round2(const float x) {
56   // Not using std::round from <cmath> here because not all platforms seem to
57   // support that (specifically Android).
58   return ::round(100.0 * x) / 100.0;
59 }
60 
FindOrCreateZero(const string & op_name,std::map<string,Costs> * op_cost)61 Costs& FindOrCreateZero(const string& op_name,
62                         std::map<string, Costs>* op_cost) {
63   auto it = op_cost->find(op_name);
64   if (it == op_cost->end()) {
65     // Note that default constructor of Costs sets some memory related fields
66     // to unknown values so we should explicitly initialize it with ZeroCosts.
67     it = op_cost->emplace(op_name, Costs::ZeroCosts()).first;
68   }
69   return it->second;
70 }
71 
72 // Key to the cached _Recv ops map, and its hash and predicate structures.
73 struct RecvNodeDescriptor {
74   const NodeDef* node;
75   const int port_num;
76   const string device;
77 
RecvNodeDescriptortensorflow::grappler::__anond2e8ad230111::RecvNodeDescriptor78   RecvNodeDescriptor(const NodeDef* node_, const int port_num_,
79                      const string& device_)
80       : node(node_), port_num(port_num_), device(device_) {}
81 };
82 
83 struct RecvNodeDescriptorHash {
operator ()tensorflow::grappler::__anond2e8ad230111::RecvNodeDescriptorHash84   std::size_t operator()(const RecvNodeDescriptor& recv_node) const {
85     return std::hash<const NodeDef*>()(recv_node.node) ^
86            std::hash<int>()(recv_node.port_num) ^
87            std::hash<string>()(recv_node.device);
88   }
89 };
90 
91 struct RecvNodeDescriptorEqual {
operator ()tensorflow::grappler::__anond2e8ad230111::RecvNodeDescriptorEqual92   bool operator()(const RecvNodeDescriptor& a,
93                   const RecvNodeDescriptor& b) const {
94     return a.node == b.node && a.port_num == b.port_num && a.device == b.device;
95   }
96 };
97 
UpdateDeviceAnnotationState(const NodeDef * node,const NodeState & node_state,DeviceState * device)98 void UpdateDeviceAnnotationState(const NodeDef* node,
99                                  const NodeState& node_state,
100                                  DeviceState* device) {
101   if (node->attr().count(kOutputShapes) == 0) return;
102 
103   int64_t execution_count = node->attr().count(kExecutionCount) == 0
104                                 ? 1
105                                 : node->attr().at(kExecutionCount).i();
106 
107   auto& shape_annotation_stats = device->shape_annotation_stats;
108   shape_annotation_stats.num_ops_annotated += 1;
109   shape_annotation_stats.num_ops_executed += execution_count;
110   shape_annotation_stats.num_ops_executed_more_than_once +=
111       execution_count > 1 ? 1 : 0;
112   shape_annotation_stats.num_ops_with_incompatible_shapes +=
113       node_state.shape_incompatible ? 1 : 0;
114   shape_annotation_stats.num_ops_with_dynamic_shapes +=
115       (execution_count > 1 && node->attr().count(kOutputSame) == 0) ? 1 : 0;
116 }
117 
IsStreamingPort(const NodeDef & node,const int port)118 bool IsStreamingPort(const NodeDef& node, const int port) {
119   if (!node.attr().contains(kStreaming)) return false;
120 
121   auto& attr_list = node.attr().at(kStreaming).list();
122   bool is_streaming_port = false;
123   if (port >= 0 && port < attr_list.b().size()) {
124     is_streaming_port = attr_list.b(port);
125   }
126   return is_streaming_port;
127 }
128 
129 }  // namespace
130 
AddNode(const NodeDef * node)131 void LIFOManager::AddNode(const NodeDef* node) {
132   // Merge nodes are scheduled with the lowest priority in LIFO manager; virtual
133   // scheduler may run multiple input nodes of Merge (when we don't have
134   // annotation, which is quite common); simply scheduling Merge after one of
135   // its input may break scheduling constraints; some inputs of Merge may be
136   // scheduled after the Merge. So, we place Merge at the beginning of the queue
137   // to guarantee all the inputs of Merge are scheduled before the Merge.
138   if (IsMerge(*node)) {
139     nodes_.push_front(node);
140   } else {
141     nodes_.push_back(node);
142   }
143 }
144 
GetCurrNode()145 const NodeDef* LIFOManager::GetCurrNode() {
146   CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
147   if (curr_pos_ == nodes_.end()) {
148     curr_pos_ = --(nodes_.rbegin().base());  // Last one in the list.
149   }
150   // Once curr_pos_ is set to a valid entry in the list, we keep using the
151   // cached curr_pos_ until RemoveCurrNode() is called. AddNode() will not
152   // change the GetCurrNode() return value.
153   return *curr_pos_;
154 }
155 
RemoveCurrNode()156 void LIFOManager::RemoveCurrNode() {
157   // Make sure we have curr_pos_ ready to be removed.
158   GetCurrNode();
159   // Note curr_pos_ may not be pointing the last element if some nodes are
160   // added.
161   nodes_.erase(curr_pos_);
162 
163   curr_pos_ = nodes_.end();  // Reset curr_pos_.
164 }
165 
HeapReadyManager()166 HeapReadyManager::HeapReadyManager() : ReadyNodeManager() {
167   std::make_heap(nodes_.begin(), nodes_.end());
168 }
169 
Init(const std::unordered_map<const NodeDef *,NodeState> * node_map)170 Status HeapReadyManager::Init(
171     const std::unordered_map<const NodeDef*, NodeState>* node_map) {
172   // Resets the node state since different instances of the scheduler can reuse
173   // the same node_manager.
174   node_map_ = node_map;
175   nodes_.clear();
176   curr_node_ = nullptr;
177 
178   // Sets up the comparator for the heap.
179   greater_ = Greater();
180 
181   return OkStatus();
182 }
183 
AddNode(const NodeDef * node)184 void HeapReadyManager::AddNode(const NodeDef* node) {
185   // push_heap in AddNode and pop_heap in RemoveCurrNode() guarantees that the
186   // first element is the node with minimum time_ready.
187   nodes_.push_back(node);
188   std::push_heap(nodes_.begin(), nodes_.end(), greater_);
189 }
190 
GetCurrNode()191 const NodeDef* HeapReadyManager::GetCurrNode() {
192   if (curr_node_) return curr_node_;
193   if (nodes_.empty()) {
194     CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
195   }
196   const std::string node_name = nodes_.front()->name();
197   // Next time we call GetCurrNode(), it just returns the cached copy
198   // curr_node_, until we call the RemoveCurrNode().
199   curr_node_ = nodes_.front();
200   // Remove current node from the heap immediately. Because if we wait until
201   // later, the heap could have gotten re-organized if AddNode is called. The
202   // current node is anyways cached, incase GetCurrNode() is called again.
203   std::pop_heap(nodes_.begin(), nodes_.end(), greater_);
204   nodes_.pop_back();
205   return curr_node_;
206 }
207 
RemoveCurrNode()208 void HeapReadyManager::RemoveCurrNode() {
209   if (curr_node_) {
210     // If cached copy exists, remove that.
211     // Reset curr_node_ so that GetCurrNode() finds another node.
212     curr_node_ = nullptr;
213   } else {
214     // If cached copy not present, then remove entry from the heap queue.
215     std::pop_heap(nodes_.begin(), nodes_.end(), greater_);
216     nodes_.pop_back();
217   }
218 }
219 
Empty() const220 bool HeapReadyManager::Empty() const {
221   return nodes_.empty() && curr_node_ == nullptr;
222 }
223 
FirstReadyCmp(const std::unordered_map<const NodeDef *,NodeState> * node_map,const NodeDef * a,const NodeDef * b)224 bool FirstReadyCmp(
225     const std::unordered_map<const NodeDef*, NodeState>* node_map,
226     const NodeDef* a, const NodeDef* b) {
227   if (node_map->at(a).time_ready == node_map->at(b).time_ready) {
228     // Use Node name as tie-breaker for deterministic node scheduling.
229     return a->name().compare(b->name()) > 0;
230   } else {
231     // Note: we need a node with minimum time_ready, not maximum; hence, using
232     // a > b for comparison function.
233     return node_map->at(a).time_ready > node_map->at(b).time_ready;
234   }
235 }
236 
237 std::function<bool(const NodeDef*, const NodeDef*)>
Greater()238 FirstReadyManager::Greater() {
239   auto greater = [this](const NodeDef* a, const NodeDef* b) -> bool {
240     return FirstReadyCmp(node_map_, a, b);
241   };
242   return greater;
243 }
244 
245 std::function<bool(const NodeDef*, const NodeDef*)>
Greater()246 PriorityReadyManager::Greater() {
247   auto greater = [this](const NodeDef* a, const NodeDef* b) -> bool {
248     auto pri_a = node_priority_.at(a->name());
249     auto pri_b = node_priority_.at(b->name());
250     if (pri_a == pri_b) {
251       // Fallback to default (FirstReady) behaviour.
252       return FirstReadyCmp(node_map_, a, b);
253     }
254     return pri_a > pri_b;
255   };
256   return greater;
257 }
258 
AddNode(const NodeDef * node)259 void PriorityReadyManager::AddNode(const NodeDef* node) {
260   if (node_priority_.count(node->name()) == 0) {
261     VLOG(3) << "Priority of node " << node->name() << " not found.";
262     node_priority_[node->name()] = 0;
263   }
264   HeapReadyManager::AddNode(node);
265 }
266 
SetPriority(const std::unordered_map<string,int> & node_priority)267 Status PriorityReadyManager::SetPriority(
268     const std::unordered_map<string, int>& node_priority) {
269   node_priority_ = node_priority;
270   return OkStatus();
271 }
272 
CompositeNodeManager()273 CompositeNodeManager::CompositeNodeManager()
274     : ReadyNodeManager(), send_manager_(), recv_manager_() {}
275 
Init(const std::unordered_map<const NodeDef *,NodeState> * node_map)276 Status CompositeNodeManager::Init(
277     const std::unordered_map<const NodeDef*, NodeState>* node_map) {
278   node_map_ = node_map;
279   TF_RETURN_IF_ERROR(send_manager_.Init(node_map));
280   TF_RETURN_IF_ERROR(recv_manager_.Init(node_map));
281   curr_node_ = nullptr;
282   return OkStatus();
283 }
284 
AddNode(const NodeDef * node)285 void CompositeNodeManager::AddNode(const NodeDef* node) {
286   if (IsSend(*node)) {
287     send_manager_.AddNode(node);
288   } else if (IsRecv(*node)) {
289     recv_manager_.AddNode(node);
290   } else {
291     const auto& device = node_map_->at(node).device_name;
292     ops_lifo_map_[device].AddNode(node);
293   }
294 }
295 
GetCurrNode()296 const NodeDef* CompositeNodeManager::GetCurrNode() {
297   if (curr_node_) return curr_node_;
298 
299   // Per-device LIFO for normal ops (not _Send / _Recv),
300   // FirstReady for _Send and _Recv (separately),
301   // Globally (among the LIFO-selected ops from each device and _Send and
302   // _Recv) FirstReady,
303   // Priority order: _Send, _Recv, and then the rest, if time_ready is equal.
304   std::vector<std::pair<const NodeDef*, Costs::Duration>> candidates;
305   for (auto& ops_lifo : ops_lifo_map_) {
306     if (!ops_lifo.second.Empty()) {
307       const auto* op = ops_lifo.second.GetCurrNode();
308       candidates.emplace_back(op, node_map_->at(op).time_ready);
309     }
310   }
311   if (!send_manager_.Empty()) {
312     const auto* send = send_manager_.GetCurrNode();
313     candidates.emplace_back(send, node_map_->at(send).time_ready);
314   }
315   if (!recv_manager_.Empty()) {
316     const auto* recv = recv_manager_.GetCurrNode();
317     candidates.emplace_back(recv, node_map_->at(recv).time_ready);
318   }
319   CHECK(!candidates.empty());
320   auto first_ready = std::min_element(
321       candidates.begin(), candidates.end(),
322       [](const std::pair<const NodeDef*, Costs::Duration>& a,
323          const std::pair<const NodeDef*, Costs::Duration>& b) {
324         if (a.second == b.second) {
325           // Note that there can be only 1 Send and only 1 Recv in candidates,
326           // at most; hence, score is 2 for Send, 1 for Recv, and 0 for a
327           // normap op, and a_score and b_score are equal only if both are
328           // normal ops.
329           int a_score = 2 * IsSend(*a.first) + IsRecv(*a.first);
330           int b_score = 2 * IsSend(*b.first) + IsRecv(*b.first);
331           if (a_score == b_score) {
332             // Both are normal ops; use node name as tie breaker.
333             return a.first->name().compare(b.first->name()) < 0;
334           } else {
335             // Prioritize by op type: _Send, _Recv, and normap ops.
336             return a_score > b_score;
337           }
338         } else {
339           return a.second < b.second;
340         }
341       });
342   // Next time we call GetCurrNode(), it just returns the cached one,
343   // curr_node_ until we call RemovCurrNode().
344   curr_node_ = first_ready->first;
345 
346   return curr_node_;
347 }
348 
RemoveCurrNode()349 void CompositeNodeManager::RemoveCurrNode() {
350   const auto* node = GetCurrNode();
351   if (IsSend(*node)) {
352     send_manager_.RemoveCurrNode();
353   } else if (IsRecv(*node)) {
354     recv_manager_.RemoveCurrNode();
355   } else {
356     const auto device = node_map_->at(node).device_name;
357     ops_lifo_map_[device].RemoveCurrNode();
358   }
359   // Reset curr_node_ so that GetCurrNode() finds another node.
360   curr_node_ = nullptr;
361 }
362 
Empty() const363 bool CompositeNodeManager::Empty() const {
364   // Empty if all the ready managers are empty.
365   bool empty = true;
366   for (const auto& ops_lifo : ops_lifo_map_) {
367     empty &= ops_lifo.second.Empty();
368   }
369   return empty && send_manager_.Empty() && recv_manager_.Empty();
370 }
371 
ReadyNodeManagerFactory(const string & ready_node_manager)372 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
373     const string& ready_node_manager) {
374   if (ready_node_manager == "FIFO") {
375     return std::make_unique<FIFOManager>();
376   } else if (ready_node_manager == "LIFO") {
377     return std::make_unique<LIFOManager>();
378   } else if (ready_node_manager == "FirstReady") {
379     return std::make_unique<FirstReadyManager>();
380   } else if (ready_node_manager == "Composite") {
381     return std::make_unique<CompositeNodeManager>();
382   }
383   LOG(FATAL) << "Not a valid ready node manager: " << ready_node_manager;
384   return nullptr;
385 }
386 
~SchedulerState()387 SchedulerState::~SchedulerState() {}
388 
SchedulerState(const bool use_static_shapes,const bool use_aggressive_shape_inference,Cluster * cluster,std::unique_ptr<VirtualPlacer> placer)389 SchedulerState::SchedulerState(const bool use_static_shapes,
390                                const bool use_aggressive_shape_inference,
391                                Cluster* cluster,
392                                std::unique_ptr<VirtualPlacer> placer)
393     : graph_costs_(Costs::ZeroCosts()),
394       cluster_(cluster),
395       use_static_shapes_(use_static_shapes),
396       use_aggressive_shape_inference_(use_aggressive_shape_inference),
397       placer_(std::move(placer)) {
398   DCHECK(placer_);  // check if the pointer is valid.
399   graph_costs_.num_ops_total = 0;
400   initialized_ = false;
401   track_mem_usage_snapshot_ = VLOG_IS_ON(1);
402 }
403 
Init(const GrapplerItem * item,std::vector<const NodeDef * > * initial_nodes,bool create_explicit_channel_device)404 Status SchedulerState::Init(const GrapplerItem* item,
405                             std::vector<const NodeDef*>* initial_nodes,
406                             bool create_explicit_channel_device) {
407   initialized_ = false;
408 
409   // Clear all internal states so that the SchedulerState is reusable for
410   // different GrapplerItems
411   node_map_.clear();
412   device_.clear();
413   additional_nodes_.clear();
414 
415   graph_costs_ = Costs::ZeroCosts();
416   graph_costs_.num_ops_total = 0;
417   op_to_cost_.clear();
418 
419   op_counts_.clear();
420   op_costs_.clear();
421 
422   initial_nodes->clear();
423 
424   // Constructs graph properties and performs shape inference.
425   graph_properties_ = std::make_unique<GraphProperties>(*item);
426   // TODO(safeen,dyoon): Will we ever use InferDynamically? If not we may want
427   // to get rid of use_static_shapes_ and cluster_.
428   if (use_static_shapes_) {
429     TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
430         true, use_aggressive_shape_inference_, true));
431   } else {
432     TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_));
433   }
434 
435   grappler_item_ = item;
436   const auto& graph = grappler_item_->graph;
437   const auto& fetch_nodes = grappler_item_->fetch;
438   std::set<string> feed_nodes;
439 
440   for (const auto& f : grappler_item_->feed) {
441     auto iter_and_inserted_flag = feed_nodes.insert(f.first);
442     QCHECK(iter_and_inserted_flag.second)
443         << "Duplicate feed node found: " << f.first;
444   }
445 
446   // Get the nodes that would run to output fetch_nodes.
447   std::unordered_map<string, const NodeDef*> name_to_node;
448   std::vector<const NodeDef*> fetch_fanin_nodes;
449   TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph, fetch_nodes, &name_to_node,
450                                             &fetch_fanin_nodes));
451 
452   // Once ComputeTransitiveFanin is complete, only the nodes that can be reached
453   // from the fetch nodes are scheduled. So the scheduled nodes should be
454   // exactly the same as those executed for real. One possible discrepancy could
455   // be the control flow nodes, where tf only executes one path.
456 
457   // Traverses the graph to record _Send nodes.
458   // TODO(dyoon): Instead of identifying _Send node here manually, add _Send
459   // to _Recv as control dependency when creating GrapplerItem.
460   std::unordered_map<string, const NodeDef*> name_to_send;
461   for (const auto& node : graph.node()) {
462     if (IsSend(node)) {
463       const auto& attr = node.attr();
464       name_to_send[attr.at("tensor_name").s()] = &node;
465     }
466   }
467 
468   // To reuse _Recv ops.
469   std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescriptorHash,
470                      RecvNodeDescriptorEqual>
471       cached_recv_nodes;
472 
473   // Build node_map; for each node, create its NodeState and connect its inputs
474   // and outputs.
475   for (const auto* curr_node : fetch_fanin_nodes) {
476     auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
477     const string curr_node_device = DeviceName(curr_node);
478     std::vector<string> inputs;
479     if (IsRecv(*curr_node)) {
480       const auto& attr = curr_node->attr();
481       if (attr.count("tensor_name")) {
482         const auto& send_node_name = attr.at("tensor_name").s();
483         auto it = name_to_send.find(send_node_name);
484         // If there is a _Send associated with the curr_node (_Recv), add it as
485         // input.
486         if (it != name_to_send.end()) {
487           const NodeDef* send = it->second;
488           inputs = {send->name()};
489         }
490       }
491     } else {
492       for (const string& input : curr_node->input()) {
493         inputs.push_back(input);
494       }
495     }
496     for (const string& input_node_name : inputs) {
497       // Note that input_node_name may be in <prefix><node_name>:<port_num>
498       // format, where <prefix> (e.g., "^" for control dependency) and
499       // ":<port_num>" may be omitted. NodeName() extracts only the node_name.
500       const NodeDef* input_node = name_to_node[NodeName(input_node_name)];
501 
502       CHECK(input_node);
503       const string in_device = DeviceName(input_node);
504       const auto input_node_port_num = NodePosition(input_node_name);
505 
506       // Control dependencies should be treated as high priority. Current
507       // Channel device doesn't model a separate virtual channel for control v/s
508       // data transfers. So in the interim, it may be okay to let control
509       // dependencies magically flow across devices bypassing the channel
510       // device.
511       if (curr_node_device == in_device || IsControlInput(input_node_name)) {
512         // Same device: connect input_node and curr_node directly.
513         curr_node_state.inputs.push_back(
514             std::make_pair(input_node, input_node_port_num));
515         auto& input_node_state = GetNodeStateOrCreateIt(input_node);
516         input_node_state.outputs[input_node_port_num].push_back(curr_node);
517       } else {
518         RecvNodeDescriptor recv_node(input_node, input_node_port_num,
519                                      curr_node_device);
520         auto it = cached_recv_nodes.find(recv_node);
521         if (it != cached_recv_nodes.end()) {
522           // Different device, but found an already-cached copy (a _Recv op);
523           // connect the _Recv to curr_node.
524           const NodeDef* recv_op = it->second;
525           // recv_op's output port is hard-coded to zero.
526           curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
527           auto& input_node_state = node_map_.at(recv_op);
528           input_node_state.outputs[0].push_back(curr_node);
529         } else {
530           // Different device, no cached copy; transfer input_node to the
531           // curr_node's device.
532           auto send_and_recv =
533               CreateSendRecv(input_node, curr_node, input_node, input_node_name,
534                              create_explicit_channel_device);
535           // Note that CreateSendRecv() already connected input/output between
536           // _Send and _Recv ops.
537           const auto* send = send_and_recv.first;
538           const auto* recv = send_and_recv.second;
539           // recv_op's output port is hard-coded to zero.
540           curr_node_state.inputs.push_back(std::make_pair(recv, 0));
541           auto& input_node_state = GetNodeStateOrCreateIt(input_node);
542           input_node_state.outputs[input_node_port_num].push_back(send);
543 
544           // Cache the _Recv op for future use.
545           cached_recv_nodes[recv_node] = recv;
546         }
547       }
548     }
549 
550     // Special case: given feed nodes are ready at time 0.
551     const bool given_as_feed =
552         feed_nodes.find(curr_node->name()) != feed_nodes.end();
553 
554     // Default case: node without inputs are ready at time 0.
555     // Note that we check inputs vector which may be different to
556     // curr_node->input(); e.g., we add Send as input to Recv.
557     const bool has_no_inputs = inputs.empty();
558 
559     if (given_as_feed || has_no_inputs) {
560       curr_node_state.time_ready = Costs::Duration();
561       initial_nodes->push_back(curr_node);
562       VLOG(3) << "Added ready node: " << curr_node->name();
563     }
564     feed_nodes.erase(curr_node->name());
565 
566     if (IsPersistent(*curr_node)) {
567       auto& device_state = device_[curr_node_device];
568       for (int port_num = 0,
569                port_num_end = curr_node_state.output_properties.size();
570            port_num < port_num_end; ++port_num) {
571         device_state.persistent_nodes.insert(
572             std::make_pair(curr_node, port_num));
573       }
574     }
575   }
576 
577   if (initial_nodes->empty()) {
578     return errors::InvalidArgument("No ready nodes in the graph.");
579   }
580 
581   if (!feed_nodes.empty()) {
582     // This isn't always a bug: when the caller hasn't specified the exact list
583     // of feed and fetch nodes, by default we consider all placeholders as feed
584     // nodes, but some of them may not be needed for the default fetch node.
585     VLOG(1) << "Some feed nodes were not consumed by the fetch fanin: "
586             << absl::StrJoin(feed_nodes, ",");
587   }
588 
589   initialized_ = true;
590   return OkStatus();
591 }
592 
MaybeUpdateInputOutput(const NodeDef * node)593 void SchedulerState::MaybeUpdateInputOutput(const NodeDef* node) {
594   CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
595   // This method is called when NodeState is created and adds input and output
596   // properties for a few exceptional cases that GraphProperties cannot provide
597   // input/output properties.
598   if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
599     // _Send and _Recv ops created from SchedulerState have kAttrInputSrc
600     // attr; normal _Send and _Recv ops (from the input graph) do not have that
601     // attr.
602     auto& node_state = node_map_[node];
603     auto& inputs = node_state.input_properties;
604     auto& outputs = node_state.output_properties;
605 
606     // _Send and _Recv ops are created from SchedulerState, so
607     // there should be no inputs TensorProperties.
608     CHECK(inputs.empty());
609     CHECK(outputs.empty());
610     const auto& attr = node->attr();
611     // This is the original input source to the _Send and _Recv, and this
612     // string includes "^" if it was control dependency, and output port
613     /// (e.g., ":2") if the input source had multiple outputs.
614     const auto& input_source_name = attr.at(kAttrInputSrc).s();
615     if (IsControlInput(input_source_name)) {
616       // Control dependency; regardless of the input source tensor size,
617       // send 4B.
618       OpInfo::TensorProperties control_message;
619       control_message.set_dtype(DT_FLOAT);
620       control_message.mutable_shape()->add_dim()->set_size(1);
621       auto* value = control_message.mutable_value();
622       value->add_float_val(1);
623       inputs.push_back(control_message);
624       outputs.push_back(control_message);
625     } else {
626       const auto& output_properties =
627           graph_properties_->GetOutputProperties(NodeName(input_source_name));
628       // Like with HasInputProperties, if a node does not have output
629       // properties, it's likely it was pruned during the shape inference run.
630       if (!output_properties.empty()) {
631         const auto input_node_port_num = NodePosition(input_source_name);
632         // Use the input source's output property as _Send and _Recv's input
633         // property.
634         CHECK_GT(output_properties.size(), input_node_port_num);
635         inputs.push_back(output_properties[input_node_port_num]);
636         outputs.push_back(output_properties[input_node_port_num]);
637       }
638     }
639   }
640 }
641 
DeviceName(const NodeDef * node) const642 string SchedulerState::DeviceName(const NodeDef* node) const {
643   return placer_->get_canonical_device_name(*node);
644 }
645 
SanitizedDeviceName(const NodeDef * node) const646 string SchedulerState::SanitizedDeviceName(const NodeDef* node) const {
647   // Replace the ":" characters that may be present in the device name with "_".
648   // This makes it possible to then use the resulting string in a node name.
649   return absl::StrReplaceAll(placer_->get_canonical_device_name(*node),
650                              {{":", "_"}});
651 }
652 
ChannelDeviceName(const NodeDef * from,const NodeDef * to) const653 string SchedulerState::ChannelDeviceName(const NodeDef* from,
654                                          const NodeDef* to) const {
655   CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
656   return absl::StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from),
657                       "_to_", SanitizedDeviceName(to));
658 }
659 
CreateSendRecv(const NodeDef * from,const NodeDef * to,const NodeDef * input_node,const string & input_name,bool create_channel_device)660 std::pair<const NodeDef*, const NodeDef*> SchedulerState::CreateSendRecv(
661     const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
662     const string& input_name, bool create_channel_device) {
663   CHECK(!initialized_) << "CreateSendRecv is called after Init().";
664 
665   // Connect "from" node to "to" node with _Send and _Recv such that
666   // from -> _Send -> _Recv -> to.
667   // _Send is placed on "Channel" device, and _Recv is on the same device
668   // as "to" node.
669   // input_node_name is the string from the "to" node to identify which output
670   // we get from the "from" node.
671 
672   // Note that we use NodeState for scheduling, so _Send and _Recv
673   // NodeDefs created here need not be correct: in terms of name,
674   // input names, attrs, etc.
675 
676   auto input_node_port_num = NodePosition(input_name);
677   string src_name;
678   bool control_input = false;
679   if (input_node_port_num >= 0) {
680     src_name = absl::StrCat(from->name(), "_", input_node_port_num);
681   } else {
682     src_name = absl::StrCat(from->name(), "_minus1");
683     control_input = true;
684   }
685 
686   // _Send op.
687   auto* send = new NodeDef();
688   send->set_name("Send_" + src_name + "_from_" + SanitizedDeviceName(from) +
689                  "_to_" + SanitizedDeviceName(to));
690   send->set_op("_Send");
691   send->add_input(from->name());
692   auto send_device =
693       create_channel_device ? ChannelDeviceName(from, to) : DeviceName(from);
694   send->set_device(send_device);
695   auto& send_attr = *(send->mutable_attr());
696   send_attr[kAttrInputSrc].set_s(input_name);
697   send_attr[kAttrSrcDevice].set_s(DeviceName(from));
698   send_attr[kAttrDstDevice].set_s(DeviceName(to));
699   // GraphDef generated by AutoGrappler has tensor_name field when removing
700   // _Send/_Recv nodes.
701   if (input_node->attr().count(kAttrTensorName)) {
702     send_attr[kAttrTensorName].set_s(
703         input_node->attr().at(kAttrTensorName).s());
704   }
705 
706   // _Recv op.
707   auto* recv = new NodeDef();
708   recv->set_name("Recv_" + src_name + "_on_" + SanitizedDeviceName(to));
709   recv->set_op("_Recv");
710   recv->add_input(send->name());
711   recv->set_device(DeviceName(to));
712   auto& recv_attr = *(recv->mutable_attr());
713   recv_attr[kAttrInputSrc].set_s(input_name);
714   if (input_node->attr().count(kAttrTensorName)) {
715     recv_attr[kAttrTensorName].set_s(
716         input_node->attr().at(kAttrTensorName).s());
717   }
718 
719   // Propagate the streaming attribute to the send/recv nodes.
720   if (from->attr().contains(kStreaming) && !control_input) {
721     if (input_node_port_num >= from->attr().at(kStreaming).list().b_size()) {
722       LOG(ERROR)
723           << from->name()
724           << " port index larger than length of _streaming attribute list.";
725     } else if (from->attr().at(kStreaming).list().b(input_node_port_num)) {
726       send_attr[kStreaming].mutable_list()->add_b(true);
727       recv_attr[kStreaming].mutable_list()->add_b(true);
728     }
729   }
730 
731   // NodeState for _Send op.
732   auto& send_node_state = GetNodeStateOrCreateIt(send);
733   send_node_state.device_name = send->device();  // Set Channel device.
734   send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num));
735   send_node_state.outputs[0].push_back(recv);
736 
737   // NodeState for _Recv op.
738   auto& recv_node_state = GetNodeStateOrCreateIt(recv);
739   recv_node_state.inputs.push_back(std::make_pair(send, 0));
740   recv_node_state.outputs[0].push_back(to);
741 
742   // Keep the created nodes.
743   additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send));
744   additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(recv));
745 
746   // Return _Send and _Recv.
747   return std::make_pair(send, recv);
748 }
749 
CreateOpContext(const NodeDef * node) const750 OpContext SchedulerState::CreateOpContext(const NodeDef* node) const {
751   // Get the device from the placer.
752   DeviceProperties device;
753   device = placer_->get_device(*node);
754 
755   // Special case for _Send op.
756   if (IsSend(*node)) {
757     device.set_type(kChannelDevice);
758   }
759 
760   // Construct OpContext.
761   OpContext op_context;
762   const auto& node_state = node_map_.at(node);
763   op_context.name = node->name();
764   op_context.device_name = node_state.device_name;
765   auto& op_info = op_context.op_info;
766   op_info.set_op(node->op());
767   *op_info.mutable_attr() = node->attr();
768   for (auto& input : node_state.input_properties) {
769     *op_info.add_inputs() = input;
770   }
771   for (auto& output : node_state.output_properties) {
772     *op_info.add_outputs() = output;
773   }
774   op_info.mutable_device()->Swap(&device);
775 
776   if (grappler_item_->graph.has_library()) {
777     op_context.function_library = &grappler_item_->graph.library();
778   }
779   return op_context;
780 }
781 
GetNodeStateOrCreateIt(const NodeDef * node)782 NodeState& SchedulerState::GetNodeStateOrCreateIt(const NodeDef* node) {
783   CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
784 
785   auto it = node_map_.find(node);
786   if (it != node_map_.end()) {
787     return it->second;
788   }
789 
790   // Not found; create a NodeState for this node.
791   it = node_map_.emplace(node, NodeState()).first;
792   auto& node_state = it->second;
793   node_state.input_properties =
794       graph_properties_->GetInputProperties(node->name());
795   node_state.output_properties =
796       graph_properties_->GetOutputProperties(node->name());
797   node_state.shape_incompatible =
798       graph_properties_->CheckShapeIncompatible(node->name());
799 
800   // Some ops may need further processing to the input / output properties:
801   // _Send and _Recv.
802   MaybeUpdateInputOutput(node);
803 
804   if (!IsSend(*node)) {
805     node_state.device_name = DeviceName(node);
806     // For _Send op, device_name will be set to Channel in CreateSendRecv().
807   }
808 
809   // Initialize output port related data:
810   // Assume the size of OutputProperties represents the number of output ports
811   // of this node.
812   for (size_t i = 0; i < node_state.output_properties.size(); ++i) {
813     node_state.time_no_references[i] = Costs::Duration::max();
814     node_state.num_outputs_executed[i] = 0;
815     // Populate an empty vector for each port. The caller will add nodes
816     // that use this port as input.
817     node_state.outputs[i] = {};
818   }
819   // Port_num -1 is for control dependency.
820   node_state.time_no_references[-1] = Costs::Duration::max();
821   node_state.num_outputs_executed[-1] = 0;
822   node_state.outputs[-1] = {};
823 
824   // Initialize time_scheduled to infinity, so we know whether it has been
825   // assigned a non-default value later.
826   node_state.time_scheduled = Costs::Duration().infinity();
827 
828   return it->second;
829 }
830 
GetOutputNodes(const NodeDef * node,const Costs::Duration & curr_time,std::vector<const NodeDef * > * output_nodes)831 void SchedulerState::GetOutputNodes(const NodeDef* node,
832                                     const Costs::Duration& curr_time,
833                                     std::vector<const NodeDef*>* output_nodes) {
834   // Checks whether the Switch's output slots change over iterations.
835   int slot = -1;
836   if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 &&
837       node->attr().at(kOutputSlots).list().i_size() > 0) {
838     slot = node->attr().at(kOutputSlots).list().i(0);
839     for (int i = 1; i < node->attr().at(kOutputSlots).list().i_size(); ++i) {
840       if (slot != node->attr().at(kOutputSlots).list().i(i)) {
841         slot = -1;
842         break;
843       }
844     }
845   }
846   // Increment num_inputs_ready of the output nodes and maybe add to ready
847   // nodes.
848   auto& node_state = node_map_[node];
849   for (const auto& port_num_output_pair : node_state.outputs) {
850     // If Switch is annotated and its output slots are always the same, we only
851     // schedule the slot that was executed. Otherwise, scheduler both slots.
852     if (slot >= 0 && port_num_output_pair.first != slot) continue;
853 
854     for (auto* output_node : port_num_output_pair.second) {
855       auto& output_state = node_map_[output_node];
856       output_state.num_inputs_ready++;
857       // Execute a node as soon as all its inputs are ready. Merge nodes are
858       // special since they run as soon as one of their inputs becomes
859       // available.
860       int output_state_inputs_size = output_state.inputs.size();
861       if (output_state.num_inputs_ready == output_state_inputs_size ||
862           IsMerge(*output_node)) {
863         // This output node is now ready.
864         output_state.time_ready = curr_time;
865         output_nodes->push_back(output_node);
866         VLOG(3) << "  Add output: " << output_node->name();
867       }
868     }
869   }
870 }
871 
MarkNodeExecuted(const NodeDef * node,const Costs & node_costs,const OpContext & op_context,bool extract_execution_count_attr,const std::string & override_device_name)872 std::vector<const NodeDef*> SchedulerState::MarkNodeExecuted(
873     const NodeDef* node, const Costs& node_costs, const OpContext& op_context,
874     bool extract_execution_count_attr,
875     const std::string& override_device_name) {
876   auto& node_state = node_map_[node];
877   // TODO(dyoon, andiryxu): Consider to revisit node execution w.r.t. Switch and
878   // Merge -- it can create a loop which may include loop-carried dependency,
879   // diverge-merge, and other complex execution patterns.
880   bool previously_executed_merge =
881       IsMerge(*node) && (node_state.time_finished != Costs::Duration::max());
882 
883   // Our approach to modeling loops is to extract the annotated _execution_count
884   // attribute and to multiply node_costs by the value of the attribute. If
885   // the attribute is not found then we assume a default execution count of 1.
886   // Note that in some simulation flows we will perform this multiplication
887   // elsewhere, as such we only perform this multiplication here if
888   // extract_execution_count_attr is true. Otherwise node_costs are unmodified
889   // and we assume the multiplication has been correctly carried out elsewhere.
890   node_state.execution_count = 1;
891 
892   if (extract_execution_count_attr && node->attr().count(kExecutionCount) > 0) {
893     node_state.execution_count = node->attr().at(kExecutionCount).i();
894   }
895 
896   node_state.node_costs = node_costs;
897   // TotalNodeCosts() Should be called after node_costs and execution_count.
898   Costs total_node_costs = node_state.TotalNodeCosts();
899 
900   graph_costs_ = CombineCosts(graph_costs_, total_node_costs);
901   const string& op_name = node->op();
902 
903   auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
904   op_cost = CombineCosts(op_cost, total_node_costs);
905 
906   if (VLOG_IS_ON(2)) {
907     // Also keep track of op counts and costs per op (with their shapes).
908     string node_description = GetOpDescription(op_context.op_info);
909     op_counts_[node_description] += 1;
910     op_costs_[node_description] =
911         std::make_pair(total_node_costs.execution_time.asMicroSeconds().count(),
912                        !node_costs.inaccurate);
913   }
914 
915   std::string device_name = node_state.device_name;
916   if (!override_device_name.empty()) {
917     // N.B. There's a chance that device_name doesn't exist in the device map
918     // (device_), but it's ok because we'll effectively create a new device the
919     // first time a new device is seen.
920     device_name = override_device_name;
921   }
922 
923   // Update node and device states.
924   auto& device = device_[device_name];
925   device.nodes_executed.push_back(node);
926   // Node is scheduled when the device is available AND all the inputs are
927   // ready; hence, time_scheduled is time_ready if time_ready > device curr
928   // time.
929   // NodeState times are assigned infinity at initialization. If they are
930   // still infinity here, we need to assign them. If not, it has been assigned
931   // already, so skip. This latter case may occur when a scheduler in-lines
932   // function calls, and thus schedules only function sub-nodes.
933   if (node_state.time_scheduled == Costs::Duration().infinity()) {
934     node_state.time_scheduled =
935         std::max(device.GetCurrTime(), node_state.time_ready);
936     // Override device curr time with the time_scheduled.
937     device.device_costs.execution_time = node_state.time_scheduled;
938   }
939   device.device_costs = CombineCosts(device.device_costs, total_node_costs);
940   auto curr_time = device.GetCurrTime();
941   node_state.time_finished = curr_time;
942 
943   // Update shape annotation states.
944   UpdateDeviceAnnotationState(node, node_state, &device);
945 
946   // Update device memory usage.
947   if (!IsPersistent(*node)) {
948     for (const auto& port_num_output_pair : node_state.outputs) {
949       int port_num = port_num_output_pair.first;
950       // There's a chance that a specific output is not used at all.
951       if (node_state.outputs[port_num].empty()) {
952         node_state.time_no_references[port_num] = curr_time;
953       } else {
954         // Allow for the possibility that some ports may be persistent even if
955         // the entire node is not labeled persistent.
956         if (node_state.node_costs.persistent_output_ports.contains(port_num)) {
957           continue;
958         }
959 
960         // Streaming outputs do not allocate memory, they are directly consumed
961         // by the target node.
962         if (!IsStreamingPort(*node, port_num)) {
963           // If possible use the node output size calculations done by the
964           // more specific CostEstimator over the general CalculateOutputSize.
965           device.memory_usage += GetOrCalculateOutputSize(node_state, port_num);
966         }
967         device.nodes_in_memory.insert(std::make_pair(node, port_num));
968       }
969     }
970   }
971 
972   // Update device state persistent node map.
973   for (const auto& port : node_costs.persistent_output_ports) {
974     device.persistent_nodes.insert({node, port});
975   }
976 
977   // Update device's per-op cost.
978   auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost);
979   device_op_cost = CombineCosts(device_op_cost, total_node_costs);
980 
981   VLOG(3) << "Op scheduled -- name: " << node->name() << ", op: " << node->op()
982           << ", device: " << node->device()
983           << ", execution_count: " << node_state.execution_count
984           << ", ready: " << node_state.time_ready.count()
985           << ", scheduled: " << node_state.time_scheduled.count()
986           << ", finished: " << node_state.time_finished.count();
987   VLOG(5) << "  Current device memory usage (before deallocation): "
988           << device.memory_usage;
989   std::vector<const NodeDef*> new_nodes;
990   if (previously_executed_merge) {
991     // Skip AddOutputNodesToReadyQueue; this is due to Switch-Merge.
992     VLOG(1) << "node [ " << node->name() << ", " << node->op() << " ] "
993             << "is executed more than once. "
994             << "Skip scheduling its output nodes.";
995   } else {
996     // Checks outputs, and adds ready nodes to queue.
997     GetOutputNodes(node, curr_time, &new_nodes);
998   }
999 
1000   // When op is scheduled, both input and output tensors must be allocated in
1001   // memory. Now that output memory is added, check max memory usage.
1002   if (!IsPersistent(*node)) {
1003     if (device.memory_usage > device.max_memory_usage) {
1004       device.max_memory_usage = device.memory_usage;
1005 
1006       if (track_mem_usage_snapshot_) {
1007         device.mem_usage_snapshot_at_peak = device.nodes_in_memory;
1008       }
1009     }
1010   }
1011 
1012   // Append the current temporary memory usage of the device to the memory usage
1013   // trace.
1014   if (track_mem_usage_snapshot_) {
1015     device.temporary_memory_usage_trace.push_back(
1016         {node->name(), device.memory_usage});
1017   }
1018 
1019   // Increment num_outputs_executed of the input nodes and maybe update memory.
1020   for (const auto& input_port : node_state.inputs) {
1021     auto* input = input_port.first;
1022     auto port = input_port.second;
1023 
1024     auto& input_state = node_map_[input];
1025     input_state.num_outputs_executed[port]++;
1026     int input_state_outputs_size_ = input_state.outputs[port].size();
1027 
1028     // Allow for the possibility that some outputs may be persistent even if the
1029     // entire node is not labeled persistent.
1030     if (input_state.node_costs.persistent_output_ports.contains(port)) continue;
1031 
1032     if (input_state.num_outputs_executed[port] == input_state_outputs_size_ &&
1033         !IsPersistent(*input)) {
1034       // All the outputs are executed; no reference to this output port of
1035       // input node.
1036       input_state.time_no_references[port] = curr_time;
1037       auto& input_device = device_[input_state.device_name];
1038       // If the node input is marked as streaming, then it wasn't allocated
1039       // in memory. A streaming input is still reference counted, but it doesn't
1040       // de-allocate memory.
1041       if (!IsStreamingPort(*input, port)) {
1042         input_device.memory_usage -=
1043             GetOrCalculateOutputSize(input_state, port);
1044       }
1045 
1046       input_device.nodes_in_memory.erase(std::make_pair(input, port));
1047     }
1048   }
1049 
1050   return new_nodes;
1051 }
1052 
Summary() const1053 Costs SchedulerState::Summary() const {
1054   // Overall statement about accuracy
1055   VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with "
1056           << graph_costs_.num_ops_with_unknown_shapes
1057           << " having unknown shapes";
1058 
1059   // Print out basic execution summary.
1060   VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count();
1061   VLOG(1) << "Expected compute time: " << graph_costs_.compute_time.count();
1062   VLOG(1) << "Expected memory time: " << graph_costs_.memory_time.count();
1063   VLOG(1) << "Expected intermediate memory time: "
1064           << graph_costs_.intermediate_memory_time.count();
1065   VLOG(1) << "Expected max memory: " << graph_costs_.max_memory;
1066   VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers;
1067   VLOG(1) << "Expected max per-op streaming buffers: "
1068           << graph_costs_.max_per_op_streaming;
1069 
1070   VLOG(1) << "Per-op execution time / compute time / memory time"
1071           << " / intermediate memory time:";
1072   for (const auto& op_cost_pair : op_to_cost_) {
1073     const auto& op = op_cost_pair.first;
1074     const auto& cost = op_cost_pair.second.execution_time.count();
1075     const auto& compute_cost = op_cost_pair.second.compute_time.count();
1076     const auto& memory_cost = op_cost_pair.second.memory_time.count();
1077     const auto& intermediate_memory_cost =
1078         op_cost_pair.second.intermediate_memory_time.count();
1079     const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
1080     if (cost) {  // Skip printing out zero-cost ops.
1081       VLOG(1) << absl::StrFormat(" + %30s : %c %10d / %10d / %10d / %10d", op,
1082                                  (is_op_cost_accurate ? ' ' : '~'), cost,
1083                                  compute_cost, memory_cost,
1084                                  intermediate_memory_cost);
1085     }
1086   }
1087 
1088   // Print per device summary
1089   VLOG(1) << "Devices:";
1090   Costs critical_path_costs = Costs::ZeroCosts();
1091   std::vector<string> device_names;
1092   device_names.reserve(device_.size());
1093   for (auto& it : device_) {
1094     device_names.push_back(it.first);
1095   }
1096   std::sort(device_names.begin(), device_names.end());
1097 
1098   for (const auto& name : device_names) {
1099     const auto& state = device_.at(name);
1100 
1101     std::map<string, int64_t> op_to_memory;
1102     // First profile only persistent memory usage.
1103     int64_t persistent_memory_usage = 0;
1104     std::set<string> persistent_ops;
1105     for (const auto& node_port : state.persistent_nodes) {
1106       const auto* node = node_port.first;
1107       const auto port = node_port.second;
1108       int64_t output_size = 0;
1109       // Check if the node is in the node_map. It may be that the node executed
1110       // on this device was executed by a different Scheduler.
1111       auto it = node_map_.find(node);
1112       if (it != node_map_.end()) {
1113         output_size = GetOrCalculateOutputSize(it->second, port);
1114       }
1115       persistent_memory_usage += output_size;
1116       op_to_memory[node->op()] += output_size;
1117       persistent_ops.insert(node->op());
1118     }
1119     int64_t max_memory_usage = persistent_memory_usage + state.max_memory_usage;
1120     critical_path_costs.estimated_max_memory_per_device[name] =
1121         max_memory_usage;
1122 
1123     const Costs::NanoSeconds wall_time_ns = state.GetCurrTime();
1124     VLOG(1) << "Device = " << name
1125             << ", num_nodes = " << state.nodes_executed.size()
1126             << ", wall_time_ns = " << wall_time_ns.count() << ", memory usage: "
1127             << "persistent = " << HumanReadableNumBytes(persistent_memory_usage)
1128             << ", peak = " << HumanReadableNumBytes(state.max_memory_usage)
1129             << ", total = " << HumanReadableNumBytes(max_memory_usage)
1130             << ", at the end: " << HumanReadableNumBytes(state.memory_usage);
1131 
1132     // Overall statement about accuracy
1133     VLOG(1) << state.device_costs.num_ops_total
1134             << " ops processed in total, with "
1135             << state.device_costs.num_ops_with_unknown_shapes
1136             << " having unknown shapes";
1137 
1138     // Device shape annotation statistics.
1139     const auto& device_annotation_stats = state.shape_annotation_stats;
1140     if (device_annotation_stats.num_ops_annotated > 0) {
1141       VLOG(1) << device_annotation_stats.num_ops_annotated
1142               << " ops with shape annotation, with "
1143               << device_annotation_stats.num_ops_executed_more_than_once
1144               << " executed more than once, "
1145               << device_annotation_stats.num_ops_with_dynamic_shapes
1146               << " with dynamic shapes, "
1147               << device_annotation_stats.num_ops_with_incompatible_shapes
1148               << " with incompatible shapes, "
1149               << device_annotation_stats.num_ops_executed
1150               << " ops executed in total.";
1151     }
1152 
1153     VLOG(1) << "Per-op execution time / compute time / memory time "
1154             << " / intermediate memory time"
1155             << " (and memory usage at peak memory usage):";
1156 
1157     // Profile non-persistent op memory usage.
1158     for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
1159       const auto* node = node_port.first;
1160       const auto port = node_port.second;
1161       // Check if the node is in the node_map. It may be that the node executed
1162       // on this device was executed by a different Scheduler.
1163       auto it = node_map_.find(node);
1164       if (it != node_map_.end()) {
1165         op_to_memory[node->op()] += GetOrCalculateOutputSize(it->second, port);
1166       }
1167     }
1168     Costs::NanoSeconds total_compute_time_ns;
1169     bool is_total_cost_accurate = true;
1170     for (const auto& op_cost_pair : state.op_to_cost) {
1171       const auto& op = op_cost_pair.first;
1172       const auto& cost = op_cost_pair.second.execution_time.count();
1173       const auto& compute_cost = op_cost_pair.second.compute_time.count();
1174       const auto& memory_cost = op_cost_pair.second.memory_time.count();
1175       const auto& intermediate_memory_cost =
1176           op_cost_pair.second.intermediate_memory_time.count();
1177       total_compute_time_ns += op_cost_pair.second.execution_time;
1178       const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
1179       if (!is_op_cost_accurate) {
1180         is_total_cost_accurate = false;
1181       }
1182 
1183       int64_t op_mem_usage = 0;
1184       auto it = op_to_memory.find(op);
1185       if (it != op_to_memory.end()) {
1186         op_mem_usage = it->second;
1187       }
1188 
1189       const float mem_usage_percent =
1190           max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage)
1191                                : 0.0;
1192       if (cost || mem_usage_percent > 1.0) {
1193         // Print out only non-zero cost ops or ops with > 1% memory usage.
1194         VLOG(1) << absl::StrFormat(
1195                        " + %30s : %c %10d / %10d / %10d / %10d", op.c_str(),
1196                        (is_op_cost_accurate ? ' ' : '~'), cost, compute_cost,
1197                        memory_cost, intermediate_memory_cost)
1198                 << " (" << HumanReadableNumBytes(op_mem_usage) << " ["
1199                 << mem_usage_percent << "%] "
1200                 << (persistent_ops.count(op) > 0 ? ": persistent op)" : ")");
1201       }
1202     }
1203 
1204     int utilization = 0;
1205     if (wall_time_ns.count() > 0) {
1206       utilization = total_compute_time_ns.count() * 100 / wall_time_ns.count();
1207     }
1208     VLOG(1) << "Device = " << name << ", total_compute_time_ns = "
1209             << (is_total_cost_accurate ? "" : "~")
1210             << total_compute_time_ns.count()
1211             << ", utilization = " << utilization << "%";
1212 
1213     if (critical_path_costs.execution_time <= state.GetCurrTime()) {
1214       critical_path_costs = state.device_costs;
1215       critical_path_costs.persistent_memory = persistent_memory_usage;
1216       critical_path_costs.temporary_memory = state.max_memory_usage;
1217       critical_path_costs.max_memory = max_memory_usage;
1218     }
1219   }
1220 
1221   if (VLOG_IS_ON(2)) {
1222     // Also log the op description and their corresponding counts.
1223     VLOG(2) << "Node description, counts, cost:";
1224     for (const auto& item : op_counts_) {
1225       int cost;
1226       bool is_cost_accurate;
1227       std::tie(cost, is_cost_accurate) = op_costs_.at(item.first);
1228       VLOG(2) << "Node: " << item.first << ", Count: " << item.second
1229               << ", Individual Cost: " << (is_cost_accurate ? "" : "~") << cost
1230               << " us";
1231     }
1232   }
1233 
1234   VLOG(1) << "Critical path execution time: "
1235           << critical_path_costs.execution_time.count();
1236   return critical_path_costs;
1237 }
1238 
Summary(RunMetadata * metadata)1239 Costs SchedulerState::Summary(RunMetadata* metadata) {
1240   if (metadata) GenerateRunMetadata(metadata);
1241   return Summary();
1242 }
1243 
GenerateRunMetadata(RunMetadata * metadata)1244 void SchedulerState::GenerateRunMetadata(RunMetadata* metadata) {
1245   // Fill RunMetadata's step_stats and partition_graphs fields.
1246   StepStats* stepstats = metadata->mutable_step_stats();
1247   for (const auto& device : device_) {
1248     GraphDef* device_partition_graph = metadata->add_partition_graphs();
1249     DeviceStepStats* device_stepstats = stepstats->add_dev_stats();
1250     device_stepstats->set_device(device.first);
1251     for (const auto& node_def : device.second.nodes_executed) {
1252       // Only proceed if the node is in the node_map. This is to cover the case
1253       // where a device has executed a node that is not in the node_map of
1254       // this scheduler.
1255       if (node_map_.find(node_def) == node_map_.end()) {
1256         continue;
1257       }
1258       const NodeState& nodestate = node_map_.at(node_def);
1259       NodeExecStats* node_stats = device_stepstats->add_node_stats();
1260       uint64 total_output_size = 0;
1261       uint64_t persistent_output_size = 0;
1262       for (int slot = 0, slot_end = nodestate.output_properties.size();
1263            slot < slot_end; slot++) {
1264         const auto& properties = nodestate.output_properties[slot];
1265         NodeOutput* no = node_stats->add_output();
1266         no->set_slot(slot);
1267         TensorDescription* tensor_descr = no->mutable_tensor_description();
1268         tensor_descr->set_dtype(properties.dtype());
1269         *tensor_descr->mutable_shape() = properties.shape();
1270         // Optional allocation description.
1271         const int64_t tensor_size_requested =
1272             CalculateOutputSize(nodestate.output_properties, slot);
1273         const int64_t tensor_size_allocated =
1274             GetOrCalculateOutputSize(nodestate, slot);
1275         total_output_size += tensor_size_allocated;
1276         if (nodestate.node_costs.persistent_output_ports.contains(slot)) {
1277           persistent_output_size += tensor_size_allocated;
1278         }
1279         tensor_descr->mutable_allocation_description()->set_requested_bytes(
1280             tensor_size_requested);
1281         tensor_descr->mutable_allocation_description()->set_allocated_bytes(
1282             tensor_size_allocated);
1283       }
1284       if (node_def->op() != "HloGenericOp") {
1285         node_stats->set_timeline_label(node_def->op());
1286       } else {
1287         // For HloGenericOp, display hlo_opcode as timeline label.
1288         string timeline_label;
1289         if (node_def->attr().count("hlo_opcode") > 0) {
1290           absl::StrAppend(&timeline_label,
1291                           node_def->attr().at("hlo_opcode").s());
1292         }
1293         if (node_def->attr().count("_hlo_metadata_op_type") > 0) {
1294           absl::StrAppend(&timeline_label, "/",
1295                           node_def->attr().at("_hlo_metadata_op_type").s());
1296         }
1297         node_stats->set_timeline_label(timeline_label);
1298       }
1299       node_stats->set_node_name(node_def->name());
1300       // Timestamps in microseconds (can be used by timeline_server).
1301       node_stats->set_op_start_rel_micros(0);
1302       node_stats->set_all_start_micros(
1303           nodestate.time_scheduled.asMicroSeconds().count());
1304       node_stats->set_op_end_rel_micros(
1305           nodestate.time_finished.asMicroSeconds().count() -
1306           nodestate.time_scheduled.asMicroSeconds().count());
1307       node_stats->set_all_end_rel_micros(
1308           nodestate.time_finished.asMicroSeconds().count() -
1309           nodestate.time_scheduled.asMicroSeconds().count());
1310       // Timestamps in nanoseconds (can be used by xprof trace).
1311       node_stats->set_op_start_rel_nanos(0);
1312       node_stats->set_all_start_nanos(nodestate.time_scheduled.count());
1313       node_stats->set_op_end_rel_nanos(nodestate.time_finished.count() -
1314                                        nodestate.time_scheduled.count());
1315       node_stats->set_all_end_rel_nanos(nodestate.time_finished.count() -
1316                                         nodestate.time_scheduled.count());
1317 
1318       auto* mem_stats = node_stats->mutable_memory_stats();
1319       // SchedulerState does not specify scratch pad memory usage.
1320       mem_stats->set_temp_memory_size(0);
1321       int64_t persistent_memory_size = 0;
1322       if (IsPersistent(*node_def)) {
1323         persistent_memory_size = total_output_size;
1324       } else {
1325         persistent_memory_size = persistent_output_size;
1326       }
1327       mem_stats->set_persistent_memory_size(persistent_memory_size);
1328       *device_partition_graph->add_node() = *node_def;
1329     }
1330   }
1331 }
1332 
GetPeakMemoryUsage() const1333 const std::unordered_map<string, int64_t> SchedulerState::GetPeakMemoryUsage()
1334     const {
1335   std::unordered_map<string, int64_t> result;
1336   for (const auto& device : device_) {
1337     const string& name = device.first;
1338     const DeviceState& state = device.second;
1339     result[name] = state.max_memory_usage;
1340   }
1341   return result;
1342 }
1343 
1344 const std::unordered_map<string, int64_t>
GetPersistentMemoryUsage() const1345 SchedulerState::GetPersistentMemoryUsage() const {
1346   std::unordered_map<string, int64_t> result;
1347   for (const auto& device : device_) {
1348     const string& name = device.first;
1349     const DeviceState& state = device.second;
1350     int64_t persistent_memory_usage = 0;
1351     for (const auto& node_port : state.persistent_nodes) {
1352       const auto* node = node_port.first;
1353       const auto port = node_port.second;
1354       const auto& node_state = node_map_.at(node);
1355       persistent_memory_usage += GetOrCalculateOutputSize(node_state, port);
1356     }
1357     result[name] = persistent_memory_usage;
1358   }
1359   return result;
1360 }
1361 
SetNodeStateTimeScheduled(const NodeDef * node)1362 void SchedulerState::SetNodeStateTimeScheduled(const NodeDef* node) {
1363   auto& node_state = node_map_.at(node);
1364   auto& device = device_[node_state.device_name];
1365   node_state.time_scheduled = device.GetCurrTime();
1366 }
1367 
GetOrCalculateOutputSize(const NodeState & node_state,int port_num) const1368 int64_t SchedulerState::GetOrCalculateOutputSize(const NodeState& node_state,
1369                                                  int port_num) const {
1370   auto& node_costs = node_state.node_costs;
1371   auto it = node_costs.output_tensor_size_bytes.find(port_num);
1372   if (it != node_costs.output_tensor_size_bytes.end()) {
1373     return it->second;
1374   }
1375   return CalculateOutputSize(node_state.output_properties, port_num);
1376 }
1377 
~VirtualScheduler()1378 VirtualScheduler::~VirtualScheduler() {}
1379 
VirtualScheduler(const bool use_static_shapes,const bool use_aggressive_shape_inference,Cluster * cluster,ReadyNodeManager * ready_nodes,std::unique_ptr<VirtualPlacer> placer)1380 VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
1381                                    const bool use_aggressive_shape_inference,
1382                                    Cluster* cluster,
1383                                    ReadyNodeManager* ready_nodes,
1384                                    std::unique_ptr<VirtualPlacer> placer)
1385     : scheduler_state_(std::make_unique<SchedulerState>(
1386           use_static_shapes, use_aggressive_shape_inference, cluster,
1387           std::move(placer))),
1388       ready_nodes_(ready_nodes) {}
1389 
VirtualScheduler(ReadyNodeManager * ready_nodes,std::unique_ptr<SchedulerState> scheduler_state)1390 VirtualScheduler::VirtualScheduler(
1391     ReadyNodeManager* ready_nodes,
1392     std::unique_ptr<SchedulerState> scheduler_state)
1393     : scheduler_state_(std::move(scheduler_state)), ready_nodes_(ready_nodes) {}
1394 
Init(const GrapplerItem * item)1395 Status VirtualScheduler::Init(const GrapplerItem* item) {
1396   // SchedulerState::Init() preprocesses the input grappler_item and
1397   // graph_properties to extract necessary information for emulating tensorflow
1398   // op scheduling and construct internal data structures (NodeState and
1399   // DeviceState) for virtual scheduling.
1400   TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
1401   std::vector<const NodeDef*> initial_nodes;
1402   auto status = scheduler_state_->Init(item, &initial_nodes);
1403   if (status.ok()) {
1404     // Add the set of initial nodes to ready_nodes_
1405     for (auto node : initial_nodes) {
1406       ready_nodes_->AddNode(node);
1407     }
1408   }
1409   return status;
1410 }
1411 
GetCurrNode()1412 OpContext VirtualScheduler::GetCurrNode() {
1413   const NodeDef* node = ready_nodes_->GetCurrNode();
1414   return scheduler_state_->CreateOpContext(node);
1415 }
1416 
MarkCurrNodeExecuted(const Costs & node_costs)1417 bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
1418   // Update graph_costs_ and per-op costs.
1419   const NodeDef* node = ready_nodes_->GetCurrNode();
1420   auto new_nodes = scheduler_state_->MarkNodeExecuted(
1421       node, node_costs,
1422       scheduler_state_->CreateOpContext(ready_nodes_->GetCurrNode()));
1423   // Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_.
1424   for (auto node : new_nodes) {
1425     ready_nodes_->AddNode(node);
1426   }
1427   ready_nodes_->RemoveCurrNode();
1428   return !ready_nodes_->Empty();
1429 }
1430 
1431 }  // end namespace grappler
1432 }  // end namespace tensorflow
1433