1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "discovery/mdns/mdns_reader.h"
6
7 #include <algorithm>
8 #include <utility>
9
10 #include "discovery/common/config.h"
11 #include "discovery/mdns/public/mdns_constants.h"
12 #include "util/osp_logging.h"
13
14 namespace openscreen {
15 namespace discovery {
16 namespace {
17
TryParseDnsType(uint16_t to_parse,DnsType * type)18 bool TryParseDnsType(uint16_t to_parse, DnsType* type) {
19 auto it = std::find(kSupportedDnsTypes.begin(), kSupportedDnsTypes.end(),
20 static_cast<DnsType>(to_parse));
21 if (it == kSupportedDnsTypes.end()) {
22 return false;
23 }
24
25 *type = *it;
26 return true;
27 }
28
29 } // namespace
30
MdnsReader(const Config & config,const uint8_t * buffer,size_t length)31 MdnsReader::MdnsReader(const Config& config,
32 const uint8_t* buffer,
33 size_t length)
34 : BigEndianReader(buffer, length),
35 kMaximumAllowedRdataSize(
36 static_cast<size_t>(config.maximum_valid_rdata_size)) {
37 // TODO(rwkeane): Validate |maximum_valid_rdata_size| > MaxWireSize() for
38 // rdata types A, AAAA, SRV, PTR.
39 OSP_DCHECK_GT(config.maximum_valid_rdata_size, 0);
40 }
41
Read(TxtRecordRdata::Entry * out)42 bool MdnsReader::Read(TxtRecordRdata::Entry* out) {
43 Cursor cursor(this);
44 uint8_t entry_length;
45 if (!Read(&entry_length)) {
46 return false;
47 }
48 const uint8_t* entry_begin = current();
49 if (!Skip(entry_length)) {
50 return false;
51 }
52 out->reserve(entry_length);
53 out->insert(out->end(), entry_begin, entry_begin + entry_length);
54 cursor.Commit();
55 return true;
56 }
57
58 // RFC 1035: https://www.ietf.org/rfc/rfc1035.txt
59 // See section 4.1.4. Message compression.
Read(DomainName * out)60 bool MdnsReader::Read(DomainName* out) {
61 OSP_DCHECK(out);
62 const uint8_t* position = current();
63 // The number of bytes consumed reading from the starting position to either
64 // the first label pointer or the final termination byte, including the
65 // pointer or the termination byte. This is equal to the actual wire size of
66 // the DomainName accounting for compression.
67 size_t bytes_consumed = 0;
68 // The number of bytes that was processed when reading the DomainName,
69 // including all label pointers and direct labels. It is used to detect
70 // circular compression. The number of processed bytes cannot be possibly
71 // greater than the length of the buffer.
72 size_t bytes_processed = 0;
73 size_t domain_name_length = 0;
74 std::vector<absl::string_view> labels;
75 // If we are pointing before the beginning or past the end of the buffer, we
76 // hit a malformed pointer. If we have processed more bytes than there are in
77 // the buffer, we are in a circular compression loop.
78 while (position >= begin() && position < end() &&
79 bytes_processed <= length()) {
80 const uint8_t label_type = ReadBigEndian<uint8_t>(position);
81 if (IsTerminationLabel(label_type)) {
82 ErrorOr<DomainName> domain =
83 DomainName::TryCreate(labels.begin(), labels.end());
84 if (domain.is_error()) {
85 return false;
86 }
87 *out = std::move(domain.value());
88 if (!bytes_consumed) {
89 bytes_consumed = position + sizeof(uint8_t) - current();
90 }
91 return Skip(bytes_consumed);
92 } else if (IsPointerLabel(label_type)) {
93 if (position + sizeof(uint16_t) > end()) {
94 return false;
95 }
96 const uint16_t label_offset =
97 GetPointerLabelOffset(ReadBigEndian<uint16_t>(position));
98 if (!bytes_consumed) {
99 bytes_consumed = position + sizeof(uint16_t) - current();
100 }
101 bytes_processed += sizeof(uint16_t);
102 position = begin() + label_offset;
103 } else if (IsDirectLabel(label_type)) {
104 const uint8_t label_length = GetDirectLabelLength(label_type);
105 OSP_DCHECK_GT(label_length, 0);
106 bytes_processed += sizeof(uint8_t);
107 position += sizeof(uint8_t);
108 if (position + label_length >= end()) {
109 return false;
110 }
111 const absl::string_view label(reinterpret_cast<const char*>(position),
112 label_length);
113 domain_name_length += label_length + 1; // including the length byte
114 if (!IsValidDomainLabel(label) ||
115 domain_name_length > kMaxDomainNameLength) {
116 return false;
117 }
118 labels.push_back(label);
119 bytes_processed += label_length;
120 position += label_length;
121 } else {
122 return false;
123 }
124 }
125 return false;
126 }
127
Read(RawRecordRdata * out)128 bool MdnsReader::Read(RawRecordRdata* out) {
129 OSP_DCHECK(out);
130 Cursor cursor(this);
131 uint16_t record_length;
132 if (Read(&record_length)) {
133 if (record_length > kMaximumAllowedRdataSize) {
134 return false;
135 }
136
137 std::vector<uint8_t> buffer(record_length);
138 if (Read(buffer.size(), buffer.data())) {
139 ErrorOr<RawRecordRdata> rdata =
140 RawRecordRdata::TryCreate(std::move(buffer));
141 if (rdata.is_error()) {
142 return false;
143 }
144 *out = std::move(rdata.value());
145 cursor.Commit();
146 return true;
147 }
148 }
149 return false;
150 }
151
Read(SrvRecordRdata * out)152 bool MdnsReader::Read(SrvRecordRdata* out) {
153 OSP_DCHECK(out);
154 Cursor cursor(this);
155 uint16_t record_length;
156 uint16_t priority;
157 uint16_t weight;
158 uint16_t port;
159 DomainName target;
160 if (Read(&record_length) && Read(&priority) && Read(&weight) && Read(&port) &&
161 Read(&target) &&
162 (cursor.delta() == sizeof(record_length) + record_length)) {
163 *out = SrvRecordRdata(priority, weight, port, std::move(target));
164 cursor.Commit();
165 return true;
166 }
167 return false;
168 }
169
Read(ARecordRdata * out)170 bool MdnsReader::Read(ARecordRdata* out) {
171 OSP_DCHECK(out);
172 Cursor cursor(this);
173 uint16_t record_length;
174 IPAddress address;
175 if (Read(&record_length) && (record_length == IPAddress::kV4Size) &&
176 Read(IPAddress::Version::kV4, &address)) {
177 *out = ARecordRdata(address);
178 cursor.Commit();
179 return true;
180 }
181 return false;
182 }
183
Read(AAAARecordRdata * out)184 bool MdnsReader::Read(AAAARecordRdata* out) {
185 OSP_DCHECK(out);
186 Cursor cursor(this);
187 uint16_t record_length;
188 IPAddress address;
189 if (Read(&record_length) && (record_length == IPAddress::kV6Size) &&
190 Read(IPAddress::Version::kV6, &address)) {
191 *out = AAAARecordRdata(address);
192 cursor.Commit();
193 return true;
194 }
195 return false;
196 }
197
Read(PtrRecordRdata * out)198 bool MdnsReader::Read(PtrRecordRdata* out) {
199 OSP_DCHECK(out);
200 Cursor cursor(this);
201 uint16_t record_length;
202 DomainName ptr_domain;
203 if (Read(&record_length) && Read(&ptr_domain) &&
204 (cursor.delta() == sizeof(record_length) + record_length)) {
205 *out = PtrRecordRdata(std::move(ptr_domain));
206 cursor.Commit();
207 return true;
208 }
209 return false;
210 }
211
Read(TxtRecordRdata * out)212 bool MdnsReader::Read(TxtRecordRdata* out) {
213 OSP_DCHECK(out);
214 Cursor cursor(this);
215 uint16_t record_length;
216 if (!Read(&record_length)) {
217 return false;
218 }
219 if (record_length > kMaximumAllowedRdataSize) {
220 return false;
221 }
222 std::vector<TxtRecordRdata::Entry> texts;
223 while (cursor.delta() < sizeof(record_length) + record_length) {
224 TxtRecordRdata::Entry entry;
225 if (!Read(&entry)) {
226 return false;
227 }
228 OSP_DCHECK(entry.size() <= kTXTMaxEntrySize);
229 if (!entry.empty()) {
230 texts.emplace_back(entry);
231 }
232 }
233 if (cursor.delta() != sizeof(record_length) + record_length) {
234 return false;
235 }
236 ErrorOr<TxtRecordRdata> rdata = TxtRecordRdata::TryCreate(std::move(texts));
237 if (rdata.is_error()) {
238 return false;
239 }
240 *out = std::move(rdata.value());
241 cursor.Commit();
242 return true;
243 }
244
Read(NsecRecordRdata * out)245 bool MdnsReader::Read(NsecRecordRdata* out) {
246 OSP_DCHECK(out);
247 Cursor cursor(this);
248
249 const uint8_t* start_position = current();
250 uint16_t record_length;
251 DomainName next_record_name;
252 if (!Read(&record_length) || !Read(&next_record_name)) {
253 return false;
254 }
255 if (record_length > kMaximumAllowedRdataSize) {
256 return false;
257 }
258
259 // Calculate the next record name length. This may not be equal to the length
260 // of |next_record_name| due to domain name compression.
261 const int encoded_next_name_length =
262 current() - start_position - sizeof(record_length);
263 const int remaining_length = record_length - encoded_next_name_length;
264 if (remaining_length <= 0) {
265 // This means either the length is invalid or the NSEC record has no
266 // associated types.
267 return false;
268 }
269
270 std::vector<DnsType> types;
271 if (Read(&types, remaining_length)) {
272 *out = NsecRecordRdata(std::move(next_record_name), std::move(types));
273 cursor.Commit();
274 return true;
275 }
276
277 return false;
278 }
279
Read(MdnsRecord * out)280 bool MdnsReader::Read(MdnsRecord* out) {
281 OSP_DCHECK(out);
282 Cursor cursor(this);
283 DomainName name;
284 uint16_t type;
285 uint16_t rrclass;
286 uint32_t ttl;
287 Rdata rdata;
288 if (Read(&name) && Read(&type) && Read(&rrclass) && Read(&ttl) &&
289 Read(static_cast<DnsType>(type), &rdata)) {
290 ErrorOr<MdnsRecord> record = MdnsRecord::TryCreate(
291 std::move(name), static_cast<DnsType>(type), GetDnsClass(rrclass),
292 GetRecordType(rrclass), std::chrono::seconds(ttl), std::move(rdata));
293 if (record.is_error()) {
294 return false;
295 }
296 *out = std::move(record.value());
297
298 cursor.Commit();
299 return true;
300 }
301 return false;
302 }
303
Read(MdnsQuestion * out)304 bool MdnsReader::Read(MdnsQuestion* out) {
305 OSP_DCHECK(out);
306 Cursor cursor(this);
307 DomainName name;
308 uint16_t type;
309 uint16_t rrclass;
310 if (Read(&name) && Read(&type) && Read(&rrclass)) {
311 ErrorOr<MdnsQuestion> question =
312 MdnsQuestion::TryCreate(std::move(name), static_cast<DnsType>(type),
313 GetDnsClass(rrclass), GetResponseType(rrclass));
314 if (question.is_error()) {
315 return false;
316 }
317 *out = std::move(question.value());
318
319 cursor.Commit();
320 return true;
321 }
322 return false;
323 }
324
Read()325 ErrorOr<MdnsMessage> MdnsReader::Read() {
326 MdnsMessage out;
327 Cursor cursor(this);
328 Header header;
329 std::vector<MdnsQuestion> questions;
330 std::vector<MdnsRecord> answers;
331 std::vector<MdnsRecord> authority_records;
332 std::vector<MdnsRecord> additional_records;
333 if (Read(&header) && Read(header.question_count, &questions) &&
334 Read(header.answer_count, &answers) &&
335 Read(header.authority_record_count, &authority_records) &&
336 Read(header.additional_record_count, &additional_records)) {
337 if (!IsValidFlagsSection(header.flags)) {
338 return Error::Code::kMdnsNonConformingFailure;
339 }
340
341 ErrorOr<MdnsMessage> message = MdnsMessage::TryCreate(
342 header.id, GetMessageType(header.flags), questions, answers,
343 authority_records, additional_records);
344 if (message.is_error()) {
345 return std::move(message.error());
346 }
347 out = std::move(message.value());
348
349 if (IsMessageTruncated(header.flags)) {
350 out.set_truncated();
351 }
352
353 cursor.Commit();
354 return out;
355 }
356 return Error::Code::kMdnsReadFailure;
357 }
358
Read(IPAddress::Version version,IPAddress * out)359 bool MdnsReader::Read(IPAddress::Version version, IPAddress* out) {
360 OSP_DCHECK(out);
361 size_t ipaddress_size = (version == IPAddress::Version::kV6)
362 ? IPAddress::kV6Size
363 : IPAddress::kV4Size;
364 const uint8_t* const address_bytes = current();
365 if (Skip(ipaddress_size)) {
366 *out = IPAddress(version, address_bytes);
367 return true;
368 }
369 return false;
370 }
371
Read(DnsType type,Rdata * out)372 bool MdnsReader::Read(DnsType type, Rdata* out) {
373 OSP_DCHECK(out);
374 switch (type) {
375 case DnsType::kSRV:
376 return Read<SrvRecordRdata>(out);
377 case DnsType::kA:
378 return Read<ARecordRdata>(out);
379 case DnsType::kAAAA:
380 return Read<AAAARecordRdata>(out);
381 case DnsType::kPTR:
382 return Read<PtrRecordRdata>(out);
383 case DnsType::kTXT:
384 return Read<TxtRecordRdata>(out);
385 case DnsType::kNSEC:
386 return Read<NsecRecordRdata>(out);
387 default:
388 OSP_DCHECK(std::find(kSupportedDnsTypes.begin(), kSupportedDnsTypes.end(),
389 type) == kSupportedDnsTypes.end());
390 return Read<RawRecordRdata>(out);
391 }
392 }
393
Read(Header * out)394 bool MdnsReader::Read(Header* out) {
395 OSP_DCHECK(out);
396 Cursor cursor(this);
397 if (Read(&out->id) && Read(&out->flags) && Read(&out->question_count) &&
398 Read(&out->answer_count) && Read(&out->authority_record_count) &&
399 Read(&out->additional_record_count)) {
400 cursor.Commit();
401 return true;
402 }
403 return false;
404 }
405
Read(std::vector<DnsType> * out,int remaining_size)406 bool MdnsReader::Read(std::vector<DnsType>* out, int remaining_size) {
407 OSP_DCHECK(out);
408 Cursor cursor(this);
409
410 // Continue reading bitmaps until the entire input is read. If we have gone
411 // past the end of the record, it's malformed input so fail.
412 *out = std::vector<DnsType>();
413 int processed_bytes = 0;
414 while (processed_bytes < remaining_size) {
415 NsecBitMapField bitmap;
416 if (!Read(&bitmap)) {
417 return false;
418 }
419
420 processed_bytes += bitmap.bitmap_length + 2;
421 if (processed_bytes > remaining_size) {
422 return false;
423 }
424
425 // The ith bit of the bitmap represents DnsType with value i, shifted
426 // a multiple of 0x100 according to the window.
427 for (int32_t i = 0; i < bitmap.bitmap_length * 8; i++) {
428 int current_byte = i / 8;
429 uint8_t bitmask = 0x80 >> i % 8;
430
431 // If this bit flag represents a type we support, add it to the vector.
432 // Else, we won't be able to use it later on in the code anyway, so drop
433 // it.
434 DnsType type;
435 uint16_t type_index = i | (bitmap.window_block << 8);
436 if ((bitmap.bitmap[current_byte] & bitmask) &&
437 TryParseDnsType(type_index, &type)) {
438 out->push_back(type);
439 }
440 }
441 }
442
443 cursor.Commit();
444 return true;
445 }
446
Read(NsecBitMapField * out)447 bool MdnsReader::Read(NsecBitMapField* out) {
448 OSP_DCHECK(out);
449 Cursor cursor(this);
450
451 // Read the window and bitmap length, then one byte for each byte called out
452 // by the length.
453 if (Read(&out->window_block) && Read(&out->bitmap_length)) {
454 if (out->bitmap_length == 0 || out->bitmap_length > 32) {
455 return false;
456 }
457
458 out->bitmap = current();
459 if (!Skip(out->bitmap_length)) {
460 return false;
461 }
462 cursor.Commit();
463 return true;
464 }
465
466 return false;
467 }
468
469 } // namespace discovery
470 } // namespace openscreen
471