xref: /aosp_15_r20/external/pigweed/pw_rpc_transport/socket_rpc_transport_test.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_rpc_transport/socket_rpc_transport.h"
16 
17 #include <algorithm>
18 #include <random>
19 
20 #include "pw_bytes/span.h"
21 #include "pw_log/log.h"
22 #include "pw_rpc_transport/socket_rpc_transport.h"
23 #include "pw_span/span.h"
24 #include "pw_status/status.h"
25 #include "pw_stream/socket_stream.h"
26 #include "pw_sync/thread_notification.h"
27 #include "pw_thread/thread.h"
28 #include "pw_thread_stl/options.h"
29 #include "pw_unit_test/framework.h"
30 
31 namespace pw::rpc {
32 namespace {
33 
34 using namespace std::chrono_literals;
35 
36 constexpr size_t kMaxWriteSize = 64;
37 constexpr size_t kReadBufferSize = 64;
38 // Let the kernel pick the port number.
39 constexpr uint16_t kServerPort = 0;
40 
41 class TestIngress : public RpcIngressHandler {
42  public:
TestIngress(size_t num_bytes_expected)43   explicit TestIngress(size_t num_bytes_expected)
44       : num_bytes_expected_(num_bytes_expected) {}
45 
ProcessIncomingData(ConstByteSpan buffer)46   Status ProcessIncomingData(ConstByteSpan buffer) override {
47     if (num_bytes_expected_ > 0) {
48       std::copy(buffer.begin(), buffer.end(), std::back_inserter(received_));
49       num_bytes_expected_ -= std::min(num_bytes_expected_, buffer.size());
50     }
51     if (num_bytes_expected_ == 0) {
52       done_.release();
53     }
54     return OkStatus();
55   }
56 
received() const57   std::vector<std::byte> received() const { return received_; }
Wait()58   void Wait() { done_.acquire(); }
59 
60  private:
61   size_t num_bytes_expected_ = 0;
62   sync::ThreadNotification done_;
63   std::vector<std::byte> received_;
64 };
65 
66 class SocketSender {
67  public:
SocketSender(SocketRpcTransport<kReadBufferSize> & transport)68   SocketSender(SocketRpcTransport<kReadBufferSize>& transport)
69       : transport_(transport) {
70     unsigned char c = 0;
71     for (auto& i : data_) {
72       i = std::byte{c++};
73     }
74     std::mt19937 rg{0x12345678};
75     std::shuffle(data_.begin(), data_.end(), rg);
76   }
77 
sent()78   std::vector<std::byte> sent() { return sent_; }
79 
MakeFrame(size_t max_size)80   RpcFrame MakeFrame(size_t max_size) {
81     std::mt19937 rg{0x12345678};
82     size_t offset = offset_dist_(rg);
83     size_t message_size = std::min(size_dist_(rg), max_size);
84     size_t header_size = message_size > 4 ? 4 : message_size;
85     size_t payload_size = message_size > 4 ? message_size - 4 : 0;
86 
87     return RpcFrame{.header = span(data_).subspan(offset, header_size),
88                     .payload = span(data_).subspan(offset, payload_size)};
89   }
90 
Send(size_t num_bytes)91   void Send(size_t num_bytes) {
92     size_t bytes_written = 0;
93     while (bytes_written < num_bytes) {
94       auto frame = MakeFrame(num_bytes - bytes_written);
95       std::copy(
96           frame.header.begin(), frame.header.end(), std::back_inserter(sent_));
97       std::copy(frame.payload.begin(),
98                 frame.payload.end(),
99                 std::back_inserter(sent_));
100 
101       // Tests below expect to see all data written to the socket to be received
102       // by the other end, so we keep retrying on any errors that could happen
103       // during reconnection: in reality it would be up to the higher level
104       // abstractions to do this depending on how they manage buffers etc. For
105       // the tests we just keep retrying indefinitely: if there is a
106       // non-transient problem then the test will eventually time out.
107       while (true) {
108         const auto send_status = transport_.Send(frame);
109         if (send_status.ok()) {
110           break;
111         }
112       }
113 
114       bytes_written += frame.header.size() + frame.payload.size();
115     }
116   }
117 
118  private:
119   SocketRpcTransport<kReadBufferSize>& transport_;
120   std::vector<std::byte> sent_;
121   std::array<std::byte, 256> data_{};
122   std::uniform_int_distribution<size_t> offset_dist_{0, 255};
123   std::uniform_int_distribution<size_t> size_dist_{1, kMaxWriteSize};
124 };
125 
126 class SocketSenderThreadCore : public SocketSender, public thread::ThreadCore {
127  public:
SocketSenderThreadCore(SocketRpcTransport<kReadBufferSize> & transport,size_t write_size)128   SocketSenderThreadCore(SocketRpcTransport<kReadBufferSize>& transport,
129                          size_t write_size)
130       : SocketSender(transport), write_size_(write_size) {}
131 
132  private:
Run()133   void Run() override { Send(write_size_); }
134   size_t write_size_;
135 };
136 
TEST(SocketRpcTransportTest,SendAndReceiveFramesOverSocketConnection)137 TEST(SocketRpcTransportTest, SendAndReceiveFramesOverSocketConnection) {
138   constexpr size_t kWriteSize = 8192;
139 
140   TestIngress server_ingress(kWriteSize);
141   TestIngress client_ingress(kWriteSize);
142 
143   auto server = SocketRpcTransport<kReadBufferSize>(
144       SocketRpcTransport<kReadBufferSize>::kAsServer,
145       kServerPort,
146       server_ingress);
147   auto server_thread = Thread(thread::stl::Options(), server);
148 
149   server.WaitUntilReady();
150   auto server_port = server.port();
151 
152   auto client = SocketRpcTransport<kReadBufferSize>(
153       SocketRpcTransport<kReadBufferSize>::kAsClient,
154       "localhost",
155       server_port,
156       client_ingress);
157   auto client_thread = Thread(thread::stl::Options(), client);
158 
159   client.WaitUntilConnected();
160   server.WaitUntilConnected();
161 
162   SocketSenderThreadCore client_sender(client, kWriteSize);
163   SocketSenderThreadCore server_sender(server, kWriteSize);
164 
165   auto client_sender_thread = Thread(thread::stl::Options(), client_sender);
166   auto server_sender_thread = Thread(thread::stl::Options(), server_sender);
167 
168   client_sender_thread.join();
169   server_sender_thread.join();
170 
171   server_ingress.Wait();
172   client_ingress.Wait();
173 
174   server.Stop();
175   client.Stop();
176 
177   server_thread.join();
178   client_thread.join();
179 
180   auto received_by_server = server_ingress.received();
181   EXPECT_EQ(received_by_server.size(), kWriteSize);
182   EXPECT_TRUE(std::equal(received_by_server.begin(),
183                          received_by_server.end(),
184                          client_sender.sent().begin()));
185 
186   auto received_by_client = client_ingress.received();
187   EXPECT_EQ(received_by_client.size(), kWriteSize);
188   EXPECT_TRUE(std::equal(received_by_client.begin(),
189                          received_by_client.end(),
190                          server_sender.sent().begin()));
191 }
192 
TEST(SocketRpcTransportTest,ServerReconnects)193 TEST(SocketRpcTransportTest, ServerReconnects) {
194   // Set up a server and a client that reconnects multiple times. The server
195   // must accept the new connection gracefully.
196   constexpr size_t kWriteSize = 8192;
197   std::vector<std::byte> received;
198 
199   TestIngress server_ingress(0);
200   auto server = SocketRpcTransport<kReadBufferSize>(
201       SocketRpcTransport<kReadBufferSize>::kAsServer,
202       kServerPort,
203       server_ingress);
204   auto server_thread = Thread(thread::stl::Options(), server);
205 
206   server.WaitUntilReady();
207   auto server_port = server.port();
208   SocketSender server_sender(server);
209 
210   {
211     TestIngress client_ingress(kWriteSize);
212     auto client = SocketRpcTransport<kReadBufferSize>(
213         SocketRpcTransport<kReadBufferSize>::kAsClient,
214         "localhost",
215         server_port,
216         client_ingress);
217     auto client_thread = Thread(thread::stl::Options(), client);
218 
219     client.WaitUntilConnected();
220     server.WaitUntilConnected();
221 
222     server_sender.Send(kWriteSize);
223     client_ingress.Wait();
224     auto client_received = client_ingress.received();
225     std::copy(client_received.begin(),
226               client_received.end(),
227               std::back_inserter(received));
228     EXPECT_EQ(received.size(), kWriteSize);
229 
230     // Stop the client but not the server: we're re-using the same server
231     // with a new client below.
232     client.Stop();
233     client_thread.join();
234   }
235 
236   // Reconnect to the server and keep sending frames.
237   {
238     TestIngress client_ingress(kWriteSize);
239     auto client = SocketRpcTransport<kReadBufferSize>(
240         SocketRpcTransport<kReadBufferSize>::kAsClient,
241         "localhost",
242         server_port,
243         client_ingress);
244     auto client_thread = Thread(thread::stl::Options(), client);
245 
246     client.WaitUntilConnected();
247     server.WaitUntilConnected();
248 
249     server_sender.Send(kWriteSize);
250     client_ingress.Wait();
251     auto client_received = client_ingress.received();
252     std::copy(client_received.begin(),
253               client_received.end(),
254               std::back_inserter(received));
255 
256     client.Stop();
257     client_thread.join();
258 
259     // This time stop the server as well.
260     SocketSender client_sender(client);
261     server.Stop();
262     server_thread.join();
263   }
264 
265   EXPECT_EQ(received.size(), 2 * kWriteSize);
266   EXPECT_EQ(server_sender.sent().size(), 2 * kWriteSize);
267   EXPECT_TRUE(std::equal(
268       received.begin(), received.end(), server_sender.sent().begin()));
269 }
270 
TEST(SocketRpcTransportTest,ClientReconnects)271 TEST(SocketRpcTransportTest, ClientReconnects) {
272   // Set up a server and a client, then recycle the server. The client must
273   // must reconnect gracefully.
274   constexpr size_t kWriteSize = 8192;
275   uint16_t server_port = 0;
276 
277   TestIngress server_ingress(0);
278   TestIngress client_ingress(2 * kWriteSize);
279 
280   auto server = std::make_unique<SocketRpcTransport<kReadBufferSize>>(
281       SocketRpcTransport<kReadBufferSize>::kAsServer,
282       kServerPort,
283       server_ingress);
284   auto server_thread = Thread(thread::stl::Options(), *server);
285 
286   server->WaitUntilReady();
287   server_port = server->port();
288 
289   auto client = SocketRpcTransport<kReadBufferSize>(
290       SocketRpcTransport<kReadBufferSize>::kAsClient,
291       "localhost",
292       server_port,
293       client_ingress);
294   auto client_thread = Thread(thread::stl::Options(), client);
295 
296   client.WaitUntilConnected();
297   server->WaitUntilConnected();
298 
299   SocketSender client_sender(client);
300   SocketSender server1_sender(*server);
301   std::vector<std::byte> sent_by_server;
302 
303   server1_sender.Send(kWriteSize);
304   server->Stop();
305   auto server1_sent = server1_sender.sent();
306   std::copy(server1_sent.begin(),
307             server1_sent.end(),
308             std::back_inserter(sent_by_server));
309 
310   server_thread.join();
311   server = nullptr;
312 
313   server = std::make_unique<SocketRpcTransport<kReadBufferSize>>(
314       SocketRpcTransport<kReadBufferSize>::kAsServer,
315       server_port,
316       server_ingress);
317   SocketSender server2_sender(*server);
318   server_thread = Thread(thread::stl::Options(), *server);
319 
320   client.WaitUntilConnected();
321   server->WaitUntilConnected();
322 
323   server2_sender.Send(kWriteSize);
324   client_ingress.Wait();
325 
326   server->Stop();
327   auto server2_sent = server2_sender.sent();
328   std::copy(server2_sent.begin(),
329             server2_sent.end(),
330             std::back_inserter(sent_by_server));
331 
332   server_thread.join();
333 
334   client.Stop();
335   client_thread.join();
336   server = nullptr;
337 
338   auto received_by_client = client_ingress.received();
339   EXPECT_EQ(received_by_client.size(), 2 * kWriteSize);
340   EXPECT_TRUE(std::equal(received_by_client.begin(),
341                          received_by_client.end(),
342                          sent_by_server.begin()));
343 }
344 
345 }  // namespace
346 }  // namespace pw::rpc
347