xref: /aosp_15_r20/external/pigweed/pw_rpc/server_test.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_rpc/server.h"
16 
17 #include <array>
18 #include <cstdint>
19 
20 #include "pw_assert/check.h"
21 #include "pw_rpc/internal/call.h"
22 #include "pw_rpc/internal/method.h"
23 #include "pw_rpc/internal/packet.h"
24 #include "pw_rpc/internal/test_utils.h"
25 #include "pw_rpc/service.h"
26 #include "pw_rpc_private/fake_server_reader_writer.h"
27 #include "pw_rpc_private/test_method.h"
28 #include "pw_unit_test/framework.h"
29 
30 namespace pw::rpc {
31 
32 class ServerTestHelper {
33  public:
FindMethod(Server & server,uint32_t service_id,uint32_t method_id)34   static std::tuple<Service*, const internal::Method*> FindMethod(
35       Server& server, uint32_t service_id, uint32_t method_id) {
36     return server.FindMethod(service_id, method_id);
37   }
38 };
39 
40 namespace {
41 
42 using std::byte;
43 
44 using internal::Packet;
45 using internal::TestMethod;
46 using internal::TestMethodUnion;
47 using internal::pwpb::PacketType;
48 
49 class TestService : public Service {
50  public:
TestService(uint32_t service_id)51   TestService(uint32_t service_id)
52       : Service(service_id, methods_),
53         methods_{
54             TestMethod(100, MethodType::kBidirectionalStreaming),
55             TestMethod(200),
56         } {}
57 
method(uint32_t id)58   const TestMethod& method(uint32_t id) {
59     for (TestMethodUnion& method : methods_) {
60       if (method.method().id() == id) {
61         return method.test_method();
62       }
63     }
64 
65     PW_CRASH("Invalid method ID %u", static_cast<unsigned>(id));
66   }
67 
68  private:
69   std::array<TestMethodUnion, 2> methods_;
70 };
71 
72 class EmptyService : public Service {
73  public:
EmptyService()74   constexpr EmptyService() : Service(200, methods_) {}
75 
76  private:
77   static constexpr std::array<TestMethodUnion, 0> methods_ = {};
78 };
79 
80 uint32_t kDefaultCallId = 24601;
81 
82 class BasicServer : public ::testing::Test {
83  protected:
84   static constexpr byte kDefaultPayload[] = {
85       byte(0x82), byte(0x02), byte(0xff), byte(0xff)};
86 
BasicServer()87   BasicServer()
88       : channels_{
89             Channel::Create<1>(&output_),
90             Channel::Create<2>(&output_),
91             Channel(),  // available for assignment
92         },
93         server_(channels_),
94         service_1_(1),
95         service_42_(42) {
96     server_.RegisterService(service_1_, service_42_, empty_service_);
97   }
98 
EncodePacket(PacketType type,uint32_t channel_id,uint32_t service_id,uint32_t method_id,uint32_t call_id=kDefaultCallId)99   span<const byte> EncodePacket(PacketType type,
100                                 uint32_t channel_id,
101                                 uint32_t service_id,
102                                 uint32_t method_id,
103                                 uint32_t call_id = kDefaultCallId) {
104     return EncodePacketWithBody(type,
105                                 channel_id,
106                                 service_id,
107                                 method_id,
108                                 call_id,
109                                 kDefaultPayload,
110                                 OkStatus());
111   }
112 
EncodeCancel(uint32_t channel_id=1,uint32_t service_id=42,uint32_t method_id=100,uint32_t call_id=kDefaultCallId)113   span<const byte> EncodeCancel(uint32_t channel_id = 1,
114                                 uint32_t service_id = 42,
115                                 uint32_t method_id = 100,
116                                 uint32_t call_id = kDefaultCallId) {
117     return EncodePacketWithBody(PacketType::CLIENT_ERROR,
118                                 channel_id,
119                                 service_id,
120                                 method_id,
121                                 call_id,
122                                 {},
123                                 Status::Cancelled());
124   }
125 
126   template <typename T = ConstByteSpan>
PacketForRpc(PacketType type,Status status=OkStatus (),T && payload={},uint32_t call_id=kDefaultCallId)127   ConstByteSpan PacketForRpc(PacketType type,
128                              Status status = OkStatus(),
129                              T&& payload = {},
130                              uint32_t call_id = kDefaultCallId) {
131     return EncodePacketWithBody(
132         type, 1, 42, 100, call_id, as_bytes(span(payload)), status);
133   }
134 
135   RawFakeChannelOutput<2> output_;
136   std::array<Channel, 3> channels_;
137   Server server_;
138   TestService service_1_;
139   TestService service_42_;
140   EmptyService empty_service_;
141 
142  private:
143   byte request_buffer_[64];
144 
EncodePacketWithBody(PacketType type,uint32_t channel_id,uint32_t service_id,uint32_t method_id,uint32_t call_id,span<const byte> payload,Status status)145   span<const byte> EncodePacketWithBody(PacketType type,
146                                         uint32_t channel_id,
147                                         uint32_t service_id,
148                                         uint32_t method_id,
149                                         uint32_t call_id,
150                                         span<const byte> payload,
151                                         Status status) {
152     auto result =
153         Packet(
154             type, channel_id, service_id, method_id, call_id, payload, status)
155             .Encode(request_buffer_);
156     EXPECT_EQ(OkStatus(), result.status());
157     return result.value_or(ConstByteSpan());
158   }
159 };
160 
TEST_F(BasicServer,IsServiceRegistered)161 TEST_F(BasicServer, IsServiceRegistered) {
162   TestService unregisteredService(0);
163   EXPECT_FALSE(server_.IsServiceRegistered(unregisteredService));
164   EXPECT_TRUE(server_.IsServiceRegistered(service_1_));
165 }
166 
TEST_F(BasicServer,ProcessPacket_ValidMethodInService1_InvokesMethod)167 TEST_F(BasicServer, ProcessPacket_ValidMethodInService1_InvokesMethod) {
168   EXPECT_EQ(
169       OkStatus(),
170       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 1, 100)));
171 
172   const TestMethod& method = service_1_.method(100);
173   EXPECT_EQ(1u, method.last_channel_id());
174   ASSERT_EQ(sizeof(kDefaultPayload), method.last_request().payload().size());
175   EXPECT_EQ(std::memcmp(kDefaultPayload,
176                         method.last_request().payload().data(),
177                         method.last_request().payload().size()),
178             0);
179 }
180 
TEST_F(BasicServer,ProcessPacket_ValidMethodInService42_InvokesMethod)181 TEST_F(BasicServer, ProcessPacket_ValidMethodInService42_InvokesMethod) {
182   EXPECT_EQ(
183       OkStatus(),
184       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 200)));
185 
186   const TestMethod& method = service_42_.method(200);
187   EXPECT_EQ(1u, method.last_channel_id());
188   ASSERT_EQ(sizeof(kDefaultPayload), method.last_request().payload().size());
189   EXPECT_EQ(std::memcmp(kDefaultPayload,
190                         method.last_request().payload().data(),
191                         method.last_request().payload().size()),
192             0);
193 }
194 
TEST_F(BasicServer,UnregisterService_CannotCallMethod)195 TEST_F(BasicServer, UnregisterService_CannotCallMethod) {
196   const uint32_t kCallId = 8675309;
197   server_.UnregisterService(service_1_, service_42_);
198 
199   EXPECT_EQ(OkStatus(),
200             server_.ProcessPacket(
201                 EncodePacket(PacketType::REQUEST, 1, 1, 100, kCallId)));
202 
203   const Packet& packet =
204       static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
205   EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
206   EXPECT_EQ(packet.channel_id(), 1u);
207   EXPECT_EQ(packet.service_id(), 1u);
208   EXPECT_EQ(packet.method_id(), 100u);
209   EXPECT_EQ(packet.call_id(), kCallId);
210   EXPECT_EQ(packet.status(), Status::NotFound());
211 }
212 
TEST_F(BasicServer,UnregisterService_AlreadyUnregistered_DoesNothing)213 TEST_F(BasicServer, UnregisterService_AlreadyUnregistered_DoesNothing) {
214   server_.UnregisterService(service_42_, service_42_, service_42_);
215   server_.UnregisterService(service_42_);
216 
217   EXPECT_EQ(
218       OkStatus(),
219       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 1, 100)));
220 
221   const TestMethod& method = service_1_.method(100);
222   EXPECT_EQ(1u, method.last_channel_id());
223   ASSERT_EQ(sizeof(kDefaultPayload), method.last_request().payload().size());
224   EXPECT_EQ(std::memcmp(kDefaultPayload,
225                         method.last_request().payload().data(),
226                         method.last_request().payload().size()),
227             0);
228 }
229 
TEST_F(BasicServer,ProcessPacket_IncompletePacket_NothingIsInvoked)230 TEST_F(BasicServer, ProcessPacket_IncompletePacket_NothingIsInvoked) {
231   EXPECT_EQ(
232       Status::DataLoss(),
233       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 0, 42, 101)));
234   EXPECT_EQ(
235       Status::DataLoss(),
236       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 0, 101)));
237   EXPECT_EQ(Status::DataLoss(),
238             server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 0)));
239 
240   EXPECT_EQ(0u, service_42_.method(100).last_channel_id());
241   EXPECT_EQ(0u, service_42_.method(200).last_channel_id());
242 }
243 
TEST_F(BasicServer,ProcessPacket_NoChannel_SendsNothing)244 TEST_F(BasicServer, ProcessPacket_NoChannel_SendsNothing) {
245   EXPECT_EQ(
246       Status::DataLoss(),
247       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 0, 42, 101)));
248 
249   EXPECT_EQ(output_.total_packets(), 0u);
250 }
251 
TEST_F(BasicServer,ProcessPacket_NoService_SendsNothing)252 TEST_F(BasicServer, ProcessPacket_NoService_SendsNothing) {
253   EXPECT_EQ(
254       Status::DataLoss(),
255       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 0, 101)));
256 
257   EXPECT_EQ(output_.total_packets(), 0u);
258 }
259 
TEST_F(BasicServer,ProcessPacket_NoMethod_SendsNothing)260 TEST_F(BasicServer, ProcessPacket_NoMethod_SendsNothing) {
261   EXPECT_EQ(Status::DataLoss(),
262             server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 0)));
263 
264   EXPECT_EQ(output_.total_packets(), 0u);
265 }
266 
TEST_F(BasicServer,ProcessPacket_InvalidMethod_NothingIsInvoked)267 TEST_F(BasicServer, ProcessPacket_InvalidMethod_NothingIsInvoked) {
268   EXPECT_EQ(
269       OkStatus(),
270       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 101)));
271 
272   EXPECT_EQ(0u, service_42_.method(100).last_channel_id());
273   EXPECT_EQ(0u, service_42_.method(200).last_channel_id());
274 }
275 
TEST_F(BasicServer,ProcessPacket_ClientErrorWithInvalidMethod_NoResponse)276 TEST_F(BasicServer, ProcessPacket_ClientErrorWithInvalidMethod_NoResponse) {
277   EXPECT_EQ(OkStatus(),
278             server_.ProcessPacket(
279                 EncodePacket(PacketType::CLIENT_ERROR, 1, 42, 101)));
280 
281   EXPECT_EQ(0u, output_.total_packets());
282 }
283 
TEST_F(BasicServer,ProcessPacket_InvalidMethod_SendsError)284 TEST_F(BasicServer, ProcessPacket_InvalidMethod_SendsError) {
285   EXPECT_EQ(
286       OkStatus(),
287       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 27)));
288 
289   const Packet& packet =
290       static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
291   EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
292   EXPECT_EQ(packet.channel_id(), 1u);
293   EXPECT_EQ(packet.service_id(), 42u);
294   EXPECT_EQ(packet.method_id(), 27u);  // No method ID 27
295   EXPECT_EQ(packet.status(), Status::NotFound());
296 }
297 
TEST_F(BasicServer,ProcessPacket_InvalidService_SendsError)298 TEST_F(BasicServer, ProcessPacket_InvalidService_SendsError) {
299   EXPECT_EQ(
300       OkStatus(),
301       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 43, 27)));
302 
303   const Packet& packet =
304       static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
305   EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
306   EXPECT_EQ(packet.channel_id(), 1u);
307   EXPECT_EQ(packet.service_id(), 43u);  // No service ID 43
308   EXPECT_EQ(packet.method_id(), 27u);
309   EXPECT_EQ(packet.status(), Status::NotFound());
310 }
311 
TEST_F(BasicServer,ProcessPacket_UnassignedChannel)312 TEST_F(BasicServer, ProcessPacket_UnassignedChannel) {
313   EXPECT_EQ(Status::Unavailable(),
314             server_.ProcessPacket(
315                 EncodePacket(PacketType::REQUEST, /*channel_id=*/99, 42, 27)));
316 }
317 
TEST_F(BasicServer,ProcessPacket_ClientErrorOnUnassignedChannel_NoResponse)318 TEST_F(BasicServer, ProcessPacket_ClientErrorOnUnassignedChannel_NoResponse) {
319   channels_[2] = Channel::Create<3>(&output_);  // Occupy only available channel
320 
321   EXPECT_EQ(Status::Unavailable(),
322             server_.ProcessPacket(EncodePacket(
323                 PacketType::CLIENT_ERROR, /*channel_id=*/99, 42, 27)));
324 
325   EXPECT_EQ(0u, output_.total_packets());
326 }
327 
TEST_F(BasicServer,ProcessPacket_Cancel_MethodNotActive_SendsNothing)328 TEST_F(BasicServer, ProcessPacket_Cancel_MethodNotActive_SendsNothing) {
329   // Set up a fake ServerWriter representing an ongoing RPC.
330   EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(1, 42, 100)));
331 
332   EXPECT_EQ(output_.total_packets(), 0u);
333 }
334 
GetChannel(internal::Endpoint & endpoint,uint32_t id)335 const internal::ChannelBase* GetChannel(internal::Endpoint& endpoint,
336                                         uint32_t id) {
337   internal::RpcLockGuard lock;
338   return endpoint.GetInternalChannel(id);
339 }
340 
TEST_F(BasicServer,CloseChannel_NoCalls)341 TEST_F(BasicServer, CloseChannel_NoCalls) {
342   EXPECT_NE(nullptr, GetChannel(server_, 2));
343   EXPECT_EQ(OkStatus(), server_.CloseChannel(2));
344   EXPECT_EQ(nullptr, GetChannel(server_, 2));
345   ASSERT_EQ(output_.total_packets(), 0u);
346 }
347 
TEST_F(BasicServer,CloseChannel_UnknownChannel)348 TEST_F(BasicServer, CloseChannel_UnknownChannel) {
349   ASSERT_EQ(nullptr, GetChannel(server_, 13579));
350   EXPECT_EQ(Status::NotFound(), server_.CloseChannel(13579));
351 }
352 
TEST_F(BasicServer,CloseChannel_PendingCall)353 TEST_F(BasicServer, CloseChannel_PendingCall) {
354   EXPECT_NE(nullptr, GetChannel(server_, 1));
355   EXPECT_EQ(static_cast<internal::Endpoint&>(server_).active_call_count(), 0u);
356 
357   internal::test::FakeServerReaderWriter call;
358   service_42_.method(100).keep_call_active(call);
359 
360   EXPECT_EQ(
361       OkStatus(),
362       server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 100)));
363 
364   Status on_error_status;
365   call.set_on_error(
366       [&on_error_status](Status error) { on_error_status = error; });
367 
368   ASSERT_TRUE(call.active());
369   EXPECT_EQ(static_cast<internal::Endpoint&>(server_).active_call_count(), 1u);
370 
371   EXPECT_EQ(OkStatus(), server_.CloseChannel(1));
372   EXPECT_EQ(nullptr, GetChannel(server_, 1));
373 
374   EXPECT_EQ(static_cast<internal::Endpoint&>(server_).active_call_count(), 0u);
375 
376   // Should call on_error, but not send a packet since the channel is closed.
377   EXPECT_EQ(Status::Aborted(), on_error_status);
378   ASSERT_EQ(output_.total_packets(), 0u);
379 }
380 
TEST_F(BasicServer,OpenChannel_UnusedSlot)381 TEST_F(BasicServer, OpenChannel_UnusedSlot) {
382   const span request = EncodePacket(PacketType::REQUEST, 9, 42, 100);
383   EXPECT_EQ(Status::Unavailable(), server_.ProcessPacket(request));
384 
385   EXPECT_EQ(OkStatus(), server_.OpenChannel(9, output_));
386   EXPECT_EQ(OkStatus(), server_.ProcessPacket(request));
387 
388   const Packet& packet =
389       static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
390   EXPECT_EQ(packet.type(), PacketType::RESPONSE);
391   EXPECT_EQ(packet.channel_id(), 9u);
392   EXPECT_EQ(packet.service_id(), 42u);
393   EXPECT_EQ(packet.method_id(), 100u);
394 }
395 
TEST_F(BasicServer,OpenChannel_AlreadyExists)396 TEST_F(BasicServer, OpenChannel_AlreadyExists) {
397   ASSERT_NE(nullptr, GetChannel(server_, 1));
398   EXPECT_EQ(Status::AlreadyExists(), server_.OpenChannel(1, output_));
399 }
400 
TEST_F(BasicServer,OpenChannel_AdditionalSlot)401 TEST_F(BasicServer, OpenChannel_AdditionalSlot) {
402   EXPECT_EQ(OkStatus(), server_.OpenChannel(3, output_));
403 
404   constexpr Status kExpected =
405       PW_RPC_DYNAMIC_ALLOCATION == 0 ? Status::ResourceExhausted() : OkStatus();
406   EXPECT_EQ(kExpected, server_.OpenChannel(19823, output_));
407 }
408 
TEST_F(BasicServer,FindMethod_FoundOkOptionallyCheckType)409 TEST_F(BasicServer, FindMethod_FoundOkOptionallyCheckType) {
410   const auto [service, method] = ServerTestHelper::FindMethod(server_, 1, 100);
411   ASSERT_TRUE(service != nullptr);
412   ASSERT_TRUE(method != nullptr);
413 #if PW_RPC_METHOD_STORES_TYPE
414   EXPECT_EQ(MethodType::kBidirectionalStreaming, method->type());
415 #endif
416 }
417 
TEST_F(BasicServer,FindMethod_NotFound)418 TEST_F(BasicServer, FindMethod_NotFound) {
419   {
420     const auto [service, method] =
421         ServerTestHelper::FindMethod(server_, 2, 100);
422     ASSERT_TRUE(service == nullptr);
423     ASSERT_TRUE(method == nullptr);
424   }
425 
426   {
427     const auto [service, method] =
428         ServerTestHelper::FindMethod(server_, 1, 101);
429     ASSERT_TRUE(service != nullptr);
430     ASSERT_TRUE(method == nullptr);
431   }
432 }
433 
434 class BidiMethod : public BasicServer {
435  protected:
BidiMethod()436   BidiMethod() {
437     internal::rpc_lock().lock();
438     internal::CallContext context(server_,
439                                   channels_[0].id(),
440                                   service_42_,
441                                   service_42_.method(100),
442                                   kDefaultCallId);
443     // A local temporary is required since the constructor requires a lock,
444     // but the *move* constructor takes out the lock.
445     internal::test::FakeServerReaderWriter responder_temp(
446         context.ClaimLocked());
447     internal::rpc_lock().unlock();
448     responder_ = std::move(responder_temp);
449     PW_CHECK(responder_.active());
450   }
451 
452   internal::test::FakeServerReaderWriter responder_;
453 };
454 
TEST_F(BidiMethod,DuplicateCallId_CancelsExistingThenCallsAgain)455 TEST_F(BidiMethod, DuplicateCallId_CancelsExistingThenCallsAgain) {
456   int cancelled = 0;
457   responder_.set_on_error([&cancelled](Status error) {
458     if (error.IsCancelled()) {
459       cancelled += 1;
460     }
461   });
462 
463   const TestMethod& method = service_42_.method(100);
464   ASSERT_EQ(method.invocations(), 0u);
465 
466   EXPECT_EQ(OkStatus(),
467             server_.ProcessPacket(PacketForRpc(PacketType::REQUEST)));
468 
469   EXPECT_EQ(cancelled, 1);
470   EXPECT_EQ(method.invocations(), 1u);
471 }
472 
TEST_F(BidiMethod,DuplicateMethodDifferentCallId_NotCancelled)473 TEST_F(BidiMethod, DuplicateMethodDifferentCallId_NotCancelled) {
474   int cancelled = 0;
475   responder_.set_on_error([&cancelled](Status error) {
476     if (error.IsCancelled()) {
477       cancelled += 1;
478     }
479   });
480 
481   const uint32_t kSecondCallId = 1625;
482   EXPECT_EQ(OkStatus(),
483             server_.ProcessPacket(PacketForRpc(
484                 PacketType::REQUEST, OkStatus(), {}, kSecondCallId)));
485 
486   EXPECT_EQ(cancelled, 0);
487 }
488 
span_as_cstr(ConstByteSpan span)489 const char* span_as_cstr(ConstByteSpan span) {
490   return reinterpret_cast<const char*>(span.data());
491 }
492 
TEST_F(BidiMethod,DuplicateMethodDifferentCallIdEachCallGetsSeparateResponse)493 TEST_F(BidiMethod, DuplicateMethodDifferentCallIdEachCallGetsSeparateResponse) {
494   const uint32_t kSecondCallId = 1625;
495 
496   internal::rpc_lock().lock();
497   internal::test::FakeServerReaderWriter responder_2(
498       internal::CallContext(server_,
499                             channels_[0].id(),
500                             service_42_,
501                             service_42_.method(100),
502                             kSecondCallId)
503           .ClaimLocked());
504   internal::rpc_lock().unlock();
505 
506   ConstByteSpan data_1 = as_bytes(span("data_1_unset"));
507   responder_.set_on_next(
508       [&data_1](ConstByteSpan payload) { data_1 = payload; });
509 
510   ConstByteSpan data_2 = as_bytes(span("data_2_unset"));
511   responder_2.set_on_next(
512       [&data_2](ConstByteSpan payload) { data_2 = payload; });
513 
514   const char* kMessage1 = "hello_1";
515   const char* kMessage2 = "hello_2";
516 
517   EXPECT_EQ(
518       OkStatus(),
519       server_.ProcessPacket(PacketForRpc(
520           PacketType::CLIENT_STREAM, OkStatus(), "hello_2", kSecondCallId)));
521 
522   EXPECT_STREQ(span_as_cstr(data_2), kMessage2);
523 
524   EXPECT_EQ(
525       OkStatus(),
526       server_.ProcessPacket(PacketForRpc(
527           PacketType::CLIENT_STREAM, OkStatus(), "hello_1", kDefaultCallId)));
528 
529   EXPECT_STREQ(span_as_cstr(data_1), kMessage1);
530 }
531 
TEST_F(BidiMethod,Cancel_ClosesServerWriter)532 TEST_F(BidiMethod, Cancel_ClosesServerWriter) {
533   EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
534 
535   EXPECT_FALSE(responder_.active());
536 }
537 
TEST_F(BidiMethod,Cancel_SendsNoResponse)538 TEST_F(BidiMethod, Cancel_SendsNoResponse) {
539   EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
540 
541   EXPECT_EQ(output_.total_packets(), 0u);
542 }
543 
TEST_F(BidiMethod,ClientError_ClosesServerWriterWithoutResponse)544 TEST_F(BidiMethod, ClientError_ClosesServerWriterWithoutResponse) {
545   ASSERT_EQ(OkStatus(),
546             server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_ERROR)));
547 
548   EXPECT_FALSE(responder_.active());
549   EXPECT_EQ(output_.total_packets(), 0u);
550 }
551 
TEST_F(BidiMethod,ClientError_CallsOnErrorCallback)552 TEST_F(BidiMethod, ClientError_CallsOnErrorCallback) {
553   Status status = Status::Unknown();
554   responder_.set_on_error([&status](Status error) { status = error; });
555 
556   ASSERT_EQ(OkStatus(),
557             server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_ERROR,
558                                                Status::Unauthenticated())));
559 
560   EXPECT_EQ(status, Status::Unauthenticated());
561 }
562 
TEST_F(BidiMethod,Cancel_CallsOnErrorCallback)563 TEST_F(BidiMethod, Cancel_CallsOnErrorCallback) {
564   Status status = Status::Unknown();
565   responder_.set_on_error([&status](Status error) { status = error; });
566 
567   ASSERT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
568   EXPECT_EQ(status, Status::Cancelled());
569 }
570 
TEST_F(BidiMethod,Cancel_IncorrectChannel_SendsNothing)571 TEST_F(BidiMethod, Cancel_IncorrectChannel_SendsNothing) {
572   EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(2, 42, 100)));
573 
574   EXPECT_EQ(output_.total_packets(), 0u);
575   EXPECT_TRUE(responder_.active());
576 }
577 
TEST_F(BidiMethod,Cancel_IncorrectService_SendsNothing)578 TEST_F(BidiMethod, Cancel_IncorrectService_SendsNothing) {
579   EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(1, 43, 100)));
580   EXPECT_EQ(output_.total_packets(), 0u);
581   EXPECT_TRUE(responder_.active());
582 }
583 
TEST_F(BidiMethod,Cancel_IncorrectMethod_SendsNothing)584 TEST_F(BidiMethod, Cancel_IncorrectMethod_SendsNothing) {
585   EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(1, 42, 101)));
586   EXPECT_EQ(output_.total_packets(), 0u);
587   EXPECT_TRUE(responder_.active());
588 }
589 
TEST_F(BidiMethod,ClientStream_CallsCallback)590 TEST_F(BidiMethod, ClientStream_CallsCallback) {
591   ConstByteSpan data = as_bytes(span("?"));
592   responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
593 
594   ASSERT_EQ(OkStatus(),
595             server_.ProcessPacket(
596                 PacketForRpc(PacketType::CLIENT_STREAM, {}, "hello")));
597 
598   EXPECT_EQ(output_.total_packets(), 0u);
599   EXPECT_STREQ(span_as_cstr(data), "hello");
600 }
601 
TEST_F(BidiMethod,ClientStream_CallsCallbackOnCallWithOpenId)602 TEST_F(BidiMethod, ClientStream_CallsCallbackOnCallWithOpenId) {
603   ConstByteSpan data = as_bytes(span("?"));
604   responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
605 
606   ASSERT_EQ(
607       OkStatus(),
608       server_.ProcessPacket(PacketForRpc(
609           PacketType::CLIENT_STREAM, {}, "hello", internal::kOpenCallId)));
610 
611   EXPECT_EQ(output_.total_packets(), 0u);
612   EXPECT_STREQ(span_as_cstr(data), "hello");
613 }
614 
TEST_F(BidiMethod,ClientStream_CallsCallbackOnCallWithLegacyOpenId)615 TEST_F(BidiMethod, ClientStream_CallsCallbackOnCallWithLegacyOpenId) {
616   ConstByteSpan data = as_bytes(span("?"));
617   responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
618 
619   ASSERT_EQ(OkStatus(),
620             server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_STREAM,
621                                                {},
622                                                "hello",
623                                                internal::kLegacyOpenCallId)));
624 
625   EXPECT_EQ(output_.total_packets(), 0u);
626   EXPECT_STREQ(span_as_cstr(data), "hello");
627 }
628 
TEST_F(BidiMethod,ClientStream_CallsOpenIdOnCallWithDifferentId)629 TEST_F(BidiMethod, ClientStream_CallsOpenIdOnCallWithDifferentId) {
630   const uint32_t kSecondCallId = 1625;
631   internal::CallContext context(server_,
632                                 channels_[0].id(),
633                                 service_42_,
634                                 service_42_.method(100),
635                                 internal::kOpenCallId);
636   internal::rpc_lock().lock();
637   auto temp_responder =
638       internal::test::FakeServerReaderWriter(context.ClaimLocked());
639   internal::rpc_lock().unlock();
640   responder_ = std::move(temp_responder);
641 
642   ConstByteSpan data = as_bytes(span("?"));
643   responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
644 
645   ASSERT_EQ(OkStatus(),
646             server_.ProcessPacket(PacketForRpc(
647                 PacketType::CLIENT_STREAM, {}, "hello", kSecondCallId)));
648 
649   EXPECT_EQ(output_.total_packets(), 0u);
650   EXPECT_STREQ(span_as_cstr(data), "hello");
651 
652   internal::RpcLockGuard lock;
653   EXPECT_EQ(responder_.as_server_call().id(), kSecondCallId);
654 }
655 
TEST_F(BidiMethod,ClientStream_CallsLegacyOpenIdOnCallWithDifferentId)656 TEST_F(BidiMethod, ClientStream_CallsLegacyOpenIdOnCallWithDifferentId) {
657   const uint32_t kSecondCallId = 1625;
658   internal::CallContext context(server_,
659                                 channels_[0].id(),
660                                 service_42_,
661                                 service_42_.method(100),
662                                 internal::kLegacyOpenCallId);
663   internal::rpc_lock().lock();
664   auto temp_responder =
665       internal::test::FakeServerReaderWriter(context.ClaimLocked());
666   internal::rpc_lock().unlock();
667   responder_ = std::move(temp_responder);
668 
669   ConstByteSpan data = as_bytes(span("?"));
670   responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
671 
672   ASSERT_EQ(OkStatus(),
673             server_.ProcessPacket(PacketForRpc(
674                 PacketType::CLIENT_STREAM, {}, "hello", kSecondCallId)));
675 
676   EXPECT_EQ(output_.total_packets(), 0u);
677   EXPECT_STREQ(span_as_cstr(data), "hello");
678 
679   internal::RpcLockGuard lock;
680   EXPECT_EQ(responder_.as_server_call().id(), kSecondCallId);
681 }
682 
TEST_F(BidiMethod,UnregsiterService_AbortsActiveCalls)683 TEST_F(BidiMethod, UnregsiterService_AbortsActiveCalls) {
684   ASSERT_TRUE(responder_.active());
685 
686   Status on_error_status = OkStatus();
687   responder_.set_on_error(
688       [&on_error_status](Status status) { on_error_status = status; });
689 
690   server_.UnregisterService(service_42_);
691 
692   EXPECT_FALSE(responder_.active());
693   EXPECT_EQ(Status::Aborted(), on_error_status);
694 }
695 
TEST_F(BidiMethod,ClientRequestedCompletion_CallsCallback)696 TEST_F(BidiMethod, ClientRequestedCompletion_CallsCallback) {
697   bool called = false;
698 #if PW_RPC_COMPLETION_REQUEST_CALLBACK
699   responder_.set_on_completion_requested([&called]() { called = true; });
700 #endif
701   ASSERT_EQ(OkStatus(),
702             server_.ProcessPacket(
703                 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
704 
705   EXPECT_EQ(output_.total_packets(), 0u);
706   EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
707 }
708 
TEST_F(BidiMethod,ClientRequestedCompletion_CallsCallbackIfEnabled)709 TEST_F(BidiMethod, ClientRequestedCompletion_CallsCallbackIfEnabled) {
710   bool called = false;
711   responder_.set_on_completion_requested_if_enabled(
712       [&called]() { called = true; });
713 
714   ASSERT_EQ(OkStatus(),
715             server_.ProcessPacket(
716                 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
717 
718   EXPECT_EQ(output_.total_packets(), 0u);
719   EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
720 }
721 
TEST_F(BidiMethod,ClientRequestedCompletion_ErrorWhenClosed)722 TEST_F(BidiMethod, ClientRequestedCompletion_ErrorWhenClosed) {
723   const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
724   ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
725   ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
726 
727   ASSERT_EQ(output_.total_packets(), 0u);
728 }
729 
TEST_F(BidiMethod,ClientRequestedCompletion_ErrorWhenAlreadyClosed)730 TEST_F(BidiMethod, ClientRequestedCompletion_ErrorWhenAlreadyClosed) {
731   ASSERT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
732   EXPECT_FALSE(responder_.active());
733 
734   const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
735   ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
736 
737   ASSERT_EQ(output_.total_packets(), 1u);
738   const Packet& packet =
739       static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
740   EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
741   EXPECT_EQ(packet.status(), Status::FailedPrecondition());
742 }
743 
744 class ServerStreamingMethod : public BasicServer {
745  protected:
ServerStreamingMethod()746   ServerStreamingMethod() {
747     internal::CallContext context(server_,
748                                   channels_[0].id(),
749                                   service_42_,
750                                   service_42_.method(100),
751                                   kDefaultCallId);
752     internal::rpc_lock().lock();
753     internal::test::FakeServerWriter responder_temp(context.ClaimLocked());
754     internal::rpc_lock().unlock();
755     responder_ = std::move(responder_temp);
756     PW_CHECK(responder_.active());
757   }
758 
759   internal::test::FakeServerWriter responder_;
760 };
761 
TEST_F(ServerStreamingMethod,ClientStream_InvalidArgumentError)762 TEST_F(ServerStreamingMethod, ClientStream_InvalidArgumentError) {
763   ASSERT_EQ(OkStatus(),
764             server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_STREAM)));
765 
766   ASSERT_EQ(output_.total_packets(), 1u);
767   const Packet& packet =
768       static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
769   EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
770   EXPECT_EQ(packet.status(), Status::InvalidArgument());
771 }
772 
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_CallsCallback)773 TEST_F(ServerStreamingMethod, ClientRequestedCompletion_CallsCallback) {
774   bool called = false;
775 #if PW_RPC_COMPLETION_REQUEST_CALLBACK
776   responder_.set_on_completion_requested([&called]() { called = true; });
777 #endif
778 
779   ASSERT_EQ(OkStatus(),
780             server_.ProcessPacket(
781                 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
782 
783   EXPECT_EQ(output_.total_packets(), 0u);
784   EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
785 }
786 
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_CallsCallbackIfEnabled)787 TEST_F(ServerStreamingMethod,
788        ClientRequestedCompletion_CallsCallbackIfEnabled) {
789   bool called = false;
790   responder_.set_on_completion_requested_if_enabled(
791       [&called]() { called = true; });
792 
793   ASSERT_EQ(OkStatus(),
794             server_.ProcessPacket(
795                 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
796 
797   EXPECT_EQ(output_.total_packets(), 0u);
798   EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
799 }
800 
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_ErrorWhenClosed)801 TEST_F(ServerStreamingMethod, ClientRequestedCompletion_ErrorWhenClosed) {
802   const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
803   ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
804   ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
805 
806   ASSERT_EQ(output_.total_packets(), 0u);
807 }
808 
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_ErrorWhenAlreadyClosed)809 TEST_F(ServerStreamingMethod,
810        ClientRequestedCompletion_ErrorWhenAlreadyClosed) {
811   ASSERT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
812   EXPECT_FALSE(responder_.active());
813 
814   const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
815   ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
816 
817   ASSERT_EQ(output_.total_packets(), 1u);
818   const Packet& packet =
819       static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
820   EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
821   EXPECT_EQ(packet.status(), Status::FailedPrecondition());
822 }
823 
824 }  // namespace
825 }  // namespace pw::rpc
826