1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <openssl/evp.h>
6 #include <openssl/mem.h>
7
8 #include <atomic>
9 #include <chrono>
10
11 #include "cast/common/certificate/cast_trust_store.h"
12 #include "cast/common/certificate/testing/test_helpers.h"
13 #include "cast/common/channel/connection_namespace_handler.h"
14 #include "cast/common/channel/message_util.h"
15 #include "cast/common/channel/virtual_connection_router.h"
16 #include "cast/common/public/cast_socket.h"
17 #include "cast/receiver/channel/device_auth_namespace_handler.h"
18 #include "cast/receiver/channel/static_credentials.h"
19 #include "cast/receiver/public/receiver_socket_factory.h"
20 #include "cast/sender/public/sender_socket_factory.h"
21 #include "gmock/gmock.h"
22 #include "gtest/gtest.h"
23 #include "platform/api/serial_delete_ptr.h"
24 #include "platform/api/tls_connection_factory.h"
25 #include "platform/base/tls_connect_options.h"
26 #include "platform/base/tls_credentials.h"
27 #include "platform/base/tls_listen_options.h"
28 #include "platform/impl/logging.h"
29 #include "platform/impl/network_interface.h"
30 #include "platform/impl/platform_client_posix.h"
31 #include "testing/util/task_util.h"
32 #include "util/crypto/certificate_utils.h"
33 #include "util/osp_logging.h"
34
35 namespace openscreen {
36 namespace cast {
37 namespace {
38
39 using ::testing::_;
40 using ::testing::StrictMock;
41
42 constexpr char kLogDecorator[] = "--- ";
43
44 } // namespace
45
46 class SenderSocketsClient : public SenderSocketFactory::Client,
47 public VirtualConnectionRouter::SocketErrorHandler {
48 public:
SenderSocketsClient(VirtualConnectionRouter * router)49 explicit SenderSocketsClient(VirtualConnectionRouter* router) // NOLINT
50 : router_(router) {}
51 virtual ~SenderSocketsClient() = default;
52
socket() const53 CastSocket* socket() const { return socket_; }
54
55 // SenderSocketFactory::Client overrides.
OnConnected(SenderSocketFactory * factory,const IPEndpoint & endpoint,std::unique_ptr<CastSocket> socket)56 void OnConnected(SenderSocketFactory* factory,
57 const IPEndpoint& endpoint,
58 std::unique_ptr<CastSocket> socket) {
59 OSP_CHECK(!socket_);
60 OSP_LOG_INFO << kLogDecorator
61 << "Sender connected to endpoint: " << endpoint;
62 socket_ = socket.get();
63 router_->TakeSocket(this, std::move(socket));
64 }
65
OnError(SenderSocketFactory * factory,const IPEndpoint & endpoint,Error error)66 void OnError(SenderSocketFactory* factory,
67 const IPEndpoint& endpoint,
68 Error error) override {
69 OSP_LOG_FATAL << error;
70 }
71
72 // VirtualConnectionRouter::SocketErrorHandler overrides.
OnClose(CastSocket * socket)73 void OnClose(CastSocket* socket) override {
74 socket_ = nullptr;
75 OnCloseMock(socket);
76 }
OnError(CastSocket * socket,Error error)77 void OnError(CastSocket* socket, Error error) override {
78 socket_ = nullptr;
79 OnErrorMock(socket, std::move(error));
80 }
81
82 MOCK_METHOD(void, OnCloseMock, (CastSocket * socket), ());
83 MOCK_METHOD(void, OnErrorMock, (CastSocket * socket, Error error), ());
84
85 private:
86 VirtualConnectionRouter* const router_;
87 std::atomic<CastSocket*> socket_{nullptr};
88 };
89
90 class ReceiverSocketsClient
91 : public ReceiverSocketFactory::Client,
92 public VirtualConnectionRouter::SocketErrorHandler {
93 public:
ReceiverSocketsClient(VirtualConnectionRouter * router)94 explicit ReceiverSocketsClient(VirtualConnectionRouter* router)
95 : router_(router) {}
96 virtual ~ReceiverSocketsClient() = default;
97
endpoint() const98 const IPEndpoint& endpoint() const { return endpoint_; }
socket() const99 CastSocket* socket() const { return socket_; }
100
101 // ReceiverSocketFactory::Client overrides.
OnConnected(ReceiverSocketFactory * factory,const IPEndpoint & endpoint,std::unique_ptr<CastSocket> socket)102 void OnConnected(ReceiverSocketFactory* factory,
103 const IPEndpoint& endpoint,
104 std::unique_ptr<CastSocket> socket) override {
105 OSP_CHECK(!socket_);
106 OSP_LOG_INFO << kLogDecorator
107 << "Receiver got connection from endpoint: " << endpoint;
108 endpoint_ = endpoint;
109 socket_ = socket.get();
110 router_->TakeSocket(this, std::move(socket));
111 }
112
OnError(ReceiverSocketFactory * factory,Error error)113 void OnError(ReceiverSocketFactory* factory, Error error) override {
114 OSP_LOG_FATAL << error;
115 }
116
117 // VirtualConnectionRouter::SocketErrorHandler overrides.
OnClose(CastSocket * socket)118 void OnClose(CastSocket* socket) override {
119 socket_ = nullptr;
120 OnCloseMock(socket);
121 }
OnError(CastSocket * socket,Error error)122 void OnError(CastSocket* socket, Error error) override {
123 socket_ = nullptr;
124 OnErrorMock(socket, std::move(error));
125 }
126
127 MOCK_METHOD(void, OnCloseMock, (CastSocket * socket), ());
128 MOCK_METHOD(void, OnErrorMock, (CastSocket * socket, Error error), ());
129
130 private:
131 VirtualConnectionRouter* router_;
132 IPEndpoint endpoint_;
133 std::atomic<CastSocket*> socket_{nullptr};
134 };
135
136 class CastSocketE2ETest : public ::testing::Test {
137 public:
SetUp()138 void SetUp() override {
139 PlatformClientPosix::Create(std::chrono::milliseconds(10));
140 task_runner_ = PlatformClientPosix::GetInstance()->GetTaskRunner();
141
142 sender_router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_);
143 sender_client_ =
144 std::make_unique<StrictMock<SenderSocketsClient>>(sender_router_.get());
145 sender_factory_ = MakeSerialDelete<SenderSocketFactory>(
146 task_runner_, sender_client_.get(), task_runner_);
147 sender_tls_factory_ = SerialDeletePtr<TlsConnectionFactory>(
148 task_runner_,
149 TlsConnectionFactory::CreateFactory(sender_factory_.get(), task_runner_)
150 .release());
151 sender_factory_->set_factory(sender_tls_factory_.get());
152
153 ErrorOr<GeneratedCredentials> creds =
154 GenerateCredentialsForTesting("Device ID");
155 ASSERT_TRUE(creds.is_value());
156 credentials_ = std::move(creds.value());
157
158 CastTrustStore::CreateInstanceForTest(credentials_.root_cert_der);
159 auth_handler_ = MakeSerialDelete<DeviceAuthNamespaceHandler>(
160 task_runner_, credentials_.provider.get());
161 receiver_router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_);
162 receiver_router_->AddHandlerForLocalId(kPlatformReceiverId,
163 auth_handler_.get());
164 receiver_client_ = std::make_unique<StrictMock<ReceiverSocketsClient>>(
165 receiver_router_.get());
166 receiver_factory_ = MakeSerialDelete<ReceiverSocketFactory>(
167 task_runner_, receiver_client_.get(), receiver_router_.get());
168
169 receiver_tls_factory_ = SerialDeletePtr<TlsConnectionFactory>(
170 task_runner_, TlsConnectionFactory::CreateFactory(
171 receiver_factory_.get(), task_runner_)
172 .release());
173 }
174
TearDown()175 void TearDown() override {
176 OSP_LOG_INFO << "Shutting down";
177 sender_router_.reset();
178 receiver_router_.reset();
179 receiver_tls_factory_.reset();
180 receiver_factory_.reset();
181 auth_handler_.reset();
182 sender_tls_factory_.reset();
183 sender_factory_.reset();
184 CastTrustStore::ResetInstance();
185 PlatformClientPosix::ShutDown();
186 }
187
188 protected:
GetLoopbackV4Address()189 IPAddress GetLoopbackV4Address() {
190 absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting();
191 OSP_CHECK(loopback);
192 IPAddress address = loopback->GetIpAddressV4();
193 OSP_CHECK(address);
194 return address;
195 }
196
GetLoopbackV6Address()197 IPAddress GetLoopbackV6Address() {
198 absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting();
199 OSP_CHECK(loopback);
200 IPAddress address = loopback->GetIpAddressV6();
201 return address;
202 }
203
Connect(const IPAddress & address)204 void Connect(const IPAddress& address) {
205 uint16_t port = 65321;
206 OSP_LOG_INFO << kLogDecorator << "Starting socket factories";
207 task_runner_->PostTask([this, &address, port]() {
208 OSP_LOG_INFO << kLogDecorator << "Receiver TLS factory Listen()";
209 receiver_tls_factory_->SetListenCredentials(credentials_.tls_credentials);
210 receiver_tls_factory_->Listen(IPEndpoint{address, port},
211 TlsListenOptions{1u});
212 });
213
214 task_runner_->PostTask([this, &address, port]() {
215 OSP_LOG_INFO << kLogDecorator << "Sender CastSocket factory Connect()";
216 sender_factory_->Connect(IPEndpoint{address, port},
217 SenderSocketFactory::DeviceMediaPolicy::kNone,
218 sender_router_.get());
219 });
220
221 WaitForCondition([this]() { return sender_client_->socket(); });
222 }
223
ConnectSocketsV4()224 void ConnectSocketsV4() {
225 OSP_LOG_INFO << "Getting loopback IPv4 address";
226 IPAddress loopback_address = GetLoopbackV4Address();
227 OSP_LOG_INFO << "Connecting CastSockets";
228 Connect(loopback_address);
229 }
230
231 template <typename SocketClient, typename PeerSocketClient>
CloseSocketsFromOneEnd(VirtualConnectionRouter * router,SocketClient * client,PeerSocketClient * peer_client)232 void CloseSocketsFromOneEnd(VirtualConnectionRouter* router,
233 SocketClient* client,
234 PeerSocketClient* peer_client) {
235 // TODO(issuetracker.google.com/169967989): Would like to have a symmetric
236 // OnClose check.
237 EXPECT_CALL(*client, OnCloseMock(client->socket()));
238 EXPECT_CALL(*peer_client, OnErrorMock(peer_client->socket(), _))
239 .WillOnce([](CastSocket* socket, Error error) {
240 EXPECT_EQ(error.code(), Error::Code::kSocketClosedFailure);
241 });
242 int32_t id = client->socket()->socket_id();
243 std::atomic_bool did_run{false};
244 task_runner_->PostTask([id, router, &did_run]() {
245 router->CloseSocket(id);
246 did_run = true;
247 });
248 OSP_LOG_INFO << "Waiting for socket to close";
249 WaitForCondition([&did_run]() { return did_run.load(); });
250 EXPECT_FALSE(sender_client_->socket());
251 EXPECT_FALSE(receiver_client_->socket());
252 }
253
254 TaskRunner* task_runner_;
255
256 // NOTE: Sender components.
257 SerialDeletePtr<VirtualConnectionRouter> sender_router_;
258 std::unique_ptr<StrictMock<SenderSocketsClient>> sender_client_;
259 SerialDeletePtr<SenderSocketFactory> sender_factory_;
260 SerialDeletePtr<TlsConnectionFactory> sender_tls_factory_;
261
262 // NOTE: Receiver components.
263 SerialDeletePtr<VirtualConnectionRouter> receiver_router_;
264 GeneratedCredentials credentials_;
265 SerialDeletePtr<DeviceAuthNamespaceHandler> auth_handler_;
266 std::unique_ptr<StrictMock<ReceiverSocketsClient>> receiver_client_;
267 SerialDeletePtr<ReceiverSocketFactory> receiver_factory_;
268 SerialDeletePtr<TlsConnectionFactory> receiver_tls_factory_;
269 };
270
271 // These test the most basic setup of a complete CastSocket. This means
272 // constructing both a SenderSocketFactory and ReceiverSocketFactory, making a
273 // TLS connection to a known port over the loopback device, and checking device
274 // authentication.
TEST_F(CastSocketE2ETest,ConnectV4)275 TEST_F(CastSocketE2ETest, ConnectV4) {
276 ConnectSocketsV4();
277 }
278
TEST_F(CastSocketE2ETest,ConnectV6)279 TEST_F(CastSocketE2ETest, ConnectV6) {
280 OSP_LOG_INFO << "Getting loopback IPv6 address";
281 IPAddress loopback_address = GetLoopbackV6Address();
282 if (loopback_address) {
283 OSP_LOG_INFO << "Connecting CastSockets";
284 Connect(loopback_address);
285 } else {
286 OSP_LOG_WARN << "Test skipped due to missing IPv6 loopback address";
287 }
288 }
289
TEST_F(CastSocketE2ETest,SenderClose)290 TEST_F(CastSocketE2ETest, SenderClose) {
291 ConnectSocketsV4();
292
293 CloseSocketsFromOneEnd(sender_router_.get(), sender_client_.get(),
294 receiver_client_.get());
295 }
296
TEST_F(CastSocketE2ETest,ReceiverClose)297 TEST_F(CastSocketE2ETest, ReceiverClose) {
298 ConnectSocketsV4();
299
300 CloseSocketsFromOneEnd(receiver_router_.get(), receiver_client_.get(),
301 sender_client_.get());
302 }
303
304 } // namespace cast
305 } // namespace openscreen
306