1 /* 2 * Copyright 2019 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_CLIENT_GRPC_BIDI_STREAM_H_ 18 #define FCP_CLIENT_GRPC_BIDI_STREAM_H_ 19 20 #include <memory> 21 #include <string> 22 23 #include "absl/base/attributes.h" 24 #include "absl/base/thread_annotations.h" 25 #include "absl/status/status.h" 26 #include "absl/synchronization/mutex.h" 27 #include "fcp/base/monitoring.h" 28 #include "fcp/protocol/grpc_chunked_bidi_stream.h" 29 #include "fcp/protos/federated_api.grpc.pb.h" 30 #include "grpcpp/impl/codegen/channel_interface.h" 31 #include "grpcpp/impl/codegen/client_context.h" 32 33 namespace fcp { 34 namespace client { 35 36 /** 37 * Interface to support dependency injection and hence testing 38 */ 39 class GrpcBidiStreamInterface { 40 public: 41 virtual ~GrpcBidiStreamInterface() = default; 42 43 virtual ABSL_MUST_USE_RESULT absl::Status Send( 44 google::internal::federatedml::v2::ClientStreamMessage* message) = 0; 45 46 virtual ABSL_MUST_USE_RESULT absl::Status Receive( 47 google::internal::federatedml::v2::ServerStreamMessage* message) = 0; 48 49 virtual void Close() = 0; 50 51 virtual int64_t ChunkingLayerBytesSent() = 0; 52 53 virtual int64_t ChunkingLayerBytesReceived() = 0; 54 }; 55 56 /** 57 * A class which encapsulates a chunking gRPC endpoint for the federated 58 * learning API. 59 * 60 * This class is thread-safe, but note that calls to Send() and Receive() are 61 * serialized *and* blocking. 62 */ 63 class GrpcBidiStream : public GrpcBidiStreamInterface { 64 public: 65 /** 66 * Create a chunking gRPC endpoint for the federated learning API. 67 * @param target The URI of the target endpoint. 68 * @param api_key The API key of the target endpoint. 69 * @param population_name The population this connection is associated with. 70 * This param will not be empty if the include_population_in_header flag is 71 * False. 72 * @param grpc_channel_deadline_seconds The deadline (in seconds) for the gRPC 73 * channel. 74 * @param cert_path Test-only path to a CA certificate root, to be used in 75 * combination with an "https+test://" URI scheme. 76 */ 77 GrpcBidiStream(const std::string& target, const std::string& api_key, 78 const std::string& population_name, 79 int64_t grpc_channel_deadline_seconds, 80 std::string cert_path = ""); 81 82 /** 83 * @param channel A preexisting channel to the target endpoint. 84 * @param api_key The API of the target endpoint. 85 * @param population_name The population this connection is associated with. 86 * This param will not be empty if the include_population_in_header flag is 87 * False. 88 * @param grpc_channel_deadline_seconds The deadline (in seconds) for the gRPC 89 * channel. 90 */ 91 GrpcBidiStream(const std::shared_ptr<grpc::ChannelInterface>& channel, 92 const std::string& api_key, const std::string& population_name, 93 int64_t grpc_channel_deadline_seconds); 94 ~GrpcBidiStream() override = default; 95 96 // GrpcBidiStream is neither copyable nor movable. 97 GrpcBidiStream(const GrpcBidiStream&) = delete; 98 GrpcBidiStream& operator=(const GrpcBidiStream&) = delete; 99 100 /** 101 * Send a ClientStreamMessage to the remote endpoint. 102 * @param message The message to send. 103 * @return absl::Status, which will have code OK if the message was sent 104 * successfully. 105 */ 106 ABSL_MUST_USE_RESULT absl::Status Send( 107 google::internal::federatedml::v2::ClientStreamMessage* message) override 108 ABSL_LOCKS_EXCLUDED(mu_); 109 110 /** 111 * Receive a ServerStreamMessage from the remote endpoint. Blocking. 112 * @param message The message to receive. 113 * @return absl::Status. This may be a translation of the status returned by 114 * the server, or a status generated during execution of the chunking 115 * protocol. 116 */ 117 ABSL_MUST_USE_RESULT absl::Status Receive( 118 google::internal::federatedml::v2::ServerStreamMessage* message) override 119 ABSL_LOCKS_EXCLUDED(mu_); 120 121 /** 122 * Close this stream. 123 * Releases any blocked readers. Thread safe. 124 */ 125 void Close() override ABSL_LOCKS_EXCLUDED(mu_); 126 127 /** 128 * Returns the number of bytes sent from the chunking layer. 129 * Flow control means this value may not increment until Receive() is called. 130 */ 131 int64_t ChunkingLayerBytesSent() override; 132 133 /** 134 * Returns the number of bytes received by the chunking layer. 135 */ 136 int64_t ChunkingLayerBytesReceived() override; 137 138 // Note: Must be lowercase: 139 static constexpr char kApiKeyHeader[] = "x-goog-api-key"; 140 static constexpr char kPopulationNameHeader[] = "x-goog-population"; 141 142 private: 143 absl::Mutex mu_; 144 std::unique_ptr<google::internal::federatedml::v2::FederatedTrainingApi::Stub> 145 stub_; 146 grpc::ClientContext client_context_; 147 std::unique_ptr<grpc::ClientReaderWriter< 148 google::internal::federatedml::v2::ClientStreamMessage, 149 google::internal::federatedml::v2::ServerStreamMessage>> 150 client_reader_writer_ ABSL_GUARDED_BY(mu_); 151 std::unique_ptr<GrpcChunkedBidiStream< 152 google::internal::federatedml::v2::ClientStreamMessage, 153 google::internal::federatedml::v2::ServerStreamMessage>> 154 chunked_bidi_stream_ ABSL_GUARDED_BY(mu_); 155 }; 156 157 } // namespace client 158 } // namespace fcp 159 160 #endif // FCP_CLIENT_GRPC_BIDI_STREAM_H_ 161