// Copyright 2012 The Chromium Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/dns/dns_test_util.h" #include #include #include #include #include #include #include "base/check.h" #include "base/containers/span.h" #include "base/functional/bind.h" #include "base/location.h" #include "base/numerics/byte_conversions.h" #include "base/numerics/safe_conversions.h" #include "base/ranges/algorithm.h" #include "base/strings/strcat.h" #include "base/sys_byteorder.h" #include "base/task/single_thread_task_runner.h" #include "base/test/test_timeouts.h" #include "base/threading/thread_restrictions.h" #include "base/time/time.h" #include "base/types/optional_util.h" #include "net/base/io_buffer.h" #include "net/base/ip_address.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/dns/address_sorter.h" #include "net/dns/dns_hosts.h" #include "net/dns/dns_names_util.h" #include "net/dns/dns_query.h" #include "net/dns/dns_session.h" #include "net/dns/mock_host_resolver.h" #include "net/dns/public/dns_over_https_server_config.h" #include "net/dns/resolve_context.h" #include "testing/gmock/include/gmock/gmock-matchers.h" #include "testing/gtest/include/gtest/gtest.h" #include "url/scheme_host_port.h" namespace net { namespace { const uint8_t kMalformedResponseHeader[] = { // Header 0x00, 0x14, // Arbitrary ID 0x81, 0x80, // Standard query response, RA, no error 0x00, 0x01, // 1 question 0x00, 0x01, // 1 RR (answers) 0x00, 0x00, // 0 authority RRs 0x00, 0x00, // 0 additional RRs }; // Create a response containing a valid question (as would normally be validated // in DnsTransaction) but completely missing a header-declared answer. DnsResponse CreateMalformedResponse(std::string hostname, uint16_t type) { std::optional> dns_name = dns_names_util::DottedNameToNetwork(hostname); CHECK(dns_name.has_value()); DnsQuery query(/*id=*/0x14, dns_name.value(), type); // Build response to simulate the barebones validation DnsResponse applies to // responses received from the network. auto buffer = base::MakeRefCounted( sizeof(kMalformedResponseHeader) + query.question().size()); memcpy(buffer->data(), kMalformedResponseHeader, sizeof(kMalformedResponseHeader)); memcpy(buffer->data() + sizeof(kMalformedResponseHeader), query.question().data(), query.question().size()); DnsResponse response(buffer, buffer->size()); CHECK(response.InitParseWithoutQuery(buffer->size())); return response; } class MockAddressSorter : public AddressSorter { public: ~MockAddressSorter() override = default; void Sort(const std::vector& endpoints, CallbackType callback) const override { // Do nothing. std::move(callback).Run(true, endpoints); } }; } // namespace DnsConfig CreateValidDnsConfig() { IPAddress dns_ip(192, 168, 1, 0); DnsConfig config; config.nameservers.emplace_back(dns_ip, dns_protocol::kDefaultPort); config.doh_config = *DnsOverHttpsConfig::FromString("https://dns.example.com/"); config.secure_dns_mode = SecureDnsMode::kOff; EXPECT_TRUE(config.IsValid()); return config; } DnsResourceRecord BuildTestDnsRecord(std::string name, uint16_t type, std::string rdata, base::TimeDelta ttl) { DCHECK(!name.empty()); DnsResourceRecord record; record.name = std::move(name); record.type = type; record.klass = dns_protocol::kClassIN; record.ttl = ttl.InSeconds(); if (!rdata.empty()) record.SetOwnedRdata(std::move(rdata)); return record; } DnsResourceRecord BuildTestCnameRecord(std::string name, std::string_view canonical_name, base::TimeDelta ttl) { DCHECK(!name.empty()); DCHECK(!canonical_name.empty()); std::optional> rdata = dns_names_util::DottedNameToNetwork(canonical_name); CHECK(rdata.has_value()); return BuildTestDnsRecord( std::move(name), dns_protocol::kTypeCNAME, std::string(reinterpret_cast(rdata.value().data()), rdata.value().size()), ttl); } DnsResourceRecord BuildTestAddressRecord(std::string name, const IPAddress& ip, base::TimeDelta ttl) { DCHECK(!name.empty()); DCHECK(ip.IsValid()); return BuildTestDnsRecord( std::move(name), ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, net::IPAddressToPackedString(ip), ttl); } DnsResourceRecord BuildTestTextRecord(std::string name, std::vector text_strings, base::TimeDelta ttl) { DCHECK(!text_strings.empty()); std::string rdata; for (const std::string& text_string : text_strings) { DCHECK(!text_string.empty()); rdata += base::checked_cast(text_string.size()); rdata += text_string; } return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeTXT, std::move(rdata), ttl); } DnsResourceRecord BuildTestHttpsAliasRecord(std::string name, std::string_view alias_name, base::TimeDelta ttl) { DCHECK(!name.empty()); std::string rdata("\000\000", 2); std::optional> alias_domain = dns_names_util::DottedNameToNetwork(alias_name); CHECK(alias_domain.has_value()); rdata.append(reinterpret_cast(alias_domain.value().data()), alias_domain.value().size()); return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeHttps, std::move(rdata), ttl); } std::pair BuildTestHttpsServiceAlpnParam( const std::vector& alpns) { std::string param_value; for (const std::string& alpn : alpns) { CHECK(!alpn.empty()); param_value.append( 1, static_cast(base::checked_cast(alpn.size()))); param_value.append(alpn); } return std::pair(dns_protocol::kHttpsServiceParamKeyAlpn, std::move(param_value)); } std::pair BuildTestHttpsServiceEchConfigParam( base::span ech_config_list) { return std::pair( dns_protocol::kHttpsServiceParamKeyEchConfig, std::string(reinterpret_cast(ech_config_list.data()), ech_config_list.size())); } std::pair BuildTestHttpsServiceMandatoryParam( std::vector param_key_list) { base::ranges::sort(param_key_list); std::string value; for (uint16_t param_key : param_key_list) { std::array num_buffer = base::U16ToBigEndian(param_key); value.append(num_buffer.begin(), num_buffer.end()); } return std::pair(dns_protocol::kHttpsServiceParamKeyMandatory, std::move(value)); } std::pair BuildTestHttpsServicePortParam(uint16_t port) { std::array buffer = base::U16ToBigEndian(port); return std::pair(dns_protocol::kHttpsServiceParamKeyPort, std::string(buffer.begin(), buffer.end())); } DnsResourceRecord BuildTestHttpsServiceRecord( std::string name, uint16_t priority, std::string_view service_name, const std::map& params, base::TimeDelta ttl) { DCHECK(!name.empty()); DCHECK_NE(priority, 0); std::string rdata; { std::array buf = base::U16ToBigEndian(priority); rdata.append(buf.begin(), buf.end()); } std::optional> service_domain; if (service_name == ".") { // HTTPS records have special behavior for `service_name == "."` (that it // will be treated as if the service name is the same as the record owner // name), so allow such inputs despite normally being disallowed for // Chrome-encoded DNS names. service_domain = std::vector{0}; } else { service_domain = dns_names_util::DottedNameToNetwork(service_name); } CHECK(service_domain.has_value()); rdata.append(reinterpret_cast(service_domain.value().data()), service_domain.value().size()); for (auto& param : params) { { std::array buf = base::U16ToBigEndian(param.first); rdata.append(buf.begin(), buf.end()); } { std::array buf = base::U16ToBigEndian( base::checked_cast(param.second.size())); rdata.append(buf.begin(), buf.end()); } rdata.append(param.second); } return BuildTestDnsRecord(std::move(name), dns_protocol::kTypeHttps, std::move(rdata), ttl); } DnsResponse BuildTestDnsResponse( std::string name, uint16_t type, const std::vector& answers, const std::vector& authority, const std::vector& additional, uint8_t rcode) { DCHECK(!name.empty()); std::optional> dns_name = dns_names_util::DottedNameToNetwork(name); CHECK(dns_name.has_value()); std::optional query(std::in_place, 0, dns_name.value(), type); return DnsResponse(0, true /* is_authoritative */, answers, authority /* authority_records */, additional /* additional_records */, query, rcode, false /* validate_records */); } DnsResponse BuildTestDnsAddressResponse(std::string name, const IPAddress& ip, std::string answer_name) { DCHECK(ip.IsValid()); if (answer_name.empty()) answer_name = name; std::vector answers = { BuildTestAddressRecord(std::move(answer_name), ip)}; return BuildTestDnsResponse( std::move(name), ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, answers); } DnsResponse BuildTestDnsAddressResponseWithCname(std::string name, const IPAddress& ip, std::string cannonname, std::string answer_name) { DCHECK(ip.IsValid()); DCHECK(!cannonname.empty()); if (answer_name.empty()) answer_name = name; std::optional> cname_rdata = dns_names_util::DottedNameToNetwork(cannonname); CHECK(cname_rdata.has_value()); std::vector answers = { BuildTestDnsRecord( std::move(answer_name), dns_protocol::kTypeCNAME, std::string(reinterpret_cast(cname_rdata.value().data()), cname_rdata.value().size())), BuildTestAddressRecord(std::move(cannonname), ip)}; return BuildTestDnsResponse( std::move(name), ip.IsIPv4() ? dns_protocol::kTypeA : dns_protocol::kTypeAAAA, answers); } DnsResponse BuildTestDnsTextResponse( std::string name, std::vector> text_records, std::string answer_name) { if (answer_name.empty()) answer_name = name; std::vector answers; for (std::vector& text_record : text_records) { answers.push_back(BuildTestTextRecord(answer_name, std::move(text_record))); } return BuildTestDnsResponse(std::move(name), dns_protocol::kTypeTXT, answers); } DnsResponse BuildTestDnsPointerResponse(std::string name, std::vector pointer_names, std::string answer_name) { if (answer_name.empty()) answer_name = name; std::vector answers; for (std::string& pointer_name : pointer_names) { std::optional> rdata = dns_names_util::DottedNameToNetwork(pointer_name); CHECK(rdata.has_value()); answers.push_back(BuildTestDnsRecord( answer_name, dns_protocol::kTypePTR, std::string(reinterpret_cast(rdata.value().data()), rdata.value().size()))); } return BuildTestDnsResponse(std::move(name), dns_protocol::kTypePTR, answers); } DnsResponse BuildTestDnsServiceResponse( std::string name, std::vector service_records, std::string answer_name) { if (answer_name.empty()) answer_name = name; std::vector answers; for (TestServiceRecord& service_record : service_records) { std::string rdata; { std::array buf = base::U16ToBigEndian(service_record.priority); rdata.append(buf.begin(), buf.end()); } { std::array buf = base::U16ToBigEndian(service_record.weight); rdata.append(buf.begin(), buf.end()); } { std::array buf = base::U16ToBigEndian(service_record.port); rdata.append(buf.begin(), buf.end()); } std::optional> dns_name = dns_names_util::DottedNameToNetwork(service_record.target); CHECK(dns_name.has_value()); rdata.append(reinterpret_cast(dns_name.value().data()), dns_name.value().size()); answers.push_back(BuildTestDnsRecord(answer_name, dns_protocol::kTypeSRV, std::move(rdata), base::Hours(5))); } return BuildTestDnsResponse(std::move(name), dns_protocol::kTypeSRV, answers); } MockDnsClientRule::Result::Result(ResultType type, std::optional response, std::optional net_error) : type(type), response(std::move(response)), net_error(net_error) {} MockDnsClientRule::Result::Result(DnsResponse response) : type(ResultType::kOk), response(std::move(response)), net_error(std::nullopt) {} MockDnsClientRule::Result::Result(Result&&) = default; MockDnsClientRule::Result& MockDnsClientRule::Result::operator=(Result&&) = default; MockDnsClientRule::Result::~Result() = default; MockDnsClientRule::MockDnsClientRule(const std::string& prefix, uint16_t qtype, bool secure, Result result, bool delay, URLRequestContext* context) : result(std::move(result)), prefix(prefix), qtype(qtype), secure(secure), delay(delay), context(context) {} MockDnsClientRule::MockDnsClientRule(MockDnsClientRule&& rule) = default; // A DnsTransaction which uses MockDnsClientRuleList to determine the response. class MockDnsTransactionFactory::MockTransaction : public DnsTransaction, public base::SupportsWeakPtr { public: MockTransaction(const MockDnsClientRuleList& rules, std::string hostname, uint16_t qtype, bool secure, bool force_doh_server_available, SecureDnsMode secure_dns_mode, ResolveContext* resolve_context, bool fast_timeout) : hostname_(std::move(hostname)), qtype_(qtype) { // Do not allow matching any rules if transaction is secure and no DoH // servers are available. if (!secure || force_doh_server_available || resolve_context->NumAvailableDohServers( resolve_context->current_session_for_testing()) > 0) { // Find the relevant rule which matches |qtype|, |secure|, prefix of // |hostname_|, and |url_request_context| (iff the rule context is not // null). for (const auto& rule : rules) { const std::string& prefix = rule.prefix; if ((rule.qtype == qtype) && (rule.secure == secure) && (hostname_.size() >= prefix.size()) && (hostname_.compare(0, prefix.size(), prefix) == 0) && (!rule.context || rule.context == resolve_context->url_request_context())) { const MockDnsClientRule::Result* result = &rule.result; result_ = MockDnsClientRule::Result(result->type); result_.net_error = result->net_error; delayed_ = rule.delay; // Generate a DnsResponse when not provided with the rule. std::vector authority_records; std::optional> dns_name = dns_names_util::DottedNameToNetwork(hostname_); CHECK(dns_name.has_value()); std::optional query(std::in_place, /*id=*/22, dns_name.value(), qtype_); switch (result->type) { case MockDnsClientRule::ResultType::kNoDomain: case MockDnsClientRule::ResultType::kEmpty: DCHECK(!result->response); // Not expected to be provided. authority_records = {BuildTestDnsRecord( hostname_, dns_protocol::kTypeSOA, "fake rdata")}; result_.response = DnsResponse( 22 /* id */, false /* is_authoritative */, std::vector() /* answers */, authority_records, std::vector() /* additional_records */, query, result->type == MockDnsClientRule::ResultType::kNoDomain ? dns_protocol::kRcodeNXDOMAIN : 0); break; case MockDnsClientRule::ResultType::kFail: if (result->response) SetResponse(result); break; case MockDnsClientRule::ResultType::kTimeout: DCHECK(!result->response); // Not expected to be provided. break; case MockDnsClientRule::ResultType::kSlow: if (!fast_timeout) SetResponse(result); break; case MockDnsClientRule::ResultType::kOk: SetResponse(result); break; case MockDnsClientRule::ResultType::kMalformed: DCHECK(!result->response); // Not expected to be provided. result_.response = CreateMalformedResponse(hostname_, qtype_); break; case MockDnsClientRule::ResultType::kUnexpected: if (!delayed_) { // Assume a delayed kUnexpected transaction is only an issue if // allowed to complete. ADD_FAILURE() << "Unexpected DNS transaction created for hostname " << hostname_; } break; } break; } } } } const std::string& GetHostname() const override { return hostname_; } uint16_t GetType() const override { return qtype_; } void Start(ResponseCallback callback) override { CHECK(!callback.is_null()); CHECK(callback_.is_null()); EXPECT_FALSE(started_); callback_ = std::move(callback); started_ = true; if (delayed_) return; // Using WeakPtr to cleanly cancel when transaction is destroyed. base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( FROM_HERE, base::BindOnce(&MockTransaction::Finish, AsWeakPtr())); } void FinishDelayedTransaction() { EXPECT_TRUE(delayed_); delayed_ = false; Finish(); } bool delayed() const { return delayed_; } private: void SetResponse(const MockDnsClientRule::Result* result) { if (result->response) { // Copy response in case |result| is destroyed before the transaction // completes. auto buffer_copy = base::MakeRefCounted( result->response->io_buffer_size()); memcpy(buffer_copy->data(), result->response->io_buffer()->data(), result->response->io_buffer_size()); result_.response = DnsResponse(std::move(buffer_copy), result->response->io_buffer_size()); CHECK(result_.response->InitParseWithoutQuery( result->response->io_buffer_size())); } else { // Generated response only available for address types. DCHECK(qtype_ == dns_protocol::kTypeA || qtype_ == dns_protocol::kTypeAAAA); result_.response = BuildTestDnsAddressResponse( hostname_, qtype_ == dns_protocol::kTypeA ? IPAddress::IPv4Localhost() : IPAddress::IPv6Localhost()); } } void Finish() { switch (result_.type) { case MockDnsClientRule::ResultType::kNoDomain: case MockDnsClientRule::ResultType::kFail: { int error = result_.net_error.value_or(ERR_NAME_NOT_RESOLVED); DCHECK_NE(error, OK); std::move(callback_).Run(error, base::OptionalToPtr(result_.response)); break; } case MockDnsClientRule::ResultType::kEmpty: case MockDnsClientRule::ResultType::kOk: case MockDnsClientRule::ResultType::kMalformed: DCHECK(!result_.net_error.has_value()); std::move(callback_).Run(OK, base::OptionalToPtr(result_.response)); break; case MockDnsClientRule::ResultType::kTimeout: DCHECK(!result_.net_error.has_value()); std::move(callback_).Run(ERR_DNS_TIMED_OUT, /*response=*/nullptr); break; case MockDnsClientRule::ResultType::kSlow: if (result_.response) { std::move(callback_).Run( result_.net_error.value_or(OK), result_.response ? &result_.response.value() : nullptr); } else { DCHECK(!result_.net_error.has_value()); std::move(callback_).Run(ERR_DNS_TIMED_OUT, /*response=*/nullptr); } break; case MockDnsClientRule::ResultType::kUnexpected: ADD_FAILURE() << "Unexpected DNS transaction completed for hostname " << hostname_; break; } } void SetRequestPriority(RequestPriority priority) override {} MockDnsClientRule::Result result_{MockDnsClientRule::ResultType::kFail}; const std::string hostname_; const uint16_t qtype_; ResponseCallback callback_; bool started_ = false; bool delayed_ = false; }; class MockDnsTransactionFactory::MockDohProbeRunner : public DnsProbeRunner { public: explicit MockDohProbeRunner(base::WeakPtr factory) : factory_(std::move(factory)) {} ~MockDohProbeRunner() override { if (factory_) factory_->running_doh_probe_runners_.erase(this); } void Start(bool network_change) override { DCHECK(factory_); factory_->running_doh_probe_runners_.insert(this); } base::TimeDelta GetDelayUntilNextProbeForTest( size_t doh_server_index) const override { NOTREACHED(); return base::TimeDelta(); } private: base::WeakPtr factory_; }; MockDnsTransactionFactory::MockDnsTransactionFactory( MockDnsClientRuleList rules) : rules_(std::move(rules)) {} MockDnsTransactionFactory::~MockDnsTransactionFactory() = default; std::unique_ptr MockDnsTransactionFactory::CreateTransaction( std::string hostname, uint16_t qtype, const NetLogWithSource&, bool secure, SecureDnsMode secure_dns_mode, ResolveContext* resolve_context, bool fast_timeout) { std::unique_ptr transaction = std::make_unique(rules_, std::move(hostname), qtype, secure, force_doh_server_available_, secure_dns_mode, resolve_context, fast_timeout); if (transaction->delayed()) delayed_transactions_.push_back(transaction->AsWeakPtr()); return transaction; } std::unique_ptr MockDnsTransactionFactory::CreateDohProbeRunner( ResolveContext* resolve_context) { return std::make_unique(weak_ptr_factory_.GetWeakPtr()); } void MockDnsTransactionFactory::AddEDNSOption( std::unique_ptr opt) {} SecureDnsMode MockDnsTransactionFactory::GetSecureDnsModeForTest() { return SecureDnsMode::kAutomatic; } void MockDnsTransactionFactory::CompleteDelayedTransactions() { DelayedTransactionList old_delayed_transactions; old_delayed_transactions.swap(delayed_transactions_); for (auto& old_delayed_transaction : old_delayed_transactions) { if (old_delayed_transaction.get()) old_delayed_transaction->FinishDelayedTransaction(); } } bool MockDnsTransactionFactory::CompleteOneDelayedTransactionOfType( DnsQueryType type) { for (base::WeakPtr& t : delayed_transactions_) { if (t && t->GetType() == DnsQueryTypeToQtype(type)) { t->FinishDelayedTransaction(); t.reset(); return true; } } return false; } MockDnsClient::MockDnsClient(DnsConfig config, MockDnsClientRuleList rules) : config_(std::move(config)), factory_(std::make_unique(std::move(rules))), address_sorter_(std::make_unique()) { effective_config_ = BuildEffectiveConfig(); session_ = BuildSession(); } MockDnsClient::~MockDnsClient() = default; bool MockDnsClient::CanUseSecureDnsTransactions() const { const DnsConfig* config = GetEffectiveConfig(); return config && config->IsValid() && !config->doh_config.servers().empty(); } bool MockDnsClient::CanUseInsecureDnsTransactions() const { const DnsConfig* config = GetEffectiveConfig(); return config && config->IsValid() && insecure_enabled_ && !config->dns_over_tls_active; } bool MockDnsClient::CanQueryAdditionalTypesViaInsecureDns() const { DCHECK(CanUseInsecureDnsTransactions()); return additional_types_enabled_; } void MockDnsClient::SetInsecureEnabled(bool enabled, bool additional_types_enabled) { insecure_enabled_ = enabled; additional_types_enabled_ = additional_types_enabled; } bool MockDnsClient::FallbackFromSecureTransactionPreferred( ResolveContext* context) const { bool doh_server_available = force_doh_server_available_ || context->NumAvailableDohServers(session_.get()) > 0; return !CanUseSecureDnsTransactions() || !doh_server_available; } bool MockDnsClient::FallbackFromInsecureTransactionPreferred() const { return !CanUseInsecureDnsTransactions() || fallback_failures_ >= max_fallback_failures_; } bool MockDnsClient::SetSystemConfig(std::optional system_config) { if (ignore_system_config_changes_) return false; std::optional before = effective_config_; config_ = std::move(system_config); effective_config_ = BuildEffectiveConfig(); session_ = BuildSession(); return before != effective_config_; } bool MockDnsClient::SetConfigOverrides(DnsConfigOverrides config_overrides) { std::optional before = effective_config_; overrides_ = std::move(config_overrides); effective_config_ = BuildEffectiveConfig(); session_ = BuildSession(); return before != effective_config_; } void MockDnsClient::ReplaceCurrentSession() { // Noop if no current effective config. session_ = BuildSession(); } DnsSession* MockDnsClient::GetCurrentSession() { return session_.get(); } const DnsConfig* MockDnsClient::GetEffectiveConfig() const { return effective_config_.has_value() ? &effective_config_.value() : nullptr; } base::Value::Dict MockDnsClient::GetDnsConfigAsValueForNetLog() const { // This is just a stub implementation that never produces a meaningful value. return base::Value::Dict(); } const DnsHosts* MockDnsClient::GetHosts() const { const DnsConfig* config = GetEffectiveConfig(); if (!config) return nullptr; return &config->hosts; } DnsTransactionFactory* MockDnsClient::GetTransactionFactory() { return GetEffectiveConfig() ? factory_.get() : nullptr; } AddressSorter* MockDnsClient::GetAddressSorter() { return GetEffectiveConfig() ? address_sorter_.get() : nullptr; } void MockDnsClient::IncrementInsecureFallbackFailures() { ++fallback_failures_; } void MockDnsClient::ClearInsecureFallbackFailures() { fallback_failures_ = 0; } std::optional MockDnsClient::GetSystemConfigForTesting() const { return config_; } DnsConfigOverrides MockDnsClient::GetConfigOverridesForTesting() const { return overrides_; } void MockDnsClient::SetTransactionFactoryForTesting( std::unique_ptr factory) { NOTREACHED(); } void MockDnsClient::SetAddressSorterForTesting( std::unique_ptr address_sorter) { address_sorter_ = std::move(address_sorter); } std::optional> MockDnsClient::GetPresetAddrs( const url::SchemeHostPort& endpoint) const { EXPECT_THAT(preset_endpoint_, testing::Optional(endpoint)); return preset_addrs_; } void MockDnsClient::CompleteDelayedTransactions() { factory_->CompleteDelayedTransactions(); } bool MockDnsClient::CompleteOneDelayedTransactionOfType(DnsQueryType type) { return factory_->CompleteOneDelayedTransactionOfType(type); } void MockDnsClient::SetForceDohServerAvailable(bool available) { force_doh_server_available_ = available; factory_->set_force_doh_server_available(available); } std::optional MockDnsClient::BuildEffectiveConfig() { if (overrides_.OverridesEverything()) return overrides_.ApplyOverrides(DnsConfig()); if (!config_ || !config_.value().IsValid()) return std::nullopt; return overrides_.ApplyOverrides(config_.value()); } scoped_refptr MockDnsClient::BuildSession() { if (!effective_config_) return nullptr; // Session not expected to be used for anything that will actually require // random numbers. auto null_random_callback = base::BindRepeating([](int, int) -> int { base::ImmediateCrash(); }); return base::MakeRefCounted( effective_config_.value(), null_random_callback, nullptr /* net_log */); } MockHostResolverProc::MockHostResolverProc() : HostResolverProc(nullptr), requests_waiting_(&lock_), slots_available_(&lock_) {} MockHostResolverProc::~MockHostResolverProc() = default; bool MockHostResolverProc::WaitFor(unsigned count) { base::AutoLock lock(lock_); base::Time start_time = base::Time::Now(); while (num_requests_waiting_ < count) { requests_waiting_.TimedWait(TestTimeouts::action_timeout()); if (base::Time::Now() > start_time + TestTimeouts::action_timeout()) { return false; } } return true; } void MockHostResolverProc::SignalMultiple(unsigned count) { base::AutoLock lock(lock_); num_slots_available_ += count; slots_available_.Broadcast(); } void MockHostResolverProc::SignalAll() { base::AutoLock lock(lock_); num_slots_available_ = num_requests_waiting_; slots_available_.Broadcast(); } void MockHostResolverProc::AddRule(const std::string& hostname, AddressFamily family, const AddressList& result, HostResolverFlags flags) { base::AutoLock lock(lock_); rules_[ResolveKey(hostname, family, flags)] = result; } void MockHostResolverProc::AddRule(const std::string& hostname, AddressFamily family, const std::string& ip_list, HostResolverFlags flags, const std::string& canonical_name) { AddressList result; std::vector dns_aliases; if (canonical_name != "") { dns_aliases = {canonical_name}; } int rv = ParseAddressList(ip_list, &result.endpoints()); result.SetDnsAliases(dns_aliases); DCHECK_EQ(OK, rv); AddRule(hostname, family, result, flags); } void MockHostResolverProc::AddRuleForAllFamilies( const std::string& hostname, const std::string& ip_list, HostResolverFlags flags, const std::string& canonical_name) { AddressList result; std::vector dns_aliases; if (canonical_name != "") { dns_aliases = {canonical_name}; } int rv = ParseAddressList(ip_list, &result.endpoints()); result.SetDnsAliases(dns_aliases); DCHECK_EQ(OK, rv); AddRule(hostname, ADDRESS_FAMILY_UNSPECIFIED, result, flags); AddRule(hostname, ADDRESS_FAMILY_IPV4, result, flags); AddRule(hostname, ADDRESS_FAMILY_IPV6, result, flags); } int MockHostResolverProc::Resolve(const std::string& hostname, AddressFamily address_family, HostResolverFlags host_resolver_flags, AddressList* addrlist, int* os_error) { base::AutoLock lock(lock_); capture_list_.emplace_back(hostname, address_family, host_resolver_flags); ++num_requests_waiting_; requests_waiting_.Broadcast(); { base::ScopedAllowBaseSyncPrimitivesForTesting scoped_allow_base_sync_primitives; while (!num_slots_available_) { slots_available_.Wait(); } } DCHECK_GT(num_requests_waiting_, 0u); --num_slots_available_; --num_requests_waiting_; if (rules_.empty()) { int rv = ParseAddressList("127.0.0.1", &addrlist->endpoints()); DCHECK_EQ(OK, rv); return OK; } ResolveKey key(hostname, address_family, host_resolver_flags); if (rules_.count(key) == 0) { return ERR_NAME_NOT_RESOLVED; } *addrlist = rules_[key]; return OK; } MockHostResolverProc::CaptureList MockHostResolverProc::GetCaptureList() const { CaptureList copy; { base::AutoLock lock(lock_); copy = capture_list_; } return copy; } void MockHostResolverProc::ClearCaptureList() { base::AutoLock lock(lock_); capture_list_.clear(); } bool MockHostResolverProc::HasBlockedRequests() const { base::AutoLock lock(lock_); return num_requests_waiting_ > num_slots_available_; } } // namespace net