1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ 18 19 #include <string> 20 #include <unordered_map> 21 #include <unordered_set> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/container/flat_hash_set.h" 25 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" 26 #include "tensorflow/core/distributed_runtime/worker_env.h" 27 #include "tensorflow/core/distributed_runtime/worker_session.h" 28 #include "tensorflow/core/framework/control_flow.h" 29 #include "tensorflow/core/framework/rendezvous.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/lib/hash/hash.h" 32 #include "tensorflow/core/platform/macros.h" 33 #include "tensorflow/core/platform/mutex.h" 34 #include "tensorflow/core/platform/thread_annotations.h" 35 #include "tensorflow/core/platform/types.h" 36 #include "tensorflow/core/util/device_name_utils.h" 37 38 namespace tensorflow { 39 40 class BaseRemoteRendezvous; 41 class BaseRecvTensorCall; 42 class CancellationManager; 43 44 // RendezvousMgr keeps track of a set of local rendezvous instances. 45 // All tensors sent by this worker are buffered in a RendezvousMgr 46 // until the tensor is received. Each global unique "step_id" 47 // corresponds to one local rendezvous instance managed by a 48 // RendezvousMgr. 49 // 50 // E.g., 51 // Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935); 52 // fork execution of a graph executor using "rendez" on thread 1; 53 // fork execution of another graph executor using "rendez" on thread 2; 54 // ... 55 // join threads 1 and 2; 56 // 57 // In the example above, execution in thread 1 and 2 communicates with 58 // each other by send/recv operations through `rendez`. 59 // 60 // Tensors sent and received through a rendezvous managed by this 61 // RendezvousMgr must have keys generated by Rendezvous::CreateKey(). 62 class BaseRendezvousMgr : public RendezvousMgrInterface { 63 public: 64 explicit BaseRendezvousMgr(const WorkerEnv* worker_env); 65 66 ~BaseRendezvousMgr() override; 67 68 // Returns Rendezvous supporting send and recv among workers in the 69 // "step_id". The caller takes ownership of one reference on the 70 // returned Rendezvous instance. 71 // 72 // Note: the caller must guarantee to eventually call Initialize on the 73 // returned RemoteRendezvous 74 RemoteRendezvous* Find(int64_t step_id) override; 75 76 // Finds the local rendezvous instance for the "step_id". Runs 77 // "done" when the tensor for "key" is produced or an error occurs. 78 // 79 // This method is used by the rpc handler of RecvTensor. 80 void RecvLocalAsync(int64_t step_id, const Rendezvous::ParsedKey& parsed, 81 Rendezvous::DoneCallback done) override; 82 83 // Synchronous wrapper for RecvLocalAsync. 84 Status RecvLocal(int64_t step_id, const Rendezvous::ParsedKey& parsed, 85 Tensor* val, bool* is_dead) override; 86 87 // Removes rendezvous for "step_id". 88 // 89 // TODO(zhifengc): Have a background thread in worker that 90 // periodically calls CleanupAll(). 91 void Cleanup(int64_t step_id) override; 92 93 // Remove all rendezvous instances owned by the rendezvous_mgr. 94 void CleanupAll() override; 95 96 protected: 97 virtual BaseRemoteRendezvous* Create(int64_t step_id, 98 const WorkerEnv* worker_env) = 0; 99 100 private: 101 // Maps step_id to rendezvous. 102 typedef absl::flat_hash_map<int64_t, BaseRemoteRendezvous*> Table; 103 104 // Not owned. 105 const WorkerEnv* const worker_env_; 106 107 mutex mu_; 108 Table table_ TF_GUARDED_BY(mu_); 109 110 BaseRemoteRendezvous* FindOrCreate(int64_t step_id); 111 112 TF_DISALLOW_COPY_AND_ASSIGN(BaseRendezvousMgr); 113 }; 114 115 // RemoteRendezvous is a Rendezvous which can handle either 116 // the producer or consumer being in a remote process. 117 // 118 // Buffering of Tensor values is delegated to a "local" Rendezvous 119 // obtained from NewLocalRendezvous(). This class just adds 120 // functionality to coordinate with remote workers. 121 class BaseRemoteRendezvous : public RemoteRendezvous { 122 public: 123 BaseRemoteRendezvous(const WorkerEnv* env, int64_t step_id); 124 125 // Upgrades the BaseRemoteRendezvous to full initialization. 126 Status Initialize(WorkerSession* session) override; 127 SetRemoteEagerContextDefault()128 void SetRemoteEagerContextDefault() override { 129 remote_eager_context_default_ = true; 130 } IsRemoteEagerContextDefault()131 bool IsRemoteEagerContextDefault() override { 132 return remote_eager_context_default_; 133 } 134 135 // Forwards to local_, where the Tensor "val" will be buffered and 136 // any waiting callback stored. 137 Status Send(const ParsedKey& key, const Rendezvous::Args& args, 138 const Tensor& val, const bool is_dead) override; 139 140 // This method is called only by the RecvOp. It tests to see 141 // whether the value will be produced by a local or remote device 142 // and handles accordingly. In the local case it forwards to 143 // local_, in the remote case it initiates an RPC request. 144 void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, 145 DoneCallback done) override; 146 147 void StartAbort(const Status& status) override; 148 149 // This method is called only by the local Worker, forwarded through 150 // the same method on RendezvousMgr. This occurs when the Worker 151 // has received a RecvTensor request, either locally or over the 152 // network. In either case it needs to retrieve a locally buffered 153 // value from local_, and give it to its caller. 154 // 155 // Runs "done" as soon as the tensor for "parsed" is available or an error 156 // is detected. 157 // 158 // REQUIRES: "parsed" is one that will be Saved into the local rendezvous. 159 void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done); 160 161 protected: 162 virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, 163 const Rendezvous::Args& args, 164 DoneCallback done) = 0; 165 166 // Returns true if "src" and "dst" are located in the same worker, 167 // and hence may use a local rendezvous. 168 virtual bool IsSameWorker(DeviceNameUtils::ParsedName src, 169 DeviceNameUtils::ParsedName dst); 170 171 // If aborted, aborts "call". Otherwise, adds "call" into calls_. 172 void RegisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args); 173 174 // Removes "call" from calls_ if "call" is in calls_. 175 void DeregisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args); 176 177 WorkerSession* session(); 178 179 bool is_initialized(); 180 181 ~BaseRemoteRendezvous() override; 182 183 const WorkerEnv* const env_; // Not owned. 184 const int64_t step_id_; 185 186 private: 187 Rendezvous* local_; // Owns a Ref on this object. 188 // Indicates whether this remote rendezvous instance is used as the default 189 // rendezvous for remote eager op-by-op execution. Errors in eager op-by-op 190 // execution should not abort the rendezvous since it is a context-wide 191 // instance and needs to be reused; instead, the errors are propagated through 192 // eager executors. 193 bool remote_eager_context_default_ = false; 194 195 mutable mutex mu_; 196 mutable mutex calls_mu_; 197 198 // Status given by StartAbort() if any. 199 Status status_ TF_GUARDED_BY(mu_); 200 201 WorkerSession* session_ TF_GUARDED_BY(mu_); // Not owned. 202 203 // Data structures to handle calls when partially initialized. 204 struct DeferredCall { 205 const ParsedKey parsed; 206 DoneCallback done; 207 208 DeferredCall(const ParsedKey& parsed, DoneCallback done); 209 }; 210 std::vector<DeferredCall> deferred_calls_ TF_GUARDED_BY(mu_); 211 212 // "CancellationToken" is stored here so that when there's no active 213 // RecvTensorCalls, we can de-register the callback in the cancellation 214 // manager. 215 // 216 // Note: pointer to CancellationManager can be nullptr in certain use cases. 217 absl::flat_hash_map< 218 CancellationManager*, 219 std::pair<CancellationToken, absl::flat_hash_set<BaseRecvTensorCall*>>> 220 calls_ TF_GUARDED_BY(calls_mu_); 221 is_initialized_locked()222 bool is_initialized_locked() TF_SHARED_LOCKS_REQUIRED(mu_) { 223 return session_ != nullptr; 224 } 225 226 // If "is_src" is true, checks that the rendezvous key "parsed"'s 227 // source is in this process. If "is_src" is false, checks that the 228 // rendezvous key "parsed"'s destination is in this process. 229 Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src); 230 231 // Callback handling the case when a rendezvous has been 232 // accomplished in local_ and the consumer is local to this process. 233 // Tensor "in" will be copied into "out". The key "parsed" encodes 234 // the src and dst devices. 235 void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed, 236 const Rendezvous::Args& in_args, 237 const Rendezvous::Args& out_args, const Tensor& in, 238 Tensor* out, StatusCallback done); 239 240 // Must be called only if fully initialized. 241 void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done); 242 243 TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous); 244 }; 245 246 class BaseRecvTensorCall { 247 public: BaseRecvTensorCall()248 BaseRecvTensorCall() {} ~BaseRecvTensorCall()249 virtual ~BaseRecvTensorCall() {} 250 251 virtual void Start(std::function<void()> recv_done) = 0; 252 253 virtual void StartAbort(const Status& s) = 0; 254 255 virtual Status status() const = 0; 256 257 private: 258 TF_DISALLOW_COPY_AND_ASSIGN(BaseRecvTensorCall); 259 }; 260 261 } // end namespace tensorflow 262 263 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ 264