xref: /aosp_15_r20/external/openscreen/discovery/mdns/mdns_querier.cc (revision 3f982cf4871df8771c9d4abe6e9a6f8d829b2736)
1*3f982cf4SFabien Sanglard // Copyright 2019 The Chromium Authors. All rights reserved.
2*3f982cf4SFabien Sanglard // Use of this source code is governed by a BSD-style license that can be
3*3f982cf4SFabien Sanglard // found in the LICENSE file.
4*3f982cf4SFabien Sanglard 
5*3f982cf4SFabien Sanglard #include "discovery/mdns/mdns_querier.h"
6*3f982cf4SFabien Sanglard 
7*3f982cf4SFabien Sanglard #include <algorithm>
8*3f982cf4SFabien Sanglard #include <array>
9*3f982cf4SFabien Sanglard #include <bitset>
10*3f982cf4SFabien Sanglard #include <memory>
11*3f982cf4SFabien Sanglard #include <unordered_set>
12*3f982cf4SFabien Sanglard #include <utility>
13*3f982cf4SFabien Sanglard #include <vector>
14*3f982cf4SFabien Sanglard 
15*3f982cf4SFabien Sanglard #include "discovery/common/config.h"
16*3f982cf4SFabien Sanglard #include "discovery/common/reporting_client.h"
17*3f982cf4SFabien Sanglard #include "discovery/mdns/mdns_random.h"
18*3f982cf4SFabien Sanglard #include "discovery/mdns/mdns_receiver.h"
19*3f982cf4SFabien Sanglard #include "discovery/mdns/mdns_sender.h"
20*3f982cf4SFabien Sanglard #include "discovery/mdns/public/mdns_constants.h"
21*3f982cf4SFabien Sanglard 
22*3f982cf4SFabien Sanglard namespace openscreen {
23*3f982cf4SFabien Sanglard namespace discovery {
24*3f982cf4SFabien Sanglard namespace {
25*3f982cf4SFabien Sanglard 
26*3f982cf4SFabien Sanglard constexpr std::array<DnsType, 5> kTranslatedNsecAnyQueryTypes = {
27*3f982cf4SFabien Sanglard     DnsType::kA, DnsType::kPTR, DnsType::kTXT, DnsType::kAAAA, DnsType::kSRV};
28*3f982cf4SFabien Sanglard 
IsNegativeResponseFor(const MdnsRecord & record,DnsType type)29*3f982cf4SFabien Sanglard bool IsNegativeResponseFor(const MdnsRecord& record, DnsType type) {
30*3f982cf4SFabien Sanglard   if (record.dns_type() != DnsType::kNSEC) {
31*3f982cf4SFabien Sanglard     return false;
32*3f982cf4SFabien Sanglard   }
33*3f982cf4SFabien Sanglard 
34*3f982cf4SFabien Sanglard   const NsecRecordRdata& nsec = absl::get<NsecRecordRdata>(record.rdata());
35*3f982cf4SFabien Sanglard 
36*3f982cf4SFabien Sanglard   // RFC 6762 section 6.1, the NSEC bit must NOT be set in the received NSEC
37*3f982cf4SFabien Sanglard   // record to indicate this is an mDNS NSEC record rather than a traditional
38*3f982cf4SFabien Sanglard   // DNS NSEC record.
39*3f982cf4SFabien Sanglard   if (std::find(nsec.types().begin(), nsec.types().end(), DnsType::kNSEC) !=
40*3f982cf4SFabien Sanglard       nsec.types().end()) {
41*3f982cf4SFabien Sanglard     return false;
42*3f982cf4SFabien Sanglard   }
43*3f982cf4SFabien Sanglard 
44*3f982cf4SFabien Sanglard   return std::find_if(nsec.types().begin(), nsec.types().end(),
45*3f982cf4SFabien Sanglard                       [type](DnsType stored_type) {
46*3f982cf4SFabien Sanglard                         return stored_type == type ||
47*3f982cf4SFabien Sanglard                                stored_type == DnsType::kANY;
48*3f982cf4SFabien Sanglard                       }) != nsec.types().end();
49*3f982cf4SFabien Sanglard }
50*3f982cf4SFabien Sanglard 
51*3f982cf4SFabien Sanglard struct HashDnsType {
operator ()openscreen::discovery::__anone816dc120111::HashDnsType52*3f982cf4SFabien Sanglard   inline size_t operator()(DnsType type) const {
53*3f982cf4SFabien Sanglard     return static_cast<size_t>(type);
54*3f982cf4SFabien Sanglard   }
55*3f982cf4SFabien Sanglard };
56*3f982cf4SFabien Sanglard 
57*3f982cf4SFabien Sanglard // Helper used for sorting MDNS records. This function guarantees the following:
58*3f982cf4SFabien Sanglard // - All MdnsRecords with the same name appear adjacent to each-other.
59*3f982cf4SFabien Sanglard // - An NSEC record with a given name appears before all other records with the
60*3f982cf4SFabien Sanglard //   same name.
CompareRecordByNameAndType(const MdnsRecord & first,const MdnsRecord & second)61*3f982cf4SFabien Sanglard bool CompareRecordByNameAndType(const MdnsRecord& first,
62*3f982cf4SFabien Sanglard                                 const MdnsRecord& second) {
63*3f982cf4SFabien Sanglard   if (first.name() != second.name()) {
64*3f982cf4SFabien Sanglard     return first.name() < second.name();
65*3f982cf4SFabien Sanglard   }
66*3f982cf4SFabien Sanglard 
67*3f982cf4SFabien Sanglard   if ((first.dns_type() == DnsType::kNSEC) !=
68*3f982cf4SFabien Sanglard       (second.dns_type() == DnsType::kNSEC)) {
69*3f982cf4SFabien Sanglard     return first.dns_type() == DnsType::kNSEC;
70*3f982cf4SFabien Sanglard   }
71*3f982cf4SFabien Sanglard 
72*3f982cf4SFabien Sanglard   return first < second;
73*3f982cf4SFabien Sanglard }
74*3f982cf4SFabien Sanglard 
75*3f982cf4SFabien Sanglard class DnsTypeBitSet {
76*3f982cf4SFabien Sanglard  public:
77*3f982cf4SFabien Sanglard   // Returns whether any types are currently stored in this data structure.
IsEmpty()78*3f982cf4SFabien Sanglard   bool IsEmpty() { return !elements_.any(); }
79*3f982cf4SFabien Sanglard 
80*3f982cf4SFabien Sanglard   // Attempts to insert the given type into this data structure. Returns
81*3f982cf4SFabien Sanglard   // true iff the type was not already present.
Insert(DnsType type)82*3f982cf4SFabien Sanglard   bool Insert(DnsType type) {
83*3f982cf4SFabien Sanglard     uint16_t bit = (type == DnsType::kANY) ? 0 : static_cast<uint16_t>(type);
84*3f982cf4SFabien Sanglard     bool was_set = elements_.test(bit);
85*3f982cf4SFabien Sanglard     elements_.set(bit);
86*3f982cf4SFabien Sanglard     return !was_set;
87*3f982cf4SFabien Sanglard   }
88*3f982cf4SFabien Sanglard 
89*3f982cf4SFabien Sanglard   // Iterates over all members of the provided container, inserting each
90*3f982cf4SFabien Sanglard   // DnsType contained within to this instance. Returns true iff any element
91*3f982cf4SFabien Sanglard   // inserted was not already present in this instance.
92*3f982cf4SFabien Sanglard   template <typename Container>
Insert(const Container & container)93*3f982cf4SFabien Sanglard   bool Insert(const Container& container) {
94*3f982cf4SFabien Sanglard     bool has_element_been_inserted = false;
95*3f982cf4SFabien Sanglard     for (DnsType type : container) {
96*3f982cf4SFabien Sanglard       has_element_been_inserted |= Insert(type);
97*3f982cf4SFabien Sanglard     }
98*3f982cf4SFabien Sanglard     return has_element_been_inserted;
99*3f982cf4SFabien Sanglard   }
100*3f982cf4SFabien Sanglard 
101*3f982cf4SFabien Sanglard   // Attempts to remove the given type from this data structure. Returns true
102*3f982cf4SFabien Sanglard   // iff the type was present prior to this call.
Remove(DnsType type)103*3f982cf4SFabien Sanglard   bool Remove(DnsType type) {
104*3f982cf4SFabien Sanglard     if (IsEmpty()) {
105*3f982cf4SFabien Sanglard       return false;
106*3f982cf4SFabien Sanglard     } else if (type == DnsType::kANY) {
107*3f982cf4SFabien Sanglard       elements_.reset();
108*3f982cf4SFabien Sanglard       return true;
109*3f982cf4SFabien Sanglard     }
110*3f982cf4SFabien Sanglard 
111*3f982cf4SFabien Sanglard     uint16_t bit = static_cast<uint16_t>(type);
112*3f982cf4SFabien Sanglard     bool was_set = elements_.test(bit);
113*3f982cf4SFabien Sanglard     elements_.reset(bit);
114*3f982cf4SFabien Sanglard     return was_set;
115*3f982cf4SFabien Sanglard   }
116*3f982cf4SFabien Sanglard 
117*3f982cf4SFabien Sanglard   // Returns the DnsTypes currently stored in this data structure.
GetTypes() const118*3f982cf4SFabien Sanglard   std::vector<DnsType> GetTypes() const {
119*3f982cf4SFabien Sanglard     if (elements_.test(0)) {
120*3f982cf4SFabien Sanglard       return {DnsType::kANY};
121*3f982cf4SFabien Sanglard     }
122*3f982cf4SFabien Sanglard 
123*3f982cf4SFabien Sanglard     std::vector<DnsType> types;
124*3f982cf4SFabien Sanglard     for (DnsType type : kSupportedDnsTypes) {
125*3f982cf4SFabien Sanglard       if (type == DnsType::kANY) {
126*3f982cf4SFabien Sanglard         continue;
127*3f982cf4SFabien Sanglard       }
128*3f982cf4SFabien Sanglard 
129*3f982cf4SFabien Sanglard       uint16_t cast_int = static_cast<uint16_t>(type);
130*3f982cf4SFabien Sanglard       if (elements_.test(cast_int)) {
131*3f982cf4SFabien Sanglard         types.push_back(type);
132*3f982cf4SFabien Sanglard       }
133*3f982cf4SFabien Sanglard     }
134*3f982cf4SFabien Sanglard     return types;
135*3f982cf4SFabien Sanglard   }
136*3f982cf4SFabien Sanglard 
137*3f982cf4SFabien Sanglard  private:
138*3f982cf4SFabien Sanglard   std::bitset<64> elements_;
139*3f982cf4SFabien Sanglard };
140*3f982cf4SFabien Sanglard 
141*3f982cf4SFabien Sanglard // Modifies |records| such that no NSEC record signifies the nonexistance of a
142*3f982cf4SFabien Sanglard // record which is also present in the same message. Order of the input vector
143*3f982cf4SFabien Sanglard // is NOT preserved.
144*3f982cf4SFabien Sanglard // NOTE: |records| is not of type MdnsRecord::ConstRef because the members must
145*3f982cf4SFabien Sanglard // be modified.
146*3f982cf4SFabien Sanglard // TODO(b/170353378): Break this logic into a separate processing module between
147*3f982cf4SFabien Sanglard // the MdnsReader and the MdnsQuerier.
RemoveInvalidNsecFlags(std::vector<MdnsRecord> * records)148*3f982cf4SFabien Sanglard void RemoveInvalidNsecFlags(std::vector<MdnsRecord>* records) {
149*3f982cf4SFabien Sanglard   // Sort the records so NSEC records are first so that only one iteration
150*3f982cf4SFabien Sanglard   // through all records is needed.
151*3f982cf4SFabien Sanglard   std::sort(records->begin(), records->end(), CompareRecordByNameAndType);
152*3f982cf4SFabien Sanglard 
153*3f982cf4SFabien Sanglard   // The set of NSEC records that need to be removed from |records|. This can't
154*3f982cf4SFabien Sanglard   // be done as part of the below loop because it would invalidate the iterator
155*3f982cf4SFabien Sanglard   // that's still being used.
156*3f982cf4SFabien Sanglard   std::vector<std::vector<MdnsRecord>::iterator> nsecs_to_delete;
157*3f982cf4SFabien Sanglard 
158*3f982cf4SFabien Sanglard   // Process all elements.
159*3f982cf4SFabien Sanglard   for (auto it = records->begin(); it != records->end();) {
160*3f982cf4SFabien Sanglard     if (it->dns_type() != DnsType::kNSEC) {
161*3f982cf4SFabien Sanglard       it++;
162*3f982cf4SFabien Sanglard       continue;
163*3f982cf4SFabien Sanglard     }
164*3f982cf4SFabien Sanglard 
165*3f982cf4SFabien Sanglard     // Track whether the current NSEC record in the input vector has been
166*3f982cf4SFabien Sanglard     // modified by some step of this algorithm, be that merging with another
167*3f982cf4SFabien Sanglard     // record, removing a DnsType, or any other modification.
168*3f982cf4SFabien Sanglard     bool has_changed = false;
169*3f982cf4SFabien Sanglard 
170*3f982cf4SFabien Sanglard     // The types for the new record to create, if |has_changed|.
171*3f982cf4SFabien Sanglard     const NsecRecordRdata& nsec_rdata = absl::get<NsecRecordRdata>(it->rdata());
172*3f982cf4SFabien Sanglard     DnsTypeBitSet types;
173*3f982cf4SFabien Sanglard     for (DnsType type : nsec_rdata.types()) {
174*3f982cf4SFabien Sanglard       types.Insert(type);
175*3f982cf4SFabien Sanglard     }
176*3f982cf4SFabien Sanglard     auto nsec = it;
177*3f982cf4SFabien Sanglard     it++;
178*3f982cf4SFabien Sanglard 
179*3f982cf4SFabien Sanglard     // Combine multiple NSECs to simplify the following code. This probably
180*3f982cf4SFabien Sanglard     // won't happen, but the RFC doesn't exclude the possibility, so account for
181*3f982cf4SFabien Sanglard     // it. Define the TTL of this new NSEC record created by this merge process
182*3f982cf4SFabien Sanglard     // to be the minimum of all merged NSEC records.
183*3f982cf4SFabien Sanglard     std::chrono::seconds new_ttl = nsec->ttl();
184*3f982cf4SFabien Sanglard     while (it != records->end() && it->name() == nsec->name() &&
185*3f982cf4SFabien Sanglard            it->dns_type() == DnsType::kNSEC) {
186*3f982cf4SFabien Sanglard       has_changed |=
187*3f982cf4SFabien Sanglard           types.Insert(absl::get<NsecRecordRdata>(it->rdata()).types());
188*3f982cf4SFabien Sanglard       new_ttl = std::min(new_ttl, it->ttl());
189*3f982cf4SFabien Sanglard       it = records->erase(it);
190*3f982cf4SFabien Sanglard     }
191*3f982cf4SFabien Sanglard 
192*3f982cf4SFabien Sanglard     // Remove any types associated with a known record type.
193*3f982cf4SFabien Sanglard     for (; it != records->end() && it->name() == nsec->name(); it++) {
194*3f982cf4SFabien Sanglard       OSP_DCHECK(it->dns_type() != DnsType::kNSEC);
195*3f982cf4SFabien Sanglard       has_changed |= types.Remove(it->dns_type());
196*3f982cf4SFabien Sanglard     }
197*3f982cf4SFabien Sanglard 
198*3f982cf4SFabien Sanglard     // Modify the stored NSEC record, if needed.
199*3f982cf4SFabien Sanglard     if (has_changed && types.IsEmpty()) {
200*3f982cf4SFabien Sanglard       nsecs_to_delete.push_back(nsec);
201*3f982cf4SFabien Sanglard     } else if (has_changed) {
202*3f982cf4SFabien Sanglard       NsecRecordRdata new_rdata(nsec_rdata.next_domain_name(),
203*3f982cf4SFabien Sanglard                                 types.GetTypes());
204*3f982cf4SFabien Sanglard       *nsec = MdnsRecord(nsec->name(), nsec->dns_type(), nsec->dns_class(),
205*3f982cf4SFabien Sanglard                          nsec->record_type(), new_ttl, std::move(new_rdata));
206*3f982cf4SFabien Sanglard     }
207*3f982cf4SFabien Sanglard   }
208*3f982cf4SFabien Sanglard 
209*3f982cf4SFabien Sanglard   // Erase invalid NSEC records. Go backwards to avoid invalidating the
210*3f982cf4SFabien Sanglard   // remaining iterators.
211*3f982cf4SFabien Sanglard   for (auto erase_it = nsecs_to_delete.rbegin();
212*3f982cf4SFabien Sanglard        erase_it != nsecs_to_delete.rend(); erase_it++) {
213*3f982cf4SFabien Sanglard     records->erase(*erase_it);
214*3f982cf4SFabien Sanglard   }
215*3f982cf4SFabien Sanglard }
216*3f982cf4SFabien Sanglard 
217*3f982cf4SFabien Sanglard }  // namespace
218*3f982cf4SFabien Sanglard 
RecordTrackerLruCache(MdnsQuerier * querier,MdnsSender * sender,MdnsRandom * random_delay,TaskRunner * task_runner,ClockNowFunctionPtr now_function,ReportingClient * reporting_client,const Config & config)219*3f982cf4SFabien Sanglard MdnsQuerier::RecordTrackerLruCache::RecordTrackerLruCache(
220*3f982cf4SFabien Sanglard     MdnsQuerier* querier,
221*3f982cf4SFabien Sanglard     MdnsSender* sender,
222*3f982cf4SFabien Sanglard     MdnsRandom* random_delay,
223*3f982cf4SFabien Sanglard     TaskRunner* task_runner,
224*3f982cf4SFabien Sanglard     ClockNowFunctionPtr now_function,
225*3f982cf4SFabien Sanglard     ReportingClient* reporting_client,
226*3f982cf4SFabien Sanglard     const Config& config)
227*3f982cf4SFabien Sanglard     : querier_(querier),
228*3f982cf4SFabien Sanglard       sender_(sender),
229*3f982cf4SFabien Sanglard       random_delay_(random_delay),
230*3f982cf4SFabien Sanglard       task_runner_(task_runner),
231*3f982cf4SFabien Sanglard       now_function_(now_function),
232*3f982cf4SFabien Sanglard       reporting_client_(reporting_client),
233*3f982cf4SFabien Sanglard       config_(config) {
234*3f982cf4SFabien Sanglard   OSP_DCHECK(sender_);
235*3f982cf4SFabien Sanglard   OSP_DCHECK(random_delay_);
236*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_);
237*3f982cf4SFabien Sanglard   OSP_DCHECK(reporting_client_);
238*3f982cf4SFabien Sanglard   OSP_DCHECK_GT(config_.querier_max_records_cached, 0);
239*3f982cf4SFabien Sanglard }
240*3f982cf4SFabien Sanglard 
241*3f982cf4SFabien Sanglard std::vector<std::reference_wrapper<const MdnsRecordTracker>>
Find(const DomainName & name)242*3f982cf4SFabien Sanglard MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name) {
243*3f982cf4SFabien Sanglard   return Find(name, DnsType::kANY, DnsClass::kANY);
244*3f982cf4SFabien Sanglard }
245*3f982cf4SFabien Sanglard 
246*3f982cf4SFabien Sanglard std::vector<std::reference_wrapper<const MdnsRecordTracker>>
Find(const DomainName & name,DnsType dns_type,DnsClass dns_class)247*3f982cf4SFabien Sanglard MdnsQuerier::RecordTrackerLruCache::Find(const DomainName& name,
248*3f982cf4SFabien Sanglard                                          DnsType dns_type,
249*3f982cf4SFabien Sanglard                                          DnsClass dns_class) {
250*3f982cf4SFabien Sanglard   std::vector<RecordTrackerConstRef> results;
251*3f982cf4SFabien Sanglard   auto pair = records_.equal_range(name);
252*3f982cf4SFabien Sanglard   for (auto it = pair.first; it != pair.second; it++) {
253*3f982cf4SFabien Sanglard     const MdnsRecordTracker& tracker = *it->second;
254*3f982cf4SFabien Sanglard     if ((dns_type == DnsType::kANY || dns_type == tracker.dns_type()) &&
255*3f982cf4SFabien Sanglard         (dns_class == DnsClass::kANY || dns_class == tracker.dns_class())) {
256*3f982cf4SFabien Sanglard       results.push_back(std::cref(tracker));
257*3f982cf4SFabien Sanglard     }
258*3f982cf4SFabien Sanglard   }
259*3f982cf4SFabien Sanglard 
260*3f982cf4SFabien Sanglard   return results;
261*3f982cf4SFabien Sanglard }
262*3f982cf4SFabien Sanglard 
Erase(const DomainName & domain,TrackerApplicableCheck check)263*3f982cf4SFabien Sanglard int MdnsQuerier::RecordTrackerLruCache::Erase(const DomainName& domain,
264*3f982cf4SFabien Sanglard                                               TrackerApplicableCheck check) {
265*3f982cf4SFabien Sanglard   auto pair = records_.equal_range(domain);
266*3f982cf4SFabien Sanglard   int count = 0;
267*3f982cf4SFabien Sanglard   for (RecordMap::iterator it = pair.first; it != pair.second;) {
268*3f982cf4SFabien Sanglard     if (check(*it->second)) {
269*3f982cf4SFabien Sanglard       lru_order_.erase(it->second);
270*3f982cf4SFabien Sanglard       it = records_.erase(it);
271*3f982cf4SFabien Sanglard       count++;
272*3f982cf4SFabien Sanglard     } else {
273*3f982cf4SFabien Sanglard       it++;
274*3f982cf4SFabien Sanglard     }
275*3f982cf4SFabien Sanglard   }
276*3f982cf4SFabien Sanglard 
277*3f982cf4SFabien Sanglard   return count;
278*3f982cf4SFabien Sanglard }
279*3f982cf4SFabien Sanglard 
ExpireSoon(const DomainName & domain,TrackerApplicableCheck check)280*3f982cf4SFabien Sanglard int MdnsQuerier::RecordTrackerLruCache::ExpireSoon(
281*3f982cf4SFabien Sanglard     const DomainName& domain,
282*3f982cf4SFabien Sanglard     TrackerApplicableCheck check) {
283*3f982cf4SFabien Sanglard   auto pair = records_.equal_range(domain);
284*3f982cf4SFabien Sanglard   int count = 0;
285*3f982cf4SFabien Sanglard   for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
286*3f982cf4SFabien Sanglard     if (check(*it->second)) {
287*3f982cf4SFabien Sanglard       MoveToEnd(it);
288*3f982cf4SFabien Sanglard       it->second->ExpireSoon();
289*3f982cf4SFabien Sanglard       count++;
290*3f982cf4SFabien Sanglard     }
291*3f982cf4SFabien Sanglard   }
292*3f982cf4SFabien Sanglard 
293*3f982cf4SFabien Sanglard   return count;
294*3f982cf4SFabien Sanglard }
295*3f982cf4SFabien Sanglard 
Update(const MdnsRecord & record,TrackerApplicableCheck check)296*3f982cf4SFabien Sanglard int MdnsQuerier::RecordTrackerLruCache::Update(const MdnsRecord& record,
297*3f982cf4SFabien Sanglard                                                TrackerApplicableCheck check) {
298*3f982cf4SFabien Sanglard   return Update(record, check, [](const MdnsRecordTracker& t) {});
299*3f982cf4SFabien Sanglard }
300*3f982cf4SFabien Sanglard 
Update(const MdnsRecord & record,TrackerApplicableCheck check,TrackerChangeCallback on_rdata_update)301*3f982cf4SFabien Sanglard int MdnsQuerier::RecordTrackerLruCache::Update(
302*3f982cf4SFabien Sanglard     const MdnsRecord& record,
303*3f982cf4SFabien Sanglard     TrackerApplicableCheck check,
304*3f982cf4SFabien Sanglard     TrackerChangeCallback on_rdata_update) {
305*3f982cf4SFabien Sanglard   auto pair = records_.equal_range(record.name());
306*3f982cf4SFabien Sanglard   int count = 0;
307*3f982cf4SFabien Sanglard   for (RecordMap::iterator it = pair.first; it != pair.second; it++) {
308*3f982cf4SFabien Sanglard     if (check(*it->second)) {
309*3f982cf4SFabien Sanglard       auto result = it->second->Update(record);
310*3f982cf4SFabien Sanglard 
311*3f982cf4SFabien Sanglard       if (result.is_error()) {
312*3f982cf4SFabien Sanglard         reporting_client_->OnRecoverableError(
313*3f982cf4SFabien Sanglard             Error(Error::Code::kUpdateReceivedRecordFailure,
314*3f982cf4SFabien Sanglard                   result.error().ToString()));
315*3f982cf4SFabien Sanglard         continue;
316*3f982cf4SFabien Sanglard       }
317*3f982cf4SFabien Sanglard 
318*3f982cf4SFabien Sanglard       count++;
319*3f982cf4SFabien Sanglard       if (result.value() == MdnsRecordTracker::UpdateType::kGoodbye) {
320*3f982cf4SFabien Sanglard         it->second->ExpireSoon();
321*3f982cf4SFabien Sanglard         MoveToEnd(it);
322*3f982cf4SFabien Sanglard       } else {
323*3f982cf4SFabien Sanglard         MoveToBeginning(it);
324*3f982cf4SFabien Sanglard         if (result.value() == MdnsRecordTracker::UpdateType::kRdata) {
325*3f982cf4SFabien Sanglard           on_rdata_update(*it->second);
326*3f982cf4SFabien Sanglard         }
327*3f982cf4SFabien Sanglard       }
328*3f982cf4SFabien Sanglard     }
329*3f982cf4SFabien Sanglard   }
330*3f982cf4SFabien Sanglard 
331*3f982cf4SFabien Sanglard   return count;
332*3f982cf4SFabien Sanglard }
333*3f982cf4SFabien Sanglard 
StartTracking(MdnsRecord record,DnsType dns_type)334*3f982cf4SFabien Sanglard const MdnsRecordTracker& MdnsQuerier::RecordTrackerLruCache::StartTracking(
335*3f982cf4SFabien Sanglard     MdnsRecord record,
336*3f982cf4SFabien Sanglard     DnsType dns_type) {
337*3f982cf4SFabien Sanglard   auto expiration_callback = [this](const MdnsRecordTracker* tracker,
338*3f982cf4SFabien Sanglard                                     const MdnsRecord& record) {
339*3f982cf4SFabien Sanglard     querier_->OnRecordExpired(tracker, record);
340*3f982cf4SFabien Sanglard   };
341*3f982cf4SFabien Sanglard 
342*3f982cf4SFabien Sanglard   while (lru_order_.size() >=
343*3f982cf4SFabien Sanglard          static_cast<size_t>(config_.querier_max_records_cached)) {
344*3f982cf4SFabien Sanglard     // This call erases one of the tracked records.
345*3f982cf4SFabien Sanglard     OSP_DVLOG << "Maximum cacheable record count exceeded ("
346*3f982cf4SFabien Sanglard               << config_.querier_max_records_cached << ")";
347*3f982cf4SFabien Sanglard     lru_order_.back().ExpireNow();
348*3f982cf4SFabien Sanglard   }
349*3f982cf4SFabien Sanglard 
350*3f982cf4SFabien Sanglard   auto name = record.name();
351*3f982cf4SFabien Sanglard   lru_order_.emplace_front(std::move(record), dns_type, sender_, task_runner_,
352*3f982cf4SFabien Sanglard                            now_function_, random_delay_,
353*3f982cf4SFabien Sanglard                            std::move(expiration_callback));
354*3f982cf4SFabien Sanglard   records_.emplace(std::move(name), lru_order_.begin());
355*3f982cf4SFabien Sanglard 
356*3f982cf4SFabien Sanglard   return lru_order_.front();
357*3f982cf4SFabien Sanglard }
358*3f982cf4SFabien Sanglard 
MoveToBeginning(MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it)359*3f982cf4SFabien Sanglard void MdnsQuerier::RecordTrackerLruCache::MoveToBeginning(
360*3f982cf4SFabien Sanglard     MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
361*3f982cf4SFabien Sanglard   lru_order_.splice(lru_order_.begin(), lru_order_, it->second);
362*3f982cf4SFabien Sanglard   it->second = lru_order_.begin();
363*3f982cf4SFabien Sanglard }
364*3f982cf4SFabien Sanglard 
MoveToEnd(MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it)365*3f982cf4SFabien Sanglard void MdnsQuerier::RecordTrackerLruCache::MoveToEnd(
366*3f982cf4SFabien Sanglard     MdnsQuerier::RecordTrackerLruCache::RecordMap::iterator it) {
367*3f982cf4SFabien Sanglard   lru_order_.splice(lru_order_.end(), lru_order_, it->second);
368*3f982cf4SFabien Sanglard   it->second = --lru_order_.end();
369*3f982cf4SFabien Sanglard }
370*3f982cf4SFabien Sanglard 
MdnsQuerier(MdnsSender * sender,MdnsReceiver * receiver,TaskRunner * task_runner,ClockNowFunctionPtr now_function,MdnsRandom * random_delay,ReportingClient * reporting_client,Config config)371*3f982cf4SFabien Sanglard MdnsQuerier::MdnsQuerier(MdnsSender* sender,
372*3f982cf4SFabien Sanglard                          MdnsReceiver* receiver,
373*3f982cf4SFabien Sanglard                          TaskRunner* task_runner,
374*3f982cf4SFabien Sanglard                          ClockNowFunctionPtr now_function,
375*3f982cf4SFabien Sanglard                          MdnsRandom* random_delay,
376*3f982cf4SFabien Sanglard                          ReportingClient* reporting_client,
377*3f982cf4SFabien Sanglard                          Config config)
378*3f982cf4SFabien Sanglard     : sender_(sender),
379*3f982cf4SFabien Sanglard       receiver_(receiver),
380*3f982cf4SFabien Sanglard       task_runner_(task_runner),
381*3f982cf4SFabien Sanglard       now_function_(now_function),
382*3f982cf4SFabien Sanglard       random_delay_(random_delay),
383*3f982cf4SFabien Sanglard       reporting_client_(reporting_client),
384*3f982cf4SFabien Sanglard       config_(std::move(config)),
385*3f982cf4SFabien Sanglard       records_(this,
386*3f982cf4SFabien Sanglard                sender_,
387*3f982cf4SFabien Sanglard                random_delay_,
388*3f982cf4SFabien Sanglard                task_runner_,
389*3f982cf4SFabien Sanglard                now_function_,
390*3f982cf4SFabien Sanglard                reporting_client_,
391*3f982cf4SFabien Sanglard                config_) {
392*3f982cf4SFabien Sanglard   OSP_DCHECK(sender_);
393*3f982cf4SFabien Sanglard   OSP_DCHECK(receiver_);
394*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_);
395*3f982cf4SFabien Sanglard   OSP_DCHECK(now_function_);
396*3f982cf4SFabien Sanglard   OSP_DCHECK(random_delay_);
397*3f982cf4SFabien Sanglard   OSP_DCHECK(reporting_client_);
398*3f982cf4SFabien Sanglard 
399*3f982cf4SFabien Sanglard   receiver_->AddResponseCallback(this);
400*3f982cf4SFabien Sanglard }
401*3f982cf4SFabien Sanglard 
~MdnsQuerier()402*3f982cf4SFabien Sanglard MdnsQuerier::~MdnsQuerier() {
403*3f982cf4SFabien Sanglard   receiver_->RemoveResponseCallback(this);
404*3f982cf4SFabien Sanglard }
405*3f982cf4SFabien Sanglard 
406*3f982cf4SFabien Sanglard // NOTE: The code below is range loops instead of std:find_if, for better
407*3f982cf4SFabien Sanglard // readability, brevity and homogeneity.  Using std::find_if results in a few
408*3f982cf4SFabien Sanglard // more lines of code, readability suffers from extra lambdas.
409*3f982cf4SFabien Sanglard 
StartQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)410*3f982cf4SFabien Sanglard void MdnsQuerier::StartQuery(const DomainName& name,
411*3f982cf4SFabien Sanglard                              DnsType dns_type,
412*3f982cf4SFabien Sanglard                              DnsClass dns_class,
413*3f982cf4SFabien Sanglard                              MdnsRecordChangedCallback* callback) {
414*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
415*3f982cf4SFabien Sanglard   OSP_DCHECK(callback);
416*3f982cf4SFabien Sanglard   OSP_DCHECK(CanBeQueried(dns_type));
417*3f982cf4SFabien Sanglard 
418*3f982cf4SFabien Sanglard   // Add a new callback if haven't seen it before
419*3f982cf4SFabien Sanglard   auto callbacks_it = callbacks_.equal_range(name);
420*3f982cf4SFabien Sanglard   for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
421*3f982cf4SFabien Sanglard     const CallbackInfo& callback_info = entry->second;
422*3f982cf4SFabien Sanglard     if (dns_type == callback_info.dns_type &&
423*3f982cf4SFabien Sanglard         dns_class == callback_info.dns_class &&
424*3f982cf4SFabien Sanglard         callback == callback_info.callback) {
425*3f982cf4SFabien Sanglard       // Already have this callback
426*3f982cf4SFabien Sanglard       return;
427*3f982cf4SFabien Sanglard     }
428*3f982cf4SFabien Sanglard   }
429*3f982cf4SFabien Sanglard   callbacks_.emplace(name, CallbackInfo{callback, dns_type, dns_class});
430*3f982cf4SFabien Sanglard 
431*3f982cf4SFabien Sanglard   // Notify the new callback with previously cached records.
432*3f982cf4SFabien Sanglard   // NOTE: In the future, could allow callers to fetch cached records after
433*3f982cf4SFabien Sanglard   // adding a callback, for example to prime the UI.
434*3f982cf4SFabien Sanglard   std::vector<PendingQueryChange> pending_changes;
435*3f982cf4SFabien Sanglard   const std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
436*3f982cf4SFabien Sanglard       records_.Find(name, dns_type, dns_class);
437*3f982cf4SFabien Sanglard   for (const MdnsRecordTracker& tracker : trackers) {
438*3f982cf4SFabien Sanglard     if (!tracker.is_negative_response()) {
439*3f982cf4SFabien Sanglard       MdnsRecord stored_record(name, tracker.dns_type(), tracker.dns_class(),
440*3f982cf4SFabien Sanglard                                tracker.record_type(), tracker.ttl(),
441*3f982cf4SFabien Sanglard                                tracker.rdata());
442*3f982cf4SFabien Sanglard       std::vector<PendingQueryChange> new_changes = callback->OnRecordChanged(
443*3f982cf4SFabien Sanglard           std::move(stored_record), RecordChangedEvent::kCreated);
444*3f982cf4SFabien Sanglard       pending_changes.insert(pending_changes.end(), new_changes.begin(),
445*3f982cf4SFabien Sanglard                              new_changes.end());
446*3f982cf4SFabien Sanglard     }
447*3f982cf4SFabien Sanglard   }
448*3f982cf4SFabien Sanglard 
449*3f982cf4SFabien Sanglard   // Add a new question if haven't seen it before
450*3f982cf4SFabien Sanglard   auto questions_it = questions_.equal_range(name);
451*3f982cf4SFabien Sanglard   const bool is_question_already_tracked =
452*3f982cf4SFabien Sanglard       std::find_if(questions_it.first, questions_it.second,
453*3f982cf4SFabien Sanglard                    [dns_type, dns_class](const auto& entry) {
454*3f982cf4SFabien Sanglard                      const MdnsQuestion& tracked_question =
455*3f982cf4SFabien Sanglard                          entry.second->question();
456*3f982cf4SFabien Sanglard                      return dns_type == tracked_question.dns_type() &&
457*3f982cf4SFabien Sanglard                             dns_class == tracked_question.dns_class();
458*3f982cf4SFabien Sanglard                    }) != questions_it.second;
459*3f982cf4SFabien Sanglard   if (!is_question_already_tracked) {
460*3f982cf4SFabien Sanglard     AddQuestion(
461*3f982cf4SFabien Sanglard         MdnsQuestion(name, dns_type, dns_class, ResponseType::kMulticast));
462*3f982cf4SFabien Sanglard   }
463*3f982cf4SFabien Sanglard 
464*3f982cf4SFabien Sanglard   // Apply any pending changes from the OnRecordChanged() callbacks.
465*3f982cf4SFabien Sanglard   ApplyPendingChanges(std::move(pending_changes));
466*3f982cf4SFabien Sanglard }
467*3f982cf4SFabien Sanglard 
StopQuery(const DomainName & name,DnsType dns_type,DnsClass dns_class,MdnsRecordChangedCallback * callback)468*3f982cf4SFabien Sanglard void MdnsQuerier::StopQuery(const DomainName& name,
469*3f982cf4SFabien Sanglard                             DnsType dns_type,
470*3f982cf4SFabien Sanglard                             DnsClass dns_class,
471*3f982cf4SFabien Sanglard                             MdnsRecordChangedCallback* callback) {
472*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
473*3f982cf4SFabien Sanglard   OSP_DCHECK(callback);
474*3f982cf4SFabien Sanglard 
475*3f982cf4SFabien Sanglard   if (!CanBeQueried(dns_type)) {
476*3f982cf4SFabien Sanglard     return;
477*3f982cf4SFabien Sanglard   }
478*3f982cf4SFabien Sanglard 
479*3f982cf4SFabien Sanglard   // Find and remove the callback.
480*3f982cf4SFabien Sanglard   int callbacks_for_key = 0;
481*3f982cf4SFabien Sanglard   auto callbacks_it = callbacks_.equal_range(name);
482*3f982cf4SFabien Sanglard   for (auto entry = callbacks_it.first; entry != callbacks_it.second;) {
483*3f982cf4SFabien Sanglard     const CallbackInfo& callback_info = entry->second;
484*3f982cf4SFabien Sanglard     if (dns_type == callback_info.dns_type &&
485*3f982cf4SFabien Sanglard         dns_class == callback_info.dns_class) {
486*3f982cf4SFabien Sanglard       if (callback == callback_info.callback) {
487*3f982cf4SFabien Sanglard         entry = callbacks_.erase(entry);
488*3f982cf4SFabien Sanglard       } else {
489*3f982cf4SFabien Sanglard         ++callbacks_for_key;
490*3f982cf4SFabien Sanglard         ++entry;
491*3f982cf4SFabien Sanglard       }
492*3f982cf4SFabien Sanglard     }
493*3f982cf4SFabien Sanglard   }
494*3f982cf4SFabien Sanglard 
495*3f982cf4SFabien Sanglard   // Exit if there are still callbacks registered for DomainName + DnsType +
496*3f982cf4SFabien Sanglard   // DnsClass
497*3f982cf4SFabien Sanglard   if (callbacks_for_key > 0) {
498*3f982cf4SFabien Sanglard     return;
499*3f982cf4SFabien Sanglard   }
500*3f982cf4SFabien Sanglard 
501*3f982cf4SFabien Sanglard   // Find and delete a question that does not have any associated callbacks
502*3f982cf4SFabien Sanglard   auto questions_it = questions_.equal_range(name);
503*3f982cf4SFabien Sanglard   for (auto entry = questions_it.first; entry != questions_it.second; ++entry) {
504*3f982cf4SFabien Sanglard     const MdnsQuestion& tracked_question = entry->second->question();
505*3f982cf4SFabien Sanglard     if (dns_type == tracked_question.dns_type() &&
506*3f982cf4SFabien Sanglard         dns_class == tracked_question.dns_class()) {
507*3f982cf4SFabien Sanglard       questions_.erase(entry);
508*3f982cf4SFabien Sanglard       return;
509*3f982cf4SFabien Sanglard     }
510*3f982cf4SFabien Sanglard   }
511*3f982cf4SFabien Sanglard }
512*3f982cf4SFabien Sanglard 
ReinitializeQueries(const DomainName & name)513*3f982cf4SFabien Sanglard void MdnsQuerier::ReinitializeQueries(const DomainName& name) {
514*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
515*3f982cf4SFabien Sanglard 
516*3f982cf4SFabien Sanglard   // Get the ongoing queries and their callbacks.
517*3f982cf4SFabien Sanglard   std::vector<CallbackInfo> callbacks;
518*3f982cf4SFabien Sanglard   auto its = callbacks_.equal_range(name);
519*3f982cf4SFabien Sanglard   for (auto it = its.first; it != its.second; it++) {
520*3f982cf4SFabien Sanglard     callbacks.push_back(std::move(it->second));
521*3f982cf4SFabien Sanglard   }
522*3f982cf4SFabien Sanglard   callbacks_.erase(name);
523*3f982cf4SFabien Sanglard 
524*3f982cf4SFabien Sanglard   // Remove all known questions and answers.
525*3f982cf4SFabien Sanglard   questions_.erase(name);
526*3f982cf4SFabien Sanglard   records_.Erase(name, [](const MdnsRecordTracker& tracker) { return true; });
527*3f982cf4SFabien Sanglard 
528*3f982cf4SFabien Sanglard   // Restart the queries.
529*3f982cf4SFabien Sanglard   for (const auto& cb : callbacks) {
530*3f982cf4SFabien Sanglard     StartQuery(name, cb.dns_type, cb.dns_class, cb.callback);
531*3f982cf4SFabien Sanglard   }
532*3f982cf4SFabien Sanglard }
533*3f982cf4SFabien Sanglard 
OnMessageReceived(const MdnsMessage & message)534*3f982cf4SFabien Sanglard void MdnsQuerier::OnMessageReceived(const MdnsMessage& message) {
535*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
536*3f982cf4SFabien Sanglard   OSP_DCHECK(message.type() == MessageType::Response);
537*3f982cf4SFabien Sanglard 
538*3f982cf4SFabien Sanglard   OSP_DVLOG << "Received mDNS Response message with "
539*3f982cf4SFabien Sanglard             << message.answers().size() << " answers and "
540*3f982cf4SFabien Sanglard             << message.additional_records().size()
541*3f982cf4SFabien Sanglard             << " additional records. Processing...";
542*3f982cf4SFabien Sanglard 
543*3f982cf4SFabien Sanglard   std::vector<MdnsRecord> records_to_process;
544*3f982cf4SFabien Sanglard 
545*3f982cf4SFabien Sanglard   // Add any records that are relevant for this querier.
546*3f982cf4SFabien Sanglard   bool found_relevant_records = false;
547*3f982cf4SFabien Sanglard   for (const MdnsRecord& record : message.answers()) {
548*3f982cf4SFabien Sanglard     if (ShouldAnswerRecordBeProcessed(record)) {
549*3f982cf4SFabien Sanglard       records_to_process.push_back(record);
550*3f982cf4SFabien Sanglard       found_relevant_records = true;
551*3f982cf4SFabien Sanglard     }
552*3f982cf4SFabien Sanglard   }
553*3f982cf4SFabien Sanglard 
554*3f982cf4SFabien Sanglard   // If any of the message's answers are relevant, add all additional records.
555*3f982cf4SFabien Sanglard   // Else, since the message has already been received and parsed, use any
556*3f982cf4SFabien Sanglard   // individual records relevant to this querier to update the cache.
557*3f982cf4SFabien Sanglard   for (const MdnsRecord& record : message.additional_records()) {
558*3f982cf4SFabien Sanglard     if (found_relevant_records || ShouldAnswerRecordBeProcessed(record)) {
559*3f982cf4SFabien Sanglard       records_to_process.push_back(record);
560*3f982cf4SFabien Sanglard     }
561*3f982cf4SFabien Sanglard   }
562*3f982cf4SFabien Sanglard 
563*3f982cf4SFabien Sanglard   // Drop NSEC records associated with a non-NSEC record of the same type.
564*3f982cf4SFabien Sanglard   RemoveInvalidNsecFlags(&records_to_process);
565*3f982cf4SFabien Sanglard 
566*3f982cf4SFabien Sanglard   // Process all remaining records.
567*3f982cf4SFabien Sanglard   for (const MdnsRecord& record_to_process : records_to_process) {
568*3f982cf4SFabien Sanglard     ProcessRecord(record_to_process);
569*3f982cf4SFabien Sanglard   }
570*3f982cf4SFabien Sanglard 
571*3f982cf4SFabien Sanglard   OSP_DVLOG << "\tmDNS Response processed (" << records_to_process.size()
572*3f982cf4SFabien Sanglard             << " records accepted)!";
573*3f982cf4SFabien Sanglard 
574*3f982cf4SFabien Sanglard   // TODO(crbug.com/openscreen/83): Check authority records.
575*3f982cf4SFabien Sanglard }
576*3f982cf4SFabien Sanglard 
ShouldAnswerRecordBeProcessed(const MdnsRecord & answer)577*3f982cf4SFabien Sanglard bool MdnsQuerier::ShouldAnswerRecordBeProcessed(const MdnsRecord& answer) {
578*3f982cf4SFabien Sanglard   // First, accept the record if it's associated with an ongoing question.
579*3f982cf4SFabien Sanglard   const auto questions_range = questions_.equal_range(answer.name());
580*3f982cf4SFabien Sanglard   const auto it = std::find_if(
581*3f982cf4SFabien Sanglard       questions_range.first, questions_range.second,
582*3f982cf4SFabien Sanglard       [&answer](const auto& pair) {
583*3f982cf4SFabien Sanglard         return (pair.second->question().dns_type() == DnsType::kANY ||
584*3f982cf4SFabien Sanglard                 IsNegativeResponseFor(answer,
585*3f982cf4SFabien Sanglard                                       pair.second->question().dns_type()) ||
586*3f982cf4SFabien Sanglard                 pair.second->question().dns_type() == answer.dns_type()) &&
587*3f982cf4SFabien Sanglard                (pair.second->question().dns_class() == DnsClass::kANY ||
588*3f982cf4SFabien Sanglard                 pair.second->question().dns_class() == answer.dns_class());
589*3f982cf4SFabien Sanglard       });
590*3f982cf4SFabien Sanglard   if (it != questions_range.second) {
591*3f982cf4SFabien Sanglard     return true;
592*3f982cf4SFabien Sanglard   }
593*3f982cf4SFabien Sanglard 
594*3f982cf4SFabien Sanglard   // If not, check if it corresponds to an already existing record. This is
595*3f982cf4SFabien Sanglard   // required because records which are already stored may either have been
596*3f982cf4SFabien Sanglard   // received in an additional records section, or are associated with a query
597*3f982cf4SFabien Sanglard   // which is no longer active.
598*3f982cf4SFabien Sanglard   std::vector<DnsType> types{answer.dns_type()};
599*3f982cf4SFabien Sanglard   if (answer.dns_type() == DnsType::kNSEC) {
600*3f982cf4SFabien Sanglard     const auto& nsec_rdata = absl::get<NsecRecordRdata>(answer.rdata());
601*3f982cf4SFabien Sanglard     types = nsec_rdata.types();
602*3f982cf4SFabien Sanglard   }
603*3f982cf4SFabien Sanglard 
604*3f982cf4SFabien Sanglard   for (DnsType type : types) {
605*3f982cf4SFabien Sanglard     std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
606*3f982cf4SFabien Sanglard         records_.Find(answer.name(), type, answer.dns_class());
607*3f982cf4SFabien Sanglard     if (!trackers.empty()) {
608*3f982cf4SFabien Sanglard       return true;
609*3f982cf4SFabien Sanglard     }
610*3f982cf4SFabien Sanglard   }
611*3f982cf4SFabien Sanglard 
612*3f982cf4SFabien Sanglard   // In all other cases, the record isn't relevant. Drop it.
613*3f982cf4SFabien Sanglard   return false;
614*3f982cf4SFabien Sanglard }
615*3f982cf4SFabien Sanglard 
OnRecordExpired(const MdnsRecordTracker * tracker,const MdnsRecord & record)616*3f982cf4SFabien Sanglard void MdnsQuerier::OnRecordExpired(const MdnsRecordTracker* tracker,
617*3f982cf4SFabien Sanglard                                   const MdnsRecord& record) {
618*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
619*3f982cf4SFabien Sanglard 
620*3f982cf4SFabien Sanglard   if (!tracker->is_negative_response()) {
621*3f982cf4SFabien Sanglard     ProcessCallbacks(record, RecordChangedEvent::kExpired);
622*3f982cf4SFabien Sanglard   }
623*3f982cf4SFabien Sanglard 
624*3f982cf4SFabien Sanglard   records_.Erase(record.name(), [tracker](const MdnsRecordTracker& it_tracker) {
625*3f982cf4SFabien Sanglard     return tracker == &it_tracker;
626*3f982cf4SFabien Sanglard   });
627*3f982cf4SFabien Sanglard }
628*3f982cf4SFabien Sanglard 
ProcessRecord(const MdnsRecord & record)629*3f982cf4SFabien Sanglard void MdnsQuerier::ProcessRecord(const MdnsRecord& record) {
630*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
631*3f982cf4SFabien Sanglard 
632*3f982cf4SFabien Sanglard   // Skip all records that can't be processed.
633*3f982cf4SFabien Sanglard   if (!CanBeProcessed(record.dns_type())) {
634*3f982cf4SFabien Sanglard     return;
635*3f982cf4SFabien Sanglard   }
636*3f982cf4SFabien Sanglard 
637*3f982cf4SFabien Sanglard   // Ignore NSEC records if the embedder has configured us to do so.
638*3f982cf4SFabien Sanglard   if (config_.ignore_nsec_responses && record.dns_type() == DnsType::kNSEC) {
639*3f982cf4SFabien Sanglard     return;
640*3f982cf4SFabien Sanglard   }
641*3f982cf4SFabien Sanglard 
642*3f982cf4SFabien Sanglard   // Get the types which the received record is associated with. In most cases
643*3f982cf4SFabien Sanglard   // this will only be the type of the provided record, but in the case of
644*3f982cf4SFabien Sanglard   // NSEC records this will be all records which the record dictates the
645*3f982cf4SFabien Sanglard   // nonexistence of.
646*3f982cf4SFabien Sanglard   std::vector<DnsType> types;
647*3f982cf4SFabien Sanglard   int types_count = 0;
648*3f982cf4SFabien Sanglard   const DnsType* types_ptr = nullptr;
649*3f982cf4SFabien Sanglard   if (record.dns_type() == DnsType::kNSEC) {
650*3f982cf4SFabien Sanglard     const auto& nsec_rdata = absl::get<NsecRecordRdata>(record.rdata());
651*3f982cf4SFabien Sanglard     if (std::find(nsec_rdata.types().begin(), nsec_rdata.types().end(),
652*3f982cf4SFabien Sanglard                   DnsType::kANY) != nsec_rdata.types().end()) {
653*3f982cf4SFabien Sanglard       types_ptr = kTranslatedNsecAnyQueryTypes.data();
654*3f982cf4SFabien Sanglard       types_count = kTranslatedNsecAnyQueryTypes.size();
655*3f982cf4SFabien Sanglard     } else {
656*3f982cf4SFabien Sanglard       types_ptr = nsec_rdata.types().data();
657*3f982cf4SFabien Sanglard       types_count = nsec_rdata.types().size();
658*3f982cf4SFabien Sanglard     }
659*3f982cf4SFabien Sanglard   } else {
660*3f982cf4SFabien Sanglard     types.push_back(record.dns_type());
661*3f982cf4SFabien Sanglard     types_ptr = types.data();
662*3f982cf4SFabien Sanglard     types_count = types.size();
663*3f982cf4SFabien Sanglard   }
664*3f982cf4SFabien Sanglard 
665*3f982cf4SFabien Sanglard   // Apply the update for each type that the record is associated with.
666*3f982cf4SFabien Sanglard   for (int i = 0; i < types_count; ++i) {
667*3f982cf4SFabien Sanglard     DnsType dns_type = types_ptr[i];
668*3f982cf4SFabien Sanglard     switch (record.record_type()) {
669*3f982cf4SFabien Sanglard       case RecordType::kShared: {
670*3f982cf4SFabien Sanglard         ProcessSharedRecord(record, dns_type);
671*3f982cf4SFabien Sanglard         break;
672*3f982cf4SFabien Sanglard       }
673*3f982cf4SFabien Sanglard       case RecordType::kUnique: {
674*3f982cf4SFabien Sanglard         ProcessUniqueRecord(record, dns_type);
675*3f982cf4SFabien Sanglard         break;
676*3f982cf4SFabien Sanglard       }
677*3f982cf4SFabien Sanglard     }
678*3f982cf4SFabien Sanglard   }
679*3f982cf4SFabien Sanglard }
680*3f982cf4SFabien Sanglard 
ProcessSharedRecord(const MdnsRecord & record,DnsType dns_type)681*3f982cf4SFabien Sanglard void MdnsQuerier::ProcessSharedRecord(const MdnsRecord& record,
682*3f982cf4SFabien Sanglard                                       DnsType dns_type) {
683*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
684*3f982cf4SFabien Sanglard   OSP_DCHECK(record.record_type() == RecordType::kShared);
685*3f982cf4SFabien Sanglard 
686*3f982cf4SFabien Sanglard   // By design, NSEC records are never shared records.
687*3f982cf4SFabien Sanglard   if (record.dns_type() == DnsType::kNSEC) {
688*3f982cf4SFabien Sanglard     return;
689*3f982cf4SFabien Sanglard   }
690*3f982cf4SFabien Sanglard 
691*3f982cf4SFabien Sanglard   // For any records updated, this host already has this shared record. Since
692*3f982cf4SFabien Sanglard   // the RDATA matches, this is only a TTL update.
693*3f982cf4SFabien Sanglard   auto check = [&record](const MdnsRecordTracker& tracker) {
694*3f982cf4SFabien Sanglard     return record.dns_type() == tracker.dns_type() &&
695*3f982cf4SFabien Sanglard            record.dns_class() == tracker.dns_class() &&
696*3f982cf4SFabien Sanglard            record.rdata() == tracker.rdata();
697*3f982cf4SFabien Sanglard   };
698*3f982cf4SFabien Sanglard   auto updated_count = records_.Update(record, std::move(check));
699*3f982cf4SFabien Sanglard 
700*3f982cf4SFabien Sanglard   if (!updated_count) {
701*3f982cf4SFabien Sanglard     // Have never before seen this shared record, insert a new one.
702*3f982cf4SFabien Sanglard     AddRecord(record, dns_type);
703*3f982cf4SFabien Sanglard     ProcessCallbacks(record, RecordChangedEvent::kCreated);
704*3f982cf4SFabien Sanglard   }
705*3f982cf4SFabien Sanglard }
706*3f982cf4SFabien Sanglard 
ProcessUniqueRecord(const MdnsRecord & record,DnsType dns_type)707*3f982cf4SFabien Sanglard void MdnsQuerier::ProcessUniqueRecord(const MdnsRecord& record,
708*3f982cf4SFabien Sanglard                                       DnsType dns_type) {
709*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
710*3f982cf4SFabien Sanglard   OSP_DCHECK(record.record_type() == RecordType::kUnique);
711*3f982cf4SFabien Sanglard 
712*3f982cf4SFabien Sanglard   std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
713*3f982cf4SFabien Sanglard       records_.Find(record.name(), dns_type, record.dns_class());
714*3f982cf4SFabien Sanglard   size_t num_records_for_key = trackers.size();
715*3f982cf4SFabien Sanglard 
716*3f982cf4SFabien Sanglard   // Have not seen any records with this key before. This case is expected the
717*3f982cf4SFabien Sanglard   // first time a record is received.
718*3f982cf4SFabien Sanglard   if (num_records_for_key == size_t{0}) {
719*3f982cf4SFabien Sanglard     const bool will_exist = record.dns_type() != DnsType::kNSEC;
720*3f982cf4SFabien Sanglard     AddRecord(record, dns_type);
721*3f982cf4SFabien Sanglard     if (will_exist) {
722*3f982cf4SFabien Sanglard       ProcessCallbacks(record, RecordChangedEvent::kCreated);
723*3f982cf4SFabien Sanglard     }
724*3f982cf4SFabien Sanglard   } else if (num_records_for_key == size_t{1}) {
725*3f982cf4SFabien Sanglard     // There is exactly one tracker associated with this key. This is the
726*3f982cf4SFabien Sanglard     // expected case when a record matching this one has already been seen.
727*3f982cf4SFabien Sanglard     ProcessSinglyTrackedUniqueRecord(record, trackers[0]);
728*3f982cf4SFabien Sanglard   } else {
729*3f982cf4SFabien Sanglard     // Multiple records with the same key.
730*3f982cf4SFabien Sanglard     ProcessMultiTrackedUniqueRecord(record, dns_type);
731*3f982cf4SFabien Sanglard   }
732*3f982cf4SFabien Sanglard }
733*3f982cf4SFabien Sanglard 
ProcessSinglyTrackedUniqueRecord(const MdnsRecord & record,const MdnsRecordTracker & tracker)734*3f982cf4SFabien Sanglard void MdnsQuerier::ProcessSinglyTrackedUniqueRecord(
735*3f982cf4SFabien Sanglard     const MdnsRecord& record,
736*3f982cf4SFabien Sanglard     const MdnsRecordTracker& tracker) {
737*3f982cf4SFabien Sanglard   const bool existed_previously = !tracker.is_negative_response();
738*3f982cf4SFabien Sanglard   const bool will_exist = record.dns_type() != DnsType::kNSEC;
739*3f982cf4SFabien Sanglard 
740*3f982cf4SFabien Sanglard   // Calculate the callback to call on record update success while the old
741*3f982cf4SFabien Sanglard   // record still exists.
742*3f982cf4SFabien Sanglard   MdnsRecord record_for_callback = record;
743*3f982cf4SFabien Sanglard   if (existed_previously && !will_exist) {
744*3f982cf4SFabien Sanglard     record_for_callback =
745*3f982cf4SFabien Sanglard         MdnsRecord(record.name(), tracker.dns_type(), tracker.dns_class(),
746*3f982cf4SFabien Sanglard                    tracker.record_type(), tracker.ttl(), tracker.rdata());
747*3f982cf4SFabien Sanglard   }
748*3f982cf4SFabien Sanglard 
749*3f982cf4SFabien Sanglard   auto on_rdata_change = [this, r = std::move(record_for_callback),
750*3f982cf4SFabien Sanglard                           existed_previously,
751*3f982cf4SFabien Sanglard                           will_exist](const MdnsRecordTracker& tracker) {
752*3f982cf4SFabien Sanglard     // If RDATA on the record is different, notify that the record has
753*3f982cf4SFabien Sanglard     // been updated.
754*3f982cf4SFabien Sanglard     if (existed_previously && will_exist) {
755*3f982cf4SFabien Sanglard       ProcessCallbacks(r, RecordChangedEvent::kUpdated);
756*3f982cf4SFabien Sanglard     } else if (existed_previously) {
757*3f982cf4SFabien Sanglard       // Do not expire the tracker, because it still holds an NSEC record.
758*3f982cf4SFabien Sanglard       ProcessCallbacks(r, RecordChangedEvent::kExpired);
759*3f982cf4SFabien Sanglard     } else if (will_exist) {
760*3f982cf4SFabien Sanglard       ProcessCallbacks(r, RecordChangedEvent::kCreated);
761*3f982cf4SFabien Sanglard     }
762*3f982cf4SFabien Sanglard   };
763*3f982cf4SFabien Sanglard 
764*3f982cf4SFabien Sanglard   int updated_count = records_.Update(
765*3f982cf4SFabien Sanglard       record, [&tracker](const MdnsRecordTracker& t) { return &tracker == &t; },
766*3f982cf4SFabien Sanglard       std::move(on_rdata_change));
767*3f982cf4SFabien Sanglard   OSP_DCHECK_EQ(updated_count, 1);
768*3f982cf4SFabien Sanglard }
769*3f982cf4SFabien Sanglard 
ProcessMultiTrackedUniqueRecord(const MdnsRecord & record,DnsType dns_type)770*3f982cf4SFabien Sanglard void MdnsQuerier::ProcessMultiTrackedUniqueRecord(const MdnsRecord& record,
771*3f982cf4SFabien Sanglard                                                   DnsType dns_type) {
772*3f982cf4SFabien Sanglard   auto update_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
773*3f982cf4SFabien Sanglard     return tracker.dns_type() == dns_type &&
774*3f982cf4SFabien Sanglard            tracker.dns_class() == record.dns_class() &&
775*3f982cf4SFabien Sanglard            tracker.rdata() == record.rdata();
776*3f982cf4SFabien Sanglard   };
777*3f982cf4SFabien Sanglard   int update_count = records_.Update(
778*3f982cf4SFabien Sanglard       record, std::move(update_check),
779*3f982cf4SFabien Sanglard       [](const MdnsRecordTracker& tracker) { OSP_NOTREACHED(); });
780*3f982cf4SFabien Sanglard   OSP_DCHECK_LE(update_count, 1);
781*3f982cf4SFabien Sanglard 
782*3f982cf4SFabien Sanglard   auto expire_check = [&record, dns_type](const MdnsRecordTracker& tracker) {
783*3f982cf4SFabien Sanglard     return tracker.dns_type() == dns_type &&
784*3f982cf4SFabien Sanglard            tracker.dns_class() == record.dns_class() &&
785*3f982cf4SFabien Sanglard            tracker.rdata() != record.rdata();
786*3f982cf4SFabien Sanglard   };
787*3f982cf4SFabien Sanglard   int expire_count =
788*3f982cf4SFabien Sanglard       records_.ExpireSoon(record.name(), std::move(expire_check));
789*3f982cf4SFabien Sanglard   OSP_DCHECK_GE(expire_count, 1);
790*3f982cf4SFabien Sanglard 
791*3f982cf4SFabien Sanglard   // Did not find an existing record to update.
792*3f982cf4SFabien Sanglard   if (!update_count && !expire_count) {
793*3f982cf4SFabien Sanglard     AddRecord(record, dns_type);
794*3f982cf4SFabien Sanglard     if (record.dns_type() != DnsType::kNSEC) {
795*3f982cf4SFabien Sanglard       ProcessCallbacks(record, RecordChangedEvent::kCreated);
796*3f982cf4SFabien Sanglard     }
797*3f982cf4SFabien Sanglard   }
798*3f982cf4SFabien Sanglard }
799*3f982cf4SFabien Sanglard 
ProcessCallbacks(const MdnsRecord & record,RecordChangedEvent event)800*3f982cf4SFabien Sanglard void MdnsQuerier::ProcessCallbacks(const MdnsRecord& record,
801*3f982cf4SFabien Sanglard                                    RecordChangedEvent event) {
802*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
803*3f982cf4SFabien Sanglard 
804*3f982cf4SFabien Sanglard   std::vector<PendingQueryChange> pending_changes;
805*3f982cf4SFabien Sanglard   auto callbacks_it = callbacks_.equal_range(record.name());
806*3f982cf4SFabien Sanglard   for (auto entry = callbacks_it.first; entry != callbacks_it.second; ++entry) {
807*3f982cf4SFabien Sanglard     const CallbackInfo& callback_info = entry->second;
808*3f982cf4SFabien Sanglard     if ((callback_info.dns_type == DnsType::kANY ||
809*3f982cf4SFabien Sanglard          record.dns_type() == callback_info.dns_type) &&
810*3f982cf4SFabien Sanglard         (callback_info.dns_class == DnsClass::kANY ||
811*3f982cf4SFabien Sanglard          record.dns_class() == callback_info.dns_class)) {
812*3f982cf4SFabien Sanglard       std::vector<PendingQueryChange> new_changes =
813*3f982cf4SFabien Sanglard           callback_info.callback->OnRecordChanged(record, event);
814*3f982cf4SFabien Sanglard       pending_changes.insert(pending_changes.end(), new_changes.begin(),
815*3f982cf4SFabien Sanglard                              new_changes.end());
816*3f982cf4SFabien Sanglard     }
817*3f982cf4SFabien Sanglard   }
818*3f982cf4SFabien Sanglard 
819*3f982cf4SFabien Sanglard   ApplyPendingChanges(std::move(pending_changes));
820*3f982cf4SFabien Sanglard }
821*3f982cf4SFabien Sanglard 
AddQuestion(const MdnsQuestion & question)822*3f982cf4SFabien Sanglard void MdnsQuerier::AddQuestion(const MdnsQuestion& question) {
823*3f982cf4SFabien Sanglard   auto tracker = std::make_unique<MdnsQuestionTracker>(
824*3f982cf4SFabien Sanglard       question, sender_, task_runner_, now_function_, random_delay_, config_);
825*3f982cf4SFabien Sanglard   MdnsQuestionTracker* ptr = tracker.get();
826*3f982cf4SFabien Sanglard   questions_.emplace(question.name(), std::move(tracker));
827*3f982cf4SFabien Sanglard 
828*3f982cf4SFabien Sanglard   // Let all records associated with this question know that there is a new
829*3f982cf4SFabien Sanglard   // query that can be used for their refresh.
830*3f982cf4SFabien Sanglard   std::vector<RecordTrackerLruCache::RecordTrackerConstRef> trackers =
831*3f982cf4SFabien Sanglard       records_.Find(question.name(), question.dns_type(), question.dns_class());
832*3f982cf4SFabien Sanglard   for (const MdnsRecordTracker& tracker : trackers) {
833*3f982cf4SFabien Sanglard     // NOTE: When the pointed to object is deleted, its dtor removes itself
834*3f982cf4SFabien Sanglard     // from all associated records.
835*3f982cf4SFabien Sanglard     ptr->AddAssociatedRecord(&tracker);
836*3f982cf4SFabien Sanglard   }
837*3f982cf4SFabien Sanglard }
838*3f982cf4SFabien Sanglard 
AddRecord(const MdnsRecord & record,DnsType type)839*3f982cf4SFabien Sanglard void MdnsQuerier::AddRecord(const MdnsRecord& record, DnsType type) {
840*3f982cf4SFabien Sanglard   // Add the new record.
841*3f982cf4SFabien Sanglard   const auto& tracker = records_.StartTracking(record, type);
842*3f982cf4SFabien Sanglard 
843*3f982cf4SFabien Sanglard   // Let all questions associated with this record know that there is a new
844*3f982cf4SFabien Sanglard   // record that answers them (for known answer suppression).
845*3f982cf4SFabien Sanglard   auto query_it = questions_.equal_range(record.name());
846*3f982cf4SFabien Sanglard   for (auto entry = query_it.first; entry != query_it.second; ++entry) {
847*3f982cf4SFabien Sanglard     const MdnsQuestion& query = entry->second->question();
848*3f982cf4SFabien Sanglard     const bool is_relevant_type =
849*3f982cf4SFabien Sanglard         type == DnsType::kANY || type == query.dns_type();
850*3f982cf4SFabien Sanglard     const bool is_relevant_class = record.dns_class() == DnsClass::kANY ||
851*3f982cf4SFabien Sanglard                                    record.dns_class() == query.dns_class();
852*3f982cf4SFabien Sanglard     if (is_relevant_type && is_relevant_class) {
853*3f982cf4SFabien Sanglard       // NOTE: When the pointed to object is deleted, its dtor removes itself
854*3f982cf4SFabien Sanglard       // from all associated queries.
855*3f982cf4SFabien Sanglard       entry->second->AddAssociatedRecord(&tracker);
856*3f982cf4SFabien Sanglard     }
857*3f982cf4SFabien Sanglard   }
858*3f982cf4SFabien Sanglard }
859*3f982cf4SFabien Sanglard 
ApplyPendingChanges(std::vector<PendingQueryChange> pending_changes)860*3f982cf4SFabien Sanglard void MdnsQuerier::ApplyPendingChanges(
861*3f982cf4SFabien Sanglard     std::vector<PendingQueryChange> pending_changes) {
862*3f982cf4SFabien Sanglard   for (auto& pending_change : pending_changes) {
863*3f982cf4SFabien Sanglard     switch (pending_change.change_type) {
864*3f982cf4SFabien Sanglard       case PendingQueryChange::kStartQuery:
865*3f982cf4SFabien Sanglard         StartQuery(std::move(pending_change.name), pending_change.dns_type,
866*3f982cf4SFabien Sanglard                    pending_change.dns_class, pending_change.callback);
867*3f982cf4SFabien Sanglard         break;
868*3f982cf4SFabien Sanglard       case PendingQueryChange::kStopQuery:
869*3f982cf4SFabien Sanglard         StopQuery(std::move(pending_change.name), pending_change.dns_type,
870*3f982cf4SFabien Sanglard                   pending_change.dns_class, pending_change.callback);
871*3f982cf4SFabien Sanglard         break;
872*3f982cf4SFabien Sanglard     }
873*3f982cf4SFabien Sanglard   }
874*3f982cf4SFabien Sanglard }
875*3f982cf4SFabien Sanglard 
876*3f982cf4SFabien Sanglard }  // namespace discovery
877*3f982cf4SFabien Sanglard }  // namespace openscreen
878