1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 18 19 #include <functional> 20 #include <list> 21 #include <memory> 22 #include <string> 23 #include <unordered_map> 24 #include <unordered_set> 25 26 #include "tensorflow/core/framework/node_def.pb.h" 27 #include "tensorflow/core/framework/step_stats.pb.h" 28 #include "tensorflow/core/grappler/costs/cost_estimator.h" 29 #include "tensorflow/core/grappler/costs/graph_properties.h" 30 #include "tensorflow/core/grappler/costs/op_context.h" 31 #include "tensorflow/core/grappler/costs/virtual_placer.h" 32 #include "tensorflow/core/grappler/grappler_item.h" 33 34 namespace tensorflow { 35 namespace grappler { 36 37 ABSL_CONST_INIT extern const char kAttrInputSrc[]; 38 ABSL_CONST_INIT extern const char kAttrSrcDevice[]; 39 ABSL_CONST_INIT extern const char kAttrDstDevice[]; 40 ABSL_CONST_INIT extern const char kAttrTensorName[]; 41 ABSL_CONST_INIT extern const char kChannelDevice[]; 42 ABSL_CONST_INIT extern const char kStreaming[]; 43 44 struct NodeState { 45 // A node (i.e., an op) takes a set of input:port pairs and produces 46 // a set of output ports. 47 48 // Cross references to input and output nodes from graphdef. 49 std::vector<std::pair<const NodeDef*, int>> inputs; // Input, port pairs. 50 // List of output nodes (a list of nodes that takes this output port as input) 51 // keyed by port_num. Note that port_num -1 is used for control dependency. 52 std::unordered_map<int, std::vector<const NodeDef*>> outputs; 53 54 // Info from GraphProperties. 55 std::vector<OpInfo::TensorProperties> input_properties; 56 std::vector<OpInfo::TensorProperties> output_properties; 57 58 // Canonical device name used within VirtualScheduler. 59 string device_name; 60 61 // States updated as scheduling nodes. 62 int num_inputs_ready; 63 std::unordered_map<int, int> num_outputs_executed; 64 Costs::Duration time_ready; 65 Costs::Duration time_scheduled; 66 Costs::Duration time_finished; 67 // Time that all the consumers are executed (hence, no need to keep this 68 // output in memory), keyed by port_num. 69 std::unordered_map<int, Costs::Duration> time_no_references; 70 71 // Note that a node may have multiple output ports. The length of outputs, 72 // num_outputs_executed, and time_no_references should be 73 // identical when a NodeState is fully initialized. 74 // They should be 1 + output_properties.size() as we add [-1] for control 75 // dependency. 76 77 // Node will be ready to be executed at time_ready, scheduled at 78 // time_scheduled, and finishes execution at time_finished. 79 // Each output port uses up memory space from time_scheduled to its 80 // time_no_references. 81 82 Costs node_costs; // Node costs per execution TotalNodeCostsNodeState83 Costs TotalNodeCosts() const { 84 return MultiplyCosts(node_costs, execution_count); 85 } 86 // How many times this node has been executed, e.g. in a while loop. 87 int execution_count; 88 89 // Output shape incompatible between shape annotation and shape inference. 90 bool shape_incompatible; 91 NodeStateNodeState92 NodeState() { 93 num_inputs_ready = 0; 94 time_ready = Costs::Duration::max(); 95 time_scheduled = Costs::Duration::max(); 96 time_finished = Costs::Duration::max(); 97 execution_count = 0; 98 shape_incompatible = false; 99 // Note that num_outputs_executed and time_no_references are not initialized 100 // here, since we don't know the size (i.e., # outputs for this node). 101 } 102 }; 103 104 struct DeviceState { 105 // Nodes executed on this device in execution order. 106 std::vector<const NodeDef*> nodes_executed; 107 108 struct NodePairHash { 109 public: operatorDeviceState::NodePairHash110 const std::size_t operator()( 111 const std::pair<const NodeDef*, int>& element) const { 112 return std::hash<const NodeDef*>()(element.first); 113 } 114 }; 115 116 // Nodes currently allocated in memory: set of NodeDef* and port_num pairs 117 // so that we can track which output of the node is in memory. 118 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 119 nodes_in_memory; 120 121 // Nodes allocated in memory persistently: e.g., Variables. 122 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 123 persistent_nodes; 124 125 // Snapshot of nodes_in_memory, when memory usage is at peak. 126 // Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs. 127 std::unordered_set<std::pair<const NodeDef*, int>, NodePairHash> 128 mem_usage_snapshot_at_peak; 129 130 // Vector of temporary memory usage trace in execution order. 131 // Each pair represents the current node name and current (accumulated) 132 // temporary memory usage of the device when the node is scheduled. 133 // Only enabled when mem_usage_tracking is enabled. 134 // Note: CPU uses an inter-op threadpool, so the execution order on CPU may 135 // not be deterministic. 136 std::vector<std::pair<std::string, int64_t>> temporary_memory_usage_trace; 137 138 Costs device_costs; 139 std::map<string, Costs> op_to_cost; // Per-op cost. 140 141 int64_t memory_usage; // Current temporary memory usage 142 int64_t max_memory_usage; // Max temporary memory usage 143 144 // Shape annotation statistics. 145 struct ShapeAnnotationStats { 146 // Number of ops with shape annotated. 147 int64_t num_ops_annotated = 0; 148 // Number of ops executed multiple times (e.g. in a loop). 149 int64_t num_ops_executed_more_than_once = 0; 150 // Number of ops executed: account for execution count. 151 int64_t num_ops_executed = 0; 152 // Number of ops with dynamic shapes (e.g. shape changes in a loop). 153 int64_t num_ops_with_dynamic_shapes = 0; 154 // Number of ops with incompatible shapes between annotation and shape 155 // inference. 156 int64_t num_ops_with_incompatible_shapes = 0; 157 } shape_annotation_stats; 158 DeviceStateDeviceState159 DeviceState() { 160 device_costs = Costs::ZeroCosts(); 161 device_costs.num_ops_total = 0; 162 memory_usage = 0; 163 max_memory_usage = 0; 164 } 165 GetCurrTimeDeviceState166 Costs::Duration GetCurrTime() const { return device_costs.execution_time; } 167 }; 168 169 // ReadyNodeManager (abstract class): 170 // Keeps ready nodes and picks the best one to be scheduled. 171 class ReadyNodeManager { 172 public: ReadyNodeManager()173 ReadyNodeManager() {} ~ReadyNodeManager()174 virtual ~ReadyNodeManager() {} Init(const std::unordered_map<const NodeDef *,NodeState> * node_map)175 virtual Status Init( 176 const std::unordered_map<const NodeDef*, NodeState>* node_map) { 177 return OkStatus(); 178 } 179 virtual void AddNode(const NodeDef* node) = 0; 180 virtual const NodeDef* GetCurrNode() = 0; 181 virtual void RemoveCurrNode() = 0; 182 virtual bool Empty() const = 0; 183 }; 184 185 class FIFOManager : public ReadyNodeManager { 186 public: FIFOManager()187 FIFOManager() : ReadyNodeManager() {} ~FIFOManager()188 ~FIFOManager() override {} AddNode(const NodeDef * node)189 void AddNode(const NodeDef* node) override { nodes_.push_back(node); } GetCurrNode()190 const NodeDef* GetCurrNode() override { 191 CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; 192 return nodes_.front(); 193 } RemoveCurrNode()194 void RemoveCurrNode() override { nodes_.pop_front(); } Empty()195 bool Empty() const override { return nodes_.empty(); } 196 197 private: 198 std::list<const NodeDef*> nodes_; 199 }; 200 201 // The LIFOManager schedules nodes by returning the last one added to the 202 // scheduler. A node is executed and then its ready outputs are newly added to 203 // the scheduler, so the LIFOManager will return outputs to a node following 204 // that node's execution. 205 class LIFOManager : public ReadyNodeManager { 206 public: LIFOManager()207 LIFOManager() : ReadyNodeManager() {} ~LIFOManager()208 ~LIFOManager() override {} 209 void AddNode(const NodeDef* node) override; 210 const NodeDef* GetCurrNode() override; 211 void RemoveCurrNode() override; Empty()212 bool Empty() const override { return nodes_.empty(); } 213 214 private: 215 std::list<const NodeDef*> nodes_; 216 // Keep track of the current node being executed by saving its position. 217 // Necessary because nodes may be added to the end of the list while a node is 218 // executing, and we want to remove the correct node (the one that is 219 // executing) rather than the new ones being added. 220 std::list<const NodeDef*>::iterator curr_pos_ = nodes_.end(); 221 }; 222 223 // Abstract class that maintains a heap/priority queue for scheduling ready 224 // nodes. Derived class needs to implement the Greater() function which returns 225 // the comparator for the heap. 226 class HeapReadyManager : public ReadyNodeManager { 227 public: 228 HeapReadyManager(); 229 Status Init( 230 const std::unordered_map<const NodeDef*, NodeState>* node_map) override; ~HeapReadyManager()231 ~HeapReadyManager() override {} 232 void AddNode(const NodeDef* node) override; 233 const NodeDef* GetCurrNode() override; 234 void RemoveCurrNode() override; 235 bool Empty() const override; 236 237 protected: 238 virtual std::function<bool(const NodeDef*, const NodeDef*)> Greater() = 0; 239 240 // nodes_ is the main queue, where we construct heap, and the front is the 241 // current node. 242 std::vector<const NodeDef*> nodes_; 243 244 // Comparator functor for heap; stl heap is max heap, so we use "greater than" 245 // functor for keeping the smallest time_ready node at the front of heap. 246 std::function<bool(const NodeDef*, const NodeDef*)> greater_; 247 248 // NodeState structure from SchedulerState to get time_ready of ready nodes. 249 // Not owned by FirstReadyManager. 250 const std::unordered_map<const NodeDef*, NodeState>* node_map_; 251 252 // Cached curr node. Set back to nullptr from RemoveCurrNode(). 253 const NodeDef* curr_node_; 254 }; 255 256 // FirstReadyManager picks a node with the minimum time_ready value. 257 // Behavior is deterministic when there are more than one nodes with the minimum 258 // time_ready value with unique node names as the tie-breaker. 259 class FirstReadyManager : public HeapReadyManager { 260 public: FirstReadyManager()261 FirstReadyManager() : HeapReadyManager() {} ~FirstReadyManager()262 ~FirstReadyManager() override {} 263 264 protected: 265 std::function<bool(const NodeDef*, const NodeDef*)> Greater() override; 266 }; 267 268 // PriorityReadyManager uses the given node priorities when picking up next node 269 // from all the ready nodes. 270 class PriorityReadyManager : public HeapReadyManager { 271 public: PriorityReadyManager()272 PriorityReadyManager() : HeapReadyManager() {} ~PriorityReadyManager()273 ~PriorityReadyManager() override {} 274 void AddNode(const NodeDef* node) override; 275 276 // Note this should be called after Init(). 277 Status SetPriority(const std::unordered_map<string, int>& node_priority); 278 279 protected: 280 std::function<bool(const NodeDef*, const NodeDef*)> Greater() override; 281 282 private: 283 // A map from unique node name to priority. Lower number means higher 284 // priority. 285 std::unordered_map<string, int> node_priority_; 286 }; 287 288 // CompositeNodeManager has a few other NodeManagers: per-device LIFO for normal 289 // ops (neither _Send nor _Recv) and FirstReadyManagers for _Send ops and _Recv 290 // ops, and then it chooses FirstReady among the ops chosen from each 291 // internal NodeManagers. The objective is to maximize producer-consumer 292 // locality within device, while processing nodes across devices, including 293 // _Send and _Recv, fairly, in terms of their time_ready. 294 class CompositeNodeManager : public ReadyNodeManager { 295 public: 296 CompositeNodeManager(); ~CompositeNodeManager()297 ~CompositeNodeManager() override {} 298 299 Status Init( 300 const std::unordered_map<const NodeDef*, NodeState>* node_map) override; 301 void AddNode(const NodeDef* node) override; 302 const NodeDef* GetCurrNode() override; 303 void RemoveCurrNode() override; 304 bool Empty() const override; 305 306 private: 307 // Internal ready node managers: 308 // LIFO for normal ops to maximize producer consumer locality. 309 // One LIFO per device. 310 std::unordered_map<string, LIFOManager> ops_lifo_map_; 311 // FirstReady for send and recv. Handle send and recv separately ensures that 312 // send and recv do not block previously read ops with LIFO schedule. 313 FirstReadyManager send_manager_; 314 FirstReadyManager recv_manager_; 315 316 // NodeState structure from SchedulerState to get time_ready of ready nodes. 317 // Not owned by CompositeReadyManager. 318 const std::unordered_map<const NodeDef*, NodeState>* node_map_; 319 320 // Cached curr node. Set back to nullptr from RemoveCurrNode(). 321 const NodeDef* curr_node_; 322 }; 323 324 // Constructs a ready node manager from the given string. 325 std::unique_ptr<ReadyNodeManager> ReadyNodeManagerFactory( 326 const string& ready_node_manager); 327 328 // Encapsulates all of the various pieces uses to track state of a scheduler; 329 // enables reuse of all scheduler state-related utilities across different 330 // scheduler implementations. 331 class SchedulerState { 332 public: 333 SchedulerState(const bool use_static_shapes, 334 const bool use_aggressive_shape_inference, Cluster* cluster, 335 std::unique_ptr<VirtualPlacer> placer); 336 // Move constructor. Explicitly defined because it otherwise gets implicitly 337 // deleted. SchedulerState is a move-only class, as we have a <unique_ptr> 338 // for it in VirtualScheduler. A derivative of VirtualScheduler can move a 339 // <unique_ptr> SchedulerState to VirtualScheduler when it is constructed, 340 // which is where this move constructor is needed. 341 SchedulerState(SchedulerState&& arg) = default; 342 // We explicitly delete assinment and copy operators, this is done implicitly, 343 // but we state it here explicitly for clarity. 344 SchedulerState& operator=(SchedulerState&& arg) = delete; 345 SchedulerState(const SchedulerState&) = delete; 346 SchedulerState& operator=(const SchedulerState&) = delete; 347 // Destructor. Must be defined such that a derivative class can override it 348 // and allow proper desctruction of the derivative class. If this is not done 349 // properly, memory leaks can occur. 350 virtual ~SchedulerState(); 351 // Sets up the graph while also performing some necessary transformations 352 // initial_nodes is the set of nodes (primary inputs) discovered by Init() 353 // which may be added by a ReadyNodeManager (or related/derivative scheduler) 354 // to begin node schedule and graph simulation. 355 Status Init(const GrapplerItem* item, 356 std::vector<const NodeDef*>* initial_nodes, 357 bool create_explicit_channel_device = true); 358 359 virtual Costs Summary() const; 360 // Like the above, but writes detailed stats to RunMetadata. 361 // If metadata is nullptr, then just calls and return Summary(). 362 virtual Costs Summary(RunMetadata* metadata); 363 // Generates RunMetadata's step_stats and partition_graphs fields from results 364 // of the virtual execution of the graph. 365 // TODO(rdegruijl) See if we can make this function and caller Summary() 366 // const. 367 void GenerateRunMetadata(RunMetadata* metadata); 368 369 // Returns per device memory usage. 370 const std::unordered_map<string, int64_t> GetPeakMemoryUsage() const; 371 const std::unordered_map<string, int64_t> GetPersistentMemoryUsage() const; enable_mem_usage_tracking()372 void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; } 373 // Returns (read only) device and node states. GetDeviceStates()374 const std::unordered_map<string, DeviceState>* GetDeviceStates() const { 375 return &device_; 376 } 377 GetNodeStates()378 const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const { 379 return &node_map_; 380 } 381 382 virtual OpContext CreateOpContext(const NodeDef* node) const; 383 std::vector<const NodeDef*> MarkNodeExecuted( 384 const NodeDef* node, const Costs& node_costs, const OpContext& op_context, 385 bool extract_execution_count_attr = true, 386 const std::string& override_device_name = ""); 387 388 // Some getter functions. GetGrapplerItem()389 const GrapplerItem* GetGrapplerItem() { return grappler_item_; } GetGraphCost()390 Costs GetGraphCost() { return graph_costs_; } GetCluster()391 Cluster* GetCluster() { return cluster_; } GetUseStaticShape()392 bool GetUseStaticShape() { return use_static_shapes_; } GetUseAggressiveShapeInference()393 bool GetUseAggressiveShapeInference() { 394 return use_aggressive_shape_inference_; 395 } GetNodeMap()396 const std::unordered_map<const NodeDef*, NodeState>& GetNodeMap() { 397 return node_map_; 398 } 399 400 protected: 401 // Assigns the time_scheduled in the NodeState of node to the current 402 // execution_time of the device executing this node. 403 void SetNodeStateTimeScheduled(const NodeDef* node); 404 405 // This method can be used by a class derived from SchedulerState to 406 // access the device state map. GetMutableDeviceState()407 std::unordered_map<string, DeviceState>* GetMutableDeviceState() { 408 return &device_; 409 } 410 411 private: 412 // Methods called from Init(). Fails if initialize_ is set. 413 414 void MaybeUpdateInputOutput(const NodeDef* node); 415 NodeState& GetNodeStateOrCreateIt(const NodeDef* node); 416 // Creates a Send_ and Recv_ pair between from and to. The argument 417 // create_channel_device tells the function to create an explicit device for 418 // the channel. 419 std::pair<const NodeDef*, const NodeDef*> CreateSendRecv( 420 const NodeDef* from, const NodeDef* to, const NodeDef* input_node, 421 const string& input_name, bool create_channel_device); 422 string DeviceName(const NodeDef* node) const; 423 string SanitizedDeviceName(const NodeDef* node) const; 424 string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const; 425 426 // Helper methods. 427 void GetOutputNodes(const NodeDef* node, const Costs::Duration& curr_time, 428 std::vector<const NodeDef*>* output_nodes); 429 // Retrieves output size from node_cost at a port_num. If the output size has 430 // not been set, defaults back to CalculateOutputSize. 431 int64_t GetOrCalculateOutputSize(const NodeState& node_state, 432 int port_num) const; 433 434 std::unordered_map<const NodeDef*, NodeState> node_map_; 435 std::unordered_map<string, DeviceState> device_; 436 437 // Pool of NodeDefs for SendRecv and Identity ops created. 438 std::vector<std::unique_ptr<NodeDef>> additional_nodes_; 439 440 // Stats: 441 // Op counts with key with input shape. 442 // Example key: "[Op=AssignSub, input_shapes=[[7,1,160,160][7,1,160,160]]" 443 std::map<string, int> op_counts_; 444 // Individual op costs with key with input shape. 445 // Integer field for execution time in micro seconds. 446 // Boolean field for whether the cost is accurate. 447 std::map<string, std::pair<int, bool>> op_costs_; 448 449 Costs graph_costs_; // Graph cost. 450 std::map<string, Costs> op_to_cost_; // Per-op cost. 451 452 // Auxiliary data structures for constructing NodeState and DeviceState. 453 std::unique_ptr<GraphProperties> graph_properties_; // Initialized in Init(). 454 Cluster* cluster_; // Not owned. 455 const GrapplerItem* grappler_item_; // Not owned. 456 bool use_static_shapes_; 457 bool initialized_; 458 bool track_mem_usage_snapshot_; 459 const bool use_aggressive_shape_inference_; 460 std::unique_ptr<VirtualPlacer> placer_; 461 }; 462 463 // The virtual scheduler emulates execution of nodes in a graph, considering 464 // dependencies, device, etc. 465 class VirtualScheduler { 466 public: 467 // Does not take ownership of cluster or ready_nodes. 468 VirtualScheduler(const bool use_static_shapes, 469 const bool use_aggressive_shape_inference, Cluster* cluster, 470 ReadyNodeManager* ready_nodes, 471 std::unique_ptr<VirtualPlacer> placer); 472 // This constructor can be called by a derivative of VirtualScheduler to 473 // construct the base class. It lets VirtualScheduler take ownership of 474 // a new SchedulerState or a derivative thereof. 475 // Note that this constructor does not set a VirtualPlacer, in this 476 // constructor the VirtialPlacer is passed as a member of the SchedulerState 477 // that is passed as an argument. 478 VirtualScheduler(ReadyNodeManager* ready_nodes, 479 std::unique_ptr<SchedulerState> scheduler_state); 480 virtual ~VirtualScheduler(); 481 482 // Initializes the scheduler for the specific grappler item. 483 // Should be called immediately after the c'tor or when the scheduler will be 484 // reused for a new grappler item. All internal states of the scheduler 485 // related to the previous grappler item will be reset/cleared. 486 // 487 // This function should be called at least once after the scheduler is 488 // constructed. An uninitialized or failed-to-initialize scheduler will cause 489 // undefined behavior. 490 virtual Status Init(const GrapplerItem* item); 491 492 // Gets the current scheduled node for execution; the caller of this function 493 // can accordingly simulate the execution of the current scheduled node. 494 virtual OpContext GetCurrNode(); 495 // Marks the current scheduled node as executed. Note that we should call this 496 // function only after the execution of the node has been simulated; 497 // node_costs_ capture the simulated costs of the node. 498 // Returns true if there is any node to be scheduled. 499 virtual bool MarkCurrNodeExecuted(const Costs& node_costs); 500 501 // Prints out summary of execution (timing, memory usage, etc.) Summary()502 Costs Summary() const { return scheduler_state_->Summary(); } 503 // Like the above, but writes detailed stats to RunMetadata. 504 // If metadata is nullptr, then just calls and return Summary(). Summary(RunMetadata * metadata)505 Costs Summary(RunMetadata* metadata) { 506 return scheduler_state_->Summary(metadata); 507 } 508 // Generates RunMetadata's step_stats and partition_graphs fields from results 509 // of the virtual execution of the graph. GenerateRunMetadata(RunMetadata * metadata)510 void GenerateRunMetadata(RunMetadata* metadata) { 511 scheduler_state_->GenerateRunMetadata(metadata); 512 } 513 // Returns per device memory usage. GetPeakMemoryUsage()514 const std::unordered_map<string, int64_t> GetPeakMemoryUsage() const { 515 return scheduler_state_->GetPeakMemoryUsage(); 516 } GetPersistentMemoryUsage()517 const std::unordered_map<string, int64_t> GetPersistentMemoryUsage() const { 518 return scheduler_state_->GetPersistentMemoryUsage(); 519 } 520 // Returns VirtualScheduler (read only) device and node states. GetDeviceStates()521 const std::unordered_map<string, DeviceState>* GetDeviceStates() const { 522 return scheduler_state_->GetDeviceStates(); 523 } GetNodeStates()524 const std::unordered_map<const NodeDef*, NodeState>* GetNodeStates() const { 525 return scheduler_state_->GetNodeStates(); 526 } enable_mem_usage_tracking()527 void enable_mem_usage_tracking() { 528 scheduler_state_->enable_mem_usage_tracking(); 529 } 530 531 protected: 532 // The state of the scheduler and the execution of the graph is encapsulated 533 // by the scheduler_state_ object. 534 std::unique_ptr<SchedulerState> scheduler_state_; 535 // ready_nodes_ is responsible for ordering the traversal of the graph. 536 ReadyNodeManager* ready_nodes_; // Not owned. 537 }; 538 539 } // namespace grappler 540 } // end namespace tensorflow 541 542 #endif // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ 543