// Copyright 2022 The Pigweed Authors // // Licensed under the Apache License, Version 2.0 (the "License"); you may not // use this file except in compliance with the License. You may obtain a copy of // the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the // License for the specific language governing permissions and limitations under // the License. #include "pw_rpc/pwpb/internal/method.h" #include #include "pw_containers/algorithm.h" #include "pw_rpc/internal/lock.h" #include "pw_rpc/internal/method_impl_tester.h" #include "pw_rpc/internal/test_utils.h" #include "pw_rpc/pwpb/internal/method_union.h" #include "pw_rpc/service.h" #include "pw_rpc_pwpb_private/internal_test_utils.h" #include "pw_rpc_test_protos/test.pwpb.h" #include "pw_unit_test/framework.h" PW_MODIFY_DIAGNOSTICS_PUSH(); PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers"); namespace pw::rpc::internal { namespace { using std::byte; struct FakePb {}; // Create a fake service for use with the MethodImplTester. class TestPwpbService final : public Service { public: // Unary signatures Status Unary(const FakePb&, FakePb&) { return Status(); } static Status StaticUnary(const FakePb&, FakePb&) { return Status(); } void AsyncUnary(const FakePb&, PwpbUnaryResponder&) {} static void StaticAsyncUnary(const FakePb&, PwpbUnaryResponder&) {} Status UnaryWrongArg(FakePb&, FakePb&) { return Status(); } static void StaticUnaryVoidReturn(const FakePb&, FakePb&) {} // Server streaming signatures void ServerStreaming(const FakePb&, PwpbServerWriter&) {} static void StaticServerStreaming(const FakePb&, PwpbServerWriter&) {} int ServerStreamingBadReturn(const FakePb&, PwpbServerWriter&) { return 5; } static void StaticServerStreamingMissingArg(PwpbServerWriter&) {} // Client streaming signatures void ClientStreaming(PwpbServerReader&) {} static void StaticClientStreaming(PwpbServerReader&) {} int ClientStreamingBadReturn(PwpbServerReader&) { return 0; } static void StaticClientStreamingMissingArg() {} // Bidirectional streaming signatures void BidirectionalStreaming(PwpbServerReaderWriter&) {} static void StaticBidirectionalStreaming( PwpbServerReaderWriter&) {} int BidirectionalStreamingBadReturn(PwpbServerReaderWriter&) { return 0; } static void StaticBidirectionalStreamingMissingArg() {} }; struct WrongPb; // Test matches() rejects incorrect request/response types. // clang-format off static_assert(!PwpbMethod::template matches<&TestPwpbService::Unary, WrongPb, FakePb>()); static_assert(!PwpbMethod::template matches<&TestPwpbService::Unary, FakePb, WrongPb>()); static_assert(!PwpbMethod::template matches<&TestPwpbService::Unary, WrongPb, WrongPb>()); static_assert(!PwpbMethod::template matches<&TestPwpbService::StaticUnary, FakePb, WrongPb>()); static_assert(!PwpbMethod::template matches<&TestPwpbService::ServerStreaming, WrongPb, FakePb>()); static_assert(!PwpbMethod::template matches<&TestPwpbService::StaticServerStreaming, FakePb, WrongPb>()); static_assert(!PwpbMethod::template matches<&TestPwpbService::ClientStreaming, WrongPb, FakePb>()); static_assert(!PwpbMethod::template matches<&TestPwpbService::StaticClientStreaming, FakePb, WrongPb>()); static_assert(!PwpbMethod::template matches<&TestPwpbService::BidirectionalStreaming, WrongPb, FakePb>()); static_assert(!PwpbMethod::template matches<&TestPwpbService::StaticBidirectionalStreaming, FakePb, WrongPb>()); // clang-format on static_assert(MethodImplTests().Pass( MatchesTypes(), std::tuple(kPwpbMethodSerde))); template class FakeServiceBase : public Service { public: FakeServiceBase(uint32_t id) : Service(id, kMethods) {} static constexpr std::array kMethods = { PwpbMethod::SynchronousUnary<&Impl::DoNothing>( 10u, kPwpbMethodSerde<&pw::rpc::test::pwpb::Empty::kMessageFields, &pw::rpc::test::pwpb::Empty::kMessageFields>), PwpbMethod::AsynchronousUnary<&Impl::AddFive>( 11u, kPwpbMethodSerde<&pw::rpc::test::pwpb::TestRequest::kMessageFields, &pw::rpc::test::pwpb::TestResponse::kMessageFields>), PwpbMethod::ServerStreaming<&Impl::StartStream>( 12u, kPwpbMethodSerde<&pw::rpc::test::pwpb::TestRequest::kMessageFields, &pw::rpc::test::pwpb::TestResponse::kMessageFields>), PwpbMethod::ClientStreaming<&Impl::ClientStream>( 13u, kPwpbMethodSerde<&pw::rpc::test::pwpb::TestRequest::kMessageFields, &pw::rpc::test::pwpb::TestResponse::kMessageFields>), PwpbMethod::BidirectionalStreaming<&Impl::BidirectionalStream>( 14u, kPwpbMethodSerde<&pw::rpc::test::pwpb::TestRequest::kMessageFields, &pw::rpc::test::pwpb::TestResponse::kMessageFields>), }; }; class FakeService : public FakeServiceBase { public: FakeService(uint32_t id) : FakeServiceBase(id) {} Status DoNothing(const pw::rpc::test::pwpb::Empty::Message&, pw::rpc::test::pwpb::Empty::Message&) { return Status::Unknown(); } void AddFive(const pw::rpc::test::pwpb::TestRequest::Message& request, PwpbUnaryResponder& responder) { last_request = request; if (fail_to_encode_async_unary_response) { pw::rpc::test::pwpb::TestResponse::Message response = {}; response.repeated_field.SetEncoder( [](const pw::rpc::test::pwpb::TestResponse::StreamEncoder&) { return Status::Internal(); }); ASSERT_EQ(OkStatus(), responder.Finish(response, Status::NotFound())); } else { ASSERT_EQ( OkStatus(), responder.Finish({.value = static_cast(request.integer + 5)}, Status::Unauthenticated())); } } void StartStream( const pw::rpc::test::pwpb::TestRequest::Message& request, PwpbServerWriter& writer) { last_request = request; last_writer = std::move(writer); } void ClientStream( PwpbServerReader& reader) { last_reader = std::move(reader); } void BidirectionalStream( PwpbServerReaderWriter& reader_writer) { last_reader_writer = std::move(reader_writer); } bool fail_to_encode_async_unary_response = false; pw::rpc::test::pwpb::TestRequest::Message last_request; PwpbServerWriter last_writer; PwpbServerReader last_reader; PwpbServerReaderWriter last_reader_writer; }; constexpr const PwpbMethod& kSyncUnary = std::get<0>(FakeServiceBase::kMethods).pwpb_method(); constexpr const PwpbMethod& kAsyncUnary = std::get<1>(FakeServiceBase::kMethods).pwpb_method(); constexpr const PwpbMethod& kServerStream = std::get<2>(FakeServiceBase::kMethods).pwpb_method(); constexpr const PwpbMethod& kClientStream = std::get<3>(FakeServiceBase::kMethods).pwpb_method(); constexpr const PwpbMethod& kBidirectionalStream = std::get<4>(FakeServiceBase::kMethods).pwpb_method(); TEST(PwpbMethod, AsyncUnaryRpc_SendsResponse) { PW_ENCODE_PB(pw::rpc::test::pwpb::TestRequest, request, .integer = 123, .status_code = 0); ServerContextForTest context(kAsyncUnary); rpc_lock().lock(); kAsyncUnary.Invoke(context.get(), context.request(request)); const Packet& response = context.output().last_packet(); EXPECT_EQ(response.status(), Status::Unauthenticated()); // Field 1 (encoded as 1 << 3) with 128 as the value. constexpr std::byte expected[]{ std::byte{0x08}, std::byte{0x80}, std::byte{0x01}}; EXPECT_EQ(sizeof(expected), response.payload().size()); EXPECT_EQ(0, std::memcmp(expected, response.payload().data(), sizeof(expected))); EXPECT_EQ(123, context.service().last_request.integer); } TEST(PwpbMethod, SyncUnaryRpc_InvalidPayload_SendsError) { std::array bad_payload{byte{0xFF}, byte{0xAA}, byte{0xDD}}; ServerContextForTest context(kSyncUnary); rpc_lock().lock(); kSyncUnary.Invoke(context.get(), context.request(bad_payload)); const Packet& packet = context.output().last_packet(); EXPECT_EQ(pwpb::PacketType::SERVER_ERROR, packet.type()); EXPECT_EQ(Status::DataLoss(), packet.status()); EXPECT_EQ(context.service_id(), packet.service_id()); EXPECT_EQ(kSyncUnary.id(), packet.method_id()); } TEST(PwpbMethod, AsyncUnaryRpc_ResponseEncodingFails_SendsInternalError) { constexpr int64_t value = 0x7FFFFFFF'FFFFFF00ll; PW_ENCODE_PB(pw::rpc::test::pwpb::TestRequest, request, .integer = value, .status_code = 0); ServerContextForTest context(kAsyncUnary); context.service().fail_to_encode_async_unary_response = true; rpc_lock().lock(); kAsyncUnary.Invoke(context.get(), context.request(request)); const Packet& packet = context.output().last_packet(); EXPECT_EQ(pwpb::PacketType::SERVER_ERROR, packet.type()); EXPECT_EQ(Status::Internal(), packet.status()); EXPECT_EQ(context.service_id(), packet.service_id()); EXPECT_EQ(kAsyncUnary.id(), packet.method_id()); EXPECT_EQ(value, context.service().last_request.integer); } TEST(PwpbMethod, ServerStreamingRpc_SendsNothingWhenInitiallyCalled) { PW_ENCODE_PB(pw::rpc::test::pwpb::TestRequest, request, .integer = 555, .status_code = 0); ServerContextForTest context(kServerStream); rpc_lock().lock(); kServerStream.Invoke(context.get(), context.request(request)); EXPECT_EQ(0u, context.output().total_packets()); EXPECT_EQ(555, context.service().last_request.integer); } TEST(PwpbMethod, ServerWriter_SendsResponse) { ServerContextForTest context(kServerStream); rpc_lock().lock(); kServerStream.Invoke(context.get(), context.request({})); EXPECT_EQ(OkStatus(), context.service().last_writer.Write({.value = 100})); PW_ENCODE_PB(pw::rpc::test::pwpb::TestResponse, payload, .value = 100); std::array encoded_response = {}; auto encoded = context.server_stream(payload).Encode(encoded_response); ASSERT_EQ(OkStatus(), encoded.status()); ConstByteSpan sent_payload = context.output().last_packet().payload(); EXPECT_TRUE(pw::containers::Equal(payload, sent_payload)); } TEST(PwpbMethod, ServerWriter_WriteWhenClosed_ReturnsFailedPrecondition) { ServerContextForTest context(kServerStream); rpc_lock().lock(); kServerStream.Invoke(context.get(), context.request({})); EXPECT_EQ(OkStatus(), context.service().last_writer.Finish()); EXPECT_TRUE(context.service() .last_writer.Write({.value = 100}) .IsFailedPrecondition()); } TEST(PwpbMethod, ServerWriter_WriteAfterMoved_ReturnsFailedPrecondition) { ServerContextForTest context(kServerStream); rpc_lock().lock(); kServerStream.Invoke(context.get(), context.request({})); PwpbServerWriter new_writer = std::move(context.service().last_writer); EXPECT_EQ(OkStatus(), new_writer.Write({.value = 100})); EXPECT_EQ(Status::FailedPrecondition(), context.service().last_writer.Write({.value = 100})); EXPECT_EQ(Status::FailedPrecondition(), context.service().last_writer.Finish()); EXPECT_EQ(OkStatus(), new_writer.Finish()); } TEST(PwpbMethod, ServerStreamingRpc_ResponseEncodingFails_InternalError) { ServerContextForTest context(kServerStream); rpc_lock().lock(); kServerStream.Invoke(context.get(), context.request({})); EXPECT_EQ(OkStatus(), context.service().last_writer.Write({})); pw::rpc::test::pwpb::TestResponse::Message response = {}; response.repeated_field.SetEncoder( [](const pw::rpc::test::pwpb::TestResponse::StreamEncoder&) { return Status::Internal(); }); EXPECT_EQ(Status::Internal(), context.service().last_writer.Write(response)); } TEST(PwpbMethod, ServerReader_HandlesRequests) { ServerContextForTest context(kClientStream); rpc_lock().lock(); kClientStream.Invoke(context.get(), context.request({})); pw::rpc::test::pwpb::TestRequest::Message request_struct{}; context.service().last_reader.set_on_next( [&request_struct](const pw::rpc::test::pwpb::TestRequest::Message& req) { request_struct = req; }); PW_ENCODE_PB(pw::rpc::test::pwpb::TestRequest, request, .integer = 1 << 30, .status_code = 9); std::array encoded_request = {}; auto encoded = context.client_stream(request).Encode(encoded_request); ASSERT_EQ(OkStatus(), encoded.status()); ASSERT_EQ(OkStatus(), context.server().ProcessPacket(*encoded)); EXPECT_EQ(request_struct.integer, 1 << 30); EXPECT_EQ(request_struct.status_code, 9u); } TEST(PwpbMethod, ServerReaderWriter_WritesResponses) { ServerContextForTest context(kBidirectionalStream); rpc_lock().lock(); kBidirectionalStream.Invoke(context.get(), context.request({})); EXPECT_EQ(OkStatus(), context.service().last_reader_writer.Write({.value = 100})); PW_ENCODE_PB(pw::rpc::test::pwpb::TestResponse, payload, .value = 100); std::array encoded_response = {}; auto encoded = context.server_stream(payload).Encode(encoded_response); ASSERT_EQ(OkStatus(), encoded.status()); ConstByteSpan sent_payload = context.output().last_packet().payload(); EXPECT_TRUE(pw::containers::Equal(payload, sent_payload)); } TEST(PwpbMethod, ServerReaderWriter_HandlesRequests) { ServerContextForTest context(kBidirectionalStream); rpc_lock().lock(); kBidirectionalStream.Invoke(context.get(), context.request({})); pw::rpc::test::pwpb::TestRequest::Message request_struct{}; context.service().last_reader_writer.set_on_next( [&request_struct](const pw::rpc::test::pwpb::TestRequest::Message& req) { request_struct = req; }); PW_ENCODE_PB(pw::rpc::test::pwpb::TestRequest, request, .integer = 1 << 29, .status_code = 8); std::array encoded_request = {}; auto encoded = context.client_stream(request).Encode(encoded_request); ASSERT_EQ(OkStatus(), encoded.status()); ASSERT_EQ(OkStatus(), context.server().ProcessPacket(*encoded)); EXPECT_EQ(request_struct.integer, 1 << 29); EXPECT_EQ(request_struct.status_code, 8u); } } // namespace } // namespace pw::rpc::internal PW_MODIFY_DIAGNOSTICS_POP();