1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // GrpcMasterService implements the RPC service MasterService.
17 //
18 // A GrpcMasterService maintains the state of live graph computation
19 // sessions, each session orchestrates both local and remote devices
20 // to carry out the graph computation.
21 //
22 // A GrpcMasterService knows ahead of time local devices available as
23 // client devices.
24 //
25 // A GrpcMasterService discovers remote devices in the background and
26 // keeps track of statistics of those remote devices.
27 //
28 // Each session analyzes the graph, places nodes across available
29 // devices, and ultimately drives the graph computation by initiating
30 // RunGraph on workers.
31 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
32
33 #include "grpcpp/alarm.h"
34 #include "grpcpp/server_builder.h"
35 #include "tensorflow/core/distributed_runtime/master.h"
36 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
37 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/tracing.h"
43 #include "tensorflow/core/profiler/lib/traceme.h"
44 #include "tensorflow/core/protobuf/master.pb.h"
45
46 namespace tensorflow {
47
48 class GrpcMasterService : public AsyncServiceInterface {
49 public:
GrpcMasterService(Master * master,const ConfigProto & default_session_config,::grpc::ServerBuilder * builder)50 GrpcMasterService(Master* master, const ConfigProto& default_session_config,
51 ::grpc::ServerBuilder* builder)
52 : master_impl_(master),
53 is_shutdown_(false),
54 default_session_config_(default_session_config) {
55 builder->RegisterService(&master_service_);
56 cq_ = builder->AddCompletionQueue();
57 }
58
~GrpcMasterService()59 ~GrpcMasterService() override { delete shutdown_alarm_; }
60
Shutdown()61 void Shutdown() override {
62 bool did_shutdown = false;
63 {
64 mutex_lock l(mu_);
65 if (!is_shutdown_) {
66 LOG(INFO) << "Shutting down GrpcMasterService.";
67 is_shutdown_ = true;
68 did_shutdown = true;
69 }
70 }
71 if (did_shutdown) {
72 // NOTE(mrry): This enqueues a special event (with a null tag)
73 // that causes the completion queue to be shut down on the
74 // polling thread.
75 shutdown_alarm_ =
76 new ::grpc::Alarm(cq_.get(), gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
77 }
78 }
79
80 // This macro creates a new request for the given RPC method name
81 // (e.g., `ENQUEUE_REQUEST(RunStep);`), and enqueues it on
82 // `this->cq_`.
83 //
84 // This macro is invoked one or more times for each RPC method to
85 // ensure that there are sufficient completion queue entries to
86 // handle incoming requests without blocking.
87 //
88 // The implementation of the request handler for each RPC method
89 // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
90 // to keep accepting new requests.
91 #define ENQUEUE_REQUEST(method, supports_cancel) \
92 do { \
93 mutex_lock l(mu_); \
94 if (!is_shutdown_) { \
95 Call<GrpcMasterService, grpc::MasterService::AsyncService, \
96 method##Request, method##Response>:: \
97 EnqueueRequest(&master_service_, cq_.get(), \
98 &grpc::MasterService::AsyncService::Request##method, \
99 &GrpcMasterService::method##Handler, \
100 (supports_cancel)); \
101 } \
102 } while (0)
103
HandleRPCsLoop()104 void HandleRPCsLoop() override {
105 ENQUEUE_REQUEST(CreateSession, true);
106 ENQUEUE_REQUEST(ExtendSession, false);
107 for (int i = 0; i < 100; ++i) {
108 ENQUEUE_REQUEST(PartialRunSetup, false);
109 ENQUEUE_REQUEST(RunStep, true);
110 }
111 ENQUEUE_REQUEST(CloseSession, false);
112 ENQUEUE_REQUEST(ListDevices, false);
113 ENQUEUE_REQUEST(Reset, false);
114 ENQUEUE_REQUEST(MakeCallable, false);
115 for (int i = 0; i < 100; ++i) {
116 ENQUEUE_REQUEST(RunCallable, true);
117 }
118 ENQUEUE_REQUEST(ReleaseCallable, false);
119
120 void* tag;
121 bool ok;
122 while (cq_->Next(&tag, &ok)) {
123 UntypedCall<GrpcMasterService>::Tag* callback_tag =
124 static_cast<UntypedCall<GrpcMasterService>::Tag*>(tag);
125 if (callback_tag) {
126 callback_tag->OnCompleted(this, ok);
127 } else {
128 // NOTE(mrry): A null `callback_tag` indicates that this is
129 // the shutdown alarm.
130 cq_->Shutdown();
131 }
132 }
133 }
134
135 private:
136 Master* master_impl_ = nullptr; // Not owned.
137 std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
138 grpc::MasterService::AsyncService master_service_;
139
140 mutex mu_;
141 bool is_shutdown_ TF_GUARDED_BY(mu_);
142 const ConfigProto default_session_config_;
143 ::grpc::Alarm* shutdown_alarm_ = nullptr;
144
145 template <class RequestMessage, class ResponseMessage>
146 using MasterCall = Call<GrpcMasterService, grpc::MasterService::AsyncService,
147 RequestMessage, ResponseMessage>;
148
149 // RPC handler for creating a session.
CreateSessionHandler(MasterCall<CreateSessionRequest,CreateSessionResponse> * call)150 void CreateSessionHandler(
151 MasterCall<CreateSessionRequest, CreateSessionResponse>* call) {
152 CreateSessionRequest* rewritten_req = new CreateSessionRequest;
153 rewritten_req->mutable_config()->MergeFrom(default_session_config_);
154 rewritten_req->MergeFrom(call->request);
155 master_impl_->CreateSession(rewritten_req, &call->response,
156 [call, rewritten_req](const Status& status) {
157 call->SendResponse(ToGrpcStatus(status));
158 delete rewritten_req;
159 });
160 ENQUEUE_REQUEST(CreateSession, true);
161 }
162
163 // RPC handler for extending a session.
ExtendSessionHandler(MasterCall<ExtendSessionRequest,ExtendSessionResponse> * call)164 void ExtendSessionHandler(
165 MasterCall<ExtendSessionRequest, ExtendSessionResponse>* call) {
166 master_impl_->ExtendSession(&call->request, &call->response,
167 [call](const Status& status) {
168 call->SendResponse(ToGrpcStatus(status));
169 });
170 ENQUEUE_REQUEST(ExtendSession, false);
171 }
172
173 // RPC handler for setting up a partial run call.
PartialRunSetupHandler(MasterCall<PartialRunSetupRequest,PartialRunSetupResponse> * call)174 void PartialRunSetupHandler(
175 MasterCall<PartialRunSetupRequest, PartialRunSetupResponse>* call) {
176 master_impl_->PartialRunSetup(&call->request, &call->response,
177 [call](const Status& status) {
178 call->SendResponse(ToGrpcStatus(status));
179 });
180 ENQUEUE_REQUEST(PartialRunSetup, false);
181 }
182
183 // RPC handler for running one step in a session.
RunStepHandler(MasterCall<RunStepRequest,RunStepResponse> * call)184 void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
185 auto* trace = TraceRpc("RunStep/Server", call->client_metadata());
186 CallOptions* call_opts = new CallOptions;
187 if (call->request.options().timeout_in_ms() > 0) {
188 call_opts->SetTimeout(call->request.options().timeout_in_ms());
189 } else {
190 call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
191 }
192 RunStepRequestWrapper* wrapped_request =
193 new ProtoRunStepRequest(&call->request);
194 MutableRunStepResponseWrapper* wrapped_response =
195 new NonOwnedProtoRunStepResponse(&call->response);
196 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
197 master_impl_->RunStep(
198 call_opts, wrapped_request, wrapped_response,
199 [call, call_opts, wrapped_request, wrapped_response,
200 trace](const Status& status) {
201 call->ClearCancelCallback();
202 delete call_opts;
203 delete wrapped_request;
204 delete wrapped_response;
205 delete trace;
206 if (call->request.store_errors_in_response_body() && !status.ok()) {
207 call->response.set_status_code(status.code());
208 call->response.set_status_error_message(status.error_message());
209 call->SendResponse(ToGrpcStatus(OkStatus()));
210 } else {
211 call->SendResponse(ToGrpcStatus(status));
212 }
213 });
214 ENQUEUE_REQUEST(RunStep, true);
215 }
216
217 // RPC handler for deleting a session.
CloseSessionHandler(MasterCall<CloseSessionRequest,CloseSessionResponse> * call)218 void CloseSessionHandler(
219 MasterCall<CloseSessionRequest, CloseSessionResponse>* call) {
220 master_impl_->CloseSession(&call->request, &call->response,
221 [call](const Status& status) {
222 call->SendResponse(ToGrpcStatus(status));
223 });
224 ENQUEUE_REQUEST(CloseSession, false);
225 }
226
227 // RPC handler for listing devices.
ListDevicesHandler(MasterCall<ListDevicesRequest,ListDevicesResponse> * call)228 void ListDevicesHandler(
229 MasterCall<ListDevicesRequest, ListDevicesResponse>* call) {
230 master_impl_->ListDevices(&call->request, &call->response,
231 [call](const Status& status) {
232 call->SendResponse(ToGrpcStatus(status));
233 });
234 ENQUEUE_REQUEST(ListDevices, false);
235 }
236
237 // RPC handler for resetting all sessions.
ResetHandler(MasterCall<ResetRequest,ResetResponse> * call)238 void ResetHandler(MasterCall<ResetRequest, ResetResponse>* call) {
239 master_impl_->Reset(&call->request, &call->response,
240 [call](const Status& status) {
241 call->SendResponse(ToGrpcStatus(status));
242 });
243 ENQUEUE_REQUEST(Reset, false);
244 }
245
246 // RPC handler for making a callable.
MakeCallableHandler(MasterCall<MakeCallableRequest,MakeCallableResponse> * call)247 void MakeCallableHandler(
248 MasterCall<MakeCallableRequest, MakeCallableResponse>* call) {
249 master_impl_->MakeCallable(&call->request, &call->response,
250 [call](const Status& status) {
251 call->SendResponse(ToGrpcStatus(status));
252 });
253 ENQUEUE_REQUEST(MakeCallable, false);
254 }
255
256 // RPC handler for running a callable.
RunCallableHandler(MasterCall<RunCallableRequest,RunCallableResponse> * call)257 void RunCallableHandler(
258 MasterCall<RunCallableRequest, RunCallableResponse>* call) {
259 auto* trace = TraceRpc("RunCallable/Server", call->client_metadata());
260 CallOptions* call_opts = new CallOptions;
261 // The timeout may be overridden by a non-zero timeout in the
262 // callable's `RunOptions`; this overriding will happen inside the
263 // `MasterSession` implementation.
264 call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
265 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
266 master_impl_->RunCallable(call_opts, &call->request, &call->response,
267 [call, call_opts, trace](const Status& status) {
268 call->ClearCancelCallback();
269 delete call_opts;
270 delete trace;
271 call->SendResponse(ToGrpcStatus(status));
272 });
273 ENQUEUE_REQUEST(RunCallable, false);
274 }
275
276 // RPC handler for making a callable.
ReleaseCallableHandler(MasterCall<ReleaseCallableRequest,ReleaseCallableResponse> * call)277 void ReleaseCallableHandler(
278 MasterCall<ReleaseCallableRequest, ReleaseCallableResponse>* call) {
279 master_impl_->ReleaseCallable(&call->request, &call->response,
280 [call](const Status& status) {
281 call->SendResponse(ToGrpcStatus(status));
282 });
283 ENQUEUE_REQUEST(ReleaseCallable, false);
284 }
285
286 #undef ENQUEUE_REQUEST
287
288 // Start tracing, including the ID attached to the RPC.
TraceRpc(StringPiece name,const std::multimap<::grpc::string_ref,::grpc::string_ref> & metadata)289 profiler::TraceMe* TraceRpc(
290 StringPiece name,
291 const std::multimap<::grpc::string_ref, ::grpc::string_ref>& metadata) {
292 StringPiece id;
293 auto it = metadata.find(GrpcIdKey());
294 if (it != metadata.end()) {
295 id = StringPiece(it->second.data(), it->second.size());
296 }
297 return new profiler::TraceMe([&] { return strings::StrCat(name, ":", id); },
298 profiler::TraceMeLevel::kInfo);
299 }
300
301 TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterService);
302 };
303
NewGrpcMasterService(Master * master,const ConfigProto & default_session_config,::grpc::ServerBuilder * builder)304 AsyncServiceInterface* NewGrpcMasterService(
305 Master* master, const ConfigProto& default_session_config,
306 ::grpc::ServerBuilder* builder) {
307 return new GrpcMasterService(master, default_session_config, builder);
308 }
309
310 } // end namespace tensorflow
311