xref: /aosp_15_r20/external/federated-compute/fcp/client/grpc_bidi_stream.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2019 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker 
17*14675a02SAndroid Build Coastguard Worker #ifndef FCP_CLIENT_GRPC_BIDI_STREAM_H_
18*14675a02SAndroid Build Coastguard Worker #define FCP_CLIENT_GRPC_BIDI_STREAM_H_
19*14675a02SAndroid Build Coastguard Worker 
20*14675a02SAndroid Build Coastguard Worker #include <memory>
21*14675a02SAndroid Build Coastguard Worker #include <string>
22*14675a02SAndroid Build Coastguard Worker 
23*14675a02SAndroid Build Coastguard Worker #include "absl/base/attributes.h"
24*14675a02SAndroid Build Coastguard Worker #include "absl/base/thread_annotations.h"
25*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
26*14675a02SAndroid Build Coastguard Worker #include "absl/synchronization/mutex.h"
27*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
28*14675a02SAndroid Build Coastguard Worker #include "fcp/protocol/grpc_chunked_bidi_stream.h"
29*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/federated_api.grpc.pb.h"
30*14675a02SAndroid Build Coastguard Worker #include "grpcpp/impl/codegen/channel_interface.h"
31*14675a02SAndroid Build Coastguard Worker #include "grpcpp/impl/codegen/client_context.h"
32*14675a02SAndroid Build Coastguard Worker 
33*14675a02SAndroid Build Coastguard Worker namespace fcp {
34*14675a02SAndroid Build Coastguard Worker namespace client {
35*14675a02SAndroid Build Coastguard Worker 
36*14675a02SAndroid Build Coastguard Worker /**
37*14675a02SAndroid Build Coastguard Worker  * Interface to support dependency injection and hence testing
38*14675a02SAndroid Build Coastguard Worker  */
39*14675a02SAndroid Build Coastguard Worker class GrpcBidiStreamInterface {
40*14675a02SAndroid Build Coastguard Worker  public:
41*14675a02SAndroid Build Coastguard Worker   virtual ~GrpcBidiStreamInterface() = default;
42*14675a02SAndroid Build Coastguard Worker 
43*14675a02SAndroid Build Coastguard Worker   virtual ABSL_MUST_USE_RESULT absl::Status Send(
44*14675a02SAndroid Build Coastguard Worker       google::internal::federatedml::v2::ClientStreamMessage* message) = 0;
45*14675a02SAndroid Build Coastguard Worker 
46*14675a02SAndroid Build Coastguard Worker   virtual ABSL_MUST_USE_RESULT absl::Status Receive(
47*14675a02SAndroid Build Coastguard Worker       google::internal::federatedml::v2::ServerStreamMessage* message) = 0;
48*14675a02SAndroid Build Coastguard Worker 
49*14675a02SAndroid Build Coastguard Worker   virtual void Close() = 0;
50*14675a02SAndroid Build Coastguard Worker 
51*14675a02SAndroid Build Coastguard Worker   virtual int64_t ChunkingLayerBytesSent() = 0;
52*14675a02SAndroid Build Coastguard Worker 
53*14675a02SAndroid Build Coastguard Worker   virtual int64_t ChunkingLayerBytesReceived() = 0;
54*14675a02SAndroid Build Coastguard Worker };
55*14675a02SAndroid Build Coastguard Worker 
56*14675a02SAndroid Build Coastguard Worker /**
57*14675a02SAndroid Build Coastguard Worker  * A class which encapsulates a chunking gRPC endpoint for the federated
58*14675a02SAndroid Build Coastguard Worker  * learning API.
59*14675a02SAndroid Build Coastguard Worker  *
60*14675a02SAndroid Build Coastguard Worker  * This class is thread-safe, but note that calls to Send() and Receive() are
61*14675a02SAndroid Build Coastguard Worker  * serialized *and* blocking.
62*14675a02SAndroid Build Coastguard Worker  */
63*14675a02SAndroid Build Coastguard Worker class GrpcBidiStream : public GrpcBidiStreamInterface {
64*14675a02SAndroid Build Coastguard Worker  public:
65*14675a02SAndroid Build Coastguard Worker   /**
66*14675a02SAndroid Build Coastguard Worker    * Create a chunking gRPC endpoint for the federated learning API.
67*14675a02SAndroid Build Coastguard Worker    * @param target The URI of the target endpoint.
68*14675a02SAndroid Build Coastguard Worker    * @param api_key The API key of the target endpoint.
69*14675a02SAndroid Build Coastguard Worker    * @param population_name The population this connection is associated with.
70*14675a02SAndroid Build Coastguard Worker    * This param will not be empty if the include_population_in_header flag is
71*14675a02SAndroid Build Coastguard Worker    * False.
72*14675a02SAndroid Build Coastguard Worker    * @param grpc_channel_deadline_seconds The deadline (in seconds) for the gRPC
73*14675a02SAndroid Build Coastguard Worker    * channel.
74*14675a02SAndroid Build Coastguard Worker    * @param cert_path Test-only path to a CA certificate root, to be used in
75*14675a02SAndroid Build Coastguard Worker    * combination with an "https+test://" URI scheme.
76*14675a02SAndroid Build Coastguard Worker    */
77*14675a02SAndroid Build Coastguard Worker   GrpcBidiStream(const std::string& target, const std::string& api_key,
78*14675a02SAndroid Build Coastguard Worker                  const std::string& population_name,
79*14675a02SAndroid Build Coastguard Worker                  int64_t grpc_channel_deadline_seconds,
80*14675a02SAndroid Build Coastguard Worker                  std::string cert_path = "");
81*14675a02SAndroid Build Coastguard Worker 
82*14675a02SAndroid Build Coastguard Worker   /**
83*14675a02SAndroid Build Coastguard Worker    * @param channel A preexisting channel to the target endpoint.
84*14675a02SAndroid Build Coastguard Worker    * @param api_key The API of the target endpoint.
85*14675a02SAndroid Build Coastguard Worker    * @param population_name The population this connection is associated with.
86*14675a02SAndroid Build Coastguard Worker    * This param will not be empty if the include_population_in_header flag is
87*14675a02SAndroid Build Coastguard Worker    * False.
88*14675a02SAndroid Build Coastguard Worker    * @param grpc_channel_deadline_seconds The deadline (in seconds) for the gRPC
89*14675a02SAndroid Build Coastguard Worker    * channel.
90*14675a02SAndroid Build Coastguard Worker    */
91*14675a02SAndroid Build Coastguard Worker   GrpcBidiStream(const std::shared_ptr<grpc::ChannelInterface>& channel,
92*14675a02SAndroid Build Coastguard Worker                  const std::string& api_key, const std::string& population_name,
93*14675a02SAndroid Build Coastguard Worker                  int64_t grpc_channel_deadline_seconds);
94*14675a02SAndroid Build Coastguard Worker   ~GrpcBidiStream() override = default;
95*14675a02SAndroid Build Coastguard Worker 
96*14675a02SAndroid Build Coastguard Worker   // GrpcBidiStream is neither copyable nor movable.
97*14675a02SAndroid Build Coastguard Worker   GrpcBidiStream(const GrpcBidiStream&) = delete;
98*14675a02SAndroid Build Coastguard Worker   GrpcBidiStream& operator=(const GrpcBidiStream&) = delete;
99*14675a02SAndroid Build Coastguard Worker 
100*14675a02SAndroid Build Coastguard Worker   /**
101*14675a02SAndroid Build Coastguard Worker    * Send a ClientStreamMessage to the remote endpoint.
102*14675a02SAndroid Build Coastguard Worker    * @param message The message to send.
103*14675a02SAndroid Build Coastguard Worker    * @return absl::Status, which will have code OK if the message was sent
104*14675a02SAndroid Build Coastguard Worker    *   successfully.
105*14675a02SAndroid Build Coastguard Worker    */
106*14675a02SAndroid Build Coastguard Worker   ABSL_MUST_USE_RESULT absl::Status Send(
107*14675a02SAndroid Build Coastguard Worker       google::internal::federatedml::v2::ClientStreamMessage* message) override
108*14675a02SAndroid Build Coastguard Worker       ABSL_LOCKS_EXCLUDED(mu_);
109*14675a02SAndroid Build Coastguard Worker 
110*14675a02SAndroid Build Coastguard Worker   /**
111*14675a02SAndroid Build Coastguard Worker    * Receive a ServerStreamMessage from the remote endpoint.  Blocking.
112*14675a02SAndroid Build Coastguard Worker    * @param message The message to receive.
113*14675a02SAndroid Build Coastguard Worker    * @return absl::Status. This may be a translation of the status returned by
114*14675a02SAndroid Build Coastguard Worker    * the server, or a status generated during execution of the chunking
115*14675a02SAndroid Build Coastguard Worker    * protocol.
116*14675a02SAndroid Build Coastguard Worker    */
117*14675a02SAndroid Build Coastguard Worker   ABSL_MUST_USE_RESULT absl::Status Receive(
118*14675a02SAndroid Build Coastguard Worker       google::internal::federatedml::v2::ServerStreamMessage* message) override
119*14675a02SAndroid Build Coastguard Worker       ABSL_LOCKS_EXCLUDED(mu_);
120*14675a02SAndroid Build Coastguard Worker 
121*14675a02SAndroid Build Coastguard Worker   /**
122*14675a02SAndroid Build Coastguard Worker    * Close this stream.
123*14675a02SAndroid Build Coastguard Worker    * Releases any blocked readers. Thread safe.
124*14675a02SAndroid Build Coastguard Worker    */
125*14675a02SAndroid Build Coastguard Worker   void Close() override ABSL_LOCKS_EXCLUDED(mu_);
126*14675a02SAndroid Build Coastguard Worker 
127*14675a02SAndroid Build Coastguard Worker   /**
128*14675a02SAndroid Build Coastguard Worker    * Returns the number of bytes sent from the chunking layer.
129*14675a02SAndroid Build Coastguard Worker    * Flow control means this value may not increment until Receive() is called.
130*14675a02SAndroid Build Coastguard Worker    */
131*14675a02SAndroid Build Coastguard Worker   int64_t ChunkingLayerBytesSent() override;
132*14675a02SAndroid Build Coastguard Worker 
133*14675a02SAndroid Build Coastguard Worker   /**
134*14675a02SAndroid Build Coastguard Worker    * Returns the number of bytes received by the chunking layer.
135*14675a02SAndroid Build Coastguard Worker    */
136*14675a02SAndroid Build Coastguard Worker   int64_t ChunkingLayerBytesReceived() override;
137*14675a02SAndroid Build Coastguard Worker 
138*14675a02SAndroid Build Coastguard Worker   // Note: Must be lowercase:
139*14675a02SAndroid Build Coastguard Worker   static constexpr char kApiKeyHeader[] = "x-goog-api-key";
140*14675a02SAndroid Build Coastguard Worker   static constexpr char kPopulationNameHeader[] = "x-goog-population";
141*14675a02SAndroid Build Coastguard Worker 
142*14675a02SAndroid Build Coastguard Worker  private:
143*14675a02SAndroid Build Coastguard Worker   absl::Mutex mu_;
144*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<google::internal::federatedml::v2::FederatedTrainingApi::Stub>
145*14675a02SAndroid Build Coastguard Worker       stub_;
146*14675a02SAndroid Build Coastguard Worker   grpc::ClientContext client_context_;
147*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<grpc::ClientReaderWriter<
148*14675a02SAndroid Build Coastguard Worker       google::internal::federatedml::v2::ClientStreamMessage,
149*14675a02SAndroid Build Coastguard Worker       google::internal::federatedml::v2::ServerStreamMessage>>
150*14675a02SAndroid Build Coastguard Worker       client_reader_writer_ ABSL_GUARDED_BY(mu_);
151*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<GrpcChunkedBidiStream<
152*14675a02SAndroid Build Coastguard Worker       google::internal::federatedml::v2::ClientStreamMessage,
153*14675a02SAndroid Build Coastguard Worker       google::internal::federatedml::v2::ServerStreamMessage>>
154*14675a02SAndroid Build Coastguard Worker       chunked_bidi_stream_ ABSL_GUARDED_BY(mu_);
155*14675a02SAndroid Build Coastguard Worker };
156*14675a02SAndroid Build Coastguard Worker 
157*14675a02SAndroid Build Coastguard Worker }  // namespace client
158*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
159*14675a02SAndroid Build Coastguard Worker 
160*14675a02SAndroid Build Coastguard Worker #endif  // FCP_CLIENT_GRPC_BIDI_STREAM_H_
161