1*14675a02SAndroid Build Coastguard Worker /* 2*14675a02SAndroid Build Coastguard Worker * Copyright 2019 Google LLC 3*14675a02SAndroid Build Coastguard Worker * 4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*14675a02SAndroid Build Coastguard Worker * 8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*14675a02SAndroid Build Coastguard Worker * 10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*14675a02SAndroid Build Coastguard Worker * limitations under the License. 15*14675a02SAndroid Build Coastguard Worker */ 16*14675a02SAndroid Build Coastguard Worker 17*14675a02SAndroid Build Coastguard Worker #ifndef FCP_CLIENT_GRPC_BIDI_STREAM_H_ 18*14675a02SAndroid Build Coastguard Worker #define FCP_CLIENT_GRPC_BIDI_STREAM_H_ 19*14675a02SAndroid Build Coastguard Worker 20*14675a02SAndroid Build Coastguard Worker #include <memory> 21*14675a02SAndroid Build Coastguard Worker #include <string> 22*14675a02SAndroid Build Coastguard Worker 23*14675a02SAndroid Build Coastguard Worker #include "absl/base/attributes.h" 24*14675a02SAndroid Build Coastguard Worker #include "absl/base/thread_annotations.h" 25*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h" 26*14675a02SAndroid Build Coastguard Worker #include "absl/synchronization/mutex.h" 27*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h" 28*14675a02SAndroid Build Coastguard Worker #include "fcp/protocol/grpc_chunked_bidi_stream.h" 29*14675a02SAndroid Build Coastguard Worker #include "fcp/protos/federated_api.grpc.pb.h" 30*14675a02SAndroid Build Coastguard Worker #include "grpcpp/impl/codegen/channel_interface.h" 31*14675a02SAndroid Build Coastguard Worker #include "grpcpp/impl/codegen/client_context.h" 32*14675a02SAndroid Build Coastguard Worker 33*14675a02SAndroid Build Coastguard Worker namespace fcp { 34*14675a02SAndroid Build Coastguard Worker namespace client { 35*14675a02SAndroid Build Coastguard Worker 36*14675a02SAndroid Build Coastguard Worker /** 37*14675a02SAndroid Build Coastguard Worker * Interface to support dependency injection and hence testing 38*14675a02SAndroid Build Coastguard Worker */ 39*14675a02SAndroid Build Coastguard Worker class GrpcBidiStreamInterface { 40*14675a02SAndroid Build Coastguard Worker public: 41*14675a02SAndroid Build Coastguard Worker virtual ~GrpcBidiStreamInterface() = default; 42*14675a02SAndroid Build Coastguard Worker 43*14675a02SAndroid Build Coastguard Worker virtual ABSL_MUST_USE_RESULT absl::Status Send( 44*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::ClientStreamMessage* message) = 0; 45*14675a02SAndroid Build Coastguard Worker 46*14675a02SAndroid Build Coastguard Worker virtual ABSL_MUST_USE_RESULT absl::Status Receive( 47*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::ServerStreamMessage* message) = 0; 48*14675a02SAndroid Build Coastguard Worker 49*14675a02SAndroid Build Coastguard Worker virtual void Close() = 0; 50*14675a02SAndroid Build Coastguard Worker 51*14675a02SAndroid Build Coastguard Worker virtual int64_t ChunkingLayerBytesSent() = 0; 52*14675a02SAndroid Build Coastguard Worker 53*14675a02SAndroid Build Coastguard Worker virtual int64_t ChunkingLayerBytesReceived() = 0; 54*14675a02SAndroid Build Coastguard Worker }; 55*14675a02SAndroid Build Coastguard Worker 56*14675a02SAndroid Build Coastguard Worker /** 57*14675a02SAndroid Build Coastguard Worker * A class which encapsulates a chunking gRPC endpoint for the federated 58*14675a02SAndroid Build Coastguard Worker * learning API. 59*14675a02SAndroid Build Coastguard Worker * 60*14675a02SAndroid Build Coastguard Worker * This class is thread-safe, but note that calls to Send() and Receive() are 61*14675a02SAndroid Build Coastguard Worker * serialized *and* blocking. 62*14675a02SAndroid Build Coastguard Worker */ 63*14675a02SAndroid Build Coastguard Worker class GrpcBidiStream : public GrpcBidiStreamInterface { 64*14675a02SAndroid Build Coastguard Worker public: 65*14675a02SAndroid Build Coastguard Worker /** 66*14675a02SAndroid Build Coastguard Worker * Create a chunking gRPC endpoint for the federated learning API. 67*14675a02SAndroid Build Coastguard Worker * @param target The URI of the target endpoint. 68*14675a02SAndroid Build Coastguard Worker * @param api_key The API key of the target endpoint. 69*14675a02SAndroid Build Coastguard Worker * @param population_name The population this connection is associated with. 70*14675a02SAndroid Build Coastguard Worker * This param will not be empty if the include_population_in_header flag is 71*14675a02SAndroid Build Coastguard Worker * False. 72*14675a02SAndroid Build Coastguard Worker * @param grpc_channel_deadline_seconds The deadline (in seconds) for the gRPC 73*14675a02SAndroid Build Coastguard Worker * channel. 74*14675a02SAndroid Build Coastguard Worker * @param cert_path Test-only path to a CA certificate root, to be used in 75*14675a02SAndroid Build Coastguard Worker * combination with an "https+test://" URI scheme. 76*14675a02SAndroid Build Coastguard Worker */ 77*14675a02SAndroid Build Coastguard Worker GrpcBidiStream(const std::string& target, const std::string& api_key, 78*14675a02SAndroid Build Coastguard Worker const std::string& population_name, 79*14675a02SAndroid Build Coastguard Worker int64_t grpc_channel_deadline_seconds, 80*14675a02SAndroid Build Coastguard Worker std::string cert_path = ""); 81*14675a02SAndroid Build Coastguard Worker 82*14675a02SAndroid Build Coastguard Worker /** 83*14675a02SAndroid Build Coastguard Worker * @param channel A preexisting channel to the target endpoint. 84*14675a02SAndroid Build Coastguard Worker * @param api_key The API of the target endpoint. 85*14675a02SAndroid Build Coastguard Worker * @param population_name The population this connection is associated with. 86*14675a02SAndroid Build Coastguard Worker * This param will not be empty if the include_population_in_header flag is 87*14675a02SAndroid Build Coastguard Worker * False. 88*14675a02SAndroid Build Coastguard Worker * @param grpc_channel_deadline_seconds The deadline (in seconds) for the gRPC 89*14675a02SAndroid Build Coastguard Worker * channel. 90*14675a02SAndroid Build Coastguard Worker */ 91*14675a02SAndroid Build Coastguard Worker GrpcBidiStream(const std::shared_ptr<grpc::ChannelInterface>& channel, 92*14675a02SAndroid Build Coastguard Worker const std::string& api_key, const std::string& population_name, 93*14675a02SAndroid Build Coastguard Worker int64_t grpc_channel_deadline_seconds); 94*14675a02SAndroid Build Coastguard Worker ~GrpcBidiStream() override = default; 95*14675a02SAndroid Build Coastguard Worker 96*14675a02SAndroid Build Coastguard Worker // GrpcBidiStream is neither copyable nor movable. 97*14675a02SAndroid Build Coastguard Worker GrpcBidiStream(const GrpcBidiStream&) = delete; 98*14675a02SAndroid Build Coastguard Worker GrpcBidiStream& operator=(const GrpcBidiStream&) = delete; 99*14675a02SAndroid Build Coastguard Worker 100*14675a02SAndroid Build Coastguard Worker /** 101*14675a02SAndroid Build Coastguard Worker * Send a ClientStreamMessage to the remote endpoint. 102*14675a02SAndroid Build Coastguard Worker * @param message The message to send. 103*14675a02SAndroid Build Coastguard Worker * @return absl::Status, which will have code OK if the message was sent 104*14675a02SAndroid Build Coastguard Worker * successfully. 105*14675a02SAndroid Build Coastguard Worker */ 106*14675a02SAndroid Build Coastguard Worker ABSL_MUST_USE_RESULT absl::Status Send( 107*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::ClientStreamMessage* message) override 108*14675a02SAndroid Build Coastguard Worker ABSL_LOCKS_EXCLUDED(mu_); 109*14675a02SAndroid Build Coastguard Worker 110*14675a02SAndroid Build Coastguard Worker /** 111*14675a02SAndroid Build Coastguard Worker * Receive a ServerStreamMessage from the remote endpoint. Blocking. 112*14675a02SAndroid Build Coastguard Worker * @param message The message to receive. 113*14675a02SAndroid Build Coastguard Worker * @return absl::Status. This may be a translation of the status returned by 114*14675a02SAndroid Build Coastguard Worker * the server, or a status generated during execution of the chunking 115*14675a02SAndroid Build Coastguard Worker * protocol. 116*14675a02SAndroid Build Coastguard Worker */ 117*14675a02SAndroid Build Coastguard Worker ABSL_MUST_USE_RESULT absl::Status Receive( 118*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::ServerStreamMessage* message) override 119*14675a02SAndroid Build Coastguard Worker ABSL_LOCKS_EXCLUDED(mu_); 120*14675a02SAndroid Build Coastguard Worker 121*14675a02SAndroid Build Coastguard Worker /** 122*14675a02SAndroid Build Coastguard Worker * Close this stream. 123*14675a02SAndroid Build Coastguard Worker * Releases any blocked readers. Thread safe. 124*14675a02SAndroid Build Coastguard Worker */ 125*14675a02SAndroid Build Coastguard Worker void Close() override ABSL_LOCKS_EXCLUDED(mu_); 126*14675a02SAndroid Build Coastguard Worker 127*14675a02SAndroid Build Coastguard Worker /** 128*14675a02SAndroid Build Coastguard Worker * Returns the number of bytes sent from the chunking layer. 129*14675a02SAndroid Build Coastguard Worker * Flow control means this value may not increment until Receive() is called. 130*14675a02SAndroid Build Coastguard Worker */ 131*14675a02SAndroid Build Coastguard Worker int64_t ChunkingLayerBytesSent() override; 132*14675a02SAndroid Build Coastguard Worker 133*14675a02SAndroid Build Coastguard Worker /** 134*14675a02SAndroid Build Coastguard Worker * Returns the number of bytes received by the chunking layer. 135*14675a02SAndroid Build Coastguard Worker */ 136*14675a02SAndroid Build Coastguard Worker int64_t ChunkingLayerBytesReceived() override; 137*14675a02SAndroid Build Coastguard Worker 138*14675a02SAndroid Build Coastguard Worker // Note: Must be lowercase: 139*14675a02SAndroid Build Coastguard Worker static constexpr char kApiKeyHeader[] = "x-goog-api-key"; 140*14675a02SAndroid Build Coastguard Worker static constexpr char kPopulationNameHeader[] = "x-goog-population"; 141*14675a02SAndroid Build Coastguard Worker 142*14675a02SAndroid Build Coastguard Worker private: 143*14675a02SAndroid Build Coastguard Worker absl::Mutex mu_; 144*14675a02SAndroid Build Coastguard Worker std::unique_ptr<google::internal::federatedml::v2::FederatedTrainingApi::Stub> 145*14675a02SAndroid Build Coastguard Worker stub_; 146*14675a02SAndroid Build Coastguard Worker grpc::ClientContext client_context_; 147*14675a02SAndroid Build Coastguard Worker std::unique_ptr<grpc::ClientReaderWriter< 148*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::ClientStreamMessage, 149*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::ServerStreamMessage>> 150*14675a02SAndroid Build Coastguard Worker client_reader_writer_ ABSL_GUARDED_BY(mu_); 151*14675a02SAndroid Build Coastguard Worker std::unique_ptr<GrpcChunkedBidiStream< 152*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::ClientStreamMessage, 153*14675a02SAndroid Build Coastguard Worker google::internal::federatedml::v2::ServerStreamMessage>> 154*14675a02SAndroid Build Coastguard Worker chunked_bidi_stream_ ABSL_GUARDED_BY(mu_); 155*14675a02SAndroid Build Coastguard Worker }; 156*14675a02SAndroid Build Coastguard Worker 157*14675a02SAndroid Build Coastguard Worker } // namespace client 158*14675a02SAndroid Build Coastguard Worker } // namespace fcp 159*14675a02SAndroid Build Coastguard Worker 160*14675a02SAndroid Build Coastguard Worker #endif // FCP_CLIENT_GRPC_BIDI_STREAM_H_ 161