1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include <optional>
16
17 #include "pw_rpc/internal/test_utils.h"
18 #include "pw_rpc/nanopb/client_reader_writer.h"
19 #include "pw_rpc_nanopb_private/internal_test_utils.h"
20 #include "pw_rpc_test_protos/test.pb.h"
21 #include "pw_unit_test/framework.h"
22
23 PW_MODIFY_DIAGNOSTICS_PUSH();
24 PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers");
25
26 namespace pw::rpc {
27 namespace {
28
29 using internal::ClientContextForTest;
30
31 constexpr uint32_t kServiceId = 16;
32 constexpr uint32_t kUnaryMethodId = 111;
33 constexpr uint32_t kServerStreamingMethodId = 112;
34
35 class FakeGeneratedServiceClient {
36 public:
TestUnaryRpc(Client & client,uint32_t channel_id,const pw_rpc_test_TestRequest & request,Function<void (const pw_rpc_test_TestResponse &,Status)> on_response,Function<void (Status)> on_error=nullptr)37 static NanopbUnaryReceiver<pw_rpc_test_TestResponse> TestUnaryRpc(
38 Client& client,
39 uint32_t channel_id,
40 const pw_rpc_test_TestRequest& request,
41 Function<void(const pw_rpc_test_TestResponse&, Status)> on_response,
42 Function<void(Status)> on_error = nullptr) {
43 return pw::rpc::internal::
44 NanopbUnaryResponseClientCall<pw_rpc_test_TestResponse>::Start<
45 pw::rpc::NanopbUnaryReceiver<pw_rpc_test_TestResponse>>(
46 client,
47 channel_id,
48 kServiceId,
49 kUnaryMethodId,
50 internal::kNanopbMethodSerde<pw_rpc_test_TestRequest_fields,
51 pw_rpc_test_TestResponse_fields>,
52 std::move(on_response),
53 std::move(on_error),
54 request);
55 }
56
TestAnotherUnaryRpc(Client & client,uint32_t channel_id,const pw_rpc_test_TestRequest & request,Function<void (const pw_rpc_test_TestResponse &,Status)> on_response,Function<void (Status)> on_error=nullptr)57 static NanopbUnaryReceiver<pw_rpc_test_TestResponse> TestAnotherUnaryRpc(
58 Client& client,
59 uint32_t channel_id,
60 const pw_rpc_test_TestRequest& request,
61 Function<void(const pw_rpc_test_TestResponse&, Status)> on_response,
62 Function<void(Status)> on_error = nullptr) {
63 return pw::rpc::internal::
64 NanopbUnaryResponseClientCall<pw_rpc_test_TestResponse>::Start<
65 pw::rpc::NanopbUnaryReceiver<pw_rpc_test_TestResponse>>(
66 client,
67 channel_id,
68 kServiceId,
69 kUnaryMethodId,
70 internal::kNanopbMethodSerde<pw_rpc_test_TestRequest_fields,
71 pw_rpc_test_TestResponse_fields>,
72 std::move(on_response),
73 std::move(on_error),
74 request);
75 }
76
TestServerStreamRpc(Client & client,uint32_t channel_id,const pw_rpc_test_TestRequest & request,Function<void (const pw_rpc_test_TestStreamResponse &)> on_response,Function<void (Status)> on_stream_end,Function<void (Status)> on_error=nullptr)77 static NanopbClientReader<pw_rpc_test_TestStreamResponse> TestServerStreamRpc(
78 Client& client,
79 uint32_t channel_id,
80 const pw_rpc_test_TestRequest& request,
81 Function<void(const pw_rpc_test_TestStreamResponse&)> on_response,
82 Function<void(Status)> on_stream_end,
83 Function<void(Status)> on_error = nullptr) {
84 return pw::rpc::internal::
85 NanopbStreamResponseClientCall<pw_rpc_test_TestStreamResponse>::Start<
86 pw::rpc::NanopbClientReader<pw_rpc_test_TestStreamResponse>>(
87 client,
88 channel_id,
89 kServiceId,
90 kServerStreamingMethodId,
91 internal::kNanopbMethodSerde<pw_rpc_test_TestRequest_fields,
92 pw_rpc_test_TestStreamResponse_fields>,
93 std::move(on_response),
94 std::move(on_stream_end),
95 std::move(on_error),
96 request);
97 }
98 };
99
TEST(NanopbClientCall,Unary_SendsRequestPacket)100 TEST(NanopbClientCall, Unary_SendsRequestPacket) {
101 ClientContextForTest context;
102
103 auto call = FakeGeneratedServiceClient::TestUnaryRpc(
104 context.client(),
105 context.channel().id(),
106 {.integer = 123, .status_code = 0},
107 nullptr);
108
109 EXPECT_EQ(context.output().total_packets(), 1u);
110 auto packet = context.output().last_packet();
111 EXPECT_EQ(packet.channel_id(), context.channel().id());
112 EXPECT_EQ(packet.service_id(), kServiceId);
113 EXPECT_EQ(packet.method_id(), kUnaryMethodId);
114
115 PW_DECODE_PB(pw_rpc_test_TestRequest, sent_proto, packet.payload());
116 EXPECT_EQ(sent_proto.integer, 123);
117 }
118
119 class UnaryClientCall : public ::testing::Test {
120 protected:
121 std::optional<Status> last_status_;
122 std::optional<Status> last_error_;
123 int responses_received_ = 0;
124 int last_response_value_ = 0;
125 };
126
TEST_F(UnaryClientCall,InvokesCallbackOnValidResponse)127 TEST_F(UnaryClientCall, InvokesCallbackOnValidResponse) {
128 ClientContextForTest context;
129
130 auto call = FakeGeneratedServiceClient::TestUnaryRpc(
131 context.client(),
132 context.channel().id(),
133 {.integer = 123, .status_code = 0},
134 [this](const pw_rpc_test_TestResponse& response, Status status) {
135 ++responses_received_;
136 last_status_ = status;
137 last_response_value_ = response.value;
138 });
139
140 PW_ENCODE_PB(pw_rpc_test_TestResponse, response, .value = 42);
141 EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), response));
142
143 ASSERT_EQ(responses_received_, 1);
144 EXPECT_EQ(last_status_, OkStatus());
145 EXPECT_EQ(last_response_value_, 42);
146 }
147
TEST_F(UnaryClientCall,DoesNothingOnNullCallback)148 TEST_F(UnaryClientCall, DoesNothingOnNullCallback) {
149 ClientContextForTest context;
150
151 auto call = FakeGeneratedServiceClient::TestUnaryRpc(
152 context.client(),
153 context.channel().id(),
154 {.integer = 123, .status_code = 0},
155 nullptr);
156
157 PW_ENCODE_PB(pw_rpc_test_TestResponse, response, .value = 42);
158 EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), response));
159
160 ASSERT_EQ(responses_received_, 0);
161 }
162
TEST_F(UnaryClientCall,InvokesErrorCallbackOnInvalidResponse)163 TEST_F(UnaryClientCall, InvokesErrorCallbackOnInvalidResponse) {
164 ClientContextForTest context;
165
166 auto call = FakeGeneratedServiceClient::TestUnaryRpc(
167 context.client(),
168 context.channel().id(),
169 {.integer = 123, .status_code = 0},
170 [this](const pw_rpc_test_TestResponse& response, Status status) {
171 ++responses_received_;
172 last_status_ = status;
173 last_response_value_ = response.value;
174 },
175 [this](Status status) { last_error_ = status; });
176
177 constexpr std::byte bad_payload[]{
178 std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
179 EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), bad_payload));
180
181 EXPECT_EQ(responses_received_, 0);
182 ASSERT_TRUE(last_error_.has_value());
183 EXPECT_EQ(last_error_, Status::DataLoss());
184 }
185
TEST_F(UnaryClientCall,InvokesErrorCallbackOnServerError)186 TEST_F(UnaryClientCall, InvokesErrorCallbackOnServerError) {
187 ClientContextForTest context;
188
189 auto call = FakeGeneratedServiceClient::TestUnaryRpc(
190 context.client(),
191 context.channel().id(),
192 {.integer = 123, .status_code = 0},
193 [this](const pw_rpc_test_TestResponse& response, Status status) {
194 ++responses_received_;
195 last_status_ = status;
196 last_response_value_ = response.value;
197 },
198 [this](Status status) { last_error_ = status; });
199
200 EXPECT_EQ(OkStatus(),
201 context.SendPacket(internal::pwpb::PacketType::SERVER_ERROR,
202 Status::NotFound()));
203
204 EXPECT_EQ(responses_received_, 0);
205 EXPECT_EQ(last_error_, Status::NotFound());
206 }
207
TEST_F(UnaryClientCall,DoesNothingOnErrorWithoutCallback)208 TEST_F(UnaryClientCall, DoesNothingOnErrorWithoutCallback) {
209 ClientContextForTest context;
210
211 auto call = FakeGeneratedServiceClient::TestUnaryRpc(
212 context.client(),
213 context.channel().id(),
214 {.integer = 123, .status_code = 0},
215 [this](const pw_rpc_test_TestResponse& response, Status status) {
216 ++responses_received_;
217 last_status_ = status;
218 last_response_value_ = response.value;
219 });
220
221 constexpr std::byte bad_payload[]{
222 std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
223 EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), bad_payload));
224
225 EXPECT_EQ(responses_received_, 0);
226 }
227
TEST_F(UnaryClientCall,OnlyReceivesOneResponse)228 TEST_F(UnaryClientCall, OnlyReceivesOneResponse) {
229 ClientContextForTest context;
230
231 auto call = FakeGeneratedServiceClient::TestUnaryRpc(
232 context.client(),
233 context.channel().id(),
234 {.integer = 123, .status_code = 0},
235 [this](const pw_rpc_test_TestResponse& response, Status status) {
236 ++responses_received_;
237 last_status_ = status;
238 last_response_value_ = response.value;
239 });
240
241 PW_ENCODE_PB(pw_rpc_test_TestResponse, r1, .value = 42);
242 EXPECT_EQ(OkStatus(), context.SendResponse(Status::Unimplemented(), r1));
243 PW_ENCODE_PB(pw_rpc_test_TestResponse, r2, .value = 44);
244 EXPECT_EQ(OkStatus(), context.SendResponse(Status::OutOfRange(), r2));
245 PW_ENCODE_PB(pw_rpc_test_TestResponse, r3, .value = 46);
246 EXPECT_EQ(OkStatus(), context.SendResponse(Status::Internal(), r3));
247
248 EXPECT_EQ(responses_received_, 1);
249 EXPECT_EQ(last_status_, Status::Unimplemented());
250 EXPECT_EQ(last_response_value_, 42);
251 }
252
253 class ServerStreamingClientCall : public ::testing::Test {
254 protected:
255 std::optional<Status> stream_status_;
256 std::optional<Status> rpc_error_;
257 int responses_received_ = 0;
258 int last_response_number_ = 0;
259 };
260
TEST_F(ServerStreamingClientCall,SendsRequestPacket)261 TEST_F(ServerStreamingClientCall, SendsRequestPacket) {
262 ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
263
264 auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
265 context.client(),
266 context.channel().id(),
267 {.integer = 71, .status_code = 0},
268 nullptr,
269 nullptr);
270
271 EXPECT_EQ(context.output().total_packets(), 1u);
272 auto packet = context.output().last_packet();
273 EXPECT_EQ(packet.channel_id(), context.channel().id());
274 EXPECT_EQ(packet.service_id(), kServiceId);
275 EXPECT_EQ(packet.method_id(), kServerStreamingMethodId);
276
277 PW_DECODE_PB(pw_rpc_test_TestRequest, sent_proto, packet.payload());
278 EXPECT_EQ(sent_proto.integer, 71);
279 }
280
TEST_F(ServerStreamingClientCall,InvokesCallbackOnValidResponse)281 TEST_F(ServerStreamingClientCall, InvokesCallbackOnValidResponse) {
282 ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
283
284 auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
285 context.client(),
286 context.channel().id(),
287 {.integer = 71, .status_code = 0},
288 [this](const pw_rpc_test_TestStreamResponse& response) {
289 ++responses_received_;
290 last_response_number_ = response.number;
291 },
292 [this](Status status) { stream_status_ = status; });
293
294 PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
295 EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
296 EXPECT_TRUE(call.active());
297 EXPECT_EQ(responses_received_, 1);
298 EXPECT_EQ(last_response_number_, 11);
299
300 PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u);
301 EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
302 EXPECT_TRUE(call.active());
303 EXPECT_EQ(responses_received_, 2);
304 EXPECT_EQ(last_response_number_, 22);
305
306 PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r3, .chunk = {}, .number = 33u);
307 EXPECT_EQ(OkStatus(), context.SendServerStream(r3));
308 EXPECT_TRUE(call.active());
309 EXPECT_EQ(responses_received_, 3);
310 EXPECT_EQ(last_response_number_, 33);
311 }
312
TEST_F(ServerStreamingClientCall,InvokesStreamEndOnFinish)313 TEST_F(ServerStreamingClientCall, InvokesStreamEndOnFinish) {
314 ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
315
316 auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
317 context.client(),
318 context.channel().id(),
319 {.integer = 71, .status_code = 0},
320 [this](const pw_rpc_test_TestStreamResponse& response) {
321 ++responses_received_;
322 last_response_number_ = response.number;
323 },
324 [this](Status status) { stream_status_ = status; });
325
326 PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
327 EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
328 EXPECT_TRUE(call.active());
329
330 PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u);
331 EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
332 EXPECT_TRUE(call.active());
333
334 // Close the stream.
335 EXPECT_EQ(OkStatus(), context.SendResponse(Status::NotFound()));
336
337 PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r3, .chunk = {}, .number = 33u);
338 EXPECT_EQ(OkStatus(), context.SendServerStream(r3));
339 EXPECT_FALSE(call.active());
340
341 EXPECT_EQ(responses_received_, 2);
342 }
343
TEST_F(ServerStreamingClientCall,ParseErrorTerminatesCallWithDataLoss)344 TEST_F(ServerStreamingClientCall, ParseErrorTerminatesCallWithDataLoss) {
345 ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
346
347 auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
348 context.client(),
349 context.channel().id(),
350 {.integer = 71, .status_code = 0},
351 [this](const pw_rpc_test_TestStreamResponse& response) {
352 ++responses_received_;
353 last_response_number_ = response.number;
354 },
355 nullptr,
356 [this](Status error) { rpc_error_ = error; });
357
358 PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
359 EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
360 EXPECT_TRUE(call.active());
361 EXPECT_EQ(responses_received_, 1);
362 EXPECT_EQ(last_response_number_, 11);
363
364 constexpr std::byte bad_payload[]{
365 std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
366 EXPECT_EQ(OkStatus(), context.SendServerStream(bad_payload));
367 EXPECT_FALSE(call.active());
368 EXPECT_EQ(responses_received_, 1);
369 EXPECT_EQ(rpc_error_, Status::DataLoss());
370 }
371
372 } // namespace
373 } // namespace pw::rpc
374
375 PW_MODIFY_DIAGNOSTICS_POP();
376