xref: /aosp_15_r20/external/pigweed/pw_rpc/pwpb/client_call_test.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2022 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include <optional>
16 
17 #include "pw_rpc/internal/test_utils.h"
18 #include "pw_rpc/pwpb/client_reader_writer.h"
19 #include "pw_rpc_pwpb_private/internal_test_utils.h"
20 #include "pw_rpc_test_protos/test.pwpb.h"
21 #include "pw_unit_test/framework.h"
22 
23 PW_MODIFY_DIAGNOSTICS_PUSH();
24 PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers");
25 
26 namespace pw::rpc {
27 namespace {
28 
29 using internal::ClientContextForTest;
30 using internal::pwpb::PacketType;
31 
32 namespace TestRequest = ::pw::rpc::test::pwpb::TestRequest;
33 namespace TestResponse = ::pw::rpc::test::pwpb::TestResponse;
34 namespace TestStreamResponse = ::pw::rpc::test::pwpb::TestStreamResponse;
35 
36 constexpr uint32_t kServiceId = 16;
37 constexpr uint32_t kUnaryMethodId = 111;
38 constexpr uint32_t kServerStreamingMethodId = 112;
39 
40 class FakeGeneratedServiceClient {
41  public:
TestUnaryRpc(Client & client,uint32_t channel_id,const TestRequest::Message & request,Function<void (const TestResponse::Message &,Status)> on_response,Function<void (Status)> on_error=nullptr)42   static PwpbUnaryReceiver<TestResponse::Message> TestUnaryRpc(
43       Client& client,
44       uint32_t channel_id,
45       const TestRequest::Message& request,
46       Function<void(const TestResponse::Message&, Status)> on_response,
47       Function<void(Status)> on_error = nullptr) {
48     return internal::PwpbUnaryResponseClientCall<TestResponse::Message>::Start<
49         PwpbUnaryReceiver<TestResponse::Message>>(
50         client,
51         channel_id,
52         kServiceId,
53         kUnaryMethodId,
54         internal::kPwpbMethodSerde<&TestRequest::kMessageFields,
55                                    &TestResponse::kMessageFields>,
56         std::move(on_response),
57         std::move(on_error),
58         request);
59   }
60 
TestAnotherUnaryRpc(Client & client,uint32_t channel_id,const TestRequest::Message & request,Function<void (const TestResponse::Message &,Status)> on_response,Function<void (Status)> on_error=nullptr)61   static PwpbUnaryReceiver<TestResponse::Message> TestAnotherUnaryRpc(
62       Client& client,
63       uint32_t channel_id,
64       const TestRequest::Message& request,
65       Function<void(const TestResponse::Message&, Status)> on_response,
66       Function<void(Status)> on_error = nullptr) {
67     return internal::PwpbUnaryResponseClientCall<TestResponse::Message>::Start<
68         PwpbUnaryReceiver<TestResponse::Message>>(
69         client,
70         channel_id,
71         kServiceId,
72         kUnaryMethodId,
73         internal::kPwpbMethodSerde<&TestRequest::kMessageFields,
74                                    &TestResponse::kMessageFields>,
75         std::move(on_response),
76         std::move(on_error),
77         request);
78   }
79 
TestServerStreamRpc(Client & client,uint32_t channel_id,const TestRequest::Message & request,Function<void (const TestStreamResponse::Message &)> on_response,Function<void (Status)> on_stream_end,Function<void (Status)> on_error=nullptr)80   static PwpbClientReader<TestStreamResponse::Message> TestServerStreamRpc(
81       Client& client,
82       uint32_t channel_id,
83       const TestRequest::Message& request,
84       Function<void(const TestStreamResponse::Message&)> on_response,
85       Function<void(Status)> on_stream_end,
86       Function<void(Status)> on_error = nullptr) {
87     return internal::PwpbStreamResponseClientCall<TestStreamResponse::Message>::
88         Start<PwpbClientReader<TestStreamResponse::Message>>(
89             client,
90             channel_id,
91             kServiceId,
92             kServerStreamingMethodId,
93             internal::kPwpbMethodSerde<&TestRequest::kMessageFields,
94                                        &TestStreamResponse::kMessageFields>,
95             std::move(on_response),
96             std::move(on_stream_end),
97             std::move(on_error),
98             request);
99   }
100 };
101 
TEST(PwpbClientCall,Unary_SendsRequestPacket)102 TEST(PwpbClientCall, Unary_SendsRequestPacket) {
103   ClientContextForTest context;
104 
105   auto call = FakeGeneratedServiceClient::TestUnaryRpc(
106       context.client(),
107       context.channel().id(),
108       {.integer = 123, .status_code = 0},
109       nullptr);
110 
111   EXPECT_EQ(context.output().total_packets(), 1u);
112   auto packet = context.output().last_packet();
113   EXPECT_EQ(packet.channel_id(), context.channel().id());
114   EXPECT_EQ(packet.service_id(), kServiceId);
115   EXPECT_EQ(packet.method_id(), kUnaryMethodId);
116 
117   PW_DECODE_PB(TestRequest, sent_proto, packet.payload());
118   EXPECT_EQ(sent_proto.integer, 123);
119 }
120 
121 class UnaryClientCall : public ::testing::Test {
122  protected:
123   std::optional<Status> last_status_;
124   std::optional<Status> last_error_;
125   int responses_received_ = 0;
126   int last_response_value_ = 0;
127 };
128 
TEST_F(UnaryClientCall,InvokesCallbackOnValidResponse)129 TEST_F(UnaryClientCall, InvokesCallbackOnValidResponse) {
130   ClientContextForTest context;
131 
132   auto call = FakeGeneratedServiceClient::TestUnaryRpc(
133       context.client(),
134       context.channel().id(),
135       {.integer = 123, .status_code = 0},
136       [this](const TestResponse::Message& response, Status status) {
137         ++responses_received_;
138         last_status_ = status;
139         last_response_value_ = response.value;
140       });
141 
142   PW_ENCODE_PB(TestResponse, response, .value = 42);
143   EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), response));
144 
145   ASSERT_EQ(responses_received_, 1);
146   EXPECT_EQ(last_status_, OkStatus());
147   EXPECT_EQ(last_response_value_, 42);
148 }
149 
TEST_F(UnaryClientCall,DoesNothingOnNullCallback)150 TEST_F(UnaryClientCall, DoesNothingOnNullCallback) {
151   ClientContextForTest context;
152 
153   auto call = FakeGeneratedServiceClient::TestUnaryRpc(
154       context.client(),
155       context.channel().id(),
156       {.integer = 123, .status_code = 0},
157       nullptr);
158 
159   PW_ENCODE_PB(TestResponse, response, .value = 42);
160   EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), response));
161 
162   ASSERT_EQ(responses_received_, 0);
163 }
164 
TEST_F(UnaryClientCall,InvokesErrorCallbackOnInvalidResponse)165 TEST_F(UnaryClientCall, InvokesErrorCallbackOnInvalidResponse) {
166   ClientContextForTest context;
167 
168   auto call = FakeGeneratedServiceClient::TestUnaryRpc(
169       context.client(),
170       context.channel().id(),
171       {.integer = 123, .status_code = 0},
172       [this](const TestResponse::Message& response, Status status) {
173         ++responses_received_;
174         last_status_ = status;
175         last_response_value_ = response.value;
176       },
177       [this](Status status) { last_error_ = status; });
178 
179   constexpr std::byte bad_payload[]{
180       std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
181   EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), bad_payload));
182 
183   EXPECT_EQ(responses_received_, 0);
184   ASSERT_TRUE(last_error_.has_value());
185   EXPECT_EQ(last_error_, Status::DataLoss());
186 }
187 
TEST_F(UnaryClientCall,InvokesErrorCallbackOnServerError)188 TEST_F(UnaryClientCall, InvokesErrorCallbackOnServerError) {
189   ClientContextForTest context;
190 
191   auto call = FakeGeneratedServiceClient::TestUnaryRpc(
192       context.client(),
193       context.channel().id(),
194       {.integer = 123, .status_code = 0},
195       [this](const TestResponse::Message& response, Status status) {
196         ++responses_received_;
197         last_status_ = status;
198         last_response_value_ = response.value;
199       },
200       [this](Status status) { last_error_ = status; });
201 
202   EXPECT_EQ(OkStatus(),
203             context.SendPacket(PacketType::SERVER_ERROR, Status::NotFound()));
204 
205   EXPECT_EQ(responses_received_, 0);
206   EXPECT_EQ(last_error_, Status::NotFound());
207 }
208 
TEST_F(UnaryClientCall,DoesNothingOnErrorWithoutCallback)209 TEST_F(UnaryClientCall, DoesNothingOnErrorWithoutCallback) {
210   ClientContextForTest context;
211 
212   auto call = FakeGeneratedServiceClient::TestUnaryRpc(
213       context.client(),
214       context.channel().id(),
215       {.integer = 123, .status_code = 0},
216       [this](const TestResponse::Message& response, Status status) {
217         ++responses_received_;
218         last_status_ = status;
219         last_response_value_ = response.value;
220       });
221 
222   constexpr std::byte bad_payload[]{
223       std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
224   EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), bad_payload));
225 
226   EXPECT_EQ(responses_received_, 0);
227 }
228 
TEST_F(UnaryClientCall,OnlyReceivesOneResponse)229 TEST_F(UnaryClientCall, OnlyReceivesOneResponse) {
230   ClientContextForTest context;
231 
232   auto call = FakeGeneratedServiceClient::TestUnaryRpc(
233       context.client(),
234       context.channel().id(),
235       {.integer = 123, .status_code = 0},
236       [this](const TestResponse::Message& response, Status status) {
237         ++responses_received_;
238         last_status_ = status;
239         last_response_value_ = response.value;
240       });
241 
242   PW_ENCODE_PB(TestResponse, r1, .value = 42);
243   EXPECT_EQ(OkStatus(), context.SendResponse(Status::Unimplemented(), r1));
244   PW_ENCODE_PB(TestResponse, r2, .value = 44);
245   EXPECT_EQ(OkStatus(), context.SendResponse(Status::OutOfRange(), r2));
246   PW_ENCODE_PB(TestResponse, r3, .value = 46);
247   EXPECT_EQ(OkStatus(), context.SendResponse(Status::Internal(), r3));
248 
249   EXPECT_EQ(responses_received_, 1);
250   EXPECT_EQ(last_status_, Status::Unimplemented());
251   EXPECT_EQ(last_response_value_, 42);
252 }
253 
254 class ServerStreamingClientCall : public ::testing::Test {
255  protected:
256   std::optional<Status> stream_status_;
257   std::optional<Status> rpc_error_;
258   int responses_received_ = 0;
259   int last_response_number_ = 0;
260 };
261 
TEST_F(ServerStreamingClientCall,SendsRequestPacket)262 TEST_F(ServerStreamingClientCall, SendsRequestPacket) {
263   ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
264 
265   auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
266       context.client(),
267       context.channel().id(),
268       {.integer = 71, .status_code = 0},
269       nullptr,
270       nullptr);
271 
272   EXPECT_EQ(context.output().total_packets(), 1u);
273   auto packet = context.output().last_packet();
274   EXPECT_EQ(packet.channel_id(), context.channel().id());
275   EXPECT_EQ(packet.service_id(), kServiceId);
276   EXPECT_EQ(packet.method_id(), kServerStreamingMethodId);
277 
278   PW_DECODE_PB(TestRequest, sent_proto, packet.payload());
279   EXPECT_EQ(sent_proto.integer, 71);
280 }
281 
TEST_F(ServerStreamingClientCall,InvokesCallbackOnValidResponse)282 TEST_F(ServerStreamingClientCall, InvokesCallbackOnValidResponse) {
283   ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
284 
285   auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
286       context.client(),
287       context.channel().id(),
288       {.integer = 71, .status_code = 0},
289       [this](const TestStreamResponse::Message& response) {
290         ++responses_received_;
291         last_response_number_ = response.number;
292       },
293       [this](Status status) { stream_status_ = status; });
294 
295   PW_ENCODE_PB(TestStreamResponse, r1, .chunk = {}, .number = 11u);
296   EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
297   EXPECT_TRUE(call.active());
298   EXPECT_EQ(responses_received_, 1);
299   EXPECT_EQ(last_response_number_, 11);
300 
301   PW_ENCODE_PB(TestStreamResponse, r2, .chunk = {}, .number = 22u);
302   EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
303   EXPECT_TRUE(call.active());
304   EXPECT_EQ(responses_received_, 2);
305   EXPECT_EQ(last_response_number_, 22);
306 
307   PW_ENCODE_PB(TestStreamResponse, r3, .chunk = {}, .number = 33u);
308   EXPECT_EQ(OkStatus(), context.SendServerStream(r3));
309   EXPECT_TRUE(call.active());
310   EXPECT_EQ(responses_received_, 3);
311   EXPECT_EQ(last_response_number_, 33);
312 }
313 
TEST_F(ServerStreamingClientCall,InvokesStreamEndOnFinish)314 TEST_F(ServerStreamingClientCall, InvokesStreamEndOnFinish) {
315   ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
316 
317   auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
318       context.client(),
319       context.channel().id(),
320       {.integer = 71, .status_code = 0},
321       [this](const TestStreamResponse::Message& response) {
322         ++responses_received_;
323         last_response_number_ = response.number;
324       },
325       [this](Status status) { stream_status_ = status; });
326 
327   PW_ENCODE_PB(TestStreamResponse, r1, .chunk = {}, .number = 11u);
328   EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
329   EXPECT_TRUE(call.active());
330 
331   PW_ENCODE_PB(TestStreamResponse, r2, .chunk = {}, .number = 22u);
332   EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
333   EXPECT_TRUE(call.active());
334 
335   // Close the stream.
336   EXPECT_EQ(OkStatus(), context.SendResponse(Status::NotFound()));
337 
338   PW_ENCODE_PB(TestStreamResponse, r3, .chunk = {}, .number = 33u);
339   EXPECT_EQ(OkStatus(), context.SendServerStream(r3));
340   EXPECT_FALSE(call.active());
341 
342   EXPECT_EQ(responses_received_, 2);
343 }
344 
TEST_F(ServerStreamingClientCall,ParseErrorTerminatesCallWithDataLoss)345 TEST_F(ServerStreamingClientCall, ParseErrorTerminatesCallWithDataLoss) {
346   ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;
347 
348   auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
349       context.client(),
350       context.channel().id(),
351       {.integer = 71, .status_code = 0},
352       [this](const TestStreamResponse::Message& response) {
353         ++responses_received_;
354         last_response_number_ = response.number;
355       },
356       nullptr,
357       [this](Status error) { rpc_error_ = error; });
358 
359   PW_ENCODE_PB(TestStreamResponse, r1, .chunk = {}, .number = 11u);
360   EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
361   EXPECT_TRUE(call.active());
362   EXPECT_EQ(responses_received_, 1);
363   EXPECT_EQ(last_response_number_, 11);
364 
365   constexpr std::byte bad_payload[]{
366       std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
367   EXPECT_EQ(OkStatus(), context.SendServerStream(bad_payload));
368   EXPECT_FALSE(call.active());
369   EXPECT_EQ(responses_received_, 1);
370   EXPECT_EQ(rpc_error_, Status::DataLoss());
371 }
372 
373 }  // namespace
374 }  // namespace pw::rpc
375 
376 PW_MODIFY_DIAGNOSTICS_POP();
377