xref: /aosp_15_r20/external/federated-compute/fcp/protocol/grpc_chunked_bidi_stream_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/protocol/grpc_chunked_bidi_stream.h"
18 
19 #include <cctype>
20 #include <string>
21 #include <tuple>
22 
23 #include "gmock/gmock.h"
24 #include "gtest/gtest.h"
25 #include "absl/algorithm/container.h"
26 #include "absl/strings/str_cat.h"
27 #include "fcp/base/monitoring.h"
28 #include "fcp/client/fake_server.h"
29 #include "fcp/client/grpc_bidi_stream.h"
30 #include "fcp/protos/federated_api.pb.h"
31 #include "fcp/testing/testing.h"
32 #include "grpcpp/security/server_credentials.h"
33 #include "grpcpp/server.h"
34 #include "grpcpp/server_builder.h"
35 
36 namespace fcp {
37 namespace client {
38 namespace test {
39 namespace {
40 
41 using google::internal::federatedml::v2::ClientStreamMessage;
42 using google::internal::federatedml::v2::CompressionLevel;
43 using google::internal::federatedml::v2::ServerStreamMessage;
44 using ::testing::Gt;
45 using ::testing::Le;
46 using ::testing::Not;
47 
SimpleSelfVerifyingString(size_t n)48 std::string SimpleSelfVerifyingString(size_t n) {
49   std::string str;
50   str.reserve(n);
51   for (auto i = 0; i < n; ++i) str.push_back(static_cast<char>(i % 128));
52   return str;
53 }
54 
VerifyString(const std::string str)55 bool VerifyString(const std::string str) {
56   auto n = str.length();
57   for (auto i = 0; i < n; ++i) {
58     if (str[i] != static_cast<char>(i % 128)) return false;
59   }
60   return true;
61 }
62 
PerformInitialCheckin(GrpcBidiStream * stream)63 Status PerformInitialCheckin(GrpcBidiStream* stream) {
64   ClientStreamMessage request_;
65   ServerStreamMessage reply;
66   Status status;
67 
68   auto options =
69       request_.mutable_checkin_request()->mutable_protocol_options_request();
70   options->set_supports_chunked_blob_transfer(true);
71   options->add_supported_compression_levels(CompressionLevel::UNCOMPRESSED);
72   options->add_supported_compression_levels(CompressionLevel::ZLIB_DEFAULT);
73   options->add_supported_compression_levels(
74       CompressionLevel::ZLIB_BEST_COMPRESSION);
75   options->add_supported_compression_levels(CompressionLevel::ZLIB_BEST_SPEED);
76 
77   EXPECT_THAT((status = stream->Send(&request_)), IsOk());
78   EXPECT_THAT((status = stream->Receive(&reply)), IsOk());
79   EXPECT_TRUE(reply.has_checkin_response()) << reply.DebugString();
80 
81   return status;
82 }
83 
84 using ChunkingParameters =
85     std::tuple<int32_t, /* chunk_size_for_upload */
86                int32_t, /* max_pending_chunks */
87                google::internal::federatedml::v2::CompressionLevel,
88                int32_t, /* request size */
89                size_t,  /* request count */
90                int32_t, /* reply size */
91                size_t   /* replies per request */
92                >;
93 
94 class ByteVerifyingFakeServer : public FakeServer {
95  public:
ByteVerifyingFakeServer(const ChunkingParameters & params)96   explicit ByteVerifyingFakeServer(const ChunkingParameters& params)
97       : FakeServer(std::get<0>(params), std::get<1>(params),
98                    std::get<2>(params)),
99         reply_size_(std::get<5>(params)),
100         replies_per_request_(std::get<6>(params)) {}
101 
Handle(const ClientStreamMessage & request,ServerStreamMessage * first_reply,GrpcChunkedBidiStream<ServerStreamMessage,ClientStreamMessage> * stream)102   Status Handle(const ClientStreamMessage& request,
103                 ServerStreamMessage* first_reply,
104                 GrpcChunkedBidiStream<ServerStreamMessage, ClientStreamMessage>*
105                     stream) override {
106     Status status;
107     if (request.has_checkin_request()) {
108       EXPECT_THAT((status = stream->Send(first_reply)), IsOk());
109       return status;
110     }
111     EXPECT_TRUE(
112         VerifyString(request.report_request().report().update_checkpoint()));
113     ServerStreamMessage reply;
114     reply.mutable_report_response()->mutable_retry_window()->set_retry_token(
115         SimpleSelfVerifyingString(reply_size_));
116     for (auto i = 0; i < replies_per_request_; ++i)
117       EXPECT_THAT((status = stream->Send(&reply)), IsOk());
118     return status;
119   }
120 
121  private:
122   int32_t reply_size_;
123   size_t replies_per_request_;
124 };
125 
126 class GrpcChunkedMessageStreamTest
127     : public ::testing::TestWithParam<ChunkingParameters> {
128  public:
GrpcChunkedMessageStreamTest()129   GrpcChunkedMessageStreamTest() : server_impl_(GetParam()) {
130     auto params = GetParam();
131     request_size_ = std::get<3>(params);
132     request_count_ = std::get<4>(params);
133     reply_size_ = std::get<5>(params);
134     replies_per_request_ = std::get<6>(params);
135 
136     grpc::ServerBuilder builder;
137     builder.AddListeningPort("dns:///localhost:0",
138                              grpc::InsecureServerCredentials(), &port_);
139     builder.RegisterService(&server_impl_);
140     grpc_server_ = builder.BuildAndStart();
141     client_stream_ =
142         std::make_unique<GrpcBidiStream>(addr_uri(), "none", "",
143                                          /*grpc_channel_deadline_seconds=*/600);
144     EXPECT_THAT(PerformInitialCheckin(client_stream_.get()), IsOk());
145 
146     request_.mutable_report_request()->mutable_report()->set_update_checkpoint(
147         SimpleSelfVerifyingString(request_size_));
148   }
149 
addr_uri()150   std::string addr_uri() { return absl::StrCat(kAddrUri, ":", port_); }
151 
152   int32_t request_size_;
153   size_t request_count_;
154   int32_t reply_size_;
155   size_t replies_per_request_;
156 
157   static constexpr char kAddrUri[] = "dns:///localhost";
158   ByteVerifyingFakeServer server_impl_;
159   int port_ = -1;
160   std::unique_ptr<grpc::Server> grpc_server_;
161   std::unique_ptr<GrpcBidiStream> client_stream_;
162 
163   ClientStreamMessage request_;
164   ServerStreamMessage reply_;
165 };
166 
TEST_P(GrpcChunkedMessageStreamTest,RequestReply)167 TEST_P(GrpcChunkedMessageStreamTest, RequestReply) {
168   for (size_t i = 0; i < request_count_; ++i) {
169     EXPECT_THAT(client_stream_->Send(&request_), IsOk());
170     for (size_t i = 0; i < replies_per_request_; ++i) {
171       EXPECT_THAT(client_stream_->Receive(&reply_), IsOk());
172       EXPECT_TRUE(
173           VerifyString(reply_.report_response().retry_window().retry_token()));
174     }
175   }
176   client_stream_->Close();
177   EXPECT_THAT(client_stream_->Receive(&reply_), Not(IsOk()));
178 }
179 
TEST_P(GrpcChunkedMessageStreamTest,RequestReplyChunkingLayerBandwidth)180 TEST_P(GrpcChunkedMessageStreamTest, RequestReplyChunkingLayerBandwidth) {
181   int64_t bytes_sent_so_far = client_stream_->ChunkingLayerBytesSent();
182   int64_t bytes_received_so_far = client_stream_->ChunkingLayerBytesReceived();
183   for (size_t i = 0; i < request_count_; ++i) {
184     EXPECT_THAT(client_stream_->Send(&request_), IsOk());
185     int64_t request_message_size = request_.ByteSizeLong();
186     // Sends may be deferred if flow control has paused the stream; in this
187     // case, they will not be recorded in statistics until they are sent as part
188     // of the next Receive(). Therefore, we assert sizes after the receives.
189 
190     for (size_t i = 0; i < replies_per_request_; ++i) {
191       EXPECT_THAT(client_stream_->Receive(&reply_), IsOk());
192       int64_t bytes_received_delta =
193           client_stream_->ChunkingLayerBytesReceived() - bytes_received_so_far;
194       EXPECT_THAT(bytes_received_delta, Gt(0));
195       int64_t receive_message_size = reply_.ByteSizeLong();
196       // Small messages may actually be expanded due to compression overhead.
197       if (receive_message_size > 64) {
198         EXPECT_THAT(bytes_received_delta, Le(receive_message_size));
199       }
200       bytes_received_so_far += bytes_received_delta;
201     }
202 
203     int64_t bytes_sent_delta =
204         client_stream_->ChunkingLayerBytesSent() - bytes_sent_so_far;
205     EXPECT_THAT(client_stream_->ChunkingLayerBytesSent(), Gt(0));
206     EXPECT_THAT(bytes_sent_delta, Gt(0));
207     // Small messages may actually be expanded due to compression overhead.
208     if (request_message_size > 64) {
209       EXPECT_THAT(bytes_sent_delta, Le(request_message_size));
210     }
211     bytes_sent_so_far += bytes_sent_delta;
212   }
213   client_stream_->Close();
214   EXPECT_THAT(client_stream_->Receive(&reply_), Not(IsOk()));
215 }
216 
217 // #define GRPC_CHUNKED_EXPENSIVE_COMPRESSED_TESTS
218 #if defined(GRPC_CHUNKED_EXPENSIVE_COMPRESSED_TESTS)
219 // Ideally we would generate a covering array rather than a Cartesian product.
220 INSTANTIATE_TEST_SUITE_P(
221     CartesianProductExpensive, GrpcChunkedMessageStreamTest,
222     testing::Combine(
223         /* chunk_size_for_upload */
224         testing::ValuesIn({0, 1, 129}),
225         /* max_pending_chunks */
226         testing::ValuesIn({0, 1, 129}),
227         /* compression_level */
228         testing::ValuesIn({CompressionLevel::ZLIB_DEFAULT}),
229         /* request size */
230         testing::ValuesIn({0, 1, 129}),
231         /* request count */
232         testing::ValuesIn({1ul, 129ul}),
233         /* reply size */
234         testing::ValuesIn({0, 1, 129}),
235         /* replies per request */
236         testing::ValuesIn({1ul, 129ul})),
237     [](const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>&
__anon1db5036f0202(const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>& info) 238            info) {
239       // clang-format off
240       std::string name = absl::StrCat(
241           std::get<0>(info.param), "csfu" "_",
242           std::get<1>(info.param), "mpc", "_",
243           std::get<2>(info.param), "cl", "_",
244           std::get<3>(info.param), "rqs", "_",
245           std::get<4>(info.param), "rqc", "_",
246           std::get<5>(info.param), "rps", "_",
247           std::get<6>(info.param), "rppr");
248       absl::c_replace_if(
249           name, [](char c) { return !std::isalnum(c); }, '_');
250       // clang-format on
251       return name;
252     });
253 #endif
254 
255 // #define GRPC_CHUNKED_EXPENSIVE_UNCOMPRESSED_TESTS
256 #if defined(GRPC_CHUNKED_EXPENSIVE_UNCOMPRESSED_TESTS)
257 // Ideally we would generate a covering array rather than a Cartesian product.
258 INSTANTIATE_TEST_SUITE_P(
259     CartesianProductUncompressed, GrpcChunkedMessageStreamTest,
260     testing::Combine(
261         /* chunk_size_for_upload */
262         testing::ValuesIn({0, 1, 129}),
263         /* max_pending_chunks */
264         testing::ValuesIn({0, 1, 129}),
265         /* compression_level */
266         testing::ValuesIn({CompressionLevel::UNCOMPRESSED}),
267         /* request size */
268         testing::ValuesIn({0, 1, 129}),
269         /* request count */
270         testing::ValuesIn({1ul, 129ul}),
271         /* reply size */
272         testing::ValuesIn({0, 1, 129}),
273         /* replies per request */
274         testing::ValuesIn({1ul, 129ul})),
275     [](const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>&
__anon1db5036f0402(const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>& info) 276            info) {
277       // clang-format off
278       std::string name = absl::StrCat(
279           std::get<0>(info.param), "csfu" "_",
280           std::get<1>(info.param), "mpc", "_",
281           std::get<2>(info.param), "cl", "_",
282           std::get<3>(info.param), "rqs", "_",
283           std::get<4>(info.param), "rqc", "_",
284           std::get<5>(info.param), "rps", "_",
285           std::get<6>(info.param), "rppr");
286       absl::c_replace_if(
287           name, [](char c) { return !std::isalnum(c); }, '_');
288       // clang-format on
289       return name;
290     });
291 #endif
292 
293 INSTANTIATE_TEST_SUITE_P(
294     CartesianProductLargeChunks, GrpcChunkedMessageStreamTest,
295     testing::Combine(
296         /* chunk_size_for_upload */
297         testing::ValuesIn({8192}),
298         /* max_pending_chunks */
299         testing::ValuesIn({2}),
300         /* compression_level */
301         testing::ValuesIn({CompressionLevel::ZLIB_BEST_SPEED}),
302         /* request size */
303         testing::ValuesIn({1024 * 1024 * 10}),
304         /* request count */
305         testing::ValuesIn({2ul}),
306         /* reply size */
307         testing::ValuesIn({1024 * 1024 * 10}),
308         /* replies per request */
309         testing::ValuesIn({2ul})),
310     [](const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>&
__anon1db5036f0602(const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>& info) 311            info) {
312       // clang-format off
313       std::string name = absl::StrCat(
314           std::get<0>(info.param), "csfu" "_",
315           std::get<1>(info.param), "mpc", "_",
316           std::get<2>(info.param), "cl", "_",
317           std::get<3>(info.param), "rqs", "_",
318           std::get<4>(info.param), "rqc", "_",
319           std::get<5>(info.param), "rps", "_",
320           std::get<6>(info.param), "rppr");
321       absl::c_replace_if(
322           name, [](char c) { return !std::isalnum(c); }, '_');
323       // clang-format on
324       return name;
325     });
326 
327 }  // namespace
328 }  // namespace test
329 }  // namespace client
330 }  // namespace fcp
331