xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // GrpcMasterService implements the RPC service MasterService.
17 //
18 // A GrpcMasterService maintains the state of live graph computation
19 // sessions, each session orchestrates both local and remote devices
20 // to carry out the graph computation.
21 //
22 // A GrpcMasterService knows ahead of time local devices available as
23 // client devices.
24 //
25 // A GrpcMasterService discovers remote devices in the background and
26 // keeps track of statistics of those remote devices.
27 //
28 // Each session analyzes the graph, places nodes across available
29 // devices, and ultimately drives the graph computation by initiating
30 // RunGraph on workers.
31 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
32 
33 #include "grpcpp/alarm.h"
34 #include "grpcpp/server_builder.h"
35 #include "tensorflow/core/distributed_runtime/master.h"
36 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
37 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/tracing.h"
43 #include "tensorflow/core/profiler/lib/traceme.h"
44 #include "tensorflow/core/protobuf/master.pb.h"
45 
46 namespace tensorflow {
47 
48 class GrpcMasterService : public AsyncServiceInterface {
49  public:
GrpcMasterService(Master * master,const ConfigProto & default_session_config,::grpc::ServerBuilder * builder)50   GrpcMasterService(Master* master, const ConfigProto& default_session_config,
51                     ::grpc::ServerBuilder* builder)
52       : master_impl_(master),
53         is_shutdown_(false),
54         default_session_config_(default_session_config) {
55     builder->RegisterService(&master_service_);
56     cq_ = builder->AddCompletionQueue();
57   }
58 
~GrpcMasterService()59   ~GrpcMasterService() override { delete shutdown_alarm_; }
60 
Shutdown()61   void Shutdown() override {
62     bool did_shutdown = false;
63     {
64       mutex_lock l(mu_);
65       if (!is_shutdown_) {
66         LOG(INFO) << "Shutting down GrpcMasterService.";
67         is_shutdown_ = true;
68         did_shutdown = true;
69       }
70     }
71     if (did_shutdown) {
72       // NOTE(mrry): This enqueues a special event (with a null tag)
73       // that causes the completion queue to be shut down on the
74       // polling thread.
75       shutdown_alarm_ =
76           new ::grpc::Alarm(cq_.get(), gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
77     }
78   }
79 
80 // This macro creates a new request for the given RPC method name
81 // (e.g., `ENQUEUE_REQUEST(RunStep);`), and enqueues it on
82 // `this->cq_`.
83 //
84 // This macro is invoked one or more times for each RPC method to
85 // ensure that there are sufficient completion queue entries to
86 // handle incoming requests without blocking.
87 //
88 // The implementation of the request handler for each RPC method
89 // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
90 // to keep accepting new requests.
91 #define ENQUEUE_REQUEST(method, supports_cancel)                              \
92   do {                                                                        \
93     mutex_lock l(mu_);                                                        \
94     if (!is_shutdown_) {                                                      \
95       Call<GrpcMasterService, grpc::MasterService::AsyncService,              \
96            method##Request, method##Response>::                               \
97           EnqueueRequest(&master_service_, cq_.get(),                         \
98                          &grpc::MasterService::AsyncService::Request##method, \
99                          &GrpcMasterService::method##Handler,                 \
100                          (supports_cancel));                                  \
101     }                                                                         \
102   } while (0)
103 
HandleRPCsLoop()104   void HandleRPCsLoop() override {
105     ENQUEUE_REQUEST(CreateSession, true);
106     ENQUEUE_REQUEST(ExtendSession, false);
107     for (int i = 0; i < 100; ++i) {
108       ENQUEUE_REQUEST(PartialRunSetup, false);
109       ENQUEUE_REQUEST(RunStep, true);
110     }
111     ENQUEUE_REQUEST(CloseSession, false);
112     ENQUEUE_REQUEST(ListDevices, false);
113     ENQUEUE_REQUEST(Reset, false);
114     ENQUEUE_REQUEST(MakeCallable, false);
115     for (int i = 0; i < 100; ++i) {
116       ENQUEUE_REQUEST(RunCallable, true);
117     }
118     ENQUEUE_REQUEST(ReleaseCallable, false);
119 
120     void* tag;
121     bool ok;
122     while (cq_->Next(&tag, &ok)) {
123       UntypedCall<GrpcMasterService>::Tag* callback_tag =
124           static_cast<UntypedCall<GrpcMasterService>::Tag*>(tag);
125       if (callback_tag) {
126         callback_tag->OnCompleted(this, ok);
127       } else {
128         // NOTE(mrry): A null `callback_tag` indicates that this is
129         // the shutdown alarm.
130         cq_->Shutdown();
131       }
132     }
133   }
134 
135  private:
136   Master* master_impl_ = nullptr;  // Not owned.
137   std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
138   grpc::MasterService::AsyncService master_service_;
139 
140   mutex mu_;
141   bool is_shutdown_ TF_GUARDED_BY(mu_);
142   const ConfigProto default_session_config_;
143   ::grpc::Alarm* shutdown_alarm_ = nullptr;
144 
145   template <class RequestMessage, class ResponseMessage>
146   using MasterCall = Call<GrpcMasterService, grpc::MasterService::AsyncService,
147                           RequestMessage, ResponseMessage>;
148 
149   // RPC handler for creating a session.
CreateSessionHandler(MasterCall<CreateSessionRequest,CreateSessionResponse> * call)150   void CreateSessionHandler(
151       MasterCall<CreateSessionRequest, CreateSessionResponse>* call) {
152     CreateSessionRequest* rewritten_req = new CreateSessionRequest;
153     rewritten_req->mutable_config()->MergeFrom(default_session_config_);
154     rewritten_req->MergeFrom(call->request);
155     master_impl_->CreateSession(rewritten_req, &call->response,
156                                 [call, rewritten_req](const Status& status) {
157                                   call->SendResponse(ToGrpcStatus(status));
158                                   delete rewritten_req;
159                                 });
160     ENQUEUE_REQUEST(CreateSession, true);
161   }
162 
163   // RPC handler for extending a session.
ExtendSessionHandler(MasterCall<ExtendSessionRequest,ExtendSessionResponse> * call)164   void ExtendSessionHandler(
165       MasterCall<ExtendSessionRequest, ExtendSessionResponse>* call) {
166     master_impl_->ExtendSession(&call->request, &call->response,
167                                 [call](const Status& status) {
168                                   call->SendResponse(ToGrpcStatus(status));
169                                 });
170     ENQUEUE_REQUEST(ExtendSession, false);
171   }
172 
173   // RPC handler for setting up a partial run call.
PartialRunSetupHandler(MasterCall<PartialRunSetupRequest,PartialRunSetupResponse> * call)174   void PartialRunSetupHandler(
175       MasterCall<PartialRunSetupRequest, PartialRunSetupResponse>* call) {
176     master_impl_->PartialRunSetup(&call->request, &call->response,
177                                   [call](const Status& status) {
178                                     call->SendResponse(ToGrpcStatus(status));
179                                   });
180     ENQUEUE_REQUEST(PartialRunSetup, false);
181   }
182 
183   // RPC handler for running one step in a session.
RunStepHandler(MasterCall<RunStepRequest,RunStepResponse> * call)184   void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
185     auto* trace = TraceRpc("RunStep/Server", call->client_metadata());
186     CallOptions* call_opts = new CallOptions;
187     if (call->request.options().timeout_in_ms() > 0) {
188       call_opts->SetTimeout(call->request.options().timeout_in_ms());
189     } else {
190       call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
191     }
192     RunStepRequestWrapper* wrapped_request =
193         new ProtoRunStepRequest(&call->request);
194     MutableRunStepResponseWrapper* wrapped_response =
195         new NonOwnedProtoRunStepResponse(&call->response);
196     call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
197     master_impl_->RunStep(
198         call_opts, wrapped_request, wrapped_response,
199         [call, call_opts, wrapped_request, wrapped_response,
200          trace](const Status& status) {
201           call->ClearCancelCallback();
202           delete call_opts;
203           delete wrapped_request;
204           delete wrapped_response;
205           delete trace;
206           if (call->request.store_errors_in_response_body() && !status.ok()) {
207             call->response.set_status_code(status.code());
208             call->response.set_status_error_message(status.error_message());
209             call->SendResponse(ToGrpcStatus(OkStatus()));
210           } else {
211             call->SendResponse(ToGrpcStatus(status));
212           }
213         });
214     ENQUEUE_REQUEST(RunStep, true);
215   }
216 
217   // RPC handler for deleting a session.
CloseSessionHandler(MasterCall<CloseSessionRequest,CloseSessionResponse> * call)218   void CloseSessionHandler(
219       MasterCall<CloseSessionRequest, CloseSessionResponse>* call) {
220     master_impl_->CloseSession(&call->request, &call->response,
221                                [call](const Status& status) {
222                                  call->SendResponse(ToGrpcStatus(status));
223                                });
224     ENQUEUE_REQUEST(CloseSession, false);
225   }
226 
227   // RPC handler for listing devices.
ListDevicesHandler(MasterCall<ListDevicesRequest,ListDevicesResponse> * call)228   void ListDevicesHandler(
229       MasterCall<ListDevicesRequest, ListDevicesResponse>* call) {
230     master_impl_->ListDevices(&call->request, &call->response,
231                               [call](const Status& status) {
232                                 call->SendResponse(ToGrpcStatus(status));
233                               });
234     ENQUEUE_REQUEST(ListDevices, false);
235   }
236 
237   // RPC handler for resetting all sessions.
ResetHandler(MasterCall<ResetRequest,ResetResponse> * call)238   void ResetHandler(MasterCall<ResetRequest, ResetResponse>* call) {
239     master_impl_->Reset(&call->request, &call->response,
240                         [call](const Status& status) {
241                           call->SendResponse(ToGrpcStatus(status));
242                         });
243     ENQUEUE_REQUEST(Reset, false);
244   }
245 
246   // RPC handler for making a callable.
MakeCallableHandler(MasterCall<MakeCallableRequest,MakeCallableResponse> * call)247   void MakeCallableHandler(
248       MasterCall<MakeCallableRequest, MakeCallableResponse>* call) {
249     master_impl_->MakeCallable(&call->request, &call->response,
250                                [call](const Status& status) {
251                                  call->SendResponse(ToGrpcStatus(status));
252                                });
253     ENQUEUE_REQUEST(MakeCallable, false);
254   }
255 
256   // RPC handler for running a callable.
RunCallableHandler(MasterCall<RunCallableRequest,RunCallableResponse> * call)257   void RunCallableHandler(
258       MasterCall<RunCallableRequest, RunCallableResponse>* call) {
259     auto* trace = TraceRpc("RunCallable/Server", call->client_metadata());
260     CallOptions* call_opts = new CallOptions;
261     // The timeout may be overridden by a non-zero timeout in the
262     // callable's `RunOptions`; this overriding will happen inside the
263     // `MasterSession` implementation.
264     call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
265     call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
266     master_impl_->RunCallable(call_opts, &call->request, &call->response,
267                               [call, call_opts, trace](const Status& status) {
268                                 call->ClearCancelCallback();
269                                 delete call_opts;
270                                 delete trace;
271                                 call->SendResponse(ToGrpcStatus(status));
272                               });
273     ENQUEUE_REQUEST(RunCallable, false);
274   }
275 
276   // RPC handler for making a callable.
ReleaseCallableHandler(MasterCall<ReleaseCallableRequest,ReleaseCallableResponse> * call)277   void ReleaseCallableHandler(
278       MasterCall<ReleaseCallableRequest, ReleaseCallableResponse>* call) {
279     master_impl_->ReleaseCallable(&call->request, &call->response,
280                                   [call](const Status& status) {
281                                     call->SendResponse(ToGrpcStatus(status));
282                                   });
283     ENQUEUE_REQUEST(ReleaseCallable, false);
284   }
285 
286 #undef ENQUEUE_REQUEST
287 
288   // Start tracing, including the ID attached to the RPC.
TraceRpc(StringPiece name,const std::multimap<::grpc::string_ref,::grpc::string_ref> & metadata)289   profiler::TraceMe* TraceRpc(
290       StringPiece name,
291       const std::multimap<::grpc::string_ref, ::grpc::string_ref>& metadata) {
292     StringPiece id;
293     auto it = metadata.find(GrpcIdKey());
294     if (it != metadata.end()) {
295       id = StringPiece(it->second.data(), it->second.size());
296     }
297     return new profiler::TraceMe([&] { return strings::StrCat(name, ":", id); },
298                                  profiler::TraceMeLevel::kInfo);
299   }
300 
301   TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterService);
302 };
303 
NewGrpcMasterService(Master * master,const ConfigProto & default_session_config,::grpc::ServerBuilder * builder)304 AsyncServiceInterface* NewGrpcMasterService(
305     Master* master, const ConfigProto& default_session_config,
306     ::grpc::ServerBuilder* builder) {
307   return new GrpcMasterService(master, default_session_config, builder);
308 }
309 
310 }  // end namespace tensorflow
311