1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/worker.h"
17
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/process_util.h"
21 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
22 #include "tensorflow/core/common_runtime/step_stats_collector.h"
23 #include "tensorflow/core/distributed_runtime/error_payloads.h"
24 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
25 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
26 #include "tensorflow/core/distributed_runtime/worker_session.h"
27 #include "tensorflow/core/framework/collective.h"
28 #include "tensorflow/core/platform/tracing.h"
29 #include "tensorflow/core/profiler/lib/device_profiler_session.h"
30 #include "tensorflow/core/protobuf/distributed_runtime_payloads.pb.h"
31
32 namespace tensorflow {
33
Worker(WorkerEnv * env)34 Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {
35 // Enable log history collection in StatusGroup so that recent warning and
36 // error log messages will be attached to the root error status to be
37 // forwarded to the master.
38 StatusGroup::ConfigureLogHistory();
39 }
40
GetStatusAsync(CallOptions * opts,const GetStatusRequest * request,GetStatusResponse * response,bool fail_fast,StatusCallback done)41 void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
42 GetStatusResponse* response, bool fail_fast,
43 StatusCallback done) {
44 const DeviceMgr* dm = env_->device_mgr;
45 std::vector<DeviceAttributes> devices;
46 dm->ListDeviceAttributes(&devices);
47 response->mutable_device_attributes()->Reserve(devices.size());
48 for (auto& d : devices) {
49 response->add_device_attributes()->Swap(&d);
50 }
51 done(OkStatus());
52 }
53
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)54 void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
55 CreateWorkerSessionResponse* response,
56 StatusCallback done) {
57 Status s = env_->session_mgr->CreateSession(
58 request->session_handle(), request->server_def(),
59 request->cluster_device_attributes(), request->isolate_session_state(),
60 request->master_task(), request->master_incarnation());
61 done(s);
62 }
63
DeleteWorkerSessionAsync(CallOptions * opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)64 void Worker::DeleteWorkerSessionAsync(CallOptions* opts,
65 const DeleteWorkerSessionRequest* request,
66 DeleteWorkerSessionResponse* response,
67 StatusCallback done) {
68 Status s = env_->session_mgr->DeleteSession(request->session_handle());
69 done(s);
70 }
71
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)72 void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
73 RegisterGraphResponse* response,
74 StatusCallback done) {
75 std::shared_ptr<WorkerSession> session;
76 Status s;
77 if (request->create_worker_session_called()) {
78 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
79 &session);
80 } else {
81 session = env_->session_mgr->LegacySession();
82 }
83 if (s.ok()) {
84 s = session->graph_mgr()->Register(
85 request->session_handle(), request->graph_def(),
86 request->graph_options(), request->debug_options(),
87 request->config_proto(), request->collective_graph_key(), session.get(),
88 session->cluster_flr(), response->mutable_graph_handle());
89 }
90 done(s);
91 }
92
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)93 void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
94 DeregisterGraphResponse* response,
95 StatusCallback done) {
96 std::shared_ptr<WorkerSession> session;
97 Status s;
98 if (request->create_worker_session_called()) {
99 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
100 &session);
101 } else {
102 session = env_->session_mgr->LegacySession();
103 }
104 if (s.ok()) {
105 s = session->graph_mgr()->Deregister(request->graph_handle());
106 }
107
108 done(s);
109 }
110
AbortStep(int64_t step_id)111 void Worker::AbortStep(int64_t step_id) {
112 RemoteRendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
113 // Do not abort if it's a context global instance for eager op-by-op execution
114 if (rendez->IsRemoteEagerContextDefault()) return;
115 SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
116 // Delay a bit before aborting the step. This way, the root
117 // cause may return first back to the client instead of this
118 // cancellation generated abort error.
119 rendez->StartAbort(errors::Aborted("Step ", step_id,
120 " cancelled. Cancelling rendezvous."));
121 rendez->Unref();
122 });
123 }
124
PrepareRunGraph(RunGraphRequestWrapper * req,GraphMgr::NamedTensors * in,GraphMgr::NamedTensors * out)125 Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req,
126 GraphMgr::NamedTensors* in,
127 GraphMgr::NamedTensors* out) {
128 static Tensor empty_tensor(DT_FLOAT);
129 if (req->num_sends() > 0) {
130 Tensor val;
131 for (size_t i = 0; i < req->num_sends(); ++i) {
132 TF_RETURN_IF_ERROR(req->SendValue(i, &val));
133 in->insert({req->send_key(i), val});
134 }
135 }
136 for (size_t i = 0; i < req->num_recvs(); ++i) {
137 out->insert({req->recv_key(i), empty_tensor});
138 }
139 return OkStatus();
140 }
141
RunGraphAsync(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)142 void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
143 MutableRunGraphResponseWrapper* response,
144 StatusCallback done) {
145 if (request->store_errors_in_response_body()) {
146 done = [response, done](const Status& status) {
147 response->set_status(status);
148 done(OkStatus());
149 };
150 }
151 if (request->is_partial()) {
152 DoPartialRunGraph(opts, request, response, std::move(done));
153 } else {
154 DoRunGraph(opts, request, response, std::move(done));
155 }
156 }
157
CreateRunGraphRequest()158 MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() {
159 return new InMemoryRunGraphRequest;
160 }
161
CreateRunGraphResponse()162 MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() {
163 return new InMemoryRunGraphResponse;
164 }
165
DoRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)166 void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
167 MutableRunGraphResponseWrapper* response,
168 StatusCallback done) {
169 const int64_t step_id = request->step_id();
170 TRACEPRINTF("RunGraph: %lld", step_id);
171 Status s = recent_request_ids_.TrackUnique(request->request_id(),
172 "RunGraph (Worker)", request);
173 if (!s.ok()) {
174 done(s);
175 return;
176 }
177
178 std::shared_ptr<WorkerSession> session;
179 if (request->create_worker_session_called()) {
180 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
181 &session);
182 } else {
183 session = env_->session_mgr->LegacySession();
184 }
185 if (!s.ok()) {
186 done(s);
187 return;
188 }
189 GraphMgr::NamedTensors in;
190 GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
191 s = PrepareRunGraph(request, &in, out);
192 if (!s.ok()) {
193 delete out;
194 done(s);
195 return;
196 }
197 StepStatsCollector* collector = nullptr;
198 if (request->exec_opts().report_tensor_allocations_upon_oom() ||
199 request->exec_opts().record_timeline() ||
200 request->exec_opts().record_costs()) {
201 collector = new StepStatsCollector(response->mutable_step_stats());
202 }
203 DeviceProfilerSession* device_profiler_session = nullptr;
204 if (collector && request->exec_opts().record_timeline()) {
205 // If timeline was requested, assume we want hardware level tracing.
206 device_profiler_session = DeviceProfilerSession::Create().release();
207 }
208 CancellationManager* cm = new CancellationManager;
209 opts->SetCancelCallback([this, cm, step_id]() {
210 LOG(INFO) << "Cancellation requested for RunGraph.";
211 cm->StartCancel();
212 AbortStep(step_id);
213 });
214 CancellationToken token;
215 token = cancellation_manager_.get_cancellation_token();
216 bool already_cancelled = !cancellation_manager_.RegisterCallback(
217 token, [cm]() { cm->StartCancel(); });
218 if (already_cancelled) {
219 opts->ClearCancelCallback();
220 delete cm;
221 delete collector;
222 delete device_profiler_session;
223 delete out;
224 done(errors::Aborted("Call was aborted"));
225 return;
226 }
227 session->graph_mgr()->ExecuteAsync(
228 request->graph_handle(), step_id, request->exec_opts(), in, session.get(),
229 collector, response, cm, env_->session_mgr->GetCoordinationServiceAgent(),
230 [this, step_id, response, session, cm, out, token, collector,
231 device_profiler_session, opts, done](const Status& status) {
232 Status s = status;
233 if (s.ok()) {
234 s = session->graph_mgr()->RecvOutputs(step_id, out);
235 }
236
237 opts->ClearCancelCallback();
238 cancellation_manager_.DeregisterCallback(token);
239 delete cm;
240
241 if (device_profiler_session) {
242 device_profiler_session->CollectData(response->mutable_step_stats())
243 .IgnoreError();
244 }
245
246 if (s.ok()) {
247 for (const auto& p : *out) {
248 const string& key = p.first;
249 const Tensor& val = p.second;
250 response->AddRecv(key, val);
251 }
252 }
253
254 if (collector) collector->Finalize();
255 delete collector;
256 delete device_profiler_session;
257 delete out;
258 done(s);
259 });
260 }
261
262 // TODO(suharshs): Add stats collection support to partial run.
DoPartialRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)263 void Worker::DoPartialRunGraph(CallOptions* opts,
264 RunGraphRequestWrapper* request,
265 MutableRunGraphResponseWrapper* response,
266 StatusCallback done) {
267 const int64_t step_id = request->step_id();
268 const string& graph_handle = request->graph_handle();
269 TRACEPRINTF("PartialRunGraph: %lld", step_id);
270 Status s = recent_request_ids_.TrackUnique(
271 request->request_id(), "PartialRunGraph (Worker)", request);
272 if (!s.ok()) {
273 done(s);
274 return;
275 }
276
277 std::shared_ptr<WorkerSession> session;
278 if (request->create_worker_session_called()) {
279 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
280 &session);
281 } else {
282 session = env_->session_mgr->LegacySession();
283 }
284 if (!s.ok()) {
285 done(s);
286 return;
287 }
288
289 GraphMgr::NamedTensors in;
290 GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
291 s = PrepareRunGraph(request, &in, out);
292 auto finish = [done, out, opts](const Status& s) {
293 opts->ClearCancelCallback();
294 delete out;
295 done(s);
296 };
297 if (!s.ok()) {
298 finish(s);
299 return;
300 }
301
302 CancellationManager* cm = nullptr;
303 bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm);
304
305 // Before we start doing anything, we set the RPC cancellation.
306 opts->SetCancelCallback([this, cm, step_id]() {
307 LOG(INFO) << "Cancellation requested for PartialRunGraph.";
308 cm->StartCancel();
309 AbortStep(step_id);
310 });
311
312 // If this is a new partial run request, the request will need to start the
313 // executors.
314 if (is_new_partial_run) {
315 CancellationToken token;
316 token = cancellation_manager_.get_cancellation_token();
317 cancellation_manager_.RegisterCallback(token,
318 [cm]() { cm->StartCancel(); });
319 session->graph_mgr()->ExecuteAsync(
320 graph_handle, step_id, request->exec_opts(), in, session.get(),
321 /*collector=*/nullptr, /*response=*/nullptr, cm,
322 env_->session_mgr->GetCoordinationServiceAgent(),
323 [this, token, step_id, session](Status s) {
324 cancellation_manager_.DeregisterCallback(token);
325 partial_run_mgr_.ExecutorDone(step_id, s);
326 });
327 } else {
328 // Send the partial run's new inputs.
329 s = session->graph_mgr()->SendInputs(step_id, in);
330 if (!s.ok()) {
331 finish(s);
332 return;
333 }
334 }
335
336 session->graph_mgr()->RecvOutputsAsync(
337 step_id, out, [this, out, request, response, step_id, finish](Status s) {
338 if (s.ok()) {
339 // Construct and return the resp.
340 for (const auto& p : *out) {
341 const string& key = p.first;
342 const Tensor& val = p.second;
343 response->AddRecv(key, val);
344 }
345 }
346 if (request->is_last_partial_run()) {
347 partial_run_mgr_.PartialRunDone(step_id, finish, s);
348 } else {
349 finish(s);
350 }
351 });
352 }
353
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)354 void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
355 CleanupGraphResponse* response,
356 StatusCallback done) {
357 const int64_t step_id = request->step_id();
358 env_->rendezvous_mgr->Cleanup(step_id);
359 if (env_->collective_executor_mgr) {
360 env_->collective_executor_mgr->Cleanup(step_id);
361 }
362 for (Device* d : env_->local_devices) {
363 ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
364 if (sam) {
365 sam->Cleanup(step_id);
366 }
367 }
368 done(OkStatus());
369 }
370
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)371 void Worker::CleanupAllAsync(const CleanupAllRequest* request,
372 CleanupAllResponse* response,
373 StatusCallback done) {
374 std::vector<string> containers;
375 for (const auto& c : request->container()) containers.push_back(c);
376 env_->device_mgr->ClearContainers(containers);
377 done(OkStatus());
378 }
379
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)380 void Worker::LoggingAsync(const LoggingRequest* request,
381 LoggingResponse* response, StatusCallback done) {
382 done(errors::Unimplemented("Logging"));
383 }
384
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)385 void Worker::TracingAsync(const TracingRequest* request,
386 TracingResponse* response, StatusCallback done) {
387 done(errors::Unimplemented("Tracing"));
388 }
389
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)390 void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
391 RecvBufResponse* response, StatusCallback done) {
392 // The base Worker class does not implement RecvBufAsync because
393 // it is not currently used for worker-to-worker communication. Use a
394 // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`)
395 // instead.
396 done(errors::Unimplemented("Worker::RecvBufAsync()"));
397 }
398
CompleteGroupAsync(CallOptions * opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)399 void Worker::CompleteGroupAsync(CallOptions* opts,
400 const CompleteGroupRequest* request,
401 CompleteGroupResponse* response,
402 StatusCallback done) {
403 if (!request->has_device_attributes()) {
404 done(errors::Internal(
405 "CompleteGroupRequest device_attributes is not set. Make sure you're "
406 "running the same version of Tensorflow on all workers."));
407 return;
408 }
409 if (env_->collective_executor_mgr) {
410 auto group_params = new CollGroupParams();
411 group_params->group_key = request->group_key();
412 group_params->group_size = request->group_size();
413 group_params->device_type = DeviceType(request->device_type());
414 env_->collective_executor_mgr->GetParamResolver()->CompleteGroupAsync(
415 request->device_attributes(), group_params, &cancellation_manager_,
416 [response, group_params, done = std::move(done)](const Status& s) {
417 if (s.ok()) {
418 response->set_group_key(group_params->group_key);
419 response->set_group_size(group_params->group_size);
420 response->set_device_type(group_params->device_type.type_string());
421 response->set_num_tasks(group_params->num_tasks);
422 for (const CollGroupMember& member : group_params->members) {
423 *response->add_device_attributes() = member.device;
424 }
425 response->set_communicator_key(
426 group_params->runtime_details.communicator_key);
427 } else {
428 LOG(ERROR) << "Bad status from CompleteGroupDistributed: " << s;
429 }
430 delete group_params;
431 done(s);
432 });
433 } else {
434 done(
435 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
436 }
437 }
438
CompleteInstanceAsync(CallOptions * opts,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)439 void Worker::CompleteInstanceAsync(CallOptions* opts,
440 const CompleteInstanceRequest* request,
441 CompleteInstanceResponse* response,
442 StatusCallback done) {
443 if (env_->collective_executor_mgr) {
444 env_->collective_executor_mgr->GetParamResolver()->CompleteInstanceAsync(
445 request, response, &cancellation_manager_, done);
446 } else {
447 done(
448 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
449 }
450 }
451
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)452 void Worker::GetStepSequenceAsync(const GetStepSequenceRequest* request,
453 GetStepSequenceResponse* response,
454 StatusCallback done) {
455 if (env_->collective_executor_mgr) {
456 env_->collective_executor_mgr->GetStepSequenceAsync(request, response,
457 done);
458 } else {
459 done(
460 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
461 }
462 }
463
464 // Helper for RecvTensor. Validates "key" and returns the source
465 // device in "*src_dev".
PrepareRecvTensor(const Rendezvous::ParsedKey & parsed,Device ** src_dev)466 Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
467 Device** src_dev) {
468 // Figures out which device the tensor is hosted on.
469 string local_name = DeviceNameUtils::LocalName(parsed.src_device);
470 TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
471
472 // Does the device have the right incarnation number we expect?
473 if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
474 return errors::AbortedWithPayloads(
475 strings::StrCat("RecvTensor expects a different device incarnation: ",
476 parsed.src_incarnation, " vs. ",
477 (*src_dev)->attributes().incarnation(),
478 ". Your worker job (\"",
479 env_->session_mgr->LegacySession()->worker_name(),
480 "\") was probably restarted. Check your "
481 "worker job for the reason why it was restarted."),
482 {{kWorkerPossiblyRestarted,
483 distributed_runtime::WorkerPossiblyRestarted().SerializeAsString()}});
484 }
485
486 return OkStatus();
487 }
488
RecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)489 void Worker::RecvTensorAsync(CallOptions* opts,
490 const RecvTensorRequest* request,
491 TensorResponse* response, StatusCallback done) {
492 // The base Worker class does not implement RecvTensorAsync, because
493 // it is not currently used for worker-to-worker communication. Use a
494 // transport-specific implementation (such as `GrpcWorker::RecvTensorAsync()`)
495 // instead.
496 done(errors::Unimplemented("Worker::RecvTensorAsync()"));
497 }
498
499 } // namespace tensorflow
500