1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
16
17 #include "tensorflow/core/common_runtime/base_collective_executor.h"
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/collective_rma_local.h"
20 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
21 #include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
22 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
23 #include "tensorflow/core/distributed_runtime/worker_cache.h"
24 #include "tensorflow/core/lib/random/random.h"
25
26 namespace tensorflow {
27
RpcCollectiveExecutorMgr(const ConfigProto & config,const DeviceMgr * dev_mgr,std::unique_ptr<DeviceResolverDistributed> dev_resolver,std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,WorkerCacheInterface * worker_cache,const string & task_name)28 RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr(
29 const ConfigProto& config, const DeviceMgr* dev_mgr,
30 std::unique_ptr<DeviceResolverDistributed> dev_resolver,
31 std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
32 std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,
33 WorkerCacheInterface* worker_cache, const string& task_name)
34 : CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
35 std::move(param_resolver),
36 std::move(nccl_communicator)),
37 worker_cache_(worker_cache),
38 task_name_(task_name) {
39 group_leader_ = (task_name == config.experimental().collective_group_leader())
40 ? ""
41 : config.experimental().collective_group_leader();
42 }
43
~RpcCollectiveExecutorMgr()44 RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() {
45 for (auto it : sequence_table_) {
46 delete it.second;
47 }
48 }
49
Create(int64_t step_id)50 CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64_t step_id) {
51 CollectiveRemoteAccessDistributed* rma =
52 new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
53 work_queue_, worker_cache_, step_id,
54 task_name_);
55 return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_, work_queue_);
56 }
57
58 namespace {
59 // StepId must leave the most-significant 7 bits empty for future use.
60 static const int64_t kStepIdMask = (((1uLL << 56) - 1) | (1uLL << 56));
61
NewRandomStepId()62 int64_t NewRandomStepId() {
63 int64_t step_id = random::New64();
64 // Leave MS 8 bits clear for future use.
65 step_id &= kStepIdMask;
66 return step_id;
67 }
68 } // namespace
69
RefreshStepIdSequenceAsync(int64_t graph_key,const StatusCallback & done)70 void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync(
71 int64_t graph_key, const StatusCallback& done) {
72 if (group_leader_.empty()) {
73 mutex_lock l(sequence_mu_);
74 GraphKeySequence* gks = nullptr;
75 auto it = sequence_table_.find(graph_key);
76 if (it == sequence_table_.end()) {
77 gks = new GraphKeySequence(graph_key);
78 sequence_table_[graph_key] = gks;
79 } else {
80 gks = it->second;
81 }
82 gks->next_step_id_ = NewRandomStepId();
83 done(OkStatus());
84 } else {
85 WorkerInterface* wi = worker_cache_->GetOrCreateWorker(group_leader_);
86 GetStepSequenceRequest* req = new GetStepSequenceRequest;
87 GetStepSequenceResponse* resp = new GetStepSequenceResponse;
88 req->add_graph_key(graph_key);
89 wi->GetStepSequenceAsync(
90 req, resp, [this, req, resp, done](const Status& s) {
91 if (!s.ok()) {
92 LOG(ERROR) << "Bad response [" << s
93 << "] from GetStepSequenceAsync call to "
94 << group_leader_;
95 done(s);
96 } else {
97 done(UpdateStepSequences(*resp));
98 }
99 delete req;
100 delete resp;
101 });
102 }
103 }
104
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,const StatusCallback & done)105 void RpcCollectiveExecutorMgr::GetStepSequenceAsync(
106 const GetStepSequenceRequest* request, GetStepSequenceResponse* response,
107 const StatusCallback& done) {
108 if (!group_leader_.empty()) {
109 LOG(ERROR) << "GetStepSequence called at non-group-leader";
110 done(errors::Internal("GetStepSequenceAsync called at non-group-leader"));
111 } else {
112 mutex_lock l(sequence_mu_);
113 for (int64_t graph_key : request->graph_key()) {
114 auto it = sequence_table_.find(graph_key);
115 GraphKeySequence* gks = nullptr;
116 if (it == sequence_table_.end()) {
117 gks = new GraphKeySequence(graph_key);
118 gks->next_step_id_ = NewRandomStepId();
119 sequence_table_[graph_key] = gks;
120 } else {
121 gks = it->second;
122 }
123 StepSequence* ss = response->add_step_sequence();
124 ss->set_graph_key(graph_key);
125 ss->set_next_step_id(gks->next_step_id_);
126 }
127 done(OkStatus());
128 }
129 }
130
UpdateStepSequences(const GetStepSequenceResponse & resp)131 Status RpcCollectiveExecutorMgr::UpdateStepSequences(
132 const GetStepSequenceResponse& resp) {
133 mutex_lock l(sequence_mu_);
134 for (const StepSequence& ss : resp.step_sequence()) {
135 GraphKeySequence* gks = nullptr;
136 auto it = sequence_table_.find(ss.graph_key());
137 if (it == sequence_table_.end()) {
138 gks = new GraphKeySequence(ss.graph_key());
139 sequence_table_[ss.graph_key()] = gks;
140 } else {
141 gks = it->second;
142 }
143 gks->next_step_id_ = ss.next_step_id();
144 }
145 return OkStatus();
146 }
147
NextStepId(int64_t graph_key)148 int64_t RpcCollectiveExecutorMgr::NextStepId(int64_t graph_key) {
149 mutex_lock l(sequence_mu_);
150 auto it = sequence_table_.find(graph_key);
151 if (it != sequence_table_.end()) {
152 return it->second->next_step_id_;
153 }
154 return CollectiveExecutor::kInvalidId;
155 }
156
RetireStepId(int64_t graph_key,int64_t step_id)157 void RpcCollectiveExecutorMgr::RetireStepId(int64_t graph_key,
158 int64_t step_id) {
159 mutex_lock l(sequence_mu_);
160 auto it = sequence_table_.find(graph_key);
161 if (it != sequence_table_.end()) {
162 if (step_id == it->second->next_step_id_) {
163 it->second->next_step_id_ = (it->second->next_step_id_ + 1) & kStepIdMask;
164 } else {
165 it->second->next_step_id_ = CollectiveExecutor::kInvalidId;
166 }
167 } else {
168 LOG(ERROR) << "Failed to find graph_key " << graph_key << " to retire.";
169 }
170 }
171
CreateProdRpcCollectiveExecutorMgr(const ConfigProto & config,const DeviceMgr * device_mgr,std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,WorkerCacheInterface * worker_cache,const string & default_worker_name)172 std::unique_ptr<RpcCollectiveExecutorMgr> CreateProdRpcCollectiveExecutorMgr(
173 const ConfigProto& config, const DeviceMgr* device_mgr,
174 std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,
175 WorkerCacheInterface* worker_cache, const string& default_worker_name) {
176 auto dev_resolver = std::make_unique<DeviceResolverDistributed>(device_mgr);
177 auto param_resolver = std::make_unique<CollectiveParamResolverDistributed>(
178 config, device_mgr, dev_resolver.get(), nccl_communicator.get(),
179 worker_cache, default_worker_name);
180 return std::make_unique<RpcCollectiveExecutorMgr>(
181 config, device_mgr, std::move(dev_resolver), std::move(param_resolver),
182 std::move(nccl_communicator), worker_cache, default_worker_name);
183 }
184
185 } // namespace tensorflow
186