xref: /aosp_15_r20/external/cronet/net/http/transport_security_persister.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/http/transport_security_persister.h"
6 
7 #include <algorithm>
8 #include <cstdint>
9 #include <memory>
10 #include <optional>
11 #include <utility>
12 #include <vector>
13 
14 #include "base/base64.h"
15 #include "base/feature_list.h"
16 #include "base/files/file_path.h"
17 #include "base/files/file_util.h"
18 #include "base/functional/bind.h"
19 #include "base/functional/callback.h"
20 #include "base/json/json_reader.h"
21 #include "base/json/json_writer.h"
22 #include "base/location.h"
23 #include "base/task/sequenced_task_runner.h"
24 #include "base/task/single_thread_task_runner.h"
25 #include "base/values.h"
26 #include "net/base/features.h"
27 #include "net/base/network_anonymization_key.h"
28 #include "net/cert/x509_certificate.h"
29 #include "net/http/transport_security_state.h"
30 
31 namespace net {
32 
33 namespace {
34 
35 constexpr const char* kHistogramSuffix = "TransportSecurityPersister";
36 
37 // This function converts the binary hashes to a base64 string which we can
38 // include in a JSON file.
HashedDomainToExternalString(const TransportSecurityState::HashedHost & hashed)39 std::string HashedDomainToExternalString(
40     const TransportSecurityState::HashedHost& hashed) {
41   return base::Base64Encode(hashed);
42 }
43 
44 // This inverts |HashedDomainToExternalString|, above. It turns an external
45 // string (from a JSON file) into an internal (binary) array.
ExternalStringToHashedDomain(const std::string & external)46 std::optional<TransportSecurityState::HashedHost> ExternalStringToHashedDomain(
47     const std::string& external) {
48   TransportSecurityState::HashedHost out;
49   std::optional<std::vector<uint8_t>> hashed = base::Base64Decode(external);
50   if (!hashed.has_value() || hashed.value().size() != out.size()) {
51     return std::nullopt;
52   }
53 
54   std::copy_n(hashed.value().begin(), out.size(), out.begin());
55   return out;
56 }
57 
58 // Version 2 of the on-disk format consists of a single JSON object. The
59 // top-level dictionary has "version", "sts", and "expect_ct" entries. The first
60 // is an integer, the latter two are unordered lists of dictionaries, each
61 // representing cached data for a single host.
62 
63 // Stored in serialized dictionary values to distinguish incompatible versions.
64 // Version 1 is distinguished by the lack of an integer version value.
65 const char kVersionKey[] = "version";
66 const int kCurrentVersionValue = 2;
67 
68 // Keys in top level serialized dictionary, for lists of STS and Expect-CT
69 // entries, respectively. The Expect-CT key is legacy and deleted when read.
70 const char kSTSKey[] = "sts";
71 const char kExpectCTKey[] = "expect_ct";
72 
73 // Hostname entry, used in serialized STS dictionaries. Value is produced by
74 // passing hashed hostname strings to HashedDomainToExternalString().
75 const char kHostname[] = "host";
76 
77 // Key values in serialized STS entries.
78 const char kStsIncludeSubdomains[] = "sts_include_subdomains";
79 const char kStsObserved[] = "sts_observed";
80 const char kExpiry[] = "expiry";
81 const char kMode[] = "mode";
82 
83 // Values for "mode" used in serialized STS entries.
84 const char kForceHTTPS[] = "force-https";
85 const char kDefault[] = "default";
86 
LoadState(const base::FilePath & path)87 std::string LoadState(const base::FilePath& path) {
88   std::string result;
89   if (!base::ReadFileToString(path, &result)) {
90     return "";
91   }
92   return result;
93 }
94 
95 // Serializes STS data from |state| to a Value.
SerializeSTSData(const TransportSecurityState * state)96 base::Value::List SerializeSTSData(const TransportSecurityState* state) {
97   base::Value::List sts_list;
98 
99   TransportSecurityState::STSStateIterator sts_iterator(*state);
100   for (; sts_iterator.HasNext(); sts_iterator.Advance()) {
101     const TransportSecurityState::STSState& sts_state =
102         sts_iterator.domain_state();
103 
104     base::Value::Dict serialized;
105     serialized.Set(kHostname,
106                    HashedDomainToExternalString(sts_iterator.hostname()));
107     serialized.Set(kStsIncludeSubdomains, sts_state.include_subdomains);
108     serialized.Set(kStsObserved,
109                    sts_state.last_observed.InSecondsFSinceUnixEpoch());
110     serialized.Set(kExpiry, sts_state.expiry.InSecondsFSinceUnixEpoch());
111 
112     switch (sts_state.upgrade_mode) {
113       case TransportSecurityState::STSState::MODE_FORCE_HTTPS:
114         serialized.Set(kMode, kForceHTTPS);
115         break;
116       case TransportSecurityState::STSState::MODE_DEFAULT:
117         serialized.Set(kMode, kDefault);
118         break;
119     }
120 
121     sts_list.Append(std::move(serialized));
122   }
123   return sts_list;
124 }
125 
126 // Deserializes STS data from a Value created by the above method.
DeserializeSTSData(const base::Value & sts_list,TransportSecurityState * state)127 void DeserializeSTSData(const base::Value& sts_list,
128                         TransportSecurityState* state) {
129   if (!sts_list.is_list())
130     return;
131 
132   base::Time current_time(base::Time::Now());
133 
134   for (const base::Value& sts_entry : sts_list.GetList()) {
135     const base::Value::Dict* sts_dict = sts_entry.GetIfDict();
136     if (!sts_dict)
137       continue;
138 
139     const std::string* hostname = sts_dict->FindString(kHostname);
140     std::optional<bool> sts_include_subdomains =
141         sts_dict->FindBool(kStsIncludeSubdomains);
142     std::optional<double> sts_observed = sts_dict->FindDouble(kStsObserved);
143     std::optional<double> expiry = sts_dict->FindDouble(kExpiry);
144     const std::string* mode = sts_dict->FindString(kMode);
145 
146     if (!hostname || !sts_include_subdomains.has_value() ||
147         !sts_observed.has_value() || !expiry.has_value() || !mode) {
148       continue;
149     }
150 
151     TransportSecurityState::STSState sts_state;
152     sts_state.include_subdomains = *sts_include_subdomains;
153     sts_state.last_observed =
154         base::Time::FromSecondsSinceUnixEpoch(*sts_observed);
155     sts_state.expiry = base::Time::FromSecondsSinceUnixEpoch(*expiry);
156 
157     if (*mode == kForceHTTPS) {
158       sts_state.upgrade_mode =
159           TransportSecurityState::STSState::MODE_FORCE_HTTPS;
160     } else if (*mode == kDefault) {
161       sts_state.upgrade_mode = TransportSecurityState::STSState::MODE_DEFAULT;
162     } else {
163       continue;
164     }
165 
166     if (sts_state.expiry < current_time || !sts_state.ShouldUpgradeToSSL())
167       continue;
168 
169     std::optional<TransportSecurityState::HashedHost> hashed =
170         ExternalStringToHashedDomain(*hostname);
171     if (!hashed.has_value())
172       continue;
173 
174     state->AddOrUpdateEnabledSTSHosts(hashed.value(), sts_state);
175   }
176 }
177 
OnWriteFinishedTask(scoped_refptr<base::SequencedTaskRunner> task_runner,base::OnceClosure callback,bool result)178 void OnWriteFinishedTask(scoped_refptr<base::SequencedTaskRunner> task_runner,
179                          base::OnceClosure callback,
180                          bool result) {
181   task_runner->PostTask(FROM_HERE, std::move(callback));
182 }
183 
184 }  // namespace
185 
TransportSecurityPersister(TransportSecurityState * state,const scoped_refptr<base::SequencedTaskRunner> & background_runner,const base::FilePath & data_path)186 TransportSecurityPersister::TransportSecurityPersister(
187     TransportSecurityState* state,
188     const scoped_refptr<base::SequencedTaskRunner>& background_runner,
189     const base::FilePath& data_path)
190     : transport_security_state_(state),
191       writer_(data_path, background_runner, kHistogramSuffix),
192       foreground_runner_(base::SingleThreadTaskRunner::GetCurrentDefault()),
193       background_runner_(background_runner) {
194   transport_security_state_->SetDelegate(this);
195 
196   background_runner_->PostTaskAndReplyWithResult(
197       FROM_HERE, base::BindOnce(&LoadState, writer_.path()),
198       base::BindOnce(&TransportSecurityPersister::CompleteLoad,
199                      weak_ptr_factory_.GetWeakPtr()));
200 }
201 
~TransportSecurityPersister()202 TransportSecurityPersister::~TransportSecurityPersister() {
203   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
204 
205   if (writer_.HasPendingWrite())
206     writer_.DoScheduledWrite();
207 
208   transport_security_state_->SetDelegate(nullptr);
209 }
210 
StateIsDirty(TransportSecurityState * state)211 void TransportSecurityPersister::StateIsDirty(TransportSecurityState* state) {
212   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
213   DCHECK_EQ(transport_security_state_, state);
214 
215   writer_.ScheduleWrite(this);
216 }
217 
WriteNow(TransportSecurityState * state,base::OnceClosure callback)218 void TransportSecurityPersister::WriteNow(TransportSecurityState* state,
219                                           base::OnceClosure callback) {
220   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
221   DCHECK_EQ(transport_security_state_, state);
222 
223   writer_.RegisterOnNextWriteCallbacks(
224       base::OnceClosure(),
225       base::BindOnce(
226           &OnWriteFinishedTask, foreground_runner_,
227           base::BindOnce(&TransportSecurityPersister::OnWriteFinished,
228                          weak_ptr_factory_.GetWeakPtr(), std::move(callback))));
229   std::optional<std::string> data = SerializeData();
230   if (data) {
231     writer_.WriteNow(std::move(data).value());
232   } else {
233     writer_.WriteNow(std::string());
234   }
235 }
236 
OnWriteFinished(base::OnceClosure callback)237 void TransportSecurityPersister::OnWriteFinished(base::OnceClosure callback) {
238   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
239   std::move(callback).Run();
240 }
241 
SerializeData()242 std::optional<std::string> TransportSecurityPersister::SerializeData() {
243   CHECK(foreground_runner_->RunsTasksInCurrentSequence());
244 
245   base::Value::Dict toplevel;
246   toplevel.Set(kVersionKey, kCurrentVersionValue);
247   toplevel.Set(kSTSKey, SerializeSTSData(transport_security_state_));
248 
249   std::string output;
250   if (!base::JSONWriter::Write(toplevel, &output)) {
251     return std::nullopt;
252   }
253   return output;
254 }
255 
LoadEntries(const std::string & serialized)256 void TransportSecurityPersister::LoadEntries(const std::string& serialized) {
257   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
258 
259   transport_security_state_->ClearDynamicData();
260   bool contains_legacy_expect_ct_data = false;
261   Deserialize(serialized, transport_security_state_,
262               contains_legacy_expect_ct_data);
263   if (contains_legacy_expect_ct_data) {
264     StateIsDirty(transport_security_state_);
265   }
266 }
267 
Deserialize(const std::string & serialized,TransportSecurityState * state,bool & contains_legacy_expect_ct_data)268 void TransportSecurityPersister::Deserialize(
269     const std::string& serialized,
270     TransportSecurityState* state,
271     bool& contains_legacy_expect_ct_data) {
272   std::optional<base::Value> value = base::JSONReader::Read(serialized);
273   if (!value || !value->is_dict())
274     return;
275 
276   base::Value::Dict& dict = value->GetDict();
277   std::optional<int> version = dict.FindInt(kVersionKey);
278 
279   // Stop if the data is out of date (or in the previous format that didn't have
280   // a version number).
281   if (!version || *version != kCurrentVersionValue)
282     return;
283 
284   base::Value* sts_value = dict.Find(kSTSKey);
285   if (sts_value)
286     DeserializeSTSData(*sts_value, state);
287 
288   // If an Expect-CT key is found on deserialization, record this so that a
289   // write can be scheduled to clear it from disk.
290   contains_legacy_expect_ct_data = !!dict.Find(kExpectCTKey);
291 }
292 
CompleteLoad(const std::string & state)293 void TransportSecurityPersister::CompleteLoad(const std::string& state) {
294   DCHECK(foreground_runner_->RunsTasksInCurrentSequence());
295 
296   if (state.empty())
297     return;
298 
299   LoadEntries(state);
300 }
301 
302 }  // namespace net
303