1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <array>
20 #include <list>
21 #include <map>
22 #include <mutex>
23 #include <vector>
24 
25 #include <aidl/android/net/resolv/aidl/DohParamsParcel.h>
26 
27 #include <android-base/format.h>
28 #include <android-base/logging.h>
29 #include <android-base/result.h>
30 #include <android-base/thread_annotations.h>
31 #include <netdutils/BackoffSequence.h>
32 #include <netdutils/DumpWriter.h>
33 #include <netdutils/InternetAddresses.h>
34 #include <netdutils/Slice.h>
35 #include <stats.pb.h>
36 
37 #include "DnsTlsServer.h"
38 #include "LockedQueue.h"
39 #include "PrivateDnsValidationObserver.h"
40 #include "doh.h"
41 
42 namespace android {
43 namespace net {
44 
45 PrivateDnsModes convertEnumType(PrivateDnsMode mode);
46 
47 struct DohServerInfo {
48     std::string httpsTemplate;
49     Validation status;
50 
DohServerInfoDohServerInfo51     DohServerInfo(const std::string httpsTemplate, Validation status)
52         : httpsTemplate(httpsTemplate), status(status) {}
53 };
54 
55 struct PrivateDnsStatus {
56     PrivateDnsMode mode;
57 
58     // TODO: change the type to std::vector<DnsTlsServer>.
59     std::map<DnsTlsServer, Validation, AddressComparator> dotServersMap;
60 
61     std::map<netdutils::IPSockAddr, DohServerInfo> dohServersMap;
62 
validatedServersPrivateDnsStatus63     std::list<DnsTlsServer> validatedServers() const {
64         std::list<DnsTlsServer> servers;
65 
66         for (const auto& pair : dotServersMap) {
67             if (pair.second == Validation::success) {
68                 servers.push_back(pair.first);
69             }
70         }
71         return servers;
72     }
73 
hasValidatedDohServersPrivateDnsStatus74     bool hasValidatedDohServers() const {
75         for (const auto& [_, info] : dohServersMap) {
76             if (info.status == Validation::success) {
77                 return true;
78             }
79         }
80         return false;
81     }
82 };
83 
84 class PrivateDnsConfiguration {
85   private:
86     using DohParamsParcel = aidl::android::net::resolv::aidl::DohParamsParcel;
87 
88   public:
89     static constexpr int kDohQueryDefaultTimeoutMs = 30000;
90     static constexpr int kDohProbeDefaultTimeoutMs = 60000;
91 
92     // The default value for QUIC max_idle_timeout.
93     static constexpr int kDohIdleDefaultTimeoutMs = 55000;
94 
95     struct ServerIdentity {
96         const netdutils::IPSockAddr sockaddr;
97         const std::string provider;
98 
ServerIdentityServerIdentity99         explicit ServerIdentity(const DnsTlsServer& server)
100             : sockaddr(server.addr()), provider(server.provider()) {}
ServerIdentityServerIdentity101         ServerIdentity(const netdutils::IPSockAddr& addr, const std::string& host)
102             : sockaddr(addr), provider(host) {}
103 
104         bool operator<(const ServerIdentity& other) const {
105             return std::tie(sockaddr, provider) < std::tie(other.sockaddr, other.provider);
106         }
107         bool operator==(const ServerIdentity& other) const {
108             return std::tie(sockaddr, provider) == std::tie(other.sockaddr, other.provider);
109         }
110     };
111 
112     // The only instance of PrivateDnsConfiguration.
getInstance()113     static PrivateDnsConfiguration& getInstance() {
114         static PrivateDnsConfiguration instance;
115         return instance;
116     }
117 
118     int set(int32_t netId, uint32_t mark, const std::vector<std::string>& unencryptedServers,
119             const std::vector<std::string>& encryptedServers, const std::string& name,
120             const std::string& caCert, const std::optional<DohParamsParcel> dohParams)
121             EXCLUDES(mPrivateDnsLock);
122 
123     void initDoh() EXCLUDES(mPrivateDnsLock);
124 
125     PrivateDnsStatus getStatus(unsigned netId) const EXCLUDES(mPrivateDnsLock);
126     NetworkDnsServerSupportReported getStatusForMetrics(unsigned netId) const
127             EXCLUDES(mPrivateDnsLock);
128 
129     void clear(unsigned netId) EXCLUDES(mPrivateDnsLock);
130 
131     ssize_t dohQuery(unsigned netId, const netdutils::Slice query, const netdutils::Slice answer,
132                      uint64_t timeoutMs) EXCLUDES(mPrivateDnsLock);
133 
134     // Request the server to be revalidated on a connection tagged with |mark|.
135     // Returns a Result to indicate if the request is accepted.
136     base::Result<void> requestDotValidation(unsigned netId, const ServerIdentity& identity,
137                                             uint32_t mark) EXCLUDES(mPrivateDnsLock);
138 
139     void setObserver(PrivateDnsValidationObserver* observer);
140 
141     void dump(netdutils::DumpWriter& dw) const;
142 
143     void onDohStatusUpdate(uint32_t netId, bool success, const char* ipAddr, const char* host)
144             EXCLUDES(mPrivateDnsLock);
145 
146     base::Result<netdutils::IPSockAddr> getDohServer(unsigned netId) const
147             EXCLUDES(mPrivateDnsLock);
148 
149   private:
150     PrivateDnsConfiguration() = default;
151 
152     int setDot(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
153                const std::string& name, const std::string& caCert) REQUIRES(mPrivateDnsLock);
154 
155     void clearDot(int32_t netId) REQUIRES(mPrivateDnsLock);
156 
157     // For testing.
158     base::Result<DnsTlsServer*> getDotServer(const ServerIdentity& identity, unsigned netId)
159             EXCLUDES(mPrivateDnsLock);
160 
161     base::Result<DnsTlsServer*> getDotServerLocked(const ServerIdentity& identity, unsigned netId)
162             REQUIRES(mPrivateDnsLock);
163 
164     // TODO: change the return type to Result<PrivateDnsStatus>.
165     PrivateDnsStatus getStatusLocked(unsigned netId) const REQUIRES(mPrivateDnsLock);
166 
167     // Launchs a thread to run the validation for the DoT server |server| on the network |netId|.
168     // |isRevalidation| is true if this call is due to a revalidation request.
169     void startDotValidation(const ServerIdentity& identity, unsigned netId, bool isRevalidation)
170             REQUIRES(mPrivateDnsLock);
171 
172     bool recordDotValidation(const ServerIdentity& identity, unsigned netId, bool success,
173                              bool isRevalidation) EXCLUDES(mPrivateDnsLock);
174 
175     void sendPrivateDnsValidationEvent(const ServerIdentity& identity, unsigned netId,
176                                        bool success) const REQUIRES(mPrivateDnsLock);
177 
178     // Decide if a validation for |server| is needed. Note that servers that have failed
179     // multiple validation attempts but for which there is still a validating
180     // thread running are marked as being Validation::in_process.
181     bool needsValidation(const DnsTlsServer& server) const REQUIRES(mPrivateDnsLock);
182 
183     void updateServerState(const ServerIdentity& identity, Validation state, uint32_t netId)
184             REQUIRES(mPrivateDnsLock);
185 
186     void initDohLocked() REQUIRES(mPrivateDnsLock);
187     int setDoh(int32_t netId, uint32_t mark, const std::vector<std::string>& servers,
188                const std::string& name, const std::string& caCert,
189                const std::optional<DohParamsParcel> dohParams) REQUIRES(mPrivateDnsLock);
190     void clearDoh(unsigned netId) REQUIRES(mPrivateDnsLock);
191 
192     mutable std::mutex mPrivateDnsLock;
193     std::map<unsigned, PrivateDnsMode> mPrivateDnsModes GUARDED_BY(mPrivateDnsLock);
194 
195     // Contains all servers for a network, along with their current validation status.
196     // In case a server is removed due to a configuration change, it remains in this map,
197     // but is marked inactive.
198     // Any pending validation threads will continue running because we have no way to cancel them.
199     std::map<unsigned, std::map<ServerIdentity, DnsTlsServer>> mDotTracker
200             GUARDED_BY(mPrivateDnsLock);
201 
202     void notifyValidationStateUpdate(const netdutils::IPSockAddr& sockaddr, Validation validation,
203                                      uint32_t netId) const REQUIRES(mPrivateDnsLock);
204 
205     bool needReportEvent(uint32_t netId, ServerIdentity identity, bool success) const
206             REQUIRES(mPrivateDnsLock);
207 
208     // TODO: fix the reentrancy problem.
209     PrivateDnsValidationObserver* mObserver GUARDED_BY(mPrivateDnsLock);
210 
211     DohDispatcher* mDohDispatcher = nullptr;
212     std::condition_variable mCv;
213 
214     friend class PrivateDnsConfigurationTest;
215 
216     // It's not const because PrivateDnsConfigurationTest needs to override it.
217     // TODO: make it const by dependency injection.
218     netdutils::BackoffSequence<>::Builder mBackoffBuilder =
219             netdutils::BackoffSequence<>::Builder()
220                     .withInitialRetransmissionTime(std::chrono::seconds(60))
221                     .withMaximumRetransmissionTime(std::chrono::seconds(3600));
222 
223     struct DohIdentity {
224         std::string httpsTemplate;
225         std::string ipAddr;
226         std::string host;
227         Validation status;
228         bool operator<(const DohIdentity& other) const {
229             return std::tie(ipAddr, host) < std::tie(other.ipAddr, other.host);
230         }
231         bool operator==(const DohIdentity& other) const {
232             return std::tie(ipAddr, host) == std::tie(other.ipAddr, other.host);
233         }
234         bool operator<(const ServerIdentity& other) const {
235             std::string otherIp = other.sockaddr.ip().toString();
236             return std::tie(ipAddr, host) < std::tie(otherIp, other.provider);
237         }
238         bool operator==(const ServerIdentity& other) const {
239             std::string otherIp = other.sockaddr.ip().toString();
240             return std::tie(ipAddr, host) == std::tie(otherIp, other.provider);
241         }
242     };
243 
244     struct DohProviderEntry {
245         std::string provider;
246         std::set<std::string> ips;
247         std::string host;
248         std::string httpsTemplate;
249         bool requireRootPermission;
250 
getDohIdentityDohProviderEntry251         base::Result<DohIdentity> getDohIdentity(const std::vector<std::string>& sortedValidIps,
252                                                  const std::string& host) const {
253             // If the private DNS hostname is known, `sortedValidIps` are the IP addresses
254             // resolved from the hostname, and hostname verification will be performed during
255             // TLS handshake to ensure the validity of the server, so it's not necessary to
256             // check the IP address.
257             if (!host.empty()) {
258                 if (this->host != host) return Errorf("host {} not matched", host);
259                 if (!sortedValidIps.empty()) {
260                     const auto& ip = sortedValidIps[0];
261                     LOG(INFO) << fmt::format("getDohIdentity: {} {}", ip, host);
262                     return DohIdentity{httpsTemplate, ip, host, Validation::in_process};
263                 }
264             }
265             for (const auto& ip : sortedValidIps) {
266                 if (ips.find(ip) == ips.end()) continue;
267                 LOG(INFO) << fmt::format("getDohIdentity: {} {}", ip, host);
268                 return DohIdentity{httpsTemplate, ip, host, Validation::in_process};
269             }
270             return Errorf("server not matched");
271         };
272     };
273 
274     // TODO: Move below DoH relevant stuff into Rust implementation.
275     std::map<unsigned, DohIdentity> mDohTracker GUARDED_BY(mPrivateDnsLock);
276     std::array<DohProviderEntry, 5> mAvailableDoHProviders = {{
277             {"Google",
278              {"2001:4860:4860::8888", "2001:4860:4860::8844", "8.8.8.8", "8.8.4.4"},
279              "dns.google",
280              "https://dns.google/dns-query",
281              false},
282             {"Google DNS64",
283              {"2001:4860:4860::64", "2001:4860:4860::6464"},
284              "dns64.dns.google",
285              "https://dns64.dns.google/dns-query",
286              false},
287             {"Cloudflare",
288              {"2606:4700::6810:f8f9", "2606:4700::6810:f9f9", "104.16.248.249", "104.16.249.249"},
289              "cloudflare-dns.com",
290              "https://cloudflare-dns.com/dns-query",
291              false},
292 
293             // The DoH providers for testing only.
294             // Using ResolverTestProvider requires that the DnsResolver is configured by someone
295             // who has root permission, which should be run by tests only.
296             {"ResolverTestProvider",
297              {"127.0.0.3", "::1"},
298              "example.com",
299              "https://example.com/dns-query",
300              true},
301             {"AndroidTesting",
302              {"192.0.2.100"},
303              "dns.androidtesting.org",
304              "https://dns.androidtesting.org/dns-query",
305              false},
306     }};
307 
308     // Makes a DohIdentity if
309     //   1. `dohParams` has some valid value, or
310     //   2. `servers` and `name` match up `mAvailableDoHProviders`.
311     base::Result<DohIdentity> makeDohIdentity(const std::vector<std::string>& servers,
312                                               const std::string& name,
313                                               const std::optional<DohParamsParcel> dohParams) const
314             REQUIRES(mPrivateDnsLock);
315 
316     // For the metrics. Store the current DNS server list in the same order as what is passed
317     // in setResolverConfiguration().
318     std::map<unsigned, std::vector<std::string>> mUnorderedDnsTracker GUARDED_BY(mPrivateDnsLock);
319     std::map<unsigned, std::vector<std::string>> mUnorderedDotTracker GUARDED_BY(mPrivateDnsLock);
320     std::map<unsigned, std::vector<std::string>> mUnorderedDohTracker GUARDED_BY(mPrivateDnsLock);
321 
322     struct RecordEntry {
RecordEntryRecordEntry323         RecordEntry(uint32_t netId, const ServerIdentity& identity, Validation state)
324             : netId(netId), serverIdentity(identity), state(state) {}
325 
326         const uint32_t netId;
327         const ServerIdentity serverIdentity;
328         const Validation state;
329         const std::chrono::system_clock::time_point timestamp = std::chrono::system_clock::now();
330     };
331 
332     LockedRingBuffer<RecordEntry> mPrivateDnsLog{100};
333 };
334 
335 }  // namespace net
336 }  // namespace android
337