1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/client/grpc_bidi_stream.h"
18
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "absl/status/status.h"
24 #include "fcp/base/status_converters.h"
25 #include "fcp/client/grpc_bidi_channel.h"
26 #include "grpcpp/support/time.h"
27
28 namespace fcp {
29 namespace client {
30
31 using fcp::base::FromGrpcStatus;
32 using google::internal::federatedml::v2::ClientStreamMessage;
33 using google::internal::federatedml::v2::FederatedTrainingApi;
34 using google::internal::federatedml::v2::ServerStreamMessage;
35 using grpc::ChannelInterface;
36
GrpcBidiStream(const std::string & target,const std::string & api_key,const std::string & population_name,int64_t grpc_channel_deadline_seconds,std::string cert_path)37 GrpcBidiStream::GrpcBidiStream(const std::string& target,
38 const std::string& api_key,
39 const std::string& population_name,
40 int64_t grpc_channel_deadline_seconds,
41 std::string cert_path)
42 : GrpcBidiStream(GrpcBidiChannel::Create(target, std::move(cert_path)),
43 api_key, population_name, grpc_channel_deadline_seconds) {}
44
GrpcBidiStream(const std::shared_ptr<grpc::ChannelInterface> & channel,const std::string & api_key,const std::string & population_name,int64_t grpc_channel_deadline_seconds)45 GrpcBidiStream::GrpcBidiStream(
46 const std::shared_ptr<grpc::ChannelInterface>& channel,
47 const std::string& api_key, const std::string& population_name,
48 int64_t grpc_channel_deadline_seconds)
49 : mu_(), stub_(FederatedTrainingApi::NewStub(channel)) {
50 FCP_LOG(INFO) << "Connecting to stub: " << stub_.get();
51 gpr_timespec deadline = gpr_time_add(
52 gpr_now(GPR_CLOCK_REALTIME),
53 gpr_time_from_seconds(grpc_channel_deadline_seconds, GPR_TIMESPAN));
54 client_context_.set_deadline(deadline);
55 client_context_.AddMetadata(kApiKeyHeader, api_key);
56 client_context_.AddMetadata(kPopulationNameHeader, population_name);
57 client_reader_writer_ = stub_->Session(&client_context_);
58 GrpcChunkedBidiStream<ClientStreamMessage,
59 ServerStreamMessage>::GrpcChunkedBidiStreamOptions
60 options;
61 chunked_bidi_stream_ = std::make_unique<
62 GrpcChunkedBidiStream<ClientStreamMessage, ServerStreamMessage>>(
63 client_reader_writer_.get(), client_reader_writer_.get(), options);
64 if (!channel) Close();
65 }
66
Send(ClientStreamMessage * message)67 absl::Status GrpcBidiStream::Send(ClientStreamMessage* message) {
68 absl::Status status;
69 {
70 absl::MutexLock _(&mu_);
71 if (client_reader_writer_ == nullptr) {
72 return absl::CancelledError(
73 "Send failed because GrpcBidiStream was closed.");
74 }
75 status = chunked_bidi_stream_->Send(message);
76 if (status.code() == absl::StatusCode::kAborted) {
77 FCP_LOG(INFO) << "Send aborted: " << status.code();
78 auto finish_status = FromGrpcStatus(client_reader_writer_->Finish());
79 // If the connection aborts early or harshly enough, there will be no
80 // error status from Finish().
81 if (!finish_status.ok()) status = finish_status;
82 }
83 }
84 if (!status.ok()) {
85 FCP_LOG(INFO) << "Closing; error on send: " << status.message();
86 Close();
87 }
88 return status;
89 }
90
Receive(ServerStreamMessage * message)91 absl::Status GrpcBidiStream::Receive(ServerStreamMessage* message) {
92 absl::Status status;
93 {
94 absl::MutexLock _(&mu_);
95 if (client_reader_writer_ == nullptr) {
96 return absl::CancelledError(
97 "Receive failed because GrpcBidiStream was closed.");
98 }
99 status = chunked_bidi_stream_->Receive(message);
100 if (status.code() == absl::StatusCode::kAborted) {
101 FCP_LOG(INFO) << "Receive aborted: " << status.code();
102 auto finish_status = FromGrpcStatus(client_reader_writer_->Finish());
103 // If the connection aborts early or harshly enough, there will be no
104 // error status from Finish().
105 if (!finish_status.ok()) status = finish_status;
106 }
107 }
108 if (!status.ok()) {
109 FCP_LOG(INFO) << "Closing; error on receive: " << status.message();
110 Close();
111 }
112 return status;
113 }
114
Close()115 void GrpcBidiStream::Close() {
116 if (!mu_.TryLock()) {
117 client_context_.TryCancel();
118 mu_.Lock();
119 }
120 chunked_bidi_stream_->Close();
121 if (client_reader_writer_) client_reader_writer_->WritesDone();
122 client_reader_writer_.reset();
123 FCP_LOG(INFO) << "Closing stub: " << stub_.get();
124 stub_.reset();
125 mu_.Unlock();
126 }
127
ChunkingLayerBytesReceived()128 int64_t GrpcBidiStream::ChunkingLayerBytesReceived() {
129 absl::MutexLock _(&mu_);
130 return chunked_bidi_stream_->ChunkingLayerBytesReceived();
131 }
132
ChunkingLayerBytesSent()133 int64_t GrpcBidiStream::ChunkingLayerBytesSent() {
134 absl::MutexLock _(&mu_);
135 return chunked_bidi_stream_->ChunkingLayerBytesSent();
136 }
137
138 } // namespace client
139 } // namespace fcp
140