xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/worker_client.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/worker_client.h"
16 
17 #include <memory>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "grpcpp/client_context.h"
23 #include "grpcpp/create_channel.h"
24 #include "grpcpp/security/credentials.h"
25 #include "grpcpp/support/channel_arguments.h"
26 #include "grpcpp/support/status.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/strings/substitute.h"
31 #include "tensorflow/core/data/dataset.pb.h"
32 #include "tensorflow/core/data/service/credentials_factory.h"
33 #include "tensorflow/core/data/service/data_transfer.h"
34 #include "tensorflow/core/data/service/grpc_util.h"
35 #include "tensorflow/core/data/service/worker.grpc.pb.h"
36 #include "tensorflow/core/data/service/worker.pb.h"
37 #include "tensorflow/core/data/service/worker_impl.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/tensor_shape.h"
40 #include "tensorflow/core/framework/tensor_types.h"
41 #include "tensorflow/core/framework/types.pb.h"
42 #include "tensorflow/core/framework/variant.h"
43 #include "tensorflow/core/platform/errors.h"
44 #include "tensorflow/core/platform/mutex.h"
45 #include "tensorflow/core/platform/status.h"
46 #include "tensorflow/core/platform/statusor.h"
47 #include "tensorflow/core/platform/thread_annotations.h"
48 #include "tensorflow/core/platform/types.h"
49 
50 namespace tensorflow {
51 namespace data {
52 
53 StatusOr<std::unique_ptr<DataServiceWorkerClient>>
CreateDataServiceWorkerClient(const std::string & address,const std::string & protocol,const std::string & transfer_protocol)54 CreateDataServiceWorkerClient(const std::string& address,
55                               const std::string& protocol,
56                               const std::string& transfer_protocol) {
57   auto client = std::make_unique<DataServiceWorkerClient>(address, protocol,
58                                                           transfer_protocol);
59   TF_RETURN_IF_ERROR(client->Initialize());
60   return client;
61 }
62 
GetElement(const GetElementRequest & req,GetElementResult & result)63 Status DataServiceWorkerClient::GetElement(const GetElementRequest& req,
64                                            GetElementResult& result) {
65   TF_RETURN_IF_ERROR(EnsureInitialized());
66   return client_->GetElement(req, result);
67 }
68 
EnsureInitialized()69 Status DataServiceWorkerClient::EnsureInitialized() {
70   mutex_lock l(mu_);
71   if (client_) {
72     return OkStatus();
73   }
74   TF_RETURN_IF_ERROR(DataTransferClient::Build(
75       GetDataTransferProtocol(), {protocol_, address_}, &client_));
76   return OkStatus();
77 }
78 
GetDataTransferProtocol() const79 std::string DataServiceWorkerClient::GetDataTransferProtocol() const {
80   if (transfer_protocol_ == kGrpcTransferProtocol &&
81       LocalWorkers::Get(address_) != nullptr) {
82     return kLocalTransferProtocol;
83   }
84   return transfer_protocol_;
85 }
86 
TryCancel()87 void DataServiceWorkerClient::TryCancel() { client_->TryCancel(); }
88 
89 class GrpcDataTransferClient : public DataTransferClient {
90  public:
GrpcDataTransferClient(std::shared_ptr<grpc::ChannelCredentials> credentials,std::string address)91   GrpcDataTransferClient(std::shared_ptr<grpc::ChannelCredentials> credentials,
92                          std::string address) {
93     VLOG(2) << "Create GrpcDataTransferClient for worker " << address << ".";
94     grpc::ChannelArguments args;
95     args.SetMaxReceiveMessageSize(-1);
96     auto channel = grpc::CreateCustomChannel(address, credentials, args);
97     stub_ = WorkerService::NewStub(channel);
98   }
99 
GetElement(const GetElementRequest & req,GetElementResult & result)100   Status GetElement(const GetElementRequest& req,
101                     GetElementResult& result) override {
102     VLOG(3) << "GetElement for task " << req.task_id() << " from gRPC worker "
103             << "server.";
104     {
105       mutex_lock l(mu_);
106       if (cancelled_) {
107         return errors::Cancelled("Client was cancelled.");
108       }
109     }
110     grpc::ClientContext ctx;
111     {
112       mutex_lock l(mu_);
113       active_contexts_.insert(&ctx);
114     }
115     GetElementResponse resp;
116     grpc::Status s = stub_->GetElement(&ctx, req, &resp);
117     result.end_of_sequence = resp.end_of_sequence();
118     result.skip = resp.skip_task();
119     switch (resp.element_case()) {
120       case GetElementResponse::kCompressed: {
121         Tensor tensor(DT_VARIANT, TensorShape{});
122         tensor.scalar<Variant>()() = std::move(resp.compressed());
123         result.components.push_back(tensor);
124         break;
125       }
126       case GetElementResponse::kUncompressed:
127         for (const auto& component : resp.uncompressed().components()) {
128           result.components.emplace_back();
129           if (!result.components.back().FromProto(component)) {
130             return errors::Internal("Failed to parse tensor.");
131           }
132         }
133         break;
134       case GetElementResponse::ELEMENT_NOT_SET:
135         break;
136     }
137     {
138       mutex_lock l(mu_);
139       active_contexts_.erase(&ctx);
140     }
141     if (!s.ok()) {
142       return grpc_util::WrapError("Failed to get element", s);
143     }
144     return OkStatus();
145   }
146 
TryCancel()147   void TryCancel() override {
148     VLOG(2) << "Cancel GrpcDataTransferClient.";
149     mutex_lock l(mu_);
150     cancelled_ = true;
151     for (const auto& ctx : active_contexts_) {
152       ctx->TryCancel();
153     }
154   }
155 
156  private:
157   mutex mu_;
158   std::unique_ptr<WorkerService::Stub> stub_;
159   // Set of all currently active clients contexts. Used to support
160   // cancellation.
161   absl::flat_hash_set<::grpc::ClientContext*> active_contexts_
162       TF_GUARDED_BY(mu_);
163   // Indicates that the client has been cancelled, so no further requests should
164   // be accepted.
165   bool cancelled_ TF_GUARDED_BY(mu_) = false;
166 };
167 
168 class GrpcTransferClientRegistrar {
169  public:
GrpcTransferClientRegistrar()170   GrpcTransferClientRegistrar() {
171     DataTransferClient::Register(
172         kGrpcTransferProtocol, [](DataTransferClient::Config config,
173                                   std::unique_ptr<DataTransferClient>* out) {
174           std::shared_ptr<grpc::ChannelCredentials> credentials;
175           TF_RETURN_IF_ERROR(CredentialsFactory::CreateClientCredentials(
176               config.protocol, &credentials));
177           *out = std::make_unique<GrpcDataTransferClient>(credentials,
178                                                           config.address);
179           return OkStatus();
180         });
181   }
182 };
183 static GrpcTransferClientRegistrar gprc_client_registrar;
184 
185 class LocalDataTransferClient : public DataTransferClient {
186  public:
LocalDataTransferClient(absl::string_view worker_address)187   explicit LocalDataTransferClient(absl::string_view worker_address)
188       : worker_address_(worker_address) {
189     VLOG(2) << "Create LocalDataTransferClient for worker " << worker_address_
190             << ".";
191   }
192 
GetElement(const GetElementRequest & req,GetElementResult & result)193   Status GetElement(const GetElementRequest& req,
194                     GetElementResult& result) override {
195     VLOG(3) << "GetElement for task " << req.task_id() << " from local worker.";
196     TF_RETURN_IF_ERROR(VerifyClientIsNotCancelled());
197     TF_ASSIGN_OR_RETURN(std::shared_ptr<DataServiceWorkerImpl> worker,
198                         GetWorker(req));
199     return worker->GetElementResult(&req, &result);
200   }
201 
TryCancel()202   void TryCancel() override {
203     VLOG(2) << "Cancel LocalDataTransferClient for worker " << worker_address_
204             << ".";
205     // Cancels incoming requests. Currently local reads assume the requests are
206     // first-come-first-served. If we need to support coordinated reads, we need
207     // to cancel in-flight requests since they may wait infinitely.
208     mutex_lock l(mu_);
209     cancelled_ = true;
210   }
211 
212  private:
VerifyClientIsNotCancelled()213   Status VerifyClientIsNotCancelled() TF_LOCKS_EXCLUDED(mu_) {
214     mutex_lock l(mu_);
215     if (cancelled_) {
216       return errors::Cancelled(absl::Substitute(
217           "Client for worker $0 has been cancelled.", worker_address_));
218     }
219     return OkStatus();
220   }
221 
GetWorker(const GetElementRequest & req) const222   StatusOr<std::shared_ptr<DataServiceWorkerImpl>> GetWorker(
223       const GetElementRequest& req) const {
224     std::shared_ptr<DataServiceWorkerImpl> worker =
225         LocalWorkers::Get(worker_address_);
226     if (!worker) {
227       return errors::Cancelled(absl::Substitute(
228           "Local worker at address $0 is no longer available; cancel request "
229           "for task $1.",
230           worker_address_, req.task_id()));
231     }
232     return worker;
233   }
234 
235   const std::string worker_address_;
236 
237   mutex mu_;
238   bool cancelled_ TF_GUARDED_BY(mu_) = false;
239 };
240 
241 class LocalTransferClientRegistrar {
242  public:
LocalTransferClientRegistrar()243   LocalTransferClientRegistrar() {
244     DataTransferClient::Register(
245         kLocalTransferProtocol, [](DataTransferClient::Config config,
246                                    std::unique_ptr<DataTransferClient>* out) {
247           *out = std::make_unique<LocalDataTransferClient>(config.address);
248           return OkStatus();
249         });
250   }
251 };
252 static LocalTransferClientRegistrar local_client_registrar;
253 
254 }  // namespace data
255 }  // namespace tensorflow
256