1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/worker_client.h"
16
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "grpcpp/client_context.h"
23 #include "grpcpp/create_channel.h"
24 #include "grpcpp/security/credentials.h"
25 #include "grpcpp/support/channel_arguments.h"
26 #include "grpcpp/support/status.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/strings/substitute.h"
31 #include "tensorflow/core/data/dataset.pb.h"
32 #include "tensorflow/core/data/service/credentials_factory.h"
33 #include "tensorflow/core/data/service/data_transfer.h"
34 #include "tensorflow/core/data/service/grpc_util.h"
35 #include "tensorflow/core/data/service/worker.grpc.pb.h"
36 #include "tensorflow/core/data/service/worker.pb.h"
37 #include "tensorflow/core/data/service/worker_impl.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/tensor_shape.h"
40 #include "tensorflow/core/framework/tensor_types.h"
41 #include "tensorflow/core/framework/types.pb.h"
42 #include "tensorflow/core/framework/variant.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/platform/mutex.h"
45 #include "tensorflow/core/platform/status.h"
46 #include "tensorflow/core/platform/statusor.h"
47 #include "tensorflow/core/platform/thread_annotations.h"
48 #include "tensorflow/core/platform/types.h"
49
50 namespace tensorflow {
51 namespace data {
52
53 StatusOr<std::unique_ptr<DataServiceWorkerClient>>
CreateDataServiceWorkerClient(const std::string & address,const std::string & protocol,const std::string & transfer_protocol)54 CreateDataServiceWorkerClient(const std::string& address,
55 const std::string& protocol,
56 const std::string& transfer_protocol) {
57 auto client = std::make_unique<DataServiceWorkerClient>(address, protocol,
58 transfer_protocol);
59 TF_RETURN_IF_ERROR(client->Initialize());
60 return client;
61 }
62
GetElement(const GetElementRequest & req,GetElementResult & result)63 Status DataServiceWorkerClient::GetElement(const GetElementRequest& req,
64 GetElementResult& result) {
65 TF_RETURN_IF_ERROR(EnsureInitialized());
66 return client_->GetElement(req, result);
67 }
68
EnsureInitialized()69 Status DataServiceWorkerClient::EnsureInitialized() {
70 mutex_lock l(mu_);
71 if (client_) {
72 return OkStatus();
73 }
74 TF_RETURN_IF_ERROR(DataTransferClient::Build(
75 GetDataTransferProtocol(), {protocol_, address_}, &client_));
76 return OkStatus();
77 }
78
GetDataTransferProtocol() const79 std::string DataServiceWorkerClient::GetDataTransferProtocol() const {
80 if (transfer_protocol_ == kGrpcTransferProtocol &&
81 LocalWorkers::Get(address_) != nullptr) {
82 return kLocalTransferProtocol;
83 }
84 return transfer_protocol_;
85 }
86
TryCancel()87 void DataServiceWorkerClient::TryCancel() { client_->TryCancel(); }
88
89 class GrpcDataTransferClient : public DataTransferClient {
90 public:
GrpcDataTransferClient(std::shared_ptr<grpc::ChannelCredentials> credentials,std::string address)91 GrpcDataTransferClient(std::shared_ptr<grpc::ChannelCredentials> credentials,
92 std::string address) {
93 VLOG(2) << "Create GrpcDataTransferClient for worker " << address << ".";
94 grpc::ChannelArguments args;
95 args.SetMaxReceiveMessageSize(-1);
96 auto channel = grpc::CreateCustomChannel(address, credentials, args);
97 stub_ = WorkerService::NewStub(channel);
98 }
99
GetElement(const GetElementRequest & req,GetElementResult & result)100 Status GetElement(const GetElementRequest& req,
101 GetElementResult& result) override {
102 VLOG(3) << "GetElement for task " << req.task_id() << " from gRPC worker "
103 << "server.";
104 {
105 mutex_lock l(mu_);
106 if (cancelled_) {
107 return errors::Cancelled("Client was cancelled.");
108 }
109 }
110 grpc::ClientContext ctx;
111 {
112 mutex_lock l(mu_);
113 active_contexts_.insert(&ctx);
114 }
115 GetElementResponse resp;
116 grpc::Status s = stub_->GetElement(&ctx, req, &resp);
117 result.end_of_sequence = resp.end_of_sequence();
118 result.skip = resp.skip_task();
119 switch (resp.element_case()) {
120 case GetElementResponse::kCompressed: {
121 Tensor tensor(DT_VARIANT, TensorShape{});
122 tensor.scalar<Variant>()() = std::move(resp.compressed());
123 result.components.push_back(tensor);
124 break;
125 }
126 case GetElementResponse::kUncompressed:
127 for (const auto& component : resp.uncompressed().components()) {
128 result.components.emplace_back();
129 if (!result.components.back().FromProto(component)) {
130 return errors::Internal("Failed to parse tensor.");
131 }
132 }
133 break;
134 case GetElementResponse::ELEMENT_NOT_SET:
135 break;
136 }
137 {
138 mutex_lock l(mu_);
139 active_contexts_.erase(&ctx);
140 }
141 if (!s.ok()) {
142 return grpc_util::WrapError("Failed to get element", s);
143 }
144 return OkStatus();
145 }
146
TryCancel()147 void TryCancel() override {
148 VLOG(2) << "Cancel GrpcDataTransferClient.";
149 mutex_lock l(mu_);
150 cancelled_ = true;
151 for (const auto& ctx : active_contexts_) {
152 ctx->TryCancel();
153 }
154 }
155
156 private:
157 mutex mu_;
158 std::unique_ptr<WorkerService::Stub> stub_;
159 // Set of all currently active clients contexts. Used to support
160 // cancellation.
161 absl::flat_hash_set<::grpc::ClientContext*> active_contexts_
162 TF_GUARDED_BY(mu_);
163 // Indicates that the client has been cancelled, so no further requests should
164 // be accepted.
165 bool cancelled_ TF_GUARDED_BY(mu_) = false;
166 };
167
168 class GrpcTransferClientRegistrar {
169 public:
GrpcTransferClientRegistrar()170 GrpcTransferClientRegistrar() {
171 DataTransferClient::Register(
172 kGrpcTransferProtocol, [](DataTransferClient::Config config,
173 std::unique_ptr<DataTransferClient>* out) {
174 std::shared_ptr<grpc::ChannelCredentials> credentials;
175 TF_RETURN_IF_ERROR(CredentialsFactory::CreateClientCredentials(
176 config.protocol, &credentials));
177 *out = std::make_unique<GrpcDataTransferClient>(credentials,
178 config.address);
179 return OkStatus();
180 });
181 }
182 };
183 static GrpcTransferClientRegistrar gprc_client_registrar;
184
185 class LocalDataTransferClient : public DataTransferClient {
186 public:
LocalDataTransferClient(absl::string_view worker_address)187 explicit LocalDataTransferClient(absl::string_view worker_address)
188 : worker_address_(worker_address) {
189 VLOG(2) << "Create LocalDataTransferClient for worker " << worker_address_
190 << ".";
191 }
192
GetElement(const GetElementRequest & req,GetElementResult & result)193 Status GetElement(const GetElementRequest& req,
194 GetElementResult& result) override {
195 VLOG(3) << "GetElement for task " << req.task_id() << " from local worker.";
196 TF_RETURN_IF_ERROR(VerifyClientIsNotCancelled());
197 TF_ASSIGN_OR_RETURN(std::shared_ptr<DataServiceWorkerImpl> worker,
198 GetWorker(req));
199 return worker->GetElementResult(&req, &result);
200 }
201
TryCancel()202 void TryCancel() override {
203 VLOG(2) << "Cancel LocalDataTransferClient for worker " << worker_address_
204 << ".";
205 // Cancels incoming requests. Currently local reads assume the requests are
206 // first-come-first-served. If we need to support coordinated reads, we need
207 // to cancel in-flight requests since they may wait infinitely.
208 mutex_lock l(mu_);
209 cancelled_ = true;
210 }
211
212 private:
VerifyClientIsNotCancelled()213 Status VerifyClientIsNotCancelled() TF_LOCKS_EXCLUDED(mu_) {
214 mutex_lock l(mu_);
215 if (cancelled_) {
216 return errors::Cancelled(absl::Substitute(
217 "Client for worker $0 has been cancelled.", worker_address_));
218 }
219 return OkStatus();
220 }
221
GetWorker(const GetElementRequest & req) const222 StatusOr<std::shared_ptr<DataServiceWorkerImpl>> GetWorker(
223 const GetElementRequest& req) const {
224 std::shared_ptr<DataServiceWorkerImpl> worker =
225 LocalWorkers::Get(worker_address_);
226 if (!worker) {
227 return errors::Cancelled(absl::Substitute(
228 "Local worker at address $0 is no longer available; cancel request "
229 "for task $1.",
230 worker_address_, req.task_id()));
231 }
232 return worker;
233 }
234
235 const std::string worker_address_;
236
237 mutex mu_;
238 bool cancelled_ TF_GUARDED_BY(mu_) = false;
239 };
240
241 class LocalTransferClientRegistrar {
242 public:
LocalTransferClientRegistrar()243 LocalTransferClientRegistrar() {
244 DataTransferClient::Register(
245 kLocalTransferProtocol, [](DataTransferClient::Config config,
246 std::unique_ptr<DataTransferClient>* out) {
247 *out = std::make_unique<LocalDataTransferClient>(config.address);
248 return OkStatus();
249 });
250 }
251 };
252 static LocalTransferClientRegistrar local_client_registrar;
253
254 } // namespace data
255 } // namespace tensorflow
256