xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/worker.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/worker.h"
17 
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/process_util.h"
21 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
22 #include "tensorflow/core/common_runtime/step_stats_collector.h"
23 #include "tensorflow/core/distributed_runtime/error_payloads.h"
24 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
25 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
26 #include "tensorflow/core/distributed_runtime/worker_session.h"
27 #include "tensorflow/core/framework/collective.h"
28 #include "tensorflow/core/platform/tracing.h"
29 #include "tensorflow/core/profiler/lib/device_profiler_session.h"
30 #include "tensorflow/core/protobuf/distributed_runtime_payloads.pb.h"
31 
32 namespace tensorflow {
33 
Worker(WorkerEnv * env)34 Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {
35   // Enable log history collection in StatusGroup so that recent warning and
36   // error log messages will be attached to the root error status to be
37   // forwarded to the master.
38   StatusGroup::ConfigureLogHistory();
39 }
40 
GetStatusAsync(CallOptions * opts,const GetStatusRequest * request,GetStatusResponse * response,bool fail_fast,StatusCallback done)41 void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
42                             GetStatusResponse* response, bool fail_fast,
43                             StatusCallback done) {
44   const DeviceMgr* dm = env_->device_mgr;
45   std::vector<DeviceAttributes> devices;
46   dm->ListDeviceAttributes(&devices);
47   response->mutable_device_attributes()->Reserve(devices.size());
48   for (auto& d : devices) {
49     response->add_device_attributes()->Swap(&d);
50   }
51   done(OkStatus());
52 }
53 
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)54 void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
55                                       CreateWorkerSessionResponse* response,
56                                       StatusCallback done) {
57   Status s = env_->session_mgr->CreateSession(
58       request->session_handle(), request->server_def(),
59       request->cluster_device_attributes(), request->isolate_session_state(),
60       request->master_task(), request->master_incarnation());
61   done(s);
62 }
63 
DeleteWorkerSessionAsync(CallOptions * opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)64 void Worker::DeleteWorkerSessionAsync(CallOptions* opts,
65                                       const DeleteWorkerSessionRequest* request,
66                                       DeleteWorkerSessionResponse* response,
67                                       StatusCallback done) {
68   Status s = env_->session_mgr->DeleteSession(request->session_handle());
69   done(s);
70 }
71 
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)72 void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
73                                 RegisterGraphResponse* response,
74                                 StatusCallback done) {
75   std::shared_ptr<WorkerSession> session;
76   Status s;
77   if (request->create_worker_session_called()) {
78     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
79                                                    &session);
80   } else {
81     session = env_->session_mgr->LegacySession();
82   }
83   if (s.ok()) {
84     s = session->graph_mgr()->Register(
85         request->session_handle(), request->graph_def(),
86         request->graph_options(), request->debug_options(),
87         request->config_proto(), request->collective_graph_key(), session.get(),
88         session->cluster_flr(), response->mutable_graph_handle());
89   }
90   done(s);
91 }
92 
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)93 void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
94                                   DeregisterGraphResponse* response,
95                                   StatusCallback done) {
96   std::shared_ptr<WorkerSession> session;
97   Status s;
98   if (request->create_worker_session_called()) {
99     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
100                                                    &session);
101   } else {
102     session = env_->session_mgr->LegacySession();
103   }
104   if (s.ok()) {
105     s = session->graph_mgr()->Deregister(request->graph_handle());
106   }
107 
108   done(s);
109 }
110 
AbortStep(int64_t step_id)111 void Worker::AbortStep(int64_t step_id) {
112   RemoteRendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
113   // Do not abort if it's a context global instance for eager op-by-op execution
114   if (rendez->IsRemoteEagerContextDefault()) return;
115   SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
116     // Delay a bit before aborting the step. This way, the root
117     // cause may return first back to the client instead of this
118     // cancellation generated abort error.
119     rendez->StartAbort(errors::Aborted("Step ", step_id,
120                                        " cancelled.  Cancelling rendezvous."));
121     rendez->Unref();
122   });
123 }
124 
PrepareRunGraph(RunGraphRequestWrapper * req,GraphMgr::NamedTensors * in,GraphMgr::NamedTensors * out)125 Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req,
126                                GraphMgr::NamedTensors* in,
127                                GraphMgr::NamedTensors* out) {
128   static Tensor empty_tensor(DT_FLOAT);
129   if (req->num_sends() > 0) {
130     Tensor val;
131     for (size_t i = 0; i < req->num_sends(); ++i) {
132       TF_RETURN_IF_ERROR(req->SendValue(i, &val));
133       in->insert({req->send_key(i), val});
134     }
135   }
136   for (size_t i = 0; i < req->num_recvs(); ++i) {
137     out->insert({req->recv_key(i), empty_tensor});
138   }
139   return OkStatus();
140 }
141 
RunGraphAsync(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)142 void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
143                            MutableRunGraphResponseWrapper* response,
144                            StatusCallback done) {
145   if (request->store_errors_in_response_body()) {
146     done = [response, done](const Status& status) {
147       response->set_status(status);
148       done(OkStatus());
149     };
150   }
151   if (request->is_partial()) {
152     DoPartialRunGraph(opts, request, response, std::move(done));
153   } else {
154     DoRunGraph(opts, request, response, std::move(done));
155   }
156 }
157 
CreateRunGraphRequest()158 MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() {
159   return new InMemoryRunGraphRequest;
160 }
161 
CreateRunGraphResponse()162 MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() {
163   return new InMemoryRunGraphResponse;
164 }
165 
DoRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)166 void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
167                         MutableRunGraphResponseWrapper* response,
168                         StatusCallback done) {
169   const int64_t step_id = request->step_id();
170   TRACEPRINTF("RunGraph: %lld", step_id);
171   Status s = recent_request_ids_.TrackUnique(request->request_id(),
172                                              "RunGraph (Worker)", request);
173   if (!s.ok()) {
174     done(s);
175     return;
176   }
177 
178   std::shared_ptr<WorkerSession> session;
179   if (request->create_worker_session_called()) {
180     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
181                                                    &session);
182   } else {
183     session = env_->session_mgr->LegacySession();
184   }
185   if (!s.ok()) {
186     done(s);
187     return;
188   }
189   GraphMgr::NamedTensors in;
190   GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
191   s = PrepareRunGraph(request, &in, out);
192   if (!s.ok()) {
193     delete out;
194     done(s);
195     return;
196   }
197   StepStatsCollector* collector = nullptr;
198   if (request->exec_opts().report_tensor_allocations_upon_oom() ||
199       request->exec_opts().record_timeline() ||
200       request->exec_opts().record_costs()) {
201     collector = new StepStatsCollector(response->mutable_step_stats());
202   }
203   DeviceProfilerSession* device_profiler_session = nullptr;
204   if (collector && request->exec_opts().record_timeline()) {
205     // If timeline was requested, assume we want hardware level tracing.
206     device_profiler_session = DeviceProfilerSession::Create().release();
207   }
208   CancellationManager* cm = new CancellationManager;
209   opts->SetCancelCallback([this, cm, step_id]() {
210     LOG(INFO) << "Cancellation requested for RunGraph.";
211     cm->StartCancel();
212     AbortStep(step_id);
213   });
214   CancellationToken token;
215   token = cancellation_manager_.get_cancellation_token();
216   bool already_cancelled = !cancellation_manager_.RegisterCallback(
217       token, [cm]() { cm->StartCancel(); });
218   if (already_cancelled) {
219     opts->ClearCancelCallback();
220     delete cm;
221     delete collector;
222     delete device_profiler_session;
223     delete out;
224     done(errors::Aborted("Call was aborted"));
225     return;
226   }
227   session->graph_mgr()->ExecuteAsync(
228       request->graph_handle(), step_id, request->exec_opts(), in, session.get(),
229       collector, response, cm, env_->session_mgr->GetCoordinationServiceAgent(),
230       [this, step_id, response, session, cm, out, token, collector,
231        device_profiler_session, opts, done](const Status& status) {
232         Status s = status;
233         if (s.ok()) {
234           s = session->graph_mgr()->RecvOutputs(step_id, out);
235         }
236 
237         opts->ClearCancelCallback();
238         cancellation_manager_.DeregisterCallback(token);
239         delete cm;
240 
241         if (device_profiler_session) {
242           device_profiler_session->CollectData(response->mutable_step_stats())
243               .IgnoreError();
244         }
245 
246         if (s.ok()) {
247           for (const auto& p : *out) {
248             const string& key = p.first;
249             const Tensor& val = p.second;
250             response->AddRecv(key, val);
251           }
252         }
253 
254         if (collector) collector->Finalize();
255         delete collector;
256         delete device_profiler_session;
257         delete out;
258         done(s);
259       });
260 }
261 
262 // TODO(suharshs): Add stats collection support to partial run.
DoPartialRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)263 void Worker::DoPartialRunGraph(CallOptions* opts,
264                                RunGraphRequestWrapper* request,
265                                MutableRunGraphResponseWrapper* response,
266                                StatusCallback done) {
267   const int64_t step_id = request->step_id();
268   const string& graph_handle = request->graph_handle();
269   TRACEPRINTF("PartialRunGraph: %lld", step_id);
270   Status s = recent_request_ids_.TrackUnique(
271       request->request_id(), "PartialRunGraph (Worker)", request);
272   if (!s.ok()) {
273     done(s);
274     return;
275   }
276 
277   std::shared_ptr<WorkerSession> session;
278   if (request->create_worker_session_called()) {
279     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
280                                                    &session);
281   } else {
282     session = env_->session_mgr->LegacySession();
283   }
284   if (!s.ok()) {
285     done(s);
286     return;
287   }
288 
289   GraphMgr::NamedTensors in;
290   GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
291   s = PrepareRunGraph(request, &in, out);
292   auto finish = [done, out, opts](const Status& s) {
293     opts->ClearCancelCallback();
294     delete out;
295     done(s);
296   };
297   if (!s.ok()) {
298     finish(s);
299     return;
300   }
301 
302   CancellationManager* cm = nullptr;
303   bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm);
304 
305   // Before we start doing anything, we set the RPC cancellation.
306   opts->SetCancelCallback([this, cm, step_id]() {
307     LOG(INFO) << "Cancellation requested for PartialRunGraph.";
308     cm->StartCancel();
309     AbortStep(step_id);
310   });
311 
312   // If this is a new partial run request, the request will need to start the
313   // executors.
314   if (is_new_partial_run) {
315     CancellationToken token;
316     token = cancellation_manager_.get_cancellation_token();
317     cancellation_manager_.RegisterCallback(token,
318                                            [cm]() { cm->StartCancel(); });
319     session->graph_mgr()->ExecuteAsync(
320         graph_handle, step_id, request->exec_opts(), in, session.get(),
321         /*collector=*/nullptr, /*response=*/nullptr, cm,
322         env_->session_mgr->GetCoordinationServiceAgent(),
323         [this, token, step_id, session](Status s) {
324           cancellation_manager_.DeregisterCallback(token);
325           partial_run_mgr_.ExecutorDone(step_id, s);
326         });
327   } else {
328     // Send the partial run's new inputs.
329     s = session->graph_mgr()->SendInputs(step_id, in);
330     if (!s.ok()) {
331       finish(s);
332       return;
333     }
334   }
335 
336   session->graph_mgr()->RecvOutputsAsync(
337       step_id, out, [this, out, request, response, step_id, finish](Status s) {
338         if (s.ok()) {
339           // Construct and return the resp.
340           for (const auto& p : *out) {
341             const string& key = p.first;
342             const Tensor& val = p.second;
343             response->AddRecv(key, val);
344           }
345         }
346         if (request->is_last_partial_run()) {
347           partial_run_mgr_.PartialRunDone(step_id, finish, s);
348         } else {
349           finish(s);
350         }
351       });
352 }
353 
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)354 void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
355                                CleanupGraphResponse* response,
356                                StatusCallback done) {
357   const int64_t step_id = request->step_id();
358   env_->rendezvous_mgr->Cleanup(step_id);
359   if (env_->collective_executor_mgr) {
360     env_->collective_executor_mgr->Cleanup(step_id);
361   }
362   for (Device* d : env_->local_devices) {
363     ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
364     if (sam) {
365       sam->Cleanup(step_id);
366     }
367   }
368   done(OkStatus());
369 }
370 
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)371 void Worker::CleanupAllAsync(const CleanupAllRequest* request,
372                              CleanupAllResponse* response,
373                              StatusCallback done) {
374   std::vector<string> containers;
375   for (const auto& c : request->container()) containers.push_back(c);
376   env_->device_mgr->ClearContainers(containers);
377   done(OkStatus());
378 }
379 
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)380 void Worker::LoggingAsync(const LoggingRequest* request,
381                           LoggingResponse* response, StatusCallback done) {
382   done(errors::Unimplemented("Logging"));
383 }
384 
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)385 void Worker::TracingAsync(const TracingRequest* request,
386                           TracingResponse* response, StatusCallback done) {
387   done(errors::Unimplemented("Tracing"));
388 }
389 
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)390 void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
391                           RecvBufResponse* response, StatusCallback done) {
392   // The base Worker class does not implement RecvBufAsync because
393   // it is not currently used for worker-to-worker communication. Use a
394   // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`)
395   // instead.
396   done(errors::Unimplemented("Worker::RecvBufAsync()"));
397 }
398 
CompleteGroupAsync(CallOptions * opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)399 void Worker::CompleteGroupAsync(CallOptions* opts,
400                                 const CompleteGroupRequest* request,
401                                 CompleteGroupResponse* response,
402                                 StatusCallback done) {
403   if (!request->has_device_attributes()) {
404     done(errors::Internal(
405         "CompleteGroupRequest device_attributes is not set. Make sure you're "
406         "running the same version of Tensorflow on all workers."));
407     return;
408   }
409   if (env_->collective_executor_mgr) {
410     auto group_params = new CollGroupParams();
411     group_params->group_key = request->group_key();
412     group_params->group_size = request->group_size();
413     group_params->device_type = DeviceType(request->device_type());
414     env_->collective_executor_mgr->GetParamResolver()->CompleteGroupAsync(
415         request->device_attributes(), group_params, &cancellation_manager_,
416         [response, group_params, done = std::move(done)](const Status& s) {
417           if (s.ok()) {
418             response->set_group_key(group_params->group_key);
419             response->set_group_size(group_params->group_size);
420             response->set_device_type(group_params->device_type.type_string());
421             response->set_num_tasks(group_params->num_tasks);
422             for (const CollGroupMember& member : group_params->members) {
423               *response->add_device_attributes() = member.device;
424             }
425             response->set_communicator_key(
426                 group_params->runtime_details.communicator_key);
427           } else {
428             LOG(ERROR) << "Bad status from CompleteGroupDistributed: " << s;
429           }
430           delete group_params;
431           done(s);
432         });
433   } else {
434     done(
435         errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
436   }
437 }
438 
CompleteInstanceAsync(CallOptions * opts,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)439 void Worker::CompleteInstanceAsync(CallOptions* opts,
440                                    const CompleteInstanceRequest* request,
441                                    CompleteInstanceResponse* response,
442                                    StatusCallback done) {
443   if (env_->collective_executor_mgr) {
444     env_->collective_executor_mgr->GetParamResolver()->CompleteInstanceAsync(
445         request, response, &cancellation_manager_, done);
446   } else {
447     done(
448         errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
449   }
450 }
451 
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)452 void Worker::GetStepSequenceAsync(const GetStepSequenceRequest* request,
453                                   GetStepSequenceResponse* response,
454                                   StatusCallback done) {
455   if (env_->collective_executor_mgr) {
456     env_->collective_executor_mgr->GetStepSequenceAsync(request, response,
457                                                         done);
458   } else {
459     done(
460         errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
461   }
462 }
463 
464 // Helper for RecvTensor. Validates "key" and returns the source
465 // device in "*src_dev".
PrepareRecvTensor(const Rendezvous::ParsedKey & parsed,Device ** src_dev)466 Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
467                                  Device** src_dev) {
468   // Figures out which device the tensor is hosted on.
469   string local_name = DeviceNameUtils::LocalName(parsed.src_device);
470   TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
471 
472   // Does the device have the right incarnation number we expect?
473   if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
474     return errors::AbortedWithPayloads(
475         strings::StrCat("RecvTensor expects a different device incarnation: ",
476                         parsed.src_incarnation, " vs. ",
477                         (*src_dev)->attributes().incarnation(),
478                         ". Your worker job (\"",
479                         env_->session_mgr->LegacySession()->worker_name(),
480                         "\") was probably restarted. Check your "
481                         "worker job for the reason why it was restarted."),
482         {{kWorkerPossiblyRestarted,
483           distributed_runtime::WorkerPossiblyRestarted().SerializeAsString()}});
484   }
485 
486   return OkStatus();
487 }
488 
RecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)489 void Worker::RecvTensorAsync(CallOptions* opts,
490                              const RecvTensorRequest* request,
491                              TensorResponse* response, StatusCallback done) {
492   // The base Worker class does not implement RecvTensorAsync, because
493   // it is not currently used for worker-to-worker communication. Use a
494   // transport-specific implementation (such as `GrpcWorker::RecvTensorAsync()`)
495   // instead.
496   done(errors::Unimplemented("Worker::RecvTensorAsync()"));
497 }
498 
499 }  // namespace tensorflow
500