1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
16 
17 #include "tensorflow/core/common_runtime/base_collective_executor.h"
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/collective_rma_local.h"
20 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
21 #include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
22 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
23 #include "tensorflow/core/distributed_runtime/worker_cache.h"
24 #include "tensorflow/core/lib/random/random.h"
25 
26 namespace tensorflow {
27 
RpcCollectiveExecutorMgr(const ConfigProto & config,const DeviceMgr * dev_mgr,std::unique_ptr<DeviceResolverDistributed> dev_resolver,std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,WorkerCacheInterface * worker_cache,const string & task_name)28 RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr(
29     const ConfigProto& config, const DeviceMgr* dev_mgr,
30     std::unique_ptr<DeviceResolverDistributed> dev_resolver,
31     std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
32     std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,
33     WorkerCacheInterface* worker_cache, const string& task_name)
34     : CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
35                             std::move(param_resolver),
36                             std::move(nccl_communicator)),
37       worker_cache_(worker_cache),
38       task_name_(task_name) {
39   group_leader_ = (task_name == config.experimental().collective_group_leader())
40                       ? ""
41                       : config.experimental().collective_group_leader();
42 }
43 
~RpcCollectiveExecutorMgr()44 RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() {
45   for (auto it : sequence_table_) {
46     delete it.second;
47   }
48 }
49 
Create(int64_t step_id)50 CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64_t step_id) {
51   CollectiveRemoteAccessDistributed* rma =
52       new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
53                                             work_queue_, worker_cache_, step_id,
54                                             task_name_);
55   return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_, work_queue_);
56 }
57 
58 namespace {
59 // StepId must leave the most-significant 7 bits empty for future use.
60 static const int64_t kStepIdMask = (((1uLL << 56) - 1) | (1uLL << 56));
61 
NewRandomStepId()62 int64_t NewRandomStepId() {
63   int64_t step_id = random::New64();
64   // Leave MS 8 bits clear for future use.
65   step_id &= kStepIdMask;
66   return step_id;
67 }
68 }  // namespace
69 
RefreshStepIdSequenceAsync(int64_t graph_key,const StatusCallback & done)70 void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync(
71     int64_t graph_key, const StatusCallback& done) {
72   if (group_leader_.empty()) {
73     mutex_lock l(sequence_mu_);
74     GraphKeySequence* gks = nullptr;
75     auto it = sequence_table_.find(graph_key);
76     if (it == sequence_table_.end()) {
77       gks = new GraphKeySequence(graph_key);
78       sequence_table_[graph_key] = gks;
79     } else {
80       gks = it->second;
81     }
82     gks->next_step_id_ = NewRandomStepId();
83     done(OkStatus());
84   } else {
85     WorkerInterface* wi = worker_cache_->GetOrCreateWorker(group_leader_);
86     GetStepSequenceRequest* req = new GetStepSequenceRequest;
87     GetStepSequenceResponse* resp = new GetStepSequenceResponse;
88     req->add_graph_key(graph_key);
89     wi->GetStepSequenceAsync(
90         req, resp, [this, req, resp, done](const Status& s) {
91           if (!s.ok()) {
92             LOG(ERROR) << "Bad response [" << s
93                        << "] from GetStepSequenceAsync call to "
94                        << group_leader_;
95             done(s);
96           } else {
97             done(UpdateStepSequences(*resp));
98           }
99           delete req;
100           delete resp;
101         });
102   }
103 }
104 
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,const StatusCallback & done)105 void RpcCollectiveExecutorMgr::GetStepSequenceAsync(
106     const GetStepSequenceRequest* request, GetStepSequenceResponse* response,
107     const StatusCallback& done) {
108   if (!group_leader_.empty()) {
109     LOG(ERROR) << "GetStepSequence called at non-group-leader";
110     done(errors::Internal("GetStepSequenceAsync called at non-group-leader"));
111   } else {
112     mutex_lock l(sequence_mu_);
113     for (int64_t graph_key : request->graph_key()) {
114       auto it = sequence_table_.find(graph_key);
115       GraphKeySequence* gks = nullptr;
116       if (it == sequence_table_.end()) {
117         gks = new GraphKeySequence(graph_key);
118         gks->next_step_id_ = NewRandomStepId();
119         sequence_table_[graph_key] = gks;
120       } else {
121         gks = it->second;
122       }
123       StepSequence* ss = response->add_step_sequence();
124       ss->set_graph_key(graph_key);
125       ss->set_next_step_id(gks->next_step_id_);
126     }
127     done(OkStatus());
128   }
129 }
130 
UpdateStepSequences(const GetStepSequenceResponse & resp)131 Status RpcCollectiveExecutorMgr::UpdateStepSequences(
132     const GetStepSequenceResponse& resp) {
133   mutex_lock l(sequence_mu_);
134   for (const StepSequence& ss : resp.step_sequence()) {
135     GraphKeySequence* gks = nullptr;
136     auto it = sequence_table_.find(ss.graph_key());
137     if (it == sequence_table_.end()) {
138       gks = new GraphKeySequence(ss.graph_key());
139       sequence_table_[ss.graph_key()] = gks;
140     } else {
141       gks = it->second;
142     }
143     gks->next_step_id_ = ss.next_step_id();
144   }
145   return OkStatus();
146 }
147 
NextStepId(int64_t graph_key)148 int64_t RpcCollectiveExecutorMgr::NextStepId(int64_t graph_key) {
149   mutex_lock l(sequence_mu_);
150   auto it = sequence_table_.find(graph_key);
151   if (it != sequence_table_.end()) {
152     return it->second->next_step_id_;
153   }
154   return CollectiveExecutor::kInvalidId;
155 }
156 
RetireStepId(int64_t graph_key,int64_t step_id)157 void RpcCollectiveExecutorMgr::RetireStepId(int64_t graph_key,
158                                             int64_t step_id) {
159   mutex_lock l(sequence_mu_);
160   auto it = sequence_table_.find(graph_key);
161   if (it != sequence_table_.end()) {
162     if (step_id == it->second->next_step_id_) {
163       it->second->next_step_id_ = (it->second->next_step_id_ + 1) & kStepIdMask;
164     } else {
165       it->second->next_step_id_ = CollectiveExecutor::kInvalidId;
166     }
167   } else {
168     LOG(ERROR) << "Failed to find graph_key " << graph_key << " to retire.";
169   }
170 }
171 
CreateProdRpcCollectiveExecutorMgr(const ConfigProto & config,const DeviceMgr * device_mgr,std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,WorkerCacheInterface * worker_cache,const string & default_worker_name)172 std::unique_ptr<RpcCollectiveExecutorMgr> CreateProdRpcCollectiveExecutorMgr(
173     const ConfigProto& config, const DeviceMgr* device_mgr,
174     std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,
175     WorkerCacheInterface* worker_cache, const string& default_worker_name) {
176   auto dev_resolver = std::make_unique<DeviceResolverDistributed>(device_mgr);
177   auto param_resolver = std::make_unique<CollectiveParamResolverDistributed>(
178       config, device_mgr, dev_resolver.get(), nccl_communicator.get(),
179       worker_cache, default_worker_name);
180   return std::make_unique<RpcCollectiveExecutorMgr>(
181       config, device_mgr, std::move(dev_resolver), std::move(param_resolver),
182       std::move(nccl_communicator), worker_cache, default_worker_name);
183 }
184 
185 }  // namespace tensorflow
186