1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
16 
17 #include "absl/strings/escaping.h"
18 #include "tensorflow/core/common_runtime/device.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/distributed_runtime/cancellable_call.h"
21 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
22 #include "tensorflow/core/distributed_runtime/worker_cache.h"
23 #include "tensorflow/core/framework/device_attributes.pb.h"
24 #include "tensorflow/core/platform/errors.h"
25 #include "tensorflow/core/platform/status.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 #include "tensorflow/core/util/device_name_utils.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 class CompleteGroupCall : public CancellableCall {
33  public:
CompleteGroupCall(const CollGroupParams & group,const DeviceAttributes & device,CancellationManager * cancel_mgr,const string & remote_worker,WorkerCacheInterface * wc)34   CompleteGroupCall(const CollGroupParams& group,
35                     const DeviceAttributes& device,
36                     CancellationManager* cancel_mgr,
37                     const string& remote_worker, WorkerCacheInterface* wc)
38       : CancellableCall(cancel_mgr, remote_worker, wc) {
39     req_.set_group_key(group.group_key);
40     req_.set_group_size(group.group_size);
41     req_.set_device_type(group.device_type.type_string());
42     *req_.mutable_device_attributes() = device;
43   }
~CompleteGroupCall()44   ~CompleteGroupCall() override {}
45 
IssueCall(const StatusCallback & done)46   void IssueCall(const StatusCallback& done) override {
47     wi_->CompleteGroupAsync(&opts_, &req_, &resp_, done);
48   }
49 
50   CompleteGroupRequest req_;
51   CompleteGroupResponse resp_;
52 };
53 
54 class CompleteInstanceCall : public CancellableCall {
55  public:
CompleteInstanceCall(const CollGroupParams & group,const CollInstanceParams & instance,const string & node_name,const string & device_name,bool is_source,CancellationManager * cancel_mgr,const string & remote_worker,WorkerCacheInterface * wc)56   CompleteInstanceCall(const CollGroupParams& group,
57                        const CollInstanceParams& instance,
58                        const string& node_name, const string& device_name,
59                        bool is_source, CancellationManager* cancel_mgr,
60                        const string& remote_worker, WorkerCacheInterface* wc)
61       : CancellableCall(cancel_mgr, remote_worker, wc) {
62     req_.set_name(node_name);
63     req_.set_type(instance.type);
64     req_.set_data_type(instance.data_type);
65     instance.shape.AsProto(req_.mutable_shape());
66     req_.set_group_key(group.group_key);
67     req_.set_group_size(group.group_size);
68     req_.set_instance_key(instance.instance_key);
69     req_.set_device_type(group.device_type.type_string());
70     for (int32_t offset : instance.impl_details.subdiv_offsets) {
71       req_.add_subdiv_offset(offset);
72     }
73     req_.set_device(device_name);
74     req_.set_is_source(is_source);
75   }
76 
~CompleteInstanceCall()77   ~CompleteInstanceCall() override {}
78 
IssueCall(const StatusCallback & done)79   void IssueCall(const StatusCallback& done) override {
80     wi_->CompleteInstanceAsync(&opts_, &req_, &resp_, done);
81   }
82 
83   CompleteInstanceRequest req_;
84   CompleteInstanceResponse resp_;
85 };
86 
87 }  // namespace
88 
CollectiveParamResolverDistributed(const ConfigProto & config,const DeviceMgr * dev_mgr,DeviceResolverDistributed * dev_resolver,NcclCommunicatorInterface * nccl_communicator,WorkerCacheInterface * worker_cache,const string & task_name)89 CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
90     const ConfigProto& config, const DeviceMgr* dev_mgr,
91     DeviceResolverDistributed* dev_resolver,
92     NcclCommunicatorInterface* nccl_communicator,
93     WorkerCacheInterface* worker_cache, const string& task_name)
94     : CollectiveParamResolverLocal(config, dev_mgr, dev_resolver,
95                                    nccl_communicator, task_name),
96       worker_cache_(worker_cache),
97       group_leader_(task_name == config.experimental().collective_group_leader()
98                         ? ""
99                         : config.experimental().collective_group_leader()) {
100   VLOG(1) << "CompleteParamResolverDistributed ctor task={" << task_name
101           << "} config.collective_group_leader={"
102           << config.experimental().collective_group_leader() << "}"
103           << " config.collective_nccl={"
104           << config.experimental().collective_nccl() << "}";
105 }
106 
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)107 void CollectiveParamResolverDistributed::CompleteParamsAsync(
108     const DeviceAttributes& device, CollectiveParams* cp,
109     CancellationManager* cancel_mgr, const StatusCallback& done) {
110   VLOG(1) << "CompleteParams distributed " << device.name() << " for " << cp
111           << ": " << cp->ToString();
112   if (cp->run_group_initialization) {
113     CompleteGroupDistributed(
114         device, &cp->group, cancel_mgr,
115         [this, device, cp, cancel_mgr, done](Status s) {
116           if (s.ok()) {
117             std::vector<DeviceAttributes> devices;
118             devices.reserve(cp->group.group_size);
119             for (const CollGroupMember& m : cp->group.members) {
120               devices.push_back(m.device);
121             }
122             s = dev_resolver_->UpdateDeviceAttributes(devices);
123           }
124           if (s.ok()) {
125             CompleteInstanceDistributed(device.name(), cp, cancel_mgr, done);
126           } else {
127             done(s);
128           }
129         });
130   } else {
131     // For Collective V3 ops, group is already initialized. Fetch attributes
132     // for the already initialized group to pass to Insitance initialization.
133     auto s = LookupGroup(cp->group.group_key, &cp->group);
134     if (s.ok()) {
135       CompleteInstanceDistributed(device.name(), cp, cancel_mgr, done);
136     } else {
137       done(s);
138     }
139   }
140 }
141 
CompleteGroupAsync(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,const StatusCallback & done)142 void CollectiveParamResolverDistributed::CompleteGroupAsync(
143     const DeviceAttributes& device, CollGroupParams* group_params,
144     CancellationManager* cancel_mgr, const StatusCallback& done) {
145   CompleteGroupDistributed(device, group_params, cancel_mgr, done);
146 }
147 
CompleteInstanceAsync(const CompleteInstanceRequest * request,CompleteInstanceResponse * response,CancellationManager * cancel_mgr,const StatusCallback & done)148 void CollectiveParamResolverDistributed::CompleteInstanceAsync(
149     const CompleteInstanceRequest* request, CompleteInstanceResponse* response,
150     CancellationManager* cancel_mgr, const StatusCallback& done) {
151   GroupRec* gr = GetCachedGroup(request->group_key());
152   if (gr == nullptr) {
153     done(errors::FailedPrecondition(
154         "group ", request->group_key(),
155         " not found. This normally means the server has restarted"));
156     return;
157   }
158   CollectiveParams* cp = new CollectiveParams;
159   {
160     mutex_lock l(gr->mu);
161     if (!gr->status.ok()) {
162       done(gr->status);
163       return;
164     } else if (gr->group.members.size() != gr->group.group_size) {
165       done(errors::FailedPrecondition(
166           "group ", request->group_key(),
167           " failed to resolve. This normally means the server has restarted"));
168       return;
169     }
170     cp->group = gr->group;
171   }
172   cp->name = request->name();
173   cp->instance.type = CollectiveType(request->type());
174   cp->instance.instance_key = request->instance_key();
175   cp->instance.data_type = request->data_type();
176   cp->instance.shape = TensorShape(request->shape());
177   cp->is_source = request->is_source();
178   for (int32_t offset : request->subdiv_offset()) {
179     cp->instance.impl_details.subdiv_offsets.push_back(offset);
180   }
181   StatusCallback done_and_cleanup = [cp, done](const Status& s) {
182     done(s);
183     cp->Unref();
184   };
185   CompleteInstanceDistributed(
186       request->device(), cp, cancel_mgr,
187       [this, cp, response, done_and_cleanup](Status status) {
188         if (status.ok()) {
189           // Now source_rank should be known, so retrieve it.
190           bool created_irec;
191           InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec);
192           {
193             mutex_lock l(ir->mu);
194             status = ir->status;
195             if (ir->status.ok()) {
196               response->set_instance_key(cp->instance.instance_key);
197               response->set_source_rank(ir->source_rank);
198             }
199           }
200         }
201         done_and_cleanup(status);
202       });
203 }
204 
205 CollectiveParamResolverDistributed::GroupRec*
GetCachedGroup(int32_t group_key)206 CollectiveParamResolverDistributed::GetCachedGroup(int32_t group_key) {
207   mutex_lock l(group_mu_);
208   auto it = group_table_.find(group_key);
209   if (it == group_table_.end()) {
210     return nullptr;
211   }
212   return it->second.get();
213 }
214 
UpdateGroupCache(const CompleteGroupResponse & resp)215 Status CollectiveParamResolverDistributed::UpdateGroupCache(
216     const CompleteGroupResponse& resp) {
217   // Build a new record from resp.
218   std::unique_ptr<GroupRec> gr(new GroupRec);
219   {
220     mutex_lock grl(gr->mu);
221     gr->group.device_type = DeviceType(resp.device_type());
222     gr->group.group_key = resp.group_key();
223     gr->group.group_size = resp.group_size();
224     gr->group.num_tasks = resp.num_tasks();
225     if (resp.device_attributes().empty()) {
226       return errors::Internal(
227           "CompleteGroupResponse device_attributes is empty. Make sure you're "
228           "running the same version of Tensorflow on all workers.");
229     }
230     if (resp.device_attributes_size() != gr->group.group_size) {
231       return errors::Internal(
232           "CompleteGroupResponse group_size doesn't match device_name list");
233     }
234     gr->group.members.reserve(resp.device_attributes().size());
235     for (const DeviceAttributes& device : resp.device_attributes()) {
236       CollGroupMember member;
237       member.device = device;
238       gr->group.members.push_back(std::move(member));
239       gr->incarnations_by_device_name[device.name()] = device.incarnation();
240     }
241     gr->group.runtime_details.communicator_key = resp.communicator_key();
242     FinishGroup(gr.get());
243   }
244   GroupRec* previous_gr = nullptr;
245   {
246     // Group membership should never change. Once a record is in group_table_
247     // it never gets removed.
248     mutex_lock l(group_mu_);
249     auto it = group_table_.find(resp.group_key());
250     if (it == group_table_.end()) {
251       VLOG(2) << "UpdateGroupCache: communicator_key="
252               << absl::CEscape(resp.communicator_key());
253       group_table_[gr->group.group_key] = std::move(gr);
254     } else {
255       previous_gr = it->second.get();
256     }
257   }
258   if (previous_gr != nullptr) {
259     mutex_lock grl(previous_gr->mu);
260     if (previous_gr->group.runtime_details.communicator_key !=
261         resp.communicator_key()) {
262       return errors::Internal(
263           "UpdateGroupCache: CompleteGroupResponse for group ",
264           resp.group_key(),
265           " gives communicator_key=", absl::CEscape(resp.communicator_key()),
266           " but cache already holds communicator_key=",
267           absl::CEscape(previous_gr->group.runtime_details.communicator_key));
268     }
269   }
270   return OkStatus();
271 }
272 
CompleteGroupDistributed(const DeviceAttributes & device,CollGroupParams * group_params,CancellationManager * cancel_mgr,const StatusCallback & done)273 void CollectiveParamResolverDistributed::CompleteGroupDistributed(
274     const DeviceAttributes& device, CollGroupParams* group_params,
275     CancellationManager* cancel_mgr, const StatusCallback& done) {
276   VLOG(1) << "CompleteGroupDistributed group_key=" << group_params->group_key
277           << " dev: " << device.name()
278           << " is_leader=" << (group_leader_.empty());
279   if (group_leader_.empty()) {
280     // This is the group leader, so resolution is local.
281     return CompleteGroupLocal(device, group_params, cancel_mgr, done);
282   } else if (GetCachedGroup(group_params->group_key) == nullptr) {
283     // Need to update Group cache from the leader.
284     CompleteGroupCall* call = new CompleteGroupCall(
285         *group_params, device, cancel_mgr, group_leader_, worker_cache_);
286     CancellationToken abortion_token =
287         abortion_cancel_mgr_.get_cancellation_token();
288     bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
289         abortion_token, [call] { call->Cancel(); });
290     if (already_aborted) {
291       done(errors::Cancelled("collective ops already aborted"));
292       delete call;
293       return;
294     }
295     call->Start([this, device, group_params, call, cancel_mgr, abortion_token,
296                  done](const Status& s) {
297       abortion_cancel_mgr_.DeregisterCallback(abortion_token);
298       if (s.ok()) {
299         Status status = UpdateGroupCache(call->resp_);
300         if (status.ok()) {
301           CompleteGroupLocal(device, group_params, cancel_mgr, done);
302         } else {
303           done(status);
304         }
305       } else {
306         done(s);
307       }
308       delete call;
309     });
310     return;
311   } else {
312     return CompleteGroupLocal(device, group_params, cancel_mgr, done);
313   }
314 }
315 
InstanceIsCached(int32_t group_key,int32_t instance_key)316 bool CollectiveParamResolverDistributed::InstanceIsCached(
317     int32_t group_key, int32_t instance_key) {
318   mutex_lock l(instance_mu_);
319   auto group_it = instance_table_.find(group_key);
320   if (group_it == instance_table_.end()) {
321     return false;
322   }
323   auto instance_it = group_it->second.find(instance_key);
324   return instance_it != group_it->second.end();
325 }
326 
UpdateInstanceCache(CollectiveParams * cp,const CompleteInstanceResponse & resp)327 Status CollectiveParamResolverDistributed::UpdateInstanceCache(
328     CollectiveParams* cp, const CompleteInstanceResponse& resp) {
329   int32_t source_rank = resp.source_rank();
330   bool created_irec;
331   InstanceRec* ir = GetOrCreateInstanceRec(cp, &created_irec);
332   mutex_lock l(ir->mu);
333   if (!ir->status.ok()) {
334     return ir->status;
335   }
336   if (ir->source_rank != source_rank) {
337     if (ir->source_rank >= 0) {
338       ir->status = errors::Internal(
339           "UpdateInstanceCache: CompleteInstanceResponse for instance ",
340           cp->instance.instance_key, " gives source_rank=", source_rank,
341           " but cache already holds value=", ir->source_rank);
342       return ir->status;
343     }
344     ir->source_rank = source_rank;
345   }
346   if (ir->known_count < cp->group.group_size) {
347     ir->known_count = cp->group.group_size;
348     const int ir_known_size = ir->known.size();
349     if (ir_known_size != cp->group.group_size) {
350       ir->status = errors::Internal(
351           "UpdateInstanceCache:: CompleteInstanceResponse for instance ",
352           cp->instance.instance_key, " has known.size()=", ir->known.size(),
353           " < group_size=", cp->group.group_size);
354       return ir->status;
355     }
356     for (int i = 0; i < ir_known_size; ++i) {
357       ir->known[i] = true;
358     }
359   }
360   return ir->status;
361 }
362 
CompleteInstanceDistributed(const string & device,CollectiveParams * cp,CancellationManager * cancel_mgr,const StatusCallback & done)363 void CollectiveParamResolverDistributed::CompleteInstanceDistributed(
364     const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
365     const StatusCallback& done) {
366   if (group_leader_.empty()) {
367     // This is the group leader so resolution is local.
368     return CompleteInstanceLocal(device, cp, done);
369   } else if (InstanceIsCached(cp->group.group_key, cp->instance.instance_key)) {
370     return CompleteInstanceLocal(device, cp, done);
371   } else {
372     CompleteInstanceCall* call = new CompleteInstanceCall(
373         cp->group, cp->instance, cp->name, device, cp->is_source, cancel_mgr,
374         group_leader_, worker_cache_);
375     CancellationToken abortion_token =
376         abortion_cancel_mgr_.get_cancellation_token();
377     bool already_aborted = !abortion_cancel_mgr_.RegisterCallback(
378         abortion_token, [call] { call->Cancel(); });
379     if (already_aborted) {
380       done(errors::Cancelled("collective ops already aborted"));
381       delete call;
382       return;
383     }
384     call->Start([this, device, cp, call, abortion_token, done](Status s) {
385       abortion_cancel_mgr_.DeregisterCallback(abortion_token);
386       if (s.ok()) {
387         s = UpdateInstanceCache(cp, call->resp_);
388       }
389       if (s.ok()) {
390         CompleteInstanceLocal(device, cp, done);
391       } else {
392         done(s);
393       }
394       delete call;
395     });
396     return;
397   }
398 }
399 
StartAbort(const Status & s)400 void CollectiveParamResolverDistributed::StartAbort(const Status& s) {
401   {
402     mutex_lock l(status_mu_);
403     if (!status_.ok()) {
404       VLOG(2) << "CollectiveParamResolverDistributed already aborted. Ignoring "
405                  "subsequent abortion with status: "
406               << s;
407       return;
408     }
409     status_ = s;
410   }
411   StartAbortLocal(s);
412   abortion_cancel_mgr_.StartCancel();
413 }
414 
415 }  // namespace tensorflow
416