1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/service/server_lib.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "grpcpp/server.h"
24 #include "grpcpp/server_builder.h"
25 #include "tensorflow/core/data/service/credentials_factory.h"
26 #include "tensorflow/core/data/service/export.pb.h"
27 #include "tensorflow/core/data/service/grpc_dispatcher_impl.h"
28 #include "tensorflow/core/data/service/grpc_util.h"
29 #include "tensorflow/core/data/service/grpc_worker_impl.h"
30 #include "tensorflow/core/platform/errors.h"
31
32 namespace tensorflow {
33 namespace data {
34
35 namespace {
36 constexpr char kPortPlaceholder[] = "%port%";
37 }
38
GrpcDataServerBase(int port,const std::string & protocol,const std::string server_type,std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options)39 GrpcDataServerBase::GrpcDataServerBase(
40 int port, const std::string& protocol, const std::string server_type,
41 std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options)
42 : requested_port_(port),
43 protocol_(protocol),
44 server_type_(server_type),
45 bound_port_(port),
46 server_options_(std::move(options)) {}
47
Start()48 Status GrpcDataServerBase::Start() {
49 if (stopped_) {
50 return errors::FailedPrecondition(
51 "Server cannot be started after it has been stopped.");
52 }
53 if (started_) {
54 return OkStatus();
55 }
56 ::grpc::ServerBuilder builder;
57 for (std::unique_ptr<::grpc::ServerBuilderOption>& option : server_options_) {
58 builder.SetOption(std::move(option));
59 }
60 server_options_.clear();
61
62 std::shared_ptr<::grpc::ServerCredentials> credentials;
63 TF_RETURN_IF_ERROR(
64 CredentialsFactory::CreateServerCredentials(protocol_, &credentials));
65 builder.AddListeningPort(strings::StrCat("0.0.0.0:", requested_port_),
66 credentials, &bound_port_);
67 builder.SetMaxReceiveMessageSize(-1);
68
69 AddDataServiceToBuilder(builder);
70 AddProfilerServiceToBuilder(builder);
71 server_ = builder.BuildAndStart();
72 if (!server_) {
73 return errors::Internal("Could not start gRPC server");
74 }
75
76 TF_RETURN_IF_ERROR(StartServiceInternal());
77
78 started_ = true;
79 LOG(INFO) << "Started tf.data " << server_type_
80 << " running at 0.0.0.0:" << BoundPort();
81 return OkStatus();
82 }
83
Stop()84 void GrpcDataServerBase::Stop() {
85 if (stopped_) {
86 return;
87 }
88 if (server_) {
89 StopServiceInternal();
90 server_->Shutdown();
91 LOG(INFO) << "Shut down " << server_type_ << " server running at port "
92 << BoundPort();
93 }
94 stopped_ = true;
95 }
96
Join()97 void GrpcDataServerBase::Join() { server_->Wait(); }
98
BoundPort()99 int GrpcDataServerBase::BoundPort() { return bound_port(); }
100
AddProfilerServiceToBuilder(::grpc::ServerBuilder & builder)101 void GrpcDataServerBase::AddProfilerServiceToBuilder(
102 ::grpc::ServerBuilder& builder) {
103 profiler_service_ = profiler::CreateProfilerService();
104 builder.RegisterService(profiler_service_.get());
105 }
106
DispatchGrpcDataServer(const experimental::DispatcherConfig & config,std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options)107 DispatchGrpcDataServer::DispatchGrpcDataServer(
108 const experimental::DispatcherConfig& config,
109 std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options)
110 : GrpcDataServerBase(config.port(), config.protocol(), "DispatchServer",
111 std::move(options)),
112 config_(config) {}
113
~DispatchGrpcDataServer()114 DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
115
AddDataServiceToBuilder(::grpc::ServerBuilder & builder)116 void DispatchGrpcDataServer::AddDataServiceToBuilder(
117 ::grpc::ServerBuilder& builder) {
118 service_ = std::make_unique<GrpcDispatcherImpl>(config_, builder).release();
119 }
120
StartServiceInternal()121 Status DispatchGrpcDataServer::StartServiceInternal() {
122 return service_->Start();
123 }
124
NumWorkers(int * num_workers)125 Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
126 GetWorkersRequest req;
127 GetWorkersResponse resp;
128 ::grpc::ServerContext ctx;
129 ::grpc::Status s = service_->GetWorkers(&ctx, &req, &resp);
130 if (!s.ok()) {
131 return grpc_util::WrapError("Failed to get workers", s);
132 }
133 *num_workers = resp.workers_size();
134 return OkStatus();
135 }
136
NumActiveIterations()137 size_t DispatchGrpcDataServer::NumActiveIterations() {
138 return service_->NumActiveIterations();
139 }
140
ExportState() const141 ServerStateExport DispatchGrpcDataServer::ExportState() const {
142 ServerStateExport server_state_export;
143 *server_state_export.mutable_dispatcher_state_export() =
144 service_->ExportState();
145 return server_state_export;
146 }
147
WorkerGrpcDataServer(const experimental::WorkerConfig & config,std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options)148 WorkerGrpcDataServer::WorkerGrpcDataServer(
149 const experimental::WorkerConfig& config,
150 std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options)
151 : GrpcDataServerBase(config.port(), config.protocol(), "WorkerServer",
152 std::move(options)),
153 config_(config) {}
154
~WorkerGrpcDataServer()155 WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
156
AddDataServiceToBuilder(::grpc::ServerBuilder & builder)157 void WorkerGrpcDataServer::AddDataServiceToBuilder(
158 ::grpc::ServerBuilder& builder) {
159 service_ = std::make_unique<GrpcWorkerImpl>(config_, builder).release();
160 }
161
StartServiceInternal()162 Status WorkerGrpcDataServer::StartServiceInternal() {
163 std::string base_address = config_.worker_address();
164 if (base_address.empty()) {
165 base_address = absl::StrCat("localhost:", kPortPlaceholder);
166 }
167 std::string worker_address = str_util::StringReplace(
168 base_address, kPortPlaceholder, absl::StrCat(bound_port()),
169 /*replace_all=*/false);
170 std::string transfer_address = worker_address;
171 std::string transfer_protocol = config_.data_transfer_protocol();
172 if (!transfer_protocol.empty() && transfer_protocol != "grpc") {
173 TF_RETURN_IF_ERROR(DataTransferServer::Build(
174 transfer_protocol, service_->get_element_getter(), &transfer_server_));
175 TF_RETURN_IF_ERROR(transfer_server_->Start());
176 LOG(INFO) << "Data transfer server started at 0.0.0.0:"
177 << transfer_server_->get_port();
178 transfer_address = str_util::StringReplace(
179 config_.data_transfer_address(), kPortPlaceholder,
180 absl::StrCat(transfer_server_->get_port()),
181 /*replace_all=*/false);
182 }
183 TF_RETURN_IF_ERROR(service_->Start(worker_address, transfer_address));
184 return OkStatus();
185 }
186
StopServiceInternal()187 void WorkerGrpcDataServer::StopServiceInternal() { service_->Stop(); }
188
NumTasks(int * num_tasks)189 Status WorkerGrpcDataServer::NumTasks(int* num_tasks) {
190 GetWorkerTasksRequest req;
191 GetWorkerTasksResponse resp;
192 ::grpc::ServerContext ctx;
193 ::grpc::Status s = service_->GetWorkerTasks(&ctx, &req, &resp);
194 if (!s.ok()) {
195 return grpc_util::WrapError("Failed to get tasks", s);
196 }
197 *num_tasks = resp.tasks_size();
198 return OkStatus();
199 }
200
ExportState() const201 ServerStateExport WorkerGrpcDataServer::ExportState() const {
202 ServerStateExport server_state_export;
203 *server_state_export.mutable_worker_state_export() = service_->ExportState();
204 return server_state_export;
205 }
206
NewDispatchServer(const experimental::DispatcherConfig & config,std::unique_ptr<DispatchGrpcDataServer> & out_server)207 Status NewDispatchServer(const experimental::DispatcherConfig& config,
208 std::unique_ptr<DispatchGrpcDataServer>& out_server) {
209 out_server = std::make_unique<DispatchGrpcDataServer>(config);
210 return OkStatus();
211 }
212
NewWorkerServer(const experimental::WorkerConfig & config,std::unique_ptr<WorkerGrpcDataServer> & out_server)213 Status NewWorkerServer(const experimental::WorkerConfig& config,
214 std::unique_ptr<WorkerGrpcDataServer>& out_server) {
215 out_server = std::make_unique<WorkerGrpcDataServer>(config);
216 return OkStatus();
217 }
218
219 } // namespace data
220 } // namespace tensorflow
221