xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/graph_mgr.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
18 
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/costmodel_manager.h"
23 #include "tensorflow/core/common_runtime/executor.h"
24 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
25 #include "tensorflow/core/distributed_runtime/message_wrappers.h"
26 #include "tensorflow/core/distributed_runtime/worker_env.h"
27 #include "tensorflow/core/framework/cancellation.h"
28 #include "tensorflow/core/framework/collective.h"
29 #include "tensorflow/core/framework/cost_graph.pb.h"
30 #include "tensorflow/core/framework/function.h"
31 #include "tensorflow/core/lib/core/refcount.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/protobuf/config.pb.h"
37 #include "tensorflow/core/protobuf/debug.pb.h"
38 #include "tensorflow/core/protobuf/worker.pb.h"
39 
40 namespace tensorflow {
41 
42 class ExecutorOpts;
43 class StepStatsCollector;
44 class RendezvousMgrInterface;
45 class DeviceMgr;
46 class WorkerSession;
47 class CoordinationServiceAgent;
48 
49 // GraphMgr keeps track of a set of graphs that are registered with a
50 // TensorFlow worker. Each registered graph is identified by a handle
51 // that is generated by GraphMgr and returned to the caller.
52 //
53 // After a successful registration, the caller executes a graph using
54 // the graph handle. Each execution is distinguished from others by a
55 // caller generated global unique id "step_id". Multiple executions
56 // can use the same graph concurrently and independently as long as
57 // "step_id" used are different.
58 //
59 // Multiple threads can call GraphMgr methods concurrently.
60 //
61 // E.g.,
62 //   GraphMgr gmgr(worker_env);
63 //   string handle;
64 //   TF_CHECK_OK(gmgr.Register("session", { graph computes c = a + b },
65 //   &handle));
66 //   GraphMgr::NamedTensors in = { { "a", Tensor({1, 2}) },
67 //                                { "b", Tensor({3, 4}) } };
68 //   GraphMgr::NamedTensors out = { { "c", Tensor() } };
69 //   TF_CHECK_OK(gmgr.Execute(handle, 0x0001, in, &out));
70 //   EXPECT_EQ(out["c"], Tensor({4, 6}));
71 class GraphMgr {
72  public:
73   explicit GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr);
74   ~GraphMgr();
75 
76   // Registers a graph. Fills in "handle". The registered graph retains a
77   // reference to cluster_flr to do cross process function calls.
78   Status Register(const string& handle, const GraphDef& gdef,
79                   const GraphOptions& graph_options,
80                   const DebugOptions& debug_options,
81                   const ConfigProto& config_proto, int64_t collective_graph_key,
82                   WorkerSession* session,
83                   DistributedFunctionLibraryRuntime* cluster_flr,
84                   string* graph_handle);
85 
86   // Executes one step of a registered graph "handle".
87   //
88   // If "out" is not nullptr, "out" specifies all keys the execution
89   // should receive upon finish.
90   typedef std::map<string, Tensor> NamedTensors;
91   typedef std::function<void(const Status&)> StatusCallback;
92   void ExecuteAsync(const string& handle, const int64_t step_id,
93                     const ExecutorOpts& opts, const NamedTensors& in,
94                     WorkerSession* session, StepStatsCollector* collector,
95                     MutableRunGraphResponseWrapper* response,
96                     CancellationManager* cancellation_manager,
97                     CoordinationServiceAgent* coordination_service_agent,
98                     StatusCallback done);
99 
100   Status SendInputs(const int64_t step_id, const NamedTensors& in);
101   Status RecvOutputs(const int64_t step_id, NamedTensors* out);
102   void RecvOutputsAsync(const int64_t step_id, NamedTensors* out,
103                         StatusCallback done);
104 
105   // Deregisters a graph.
106   Status Deregister(const string& handle);
107 
108   // Deregister all graphs.
109   Status DeregisterAll();
110 
111  private:
112   typedef GraphMgr ME;
113 
114   struct ExecutionUnit {
115     std::unique_ptr<Graph> graph = nullptr;
116     Device* device = nullptr;               // not owned.
117     Executor* root = nullptr;               // not owned.
118     FunctionLibraryRuntime* lib = nullptr;  // not owned.
119     // Build the cost model if this value is strictly positive.
120     int64_t build_cost_model = 0;
121   };
122 
123   struct Item : public core::RefCounted {
124     // TODO(zhifengc): Keeps a copy of the original graph if the need arises.
125     // TODO(zhifengc): Stats, updated by multiple runs potentially.
126     // TODO(zhifengc): Dup-detection. Ensure step_id only run once.
127     ~Item() override;
128 
129     // Session handle.
130     string session;
131 
132     // Graph handle.
133     string handle;
134 
135     std::unique_ptr<FunctionLibraryDefinition> lib_def;
136     // Owns the FunctionLibraryRuntime objects needed to execute functions, one
137     // per device.
138     std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
139     // A graph is partitioned over multiple devices.  Each partition
140     // has a root executor which may call into the runtime library.
141     std::vector<ExecutionUnit> units;
142 
143     // Used to deregister a cost model when cost model is required in graph
144     // manager.
145     GraphMgr* graph_mgr;
146 
147     int64_t collective_graph_key;
148   };
149 
150   const WorkerEnv* worker_env_;  // Not owned.
151   const DeviceMgr* device_mgr_;
152 
153   CostModelManager cost_model_manager_;
154 
155   // Owned.
156   mutex mu_;
157   int64_t next_id_ TF_GUARDED_BY(mu_) = 0;
158 
159   // If true, blocks until device has finished all queued operations in a step.
160   bool sync_on_finish_ = true;
161 
162   // Table mapping graph handles to registered graphs.
163   //
164   // TODO(zhifengc): If the client does not call Deregister, we'll
165   // lose memory over time. We should implement a timeout-based
166   // mechanism to gc these graphs.
167   std::unordered_map<string, Item*> table_;
168 
169   void StartParallelExecutors(
170       const string& handle, int64_t step_id, Item* item, Rendezvous* rendezvous,
171       CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector,
172       CostGraphDef* cost_graph, CancellationManager* cancellation_manager,
173       WorkerSession* session, int64_t start_time_usecs,
174       CoordinationServiceAgent* coordination_service_agent,
175       StatusCallback done);
176 
177   // Don't attempt to process cost models unless explicitly requested for at
178   // least one of the items.
179   bool skip_cost_models_ = true;
180 
181   void BuildCostModel(Item* item, StepStatsCollector* collector,
182                       CostGraphDef* cost_graph);
183 
184   Status InitItem(const string& handle, const GraphDef& gdef,
185                   const GraphOptions& graph_options,
186                   const DebugOptions& debug_options,
187                   const ConfigProto& config_proto, int64_t collective_graph_key,
188                   WorkerSession* session,
189                   DistributedFunctionLibraryRuntime* cluster_flr, Item* item);
190 
191   Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options,
192                                          Graph* graph, Device* device);
193 
194   TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr);
195 };
196 
197 }  // end namespace tensorflow
198 
199 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
200