xref: /aosp_15_r20/external/pigweed/pw_rpc/nanopb/client_call_test.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include <optional>
16 
17 #include "pw_rpc/internal/test_utils.h"
18 #include "pw_rpc/nanopb/client_reader_writer.h"
19 #include "pw_rpc_nanopb_private/internal_test_utils.h"
20 #include "pw_rpc_test_protos/test.pb.h"
21 #include "pw_unit_test/framework.h"
22 
23 PW_MODIFY_DIAGNOSTICS_PUSH();
24 PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers");
25 
26 namespace pw::rpc {
27 namespace {
28 
29 using internal::ClientContextForTest;
30 
31 constexpr uint32_t kServiceId = 16;
32 constexpr uint32_t kUnaryMethodId = 111;
33 constexpr uint32_t kServerStreamingMethodId = 112;
34 
35 class FakeGeneratedServiceClient {
36  public:
TestUnaryRpc(Client & client,uint32_t channel_id,const pw_rpc_test_TestRequest & request,Function<void (const pw_rpc_test_TestResponse &,Status)> on_response,Function<void (Status)> on_error=nullptr)37   static NanopbUnaryReceiver<pw_rpc_test_TestResponse> TestUnaryRpc(
38       Client& client,
39       uint32_t channel_id,
40       const pw_rpc_test_TestRequest& request,
41       Function<void(const pw_rpc_test_TestResponse&, Status)> on_response,
42       Function<void(Status)> on_error = nullptr) {
43     return pw::rpc::internal::
44         NanopbUnaryResponseClientCall<pw_rpc_test_TestResponse>::Start<
45             pw::rpc::NanopbUnaryReceiver<pw_rpc_test_TestResponse>>(
46             client,
47             channel_id,
48             kServiceId,
49             kUnaryMethodId,
50             internal::kNanopbMethodSerde<pw_rpc_test_TestRequest_fields,
51                                          pw_rpc_test_TestResponse_fields>,
52             std::move(on_response),
53             std::move(on_error),
54             request);
55   }
56 
TestAnotherUnaryRpc(Client & client,uint32_t channel_id,const pw_rpc_test_TestRequest & request,Function<void (const pw_rpc_test_TestResponse &,Status)> on_response,Function<void (Status)> on_error=nullptr)57   static NanopbUnaryReceiver<pw_rpc_test_TestResponse> TestAnotherUnaryRpc(
58       Client& client,
59       uint32_t channel_id,
60       const pw_rpc_test_TestRequest& request,
61       Function<void(const pw_rpc_test_TestResponse&, Status)> on_response,
62       Function<void(Status)> on_error = nullptr) {
63     return pw::rpc::internal::
64         NanopbUnaryResponseClientCall<pw_rpc_test_TestResponse>::Start<
65             pw::rpc::NanopbUnaryReceiver<pw_rpc_test_TestResponse>>(
66             client,
67             channel_id,
68             kServiceId,
69             kUnaryMethodId,
70             internal::kNanopbMethodSerde<pw_rpc_test_TestRequest_fields,
71                                          pw_rpc_test_TestResponse_fields>,
72             std::move(on_response),
73             std::move(on_error),
74             request);
75   }
76 
TestServerStreamRpc(Client & client,uint32_t channel_id,const pw_rpc_test_TestRequest & request,Function<void (const pw_rpc_test_TestStreamResponse &)> on_response,Function<void (Status)> on_stream_end,Function<void (Status)> on_error=nullptr)77   static NanopbClientReader<pw_rpc_test_TestStreamResponse> TestServerStreamRpc(
78       Client& client,
79       uint32_t channel_id,
80       const pw_rpc_test_TestRequest& request,
81       Function<void(const pw_rpc_test_TestStreamResponse&)> on_response,
82       Function<void(Status)> on_stream_end,
83       Function<void(Status)> on_error = nullptr) {
84     return pw::rpc::internal::
85         NanopbStreamResponseClientCall<pw_rpc_test_TestStreamResponse>::Start<
86             pw::rpc::NanopbClientReader<pw_rpc_test_TestStreamResponse>>(
87             client,
88             channel_id,
89             kServiceId,
90             kServerStreamingMethodId,
91             internal::kNanopbMethodSerde<pw_rpc_test_TestRequest_fields,
92                                          pw_rpc_test_TestStreamResponse_fields>,
93             std::move(on_response),
94             std::move(on_stream_end),
95             std::move(on_error),
96             request);
97   }
98 };
99 
TEST(NanopbClientCall,Unary_SendsRequestPacket)100 TEST(NanopbClientCall, Unary_SendsRequestPacket) {
101   ClientContextForTest context;
102 
103   auto call = FakeGeneratedServiceClient::TestUnaryRpc(
104       context.client(),
105       context.channel().id(),
106       {.integer = 123, .status_code = 0},
107       nullptr);
108 
109   EXPECT_EQ(context.output().total_packets(), 1u);
110   auto packet = context.output().last_packet();
111   EXPECT_EQ(packet.channel_id(), context.channel().id());
112   EXPECT_EQ(packet.service_id(), kServiceId);
113   EXPECT_EQ(packet.method_id(), kUnaryMethodId);
114 
115   PW_DECODE_PB(pw_rpc_test_TestRequest, sent_proto, packet.payload());
116   EXPECT_EQ(sent_proto.integer, 123);
117 }
118 
119 class UnaryClientCall : public ::testing::Test {
120  protected:
121   std::optional<Status> last_status_;
122   std::optional<Status> last_error_;
123   int responses_received_ = 0;
124   int last_response_value_ = 0;
125 };
126 
TEST_F(UnaryClientCall,InvokesCallbackOnValidResponse)127 TEST_F(UnaryClientCall, InvokesCallbackOnValidResponse) {
128   ClientContextForTest context;
129 
130   auto call = FakeGeneratedServiceClient::TestUnaryRpc(
131       context.client(),
132       context.channel().id(),
133       {.integer = 123, .status_code = 0},
134       [this](const pw_rpc_test_TestResponse& response, Status status) {
135         ++responses_received_;
136         last_status_ = status;
137         last_response_value_ = response.value;
138       });
139 
140   PW_ENCODE_PB(pw_rpc_test_TestResponse, response, .value = 42);
141   EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), response));
142 
143   ASSERT_EQ(responses_received_, 1);
144   EXPECT_EQ(last_status_, OkStatus());
145   EXPECT_EQ(last_response_value_, 42);
146 }
147 
TEST_F(UnaryClientCall,DoesNothingOnNullCallback)148 TEST_F(UnaryClientCall, DoesNothingOnNullCallback) {
149   ClientContextForTest context;
150 
151   auto call = FakeGeneratedServiceClient::TestUnaryRpc(
152       context.client(),
153       context.channel().id(),
154       {.integer = 123, .status_code = 0},
155       nullptr);
156 
157   PW_ENCODE_PB(pw_rpc_test_TestResponse, response, .value = 42);
158   EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), response));
159 
160   ASSERT_EQ(responses_received_, 0);
161 }
162 
TEST_F(UnaryClientCall,InvokesErrorCallbackOnInvalidResponse)163 TEST_F(UnaryClientCall, InvokesErrorCallbackOnInvalidResponse) {
164   ClientContextForTest context;
165 
166   auto call = FakeGeneratedServiceClient::TestUnaryRpc(
167       context.client(),
168       context.channel().id(),
169       {.integer = 123, .status_code = 0},
170       [this](const pw_rpc_test_TestResponse& response, Status status) {
171         ++responses_received_;
172         last_status_ = status;
173         last_response_value_ = response.value;
174       },
175       [this](Status status) { last_error_ = status; });
176 
177   constexpr std::byte bad_payload[]{
178       std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
179   EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), bad_payload));
180 
181   EXPECT_EQ(responses_received_, 0);
182   ASSERT_TRUE(last_error_.has_value());
183   EXPECT_EQ(last_error_, Status::DataLoss());
184 }
185 
TEST_F(UnaryClientCall,InvokesErrorCallbackOnServerError)186 TEST_F(UnaryClientCall, InvokesErrorCallbackOnServerError) {
187   ClientContextForTest context;
188 
189   auto call = FakeGeneratedServiceClient::TestUnaryRpc(
190       context.client(),
191       context.channel().id(),
192       {.integer = 123, .status_code = 0},
193       [this](const pw_rpc_test_TestResponse& response, Status status) {
194         ++responses_received_;
195         last_status_ = status;
196         last_response_value_ = response.value;
197       },
198       [this](Status status) { last_error_ = status; });
199 
200   EXPECT_EQ(OkStatus(),
201             context.SendPacket(internal::pwpb::PacketType::SERVER_ERROR,
202                                Status::NotFound()));
203 
204   EXPECT_EQ(responses_received_, 0);
205   EXPECT_EQ(last_error_, Status::NotFound());
206 }
207 
TEST_F(UnaryClientCall,DoesNothingOnErrorWithoutCallback)208 TEST_F(UnaryClientCall, DoesNothingOnErrorWithoutCallback) {
209   ClientContextForTest context;
210 
211   auto call = FakeGeneratedServiceClient::TestUnaryRpc(
212       context.client(),
213       context.channel().id(),
214       {.integer = 123, .status_code = 0},
215       [this](const pw_rpc_test_TestResponse& response, Status status) {
216         ++responses_received_;
217         last_status_ = status;
218         last_response_value_ = response.value;
219       });
220 
221   constexpr std::byte bad_payload[]{
222       std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
223   EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), bad_payload));
224 
225   EXPECT_EQ(responses_received_, 0);
226 }
227 
TEST_F(UnaryClientCall,OnlyReceivesOneResponse)228 TEST_F(UnaryClientCall, OnlyReceivesOneResponse) {
229   ClientContextForTest context;
230 
231   auto call = FakeGeneratedServiceClient::TestUnaryRpc(
232       context.client(),
233       context.channel().id(),
234       {.integer = 123, .status_code = 0},
235       [this](const pw_rpc_test_TestResponse& response, Status status) {
236         ++responses_received_;
237         last_status_ = status;
238         last_response_value_ = response.value;
239       });
240 
241   PW_ENCODE_PB(pw_rpc_test_TestResponse, r1, .value = 42);
242   EXPECT_EQ(OkStatus(), context.SendResponse(Status::Unimplemented(), r1));
243   PW_ENCODE_PB(pw_rpc_test_TestResponse, r2, .value = 44);
244   EXPECT_EQ(OkStatus(), context.SendResponse(Status::OutOfRange(), r2));
245   PW_ENCODE_PB(pw_rpc_test_TestResponse, r3, .value = 46);
246   EXPECT_EQ(OkStatus(), context.SendResponse(Status::Internal(), r3));
247 
248   EXPECT_EQ(responses_received_, 1);
249   EXPECT_EQ(last_status_, Status::Unimplemented());
250   EXPECT_EQ(last_response_value_, 42);
251 }
252 
253 class ServerStreamingClientCall : public ::testing::Test {
254  protected:
255   std::optional<Status> stream_status_;
256   std::optional<Status> rpc_error_;
257   int responses_received_ = 0;
258   int last_response_number_ = 0;
259 };
260 
TEST_F(ServerStreamingClientCall,SendsRequestPacket)261 TEST_F(ServerStreamingClientCall, SendsRequestPacket) {
262   ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
263 
264   auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
265       context.client(),
266       context.channel().id(),
267       {.integer = 71, .status_code = 0},
268       nullptr,
269       nullptr);
270 
271   EXPECT_EQ(context.output().total_packets(), 1u);
272   auto packet = context.output().last_packet();
273   EXPECT_EQ(packet.channel_id(), context.channel().id());
274   EXPECT_EQ(packet.service_id(), kServiceId);
275   EXPECT_EQ(packet.method_id(), kServerStreamingMethodId);
276 
277   PW_DECODE_PB(pw_rpc_test_TestRequest, sent_proto, packet.payload());
278   EXPECT_EQ(sent_proto.integer, 71);
279 }
280 
TEST_F(ServerStreamingClientCall,InvokesCallbackOnValidResponse)281 TEST_F(ServerStreamingClientCall, InvokesCallbackOnValidResponse) {
282   ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
283 
284   auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
285       context.client(),
286       context.channel().id(),
287       {.integer = 71, .status_code = 0},
288       [this](const pw_rpc_test_TestStreamResponse& response) {
289         ++responses_received_;
290         last_response_number_ = response.number;
291       },
292       [this](Status status) { stream_status_ = status; });
293 
294   PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
295   EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
296   EXPECT_TRUE(call.active());
297   EXPECT_EQ(responses_received_, 1);
298   EXPECT_EQ(last_response_number_, 11);
299 
300   PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u);
301   EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
302   EXPECT_TRUE(call.active());
303   EXPECT_EQ(responses_received_, 2);
304   EXPECT_EQ(last_response_number_, 22);
305 
306   PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r3, .chunk = {}, .number = 33u);
307   EXPECT_EQ(OkStatus(), context.SendServerStream(r3));
308   EXPECT_TRUE(call.active());
309   EXPECT_EQ(responses_received_, 3);
310   EXPECT_EQ(last_response_number_, 33);
311 }
312 
TEST_F(ServerStreamingClientCall,InvokesStreamEndOnFinish)313 TEST_F(ServerStreamingClientCall, InvokesStreamEndOnFinish) {
314   ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
315 
316   auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
317       context.client(),
318       context.channel().id(),
319       {.integer = 71, .status_code = 0},
320       [this](const pw_rpc_test_TestStreamResponse& response) {
321         ++responses_received_;
322         last_response_number_ = response.number;
323       },
324       [this](Status status) { stream_status_ = status; });
325 
326   PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
327   EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
328   EXPECT_TRUE(call.active());
329 
330   PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u);
331   EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
332   EXPECT_TRUE(call.active());
333 
334   // Close the stream.
335   EXPECT_EQ(OkStatus(), context.SendResponse(Status::NotFound()));
336 
337   PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r3, .chunk = {}, .number = 33u);
338   EXPECT_EQ(OkStatus(), context.SendServerStream(r3));
339   EXPECT_FALSE(call.active());
340 
341   EXPECT_EQ(responses_received_, 2);
342 }
343 
TEST_F(ServerStreamingClientCall,ParseErrorTerminatesCallWithDataLoss)344 TEST_F(ServerStreamingClientCall, ParseErrorTerminatesCallWithDataLoss) {
345   ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
346 
347   auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
348       context.client(),
349       context.channel().id(),
350       {.integer = 71, .status_code = 0},
351       [this](const pw_rpc_test_TestStreamResponse& response) {
352         ++responses_received_;
353         last_response_number_ = response.number;
354       },
355       nullptr,
356       [this](Status error) { rpc_error_ = error; });
357 
358   PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
359   EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
360   EXPECT_TRUE(call.active());
361   EXPECT_EQ(responses_received_, 1);
362   EXPECT_EQ(last_response_number_, 11);
363 
364   constexpr std::byte bad_payload[]{
365       std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
366   EXPECT_EQ(OkStatus(), context.SendServerStream(bad_payload));
367   EXPECT_FALSE(call.active());
368   EXPECT_EQ(responses_received_, 1);
369   EXPECT_EQ(rpc_error_, Status::DataLoss());
370 }
371 
372 }  // namespace
373 }  // namespace pw::rpc
374 
375 PW_MODIFY_DIAGNOSTICS_POP();
376