// Copyright 2023 The Chromium Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/dns/host_resolver_cache.h" #include #include #include #include #include #include #include #include "base/check_op.h" #include "base/numerics/safe_conversions.h" #include "base/time/clock.h" #include "base/time/time.h" #include "net/base/network_anonymization_key.h" #include "net/dns/host_resolver_internal_result.h" #include "net/dns/public/dns_query_type.h" #include "net/dns/public/host_resolver_source.h" #include "url/third_party/mozilla/url_parse.h" #include "url/url_canon.h" #include "url/url_canon_stdstring.h" namespace net { namespace { constexpr std::string_view kNakKey = "network_anonymization_key"; constexpr std::string_view kSourceKey = "source"; constexpr std::string_view kSecureKey = "secure"; constexpr std::string_view kResultKey = "result"; constexpr std::string_view kStalenessGenerationKey = "staleness_generation"; constexpr std::string_view kMaxEntriesKey = "max_entries"; constexpr std::string_view kEntriesKey = "entries"; } // namespace HostResolverCache::Key::~Key() = default; HostResolverCache::StaleLookupResult::StaleLookupResult( const HostResolverInternalResult& result, std::optional expired_by, bool stale_by_generation) : result(result), expired_by(expired_by), stale_by_generation(stale_by_generation) {} HostResolverCache::HostResolverCache(size_t max_results, const base::Clock& clock, const base::TickClock& tick_clock) : max_entries_(max_results), clock_(clock), tick_clock_(tick_clock) { DCHECK_GT(max_entries_, 0u); } HostResolverCache::~HostResolverCache() = default; HostResolverCache::HostResolverCache(HostResolverCache&&) = default; HostResolverCache& HostResolverCache::operator=(HostResolverCache&&) = default; const HostResolverInternalResult* HostResolverCache::Lookup( std::string_view domain_name, const NetworkAnonymizationKey& network_anonymization_key, DnsQueryType query_type, HostResolverSource source, std::optional secure) const { std::vector candidates = LookupInternal( domain_name, network_anonymization_key, query_type, source, secure); // Get the most secure, last-matching (which is first in the vector returned // by LookupInternal()) non-expired result. base::TimeTicks now_ticks = tick_clock_->NowTicks(); base::Time now = clock_->Now(); HostResolverInternalResult* most_secure_result = nullptr; for (const EntryMap::const_iterator& candidate : candidates) { DCHECK(candidate->second.result->timed_expiration().has_value()); if (candidate->second.IsStale(now, now_ticks, staleness_generation_)) { continue; } // If the candidate is secure, or all results are insecure, no need to check // any more. if (candidate->second.secure || !secure.value_or(true)) { return candidate->second.result.get(); } else if (most_secure_result == nullptr) { most_secure_result = candidate->second.result.get(); } } return most_secure_result; } std::optional HostResolverCache::LookupStale( std::string_view domain_name, const NetworkAnonymizationKey& network_anonymization_key, DnsQueryType query_type, HostResolverSource source, std::optional secure) const { std::vector candidates = LookupInternal( domain_name, network_anonymization_key, query_type, source, secure); // Get the least expired, most secure result. base::TimeTicks now_ticks = tick_clock_->NowTicks(); base::Time now = clock_->Now(); const Entry* best_match = nullptr; base::TimeDelta best_match_time_until_expiration; for (const EntryMap::const_iterator& candidate : candidates) { DCHECK(candidate->second.result->timed_expiration().has_value()); base::TimeDelta candidate_time_until_expiration = candidate->second.TimeUntilExpiration(now, now_ticks); if (!candidate->second.IsStale(now, now_ticks, staleness_generation_) && (candidate->second.secure || !secure.value_or(true))) { // If a non-stale candidate is secure, or all results are insecure, no // need to check any more. best_match = &candidate->second; best_match_time_until_expiration = candidate_time_until_expiration; break; } else if (best_match == nullptr || (!candidate->second.IsStale(now, now_ticks, staleness_generation_) && best_match->IsStale(now, now_ticks, staleness_generation_)) || candidate->second.staleness_generation > best_match->staleness_generation || (candidate->second.staleness_generation == best_match->staleness_generation && candidate_time_until_expiration > best_match_time_until_expiration) || (candidate->second.staleness_generation == best_match->staleness_generation && candidate_time_until_expiration == best_match_time_until_expiration && candidate->second.secure && !best_match->secure)) { best_match = &candidate->second; best_match_time_until_expiration = candidate_time_until_expiration; } } if (best_match == nullptr) { return std::nullopt; } else { std::optional expired_by; if (best_match_time_until_expiration.is_negative()) { expired_by = best_match_time_until_expiration.magnitude(); } return StaleLookupResult( *best_match->result, expired_by, best_match->staleness_generation != staleness_generation_); } } void HostResolverCache::Set( std::unique_ptr result, const NetworkAnonymizationKey& network_anonymization_key, HostResolverSource source, bool secure) { Set(std::move(result), network_anonymization_key, source, secure, /*replace_existing=*/true, staleness_generation_); } void HostResolverCache::MakeAllResultsStale() { ++staleness_generation_; } base::Value HostResolverCache::Serialize() const { // Do not serialize any entries without a persistable anonymization key // because it is required to store and restore entries with the correct // annonymization key. A non-persistable anonymization key is typically used // for short-lived contexts, and associated entries are not expected to be // useful after persistence to disk anyway. return SerializeEntries(/*serialize_staleness_generation=*/false, /*require_persistable_anonymization_key=*/true); } bool HostResolverCache::RestoreFromValue(const base::Value& value) { const base::Value::List* list = value.GetIfList(); if (!list) { return false; } for (const base::Value& list_value : *list) { // Simply stop on reaching max size rather than attempting to figure out if // any current entries should be evicted over the deserialized entries. if (entries_.size() == max_entries_) { return true; } const base::Value::Dict* dict = list_value.GetIfDict(); if (!dict) { return false; } const base::Value* anonymization_key_value = dict->Find(kNakKey); NetworkAnonymizationKey anonymization_key; if (!anonymization_key_value || !NetworkAnonymizationKey::FromValue(*anonymization_key_value, &anonymization_key)) { return false; } const base::Value* source_value = dict->Find(kSourceKey); std::optional source = source_value == nullptr ? std::nullopt : HostResolverSourceFromValue(*source_value); if (!source.has_value()) { return false; } std::optional secure = dict->FindBool(kSecureKey); if (!secure.has_value()) { return false; } const base::Value* result_value = dict->Find(kResultKey); std::unique_ptr result = result_value == nullptr ? nullptr : HostResolverInternalResult::FromValue(*result_value); if (!result || !result->timed_expiration().has_value()) { return false; } // `staleness_generation_ - 1` to make entry stale-by-generation. Set(std::move(result), anonymization_key, source.value(), secure.value(), /*replace_existing=*/false, staleness_generation_ - 1); } CHECK_LE(entries_.size(), max_entries_); return true; } base::Value HostResolverCache::SerializeForLogging() const { base::Value::Dict dict; dict.Set(kMaxEntriesKey, base::checked_cast(max_entries_)); dict.Set(kStalenessGenerationKey, staleness_generation_); // Include entries with non-persistable anonymization keys, so the log can // contain all entries. Restoring from this serialization is not supported. dict.Set(kEntriesKey, SerializeEntries(/*serialize_staleness_generation=*/true, /*require_persistable_anonymization_key=*/false)); return base::Value(std::move(dict)); } HostResolverCache::Entry::Entry( std::unique_ptr result, HostResolverSource source, bool secure, int staleness_generation) : result(std::move(result)), source(source), secure(secure), staleness_generation(staleness_generation) {} HostResolverCache::Entry::~Entry() = default; HostResolverCache::Entry::Entry(Entry&&) = default; HostResolverCache::Entry& HostResolverCache::Entry::operator=(Entry&&) = default; bool HostResolverCache::Entry::IsStale(base::Time now, base::TimeTicks now_ticks, int current_staleness_generation) const { return staleness_generation != current_staleness_generation || TimeUntilExpiration(now, now_ticks).is_negative(); } base::TimeDelta HostResolverCache::Entry::TimeUntilExpiration( base::Time now, base::TimeTicks now_ticks) const { if (result->expiration().has_value()) { return result->expiration().value() - now_ticks; } else { DCHECK(result->timed_expiration().has_value()); return result->timed_expiration().value() - now; } } std::vector HostResolverCache::LookupInternal( std::string_view domain_name, const NetworkAnonymizationKey& network_anonymization_key, DnsQueryType query_type, HostResolverSource source, std::optional secure) const { auto matches = std::vector(); if (entries_.empty()) { return matches; } std::string canonicalized; url::StdStringCanonOutput output(&canonicalized); url::CanonHostInfo host_info; url::CanonicalizeHostVerbose(domain_name.data(), url::Component(0, domain_name.size()), &output, &host_info); // For performance, when canonicalization can't canonicalize, minimize string // copies and just reuse the input StringPiece. This optimization prevents // easily reusing a MaybeCanoncalize util with similar code. std::string_view lookup_name = domain_name; if (host_info.family == url::CanonHostInfo::Family::NEUTRAL) { output.Complete(); lookup_name = canonicalized; } auto range = entries_.equal_range( KeyRef{lookup_name, raw_ref(network_anonymization_key)}); if (range.first == entries_.cend() || range.second == entries_.cbegin() || range.first == range.second) { return matches; } // Iterate in reverse order to return most-recently-added entry first. auto it = --range.second; while (true) { if ((query_type == DnsQueryType::UNSPECIFIED || it->second.result->query_type() == DnsQueryType::UNSPECIFIED || query_type == it->second.result->query_type()) && (source == HostResolverSource::ANY || source == it->second.source) && (!secure.has_value() || secure.value() == it->second.secure)) { matches.push_back(it); } if (it == range.first) { break; } --it; } return matches; } void HostResolverCache::Set( std::unique_ptr result, const NetworkAnonymizationKey& network_anonymization_key, HostResolverSource source, bool secure, bool replace_existing, int staleness_generation) { DCHECK(result); // Result must have at least a timed expiration to be a cacheable result. DCHECK(result->timed_expiration().has_value()); std::vector matches = LookupInternal(result->domain_name(), network_anonymization_key, result->query_type(), source, secure); if (!matches.empty() && !replace_existing) { // Matches already present that are not to be replaced. return; } for (const EntryMap::const_iterator& match : matches) { entries_.erase(match); } std::string domain_name = result->domain_name(); entries_.emplace( Key(std::move(domain_name), network_anonymization_key), Entry(std::move(result), source, secure, staleness_generation)); if (entries_.size() > max_entries_) { EvictEntries(); } } // Remove all stale entries, or if none stale, the soonest-to-expire, // least-secure entry. void HostResolverCache::EvictEntries() { base::TimeTicks now_ticks = tick_clock_->NowTicks(); base::Time now = clock_->Now(); bool stale_found = false; base::TimeDelta soonest_time_till_expriation = base::TimeDelta::Max(); std::optional best_for_removal; auto it = entries_.cbegin(); while (it != entries_.cend()) { if (it->second.IsStale(now, now_ticks, staleness_generation_)) { stale_found = true; it = entries_.erase(it); } else { base::TimeDelta time_till_expiration = it->second.TimeUntilExpiration(now, now_ticks); if (!best_for_removal.has_value() || time_till_expiration < soonest_time_till_expriation || (time_till_expiration == soonest_time_till_expriation && best_for_removal.value()->second.secure && !it->second.secure)) { soonest_time_till_expriation = time_till_expiration; best_for_removal = it; } ++it; } } if (!stale_found) { CHECK(best_for_removal.has_value()); entries_.erase(best_for_removal.value()); } CHECK_LE(entries_.size(), max_entries_); } base::Value HostResolverCache::SerializeEntries( bool serialize_staleness_generation, bool require_persistable_anonymization_key) const { base::Value::List list; for (const auto& [key, entry] : entries_) { base::Value::Dict dict; if (serialize_staleness_generation) { dict.Set(kStalenessGenerationKey, entry.staleness_generation); } base::Value anonymization_key_value; if (!key.network_anonymization_key.ToValue(&anonymization_key_value)) { if (require_persistable_anonymization_key) { continue; } else { // If the caller doesn't care about anonymization keys that can be // serialized and restored, construct a serialization just for the sake // of logging information. anonymization_key_value = base::Value("Non-persistable network anonymization key: " + key.network_anonymization_key.ToDebugString()); } } dict.Set(kNakKey, std::move(anonymization_key_value)); dict.Set(kSourceKey, ToValue(entry.source)); dict.Set(kSecureKey, entry.secure); dict.Set(kResultKey, entry.result->ToValue()); list.Append(std::move(dict)); } return base::Value(std::move(list)); } } // namespace net