xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/server_lib.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/service/server_lib.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "grpcpp/server.h"
24 #include "grpcpp/server_builder.h"
25 #include "tensorflow/core/data/service/credentials_factory.h"
26 #include "tensorflow/core/data/service/export.pb.h"
27 #include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
28 #include "tensorflow/core/data/service/grpc_util.h"
29 #include "tensorflow/core/data/service/grpc_worker_impl.h"
30 #include "tensorflow/core/platform/errors.h"
31 
32 namespace tensorflow {
33 namespace data {
34 
35 namespace {
36 constexpr char kPortPlaceholder[] = "%port%";
37 }
38 
GrpcDataServerBase(int port,const std::string & protocol,const std::string server_type,std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options)39 GrpcDataServerBase::GrpcDataServerBase(
40     int port, const std::string& protocol, const std::string server_type,
41     std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options)
42     : requested_port_(port),
43       protocol_(protocol),
44       server_type_(server_type),
45       bound_port_(port),
46       server_options_(std::move(options)) {}
47 
Start()48 Status GrpcDataServerBase::Start() {
49   if (stopped_) {
50     return errors::FailedPrecondition(
51         "Server cannot be started after it has been stopped.");
52   }
53   if (started_) {
54     return OkStatus();
55   }
56   ::grpc::ServerBuilder builder;
57   for (std::unique_ptr<::grpc::ServerBuilderOption>& option : server_options_) {
58     builder.SetOption(std::move(option));
59   }
60   server_options_.clear();
61 
62   std::shared_ptr<::grpc::ServerCredentials> credentials;
63   TF_RETURN_IF_ERROR(
64       CredentialsFactory::CreateServerCredentials(protocol_, &credentials));
65   builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_),
66                            credentials, &bound_port_);
67   builder.SetMaxReceiveMessageSize(-1);
68 
69   AddDataServiceToBuilder(builder);
70   AddProfilerServiceToBuilder(builder);
71   server_ = builder.BuildAndStart();
72   if (!server_) {
73     return errors::Internal("Could not start gRPC server");
74   }
75 
76   TF_RETURN_IF_ERROR(StartServiceInternal());
77 
78   started_ = true;
79   LOG(INFO) << "Started tf.data " << server_type_
80             << " running at 0.0.0.0:" << BoundPort();
81   return OkStatus();
82 }
83 
Stop()84 void GrpcDataServerBase::Stop() {
85   if (stopped_) {
86     return;
87   }
88   if (server_) {
89     StopServiceInternal();
90     server_->Shutdown();
91     LOG(INFO) << "Shut down " << server_type_ << " server running at port "
92               << BoundPort();
93   }
94   stopped_ = true;
95 }
96 
Join()97 void GrpcDataServerBase::Join() { server_->Wait(); }
98 
BoundPort()99 int GrpcDataServerBase::BoundPort() { return bound_port(); }
100 
AddProfilerServiceToBuilder(::grpc::ServerBuilder & builder)101 void GrpcDataServerBase::AddProfilerServiceToBuilder(
102     ::grpc::ServerBuilder& builder) {
103   profiler_service_ = profiler::CreateProfilerService();
104   builder.RegisterService(profiler_service_.get());
105 }
106 
DispatchGrpcDataServer(const experimental::DispatcherConfig & config,std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options)107 DispatchGrpcDataServer::DispatchGrpcDataServer(
108     const experimental::DispatcherConfig& config,
109     std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options)
110     : GrpcDataServerBase(config.port(), config.protocol(), "DispatchServer",
111                          std::move(options)),
112       config_(config) {}
113 
~DispatchGrpcDataServer()114 DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
115 
AddDataServiceToBuilder(::grpc::ServerBuilder & builder)116 void DispatchGrpcDataServer::AddDataServiceToBuilder(
117     ::grpc::ServerBuilder& builder) {
118   service_ = std::make_unique<GrpcDispatcherImpl>(config_, builder).release();
119 }
120 
StartServiceInternal()121 Status DispatchGrpcDataServer::StartServiceInternal() {
122   return service_->Start();
123 }
124 
NumWorkers(int * num_workers)125 Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
126   GetWorkersRequest req;
127   GetWorkersResponse resp;
128   ::grpc::ServerContext ctx;
129   ::grpc::Status s = service_->GetWorkers(&ctx, &req, &resp);
130   if (!s.ok()) {
131     return grpc_util::WrapError("Failed to get workers", s);
132   }
133   *num_workers = resp.workers_size();
134   return OkStatus();
135 }
136 
NumActiveIterations()137 size_t DispatchGrpcDataServer::NumActiveIterations() {
138   return service_->NumActiveIterations();
139 }
140 
ExportState() const141 ServerStateExport DispatchGrpcDataServer::ExportState() const {
142   ServerStateExport server_state_export;
143   *server_state_export.mutable_dispatcher_state_export() =
144       service_->ExportState();
145   return server_state_export;
146 }
147 
WorkerGrpcDataServer(const experimental::WorkerConfig & config,std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options)148 WorkerGrpcDataServer::WorkerGrpcDataServer(
149     const experimental::WorkerConfig& config,
150     std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options)
151     : GrpcDataServerBase(config.port(), config.protocol(), "WorkerServer",
152                          std::move(options)),
153       config_(config) {}
154 
~WorkerGrpcDataServer()155 WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
156 
AddDataServiceToBuilder(::grpc::ServerBuilder & builder)157 void WorkerGrpcDataServer::AddDataServiceToBuilder(
158     ::grpc::ServerBuilder& builder) {
159   service_ = std::make_unique<GrpcWorkerImpl>(config_, builder).release();
160 }
161 
StartServiceInternal()162 Status WorkerGrpcDataServer::StartServiceInternal() {
163   std::string base_address = config_.worker_address();
164   if (base_address.empty()) {
165     base_address = absl::StrCat("localhost:", kPortPlaceholder);
166   }
167   std::string worker_address = str_util::StringReplace(
168       base_address, kPortPlaceholder, absl::StrCat(bound_port()),
169       /*replace_all=*/false);
170   std::string transfer_address = worker_address;
171   std::string transfer_protocol = config_.data_transfer_protocol();
172   if (!transfer_protocol.empty() && transfer_protocol != "grpc") {
173     TF_RETURN_IF_ERROR(DataTransferServer::Build(
174         transfer_protocol, service_->get_element_getter(), &transfer_server_));
175     TF_RETURN_IF_ERROR(transfer_server_->Start());
176     LOG(INFO) << "Data transfer server started at 0.0.0.0:"
177               << transfer_server_->get_port();
178     transfer_address = str_util::StringReplace(
179         config_.data_transfer_address(), kPortPlaceholder,
180         absl::StrCat(transfer_server_->get_port()),
181         /*replace_all=*/false);
182   }
183   TF_RETURN_IF_ERROR(service_->Start(worker_address, transfer_address));
184   return OkStatus();
185 }
186 
StopServiceInternal()187 void WorkerGrpcDataServer::StopServiceInternal() { service_->Stop(); }
188 
NumTasks(int * num_tasks)189 Status WorkerGrpcDataServer::NumTasks(int* num_tasks) {
190   GetWorkerTasksRequest req;
191   GetWorkerTasksResponse resp;
192   ::grpc::ServerContext ctx;
193   ::grpc::Status s = service_->GetWorkerTasks(&ctx, &req, &resp);
194   if (!s.ok()) {
195     return grpc_util::WrapError("Failed to get tasks", s);
196   }
197   *num_tasks = resp.tasks_size();
198   return OkStatus();
199 }
200 
ExportState() const201 ServerStateExport WorkerGrpcDataServer::ExportState() const {
202   ServerStateExport server_state_export;
203   *server_state_export.mutable_worker_state_export() = service_->ExportState();
204   return server_state_export;
205 }
206 
NewDispatchServer(const experimental::DispatcherConfig & config,std::unique_ptr<DispatchGrpcDataServer> & out_server)207 Status NewDispatchServer(const experimental::DispatcherConfig& config,
208                          std::unique_ptr<DispatchGrpcDataServer>& out_server) {
209   out_server = std::make_unique<DispatchGrpcDataServer>(config);
210   return OkStatus();
211 }
212 
NewWorkerServer(const experimental::WorkerConfig & config,std::unique_ptr<WorkerGrpcDataServer> & out_server)213 Status NewWorkerServer(const experimental::WorkerConfig& config,
214                        std::unique_ptr<WorkerGrpcDataServer>& out_server) {
215   out_server = std::make_unique<WorkerGrpcDataServer>(config);
216   return OkStatus();
217 }
218 
219 }  // namespace data
220 }  // namespace tensorflow
221