xref: /aosp_15_r20/external/federated-compute/fcp/client/grpc_bidi_stream.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_CLIENT_GRPC_BIDI_STREAM_H_
18 #define FCP_CLIENT_GRPC_BIDI_STREAM_H_
19 
20 #include <memory>
21 #include <string>
22 
23 #include "absl/base/attributes.h"
24 #include "absl/base/thread_annotations.h"
25 #include "absl/status/status.h"
26 #include "absl/synchronization/mutex.h"
27 #include "fcp/base/monitoring.h"
28 #include "fcp/protocol/grpc_chunked_bidi_stream.h"
29 #include "fcp/protos/federated_api.grpc.pb.h"
30 #include "grpcpp/impl/codegen/channel_interface.h"
31 #include "grpcpp/impl/codegen/client_context.h"
32 
33 namespace fcp {
34 namespace client {
35 
36 /**
37  * Interface to support dependency injection and hence testing
38  */
39 class GrpcBidiStreamInterface {
40  public:
41   virtual ~GrpcBidiStreamInterface() = default;
42 
43   virtual ABSL_MUST_USE_RESULT absl::Status Send(
44       google::internal::federatedml::v2::ClientStreamMessage* message) = 0;
45 
46   virtual ABSL_MUST_USE_RESULT absl::Status Receive(
47       google::internal::federatedml::v2::ServerStreamMessage* message) = 0;
48 
49   virtual void Close() = 0;
50 
51   virtual int64_t ChunkingLayerBytesSent() = 0;
52 
53   virtual int64_t ChunkingLayerBytesReceived() = 0;
54 };
55 
56 /**
57  * A class which encapsulates a chunking gRPC endpoint for the federated
58  * learning API.
59  *
60  * This class is thread-safe, but note that calls to Send() and Receive() are
61  * serialized *and* blocking.
62  */
63 class GrpcBidiStream : public GrpcBidiStreamInterface {
64  public:
65   /**
66    * Create a chunking gRPC endpoint for the federated learning API.
67    * @param target The URI of the target endpoint.
68    * @param api_key The API key of the target endpoint.
69    * @param population_name The population this connection is associated with.
70    * This param will not be empty if the include_population_in_header flag is
71    * False.
72    * @param grpc_channel_deadline_seconds The deadline (in seconds) for the gRPC
73    * channel.
74    * @param cert_path Test-only path to a CA certificate root, to be used in
75    * combination with an "https+test://" URI scheme.
76    */
77   GrpcBidiStream(const std::string& target, const std::string& api_key,
78                  const std::string& population_name,
79                  int64_t grpc_channel_deadline_seconds,
80                  std::string cert_path = "");
81 
82   /**
83    * @param channel A preexisting channel to the target endpoint.
84    * @param api_key The API of the target endpoint.
85    * @param population_name The population this connection is associated with.
86    * This param will not be empty if the include_population_in_header flag is
87    * False.
88    * @param grpc_channel_deadline_seconds The deadline (in seconds) for the gRPC
89    * channel.
90    */
91   GrpcBidiStream(const std::shared_ptr<grpc::ChannelInterface>& channel,
92                  const std::string& api_key, const std::string& population_name,
93                  int64_t grpc_channel_deadline_seconds);
94   ~GrpcBidiStream() override = default;
95 
96   // GrpcBidiStream is neither copyable nor movable.
97   GrpcBidiStream(const GrpcBidiStream&) = delete;
98   GrpcBidiStream& operator=(const GrpcBidiStream&) = delete;
99 
100   /**
101    * Send a ClientStreamMessage to the remote endpoint.
102    * @param message The message to send.
103    * @return absl::Status, which will have code OK if the message was sent
104    *   successfully.
105    */
106   ABSL_MUST_USE_RESULT absl::Status Send(
107       google::internal::federatedml::v2::ClientStreamMessage* message) override
108       ABSL_LOCKS_EXCLUDED(mu_);
109 
110   /**
111    * Receive a ServerStreamMessage from the remote endpoint.  Blocking.
112    * @param message The message to receive.
113    * @return absl::Status. This may be a translation of the status returned by
114    * the server, or a status generated during execution of the chunking
115    * protocol.
116    */
117   ABSL_MUST_USE_RESULT absl::Status Receive(
118       google::internal::federatedml::v2::ServerStreamMessage* message) override
119       ABSL_LOCKS_EXCLUDED(mu_);
120 
121   /**
122    * Close this stream.
123    * Releases any blocked readers. Thread safe.
124    */
125   void Close() override ABSL_LOCKS_EXCLUDED(mu_);
126 
127   /**
128    * Returns the number of bytes sent from the chunking layer.
129    * Flow control means this value may not increment until Receive() is called.
130    */
131   int64_t ChunkingLayerBytesSent() override;
132 
133   /**
134    * Returns the number of bytes received by the chunking layer.
135    */
136   int64_t ChunkingLayerBytesReceived() override;
137 
138   // Note: Must be lowercase:
139   static constexpr char kApiKeyHeader[] = "x-goog-api-key";
140   static constexpr char kPopulationNameHeader[] = "x-goog-population";
141 
142  private:
143   absl::Mutex mu_;
144   std::unique_ptr<google::internal::federatedml::v2::FederatedTrainingApi::Stub>
145       stub_;
146   grpc::ClientContext client_context_;
147   std::unique_ptr<grpc::ClientReaderWriter<
148       google::internal::federatedml::v2::ClientStreamMessage,
149       google::internal::federatedml::v2::ServerStreamMessage>>
150       client_reader_writer_ ABSL_GUARDED_BY(mu_);
151   std::unique_ptr<GrpcChunkedBidiStream<
152       google::internal::federatedml::v2::ClientStreamMessage,
153       google::internal::federatedml::v2::ServerStreamMessage>>
154       chunked_bidi_stream_ ABSL_GUARDED_BY(mu_);
155 };
156 
157 }  // namespace client
158 }  // namespace fcp
159 
160 #endif  // FCP_CLIENT_GRPC_BIDI_STREAM_H_
161