1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_ 17 #define TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "grpcpp/server.h" 24 #include "grpcpp/server_builder.h" 25 #include "tensorflow/core/data/service/data_transfer.h" 26 #include "tensorflow/core/data/service/export.pb.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h" 29 #include "tensorflow/core/protobuf/service_config.pb.h" 30 31 namespace tensorflow { 32 namespace data { 33 34 // Forward declared because transitively depending on .grpc.pb.h files causes 35 // issues in the pywrap build. 36 class GrpcDispatcherImpl; 37 class GrpcWorkerImpl; 38 39 // A grpc server for the tf.data service. 40 class GrpcDataServerBase { 41 public: 42 // Constructs a tf.data server with the specified port. If the port is 0, the 43 // server will find an available port in `Start()`. The chosen port can be 44 // found by calling `BoundPort()`. 45 GrpcDataServerBase( 46 int requested_port, const std::string& protocol, 47 const std::string server_type, 48 std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options = {}); 49 virtual ~GrpcDataServerBase() = default; 50 51 // Starts the server running asynchronously. 52 Status Start(); 53 54 // Stops the server. This will block until all outstanding requests complete. 55 void Stop(); 56 57 // Blocks until the server stops. 58 void Join(); 59 60 // Returns the port bound by the server. Only valid after calling Start(). 61 int BoundPort(); 62 63 // Exports the server state to improve debuggability. 64 virtual ServerStateExport ExportState() const = 0; 65 66 protected: 67 virtual void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) = 0; 68 void AddProfilerServiceToBuilder(::grpc::ServerBuilder& builder); 69 // Starts the service. This will be called after building the service, so 70 // bound_port() will return the actual bound port. 71 virtual Status StartServiceInternal() = 0; StopServiceInternal()72 virtual void StopServiceInternal() {} 73 bound_port()74 int bound_port() { return bound_port_; } 75 76 const int requested_port_; 77 const std::string protocol_; 78 const std::string server_type_; 79 80 private: 81 int bound_port_; 82 bool started_ = false; 83 bool stopped_ = false; 84 85 std::unique_ptr<::grpc::Server> server_; 86 // TensorFlow profiler service implementation. 87 std::unique_ptr<grpc::ProfilerService::Service> profiler_service_ = nullptr; 88 std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> server_options_; 89 }; 90 91 class DispatchGrpcDataServer : public GrpcDataServerBase { 92 public: 93 explicit DispatchGrpcDataServer( 94 const experimental::DispatcherConfig& config, 95 std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options = {}); 96 ~DispatchGrpcDataServer() override; 97 98 // Returns the number of workers registered with the dispatcher. 99 Status NumWorkers(int* num_workers); 100 // Returns the number of active (non-finished) iterations running on the 101 // dispatcher. 102 size_t NumActiveIterations(); 103 104 ServerStateExport ExportState() const override; 105 106 protected: 107 void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override; 108 Status StartServiceInternal() override; 109 110 private: 111 const experimental::DispatcherConfig config_; 112 // Owned. We use a raw pointer because GrpcDispatcherImpl is forward-declared. 113 GrpcDispatcherImpl* service_; 114 }; 115 116 class WorkerGrpcDataServer : public GrpcDataServerBase { 117 public: 118 explicit WorkerGrpcDataServer( 119 const experimental::WorkerConfig& config, 120 std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options = {}); 121 ~WorkerGrpcDataServer() override; 122 123 // Returns the number of tasks currently being executed by the worker. 124 Status NumTasks(int* num_tasks); 125 126 ServerStateExport ExportState() const override; 127 128 protected: 129 void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override; 130 Status StartServiceInternal() override; 131 void StopServiceInternal() override; 132 133 private: 134 const experimental::WorkerConfig config_; 135 // Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared. 136 GrpcWorkerImpl* service_; 137 std::shared_ptr<DataTransferServer> transfer_server_; 138 }; 139 140 // Creates a dispatch tf.data server and stores it in `out_server`. 141 Status NewDispatchServer(const experimental::DispatcherConfig& config, 142 std::unique_ptr<DispatchGrpcDataServer>& out_server); 143 144 // Creates a worker tf.data server and stores it in `out_server`. 145 Status NewWorkerServer(const experimental::WorkerConfig& config, 146 std::unique_ptr<WorkerGrpcDataServer>& out_server); 147 148 } // namespace data 149 } // namespace tensorflow 150 151 #endif // TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_ 152