xref: /aosp_15_r20/external/openscreen/discovery/mdns/mdns_reader.cc (revision 3f982cf4871df8771c9d4abe6e9a6f8d829b2736)
1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "discovery/mdns/mdns_reader.h"
6 
7 #include <algorithm>
8 #include <utility>
9 
10 #include "discovery/common/config.h"
11 #include "discovery/mdns/public/mdns_constants.h"
12 #include "util/osp_logging.h"
13 
14 namespace openscreen {
15 namespace discovery {
16 namespace {
17 
TryParseDnsType(uint16_t to_parse,DnsType * type)18 bool TryParseDnsType(uint16_t to_parse, DnsType* type) {
19   auto it = std::find(kSupportedDnsTypes.begin(), kSupportedDnsTypes.end(),
20                       static_cast<DnsType>(to_parse));
21   if (it == kSupportedDnsTypes.end()) {
22     return false;
23   }
24 
25   *type = *it;
26   return true;
27 }
28 
29 }  // namespace
30 
MdnsReader(const Config & config,const uint8_t * buffer,size_t length)31 MdnsReader::MdnsReader(const Config& config,
32                        const uint8_t* buffer,
33                        size_t length)
34     : BigEndianReader(buffer, length),
35       kMaximumAllowedRdataSize(
36           static_cast<size_t>(config.maximum_valid_rdata_size)) {
37   // TODO(rwkeane): Validate |maximum_valid_rdata_size| > MaxWireSize() for
38   // rdata types A, AAAA, SRV, PTR.
39   OSP_DCHECK_GT(config.maximum_valid_rdata_size, 0);
40 }
41 
Read(TxtRecordRdata::Entry * out)42 bool MdnsReader::Read(TxtRecordRdata::Entry* out) {
43   Cursor cursor(this);
44   uint8_t entry_length;
45   if (!Read(&entry_length)) {
46     return false;
47   }
48   const uint8_t* entry_begin = current();
49   if (!Skip(entry_length)) {
50     return false;
51   }
52   out->reserve(entry_length);
53   out->insert(out->end(), entry_begin, entry_begin + entry_length);
54   cursor.Commit();
55   return true;
56 }
57 
58 // RFC 1035: https://www.ietf.org/rfc/rfc1035.txt
59 // See section 4.1.4. Message compression.
Read(DomainName * out)60 bool MdnsReader::Read(DomainName* out) {
61   OSP_DCHECK(out);
62   const uint8_t* position = current();
63   // The number of bytes consumed reading from the starting position to either
64   // the first label pointer or the final termination byte, including the
65   // pointer or the termination byte. This is equal to the actual wire size of
66   // the DomainName accounting for compression.
67   size_t bytes_consumed = 0;
68   // The number of bytes that was processed when reading the DomainName,
69   // including all label pointers and direct labels. It is used to detect
70   // circular compression. The number of processed bytes cannot be possibly
71   // greater than the length of the buffer.
72   size_t bytes_processed = 0;
73   size_t domain_name_length = 0;
74   std::vector<absl::string_view> labels;
75   // If we are pointing before the beginning or past the end of the buffer, we
76   // hit a malformed pointer. If we have processed more bytes than there are in
77   // the buffer, we are in a circular compression loop.
78   while (position >= begin() && position < end() &&
79          bytes_processed <= length()) {
80     const uint8_t label_type = ReadBigEndian<uint8_t>(position);
81     if (IsTerminationLabel(label_type)) {
82       ErrorOr<DomainName> domain =
83           DomainName::TryCreate(labels.begin(), labels.end());
84       if (domain.is_error()) {
85         return false;
86       }
87       *out = std::move(domain.value());
88       if (!bytes_consumed) {
89         bytes_consumed = position + sizeof(uint8_t) - current();
90       }
91       return Skip(bytes_consumed);
92     } else if (IsPointerLabel(label_type)) {
93       if (position + sizeof(uint16_t) > end()) {
94         return false;
95       }
96       const uint16_t label_offset =
97           GetPointerLabelOffset(ReadBigEndian<uint16_t>(position));
98       if (!bytes_consumed) {
99         bytes_consumed = position + sizeof(uint16_t) - current();
100       }
101       bytes_processed += sizeof(uint16_t);
102       position = begin() + label_offset;
103     } else if (IsDirectLabel(label_type)) {
104       const uint8_t label_length = GetDirectLabelLength(label_type);
105       OSP_DCHECK_GT(label_length, 0);
106       bytes_processed += sizeof(uint8_t);
107       position += sizeof(uint8_t);
108       if (position + label_length >= end()) {
109         return false;
110       }
111       const absl::string_view label(reinterpret_cast<const char*>(position),
112                                     label_length);
113       domain_name_length += label_length + 1;  // including the length byte
114       if (!IsValidDomainLabel(label) ||
115           domain_name_length > kMaxDomainNameLength) {
116         return false;
117       }
118       labels.push_back(label);
119       bytes_processed += label_length;
120       position += label_length;
121     } else {
122       return false;
123     }
124   }
125   return false;
126 }
127 
Read(RawRecordRdata * out)128 bool MdnsReader::Read(RawRecordRdata* out) {
129   OSP_DCHECK(out);
130   Cursor cursor(this);
131   uint16_t record_length;
132   if (Read(&record_length)) {
133     if (record_length > kMaximumAllowedRdataSize) {
134       return false;
135     }
136 
137     std::vector<uint8_t> buffer(record_length);
138     if (Read(buffer.size(), buffer.data())) {
139       ErrorOr<RawRecordRdata> rdata =
140           RawRecordRdata::TryCreate(std::move(buffer));
141       if (rdata.is_error()) {
142         return false;
143       }
144       *out = std::move(rdata.value());
145       cursor.Commit();
146       return true;
147     }
148   }
149   return false;
150 }
151 
Read(SrvRecordRdata * out)152 bool MdnsReader::Read(SrvRecordRdata* out) {
153   OSP_DCHECK(out);
154   Cursor cursor(this);
155   uint16_t record_length;
156   uint16_t priority;
157   uint16_t weight;
158   uint16_t port;
159   DomainName target;
160   if (Read(&record_length) && Read(&priority) && Read(&weight) && Read(&port) &&
161       Read(&target) &&
162       (cursor.delta() == sizeof(record_length) + record_length)) {
163     *out = SrvRecordRdata(priority, weight, port, std::move(target));
164     cursor.Commit();
165     return true;
166   }
167   return false;
168 }
169 
Read(ARecordRdata * out)170 bool MdnsReader::Read(ARecordRdata* out) {
171   OSP_DCHECK(out);
172   Cursor cursor(this);
173   uint16_t record_length;
174   IPAddress address;
175   if (Read(&record_length) && (record_length == IPAddress::kV4Size) &&
176       Read(IPAddress::Version::kV4, &address)) {
177     *out = ARecordRdata(address);
178     cursor.Commit();
179     return true;
180   }
181   return false;
182 }
183 
Read(AAAARecordRdata * out)184 bool MdnsReader::Read(AAAARecordRdata* out) {
185   OSP_DCHECK(out);
186   Cursor cursor(this);
187   uint16_t record_length;
188   IPAddress address;
189   if (Read(&record_length) && (record_length == IPAddress::kV6Size) &&
190       Read(IPAddress::Version::kV6, &address)) {
191     *out = AAAARecordRdata(address);
192     cursor.Commit();
193     return true;
194   }
195   return false;
196 }
197 
Read(PtrRecordRdata * out)198 bool MdnsReader::Read(PtrRecordRdata* out) {
199   OSP_DCHECK(out);
200   Cursor cursor(this);
201   uint16_t record_length;
202   DomainName ptr_domain;
203   if (Read(&record_length) && Read(&ptr_domain) &&
204       (cursor.delta() == sizeof(record_length) + record_length)) {
205     *out = PtrRecordRdata(std::move(ptr_domain));
206     cursor.Commit();
207     return true;
208   }
209   return false;
210 }
211 
Read(TxtRecordRdata * out)212 bool MdnsReader::Read(TxtRecordRdata* out) {
213   OSP_DCHECK(out);
214   Cursor cursor(this);
215   uint16_t record_length;
216   if (!Read(&record_length)) {
217     return false;
218   }
219   if (record_length > kMaximumAllowedRdataSize) {
220     return false;
221   }
222   std::vector<TxtRecordRdata::Entry> texts;
223   while (cursor.delta() < sizeof(record_length) + record_length) {
224     TxtRecordRdata::Entry entry;
225     if (!Read(&entry)) {
226       return false;
227     }
228     OSP_DCHECK(entry.size() <= kTXTMaxEntrySize);
229     if (!entry.empty()) {
230       texts.emplace_back(entry);
231     }
232   }
233   if (cursor.delta() != sizeof(record_length) + record_length) {
234     return false;
235   }
236   ErrorOr<TxtRecordRdata> rdata = TxtRecordRdata::TryCreate(std::move(texts));
237   if (rdata.is_error()) {
238     return false;
239   }
240   *out = std::move(rdata.value());
241   cursor.Commit();
242   return true;
243 }
244 
Read(NsecRecordRdata * out)245 bool MdnsReader::Read(NsecRecordRdata* out) {
246   OSP_DCHECK(out);
247   Cursor cursor(this);
248 
249   const uint8_t* start_position = current();
250   uint16_t record_length;
251   DomainName next_record_name;
252   if (!Read(&record_length) || !Read(&next_record_name)) {
253     return false;
254   }
255   if (record_length > kMaximumAllowedRdataSize) {
256     return false;
257   }
258 
259   // Calculate the next record name length. This may not be equal to the length
260   // of |next_record_name| due to domain name compression.
261   const int encoded_next_name_length =
262       current() - start_position - sizeof(record_length);
263   const int remaining_length = record_length - encoded_next_name_length;
264   if (remaining_length <= 0) {
265     // This means either the length is invalid or the NSEC record has no
266     // associated types.
267     return false;
268   }
269 
270   std::vector<DnsType> types;
271   if (Read(&types, remaining_length)) {
272     *out = NsecRecordRdata(std::move(next_record_name), std::move(types));
273     cursor.Commit();
274     return true;
275   }
276 
277   return false;
278 }
279 
Read(MdnsRecord * out)280 bool MdnsReader::Read(MdnsRecord* out) {
281   OSP_DCHECK(out);
282   Cursor cursor(this);
283   DomainName name;
284   uint16_t type;
285   uint16_t rrclass;
286   uint32_t ttl;
287   Rdata rdata;
288   if (Read(&name) && Read(&type) && Read(&rrclass) && Read(&ttl) &&
289       Read(static_cast<DnsType>(type), &rdata)) {
290     ErrorOr<MdnsRecord> record = MdnsRecord::TryCreate(
291         std::move(name), static_cast<DnsType>(type), GetDnsClass(rrclass),
292         GetRecordType(rrclass), std::chrono::seconds(ttl), std::move(rdata));
293     if (record.is_error()) {
294       return false;
295     }
296     *out = std::move(record.value());
297 
298     cursor.Commit();
299     return true;
300   }
301   return false;
302 }
303 
Read(MdnsQuestion * out)304 bool MdnsReader::Read(MdnsQuestion* out) {
305   OSP_DCHECK(out);
306   Cursor cursor(this);
307   DomainName name;
308   uint16_t type;
309   uint16_t rrclass;
310   if (Read(&name) && Read(&type) && Read(&rrclass)) {
311     ErrorOr<MdnsQuestion> question =
312         MdnsQuestion::TryCreate(std::move(name), static_cast<DnsType>(type),
313                                 GetDnsClass(rrclass), GetResponseType(rrclass));
314     if (question.is_error()) {
315       return false;
316     }
317     *out = std::move(question.value());
318 
319     cursor.Commit();
320     return true;
321   }
322   return false;
323 }
324 
Read()325 ErrorOr<MdnsMessage> MdnsReader::Read() {
326   MdnsMessage out;
327   Cursor cursor(this);
328   Header header;
329   std::vector<MdnsQuestion> questions;
330   std::vector<MdnsRecord> answers;
331   std::vector<MdnsRecord> authority_records;
332   std::vector<MdnsRecord> additional_records;
333   if (Read(&header) && Read(header.question_count, &questions) &&
334       Read(header.answer_count, &answers) &&
335       Read(header.authority_record_count, &authority_records) &&
336       Read(header.additional_record_count, &additional_records)) {
337     if (!IsValidFlagsSection(header.flags)) {
338       return Error::Code::kMdnsNonConformingFailure;
339     }
340 
341     ErrorOr<MdnsMessage> message = MdnsMessage::TryCreate(
342         header.id, GetMessageType(header.flags), questions, answers,
343         authority_records, additional_records);
344     if (message.is_error()) {
345       return std::move(message.error());
346     }
347     out = std::move(message.value());
348 
349     if (IsMessageTruncated(header.flags)) {
350       out.set_truncated();
351     }
352 
353     cursor.Commit();
354     return out;
355   }
356   return Error::Code::kMdnsReadFailure;
357 }
358 
Read(IPAddress::Version version,IPAddress * out)359 bool MdnsReader::Read(IPAddress::Version version, IPAddress* out) {
360   OSP_DCHECK(out);
361   size_t ipaddress_size = (version == IPAddress::Version::kV6)
362                               ? IPAddress::kV6Size
363                               : IPAddress::kV4Size;
364   const uint8_t* const address_bytes = current();
365   if (Skip(ipaddress_size)) {
366     *out = IPAddress(version, address_bytes);
367     return true;
368   }
369   return false;
370 }
371 
Read(DnsType type,Rdata * out)372 bool MdnsReader::Read(DnsType type, Rdata* out) {
373   OSP_DCHECK(out);
374   switch (type) {
375     case DnsType::kSRV:
376       return Read<SrvRecordRdata>(out);
377     case DnsType::kA:
378       return Read<ARecordRdata>(out);
379     case DnsType::kAAAA:
380       return Read<AAAARecordRdata>(out);
381     case DnsType::kPTR:
382       return Read<PtrRecordRdata>(out);
383     case DnsType::kTXT:
384       return Read<TxtRecordRdata>(out);
385     case DnsType::kNSEC:
386       return Read<NsecRecordRdata>(out);
387     default:
388       OSP_DCHECK(std::find(kSupportedDnsTypes.begin(), kSupportedDnsTypes.end(),
389                            type) == kSupportedDnsTypes.end());
390       return Read<RawRecordRdata>(out);
391   }
392 }
393 
Read(Header * out)394 bool MdnsReader::Read(Header* out) {
395   OSP_DCHECK(out);
396   Cursor cursor(this);
397   if (Read(&out->id) && Read(&out->flags) && Read(&out->question_count) &&
398       Read(&out->answer_count) && Read(&out->authority_record_count) &&
399       Read(&out->additional_record_count)) {
400     cursor.Commit();
401     return true;
402   }
403   return false;
404 }
405 
Read(std::vector<DnsType> * out,int remaining_size)406 bool MdnsReader::Read(std::vector<DnsType>* out, int remaining_size) {
407   OSP_DCHECK(out);
408   Cursor cursor(this);
409 
410   // Continue reading bitmaps until the entire input is read. If we have gone
411   // past the end of the record, it's malformed input so fail.
412   *out = std::vector<DnsType>();
413   int processed_bytes = 0;
414   while (processed_bytes < remaining_size) {
415     NsecBitMapField bitmap;
416     if (!Read(&bitmap)) {
417       return false;
418     }
419 
420     processed_bytes += bitmap.bitmap_length + 2;
421     if (processed_bytes > remaining_size) {
422       return false;
423     }
424 
425     // The ith bit of the bitmap represents DnsType with value i, shifted
426     // a multiple of 0x100 according to the window.
427     for (int32_t i = 0; i < bitmap.bitmap_length * 8; i++) {
428       int current_byte = i / 8;
429       uint8_t bitmask = 0x80 >> i % 8;
430 
431       // If this bit flag represents a type we support, add it to the vector.
432       // Else, we won't be able to use it later on in the code anyway, so drop
433       // it.
434       DnsType type;
435       uint16_t type_index = i | (bitmap.window_block << 8);
436       if ((bitmap.bitmap[current_byte] & bitmask) &&
437           TryParseDnsType(type_index, &type)) {
438         out->push_back(type);
439       }
440     }
441   }
442 
443   cursor.Commit();
444   return true;
445 }
446 
Read(NsecBitMapField * out)447 bool MdnsReader::Read(NsecBitMapField* out) {
448   OSP_DCHECK(out);
449   Cursor cursor(this);
450 
451   // Read the window and bitmap length, then one byte for each byte called out
452   // by the length.
453   if (Read(&out->window_block) && Read(&out->bitmap_length)) {
454     if (out->bitmap_length == 0 || out->bitmap_length > 32) {
455       return false;
456     }
457 
458     out->bitmap = current();
459     if (!Skip(out->bitmap_length)) {
460       return false;
461     }
462     cursor.Commit();
463     return true;
464   }
465 
466   return false;
467 }
468 
469 }  // namespace discovery
470 }  // namespace openscreen
471