xref: /aosp_15_r20/external/openscreen/discovery/dnssd/impl/querier_impl.cc (revision 3f982cf4871df8771c9d4abe6e9a6f8d829b2736)
1*3f982cf4SFabien Sanglard // Copyright 2019 The Chromium Authors. All rights reserved.
2*3f982cf4SFabien Sanglard // Use of this source code is governed by a BSD-style license that can be
3*3f982cf4SFabien Sanglard // found in the LICENSE file.
4*3f982cf4SFabien Sanglard 
5*3f982cf4SFabien Sanglard #include "discovery/dnssd/impl/querier_impl.h"
6*3f982cf4SFabien Sanglard 
7*3f982cf4SFabien Sanglard #include <algorithm>
8*3f982cf4SFabien Sanglard #include <string>
9*3f982cf4SFabien Sanglard #include <utility>
10*3f982cf4SFabien Sanglard #include <vector>
11*3f982cf4SFabien Sanglard 
12*3f982cf4SFabien Sanglard #include "discovery/common/reporting_client.h"
13*3f982cf4SFabien Sanglard #include "discovery/dnssd/impl/conversion_layer.h"
14*3f982cf4SFabien Sanglard #include "discovery/dnssd/impl/network_interface_config.h"
15*3f982cf4SFabien Sanglard #include "platform/api/task_runner.h"
16*3f982cf4SFabien Sanglard #include "util/osp_logging.h"
17*3f982cf4SFabien Sanglard 
18*3f982cf4SFabien Sanglard namespace openscreen {
19*3f982cf4SFabien Sanglard namespace discovery {
20*3f982cf4SFabien Sanglard namespace {
21*3f982cf4SFabien Sanglard 
22*3f982cf4SFabien Sanglard static constexpr char kLocalDomain[] = "local";
23*3f982cf4SFabien Sanglard 
24*3f982cf4SFabien Sanglard // Removes all error instances from the below records, and calls the log
25*3f982cf4SFabien Sanglard // function on all errors present in |new_endpoints|. Input vectors are expected
26*3f982cf4SFabien Sanglard // to be sorted in ascending order.
ProcessErrors(std::vector<ErrorOr<DnsSdInstanceEndpoint>> * old_endpoints,std::vector<ErrorOr<DnsSdInstanceEndpoint>> * new_endpoints,std::function<void (Error)> log)27*3f982cf4SFabien Sanglard void ProcessErrors(std::vector<ErrorOr<DnsSdInstanceEndpoint>>* old_endpoints,
28*3f982cf4SFabien Sanglard                    std::vector<ErrorOr<DnsSdInstanceEndpoint>>* new_endpoints,
29*3f982cf4SFabien Sanglard                    std::function<void(Error)> log) {
30*3f982cf4SFabien Sanglard   OSP_DCHECK(old_endpoints);
31*3f982cf4SFabien Sanglard   OSP_DCHECK(new_endpoints);
32*3f982cf4SFabien Sanglard 
33*3f982cf4SFabien Sanglard   auto old_it = old_endpoints->begin();
34*3f982cf4SFabien Sanglard   auto new_it = new_endpoints->begin();
35*3f982cf4SFabien Sanglard 
36*3f982cf4SFabien Sanglard   // Iterate across both vectors and log new errors in the process.
37*3f982cf4SFabien Sanglard   // NOTE: In sorted order, all errors will appear before all non-errors.
38*3f982cf4SFabien Sanglard   while (old_it != old_endpoints->end() && new_it != new_endpoints->end()) {
39*3f982cf4SFabien Sanglard     ErrorOr<DnsSdInstanceEndpoint>& old_ep = *old_it;
40*3f982cf4SFabien Sanglard     ErrorOr<DnsSdInstanceEndpoint>& new_ep = *new_it;
41*3f982cf4SFabien Sanglard 
42*3f982cf4SFabien Sanglard     if (new_ep.is_value()) {
43*3f982cf4SFabien Sanglard       break;
44*3f982cf4SFabien Sanglard     }
45*3f982cf4SFabien Sanglard 
46*3f982cf4SFabien Sanglard     // If they are equal, the element is in both |old_endpoints| and
47*3f982cf4SFabien Sanglard     // |new_endpoints|, so skip it in both vectors.
48*3f982cf4SFabien Sanglard     if (old_ep == new_ep) {
49*3f982cf4SFabien Sanglard       old_it++;
50*3f982cf4SFabien Sanglard       new_it++;
51*3f982cf4SFabien Sanglard       continue;
52*3f982cf4SFabien Sanglard     }
53*3f982cf4SFabien Sanglard 
54*3f982cf4SFabien Sanglard     // There's an error in |old_endpoints| not in |new_endpoints|, so skip it.
55*3f982cf4SFabien Sanglard     if (old_ep < new_ep) {
56*3f982cf4SFabien Sanglard       old_it++;
57*3f982cf4SFabien Sanglard       continue;
58*3f982cf4SFabien Sanglard     }
59*3f982cf4SFabien Sanglard 
60*3f982cf4SFabien Sanglard     // There's an error in |new_endpoints| not in |old_endpoints|, so it's a new
61*3f982cf4SFabien Sanglard     // error from the applied changes. Log it.
62*3f982cf4SFabien Sanglard     log(std::move(new_ep.error()));
63*3f982cf4SFabien Sanglard     new_it++;
64*3f982cf4SFabien Sanglard   }
65*3f982cf4SFabien Sanglard 
66*3f982cf4SFabien Sanglard   // Skip all remaining errors in the old vector.
67*3f982cf4SFabien Sanglard   for (; old_it != old_endpoints->end() && old_it->is_error(); old_it++) {
68*3f982cf4SFabien Sanglard   }
69*3f982cf4SFabien Sanglard 
70*3f982cf4SFabien Sanglard   // Log all errors remaining in the new vector.
71*3f982cf4SFabien Sanglard   for (; new_it != new_endpoints->end() && new_it->is_error(); new_it++) {
72*3f982cf4SFabien Sanglard     log(std::move(new_it->error()));
73*3f982cf4SFabien Sanglard   }
74*3f982cf4SFabien Sanglard 
75*3f982cf4SFabien Sanglard   // Erase errors.
76*3f982cf4SFabien Sanglard   old_endpoints->erase(old_endpoints->begin(), old_it);
77*3f982cf4SFabien Sanglard   new_endpoints->erase(new_endpoints->begin(), new_it);
78*3f982cf4SFabien Sanglard }
79*3f982cf4SFabien Sanglard 
80*3f982cf4SFabien Sanglard // Returns a vector containing the value of each ErrorOr<> instance provided.
81*3f982cf4SFabien Sanglard // All ErrorOr<> values are expected to be non-errors.
GetValues(std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints)82*3f982cf4SFabien Sanglard std::vector<DnsSdInstanceEndpoint> GetValues(
83*3f982cf4SFabien Sanglard     std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints) {
84*3f982cf4SFabien Sanglard   std::vector<DnsSdInstanceEndpoint> results;
85*3f982cf4SFabien Sanglard   results.reserve(endpoints.size());
86*3f982cf4SFabien Sanglard   for (ErrorOr<DnsSdInstanceEndpoint>& endpoint : endpoints) {
87*3f982cf4SFabien Sanglard     results.push_back(std::move(endpoint.value()));
88*3f982cf4SFabien Sanglard   }
89*3f982cf4SFabien Sanglard   return results;
90*3f982cf4SFabien Sanglard }
91*3f982cf4SFabien Sanglard 
IsEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint> & first,const absl::optional<DnsSdInstanceEndpoint> & second)92*3f982cf4SFabien Sanglard bool IsEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint>& first,
93*3f982cf4SFabien Sanglard                      const absl::optional<DnsSdInstanceEndpoint>& second) {
94*3f982cf4SFabien Sanglard   if (!first.has_value() || !second.has_value()) {
95*3f982cf4SFabien Sanglard     return !first.has_value() && !second.has_value();
96*3f982cf4SFabien Sanglard   }
97*3f982cf4SFabien Sanglard 
98*3f982cf4SFabien Sanglard   // In the remaining case, both |first| and |second| must be values.
99*3f982cf4SFabien Sanglard   const DnsSdInstanceEndpoint& a = first.value();
100*3f982cf4SFabien Sanglard   const DnsSdInstanceEndpoint& b = second.value();
101*3f982cf4SFabien Sanglard 
102*3f982cf4SFabien Sanglard   // All endpoints from this querier should have the same network interface
103*3f982cf4SFabien Sanglard   // because the querier is only associated with a single network interface.
104*3f982cf4SFabien Sanglard   OSP_DCHECK_EQ(a.network_interface(), b.network_interface());
105*3f982cf4SFabien Sanglard 
106*3f982cf4SFabien Sanglard   // Function returns true if first < second.
107*3f982cf4SFabien Sanglard   return a.instance_id() == b.instance_id() &&
108*3f982cf4SFabien Sanglard          a.service_id() == b.service_id() && a.domain_id() == b.domain_id();
109*3f982cf4SFabien Sanglard }
110*3f982cf4SFabien Sanglard 
IsNotEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint> & first,const absl::optional<DnsSdInstanceEndpoint> & second)111*3f982cf4SFabien Sanglard bool IsNotEqualOrUpdate(const absl::optional<DnsSdInstanceEndpoint>& first,
112*3f982cf4SFabien Sanglard                         const absl::optional<DnsSdInstanceEndpoint>& second) {
113*3f982cf4SFabien Sanglard   return !IsEqualOrUpdate(first, second);
114*3f982cf4SFabien Sanglard }
115*3f982cf4SFabien Sanglard 
116*3f982cf4SFabien Sanglard // Calculates the created, updated, and deleted elements using the provided
117*3f982cf4SFabien Sanglard // sets, appending these values to the provided vectors. Each of the input
118*3f982cf4SFabien Sanglard // vectors is expected to contain only elements such that
119*3f982cf4SFabien Sanglard // |element|.is_error() == false. Additionally, input vectors are expected to
120*3f982cf4SFabien Sanglard // be sorted in ascending order.
121*3f982cf4SFabien Sanglard //
122*3f982cf4SFabien Sanglard // NOTE: A lot of operations are used to do this, but each is only O(n) so the
123*3f982cf4SFabien Sanglard // resulting algorithm is still fast.
CalculateChangeSets(std::vector<DnsSdInstanceEndpoint> old_endpoints,std::vector<DnsSdInstanceEndpoint> new_endpoints,std::vector<DnsSdInstanceEndpoint> * created_out,std::vector<DnsSdInstanceEndpoint> * updated_out,std::vector<DnsSdInstanceEndpoint> * deleted_out)124*3f982cf4SFabien Sanglard void CalculateChangeSets(std::vector<DnsSdInstanceEndpoint> old_endpoints,
125*3f982cf4SFabien Sanglard                          std::vector<DnsSdInstanceEndpoint> new_endpoints,
126*3f982cf4SFabien Sanglard                          std::vector<DnsSdInstanceEndpoint>* created_out,
127*3f982cf4SFabien Sanglard                          std::vector<DnsSdInstanceEndpoint>* updated_out,
128*3f982cf4SFabien Sanglard                          std::vector<DnsSdInstanceEndpoint>* deleted_out) {
129*3f982cf4SFabien Sanglard   OSP_DCHECK(created_out);
130*3f982cf4SFabien Sanglard   OSP_DCHECK(updated_out);
131*3f982cf4SFabien Sanglard   OSP_DCHECK(deleted_out);
132*3f982cf4SFabien Sanglard 
133*3f982cf4SFabien Sanglard   // Use set difference with default operators to find the elements present in
134*3f982cf4SFabien Sanglard   // one list but not the others.
135*3f982cf4SFabien Sanglard   //
136*3f982cf4SFabien Sanglard   // NOTE: Because absl::optional<...> types are used here and below, calls to
137*3f982cf4SFabien Sanglard   // the ctor and dtor for empty elements are no-ops.
138*3f982cf4SFabien Sanglard   const int total_count = old_endpoints.size() + new_endpoints.size();
139*3f982cf4SFabien Sanglard 
140*3f982cf4SFabien Sanglard   // This is the set of elements that aren't in the old endpoints, meaning the
141*3f982cf4SFabien Sanglard   // old endpoint either didn't exist or had different TXT / Address / etc..
142*3f982cf4SFabien Sanglard   std::vector<absl::optional<DnsSdInstanceEndpoint>> created_or_updated(
143*3f982cf4SFabien Sanglard       total_count);
144*3f982cf4SFabien Sanglard   auto new_end = std::set_difference(new_endpoints.begin(), new_endpoints.end(),
145*3f982cf4SFabien Sanglard                                      old_endpoints.begin(), old_endpoints.end(),
146*3f982cf4SFabien Sanglard                                      created_or_updated.begin());
147*3f982cf4SFabien Sanglard   created_or_updated.erase(new_end, created_or_updated.end());
148*3f982cf4SFabien Sanglard 
149*3f982cf4SFabien Sanglard   // This is the set of elements that are only in the old endpoints, similar to
150*3f982cf4SFabien Sanglard   // the above.
151*3f982cf4SFabien Sanglard   std::vector<absl::optional<DnsSdInstanceEndpoint>> deleted_or_updated(
152*3f982cf4SFabien Sanglard       total_count);
153*3f982cf4SFabien Sanglard   new_end = std::set_difference(old_endpoints.begin(), old_endpoints.end(),
154*3f982cf4SFabien Sanglard                                 new_endpoints.begin(), new_endpoints.end(),
155*3f982cf4SFabien Sanglard                                 deleted_or_updated.begin());
156*3f982cf4SFabien Sanglard   deleted_or_updated.erase(new_end, deleted_or_updated.end());
157*3f982cf4SFabien Sanglard 
158*3f982cf4SFabien Sanglard   // Next, find the elements which were updated.
159*3f982cf4SFabien Sanglard   const size_t max_count =
160*3f982cf4SFabien Sanglard       std::max(created_or_updated.size(), deleted_or_updated.size());
161*3f982cf4SFabien Sanglard   std::vector<absl::optional<DnsSdInstanceEndpoint>> updated(max_count);
162*3f982cf4SFabien Sanglard   new_end = std::set_intersection(
163*3f982cf4SFabien Sanglard       created_or_updated.begin(), created_or_updated.end(),
164*3f982cf4SFabien Sanglard       deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(),
165*3f982cf4SFabien Sanglard       IsNotEqualOrUpdate);
166*3f982cf4SFabien Sanglard   updated.erase(new_end, updated.end());
167*3f982cf4SFabien Sanglard 
168*3f982cf4SFabien Sanglard   // Use the updated elements to find all created and deleted elements.
169*3f982cf4SFabien Sanglard   std::vector<absl::optional<DnsSdInstanceEndpoint>> created(
170*3f982cf4SFabien Sanglard       created_or_updated.size());
171*3f982cf4SFabien Sanglard   new_end = std::set_difference(
172*3f982cf4SFabien Sanglard       created_or_updated.begin(), created_or_updated.end(), updated.begin(),
173*3f982cf4SFabien Sanglard       updated.end(), created.begin(), IsNotEqualOrUpdate);
174*3f982cf4SFabien Sanglard   created.erase(new_end, created.end());
175*3f982cf4SFabien Sanglard 
176*3f982cf4SFabien Sanglard   std::vector<absl::optional<DnsSdInstanceEndpoint>> deleted(
177*3f982cf4SFabien Sanglard       deleted_or_updated.size());
178*3f982cf4SFabien Sanglard   new_end = std::set_difference(
179*3f982cf4SFabien Sanglard       deleted_or_updated.begin(), deleted_or_updated.end(), updated.begin(),
180*3f982cf4SFabien Sanglard       updated.end(), deleted.begin(), IsNotEqualOrUpdate);
181*3f982cf4SFabien Sanglard   deleted.erase(new_end, deleted.end());
182*3f982cf4SFabien Sanglard 
183*3f982cf4SFabien Sanglard   // Return the calculated elements back to the caller in the output variables.
184*3f982cf4SFabien Sanglard   created_out->reserve(created.size());
185*3f982cf4SFabien Sanglard   for (absl::optional<DnsSdInstanceEndpoint>& endpoint : created) {
186*3f982cf4SFabien Sanglard     OSP_DCHECK(endpoint.has_value());
187*3f982cf4SFabien Sanglard     created_out->push_back(std::move(endpoint.value()));
188*3f982cf4SFabien Sanglard   }
189*3f982cf4SFabien Sanglard 
190*3f982cf4SFabien Sanglard   updated_out->reserve(updated.size());
191*3f982cf4SFabien Sanglard   for (absl::optional<DnsSdInstanceEndpoint>& endpoint : updated) {
192*3f982cf4SFabien Sanglard     OSP_DCHECK(endpoint.has_value());
193*3f982cf4SFabien Sanglard     updated_out->push_back(std::move(endpoint.value()));
194*3f982cf4SFabien Sanglard   }
195*3f982cf4SFabien Sanglard 
196*3f982cf4SFabien Sanglard   deleted_out->reserve(deleted.size());
197*3f982cf4SFabien Sanglard   for (absl::optional<DnsSdInstanceEndpoint>& endpoint : deleted) {
198*3f982cf4SFabien Sanglard     OSP_DCHECK(endpoint.has_value());
199*3f982cf4SFabien Sanglard     deleted_out->push_back(std::move(endpoint.value()));
200*3f982cf4SFabien Sanglard   }
201*3f982cf4SFabien Sanglard }
202*3f982cf4SFabien Sanglard 
203*3f982cf4SFabien Sanglard }  // namespace
204*3f982cf4SFabien Sanglard 
QuerierImpl(MdnsService * mdns_querier,TaskRunner * task_runner,ReportingClient * reporting_client,const NetworkInterfaceConfig * network_config)205*3f982cf4SFabien Sanglard QuerierImpl::QuerierImpl(MdnsService* mdns_querier,
206*3f982cf4SFabien Sanglard                          TaskRunner* task_runner,
207*3f982cf4SFabien Sanglard                          ReportingClient* reporting_client,
208*3f982cf4SFabien Sanglard                          const NetworkInterfaceConfig* network_config)
209*3f982cf4SFabien Sanglard     : mdns_querier_(mdns_querier),
210*3f982cf4SFabien Sanglard       task_runner_(task_runner),
211*3f982cf4SFabien Sanglard       reporting_client_(reporting_client) {
212*3f982cf4SFabien Sanglard   OSP_DCHECK(mdns_querier_);
213*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_);
214*3f982cf4SFabien Sanglard 
215*3f982cf4SFabien Sanglard   OSP_DCHECK(network_config);
216*3f982cf4SFabien Sanglard   graph_ = DnsDataGraph::Create(network_config->network_interface());
217*3f982cf4SFabien Sanglard }
218*3f982cf4SFabien Sanglard 
219*3f982cf4SFabien Sanglard QuerierImpl::~QuerierImpl() = default;
220*3f982cf4SFabien Sanglard 
StartQuery(const std::string & service,Callback * callback)221*3f982cf4SFabien Sanglard void QuerierImpl::StartQuery(const std::string& service, Callback* callback) {
222*3f982cf4SFabien Sanglard   OSP_DCHECK(callback);
223*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
224*3f982cf4SFabien Sanglard 
225*3f982cf4SFabien Sanglard   OSP_DVLOG << "Starting DNS-SD query for service '" << service << "'";
226*3f982cf4SFabien Sanglard 
227*3f982cf4SFabien Sanglard   // Start tracking the new callback
228*3f982cf4SFabien Sanglard   const ServiceKey key(service, kLocalDomain);
229*3f982cf4SFabien Sanglard   auto it = callback_map_.emplace(key, std::vector<Callback*>{}).first;
230*3f982cf4SFabien Sanglard   it->second.push_back(callback);
231*3f982cf4SFabien Sanglard 
232*3f982cf4SFabien Sanglard   const DomainName domain = key.GetName();
233*3f982cf4SFabien Sanglard 
234*3f982cf4SFabien Sanglard   // If the associated service isn't tracked yet, start tracking it and start
235*3f982cf4SFabien Sanglard   // queries for the relevant PTR records.
236*3f982cf4SFabien Sanglard   if (!graph_->IsTracked(domain)) {
237*3f982cf4SFabien Sanglard     std::function<void(const DomainName&)> mdns_query(
238*3f982cf4SFabien Sanglard         [this, &domain](const DomainName& changed_domain) {
239*3f982cf4SFabien Sanglard           OSP_DVLOG << "Starting mDNS query for '" << domain.ToString() << "'";
240*3f982cf4SFabien Sanglard           mdns_querier_->StartQuery(changed_domain, DnsType::kANY,
241*3f982cf4SFabien Sanglard                                     DnsClass::kANY, this);
242*3f982cf4SFabien Sanglard         });
243*3f982cf4SFabien Sanglard     graph_->StartTracking(domain, std::move(mdns_query));
244*3f982cf4SFabien Sanglard     return;
245*3f982cf4SFabien Sanglard   }
246*3f982cf4SFabien Sanglard 
247*3f982cf4SFabien Sanglard   // Else, it's already being tracked so fire creation callbacks for any already
248*3f982cf4SFabien Sanglard   // found service instances.
249*3f982cf4SFabien Sanglard   const std::vector<ErrorOr<DnsSdInstanceEndpoint>> endpoints =
250*3f982cf4SFabien Sanglard       graph_->CreateEndpoints(DnsDataGraph::DomainGroup::kPtr, domain);
251*3f982cf4SFabien Sanglard   for (const auto& endpoint : endpoints) {
252*3f982cf4SFabien Sanglard     if (endpoint.is_value()) {
253*3f982cf4SFabien Sanglard       callback->OnEndpointCreated(endpoint.value());
254*3f982cf4SFabien Sanglard     }
255*3f982cf4SFabien Sanglard   }
256*3f982cf4SFabien Sanglard }
257*3f982cf4SFabien Sanglard 
StopQuery(const std::string & service,Callback * callback)258*3f982cf4SFabien Sanglard void QuerierImpl::StopQuery(const std::string& service, Callback* callback) {
259*3f982cf4SFabien Sanglard   OSP_DCHECK(callback);
260*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
261*3f982cf4SFabien Sanglard 
262*3f982cf4SFabien Sanglard   OSP_DVLOG << "Stopping DNS-SD query for service '" << service << "'";
263*3f982cf4SFabien Sanglard 
264*3f982cf4SFabien Sanglard   ServiceKey key(service, kLocalDomain);
265*3f982cf4SFabien Sanglard   const auto callbacks_it = callback_map_.find(key);
266*3f982cf4SFabien Sanglard   if (callbacks_it == callback_map_.end()) {
267*3f982cf4SFabien Sanglard     return;
268*3f982cf4SFabien Sanglard   }
269*3f982cf4SFabien Sanglard 
270*3f982cf4SFabien Sanglard   std::vector<Callback*>& callbacks = callbacks_it->second;
271*3f982cf4SFabien Sanglard   const auto it = std::find(callbacks.begin(), callbacks.end(), callback);
272*3f982cf4SFabien Sanglard   if (it == callbacks.end()) {
273*3f982cf4SFabien Sanglard     return;
274*3f982cf4SFabien Sanglard   }
275*3f982cf4SFabien Sanglard 
276*3f982cf4SFabien Sanglard   callbacks.erase(it);
277*3f982cf4SFabien Sanglard   if (callbacks.empty()) {
278*3f982cf4SFabien Sanglard     callback_map_.erase(callbacks_it);
279*3f982cf4SFabien Sanglard 
280*3f982cf4SFabien Sanglard     ServiceKey key(service, kLocalDomain);
281*3f982cf4SFabien Sanglard     DomainName domain = key.GetName();
282*3f982cf4SFabien Sanglard 
283*3f982cf4SFabien Sanglard     std::function<void(const DomainName&)> stop_mdns_query(
284*3f982cf4SFabien Sanglard         [this](const DomainName& changed_domain) {
285*3f982cf4SFabien Sanglard           OSP_DVLOG << "Stopping mDNS query for '" << changed_domain.ToString()
286*3f982cf4SFabien Sanglard                     << "'";
287*3f982cf4SFabien Sanglard           mdns_querier_->StopQuery(changed_domain, DnsType::kANY,
288*3f982cf4SFabien Sanglard                                    DnsClass::kANY, this);
289*3f982cf4SFabien Sanglard         });
290*3f982cf4SFabien Sanglard     graph_->StopTracking(domain, std::move(stop_mdns_query));
291*3f982cf4SFabien Sanglard   }
292*3f982cf4SFabien Sanglard }
293*3f982cf4SFabien Sanglard 
IsQueryRunning(const std::string & service) const294*3f982cf4SFabien Sanglard bool QuerierImpl::IsQueryRunning(const std::string& service) const {
295*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
296*3f982cf4SFabien Sanglard   const ServiceKey key(service, kLocalDomain);
297*3f982cf4SFabien Sanglard   return graph_->IsTracked(key.GetName());
298*3f982cf4SFabien Sanglard }
299*3f982cf4SFabien Sanglard 
ReinitializeQueries(const std::string & service)300*3f982cf4SFabien Sanglard void QuerierImpl::ReinitializeQueries(const std::string& service) {
301*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
302*3f982cf4SFabien Sanglard 
303*3f982cf4SFabien Sanglard   OSP_DVLOG << "Re-initializing query for service '" << service << "'";
304*3f982cf4SFabien Sanglard 
305*3f982cf4SFabien Sanglard   const ServiceKey key(service, kLocalDomain);
306*3f982cf4SFabien Sanglard   const DomainName domain = key.GetName();
307*3f982cf4SFabien Sanglard 
308*3f982cf4SFabien Sanglard   std::function<void(const DomainName&)> start_callback(
309*3f982cf4SFabien Sanglard       [this](const DomainName& domain) {
310*3f982cf4SFabien Sanglard         mdns_querier_->StartQuery(domain, DnsType::kANY, DnsClass::kANY, this);
311*3f982cf4SFabien Sanglard       });
312*3f982cf4SFabien Sanglard   std::function<void(const DomainName&)> stop_callback(
313*3f982cf4SFabien Sanglard       [this](const DomainName& domain) {
314*3f982cf4SFabien Sanglard         mdns_querier_->StopQuery(domain, DnsType::kANY, DnsClass::kANY, this);
315*3f982cf4SFabien Sanglard       });
316*3f982cf4SFabien Sanglard   graph_->StopTracking(domain, std::move(stop_callback));
317*3f982cf4SFabien Sanglard 
318*3f982cf4SFabien Sanglard   // Restart top-level queries.
319*3f982cf4SFabien Sanglard   mdns_querier_->ReinitializeQueries(GetPtrQueryInfo(key).name);
320*3f982cf4SFabien Sanglard 
321*3f982cf4SFabien Sanglard   graph_->StartTracking(domain, std::move(start_callback));
322*3f982cf4SFabien Sanglard }
323*3f982cf4SFabien Sanglard 
OnRecordChanged(const MdnsRecord & record,RecordChangedEvent event)324*3f982cf4SFabien Sanglard std::vector<PendingQueryChange> QuerierImpl::OnRecordChanged(
325*3f982cf4SFabien Sanglard     const MdnsRecord& record,
326*3f982cf4SFabien Sanglard     RecordChangedEvent event) {
327*3f982cf4SFabien Sanglard   OSP_DCHECK(task_runner_->IsRunningOnTaskRunner());
328*3f982cf4SFabien Sanglard 
329*3f982cf4SFabien Sanglard   OSP_DVLOG << "Record " << record.ToString()
330*3f982cf4SFabien Sanglard             << " has received change of type '" << event << "'";
331*3f982cf4SFabien Sanglard 
332*3f982cf4SFabien Sanglard   std::function<void(Error)> log = [this](Error error) mutable {
333*3f982cf4SFabien Sanglard     reporting_client_->OnRecoverableError(
334*3f982cf4SFabien Sanglard         Error(Error::Code::kProcessReceivedRecordFailure));
335*3f982cf4SFabien Sanglard   };
336*3f982cf4SFabien Sanglard 
337*3f982cf4SFabien Sanglard   // Get the details to use for calling CreateEndpoints(). Special case PTR
338*3f982cf4SFabien Sanglard   // records to optimize performance.
339*3f982cf4SFabien Sanglard   const DomainName& create_endpoints_domain =
340*3f982cf4SFabien Sanglard       record.dns_type() != DnsType::kPTR
341*3f982cf4SFabien Sanglard           ? record.name()
342*3f982cf4SFabien Sanglard           : absl::get<PtrRecordRdata>(record.rdata()).ptr_domain();
343*3f982cf4SFabien Sanglard   const DnsDataGraph::DomainGroup create_endpoints_group =
344*3f982cf4SFabien Sanglard       record.dns_type() != DnsType::kPTR
345*3f982cf4SFabien Sanglard           ? DnsDataGraph::GetDomainGroup(record)
346*3f982cf4SFabien Sanglard           : DnsDataGraph::DomainGroup::kSrvAndTxt;
347*3f982cf4SFabien Sanglard 
348*3f982cf4SFabien Sanglard   // Get the current set of DnsSdInstanceEndpoints prior to this change. Special
349*3f982cf4SFabien Sanglard   // case PTR records to avoid iterating over unrelated child domains.
350*3f982cf4SFabien Sanglard   std::vector<ErrorOr<DnsSdInstanceEndpoint>> old_endpoints_or_errors =
351*3f982cf4SFabien Sanglard       graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain);
352*3f982cf4SFabien Sanglard 
353*3f982cf4SFabien Sanglard   // Apply the changes, creating a list of all pending changes that should be
354*3f982cf4SFabien Sanglard   // applied afterwards.
355*3f982cf4SFabien Sanglard   ErrorOr<std::vector<PendingQueryChange>> pending_changes_or_error =
356*3f982cf4SFabien Sanglard       ApplyRecordChanges(record, event);
357*3f982cf4SFabien Sanglard   if (pending_changes_or_error.is_error()) {
358*3f982cf4SFabien Sanglard     OSP_DVLOG << "Failed to apply changes for " << record.dns_type()
359*3f982cf4SFabien Sanglard               << " record change of type " << event << " with error "
360*3f982cf4SFabien Sanglard               << pending_changes_or_error.error();
361*3f982cf4SFabien Sanglard     log(std::move(pending_changes_or_error.error()));
362*3f982cf4SFabien Sanglard     return {};
363*3f982cf4SFabien Sanglard   }
364*3f982cf4SFabien Sanglard   std::vector<PendingQueryChange>& pending_changes =
365*3f982cf4SFabien Sanglard       pending_changes_or_error.value();
366*3f982cf4SFabien Sanglard 
367*3f982cf4SFabien Sanglard   // Get the new set of DnsSdInstanceEndpoints following this change.
368*3f982cf4SFabien Sanglard   std::vector<ErrorOr<DnsSdInstanceEndpoint>> new_endpoints_or_errors =
369*3f982cf4SFabien Sanglard       graph_->CreateEndpoints(create_endpoints_group, create_endpoints_domain);
370*3f982cf4SFabien Sanglard 
371*3f982cf4SFabien Sanglard   // Return early if the resulting sets are equal. This will frequently be the
372*3f982cf4SFabien Sanglard   // case, especially when both sets are empty.
373*3f982cf4SFabien Sanglard   std::sort(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end());
374*3f982cf4SFabien Sanglard   std::sort(new_endpoints_or_errors.begin(), new_endpoints_or_errors.end());
375*3f982cf4SFabien Sanglard   if (old_endpoints_or_errors.size() == new_endpoints_or_errors.size() &&
376*3f982cf4SFabien Sanglard       std::equal(old_endpoints_or_errors.begin(), old_endpoints_or_errors.end(),
377*3f982cf4SFabien Sanglard                  new_endpoints_or_errors.begin())) {
378*3f982cf4SFabien Sanglard     return pending_changes;
379*3f982cf4SFabien Sanglard   }
380*3f982cf4SFabien Sanglard 
381*3f982cf4SFabien Sanglard   // Log all errors and erase them.
382*3f982cf4SFabien Sanglard   ProcessErrors(&old_endpoints_or_errors, &new_endpoints_or_errors,
383*3f982cf4SFabien Sanglard                 std::move(log));
384*3f982cf4SFabien Sanglard   const size_t old_endpoints_or_errors_count = old_endpoints_or_errors.size();
385*3f982cf4SFabien Sanglard   const size_t new_endpoints_or_errors_count = new_endpoints_or_errors.size();
386*3f982cf4SFabien Sanglard   std::vector<DnsSdInstanceEndpoint> old_endpoints =
387*3f982cf4SFabien Sanglard       GetValues(std::move(old_endpoints_or_errors));
388*3f982cf4SFabien Sanglard   std::vector<DnsSdInstanceEndpoint> new_endpoints =
389*3f982cf4SFabien Sanglard       GetValues(std::move(new_endpoints_or_errors));
390*3f982cf4SFabien Sanglard   OSP_DCHECK_EQ(old_endpoints.size(), old_endpoints_or_errors_count);
391*3f982cf4SFabien Sanglard   OSP_DCHECK_EQ(new_endpoints.size(), new_endpoints_or_errors_count);
392*3f982cf4SFabien Sanglard 
393*3f982cf4SFabien Sanglard   // Calculate the changes and call callbacks.
394*3f982cf4SFabien Sanglard   //
395*3f982cf4SFabien Sanglard   // NOTE: As the input sets are expected to be small, the generated sets will
396*3f982cf4SFabien Sanglard   // also be small.
397*3f982cf4SFabien Sanglard   std::vector<DnsSdInstanceEndpoint> created;
398*3f982cf4SFabien Sanglard   std::vector<DnsSdInstanceEndpoint> updated;
399*3f982cf4SFabien Sanglard   std::vector<DnsSdInstanceEndpoint> deleted;
400*3f982cf4SFabien Sanglard   CalculateChangeSets(std::move(old_endpoints), std::move(new_endpoints),
401*3f982cf4SFabien Sanglard                       &created, &updated, &deleted);
402*3f982cf4SFabien Sanglard 
403*3f982cf4SFabien Sanglard   InvokeChangeCallbacks(std::move(created), std::move(updated),
404*3f982cf4SFabien Sanglard                         std::move(deleted));
405*3f982cf4SFabien Sanglard   return pending_changes;
406*3f982cf4SFabien Sanglard }
407*3f982cf4SFabien Sanglard 
InvokeChangeCallbacks(std::vector<DnsSdInstanceEndpoint> created,std::vector<DnsSdInstanceEndpoint> updated,std::vector<DnsSdInstanceEndpoint> deleted)408*3f982cf4SFabien Sanglard void QuerierImpl::InvokeChangeCallbacks(
409*3f982cf4SFabien Sanglard     std::vector<DnsSdInstanceEndpoint> created,
410*3f982cf4SFabien Sanglard     std::vector<DnsSdInstanceEndpoint> updated,
411*3f982cf4SFabien Sanglard     std::vector<DnsSdInstanceEndpoint> deleted) {
412*3f982cf4SFabien Sanglard   // Find an endpoint and use it to create the key, or return if there is none.
413*3f982cf4SFabien Sanglard   DnsSdInstanceEndpoint* some_endpoint;
414*3f982cf4SFabien Sanglard   if (!created.empty()) {
415*3f982cf4SFabien Sanglard     some_endpoint = &created.front();
416*3f982cf4SFabien Sanglard   } else if (!updated.empty()) {
417*3f982cf4SFabien Sanglard     some_endpoint = &updated.front();
418*3f982cf4SFabien Sanglard   } else if (!deleted.empty()) {
419*3f982cf4SFabien Sanglard     some_endpoint = &deleted.front();
420*3f982cf4SFabien Sanglard   } else {
421*3f982cf4SFabien Sanglard     return;
422*3f982cf4SFabien Sanglard   }
423*3f982cf4SFabien Sanglard   ServiceKey key(some_endpoint->service_id(), some_endpoint->domain_id());
424*3f982cf4SFabien Sanglard 
425*3f982cf4SFabien Sanglard   // Find all callbacks.
426*3f982cf4SFabien Sanglard   auto it = callback_map_.find(key);
427*3f982cf4SFabien Sanglard   if (it == callback_map_.end()) {
428*3f982cf4SFabien Sanglard     return;
429*3f982cf4SFabien Sanglard   }
430*3f982cf4SFabien Sanglard 
431*3f982cf4SFabien Sanglard   // Call relevant callbacks.
432*3f982cf4SFabien Sanglard   std::vector<Callback*>& callbacks = it->second;
433*3f982cf4SFabien Sanglard   for (Callback* callback : callbacks) {
434*3f982cf4SFabien Sanglard     for (const DnsSdInstanceEndpoint& endpoint : created) {
435*3f982cf4SFabien Sanglard       callback->OnEndpointCreated(endpoint);
436*3f982cf4SFabien Sanglard     }
437*3f982cf4SFabien Sanglard     for (const DnsSdInstanceEndpoint& endpoint : updated) {
438*3f982cf4SFabien Sanglard       callback->OnEndpointUpdated(endpoint);
439*3f982cf4SFabien Sanglard     }
440*3f982cf4SFabien Sanglard     for (const DnsSdInstanceEndpoint& endpoint : deleted) {
441*3f982cf4SFabien Sanglard       callback->OnEndpointDeleted(endpoint);
442*3f982cf4SFabien Sanglard     }
443*3f982cf4SFabien Sanglard   }
444*3f982cf4SFabien Sanglard }
445*3f982cf4SFabien Sanglard 
ApplyRecordChanges(const MdnsRecord & record,RecordChangedEvent event)446*3f982cf4SFabien Sanglard ErrorOr<std::vector<PendingQueryChange>> QuerierImpl::ApplyRecordChanges(
447*3f982cf4SFabien Sanglard     const MdnsRecord& record,
448*3f982cf4SFabien Sanglard     RecordChangedEvent event) {
449*3f982cf4SFabien Sanglard   std::vector<PendingQueryChange> pending_changes;
450*3f982cf4SFabien Sanglard   std::function<void(DomainName)> creation_callback(
451*3f982cf4SFabien Sanglard       [this, &pending_changes](DomainName domain) mutable {
452*3f982cf4SFabien Sanglard         pending_changes.push_back({std::move(domain), DnsType::kANY,
453*3f982cf4SFabien Sanglard                                    DnsClass::kANY, this,
454*3f982cf4SFabien Sanglard                                    PendingQueryChange::kStartQuery});
455*3f982cf4SFabien Sanglard       });
456*3f982cf4SFabien Sanglard   std::function<void(DomainName)> deletion_callback(
457*3f982cf4SFabien Sanglard       [this, &pending_changes](DomainName domain) mutable {
458*3f982cf4SFabien Sanglard         pending_changes.push_back({std::move(domain), DnsType::kANY,
459*3f982cf4SFabien Sanglard                                    DnsClass::kANY, this,
460*3f982cf4SFabien Sanglard                                    PendingQueryChange::kStopQuery});
461*3f982cf4SFabien Sanglard       });
462*3f982cf4SFabien Sanglard   Error result =
463*3f982cf4SFabien Sanglard       graph_->ApplyDataRecordChange(record, event, std::move(creation_callback),
464*3f982cf4SFabien Sanglard                                     std::move(deletion_callback));
465*3f982cf4SFabien Sanglard   if (!result.ok()) {
466*3f982cf4SFabien Sanglard     return result;
467*3f982cf4SFabien Sanglard   }
468*3f982cf4SFabien Sanglard 
469*3f982cf4SFabien Sanglard   return pending_changes;
470*3f982cf4SFabien Sanglard }
471*3f982cf4SFabien Sanglard 
472*3f982cf4SFabien Sanglard }  // namespace discovery
473*3f982cf4SFabien Sanglard }  // namespace openscreen
474