1 // Copyright 2022 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include <optional>
16
17 #include "pw_rpc/internal/test_utils.h"
18 #include "pw_rpc/pwpb/client_reader_writer.h"
19 #include "pw_rpc_pwpb_private/internal_test_utils.h"
20 #include "pw_rpc_test_protos/test.pwpb.h"
21 #include "pw_unit_test/framework.h"
22
23 PW_MODIFY_DIAGNOSTICS_PUSH();
24 PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers");
25
26 namespace pw::rpc {
27 namespace {
28
29 using internal::ClientContextForTest;
30 using internal::pwpb::PacketType;
31
32 namespace TestRequest = ::pw::rpc::test::pwpb::TestRequest;
33 namespace TestResponse = ::pw::rpc::test::pwpb::TestResponse;
34 namespace TestStreamResponse = ::pw::rpc::test::pwpb::TestStreamResponse;
35
36 constexpr uint32_t kServiceId = 16;
37 constexpr uint32_t kUnaryMethodId = 111;
38 constexpr uint32_t kServerStreamingMethodId = 112;
39
40 class FakeGeneratedServiceClient {
41 public:
TestUnaryRpc(Client & client,uint32_t channel_id,const TestRequest::Message & request,Function<void (const TestResponse::Message &,Status)> on_response,Function<void (Status)> on_error=nullptr)42 static PwpbUnaryReceiver<TestResponse::Message> TestUnaryRpc(
43 Client& client,
44 uint32_t channel_id,
45 const TestRequest::Message& request,
46 Function<void(const TestResponse::Message&, Status)> on_response,
47 Function<void(Status)> on_error = nullptr) {
48 return internal::PwpbUnaryResponseClientCall<TestResponse::Message>::Start<
49 PwpbUnaryReceiver<TestResponse::Message>>(
50 client,
51 channel_id,
52 kServiceId,
53 kUnaryMethodId,
54 internal::kPwpbMethodSerde<&TestRequest::kMessageFields,
55 &TestResponse::kMessageFields>,
56 std::move(on_response),
57 std::move(on_error),
58 request);
59 }
60
TestAnotherUnaryRpc(Client & client,uint32_t channel_id,const TestRequest::Message & request,Function<void (const TestResponse::Message &,Status)> on_response,Function<void (Status)> on_error=nullptr)61 static PwpbUnaryReceiver<TestResponse::Message> TestAnotherUnaryRpc(
62 Client& client,
63 uint32_t channel_id,
64 const TestRequest::Message& request,
65 Function<void(const TestResponse::Message&, Status)> on_response,
66 Function<void(Status)> on_error = nullptr) {
67 return internal::PwpbUnaryResponseClientCall<TestResponse::Message>::Start<
68 PwpbUnaryReceiver<TestResponse::Message>>(
69 client,
70 channel_id,
71 kServiceId,
72 kUnaryMethodId,
73 internal::kPwpbMethodSerde<&TestRequest::kMessageFields,
74 &TestResponse::kMessageFields>,
75 std::move(on_response),
76 std::move(on_error),
77 request);
78 }
79
TestServerStreamRpc(Client & client,uint32_t channel_id,const TestRequest::Message & request,Function<void (const TestStreamResponse::Message &)> on_response,Function<void (Status)> on_stream_end,Function<void (Status)> on_error=nullptr)80 static PwpbClientReader<TestStreamResponse::Message> TestServerStreamRpc(
81 Client& client,
82 uint32_t channel_id,
83 const TestRequest::Message& request,
84 Function<void(const TestStreamResponse::Message&)> on_response,
85 Function<void(Status)> on_stream_end,
86 Function<void(Status)> on_error = nullptr) {
87 return internal::PwpbStreamResponseClientCall<TestStreamResponse::Message>::
88 Start<PwpbClientReader<TestStreamResponse::Message>>(
89 client,
90 channel_id,
91 kServiceId,
92 kServerStreamingMethodId,
93 internal::kPwpbMethodSerde<&TestRequest::kMessageFields,
94 &TestStreamResponse::kMessageFields>,
95 std::move(on_response),
96 std::move(on_stream_end),
97 std::move(on_error),
98 request);
99 }
100 };
101
TEST(PwpbClientCall,Unary_SendsRequestPacket)102 TEST(PwpbClientCall, Unary_SendsRequestPacket) {
103 ClientContextForTest context;
104
105 auto call = FakeGeneratedServiceClient::TestUnaryRpc(
106 context.client(),
107 context.channel().id(),
108 {.integer = 123, .status_code = 0},
109 nullptr);
110
111 EXPECT_EQ(context.output().total_packets(), 1u);
112 auto packet = context.output().last_packet();
113 EXPECT_EQ(packet.channel_id(), context.channel().id());
114 EXPECT_EQ(packet.service_id(), kServiceId);
115 EXPECT_EQ(packet.method_id(), kUnaryMethodId);
116
117 PW_DECODE_PB(TestRequest, sent_proto, packet.payload());
118 EXPECT_EQ(sent_proto.integer, 123);
119 }
120
121 class UnaryClientCall : public ::testing::Test {
122 protected:
123 std::optional<Status> last_status_;
124 std::optional<Status> last_error_;
125 int responses_received_ = 0;
126 int last_response_value_ = 0;
127 };
128
TEST_F(UnaryClientCall,InvokesCallbackOnValidResponse)129 TEST_F(UnaryClientCall, InvokesCallbackOnValidResponse) {
130 ClientContextForTest context;
131
132 auto call = FakeGeneratedServiceClient::TestUnaryRpc(
133 context.client(),
134 context.channel().id(),
135 {.integer = 123, .status_code = 0},
136 [this](const TestResponse::Message& response, Status status) {
137 ++responses_received_;
138 last_status_ = status;
139 last_response_value_ = response.value;
140 });
141
142 PW_ENCODE_PB(TestResponse, response, .value = 42);
143 EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), response));
144
145 ASSERT_EQ(responses_received_, 1);
146 EXPECT_EQ(last_status_, OkStatus());
147 EXPECT_EQ(last_response_value_, 42);
148 }
149
TEST_F(UnaryClientCall,DoesNothingOnNullCallback)150 TEST_F(UnaryClientCall, DoesNothingOnNullCallback) {
151 ClientContextForTest context;
152
153 auto call = FakeGeneratedServiceClient::TestUnaryRpc(
154 context.client(),
155 context.channel().id(),
156 {.integer = 123, .status_code = 0},
157 nullptr);
158
159 PW_ENCODE_PB(TestResponse, response, .value = 42);
160 EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), response));
161
162 ASSERT_EQ(responses_received_, 0);
163 }
164
TEST_F(UnaryClientCall,InvokesErrorCallbackOnInvalidResponse)165 TEST_F(UnaryClientCall, InvokesErrorCallbackOnInvalidResponse) {
166 ClientContextForTest context;
167
168 auto call = FakeGeneratedServiceClient::TestUnaryRpc(
169 context.client(),
170 context.channel().id(),
171 {.integer = 123, .status_code = 0},
172 [this](const TestResponse::Message& response, Status status) {
173 ++responses_received_;
174 last_status_ = status;
175 last_response_value_ = response.value;
176 },
177 [this](Status status) { last_error_ = status; });
178
179 constexpr std::byte bad_payload[]{
180 std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
181 EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), bad_payload));
182
183 EXPECT_EQ(responses_received_, 0);
184 ASSERT_TRUE(last_error_.has_value());
185 EXPECT_EQ(last_error_, Status::DataLoss());
186 }
187
TEST_F(UnaryClientCall,InvokesErrorCallbackOnServerError)188 TEST_F(UnaryClientCall, InvokesErrorCallbackOnServerError) {
189 ClientContextForTest context;
190
191 auto call = FakeGeneratedServiceClient::TestUnaryRpc(
192 context.client(),
193 context.channel().id(),
194 {.integer = 123, .status_code = 0},
195 [this](const TestResponse::Message& response, Status status) {
196 ++responses_received_;
197 last_status_ = status;
198 last_response_value_ = response.value;
199 },
200 [this](Status status) { last_error_ = status; });
201
202 EXPECT_EQ(OkStatus(),
203 context.SendPacket(PacketType::SERVER_ERROR, Status::NotFound()));
204
205 EXPECT_EQ(responses_received_, 0);
206 EXPECT_EQ(last_error_, Status::NotFound());
207 }
208
TEST_F(UnaryClientCall,DoesNothingOnErrorWithoutCallback)209 TEST_F(UnaryClientCall, DoesNothingOnErrorWithoutCallback) {
210 ClientContextForTest context;
211
212 auto call = FakeGeneratedServiceClient::TestUnaryRpc(
213 context.client(),
214 context.channel().id(),
215 {.integer = 123, .status_code = 0},
216 [this](const TestResponse::Message& response, Status status) {
217 ++responses_received_;
218 last_status_ = status;
219 last_response_value_ = response.value;
220 });
221
222 constexpr std::byte bad_payload[]{
223 std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
224 EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), bad_payload));
225
226 EXPECT_EQ(responses_received_, 0);
227 }
228
TEST_F(UnaryClientCall,OnlyReceivesOneResponse)229 TEST_F(UnaryClientCall, OnlyReceivesOneResponse) {
230 ClientContextForTest context;
231
232 auto call = FakeGeneratedServiceClient::TestUnaryRpc(
233 context.client(),
234 context.channel().id(),
235 {.integer = 123, .status_code = 0},
236 [this](const TestResponse::Message& response, Status status) {
237 ++responses_received_;
238 last_status_ = status;
239 last_response_value_ = response.value;
240 });
241
242 PW_ENCODE_PB(TestResponse, r1, .value = 42);
243 EXPECT_EQ(OkStatus(), context.SendResponse(Status::Unimplemented(), r1));
244 PW_ENCODE_PB(TestResponse, r2, .value = 44);
245 EXPECT_EQ(OkStatus(), context.SendResponse(Status::OutOfRange(), r2));
246 PW_ENCODE_PB(TestResponse, r3, .value = 46);
247 EXPECT_EQ(OkStatus(), context.SendResponse(Status::Internal(), r3));
248
249 EXPECT_EQ(responses_received_, 1);
250 EXPECT_EQ(last_status_, Status::Unimplemented());
251 EXPECT_EQ(last_response_value_, 42);
252 }
253
254 class ServerStreamingClientCall : public ::testing::Test {
255 protected:
256 std::optional<Status> stream_status_;
257 std::optional<Status> rpc_error_;
258 int responses_received_ = 0;
259 int last_response_number_ = 0;
260 };
261
TEST_F(ServerStreamingClientCall,SendsRequestPacket)262 TEST_F(ServerStreamingClientCall, SendsRequestPacket) {
263 ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
264
265 auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
266 context.client(),
267 context.channel().id(),
268 {.integer = 71, .status_code = 0},
269 nullptr,
270 nullptr);
271
272 EXPECT_EQ(context.output().total_packets(), 1u);
273 auto packet = context.output().last_packet();
274 EXPECT_EQ(packet.channel_id(), context.channel().id());
275 EXPECT_EQ(packet.service_id(), kServiceId);
276 EXPECT_EQ(packet.method_id(), kServerStreamingMethodId);
277
278 PW_DECODE_PB(TestRequest, sent_proto, packet.payload());
279 EXPECT_EQ(sent_proto.integer, 71);
280 }
281
TEST_F(ServerStreamingClientCall,InvokesCallbackOnValidResponse)282 TEST_F(ServerStreamingClientCall, InvokesCallbackOnValidResponse) {
283 ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
284
285 auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
286 context.client(),
287 context.channel().id(),
288 {.integer = 71, .status_code = 0},
289 [this](const TestStreamResponse::Message& response) {
290 ++responses_received_;
291 last_response_number_ = response.number;
292 },
293 [this](Status status) { stream_status_ = status; });
294
295 PW_ENCODE_PB(TestStreamResponse, r1, .chunk = {}, .number = 11u);
296 EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
297 EXPECT_TRUE(call.active());
298 EXPECT_EQ(responses_received_, 1);
299 EXPECT_EQ(last_response_number_, 11);
300
301 PW_ENCODE_PB(TestStreamResponse, r2, .chunk = {}, .number = 22u);
302 EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
303 EXPECT_TRUE(call.active());
304 EXPECT_EQ(responses_received_, 2);
305 EXPECT_EQ(last_response_number_, 22);
306
307 PW_ENCODE_PB(TestStreamResponse, r3, .chunk = {}, .number = 33u);
308 EXPECT_EQ(OkStatus(), context.SendServerStream(r3));
309 EXPECT_TRUE(call.active());
310 EXPECT_EQ(responses_received_, 3);
311 EXPECT_EQ(last_response_number_, 33);
312 }
313
TEST_F(ServerStreamingClientCall,InvokesStreamEndOnFinish)314 TEST_F(ServerStreamingClientCall, InvokesStreamEndOnFinish) {
315 ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
316
317 auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
318 context.client(),
319 context.channel().id(),
320 {.integer = 71, .status_code = 0},
321 [this](const TestStreamResponse::Message& response) {
322 ++responses_received_;
323 last_response_number_ = response.number;
324 },
325 [this](Status status) { stream_status_ = status; });
326
327 PW_ENCODE_PB(TestStreamResponse, r1, .chunk = {}, .number = 11u);
328 EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
329 EXPECT_TRUE(call.active());
330
331 PW_ENCODE_PB(TestStreamResponse, r2, .chunk = {}, .number = 22u);
332 EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
333 EXPECT_TRUE(call.active());
334
335 // Close the stream.
336 EXPECT_EQ(OkStatus(), context.SendResponse(Status::NotFound()));
337
338 PW_ENCODE_PB(TestStreamResponse, r3, .chunk = {}, .number = 33u);
339 EXPECT_EQ(OkStatus(), context.SendServerStream(r3));
340 EXPECT_FALSE(call.active());
341
342 EXPECT_EQ(responses_received_, 2);
343 }
344
TEST_F(ServerStreamingClientCall,ParseErrorTerminatesCallWithDataLoss)345 TEST_F(ServerStreamingClientCall, ParseErrorTerminatesCallWithDataLoss) {
346 ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
347
348 auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
349 context.client(),
350 context.channel().id(),
351 {.integer = 71, .status_code = 0},
352 [this](const TestStreamResponse::Message& response) {
353 ++responses_received_;
354 last_response_number_ = response.number;
355 },
356 nullptr,
357 [this](Status error) { rpc_error_ = error; });
358
359 PW_ENCODE_PB(TestStreamResponse, r1, .chunk = {}, .number = 11u);
360 EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
361 EXPECT_TRUE(call.active());
362 EXPECT_EQ(responses_received_, 1);
363 EXPECT_EQ(last_response_number_, 11);
364
365 constexpr std::byte bad_payload[]{
366 std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
367 EXPECT_EQ(OkStatus(), context.SendServerStream(bad_payload));
368 EXPECT_FALSE(call.active());
369 EXPECT_EQ(responses_received_, 1);
370 EXPECT_EQ(rpc_error_, Status::DataLoss());
371 }
372
373 } // namespace
374 } // namespace pw::rpc
375
376 PW_MODIFY_DIAGNOSTICS_POP();
377