1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_rpc/server.h"
16
17 #include <array>
18 #include <cstdint>
19
20 #include "pw_assert/check.h"
21 #include "pw_rpc/internal/call.h"
22 #include "pw_rpc/internal/method.h"
23 #include "pw_rpc/internal/packet.h"
24 #include "pw_rpc/internal/test_utils.h"
25 #include "pw_rpc/service.h"
26 #include "pw_rpc_private/fake_server_reader_writer.h"
27 #include "pw_rpc_private/test_method.h"
28 #include "pw_unit_test/framework.h"
29
30 namespace pw::rpc {
31
32 class ServerTestHelper {
33 public:
FindMethod(Server & server,uint32_t service_id,uint32_t method_id)34 static std::tuple<Service*, const internal::Method*> FindMethod(
35 Server& server, uint32_t service_id, uint32_t method_id) {
36 return server.FindMethod(service_id, method_id);
37 }
38 };
39
40 namespace {
41
42 using std::byte;
43
44 using internal::Packet;
45 using internal::TestMethod;
46 using internal::TestMethodUnion;
47 using internal::pwpb::PacketType;
48
49 class TestService : public Service {
50 public:
TestService(uint32_t service_id)51 TestService(uint32_t service_id)
52 : Service(service_id, methods_),
53 methods_{
54 TestMethod(100, MethodType::kBidirectionalStreaming),
55 TestMethod(200),
56 } {}
57
method(uint32_t id)58 const TestMethod& method(uint32_t id) {
59 for (TestMethodUnion& method : methods_) {
60 if (method.method().id() == id) {
61 return method.test_method();
62 }
63 }
64
65 PW_CRASH("Invalid method ID %u", static_cast<unsigned>(id));
66 }
67
68 private:
69 std::array<TestMethodUnion, 2> methods_;
70 };
71
72 class EmptyService : public Service {
73 public:
EmptyService()74 constexpr EmptyService() : Service(200, methods_) {}
75
76 private:
77 static constexpr std::array<TestMethodUnion, 0> methods_ = {};
78 };
79
80 uint32_t kDefaultCallId = 24601;
81
82 class BasicServer : public ::testing::Test {
83 protected:
84 static constexpr byte kDefaultPayload[] = {
85 byte(0x82), byte(0x02), byte(0xff), byte(0xff)};
86
BasicServer()87 BasicServer()
88 : channels_{
89 Channel::Create<1>(&output_),
90 Channel::Create<2>(&output_),
91 Channel(), // available for assignment
92 },
93 server_(channels_),
94 service_1_(1),
95 service_42_(42) {
96 server_.RegisterService(service_1_, service_42_, empty_service_);
97 }
98
EncodePacket(PacketType type,uint32_t channel_id,uint32_t service_id,uint32_t method_id,uint32_t call_id=kDefaultCallId)99 span<const byte> EncodePacket(PacketType type,
100 uint32_t channel_id,
101 uint32_t service_id,
102 uint32_t method_id,
103 uint32_t call_id = kDefaultCallId) {
104 return EncodePacketWithBody(type,
105 channel_id,
106 service_id,
107 method_id,
108 call_id,
109 kDefaultPayload,
110 OkStatus());
111 }
112
EncodeCancel(uint32_t channel_id=1,uint32_t service_id=42,uint32_t method_id=100,uint32_t call_id=kDefaultCallId)113 span<const byte> EncodeCancel(uint32_t channel_id = 1,
114 uint32_t service_id = 42,
115 uint32_t method_id = 100,
116 uint32_t call_id = kDefaultCallId) {
117 return EncodePacketWithBody(PacketType::CLIENT_ERROR,
118 channel_id,
119 service_id,
120 method_id,
121 call_id,
122 {},
123 Status::Cancelled());
124 }
125
126 template <typename T = ConstByteSpan>
PacketForRpc(PacketType type,Status status=OkStatus (),T && payload={},uint32_t call_id=kDefaultCallId)127 ConstByteSpan PacketForRpc(PacketType type,
128 Status status = OkStatus(),
129 T&& payload = {},
130 uint32_t call_id = kDefaultCallId) {
131 return EncodePacketWithBody(
132 type, 1, 42, 100, call_id, as_bytes(span(payload)), status);
133 }
134
135 RawFakeChannelOutput<2> output_;
136 std::array<Channel, 3> channels_;
137 Server server_;
138 TestService service_1_;
139 TestService service_42_;
140 EmptyService empty_service_;
141
142 private:
143 byte request_buffer_[64];
144
EncodePacketWithBody(PacketType type,uint32_t channel_id,uint32_t service_id,uint32_t method_id,uint32_t call_id,span<const byte> payload,Status status)145 span<const byte> EncodePacketWithBody(PacketType type,
146 uint32_t channel_id,
147 uint32_t service_id,
148 uint32_t method_id,
149 uint32_t call_id,
150 span<const byte> payload,
151 Status status) {
152 auto result =
153 Packet(
154 type, channel_id, service_id, method_id, call_id, payload, status)
155 .Encode(request_buffer_);
156 EXPECT_EQ(OkStatus(), result.status());
157 return result.value_or(ConstByteSpan());
158 }
159 };
160
TEST_F(BasicServer,IsServiceRegistered)161 TEST_F(BasicServer, IsServiceRegistered) {
162 TestService unregisteredService(0);
163 EXPECT_FALSE(server_.IsServiceRegistered(unregisteredService));
164 EXPECT_TRUE(server_.IsServiceRegistered(service_1_));
165 }
166
TEST_F(BasicServer,ProcessPacket_ValidMethodInService1_InvokesMethod)167 TEST_F(BasicServer, ProcessPacket_ValidMethodInService1_InvokesMethod) {
168 EXPECT_EQ(
169 OkStatus(),
170 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 1, 100)));
171
172 const TestMethod& method = service_1_.method(100);
173 EXPECT_EQ(1u, method.last_channel_id());
174 ASSERT_EQ(sizeof(kDefaultPayload), method.last_request().payload().size());
175 EXPECT_EQ(std::memcmp(kDefaultPayload,
176 method.last_request().payload().data(),
177 method.last_request().payload().size()),
178 0);
179 }
180
TEST_F(BasicServer,ProcessPacket_ValidMethodInService42_InvokesMethod)181 TEST_F(BasicServer, ProcessPacket_ValidMethodInService42_InvokesMethod) {
182 EXPECT_EQ(
183 OkStatus(),
184 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 200)));
185
186 const TestMethod& method = service_42_.method(200);
187 EXPECT_EQ(1u, method.last_channel_id());
188 ASSERT_EQ(sizeof(kDefaultPayload), method.last_request().payload().size());
189 EXPECT_EQ(std::memcmp(kDefaultPayload,
190 method.last_request().payload().data(),
191 method.last_request().payload().size()),
192 0);
193 }
194
TEST_F(BasicServer,UnregisterService_CannotCallMethod)195 TEST_F(BasicServer, UnregisterService_CannotCallMethod) {
196 const uint32_t kCallId = 8675309;
197 server_.UnregisterService(service_1_, service_42_);
198
199 EXPECT_EQ(OkStatus(),
200 server_.ProcessPacket(
201 EncodePacket(PacketType::REQUEST, 1, 1, 100, kCallId)));
202
203 const Packet& packet =
204 static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
205 EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
206 EXPECT_EQ(packet.channel_id(), 1u);
207 EXPECT_EQ(packet.service_id(), 1u);
208 EXPECT_EQ(packet.method_id(), 100u);
209 EXPECT_EQ(packet.call_id(), kCallId);
210 EXPECT_EQ(packet.status(), Status::NotFound());
211 }
212
TEST_F(BasicServer,UnregisterService_AlreadyUnregistered_DoesNothing)213 TEST_F(BasicServer, UnregisterService_AlreadyUnregistered_DoesNothing) {
214 server_.UnregisterService(service_42_, service_42_, service_42_);
215 server_.UnregisterService(service_42_);
216
217 EXPECT_EQ(
218 OkStatus(),
219 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 1, 100)));
220
221 const TestMethod& method = service_1_.method(100);
222 EXPECT_EQ(1u, method.last_channel_id());
223 ASSERT_EQ(sizeof(kDefaultPayload), method.last_request().payload().size());
224 EXPECT_EQ(std::memcmp(kDefaultPayload,
225 method.last_request().payload().data(),
226 method.last_request().payload().size()),
227 0);
228 }
229
TEST_F(BasicServer,ProcessPacket_IncompletePacket_NothingIsInvoked)230 TEST_F(BasicServer, ProcessPacket_IncompletePacket_NothingIsInvoked) {
231 EXPECT_EQ(
232 Status::DataLoss(),
233 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 0, 42, 101)));
234 EXPECT_EQ(
235 Status::DataLoss(),
236 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 0, 101)));
237 EXPECT_EQ(Status::DataLoss(),
238 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 0)));
239
240 EXPECT_EQ(0u, service_42_.method(100).last_channel_id());
241 EXPECT_EQ(0u, service_42_.method(200).last_channel_id());
242 }
243
TEST_F(BasicServer,ProcessPacket_NoChannel_SendsNothing)244 TEST_F(BasicServer, ProcessPacket_NoChannel_SendsNothing) {
245 EXPECT_EQ(
246 Status::DataLoss(),
247 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 0, 42, 101)));
248
249 EXPECT_EQ(output_.total_packets(), 0u);
250 }
251
TEST_F(BasicServer,ProcessPacket_NoService_SendsNothing)252 TEST_F(BasicServer, ProcessPacket_NoService_SendsNothing) {
253 EXPECT_EQ(
254 Status::DataLoss(),
255 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 0, 101)));
256
257 EXPECT_EQ(output_.total_packets(), 0u);
258 }
259
TEST_F(BasicServer,ProcessPacket_NoMethod_SendsNothing)260 TEST_F(BasicServer, ProcessPacket_NoMethod_SendsNothing) {
261 EXPECT_EQ(Status::DataLoss(),
262 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 0)));
263
264 EXPECT_EQ(output_.total_packets(), 0u);
265 }
266
TEST_F(BasicServer,ProcessPacket_InvalidMethod_NothingIsInvoked)267 TEST_F(BasicServer, ProcessPacket_InvalidMethod_NothingIsInvoked) {
268 EXPECT_EQ(
269 OkStatus(),
270 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 101)));
271
272 EXPECT_EQ(0u, service_42_.method(100).last_channel_id());
273 EXPECT_EQ(0u, service_42_.method(200).last_channel_id());
274 }
275
TEST_F(BasicServer,ProcessPacket_ClientErrorWithInvalidMethod_NoResponse)276 TEST_F(BasicServer, ProcessPacket_ClientErrorWithInvalidMethod_NoResponse) {
277 EXPECT_EQ(OkStatus(),
278 server_.ProcessPacket(
279 EncodePacket(PacketType::CLIENT_ERROR, 1, 42, 101)));
280
281 EXPECT_EQ(0u, output_.total_packets());
282 }
283
TEST_F(BasicServer,ProcessPacket_InvalidMethod_SendsError)284 TEST_F(BasicServer, ProcessPacket_InvalidMethod_SendsError) {
285 EXPECT_EQ(
286 OkStatus(),
287 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 27)));
288
289 const Packet& packet =
290 static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
291 EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
292 EXPECT_EQ(packet.channel_id(), 1u);
293 EXPECT_EQ(packet.service_id(), 42u);
294 EXPECT_EQ(packet.method_id(), 27u); // No method ID 27
295 EXPECT_EQ(packet.status(), Status::NotFound());
296 }
297
TEST_F(BasicServer,ProcessPacket_InvalidService_SendsError)298 TEST_F(BasicServer, ProcessPacket_InvalidService_SendsError) {
299 EXPECT_EQ(
300 OkStatus(),
301 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 43, 27)));
302
303 const Packet& packet =
304 static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
305 EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
306 EXPECT_EQ(packet.channel_id(), 1u);
307 EXPECT_EQ(packet.service_id(), 43u); // No service ID 43
308 EXPECT_EQ(packet.method_id(), 27u);
309 EXPECT_EQ(packet.status(), Status::NotFound());
310 }
311
TEST_F(BasicServer,ProcessPacket_UnassignedChannel)312 TEST_F(BasicServer, ProcessPacket_UnassignedChannel) {
313 EXPECT_EQ(Status::Unavailable(),
314 server_.ProcessPacket(
315 EncodePacket(PacketType::REQUEST, /*channel_id=*/99, 42, 27)));
316 }
317
TEST_F(BasicServer,ProcessPacket_ClientErrorOnUnassignedChannel_NoResponse)318 TEST_F(BasicServer, ProcessPacket_ClientErrorOnUnassignedChannel_NoResponse) {
319 channels_[2] = Channel::Create<3>(&output_); // Occupy only available channel
320
321 EXPECT_EQ(Status::Unavailable(),
322 server_.ProcessPacket(EncodePacket(
323 PacketType::CLIENT_ERROR, /*channel_id=*/99, 42, 27)));
324
325 EXPECT_EQ(0u, output_.total_packets());
326 }
327
TEST_F(BasicServer,ProcessPacket_Cancel_MethodNotActive_SendsNothing)328 TEST_F(BasicServer, ProcessPacket_Cancel_MethodNotActive_SendsNothing) {
329 // Set up a fake ServerWriter representing an ongoing RPC.
330 EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(1, 42, 100)));
331
332 EXPECT_EQ(output_.total_packets(), 0u);
333 }
334
GetChannel(internal::Endpoint & endpoint,uint32_t id)335 const internal::ChannelBase* GetChannel(internal::Endpoint& endpoint,
336 uint32_t id) {
337 internal::RpcLockGuard lock;
338 return endpoint.GetInternalChannel(id);
339 }
340
TEST_F(BasicServer,CloseChannel_NoCalls)341 TEST_F(BasicServer, CloseChannel_NoCalls) {
342 EXPECT_NE(nullptr, GetChannel(server_, 2));
343 EXPECT_EQ(OkStatus(), server_.CloseChannel(2));
344 EXPECT_EQ(nullptr, GetChannel(server_, 2));
345 ASSERT_EQ(output_.total_packets(), 0u);
346 }
347
TEST_F(BasicServer,CloseChannel_UnknownChannel)348 TEST_F(BasicServer, CloseChannel_UnknownChannel) {
349 ASSERT_EQ(nullptr, GetChannel(server_, 13579));
350 EXPECT_EQ(Status::NotFound(), server_.CloseChannel(13579));
351 }
352
TEST_F(BasicServer,CloseChannel_PendingCall)353 TEST_F(BasicServer, CloseChannel_PendingCall) {
354 EXPECT_NE(nullptr, GetChannel(server_, 1));
355 EXPECT_EQ(static_cast<internal::Endpoint&>(server_).active_call_count(), 0u);
356
357 internal::test::FakeServerReaderWriter call;
358 service_42_.method(100).keep_call_active(call);
359
360 EXPECT_EQ(
361 OkStatus(),
362 server_.ProcessPacket(EncodePacket(PacketType::REQUEST, 1, 42, 100)));
363
364 Status on_error_status;
365 call.set_on_error(
366 [&on_error_status](Status error) { on_error_status = error; });
367
368 ASSERT_TRUE(call.active());
369 EXPECT_EQ(static_cast<internal::Endpoint&>(server_).active_call_count(), 1u);
370
371 EXPECT_EQ(OkStatus(), server_.CloseChannel(1));
372 EXPECT_EQ(nullptr, GetChannel(server_, 1));
373
374 EXPECT_EQ(static_cast<internal::Endpoint&>(server_).active_call_count(), 0u);
375
376 // Should call on_error, but not send a packet since the channel is closed.
377 EXPECT_EQ(Status::Aborted(), on_error_status);
378 ASSERT_EQ(output_.total_packets(), 0u);
379 }
380
TEST_F(BasicServer,OpenChannel_UnusedSlot)381 TEST_F(BasicServer, OpenChannel_UnusedSlot) {
382 const span request = EncodePacket(PacketType::REQUEST, 9, 42, 100);
383 EXPECT_EQ(Status::Unavailable(), server_.ProcessPacket(request));
384
385 EXPECT_EQ(OkStatus(), server_.OpenChannel(9, output_));
386 EXPECT_EQ(OkStatus(), server_.ProcessPacket(request));
387
388 const Packet& packet =
389 static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
390 EXPECT_EQ(packet.type(), PacketType::RESPONSE);
391 EXPECT_EQ(packet.channel_id(), 9u);
392 EXPECT_EQ(packet.service_id(), 42u);
393 EXPECT_EQ(packet.method_id(), 100u);
394 }
395
TEST_F(BasicServer,OpenChannel_AlreadyExists)396 TEST_F(BasicServer, OpenChannel_AlreadyExists) {
397 ASSERT_NE(nullptr, GetChannel(server_, 1));
398 EXPECT_EQ(Status::AlreadyExists(), server_.OpenChannel(1, output_));
399 }
400
TEST_F(BasicServer,OpenChannel_AdditionalSlot)401 TEST_F(BasicServer, OpenChannel_AdditionalSlot) {
402 EXPECT_EQ(OkStatus(), server_.OpenChannel(3, output_));
403
404 constexpr Status kExpected =
405 PW_RPC_DYNAMIC_ALLOCATION == 0 ? Status::ResourceExhausted() : OkStatus();
406 EXPECT_EQ(kExpected, server_.OpenChannel(19823, output_));
407 }
408
TEST_F(BasicServer,FindMethod_FoundOkOptionallyCheckType)409 TEST_F(BasicServer, FindMethod_FoundOkOptionallyCheckType) {
410 const auto [service, method] = ServerTestHelper::FindMethod(server_, 1, 100);
411 ASSERT_TRUE(service != nullptr);
412 ASSERT_TRUE(method != nullptr);
413 #if PW_RPC_METHOD_STORES_TYPE
414 EXPECT_EQ(MethodType::kBidirectionalStreaming, method->type());
415 #endif
416 }
417
TEST_F(BasicServer,FindMethod_NotFound)418 TEST_F(BasicServer, FindMethod_NotFound) {
419 {
420 const auto [service, method] =
421 ServerTestHelper::FindMethod(server_, 2, 100);
422 ASSERT_TRUE(service == nullptr);
423 ASSERT_TRUE(method == nullptr);
424 }
425
426 {
427 const auto [service, method] =
428 ServerTestHelper::FindMethod(server_, 1, 101);
429 ASSERT_TRUE(service != nullptr);
430 ASSERT_TRUE(method == nullptr);
431 }
432 }
433
434 class BidiMethod : public BasicServer {
435 protected:
BidiMethod()436 BidiMethod() {
437 internal::rpc_lock().lock();
438 internal::CallContext context(server_,
439 channels_[0].id(),
440 service_42_,
441 service_42_.method(100),
442 kDefaultCallId);
443 // A local temporary is required since the constructor requires a lock,
444 // but the *move* constructor takes out the lock.
445 internal::test::FakeServerReaderWriter responder_temp(
446 context.ClaimLocked());
447 internal::rpc_lock().unlock();
448 responder_ = std::move(responder_temp);
449 PW_CHECK(responder_.active());
450 }
451
452 internal::test::FakeServerReaderWriter responder_;
453 };
454
TEST_F(BidiMethod,DuplicateCallId_CancelsExistingThenCallsAgain)455 TEST_F(BidiMethod, DuplicateCallId_CancelsExistingThenCallsAgain) {
456 int cancelled = 0;
457 responder_.set_on_error([&cancelled](Status error) {
458 if (error.IsCancelled()) {
459 cancelled += 1;
460 }
461 });
462
463 const TestMethod& method = service_42_.method(100);
464 ASSERT_EQ(method.invocations(), 0u);
465
466 EXPECT_EQ(OkStatus(),
467 server_.ProcessPacket(PacketForRpc(PacketType::REQUEST)));
468
469 EXPECT_EQ(cancelled, 1);
470 EXPECT_EQ(method.invocations(), 1u);
471 }
472
TEST_F(BidiMethod,DuplicateMethodDifferentCallId_NotCancelled)473 TEST_F(BidiMethod, DuplicateMethodDifferentCallId_NotCancelled) {
474 int cancelled = 0;
475 responder_.set_on_error([&cancelled](Status error) {
476 if (error.IsCancelled()) {
477 cancelled += 1;
478 }
479 });
480
481 const uint32_t kSecondCallId = 1625;
482 EXPECT_EQ(OkStatus(),
483 server_.ProcessPacket(PacketForRpc(
484 PacketType::REQUEST, OkStatus(), {}, kSecondCallId)));
485
486 EXPECT_EQ(cancelled, 0);
487 }
488
span_as_cstr(ConstByteSpan span)489 const char* span_as_cstr(ConstByteSpan span) {
490 return reinterpret_cast<const char*>(span.data());
491 }
492
TEST_F(BidiMethod,DuplicateMethodDifferentCallIdEachCallGetsSeparateResponse)493 TEST_F(BidiMethod, DuplicateMethodDifferentCallIdEachCallGetsSeparateResponse) {
494 const uint32_t kSecondCallId = 1625;
495
496 internal::rpc_lock().lock();
497 internal::test::FakeServerReaderWriter responder_2(
498 internal::CallContext(server_,
499 channels_[0].id(),
500 service_42_,
501 service_42_.method(100),
502 kSecondCallId)
503 .ClaimLocked());
504 internal::rpc_lock().unlock();
505
506 ConstByteSpan data_1 = as_bytes(span("data_1_unset"));
507 responder_.set_on_next(
508 [&data_1](ConstByteSpan payload) { data_1 = payload; });
509
510 ConstByteSpan data_2 = as_bytes(span("data_2_unset"));
511 responder_2.set_on_next(
512 [&data_2](ConstByteSpan payload) { data_2 = payload; });
513
514 const char* kMessage1 = "hello_1";
515 const char* kMessage2 = "hello_2";
516
517 EXPECT_EQ(
518 OkStatus(),
519 server_.ProcessPacket(PacketForRpc(
520 PacketType::CLIENT_STREAM, OkStatus(), "hello_2", kSecondCallId)));
521
522 EXPECT_STREQ(span_as_cstr(data_2), kMessage2);
523
524 EXPECT_EQ(
525 OkStatus(),
526 server_.ProcessPacket(PacketForRpc(
527 PacketType::CLIENT_STREAM, OkStatus(), "hello_1", kDefaultCallId)));
528
529 EXPECT_STREQ(span_as_cstr(data_1), kMessage1);
530 }
531
TEST_F(BidiMethod,Cancel_ClosesServerWriter)532 TEST_F(BidiMethod, Cancel_ClosesServerWriter) {
533 EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
534
535 EXPECT_FALSE(responder_.active());
536 }
537
TEST_F(BidiMethod,Cancel_SendsNoResponse)538 TEST_F(BidiMethod, Cancel_SendsNoResponse) {
539 EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
540
541 EXPECT_EQ(output_.total_packets(), 0u);
542 }
543
TEST_F(BidiMethod,ClientError_ClosesServerWriterWithoutResponse)544 TEST_F(BidiMethod, ClientError_ClosesServerWriterWithoutResponse) {
545 ASSERT_EQ(OkStatus(),
546 server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_ERROR)));
547
548 EXPECT_FALSE(responder_.active());
549 EXPECT_EQ(output_.total_packets(), 0u);
550 }
551
TEST_F(BidiMethod,ClientError_CallsOnErrorCallback)552 TEST_F(BidiMethod, ClientError_CallsOnErrorCallback) {
553 Status status = Status::Unknown();
554 responder_.set_on_error([&status](Status error) { status = error; });
555
556 ASSERT_EQ(OkStatus(),
557 server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_ERROR,
558 Status::Unauthenticated())));
559
560 EXPECT_EQ(status, Status::Unauthenticated());
561 }
562
TEST_F(BidiMethod,Cancel_CallsOnErrorCallback)563 TEST_F(BidiMethod, Cancel_CallsOnErrorCallback) {
564 Status status = Status::Unknown();
565 responder_.set_on_error([&status](Status error) { status = error; });
566
567 ASSERT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
568 EXPECT_EQ(status, Status::Cancelled());
569 }
570
TEST_F(BidiMethod,Cancel_IncorrectChannel_SendsNothing)571 TEST_F(BidiMethod, Cancel_IncorrectChannel_SendsNothing) {
572 EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(2, 42, 100)));
573
574 EXPECT_EQ(output_.total_packets(), 0u);
575 EXPECT_TRUE(responder_.active());
576 }
577
TEST_F(BidiMethod,Cancel_IncorrectService_SendsNothing)578 TEST_F(BidiMethod, Cancel_IncorrectService_SendsNothing) {
579 EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(1, 43, 100)));
580 EXPECT_EQ(output_.total_packets(), 0u);
581 EXPECT_TRUE(responder_.active());
582 }
583
TEST_F(BidiMethod,Cancel_IncorrectMethod_SendsNothing)584 TEST_F(BidiMethod, Cancel_IncorrectMethod_SendsNothing) {
585 EXPECT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel(1, 42, 101)));
586 EXPECT_EQ(output_.total_packets(), 0u);
587 EXPECT_TRUE(responder_.active());
588 }
589
TEST_F(BidiMethod,ClientStream_CallsCallback)590 TEST_F(BidiMethod, ClientStream_CallsCallback) {
591 ConstByteSpan data = as_bytes(span("?"));
592 responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
593
594 ASSERT_EQ(OkStatus(),
595 server_.ProcessPacket(
596 PacketForRpc(PacketType::CLIENT_STREAM, {}, "hello")));
597
598 EXPECT_EQ(output_.total_packets(), 0u);
599 EXPECT_STREQ(span_as_cstr(data), "hello");
600 }
601
TEST_F(BidiMethod,ClientStream_CallsCallbackOnCallWithOpenId)602 TEST_F(BidiMethod, ClientStream_CallsCallbackOnCallWithOpenId) {
603 ConstByteSpan data = as_bytes(span("?"));
604 responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
605
606 ASSERT_EQ(
607 OkStatus(),
608 server_.ProcessPacket(PacketForRpc(
609 PacketType::CLIENT_STREAM, {}, "hello", internal::kOpenCallId)));
610
611 EXPECT_EQ(output_.total_packets(), 0u);
612 EXPECT_STREQ(span_as_cstr(data), "hello");
613 }
614
TEST_F(BidiMethod,ClientStream_CallsCallbackOnCallWithLegacyOpenId)615 TEST_F(BidiMethod, ClientStream_CallsCallbackOnCallWithLegacyOpenId) {
616 ConstByteSpan data = as_bytes(span("?"));
617 responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
618
619 ASSERT_EQ(OkStatus(),
620 server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_STREAM,
621 {},
622 "hello",
623 internal::kLegacyOpenCallId)));
624
625 EXPECT_EQ(output_.total_packets(), 0u);
626 EXPECT_STREQ(span_as_cstr(data), "hello");
627 }
628
TEST_F(BidiMethod,ClientStream_CallsOpenIdOnCallWithDifferentId)629 TEST_F(BidiMethod, ClientStream_CallsOpenIdOnCallWithDifferentId) {
630 const uint32_t kSecondCallId = 1625;
631 internal::CallContext context(server_,
632 channels_[0].id(),
633 service_42_,
634 service_42_.method(100),
635 internal::kOpenCallId);
636 internal::rpc_lock().lock();
637 auto temp_responder =
638 internal::test::FakeServerReaderWriter(context.ClaimLocked());
639 internal::rpc_lock().unlock();
640 responder_ = std::move(temp_responder);
641
642 ConstByteSpan data = as_bytes(span("?"));
643 responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
644
645 ASSERT_EQ(OkStatus(),
646 server_.ProcessPacket(PacketForRpc(
647 PacketType::CLIENT_STREAM, {}, "hello", kSecondCallId)));
648
649 EXPECT_EQ(output_.total_packets(), 0u);
650 EXPECT_STREQ(span_as_cstr(data), "hello");
651
652 internal::RpcLockGuard lock;
653 EXPECT_EQ(responder_.as_server_call().id(), kSecondCallId);
654 }
655
TEST_F(BidiMethod,ClientStream_CallsLegacyOpenIdOnCallWithDifferentId)656 TEST_F(BidiMethod, ClientStream_CallsLegacyOpenIdOnCallWithDifferentId) {
657 const uint32_t kSecondCallId = 1625;
658 internal::CallContext context(server_,
659 channels_[0].id(),
660 service_42_,
661 service_42_.method(100),
662 internal::kLegacyOpenCallId);
663 internal::rpc_lock().lock();
664 auto temp_responder =
665 internal::test::FakeServerReaderWriter(context.ClaimLocked());
666 internal::rpc_lock().unlock();
667 responder_ = std::move(temp_responder);
668
669 ConstByteSpan data = as_bytes(span("?"));
670 responder_.set_on_next([&data](ConstByteSpan payload) { data = payload; });
671
672 ASSERT_EQ(OkStatus(),
673 server_.ProcessPacket(PacketForRpc(
674 PacketType::CLIENT_STREAM, {}, "hello", kSecondCallId)));
675
676 EXPECT_EQ(output_.total_packets(), 0u);
677 EXPECT_STREQ(span_as_cstr(data), "hello");
678
679 internal::RpcLockGuard lock;
680 EXPECT_EQ(responder_.as_server_call().id(), kSecondCallId);
681 }
682
TEST_F(BidiMethod,UnregsiterService_AbortsActiveCalls)683 TEST_F(BidiMethod, UnregsiterService_AbortsActiveCalls) {
684 ASSERT_TRUE(responder_.active());
685
686 Status on_error_status = OkStatus();
687 responder_.set_on_error(
688 [&on_error_status](Status status) { on_error_status = status; });
689
690 server_.UnregisterService(service_42_);
691
692 EXPECT_FALSE(responder_.active());
693 EXPECT_EQ(Status::Aborted(), on_error_status);
694 }
695
TEST_F(BidiMethod,ClientRequestedCompletion_CallsCallback)696 TEST_F(BidiMethod, ClientRequestedCompletion_CallsCallback) {
697 bool called = false;
698 #if PW_RPC_COMPLETION_REQUEST_CALLBACK
699 responder_.set_on_completion_requested([&called]() { called = true; });
700 #endif
701 ASSERT_EQ(OkStatus(),
702 server_.ProcessPacket(
703 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
704
705 EXPECT_EQ(output_.total_packets(), 0u);
706 EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
707 }
708
TEST_F(BidiMethod,ClientRequestedCompletion_CallsCallbackIfEnabled)709 TEST_F(BidiMethod, ClientRequestedCompletion_CallsCallbackIfEnabled) {
710 bool called = false;
711 responder_.set_on_completion_requested_if_enabled(
712 [&called]() { called = true; });
713
714 ASSERT_EQ(OkStatus(),
715 server_.ProcessPacket(
716 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
717
718 EXPECT_EQ(output_.total_packets(), 0u);
719 EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
720 }
721
TEST_F(BidiMethod,ClientRequestedCompletion_ErrorWhenClosed)722 TEST_F(BidiMethod, ClientRequestedCompletion_ErrorWhenClosed) {
723 const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
724 ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
725 ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
726
727 ASSERT_EQ(output_.total_packets(), 0u);
728 }
729
TEST_F(BidiMethod,ClientRequestedCompletion_ErrorWhenAlreadyClosed)730 TEST_F(BidiMethod, ClientRequestedCompletion_ErrorWhenAlreadyClosed) {
731 ASSERT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
732 EXPECT_FALSE(responder_.active());
733
734 const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
735 ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
736
737 ASSERT_EQ(output_.total_packets(), 1u);
738 const Packet& packet =
739 static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
740 EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
741 EXPECT_EQ(packet.status(), Status::FailedPrecondition());
742 }
743
744 class ServerStreamingMethod : public BasicServer {
745 protected:
ServerStreamingMethod()746 ServerStreamingMethod() {
747 internal::CallContext context(server_,
748 channels_[0].id(),
749 service_42_,
750 service_42_.method(100),
751 kDefaultCallId);
752 internal::rpc_lock().lock();
753 internal::test::FakeServerWriter responder_temp(context.ClaimLocked());
754 internal::rpc_lock().unlock();
755 responder_ = std::move(responder_temp);
756 PW_CHECK(responder_.active());
757 }
758
759 internal::test::FakeServerWriter responder_;
760 };
761
TEST_F(ServerStreamingMethod,ClientStream_InvalidArgumentError)762 TEST_F(ServerStreamingMethod, ClientStream_InvalidArgumentError) {
763 ASSERT_EQ(OkStatus(),
764 server_.ProcessPacket(PacketForRpc(PacketType::CLIENT_STREAM)));
765
766 ASSERT_EQ(output_.total_packets(), 1u);
767 const Packet& packet =
768 static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
769 EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
770 EXPECT_EQ(packet.status(), Status::InvalidArgument());
771 }
772
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_CallsCallback)773 TEST_F(ServerStreamingMethod, ClientRequestedCompletion_CallsCallback) {
774 bool called = false;
775 #if PW_RPC_COMPLETION_REQUEST_CALLBACK
776 responder_.set_on_completion_requested([&called]() { called = true; });
777 #endif
778
779 ASSERT_EQ(OkStatus(),
780 server_.ProcessPacket(
781 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
782
783 EXPECT_EQ(output_.total_packets(), 0u);
784 EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
785 }
786
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_CallsCallbackIfEnabled)787 TEST_F(ServerStreamingMethod,
788 ClientRequestedCompletion_CallsCallbackIfEnabled) {
789 bool called = false;
790 responder_.set_on_completion_requested_if_enabled(
791 [&called]() { called = true; });
792
793 ASSERT_EQ(OkStatus(),
794 server_.ProcessPacket(
795 PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION)));
796
797 EXPECT_EQ(output_.total_packets(), 0u);
798 EXPECT_EQ(called, PW_RPC_COMPLETION_REQUEST_CALLBACK);
799 }
800
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_ErrorWhenClosed)801 TEST_F(ServerStreamingMethod, ClientRequestedCompletion_ErrorWhenClosed) {
802 const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
803 ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
804 ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
805
806 ASSERT_EQ(output_.total_packets(), 0u);
807 }
808
TEST_F(ServerStreamingMethod,ClientRequestedCompletion_ErrorWhenAlreadyClosed)809 TEST_F(ServerStreamingMethod,
810 ClientRequestedCompletion_ErrorWhenAlreadyClosed) {
811 ASSERT_EQ(OkStatus(), server_.ProcessPacket(EncodeCancel()));
812 EXPECT_FALSE(responder_.active());
813
814 const auto end = PacketForRpc(PacketType::CLIENT_REQUEST_COMPLETION);
815 ASSERT_EQ(OkStatus(), server_.ProcessPacket(end));
816
817 ASSERT_EQ(output_.total_packets(), 1u);
818 const Packet& packet =
819 static_cast<internal::test::FakeChannelOutput&>(output_).last_packet();
820 EXPECT_EQ(packet.type(), PacketType::SERVER_ERROR);
821 EXPECT_EQ(packet.status(), Status::FailedPrecondition());
822 }
823
824 } // namespace
825 } // namespace pw::rpc
826