xref: /aosp_15_r20/external/cronet/net/dns/dns_transaction.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/dns/dns_transaction.h"
6 
7 #include <cstdint>
8 #include <memory>
9 #include <optional>
10 #include <set>
11 #include <string>
12 #include <string_view>
13 #include <unordered_map>
14 #include <utility>
15 #include <vector>
16 
17 #include "base/base64url.h"
18 #include "base/containers/circular_deque.h"
19 #include "base/containers/span.h"
20 #include "base/functional/bind.h"
21 #include "base/functional/callback_helpers.h"
22 #include "base/location.h"
23 #include "base/memory/ptr_util.h"
24 #include "base/memory/raw_ptr.h"
25 #include "base/memory/ref_counted.h"
26 #include "base/memory/safe_ref.h"
27 #include "base/memory/weak_ptr.h"
28 #include "base/metrics/histogram_functions.h"
29 #include "base/metrics/histogram_macros.h"
30 #include "base/numerics/byte_conversions.h"
31 #include "base/rand_util.h"
32 #include "base/ranges/algorithm.h"
33 #include "base/strings/stringprintf.h"
34 #include "base/task/sequenced_task_runner.h"
35 #include "base/task/single_thread_task_runner.h"
36 #include "base/threading/thread_checker.h"
37 #include "base/timer/elapsed_timer.h"
38 #include "base/timer/timer.h"
39 #include "base/values.h"
40 #include "build/build_config.h"
41 #include "net/base/backoff_entry.h"
42 #include "net/base/completion_once_callback.h"
43 #include "net/base/elements_upload_data_stream.h"
44 #include "net/base/idempotency.h"
45 #include "net/base/io_buffer.h"
46 #include "net/base/ip_address.h"
47 #include "net/base/ip_endpoint.h"
48 #include "net/base/load_flags.h"
49 #include "net/base/net_errors.h"
50 #include "net/base/upload_bytes_element_reader.h"
51 #include "net/dns/dns_config.h"
52 #include "net/dns/dns_names_util.h"
53 #include "net/dns/dns_query.h"
54 #include "net/dns/dns_response.h"
55 #include "net/dns/dns_response_result_extractor.h"
56 #include "net/dns/dns_server_iterator.h"
57 #include "net/dns/dns_session.h"
58 #include "net/dns/dns_udp_tracker.h"
59 #include "net/dns/dns_util.h"
60 #include "net/dns/host_cache.h"
61 #include "net/dns/host_resolver_internal_result.h"
62 #include "net/dns/public/dns_over_https_config.h"
63 #include "net/dns/public/dns_over_https_server_config.h"
64 #include "net/dns/public/dns_protocol.h"
65 #include "net/dns/public/dns_query_type.h"
66 #include "net/dns/public/secure_dns_policy.h"
67 #include "net/dns/resolve_context.h"
68 #include "net/http/http_request_headers.h"
69 #include "net/log/net_log.h"
70 #include "net/log/net_log_capture_mode.h"
71 #include "net/log/net_log_event_type.h"
72 #include "net/log/net_log_source.h"
73 #include "net/log/net_log_values.h"
74 #include "net/log/net_log_with_source.h"
75 #include "net/socket/client_socket_factory.h"
76 #include "net/socket/datagram_client_socket.h"
77 #include "net/socket/stream_socket.h"
78 #include "net/third_party/uri_template/uri_template.h"
79 #include "net/traffic_annotation/network_traffic_annotation.h"
80 #include "net/url_request/url_request.h"
81 #include "net/url_request/url_request_context.h"
82 #include "net/url_request/url_request_context_builder.h"
83 #include "url/url_constants.h"
84 
85 namespace net {
86 
87 namespace {
88 
89 constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation =
90     net::DefineNetworkTrafficAnnotation("dns_transaction", R"(
91         semantics {
92           sender: "DNS Transaction"
93           description:
94             "DNS Transaction implements a stub DNS resolver as defined in RFC "
95             "1034."
96           trigger:
97             "Any network request that may require DNS resolution, including "
98             "navigations, connecting to a proxy server, detecting proxy "
99             "settings, getting proxy config, certificate checking, and more."
100           data:
101             "Domain name that needs resolution."
102           destination: OTHER
103           destination_other:
104             "The connection is made to a DNS server based on user's network "
105             "settings."
106         }
107         policy {
108           cookies_allowed: NO
109           setting:
110             "This feature cannot be disabled. Without DNS Transactions Chrome "
111             "cannot resolve host names."
112           policy_exception_justification:
113             "Essential for Chrome's navigation."
114         })");
115 
116 const char kDnsOverHttpResponseContentType[] = "application/dns-message";
117 
118 // The maximum size of the DNS message for DoH, per
119 // https://datatracker.ietf.org/doc/html/rfc8484#section-6
120 const int64_t kDnsOverHttpResponseMaximumSize = 65535;
121 
122 // Count labels in the fully-qualified name in DNS format.
CountLabels(base::span<const uint8_t> name)123 int CountLabels(base::span<const uint8_t> name) {
124   size_t count = 0;
125   for (size_t i = 0; i < name.size() && name[i]; i += name[i] + 1)
126     ++count;
127   return count;
128 }
129 
IsIPLiteral(const std::string & hostname)130 bool IsIPLiteral(const std::string& hostname) {
131   IPAddress ip;
132   return ip.AssignFromIPLiteral(hostname);
133 }
134 
NetLogStartParams(const std::string & hostname,uint16_t qtype)135 base::Value::Dict NetLogStartParams(const std::string& hostname,
136                                     uint16_t qtype) {
137   base::Value::Dict dict;
138   dict.Set("hostname", hostname);
139   dict.Set("query_type", qtype);
140   return dict;
141 }
142 
143 // ----------------------------------------------------------------------------
144 
145 // A single asynchronous DNS exchange, which consists of sending out a
146 // DNS query, waiting for a response, and returning the response that it
147 // matches. Logging is done in the socket and in the outer DnsTransaction.
148 class DnsAttempt {
149  public:
DnsAttempt(size_t server_index)150   explicit DnsAttempt(size_t server_index) : server_index_(server_index) {}
151 
152   DnsAttempt(const DnsAttempt&) = delete;
153   DnsAttempt& operator=(const DnsAttempt&) = delete;
154 
155   virtual ~DnsAttempt() = default;
156   // Starts the attempt. Returns ERR_IO_PENDING if cannot complete synchronously
157   // and calls |callback| upon completion.
158   virtual int Start(CompletionOnceCallback callback) = 0;
159 
160   // Returns the query of this attempt.
161   virtual const DnsQuery* GetQuery() const = 0;
162 
163   // Returns the response or NULL if has not received a matching response from
164   // the server.
165   virtual const DnsResponse* GetResponse() const = 0;
166 
167   virtual base::Value GetRawResponseBufferForLog() const = 0;
168 
169   // Returns the net log bound to the source of the socket.
170   virtual const NetLogWithSource& GetSocketNetLog() const = 0;
171 
172   // Returns the index of the destination server within DnsConfig::nameservers
173   // (or DnsConfig::dns_over_https_servers for secure transactions).
server_index() const174   size_t server_index() const { return server_index_; }
175 
176   // Returns a Value representing the received response, along with a reference
177   // to the NetLog source source of the UDP socket used.  The request must have
178   // completed before this is called.
NetLogResponseParams(NetLogCaptureMode capture_mode) const179   base::Value::Dict NetLogResponseParams(NetLogCaptureMode capture_mode) const {
180     base::Value::Dict dict;
181 
182     if (GetResponse()) {
183       DCHECK(GetResponse()->IsValid());
184       dict.Set("rcode", GetResponse()->rcode());
185       dict.Set("answer_count", static_cast<int>(GetResponse()->answer_count()));
186       dict.Set("additional_answer_count",
187                static_cast<int>(GetResponse()->additional_answer_count()));
188     }
189 
190     GetSocketNetLog().source().AddToEventParameters(dict);
191 
192     if (capture_mode == NetLogCaptureMode::kEverything) {
193       dict.Set("response_buffer", GetRawResponseBufferForLog());
194     }
195 
196     return dict;
197   }
198 
199   // True if current attempt is pending (waiting for server response).
200   virtual bool IsPending() const = 0;
201 
202  private:
203   const size_t server_index_;
204 };
205 
206 class DnsUDPAttempt : public DnsAttempt {
207  public:
DnsUDPAttempt(size_t server_index,std::unique_ptr<DatagramClientSocket> socket,const IPEndPoint & server,std::unique_ptr<DnsQuery> query,DnsUdpTracker * udp_tracker)208   DnsUDPAttempt(size_t server_index,
209                 std::unique_ptr<DatagramClientSocket> socket,
210                 const IPEndPoint& server,
211                 std::unique_ptr<DnsQuery> query,
212                 DnsUdpTracker* udp_tracker)
213       : DnsAttempt(server_index),
214         socket_(std::move(socket)),
215         server_(server),
216         query_(std::move(query)),
217         udp_tracker_(udp_tracker) {}
218 
219   DnsUDPAttempt(const DnsUDPAttempt&) = delete;
220   DnsUDPAttempt& operator=(const DnsUDPAttempt&) = delete;
221 
222   // DnsAttempt methods.
223 
Start(CompletionOnceCallback callback)224   int Start(CompletionOnceCallback callback) override {
225     DCHECK_EQ(STATE_NONE, next_state_);
226     callback_ = std::move(callback);
227     start_time_ = base::TimeTicks::Now();
228     next_state_ = STATE_CONNECT_COMPLETE;
229 
230     int rv = socket_->ConnectAsync(
231         server_,
232         base::BindOnce(&DnsUDPAttempt::OnIOComplete, base::Unretained(this)));
233     if (rv == ERR_IO_PENDING) {
234       return rv;
235     }
236     return DoLoop(rv);
237   }
238 
GetQuery() const239   const DnsQuery* GetQuery() const override { return query_.get(); }
240 
GetResponse() const241   const DnsResponse* GetResponse() const override {
242     const DnsResponse* resp = response_.get();
243     return (resp != nullptr && resp->IsValid()) ? resp : nullptr;
244   }
245 
GetRawResponseBufferForLog() const246   base::Value GetRawResponseBufferForLog() const override {
247     if (!response_)
248       return base::Value();
249     return NetLogBinaryValue(response_->io_buffer()->data(), read_size_);
250   }
251 
GetSocketNetLog() const252   const NetLogWithSource& GetSocketNetLog() const override {
253     return socket_->NetLog();
254   }
255 
IsPending() const256   bool IsPending() const override { return next_state_ != STATE_NONE; }
257 
258  private:
259   enum State {
260     STATE_CONNECT_COMPLETE,
261     STATE_SEND_QUERY,
262     STATE_SEND_QUERY_COMPLETE,
263     STATE_READ_RESPONSE,
264     STATE_READ_RESPONSE_COMPLETE,
265     STATE_NONE,
266   };
267 
DoLoop(int result)268   int DoLoop(int result) {
269     CHECK_NE(STATE_NONE, next_state_);
270     int rv = result;
271     do {
272       State state = next_state_;
273       next_state_ = STATE_NONE;
274       switch (state) {
275         case STATE_CONNECT_COMPLETE:
276           rv = DoConnectComplete(rv);
277           break;
278         case STATE_SEND_QUERY:
279           rv = DoSendQuery(rv);
280           break;
281         case STATE_SEND_QUERY_COMPLETE:
282           rv = DoSendQueryComplete(rv);
283           break;
284         case STATE_READ_RESPONSE:
285           rv = DoReadResponse();
286           break;
287         case STATE_READ_RESPONSE_COMPLETE:
288           rv = DoReadResponseComplete(rv);
289           break;
290         default:
291           NOTREACHED();
292           break;
293       }
294     } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
295 
296     if (rv != ERR_IO_PENDING)
297       DCHECK_EQ(STATE_NONE, next_state_);
298 
299     return rv;
300   }
301 
DoConnectComplete(int rv)302   int DoConnectComplete(int rv) {
303     if (rv != OK) {
304       DVLOG(1) << "Failed to connect socket: " << rv;
305       udp_tracker_->RecordConnectionError(rv);
306       return ERR_CONNECTION_REFUSED;
307     }
308     next_state_ = STATE_SEND_QUERY;
309     IPEndPoint local_address;
310     if (socket_->GetLocalAddress(&local_address) == OK)
311       udp_tracker_->RecordQuery(local_address.port(), query_->id());
312     return OK;
313   }
314 
DoSendQuery(int rv)315   int DoSendQuery(int rv) {
316     DCHECK_NE(ERR_IO_PENDING, rv);
317     if (rv < 0)
318       return rv;
319     next_state_ = STATE_SEND_QUERY_COMPLETE;
320     return socket_->Write(
321         query_->io_buffer(), query_->io_buffer()->size(),
322         base::BindOnce(&DnsUDPAttempt::OnIOComplete, base::Unretained(this)),
323         kTrafficAnnotation);
324   }
325 
DoSendQueryComplete(int rv)326   int DoSendQueryComplete(int rv) {
327     DCHECK_NE(ERR_IO_PENDING, rv);
328     if (rv < 0)
329       return rv;
330 
331     // Writing to UDP should not result in a partial datagram.
332     if (rv != query_->io_buffer()->size())
333       return ERR_MSG_TOO_BIG;
334 
335     next_state_ = STATE_READ_RESPONSE;
336     return OK;
337   }
338 
DoReadResponse()339   int DoReadResponse() {
340     next_state_ = STATE_READ_RESPONSE_COMPLETE;
341     response_ = std::make_unique<DnsResponse>();
342     return socket_->Read(
343         response_->io_buffer(), response_->io_buffer_size(),
344         base::BindOnce(&DnsUDPAttempt::OnIOComplete, base::Unretained(this)));
345   }
346 
DoReadResponseComplete(int rv)347   int DoReadResponseComplete(int rv) {
348     DCHECK_NE(ERR_IO_PENDING, rv);
349     if (rv < 0)
350       return rv;
351     read_size_ = rv;
352 
353     bool parse_result = response_->InitParse(rv, *query_);
354     if (response_->id())
355       udp_tracker_->RecordResponseId(query_->id(), response_->id().value());
356 
357     if (!parse_result)
358       return ERR_DNS_MALFORMED_RESPONSE;
359     if (response_->flags() & dns_protocol::kFlagTC)
360       return ERR_DNS_SERVER_REQUIRES_TCP;
361     if (response_->rcode() == dns_protocol::kRcodeNXDOMAIN)
362       return ERR_NAME_NOT_RESOLVED;
363     if (response_->rcode() != dns_protocol::kRcodeNOERROR)
364       return ERR_DNS_SERVER_FAILED;
365 
366     return OK;
367   }
368 
OnIOComplete(int rv)369   void OnIOComplete(int rv) {
370     rv = DoLoop(rv);
371     if (rv != ERR_IO_PENDING)
372       std::move(callback_).Run(rv);
373   }
374 
375   State next_state_ = STATE_NONE;
376   base::TimeTicks start_time_;
377 
378   std::unique_ptr<DatagramClientSocket> socket_;
379   IPEndPoint server_;
380   std::unique_ptr<DnsQuery> query_;
381 
382   // Should be owned by the DnsSession, to which the transaction should own a
383   // reference.
384   const raw_ptr<DnsUdpTracker> udp_tracker_;
385 
386   std::unique_ptr<DnsResponse> response_;
387   int read_size_ = 0;
388 
389   CompletionOnceCallback callback_;
390 };
391 
392 class DnsHTTPAttempt : public DnsAttempt, public URLRequest::Delegate {
393  public:
DnsHTTPAttempt(size_t doh_server_index,std::unique_ptr<DnsQuery> query,const string & server_template,const GURL & gurl_without_parameters,bool use_post,URLRequestContext * url_request_context,const IsolationInfo & isolation_info,RequestPriority request_priority_,bool is_probe)394   DnsHTTPAttempt(size_t doh_server_index,
395                  std::unique_ptr<DnsQuery> query,
396                  const string& server_template,
397                  const GURL& gurl_without_parameters,
398                  bool use_post,
399                  URLRequestContext* url_request_context,
400                  const IsolationInfo& isolation_info,
401                  RequestPriority request_priority_,
402                  bool is_probe)
403       : DnsAttempt(doh_server_index),
404         query_(std::move(query)),
405         net_log_(NetLogWithSource::Make(NetLog::Get(),
406                                         NetLogSourceType::DNS_OVER_HTTPS)) {
407     GURL url;
408     if (use_post) {
409       // Set url for a POST request
410       url = gurl_without_parameters;
411     } else {
412       // Set url for a GET request
413       std::string url_string;
414       std::unordered_map<string, string> parameters;
415       std::string encoded_query;
416       base::Base64UrlEncode(std::string_view(query_->io_buffer()->data(),
417                                              query_->io_buffer()->size()),
418                             base::Base64UrlEncodePolicy::OMIT_PADDING,
419                             &encoded_query);
420       parameters.emplace("dns", encoded_query);
421       uri_template::Expand(server_template, parameters, &url_string);
422       url = GURL(url_string);
423     }
424 
425     net_log_.BeginEvent(NetLogEventType::DOH_URL_REQUEST, [&] {
426       if (is_probe) {
427         return NetLogStartParams("(probe)", query_->qtype());
428       }
429       std::optional<std::string> hostname =
430           dns_names_util::NetworkToDottedName(query_->qname());
431       DCHECK(hostname.has_value());
432       return NetLogStartParams(*hostname, query_->qtype());
433     });
434 
435     HttpRequestHeaders extra_request_headers;
436     extra_request_headers.SetHeader(HttpRequestHeaders::kAccept,
437                                     kDnsOverHttpResponseContentType);
438     // Send minimal request headers where possible.
439     extra_request_headers.SetHeader(HttpRequestHeaders::kAcceptLanguage, "*");
440     extra_request_headers.SetHeader(HttpRequestHeaders::kUserAgent, "Chrome");
441     extra_request_headers.SetHeader(HttpRequestHeaders::kAcceptEncoding,
442                                     "identity");
443 
444     DCHECK(url_request_context);
445     request_ = url_request_context->CreateRequest(
446         url, request_priority_, this,
447         net::DefineNetworkTrafficAnnotation("dns_over_https", R"(
448         semantics {
449           sender: "DNS over HTTPS"
450           description: "Domain name resolution over HTTPS"
451           trigger: "User enters a navigates to a domain or Chrome otherwise "
452                    "makes a connection to a domain whose IP address isn't cached"
453           data: "The domain name that is being requested"
454           destination: OTHER
455           destination_other: "The user configured DNS over HTTPS server, which"
456                              "may be dns.google.com"
457         }
458         policy {
459           cookies_allowed: NO
460           setting:
461             "You can configure this feature via that 'dns_over_https_servers' and"
462             "'dns_over_https.method' prefs. Empty lists imply this feature is"
463             "disabled"
464           policy_exception_justification: "Experimental feature that"
465                                           "is disabled by default"
466         }
467       )"),
468         /*is_for_websockets=*/false, net_log_.source());
469 
470     if (use_post) {
471       request_->set_method("POST");
472       request_->SetIdempotency(IDEMPOTENT);
473       std::unique_ptr<UploadElementReader> reader =
474           std::make_unique<UploadBytesElementReader>(
475               query_->io_buffer()->data(), query_->io_buffer()->size());
476       request_->set_upload(
477           ElementsUploadDataStream::CreateWithReader(std::move(reader), 0));
478       extra_request_headers.SetHeader(HttpRequestHeaders::kContentType,
479                                       kDnsOverHttpResponseContentType);
480     }
481 
482     request_->SetExtraRequestHeaders(extra_request_headers);
483     // Apply special policy to DNS lookups for for a DoH server hostname to
484     // avoid deadlock and enable the use of preconfigured IP addresses.
485     request_->SetSecureDnsPolicy(SecureDnsPolicy::kBootstrap);
486     request_->SetLoadFlags(request_->load_flags() | LOAD_DISABLE_CACHE |
487                            LOAD_BYPASS_PROXY);
488     request_->set_allow_credentials(false);
489     request_->set_isolation_info(isolation_info);
490   }
491 
492   DnsHTTPAttempt(const DnsHTTPAttempt&) = delete;
493   DnsHTTPAttempt& operator=(const DnsHTTPAttempt&) = delete;
494 
495   // DnsAttempt overrides.
496 
497   int Start(CompletionOnceCallback callback) override {
498     callback_ = std::move(callback);
499     // Start the request asynchronously to avoid reentrancy in
500     // the network stack.
501     base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
502         FROM_HERE, base::BindOnce(&DnsHTTPAttempt::StartAsync,
503                                   weak_factory_.GetWeakPtr()));
504     return ERR_IO_PENDING;
505   }
506 
507   const DnsQuery* GetQuery() const override { return query_.get(); }
508   const DnsResponse* GetResponse() const override {
509     const DnsResponse* resp = response_.get();
510     return (resp != nullptr && resp->IsValid()) ? resp : nullptr;
511   }
512   base::Value GetRawResponseBufferForLog() const override {
513     if (!response_)
514       return base::Value();
515 
516     return NetLogBinaryValue(response_->io_buffer()->data(),
517                              response_->io_buffer_size());
518   }
519   const NetLogWithSource& GetSocketNetLog() const override { return net_log_; }
520 
521   // URLRequest::Delegate overrides
522 
523   void OnResponseStarted(net::URLRequest* request, int net_error) override {
524     DCHECK_NE(net::ERR_IO_PENDING, net_error);
525     std::string content_type;
526     if (net_error != OK) {
527       // Update the error code if there was an issue resolving the secure
528       // server hostname.
529       if (IsHostnameResolutionError(net_error))
530         net_error = ERR_DNS_SECURE_RESOLVER_HOSTNAME_RESOLUTION_FAILED;
531       ResponseCompleted(net_error);
532       return;
533     }
534 
535     if (request_->GetResponseCode() != 200 ||
536         !request->response_headers()->GetMimeType(&content_type) ||
537         0 != content_type.compare(kDnsOverHttpResponseContentType)) {
538       ResponseCompleted(ERR_DNS_MALFORMED_RESPONSE);
539       return;
540     }
541 
542     buffer_ = base::MakeRefCounted<GrowableIOBuffer>();
543 
544     if (request->response_headers()->HasHeader(
545             HttpRequestHeaders::kContentLength)) {
546       if (request_->response_headers()->GetContentLength() >
547           kDnsOverHttpResponseMaximumSize) {
548         ResponseCompleted(ERR_DNS_MALFORMED_RESPONSE);
549         return;
550       }
551 
552       buffer_->SetCapacity(request_->response_headers()->GetContentLength() +
553                            1);
554     } else {
555       buffer_->SetCapacity(kDnsOverHttpResponseMaximumSize + 1);
556     }
557 
558     DCHECK(buffer_->data());
559     DCHECK_GT(buffer_->capacity(), 0);
560 
561     int bytes_read =
562         request_->Read(buffer_.get(), buffer_->RemainingCapacity());
563 
564     // If IO is pending, wait for the URLRequest to call OnReadCompleted.
565     if (bytes_read == net::ERR_IO_PENDING)
566       return;
567 
568     OnReadCompleted(request_.get(), bytes_read);
569   }
570 
571   void OnReceivedRedirect(URLRequest* request,
572                           const RedirectInfo& redirect_info,
573                           bool* defer_redirect) override {
574     // Section 5 of RFC 8484 states that scheme must be https.
575     if (!redirect_info.new_url.SchemeIs(url::kHttpsScheme)) {
576       request->Cancel();
577     }
578   }
579 
580   void OnReadCompleted(net::URLRequest* request, int bytes_read) override {
581     // bytes_read can be an error.
582     if (bytes_read < 0) {
583       ResponseCompleted(bytes_read);
584       return;
585     }
586 
587     DCHECK_GE(bytes_read, 0);
588 
589     if (bytes_read > 0) {
590       if (buffer_->offset() + bytes_read > kDnsOverHttpResponseMaximumSize) {
591         ResponseCompleted(ERR_DNS_MALFORMED_RESPONSE);
592         return;
593       }
594 
595       buffer_->set_offset(buffer_->offset() + bytes_read);
596 
597       if (buffer_->RemainingCapacity() == 0) {
598         buffer_->SetCapacity(buffer_->capacity() + 16384);  // Grow by 16kb.
599       }
600 
601       DCHECK(buffer_->data());
602       DCHECK_GT(buffer_->capacity(), 0);
603 
604       int read_result =
605           request_->Read(buffer_.get(), buffer_->RemainingCapacity());
606 
607       // If IO is pending, wait for the URLRequest to call OnReadCompleted.
608       if (read_result == net::ERR_IO_PENDING)
609         return;
610 
611       if (read_result <= 0) {
612         OnReadCompleted(request_.get(), read_result);
613       } else {
614         // Else, trigger OnReadCompleted asynchronously to avoid starving the IO
615         // thread in case the URLRequest can provide data synchronously.
616         base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
617             FROM_HERE, base::BindOnce(&DnsHTTPAttempt::OnReadCompleted,
618                                       weak_factory_.GetWeakPtr(),
619                                       request_.get(), read_result));
620       }
621     } else {
622       // URLRequest reported an EOF. Call ResponseCompleted.
623       DCHECK_EQ(0, bytes_read);
624       ResponseCompleted(net::OK);
625     }
626   }
627 
628   bool IsPending() const override { return !callback_.is_null(); }
629 
630  private:
631   void StartAsync() {
632     DCHECK(request_);
633     request_->Start();
634   }
635 
636   void ResponseCompleted(int net_error) {
637     request_.reset();
638     std::move(callback_).Run(CompleteResponse(net_error));
639   }
640 
641   int CompleteResponse(int net_error) {
642     net_log_.EndEventWithNetErrorCode(NetLogEventType::DOH_URL_REQUEST,
643                                       net_error);
644     DCHECK_NE(net::ERR_IO_PENDING, net_error);
645     if (net_error != OK) {
646       return net_error;
647     }
648     if (!buffer_.get() || 0 == buffer_->capacity())
649       return ERR_DNS_MALFORMED_RESPONSE;
650 
651     size_t size = buffer_->offset();
652     buffer_->set_offset(0);
653     if (size == 0u)
654       return ERR_DNS_MALFORMED_RESPONSE;
655     response_ = std::make_unique<DnsResponse>(buffer_, size);
656     if (!response_->InitParse(size, *query_))
657       return ERR_DNS_MALFORMED_RESPONSE;
658     if (response_->rcode() == dns_protocol::kRcodeNXDOMAIN)
659       return ERR_NAME_NOT_RESOLVED;
660     if (response_->rcode() != dns_protocol::kRcodeNOERROR)
661       return ERR_DNS_SERVER_FAILED;
662     return OK;
663   }
664 
665   scoped_refptr<GrowableIOBuffer> buffer_;
666   std::unique_ptr<DnsQuery> query_;
667   CompletionOnceCallback callback_;
668   std::unique_ptr<DnsResponse> response_;
669   std::unique_ptr<URLRequest> request_;
670   NetLogWithSource net_log_;
671 
672   base::WeakPtrFactory<DnsHTTPAttempt> weak_factory_{this};
673 };
674 
675 void ConstructDnsHTTPAttempt(DnsSession* session,
676                              size_t doh_server_index,
677                              base::span<const uint8_t> qname,
678                              uint16_t qtype,
679                              const OptRecordRdata* opt_rdata,
680                              std::vector<std::unique_ptr<DnsAttempt>>* attempts,
681                              URLRequestContext* url_request_context,
682                              const IsolationInfo& isolation_info,
683                              RequestPriority request_priority,
684                              bool is_probe) {
685   DCHECK(url_request_context);
686 
687   std::unique_ptr<DnsQuery> query;
688   if (attempts->empty()) {
689     query =
690         std::make_unique<DnsQuery>(/*id=*/0, qname, qtype, opt_rdata,
691                                    DnsQuery::PaddingStrategy::BLOCK_LENGTH_128);
692   } else {
693     query = std::make_unique<DnsQuery>(*attempts->at(0)->GetQuery());
694   }
695 
696   DCHECK_LT(doh_server_index, session->config().doh_config.servers().size());
697   const DnsOverHttpsServerConfig& doh_server =
698       session->config().doh_config.servers()[doh_server_index];
699   GURL gurl_without_parameters(
700       GetURLFromTemplateWithoutParameters(doh_server.server_template()));
701   attempts->push_back(std::make_unique<DnsHTTPAttempt>(
702       doh_server_index, std::move(query), doh_server.server_template(),
703       gurl_without_parameters, doh_server.use_post(), url_request_context,
704       isolation_info, request_priority, is_probe));
705 }
706 
707 class DnsTCPAttempt : public DnsAttempt {
708  public:
709   DnsTCPAttempt(size_t server_index,
710                 std::unique_ptr<StreamSocket> socket,
711                 std::unique_ptr<DnsQuery> query)
712       : DnsAttempt(server_index),
713         socket_(std::move(socket)),
714         query_(std::move(query)),
715         length_buffer_(
716             base::MakeRefCounted<IOBufferWithSize>(sizeof(uint16_t))) {}
717 
718   DnsTCPAttempt(const DnsTCPAttempt&) = delete;
719   DnsTCPAttempt& operator=(const DnsTCPAttempt&) = delete;
720 
721   // DnsAttempt:
722   int Start(CompletionOnceCallback callback) override {
723     DCHECK_EQ(STATE_NONE, next_state_);
724     callback_ = std::move(callback);
725     start_time_ = base::TimeTicks::Now();
726     next_state_ = STATE_CONNECT_COMPLETE;
727     int rv = socket_->Connect(
728         base::BindOnce(&DnsTCPAttempt::OnIOComplete, base::Unretained(this)));
729     if (rv == ERR_IO_PENDING) {
730       return rv;
731     }
732     return DoLoop(rv);
733   }
734 
735   const DnsQuery* GetQuery() const override { return query_.get(); }
736 
737   const DnsResponse* GetResponse() const override {
738     const DnsResponse* resp = response_.get();
739     return (resp != nullptr && resp->IsValid()) ? resp : nullptr;
740   }
741 
742   base::Value GetRawResponseBufferForLog() const override {
743     if (!response_)
744       return base::Value();
745 
746     return NetLogBinaryValue(response_->io_buffer()->data(),
747                              response_->io_buffer_size());
748   }
749 
750   const NetLogWithSource& GetSocketNetLog() const override {
751     return socket_->NetLog();
752   }
753 
754   bool IsPending() const override { return next_state_ != STATE_NONE; }
755 
756  private:
757   enum State {
758     STATE_CONNECT_COMPLETE,
759     STATE_SEND_LENGTH,
760     STATE_SEND_QUERY,
761     STATE_READ_LENGTH,
762     STATE_READ_LENGTH_COMPLETE,
763     STATE_READ_RESPONSE,
764     STATE_READ_RESPONSE_COMPLETE,
765     STATE_NONE,
766   };
767 
768   int DoLoop(int result) {
769     CHECK_NE(STATE_NONE, next_state_);
770     int rv = result;
771     do {
772       State state = next_state_;
773       next_state_ = STATE_NONE;
774       switch (state) {
775         case STATE_CONNECT_COMPLETE:
776           rv = DoConnectComplete(rv);
777           break;
778         case STATE_SEND_LENGTH:
779           rv = DoSendLength(rv);
780           break;
781         case STATE_SEND_QUERY:
782           rv = DoSendQuery(rv);
783           break;
784         case STATE_READ_LENGTH:
785           rv = DoReadLength(rv);
786           break;
787         case STATE_READ_LENGTH_COMPLETE:
788           rv = DoReadLengthComplete(rv);
789           break;
790         case STATE_READ_RESPONSE:
791           rv = DoReadResponse(rv);
792           break;
793         case STATE_READ_RESPONSE_COMPLETE:
794           rv = DoReadResponseComplete(rv);
795           break;
796         default:
797           NOTREACHED();
798           break;
799       }
800     } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
801 
802     if (rv != ERR_IO_PENDING)
803       DCHECK_EQ(STATE_NONE, next_state_);
804 
805     return rv;
806   }
807 
808   int DoConnectComplete(int rv) {
809     DCHECK_NE(ERR_IO_PENDING, rv);
810     if (rv < 0)
811       return rv;
812 
813     uint16_t query_size = static_cast<uint16_t>(query_->io_buffer()->size());
814     if (static_cast<int>(query_size) != query_->io_buffer()->size())
815       return ERR_FAILED;
816     base::as_writable_bytes(length_buffer_->span())
817         .copy_from(base::U16ToBigEndian(query_size));
818     buffer_ = base::MakeRefCounted<DrainableIOBuffer>(length_buffer_,
819                                                       length_buffer_->size());
820     next_state_ = STATE_SEND_LENGTH;
821     return OK;
822   }
823 
824   int DoSendLength(int rv) {
825     DCHECK_NE(ERR_IO_PENDING, rv);
826     if (rv < 0)
827       return rv;
828 
829     buffer_->DidConsume(rv);
830     if (buffer_->BytesRemaining() > 0) {
831       next_state_ = STATE_SEND_LENGTH;
832       return socket_->Write(
833           buffer_.get(), buffer_->BytesRemaining(),
834           base::BindOnce(&DnsTCPAttempt::OnIOComplete, base::Unretained(this)),
835           kTrafficAnnotation);
836     }
837     buffer_ = base::MakeRefCounted<DrainableIOBuffer>(
838         query_->io_buffer(), query_->io_buffer()->size());
839     next_state_ = STATE_SEND_QUERY;
840     return OK;
841   }
842 
843   int DoSendQuery(int rv) {
844     DCHECK_NE(ERR_IO_PENDING, rv);
845     if (rv < 0)
846       return rv;
847 
848     buffer_->DidConsume(rv);
849     if (buffer_->BytesRemaining() > 0) {
850       next_state_ = STATE_SEND_QUERY;
851       return socket_->Write(
852           buffer_.get(), buffer_->BytesRemaining(),
853           base::BindOnce(&DnsTCPAttempt::OnIOComplete, base::Unretained(this)),
854           kTrafficAnnotation);
855     }
856     buffer_ = base::MakeRefCounted<DrainableIOBuffer>(length_buffer_,
857                                                       length_buffer_->size());
858     next_state_ = STATE_READ_LENGTH;
859     return OK;
860   }
861 
862   int DoReadLength(int rv) {
863     DCHECK_EQ(OK, rv);
864 
865     next_state_ = STATE_READ_LENGTH_COMPLETE;
866     return ReadIntoBuffer();
867   }
868 
869   int DoReadLengthComplete(int rv) {
870     DCHECK_NE(ERR_IO_PENDING, rv);
871     if (rv < 0)
872       return rv;
873     if (rv == 0)
874       return ERR_CONNECTION_CLOSED;
875 
876     buffer_->DidConsume(rv);
877     if (buffer_->BytesRemaining() > 0) {
878       next_state_ = STATE_READ_LENGTH;
879       return OK;
880     }
881 
882     response_length_ = base::U16FromBigEndian(
883         base::as_bytes(length_buffer_->span().first<2u>()));
884     // Check if advertised response is too short. (Optimization only.)
885     if (response_length_ < query_->io_buffer()->size())
886       return ERR_DNS_MALFORMED_RESPONSE;
887     response_ = std::make_unique<DnsResponse>(response_length_);
888     buffer_ = base::MakeRefCounted<DrainableIOBuffer>(response_->io_buffer(),
889                                                       response_length_);
890     next_state_ = STATE_READ_RESPONSE;
891     return OK;
892   }
893 
894   int DoReadResponse(int rv) {
895     DCHECK_EQ(OK, rv);
896 
897     next_state_ = STATE_READ_RESPONSE_COMPLETE;
898     return ReadIntoBuffer();
899   }
900 
901   int DoReadResponseComplete(int rv) {
902     DCHECK_NE(ERR_IO_PENDING, rv);
903     if (rv < 0)
904       return rv;
905     if (rv == 0)
906       return ERR_CONNECTION_CLOSED;
907 
908     buffer_->DidConsume(rv);
909     if (buffer_->BytesRemaining() > 0) {
910       next_state_ = STATE_READ_RESPONSE;
911       return OK;
912     }
913     DCHECK_GT(buffer_->BytesConsumed(), 0);
914     if (!response_->InitParse(buffer_->BytesConsumed(), *query_))
915       return ERR_DNS_MALFORMED_RESPONSE;
916     if (response_->flags() & dns_protocol::kFlagTC)
917       return ERR_UNEXPECTED;
918     // TODO(szym): Frankly, none of these are expected.
919     if (response_->rcode() == dns_protocol::kRcodeNXDOMAIN)
920       return ERR_NAME_NOT_RESOLVED;
921     if (response_->rcode() != dns_protocol::kRcodeNOERROR)
922       return ERR_DNS_SERVER_FAILED;
923 
924     return OK;
925   }
926 
927   void OnIOComplete(int rv) {
928     rv = DoLoop(rv);
929     if (rv != ERR_IO_PENDING)
930       std::move(callback_).Run(rv);
931   }
932 
933   int ReadIntoBuffer() {
934     return socket_->Read(
935         buffer_.get(), buffer_->BytesRemaining(),
936         base::BindOnce(&DnsTCPAttempt::OnIOComplete, base::Unretained(this)));
937   }
938 
939   State next_state_ = STATE_NONE;
940   base::TimeTicks start_time_;
941 
942   std::unique_ptr<StreamSocket> socket_;
943   std::unique_ptr<DnsQuery> query_;
944   scoped_refptr<IOBufferWithSize> length_buffer_;
945   scoped_refptr<DrainableIOBuffer> buffer_;
946 
947   uint16_t response_length_ = 0;
948   std::unique_ptr<DnsResponse> response_;
949 
950   CompletionOnceCallback callback_;
951 };
952 
953 // ----------------------------------------------------------------------------
954 
955 const net::BackoffEntry::Policy kProbeBackoffPolicy = {
956     // Apply exponential backoff rules after the first error.
957     0,
958     // Begin with a 1s delay between probes.
959     1000,
960     // Increase the delay between consecutive probes by a factor of 1.5.
961     1.5,
962     // Fuzz the delay between consecutive probes between 80%-100% of the
963     // calculated time.
964     0.2,
965     // Cap the maximum delay between consecutive probes at 1 hour.
966     1000 * 60 * 60,
967     // Never expire entries.
968     -1,
969     // Do not apply an initial delay.
970     false,
971 };
972 
973 // Probe runner that continually sends test queries (with backoff) to DoH
974 // servers to determine availability.
975 //
976 // Expected to be contained in request classes owned externally to HostResolver,
977 // so no assumptions are made regarding cancellation compared to the DnsSession
978 // or ResolveContext. Instead, uses WeakPtrs to gracefully clean itself up and
979 // stop probing after session or context destruction.
980 class DnsOverHttpsProbeRunner : public DnsProbeRunner {
981  public:
982   DnsOverHttpsProbeRunner(base::WeakPtr<DnsSession> session,
983                           base::WeakPtr<ResolveContext> context)
984       : session_(session), context_(context) {
985     DCHECK(session_);
986     DCHECK(!session_->config().doh_config.servers().empty());
987     DCHECK(context_);
988 
989     std::optional<std::vector<uint8_t>> qname =
990         dns_names_util::DottedNameToNetwork(kDohProbeHostname);
991     DCHECK(qname.has_value());
992     formatted_probe_qname_ = std::move(qname).value();
993 
994     for (size_t i = 0; i < session_->config().doh_config.servers().size();
995          i++) {
996       probe_stats_list_.push_back(nullptr);
997     }
998   }
999 
1000   ~DnsOverHttpsProbeRunner() override = default;
1001 
1002   void Start(bool network_change) override {
1003     DCHECK(session_);
1004     DCHECK(context_);
1005 
1006     const auto& config = session_->config().doh_config;
1007     // Start probe sequences for any servers where it is not currently running.
1008     for (size_t i = 0; i < config.servers().size(); i++) {
1009       if (!probe_stats_list_[i]) {
1010         probe_stats_list_[i] = std::make_unique<ProbeStats>();
1011         ContinueProbe(i, probe_stats_list_[i]->weak_factory.GetWeakPtr(),
1012                       network_change,
1013                       base::TimeTicks::Now() /* sequence_start_time */);
1014       }
1015     }
1016   }
1017 
1018   base::TimeDelta GetDelayUntilNextProbeForTest(
1019       size_t doh_server_index) const override {
1020     if (doh_server_index >= probe_stats_list_.size() ||
1021         !probe_stats_list_[doh_server_index])
1022       return base::TimeDelta();
1023 
1024     return probe_stats_list_[doh_server_index]
1025         ->backoff_entry->GetTimeUntilRelease();
1026   }
1027 
1028  private:
1029   struct ProbeStats {
1030     ProbeStats()
1031         : backoff_entry(
1032               std::make_unique<net::BackoffEntry>(&kProbeBackoffPolicy)) {}
1033 
1034     std::unique_ptr<net::BackoffEntry> backoff_entry;
1035     std::vector<std::unique_ptr<DnsAttempt>> probe_attempts;
1036     base::WeakPtrFactory<ProbeStats> weak_factory{this};
1037   };
1038 
1039   void ContinueProbe(size_t doh_server_index,
1040                      base::WeakPtr<ProbeStats> probe_stats,
1041                      bool network_change,
1042                      base::TimeTicks sequence_start_time) {
1043     // If the DnsSession or ResolveContext has been destroyed, no reason to
1044     // continue probing.
1045     if (!session_ || !context_) {
1046       probe_stats_list_.clear();
1047       return;
1048     }
1049 
1050     // If the ProbeStats for which this probe was scheduled has been deleted,
1051     // don't continue to send probes.
1052     if (!probe_stats)
1053       return;
1054 
1055     // Cancel the probe sequence for this server if the server is already
1056     // available.
1057     if (context_->GetDohServerAvailability(doh_server_index, session_.get())) {
1058       probe_stats_list_[doh_server_index] = nullptr;
1059       return;
1060     }
1061 
1062     // Schedule a new probe assuming this one will fail. The newly scheduled
1063     // probe will not run if an earlier probe has already succeeded. Probes may
1064     // take awhile to fail, which is why we schedule the next one here rather
1065     // than on probe completion.
1066     DCHECK(probe_stats);
1067     DCHECK(probe_stats->backoff_entry);
1068     probe_stats->backoff_entry->InformOfRequest(false /* success */);
1069     base::SequencedTaskRunner::GetCurrentDefault()->PostDelayedTask(
1070         FROM_HERE,
1071         base::BindOnce(&DnsOverHttpsProbeRunner::ContinueProbe,
1072                        weak_ptr_factory_.GetWeakPtr(), doh_server_index,
1073                        probe_stats, network_change, sequence_start_time),
1074         probe_stats->backoff_entry->GetTimeUntilRelease());
1075 
1076     unsigned attempt_number = probe_stats->probe_attempts.size();
1077     ConstructDnsHTTPAttempt(
1078         session_.get(), doh_server_index, formatted_probe_qname_,
1079         dns_protocol::kTypeA, /*opt_rdata=*/nullptr,
1080         &probe_stats->probe_attempts, context_->url_request_context(),
1081         context_->isolation_info(), RequestPriority::DEFAULT_PRIORITY,
1082         /*is_probe=*/true);
1083 
1084     DnsAttempt* probe_attempt = probe_stats->probe_attempts.back().get();
1085     probe_attempt->Start(base::BindOnce(
1086         &DnsOverHttpsProbeRunner::ProbeComplete, weak_ptr_factory_.GetWeakPtr(),
1087         attempt_number, doh_server_index, std::move(probe_stats),
1088         network_change, sequence_start_time,
1089         base::TimeTicks::Now() /* query_start_time */));
1090   }
1091 
ProbeComplete(unsigned attempt_number,size_t doh_server_index,base::WeakPtr<ProbeStats> probe_stats,bool network_change,base::TimeTicks sequence_start_time,base::TimeTicks query_start_time,int rv)1092   void ProbeComplete(unsigned attempt_number,
1093                      size_t doh_server_index,
1094                      base::WeakPtr<ProbeStats> probe_stats,
1095                      bool network_change,
1096                      base::TimeTicks sequence_start_time,
1097                      base::TimeTicks query_start_time,
1098                      int rv) {
1099     bool success = false;
1100     while (probe_stats && session_ && context_) {
1101       if (rv != OK) {
1102         // The DoH probe queries don't go through the standard DnsAttempt path,
1103         // so the ServerStats have not been updated yet.
1104         context_->RecordServerFailure(doh_server_index, /*is_doh_server=*/true,
1105                                       rv, session_.get());
1106         break;
1107       }
1108       // Check that the response parses properly before considering it a
1109       // success.
1110       DCHECK_LT(attempt_number, probe_stats->probe_attempts.size());
1111       const DnsAttempt* attempt =
1112           probe_stats->probe_attempts[attempt_number].get();
1113       const DnsResponse* response = attempt->GetResponse();
1114       if (response) {
1115         DnsResponseResultExtractor extractor(*response);
1116         DnsResponseResultExtractor::ResultsOrError results =
1117             extractor.ExtractDnsResults(
1118                 DnsQueryType::A,
1119                 /*original_domain_name=*/kDohProbeHostname,
1120                 /*request_port=*/0);
1121 
1122         if (results.has_value()) {
1123           for (const auto& result : results.value()) {
1124             if (result->type() == HostResolverInternalResult::Type::kData &&
1125                 !result->AsData().endpoints().empty()) {
1126               context_->RecordServerSuccess(
1127                   doh_server_index, /*is_doh_server=*/true, session_.get());
1128               context_->RecordRtt(doh_server_index, /*is_doh_server=*/true,
1129                                   base::TimeTicks::Now() - query_start_time, rv,
1130                                   session_.get());
1131               success = true;
1132 
1133               // Do not delete the ProbeStats and cancel the probe sequence. It
1134               // will cancel itself on the next scheduled ContinueProbe() call
1135               // if the server is still available. This way, the backoff
1136               // schedule will be maintained if a server quickly becomes
1137               // unavailable again before that scheduled call.
1138               break;
1139             }
1140           }
1141         }
1142       }
1143       if (!success) {
1144         context_->RecordServerFailure(
1145             doh_server_index, /*is_doh_server=*/true,
1146             /*rv=*/ERR_DNS_SECURE_PROBE_RECORD_INVALID, session_.get());
1147       }
1148       break;
1149     }
1150 
1151     base::UmaHistogramLongTimes(
1152         base::JoinString({"Net.DNS.ProbeSequence",
1153                           network_change ? "NetworkChange" : "ConfigChange",
1154                           success ? "Success" : "Failure", "AttemptTime"},
1155                          "."),
1156         base::TimeTicks::Now() - sequence_start_time);
1157   }
1158 
1159   base::WeakPtr<DnsSession> session_;
1160   base::WeakPtr<ResolveContext> context_;
1161   std::vector<uint8_t> formatted_probe_qname_;
1162 
1163   // List of ProbeStats, one for each DoH server, indexed by the DoH server
1164   // config index.
1165   std::vector<std::unique_ptr<ProbeStats>> probe_stats_list_;
1166 
1167   base::WeakPtrFactory<DnsOverHttpsProbeRunner> weak_ptr_factory_{this};
1168 };
1169 
1170 // ----------------------------------------------------------------------------
1171 
1172 // Implements DnsTransaction. Configuration is supplied by DnsSession.
1173 // The suffix list is built according to the DnsConfig from the session.
1174 // The fallback period for each DnsUDPAttempt is given by
1175 // ResolveContext::NextClassicFallbackPeriod(). The first server to attempt on
1176 // each query is given by ResolveContext::NextFirstServerIndex, and the order is
1177 // round-robin afterwards. Each server is attempted DnsConfig::attempts times.
1178 class DnsTransactionImpl final : public DnsTransaction {
1179  public:
DnsTransactionImpl(DnsSession * session,std::string hostname,uint16_t qtype,const NetLogWithSource & parent_net_log,const OptRecordRdata * opt_rdata,bool secure,SecureDnsMode secure_dns_mode,ResolveContext * resolve_context,bool fast_timeout)1180   DnsTransactionImpl(DnsSession* session,
1181                      std::string hostname,
1182                      uint16_t qtype,
1183                      const NetLogWithSource& parent_net_log,
1184                      const OptRecordRdata* opt_rdata,
1185                      bool secure,
1186                      SecureDnsMode secure_dns_mode,
1187                      ResolveContext* resolve_context,
1188                      bool fast_timeout)
1189       : session_(session),
1190         hostname_(std::move(hostname)),
1191         qtype_(qtype),
1192         opt_rdata_(opt_rdata),
1193         secure_(secure),
1194         secure_dns_mode_(secure_dns_mode),
1195         fast_timeout_(fast_timeout),
1196         net_log_(NetLogWithSource::Make(NetLog::Get(),
1197                                         NetLogSourceType::DNS_TRANSACTION)),
1198         resolve_context_(resolve_context->AsSafeRef()) {
1199     DCHECK(session_.get());
1200     DCHECK(!hostname_.empty());
1201     DCHECK(!IsIPLiteral(hostname_));
1202     parent_net_log.AddEventReferencingSource(NetLogEventType::DNS_TRANSACTION,
1203                                              net_log_.source());
1204   }
1205 
1206   DnsTransactionImpl(const DnsTransactionImpl&) = delete;
1207   DnsTransactionImpl& operator=(const DnsTransactionImpl&) = delete;
1208 
~DnsTransactionImpl()1209   ~DnsTransactionImpl() override {
1210     DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1211     if (!callback_.is_null()) {
1212       net_log_.EndEventWithNetErrorCode(NetLogEventType::DNS_TRANSACTION,
1213                                         ERR_ABORTED);
1214     }  // otherwise logged in DoCallback or Start
1215   }
1216 
GetHostname() const1217   const std::string& GetHostname() const override {
1218     DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1219     return hostname_;
1220   }
1221 
GetType() const1222   uint16_t GetType() const override {
1223     DCHECK_CALLED_ON_VALID_THREAD(thread_checker_);
1224     return qtype_;
1225   }
1226 
Start(ResponseCallback callback)1227   void Start(ResponseCallback callback) override {
1228     DCHECK(!callback.is_null());
1229     DCHECK(callback_.is_null());
1230     DCHECK(attempts_.empty());
1231 
1232     callback_ = std::move(callback);
1233 
1234     net_log_.BeginEvent(NetLogEventType::DNS_TRANSACTION,
1235                         [&] { return NetLogStartParams(hostname_, qtype_); });
1236     time_from_start_ = std::make_unique<base::ElapsedTimer>();
1237     AttemptResult result(PrepareSearch(), nullptr);
1238     if (result.rv == OK) {
1239       qnames_initial_size_ = qnames_.size();
1240       result = ProcessAttemptResult(StartQuery());
1241     }
1242 
1243     // Must always return result asynchronously, to avoid reentrancy.
1244     if (result.rv != ERR_IO_PENDING) {
1245       // Clear all other non-completed attempts. They are no longer needed and
1246       // they may interfere with this posted result.
1247       ClearAttempts(result.attempt);
1248       base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
1249           FROM_HERE, base::BindOnce(&DnsTransactionImpl::DoCallback,
1250                                     weak_ptr_factory_.GetWeakPtr(), result));
1251     }
1252   }
1253 
SetRequestPriority(RequestPriority priority)1254   void SetRequestPriority(RequestPriority priority) override {
1255     request_priority_ = priority;
1256   }
1257 
1258  private:
1259   // Wrapper for the result of a DnsUDPAttempt.
1260   struct AttemptResult {
1261     AttemptResult() = default;
AttemptResultnet::__anond9c6f5e90111::DnsTransactionImpl::AttemptResult1262     AttemptResult(int rv, const DnsAttempt* attempt)
1263         : rv(rv), attempt(attempt) {}
1264 
1265     int rv;
1266     raw_ptr<const DnsAttempt, AcrossTasksDanglingUntriaged> attempt;
1267   };
1268 
1269   // Used in UMA (DNS.AttemptType). Do not renumber or remove values.
1270   enum class DnsAttemptType {
1271     kUdp = 0,
1272     kTcpLowEntropy = 1,
1273     kTcpTruncationRetry = 2,
1274     kHttp = 3,
1275     kMaxValue = kHttp,
1276   };
1277 
1278   // Prepares |qnames_| according to the DnsConfig.
PrepareSearch()1279   int PrepareSearch() {
1280     const DnsConfig& config = session_->config();
1281 
1282     std::optional<std::vector<uint8_t>> labeled_qname =
1283         dns_names_util::DottedNameToNetwork(
1284             hostname_,
1285             /*require_valid_internet_hostname=*/true);
1286     if (!labeled_qname.has_value())
1287       return ERR_INVALID_ARGUMENT;
1288 
1289     if (hostname_.back() == '.') {
1290       // It's a fully-qualified name, no suffix search.
1291       qnames_.push_back(std::move(labeled_qname).value());
1292       return OK;
1293     }
1294 
1295     int ndots = CountLabels(labeled_qname.value()) - 1;
1296 
1297     if (ndots > 0 && !config.append_to_multi_label_name) {
1298       qnames_.push_back(std::move(labeled_qname).value());
1299       return OK;
1300     }
1301 
1302     // Set true when `labeled_qname` is put on the list.
1303     bool had_qname = false;
1304 
1305     if (ndots >= config.ndots) {
1306       qnames_.push_back(labeled_qname.value());
1307       had_qname = true;
1308     }
1309 
1310     for (const auto& suffix : config.search) {
1311       std::optional<std::vector<uint8_t>> qname =
1312           dns_names_util::DottedNameToNetwork(
1313               hostname_ + "." + suffix,
1314               /*require_valid_internet_hostname=*/true);
1315       // Ignore invalid (too long) combinations.
1316       if (!qname.has_value())
1317         continue;
1318       if (qname.value().size() == labeled_qname.value().size()) {
1319         if (had_qname)
1320           continue;
1321         had_qname = true;
1322       }
1323       qnames_.push_back(std::move(qname).value());
1324     }
1325 
1326     if (ndots > 0 && !had_qname)
1327       qnames_.push_back(std::move(labeled_qname).value());
1328 
1329     return qnames_.empty() ? ERR_DNS_SEARCH_EMPTY : OK;
1330   }
1331 
DoCallback(AttemptResult result)1332   void DoCallback(AttemptResult result) {
1333     DCHECK_NE(ERR_IO_PENDING, result.rv);
1334 
1335     // TODO(mgersh): consider changing back to a DCHECK once
1336     // https://crbug.com/779589 is fixed.
1337     if (callback_.is_null())
1338       return;
1339 
1340     const DnsResponse* response =
1341         result.attempt ? result.attempt->GetResponse() : nullptr;
1342     CHECK(result.rv != OK || response != nullptr);
1343 
1344     timer_.Stop();
1345 
1346     net_log_.EndEventWithNetErrorCode(NetLogEventType::DNS_TRANSACTION,
1347                                       result.rv);
1348 
1349     std::move(callback_).Run(result.rv, response);
1350   }
1351 
RecordAttemptUma(DnsAttemptType attempt_type)1352   void RecordAttemptUma(DnsAttemptType attempt_type) {
1353     UMA_HISTOGRAM_ENUMERATION("Net.DNS.DnsTransaction.AttemptType",
1354                               attempt_type);
1355   }
1356 
MakeAttempt()1357   AttemptResult MakeAttempt() {
1358     DCHECK(MoreAttemptsAllowed());
1359 
1360     DnsConfig config = session_->config();
1361     if (secure_) {
1362       DCHECK(!config.doh_config.servers().empty());
1363       RecordAttemptUma(DnsAttemptType::kHttp);
1364       return MakeHTTPAttempt();
1365     }
1366 
1367     DCHECK_GT(config.nameservers.size(), 0u);
1368     return MakeClassicDnsAttempt();
1369   }
1370 
MakeClassicDnsAttempt()1371   AttemptResult MakeClassicDnsAttempt() {
1372     uint16_t id = session_->NextQueryId();
1373     std::unique_ptr<DnsQuery> query;
1374     if (attempts_.empty()) {
1375       query =
1376           std::make_unique<DnsQuery>(id, qnames_.front(), qtype_, opt_rdata_);
1377     } else {
1378       query = attempts_[0]->GetQuery()->CloneWithNewId(id);
1379     }
1380     DCHECK(dns_server_iterator_->AttemptAvailable());
1381     size_t server_index = dns_server_iterator_->GetNextAttemptIndex();
1382 
1383     size_t attempt_number = attempts_.size();
1384     AttemptResult result;
1385     if (session_->udp_tracker()->low_entropy()) {
1386       result = MakeTcpAttempt(server_index, std::move(query));
1387       RecordAttemptUma(DnsAttemptType::kTcpLowEntropy);
1388     } else {
1389       result = MakeUdpAttempt(server_index, std::move(query));
1390       RecordAttemptUma(DnsAttemptType::kUdp);
1391     }
1392 
1393     if (result.rv == ERR_IO_PENDING) {
1394       base::TimeDelta fallback_period =
1395           resolve_context_->NextClassicFallbackPeriod(
1396               server_index, attempt_number, session_.get());
1397       timer_.Start(FROM_HERE, fallback_period, this,
1398                    &DnsTransactionImpl::OnFallbackPeriodExpired);
1399     }
1400 
1401     return result;
1402   }
1403 
1404   // Makes another attempt at the current name, |qnames_.front()|, using the
1405   // next nameserver.
MakeUdpAttempt(size_t server_index,std::unique_ptr<DnsQuery> query)1406   AttemptResult MakeUdpAttempt(size_t server_index,
1407                                std::unique_ptr<DnsQuery> query) {
1408     DCHECK(!secure_);
1409     DCHECK(!session_->udp_tracker()->low_entropy());
1410 
1411     const DnsConfig& config = session_->config();
1412     DCHECK_LT(server_index, config.nameservers.size());
1413     size_t attempt_number = attempts_.size();
1414 
1415     std::unique_ptr<DatagramClientSocket> socket =
1416         resolve_context_->url_request_context()
1417             ->GetNetworkSessionContext()
1418             ->client_socket_factory->CreateDatagramClientSocket(
1419                 DatagramSocket::RANDOM_BIND, net_log_.net_log(),
1420                 net_log_.source());
1421 
1422     attempts_.push_back(std::make_unique<DnsUDPAttempt>(
1423         server_index, std::move(socket), config.nameservers[server_index],
1424         std::move(query), session_->udp_tracker()));
1425     ++attempts_count_;
1426 
1427     DnsAttempt* attempt = attempts_.back().get();
1428     net_log_.AddEventReferencingSource(NetLogEventType::DNS_TRANSACTION_ATTEMPT,
1429                                        attempt->GetSocketNetLog().source());
1430 
1431     int rv = attempt->Start(base::BindOnce(
1432         &DnsTransactionImpl::OnAttemptComplete, base::Unretained(this),
1433         attempt_number, true /* record_rtt */, base::TimeTicks::Now()));
1434     return AttemptResult(rv, attempt);
1435   }
1436 
MakeHTTPAttempt()1437   AttemptResult MakeHTTPAttempt() {
1438     DCHECK(secure_);
1439 
1440     size_t doh_server_index = dns_server_iterator_->GetNextAttemptIndex();
1441 
1442     unsigned attempt_number = attempts_.size();
1443     ConstructDnsHTTPAttempt(session_.get(), doh_server_index, qnames_.front(),
1444                             qtype_, opt_rdata_, &attempts_,
1445                             resolve_context_->url_request_context(),
1446                             resolve_context_->isolation_info(),
1447                             request_priority_, /*is_probe=*/false);
1448     ++attempts_count_;
1449     DnsAttempt* attempt = attempts_.back().get();
1450     // Associate this attempt with the DoH request in NetLog.
1451     net_log_.AddEventReferencingSource(
1452         NetLogEventType::DNS_TRANSACTION_HTTPS_ATTEMPT,
1453         attempt->GetSocketNetLog().source());
1454     attempt->GetSocketNetLog().AddEventReferencingSource(
1455         NetLogEventType::DNS_TRANSACTION_HTTPS_ATTEMPT, net_log_.source());
1456     int rv = attempt->Start(base::BindOnce(
1457         &DnsTransactionImpl::OnAttemptComplete, base::Unretained(this),
1458         attempt_number, true /* record_rtt */, base::TimeTicks::Now()));
1459     if (rv == ERR_IO_PENDING) {
1460       base::TimeDelta fallback_period = resolve_context_->NextDohFallbackPeriod(
1461           doh_server_index, session_.get());
1462       timer_.Start(FROM_HERE, fallback_period, this,
1463                    &DnsTransactionImpl::OnFallbackPeriodExpired);
1464     }
1465     return AttemptResult(rv, attempts_.back().get());
1466   }
1467 
RetryUdpAttemptAsTcp(const DnsAttempt * previous_attempt)1468   AttemptResult RetryUdpAttemptAsTcp(const DnsAttempt* previous_attempt) {
1469     DCHECK(previous_attempt);
1470     DCHECK(!had_tcp_retry_);
1471 
1472     // Only allow a single TCP retry per query.
1473     had_tcp_retry_ = true;
1474 
1475     size_t server_index = previous_attempt->server_index();
1476     // Use a new query ID instead of reusing the same one from the UDP attempt.
1477     // RFC5452, section 9.2 requires an unpredictable ID for all outgoing
1478     // queries, with no distinction made between queries made via TCP or UDP.
1479     std::unique_ptr<DnsQuery> query =
1480         previous_attempt->GetQuery()->CloneWithNewId(session_->NextQueryId());
1481 
1482     // Cancel all attempts that have not received a response, as they will
1483     // likely similarly require TCP retry.
1484     ClearAttempts(nullptr);
1485 
1486     AttemptResult result = MakeTcpAttempt(server_index, std::move(query));
1487     RecordAttemptUma(DnsAttemptType::kTcpTruncationRetry);
1488 
1489     if (result.rv == ERR_IO_PENDING) {
1490       // On TCP upgrade, use 2x the upgraded fallback period.
1491       base::TimeDelta fallback_period = timer_.GetCurrentDelay() * 2;
1492       timer_.Start(FROM_HERE, fallback_period, this,
1493                    &DnsTransactionImpl::OnFallbackPeriodExpired);
1494     }
1495 
1496     return result;
1497   }
1498 
MakeTcpAttempt(size_t server_index,std::unique_ptr<DnsQuery> query)1499   AttemptResult MakeTcpAttempt(size_t server_index,
1500                                std::unique_ptr<DnsQuery> query) {
1501     DCHECK(!secure_);
1502     const DnsConfig& config = session_->config();
1503     DCHECK_LT(server_index, config.nameservers.size());
1504 
1505     // TODO(https://crbug.com/1123197): Pass a non-null NetworkQualityEstimator.
1506     NetworkQualityEstimator* network_quality_estimator = nullptr;
1507 
1508     std::unique_ptr<StreamSocket> socket =
1509         resolve_context_->url_request_context()
1510             ->GetNetworkSessionContext()
1511             ->client_socket_factory->CreateTransportClientSocket(
1512                 AddressList(config.nameservers[server_index]), nullptr,
1513                 network_quality_estimator, net_log_.net_log(),
1514                 net_log_.source());
1515 
1516     unsigned attempt_number = attempts_.size();
1517 
1518     attempts_.push_back(std::make_unique<DnsTCPAttempt>(
1519         server_index, std::move(socket), std::move(query)));
1520     ++attempts_count_;
1521 
1522     DnsAttempt* attempt = attempts_.back().get();
1523     net_log_.AddEventReferencingSource(
1524         NetLogEventType::DNS_TRANSACTION_TCP_ATTEMPT,
1525         attempt->GetSocketNetLog().source());
1526 
1527     int rv = attempt->Start(base::BindOnce(
1528         &DnsTransactionImpl::OnAttemptComplete, base::Unretained(this),
1529         attempt_number, false /* record_rtt */, base::TimeTicks::Now()));
1530     return AttemptResult(rv, attempt);
1531   }
1532 
1533   // Begins query for the current name. Makes the first attempt.
StartQuery()1534   AttemptResult StartQuery() {
1535     std::optional<std::string> dotted_qname =
1536         dns_names_util::NetworkToDottedName(qnames_.front());
1537     net_log_.BeginEventWithStringParams(
1538         NetLogEventType::DNS_TRANSACTION_QUERY, "qname",
1539         dotted_qname.value_or("???MALFORMED_NAME???"));
1540 
1541     attempts_.clear();
1542     had_tcp_retry_ = false;
1543     if (secure_) {
1544       dns_server_iterator_ = resolve_context_->GetDohIterator(
1545           session_->config(), secure_dns_mode_, session_.get());
1546     } else {
1547       dns_server_iterator_ = resolve_context_->GetClassicDnsIterator(
1548           session_->config(), session_.get());
1549     }
1550     DCHECK(dns_server_iterator_);
1551     // Check for available server before starting as DoH servers might be
1552     // unavailable.
1553     if (!dns_server_iterator_->AttemptAvailable())
1554       return AttemptResult(ERR_BLOCKED_BY_CLIENT, nullptr);
1555 
1556     return MakeAttempt();
1557   }
1558 
OnAttemptComplete(unsigned attempt_number,bool record_rtt,base::TimeTicks start,int rv)1559   void OnAttemptComplete(unsigned attempt_number,
1560                          bool record_rtt,
1561                          base::TimeTicks start,
1562                          int rv) {
1563     DCHECK_LT(attempt_number, attempts_.size());
1564     const DnsAttempt* attempt = attempts_[attempt_number].get();
1565     if (record_rtt && attempt->GetResponse()) {
1566       resolve_context_->RecordRtt(
1567           attempt->server_index(), secure_ /* is_doh_server */,
1568           base::TimeTicks::Now() - start, rv, session_.get());
1569     }
1570     if (callback_.is_null())
1571       return;
1572     AttemptResult result = ProcessAttemptResult(AttemptResult(rv, attempt));
1573     if (result.rv != ERR_IO_PENDING)
1574       DoCallback(result);
1575   }
1576 
LogResponse(const DnsAttempt * attempt)1577   void LogResponse(const DnsAttempt* attempt) {
1578     if (attempt) {
1579       net_log_.AddEvent(NetLogEventType::DNS_TRANSACTION_RESPONSE,
1580                         [&](NetLogCaptureMode capture_mode) {
1581                           return attempt->NetLogResponseParams(capture_mode);
1582                         });
1583     }
1584   }
1585 
MoreAttemptsAllowed() const1586   bool MoreAttemptsAllowed() const {
1587     if (had_tcp_retry_)
1588       return false;
1589 
1590     return dns_server_iterator_->AttemptAvailable();
1591   }
1592 
1593   // Resolves the result of a DnsAttempt until a terminal result is reached
1594   // or it will complete asynchronously (ERR_IO_PENDING).
ProcessAttemptResult(AttemptResult result)1595   AttemptResult ProcessAttemptResult(AttemptResult result) {
1596     while (result.rv != ERR_IO_PENDING) {
1597       LogResponse(result.attempt);
1598 
1599       switch (result.rv) {
1600         case OK:
1601           resolve_context_->RecordServerSuccess(result.attempt->server_index(),
1602                                                 secure_ /* is_doh_server */,
1603                                                 session_.get());
1604           net_log_.EndEventWithNetErrorCode(
1605               NetLogEventType::DNS_TRANSACTION_QUERY, result.rv);
1606           DCHECK(result.attempt);
1607           DCHECK(result.attempt->GetResponse());
1608           return result;
1609         case ERR_NAME_NOT_RESOLVED:
1610           resolve_context_->RecordServerSuccess(result.attempt->server_index(),
1611                                                 secure_ /* is_doh_server */,
1612                                                 session_.get());
1613           net_log_.EndEventWithNetErrorCode(
1614               NetLogEventType::DNS_TRANSACTION_QUERY, result.rv);
1615           // Try next suffix. Check that qnames_ isn't already empty first,
1616           // which can happen when there are two attempts running at once.
1617           // TODO(mgersh): remove this workaround for https://crbug.com/774846
1618           // when https://crbug.com/779589 is fixed.
1619           if (!qnames_.empty())
1620             qnames_.pop_front();
1621           if (qnames_.empty()) {
1622             return result;
1623           } else {
1624             result = StartQuery();
1625           }
1626           break;
1627         case ERR_DNS_TIMED_OUT:
1628           timer_.Stop();
1629 
1630           if (result.attempt) {
1631             DCHECK(result.attempt == attempts_.back().get());
1632             resolve_context_->RecordServerFailure(
1633                 result.attempt->server_index(), secure_ /* is_doh_server */,
1634                 result.rv, session_.get());
1635           }
1636           if (MoreAttemptsAllowed()) {
1637             result = MakeAttempt();
1638             break;
1639           }
1640 
1641           if (!fast_timeout_ && AnyAttemptPending()) {
1642             StartTimeoutTimer();
1643             return AttemptResult(ERR_IO_PENDING, nullptr);
1644           }
1645 
1646           return result;
1647         case ERR_DNS_SERVER_REQUIRES_TCP:
1648           result = RetryUdpAttemptAsTcp(result.attempt);
1649           break;
1650         case ERR_BLOCKED_BY_CLIENT:
1651           net_log_.EndEventWithNetErrorCode(
1652               NetLogEventType::DNS_TRANSACTION_QUERY, result.rv);
1653           return result;
1654         default:
1655           // Server failure.
1656           DCHECK(result.attempt);
1657 
1658           // If attempt is not the most recent attempt, means this error is for
1659           // a previous attempt that already passed its fallback period and
1660           // continued attempting in parallel with new attempts (see the
1661           // ERR_DNS_TIMED_OUT case above). As the failure was already recorded
1662           // at fallback time and is no longer being waited on, ignore this
1663           // failure.
1664           if (result.attempt == attempts_.back().get()) {
1665             timer_.Stop();
1666             resolve_context_->RecordServerFailure(
1667                 result.attempt->server_index(), secure_ /* is_doh_server */,
1668                 result.rv, session_.get());
1669 
1670             if (MoreAttemptsAllowed()) {
1671               result = MakeAttempt();
1672               break;
1673             }
1674 
1675             if (fast_timeout_) {
1676               return result;
1677             }
1678 
1679             // No more attempts can be made, but there may be other attempts
1680             // still pending, so start the timeout timer.
1681             StartTimeoutTimer();
1682           }
1683 
1684           // If any attempts are still pending, continue to wait for them.
1685           if (AnyAttemptPending()) {
1686             DCHECK(timer_.IsRunning());
1687             return AttemptResult(ERR_IO_PENDING, nullptr);
1688           }
1689 
1690           return result;
1691       }
1692     }
1693     return result;
1694   }
1695 
1696   // Clears and cancels all pending attempts. If |leave_attempt| is not
1697   // null, that attempt is not cleared even if pending.
ClearAttempts(const DnsAttempt * leave_attempt)1698   void ClearAttempts(const DnsAttempt* leave_attempt) {
1699     for (auto it = attempts_.begin(); it != attempts_.end();) {
1700       if ((*it)->IsPending() && it->get() != leave_attempt) {
1701         it = attempts_.erase(it);
1702       } else {
1703         ++it;
1704       }
1705     }
1706   }
1707 
AnyAttemptPending()1708   bool AnyAttemptPending() {
1709     return base::ranges::any_of(attempts_,
1710                                 [](std::unique_ptr<DnsAttempt>& attempt) {
1711                                   return attempt->IsPending();
1712                                 });
1713   }
1714 
OnFallbackPeriodExpired()1715   void OnFallbackPeriodExpired() {
1716     if (callback_.is_null())
1717       return;
1718     DCHECK(!attempts_.empty());
1719     AttemptResult result = ProcessAttemptResult(
1720         AttemptResult(ERR_DNS_TIMED_OUT, attempts_.back().get()));
1721     if (result.rv != ERR_IO_PENDING)
1722       DoCallback(result);
1723   }
1724 
StartTimeoutTimer()1725   void StartTimeoutTimer() {
1726     DCHECK(!fast_timeout_);
1727     DCHECK(!timer_.IsRunning());
1728     DCHECK(!callback_.is_null());
1729 
1730     base::TimeDelta timeout;
1731     if (secure_) {
1732       timeout = resolve_context_->SecureTransactionTimeout(secure_dns_mode_,
1733                                                            session_.get());
1734     } else {
1735       timeout = resolve_context_->ClassicTransactionTimeout(session_.get());
1736     }
1737     timeout -= time_from_start_->Elapsed();
1738 
1739     timer_.Start(FROM_HERE, timeout, this, &DnsTransactionImpl::OnTimeout);
1740   }
1741 
OnTimeout()1742   void OnTimeout() {
1743     if (callback_.is_null())
1744       return;
1745     DoCallback(AttemptResult(ERR_DNS_TIMED_OUT, nullptr));
1746   }
1747 
1748   scoped_refptr<DnsSession> session_;
1749   std::string hostname_;
1750   uint16_t qtype_;
1751   raw_ptr<const OptRecordRdata, DanglingUntriaged> opt_rdata_;
1752   const bool secure_;
1753   const SecureDnsMode secure_dns_mode_;
1754   // Cleared in DoCallback.
1755   ResponseCallback callback_;
1756 
1757   // When true, transaction should time out immediately on expiration of the
1758   // last attempt fallback period rather than waiting the overall transaction
1759   // timeout period.
1760   const bool fast_timeout_;
1761 
1762   NetLogWithSource net_log_;
1763 
1764   // Search list of fully-qualified DNS names to query next (in DNS format).
1765   base::circular_deque<std::vector<uint8_t>> qnames_;
1766   size_t qnames_initial_size_ = 0;
1767 
1768   // List of attempts for the current name.
1769   std::vector<std::unique_ptr<DnsAttempt>> attempts_;
1770   // Count of attempts, not reset when |attempts_| vector is cleared.
1771   int attempts_count_ = 0;
1772 
1773   // Records when an attempt was retried via TCP due to a truncation error.
1774   bool had_tcp_retry_ = false;
1775 
1776   // Iterator to get the index of the DNS server for each search query.
1777   std::unique_ptr<DnsServerIterator> dns_server_iterator_;
1778 
1779   base::OneShotTimer timer_;
1780   std::unique_ptr<base::ElapsedTimer> time_from_start_;
1781 
1782   base::SafeRef<ResolveContext> resolve_context_;
1783   RequestPriority request_priority_ = DEFAULT_PRIORITY;
1784 
1785   THREAD_CHECKER(thread_checker_);
1786 
1787   base::WeakPtrFactory<DnsTransactionImpl> weak_ptr_factory_{this};
1788 };
1789 
1790 // ----------------------------------------------------------------------------
1791 
1792 // Implementation of DnsTransactionFactory that returns instances of
1793 // DnsTransactionImpl.
1794 class DnsTransactionFactoryImpl : public DnsTransactionFactory {
1795  public:
DnsTransactionFactoryImpl(DnsSession * session)1796   explicit DnsTransactionFactoryImpl(DnsSession* session) {
1797     session_ = session;
1798   }
1799 
CreateTransaction(std::string hostname,uint16_t qtype,const NetLogWithSource & net_log,bool secure,SecureDnsMode secure_dns_mode,ResolveContext * resolve_context,bool fast_timeout)1800   std::unique_ptr<DnsTransaction> CreateTransaction(
1801       std::string hostname,
1802       uint16_t qtype,
1803       const NetLogWithSource& net_log,
1804       bool secure,
1805       SecureDnsMode secure_dns_mode,
1806       ResolveContext* resolve_context,
1807       bool fast_timeout) override {
1808     return std::make_unique<DnsTransactionImpl>(
1809         session_.get(), std::move(hostname), qtype, net_log, opt_rdata_.get(),
1810         secure, secure_dns_mode, resolve_context, fast_timeout);
1811   }
1812 
CreateDohProbeRunner(ResolveContext * resolve_context)1813   std::unique_ptr<DnsProbeRunner> CreateDohProbeRunner(
1814       ResolveContext* resolve_context) override {
1815     // Start a timer that will emit metrics after a timeout to indicate whether
1816     // DoH auto-upgrade was successful for this session.
1817     resolve_context->StartDohAutoupgradeSuccessTimer(session_.get());
1818 
1819     return std::make_unique<DnsOverHttpsProbeRunner>(
1820         session_->GetWeakPtr(), resolve_context->GetWeakPtr());
1821   }
1822 
AddEDNSOption(std::unique_ptr<OptRecordRdata::Opt> opt)1823   void AddEDNSOption(std::unique_ptr<OptRecordRdata::Opt> opt) override {
1824     DCHECK(opt);
1825     if (opt_rdata_ == nullptr)
1826       opt_rdata_ = std::make_unique<OptRecordRdata>();
1827 
1828     opt_rdata_->AddOpt(std::move(opt));
1829   }
1830 
GetSecureDnsModeForTest()1831   SecureDnsMode GetSecureDnsModeForTest() override {
1832     return session_->config().secure_dns_mode;
1833   }
1834 
1835  private:
1836   scoped_refptr<DnsSession> session_;
1837   std::unique_ptr<OptRecordRdata> opt_rdata_;
1838 };
1839 
1840 }  // namespace
1841 
1842 DnsTransactionFactory::DnsTransactionFactory() = default;
1843 DnsTransactionFactory::~DnsTransactionFactory() = default;
1844 
1845 // static
CreateFactory(DnsSession * session)1846 std::unique_ptr<DnsTransactionFactory> DnsTransactionFactory::CreateFactory(
1847     DnsSession* session) {
1848   return std::make_unique<DnsTransactionFactoryImpl>(session);
1849 }
1850 
1851 }  // namespace net
1852