1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_WORKER_CLIENT_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_WORKER_CLIENT_H_ 17 18 #include <memory> 19 #include <string> 20 21 #include "tensorflow/core/data/service/common.h" 22 #include "tensorflow/core/data/service/data_transfer.h" 23 #include "tensorflow/core/data/service/worker.pb.h" 24 #include "tensorflow/core/platform/mutex.h" 25 #include "tensorflow/core/platform/status.h" 26 #include "tensorflow/core/platform/statusor.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace tensorflow { 30 namespace data { 31 32 constexpr const char kLocalTransferProtocol[] = "local"; 33 constexpr const char kGrpcTransferProtocol[] = "grpc"; 34 35 // Client for communicating with the tf.data service worker. 36 class DataServiceWorkerClient : public DataServiceClientBase { 37 public: DataServiceWorkerClient(const std::string & address,const std::string & protocol,const std::string & transfer_protocol)38 DataServiceWorkerClient(const std::string& address, 39 const std::string& protocol, 40 const std::string& transfer_protocol) 41 : DataServiceClientBase(address, protocol), 42 transfer_protocol_(transfer_protocol) {} 43 44 // Fetches an element from the worker. 45 Status GetElement(const GetElementRequest& req, GetElementResult& result); 46 47 // Makes a best effort to cancel all outstanding calls in progress for the 48 // client, and causes further calls to return Cancelled status. 49 void TryCancel(); 50 51 protected: 52 Status EnsureInitialized() override; 53 54 private: 55 // Returns the data transfer protocol, preferring to use the local transfer 56 // protocol if a local tf.data worker exists. 57 std::string GetDataTransferProtocol() const; 58 59 const std::string transfer_protocol_; 60 mutex mu_; 61 // Initialization is guarded by `mu_`, but using the stub does not require 62 // holding `mu_` 63 std::unique_ptr<DataTransferClient> client_; 64 }; 65 66 // Creates and initializes a new tf.data service worker client. 67 StatusOr<std::unique_ptr<DataServiceWorkerClient>> 68 CreateDataServiceWorkerClient(const std::string& address, 69 const std::string& protocol, 70 const std::string& transfer_protocol); 71 72 } // namespace data 73 } // namespace tensorflow 74 75 #endif // TENSORFLOW_CORE_DATA_SERVICE_WORKER_CLIENT_H_ 76