1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ 18 19 #include <unordered_map> 20 #include <vector> 21 22 #include "tensorflow/core/common_runtime/costmodel_manager.h" 23 #include "tensorflow/core/common_runtime/executor.h" 24 #include "tensorflow/core/common_runtime/process_function_library_runtime.h" 25 #include "tensorflow/core/distributed_runtime/message_wrappers.h" 26 #include "tensorflow/core/distributed_runtime/worker_env.h" 27 #include "tensorflow/core/framework/cancellation.h" 28 #include "tensorflow/core/framework/collective.h" 29 #include "tensorflow/core/framework/cost_graph.pb.h" 30 #include "tensorflow/core/framework/function.h" 31 #include "tensorflow/core/lib/core/refcount.h" 32 #include "tensorflow/core/platform/logging.h" 33 #include "tensorflow/core/platform/macros.h" 34 #include "tensorflow/core/platform/mutex.h" 35 #include "tensorflow/core/platform/types.h" 36 #include "tensorflow/core/protobuf/config.pb.h" 37 #include "tensorflow/core/protobuf/debug.pb.h" 38 #include "tensorflow/core/protobuf/worker.pb.h" 39 40 namespace tensorflow { 41 42 class ExecutorOpts; 43 class StepStatsCollector; 44 class RendezvousMgrInterface; 45 class DeviceMgr; 46 class WorkerSession; 47 class CoordinationServiceAgent; 48 49 // GraphMgr keeps track of a set of graphs that are registered with a 50 // TensorFlow worker. Each registered graph is identified by a handle 51 // that is generated by GraphMgr and returned to the caller. 52 // 53 // After a successful registration, the caller executes a graph using 54 // the graph handle. Each execution is distinguished from others by a 55 // caller generated global unique id "step_id". Multiple executions 56 // can use the same graph concurrently and independently as long as 57 // "step_id" used are different. 58 // 59 // Multiple threads can call GraphMgr methods concurrently. 60 // 61 // E.g., 62 // GraphMgr gmgr(worker_env); 63 // string handle; 64 // TF_CHECK_OK(gmgr.Register("session", { graph computes c = a + b }, 65 // &handle)); 66 // GraphMgr::NamedTensors in = { { "a", Tensor({1, 2}) }, 67 // { "b", Tensor({3, 4}) } }; 68 // GraphMgr::NamedTensors out = { { "c", Tensor() } }; 69 // TF_CHECK_OK(gmgr.Execute(handle, 0x0001, in, &out)); 70 // EXPECT_EQ(out["c"], Tensor({4, 6})); 71 class GraphMgr { 72 public: 73 explicit GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr); 74 ~GraphMgr(); 75 76 // Registers a graph. Fills in "handle". The registered graph retains a 77 // reference to cluster_flr to do cross process function calls. 78 Status Register(const string& handle, const GraphDef& gdef, 79 const GraphOptions& graph_options, 80 const DebugOptions& debug_options, 81 const ConfigProto& config_proto, int64_t collective_graph_key, 82 WorkerSession* session, 83 DistributedFunctionLibraryRuntime* cluster_flr, 84 string* graph_handle); 85 86 // Executes one step of a registered graph "handle". 87 // 88 // If "out" is not nullptr, "out" specifies all keys the execution 89 // should receive upon finish. 90 typedef std::map<string, Tensor> NamedTensors; 91 typedef std::function<void(const Status&)> StatusCallback; 92 void ExecuteAsync(const string& handle, const int64_t step_id, 93 const ExecutorOpts& opts, const NamedTensors& in, 94 WorkerSession* session, StepStatsCollector* collector, 95 MutableRunGraphResponseWrapper* response, 96 CancellationManager* cancellation_manager, 97 CoordinationServiceAgent* coordination_service_agent, 98 StatusCallback done); 99 100 Status SendInputs(const int64_t step_id, const NamedTensors& in); 101 Status RecvOutputs(const int64_t step_id, NamedTensors* out); 102 void RecvOutputsAsync(const int64_t step_id, NamedTensors* out, 103 StatusCallback done); 104 105 // Deregisters a graph. 106 Status Deregister(const string& handle); 107 108 // Deregister all graphs. 109 Status DeregisterAll(); 110 111 private: 112 typedef GraphMgr ME; 113 114 struct ExecutionUnit { 115 std::unique_ptr<Graph> graph = nullptr; 116 Device* device = nullptr; // not owned. 117 Executor* root = nullptr; // not owned. 118 FunctionLibraryRuntime* lib = nullptr; // not owned. 119 // Build the cost model if this value is strictly positive. 120 int64_t build_cost_model = 0; 121 }; 122 123 struct Item : public core::RefCounted { 124 // TODO(zhifengc): Keeps a copy of the original graph if the need arises. 125 // TODO(zhifengc): Stats, updated by multiple runs potentially. 126 // TODO(zhifengc): Dup-detection. Ensure step_id only run once. 127 ~Item() override; 128 129 // Session handle. 130 string session; 131 132 // Graph handle. 133 string handle; 134 135 std::unique_ptr<FunctionLibraryDefinition> lib_def; 136 // Owns the FunctionLibraryRuntime objects needed to execute functions, one 137 // per device. 138 std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr; 139 // A graph is partitioned over multiple devices. Each partition 140 // has a root executor which may call into the runtime library. 141 std::vector<ExecutionUnit> units; 142 143 // Used to deregister a cost model when cost model is required in graph 144 // manager. 145 GraphMgr* graph_mgr; 146 147 int64_t collective_graph_key; 148 }; 149 150 const WorkerEnv* worker_env_; // Not owned. 151 const DeviceMgr* device_mgr_; 152 153 CostModelManager cost_model_manager_; 154 155 // Owned. 156 mutex mu_; 157 int64_t next_id_ TF_GUARDED_BY(mu_) = 0; 158 159 // If true, blocks until device has finished all queued operations in a step. 160 bool sync_on_finish_ = true; 161 162 // Table mapping graph handles to registered graphs. 163 // 164 // TODO(zhifengc): If the client does not call Deregister, we'll 165 // lose memory over time. We should implement a timeout-based 166 // mechanism to gc these graphs. 167 std::unordered_map<string, Item*> table_; 168 169 void StartParallelExecutors( 170 const string& handle, int64_t step_id, Item* item, Rendezvous* rendezvous, 171 CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector, 172 CostGraphDef* cost_graph, CancellationManager* cancellation_manager, 173 WorkerSession* session, int64_t start_time_usecs, 174 CoordinationServiceAgent* coordination_service_agent, 175 StatusCallback done); 176 177 // Don't attempt to process cost models unless explicitly requested for at 178 // least one of the items. 179 bool skip_cost_models_ = true; 180 181 void BuildCostModel(Item* item, StepStatsCollector* collector, 182 CostGraphDef* cost_graph); 183 184 Status InitItem(const string& handle, const GraphDef& gdef, 185 const GraphOptions& graph_options, 186 const DebugOptions& debug_options, 187 const ConfigProto& config_proto, int64_t collective_graph_key, 188 WorkerSession* session, 189 DistributedFunctionLibraryRuntime* cluster_flr, Item* item); 190 191 Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options, 192 Graph* graph, Device* device); 193 194 TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr); 195 }; 196 197 } // end namespace tensorflow 198 199 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ 200