xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/session_mgr.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/session_mgr.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <utility>
21 
22 #include "tensorflow/core/activity_watcher/activity.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/renamed_device.h"
25 #include "tensorflow/core/distributed_runtime/coordination/coordination_service.h"
26 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
27 #include "tensorflow/core/distributed_runtime/error_payloads.h"
28 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
29 #include "tensorflow/core/distributed_runtime/remote_device.h"
30 #include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/protobuf/cluster.pb.h"
33 #include "tensorflow/core/protobuf/coordination_config.pb.h"
34 #include "tensorflow/core/protobuf/coordination_service.pb.h"
35 #include "tensorflow/core/protobuf/distributed_runtime_payloads.pb.h"
36 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
37 #include "tensorflow/core/util/ptr_util.h"
38 
39 namespace tensorflow {
40 
SessionMgr(WorkerEnv * worker_env,const std::string & default_worker_name,std::unique_ptr<WorkerCacheInterface> default_worker_cache,WorkerCacheFactory worker_cache_factory)41 SessionMgr::SessionMgr(
42     WorkerEnv* worker_env, const std::string& default_worker_name,
43     std::unique_ptr<WorkerCacheInterface> default_worker_cache,
44     WorkerCacheFactory worker_cache_factory)
45     : worker_env_(worker_env),
46       default_worker_cache_(std::move(default_worker_cache)),
47       legacy_session_(WorkerSession::CreateWithBorrowedDeviceMgr(
48           "", default_worker_name,
49           std::unique_ptr<WorkerCacheInterface>(
50               new WorkerCacheWrapper(default_worker_cache_.get())),
51           worker_env->device_mgr,
52           std::unique_ptr<GraphMgr>(
53               new GraphMgr(worker_env, worker_env->device_mgr)),
54           nullptr)),
55       worker_cache_factory_(std::move(worker_cache_factory)) {}
56 
57 /* static */
WorkerNameFromServerDef(const ServerDef & server_def)58 std::string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
59   return strings::StrCat("/job:", server_def.job_name(),
60                          "/replica:0/task:", server_def.task_index());
61 }
62 
CreateSession(const std::string & session,const ServerDef & server_def,bool isolate_session_state,StatusCallback coordination_error_callback)63 Status SessionMgr::CreateSession(const std::string& session,
64                                  const ServerDef& server_def,
65                                  bool isolate_session_state,
66                                  StatusCallback coordination_error_callback) {
67   return CreateSession(session, server_def, {}, isolate_session_state,
68                        /*master_task=*/"",
69                        /*master_incarnation=*/0, coordination_error_callback);
70 }
71 
CreateSession(const std::string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state)72 Status SessionMgr::CreateSession(
73     const std::string& session, const ServerDef& server_def,
74     const protobuf::RepeatedPtrField<DeviceAttributes>&
75         cluster_device_attributes,
76     bool isolate_session_state) {
77   return CreateSession(session, server_def, cluster_device_attributes,
78                        isolate_session_state,
79                        /*master_task=*/"",
80                        /*master_incarnation=*/0);
81 }
82 
CreateSession(const std::string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes,bool isolate_session_state,std::string master_task,int64_t master_incarnation,StatusCallback coordination_error_callback)83 Status SessionMgr::CreateSession(
84     const std::string& session, const ServerDef& server_def,
85     const protobuf::RepeatedPtrField<DeviceAttributes>&
86         cluster_device_attributes,
87     bool isolate_session_state, std::string master_task,
88     int64_t master_incarnation, StatusCallback coordination_error_callback) {
89   mutex_lock l(mu_);
90   if (session.empty()) {
91     return errors::InvalidArgument("Session must be non-empty.");
92   }
93 
94   // For given master task name, check if one or more `WorkerSession`s have been
95   // created previously on this worker, and if so garbage collect the expired
96   // `WorkerSession`s. This happens when the master fails before sending
97   // `DeleteSession` requests, which can cause `WorkerSession`s to be leaked.
98   if (!master_task.empty()) {
99     auto it_range = master_to_associated_sessions_.equal_range(master_task);
100     if (it_range.first != it_range.second &&
101         it_range.first->second.master_incarnation != master_incarnation) {
102       LOG(INFO) << "When creating WorkerSession for master task " << master_task
103                 << ", found old WorkerSessions created by the same master task "
104                 << "with a different incarnation. These sessions will "
105                 << "be garbage collected. Current WorkerSession count: "
106                 << sessions_.size();
107 
108       auto it = it_range.first;
109       while (it != it_range.second) {
110         auto session_it = sessions_.find(it->second.session_handle);
111         if (session_it != sessions_.end()) {
112           sessions_.erase(session_it);
113         }
114         it = master_to_associated_sessions_.erase(it);
115       }
116     }
117   }
118 
119   WorkerCacheInterface* worker_cache = nullptr;
120   std::string worker_name;
121   if (server_def.cluster().job().empty()) {
122     worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
123     worker_name = legacy_session_->worker_name();
124   } else {
125     TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
126     worker_name = WorkerNameFromServerDef(server_def);
127   }
128 
129   if (worker_cache != nullptr && default_worker_cache_ != nullptr) {
130     worker_cache->SetLogging(this->is_logging_active_);
131   }
132 
133   CHECK(!worker_env_->local_devices.empty())
134       << "The WorkerEnv must have at least one device in `local_devices`.";
135 
136   std::shared_ptr<WorkerSession> worker_session;
137   std::vector<std::unique_ptr<Device>> cluster_devices;
138 
139   if (isolate_session_state || server_def.cluster().job_size()) {
140     if (server_def.cluster().job_size()) {
141       VLOG(1) << "ClusterSpec propagation is enabled.";
142     }
143     if (!isolate_session_state) {
144       VLOG(1) << "Session state isolation is disabled.";
145     }
146 
147     // Create a private copy of the DeviceMgr for the WorkerSession.
148     std::vector<std::unique_ptr<Device>> renamed_devices;
149     for (Device* d : worker_env_->local_devices) {
150       renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
151           worker_name, d, false, isolate_session_state));
152     }
153     auto device_mgr = MakeUnique<StaticDeviceMgr>(std::move(renamed_devices));
154     LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) {
155       return device_mgr->LookupDevice(name, device);
156     };
157     AsRemoteDevices(worker_env_->env, cluster_device_attributes, cb,
158                     &cluster_devices);
159     std::unique_ptr<DynamicDeviceMgr> remote_devices;
160     if (!cluster_device_attributes.empty()) {
161       remote_devices = MakeUnique<DynamicDeviceMgr>();
162       TF_RETURN_IF_ERROR(
163           remote_devices->AddDevices(std::move(cluster_devices)));
164     }
165 
166     auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
167     worker_session.reset(
168         new WorkerSession(session, worker_name,
169                           std::unique_ptr<WorkerCacheInterface>(worker_cache),
170                           std::move(device_mgr), std::move(graph_mgr),
171                           std::move(remote_devices)));
172   } else {
173     AsRemoteDevices(worker_env_->env, cluster_device_attributes, nullptr,
174                     &cluster_devices);
175     std::unique_ptr<DynamicDeviceMgr> remote_devices;
176     if (!cluster_device_attributes.empty()) {
177       remote_devices = MakeUnique<DynamicDeviceMgr>();
178       TF_RETURN_IF_ERROR(
179           remote_devices->AddDevices(std::move(cluster_devices)));
180     }
181     // Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so
182     // that resources using it can use its devices after the
183     // WorkerSession has been deleted.
184     auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, worker_env_->device_mgr);
185     worker_session = WorkerSession::CreateWithBorrowedDeviceMgr(
186         session, worker_name,
187         std::unique_ptr<WorkerCacheInterface>(worker_cache),
188         worker_env_->device_mgr, std::move(graph_mgr),
189         std::move(remote_devices));
190   }
191 
192   sessions_.insert(std::make_pair(session, std::move(worker_session)));
193   if (!master_task.empty()) {
194     MasterAssociatedSession s{master_incarnation, session};
195     master_to_associated_sessions_.emplace(master_task, s);
196   }
197 
198   // If configured, enable coordination service and agent in the first worker
199   // session.
200   const CoordinationServiceConfig& coordination_service_config =
201       server_def.default_session_config().experimental().coordination_config();
202   if (!coordination_service_config.service_type().empty() &&
203       coordination_service_agent_ == nullptr) {
204     std::unique_ptr<CoordinationClientCache> client_cache;
205     TF_RETURN_IF_ERROR(worker_cache->GetCoordinationClientCache(&client_cache));
206     // Note: If this worker is not the leader, no service instance will be
207     // returned. Hence, only the worker leader in the cluster would hold the
208     // coordination service instance.
209     coordination_service_ =
210         CoordinationServiceInterface::EnableCoordinationService(
211             coordination_service_config.service_type(), worker_env_->env,
212             server_def, std::move(client_cache));
213 
214     std::unique_ptr<CoordinationClientCache> agent_cache;
215     TF_RETURN_IF_ERROR(worker_cache->GetCoordinationClientCache(&agent_cache));
216     coordination_service_agent_ = CreateCoordinationServiceAgent();
217     TF_RETURN_IF_ERROR(coordination_service_agent_->Initialize(
218         worker_env_->env, server_def, std::move(agent_cache),
219         std::move(coordination_error_callback)));
220     activity_watcher::MaybeEnableMultiWorkersWatching(
221         coordination_service_agent_.get());
222   }
223   return OkStatus();
224 }
225 
ResetDefaultWorkerCache(WorkerCacheInterface * worker_cache)226 void SessionMgr::ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache) {
227   default_worker_cache_.reset(worker_cache);
228 }
229 
UpdateSession(const std::string & session,const ServerDef & server_def,const protobuf::RepeatedPtrField<DeviceAttributes> & cluster_device_attributes)230 Status SessionMgr::UpdateSession(
231     const std::string& session, const ServerDef& server_def,
232     const protobuf::RepeatedPtrField<DeviceAttributes>&
233         cluster_device_attributes) {
234   mutex_lock l(mu_);
235   if (session.empty()) {
236     return errors::InvalidArgument("Session must be non-empty.");
237   }
238   auto it = sessions_.find(session);
239   if (it == sessions_.end()) {
240     return errors::InvalidArgument("Cannot update session ", session,
241                                    " because it does not exist.");
242   }
243   std::shared_ptr<WorkerSession> worker_session = it->second;
244 
245   WorkerCacheInterface* worker_cache = nullptr;
246   if (server_def.cluster().job().empty()) {
247     worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
248   } else {
249     TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
250   }
251   std::vector<std::string> updated_remote_workers;
252   worker_cache->ListWorkers(&updated_remote_workers);
253 
254   std::vector<std::unique_ptr<Device>> cluster_devices;
255 
256   const DeviceMgr* local_device_mgr = worker_session->device_mgr();
257   DeviceMgr* remote_device_mgr = worker_session->remote_device_mgr();
258   std::vector<Device*> curr_remote_devices = remote_device_mgr->ListDevices();
259   std::vector<std::unique_ptr<Device>> added_remote_devices;
260   std::vector<Device*> removed_remote_devices;
261 
262   std::vector<DeviceAttributes> added_cluster_device_attrs;
263   for (const auto& da : cluster_device_attributes) {
264     Device* device;
265     if (!local_device_mgr->LookupDevice(da.name(), &device).ok() &&
266         !remote_device_mgr->LookupDevice(da.name(), &device).ok()) {
267       added_cluster_device_attrs.emplace_back(da);
268     } else if (device != nullptr &&
269                device->attributes().incarnation() != da.incarnation()) {
270       removed_remote_devices.emplace_back(device);
271       added_cluster_device_attrs.emplace_back(da);
272     }
273   }
274   for (Device* device : curr_remote_devices) {
275     std::string task_name;
276     DeviceNameUtils::GetTaskName(device->parsed_name(), &task_name);
277     if (std::find(updated_remote_workers.begin(), updated_remote_workers.end(),
278                   task_name) == updated_remote_workers.end()) {
279       removed_remote_devices.emplace_back(device);
280     }
281   }
282   protobuf::RepeatedPtrField<DeviceAttributes> added_cluster_device_attrs_pb(
283       added_cluster_device_attrs.begin(), added_cluster_device_attrs.end());
284   AsRemoteDevices(worker_env_->env, added_cluster_device_attrs_pb, nullptr,
285                   &added_remote_devices);
286 
287   TF_RETURN_IF_ERROR(worker_session->UpdateWorkerCacheAndDevices(
288       std::unique_ptr<WorkerCacheInterface>(worker_cache),
289       std::move(added_remote_devices), removed_remote_devices));
290   return OkStatus();
291 }
292 
DeleteSession(const std::string & session)293 Status SessionMgr::DeleteSession(const std::string& session) {
294   mutex_lock l(mu_);
295   auto it = sessions_.find(session);
296   if (it != sessions_.end()) {
297     sessions_.erase(it);
298   }
299   return OkStatus();
300 }
301 
WorkerSessionForSessionLocked(const std::string & session_handle,std::shared_ptr<WorkerSession> * out_session)302 Status SessionMgr::WorkerSessionForSessionLocked(
303     const std::string& session_handle,
304     std::shared_ptr<WorkerSession>* out_session) {
305   if (session_handle.empty()) {
306     *out_session = legacy_session_;
307   } else {
308     auto it = sessions_.find(session_handle);
309     if (it == sessions_.end()) {
310       return errors::AbortedWithPayloads(
311           strings::StrCat("Session handle is not found: ", session_handle,
312                           ". Possibly this worker (\"",
313                           legacy_session_->worker_name(),
314                           "\") just restarted."),
315           {{kWorkerPossiblyRestarted,
316             distributed_runtime::WorkerPossiblyRestarted()
317                 .SerializeAsString()}});
318     } else {
319       *out_session = it->second;
320     }
321   }
322   return OkStatus();
323 }
324 
WorkerSessionForSession(const std::string & session_handle,std::shared_ptr<WorkerSession> * out_session)325 Status SessionMgr::WorkerSessionForSession(
326     const std::string& session_handle,
327     std::shared_ptr<WorkerSession>* out_session) {
328   mutex_lock l(mu_);
329   return WorkerSessionForSessionLocked(session_handle, out_session);
330 }
331 
LegacySession()332 std::shared_ptr<WorkerSession> SessionMgr::LegacySession() {
333   return legacy_session_;
334 }
335 
GetCoordinationServiceAgent()336 CoordinationServiceAgent* SessionMgr::GetCoordinationServiceAgent() {
337   return coordination_service_agent_.get();
338 }
339 
SetLogging(bool active)340 void SessionMgr::SetLogging(bool active) {
341   mutex_lock l(mu_);
342   this->is_logging_active_ = active;
343   // Legacy Session
344   if (legacy_session_) {
345     auto* worker_cache = legacy_session_->worker_cache();
346     if (worker_cache) {
347       worker_cache->SetLogging(active);
348     }
349   }
350 
351   for (const auto& session_kv : sessions_) {
352     auto session = session_kv.second.get();
353     if (session) {
354       auto* worker_cache = session->worker_cache();
355       if (worker_cache) {
356         worker_cache->SetLogging(active);
357       }
358     }
359   }
360 }
361 
RetrieveLogs(int64_t step_id,LoggingResponse * response)362 void SessionMgr::RetrieveLogs(int64_t step_id, LoggingResponse* response) {
363   mutex_lock l(mu_);
364   // Legacy Session
365   if (legacy_session_) {
366     auto* worker_cache = legacy_session_->worker_cache();
367     if (worker_cache) {
368       auto step_stats = StepStats();
369       if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
370         auto* labeled_step_stats = response->add_step();
371         labeled_step_stats->set_step_id(step_id);
372         labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
373       }
374     }
375   }
376   for (const auto& session_kv : sessions_) {
377     auto session = session_kv.second.get();
378     if (session) {
379       auto* worker_cache = session->worker_cache();
380       if (worker_cache) {
381         auto step_stats = StepStats();
382         if (worker_cache->RetrieveLogs(step_id, &step_stats)) {
383           auto* labeled_step_stats = response->add_step();
384           labeled_step_stats->set_step_id(step_id);
385           labeled_step_stats->mutable_step_stats()->Swap(&step_stats);
386         }
387       }
388     }
389   }
390 }
391 
ClearLogs()392 void SessionMgr::ClearLogs() {
393   mutex_lock l(mu_);
394   // Legacy Session
395   if (legacy_session_) {
396     auto* worker_cache = legacy_session_->worker_cache();
397     if (worker_cache) {
398       worker_cache->ClearLogs();
399     }
400   }
401 
402   for (const auto& session_kv : sessions_) {
403     auto session = session_kv.second.get();
404     if (session) {
405       auto* worker_cache = session->worker_cache();
406       if (worker_cache) {
407         worker_cache->ClearLogs();
408       }
409     }
410   }
411 }
412 
TeardownCoordinationService()413 void SessionMgr::TeardownCoordinationService() {
414   coordination_service_ = nullptr;
415 }
416 
TeardownCoordinationServiceAgent()417 void SessionMgr::TeardownCoordinationServiceAgent() {
418   coordination_service_agent_ = nullptr;
419 }
420 }  // namespace tensorflow
421