xref: /aosp_15_r20/external/webrtc/media/sctp/dcsctp_transport.h (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef MEDIA_SCTP_DCSCTP_TRANSPORT_H_
12 #define MEDIA_SCTP_DCSCTP_TRANSPORT_H_
13 
14 #include <memory>
15 #include <string>
16 
17 #include "absl/strings/string_view.h"
18 #include "absl/types/optional.h"
19 #include "api/array_view.h"
20 #include "api/task_queue/task_queue_base.h"
21 #include "media/sctp/sctp_transport_internal.h"
22 #include "net/dcsctp/public/dcsctp_options.h"
23 #include "net/dcsctp/public/dcsctp_socket.h"
24 #include "net/dcsctp/public/dcsctp_socket_factory.h"
25 #include "net/dcsctp/public/types.h"
26 #include "net/dcsctp/timer/task_queue_timeout.h"
27 #include "p2p/base/packet_transport_internal.h"
28 #include "rtc_base/containers/flat_map.h"
29 #include "rtc_base/copy_on_write_buffer.h"
30 #include "rtc_base/random.h"
31 #include "rtc_base/third_party/sigslot/sigslot.h"
32 #include "rtc_base/thread.h"
33 #include "rtc_base/thread_annotations.h"
34 #include "system_wrappers/include/clock.h"
35 
36 namespace webrtc {
37 
38 class DcSctpTransport : public cricket::SctpTransportInternal,
39                         public dcsctp::DcSctpSocketCallbacks,
40                         public sigslot::has_slots<> {
41  public:
42   DcSctpTransport(rtc::Thread* network_thread,
43                   rtc::PacketTransportInternal* transport,
44                   Clock* clock);
45   DcSctpTransport(rtc::Thread* network_thread,
46                   rtc::PacketTransportInternal* transport,
47                   Clock* clock,
48                   std::unique_ptr<dcsctp::DcSctpSocketFactory> socket_factory);
49   ~DcSctpTransport() override;
50 
51   // cricket::SctpTransportInternal
52   void SetOnConnectedCallback(std::function<void()> callback) override;
53   void SetDataChannelSink(DataChannelSink* sink) override;
54   void SetDtlsTransport(rtc::PacketTransportInternal* transport) override;
55   bool Start(int local_sctp_port,
56              int remote_sctp_port,
57              int max_message_size) override;
58   bool OpenStream(int sid) override;
59   bool ResetStream(int sid) override;
60   bool SendData(int sid,
61                 const SendDataParams& params,
62                 const rtc::CopyOnWriteBuffer& payload,
63                 cricket::SendDataResult* result = nullptr) override;
64   bool ReadyToSendData() override;
65   int max_message_size() const override;
66   absl::optional<int> max_outbound_streams() const override;
67   absl::optional<int> max_inbound_streams() const override;
68   void set_debug_name_for_testing(const char* debug_name) override;
69 
70  private:
71   // dcsctp::DcSctpSocketCallbacks
72   dcsctp::SendPacketStatus SendPacketWithStatus(
73       rtc::ArrayView<const uint8_t> data) override;
74   std::unique_ptr<dcsctp::Timeout> CreateTimeout(
75       webrtc::TaskQueueBase::DelayPrecision precision) override;
76   dcsctp::TimeMs TimeMillis() override;
77   uint32_t GetRandomInt(uint32_t low, uint32_t high) override;
78   void OnTotalBufferedAmountLow() override;
79   void OnMessageReceived(dcsctp::DcSctpMessage message) override;
80   void OnError(dcsctp::ErrorKind error, absl::string_view message) override;
81   void OnAborted(dcsctp::ErrorKind error, absl::string_view message) override;
82   void OnConnected() override;
83   void OnClosed() override;
84   void OnConnectionRestarted() override;
85   void OnStreamsResetFailed(
86       rtc::ArrayView<const dcsctp::StreamID> outgoing_streams,
87       absl::string_view reason) override;
88   void OnStreamsResetPerformed(
89       rtc::ArrayView<const dcsctp::StreamID> outgoing_streams) override;
90   void OnIncomingStreamsReset(
91       rtc::ArrayView<const dcsctp::StreamID> incoming_streams) override;
92 
93   // Transport callbacks
94   void ConnectTransportSignals();
95   void DisconnectTransportSignals();
96   void OnTransportWritableState(rtc::PacketTransportInternal* transport);
97   void OnTransportReadPacket(rtc::PacketTransportInternal* transport,
98                              const char* data,
99                              size_t length,
100                              const int64_t& /* packet_time_us */,
101                              int flags);
102   void OnTransportClosed(rtc::PacketTransportInternal* transport);
103 
104   void MaybeConnectSocket();
105 
106   rtc::Thread* network_thread_;
107   rtc::PacketTransportInternal* transport_;
108   Clock* clock_;
109   Random random_;
110 
111   std::unique_ptr<dcsctp::DcSctpSocketFactory> socket_factory_;
112   dcsctp::TaskQueueTimeoutFactory task_queue_timeout_factory_;
113   std::unique_ptr<dcsctp::DcSctpSocketInterface> socket_;
114   std::string debug_name_ = "DcSctpTransport";
115   rtc::CopyOnWriteBuffer receive_buffer_;
116 
117   // Used to keep track of the state of data channels.
118   // Reset needs to happen both ways before signaling the transport
119   // is closed.
120   struct StreamState {
121     // True when the local connection has initiated the reset.
122     // If a connection receives a reset for a stream that isn't
123     // already being reset locally, it needs to fire the signal
124     // SignalClosingProcedureStartedRemotely.
125     bool closure_initiated = false;
126     // True when the local connection received OnIncomingStreamsReset
127     bool incoming_reset_done = false;
128     // True when the local connection received OnStreamsResetPerformed
129     bool outgoing_reset_done = false;
130   };
131 
132   // Map of all currently open or closing data channels
133   flat_map<dcsctp::StreamID, StreamState> stream_states_
134       RTC_GUARDED_BY(network_thread_);
135   bool ready_to_send_data_ = false;
136   std::function<void()> on_connected_callback_ RTC_GUARDED_BY(network_thread_);
137   DataChannelSink* data_channel_sink_ RTC_GUARDED_BY(network_thread_) = nullptr;
138 };
139 
140 }  // namespace webrtc
141 
142 #endif  // MEDIA_SCTP_DCSCTP_TRANSPORT_H_
143