1*3f982cf4SFabien Sanglard // Copyright 2019 The Chromium Authors. All rights reserved.
2*3f982cf4SFabien Sanglard // Use of this source code is governed by a BSD-style license that can be
3*3f982cf4SFabien Sanglard // found in the LICENSE file.
4*3f982cf4SFabien Sanglard
5*3f982cf4SFabien Sanglard #include "discovery/mdns/mdns_querier.h"
6*3f982cf4SFabien Sanglard
7*3f982cf4SFabien Sanglard #include <algorithm>
8*3f982cf4SFabien Sanglard #include <array>
9*3f982cf4SFabien Sanglard #include <bitset>
10*3f982cf4SFabien Sanglard #include <memory>
11*3f982cf4SFabien Sanglard #include <unordered_set>
12*3f982cf4SFabien Sanglard #include <utility>
13*3f982cf4SFabien Sanglard #include <vector>
14*3f982cf4SFabien Sanglard
15*3f982cf4SFabien Sanglard #include "discovery/common/config.h"
16*3f982cf4SFabien Sanglard #include "discovery/common/reporting_client.h"
17*3f982cf4SFabien Sanglard #include "discovery/mdns/mdns_random.h"
18*3f982cf4SFabien Sanglard #include "discovery/mdns/mdns_receiver.h"
19*3f982cf4SFabien Sanglard #include "discovery/mdns/mdns_sender.h"
20*3f982cf4SFabien Sanglard #include "discovery/mdns/public/mdns_constants.h"
21*3f982cf4SFabien Sanglard
22*3f982cf4SFabien Sanglard namespace openscreen {
23*3f982cf4SFabien Sanglard namespace discovery {
24*3f982cf4SFabien Sanglard namespace {
25*3f982cf4SFabien Sanglard
26*3f982cf4SFabien Sanglard constexpr std::array<DnsType, 5> kTranslatedNsecAnyQueryTypes = {
27*3f982cf4SFabien Sanglard DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA, DnsType::kSRV};
28*3f982cf4SFabien Sanglard
IsNegativeResponseFor(const MdnsRecord & record,DnsType type)29*3f982cf4SFabien Sanglard bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) {
30*3f982cf4SFabien Sanglard if (record.dns_type() != DnsType::kNSEC) {
31*3f982cf4SFabien Sanglard return false;
32*3f982cf4SFabien Sanglard }
33*3f982cf4SFabien Sanglard
34*3f982cf4SFabien Sanglard const NsecRecordRdata& nsec = absl::get<NsecRecordRdata>(record.rdata());
35*3f982cf4SFabien Sanglard
36*3f982cf4SFabien Sanglard // RFC 6762 section 6.1, the NSEC bit must NOT be set in the received NSEC
37*3f982cf4SFabien Sanglard // record to indicate this is an mDNS NSEC record rather than a traditional
38*3f982cf4SFabien Sanglard // DNS NSEC record.
39*3f982cf4SFabien Sanglard if (std::find(nsec.types().begin(), nsec.types().end(), DnsType::kNSEC) !=
40*3f982cf4SFabien Sanglard nsec.types().end()) {
41*3f982cf4SFabien Sanglard return false;
42*3f982cf4SFabien Sanglard }
43*3f982cf4SFabien Sanglard
44*3f982cf4SFabien Sanglard return std::find_if(nsec.types().begin(), nsec.types().end(),
45*3f982cf4SFabien Sanglard [type](DnsType stored_type) {
46*3f982cf4SFabien Sanglard return stored_type == type ||
47*3f982cf4SFabien Sanglard stored_type == DnsType::kANY;
48*3f982cf4SFabien Sanglard }) != nsec.types().end();
49*3f982cf4SFabien Sanglard }
50*3f982cf4SFabien Sanglard
51*3f982cf4SFabien Sanglard struct HashDnsType {
operator ()openscreen::discovery::__anone816dc120111::HashDnsType52*3f982cf4SFabien Sanglard inline size_t operator()(DnsType type) const {
53*3f982cf4SFabien Sanglard return static_cast<size_t>(type);
54*3f982cf4SFabien Sanglard }
55*3f982cf4SFabien Sanglard };
56*3f982cf4SFabien Sanglard
57*3f982cf4SFabien Sanglard // Helper used for sorting MDNS records. This function guarantees the following:
58*3f982cf4SFabien Sanglard // - All MdnsRecords with the same name appear adjacent to each-other.
59*3f982cf4SFabien Sanglard // - An NSEC record with a given name appears before all other records with the
60*3f982cf4SFabien Sanglard // same name.
CompareRecordByNameAndType(const MdnsRecord & first,const MdnsRecord & second)61*3f982cf4SFabien Sanglard bool CompareRecordByNameAndType(const MdnsRecord& first,
62*3f982cf4SFabien Sanglard const MdnsRecord& second) {
63*3f982cf4SFabien Sanglard if (first.name() != second.name()) {
64*3f982cf4SFabien Sanglard return first.name() < second.name();
65*3f982cf4SFabien Sanglard }
66*3f982cf4SFabien Sanglard
67*3f982cf4SFabien Sanglard if ((first.dns_type() == DnsType::kNSEC) !=
68*3f982cf4SFabien Sanglard (second.dns_type() == DnsType::kNSEC)) {
69*3f982cf4SFabien Sanglard return first.dns_type() == DnsType::kNSEC;
70*3f982cf4SFabien Sanglard }
71*3f982cf4SFabien Sanglard
72*3f982cf4SFabien Sanglard return first < second;
73*3f982cf4SFabien Sanglard }
74*3f982cf4SFabien Sanglard
75*3f982cf4SFabien Sanglard class DnsTypeBitSet {
76*3f982cf4SFabien Sanglard public:
77*3f982cf4SFabien Sanglard // Returns whether any types are currently stored in this data structure.
IsEmpty()78*3f982cf4SFabien Sanglard bool IsEmpty() { return !elements_.any(); }
79*3f982cf4SFabien Sanglard
80*3f982cf4SFabien Sanglard // Attempts to insert the given type into this data structure. Returns
81*3f982cf4SFabien Sanglard // true iff the type was not already present.
Insert(DnsType type)82*3f982cf4SFabien Sanglard bool Insert(DnsType type) {
83*3f982cf4SFabien Sanglard uint16_t bit = (type == DnsType::kANY) ? 0 : static_cast<uint16_t>(type);
84*3f982cf4SFabien Sanglard bool was_set = elements_.test(bit);
85*3f982cf4SFabien Sanglard elements_.set(bit);
86*3f982cf4SFabien Sanglard return !was_set;
87*3f982cf4SFabien Sanglard }
88*3f982cf4SFabien Sanglard
89*3f982cf4SFabien Sanglard // Iterates over all members of the provided container, inserting each
90*3f982cf4SFabien Sanglard // DnsType contained within to this instance. Returns true iff any element
91*3f982cf4SFabien Sanglard // inserted was not already present in this instance.
92*3f982cf4SFabien Sanglard template <typename Container>
Insert(const Container & container)93*3f982cf4SFabien Sanglard bool Insert(const Container& container) {
94*3f982cf4SFabien Sanglard bool has_element_been_inserted = false;
95*3f982cf4SFabien Sanglard for (DnsType type : container) {
96*3f982cf4SFabien Sanglard has_element_been_inserted |= Insert(type);
97*3f982cf4SFabien Sanglard }
98*3f982cf4SFabien Sanglard return has_element_been_inserted;
99*3f982cf4SFabien Sanglard }
100*3f982cf4SFabien Sanglard
101*3f982cf4SFabien Sanglard // Attempts to remove the given type from this data structure. Returns true
102*3f982cf4SFabien Sanglard // iff the type was present prior to this call.
Remove(DnsType type)103*3f982cf4SFabien Sanglard bool Remove(DnsType type) {
104*3f982cf4SFabien Sanglard if (IsEmpty()) {
105*3f982cf4SFabien Sanglard return false;
106*3f982cf4SFabien Sanglard } else if (type == DnsType::kANY) {
107*3f982cf4SFabien Sanglard elements_.reset();
108*3f982cf4SFabien Sanglard return true;
109*3f982cf4SFabien Sanglard }
110*3f982cf4SFabien Sanglard
111*3f982cf4SFabien Sanglard uint16_t bit = static_cast<uint16_t>(type);
112*3f982cf4SFabien Sanglard bool was_set = elements_.test(bit);
113*3f982cf4SFabien Sanglard elements_.reset(bit);
114*3f982cf4SFabien Sanglard return was_set;
115*3f982cf4SFabien Sanglard }
116*3f982cf4SFabien Sanglard
117*3f982cf4SFabien Sanglard // Returns the DnsTypes currently stored in this data structure.
GetTypes() const118*3f982cf4SFabien Sanglard std::vector<DnsType> GetTypes() const {
119*3f982cf4SFabien Sanglard if (elements_.test(0)) {
120*3f982cf4SFabien Sanglard return {DnsType::kANY};
121*3f982cf4SFabien Sanglard }
122*3f982cf4SFabien Sanglard
123*3f982cf4SFabien Sanglard std::vector<DnsType> types;
124*3f982cf4SFabien Sanglard for (DnsType type : kSupportedDnsTypes) {
125*3f982cf4SFabien Sanglard if (type == DnsType::kANY) {
126*3f982cf4SFabien Sanglard continue;
127*3f982cf4SFabien Sanglard }
128*3f982cf4SFabien Sanglard
129*3f982cf4SFabien Sanglard uint16_t cast_int = static_cast<uint16_t>(type);
130*3f982cf4SFabien Sanglard if (elements_.test(cast_int)) {
131*3f982cf4SFabien Sanglard types.push_back(type);
132*3f982cf4SFabien Sanglard }
133*3f982cf4SFabien Sanglard }
134*3f982cf4SFabien Sanglard return types;
135*3f982cf4SFabien Sanglard }
136*3f982cf4SFabien Sanglard
137*3f982cf4SFabien Sanglard private:
138*3f982cf4SFabien Sanglard std::bitset<64> elements_;
139*3f982cf4SFabien Sanglard };
140*3f982cf4SFabien Sanglard
141*3f982cf4SFabien Sanglard // Modifies |records| such that no NSEC record signifies the nonexistance of a
142*3f982cf4SFabien Sanglard // record which is also present in the same message. Order of the input vector
143*3f982cf4SFabien Sanglard // is NOT preserved.
144*3f982cf4SFabien Sanglard // NOTE: |records| is not of type MdnsRecord::ConstRef because the members must
145*3f982cf4SFabien Sanglard // be modified.
146*3f982cf4SFabien Sanglard // TODO(b/170353378): Break this logic into a separate processing module between
147*3f982cf4SFabien Sanglard // the MdnsReader and the MdnsQuerier.
RemoveInvalidNsecFlags(std::vector<MdnsRecord> * records)148*3f982cf4SFabien Sanglard void RemoveInvalidNsecFlags(std::vector<MdnsRecord>* records) {
149*3f982cf4SFabien Sanglard // Sort the records so NSEC records are first so that only one iteration
150*3f982cf4SFabien Sanglard // through all records is needed.
151*3f982cf4SFabien Sanglard std::sort(records->begin(), records->end(), CompareRecordByNameAndType);
152*3f982cf4SFabien Sanglard
153*3f982cf4SFabien Sanglard // The set of NSEC records that need to be removed from |records|. This can't
154*3f982cf4SFabien Sanglard // be done as part of the below loop because it would invalidate the iterator
155*3f982cf4SFabien Sanglard // that's still being used.
156*3f982cf4SFabien Sanglard std::vector<std::vector<MdnsRecord>::iterator> nsecs_to_delete;
157*3f982cf4SFabien Sanglard
158*3f982cf4SFabien Sanglard // Process all elements.
159*3f982cf4SFabien Sanglard for (auto it = records->begin(); it != records->end();) {
160*3f982cf4SFabien Sanglard if (it->dns_type() != DnsType::kNSEC) {
161*3f982cf4SFabien Sanglard it++;
162*3f982cf4SFabien Sanglard continue;
163*3f982cf4SFabien Sanglard }
164*3f982cf4SFabien Sanglard
165*3f982cf4SFabien Sanglard // Track whether the current NSEC record in the input vector has been
166*3f982cf4SFabien Sanglard // modified by some step of this algorithm, be that merging with another
167*3f982cf4SFabien Sanglard // record, removing a DnsType, or any other modification.
168*3f982cf4SFabien Sanglard bool has_changed = false;
169*3f982cf4SFabien Sanglard
170*3f982cf4SFabien Sanglard // The types for the new record to create, if |has_changed|.
171*3f982cf4SFabien Sanglard const NsecRecordRdata& nsec_rdata = absl::get<NsecRecordRdata>(it->rdata());
172*3f982cf4SFabien Sanglard DnsTypeBitSet types;
173*3f982cf4SFabien Sanglard for (DnsType type : nsec_rdata.types()) {
174*3f982cf4SFabien Sanglard types.Insert(type);
175*3f982cf4SFabien Sanglard }
176*3f982cf4SFabien Sanglard auto nsec = it;
177*3f982cf4SFabien Sanglard it++;
178*3f982cf4SFabien Sanglard
179*3f982cf4SFabien Sanglard // Combine multiple NSECs to simplify the following code. This probably
180*3f982cf4SFabien Sanglard // won't happen, but the RFC doesn't exclude the possibility, so account for
181*3f982cf4SFabien Sanglard // it. Define the TTL of this new NSEC record created by this merge process
182*3f982cf4SFabien Sanglard // to be the minimum of all merged NSEC records.
183*3f982cf4SFabien Sanglard std::chrono::seconds new_ttl = nsec->ttl();
184*3f982cf4SFabien Sanglard while (it != records->end() && it->name() == nsec->name() &&
185*3f982cf4SFabien Sanglard it->dns_type() == DnsType::kNSEC) {
186*3f982cf4SFabien Sanglard has_changed |=
187*3f982cf4SFabien Sanglard types.Insert(absl::get<NsecRecordRdata>(it->rdata()).types());
188*3f982cf4SFabien Sanglard new_ttl = std::min(new_ttl, it->ttl());
189*3f982cf4SFabien Sanglard it = records->erase(it);
190*3f982cf4SFabien Sanglard }
191*3f982cf4SFabien Sanglard
192*3f982cf4SFabien Sanglard // Remove any types associated with a known record type.
193*3f982cf4SFabien Sanglard for (; it != records->end() && it->name() == nsec->name(); it++) {
194*3f982cf4SFabien Sanglard OSP_DCHECK(it->dns_type() != DnsType::kNSEC);
195*3f982cf4SFabien Sanglard has_changed |= types.Remove(it->dns_type());
196*3f982cf4SFabien Sanglard }
197*3f982cf4SFabien Sanglard
198*3f982cf4SFabien Sanglard // Modify the stored NSEC record, if needed.
199*3f982cf4SFabien Sanglard if (has_changed && types.IsEmpty()) {
200*3f982cf4SFabien Sanglard nsecs_to_delete.push_back(nsec);
201*3f982cf4SFabien Sanglard } else if (has_changed) {
202*3f982cf4SFabien Sanglard NsecRecordRdata new_rdata(nsec_rdata.next_domain_name(),
203*3f982cf4SFabien Sanglard types.GetTypes());
204*3f982cf4SFabien Sanglard *nsec = MdnsRecord(nsec->name(), nsec->dns_type(), nsec->dns_class(),
205*3f982cf4SFabien Sanglard nsec->record_type(), new_ttl, std::move(new_rdata));
206*3f982cf4SFabien Sanglard }
207*3f982cf4SFabien Sanglard }
208*3f982cf4SFabien Sanglard
209*3f982cf4SFabien Sanglard // Erase invalid NSEC records. Go backwards to avoid invalidating the
210*3f982cf4SFabien Sanglard // remaining iterators.
211*3f982cf4SFabien Sanglard for (auto erase_it = nsecs_to_delete.rbegin();
212*3f982cf4SFabien Sanglard erase_it != nsecs_to_delete.rend(); erase_it++) {
213*3f982cf4SFabien Sanglard records->erase(*erase_it);
214*3f982cf4SFabien Sanglard }
215*3f982cf4SFabien Sanglard }
216*3f982cf4SFabien Sanglard
217*3f982cf4SFabien Sanglard } // namespace
218*3f982cf4SFabien Sanglard
RecordTrackerLruCache(MdnsQuerier * querier,MdnsSender * sender,MdnsRandom * random_delay,TaskRunner * task_runner,ClockNowFunctionPtr now_function,ReportingClient * reporting_client,const Config & config)219*3f982cf4SFabien Sanglard MdnsQuerier::RecordTrackerLruCache::RecordTrackerLruCache(
220*3f982cf4SFabien Sanglard MdnsQuerier* querier,
221*3f982cf4SFabien Sanglard MdnsSender* sender,
222*3f982cf4SFabien Sanglard MdnsRandom* random_delay,
223*3f982cf4SFabien Sanglard TaskRunner* task_runner,
224*3f982cf4SFabien Sanglard ClockNowFunctionPtr now_function,
225*3f982cf4SFabien Sanglard ReportingClient* reporting_client,
226*3f982cf4SFabien Sanglard const Config& config)
227*3f982cf4SFabien Sanglard : querier_(querier),
228*3f982cf4SFabien Sanglard sender_(sender),
229*3f982cf4SFabien Sanglard random_delay_(random_delay),
230*3f982cf4SFabien Sanglard task_runner_(task_runner),
231*3f982cf4SFabien Sanglard now_function_(now_function),
232*3f982cf4SFabien Sanglard reporting_client_(reporting_client),
233*3f982cf4SFabien Sanglard config_(config) {
234*3f982cf4SFabien Sanglard OSP_DCHECK(sender_);
235*3f982cf4SFabien Sanglard OSP_DCHECK(random_delay_);
236*3f982cf4SFabien Sanglard OSP_DCHECK(task_runner_);
237*3f982cf4SFabien Sanglard OSP_DCHECK(reporting_client_);
238*3f982cf4SFabien Sanglard OSP_DCHECK_GT(config_.querier_max_records_cached, 0);
239*3f982cf4SFabien Sanglard }
240*3f982cf4SFabien Sanglard
241*3f982cf4SFabien Sanglard std::vector<std::reference_wrapper<const MdnsRecordTracker>>
Find(const DomainName & name)242*3f982cf4SFabien Sanglard MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name) {
243*3f982cf4SFabien Sanglard return Find(name, DnsType::kANY, DnsClass::kANY);
244*3f982cf4SFabien Sanglard }
245*3f982cf4SFabien Sanglard
246*3f982cf4SFabien Sanglard std::vector<std::reference_wrapper<const MdnsRecordTracker>>
Find(const DomainName & name,DnsType dns_type,DnsClass dns_class)247*3f982cf4SFabien Sanglard MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name,
248*3f982cf4SFabien Sanglard DnsType dns_type,
249*3f982cf4SFabien Sanglard DnsClass dns_class) {
250*3f982cf4SFabien Sanglard std::vector<RecordTrackerConstRef> results;
251*3f982cf4SFabien Sanglard auto pair = records_.equal_range(name);
252*3f982cf4SFabien Sanglard for (auto it = pair.first; it != pair.second; it++) {
253*3f982cf4SFabien Sanglard const MdnsRecordTracker& tracker = *it->second;
254*3f982cf4SFabien Sanglard if ((dns_type == DnsType::kANY || dns_type == tracker.dns_type()) &&
255*3f982cf4SFabien Sanglard (dns_class == DnsClass::kANY || dns_class == tracker.dns_class())) {
256*3f982cf4SFabien Sanglard results.push_back(std::cref(tracker));
257*3f982cf4SFabien Sanglard }
258*3f982cf4SFabien Sanglard }
259*3f982cf4SFabien Sanglard
260*3f982cf4SFabien Sanglard return results;
261*3f982cf4SFabien Sanglard }
262*3f982cf4SFabien Sanglard
Erase(const DomainName & domain,TrackerApplicableCheck check)263*3f982cf4SFabien Sanglard int MdnsQuerier::RecordTrackerLruCache::Erase(const DomainName& domain,
264*3f982cf4SFabien Sanglard TrackerApplicableCheck check) {
265*3f982cf4SFabien Sanglard auto pair = records_.equal_range(domain);
266*3f982cf4SFabien Sanglard int count = 0;
267*3f982cf4SFabien Sanglard for (RecordMap::iterator it = pair.first; it != pair.second;) {
268*3f982cf4SFabien Sanglard if (check(*it->second)) {
269*3f982cf4SFabien Sanglard lru_order_.erase(it->second);
270*3f982cf4SFabien Sanglard it = records_.erase(it);
271*3f982cf4SFabien Sanglard count++;
272*3f982cf4SFabien Sanglard } else {
273*3f982cf4SFabien Sanglard it++;
274*3f982cf4SFabien Sanglard }
275*3f982cf4SFabien Sanglard }
276*3f982cf4SFabien Sanglard
277*3f982cf4SFabien Sanglard return count;
278*3f982cf4SFabien Sanglard }
279*3f982cf4SFabien Sanglard
ExpireSoon(const DomainName & domain,TrackerApplicableCheck check)280*3f982cf4SFabien Sanglard int MdnsQuerier::RecordTrackerLruCache::ExpireSoon(
281*3f982cf4SFabien Sanglard const DomainName& domain,
282*3f982cf4SFabien Sanglard TrackerApplicableCheck check) {
283*3f982cf4SFabien Sanglard auto pair = records_.equal_range(domain);
284*3f982cf4SFabien Sanglard int count = 0;
285*3f982cf4SFabien Sanglard for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
286*3f982cf4SFabien Sanglard if (check(*it->second)) {
287*3f982cf4SFabien Sanglard MoveToEnd(it);
288*3f982cf4SFabien Sanglard it->second->ExpireSoon();
289*3f982cf4SFabien Sanglard count++;
290*3f982cf4SFabien Sanglard }
291*3f982cf4SFabien Sanglard }
292*3f982cf4SFabien Sanglard
293*3f982cf4SFabien Sanglard return count;
294*3f982cf4SFabien Sanglard }
295*3f982cf4SFabien Sanglard
Update(const MdnsRecord & record,TrackerApplicableCheck check)296*3f982cf4SFabien Sanglard int MdnsQuerier::RecordTrackerLruCache::Update(const MdnsRecord& record,
297*3f982cf4SFabien Sanglard TrackerApplicableCheck check) {
298*3f982cf4SFabien Sanglard return Update(record, check, [](const MdnsRecordTracker& t) {});
299*3f982cf4SFabien Sanglard }
300*3f982cf4SFabien Sanglard
Update(const MdnsRecord & record,TrackerApplicableCheck check,TrackerChangeCallback on_rdata_update)301*3f982cf4SFabien Sanglard int MdnsQuerier::RecordTrackerLruCache::Update(
302*3f982cf4SFabien Sanglard const MdnsRecord& record,
303*3f982cf4SFabien Sanglard TrackerApplicableCheck check,
304*3f982cf4SFabien Sanglard TrackerChangeCallback on_rdata_update) {
305*3f982cf4SFabien Sanglard auto pair = records_.equal_range(record.name());
306*3f982cf4SFabien Sanglard int count = 0;
307*3f982cf4SFabien Sanglard for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
308*3f982cf4SFabien Sanglard if (check(*it->second)) {
309*3f982cf4SFabien Sanglard auto result = it->second->Update(record);
310*3f982cf4SFabien Sanglard
311*3f982cf4SFabien Sanglard if (result.is_error()) {
312*3f982cf4SFabien Sanglard reporting_client_->OnRecoverableError(
313*3f982cf4SFabien Sanglard Error(Error::Code::kUpdateReceivedRecordFailure,
314*3f982cf4SFabien Sanglard result.error().ToString()));
315*3f982cf4SFabien Sanglard continue;
316*3f982cf4SFabien Sanglard }
317*3f982cf4SFabien Sanglard
318*3f982cf4SFabien Sanglard count++;
319*3f982cf4SFabien Sanglard if (result.value() == MdnsRecordTracker::UpdateType::kGoodbye) {
320*3f982cf4SFabien Sanglard it->second->ExpireSoon();
321*3f982cf4SFabien Sanglard MoveToEnd(it);
322*3f982cf4SFabien Sanglard } else {
323*3f982cf4SFabien Sanglard MoveToBeginning(it);
324*3f982cf4SFabien Sanglard if (result.value() == MdnsRecordTracker::UpdateType::kRdata) {
325*3f982cf4SFabien Sanglard on_rdata_update(*it->second);
326*3f982cf4SFabien Sanglard }
327*3f982cf4SFabien Sanglard }
328*3f982cf4SFabien Sanglard }
329*3f982cf4SFabien Sanglard }
330*3f982cf4SFabien Sanglard
331*3f982cf4SFabien Sanglard return count;
332*3f982cf4SFabien Sanglard }
333*3f982cf4SFabien Sanglard
StartTracking(MdnsRecord record,DnsType dns_type)334*3f982cf4SFabien Sanglard const MdnsRecordTracker& MdnsQuerier::RecordTrackerLruCache::StartTracking(
335*3f982cf4SFabien Sanglard MdnsRecord record,
336*3f982cf4SFabien Sanglard DnsType dns_type) {
337*3f982cf4SFabien Sanglard auto expiration_callback = [this](const MdnsRecordTracker* tracker,
338*3f982cf4SFabien Sanglard const MdnsRecord& record) {
339*3f982cf4SFabien Sanglard querier_->OnRecordExpired(tracker, record);
340*3f982cf4SFabien Sanglard };
341*3f982cf4SFabien Sanglard
342*3f982cf4SFabien Sanglard while (lru_order_.size() >=
343*3f982cf4SFabien Sanglard static_cast<size_t>(config_.querier_max_records_cached)) {
344*3f982cf4SFabien Sanglard // This call erases one of the tracked records.
345*3f982cf4SFabien Sanglard OSP_DVLOG << "Maximum cacheable record count exceeded ("
346*3f982cf4SFabien Sanglard << config_.querier_max_records_cached << ")";
347*3f982cf4SFabien Sanglard lru_order_.back().ExpireNow();
348*3f982cf4SFabien Sanglard }
349*3f982cf4SFabien Sanglard
350*3f982cf4SFabien Sanglard auto name = record.name();
351*3f982cf4SFabien Sanglard lru_order_.emplace_front(std::move(record), dns_type, sender_, task_runner_,
352*3f982cf4SFabien Sanglard now_function_, random_delay_,
353*3f982cf4SFabien Sanglard std::move(expiration_callback));
354*3f982cf4SFabien Sanglard records_.emplace(std::move(name), lru_order_.begin());
355*3f982cf4SFabien Sanglard
356*3f982cf4SFabien Sanglard return lru_order_.front();
357*3f982cf4SFabien Sanglard }
358*3f982cf4SFabien Sanglard
MoveToBeginning(MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it)359*3f982cf4SFabien Sanglard void MdnsQuerier::RecordTrackerLruCache::MoveToBeginning(
360*3f982cf4SFabien Sanglard MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
361*3f982cf4SFabien Sanglard lru_order_.splice(lru_order_.begin(), lru_order_, it->second);
362*3f982cf4SFabien Sanglard it->second = lru_order_.begin();
363*3f982cf4SFabien Sanglard }
364*3f982cf4SFabien Sanglard
MoveToEnd(MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it)365*3f982cf4SFabien Sanglard void MdnsQuerier::RecordTrackerLruCache::MoveToEnd(
366*3f982cf4SFabien Sanglard MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
367*3f982cf4SFabien Sanglard lru_order_.splice(lru_order_.end(), lru_order_, it->second);
368*3f982cf4SFabien Sanglard it->second = --lru_order_.end();
369*3f982cf4SFabien Sanglard }
370*3f982cf4SFabien Sanglard
MdnsQuerier(MdnsSender * sender,MdnsReceiver * receiver,TaskRunner * task_runner,ClockNowFunctionPtr now_function,MdnsRandom * random_delay,ReportingClient * reporting_client,Config config)371*3f982cf4SFabien Sanglard MdnsQuerier::MdnsQuerier(MdnsSender* sender,
372*3f982cf4SFabien Sanglard MdnsReceiver* receiver,
373*3f982cf4SFabien Sanglard TaskRunner* task_runner,
374*3f982cf4SFabien Sanglard ClockNowFunctionPtr now_function,
375*3f982cf4SFabien Sanglard MdnsRandom* random_delay,
376*3f982cf4SFabien Sanglard ReportingClient* reporting_client,
377*3f982cf4SFabien Sanglard Config config)
378*3f982cf4SFabien Sanglard : sender_(sender),
379*3f982cf4SFabien Sanglard receiver_(receiver),
380*3f982cf4SFabien Sanglard task_runner_(task_runner),
381*3f982cf4SFabien Sanglard now_function_(now_function),
382*3f982cf4SFabien Sanglard random_delay_(random_delay),
383*3f982cf4SFabien Sanglard reporting_client_(reporting_client),
384*3f982cf4SFabien Sanglard config_(std::move(config)),
385*3f982cf4SFabien Sanglard records_(this,
386*3f982cf4SFabien Sanglard sender_,
387*3f982cf4SFabien Sanglard random_delay_,
388*3f982cf4SFabien Sanglard task_runner_,
389*3f982cf4SFabien Sanglard now_function_,
390*3f982cf4SFabien Sanglard reporting_client_,
391*3f982cf4SFabien Sanglard config_) {
392*3f982cf4SFabien Sanglard OSP_DCHECK(sender_);
393*3f982cf4SFabien Sanglard OSP_DCHECK(receiver_);
394*3f982cf4SFabien Sanglard OSP_DCHECK(task_runner_);
395*3f982cf4SFabien Sanglard OSP_DCHECK(now_function_);
396*3f982cf4SFabien Sanglard OSP_DCHECK(random_delay_);
397*3f982cf4SFabien Sanglard OSP_DCHECK(reporting_client_);
398*3f982cf4SFabien Sanglard
399*3f982cf4SFabien Sanglard receiver_->AddResponseCallback(this);
400*3f982cf4SFabien Sanglard }
401*3f982cf4SFabien Sanglard
~MdnsQuerier()402*3f982cf4SFabien Sanglard MdnsQuerier::~MdnsQuerier() {
403*3f982cf4SFabien Sanglard receiver_->RemoveResponseCallback(this);
404*3f982cf4SFabien Sanglard }
405*3f982cf4SFabien Sanglard
406*3f982cf4SFabien Sanglard // NOTE: The code below is range loops instead of std:find_if, for better
407*3f982cf4SFabien Sanglard // readability, brevity and homogeneity. Using std::find_if results in a few
408*3f982cf4SFabien Sanglard // more lines of code, readability suffers from extra lambdas.
409*3f982cf4SFabien Sanglard
StartQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)410*3f982cf4SFabien Sanglard void MdnsQuerier::StartQuery(const DomainName& name,
411*3f982cf4SFabien Sanglard DnsType dns_type,
412*3f982cf4SFabien Sanglard DnsClass dns_class,
413*3f982cf4SFabien Sanglard MdnsRecordChangedCallback* callback) {
414*3f982cf4SFabien Sanglard OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
415*3f982cf4SFabien Sanglard OSP_DCHECK(callback);
416*3f982cf4SFabien Sanglard OSP_DCHECK(CanBeQueried(dns_type));
417*3f982cf4SFabien Sanglard
418*3f982cf4SFabien Sanglard // Add a new callback if haven't seen it before
419*3f982cf4SFabien Sanglard auto callbacks_it = callbacks_.equal_range(name);
420*3f982cf4SFabien Sanglard for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
421*3f982cf4SFabien Sanglard const CallbackInfo& callback_info = entry->second;
422*3f982cf4SFabien Sanglard if (dns_type == callback_info.dns_type &&
423*3f982cf4SFabien Sanglard dns_class == callback_info.dns_class &&
424*3f982cf4SFabien Sanglard callback == callback_info.callback) {
425*3f982cf4SFabien Sanglard // Already have this callback
426*3f982cf4SFabien Sanglard return;
427*3f982cf4SFabien Sanglard }
428*3f982cf4SFabien Sanglard }
429*3f982cf4SFabien Sanglard callbacks_.emplace(name, CallbackInfo{callback, dns_type, dns_class});
430*3f982cf4SFabien Sanglard
431*3f982cf4SFabien Sanglard // Notify the new callback with previously cached records.
432*3f982cf4SFabien Sanglard // NOTE: In the future, could allow callers to fetch cached records after
433*3f982cf4SFabien Sanglard // adding a callback, for example to prime the UI.
434*3f982cf4SFabien Sanglard std::vector<PendingQueryChange> pending_changes;
435*3f982cf4SFabien Sanglard const std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
436*3f982cf4SFabien Sanglard records_.Find(name, dns_type, dns_class);
437*3f982cf4SFabien Sanglard for (const MdnsRecordTracker& tracker : trackers) {
438*3f982cf4SFabien Sanglard if (!tracker.is_negative_response()) {
439*3f982cf4SFabien Sanglard MdnsRecord stored_record(name, tracker.dns_type(), tracker.dns_class(),
440*3f982cf4SFabien Sanglard tracker.record_type(), tracker.ttl(),
441*3f982cf4SFabien Sanglard tracker.rdata());
442*3f982cf4SFabien Sanglard std::vector<PendingQueryChange> new_changes = callback->OnRecordChanged(
443*3f982cf4SFabien Sanglard std::move(stored_record), RecordChangedEvent::kCreated);
444*3f982cf4SFabien Sanglard pending_changes.insert(pending_changes.end(), new_changes.begin(),
445*3f982cf4SFabien Sanglard new_changes.end());
446*3f982cf4SFabien Sanglard }
447*3f982cf4SFabien Sanglard }
448*3f982cf4SFabien Sanglard
449*3f982cf4SFabien Sanglard // Add a new question if haven't seen it before
450*3f982cf4SFabien Sanglard auto questions_it = questions_.equal_range(name);
451*3f982cf4SFabien Sanglard const bool is_question_already_tracked =
452*3f982cf4SFabien Sanglard std::find_if(questions_it.first, questions_it.second,
453*3f982cf4SFabien Sanglard [dns_type, dns_class](const auto& entry) {
454*3f982cf4SFabien Sanglard const MdnsQuestion& tracked_question =
455*3f982cf4SFabien Sanglard entry.second->question();
456*3f982cf4SFabien Sanglard return dns_type == tracked_question.dns_type() &&
457*3f982cf4SFabien Sanglard dns_class == tracked_question.dns_class();
458*3f982cf4SFabien Sanglard }) != questions_it.second;
459*3f982cf4SFabien Sanglard if (!is_question_already_tracked) {
460*3f982cf4SFabien Sanglard AddQuestion(
461*3f982cf4SFabien Sanglard MdnsQuestion(name, dns_type, dns_class, ResponseType::kMulticast));
462*3f982cf4SFabien Sanglard }
463*3f982cf4SFabien Sanglard
464*3f982cf4SFabien Sanglard // Apply any pending changes from the OnRecordChanged() callbacks.
465*3f982cf4SFabien Sanglard ApplyPendingChanges(std::move(pending_changes));
466*3f982cf4SFabien Sanglard }
467*3f982cf4SFabien Sanglard
StopQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)468*3f982cf4SFabien Sanglard void MdnsQuerier::StopQuery(const DomainName& name,
469*3f982cf4SFabien Sanglard DnsType dns_type,
470*3f982cf4SFabien Sanglard DnsClass dns_class,
471*3f982cf4SFabien Sanglard MdnsRecordChangedCallback* callback) {
472*3f982cf4SFabien Sanglard OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
473*3f982cf4SFabien Sanglard OSP_DCHECK(callback);
474*3f982cf4SFabien Sanglard
475*3f982cf4SFabien Sanglard if (!CanBeQueried(dns_type)) {
476*3f982cf4SFabien Sanglard return;
477*3f982cf4SFabien Sanglard }
478*3f982cf4SFabien Sanglard
479*3f982cf4SFabien Sanglard // Find and remove the callback.
480*3f982cf4SFabien Sanglard int callbacks_for_key = 0;
481*3f982cf4SFabien Sanglard auto callbacks_it = callbacks_.equal_range(name);
482*3f982cf4SFabien Sanglard for (auto entry = callbacks_it.first; entry != callbacks_it.second;) {
483*3f982cf4SFabien Sanglard const CallbackInfo& callback_info = entry->second;
484*3f982cf4SFabien Sanglard if (dns_type == callback_info.dns_type &&
485*3f982cf4SFabien Sanglard dns_class == callback_info.dns_class) {
486*3f982cf4SFabien Sanglard if (callback == callback_info.callback) {
487*3f982cf4SFabien Sanglard entry = callbacks_.erase(entry);
488*3f982cf4SFabien Sanglard } else {
489*3f982cf4SFabien Sanglard ++callbacks_for_key;
490*3f982cf4SFabien Sanglard ++entry;
491*3f982cf4SFabien Sanglard }
492*3f982cf4SFabien Sanglard }
493*3f982cf4SFabien Sanglard }
494*3f982cf4SFabien Sanglard
495*3f982cf4SFabien Sanglard // Exit if there are still callbacks registered for DomainName + DnsType +
496*3f982cf4SFabien Sanglard // DnsClass
497*3f982cf4SFabien Sanglard if (callbacks_for_key > 0) {
498*3f982cf4SFabien Sanglard return;
499*3f982cf4SFabien Sanglard }
500*3f982cf4SFabien Sanglard
501*3f982cf4SFabien Sanglard // Find and delete a question that does not have any associated callbacks
502*3f982cf4SFabien Sanglard auto questions_it = questions_.equal_range(name);
503*3f982cf4SFabien Sanglard for (auto entry = questions_it.first; entry != questions_it.second; ++entry) {
504*3f982cf4SFabien Sanglard const MdnsQuestion& tracked_question = entry->second->question();
505*3f982cf4SFabien Sanglard if (dns_type == tracked_question.dns_type() &&
506*3f982cf4SFabien Sanglard dns_class == tracked_question.dns_class()) {
507*3f982cf4SFabien Sanglard questions_.erase(entry);
508*3f982cf4SFabien Sanglard return;
509*3f982cf4SFabien Sanglard }
510*3f982cf4SFabien Sanglard }
511*3f982cf4SFabien Sanglard }
512*3f982cf4SFabien Sanglard
ReinitializeQueries(const DomainName & name)513*3f982cf4SFabien Sanglard void MdnsQuerier::ReinitializeQueries(const DomainName& name) {
514*3f982cf4SFabien Sanglard OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
515*3f982cf4SFabien Sanglard
516*3f982cf4SFabien Sanglard // Get the ongoing queries and their callbacks.
517*3f982cf4SFabien Sanglard std::vector<CallbackInfo> callbacks;
518*3f982cf4SFabien Sanglard auto its = callbacks_.equal_range(name);
519*3f982cf4SFabien Sanglard for (auto it = its.first; it != its.second; it++) {
520*3f982cf4SFabien Sanglard callbacks.push_back(std::move(it->second));
521*3f982cf4SFabien Sanglard }
522*3f982cf4SFabien Sanglard callbacks_.erase(name);
523*3f982cf4SFabien Sanglard
524*3f982cf4SFabien Sanglard // Remove all known questions and answers.
525*3f982cf4SFabien Sanglard questions_.erase(name);
526*3f982cf4SFabien Sanglard records_.Erase(name, [](const MdnsRecordTracker& tracker) { return true; });
527*3f982cf4SFabien Sanglard
528*3f982cf4SFabien Sanglard // Restart the queries.
529*3f982cf4SFabien Sanglard for (const auto& cb : callbacks) {
530*3f982cf4SFabien Sanglard StartQuery(name, cb.dns_type, cb.dns_class, cb.callback);
531*3f982cf4SFabien Sanglard }
532*3f982cf4SFabien Sanglard }
533*3f982cf4SFabien Sanglard
OnMessageReceived(const MdnsMessage & message)534*3f982cf4SFabien Sanglard void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) {
535*3f982cf4SFabien Sanglard OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
536*3f982cf4SFabien Sanglard OSP_DCHECK(message.type() == MessageType::Response);
537*3f982cf4SFabien Sanglard
538*3f982cf4SFabien Sanglard OSP_DVLOG << "Received mDNS Response message with "
539*3f982cf4SFabien Sanglard << message.answers().size() << " answers and "
540*3f982cf4SFabien Sanglard << message.additional_records().size()
541*3f982cf4SFabien Sanglard << " additional records. Processing...";
542*3f982cf4SFabien Sanglard
543*3f982cf4SFabien Sanglard std::vector<MdnsRecord> records_to_process;
544*3f982cf4SFabien Sanglard
545*3f982cf4SFabien Sanglard // Add any records that are relevant for this querier.
546*3f982cf4SFabien Sanglard bool found_relevant_records = false;
547*3f982cf4SFabien Sanglard for (const MdnsRecord& record : message.answers()) {
548*3f982cf4SFabien Sanglard if (ShouldAnswerRecordBeProcessed(record)) {
549*3f982cf4SFabien Sanglard records_to_process.push_back(record);
550*3f982cf4SFabien Sanglard found_relevant_records = true;
551*3f982cf4SFabien Sanglard }
552*3f982cf4SFabien Sanglard }
553*3f982cf4SFabien Sanglard
554*3f982cf4SFabien Sanglard // If any of the message's answers are relevant, add all additional records.
555*3f982cf4SFabien Sanglard // Else, since the message has already been received and parsed, use any
556*3f982cf4SFabien Sanglard // individual records relevant to this querier to update the cache.
557*3f982cf4SFabien Sanglard for (const MdnsRecord& record : message.additional_records()) {
558*3f982cf4SFabien Sanglard if (found_relevant_records || ShouldAnswerRecordBeProcessed(record)) {
559*3f982cf4SFabien Sanglard records_to_process.push_back(record);
560*3f982cf4SFabien Sanglard }
561*3f982cf4SFabien Sanglard }
562*3f982cf4SFabien Sanglard
563*3f982cf4SFabien Sanglard // Drop NSEC records associated with a non-NSEC record of the same type.
564*3f982cf4SFabien Sanglard RemoveInvalidNsecFlags(&records_to_process);
565*3f982cf4SFabien Sanglard
566*3f982cf4SFabien Sanglard // Process all remaining records.
567*3f982cf4SFabien Sanglard for (const MdnsRecord& record_to_process : records_to_process) {
568*3f982cf4SFabien Sanglard ProcessRecord(record_to_process);
569*3f982cf4SFabien Sanglard }
570*3f982cf4SFabien Sanglard
571*3f982cf4SFabien Sanglard OSP_DVLOG << "\tmDNS Response processed (" << records_to_process.size()
572*3f982cf4SFabien Sanglard << " records accepted)!";
573*3f982cf4SFabien Sanglard
574*3f982cf4SFabien Sanglard // TODO(crbug.com/openscreen/83): Check authority records.
575*3f982cf4SFabien Sanglard }
576*3f982cf4SFabien Sanglard
ShouldAnswerRecordBeProcessed(const MdnsRecord & answer)577*3f982cf4SFabien Sanglard bool MdnsQuerier::ShouldAnswerRecordBeProcessed(const MdnsRecord& answer) {
578*3f982cf4SFabien Sanglard // First, accept the record if it's associated with an ongoing question.
579*3f982cf4SFabien Sanglard const auto questions_range = questions_.equal_range(answer.name());
580*3f982cf4SFabien Sanglard const auto it = std::find_if(
581*3f982cf4SFabien Sanglard questions_range.first, questions_range.second,
582*3f982cf4SFabien Sanglard [&answer](const auto& pair) {
583*3f982cf4SFabien Sanglard return (pair.second->question().dns_type() == DnsType::kANY ||
584*3f982cf4SFabien Sanglard IsNegativeResponseFor(answer,
585*3f982cf4SFabien Sanglard pair.second->question().dns_type()) ||
586*3f982cf4SFabien Sanglard pair.second->question().dns_type() == answer.dns_type()) &&
587*3f982cf4SFabien Sanglard (pair.second->question().dns_class() == DnsClass::kANY ||
588*3f982cf4SFabien Sanglard pair.second->question().dns_class() == answer.dns_class());
589*3f982cf4SFabien Sanglard });
590*3f982cf4SFabien Sanglard if (it != questions_range.second) {
591*3f982cf4SFabien Sanglard return true;
592*3f982cf4SFabien Sanglard }
593*3f982cf4SFabien Sanglard
594*3f982cf4SFabien Sanglard // If not, check if it corresponds to an already existing record. This is
595*3f982cf4SFabien Sanglard // required because records which are already stored may either have been
596*3f982cf4SFabien Sanglard // received in an additional records section, or are associated with a query
597*3f982cf4SFabien Sanglard // which is no longer active.
598*3f982cf4SFabien Sanglard std::vector<DnsType> types{answer.dns_type()};
599*3f982cf4SFabien Sanglard if (answer.dns_type() == DnsType::kNSEC) {
600*3f982cf4SFabien Sanglard const auto& nsec_rdata = absl::get<NsecRecordRdata>(answer.rdata());
601*3f982cf4SFabien Sanglard types = nsec_rdata.types();
602*3f982cf4SFabien Sanglard }
603*3f982cf4SFabien Sanglard
604*3f982cf4SFabien Sanglard for (DnsType type : types) {
605*3f982cf4SFabien Sanglard std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
606*3f982cf4SFabien Sanglard records_.Find(answer.name(), type, answer.dns_class());
607*3f982cf4SFabien Sanglard if (!trackers.empty()) {
608*3f982cf4SFabien Sanglard return true;
609*3f982cf4SFabien Sanglard }
610*3f982cf4SFabien Sanglard }
611*3f982cf4SFabien Sanglard
612*3f982cf4SFabien Sanglard // In all other cases, the record isn't relevant. Drop it.
613*3f982cf4SFabien Sanglard return false;
614*3f982cf4SFabien Sanglard }
615*3f982cf4SFabien Sanglard
OnRecordExpired(const MdnsRecordTracker * tracker,const MdnsRecord & record)616*3f982cf4SFabien Sanglard void MdnsQuerier::OnRecordExpired(const MdnsRecordTracker* tracker,
617*3f982cf4SFabien Sanglard const MdnsRecord& record) {
618*3f982cf4SFabien Sanglard OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
619*3f982cf4SFabien Sanglard
620*3f982cf4SFabien Sanglard if (!tracker->is_negative_response()) {
621*3f982cf4SFabien Sanglard ProcessCallbacks(record, RecordChangedEvent::kExpired);
622*3f982cf4SFabien Sanglard }
623*3f982cf4SFabien Sanglard
624*3f982cf4SFabien Sanglard records_.Erase(record.name(), [tracker](const MdnsRecordTracker& it_tracker) {
625*3f982cf4SFabien Sanglard return tracker == &it_tracker;
626*3f982cf4SFabien Sanglard });
627*3f982cf4SFabien Sanglard }
628*3f982cf4SFabien Sanglard
ProcessRecord(const MdnsRecord & record)629*3f982cf4SFabien Sanglard void MdnsQuerier::ProcessRecord(const MdnsRecord& record) {
630*3f982cf4SFabien Sanglard OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
631*3f982cf4SFabien Sanglard
632*3f982cf4SFabien Sanglard // Skip all records that can't be processed.
633*3f982cf4SFabien Sanglard if (!CanBeProcessed(record.dns_type())) {
634*3f982cf4SFabien Sanglard return;
635*3f982cf4SFabien Sanglard }
636*3f982cf4SFabien Sanglard
637*3f982cf4SFabien Sanglard // Ignore NSEC records if the embedder has configured us to do so.
638*3f982cf4SFabien Sanglard if (config_.ignore_nsec_responses && record.dns_type() == DnsType::kNSEC) {
639*3f982cf4SFabien Sanglard return;
640*3f982cf4SFabien Sanglard }
641*3f982cf4SFabien Sanglard
642*3f982cf4SFabien Sanglard // Get the types which the received record is associated with. In most cases
643*3f982cf4SFabien Sanglard // this will only be the type of the provided record, but in the case of
644*3f982cf4SFabien Sanglard // NSEC records this will be all records which the record dictates the
645*3f982cf4SFabien Sanglard // nonexistence of.
646*3f982cf4SFabien Sanglard std::vector<DnsType> types;
647*3f982cf4SFabien Sanglard int types_count = 0;
648*3f982cf4SFabien Sanglard const DnsType* types_ptr = nullptr;
649*3f982cf4SFabien Sanglard if (record.dns_type() == DnsType::kNSEC) {
650*3f982cf4SFabien Sanglard const auto& nsec_rdata = absl::get<NsecRecordRdata>(record.rdata());
651*3f982cf4SFabien Sanglard if (std::find(nsec_rdata.types().begin(), nsec_rdata.types().end(),
652*3f982cf4SFabien Sanglard DnsType::kANY) != nsec_rdata.types().end()) {
653*3f982cf4SFabien Sanglard types_ptr = kTranslatedNsecAnyQueryTypes.data();
654*3f982cf4SFabien Sanglard types_count = kTranslatedNsecAnyQueryTypes.size();
655*3f982cf4SFabien Sanglard } else {
656*3f982cf4SFabien Sanglard types_ptr = nsec_rdata.types().data();
657*3f982cf4SFabien Sanglard types_count = nsec_rdata.types().size();
658*3f982cf4SFabien Sanglard }
659*3f982cf4SFabien Sanglard } else {
660*3f982cf4SFabien Sanglard types.push_back(record.dns_type());
661*3f982cf4SFabien Sanglard types_ptr = types.data();
662*3f982cf4SFabien Sanglard types_count = types.size();
663*3f982cf4SFabien Sanglard }
664*3f982cf4SFabien Sanglard
665*3f982cf4SFabien Sanglard // Apply the update for each type that the record is associated with.
666*3f982cf4SFabien Sanglard for (int i = 0; i < types_count; ++i) {
667*3f982cf4SFabien Sanglard DnsType dns_type = types_ptr[i];
668*3f982cf4SFabien Sanglard switch (record.record_type()) {
669*3f982cf4SFabien Sanglard case RecordType::kShared: {
670*3f982cf4SFabien Sanglard ProcessSharedRecord(record, dns_type);
671*3f982cf4SFabien Sanglard break;
672*3f982cf4SFabien Sanglard }
673*3f982cf4SFabien Sanglard case RecordType::kUnique: {
674*3f982cf4SFabien Sanglard ProcessUniqueRecord(record, dns_type);
675*3f982cf4SFabien Sanglard break;
676*3f982cf4SFabien Sanglard }
677*3f982cf4SFabien Sanglard }
678*3f982cf4SFabien Sanglard }
679*3f982cf4SFabien Sanglard }
680*3f982cf4SFabien Sanglard
ProcessSharedRecord(const MdnsRecord & record,DnsType dns_type)681*3f982cf4SFabien Sanglard void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record,
682*3f982cf4SFabien Sanglard DnsType dns_type) {
683*3f982cf4SFabien Sanglard OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
684*3f982cf4SFabien Sanglard OSP_DCHECK(record.record_type() == RecordType::kShared);
685*3f982cf4SFabien Sanglard
686*3f982cf4SFabien Sanglard // By design, NSEC records are never shared records.
687*3f982cf4SFabien Sanglard if (record.dns_type() == DnsType::kNSEC) {
688*3f982cf4SFabien Sanglard return;
689*3f982cf4SFabien Sanglard }
690*3f982cf4SFabien Sanglard
691*3f982cf4SFabien Sanglard // For any records updated, this host already has this shared record. Since
692*3f982cf4SFabien Sanglard // the RDATA matches, this is only a TTL update.
693*3f982cf4SFabien Sanglard auto check = [&record](const MdnsRecordTracker& tracker) {
694*3f982cf4SFabien Sanglard return record.dns_type() == tracker.dns_type() &&
695*3f982cf4SFabien Sanglard record.dns_class() == tracker.dns_class() &&
696*3f982cf4SFabien Sanglard record.rdata() == tracker.rdata();
697*3f982cf4SFabien Sanglard };
698*3f982cf4SFabien Sanglard auto updated_count = records_.Update(record, std::move(check));
699*3f982cf4SFabien Sanglard
700*3f982cf4SFabien Sanglard if (!updated_count) {
701*3f982cf4SFabien Sanglard // Have never before seen this shared record, insert a new one.
702*3f982cf4SFabien Sanglard AddRecord(record, dns_type);
703*3f982cf4SFabien Sanglard ProcessCallbacks(record, RecordChangedEvent::kCreated);
704*3f982cf4SFabien Sanglard }
705*3f982cf4SFabien Sanglard }
706*3f982cf4SFabien Sanglard
ProcessUniqueRecord(const MdnsRecord & record,DnsType dns_type)707*3f982cf4SFabien Sanglard void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
708*3f982cf4SFabien Sanglard DnsType dns_type) {
709*3f982cf4SFabien Sanglard OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
710*3f982cf4SFabien Sanglard OSP_DCHECK(record.record_type() == RecordType::kUnique);
711*3f982cf4SFabien Sanglard
712*3f982cf4SFabien Sanglard std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
713*3f982cf4SFabien Sanglard records_.Find(record.name(), dns_type, record.dns_class());
714*3f982cf4SFabien Sanglard size_t num_records_for_key = trackers.size();
715*3f982cf4SFabien Sanglard
716*3f982cf4SFabien Sanglard // Have not seen any records with this key before. This case is expected the
717*3f982cf4SFabien Sanglard // first time a record is received.
718*3f982cf4SFabien Sanglard if (num_records_for_key == size_t{0}) {
719*3f982cf4SFabien Sanglard const bool will_exist = record.dns_type() != DnsType::kNSEC;
720*3f982cf4SFabien Sanglard AddRecord(record, dns_type);
721*3f982cf4SFabien Sanglard if (will_exist) {
722*3f982cf4SFabien Sanglard ProcessCallbacks(record, RecordChangedEvent::kCreated);
723*3f982cf4SFabien Sanglard }
724*3f982cf4SFabien Sanglard } else if (num_records_for_key == size_t{1}) {
725*3f982cf4SFabien Sanglard // There is exactly one tracker associated with this key. This is the
726*3f982cf4SFabien Sanglard // expected case when a record matching this one has already been seen.
727*3f982cf4SFabien Sanglard ProcessSinglyTrackedUniqueRecord(record, trackers[0]);
728*3f982cf4SFabien Sanglard } else {
729*3f982cf4SFabien Sanglard // Multiple records with the same key.
730*3f982cf4SFabien Sanglard ProcessMultiTrackedUniqueRecord(record, dns_type);
731*3f982cf4SFabien Sanglard }
732*3f982cf4SFabien Sanglard }
733*3f982cf4SFabien Sanglard
ProcessSinglyTrackedUniqueRecord(const MdnsRecord & record,const MdnsRecordTracker & tracker)734*3f982cf4SFabien Sanglard void MdnsQuerier::ProcessSinglyTrackedUniqueRecord(
735*3f982cf4SFabien Sanglard const MdnsRecord& record,
736*3f982cf4SFabien Sanglard const MdnsRecordTracker& tracker) {
737*3f982cf4SFabien Sanglard const bool existed_previously = !tracker.is_negative_response();
738*3f982cf4SFabien Sanglard const bool will_exist = record.dns_type() != DnsType::kNSEC;
739*3f982cf4SFabien Sanglard
740*3f982cf4SFabien Sanglard // Calculate the callback to call on record update success while the old
741*3f982cf4SFabien Sanglard // record still exists.
742*3f982cf4SFabien Sanglard MdnsRecord record_for_callback = record;
743*3f982cf4SFabien Sanglard if (existed_previously && !will_exist) {
744*3f982cf4SFabien Sanglard record_for_callback =
745*3f982cf4SFabien Sanglard MdnsRecord(record.name(), tracker.dns_type(), tracker.dns_class(),
746*3f982cf4SFabien Sanglard tracker.record_type(), tracker.ttl(), tracker.rdata());
747*3f982cf4SFabien Sanglard }
748*3f982cf4SFabien Sanglard
749*3f982cf4SFabien Sanglard auto on_rdata_change = [this, r = std::move(record_for_callback),
750*3f982cf4SFabien Sanglard existed_previously,
751*3f982cf4SFabien Sanglard will_exist](const MdnsRecordTracker& tracker) {
752*3f982cf4SFabien Sanglard // If RDATA on the record is different, notify that the record has
753*3f982cf4SFabien Sanglard // been updated.
754*3f982cf4SFabien Sanglard if (existed_previously && will_exist) {
755*3f982cf4SFabien Sanglard ProcessCallbacks(r, RecordChangedEvent::kUpdated);
756*3f982cf4SFabien Sanglard } else if (existed_previously) {
757*3f982cf4SFabien Sanglard // Do not expire the tracker, because it still holds an NSEC record.
758*3f982cf4SFabien Sanglard ProcessCallbacks(r, RecordChangedEvent::kExpired);
759*3f982cf4SFabien Sanglard } else if (will_exist) {
760*3f982cf4SFabien Sanglard ProcessCallbacks(r, RecordChangedEvent::kCreated);
761*3f982cf4SFabien Sanglard }
762*3f982cf4SFabien Sanglard };
763*3f982cf4SFabien Sanglard
764*3f982cf4SFabien Sanglard int updated_count = records_.Update(
765*3f982cf4SFabien Sanglard record, [&tracker](const MdnsRecordTracker& t) { return &tracker == &t; },
766*3f982cf4SFabien Sanglard std::move(on_rdata_change));
767*3f982cf4SFabien Sanglard OSP_DCHECK_EQ(updated_count, 1);
768*3f982cf4SFabien Sanglard }
769*3f982cf4SFabien Sanglard
ProcessMultiTrackedUniqueRecord(const MdnsRecord & record,DnsType dns_type)770*3f982cf4SFabien Sanglard void MdnsQuerier::ProcessMultiTrackedUniqueRecord(const MdnsRecord& record,
771*3f982cf4SFabien Sanglard DnsType dns_type) {
772*3f982cf4SFabien Sanglard auto update_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
773*3f982cf4SFabien Sanglard return tracker.dns_type() == dns_type &&
774*3f982cf4SFabien Sanglard tracker.dns_class() == record.dns_class() &&
775*3f982cf4SFabien Sanglard tracker.rdata() == record.rdata();
776*3f982cf4SFabien Sanglard };
777*3f982cf4SFabien Sanglard int update_count = records_.Update(
778*3f982cf4SFabien Sanglard record, std::move(update_check),
779*3f982cf4SFabien Sanglard [](const MdnsRecordTracker& tracker) { OSP_NOTREACHED(); });
780*3f982cf4SFabien Sanglard OSP_DCHECK_LE(update_count, 1);
781*3f982cf4SFabien Sanglard
782*3f982cf4SFabien Sanglard auto expire_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
783*3f982cf4SFabien Sanglard return tracker.dns_type() == dns_type &&
784*3f982cf4SFabien Sanglard tracker.dns_class() == record.dns_class() &&
785*3f982cf4SFabien Sanglard tracker.rdata() != record.rdata();
786*3f982cf4SFabien Sanglard };
787*3f982cf4SFabien Sanglard int expire_count =
788*3f982cf4SFabien Sanglard records_.ExpireSoon(record.name(), std::move(expire_check));
789*3f982cf4SFabien Sanglard OSP_DCHECK_GE(expire_count, 1);
790*3f982cf4SFabien Sanglard
791*3f982cf4SFabien Sanglard // Did not find an existing record to update.
792*3f982cf4SFabien Sanglard if (!update_count && !expire_count) {
793*3f982cf4SFabien Sanglard AddRecord(record, dns_type);
794*3f982cf4SFabien Sanglard if (record.dns_type() != DnsType::kNSEC) {
795*3f982cf4SFabien Sanglard ProcessCallbacks(record, RecordChangedEvent::kCreated);
796*3f982cf4SFabien Sanglard }
797*3f982cf4SFabien Sanglard }
798*3f982cf4SFabien Sanglard }
799*3f982cf4SFabien Sanglard
ProcessCallbacks(const MdnsRecord & record,RecordChangedEvent event)800*3f982cf4SFabien Sanglard void MdnsQuerier::ProcessCallbacks(const MdnsRecord& record,
801*3f982cf4SFabien Sanglard RecordChangedEvent event) {
802*3f982cf4SFabien Sanglard OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
803*3f982cf4SFabien Sanglard
804*3f982cf4SFabien Sanglard std::vector<PendingQueryChange> pending_changes;
805*3f982cf4SFabien Sanglard auto callbacks_it = callbacks_.equal_range(record.name());
806*3f982cf4SFabien Sanglard for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
807*3f982cf4SFabien Sanglard const CallbackInfo& callback_info = entry->second;
808*3f982cf4SFabien Sanglard if ((callback_info.dns_type == DnsType::kANY ||
809*3f982cf4SFabien Sanglard record.dns_type() == callback_info.dns_type) &&
810*3f982cf4SFabien Sanglard (callback_info.dns_class == DnsClass::kANY ||
811*3f982cf4SFabien Sanglard record.dns_class() == callback_info.dns_class)) {
812*3f982cf4SFabien Sanglard std::vector<PendingQueryChange> new_changes =
813*3f982cf4SFabien Sanglard callback_info.callback->OnRecordChanged(record, event);
814*3f982cf4SFabien Sanglard pending_changes.insert(pending_changes.end(), new_changes.begin(),
815*3f982cf4SFabien Sanglard new_changes.end());
816*3f982cf4SFabien Sanglard }
817*3f982cf4SFabien Sanglard }
818*3f982cf4SFabien Sanglard
819*3f982cf4SFabien Sanglard ApplyPendingChanges(std::move(pending_changes));
820*3f982cf4SFabien Sanglard }
821*3f982cf4SFabien Sanglard
AddQuestion(const MdnsQuestion & question)822*3f982cf4SFabien Sanglard void MdnsQuerier::AddQuestion(const MdnsQuestion& question) {
823*3f982cf4SFabien Sanglard auto tracker = std::make_unique<MdnsQuestionTracker>(
824*3f982cf4SFabien Sanglard question, sender_, task_runner_, now_function_, random_delay_, config_);
825*3f982cf4SFabien Sanglard MdnsQuestionTracker* ptr = tracker.get();
826*3f982cf4SFabien Sanglard questions_.emplace(question.name(), std::move(tracker));
827*3f982cf4SFabien Sanglard
828*3f982cf4SFabien Sanglard // Let all records associated with this question know that there is a new
829*3f982cf4SFabien Sanglard // query that can be used for their refresh.
830*3f982cf4SFabien Sanglard std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
831*3f982cf4SFabien Sanglard records_.Find(question.name(), question.dns_type(), question.dns_class());
832*3f982cf4SFabien Sanglard for (const MdnsRecordTracker& tracker : trackers) {
833*3f982cf4SFabien Sanglard // NOTE: When the pointed to object is deleted, its dtor removes itself
834*3f982cf4SFabien Sanglard // from all associated records.
835*3f982cf4SFabien Sanglard ptr->AddAssociatedRecord(&tracker);
836*3f982cf4SFabien Sanglard }
837*3f982cf4SFabien Sanglard }
838*3f982cf4SFabien Sanglard
AddRecord(const MdnsRecord & record,DnsType type)839*3f982cf4SFabien Sanglard void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) {
840*3f982cf4SFabien Sanglard // Add the new record.
841*3f982cf4SFabien Sanglard const auto& tracker = records_.StartTracking(record, type);
842*3f982cf4SFabien Sanglard
843*3f982cf4SFabien Sanglard // Let all questions associated with this record know that there is a new
844*3f982cf4SFabien Sanglard // record that answers them (for known answer suppression).
845*3f982cf4SFabien Sanglard auto query_it = questions_.equal_range(record.name());
846*3f982cf4SFabien Sanglard for (auto entry = query_it.first; entry != query_it.second; ++entry) {
847*3f982cf4SFabien Sanglard const MdnsQuestion& query = entry->second->question();
848*3f982cf4SFabien Sanglard const bool is_relevant_type =
849*3f982cf4SFabien Sanglard type == DnsType::kANY || type == query.dns_type();
850*3f982cf4SFabien Sanglard const bool is_relevant_class = record.dns_class() == DnsClass::kANY ||
851*3f982cf4SFabien Sanglard record.dns_class() == query.dns_class();
852*3f982cf4SFabien Sanglard if (is_relevant_type && is_relevant_class) {
853*3f982cf4SFabien Sanglard // NOTE: When the pointed to object is deleted, its dtor removes itself
854*3f982cf4SFabien Sanglard // from all associated queries.
855*3f982cf4SFabien Sanglard entry->second->AddAssociatedRecord(&tracker);
856*3f982cf4SFabien Sanglard }
857*3f982cf4SFabien Sanglard }
858*3f982cf4SFabien Sanglard }
859*3f982cf4SFabien Sanglard
ApplyPendingChanges(std::vector<PendingQueryChange> pending_changes)860*3f982cf4SFabien Sanglard void MdnsQuerier::ApplyPendingChanges(
861*3f982cf4SFabien Sanglard std::vector<PendingQueryChange> pending_changes) {
862*3f982cf4SFabien Sanglard for (auto& pending_change : pending_changes) {
863*3f982cf4SFabien Sanglard switch (pending_change.change_type) {
864*3f982cf4SFabien Sanglard case PendingQueryChange::kStartQuery:
865*3f982cf4SFabien Sanglard StartQuery(std::move(pending_change.name), pending_change.dns_type,
866*3f982cf4SFabien Sanglard pending_change.dns_class, pending_change.callback);
867*3f982cf4SFabien Sanglard break;
868*3f982cf4SFabien Sanglard case PendingQueryChange::kStopQuery:
869*3f982cf4SFabien Sanglard StopQuery(std::move(pending_change.name), pending_change.dns_type,
870*3f982cf4SFabien Sanglard pending_change.dns_class, pending_change.callback);
871*3f982cf4SFabien Sanglard break;
872*3f982cf4SFabien Sanglard }
873*3f982cf4SFabien Sanglard }
874*3f982cf4SFabien Sanglard }
875*3f982cf4SFabien Sanglard
876*3f982cf4SFabien Sanglard } // namespace discovery
877*3f982cf4SFabien Sanglard } // namespace openscreen
878