1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_rpc_transport/socket_rpc_transport.h"
16
17 #include <algorithm>
18 #include <random>
19
20 #include "pw_bytes/span.h"
21 #include "pw_log/log.h"
22 #include "pw_rpc_transport/socket_rpc_transport.h"
23 #include "pw_span/span.h"
24 #include "pw_status/status.h"
25 #include "pw_stream/socket_stream.h"
26 #include "pw_sync/thread_notification.h"
27 #include "pw_thread/thread.h"
28 #include "pw_thread_stl/options.h"
29 #include "pw_unit_test/framework.h"
30
31 namespace pw::rpc {
32 namespace {
33
34 using namespace std::chrono_literals;
35
36 constexpr size_t kMaxWriteSize = 64;
37 constexpr size_t kReadBufferSize = 64;
38 // Let the kernel pick the port number.
39 constexpr uint16_t kServerPort = 0;
40
41 class TestIngress : public RpcIngressHandler {
42 public:
TestIngress(size_t num_bytes_expected)43 explicit TestIngress(size_t num_bytes_expected)
44 : num_bytes_expected_(num_bytes_expected) {}
45
ProcessIncomingData(ConstByteSpan buffer)46 Status ProcessIncomingData(ConstByteSpan buffer) override {
47 if (num_bytes_expected_ > 0) {
48 std::copy(buffer.begin(), buffer.end(), std::back_inserter(received_));
49 num_bytes_expected_ -= std::min(num_bytes_expected_, buffer.size());
50 }
51 if (num_bytes_expected_ == 0) {
52 done_.release();
53 }
54 return OkStatus();
55 }
56
received() const57 std::vector<std::byte> received() const { return received_; }
Wait()58 void Wait() { done_.acquire(); }
59
60 private:
61 size_t num_bytes_expected_ = 0;
62 sync::ThreadNotification done_;
63 std::vector<std::byte> received_;
64 };
65
66 class SocketSender {
67 public:
SocketSender(SocketRpcTransport<kReadBufferSize> & transport)68 SocketSender(SocketRpcTransport<kReadBufferSize>& transport)
69 : transport_(transport) {
70 unsigned char c = 0;
71 for (auto& i : data_) {
72 i = std::byte{c++};
73 }
74 std::mt19937 rg{0x12345678};
75 std::shuffle(data_.begin(), data_.end(), rg);
76 }
77
sent()78 std::vector<std::byte> sent() { return sent_; }
79
MakeFrame(size_t max_size)80 RpcFrame MakeFrame(size_t max_size) {
81 std::mt19937 rg{0x12345678};
82 size_t offset = offset_dist_(rg);
83 size_t message_size = std::min(size_dist_(rg), max_size);
84 size_t header_size = message_size > 4 ? 4 : message_size;
85 size_t payload_size = message_size > 4 ? message_size - 4 : 0;
86
87 return RpcFrame{.header = span(data_).subspan(offset, header_size),
88 .payload = span(data_).subspan(offset, payload_size)};
89 }
90
Send(size_t num_bytes)91 void Send(size_t num_bytes) {
92 size_t bytes_written = 0;
93 while (bytes_written < num_bytes) {
94 auto frame = MakeFrame(num_bytes - bytes_written);
95 std::copy(
96 frame.header.begin(), frame.header.end(), std::back_inserter(sent_));
97 std::copy(frame.payload.begin(),
98 frame.payload.end(),
99 std::back_inserter(sent_));
100
101 // Tests below expect to see all data written to the socket to be received
102 // by the other end, so we keep retrying on any errors that could happen
103 // during reconnection: in reality it would be up to the higher level
104 // abstractions to do this depending on how they manage buffers etc. For
105 // the tests we just keep retrying indefinitely: if there is a
106 // non-transient problem then the test will eventually time out.
107 while (true) {
108 const auto send_status = transport_.Send(frame);
109 if (send_status.ok()) {
110 break;
111 }
112 }
113
114 bytes_written += frame.header.size() + frame.payload.size();
115 }
116 }
117
118 private:
119 SocketRpcTransport<kReadBufferSize>& transport_;
120 std::vector<std::byte> sent_;
121 std::array<std::byte, 256> data_{};
122 std::uniform_int_distribution<size_t> offset_dist_{0, 255};
123 std::uniform_int_distribution<size_t> size_dist_{1, kMaxWriteSize};
124 };
125
126 class SocketSenderThreadCore : public SocketSender, public thread::ThreadCore {
127 public:
SocketSenderThreadCore(SocketRpcTransport<kReadBufferSize> & transport,size_t write_size)128 SocketSenderThreadCore(SocketRpcTransport<kReadBufferSize>& transport,
129 size_t write_size)
130 : SocketSender(transport), write_size_(write_size) {}
131
132 private:
Run()133 void Run() override { Send(write_size_); }
134 size_t write_size_;
135 };
136
TEST(SocketRpcTransportTest,SendAndReceiveFramesOverSocketConnection)137 TEST(SocketRpcTransportTest, SendAndReceiveFramesOverSocketConnection) {
138 constexpr size_t kWriteSize = 8192;
139
140 TestIngress server_ingress(kWriteSize);
141 TestIngress client_ingress(kWriteSize);
142
143 auto server = SocketRpcTransport<kReadBufferSize>(
144 SocketRpcTransport<kReadBufferSize>::kAsServer,
145 kServerPort,
146 server_ingress);
147 auto server_thread = Thread(thread::stl::Options(), server);
148
149 server.WaitUntilReady();
150 auto server_port = server.port();
151
152 auto client = SocketRpcTransport<kReadBufferSize>(
153 SocketRpcTransport<kReadBufferSize>::kAsClient,
154 "localhost",
155 server_port,
156 client_ingress);
157 auto client_thread = Thread(thread::stl::Options(), client);
158
159 client.WaitUntilConnected();
160 server.WaitUntilConnected();
161
162 SocketSenderThreadCore client_sender(client, kWriteSize);
163 SocketSenderThreadCore server_sender(server, kWriteSize);
164
165 auto client_sender_thread = Thread(thread::stl::Options(), client_sender);
166 auto server_sender_thread = Thread(thread::stl::Options(), server_sender);
167
168 client_sender_thread.join();
169 server_sender_thread.join();
170
171 server_ingress.Wait();
172 client_ingress.Wait();
173
174 server.Stop();
175 client.Stop();
176
177 server_thread.join();
178 client_thread.join();
179
180 auto received_by_server = server_ingress.received();
181 EXPECT_EQ(received_by_server.size(), kWriteSize);
182 EXPECT_TRUE(std::equal(received_by_server.begin(),
183 received_by_server.end(),
184 client_sender.sent().begin()));
185
186 auto received_by_client = client_ingress.received();
187 EXPECT_EQ(received_by_client.size(), kWriteSize);
188 EXPECT_TRUE(std::equal(received_by_client.begin(),
189 received_by_client.end(),
190 server_sender.sent().begin()));
191 }
192
TEST(SocketRpcTransportTest,ServerReconnects)193 TEST(SocketRpcTransportTest, ServerReconnects) {
194 // Set up a server and a client that reconnects multiple times. The server
195 // must accept the new connection gracefully.
196 constexpr size_t kWriteSize = 8192;
197 std::vector<std::byte> received;
198
199 TestIngress server_ingress(0);
200 auto server = SocketRpcTransport<kReadBufferSize>(
201 SocketRpcTransport<kReadBufferSize>::kAsServer,
202 kServerPort,
203 server_ingress);
204 auto server_thread = Thread(thread::stl::Options(), server);
205
206 server.WaitUntilReady();
207 auto server_port = server.port();
208 SocketSender server_sender(server);
209
210 {
211 TestIngress client_ingress(kWriteSize);
212 auto client = SocketRpcTransport<kReadBufferSize>(
213 SocketRpcTransport<kReadBufferSize>::kAsClient,
214 "localhost",
215 server_port,
216 client_ingress);
217 auto client_thread = Thread(thread::stl::Options(), client);
218
219 client.WaitUntilConnected();
220 server.WaitUntilConnected();
221
222 server_sender.Send(kWriteSize);
223 client_ingress.Wait();
224 auto client_received = client_ingress.received();
225 std::copy(client_received.begin(),
226 client_received.end(),
227 std::back_inserter(received));
228 EXPECT_EQ(received.size(), kWriteSize);
229
230 // Stop the client but not the server: we're re-using the same server
231 // with a new client below.
232 client.Stop();
233 client_thread.join();
234 }
235
236 // Reconnect to the server and keep sending frames.
237 {
238 TestIngress client_ingress(kWriteSize);
239 auto client = SocketRpcTransport<kReadBufferSize>(
240 SocketRpcTransport<kReadBufferSize>::kAsClient,
241 "localhost",
242 server_port,
243 client_ingress);
244 auto client_thread = Thread(thread::stl::Options(), client);
245
246 client.WaitUntilConnected();
247 server.WaitUntilConnected();
248
249 server_sender.Send(kWriteSize);
250 client_ingress.Wait();
251 auto client_received = client_ingress.received();
252 std::copy(client_received.begin(),
253 client_received.end(),
254 std::back_inserter(received));
255
256 client.Stop();
257 client_thread.join();
258
259 // This time stop the server as well.
260 SocketSender client_sender(client);
261 server.Stop();
262 server_thread.join();
263 }
264
265 EXPECT_EQ(received.size(), 2 * kWriteSize);
266 EXPECT_EQ(server_sender.sent().size(), 2 * kWriteSize);
267 EXPECT_TRUE(std::equal(
268 received.begin(), received.end(), server_sender.sent().begin()));
269 }
270
TEST(SocketRpcTransportTest,ClientReconnects)271 TEST(SocketRpcTransportTest, ClientReconnects) {
272 // Set up a server and a client, then recycle the server. The client must
273 // must reconnect gracefully.
274 constexpr size_t kWriteSize = 8192;
275 uint16_t server_port = 0;
276
277 TestIngress server_ingress(0);
278 TestIngress client_ingress(2 * kWriteSize);
279
280 auto server = std::make_unique<SocketRpcTransport<kReadBufferSize>>(
281 SocketRpcTransport<kReadBufferSize>::kAsServer,
282 kServerPort,
283 server_ingress);
284 auto server_thread = Thread(thread::stl::Options(), *server);
285
286 server->WaitUntilReady();
287 server_port = server->port();
288
289 auto client = SocketRpcTransport<kReadBufferSize>(
290 SocketRpcTransport<kReadBufferSize>::kAsClient,
291 "localhost",
292 server_port,
293 client_ingress);
294 auto client_thread = Thread(thread::stl::Options(), client);
295
296 client.WaitUntilConnected();
297 server->WaitUntilConnected();
298
299 SocketSender client_sender(client);
300 SocketSender server1_sender(*server);
301 std::vector<std::byte> sent_by_server;
302
303 server1_sender.Send(kWriteSize);
304 server->Stop();
305 auto server1_sent = server1_sender.sent();
306 std::copy(server1_sent.begin(),
307 server1_sent.end(),
308 std::back_inserter(sent_by_server));
309
310 server_thread.join();
311 server = nullptr;
312
313 server = std::make_unique<SocketRpcTransport<kReadBufferSize>>(
314 SocketRpcTransport<kReadBufferSize>::kAsServer,
315 server_port,
316 server_ingress);
317 SocketSender server2_sender(*server);
318 server_thread = Thread(thread::stl::Options(), *server);
319
320 client.WaitUntilConnected();
321 server->WaitUntilConnected();
322
323 server2_sender.Send(kWriteSize);
324 client_ingress.Wait();
325
326 server->Stop();
327 auto server2_sent = server2_sender.sent();
328 std::copy(server2_sent.begin(),
329 server2_sent.end(),
330 std::back_inserter(sent_by_server));
331
332 server_thread.join();
333
334 client.Stop();
335 client_thread.join();
336 server = nullptr;
337
338 auto received_by_client = client_ingress.received();
339 EXPECT_EQ(received_by_client.size(), 2 * kWriteSize);
340 EXPECT_TRUE(std::equal(received_by_client.begin(),
341 received_by_client.end(),
342 sent_by_server.begin()));
343 }
344
345 } // namespace
346 } // namespace pw::rpc
347