1 /*
2 * Copyright 2019 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/protocol/grpc_chunked_bidi_stream.h"
18
19 #include <cctype>
20 #include <string>
21 #include <tuple>
22
23 #include "gmock/gmock.h"
24 #include "gtest/gtest.h"
25 #include "absl/algorithm/container.h"
26 #include "absl/strings/str_cat.h"
27 #include "fcp/base/monitoring.h"
28 #include "fcp/client/fake_server.h"
29 #include "fcp/client/grpc_bidi_stream.h"
30 #include "fcp/protos/federated_api.pb.h"
31 #include "fcp/testing/testing.h"
32 #include "grpcpp/security/server_credentials.h"
33 #include "grpcpp/server.h"
34 #include "grpcpp/server_builder.h"
35
36 namespace fcp {
37 namespace client {
38 namespace test {
39 namespace {
40
41 using google::internal::federatedml::v2::ClientStreamMessage;
42 using google::internal::federatedml::v2::CompressionLevel;
43 using google::internal::federatedml::v2::ServerStreamMessage;
44 using ::testing::Gt;
45 using ::testing::Le;
46 using ::testing::Not;
47
SimpleSelfVerifyingString(size_t n)48 std::string SimpleSelfVerifyingString(size_t n) {
49 std::string str;
50 str.reserve(n);
51 for (auto i = 0; i < n; ++i) str.push_back(static_cast<char>(i % 128));
52 return str;
53 }
54
VerifyString(const std::string str)55 bool VerifyString(const std::string str) {
56 auto n = str.length();
57 for (auto i = 0; i < n; ++i) {
58 if (str[i] != static_cast<char>(i % 128)) return false;
59 }
60 return true;
61 }
62
PerformInitialCheckin(GrpcBidiStream * stream)63 Status PerformInitialCheckin(GrpcBidiStream* stream) {
64 ClientStreamMessage request_;
65 ServerStreamMessage reply;
66 Status status;
67
68 auto options =
69 request_.mutable_checkin_request()->mutable_protocol_options_request();
70 options->set_supports_chunked_blob_transfer(true);
71 options->add_supported_compression_levels(CompressionLevel::UNCOMPRESSED);
72 options->add_supported_compression_levels(CompressionLevel::ZLIB_DEFAULT);
73 options->add_supported_compression_levels(
74 CompressionLevel::ZLIB_BEST_COMPRESSION);
75 options->add_supported_compression_levels(CompressionLevel::ZLIB_BEST_SPEED);
76
77 EXPECT_THAT((status = stream->Send(&request_)), IsOk());
78 EXPECT_THAT((status = stream->Receive(&reply)), IsOk());
79 EXPECT_TRUE(reply.has_checkin_response()) << reply.DebugString();
80
81 return status;
82 }
83
84 using ChunkingParameters =
85 std::tuple<int32_t, /* chunk_size_for_upload */
86 int32_t, /* max_pending_chunks */
87 google::internal::federatedml::v2::CompressionLevel,
88 int32_t, /* request size */
89 size_t, /* request count */
90 int32_t, /* reply size */
91 size_t /* replies per request */
92 >;
93
94 class ByteVerifyingFakeServer : public FakeServer {
95 public:
ByteVerifyingFakeServer(const ChunkingParameters & params)96 explicit ByteVerifyingFakeServer(const ChunkingParameters& params)
97 : FakeServer(std::get<0>(params), std::get<1>(params),
98 std::get<2>(params)),
99 reply_size_(std::get<5>(params)),
100 replies_per_request_(std::get<6>(params)) {}
101
Handle(const ClientStreamMessage & request,ServerStreamMessage * first_reply,GrpcChunkedBidiStream<ServerStreamMessage,ClientStreamMessage> * stream)102 Status Handle(const ClientStreamMessage& request,
103 ServerStreamMessage* first_reply,
104 GrpcChunkedBidiStream<ServerStreamMessage, ClientStreamMessage>*
105 stream) override {
106 Status status;
107 if (request.has_checkin_request()) {
108 EXPECT_THAT((status = stream->Send(first_reply)), IsOk());
109 return status;
110 }
111 EXPECT_TRUE(
112 VerifyString(request.report_request().report().update_checkpoint()));
113 ServerStreamMessage reply;
114 reply.mutable_report_response()->mutable_retry_window()->set_retry_token(
115 SimpleSelfVerifyingString(reply_size_));
116 for (auto i = 0; i < replies_per_request_; ++i)
117 EXPECT_THAT((status = stream->Send(&reply)), IsOk());
118 return status;
119 }
120
121 private:
122 int32_t reply_size_;
123 size_t replies_per_request_;
124 };
125
126 class GrpcChunkedMessageStreamTest
127 : public ::testing::TestWithParam<ChunkingParameters> {
128 public:
GrpcChunkedMessageStreamTest()129 GrpcChunkedMessageStreamTest() : server_impl_(GetParam()) {
130 auto params = GetParam();
131 request_size_ = std::get<3>(params);
132 request_count_ = std::get<4>(params);
133 reply_size_ = std::get<5>(params);
134 replies_per_request_ = std::get<6>(params);
135
136 grpc::ServerBuilder builder;
137 builder.AddListeningPort("dns:///localhost:0",
138 grpc::InsecureServerCredentials(), &port_);
139 builder.RegisterService(&server_impl_);
140 grpc_server_ = builder.BuildAndStart();
141 client_stream_ =
142 std::make_unique<GrpcBidiStream>(addr_uri(), "none", "",
143 /*grpc_channel_deadline_seconds=*/600);
144 EXPECT_THAT(PerformInitialCheckin(client_stream_.get()), IsOk());
145
146 request_.mutable_report_request()->mutable_report()->set_update_checkpoint(
147 SimpleSelfVerifyingString(request_size_));
148 }
149
addr_uri()150 std::string addr_uri() { return absl::StrCat(kAddrUri, ":", port_); }
151
152 int32_t request_size_;
153 size_t request_count_;
154 int32_t reply_size_;
155 size_t replies_per_request_;
156
157 static constexpr char kAddrUri[] = "dns:///localhost";
158 ByteVerifyingFakeServer server_impl_;
159 int port_ = -1;
160 std::unique_ptr<grpc::Server> grpc_server_;
161 std::unique_ptr<GrpcBidiStream> client_stream_;
162
163 ClientStreamMessage request_;
164 ServerStreamMessage reply_;
165 };
166
TEST_P(GrpcChunkedMessageStreamTest,RequestReply)167 TEST_P(GrpcChunkedMessageStreamTest, RequestReply) {
168 for (size_t i = 0; i < request_count_; ++i) {
169 EXPECT_THAT(client_stream_->Send(&request_), IsOk());
170 for (size_t i = 0; i < replies_per_request_; ++i) {
171 EXPECT_THAT(client_stream_->Receive(&reply_), IsOk());
172 EXPECT_TRUE(
173 VerifyString(reply_.report_response().retry_window().retry_token()));
174 }
175 }
176 client_stream_->Close();
177 EXPECT_THAT(client_stream_->Receive(&reply_), Not(IsOk()));
178 }
179
TEST_P(GrpcChunkedMessageStreamTest,RequestReplyChunkingLayerBandwidth)180 TEST_P(GrpcChunkedMessageStreamTest, RequestReplyChunkingLayerBandwidth) {
181 int64_t bytes_sent_so_far = client_stream_->ChunkingLayerBytesSent();
182 int64_t bytes_received_so_far = client_stream_->ChunkingLayerBytesReceived();
183 for (size_t i = 0; i < request_count_; ++i) {
184 EXPECT_THAT(client_stream_->Send(&request_), IsOk());
185 int64_t request_message_size = request_.ByteSizeLong();
186 // Sends may be deferred if flow control has paused the stream; in this
187 // case, they will not be recorded in statistics until they are sent as part
188 // of the next Receive(). Therefore, we assert sizes after the receives.
189
190 for (size_t i = 0; i < replies_per_request_; ++i) {
191 EXPECT_THAT(client_stream_->Receive(&reply_), IsOk());
192 int64_t bytes_received_delta =
193 client_stream_->ChunkingLayerBytesReceived() - bytes_received_so_far;
194 EXPECT_THAT(bytes_received_delta, Gt(0));
195 int64_t receive_message_size = reply_.ByteSizeLong();
196 // Small messages may actually be expanded due to compression overhead.
197 if (receive_message_size > 64) {
198 EXPECT_THAT(bytes_received_delta, Le(receive_message_size));
199 }
200 bytes_received_so_far += bytes_received_delta;
201 }
202
203 int64_t bytes_sent_delta =
204 client_stream_->ChunkingLayerBytesSent() - bytes_sent_so_far;
205 EXPECT_THAT(client_stream_->ChunkingLayerBytesSent(), Gt(0));
206 EXPECT_THAT(bytes_sent_delta, Gt(0));
207 // Small messages may actually be expanded due to compression overhead.
208 if (request_message_size > 64) {
209 EXPECT_THAT(bytes_sent_delta, Le(request_message_size));
210 }
211 bytes_sent_so_far += bytes_sent_delta;
212 }
213 client_stream_->Close();
214 EXPECT_THAT(client_stream_->Receive(&reply_), Not(IsOk()));
215 }
216
217 // #define GRPC_CHUNKED_EXPENSIVE_COMPRESSED_TESTS
218 #if defined(GRPC_CHUNKED_EXPENSIVE_COMPRESSED_TESTS)
219 // Ideally we would generate a covering array rather than a Cartesian product.
220 INSTANTIATE_TEST_SUITE_P(
221 CartesianProductExpensive, GrpcChunkedMessageStreamTest,
222 testing::Combine(
223 /* chunk_size_for_upload */
224 testing::ValuesIn({0, 1, 129}),
225 /* max_pending_chunks */
226 testing::ValuesIn({0, 1, 129}),
227 /* compression_level */
228 testing::ValuesIn({CompressionLevel::ZLIB_DEFAULT}),
229 /* request size */
230 testing::ValuesIn({0, 1, 129}),
231 /* request count */
232 testing::ValuesIn({1ul, 129ul}),
233 /* reply size */
234 testing::ValuesIn({0, 1, 129}),
235 /* replies per request */
236 testing::ValuesIn({1ul, 129ul})),
237 [](const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>&
__anon1db5036f0202(const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>& info) 238 info) {
239 // clang-format off
240 std::string name = absl::StrCat(
241 std::get<0>(info.param), "csfu" "_",
242 std::get<1>(info.param), "mpc", "_",
243 std::get<2>(info.param), "cl", "_",
244 std::get<3>(info.param), "rqs", "_",
245 std::get<4>(info.param), "rqc", "_",
246 std::get<5>(info.param), "rps", "_",
247 std::get<6>(info.param), "rppr");
248 absl::c_replace_if(
249 name, [](char c) { return !std::isalnum(c); }, '_');
250 // clang-format on
251 return name;
252 });
253 #endif
254
255 // #define GRPC_CHUNKED_EXPENSIVE_UNCOMPRESSED_TESTS
256 #if defined(GRPC_CHUNKED_EXPENSIVE_UNCOMPRESSED_TESTS)
257 // Ideally we would generate a covering array rather than a Cartesian product.
258 INSTANTIATE_TEST_SUITE_P(
259 CartesianProductUncompressed, GrpcChunkedMessageStreamTest,
260 testing::Combine(
261 /* chunk_size_for_upload */
262 testing::ValuesIn({0, 1, 129}),
263 /* max_pending_chunks */
264 testing::ValuesIn({0, 1, 129}),
265 /* compression_level */
266 testing::ValuesIn({CompressionLevel::UNCOMPRESSED}),
267 /* request size */
268 testing::ValuesIn({0, 1, 129}),
269 /* request count */
270 testing::ValuesIn({1ul, 129ul}),
271 /* reply size */
272 testing::ValuesIn({0, 1, 129}),
273 /* replies per request */
274 testing::ValuesIn({1ul, 129ul})),
275 [](const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>&
__anon1db5036f0402(const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>& info) 276 info) {
277 // clang-format off
278 std::string name = absl::StrCat(
279 std::get<0>(info.param), "csfu" "_",
280 std::get<1>(info.param), "mpc", "_",
281 std::get<2>(info.param), "cl", "_",
282 std::get<3>(info.param), "rqs", "_",
283 std::get<4>(info.param), "rqc", "_",
284 std::get<5>(info.param), "rps", "_",
285 std::get<6>(info.param), "rppr");
286 absl::c_replace_if(
287 name, [](char c) { return !std::isalnum(c); }, '_');
288 // clang-format on
289 return name;
290 });
291 #endif
292
293 INSTANTIATE_TEST_SUITE_P(
294 CartesianProductLargeChunks, GrpcChunkedMessageStreamTest,
295 testing::Combine(
296 /* chunk_size_for_upload */
297 testing::ValuesIn({8192}),
298 /* max_pending_chunks */
299 testing::ValuesIn({2}),
300 /* compression_level */
301 testing::ValuesIn({CompressionLevel::ZLIB_BEST_SPEED}),
302 /* request size */
303 testing::ValuesIn({1024 * 1024 * 10}),
304 /* request count */
305 testing::ValuesIn({2ul}),
306 /* reply size */
307 testing::ValuesIn({1024 * 1024 * 10}),
308 /* replies per request */
309 testing::ValuesIn({2ul})),
310 [](const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>&
__anon1db5036f0602(const testing::TestParamInfo<GrpcChunkedMessageStreamTest::ParamType>& info) 311 info) {
312 // clang-format off
313 std::string name = absl::StrCat(
314 std::get<0>(info.param), "csfu" "_",
315 std::get<1>(info.param), "mpc", "_",
316 std::get<2>(info.param), "cl", "_",
317 std::get<3>(info.param), "rqs", "_",
318 std::get<4>(info.param), "rqc", "_",
319 std::get<5>(info.param), "rps", "_",
320 std::get<6>(info.param), "rppr");
321 absl::c_replace_if(
322 name, [](char c) { return !std::isalnum(c); }, '_');
323 // clang-format on
324 return name;
325 });
326
327 } // namespace
328 } // namespace test
329 } // namespace client
330 } // namespace fcp
331