1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
17
18 #include <utility>
19
20 #include "grpcpp/generic/generic_stub.h"
21 #include "grpcpp/grpcpp.h"
22 #include "tensorflow/core/common_runtime/process_util.h"
23 #include "tensorflow/core/distributed_runtime/call_options.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
28 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
29 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
30 #include "tensorflow/core/distributed_runtime/worker_interface.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/tracing.h"
37 #include "tensorflow/core/protobuf/transport_options.pb.h"
38 #include "tensorflow/core/protobuf/worker.pb.h"
39 #include "tensorflow/core/util/env_var.h"
40
41 namespace tensorflow {
42
43 class GrpcRemoteWorker : public WorkerInterface {
44 public:
GrpcRemoteWorker(SharedGrpcChannelPtr channel,::grpc::CompletionQueue * completion_queue,thread::ThreadPool * callback_threadpool,WorkerCacheLogger * logger,const string & target)45 explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
46 ::grpc::CompletionQueue* completion_queue,
47 thread::ThreadPool* callback_threadpool,
48 WorkerCacheLogger* logger, const string& target)
49 : channel_(std::move(channel)),
50 stub_(channel_),
51 cq_(completion_queue),
52 callback_threadpool_(callback_threadpool),
53 getstatus_(Method(GrpcWorkerMethod::kGetStatus)),
54 createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),
55 deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)),
56 registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)),
57 deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)),
58 rungraph_(Method(GrpcWorkerMethod::kRunGraph)),
59 cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
60 cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
61 recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
62 recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
63 logging_(Method(GrpcWorkerMethod::kLogging)),
64 tracing_(Method(GrpcWorkerMethod::kTracing)),
65 completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
66 instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
67 getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
68 markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
69 logger_(logger),
70 target_(target) {}
71
~GrpcRemoteWorker()72 ~GrpcRemoteWorker() override {}
73
GetStatusAsync(CallOptions * call_opts,const GetStatusRequest * request,GetStatusResponse * response,bool fail_fast,StatusCallback done)74 void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request,
75 GetStatusResponse* response, bool fail_fast,
76 StatusCallback done) override {
77 IssueRequest(request, response, getstatus_, std::move(done), call_opts,
78 fail_fast);
79 }
80
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)81 void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
82 CreateWorkerSessionResponse* response,
83 StatusCallback done) override {
84 IssueRequest(request, response, createworkersession_, std::move(done));
85 }
86
DeleteWorkerSessionAsync(CallOptions * call_opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)87 void DeleteWorkerSessionAsync(CallOptions* call_opts,
88 const DeleteWorkerSessionRequest* request,
89 DeleteWorkerSessionResponse* response,
90 StatusCallback done) override {
91 IssueRequest(request, response, deleteworkersession_, std::move(done),
92 call_opts);
93 }
94
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)95 void RegisterGraphAsync(const RegisterGraphRequest* request,
96 RegisterGraphResponse* response,
97 StatusCallback done) override {
98 IssueRequest(request, response, registergraph_, std::move(done));
99 }
100
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)101 void DeregisterGraphAsync(const DeregisterGraphRequest* request,
102 DeregisterGraphResponse* response,
103 StatusCallback done) override {
104 IssueRequest(request, response, deregistergraph_, std::move(done));
105 }
106
RunGraphAsync(CallOptions * call_opts,const RunGraphRequest * request,RunGraphResponse * response,StatusCallback done)107 void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
108 RunGraphResponse* response, StatusCallback done) override {
109 IssueRequest(request, response, rungraph_, std::move(done), call_opts);
110 }
RunGraphAsync(CallOptions * call_opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)111 void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request,
112 MutableRunGraphResponseWrapper* response,
113 StatusCallback done) override {
114 IssueRequest(&request->ToProto(), get_proto_from_wrapper(response),
115 rungraph_, std::move(done), call_opts);
116 }
117
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)118 void CleanupGraphAsync(const CleanupGraphRequest* request,
119 CleanupGraphResponse* response,
120 StatusCallback done) override {
121 IssueRequest(request, response, cleanupgraph_, std::move(done));
122 }
123
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)124 void CleanupAllAsync(const CleanupAllRequest* request,
125 CleanupAllResponse* response,
126 StatusCallback done) override {
127 IssueRequest(request, response, cleanupall_, std::move(done));
128 }
129
RecvBufAsync(CallOptions * call_opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)130 void RecvBufAsync(CallOptions* call_opts, const RecvBufRequest* request,
131 RecvBufResponse* response, StatusCallback done) override {
132 int64_t start_usec = Env::Default()->NowMicros();
133 // Type-specialized logging for this method.
134 bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
135
136 auto callback = [this, request, response, done, start_usec,
137 logging_active](Status s) {
138 if (logging_active) {
139 if (logger_->LoggingActive()) {
140 int64_t end_usec = Env::Default()->NowMicros();
141 int64_t step_id = request->step_id();
142 RecvBufRespExtra extra;
143 response->transport_options().UnpackTo(&extra);
144 int64_t num_bytes = 0;
145 for (const auto& chunk : extra.tensor_content()) {
146 num_bytes += chunk.size();
147 }
148 int64_t send_start_usec = start_usec;
149 // Prefer start time reported by the sender, if available.
150 if (response->send_start_micros()) {
151 send_start_usec =
152 std::max(start_usec,
153 static_cast<int64_t>(response->send_start_micros()));
154 send_start_usec = std::min(send_start_usec, end_usec - 1);
155 }
156 const string& key = request->buf_rendezvous_key();
157 logger_->RecordDataTransfer(
158 step_id, send_start_usec, end_usec, key, request->src_device(),
159 request->dst_device(), num_bytes, "", "RecvBuf");
160 }
161 VLOG(2) << "done callback, req: " << request->DebugString()
162 << " response " << response->DebugString();
163 }
164
165 // Note done() can delete this worker object, so we need to call done()
166 // last.
167 if (response->require_ack()) {
168 IssueMarkRecvFinishedRequest(request->request_id());
169 }
170 done(s);
171 };
172
173 IssueRequest(request, response, recvbuf_, callback, call_opts);
174 }
175
CompleteGroupAsync(CallOptions * call_opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)176 void CompleteGroupAsync(CallOptions* call_opts,
177 const CompleteGroupRequest* request,
178 CompleteGroupResponse* response,
179 StatusCallback done) override {
180 IssueRequest(request, response, completegroup_, std::move(done), call_opts,
181 /*fail_fast=*/false);
182 }
183
CompleteInstanceAsync(CallOptions * call_opts,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)184 void CompleteInstanceAsync(CallOptions* call_opts,
185 const CompleteInstanceRequest* request,
186 CompleteInstanceResponse* response,
187 StatusCallback done) override {
188 IssueRequest(request, response, instancesource_, std::move(done),
189 call_opts);
190 }
191
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)192 void GetStepSequenceAsync(const GetStepSequenceRequest* request,
193 GetStepSequenceResponse* response,
194 StatusCallback done) override {
195 IssueRequest(request, response, getstepsequence_, std::move(done));
196 }
197
RecvTensorAsync(CallOptions * call_opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)198 void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request,
199 TensorResponse* response, StatusCallback done) override {
200 VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
201 int64_t start_usec = Env::Default()->NowMicros();
202 // Type-specialized logging for this method.
203 bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
204
205 auto callback = [this, request, response, done, start_usec,
206 logging_active](Status s) {
207 if (logging_active) {
208 if (logger_->LoggingActive()) {
209 int64_t end_usec = Env::Default()->NowMicros();
210 int64_t step_id = request->step_id();
211 int64_t bytes = response->tensor().TotalBytes();
212 int64_t send_start_usec = start_usec;
213 // If a send start time was reported by the other side, use
214 // that instead. Maybe we should mark the display if we're using
215 // our local time instead of the remote start time?
216 if (response->metadata().send_start_micros()) {
217 // send_start_micros is the timestamp taken when the
218 // remote machine began to send the RecvTensor response.
219 // Due to clock skew between source and dest machines, it
220 // is possible that send_start_micros can be larger than
221 // end_usec or less than start_usec.
222 //
223 // To respect causality, we enforce the invariants that
224 // the RecvTensor response can not have been sent before
225 // the RecvTensor request, and must have been sent before
226 // it was received.
227 send_start_usec = std::max(
228 start_usec,
229 static_cast<int64_t>(response->metadata().send_start_micros()));
230 send_start_usec = std::min(send_start_usec, end_usec - 1);
231 }
232 const string& key = request->rendezvous_key();
233 std::vector<string> key_parts = str_util::Split(key, ';');
234 if (key_parts.size() != 5) {
235 LOG(WARNING) << "Bad key: " << key;
236 } else {
237 logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
238 key_parts[3], // tensor name
239 key_parts[0], // src_device
240 key_parts[2], // dst_device
241 bytes);
242 }
243 }
244 VLOG(2) << "done callback, req: " << request->DebugString()
245 << " response " << response->metadata().DebugString();
246 }
247
248 // Note done() can delete this worker object, so we need to call done()
249 // last.
250 if (response->metadata().require_ack()) {
251 IssueMarkRecvFinishedRequest(request->request_id());
252 }
253 done(s);
254 };
255
256 IssueRequest(request, response, recvtensor_, callback, call_opts);
257 }
258
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)259 void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
260 StatusCallback done) override {
261 IssueRequest(request, response, logging_, done);
262 }
263
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)264 void TracingAsync(const TracingRequest* request, TracingResponse* response,
265 StatusCallback done) override {
266 IssueRequest(request, response, tracing_, done);
267 }
268
269 private:
270 // Utility method for issuing a generic asynchronous request. The
271 // given callback, `done`, will be called when the RPC completes.
IssueRequest(const protobuf::Message * request,protobuf::Message * response,const::grpc::string & method,StatusCallback done,CallOptions * call_opts=nullptr,bool fail_fast=true)272 void IssueRequest(const protobuf::Message* request,
273 protobuf::Message* response, const ::grpc::string& method,
274 StatusCallback done, CallOptions* call_opts = nullptr,
275 bool fail_fast = true) {
276 new RPCState<protobuf::Message>(
277 &stub_, cq_, method, *request, response, std::move(done), call_opts,
278 callback_threadpool_, MaxRetries(), fail_fast, &target_);
279 }
280
IssueRequest(const protobuf::Message * request,TensorResponse * response,const::grpc::string & method,StatusCallback done,CallOptions * call_opts=nullptr)281 void IssueRequest(const protobuf::Message* request, TensorResponse* response,
282 const ::grpc::string& method, StatusCallback done,
283 CallOptions* call_opts = nullptr) {
284 new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
285 std::move(done), call_opts,
286 callback_threadpool_, MaxRetries(),
287 /*fail_fast=*/true, &target_);
288 }
289
IssueMarkRecvFinishedRequest(int64_t request_id)290 void IssueMarkRecvFinishedRequest(int64_t request_id) {
291 VLOG(2) << "Send MarkRecvFinishedRequest for request " << request_id;
292 MarkRecvFinishedRequest request;
293 request.set_request_id(request_id);
294
295 MarkRecvFinishedResponse* response = new MarkRecvFinishedResponse();
296 auto done = [response](Status status) { delete response; };
297 IssueRequest(&request, response, markrecvfinished_, done);
298 }
299
300 // Helper function for initializing the RpcMethod objects below.
Method(GrpcWorkerMethod id)301 const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }
302
303 // Helper function for configuring max GRPC retries. Defaults to 0 (no
304 // retries).
MaxRetries()305 const int64_t MaxRetries() {
306 int64_t max_retries = -1;
307 TF_CHECK_OK(ReadInt64FromEnvVar("GRPC_MAX_RETRIES", 0, &max_retries));
308 return max_retries;
309 }
310
311 SharedGrpcChannelPtr channel_;
312 ::grpc::GenericStub stub_;
313 ::grpc::CompletionQueue* cq_;
314 thread::ThreadPool* callback_threadpool_;
315
316 const ::grpc::string getstatus_;
317 const ::grpc::string createworkersession_;
318 const ::grpc::string deleteworkersession_;
319 const ::grpc::string registergraph_;
320 const ::grpc::string deregistergraph_;
321 const ::grpc::string rungraph_;
322 const ::grpc::string cleanupgraph_;
323 const ::grpc::string cleanupall_;
324 const ::grpc::string recvtensor_;
325 const ::grpc::string recvbuf_;
326 const ::grpc::string logging_;
327 const ::grpc::string tracing_;
328 const ::grpc::string completegroup_;
329 const ::grpc::string instancesource_;
330 const ::grpc::string getstepsequence_;
331 const ::grpc::string markrecvfinished_;
332
333 // Support for logging.
334 WorkerCacheLogger* logger_;
335 const string target_;
336
337 TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
338 };
339
NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,::grpc::CompletionQueue * completion_queue,thread::ThreadPool * callback_threadpool,WorkerCacheLogger * logger,const string & target)340 WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
341 ::grpc::CompletionQueue* completion_queue,
342 thread::ThreadPool* callback_threadpool,
343 WorkerCacheLogger* logger,
344 const string& target) {
345 return new GrpcRemoteWorker(std::move(channel), completion_queue,
346 callback_threadpool, logger, target);
347 }
348
349 } // namespace tensorflow
350