1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "cast/common/public/cast_socket.h"
6
7 #include "cast/common/channel/message_framer.h"
8 #include "cast/common/channel/proto/cast_channel.pb.h"
9 #include "cast/common/channel/testing/fake_cast_socket.h"
10 #include "gmock/gmock.h"
11 #include "gtest/gtest.h"
12
13 namespace openscreen {
14 namespace cast {
15
16 using ::cast::channel::CastMessage;
17
18 namespace {
19
20 using ::testing::_;
21 using ::testing::Invoke;
22 using ::testing::Return;
23
24 class CastSocketTest : public ::testing::Test {
25 public:
SetUp()26 void SetUp() override {
27 message_.set_protocol_version(CastMessage::CASTV2_1_0);
28 message_.set_source_id("source");
29 message_.set_destination_id("destination");
30 message_.set_namespace_("namespace");
31 message_.set_payload_type(CastMessage::STRING);
32 message_.set_payload_utf8("payload");
33 ErrorOr<std::vector<uint8_t>> serialized_or_error =
34 message_serialization::Serialize(message_);
35 ASSERT_TRUE(serialized_or_error);
36 frame_serial_ = std::move(serialized_or_error.value());
37 }
38
39 protected:
connection()40 MockTlsConnection& connection() { return *fake_socket_.connection; }
mock_client()41 MockCastSocketClient& mock_client() { return fake_socket_.mock_client; }
socket()42 CastSocket& socket() { return fake_socket_.socket; }
43
44 FakeCastSocket fake_socket_;
45 CastMessage message_;
46 std::vector<uint8_t> frame_serial_;
47 };
48
49 } // namespace
50
TEST_F(CastSocketTest,SendMessage)51 TEST_F(CastSocketTest, SendMessage) {
52 EXPECT_CALL(connection(), Send(_, _))
53 .WillOnce(Invoke([this](const void* data, size_t len) {
54 EXPECT_EQ(
55 frame_serial_,
56 std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data),
57 reinterpret_cast<const uint8_t*>(data) + len));
58 return true;
59 }));
60 ASSERT_TRUE(socket().Send(message_).ok());
61 }
62
TEST_F(CastSocketTest,SendMessageEventuallyBlocks)63 TEST_F(CastSocketTest, SendMessageEventuallyBlocks) {
64 EXPECT_CALL(connection(), Send(_, _))
65 .Times(3)
66 .WillRepeatedly(Invoke([this](const void* data, size_t len) {
67 EXPECT_EQ(
68 frame_serial_,
69 std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data),
70 reinterpret_cast<const uint8_t*>(data) + len));
71 return true;
72 }))
73 .RetiresOnSaturation();
74 ASSERT_TRUE(socket().Send(message_).ok());
75 ASSERT_TRUE(socket().Send(message_).ok());
76 ASSERT_TRUE(socket().Send(message_).ok());
77
78 EXPECT_CALL(connection(), Send(_, _)).WillOnce(Return(false));
79 ASSERT_EQ(socket().Send(message_).code(), Error::Code::kAgain);
80 }
81
TEST_F(CastSocketTest,ReadCompleteMessage)82 TEST_F(CastSocketTest, ReadCompleteMessage) {
83 const uint8_t* data = frame_serial_.data();
84 EXPECT_CALL(mock_client(), OnMessage(_, _))
85 .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
86 EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
87 }));
88 connection().OnRead(std::vector<uint8_t>(data, data + frame_serial_.size()));
89 }
90
TEST_F(CastSocketTest,ReadChunkedMessage)91 TEST_F(CastSocketTest, ReadChunkedMessage) {
92 const uint8_t* data = frame_serial_.data();
93 EXPECT_CALL(mock_client(), OnMessage(_, _)).Times(0);
94 connection().OnRead(std::vector<uint8_t>(data, data + 10));
95
96 EXPECT_CALL(mock_client(), OnMessage(_, _))
97 .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
98 EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
99 }));
100 connection().OnRead(
101 std::vector<uint8_t>(data + 10, data + frame_serial_.size()));
102
103 std::vector<uint8_t> double_message;
104 double_message.insert(double_message.end(), frame_serial_.begin(),
105 frame_serial_.end());
106 double_message.insert(double_message.end(), frame_serial_.begin(),
107 frame_serial_.end());
108 data = double_message.data();
109 EXPECT_CALL(mock_client(), OnMessage(_, _))
110 .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
111 EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
112 }));
113 connection().OnRead(
114 std::vector<uint8_t>(data, data + frame_serial_.size() + 10));
115
116 EXPECT_CALL(mock_client(), OnMessage(_, _))
117 .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
118 EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
119 }));
120 connection().OnRead(std::vector<uint8_t>(data + frame_serial_.size() + 10,
121 data + double_message.size()));
122 }
123
TEST_F(CastSocketTest,ReadMultipleMessagesPerBlock)124 TEST_F(CastSocketTest, ReadMultipleMessagesPerBlock) {
125 CastMessage message2;
126 std::vector<uint8_t> frame_serial2;
127 message2.set_protocol_version(CastMessage::CASTV2_1_0);
128 message2.set_source_id("alt-source");
129 message2.set_destination_id("alt-destination");
130 message2.set_namespace_("alt-namespace");
131 message2.set_payload_type(CastMessage::STRING);
132 message2.set_payload_utf8("alternate payload");
133 ErrorOr<std::vector<uint8_t>> serialized_or_error =
134 message_serialization::Serialize(message2);
135 ASSERT_TRUE(serialized_or_error);
136 frame_serial2 = std::move(serialized_or_error.value());
137
138 std::vector<uint8_t> send_data;
139 send_data.reserve(frame_serial_.size() + frame_serial2.size());
140 send_data.insert(send_data.end(), frame_serial_.begin(), frame_serial_.end());
141 send_data.insert(send_data.end(), frame_serial2.begin(), frame_serial2.end());
142 EXPECT_CALL(mock_client(), OnMessage(_, _))
143 .WillOnce(Invoke([this](CastSocket* socket, CastMessage message) {
144 EXPECT_EQ(message_.SerializeAsString(), message.SerializeAsString());
145 }))
146 .WillOnce([message2](CastSocket* socket, CastMessage message) {
147 EXPECT_EQ(message2.SerializeAsString(), message.SerializeAsString());
148 });
149 connection().OnRead(std::move(send_data));
150 }
151
TEST_F(CastSocketTest,SanitizedAddress)152 TEST_F(CastSocketTest, SanitizedAddress) {
153 std::array<uint8_t, 2> result1 = socket().GetSanitizedIpAddress();
154 EXPECT_EQ(result1[0], 1u);
155 EXPECT_EQ(result1[1], 9u);
156
157 FakeCastSocket v6_socket(IPEndpoint{{1, 2, 3, 4}, 1025},
158 IPEndpoint{{0x1819, 0x1a1b, 0x1c1d, 0x1e1f, 0x207b,
159 0x7c7d, 0x7e7f, 0x8081},
160 4321});
161 std::array<uint8_t, 2> result2 = v6_socket.socket.GetSanitizedIpAddress();
162 EXPECT_EQ(result2[0], 128);
163 EXPECT_EQ(result2[1], 129);
164 }
165
166 } // namespace cast
167 } // namespace openscreen
168