xref: /aosp_15_r20/external/federated-compute/fcp/client/grpc_bidi_stream.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2019 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker 
17*14675a02SAndroid Build Coastguard Worker #include "fcp/client/grpc_bidi_stream.h"
18*14675a02SAndroid Build Coastguard Worker 
19*14675a02SAndroid Build Coastguard Worker #include <memory>
20*14675a02SAndroid Build Coastguard Worker #include <string>
21*14675a02SAndroid Build Coastguard Worker #include <utility>
22*14675a02SAndroid Build Coastguard Worker 
23*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
24*14675a02SAndroid Build Coastguard Worker #include "fcp/base/status_converters.h"
25*14675a02SAndroid Build Coastguard Worker #include "fcp/client/grpc_bidi_channel.h"
26*14675a02SAndroid Build Coastguard Worker #include "grpcpp/support/time.h"
27*14675a02SAndroid Build Coastguard Worker 
28*14675a02SAndroid Build Coastguard Worker namespace fcp {
29*14675a02SAndroid Build Coastguard Worker namespace client {
30*14675a02SAndroid Build Coastguard Worker 
31*14675a02SAndroid Build Coastguard Worker using fcp::base::FromGrpcStatus;
32*14675a02SAndroid Build Coastguard Worker using google::internal::federatedml::v2::ClientStreamMessage;
33*14675a02SAndroid Build Coastguard Worker using google::internal::federatedml::v2::FederatedTrainingApi;
34*14675a02SAndroid Build Coastguard Worker using google::internal::federatedml::v2::ServerStreamMessage;
35*14675a02SAndroid Build Coastguard Worker using grpc::ChannelInterface;
36*14675a02SAndroid Build Coastguard Worker 
GrpcBidiStream(const std::string & target,const std::string & api_key,const std::string & population_name,int64_t grpc_channel_deadline_seconds,std::string cert_path)37*14675a02SAndroid Build Coastguard Worker GrpcBidiStream::GrpcBidiStream(const std::string& target,
38*14675a02SAndroid Build Coastguard Worker                                const std::string& api_key,
39*14675a02SAndroid Build Coastguard Worker                                const std::string& population_name,
40*14675a02SAndroid Build Coastguard Worker                                int64_t grpc_channel_deadline_seconds,
41*14675a02SAndroid Build Coastguard Worker                                std::string cert_path)
42*14675a02SAndroid Build Coastguard Worker     : GrpcBidiStream(GrpcBidiChannel::Create(target, std::move(cert_path)),
43*14675a02SAndroid Build Coastguard Worker                      api_key, population_name, grpc_channel_deadline_seconds) {}
44*14675a02SAndroid Build Coastguard Worker 
GrpcBidiStream(const std::shared_ptr<grpc::ChannelInterface> & channel,const std::string & api_key,const std::string & population_name,int64_t grpc_channel_deadline_seconds)45*14675a02SAndroid Build Coastguard Worker GrpcBidiStream::GrpcBidiStream(
46*14675a02SAndroid Build Coastguard Worker     const std::shared_ptr<grpc::ChannelInterface>& channel,
47*14675a02SAndroid Build Coastguard Worker     const std::string& api_key, const std::string& population_name,
48*14675a02SAndroid Build Coastguard Worker     int64_t grpc_channel_deadline_seconds)
49*14675a02SAndroid Build Coastguard Worker     : mu_(), stub_(FederatedTrainingApi::NewStub(channel)) {
50*14675a02SAndroid Build Coastguard Worker   FCP_LOG(INFO) << "Connecting to stub: " << stub_.get();
51*14675a02SAndroid Build Coastguard Worker   gpr_timespec deadline = gpr_time_add(
52*14675a02SAndroid Build Coastguard Worker       gpr_now(GPR_CLOCK_REALTIME),
53*14675a02SAndroid Build Coastguard Worker       gpr_time_from_seconds(grpc_channel_deadline_seconds, GPR_TIMESPAN));
54*14675a02SAndroid Build Coastguard Worker   client_context_.set_deadline(deadline);
55*14675a02SAndroid Build Coastguard Worker   client_context_.AddMetadata(kApiKeyHeader, api_key);
56*14675a02SAndroid Build Coastguard Worker   client_context_.AddMetadata(kPopulationNameHeader, population_name);
57*14675a02SAndroid Build Coastguard Worker   client_reader_writer_ = stub_->Session(&client_context_);
58*14675a02SAndroid Build Coastguard Worker   GrpcChunkedBidiStream<ClientStreamMessage,
59*14675a02SAndroid Build Coastguard Worker                         ServerStreamMessage>::GrpcChunkedBidiStreamOptions
60*14675a02SAndroid Build Coastguard Worker       options;
61*14675a02SAndroid Build Coastguard Worker   chunked_bidi_stream_ = std::make_unique<
62*14675a02SAndroid Build Coastguard Worker       GrpcChunkedBidiStream<ClientStreamMessage, ServerStreamMessage>>(
63*14675a02SAndroid Build Coastguard Worker       client_reader_writer_.get(), client_reader_writer_.get(), options);
64*14675a02SAndroid Build Coastguard Worker   if (!channel) Close();
65*14675a02SAndroid Build Coastguard Worker }
66*14675a02SAndroid Build Coastguard Worker 
Send(ClientStreamMessage * message)67*14675a02SAndroid Build Coastguard Worker absl::Status GrpcBidiStream::Send(ClientStreamMessage* message) {
68*14675a02SAndroid Build Coastguard Worker   absl::Status status;
69*14675a02SAndroid Build Coastguard Worker   {
70*14675a02SAndroid Build Coastguard Worker     absl::MutexLock _(&mu_);
71*14675a02SAndroid Build Coastguard Worker     if (client_reader_writer_ == nullptr) {
72*14675a02SAndroid Build Coastguard Worker       return absl::CancelledError(
73*14675a02SAndroid Build Coastguard Worker           "Send failed because GrpcBidiStream was closed.");
74*14675a02SAndroid Build Coastguard Worker     }
75*14675a02SAndroid Build Coastguard Worker     status = chunked_bidi_stream_->Send(message);
76*14675a02SAndroid Build Coastguard Worker     if (status.code() == absl::StatusCode::kAborted) {
77*14675a02SAndroid Build Coastguard Worker       FCP_LOG(INFO) << "Send aborted: " << status.code();
78*14675a02SAndroid Build Coastguard Worker       auto finish_status = FromGrpcStatus(client_reader_writer_->Finish());
79*14675a02SAndroid Build Coastguard Worker       // If the connection aborts early or harshly enough, there will be no
80*14675a02SAndroid Build Coastguard Worker       // error status from Finish().
81*14675a02SAndroid Build Coastguard Worker       if (!finish_status.ok()) status = finish_status;
82*14675a02SAndroid Build Coastguard Worker     }
83*14675a02SAndroid Build Coastguard Worker   }
84*14675a02SAndroid Build Coastguard Worker   if (!status.ok()) {
85*14675a02SAndroid Build Coastguard Worker     FCP_LOG(INFO) << "Closing; error on send: " << status.message();
86*14675a02SAndroid Build Coastguard Worker     Close();
87*14675a02SAndroid Build Coastguard Worker   }
88*14675a02SAndroid Build Coastguard Worker   return status;
89*14675a02SAndroid Build Coastguard Worker }
90*14675a02SAndroid Build Coastguard Worker 
Receive(ServerStreamMessage * message)91*14675a02SAndroid Build Coastguard Worker absl::Status GrpcBidiStream::Receive(ServerStreamMessage* message) {
92*14675a02SAndroid Build Coastguard Worker   absl::Status status;
93*14675a02SAndroid Build Coastguard Worker   {
94*14675a02SAndroid Build Coastguard Worker     absl::MutexLock _(&mu_);
95*14675a02SAndroid Build Coastguard Worker     if (client_reader_writer_ == nullptr) {
96*14675a02SAndroid Build Coastguard Worker       return absl::CancelledError(
97*14675a02SAndroid Build Coastguard Worker           "Receive failed because GrpcBidiStream was closed.");
98*14675a02SAndroid Build Coastguard Worker     }
99*14675a02SAndroid Build Coastguard Worker     status = chunked_bidi_stream_->Receive(message);
100*14675a02SAndroid Build Coastguard Worker     if (status.code() == absl::StatusCode::kAborted) {
101*14675a02SAndroid Build Coastguard Worker       FCP_LOG(INFO) << "Receive aborted: " << status.code();
102*14675a02SAndroid Build Coastguard Worker       auto finish_status = FromGrpcStatus(client_reader_writer_->Finish());
103*14675a02SAndroid Build Coastguard Worker       // If the connection aborts early or harshly enough, there will be no
104*14675a02SAndroid Build Coastguard Worker       // error status from Finish().
105*14675a02SAndroid Build Coastguard Worker       if (!finish_status.ok()) status = finish_status;
106*14675a02SAndroid Build Coastguard Worker     }
107*14675a02SAndroid Build Coastguard Worker   }
108*14675a02SAndroid Build Coastguard Worker   if (!status.ok()) {
109*14675a02SAndroid Build Coastguard Worker     FCP_LOG(INFO) << "Closing; error on receive: " << status.message();
110*14675a02SAndroid Build Coastguard Worker     Close();
111*14675a02SAndroid Build Coastguard Worker   }
112*14675a02SAndroid Build Coastguard Worker   return status;
113*14675a02SAndroid Build Coastguard Worker }
114*14675a02SAndroid Build Coastguard Worker 
Close()115*14675a02SAndroid Build Coastguard Worker void GrpcBidiStream::Close() {
116*14675a02SAndroid Build Coastguard Worker   if (!mu_.TryLock()) {
117*14675a02SAndroid Build Coastguard Worker     client_context_.TryCancel();
118*14675a02SAndroid Build Coastguard Worker     mu_.Lock();
119*14675a02SAndroid Build Coastguard Worker   }
120*14675a02SAndroid Build Coastguard Worker   chunked_bidi_stream_->Close();
121*14675a02SAndroid Build Coastguard Worker   if (client_reader_writer_) client_reader_writer_->WritesDone();
122*14675a02SAndroid Build Coastguard Worker   client_reader_writer_.reset();
123*14675a02SAndroid Build Coastguard Worker   FCP_LOG(INFO) << "Closing stub: " << stub_.get();
124*14675a02SAndroid Build Coastguard Worker   stub_.reset();
125*14675a02SAndroid Build Coastguard Worker   mu_.Unlock();
126*14675a02SAndroid Build Coastguard Worker }
127*14675a02SAndroid Build Coastguard Worker 
ChunkingLayerBytesReceived()128*14675a02SAndroid Build Coastguard Worker int64_t GrpcBidiStream::ChunkingLayerBytesReceived() {
129*14675a02SAndroid Build Coastguard Worker   absl::MutexLock _(&mu_);
130*14675a02SAndroid Build Coastguard Worker   return chunked_bidi_stream_->ChunkingLayerBytesReceived();
131*14675a02SAndroid Build Coastguard Worker }
132*14675a02SAndroid Build Coastguard Worker 
ChunkingLayerBytesSent()133*14675a02SAndroid Build Coastguard Worker int64_t GrpcBidiStream::ChunkingLayerBytesSent() {
134*14675a02SAndroid Build Coastguard Worker   absl::MutexLock _(&mu_);
135*14675a02SAndroid Build Coastguard Worker   return chunked_bidi_stream_->ChunkingLayerBytesSent();
136*14675a02SAndroid Build Coastguard Worker }
137*14675a02SAndroid Build Coastguard Worker 
138*14675a02SAndroid Build Coastguard Worker }  // namespace client
139*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
140