1 /* 2 * Copyright (c) 2021 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 #ifndef NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_ 11 #define NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_ 12 13 #include <deque> 14 #include <memory> 15 #include <set> 16 #include <vector> 17 18 #include "api/array_view.h" 19 #include "api/task_queue/task_queue_base.h" 20 #include "net/dcsctp/public/dcsctp_socket.h" 21 22 namespace dcsctp { 23 namespace dcsctp_fuzzers { 24 25 // A fake timeout used during fuzzing. 26 class FuzzerTimeout : public Timeout { 27 public: FuzzerTimeout(std::set<TimeoutID> & active_timeouts)28 explicit FuzzerTimeout(std::set<TimeoutID>& active_timeouts) 29 : active_timeouts_(active_timeouts) {} 30 Start(DurationMs duration_ms,TimeoutID timeout_id)31 void Start(DurationMs duration_ms, TimeoutID timeout_id) override { 32 // Start is only allowed to be called on stopped or expired timeouts. 33 if (timeout_id_.has_value()) { 34 // It has been started before, but maybe it expired. Ensure that it's not 35 // running at least. 36 RTC_DCHECK(active_timeouts_.find(*timeout_id_) == active_timeouts_.end()); 37 } 38 timeout_id_ = timeout_id; 39 RTC_DCHECK(active_timeouts_.insert(timeout_id).second); 40 } 41 Stop()42 void Stop() override { 43 // Stop is only allowed to be called on active timeouts. Not stopped or 44 // expired. 45 RTC_DCHECK(timeout_id_.has_value()); 46 RTC_DCHECK(active_timeouts_.erase(*timeout_id_) == 1); 47 timeout_id_ = absl::nullopt; 48 } 49 50 // A set of all active timeouts, managed by `FuzzerCallbacks`. 51 std::set<TimeoutID>& active_timeouts_; 52 // If present, the timout is active and will expire reported as `timeout_id`. 53 absl::optional<TimeoutID> timeout_id_; 54 }; 55 56 class FuzzerCallbacks : public DcSctpSocketCallbacks { 57 public: 58 static constexpr int kRandomValue = 42; SendPacket(rtc::ArrayView<const uint8_t> data)59 void SendPacket(rtc::ArrayView<const uint8_t> data) override { 60 sent_packets_.emplace_back(std::vector<uint8_t>(data.begin(), data.end())); 61 } CreateTimeout(webrtc::TaskQueueBase::DelayPrecision precision)62 std::unique_ptr<Timeout> CreateTimeout( 63 webrtc::TaskQueueBase::DelayPrecision precision) override { 64 // The fuzzer timeouts don't implement |precision|. 65 return std::make_unique<FuzzerTimeout>(active_timeouts_); 66 } TimeMillis()67 TimeMs TimeMillis() override { return TimeMs(42); } GetRandomInt(uint32_t low,uint32_t high)68 uint32_t GetRandomInt(uint32_t low, uint32_t high) override { 69 return kRandomValue; 70 } OnMessageReceived(DcSctpMessage message)71 void OnMessageReceived(DcSctpMessage message) override {} OnError(ErrorKind error,absl::string_view message)72 void OnError(ErrorKind error, absl::string_view message) override {} OnAborted(ErrorKind error,absl::string_view message)73 void OnAborted(ErrorKind error, absl::string_view message) override {} OnConnected()74 void OnConnected() override {} OnClosed()75 void OnClosed() override {} OnConnectionRestarted()76 void OnConnectionRestarted() override {} OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams,absl::string_view reason)77 void OnStreamsResetFailed(rtc::ArrayView<const StreamID> outgoing_streams, 78 absl::string_view reason) override {} OnStreamsResetPerformed(rtc::ArrayView<const StreamID> outgoing_streams)79 void OnStreamsResetPerformed( 80 rtc::ArrayView<const StreamID> outgoing_streams) override {} OnIncomingStreamsReset(rtc::ArrayView<const StreamID> incoming_streams)81 void OnIncomingStreamsReset( 82 rtc::ArrayView<const StreamID> incoming_streams) override {} 83 ConsumeSentPacket()84 std::vector<uint8_t> ConsumeSentPacket() { 85 if (sent_packets_.empty()) { 86 return {}; 87 } 88 std::vector<uint8_t> ret = sent_packets_.front(); 89 sent_packets_.pop_front(); 90 return ret; 91 } 92 93 // Given an index among the active timeouts, will expire that one. ExpireTimeout(size_t index)94 absl::optional<TimeoutID> ExpireTimeout(size_t index) { 95 if (index < active_timeouts_.size()) { 96 auto it = active_timeouts_.begin(); 97 std::advance(it, index); 98 TimeoutID timeout_id = *it; 99 active_timeouts_.erase(it); 100 return timeout_id; 101 } 102 return absl::nullopt; 103 } 104 105 private: 106 // Needs to be ordered, to allow fuzzers to expire timers. 107 std::set<TimeoutID> active_timeouts_; 108 std::deque<std::vector<uint8_t>> sent_packets_; 109 }; 110 111 // Given some fuzzing `data` will send packets to the socket as well as calling 112 // API methods. 113 void FuzzSocket(DcSctpSocketInterface& socket, 114 FuzzerCallbacks& cb, 115 rtc::ArrayView<const uint8_t> data); 116 117 } // namespace dcsctp_fuzzers 118 } // namespace dcsctp 119 #endif // NET_DCSCTP_FUZZERS_DCSCTP_FUZZERS_H_ 120