xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
17 
18 #include <utility>
19 
20 #include "grpcpp/generic/generic_stub.h"
21 #include "grpcpp/grpcpp.h"
22 #include "tensorflow/core/common_runtime/process_util.h"
23 #include "tensorflow/core/distributed_runtime/call_options.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
28 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
29 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
30 #include "tensorflow/core/distributed_runtime/worker_interface.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/tracing.h"
37 #include "tensorflow/core/protobuf/transport_options.pb.h"
38 #include "tensorflow/core/protobuf/worker.pb.h"
39 #include "tensorflow/core/util/env_var.h"
40 
41 namespace tensorflow {
42 
43 class GrpcRemoteWorker : public WorkerInterface {
44  public:
GrpcRemoteWorker(SharedGrpcChannelPtr channel,::grpc::CompletionQueue * completion_queue,thread::ThreadPool * callback_threadpool,WorkerCacheLogger * logger,const string & target)45   explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
46                             ::grpc::CompletionQueue* completion_queue,
47                             thread::ThreadPool* callback_threadpool,
48                             WorkerCacheLogger* logger, const string& target)
49       : channel_(std::move(channel)),
50         stub_(channel_),
51         cq_(completion_queue),
52         callback_threadpool_(callback_threadpool),
53         getstatus_(Method(GrpcWorkerMethod::kGetStatus)),
54         createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),
55         deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)),
56         registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)),
57         deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)),
58         rungraph_(Method(GrpcWorkerMethod::kRunGraph)),
59         cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
60         cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
61         recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
62         recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
63         logging_(Method(GrpcWorkerMethod::kLogging)),
64         tracing_(Method(GrpcWorkerMethod::kTracing)),
65         completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
66         instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
67         getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
68         markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
69         logger_(logger),
70         target_(target) {}
71 
~GrpcRemoteWorker()72   ~GrpcRemoteWorker() override {}
73 
GetStatusAsync(CallOptions * call_opts,const GetStatusRequest * request,GetStatusResponse * response,bool fail_fast,StatusCallback done)74   void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request,
75                       GetStatusResponse* response, bool fail_fast,
76                       StatusCallback done) override {
77     IssueRequest(request, response, getstatus_, std::move(done), call_opts,
78                  fail_fast);
79   }
80 
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)81   void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
82                                 CreateWorkerSessionResponse* response,
83                                 StatusCallback done) override {
84     IssueRequest(request, response, createworkersession_, std::move(done));
85   }
86 
DeleteWorkerSessionAsync(CallOptions * call_opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)87   void DeleteWorkerSessionAsync(CallOptions* call_opts,
88                                 const DeleteWorkerSessionRequest* request,
89                                 DeleteWorkerSessionResponse* response,
90                                 StatusCallback done) override {
91     IssueRequest(request, response, deleteworkersession_, std::move(done),
92                  call_opts);
93   }
94 
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)95   void RegisterGraphAsync(const RegisterGraphRequest* request,
96                           RegisterGraphResponse* response,
97                           StatusCallback done) override {
98     IssueRequest(request, response, registergraph_, std::move(done));
99   }
100 
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)101   void DeregisterGraphAsync(const DeregisterGraphRequest* request,
102                             DeregisterGraphResponse* response,
103                             StatusCallback done) override {
104     IssueRequest(request, response, deregistergraph_, std::move(done));
105   }
106 
RunGraphAsync(CallOptions * call_opts,const RunGraphRequest * request,RunGraphResponse * response,StatusCallback done)107   void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
108                      RunGraphResponse* response, StatusCallback done) override {
109     IssueRequest(request, response, rungraph_, std::move(done), call_opts);
110   }
RunGraphAsync(CallOptions * call_opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)111   void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request,
112                      MutableRunGraphResponseWrapper* response,
113                      StatusCallback done) override {
114     IssueRequest(&request->ToProto(), get_proto_from_wrapper(response),
115                  rungraph_, std::move(done), call_opts);
116   }
117 
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)118   void CleanupGraphAsync(const CleanupGraphRequest* request,
119                          CleanupGraphResponse* response,
120                          StatusCallback done) override {
121     IssueRequest(request, response, cleanupgraph_, std::move(done));
122   }
123 
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)124   void CleanupAllAsync(const CleanupAllRequest* request,
125                        CleanupAllResponse* response,
126                        StatusCallback done) override {
127     IssueRequest(request, response, cleanupall_, std::move(done));
128   }
129 
RecvBufAsync(CallOptions * call_opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)130   void RecvBufAsync(CallOptions* call_opts, const RecvBufRequest* request,
131                     RecvBufResponse* response, StatusCallback done) override {
132     int64_t start_usec = Env::Default()->NowMicros();
133     // Type-specialized logging for this method.
134     bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
135 
136     auto callback = [this, request, response, done, start_usec,
137                      logging_active](Status s) {
138       if (logging_active) {
139         if (logger_->LoggingActive()) {
140           int64_t end_usec = Env::Default()->NowMicros();
141           int64_t step_id = request->step_id();
142           RecvBufRespExtra extra;
143           response->transport_options().UnpackTo(&extra);
144           int64_t num_bytes = 0;
145           for (const auto& chunk : extra.tensor_content()) {
146             num_bytes += chunk.size();
147           }
148           int64_t send_start_usec = start_usec;
149           // Prefer start time reported by the sender, if available.
150           if (response->send_start_micros()) {
151             send_start_usec =
152                 std::max(start_usec,
153                          static_cast<int64_t>(response->send_start_micros()));
154             send_start_usec = std::min(send_start_usec, end_usec - 1);
155           }
156           const string& key = request->buf_rendezvous_key();
157           logger_->RecordDataTransfer(
158               step_id, send_start_usec, end_usec, key, request->src_device(),
159               request->dst_device(), num_bytes, "", "RecvBuf");
160         }
161         VLOG(2) << "done callback, req: " << request->DebugString()
162                 << " response " << response->DebugString();
163       }
164 
165       // Note done() can delete this worker object, so we need to call done()
166       // last.
167       if (response->require_ack()) {
168         IssueMarkRecvFinishedRequest(request->request_id());
169       }
170       done(s);
171     };
172 
173     IssueRequest(request, response, recvbuf_, callback, call_opts);
174   }
175 
CompleteGroupAsync(CallOptions * call_opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)176   void CompleteGroupAsync(CallOptions* call_opts,
177                           const CompleteGroupRequest* request,
178                           CompleteGroupResponse* response,
179                           StatusCallback done) override {
180     IssueRequest(request, response, completegroup_, std::move(done), call_opts,
181                  /*fail_fast=*/false);
182   }
183 
CompleteInstanceAsync(CallOptions * call_opts,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)184   void CompleteInstanceAsync(CallOptions* call_opts,
185                              const CompleteInstanceRequest* request,
186                              CompleteInstanceResponse* response,
187                              StatusCallback done) override {
188     IssueRequest(request, response, instancesource_, std::move(done),
189                  call_opts);
190   }
191 
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)192   void GetStepSequenceAsync(const GetStepSequenceRequest* request,
193                             GetStepSequenceResponse* response,
194                             StatusCallback done) override {
195     IssueRequest(request, response, getstepsequence_, std::move(done));
196   }
197 
RecvTensorAsync(CallOptions * call_opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)198   void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request,
199                        TensorResponse* response, StatusCallback done) override {
200     VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
201     int64_t start_usec = Env::Default()->NowMicros();
202     // Type-specialized logging for this method.
203     bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
204 
205     auto callback = [this, request, response, done, start_usec,
206                      logging_active](Status s) {
207       if (logging_active) {
208         if (logger_->LoggingActive()) {
209           int64_t end_usec = Env::Default()->NowMicros();
210           int64_t step_id = request->step_id();
211           int64_t bytes = response->tensor().TotalBytes();
212           int64_t send_start_usec = start_usec;
213           // If a send start time was reported by the other side, use
214           // that instead.  Maybe we should mark the display if we're using
215           // our local time instead of the remote start time?
216           if (response->metadata().send_start_micros()) {
217             // send_start_micros is the timestamp taken when the
218             // remote machine began to send the RecvTensor response.
219             // Due to clock skew between source and dest machines, it
220             // is possible that send_start_micros can be larger than
221             // end_usec or less than start_usec.
222             //
223             // To respect causality, we enforce the invariants that
224             // the RecvTensor response can not have been sent before
225             // the RecvTensor request, and must have been sent before
226             // it was received.
227             send_start_usec = std::max(
228                 start_usec,
229                 static_cast<int64_t>(response->metadata().send_start_micros()));
230             send_start_usec = std::min(send_start_usec, end_usec - 1);
231           }
232           const string& key = request->rendezvous_key();
233           std::vector<string> key_parts = str_util::Split(key, ';');
234           if (key_parts.size() != 5) {
235             LOG(WARNING) << "Bad key: " << key;
236           } else {
237             logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
238                                       key_parts[3],  // tensor name
239                                       key_parts[0],  // src_device
240                                       key_parts[2],  // dst_device
241                                       bytes);
242           }
243         }
244         VLOG(2) << "done callback, req: " << request->DebugString()
245                 << " response " << response->metadata().DebugString();
246       }
247 
248       // Note done() can delete this worker object, so we need to call done()
249       // last.
250       if (response->metadata().require_ack()) {
251         IssueMarkRecvFinishedRequest(request->request_id());
252       }
253       done(s);
254     };
255 
256     IssueRequest(request, response, recvtensor_, callback, call_opts);
257   }
258 
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)259   void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
260                     StatusCallback done) override {
261     IssueRequest(request, response, logging_, done);
262   }
263 
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)264   void TracingAsync(const TracingRequest* request, TracingResponse* response,
265                     StatusCallback done) override {
266     IssueRequest(request, response, tracing_, done);
267   }
268 
269  private:
270   // Utility method for issuing a generic asynchronous request. The
271   // given callback, `done`, will be called when the RPC completes.
IssueRequest(const protobuf::Message * request,protobuf::Message * response,const::grpc::string & method,StatusCallback done,CallOptions * call_opts=nullptr,bool fail_fast=true)272   void IssueRequest(const protobuf::Message* request,
273                     protobuf::Message* response, const ::grpc::string& method,
274                     StatusCallback done, CallOptions* call_opts = nullptr,
275                     bool fail_fast = true) {
276     new RPCState<protobuf::Message>(
277         &stub_, cq_, method, *request, response, std::move(done), call_opts,
278         callback_threadpool_, MaxRetries(), fail_fast, &target_);
279   }
280 
IssueRequest(const protobuf::Message * request,TensorResponse * response,const::grpc::string & method,StatusCallback done,CallOptions * call_opts=nullptr)281   void IssueRequest(const protobuf::Message* request, TensorResponse* response,
282                     const ::grpc::string& method, StatusCallback done,
283                     CallOptions* call_opts = nullptr) {
284     new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
285                                  std::move(done), call_opts,
286                                  callback_threadpool_, MaxRetries(),
287                                  /*fail_fast=*/true, &target_);
288   }
289 
IssueMarkRecvFinishedRequest(int64_t request_id)290   void IssueMarkRecvFinishedRequest(int64_t request_id) {
291     VLOG(2) << "Send MarkRecvFinishedRequest for request " << request_id;
292     MarkRecvFinishedRequest request;
293     request.set_request_id(request_id);
294 
295     MarkRecvFinishedResponse* response = new MarkRecvFinishedResponse();
296     auto done = [response](Status status) { delete response; };
297     IssueRequest(&request, response, markrecvfinished_, done);
298   }
299 
300   // Helper function for initializing the RpcMethod objects below.
Method(GrpcWorkerMethod id)301   const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }
302 
303   // Helper function for configuring max GRPC retries. Defaults to 0 (no
304   // retries).
MaxRetries()305   const int64_t MaxRetries() {
306     int64_t max_retries = -1;
307     TF_CHECK_OK(ReadInt64FromEnvVar("GRPC_MAX_RETRIES", 0, &max_retries));
308     return max_retries;
309   }
310 
311   SharedGrpcChannelPtr channel_;
312   ::grpc::GenericStub stub_;
313   ::grpc::CompletionQueue* cq_;
314   thread::ThreadPool* callback_threadpool_;
315 
316   const ::grpc::string getstatus_;
317   const ::grpc::string createworkersession_;
318   const ::grpc::string deleteworkersession_;
319   const ::grpc::string registergraph_;
320   const ::grpc::string deregistergraph_;
321   const ::grpc::string rungraph_;
322   const ::grpc::string cleanupgraph_;
323   const ::grpc::string cleanupall_;
324   const ::grpc::string recvtensor_;
325   const ::grpc::string recvbuf_;
326   const ::grpc::string logging_;
327   const ::grpc::string tracing_;
328   const ::grpc::string completegroup_;
329   const ::grpc::string instancesource_;
330   const ::grpc::string getstepsequence_;
331   const ::grpc::string markrecvfinished_;
332 
333   // Support for logging.
334   WorkerCacheLogger* logger_;
335   const string target_;
336 
337   TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
338 };
339 
NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,::grpc::CompletionQueue * completion_queue,thread::ThreadPool * callback_threadpool,WorkerCacheLogger * logger,const string & target)340 WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
341                                      ::grpc::CompletionQueue* completion_queue,
342                                      thread::ThreadPool* callback_threadpool,
343                                      WorkerCacheLogger* logger,
344                                      const string& target) {
345   return new GrpcRemoteWorker(std::move(channel), completion_queue,
346                               callback_threadpool, logger, target);
347 }
348 
349 }  // namespace tensorflow
350