xref: /aosp_15_r20/external/openscreen/cast/common/channel/cast_socket_unittest.cc (revision 3f982cf4871df8771c9d4abe6e9a6f8d829b2736)
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "cast/common/public/cast_socket.h"
6 
7 #include "cast/common/channel/message_framer.h"
8 #include "cast/common/channel/proto/cast_channel.pb.h"
9 #include "cast/common/channel/testing/fake_cast_socket.h"
10 #include "gmock/gmock.h"
11 #include "gtest/gtest.h"
12 
13 namespace openscreen {
14 namespace cast {
15 
16 using ::cast::channel::CastMessage;
17 
18 namespace {
19 
20 using ::testing::_;
21 using ::testing::Invoke;
22 using ::testing::Return;
23 
24 class CastSocketTest : public ::testing::Test {
25  public:
SetUp()26   void SetUp() override {
27     message_.set_protocol_version(CastMessage::CASTV2_1_0);
28     message_.set_source_id("source");
29     message_.set_destination_id("destination");
30     message_.set_namespace_("namespace");
31     message_.set_payload_type(CastMessage::STRING);
32     message_.set_payload_utf8("payload");
33     ErrorOr<std::vector<uint8_t>> serialized_or_error =
34         message_serialization::Serialize(message_);
35     ASSERT_TRUE(serialized_or_error);
36     frame_serial_ = std::move(serialized_or_error.value());
37   }
38 
39  protected:
connection()40   MockTlsConnection& connection() { return *fake_socket_.connection; }
mock_client()41   MockCastSocketClient& mock_client() { return fake_socket_.mock_client; }
socket()42   CastSocket& socket() { return fake_socket_.socket; }
43 
44   FakeCastSocket fake_socket_;
45   CastMessage message_;
46   std::vector<uint8_t> frame_serial_;
47 };
48 
49 }  // namespace
50 
TEST_F(CastSocketTest,SendMessage)51 TEST_F(CastSocketTest, SendMessage) {
52   EXPECT_CALL(connection(), Send(_, _))
53       .WillOnce(Invoke([this](const void* data, size_t len) {
54         EXPECT_EQ(
55             frame_serial_,
56             std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data),
57                                  reinterpret_cast<const uint8_t*>(data) + len));
58         return true;
59       }));
60   ASSERT_TRUE(socket().Send(message_).ok());
61 }
62 
TEST_F(CastSocketTest,SendMessageEventuallyBlocks)63 TEST_F(CastSocketTest, SendMessageEventuallyBlocks) {
64   EXPECT_CALL(connection(), Send(_, _))
65       .Times(3)
66       .WillRepeatedly(Invoke([this](const void* data, size_t len) {
67         EXPECT_EQ(
68             frame_serial_,
69             std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data),
70                                  reinterpret_cast<const uint8_t*>(data) + len));
71         return true;
72       }))
73       .RetiresOnSaturation();
74   ASSERT_TRUE(socket().Send(message_).ok());
75   ASSERT_TRUE(socket().Send(message_).ok());
76   ASSERT_TRUE(socket().Send(message_).ok());
77 
78   EXPECT_CALL(connection(), Send(_, _)).WillOnce(Return(false));
79   ASSERT_EQ(socket().Send(message_).code(), Error::Code::kAgain);
80 }
81 
TEST_F(CastSocketTest,ReadCompleteMessage)82 TEST_F(CastSocketTest, ReadCompleteMessage) {
83   const uint8_t* data = frame_serial_.data();
84   EXPECT_CALL(mock_client(), OnMessage(_, _))
85       .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
86         EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
87       }));
88   connection().OnRead(std::vector<uint8_t>(data, data + frame_serial_.size()));
89 }
90 
TEST_F(CastSocketTest,ReadChunkedMessage)91 TEST_F(CastSocketTest, ReadChunkedMessage) {
92   const uint8_t* data = frame_serial_.data();
93   EXPECT_CALL(mock_client(), OnMessage(_, _)).Times(0);
94   connection().OnRead(std::vector<uint8_t>(data, data + 10));
95 
96   EXPECT_CALL(mock_client(), OnMessage(_, _))
97       .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
98         EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
99       }));
100   connection().OnRead(
101       std::vector<uint8_t>(data + 10, data + frame_serial_.size()));
102 
103   std::vector<uint8_t> double_message;
104   double_message.insert(double_message.end(), frame_serial_.begin(),
105                         frame_serial_.end());
106   double_message.insert(double_message.end(), frame_serial_.begin(),
107                         frame_serial_.end());
108   data = double_message.data();
109   EXPECT_CALL(mock_client(), OnMessage(_, _))
110       .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
111         EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
112       }));
113   connection().OnRead(
114       std::vector<uint8_t>(data, data + frame_serial_.size() + 10));
115 
116   EXPECT_CALL(mock_client(), OnMessage(_, _))
117       .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
118         EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
119       }));
120   connection().OnRead(std::vector<uint8_t>(data + frame_serial_.size() + 10,
121                                            data + double_message.size()));
122 }
123 
TEST_F(CastSocketTest,ReadMultipleMessagesPerBlock)124 TEST_F(CastSocketTest, ReadMultipleMessagesPerBlock) {
125   CastMessage message2;
126   std::vector<uint8_t> frame_serial2;
127   message2.set_protocol_version(CastMessage::CASTV2_1_0);
128   message2.set_source_id("alt-source");
129   message2.set_destination_id("alt-destination");
130   message2.set_namespace_("alt-namespace");
131   message2.set_payload_type(CastMessage::STRING);
132   message2.set_payload_utf8("alternate payload");
133   ErrorOr<std::vector<uint8_t>> serialized_or_error =
134       message_serialization::Serialize(message2);
135   ASSERT_TRUE(serialized_or_error);
136   frame_serial2 = std::move(serialized_or_error.value());
137 
138   std::vector<uint8_t> send_data;
139   send_data.reserve(frame_serial_.size() + frame_serial2.size());
140   send_data.insert(send_data.end(), frame_serial_.begin(), frame_serial_.end());
141   send_data.insert(send_data.end(), frame_serial2.begin(), frame_serial2.end());
142   EXPECT_CALL(mock_client(), OnMessage(_, _))
143       .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
144         EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
145       }))
146       .WillOnce([message2](CastSocket* socket, CastMessage message) {
147         EXPECT_EQ(message2.SerializeAsString(), message.SerializeAsString());
148       });
149   connection().OnRead(std::move(send_data));
150 }
151 
TEST_F(CastSocketTest,SanitizedAddress)152 TEST_F(CastSocketTest, SanitizedAddress) {
153   std::array<uint8_t, 2> result1 = socket().GetSanitizedIpAddress();
154   EXPECT_EQ(result1[0], 1u);
155   EXPECT_EQ(result1[1], 9u);
156 
157   FakeCastSocket v6_socket(IPEndpoint{{1, 2, 3, 4}, 1025},
158                            IPEndpoint{{0x1819, 0x1a1b, 0x1c1d, 0x1e1f, 0x207b,
159                                        0x7c7d, 0x7e7f, 0x8081},
160                                       4321});
161   std::array<uint8_t, 2> result2 = v6_socket.socket.GetSanitizedIpAddress();
162   EXPECT_EQ(result2[0], 128);
163   EXPECT_EQ(result2[1], 129);
164 }
165 
166 }  // namespace cast
167 }  // namespace openscreen
168