xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/server_lib.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_
17 #define TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "grpcpp/server.h"
24 #include "grpcpp/server_builder.h"
25 #include "tensorflow/core/data/service/data_transfer.h"
26 #include "tensorflow/core/data/service/export.pb.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
29 #include "tensorflow/core/protobuf/service_config.pb.h"
30 
31 namespace tensorflow {
32 namespace data {
33 
34 // Forward declared because transitively depending on .grpc.pb.h files causes
35 // issues in the pywrap build.
36 class GrpcDispatcherImpl;
37 class GrpcWorkerImpl;
38 
39 // A grpc server for the tf.data service.
40 class GrpcDataServerBase {
41  public:
42   // Constructs a tf.data server with the specified port. If the port is 0, the
43   // server will find an available port in `Start()`. The chosen port can be
44   // found by calling `BoundPort()`.
45   GrpcDataServerBase(
46       int requested_port, const std::string& protocol,
47       const std::string server_type,
48       std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options = {});
49   virtual ~GrpcDataServerBase() = default;
50 
51   // Starts the server running asynchronously.
52   Status Start();
53 
54   // Stops the server. This will block until all outstanding requests complete.
55   void Stop();
56 
57   // Blocks until the server stops.
58   void Join();
59 
60   // Returns the port bound by the server. Only valid after calling Start().
61   int BoundPort();
62 
63   // Exports the server state to improve debuggability.
64   virtual ServerStateExport ExportState() const = 0;
65 
66  protected:
67   virtual void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) = 0;
68   void AddProfilerServiceToBuilder(::grpc::ServerBuilder& builder);
69   // Starts the service. This will be called after building the service, so
70   // bound_port() will return the actual bound port.
71   virtual Status StartServiceInternal() = 0;
StopServiceInternal()72   virtual void StopServiceInternal() {}
73 
bound_port()74   int bound_port() { return bound_port_; }
75 
76   const int requested_port_;
77   const std::string protocol_;
78   const std::string server_type_;
79 
80  private:
81   int bound_port_;
82   bool started_ = false;
83   bool stopped_ = false;
84 
85   std::unique_ptr<::grpc::Server> server_;
86   // TensorFlow profiler service implementation.
87   std::unique_ptr<grpc::ProfilerService::Service> profiler_service_ = nullptr;
88   std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> server_options_;
89 };
90 
91 class DispatchGrpcDataServer : public GrpcDataServerBase {
92  public:
93   explicit DispatchGrpcDataServer(
94       const experimental::DispatcherConfig& config,
95       std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options = {});
96   ~DispatchGrpcDataServer() override;
97 
98   // Returns the number of workers registered with the dispatcher.
99   Status NumWorkers(int* num_workers);
100   // Returns the number of active (non-finished) iterations running on the
101   // dispatcher.
102   size_t NumActiveIterations();
103 
104   ServerStateExport ExportState() const override;
105 
106  protected:
107   void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override;
108   Status StartServiceInternal() override;
109 
110  private:
111   const experimental::DispatcherConfig config_;
112   // Owned. We use a raw pointer because GrpcDispatcherImpl is forward-declared.
113   GrpcDispatcherImpl* service_;
114 };
115 
116 class WorkerGrpcDataServer : public GrpcDataServerBase {
117  public:
118   explicit WorkerGrpcDataServer(
119       const experimental::WorkerConfig& config,
120       std::vector<std::unique_ptr<::grpc::ServerBuilderOption>> options = {});
121   ~WorkerGrpcDataServer() override;
122 
123   // Returns the number of tasks currently being executed by the worker.
124   Status NumTasks(int* num_tasks);
125 
126   ServerStateExport ExportState() const override;
127 
128  protected:
129   void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override;
130   Status StartServiceInternal() override;
131   void StopServiceInternal() override;
132 
133  private:
134   const experimental::WorkerConfig config_;
135   // Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared.
136   GrpcWorkerImpl* service_;
137   std::shared_ptr<DataTransferServer> transfer_server_;
138 };
139 
140 // Creates a dispatch tf.data server and stores it in `out_server`.
141 Status NewDispatchServer(const experimental::DispatcherConfig& config,
142                          std::unique_ptr<DispatchGrpcDataServer>& out_server);
143 
144 // Creates a worker tf.data server and stores it in `out_server`.
145 Status NewWorkerServer(const experimental::WorkerConfig& config,
146                        std::unique_ptr<WorkerGrpcDataServer>& out_server);
147 
148 }  // namespace data
149 }  // namespace tensorflow
150 
151 #endif  // TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_
152