1 // Copyright 2022 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/dns/opt_record_rdata.h"
6
7 #include <algorithm>
8 #include <memory>
9 #include <numeric>
10 #include <string_view>
11 #include <utility>
12
13 #include "base/big_endian.h"
14 #include "base/check_is_test.h"
15 #include "base/containers/contains.h"
16 #include "base/containers/span.h"
17 #include "base/containers/span_reader.h"
18 #include "base/containers/span_writer.h"
19 #include "base/memory/ptr_util.h"
20 #include "base/numerics/safe_conversions.h"
21 #include "base/strings/string_util.h"
22 #include "base/types/optional_util.h"
23 #include "net/dns/public/dns_protocol.h"
24
25 namespace net {
26
27 namespace {
SerializeEdeOpt(uint16_t info_code,std::string_view extra_text)28 std::string SerializeEdeOpt(uint16_t info_code, std::string_view extra_text) {
29 std::string buf(2 + extra_text.size(), '\0');
30
31 auto writer = base::SpanWriter(base::as_writable_byte_span(buf));
32 CHECK(writer.WriteU16BigEndian(info_code));
33 CHECK(writer.Write(base::as_byte_span(extra_text)));
34 CHECK_EQ(writer.remaining(), 0u);
35 return buf;
36 }
37 } // namespace
38
Opt(std::string data)39 OptRecordRdata::Opt::Opt(std::string data) : data_(std::move(data)) {}
40
operator ==(const OptRecordRdata::Opt & other) const41 bool OptRecordRdata::Opt::operator==(const OptRecordRdata::Opt& other) const {
42 return IsEqual(other);
43 }
44
operator !=(const OptRecordRdata::Opt & other) const45 bool OptRecordRdata::Opt::operator!=(const OptRecordRdata::Opt& other) const {
46 return !IsEqual(other);
47 }
48
IsEqual(const OptRecordRdata::Opt & other) const49 bool OptRecordRdata::Opt::IsEqual(const OptRecordRdata::Opt& other) const {
50 return GetCode() == other.GetCode() && data() == other.data();
51 }
52
EdeOpt(uint16_t info_code,std::string extra_text)53 OptRecordRdata::EdeOpt::EdeOpt(uint16_t info_code, std::string extra_text)
54 : Opt(SerializeEdeOpt(info_code, extra_text)),
55 info_code_(info_code),
56 extra_text_(std::move(extra_text)) {
57 CHECK(base::IsStringUTF8(extra_text_));
58 }
59
60 OptRecordRdata::EdeOpt::~EdeOpt() = default;
61
Create(std::string data)62 std::unique_ptr<OptRecordRdata::EdeOpt> OptRecordRdata::EdeOpt::Create(
63 std::string data) {
64 uint16_t info_code;
65 auto edeReader = base::SpanReader(base::as_byte_span(data));
66
67 // size must be at least 2: info_code + optional extra_text
68 base::span<const uint8_t> extra_text;
69 if (!edeReader.ReadU16BigEndian(info_code) ||
70 !base::OptionalUnwrapTo(edeReader.Read(edeReader.remaining()),
71 extra_text)) {
72 return nullptr;
73 }
74
75 if (!base::IsStringUTF8(base::as_string_view(extra_text))) {
76 return nullptr;
77 }
78
79 return std::make_unique<EdeOpt>(
80 info_code, std::string(base::as_string_view(extra_text)));
81 }
82
GetCode() const83 uint16_t OptRecordRdata::EdeOpt::GetCode() const {
84 return EdeOpt::kOptCode;
85 }
86
87 OptRecordRdata::EdeOpt::EdeInfoCode
GetEnumFromInfoCode() const88 OptRecordRdata::EdeOpt::GetEnumFromInfoCode() const {
89 return GetEnumFromInfoCode(info_code_);
90 }
91
GetEnumFromInfoCode(uint16_t info_code)92 OptRecordRdata::EdeOpt::EdeInfoCode OptRecordRdata::EdeOpt::GetEnumFromInfoCode(
93 uint16_t info_code) {
94 switch (info_code) {
95 case 0:
96 return EdeInfoCode::kOtherError;
97 case 1:
98 return EdeInfoCode::kUnsupportedDnskeyAlgorithm;
99 case 2:
100 return EdeInfoCode::kUnsupportedDsDigestType;
101 case 3:
102 return EdeInfoCode::kStaleAnswer;
103 case 4:
104 return EdeInfoCode::kForgedAnswer;
105 case 5:
106 return EdeInfoCode::kDnssecIndeterminate;
107 case 6:
108 return EdeInfoCode::kDnssecBogus;
109 case 7:
110 return EdeInfoCode::kSignatureExpired;
111 case 8:
112 return EdeInfoCode::kSignatureNotYetValid;
113 case 9:
114 return EdeInfoCode::kDnskeyMissing;
115 case 10:
116 return EdeInfoCode::kRrsigsMissing;
117 case 11:
118 return EdeInfoCode::kNoZoneKeyBitSet;
119 case 12:
120 return EdeInfoCode::kNsecMissing;
121 case 13:
122 return EdeInfoCode::kCachedError;
123 case 14:
124 return EdeInfoCode::kNotReady;
125 case 15:
126 return EdeInfoCode::kBlocked;
127 case 16:
128 return EdeInfoCode::kCensored;
129 case 17:
130 return EdeInfoCode::kFiltered;
131 case 18:
132 return EdeInfoCode::kProhibited;
133 case 19:
134 return EdeInfoCode::kStaleNxdomainAnswer;
135 case 20:
136 return EdeInfoCode::kNotAuthoritative;
137 case 21:
138 return EdeInfoCode::kNotSupported;
139 case 22:
140 return EdeInfoCode::kNoReachableAuthority;
141 case 23:
142 return EdeInfoCode::kNetworkError;
143 case 24:
144 return EdeInfoCode::kInvalidData;
145 case 25:
146 return EdeInfoCode::kSignatureExpiredBeforeValid;
147 case 26:
148 return EdeInfoCode::kTooEarly;
149 case 27:
150 return EdeInfoCode::kUnsupportedNsec3IterationsValue;
151 default:
152 return EdeInfoCode::kUnrecognizedErrorCode;
153 }
154 }
155
PaddingOpt(std::string padding)156 OptRecordRdata::PaddingOpt::PaddingOpt(std::string padding)
157 : Opt(std::move(padding)) {}
158
PaddingOpt(uint16_t padding_len)159 OptRecordRdata::PaddingOpt::PaddingOpt(uint16_t padding_len)
160 : Opt(std::string(base::checked_cast<size_t>(padding_len), '\0')) {}
161
162 OptRecordRdata::PaddingOpt::~PaddingOpt() = default;
163
GetCode() const164 uint16_t OptRecordRdata::PaddingOpt::GetCode() const {
165 return PaddingOpt::kOptCode;
166 }
167
168 OptRecordRdata::UnknownOpt::~UnknownOpt() = default;
169
170 std::unique_ptr<OptRecordRdata::UnknownOpt>
CreateForTesting(uint16_t code,std::string data)171 OptRecordRdata::UnknownOpt::CreateForTesting(uint16_t code, std::string data) {
172 CHECK_IS_TEST();
173 return base::WrapUnique(
174 new OptRecordRdata::UnknownOpt(code, std::move(data)));
175 }
176
UnknownOpt(uint16_t code,std::string data)177 OptRecordRdata::UnknownOpt::UnknownOpt(uint16_t code, std::string data)
178 : Opt(std::move(data)), code_(code) {
179 CHECK(!base::Contains(kOptsWithDedicatedClasses, code));
180 }
181
GetCode() const182 uint16_t OptRecordRdata::UnknownOpt::GetCode() const {
183 return code_;
184 }
185
186 OptRecordRdata::OptRecordRdata() = default;
187
188 OptRecordRdata::~OptRecordRdata() = default;
189
operator ==(const OptRecordRdata & other) const190 bool OptRecordRdata::operator==(const OptRecordRdata& other) const {
191 return IsEqual(&other);
192 }
193
operator !=(const OptRecordRdata & other) const194 bool OptRecordRdata::operator!=(const OptRecordRdata& other) const {
195 return !IsEqual(&other);
196 }
197
198 // static
Create(std::string_view data)199 std::unique_ptr<OptRecordRdata> OptRecordRdata::Create(std::string_view data) {
200 auto rdata = std::make_unique<OptRecordRdata>();
201 rdata->buf_.assign(data.begin(), data.end());
202
203 auto reader = base::SpanReader(base::as_byte_span(data));
204 while (reader.remaining() > 0u) {
205 uint16_t opt_code, opt_data_size;
206 base::span<const uint8_t> opt_data;
207
208 if (!reader.ReadU16BigEndian(opt_code) ||
209 !reader.ReadU16BigEndian(opt_data_size) ||
210 !base::OptionalUnwrapTo(reader.Read(opt_data_size), opt_data)) {
211 return nullptr;
212 }
213
214 // After the Opt object has been parsed, parse the contents (the data)
215 // depending on the opt_code. The specific Opt subclasses all inherit from
216 // Opt. If an opt code does not have a matching Opt subclass, a simple Opt
217 // object will be created, and data won't be parsed.
218
219 std::unique_ptr<Opt> opt;
220
221 switch (opt_code) {
222 case dns_protocol::kEdnsPadding:
223 opt = std::make_unique<OptRecordRdata::PaddingOpt>(
224 std::string(base::as_string_view(opt_data)));
225 break;
226 case dns_protocol::kEdnsExtendedDnsError:
227 opt = OptRecordRdata::EdeOpt::Create(
228 std::string(base::as_string_view(opt_data)));
229 break;
230 default:
231 opt = base::WrapUnique(new OptRecordRdata::UnknownOpt(
232 opt_code, std::string(base::as_string_view(opt_data))));
233 break;
234 }
235
236 // Confirm that opt is not null, which would be the result of a failed
237 // parse.
238 if (!opt) {
239 return nullptr;
240 }
241
242 rdata->opts_.emplace(opt_code, std::move(opt));
243 }
244
245 return rdata;
246 }
247
Type() const248 uint16_t OptRecordRdata::Type() const {
249 return OptRecordRdata::kType;
250 }
251
IsEqual(const RecordRdata * other) const252 bool OptRecordRdata::IsEqual(const RecordRdata* other) const {
253 if (other->Type() != Type()) {
254 return false;
255 }
256 const OptRecordRdata* opt_other = static_cast<const OptRecordRdata*>(other);
257 return opt_other->buf_ == buf_;
258 }
259
AddOpt(std::unique_ptr<Opt> opt)260 void OptRecordRdata::AddOpt(std::unique_ptr<Opt> opt) {
261 std::string_view opt_data = opt->data();
262
263 // Resize buffer to accommodate new OPT.
264 const size_t orig_rdata_size = buf_.size();
265 buf_.resize(orig_rdata_size + Opt::kHeaderSize + opt_data.size());
266
267 // Start writing from the end of the existing rdata.
268 auto writer = base::SpanWriter(base::as_writable_byte_span(buf_));
269 CHECK(writer.Skip(orig_rdata_size));
270 bool success = writer.WriteU16BigEndian(opt->GetCode()) &&
271 writer.WriteU16BigEndian(opt_data.size()) &&
272 writer.Write(base::as_byte_span(opt_data));
273 DCHECK(success);
274
275 opts_.emplace(opt->GetCode(), std::move(opt));
276 }
277
ContainsOptCode(uint16_t opt_code) const278 bool OptRecordRdata::ContainsOptCode(uint16_t opt_code) const {
279 return base::Contains(opts_, opt_code);
280 }
281
GetOpts() const282 std::vector<const OptRecordRdata::Opt*> OptRecordRdata::GetOpts() const {
283 std::vector<const OptRecordRdata::Opt*> opts;
284 opts.reserve(OptCount());
285 for (const auto& elem : opts_) {
286 opts.push_back(elem.second.get());
287 }
288 return opts;
289 }
290
GetPaddingOpts() const291 std::vector<const OptRecordRdata::PaddingOpt*> OptRecordRdata::GetPaddingOpts()
292 const {
293 std::vector<const OptRecordRdata::PaddingOpt*> opts;
294 auto range = opts_.equal_range(dns_protocol::kEdnsPadding);
295 for (auto it = range.first; it != range.second; ++it) {
296 opts.push_back(static_cast<const PaddingOpt*>(it->second.get()));
297 }
298 return opts;
299 }
300
GetEdeOpts() const301 std::vector<const OptRecordRdata::EdeOpt*> OptRecordRdata::GetEdeOpts() const {
302 std::vector<const OptRecordRdata::EdeOpt*> opts;
303 auto range = opts_.equal_range(dns_protocol::kEdnsExtendedDnsError);
304 for (auto it = range.first; it != range.second; ++it) {
305 opts.push_back(static_cast<const EdeOpt*>(it->second.get()));
306 }
307 return opts;
308 }
309
310 } // namespace net
311