1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/virtual_scheduler.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <string>
21 #include <utility>
22
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_replace.h"
25 #include "tensorflow/core/framework/allocation_description.pb.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/tensor.pb.h"
29 #include "tensorflow/core/framework/tensor_description.pb.h"
30 #include "tensorflow/core/framework/tensor_shape.pb.h"
31 #include "tensorflow/core/grappler/clusters/utils.h"
32 #include "tensorflow/core/grappler/costs/utils.h"
33 #include "tensorflow/core/grappler/op_types.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/grappler/utils/transitive_fanin.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/strings/numbers.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/util/device_name_utils.h"
40
41 namespace tensorflow {
42 namespace grappler {
43
44 const char kAttrInputSrc[] = "input_source_";
45 const char kAttrSrcDevice[] = "send_device";
46 const char kAttrDstDevice[] = "recv_device";
47 const char kAttrTensorName[] = "tensor_name";
48 const char kChannelDevice[] = "Channel";
49 const char kStreaming[] = "_streaming";
50
51 namespace {
52
53 using ::tensorflow::strings::HumanReadableNumBytes;
54
Round2(const float x)55 float Round2(const float x) {
56 // Not using std::round from <cmath> here because not all platforms seem to
57 // support that (specifically Android).
58 return ::round(100.0 * x) / 100.0;
59 }
60
FindOrCreateZero(const string & op_name,std::map<string,Costs> * op_cost)61 Costs& FindOrCreateZero(const string& op_name,
62 std::map<string, Costs>* op_cost) {
63 auto it = op_cost->find(op_name);
64 if (it == op_cost->end()) {
65 // Note that default constructor of Costs sets some memory related fields
66 // to unknown values so we should explicitly initialize it with ZeroCosts.
67 it = op_cost->emplace(op_name, Costs::ZeroCosts()).first;
68 }
69 return it->second;
70 }
71
72 // Key to the cached _Recv ops map, and its hash and predicate structures.
73 struct RecvNodeDescriptor {
74 const NodeDef* node;
75 const int port_num;
76 const string device;
77
RecvNodeDescriptortensorflow::grappler::__anond2e8ad230111::RecvNodeDescriptor78 RecvNodeDescriptor(const NodeDef* node_, const int port_num_,
79 const string& device_)
80 : node(node_), port_num(port_num_), device(device_) {}
81 };
82
83 struct RecvNodeDescriptorHash {
operator ()tensorflow::grappler::__anond2e8ad230111::RecvNodeDescriptorHash84 std::size_t operator()(const RecvNodeDescriptor& recv_node) const {
85 return std::hash<const NodeDef*>()(recv_node.node) ^
86 std::hash<int>()(recv_node.port_num) ^
87 std::hash<string>()(recv_node.device);
88 }
89 };
90
91 struct RecvNodeDescriptorEqual {
operator ()tensorflow::grappler::__anond2e8ad230111::RecvNodeDescriptorEqual92 bool operator()(const RecvNodeDescriptor& a,
93 const RecvNodeDescriptor& b) const {
94 return a.node == b.node && a.port_num == b.port_num && a.device == b.device;
95 }
96 };
97
UpdateDeviceAnnotationState(const NodeDef * node,const NodeState & node_state,DeviceState * device)98 void UpdateDeviceAnnotationState(const NodeDef* node,
99 const NodeState& node_state,
100 DeviceState* device) {
101 if (node->attr().count(kOutputShapes) == 0) return;
102
103 int64_t execution_count = node->attr().count(kExecutionCount) == 0
104 ? 1
105 : node->attr().at(kExecutionCount).i();
106
107 auto& shape_annotation_stats = device->shape_annotation_stats;
108 shape_annotation_stats.num_ops_annotated += 1;
109 shape_annotation_stats.num_ops_executed += execution_count;
110 shape_annotation_stats.num_ops_executed_more_than_once +=
111 execution_count > 1 ? 1 : 0;
112 shape_annotation_stats.num_ops_with_incompatible_shapes +=
113 node_state.shape_incompatible ? 1 : 0;
114 shape_annotation_stats.num_ops_with_dynamic_shapes +=
115 (execution_count > 1 && node->attr().count(kOutputSame) == 0) ? 1 : 0;
116 }
117
IsStreamingPort(const NodeDef & node,const int port)118 bool IsStreamingPort(const NodeDef& node, const int port) {
119 if (!node.attr().contains(kStreaming)) return false;
120
121 auto& attr_list = node.attr().at(kStreaming).list();
122 bool is_streaming_port = false;
123 if (port >= 0 && port < attr_list.b().size()) {
124 is_streaming_port = attr_list.b(port);
125 }
126 return is_streaming_port;
127 }
128
129 } // namespace
130
AddNode(const NodeDef * node)131 void LIFOManager::AddNode(const NodeDef* node) {
132 // Merge nodes are scheduled with the lowest priority in LIFO manager; virtual
133 // scheduler may run multiple input nodes of Merge (when we don't have
134 // annotation, which is quite common); simply scheduling Merge after one of
135 // its input may break scheduling constraints; some inputs of Merge may be
136 // scheduled after the Merge. So, we place Merge at the beginning of the queue
137 // to guarantee all the inputs of Merge are scheduled before the Merge.
138 if (IsMerge(*node)) {
139 nodes_.push_front(node);
140 } else {
141 nodes_.push_back(node);
142 }
143 }
144
GetCurrNode()145 const NodeDef* LIFOManager::GetCurrNode() {
146 CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
147 if (curr_pos_ == nodes_.end()) {
148 curr_pos_ = --(nodes_.rbegin().base()); // Last one in the list.
149 }
150 // Once curr_pos_ is set to a valid entry in the list, we keep using the
151 // cached curr_pos_ until RemoveCurrNode() is called. AddNode() will not
152 // change the GetCurrNode() return value.
153 return *curr_pos_;
154 }
155
RemoveCurrNode()156 void LIFOManager::RemoveCurrNode() {
157 // Make sure we have curr_pos_ ready to be removed.
158 GetCurrNode();
159 // Note curr_pos_ may not be pointing the last element if some nodes are
160 // added.
161 nodes_.erase(curr_pos_);
162
163 curr_pos_ = nodes_.end(); // Reset curr_pos_.
164 }
165
HeapReadyManager()166 HeapReadyManager::HeapReadyManager() : ReadyNodeManager() {
167 std::make_heap(nodes_.begin(), nodes_.end());
168 }
169
Init(const std::unordered_map<const NodeDef *,NodeState> * node_map)170 Status HeapReadyManager::Init(
171 const std::unordered_map<const NodeDef*, NodeState>* node_map) {
172 // Resets the node state since different instances of the scheduler can reuse
173 // the same node_manager.
174 node_map_ = node_map;
175 nodes_.clear();
176 curr_node_ = nullptr;
177
178 // Sets up the comparator for the heap.
179 greater_ = Greater();
180
181 return OkStatus();
182 }
183
AddNode(const NodeDef * node)184 void HeapReadyManager::AddNode(const NodeDef* node) {
185 // push_heap in AddNode and pop_heap in RemoveCurrNode() guarantees that the
186 // first element is the node with minimum time_ready.
187 nodes_.push_back(node);
188 std::push_heap(nodes_.begin(), nodes_.end(), greater_);
189 }
190
GetCurrNode()191 const NodeDef* HeapReadyManager::GetCurrNode() {
192 if (curr_node_) return curr_node_;
193 if (nodes_.empty()) {
194 CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
195 }
196 const std::string node_name = nodes_.front()->name();
197 // Next time we call GetCurrNode(), it just returns the cached copy
198 // curr_node_, until we call the RemoveCurrNode().
199 curr_node_ = nodes_.front();
200 // Remove current node from the heap immediately. Because if we wait until
201 // later, the heap could have gotten re-organized if AddNode is called. The
202 // current node is anyways cached, incase GetCurrNode() is called again.
203 std::pop_heap(nodes_.begin(), nodes_.end(), greater_);
204 nodes_.pop_back();
205 return curr_node_;
206 }
207
RemoveCurrNode()208 void HeapReadyManager::RemoveCurrNode() {
209 if (curr_node_) {
210 // If cached copy exists, remove that.
211 // Reset curr_node_ so that GetCurrNode() finds another node.
212 curr_node_ = nullptr;
213 } else {
214 // If cached copy not present, then remove entry from the heap queue.
215 std::pop_heap(nodes_.begin(), nodes_.end(), greater_);
216 nodes_.pop_back();
217 }
218 }
219
Empty() const220 bool HeapReadyManager::Empty() const {
221 return nodes_.empty() && curr_node_ == nullptr;
222 }
223
FirstReadyCmp(const std::unordered_map<const NodeDef *,NodeState> * node_map,const NodeDef * a,const NodeDef * b)224 bool FirstReadyCmp(
225 const std::unordered_map<const NodeDef*, NodeState>* node_map,
226 const NodeDef* a, const NodeDef* b) {
227 if (node_map->at(a).time_ready == node_map->at(b).time_ready) {
228 // Use Node name as tie-breaker for deterministic node scheduling.
229 return a->name().compare(b->name()) > 0;
230 } else {
231 // Note: we need a node with minimum time_ready, not maximum; hence, using
232 // a > b for comparison function.
233 return node_map->at(a).time_ready > node_map->at(b).time_ready;
234 }
235 }
236
237 std::function<bool(const NodeDef*, const NodeDef*)>
Greater()238 FirstReadyManager::Greater() {
239 auto greater = [this](const NodeDef* a, const NodeDef* b) -> bool {
240 return FirstReadyCmp(node_map_, a, b);
241 };
242 return greater;
243 }
244
245 std::function<bool(const NodeDef*, const NodeDef*)>
Greater()246 PriorityReadyManager::Greater() {
247 auto greater = [this](const NodeDef* a, const NodeDef* b) -> bool {
248 auto pri_a = node_priority_.at(a->name());
249 auto pri_b = node_priority_.at(b->name());
250 if (pri_a == pri_b) {
251 // Fallback to default (FirstReady) behaviour.
252 return FirstReadyCmp(node_map_, a, b);
253 }
254 return pri_a > pri_b;
255 };
256 return greater;
257 }
258
AddNode(const NodeDef * node)259 void PriorityReadyManager::AddNode(const NodeDef* node) {
260 if (node_priority_.count(node->name()) == 0) {
261 VLOG(3) << "Priority of node " << node->name() << " not found.";
262 node_priority_[node->name()] = 0;
263 }
264 HeapReadyManager::AddNode(node);
265 }
266
SetPriority(const std::unordered_map<string,int> & node_priority)267 Status PriorityReadyManager::SetPriority(
268 const std::unordered_map<string, int>& node_priority) {
269 node_priority_ = node_priority;
270 return OkStatus();
271 }
272
CompositeNodeManager()273 CompositeNodeManager::CompositeNodeManager()
274 : ReadyNodeManager(), send_manager_(), recv_manager_() {}
275
Init(const std::unordered_map<const NodeDef *,NodeState> * node_map)276 Status CompositeNodeManager::Init(
277 const std::unordered_map<const NodeDef*, NodeState>* node_map) {
278 node_map_ = node_map;
279 TF_RETURN_IF_ERROR(send_manager_.Init(node_map));
280 TF_RETURN_IF_ERROR(recv_manager_.Init(node_map));
281 curr_node_ = nullptr;
282 return OkStatus();
283 }
284
AddNode(const NodeDef * node)285 void CompositeNodeManager::AddNode(const NodeDef* node) {
286 if (IsSend(*node)) {
287 send_manager_.AddNode(node);
288 } else if (IsRecv(*node)) {
289 recv_manager_.AddNode(node);
290 } else {
291 const auto& device = node_map_->at(node).device_name;
292 ops_lifo_map_[device].AddNode(node);
293 }
294 }
295
GetCurrNode()296 const NodeDef* CompositeNodeManager::GetCurrNode() {
297 if (curr_node_) return curr_node_;
298
299 // Per-device LIFO for normal ops (not _Send / _Recv),
300 // FirstReady for _Send and _Recv (separately),
301 // Globally (among the LIFO-selected ops from each device and _Send and
302 // _Recv) FirstReady,
303 // Priority order: _Send, _Recv, and then the rest, if time_ready is equal.
304 std::vector<std::pair<const NodeDef*, Costs::Duration>> candidates;
305 for (auto& ops_lifo : ops_lifo_map_) {
306 if (!ops_lifo.second.Empty()) {
307 const auto* op = ops_lifo.second.GetCurrNode();
308 candidates.emplace_back(op, node_map_->at(op).time_ready);
309 }
310 }
311 if (!send_manager_.Empty()) {
312 const auto* send = send_manager_.GetCurrNode();
313 candidates.emplace_back(send, node_map_->at(send).time_ready);
314 }
315 if (!recv_manager_.Empty()) {
316 const auto* recv = recv_manager_.GetCurrNode();
317 candidates.emplace_back(recv, node_map_->at(recv).time_ready);
318 }
319 CHECK(!candidates.empty());
320 auto first_ready = std::min_element(
321 candidates.begin(), candidates.end(),
322 [](const std::pair<const NodeDef*, Costs::Duration>& a,
323 const std::pair<const NodeDef*, Costs::Duration>& b) {
324 if (a.second == b.second) {
325 // Note that there can be only 1 Send and only 1 Recv in candidates,
326 // at most; hence, score is 2 for Send, 1 for Recv, and 0 for a
327 // normap op, and a_score and b_score are equal only if both are
328 // normal ops.
329 int a_score = 2 * IsSend(*a.first) + IsRecv(*a.first);
330 int b_score = 2 * IsSend(*b.first) + IsRecv(*b.first);
331 if (a_score == b_score) {
332 // Both are normal ops; use node name as tie breaker.
333 return a.first->name().compare(b.first->name()) < 0;
334 } else {
335 // Prioritize by op type: _Send, _Recv, and normap ops.
336 return a_score > b_score;
337 }
338 } else {
339 return a.second < b.second;
340 }
341 });
342 // Next time we call GetCurrNode(), it just returns the cached one,
343 // curr_node_ until we call RemovCurrNode().
344 curr_node_ = first_ready->first;
345
346 return curr_node_;
347 }
348
RemoveCurrNode()349 void CompositeNodeManager::RemoveCurrNode() {
350 const auto* node = GetCurrNode();
351 if (IsSend(*node)) {
352 send_manager_.RemoveCurrNode();
353 } else if (IsRecv(*node)) {
354 recv_manager_.RemoveCurrNode();
355 } else {
356 const auto device = node_map_->at(node).device_name;
357 ops_lifo_map_[device].RemoveCurrNode();
358 }
359 // Reset curr_node_ so that GetCurrNode() finds another node.
360 curr_node_ = nullptr;
361 }
362
Empty() const363 bool CompositeNodeManager::Empty() const {
364 // Empty if all the ready managers are empty.
365 bool empty = true;
366 for (const auto& ops_lifo : ops_lifo_map_) {
367 empty &= ops_lifo.second.Empty();
368 }
369 return empty && send_manager_.Empty() && recv_manager_.Empty();
370 }
371
ReadyNodeManagerFactory(const string & ready_node_manager)372 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory(
373 const string& ready_node_manager) {
374 if (ready_node_manager == "FIFO") {
375 return std::make_unique<FIFOManager>();
376 } else if (ready_node_manager == "LIFO") {
377 return std::make_unique<LIFOManager>();
378 } else if (ready_node_manager == "FirstReady") {
379 return std::make_unique<FirstReadyManager>();
380 } else if (ready_node_manager == "Composite") {
381 return std::make_unique<CompositeNodeManager>();
382 }
383 LOG(FATAL) << "Not a valid ready node manager: " << ready_node_manager;
384 return nullptr;
385 }
386
~SchedulerState()387 SchedulerState::~SchedulerState() {}
388
SchedulerState(const bool use_static_shapes,const bool use_aggressive_shape_inference,Cluster * cluster,std::unique_ptr<VirtualPlacer> placer)389 SchedulerState::SchedulerState(const bool use_static_shapes,
390 const bool use_aggressive_shape_inference,
391 Cluster* cluster,
392 std::unique_ptr<VirtualPlacer> placer)
393 : graph_costs_(Costs::ZeroCosts()),
394 cluster_(cluster),
395 use_static_shapes_(use_static_shapes),
396 use_aggressive_shape_inference_(use_aggressive_shape_inference),
397 placer_(std::move(placer)) {
398 DCHECK(placer_); // check if the pointer is valid.
399 graph_costs_.num_ops_total = 0;
400 initialized_ = false;
401 track_mem_usage_snapshot_ = VLOG_IS_ON(1);
402 }
403
Init(const GrapplerItem * item,std::vector<const NodeDef * > * initial_nodes,bool create_explicit_channel_device)404 Status SchedulerState::Init(const GrapplerItem* item,
405 std::vector<const NodeDef*>* initial_nodes,
406 bool create_explicit_channel_device) {
407 initialized_ = false;
408
409 // Clear all internal states so that the SchedulerState is reusable for
410 // different GrapplerItems
411 node_map_.clear();
412 device_.clear();
413 additional_nodes_.clear();
414
415 graph_costs_ = Costs::ZeroCosts();
416 graph_costs_.num_ops_total = 0;
417 op_to_cost_.clear();
418
419 op_counts_.clear();
420 op_costs_.clear();
421
422 initial_nodes->clear();
423
424 // Constructs graph properties and performs shape inference.
425 graph_properties_ = std::make_unique<GraphProperties>(*item);
426 // TODO(safeen,dyoon): Will we ever use InferDynamically? If not we may want
427 // to get rid of use_static_shapes_ and cluster_.
428 if (use_static_shapes_) {
429 TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
430 true, use_aggressive_shape_inference_, true));
431 } else {
432 TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_));
433 }
434
435 grappler_item_ = item;
436 const auto& graph = grappler_item_->graph;
437 const auto& fetch_nodes = grappler_item_->fetch;
438 std::set<string> feed_nodes;
439
440 for (const auto& f : grappler_item_->feed) {
441 auto iter_and_inserted_flag = feed_nodes.insert(f.first);
442 QCHECK(iter_and_inserted_flag.second)
443 << "Duplicate feed node found: " << f.first;
444 }
445
446 // Get the nodes that would run to output fetch_nodes.
447 std::unordered_map<string, const NodeDef*> name_to_node;
448 std::vector<const NodeDef*> fetch_fanin_nodes;
449 TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph, fetch_nodes, &name_to_node,
450 &fetch_fanin_nodes));
451
452 // Once ComputeTransitiveFanin is complete, only the nodes that can be reached
453 // from the fetch nodes are scheduled. So the scheduled nodes should be
454 // exactly the same as those executed for real. One possible discrepancy could
455 // be the control flow nodes, where tf only executes one path.
456
457 // Traverses the graph to record _Send nodes.
458 // TODO(dyoon): Instead of identifying _Send node here manually, add _Send
459 // to _Recv as control dependency when creating GrapplerItem.
460 std::unordered_map<string, const NodeDef*> name_to_send;
461 for (const auto& node : graph.node()) {
462 if (IsSend(node)) {
463 const auto& attr = node.attr();
464 name_to_send[attr.at("tensor_name").s()] = &node;
465 }
466 }
467
468 // To reuse _Recv ops.
469 std::unordered_map<RecvNodeDescriptor, const NodeDef*, RecvNodeDescriptorHash,
470 RecvNodeDescriptorEqual>
471 cached_recv_nodes;
472
473 // Build node_map; for each node, create its NodeState and connect its inputs
474 // and outputs.
475 for (const auto* curr_node : fetch_fanin_nodes) {
476 auto& curr_node_state = GetNodeStateOrCreateIt(curr_node);
477 const string curr_node_device = DeviceName(curr_node);
478 std::vector<string> inputs;
479 if (IsRecv(*curr_node)) {
480 const auto& attr = curr_node->attr();
481 if (attr.count("tensor_name")) {
482 const auto& send_node_name = attr.at("tensor_name").s();
483 auto it = name_to_send.find(send_node_name);
484 // If there is a _Send associated with the curr_node (_Recv), add it as
485 // input.
486 if (it != name_to_send.end()) {
487 const NodeDef* send = it->second;
488 inputs = {send->name()};
489 }
490 }
491 } else {
492 for (const string& input : curr_node->input()) {
493 inputs.push_back(input);
494 }
495 }
496 for (const string& input_node_name : inputs) {
497 // Note that input_node_name may be in <prefix><node_name>:<port_num>
498 // format, where <prefix> (e.g., "^" for control dependency) and
499 // ":<port_num>" may be omitted. NodeName() extracts only the node_name.
500 const NodeDef* input_node = name_to_node[NodeName(input_node_name)];
501
502 CHECK(input_node);
503 const string in_device = DeviceName(input_node);
504 const auto input_node_port_num = NodePosition(input_node_name);
505
506 // Control dependencies should be treated as high priority. Current
507 // Channel device doesn't model a separate virtual channel for control v/s
508 // data transfers. So in the interim, it may be okay to let control
509 // dependencies magically flow across devices bypassing the channel
510 // device.
511 if (curr_node_device == in_device || IsControlInput(input_node_name)) {
512 // Same device: connect input_node and curr_node directly.
513 curr_node_state.inputs.push_back(
514 std::make_pair(input_node, input_node_port_num));
515 auto& input_node_state = GetNodeStateOrCreateIt(input_node);
516 input_node_state.outputs[input_node_port_num].push_back(curr_node);
517 } else {
518 RecvNodeDescriptor recv_node(input_node, input_node_port_num,
519 curr_node_device);
520 auto it = cached_recv_nodes.find(recv_node);
521 if (it != cached_recv_nodes.end()) {
522 // Different device, but found an already-cached copy (a _Recv op);
523 // connect the _Recv to curr_node.
524 const NodeDef* recv_op = it->second;
525 // recv_op's output port is hard-coded to zero.
526 curr_node_state.inputs.push_back(std::make_pair(recv_op, 0));
527 auto& input_node_state = node_map_.at(recv_op);
528 input_node_state.outputs[0].push_back(curr_node);
529 } else {
530 // Different device, no cached copy; transfer input_node to the
531 // curr_node's device.
532 auto send_and_recv =
533 CreateSendRecv(input_node, curr_node, input_node, input_node_name,
534 create_explicit_channel_device);
535 // Note that CreateSendRecv() already connected input/output between
536 // _Send and _Recv ops.
537 const auto* send = send_and_recv.first;
538 const auto* recv = send_and_recv.second;
539 // recv_op's output port is hard-coded to zero.
540 curr_node_state.inputs.push_back(std::make_pair(recv, 0));
541 auto& input_node_state = GetNodeStateOrCreateIt(input_node);
542 input_node_state.outputs[input_node_port_num].push_back(send);
543
544 // Cache the _Recv op for future use.
545 cached_recv_nodes[recv_node] = recv;
546 }
547 }
548 }
549
550 // Special case: given feed nodes are ready at time 0.
551 const bool given_as_feed =
552 feed_nodes.find(curr_node->name()) != feed_nodes.end();
553
554 // Default case: node without inputs are ready at time 0.
555 // Note that we check inputs vector which may be different to
556 // curr_node->input(); e.g., we add Send as input to Recv.
557 const bool has_no_inputs = inputs.empty();
558
559 if (given_as_feed || has_no_inputs) {
560 curr_node_state.time_ready = Costs::Duration();
561 initial_nodes->push_back(curr_node);
562 VLOG(3) << "Added ready node: " << curr_node->name();
563 }
564 feed_nodes.erase(curr_node->name());
565
566 if (IsPersistent(*curr_node)) {
567 auto& device_state = device_[curr_node_device];
568 for (int port_num = 0,
569 port_num_end = curr_node_state.output_properties.size();
570 port_num < port_num_end; ++port_num) {
571 device_state.persistent_nodes.insert(
572 std::make_pair(curr_node, port_num));
573 }
574 }
575 }
576
577 if (initial_nodes->empty()) {
578 return errors::InvalidArgument("No ready nodes in the graph.");
579 }
580
581 if (!feed_nodes.empty()) {
582 // This isn't always a bug: when the caller hasn't specified the exact list
583 // of feed and fetch nodes, by default we consider all placeholders as feed
584 // nodes, but some of them may not be needed for the default fetch node.
585 VLOG(1) << "Some feed nodes were not consumed by the fetch fanin: "
586 << absl::StrJoin(feed_nodes, ",");
587 }
588
589 initialized_ = true;
590 return OkStatus();
591 }
592
MaybeUpdateInputOutput(const NodeDef * node)593 void SchedulerState::MaybeUpdateInputOutput(const NodeDef* node) {
594 CHECK(!initialized_) << "MaybeUpdateInputOutput is called after Init().";
595 // This method is called when NodeState is created and adds input and output
596 // properties for a few exceptional cases that GraphProperties cannot provide
597 // input/output properties.
598 if ((IsSend(*node) || IsRecv(*node)) && node->attr().count(kAttrInputSrc)) {
599 // _Send and _Recv ops created from SchedulerState have kAttrInputSrc
600 // attr; normal _Send and _Recv ops (from the input graph) do not have that
601 // attr.
602 auto& node_state = node_map_[node];
603 auto& inputs = node_state.input_properties;
604 auto& outputs = node_state.output_properties;
605
606 // _Send and _Recv ops are created from SchedulerState, so
607 // there should be no inputs TensorProperties.
608 CHECK(inputs.empty());
609 CHECK(outputs.empty());
610 const auto& attr = node->attr();
611 // This is the original input source to the _Send and _Recv, and this
612 // string includes "^" if it was control dependency, and output port
613 /// (e.g., ":2") if the input source had multiple outputs.
614 const auto& input_source_name = attr.at(kAttrInputSrc).s();
615 if (IsControlInput(input_source_name)) {
616 // Control dependency; regardless of the input source tensor size,
617 // send 4B.
618 OpInfo::TensorProperties control_message;
619 control_message.set_dtype(DT_FLOAT);
620 control_message.mutable_shape()->add_dim()->set_size(1);
621 auto* value = control_message.mutable_value();
622 value->add_float_val(1);
623 inputs.push_back(control_message);
624 outputs.push_back(control_message);
625 } else {
626 const auto& output_properties =
627 graph_properties_->GetOutputProperties(NodeName(input_source_name));
628 // Like with HasInputProperties, if a node does not have output
629 // properties, it's likely it was pruned during the shape inference run.
630 if (!output_properties.empty()) {
631 const auto input_node_port_num = NodePosition(input_source_name);
632 // Use the input source's output property as _Send and _Recv's input
633 // property.
634 CHECK_GT(output_properties.size(), input_node_port_num);
635 inputs.push_back(output_properties[input_node_port_num]);
636 outputs.push_back(output_properties[input_node_port_num]);
637 }
638 }
639 }
640 }
641
DeviceName(const NodeDef * node) const642 string SchedulerState::DeviceName(const NodeDef* node) const {
643 return placer_->get_canonical_device_name(*node);
644 }
645
SanitizedDeviceName(const NodeDef * node) const646 string SchedulerState::SanitizedDeviceName(const NodeDef* node) const {
647 // Replace the ":" characters that may be present in the device name with "_".
648 // This makes it possible to then use the resulting string in a node name.
649 return absl::StrReplaceAll(placer_->get_canonical_device_name(*node),
650 {{":", "_"}});
651 }
652
ChannelDeviceName(const NodeDef * from,const NodeDef * to) const653 string SchedulerState::ChannelDeviceName(const NodeDef* from,
654 const NodeDef* to) const {
655 CHECK(!initialized_) << "ChannelDeviceName is called after Init().";
656 return absl::StrCat(kChannelDevice, "_from_", SanitizedDeviceName(from),
657 "_to_", SanitizedDeviceName(to));
658 }
659
CreateSendRecv(const NodeDef * from,const NodeDef * to,const NodeDef * input_node,const string & input_name,bool create_channel_device)660 std::pair<const NodeDef*, const NodeDef*> SchedulerState::CreateSendRecv(
661 const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
662 const string& input_name, bool create_channel_device) {
663 CHECK(!initialized_) << "CreateSendRecv is called after Init().";
664
665 // Connect "from" node to "to" node with _Send and _Recv such that
666 // from -> _Send -> _Recv -> to.
667 // _Send is placed on "Channel" device, and _Recv is on the same device
668 // as "to" node.
669 // input_node_name is the string from the "to" node to identify which output
670 // we get from the "from" node.
671
672 // Note that we use NodeState for scheduling, so _Send and _Recv
673 // NodeDefs created here need not be correct: in terms of name,
674 // input names, attrs, etc.
675
676 auto input_node_port_num = NodePosition(input_name);
677 string src_name;
678 bool control_input = false;
679 if (input_node_port_num >= 0) {
680 src_name = absl::StrCat(from->name(), "_", input_node_port_num);
681 } else {
682 src_name = absl::StrCat(from->name(), "_minus1");
683 control_input = true;
684 }
685
686 // _Send op.
687 auto* send = new NodeDef();
688 send->set_name("Send_" + src_name + "_from_" + SanitizedDeviceName(from) +
689 "_to_" + SanitizedDeviceName(to));
690 send->set_op("_Send");
691 send->add_input(from->name());
692 auto send_device =
693 create_channel_device ? ChannelDeviceName(from, to) : DeviceName(from);
694 send->set_device(send_device);
695 auto& send_attr = *(send->mutable_attr());
696 send_attr[kAttrInputSrc].set_s(input_name);
697 send_attr[kAttrSrcDevice].set_s(DeviceName(from));
698 send_attr[kAttrDstDevice].set_s(DeviceName(to));
699 // GraphDef generated by AutoGrappler has tensor_name field when removing
700 // _Send/_Recv nodes.
701 if (input_node->attr().count(kAttrTensorName)) {
702 send_attr[kAttrTensorName].set_s(
703 input_node->attr().at(kAttrTensorName).s());
704 }
705
706 // _Recv op.
707 auto* recv = new NodeDef();
708 recv->set_name("Recv_" + src_name + "_on_" + SanitizedDeviceName(to));
709 recv->set_op("_Recv");
710 recv->add_input(send->name());
711 recv->set_device(DeviceName(to));
712 auto& recv_attr = *(recv->mutable_attr());
713 recv_attr[kAttrInputSrc].set_s(input_name);
714 if (input_node->attr().count(kAttrTensorName)) {
715 recv_attr[kAttrTensorName].set_s(
716 input_node->attr().at(kAttrTensorName).s());
717 }
718
719 // Propagate the streaming attribute to the send/recv nodes.
720 if (from->attr().contains(kStreaming) && !control_input) {
721 if (input_node_port_num >= from->attr().at(kStreaming).list().b_size()) {
722 LOG(ERROR)
723 << from->name()
724 << " port index larger than length of _streaming attribute list.";
725 } else if (from->attr().at(kStreaming).list().b(input_node_port_num)) {
726 send_attr[kStreaming].mutable_list()->add_b(true);
727 recv_attr[kStreaming].mutable_list()->add_b(true);
728 }
729 }
730
731 // NodeState for _Send op.
732 auto& send_node_state = GetNodeStateOrCreateIt(send);
733 send_node_state.device_name = send->device(); // Set Channel device.
734 send_node_state.inputs.push_back(std::make_pair(from, input_node_port_num));
735 send_node_state.outputs[0].push_back(recv);
736
737 // NodeState for _Recv op.
738 auto& recv_node_state = GetNodeStateOrCreateIt(recv);
739 recv_node_state.inputs.push_back(std::make_pair(send, 0));
740 recv_node_state.outputs[0].push_back(to);
741
742 // Keep the created nodes.
743 additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(send));
744 additional_nodes_.emplace_back(std::unique_ptr<NodeDef>(recv));
745
746 // Return _Send and _Recv.
747 return std::make_pair(send, recv);
748 }
749
CreateOpContext(const NodeDef * node) const750 OpContext SchedulerState::CreateOpContext(const NodeDef* node) const {
751 // Get the device from the placer.
752 DeviceProperties device;
753 device = placer_->get_device(*node);
754
755 // Special case for _Send op.
756 if (IsSend(*node)) {
757 device.set_type(kChannelDevice);
758 }
759
760 // Construct OpContext.
761 OpContext op_context;
762 const auto& node_state = node_map_.at(node);
763 op_context.name = node->name();
764 op_context.device_name = node_state.device_name;
765 auto& op_info = op_context.op_info;
766 op_info.set_op(node->op());
767 *op_info.mutable_attr() = node->attr();
768 for (auto& input : node_state.input_properties) {
769 *op_info.add_inputs() = input;
770 }
771 for (auto& output : node_state.output_properties) {
772 *op_info.add_outputs() = output;
773 }
774 op_info.mutable_device()->Swap(&device);
775
776 if (grappler_item_->graph.has_library()) {
777 op_context.function_library = &grappler_item_->graph.library();
778 }
779 return op_context;
780 }
781
GetNodeStateOrCreateIt(const NodeDef * node)782 NodeState& SchedulerState::GetNodeStateOrCreateIt(const NodeDef* node) {
783 CHECK(!initialized_) << "GetNodeStateOrCreateIt is called after Init().";
784
785 auto it = node_map_.find(node);
786 if (it != node_map_.end()) {
787 return it->second;
788 }
789
790 // Not found; create a NodeState for this node.
791 it = node_map_.emplace(node, NodeState()).first;
792 auto& node_state = it->second;
793 node_state.input_properties =
794 graph_properties_->GetInputProperties(node->name());
795 node_state.output_properties =
796 graph_properties_->GetOutputProperties(node->name());
797 node_state.shape_incompatible =
798 graph_properties_->CheckShapeIncompatible(node->name());
799
800 // Some ops may need further processing to the input / output properties:
801 // _Send and _Recv.
802 MaybeUpdateInputOutput(node);
803
804 if (!IsSend(*node)) {
805 node_state.device_name = DeviceName(node);
806 // For _Send op, device_name will be set to Channel in CreateSendRecv().
807 }
808
809 // Initialize output port related data:
810 // Assume the size of OutputProperties represents the number of output ports
811 // of this node.
812 for (size_t i = 0; i < node_state.output_properties.size(); ++i) {
813 node_state.time_no_references[i] = Costs::Duration::max();
814 node_state.num_outputs_executed[i] = 0;
815 // Populate an empty vector for each port. The caller will add nodes
816 // that use this port as input.
817 node_state.outputs[i] = {};
818 }
819 // Port_num -1 is for control dependency.
820 node_state.time_no_references[-1] = Costs::Duration::max();
821 node_state.num_outputs_executed[-1] = 0;
822 node_state.outputs[-1] = {};
823
824 // Initialize time_scheduled to infinity, so we know whether it has been
825 // assigned a non-default value later.
826 node_state.time_scheduled = Costs::Duration().infinity();
827
828 return it->second;
829 }
830
GetOutputNodes(const NodeDef * node,const Costs::Duration & curr_time,std::vector<const NodeDef * > * output_nodes)831 void SchedulerState::GetOutputNodes(const NodeDef* node,
832 const Costs::Duration& curr_time,
833 std::vector<const NodeDef*>* output_nodes) {
834 // Checks whether the Switch's output slots change over iterations.
835 int slot = -1;
836 if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 &&
837 node->attr().at(kOutputSlots).list().i_size() > 0) {
838 slot = node->attr().at(kOutputSlots).list().i(0);
839 for (int i = 1; i < node->attr().at(kOutputSlots).list().i_size(); ++i) {
840 if (slot != node->attr().at(kOutputSlots).list().i(i)) {
841 slot = -1;
842 break;
843 }
844 }
845 }
846 // Increment num_inputs_ready of the output nodes and maybe add to ready
847 // nodes.
848 auto& node_state = node_map_[node];
849 for (const auto& port_num_output_pair : node_state.outputs) {
850 // If Switch is annotated and its output slots are always the same, we only
851 // schedule the slot that was executed. Otherwise, scheduler both slots.
852 if (slot >= 0 && port_num_output_pair.first != slot) continue;
853
854 for (auto* output_node : port_num_output_pair.second) {
855 auto& output_state = node_map_[output_node];
856 output_state.num_inputs_ready++;
857 // Execute a node as soon as all its inputs are ready. Merge nodes are
858 // special since they run as soon as one of their inputs becomes
859 // available.
860 int output_state_inputs_size = output_state.inputs.size();
861 if (output_state.num_inputs_ready == output_state_inputs_size ||
862 IsMerge(*output_node)) {
863 // This output node is now ready.
864 output_state.time_ready = curr_time;
865 output_nodes->push_back(output_node);
866 VLOG(3) << " Add output: " << output_node->name();
867 }
868 }
869 }
870 }
871
MarkNodeExecuted(const NodeDef * node,const Costs & node_costs,const OpContext & op_context,bool extract_execution_count_attr,const std::string & override_device_name)872 std::vector<const NodeDef*> SchedulerState::MarkNodeExecuted(
873 const NodeDef* node, const Costs& node_costs, const OpContext& op_context,
874 bool extract_execution_count_attr,
875 const std::string& override_device_name) {
876 auto& node_state = node_map_[node];
877 // TODO(dyoon, andiryxu): Consider to revisit node execution w.r.t. Switch and
878 // Merge -- it can create a loop which may include loop-carried dependency,
879 // diverge-merge, and other complex execution patterns.
880 bool previously_executed_merge =
881 IsMerge(*node) && (node_state.time_finished != Costs::Duration::max());
882
883 // Our approach to modeling loops is to extract the annotated _execution_count
884 // attribute and to multiply node_costs by the value of the attribute. If
885 // the attribute is not found then we assume a default execution count of 1.
886 // Note that in some simulation flows we will perform this multiplication
887 // elsewhere, as such we only perform this multiplication here if
888 // extract_execution_count_attr is true. Otherwise node_costs are unmodified
889 // and we assume the multiplication has been correctly carried out elsewhere.
890 node_state.execution_count = 1;
891
892 if (extract_execution_count_attr && node->attr().count(kExecutionCount) > 0) {
893 node_state.execution_count = node->attr().at(kExecutionCount).i();
894 }
895
896 node_state.node_costs = node_costs;
897 // TotalNodeCosts() Should be called after node_costs and execution_count.
898 Costs total_node_costs = node_state.TotalNodeCosts();
899
900 graph_costs_ = CombineCosts(graph_costs_, total_node_costs);
901 const string& op_name = node->op();
902
903 auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
904 op_cost = CombineCosts(op_cost, total_node_costs);
905
906 if (VLOG_IS_ON(2)) {
907 // Also keep track of op counts and costs per op (with their shapes).
908 string node_description = GetOpDescription(op_context.op_info);
909 op_counts_[node_description] += 1;
910 op_costs_[node_description] =
911 std::make_pair(total_node_costs.execution_time.asMicroSeconds().count(),
912 !node_costs.inaccurate);
913 }
914
915 std::string device_name = node_state.device_name;
916 if (!override_device_name.empty()) {
917 // N.B. There's a chance that device_name doesn't exist in the device map
918 // (device_), but it's ok because we'll effectively create a new device the
919 // first time a new device is seen.
920 device_name = override_device_name;
921 }
922
923 // Update node and device states.
924 auto& device = device_[device_name];
925 device.nodes_executed.push_back(node);
926 // Node is scheduled when the device is available AND all the inputs are
927 // ready; hence, time_scheduled is time_ready if time_ready > device curr
928 // time.
929 // NodeState times are assigned infinity at initialization. If they are
930 // still infinity here, we need to assign them. If not, it has been assigned
931 // already, so skip. This latter case may occur when a scheduler in-lines
932 // function calls, and thus schedules only function sub-nodes.
933 if (node_state.time_scheduled == Costs::Duration().infinity()) {
934 node_state.time_scheduled =
935 std::max(device.GetCurrTime(), node_state.time_ready);
936 // Override device curr time with the time_scheduled.
937 device.device_costs.execution_time = node_state.time_scheduled;
938 }
939 device.device_costs = CombineCosts(device.device_costs, total_node_costs);
940 auto curr_time = device.GetCurrTime();
941 node_state.time_finished = curr_time;
942
943 // Update shape annotation states.
944 UpdateDeviceAnnotationState(node, node_state, &device);
945
946 // Update device memory usage.
947 if (!IsPersistent(*node)) {
948 for (const auto& port_num_output_pair : node_state.outputs) {
949 int port_num = port_num_output_pair.first;
950 // There's a chance that a specific output is not used at all.
951 if (node_state.outputs[port_num].empty()) {
952 node_state.time_no_references[port_num] = curr_time;
953 } else {
954 // Allow for the possibility that some ports may be persistent even if
955 // the entire node is not labeled persistent.
956 if (node_state.node_costs.persistent_output_ports.contains(port_num)) {
957 continue;
958 }
959
960 // Streaming outputs do not allocate memory, they are directly consumed
961 // by the target node.
962 if (!IsStreamingPort(*node, port_num)) {
963 // If possible use the node output size calculations done by the
964 // more specific CostEstimator over the general CalculateOutputSize.
965 device.memory_usage += GetOrCalculateOutputSize(node_state, port_num);
966 }
967 device.nodes_in_memory.insert(std::make_pair(node, port_num));
968 }
969 }
970 }
971
972 // Update device state persistent node map.
973 for (const auto& port : node_costs.persistent_output_ports) {
974 device.persistent_nodes.insert({node, port});
975 }
976
977 // Update device's per-op cost.
978 auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost);
979 device_op_cost = CombineCosts(device_op_cost, total_node_costs);
980
981 VLOG(3) << "Op scheduled -- name: " << node->name() << ", op: " << node->op()
982 << ", device: " << node->device()
983 << ", execution_count: " << node_state.execution_count
984 << ", ready: " << node_state.time_ready.count()
985 << ", scheduled: " << node_state.time_scheduled.count()
986 << ", finished: " << node_state.time_finished.count();
987 VLOG(5) << " Current device memory usage (before deallocation): "
988 << device.memory_usage;
989 std::vector<const NodeDef*> new_nodes;
990 if (previously_executed_merge) {
991 // Skip AddOutputNodesToReadyQueue; this is due to Switch-Merge.
992 VLOG(1) << "node [ " << node->name() << ", " << node->op() << " ] "
993 << "is executed more than once. "
994 << "Skip scheduling its output nodes.";
995 } else {
996 // Checks outputs, and adds ready nodes to queue.
997 GetOutputNodes(node, curr_time, &new_nodes);
998 }
999
1000 // When op is scheduled, both input and output tensors must be allocated in
1001 // memory. Now that output memory is added, check max memory usage.
1002 if (!IsPersistent(*node)) {
1003 if (device.memory_usage > device.max_memory_usage) {
1004 device.max_memory_usage = device.memory_usage;
1005
1006 if (track_mem_usage_snapshot_) {
1007 device.mem_usage_snapshot_at_peak = device.nodes_in_memory;
1008 }
1009 }
1010 }
1011
1012 // Append the current temporary memory usage of the device to the memory usage
1013 // trace.
1014 if (track_mem_usage_snapshot_) {
1015 device.temporary_memory_usage_trace.push_back(
1016 {node->name(), device.memory_usage});
1017 }
1018
1019 // Increment num_outputs_executed of the input nodes and maybe update memory.
1020 for (const auto& input_port : node_state.inputs) {
1021 auto* input = input_port.first;
1022 auto port = input_port.second;
1023
1024 auto& input_state = node_map_[input];
1025 input_state.num_outputs_executed[port]++;
1026 int input_state_outputs_size_ = input_state.outputs[port].size();
1027
1028 // Allow for the possibility that some outputs may be persistent even if the
1029 // entire node is not labeled persistent.
1030 if (input_state.node_costs.persistent_output_ports.contains(port)) continue;
1031
1032 if (input_state.num_outputs_executed[port] == input_state_outputs_size_ &&
1033 !IsPersistent(*input)) {
1034 // All the outputs are executed; no reference to this output port of
1035 // input node.
1036 input_state.time_no_references[port] = curr_time;
1037 auto& input_device = device_[input_state.device_name];
1038 // If the node input is marked as streaming, then it wasn't allocated
1039 // in memory. A streaming input is still reference counted, but it doesn't
1040 // de-allocate memory.
1041 if (!IsStreamingPort(*input, port)) {
1042 input_device.memory_usage -=
1043 GetOrCalculateOutputSize(input_state, port);
1044 }
1045
1046 input_device.nodes_in_memory.erase(std::make_pair(input, port));
1047 }
1048 }
1049
1050 return new_nodes;
1051 }
1052
Summary() const1053 Costs SchedulerState::Summary() const {
1054 // Overall statement about accuracy
1055 VLOG(1) << graph_costs_.num_ops_total << " ops processed in total, with "
1056 << graph_costs_.num_ops_with_unknown_shapes
1057 << " having unknown shapes";
1058
1059 // Print out basic execution summary.
1060 VLOG(1) << "Expected execution time: " << graph_costs_.execution_time.count();
1061 VLOG(1) << "Expected compute time: " << graph_costs_.compute_time.count();
1062 VLOG(1) << "Expected memory time: " << graph_costs_.memory_time.count();
1063 VLOG(1) << "Expected intermediate memory time: "
1064 << graph_costs_.intermediate_memory_time.count();
1065 VLOG(1) << "Expected max memory: " << graph_costs_.max_memory;
1066 VLOG(1) << "Expected max per-op buffers: " << graph_costs_.max_per_op_buffers;
1067 VLOG(1) << "Expected max per-op streaming buffers: "
1068 << graph_costs_.max_per_op_streaming;
1069
1070 VLOG(1) << "Per-op execution time / compute time / memory time"
1071 << " / intermediate memory time:";
1072 for (const auto& op_cost_pair : op_to_cost_) {
1073 const auto& op = op_cost_pair.first;
1074 const auto& cost = op_cost_pair.second.execution_time.count();
1075 const auto& compute_cost = op_cost_pair.second.compute_time.count();
1076 const auto& memory_cost = op_cost_pair.second.memory_time.count();
1077 const auto& intermediate_memory_cost =
1078 op_cost_pair.second.intermediate_memory_time.count();
1079 const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
1080 if (cost) { // Skip printing out zero-cost ops.
1081 VLOG(1) << absl::StrFormat(" + %30s : %c %10d / %10d / %10d / %10d", op,
1082 (is_op_cost_accurate ? ' ' : '~'), cost,
1083 compute_cost, memory_cost,
1084 intermediate_memory_cost);
1085 }
1086 }
1087
1088 // Print per device summary
1089 VLOG(1) << "Devices:";
1090 Costs critical_path_costs = Costs::ZeroCosts();
1091 std::vector<string> device_names;
1092 device_names.reserve(device_.size());
1093 for (auto& it : device_) {
1094 device_names.push_back(it.first);
1095 }
1096 std::sort(device_names.begin(), device_names.end());
1097
1098 for (const auto& name : device_names) {
1099 const auto& state = device_.at(name);
1100
1101 std::map<string, int64_t> op_to_memory;
1102 // First profile only persistent memory usage.
1103 int64_t persistent_memory_usage = 0;
1104 std::set<string> persistent_ops;
1105 for (const auto& node_port : state.persistent_nodes) {
1106 const auto* node = node_port.first;
1107 const auto port = node_port.second;
1108 int64_t output_size = 0;
1109 // Check if the node is in the node_map. It may be that the node executed
1110 // on this device was executed by a different Scheduler.
1111 auto it = node_map_.find(node);
1112 if (it != node_map_.end()) {
1113 output_size = GetOrCalculateOutputSize(it->second, port);
1114 }
1115 persistent_memory_usage += output_size;
1116 op_to_memory[node->op()] += output_size;
1117 persistent_ops.insert(node->op());
1118 }
1119 int64_t max_memory_usage = persistent_memory_usage + state.max_memory_usage;
1120 critical_path_costs.estimated_max_memory_per_device[name] =
1121 max_memory_usage;
1122
1123 const Costs::NanoSeconds wall_time_ns = state.GetCurrTime();
1124 VLOG(1) << "Device = " << name
1125 << ", num_nodes = " << state.nodes_executed.size()
1126 << ", wall_time_ns = " << wall_time_ns.count() << ", memory usage: "
1127 << "persistent = " << HumanReadableNumBytes(persistent_memory_usage)
1128 << ", peak = " << HumanReadableNumBytes(state.max_memory_usage)
1129 << ", total = " << HumanReadableNumBytes(max_memory_usage)
1130 << ", at the end: " << HumanReadableNumBytes(state.memory_usage);
1131
1132 // Overall statement about accuracy
1133 VLOG(1) << state.device_costs.num_ops_total
1134 << " ops processed in total, with "
1135 << state.device_costs.num_ops_with_unknown_shapes
1136 << " having unknown shapes";
1137
1138 // Device shape annotation statistics.
1139 const auto& device_annotation_stats = state.shape_annotation_stats;
1140 if (device_annotation_stats.num_ops_annotated > 0) {
1141 VLOG(1) << device_annotation_stats.num_ops_annotated
1142 << " ops with shape annotation, with "
1143 << device_annotation_stats.num_ops_executed_more_than_once
1144 << " executed more than once, "
1145 << device_annotation_stats.num_ops_with_dynamic_shapes
1146 << " with dynamic shapes, "
1147 << device_annotation_stats.num_ops_with_incompatible_shapes
1148 << " with incompatible shapes, "
1149 << device_annotation_stats.num_ops_executed
1150 << " ops executed in total.";
1151 }
1152
1153 VLOG(1) << "Per-op execution time / compute time / memory time "
1154 << " / intermediate memory time"
1155 << " (and memory usage at peak memory usage):";
1156
1157 // Profile non-persistent op memory usage.
1158 for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
1159 const auto* node = node_port.first;
1160 const auto port = node_port.second;
1161 // Check if the node is in the node_map. It may be that the node executed
1162 // on this device was executed by a different Scheduler.
1163 auto it = node_map_.find(node);
1164 if (it != node_map_.end()) {
1165 op_to_memory[node->op()] += GetOrCalculateOutputSize(it->second, port);
1166 }
1167 }
1168 Costs::NanoSeconds total_compute_time_ns;
1169 bool is_total_cost_accurate = true;
1170 for (const auto& op_cost_pair : state.op_to_cost) {
1171 const auto& op = op_cost_pair.first;
1172 const auto& cost = op_cost_pair.second.execution_time.count();
1173 const auto& compute_cost = op_cost_pair.second.compute_time.count();
1174 const auto& memory_cost = op_cost_pair.second.memory_time.count();
1175 const auto& intermediate_memory_cost =
1176 op_cost_pair.second.intermediate_memory_time.count();
1177 total_compute_time_ns += op_cost_pair.second.execution_time;
1178 const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
1179 if (!is_op_cost_accurate) {
1180 is_total_cost_accurate = false;
1181 }
1182
1183 int64_t op_mem_usage = 0;
1184 auto it = op_to_memory.find(op);
1185 if (it != op_to_memory.end()) {
1186 op_mem_usage = it->second;
1187 }
1188
1189 const float mem_usage_percent =
1190 max_memory_usage > 0 ? Round2(100.0 * op_mem_usage / max_memory_usage)
1191 : 0.0;
1192 if (cost || mem_usage_percent > 1.0) {
1193 // Print out only non-zero cost ops or ops with > 1% memory usage.
1194 VLOG(1) << absl::StrFormat(
1195 " + %30s : %c %10d / %10d / %10d / %10d", op.c_str(),
1196 (is_op_cost_accurate ? ' ' : '~'), cost, compute_cost,
1197 memory_cost, intermediate_memory_cost)
1198 << " (" << HumanReadableNumBytes(op_mem_usage) << " ["
1199 << mem_usage_percent << "%] "
1200 << (persistent_ops.count(op) > 0 ? ": persistent op)" : ")");
1201 }
1202 }
1203
1204 int utilization = 0;
1205 if (wall_time_ns.count() > 0) {
1206 utilization = total_compute_time_ns.count() * 100 / wall_time_ns.count();
1207 }
1208 VLOG(1) << "Device = " << name << ", total_compute_time_ns = "
1209 << (is_total_cost_accurate ? "" : "~")
1210 << total_compute_time_ns.count()
1211 << ", utilization = " << utilization << "%";
1212
1213 if (critical_path_costs.execution_time <= state.GetCurrTime()) {
1214 critical_path_costs = state.device_costs;
1215 critical_path_costs.persistent_memory = persistent_memory_usage;
1216 critical_path_costs.temporary_memory = state.max_memory_usage;
1217 critical_path_costs.max_memory = max_memory_usage;
1218 }
1219 }
1220
1221 if (VLOG_IS_ON(2)) {
1222 // Also log the op description and their corresponding counts.
1223 VLOG(2) << "Node description, counts, cost:";
1224 for (const auto& item : op_counts_) {
1225 int cost;
1226 bool is_cost_accurate;
1227 std::tie(cost, is_cost_accurate) = op_costs_.at(item.first);
1228 VLOG(2) << "Node: " << item.first << ", Count: " << item.second
1229 << ", Individual Cost: " << (is_cost_accurate ? "" : "~") << cost
1230 << " us";
1231 }
1232 }
1233
1234 VLOG(1) << "Critical path execution time: "
1235 << critical_path_costs.execution_time.count();
1236 return critical_path_costs;
1237 }
1238
Summary(RunMetadata * metadata)1239 Costs SchedulerState::Summary(RunMetadata* metadata) {
1240 if (metadata) GenerateRunMetadata(metadata);
1241 return Summary();
1242 }
1243
GenerateRunMetadata(RunMetadata * metadata)1244 void SchedulerState::GenerateRunMetadata(RunMetadata* metadata) {
1245 // Fill RunMetadata's step_stats and partition_graphs fields.
1246 StepStats* stepstats = metadata->mutable_step_stats();
1247 for (const auto& device : device_) {
1248 GraphDef* device_partition_graph = metadata->add_partition_graphs();
1249 DeviceStepStats* device_stepstats = stepstats->add_dev_stats();
1250 device_stepstats->set_device(device.first);
1251 for (const auto& node_def : device.second.nodes_executed) {
1252 // Only proceed if the node is in the node_map. This is to cover the case
1253 // where a device has executed a node that is not in the node_map of
1254 // this scheduler.
1255 if (node_map_.find(node_def) == node_map_.end()) {
1256 continue;
1257 }
1258 const NodeState& nodestate = node_map_.at(node_def);
1259 NodeExecStats* node_stats = device_stepstats->add_node_stats();
1260 uint64 total_output_size = 0;
1261 uint64_t persistent_output_size = 0;
1262 for (int slot = 0, slot_end = nodestate.output_properties.size();
1263 slot < slot_end; slot++) {
1264 const auto& properties = nodestate.output_properties[slot];
1265 NodeOutput* no = node_stats->add_output();
1266 no->set_slot(slot);
1267 TensorDescription* tensor_descr = no->mutable_tensor_description();
1268 tensor_descr->set_dtype(properties.dtype());
1269 *tensor_descr->mutable_shape() = properties.shape();
1270 // Optional allocation description.
1271 const int64_t tensor_size_requested =
1272 CalculateOutputSize(nodestate.output_properties, slot);
1273 const int64_t tensor_size_allocated =
1274 GetOrCalculateOutputSize(nodestate, slot);
1275 total_output_size += tensor_size_allocated;
1276 if (nodestate.node_costs.persistent_output_ports.contains(slot)) {
1277 persistent_output_size += tensor_size_allocated;
1278 }
1279 tensor_descr->mutable_allocation_description()->set_requested_bytes(
1280 tensor_size_requested);
1281 tensor_descr->mutable_allocation_description()->set_allocated_bytes(
1282 tensor_size_allocated);
1283 }
1284 if (node_def->op() != "HloGenericOp") {
1285 node_stats->set_timeline_label(node_def->op());
1286 } else {
1287 // For HloGenericOp, display hlo_opcode as timeline label.
1288 string timeline_label;
1289 if (node_def->attr().count("hlo_opcode") > 0) {
1290 absl::StrAppend(&timeline_label,
1291 node_def->attr().at("hlo_opcode").s());
1292 }
1293 if (node_def->attr().count("_hlo_metadata_op_type") > 0) {
1294 absl::StrAppend(&timeline_label, "/",
1295 node_def->attr().at("_hlo_metadata_op_type").s());
1296 }
1297 node_stats->set_timeline_label(timeline_label);
1298 }
1299 node_stats->set_node_name(node_def->name());
1300 // Timestamps in microseconds (can be used by timeline_server).
1301 node_stats->set_op_start_rel_micros(0);
1302 node_stats->set_all_start_micros(
1303 nodestate.time_scheduled.asMicroSeconds().count());
1304 node_stats->set_op_end_rel_micros(
1305 nodestate.time_finished.asMicroSeconds().count() -
1306 nodestate.time_scheduled.asMicroSeconds().count());
1307 node_stats->set_all_end_rel_micros(
1308 nodestate.time_finished.asMicroSeconds().count() -
1309 nodestate.time_scheduled.asMicroSeconds().count());
1310 // Timestamps in nanoseconds (can be used by xprof trace).
1311 node_stats->set_op_start_rel_nanos(0);
1312 node_stats->set_all_start_nanos(nodestate.time_scheduled.count());
1313 node_stats->set_op_end_rel_nanos(nodestate.time_finished.count() -
1314 nodestate.time_scheduled.count());
1315 node_stats->set_all_end_rel_nanos(nodestate.time_finished.count() -
1316 nodestate.time_scheduled.count());
1317
1318 auto* mem_stats = node_stats->mutable_memory_stats();
1319 // SchedulerState does not specify scratch pad memory usage.
1320 mem_stats->set_temp_memory_size(0);
1321 int64_t persistent_memory_size = 0;
1322 if (IsPersistent(*node_def)) {
1323 persistent_memory_size = total_output_size;
1324 } else {
1325 persistent_memory_size = persistent_output_size;
1326 }
1327 mem_stats->set_persistent_memory_size(persistent_memory_size);
1328 *device_partition_graph->add_node() = *node_def;
1329 }
1330 }
1331 }
1332
GetPeakMemoryUsage() const1333 const std::unordered_map<string, int64_t> SchedulerState::GetPeakMemoryUsage()
1334 const {
1335 std::unordered_map<string, int64_t> result;
1336 for (const auto& device : device_) {
1337 const string& name = device.first;
1338 const DeviceState& state = device.second;
1339 result[name] = state.max_memory_usage;
1340 }
1341 return result;
1342 }
1343
1344 const std::unordered_map<string, int64_t>
GetPersistentMemoryUsage() const1345 SchedulerState::GetPersistentMemoryUsage() const {
1346 std::unordered_map<string, int64_t> result;
1347 for (const auto& device : device_) {
1348 const string& name = device.first;
1349 const DeviceState& state = device.second;
1350 int64_t persistent_memory_usage = 0;
1351 for (const auto& node_port : state.persistent_nodes) {
1352 const auto* node = node_port.first;
1353 const auto port = node_port.second;
1354 const auto& node_state = node_map_.at(node);
1355 persistent_memory_usage += GetOrCalculateOutputSize(node_state, port);
1356 }
1357 result[name] = persistent_memory_usage;
1358 }
1359 return result;
1360 }
1361
SetNodeStateTimeScheduled(const NodeDef * node)1362 void SchedulerState::SetNodeStateTimeScheduled(const NodeDef* node) {
1363 auto& node_state = node_map_.at(node);
1364 auto& device = device_[node_state.device_name];
1365 node_state.time_scheduled = device.GetCurrTime();
1366 }
1367
GetOrCalculateOutputSize(const NodeState & node_state,int port_num) const1368 int64_t SchedulerState::GetOrCalculateOutputSize(const NodeState& node_state,
1369 int port_num) const {
1370 auto& node_costs = node_state.node_costs;
1371 auto it = node_costs.output_tensor_size_bytes.find(port_num);
1372 if (it != node_costs.output_tensor_size_bytes.end()) {
1373 return it->second;
1374 }
1375 return CalculateOutputSize(node_state.output_properties, port_num);
1376 }
1377
~VirtualScheduler()1378 VirtualScheduler::~VirtualScheduler() {}
1379
VirtualScheduler(const bool use_static_shapes,const bool use_aggressive_shape_inference,Cluster * cluster,ReadyNodeManager * ready_nodes,std::unique_ptr<VirtualPlacer> placer)1380 VirtualScheduler::VirtualScheduler(const bool use_static_shapes,
1381 const bool use_aggressive_shape_inference,
1382 Cluster* cluster,
1383 ReadyNodeManager* ready_nodes,
1384 std::unique_ptr<VirtualPlacer> placer)
1385 : scheduler_state_(std::make_unique<SchedulerState>(
1386 use_static_shapes, use_aggressive_shape_inference, cluster,
1387 std::move(placer))),
1388 ready_nodes_(ready_nodes) {}
1389
VirtualScheduler(ReadyNodeManager * ready_nodes,std::unique_ptr<SchedulerState> scheduler_state)1390 VirtualScheduler::VirtualScheduler(
1391 ReadyNodeManager* ready_nodes,
1392 std::unique_ptr<SchedulerState> scheduler_state)
1393 : scheduler_state_(std::move(scheduler_state)), ready_nodes_(ready_nodes) {}
1394
Init(const GrapplerItem * item)1395 Status VirtualScheduler::Init(const GrapplerItem* item) {
1396 // SchedulerState::Init() preprocesses the input grappler_item and
1397 // graph_properties to extract necessary information for emulating tensorflow
1398 // op scheduling and construct internal data structures (NodeState and
1399 // DeviceState) for virtual scheduling.
1400 TF_RETURN_IF_ERROR(ready_nodes_->Init(GetNodeStates()));
1401 std::vector<const NodeDef*> initial_nodes;
1402 auto status = scheduler_state_->Init(item, &initial_nodes);
1403 if (status.ok()) {
1404 // Add the set of initial nodes to ready_nodes_
1405 for (auto node : initial_nodes) {
1406 ready_nodes_->AddNode(node);
1407 }
1408 }
1409 return status;
1410 }
1411
GetCurrNode()1412 OpContext VirtualScheduler::GetCurrNode() {
1413 const NodeDef* node = ready_nodes_->GetCurrNode();
1414 return scheduler_state_->CreateOpContext(node);
1415 }
1416
MarkCurrNodeExecuted(const Costs & node_costs)1417 bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
1418 // Update graph_costs_ and per-op costs.
1419 const NodeDef* node = ready_nodes_->GetCurrNode();
1420 auto new_nodes = scheduler_state_->MarkNodeExecuted(
1421 node, node_costs,
1422 scheduler_state_->CreateOpContext(ready_nodes_->GetCurrNode()));
1423 // Add the set of new nodes obtained from MarkNodeExecuted() to ready_nodes_.
1424 for (auto node : new_nodes) {
1425 ready_nodes_->AddNode(node);
1426 }
1427 ready_nodes_->RemoveCurrNode();
1428 return !ready_nodes_->Empty();
1429 }
1430
1431 } // end namespace grappler
1432 } // end namespace tensorflow
1433