1 // Copyright 2019 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef CAST_SENDER_PUBLIC_SENDER_SOCKET_FACTORY_H_ 6 #define CAST_SENDER_PUBLIC_SENDER_SOCKET_FACTORY_H_ 7 8 #include <openssl/x509.h> 9 10 #include <memory> 11 #include <set> 12 #include <utility> 13 #include <vector> 14 15 #include "cast/common/public/cast_socket.h" 16 #include "platform/api/serial_delete_ptr.h" 17 #include "platform/api/task_runner.h" 18 #include "platform/api/tls_connection_factory.h" 19 #include "platform/base/ip_address.h" 20 21 namespace openscreen { 22 namespace cast { 23 24 class AuthContext; 25 26 class SenderSocketFactory final : public TlsConnectionFactory::Client, 27 public CastSocket::Client { 28 public: 29 class Client { 30 public: 31 virtual void OnConnected(SenderSocketFactory* factory, 32 const IPEndpoint& endpoint, 33 std::unique_ptr<CastSocket> socket) = 0; 34 virtual void OnError(SenderSocketFactory* factory, 35 const IPEndpoint& endpoint, 36 Error error) = 0; 37 38 protected: 39 virtual ~Client(); 40 }; 41 42 enum class DeviceMediaPolicy { 43 kNone = 0, 44 kAudioOnly, 45 kIncludesVideo, 46 }; 47 48 // |client| and |task_runner| must outlive |this|. 49 SenderSocketFactory(Client* client, TaskRunner* task_runner); 50 ~SenderSocketFactory(); 51 52 // |factory| cannot be nullptr and must outlive |this|. 53 void set_factory(TlsConnectionFactory* factory); 54 55 // Begins connecting to a Cast device at |endpoint|. If a successful 56 // connection is made, including device authentication, the new CastSocket 57 // will be passed to |client_|'s OnConnected method. The new CastSocket will 58 // have its client set to |client|. If any part of the connection process 59 // fails, |client_|'s OnError method is called instead. This includes if the 60 // device's media policy, as determined by authentication, is audio-only and 61 // |media_policy| is kIncludesVideo. 62 void Connect(const IPEndpoint& endpoint, 63 DeviceMediaPolicy media_policy, 64 CastSocket::Client* client); 65 66 // TlsConnectionFactory::Client overrides. 67 void OnAccepted(TlsConnectionFactory* factory, 68 std::vector<uint8_t> der_x509_peer_cert, 69 std::unique_ptr<TlsConnection> connection) override; 70 void OnConnected(TlsConnectionFactory* factory, 71 std::vector<uint8_t> der_x509_peer_cert, 72 std::unique_ptr<TlsConnection> connection) override; 73 void OnConnectionFailed(TlsConnectionFactory* factory, 74 const IPEndpoint& remote_address) override; 75 void OnError(TlsConnectionFactory* factory, Error error) override; 76 77 private: 78 struct PendingConnection { 79 IPEndpoint endpoint; 80 DeviceMediaPolicy media_policy; 81 CastSocket::Client* client; 82 }; 83 84 struct PendingAuth { 85 IPEndpoint endpoint; 86 DeviceMediaPolicy media_policy; 87 SerialDeletePtr<CastSocket> socket; 88 CastSocket::Client* client; 89 std::unique_ptr<AuthContext> auth_context; 90 bssl::UniquePtr<X509> peer_cert; 91 }; 92 93 friend bool operator<(const std::unique_ptr<PendingAuth>& a, int b); 94 friend bool operator<(int a, const std::unique_ptr<PendingAuth>& b); 95 96 std::vector<PendingConnection>::iterator FindPendingConnection( 97 const IPEndpoint& endpoint); 98 99 // CastSocket::Client overrides. 100 void OnError(CastSocket* socket, Error error) override; 101 void OnMessage(CastSocket* socket, 102 ::cast::channel::CastMessage message) override; 103 104 Client* const client_; 105 TaskRunner* const task_runner_; 106 TlsConnectionFactory* factory_ = nullptr; 107 std::vector<PendingConnection> pending_connections_; 108 std::vector<std::unique_ptr<PendingAuth>> pending_auth_; 109 }; 110 111 } // namespace cast 112 } // namespace openscreen 113 114 #endif // CAST_SENDER_PUBLIC_SENDER_SOCKET_FACTORY_H_ 115