xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/local_master.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/local_master.h"
17 
18 #include <unordered_map>
19 
20 #include "tensorflow/core/distributed_runtime/master.h"
21 #include "tensorflow/core/platform/mutex.h"
22 
23 namespace tensorflow {
24 
25 namespace {
WaitForNotification(CallOptions * call_options,const int64_t default_timeout_in_ms,Notification * n)26 Status WaitForNotification(CallOptions* call_options,
27                            const int64_t default_timeout_in_ms,
28                            Notification* n) {
29   int64_t timeout_in_ms = call_options->GetTimeout();
30   if (timeout_in_ms == 0) {
31     timeout_in_ms = default_timeout_in_ms;
32   }
33   if (timeout_in_ms > 0) {
34     int64_t timeout_in_us = timeout_in_ms * 1000;
35     bool notified = WaitForNotificationWithTimeout(n, timeout_in_us);
36     if (!notified) {
37       call_options->StartCancel();
38       // The call has borrowed pointers to the request and response
39       // messages, so we must still wait for the call to complete.
40       n->WaitForNotification();
41       return errors::DeadlineExceeded("Operation timed out.");
42     }
43   } else {
44     n->WaitForNotification();
45   }
46   return OkStatus();
47 }
48 }  // namespace
49 
LocalMaster(Master * master_impl,const int64_t default_timeout_in_ms)50 LocalMaster::LocalMaster(Master* master_impl,
51                          const int64_t default_timeout_in_ms)
52     : master_impl_(master_impl),
53       default_timeout_in_ms_(default_timeout_in_ms) {}
54 
CreateSession(CallOptions * call_options,const CreateSessionRequest * request,CreateSessionResponse * response)55 Status LocalMaster::CreateSession(CallOptions* call_options,
56                                   const CreateSessionRequest* request,
57                                   CreateSessionResponse* response) {
58   Notification n;
59   Status ret;
60   master_impl_->CreateSession(request, response, [&n, &ret](const Status& s) {
61     ret.Update(s);
62     n.Notify();
63   });
64   TF_RETURN_IF_ERROR(
65       WaitForNotification(call_options, default_timeout_in_ms_, &n));
66   return ret;
67 }
68 
ExtendSession(CallOptions * call_options,const ExtendSessionRequest * request,ExtendSessionResponse * response)69 Status LocalMaster::ExtendSession(CallOptions* call_options,
70                                   const ExtendSessionRequest* request,
71                                   ExtendSessionResponse* response) {
72   Notification n;
73   Status ret;
74   master_impl_->ExtendSession(request, response, [&n, &ret](const Status& s) {
75     ret.Update(s);
76     n.Notify();
77   });
78   TF_RETURN_IF_ERROR(
79       WaitForNotification(call_options, default_timeout_in_ms_, &n));
80   return ret;
81 }
82 
PartialRunSetup(CallOptions * call_options,const PartialRunSetupRequest * request,PartialRunSetupResponse * response)83 Status LocalMaster::PartialRunSetup(CallOptions* call_options,
84                                     const PartialRunSetupRequest* request,
85                                     PartialRunSetupResponse* response) {
86   Notification n;
87   Status ret;
88   master_impl_->PartialRunSetup(request, response, [&n, &ret](const Status& s) {
89     ret.Update(s);
90     n.Notify();
91   });
92   TF_RETURN_IF_ERROR(
93       WaitForNotification(call_options, default_timeout_in_ms_, &n));
94   return ret;
95 }
96 
RunStep(CallOptions * call_options,RunStepRequestWrapper * request,MutableRunStepResponseWrapper * response)97 Status LocalMaster::RunStep(CallOptions* call_options,
98                             RunStepRequestWrapper* request,
99                             MutableRunStepResponseWrapper* response) {
100   Notification n;
101   Status ret;
102   master_impl_->RunStep(call_options, request, response,
103                         [&n, &ret](const Status& s) {
104                           ret.Update(s);
105                           n.Notify();
106                         });
107   TF_RETURN_IF_ERROR(
108       WaitForNotification(call_options, default_timeout_in_ms_, &n));
109   return ret;
110 }
111 
CreateRunStepRequest()112 MutableRunStepRequestWrapper* LocalMaster::CreateRunStepRequest() {
113   return new InMemoryRunStepRequest;
114 }
115 
CreateRunStepResponse()116 MutableRunStepResponseWrapper* LocalMaster::CreateRunStepResponse() {
117   return new InMemoryRunStepResponse;
118 }
119 
CloseSession(CallOptions * call_options,const CloseSessionRequest * request,CloseSessionResponse * response)120 Status LocalMaster::CloseSession(CallOptions* call_options,
121                                  const CloseSessionRequest* request,
122                                  CloseSessionResponse* response) {
123   Notification n;
124   Status ret;
125   master_impl_->CloseSession(request, response, [&n, &ret](const Status& s) {
126     ret.Update(s);
127     n.Notify();
128   });
129   TF_RETURN_IF_ERROR(
130       WaitForNotification(call_options, default_timeout_in_ms_, &n));
131   return ret;
132 }
133 
ListDevices(CallOptions * call_options,const ListDevicesRequest * request,ListDevicesResponse * response)134 Status LocalMaster::ListDevices(CallOptions* call_options,
135                                 const ListDevicesRequest* request,
136                                 ListDevicesResponse* response) {
137   Notification n;
138   Status ret;
139   master_impl_->ListDevices(request, response, [&n, &ret](const Status& s) {
140     ret.Update(s);
141     n.Notify();
142   });
143   TF_RETURN_IF_ERROR(
144       WaitForNotification(call_options, default_timeout_in_ms_, &n));
145   return ret;
146 }
147 
Reset(CallOptions * call_options,const ResetRequest * request,ResetResponse * response)148 Status LocalMaster::Reset(CallOptions* call_options,
149                           const ResetRequest* request,
150                           ResetResponse* response) {
151   Notification n;
152   Status ret;
153   master_impl_->Reset(request, response, [&n, &ret](const Status& s) {
154     ret.Update(s);
155     n.Notify();
156   });
157   TF_RETURN_IF_ERROR(
158       WaitForNotification(call_options, default_timeout_in_ms_, &n));
159   return ret;
160 }
161 
MakeCallable(CallOptions * call_options,const MakeCallableRequest * request,MakeCallableResponse * response)162 Status LocalMaster::MakeCallable(CallOptions* call_options,
163                                  const MakeCallableRequest* request,
164                                  MakeCallableResponse* response) {
165   Notification n;
166   Status ret;
167   master_impl_->MakeCallable(request, response, [&n, &ret](const Status& s) {
168     ret.Update(s);
169     n.Notify();
170   });
171   TF_RETURN_IF_ERROR(
172       WaitForNotification(call_options, default_timeout_in_ms_, &n));
173   return ret;
174 }
RunCallable(CallOptions * call_options,const RunCallableRequest * request,RunCallableResponse * response)175 Status LocalMaster::RunCallable(CallOptions* call_options,
176                                 const RunCallableRequest* request,
177                                 RunCallableResponse* response) {
178   Notification n;
179   Status ret;
180   master_impl_->RunCallable(call_options, request, response,
181                             [&n, &ret](const Status& s) {
182                               ret.Update(s);
183                               n.Notify();
184                             });
185   TF_RETURN_IF_ERROR(
186       WaitForNotification(call_options, default_timeout_in_ms_, &n));
187   return ret;
188 }
ReleaseCallable(CallOptions * call_options,const ReleaseCallableRequest * request,ReleaseCallableResponse * response)189 Status LocalMaster::ReleaseCallable(CallOptions* call_options,
190                                     const ReleaseCallableRequest* request,
191                                     ReleaseCallableResponse* response) {
192   Notification n;
193   Status ret;
194   master_impl_->ReleaseCallable(request, response, [&n, &ret](const Status& s) {
195     ret.Update(s);
196     n.Notify();
197   });
198   TF_RETURN_IF_ERROR(
199       WaitForNotification(call_options, default_timeout_in_ms_, &n));
200   return ret;
201 }
202 
203 namespace {
get_local_master_registry_lock()204 mutex* get_local_master_registry_lock() {
205   static mutex local_master_registry_lock(LINKER_INITIALIZED);
206   return &local_master_registry_lock;
207 }
208 
209 struct MasterInfo {
210   Master* master;
211   const int64_t default_timeout_in_ms;
212 
MasterInfotensorflow::__anon47cddf220c11::MasterInfo213   MasterInfo(Master* master, const int64_t default_timeout_in_ms)
214       : master(master), default_timeout_in_ms(default_timeout_in_ms) {}
215 };
216 
217 typedef std::unordered_map<string, MasterInfo> LocalMasterRegistry;
local_master_registry()218 LocalMasterRegistry* local_master_registry() {
219   static LocalMasterRegistry* local_master_registry_ = new LocalMasterRegistry;
220   return local_master_registry_;
221 }
222 }  // namespace
223 
224 /* static */
Register(const string & target,Master * master,int64_t default_timeout_in_ms)225 void LocalMaster::Register(const string& target, Master* master,
226                            int64_t default_timeout_in_ms) {
227   mutex_lock l(*get_local_master_registry_lock());
228   local_master_registry()->insert(
229       {target, MasterInfo(master, default_timeout_in_ms)});
230 }
231 
232 /* static */
Lookup(const string & target)233 std::unique_ptr<LocalMaster> LocalMaster::Lookup(const string& target) {
234   std::unique_ptr<LocalMaster> ret;
235   mutex_lock l(*get_local_master_registry_lock());
236   auto iter = local_master_registry()->find(target);
237   if (iter != local_master_registry()->end()) {
238     ret.reset(new LocalMaster(iter->second.master,
239                               iter->second.default_timeout_in_ms));
240   }
241   return ret;
242 }
243 
244 }  // namespace tensorflow
245