xref: /aosp_15_r20/external/webrtc/net/dcsctp/fuzzers/dcsctp_fuzzers.h (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 #ifndef NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_
11 #define NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_
12 
13 #include <deque>
14 #include <memory>
15 #include <set>
16 #include <vector>
17 
18 #include "api/array_view.h"
19 #include "api/task_queue/task_queue_base.h"
20 #include "net/dcsctp/public/dcsctp_socket.h"
21 
22 namespace dcsctp {
23 namespace dcsctp_fuzzers {
24 
25 // A fake timeout used during fuzzing.
26 class FuzzerTimeout : public Timeout {
27  public:
FuzzerTimeout(std::set<TimeoutID> & active_timeouts)28   explicit FuzzerTimeout(std::set<TimeoutID>& active_timeouts)
29       : active_timeouts_(active_timeouts) {}
30 
Start(DurationMs duration_ms,TimeoutID timeout_id)31   void Start(DurationMs duration_ms, TimeoutID timeout_id) override {
32     // Start is only allowed to be called on stopped or expired timeouts.
33     if (timeout_id_.has_value()) {
34       // It has been started before, but maybe it expired. Ensure that it's not
35       // running at least.
36       RTC_DCHECK(active_timeouts_.find(*timeout_id_) == active_timeouts_.end());
37     }
38     timeout_id_ = timeout_id;
39     RTC_DCHECK(active_timeouts_.insert(timeout_id).second);
40   }
41 
Stop()42   void Stop() override {
43     // Stop is only allowed to be called on active timeouts. Not stopped or
44     // expired.
45     RTC_DCHECK(timeout_id_.has_value());
46     RTC_DCHECK(active_timeouts_.erase(*timeout_id_) == 1);
47     timeout_id_ = absl::nullopt;
48   }
49 
50   // A set of all active timeouts, managed by `FuzzerCallbacks`.
51   std::set<TimeoutID>& active_timeouts_;
52   // If present, the timout is active and will expire reported as `timeout_id`.
53   absl::optional<TimeoutID> timeout_id_;
54 };
55 
56 class FuzzerCallbacks : public DcSctpSocketCallbacks {
57  public:
58   static constexpr int kRandomValue = 42;
SendPacket(rtc::ArrayView<const uint8_t> data)59   void SendPacket(rtc::ArrayView<const uint8_t> data) override {
60     sent_packets_.emplace_back(std::vector<uint8_t>(data.begin(), data.end()));
61   }
CreateTimeout(webrtc::TaskQueueBase::DelayPrecision precision)62   std::unique_ptr<Timeout> CreateTimeout(
63       webrtc::TaskQueueBase::DelayPrecision precision) override {
64     // The fuzzer timeouts don't implement |precision|.
65     return std::make_unique<FuzzerTimeout>(active_timeouts_);
66   }
TimeMillis()67   TimeMs TimeMillis() override { return TimeMs(42); }
GetRandomInt(uint32_t low,uint32_t high)68   uint32_t GetRandomInt(uint32_t low, uint32_t high) override {
69     return kRandomValue;
70   }
OnMessageReceived(DcSctpMessage message)71   void OnMessageReceived(DcSctpMessage message) override {}
OnError(ErrorKind error,absl::string_view message)72   void OnError(ErrorKind error, absl::string_view message) override {}
OnAborted(ErrorKind error,absl::string_view message)73   void OnAborted(ErrorKind error, absl::string_view message) override {}
OnConnected()74   void OnConnected() override {}
OnClosed()75   void OnClosed() override {}
OnConnectionRestarted()76   void OnConnectionRestarted() override {}
OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams,absl::string_view reason)77   void OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams,
78                             absl::string_view reason) override {}
OnStreamsResetPerformed(rtc::ArrayView<const StreamID> outgoing_streams)79   void OnStreamsResetPerformed(
80       rtc::ArrayView<const StreamID> outgoing_streams) override {}
OnIncomingStreamsReset(rtc::ArrayView<const StreamID> incoming_streams)81   void OnIncomingStreamsReset(
82       rtc::ArrayView<const StreamID> incoming_streams) override {}
83 
ConsumeSentPacket()84   std::vector<uint8_t> ConsumeSentPacket() {
85     if (sent_packets_.empty()) {
86       return {};
87     }
88     std::vector<uint8_t> ret = sent_packets_.front();
89     sent_packets_.pop_front();
90     return ret;
91   }
92 
93   // Given an index among the active timeouts, will expire that one.
ExpireTimeout(size_t index)94   absl::optional<TimeoutID> ExpireTimeout(size_t index) {
95     if (index < active_timeouts_.size()) {
96       auto it = active_timeouts_.begin();
97       std::advance(it, index);
98       TimeoutID timeout_id = *it;
99       active_timeouts_.erase(it);
100       return timeout_id;
101     }
102     return absl::nullopt;
103   }
104 
105  private:
106   // Needs to be ordered, to allow fuzzers to expire timers.
107   std::set<TimeoutID> active_timeouts_;
108   std::deque<std::vector<uint8_t>> sent_packets_;
109 };
110 
111 // Given some fuzzing `data` will send packets to the socket as well as calling
112 // API methods.
113 void FuzzSocket(DcSctpSocketInterface& socket,
114                 FuzzerCallbacks& cb,
115                 rtc::ArrayView<const uint8_t> data);
116 
117 }  // namespace dcsctp_fuzzers
118 }  // namespace dcsctp
119 #endif  // NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_
120