xref: /aosp_15_r20/external/openscreen/cast/test/cast_socket_e2e_test.cc (revision 3f982cf4871df8771c9d4abe6e9a6f8d829b2736)
1 // Copyright 2020 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <openssl/evp.h>
6 #include <openssl/mem.h>
7 
8 #include <atomic>
9 #include <chrono>
10 
11 #include "cast/common/certificate/cast_trust_store.h"
12 #include "cast/common/certificate/testing/test_helpers.h"
13 #include "cast/common/channel/connection_namespace_handler.h"
14 #include "cast/common/channel/message_util.h"
15 #include "cast/common/channel/virtual_connection_router.h"
16 #include "cast/common/public/cast_socket.h"
17 #include "cast/receiver/channel/device_auth_namespace_handler.h"
18 #include "cast/receiver/channel/static_credentials.h"
19 #include "cast/receiver/public/receiver_socket_factory.h"
20 #include "cast/sender/public/sender_socket_factory.h"
21 #include "gmock/gmock.h"
22 #include "gtest/gtest.h"
23 #include "platform/api/serial_delete_ptr.h"
24 #include "platform/api/tls_connection_factory.h"
25 #include "platform/base/tls_connect_options.h"
26 #include "platform/base/tls_credentials.h"
27 #include "platform/base/tls_listen_options.h"
28 #include "platform/impl/logging.h"
29 #include "platform/impl/network_interface.h"
30 #include "platform/impl/platform_client_posix.h"
31 #include "testing/util/task_util.h"
32 #include "util/crypto/certificate_utils.h"
33 #include "util/osp_logging.h"
34 
35 namespace openscreen {
36 namespace cast {
37 namespace {
38 
39 using ::testing::_;
40 using ::testing::StrictMock;
41 
42 constexpr char kLogDecorator[] = "--- ";
43 
44 }  // namespace
45 
46 class SenderSocketsClient : public SenderSocketFactory::Client,
47                             public VirtualConnectionRouter::SocketErrorHandler {
48  public:
SenderSocketsClient(VirtualConnectionRouter * router)49   explicit SenderSocketsClient(VirtualConnectionRouter* router)  // NOLINT
50       : router_(router) {}
51   virtual ~SenderSocketsClient() = default;
52 
socket() const53   CastSocket* socket() const { return socket_; }
54 
55   // SenderSocketFactory::Client overrides.
OnConnected(SenderSocketFactory * factory,const IPEndpoint & endpoint,std::unique_ptr<CastSocket> socket)56   void OnConnected(SenderSocketFactory* factory,
57                    const IPEndpoint& endpoint,
58                    std::unique_ptr<CastSocket> socket) {
59     OSP_CHECK(!socket_);
60     OSP_LOG_INFO << kLogDecorator
61                  << "Sender connected to endpoint: " << endpoint;
62     socket_ = socket.get();
63     router_->TakeSocket(this, std::move(socket));
64   }
65 
OnError(SenderSocketFactory * factory,const IPEndpoint & endpoint,Error error)66   void OnError(SenderSocketFactory* factory,
67                const IPEndpoint& endpoint,
68                Error error) override {
69     OSP_LOG_FATAL << error;
70   }
71 
72   // VirtualConnectionRouter::SocketErrorHandler overrides.
OnClose(CastSocket * socket)73   void OnClose(CastSocket* socket) override {
74     socket_ = nullptr;
75     OnCloseMock(socket);
76   }
OnError(CastSocket * socket,Error error)77   void OnError(CastSocket* socket, Error error) override {
78     socket_ = nullptr;
79     OnErrorMock(socket, std::move(error));
80   }
81 
82   MOCK_METHOD(void, OnCloseMock, (CastSocket * socket), ());
83   MOCK_METHOD(void, OnErrorMock, (CastSocket * socket, Error error), ());
84 
85  private:
86   VirtualConnectionRouter* const router_;
87   std::atomic<CastSocket*> socket_{nullptr};
88 };
89 
90 class ReceiverSocketsClient
91     : public ReceiverSocketFactory::Client,
92       public VirtualConnectionRouter::SocketErrorHandler {
93  public:
ReceiverSocketsClient(VirtualConnectionRouter * router)94   explicit ReceiverSocketsClient(VirtualConnectionRouter* router)
95       : router_(router) {}
96   virtual ~ReceiverSocketsClient() = default;
97 
endpoint() const98   const IPEndpoint& endpoint() const { return endpoint_; }
socket() const99   CastSocket* socket() const { return socket_; }
100 
101   // ReceiverSocketFactory::Client overrides.
OnConnected(ReceiverSocketFactory * factory,const IPEndpoint & endpoint,std::unique_ptr<CastSocket> socket)102   void OnConnected(ReceiverSocketFactory* factory,
103                    const IPEndpoint& endpoint,
104                    std::unique_ptr<CastSocket> socket) override {
105     OSP_CHECK(!socket_);
106     OSP_LOG_INFO << kLogDecorator
107                  << "Receiver got connection from endpoint: " << endpoint;
108     endpoint_ = endpoint;
109     socket_ = socket.get();
110     router_->TakeSocket(this, std::move(socket));
111   }
112 
OnError(ReceiverSocketFactory * factory,Error error)113   void OnError(ReceiverSocketFactory* factory, Error error) override {
114     OSP_LOG_FATAL << error;
115   }
116 
117   // VirtualConnectionRouter::SocketErrorHandler overrides.
OnClose(CastSocket * socket)118   void OnClose(CastSocket* socket) override {
119     socket_ = nullptr;
120     OnCloseMock(socket);
121   }
OnError(CastSocket * socket,Error error)122   void OnError(CastSocket* socket, Error error) override {
123     socket_ = nullptr;
124     OnErrorMock(socket, std::move(error));
125   }
126 
127   MOCK_METHOD(void, OnCloseMock, (CastSocket * socket), ());
128   MOCK_METHOD(void, OnErrorMock, (CastSocket * socket, Error error), ());
129 
130  private:
131   VirtualConnectionRouter* router_;
132   IPEndpoint endpoint_;
133   std::atomic<CastSocket*> socket_{nullptr};
134 };
135 
136 class CastSocketE2ETest : public ::testing::Test {
137  public:
SetUp()138   void SetUp() override {
139     PlatformClientPosix::Create(std::chrono::milliseconds(10));
140     task_runner_ = PlatformClientPosix::GetInstance()->GetTaskRunner();
141 
142     sender_router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_);
143     sender_client_ =
144         std::make_unique<StrictMock<SenderSocketsClient>>(sender_router_.get());
145     sender_factory_ = MakeSerialDelete<SenderSocketFactory>(
146         task_runner_, sender_client_.get(), task_runner_);
147     sender_tls_factory_ = SerialDeletePtr<TlsConnectionFactory>(
148         task_runner_,
149         TlsConnectionFactory::CreateFactory(sender_factory_.get(), task_runner_)
150             .release());
151     sender_factory_->set_factory(sender_tls_factory_.get());
152 
153     ErrorOr<GeneratedCredentials> creds =
154         GenerateCredentialsForTesting("Device ID");
155     ASSERT_TRUE(creds.is_value());
156     credentials_ = std::move(creds.value());
157 
158     CastTrustStore::CreateInstanceForTest(credentials_.root_cert_der);
159     auth_handler_ = MakeSerialDelete<DeviceAuthNamespaceHandler>(
160         task_runner_, credentials_.provider.get());
161     receiver_router_ = MakeSerialDelete<VirtualConnectionRouter>(task_runner_);
162     receiver_router_->AddHandlerForLocalId(kPlatformReceiverId,
163                                            auth_handler_.get());
164     receiver_client_ = std::make_unique<StrictMock<ReceiverSocketsClient>>(
165         receiver_router_.get());
166     receiver_factory_ = MakeSerialDelete<ReceiverSocketFactory>(
167         task_runner_, receiver_client_.get(), receiver_router_.get());
168 
169     receiver_tls_factory_ = SerialDeletePtr<TlsConnectionFactory>(
170         task_runner_, TlsConnectionFactory::CreateFactory(
171                           receiver_factory_.get(), task_runner_)
172                           .release());
173   }
174 
TearDown()175   void TearDown() override {
176     OSP_LOG_INFO << "Shutting down";
177     sender_router_.reset();
178     receiver_router_.reset();
179     receiver_tls_factory_.reset();
180     receiver_factory_.reset();
181     auth_handler_.reset();
182     sender_tls_factory_.reset();
183     sender_factory_.reset();
184     CastTrustStore::ResetInstance();
185     PlatformClientPosix::ShutDown();
186   }
187 
188  protected:
GetLoopbackV4Address()189   IPAddress GetLoopbackV4Address() {
190     absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting();
191     OSP_CHECK(loopback);
192     IPAddress address = loopback->GetIpAddressV4();
193     OSP_CHECK(address);
194     return address;
195   }
196 
GetLoopbackV6Address()197   IPAddress GetLoopbackV6Address() {
198     absl::optional<InterfaceInfo> loopback = GetLoopbackInterfaceForTesting();
199     OSP_CHECK(loopback);
200     IPAddress address = loopback->GetIpAddressV6();
201     return address;
202   }
203 
Connect(const IPAddress & address)204   void Connect(const IPAddress& address) {
205     uint16_t port = 65321;
206     OSP_LOG_INFO << kLogDecorator << "Starting socket factories";
207     task_runner_->PostTask([this, &address, port]() {
208       OSP_LOG_INFO << kLogDecorator << "Receiver TLS factory Listen()";
209       receiver_tls_factory_->SetListenCredentials(credentials_.tls_credentials);
210       receiver_tls_factory_->Listen(IPEndpoint{address, port},
211                                     TlsListenOptions{1u});
212     });
213 
214     task_runner_->PostTask([this, &address, port]() {
215       OSP_LOG_INFO << kLogDecorator << "Sender CastSocket factory Connect()";
216       sender_factory_->Connect(IPEndpoint{address, port},
217                                SenderSocketFactory::DeviceMediaPolicy::kNone,
218                                sender_router_.get());
219     });
220 
221     WaitForCondition([this]() { return sender_client_->socket(); });
222   }
223 
ConnectSocketsV4()224   void ConnectSocketsV4() {
225     OSP_LOG_INFO << "Getting loopback IPv4 address";
226     IPAddress loopback_address = GetLoopbackV4Address();
227     OSP_LOG_INFO << "Connecting CastSockets";
228     Connect(loopback_address);
229   }
230 
231   template <typename SocketClient, typename PeerSocketClient>
CloseSocketsFromOneEnd(VirtualConnectionRouter * router,SocketClient * client,PeerSocketClient * peer_client)232   void CloseSocketsFromOneEnd(VirtualConnectionRouter* router,
233                               SocketClient* client,
234                               PeerSocketClient* peer_client) {
235     // TODO(issuetracker.google.com/169967989): Would like to have a symmetric
236     // OnClose check.
237     EXPECT_CALL(*client, OnCloseMock(client->socket()));
238     EXPECT_CALL(*peer_client, OnErrorMock(peer_client->socket(), _))
239         .WillOnce([](CastSocket* socket, Error error) {
240           EXPECT_EQ(error.code(), Error::Code::kSocketClosedFailure);
241         });
242     int32_t id = client->socket()->socket_id();
243     std::atomic_bool did_run{false};
244     task_runner_->PostTask([id, router, &did_run]() {
245       router->CloseSocket(id);
246       did_run = true;
247     });
248     OSP_LOG_INFO << "Waiting for socket to close";
249     WaitForCondition([&did_run]() { return did_run.load(); });
250     EXPECT_FALSE(sender_client_->socket());
251     EXPECT_FALSE(receiver_client_->socket());
252   }
253 
254   TaskRunner* task_runner_;
255 
256   // NOTE: Sender components.
257   SerialDeletePtr<VirtualConnectionRouter> sender_router_;
258   std::unique_ptr<StrictMock<SenderSocketsClient>> sender_client_;
259   SerialDeletePtr<SenderSocketFactory> sender_factory_;
260   SerialDeletePtr<TlsConnectionFactory> sender_tls_factory_;
261 
262   // NOTE: Receiver components.
263   SerialDeletePtr<VirtualConnectionRouter> receiver_router_;
264   GeneratedCredentials credentials_;
265   SerialDeletePtr<DeviceAuthNamespaceHandler> auth_handler_;
266   std::unique_ptr<StrictMock<ReceiverSocketsClient>> receiver_client_;
267   SerialDeletePtr<ReceiverSocketFactory> receiver_factory_;
268   SerialDeletePtr<TlsConnectionFactory> receiver_tls_factory_;
269 };
270 
271 // These test the most basic setup of a complete CastSocket.  This means
272 // constructing both a SenderSocketFactory and ReceiverSocketFactory, making a
273 // TLS connection to a known port over the loopback device, and checking device
274 // authentication.
TEST_F(CastSocketE2ETest,ConnectV4)275 TEST_F(CastSocketE2ETest, ConnectV4) {
276   ConnectSocketsV4();
277 }
278 
TEST_F(CastSocketE2ETest,ConnectV6)279 TEST_F(CastSocketE2ETest, ConnectV6) {
280   OSP_LOG_INFO << "Getting loopback IPv6 address";
281   IPAddress loopback_address = GetLoopbackV6Address();
282   if (loopback_address) {
283     OSP_LOG_INFO << "Connecting CastSockets";
284     Connect(loopback_address);
285   } else {
286     OSP_LOG_WARN << "Test skipped due to missing IPv6 loopback address";
287   }
288 }
289 
TEST_F(CastSocketE2ETest,SenderClose)290 TEST_F(CastSocketE2ETest, SenderClose) {
291   ConnectSocketsV4();
292 
293   CloseSocketsFromOneEnd(sender_router_.get(), sender_client_.get(),
294                          receiver_client_.get());
295 }
296 
TEST_F(CastSocketE2ETest,ReceiverClose)297 TEST_F(CastSocketE2ETest, ReceiverClose) {
298   ConnectSocketsV4();
299 
300   CloseSocketsFromOneEnd(receiver_router_.get(), receiver_client_.get(),
301                          sender_client_.get());
302 }
303 
304 }  // namespace cast
305 }  // namespace openscreen
306