xref: /aosp_15_r20/external/openscreen/cast/sender/channel/sender_socket_factory.cc (revision 3f982cf4871df8771c9d4abe6e9a6f8d829b2736)
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "cast/sender/public/sender_socket_factory.h"
6 
7 #include "cast/common/channel/proto/cast_channel.pb.h"
8 #include "cast/sender/channel/cast_auth_util.h"
9 #include "cast/sender/channel/message_util.h"
10 #include "platform/base/tls_connect_options.h"
11 #include "util/crypto/certificate_utils.h"
12 #include "util/osp_logging.h"
13 
14 using ::cast::channel::CastMessage;
15 
16 namespace openscreen {
17 namespace cast {
18 
19 SenderSocketFactory::Client::~Client() = default;
20 
operator <(const std::unique_ptr<SenderSocketFactory::PendingAuth> & a,int b)21 bool operator<(const std::unique_ptr<SenderSocketFactory::PendingAuth>& a,
22                int b) {
23   return a && a->socket->socket_id() < b;
24 }
25 
operator <(int a,const std::unique_ptr<SenderSocketFactory::PendingAuth> & b)26 bool operator<(int a,
27                const std::unique_ptr<SenderSocketFactory::PendingAuth>& b) {
28   return b && a < b->socket->socket_id();
29 }
30 
SenderSocketFactory(Client * client,TaskRunner * task_runner)31 SenderSocketFactory::SenderSocketFactory(Client* client,
32                                          TaskRunner* task_runner)
33     : client_(client), task_runner_(task_runner) {
34   OSP_DCHECK(client);
35   OSP_DCHECK(task_runner);
36 }
37 
~SenderSocketFactory()38 SenderSocketFactory::~SenderSocketFactory() {
39   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
40 }
41 
set_factory(TlsConnectionFactory * factory)42 void SenderSocketFactory::set_factory(TlsConnectionFactory* factory) {
43   OSP_DCHECK(factory);
44   factory_ = factory;
45 }
46 
Connect(const IPEndpoint & endpoint,DeviceMediaPolicy media_policy,CastSocket::Client * client)47 void SenderSocketFactory::Connect(const IPEndpoint& endpoint,
48                                   DeviceMediaPolicy media_policy,
49                                   CastSocket::Client* client) {
50   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
51   OSP_DCHECK(client);
52   auto it = FindPendingConnection(endpoint);
53   if (it == pending_connections_.end()) {
54     pending_connections_.emplace_back(
55         PendingConnection{endpoint, media_policy, client});
56     factory_->Connect(endpoint, TlsConnectOptions{true});
57   }
58 }
59 
OnAccepted(TlsConnectionFactory * factory,std::vector<uint8_t> der_x509_peer_cert,std::unique_ptr<TlsConnection> connection)60 void SenderSocketFactory::OnAccepted(
61     TlsConnectionFactory* factory,
62     std::vector<uint8_t> der_x509_peer_cert,
63     std::unique_ptr<TlsConnection> connection) {
64   OSP_NOTREACHED();
65   OSP_LOG_FATAL << "This factory is connect-only";
66 }
67 
OnConnected(TlsConnectionFactory * factory,std::vector<uint8_t> der_x509_peer_cert,std::unique_ptr<TlsConnection> connection)68 void SenderSocketFactory::OnConnected(
69     TlsConnectionFactory* factory,
70     std::vector<uint8_t> der_x509_peer_cert,
71     std::unique_ptr<TlsConnection> connection) {
72   const IPEndpoint& endpoint = connection->GetRemoteEndpoint();
73   auto it = FindPendingConnection(endpoint);
74   if (it == pending_connections_.end()) {
75     OSP_DLOG_ERROR << "TLS connection succeeded for unknown endpoint: "
76                    << endpoint;
77     return;
78   }
79   DeviceMediaPolicy media_policy = it->media_policy;
80   CastSocket::Client* client = it->client;
81   pending_connections_.erase(it);
82 
83   ErrorOr<bssl::UniquePtr<X509>> peer_cert =
84       ImportCertificate(der_x509_peer_cert.data(), der_x509_peer_cert.size());
85   if (!peer_cert) {
86     client_->OnError(this, endpoint, peer_cert.error());
87     return;
88   }
89 
90   auto socket =
91       MakeSerialDelete<CastSocket>(task_runner_, std::move(connection), this);
92   pending_auth_.emplace_back(
93       new PendingAuth{endpoint, media_policy, std::move(socket), client,
94                       std::make_unique<AuthContext>(AuthContext::Create()),
95                       std::move(peer_cert.value())});
96   PendingAuth& pending = *pending_auth_.back();
97 
98   CastMessage auth_challenge =
99       CreateAuthChallengeMessage(*pending.auth_context);
100   Error error = pending.socket->Send(auth_challenge);
101   if (!error.ok()) {
102     pending_auth_.pop_back();
103     client_->OnError(this, endpoint, error);
104   }
105 }
106 
OnConnectionFailed(TlsConnectionFactory * factory,const IPEndpoint & remote_address)107 void SenderSocketFactory::OnConnectionFailed(TlsConnectionFactory* factory,
108                                              const IPEndpoint& remote_address) {
109   auto it = FindPendingConnection(remote_address);
110   if (it == pending_connections_.end()) {
111     return;
112   }
113   pending_connections_.erase(it);
114   client_->OnError(this, remote_address, Error::Code::kConnectionFailed);
115 }
116 
OnError(TlsConnectionFactory * factory,Error error)117 void SenderSocketFactory::OnError(TlsConnectionFactory* factory, Error error) {
118   std::vector<PendingConnection> connections;
119   pending_connections_.swap(connections);
120   for (const PendingConnection& pending : connections) {
121     client_->OnError(this, pending.endpoint, error);
122   }
123 }
124 
125 std::vector<SenderSocketFactory::PendingConnection>::iterator
FindPendingConnection(const IPEndpoint & endpoint)126 SenderSocketFactory::FindPendingConnection(const IPEndpoint& endpoint) {
127   return std::find_if(pending_connections_.begin(), pending_connections_.end(),
128                       [&endpoint](const PendingConnection& pending) {
129                         return pending.endpoint == endpoint;
130                       });
131 }
132 
OnError(CastSocket * socket,Error error)133 void SenderSocketFactory::OnError(CastSocket* socket, Error error) {
134   auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
135                          [id = socket->socket_id()](
136                              const std::unique_ptr<PendingAuth>& pending_auth) {
137                            return pending_auth->socket->socket_id() == id;
138                          });
139   if (it == pending_auth_.end()) {
140     OSP_DLOG_ERROR << "Got error for unknown pending socket";
141     return;
142   }
143   IPEndpoint endpoint = (*it)->endpoint;
144   pending_auth_.erase(it);
145   client_->OnError(this, endpoint, error);
146 }
147 
OnMessage(CastSocket * socket,CastMessage message)148 void SenderSocketFactory::OnMessage(CastSocket* socket, CastMessage message) {
149   auto it = std::find_if(pending_auth_.begin(), pending_auth_.end(),
150                          [id = socket->socket_id()](
151                              const std::unique_ptr<PendingAuth>& pending_auth) {
152                            return pending_auth->socket->socket_id() == id;
153                          });
154   if (it == pending_auth_.end()) {
155     OSP_DLOG_ERROR << "Got message for unknown pending socket";
156     return;
157   }
158 
159   std::unique_ptr<PendingAuth> pending = std::move(*it);
160   pending_auth_.erase(it);
161   if (!IsAuthMessage(message)) {
162     client_->OnError(this, pending->endpoint,
163                      Error::Code::kCastV2AuthenticationError);
164     return;
165   }
166 
167   ErrorOr<CastDeviceCertPolicy> policy_or_error = AuthenticateChallengeReply(
168       message, pending->peer_cert.get(), *pending->auth_context);
169   if (policy_or_error.is_error()) {
170     OSP_DLOG_WARN << "Authentication failed for " << pending->endpoint
171                   << " with error: " << policy_or_error.error();
172     client_->OnError(this, pending->endpoint, policy_or_error.error());
173     return;
174   }
175 
176   if (policy_or_error.value() == CastDeviceCertPolicy::kAudioOnly &&
177       pending->media_policy == DeviceMediaPolicy::kIncludesVideo) {
178     client_->OnError(this, pending->endpoint,
179                      Error::Code::kCastV2ChannelPolicyMismatch);
180     return;
181   }
182   pending->socket->set_audio_only(policy_or_error.value() ==
183                                   CastDeviceCertPolicy::kAudioOnly);
184 
185   pending->socket->SetClient(pending->client);
186   client_->OnConnected(this, pending->endpoint,
187                        std::unique_ptr<CastSocket>(pending->socket.release()));
188 }
189 
190 }  // namespace cast
191 }  // namespace openscreen
192