xref: /aosp_15_r20/external/federated-compute/fcp/client/grpc_bidi_stream.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/client/grpc_bidi_stream.h"
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/status/status.h"
24 #include "fcp/base/status_converters.h"
25 #include "fcp/client/grpc_bidi_channel.h"
26 #include "grpcpp/support/time.h"
27 
28 namespace fcp {
29 namespace client {
30 
31 using fcp::base::FromGrpcStatus;
32 using google::internal::federatedml::v2::ClientStreamMessage;
33 using google::internal::federatedml::v2::FederatedTrainingApi;
34 using google::internal::federatedml::v2::ServerStreamMessage;
35 using grpc::ChannelInterface;
36 
GrpcBidiStream(const std::string & target,const std::string & api_key,const std::string & population_name,int64_t grpc_channel_deadline_seconds,std::string cert_path)37 GrpcBidiStream::GrpcBidiStream(const std::string& target,
38                                const std::string& api_key,
39                                const std::string& population_name,
40                                int64_t grpc_channel_deadline_seconds,
41                                std::string cert_path)
42     : GrpcBidiStream(GrpcBidiChannel::Create(target, std::move(cert_path)),
43                      api_key, population_name, grpc_channel_deadline_seconds) {}
44 
GrpcBidiStream(const std::shared_ptr<grpc::ChannelInterface> & channel,const std::string & api_key,const std::string & population_name,int64_t grpc_channel_deadline_seconds)45 GrpcBidiStream::GrpcBidiStream(
46     const std::shared_ptr<grpc::ChannelInterface>& channel,
47     const std::string& api_key, const std::string& population_name,
48     int64_t grpc_channel_deadline_seconds)
49     : mu_(), stub_(FederatedTrainingApi::NewStub(channel)) {
50   FCP_LOG(INFO) << "Connecting to stub: " << stub_.get();
51   gpr_timespec deadline = gpr_time_add(
52       gpr_now(GPR_CLOCK_REALTIME),
53       gpr_time_from_seconds(grpc_channel_deadline_seconds, GPR_TIMESPAN));
54   client_context_.set_deadline(deadline);
55   client_context_.AddMetadata(kApiKeyHeader, api_key);
56   client_context_.AddMetadata(kPopulationNameHeader, population_name);
57   client_reader_writer_ = stub_->Session(&client_context_);
58   GrpcChunkedBidiStream<ClientStreamMessage,
59                         ServerStreamMessage>::GrpcChunkedBidiStreamOptions
60       options;
61   chunked_bidi_stream_ = std::make_unique<
62       GrpcChunkedBidiStream<ClientStreamMessage, ServerStreamMessage>>(
63       client_reader_writer_.get(), client_reader_writer_.get(), options);
64   if (!channel) Close();
65 }
66 
Send(ClientStreamMessage * message)67 absl::Status GrpcBidiStream::Send(ClientStreamMessage* message) {
68   absl::Status status;
69   {
70     absl::MutexLock _(&mu_);
71     if (client_reader_writer_ == nullptr) {
72       return absl::CancelledError(
73           "Send failed because GrpcBidiStream was closed.");
74     }
75     status = chunked_bidi_stream_->Send(message);
76     if (status.code() == absl::StatusCode::kAborted) {
77       FCP_LOG(INFO) << "Send aborted: " << status.code();
78       auto finish_status = FromGrpcStatus(client_reader_writer_->Finish());
79       // If the connection aborts early or harshly enough, there will be no
80       // error status from Finish().
81       if (!finish_status.ok()) status = finish_status;
82     }
83   }
84   if (!status.ok()) {
85     FCP_LOG(INFO) << "Closing; error on send: " << status.message();
86     Close();
87   }
88   return status;
89 }
90 
Receive(ServerStreamMessage * message)91 absl::Status GrpcBidiStream::Receive(ServerStreamMessage* message) {
92   absl::Status status;
93   {
94     absl::MutexLock _(&mu_);
95     if (client_reader_writer_ == nullptr) {
96       return absl::CancelledError(
97           "Receive failed because GrpcBidiStream was closed.");
98     }
99     status = chunked_bidi_stream_->Receive(message);
100     if (status.code() == absl::StatusCode::kAborted) {
101       FCP_LOG(INFO) << "Receive aborted: " << status.code();
102       auto finish_status = FromGrpcStatus(client_reader_writer_->Finish());
103       // If the connection aborts early or harshly enough, there will be no
104       // error status from Finish().
105       if (!finish_status.ok()) status = finish_status;
106     }
107   }
108   if (!status.ok()) {
109     FCP_LOG(INFO) << "Closing; error on receive: " << status.message();
110     Close();
111   }
112   return status;
113 }
114 
Close()115 void GrpcBidiStream::Close() {
116   if (!mu_.TryLock()) {
117     client_context_.TryCancel();
118     mu_.Lock();
119   }
120   chunked_bidi_stream_->Close();
121   if (client_reader_writer_) client_reader_writer_->WritesDone();
122   client_reader_writer_.reset();
123   FCP_LOG(INFO) << "Closing stub: " << stub_.get();
124   stub_.reset();
125   mu_.Unlock();
126 }
127 
ChunkingLayerBytesReceived()128 int64_t GrpcBidiStream::ChunkingLayerBytesReceived() {
129   absl::MutexLock _(&mu_);
130   return chunked_bidi_stream_->ChunkingLayerBytesReceived();
131 }
132 
ChunkingLayerBytesSent()133 int64_t GrpcBidiStream::ChunkingLayerBytesSent() {
134   absl::MutexLock _(&mu_);
135   return chunked_bidi_stream_->ChunkingLayerBytesSent();
136 }
137 
138 }  // namespace client
139 }  // namespace fcp
140