// Copyright 2012 The Chromium Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/dns/dns_response.h" #include #include #include #include #include #include #include #include #include "base/big_endian.h" #include "base/containers/span.h" #include "base/containers/span_reader.h" #include "base/containers/span_writer.h" #include "base/logging.h" #include "base/numerics/safe_conversions.h" #include "base/strings/string_util.h" #include "base/sys_byteorder.h" #include "base/types/optional_util.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/dns/dns_names_util.h" #include "net/dns/dns_query.h" #include "net/dns/dns_response_result_extractor.h" #include "net/dns/dns_util.h" #include "net/dns/public/dns_protocol.h" #include "net/dns/record_rdata.h" namespace net { namespace { const size_t kHeaderSize = sizeof(dns_protocol::Header); const uint8_t kRcodeMask = 0xf; } // namespace DnsResourceRecord::DnsResourceRecord() = default; DnsResourceRecord::DnsResourceRecord(const DnsResourceRecord& other) : name(other.name), type(other.type), klass(other.klass), ttl(other.ttl), owned_rdata(other.owned_rdata) { if (!owned_rdata.empty()) rdata = owned_rdata; else rdata = other.rdata; } DnsResourceRecord::DnsResourceRecord(DnsResourceRecord&& other) : name(std::move(other.name)), type(other.type), klass(other.klass), ttl(other.ttl), owned_rdata(std::move(other.owned_rdata)) { if (!owned_rdata.empty()) rdata = owned_rdata; else rdata = other.rdata; } DnsResourceRecord::~DnsResourceRecord() = default; DnsResourceRecord& DnsResourceRecord::operator=( const DnsResourceRecord& other) { name = other.name; type = other.type; klass = other.klass; ttl = other.ttl; owned_rdata = other.owned_rdata; if (!owned_rdata.empty()) rdata = owned_rdata; else rdata = other.rdata; return *this; } DnsResourceRecord& DnsResourceRecord::operator=(DnsResourceRecord&& other) { name = std::move(other.name); type = other.type; klass = other.klass; ttl = other.ttl; owned_rdata = std::move(other.owned_rdata); if (!owned_rdata.empty()) rdata = owned_rdata; else rdata = other.rdata; return *this; } void DnsResourceRecord::SetOwnedRdata(std::string value) { DCHECK(!value.empty()); owned_rdata = std::move(value); rdata = owned_rdata; DCHECK_EQ(owned_rdata.data(), rdata.data()); } size_t DnsResourceRecord::CalculateRecordSize() const { bool has_final_dot = name.back() == '.'; // Depending on if |name| in the dotted format has the final dot for the root // domain or not, the corresponding wire data in the DNS domain name format is // 1 byte (with dot) or 2 bytes larger in size. See RFC 1035, Section 3.1 and // DNSDomainFromDot. return name.size() + (has_final_dot ? 1 : 2) + net::dns_protocol::kResourceRecordSizeInBytesWithoutNameAndRData + (owned_rdata.empty() ? rdata.size() : owned_rdata.size()); } DnsRecordParser::DnsRecordParser() = default; DnsRecordParser::DnsRecordParser(base::span packet, size_t offset, size_t num_records) : packet_(packet), num_records_(num_records), cur_(offset) { CHECK_LE(offset, packet_.size()); } DnsRecordParser::DnsRecordParser(const void* packet, size_t length, size_t offset, size_t num_records) : DnsRecordParser( // TODO(crbug.com/40284755): This span construction can not be sound // here. This DnsRecordParser constructor should be removed. UNSAFE_BUFFERS( base::span(static_cast(packet), length)), offset, num_records) {} unsigned DnsRecordParser::ReadName(const void* const vpos, std::string* out) const { static const char kAbortMsg[] = "Abort parsing of noncompliant DNS record."; CHECK_LE(packet_.data(), vpos); CHECK_LE(vpos, packet_.last(0u).data()); const size_t initial_offset = // SAFETY: `vpos` points into the span, as verified by the CHECKs above, // so subtracting the data pointer is well-defined and gives an offset // into the span. // // TODO(danakj): Since we need an offset anyway, no unsafe pointer usage // would be required, and fewer CHECKs, if this function took an offset // instead of a pointer. UNSAFE_BUFFERS(static_cast(vpos) - packet_.data()); if (initial_offset == packet_.size()) { return 0; } size_t offset = initial_offset; // Count number of seen bytes to detect loops. unsigned seen = 0u; // Remember how many bytes were consumed before first jump. unsigned consumed = 0u; // The length of the encoded name (sum of label octets and label lengths). // For context, RFC 1034 states that the total number of octets representing a // domain name (the sum of all label octets and label lengths) is limited to // 255. RFC 1035 introduces message compression as a way to reduce packet size // on the wire, not to increase the maximum domain name length. unsigned encoded_name_len = 0u; if (out) { out->clear(); out->reserve(dns_protocol::kMaxCharNameLength); } for (;;) { // The first two bits of the length give the type of the length. It's // either a direct length or a pointer to the remainder of the name. switch (packet_[offset] & dns_protocol::kLabelMask) { case dns_protocol::kLabelPointer: { if (packet_.size() < sizeof(uint16_t) || offset > packet_.size() - sizeof(uint16_t)) { VLOG(1) << kAbortMsg << " Truncated or missing label pointer."; return 0; } if (consumed == 0u) { consumed = offset - initial_offset + sizeof(uint16_t); if (!out) { return consumed; // If name is not stored, that's all we need. } } seen += sizeof(uint16_t); // If seen the whole packet, then we must be in a loop. if (seen > packet_.size()) { VLOG(1) << kAbortMsg << " Detected loop in label pointers."; return 0; } uint16_t new_offset = base::U16FromBigEndian(packet_.subspan(offset).first<2u>()); offset = new_offset & dns_protocol::kOffsetMask; if (offset >= packet_.size()) { VLOG(1) << kAbortMsg << " Label pointer points outside packet."; return 0; } break; } case dns_protocol::kLabelDirect: { uint8_t label_len = packet_[offset]; ++offset; // Note: root domain (".") is NOT included. if (label_len == 0) { if (consumed == 0) { consumed = offset - initial_offset; } // else we set |consumed| before first jump return consumed; } // Add one octet for the length and |label_len| for the number of // following octets. encoded_name_len += 1 + label_len; if (encoded_name_len > dns_protocol::kMaxNameLength) { VLOG(1) << kAbortMsg << " Name is too long."; return 0; } if (label_len >= packet_.size() - offset) { VLOG(1) << kAbortMsg << " Truncated or missing label."; return 0; // Truncated or missing label. } if (out) { if (!out->empty()) out->append("."); // TODO(danakj): Use append_range() in C++23. auto range = packet_.subspan(offset, label_len); out->append(range.begin(), range.end()); CHECK_LE(out->size(), dns_protocol::kMaxCharNameLength); } offset += label_len; seen += 1 + label_len; break; } default: // unhandled label type VLOG(1) << kAbortMsg << " Unhandled label type."; return 0; } } } bool DnsRecordParser::ReadRecord(DnsResourceRecord* out) { CHECK(!packet_.empty()); // Disallow parsing any more than the claimed number of records. if (num_records_parsed_ >= num_records_) return false; size_t consumed = ReadName(packet_.subspan(cur_).data(), &out->name); if (!consumed) { return false; } auto reader = base::SpanReader(packet_.subspan(cur_ + consumed)); uint16_t rdlen; if (reader.ReadU16BigEndian(out->type) && reader.ReadU16BigEndian(out->klass) && reader.ReadU32BigEndian(out->ttl) && // reader.ReadU16BigEndian(rdlen) && base::OptionalUnwrapTo(reader.Read(rdlen), out->rdata, [](auto span) { return base::as_string_view(span); })) { cur_ += consumed + 2u + 2u + 4u + 2u + rdlen; ++num_records_parsed_; return true; } return false; } bool DnsRecordParser::ReadQuestion(std::string& out_dotted_qname, uint16_t& out_qtype) { size_t consumed = ReadName(packet_.subspan(cur_).data(), &out_dotted_qname); if (!consumed) return false; if (consumed + 2 * sizeof(uint16_t) > packet_.size() - cur_) { return false; } out_qtype = base::U16FromBigEndian( packet_.subspan(cur_ + consumed).first()); cur_ += consumed + 2 * sizeof(uint16_t); // QTYPE + QCLASS return true; } DnsResponse::DnsResponse( uint16_t id, bool is_authoritative, const std::vector& answers, const std::vector& authority_records, const std::vector& additional_records, const std::optional& query, uint8_t rcode, bool validate_records, bool validate_names_as_internet_hostnames) { bool has_query = query.has_value(); dns_protocol::Header header; header.id = id; bool success = true; if (has_query) { success &= (id == query.value().id()); DCHECK(success); // DnsQuery only supports a single question. header.qdcount = 1; } header.flags |= dns_protocol::kFlagResponse; if (is_authoritative) header.flags |= dns_protocol::kFlagAA; DCHECK_EQ(0, rcode & ~kRcodeMask); header.flags |= rcode; header.ancount = answers.size(); header.nscount = authority_records.size(); header.arcount = additional_records.size(); // Response starts with the header and the question section (if any). size_t response_size = has_query ? sizeof(header) + query.value().question_size() : sizeof(header); // Add the size of all answers and additional records. auto do_accumulation = [](size_t cur_size, const DnsResourceRecord& record) { return cur_size + record.CalculateRecordSize(); }; response_size = std::accumulate(answers.begin(), answers.end(), response_size, do_accumulation); response_size = std::accumulate(authority_records.begin(), authority_records.end(), response_size, do_accumulation); response_size = std::accumulate(additional_records.begin(), additional_records.end(), response_size, do_accumulation); auto io_buffer = base::MakeRefCounted(response_size); auto writer = base::SpanWriter(base::as_writable_bytes(io_buffer->span())); success &= WriteHeader(&writer, header); DCHECK(success); if (has_query) { success &= WriteQuestion(&writer, query.value()); DCHECK(success); } // Start the Answer section. for (const auto& answer : answers) { success &= WriteAnswer(&writer, answer, query, validate_records, validate_names_as_internet_hostnames); DCHECK(success); } // Start the Authority section. for (const auto& record : authority_records) { success &= WriteRecord(&writer, record, validate_records, validate_names_as_internet_hostnames); DCHECK(success); } // Start the Additional section. for (const auto& record : additional_records) { success &= WriteRecord(&writer, record, validate_records, validate_names_as_internet_hostnames); DCHECK(success); } if (!success) { return; } io_buffer_ = io_buffer; io_buffer_size_ = response_size; // Ensure we don't have any remaining uninitialized bytes in the buffer. DCHECK_EQ(writer.remaining(), 0u); std::ranges::fill(writer.remaining_span(), uint8_t{0}); if (has_query) InitParse(io_buffer_size_, query.value()); else InitParseWithoutQuery(io_buffer_size_); } DnsResponse::DnsResponse() : io_buffer_(base::MakeRefCounted( dns_protocol::kMaxUDPSize + 1)), io_buffer_size_(dns_protocol::kMaxUDPSize + 1) {} DnsResponse::DnsResponse(scoped_refptr buffer, size_t size) : io_buffer_(std::move(buffer)), io_buffer_size_(size) {} DnsResponse::DnsResponse(size_t length) : io_buffer_(base::MakeRefCounted(length)), io_buffer_size_(length) {} DnsResponse::DnsResponse(const void* data, size_t length, size_t answer_offset) : io_buffer_(base::MakeRefCounted(length)), io_buffer_size_(length), parser_(io_buffer_->data(), length, answer_offset, std::numeric_limits::max()) { DCHECK(data); std::copy(static_cast(data), static_cast(data) + length, io_buffer_->data()); } // static DnsResponse DnsResponse::CreateEmptyNoDataResponse( uint16_t id, bool is_authoritative, base::span qname, uint16_t qtype) { return DnsResponse(id, is_authoritative, /*answers=*/{}, /*authority_records=*/{}, /*additional_records=*/{}, DnsQuery(id, qname, qtype)); } DnsResponse::DnsResponse(DnsResponse&& other) = default; DnsResponse& DnsResponse::operator=(DnsResponse&& other) = default; DnsResponse::~DnsResponse() = default; bool DnsResponse::InitParse(size_t nbytes, const DnsQuery& query) { const std::string_view question = query.question(); // Response includes question, it should be at least that size. if (nbytes < kHeaderSize + question.size() || nbytes > io_buffer_size_) { return false; } // At this point, it has been validated that the response is at least large // enough to read the ID field. id_available_ = true; // Match the query id. DCHECK(id()); if (id().value() != query.id()) return false; // Not a response? if ((base::NetToHost16(header()->flags) & dns_protocol::kFlagResponse) == 0) return false; // Match question count. if (base::NetToHost16(header()->qdcount) != 1) return false; // Match the question section. if (question != std::string_view(io_buffer_->data() + kHeaderSize, question.size())) { return false; } std::optional dotted_qname = dns_names_util::NetworkToDottedName(query.qname()); if (!dotted_qname.has_value()) return false; dotted_qnames_.push_back(std::move(dotted_qname).value()); qtypes_.push_back(query.qtype()); size_t num_records = base::NetToHost16(header()->ancount) + base::NetToHost16(header()->nscount) + base::NetToHost16(header()->arcount); // Construct the parser. Only allow parsing up to `num_records` records. If // more records are present in the buffer, it's just garbage extra data after // the formal end of the response and should be ignored. parser_ = DnsRecordParser(io_buffer_->data(), nbytes, kHeaderSize + question.size(), num_records); return true; } bool DnsResponse::InitParseWithoutQuery(size_t nbytes) { if (nbytes < kHeaderSize || nbytes > io_buffer_size_) { return false; } id_available_ = true; // Not a response? if ((base::NetToHost16(header()->flags) & dns_protocol::kFlagResponse) == 0) return false; size_t num_records = base::NetToHost16(header()->ancount) + base::NetToHost16(header()->nscount) + base::NetToHost16(header()->arcount); // Only allow parsing up to `num_records` records. If more records are present // in the buffer, it's just garbage extra data after the formal end of the // response and should be ignored. parser_ = DnsRecordParser(io_buffer_->data(), nbytes, kHeaderSize, num_records); unsigned qdcount = base::NetToHost16(header()->qdcount); for (unsigned i = 0; i < qdcount; ++i) { std::string dotted_qname; uint16_t qtype; if (!parser_.ReadQuestion(dotted_qname, qtype)) { parser_ = DnsRecordParser(); // Make parser invalid again. return false; } dotted_qnames_.push_back(std::move(dotted_qname)); qtypes_.push_back(qtype); } return true; } std::optional DnsResponse::id() const { if (!id_available_) return std::nullopt; return base::NetToHost16(header()->id); } bool DnsResponse::IsValid() const { return parser_.IsValid(); } uint16_t DnsResponse::flags() const { DCHECK(parser_.IsValid()); return base::NetToHost16(header()->flags) & ~(kRcodeMask); } uint8_t DnsResponse::rcode() const { DCHECK(parser_.IsValid()); return base::NetToHost16(header()->flags) & kRcodeMask; } unsigned DnsResponse::question_count() const { DCHECK(parser_.IsValid()); return base::NetToHost16(header()->qdcount); } unsigned DnsResponse::answer_count() const { DCHECK(parser_.IsValid()); return base::NetToHost16(header()->ancount); } unsigned DnsResponse::authority_count() const { DCHECK(parser_.IsValid()); return base::NetToHost16(header()->nscount); } unsigned DnsResponse::additional_answer_count() const { DCHECK(parser_.IsValid()); return base::NetToHost16(header()->arcount); } uint16_t DnsResponse::GetSingleQType() const { DCHECK_EQ(qtypes().size(), 1u); return qtypes().front(); } std::string_view DnsResponse::GetSingleDottedName() const { DCHECK_EQ(dotted_qnames().size(), 1u); return dotted_qnames().front(); } DnsRecordParser DnsResponse::Parser() const { DCHECK(parser_.IsValid()); // Return a copy of the parser. return parser_; } const dns_protocol::Header* DnsResponse::header() const { return reinterpret_cast(io_buffer_->data()); } bool DnsResponse::WriteHeader(base::SpanWriter* writer, const dns_protocol::Header& header) { return writer->WriteU16BigEndian(header.id) && writer->WriteU16BigEndian(header.flags) && writer->WriteU16BigEndian(header.qdcount) && writer->WriteU16BigEndian(header.ancount) && writer->WriteU16BigEndian(header.nscount) && writer->WriteU16BigEndian(header.arcount); } bool DnsResponse::WriteQuestion(base::SpanWriter* writer, const DnsQuery& query) { return writer->Write(base::as_byte_span(query.question())); } bool DnsResponse::WriteRecord(base::SpanWriter* writer, const DnsResourceRecord& record, bool validate_record, bool validate_name_as_internet_hostname) { if (record.rdata != std::string_view(record.owned_rdata)) { VLOG(1) << "record.rdata should point to record.owned_rdata."; return false; } if (validate_record && !RecordRdata::HasValidSize(record.owned_rdata, record.type)) { VLOG(1) << "Invalid RDATA size for a record."; return false; } std::optional> domain_name = dns_names_util::DottedNameToNetwork(record.name, validate_name_as_internet_hostname); if (!domain_name.has_value()) { VLOG(1) << "Invalid dotted name (as " << (validate_name_as_internet_hostname ? "Internet hostname)." : "DNS name)."); return false; } return writer->Write(domain_name.value()) && writer->WriteU16BigEndian(record.type) && writer->WriteU16BigEndian(record.klass) && writer->WriteU32BigEndian(record.ttl) && writer->WriteU16BigEndian(record.owned_rdata.size()) && // Use the owned RDATA in the record to construct the response. writer->Write(base::as_byte_span(record.owned_rdata)); } bool DnsResponse::WriteAnswer(base::SpanWriter* writer, const DnsResourceRecord& answer, const std::optional& query, bool validate_record, bool validate_name_as_internet_hostname) { // Generally assumed to be a mistake if we write answers that don't match the // query type, except CNAME answers which can always be added. if (validate_record && query.has_value() && answer.type != query.value().qtype() && answer.type != dns_protocol::kTypeCNAME) { VLOG(1) << "Mismatched answer resource record type and qtype."; return false; } return WriteRecord(writer, answer, validate_record, validate_name_as_internet_hostname); } } // namespace net