xref: /aosp_15_r20/external/cronet/net/socket/ssl_client_socket.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/ssl_client_socket.h"
6 
7 #include <string>
8 
9 #include "base/containers/flat_tree.h"
10 #include "base/logging.h"
11 #include "base/observer_list.h"
12 #include "base/values.h"
13 #include "net/cert/x509_certificate_net_log_param.h"
14 #include "net/log/net_log.h"
15 #include "net/log/net_log_event_type.h"
16 #include "net/socket/ssl_client_socket_impl.h"
17 #include "net/socket/stream_socket.h"
18 #include "net/ssl/ssl_client_session_cache.h"
19 #include "net/ssl/ssl_key_logger.h"
20 
21 namespace net {
22 
23 namespace {
24 
25 // Returns true if |first_cert| and |second_cert| represent the same certificate
26 // (with the same chain), or if they're both NULL.
AreCertificatesEqual(const scoped_refptr<X509Certificate> & first_cert,const scoped_refptr<X509Certificate> & second_cert,bool include_chain=true)27 bool AreCertificatesEqual(const scoped_refptr<X509Certificate>& first_cert,
28                           const scoped_refptr<X509Certificate>& second_cert,
29                           bool include_chain = true) {
30   return (!first_cert && !second_cert) ||
31          (first_cert && second_cert &&
32           (include_chain
33                ? first_cert->EqualsIncludingChain(second_cert.get())
34                : first_cert->EqualsExcludingChain(second_cert.get())));
35 }
36 
37 // Returns a base::Value::Dict value NetLog parameter with the expected format
38 // for events of type CLEAR_CACHED_CLIENT_CERT.
NetLogClearCachedClientCertParams(const net::HostPortPair & host,const scoped_refptr<net::X509Certificate> & cert,bool is_cleared)39 base::Value::Dict NetLogClearCachedClientCertParams(
40     const net::HostPortPair& host,
41     const scoped_refptr<net::X509Certificate>& cert,
42     bool is_cleared) {
43   return base::Value::Dict()
44       .Set("host", host.ToString())
45       .Set("certificates", cert ? net::NetLogX509CertificateList(cert.get())
46                                 : base::Value(base::Value::List()))
47       .Set("is_cleared", is_cleared);
48 }
49 
50 // Returns a base::Value::Dict value NetLog parameter with the expected format
51 // for events of type CLEAR_MATCHING_CACHED_CLIENT_CERT.
NetLogClearMatchingCachedClientCertParams(const base::flat_set<net::HostPortPair> & hosts,const scoped_refptr<net::X509Certificate> & cert)52 base::Value::Dict NetLogClearMatchingCachedClientCertParams(
53     const base::flat_set<net::HostPortPair>& hosts,
54     const scoped_refptr<net::X509Certificate>& cert) {
55   base::Value::List hosts_values;
56   for (const auto& host : hosts) {
57     hosts_values.Append(host.ToString());
58   }
59 
60   return base::Value::Dict()
61       .Set("hosts", base::Value(std::move(hosts_values)))
62       .Set("certificates", cert ? net::NetLogX509CertificateList(cert.get())
63                                 : base::Value(base::Value::List()));
64 }
65 
66 }  // namespace
67 
68 SSLClientSocket::SSLClientSocket() = default;
69 
70 // static
SetSSLKeyLogger(std::unique_ptr<SSLKeyLogger> logger)71 void SSLClientSocket::SetSSLKeyLogger(std::unique_ptr<SSLKeyLogger> logger) {
72   SSLClientSocketImpl::SetSSLKeyLogger(std::move(logger));
73 }
74 
75 // static
SerializeNextProtos(const NextProtoVector & next_protos)76 std::vector<uint8_t> SSLClientSocket::SerializeNextProtos(
77     const NextProtoVector& next_protos) {
78   std::vector<uint8_t> wire_protos;
79   for (const NextProto next_proto : next_protos) {
80     const std::string proto = NextProtoToString(next_proto);
81     if (proto.size() > 255) {
82       LOG(WARNING) << "Ignoring overlong ALPN protocol: " << proto;
83       continue;
84     }
85     if (proto.size() == 0) {
86       LOG(WARNING) << "Ignoring empty ALPN protocol";
87       continue;
88     }
89     wire_protos.push_back(proto.size());
90     for (const char ch : proto) {
91       wire_protos.push_back(static_cast<uint8_t>(ch));
92     }
93   }
94 
95   return wire_protos;
96 }
97 
SSLClientContext(SSLConfigService * ssl_config_service,CertVerifier * cert_verifier,TransportSecurityState * transport_security_state,SSLClientSessionCache * ssl_client_session_cache,SCTAuditingDelegate * sct_auditing_delegate)98 SSLClientContext::SSLClientContext(
99     SSLConfigService* ssl_config_service,
100     CertVerifier* cert_verifier,
101     TransportSecurityState* transport_security_state,
102     SSLClientSessionCache* ssl_client_session_cache,
103     SCTAuditingDelegate* sct_auditing_delegate)
104     : ssl_config_service_(ssl_config_service),
105       cert_verifier_(cert_verifier),
106       transport_security_state_(transport_security_state),
107       ssl_client_session_cache_(ssl_client_session_cache),
108       sct_auditing_delegate_(sct_auditing_delegate) {
109   CHECK(cert_verifier_);
110   CHECK(transport_security_state_);
111 
112   if (ssl_config_service_) {
113     config_ = ssl_config_service_->GetSSLContextConfig();
114     ssl_config_service_->AddObserver(this);
115   }
116   cert_verifier_->AddObserver(this);
117   CertDatabase::GetInstance()->AddObserver(this);
118 }
119 
~SSLClientContext()120 SSLClientContext::~SSLClientContext() {
121   if (ssl_config_service_) {
122     ssl_config_service_->RemoveObserver(this);
123   }
124   cert_verifier_->RemoveObserver(this);
125   CertDatabase::GetInstance()->RemoveObserver(this);
126 }
127 
CreateSSLClientSocket(std::unique_ptr<StreamSocket> stream_socket,const HostPortPair & host_and_port,const SSLConfig & ssl_config)128 std::unique_ptr<SSLClientSocket> SSLClientContext::CreateSSLClientSocket(
129     std::unique_ptr<StreamSocket> stream_socket,
130     const HostPortPair& host_and_port,
131     const SSLConfig& ssl_config) {
132   return std::make_unique<SSLClientSocketImpl>(this, std::move(stream_socket),
133                                                host_and_port, ssl_config);
134 }
135 
GetClientCertificate(const HostPortPair & server,scoped_refptr<X509Certificate> * client_cert,scoped_refptr<SSLPrivateKey> * private_key)136 bool SSLClientContext::GetClientCertificate(
137     const HostPortPair& server,
138     scoped_refptr<X509Certificate>* client_cert,
139     scoped_refptr<SSLPrivateKey>* private_key) {
140   return ssl_client_auth_cache_.Lookup(server, client_cert, private_key);
141 }
142 
SetClientCertificate(const HostPortPair & server,scoped_refptr<X509Certificate> client_cert,scoped_refptr<SSLPrivateKey> private_key)143 void SSLClientContext::SetClientCertificate(
144     const HostPortPair& server,
145     scoped_refptr<X509Certificate> client_cert,
146     scoped_refptr<SSLPrivateKey> private_key) {
147   ssl_client_auth_cache_.Add(server, std::move(client_cert),
148                              std::move(private_key));
149 
150   if (ssl_client_session_cache_) {
151     // Session resumption bypasses client certificate negotiation, so flush all
152     // associated sessions when preferences change.
153     ssl_client_session_cache_->FlushForServers({server});
154   }
155   NotifySSLConfigForServersChanged({server});
156 }
157 
ClearClientCertificate(const HostPortPair & server)158 bool SSLClientContext::ClearClientCertificate(const HostPortPair& server) {
159   if (!ssl_client_auth_cache_.Remove(server)) {
160     return false;
161   }
162 
163   if (ssl_client_session_cache_) {
164     // Session resumption bypasses client certificate negotiation, so flush all
165     // associated sessions when preferences change.
166     ssl_client_session_cache_->FlushForServers({server});
167   }
168   NotifySSLConfigForServersChanged({server});
169   return true;
170 }
171 
AddObserver(Observer * observer)172 void SSLClientContext::AddObserver(Observer* observer) {
173   observers_.AddObserver(observer);
174 }
175 
RemoveObserver(Observer * observer)176 void SSLClientContext::RemoveObserver(Observer* observer) {
177   observers_.RemoveObserver(observer);
178 }
179 
OnSSLContextConfigChanged()180 void SSLClientContext::OnSSLContextConfigChanged() {
181   config_ = ssl_config_service_->GetSSLContextConfig();
182   if (ssl_client_session_cache_) {
183     ssl_client_session_cache_->Flush();
184   }
185   NotifySSLConfigChanged(SSLConfigChangeType::kSSLConfigChanged);
186 }
187 
OnCertVerifierChanged()188 void SSLClientContext::OnCertVerifierChanged() {
189   NotifySSLConfigChanged(SSLConfigChangeType::kCertVerifierChanged);
190 }
191 
OnTrustStoreChanged()192 void SSLClientContext::OnTrustStoreChanged() {
193   NotifySSLConfigChanged(SSLConfigChangeType::kCertDatabaseChanged);
194 }
195 
OnClientCertStoreChanged()196 void SSLClientContext::OnClientCertStoreChanged() {
197   base::flat_set<HostPortPair> servers =
198       ssl_client_auth_cache_.GetCachedServers();
199   ssl_client_auth_cache_.Clear();
200   if (ssl_client_session_cache_) {
201     ssl_client_session_cache_->FlushForServers(servers);
202   }
203   NotifySSLConfigForServersChanged(servers);
204 }
205 
ClearClientCertificateIfNeeded(const net::HostPortPair & host,const scoped_refptr<net::X509Certificate> & certificate)206 void SSLClientContext::ClearClientCertificateIfNeeded(
207     const net::HostPortPair& host,
208     const scoped_refptr<net::X509Certificate>& certificate) {
209   scoped_refptr<X509Certificate> cached_certificate;
210   scoped_refptr<SSLPrivateKey> cached_private_key;
211   if (!ssl_client_auth_cache_.Lookup(host, &cached_certificate,
212                                      &cached_private_key) ||
213       AreCertificatesEqual(cached_certificate, certificate)) {
214     // No cached client certificate preference for this host.
215     net::NetLog::Get()->AddGlobalEntry(
216         NetLogEventType::CLEAR_CACHED_CLIENT_CERT, [&]() {
217           return NetLogClearCachedClientCertParams(host, certificate,
218                                                    /*is_cleared=*/false);
219         });
220     return;
221   }
222 
223   net::NetLog::Get()->AddGlobalEntry(
224       NetLogEventType::CLEAR_CACHED_CLIENT_CERT, [&]() {
225         return NetLogClearCachedClientCertParams(host, certificate,
226                                                  /*is_cleared=*/true);
227       });
228 
229   ssl_client_auth_cache_.Remove(host);
230 
231   if (ssl_client_session_cache_) {
232     ssl_client_session_cache_->FlushForServers({host});
233   }
234 
235   NotifySSLConfigForServersChanged({host});
236 }
237 
ClearMatchingClientCertificate(const scoped_refptr<net::X509Certificate> & certificate)238 void SSLClientContext::ClearMatchingClientCertificate(
239     const scoped_refptr<net::X509Certificate>& certificate) {
240   CHECK(certificate);
241 
242   base::flat_set<HostPortPair> cleared_servers;
243   for (const auto& server : ssl_client_auth_cache_.GetCachedServers()) {
244     scoped_refptr<X509Certificate> cached_certificate;
245     scoped_refptr<SSLPrivateKey> cached_private_key;
246     if (ssl_client_auth_cache_.Lookup(server, &cached_certificate,
247                                       &cached_private_key) &&
248         AreCertificatesEqual(cached_certificate, certificate,
249                              /*include_chain=*/false)) {
250       cleared_servers.insert(cleared_servers.end(), server);
251     }
252   }
253 
254   net::NetLog::Get()->AddGlobalEntry(
255       NetLogEventType::CLEAR_MATCHING_CACHED_CLIENT_CERT, [&]() {
256         return NetLogClearMatchingCachedClientCertParams(cleared_servers,
257                                                          certificate);
258       });
259 
260   if (cleared_servers.empty()) {
261     return;
262   }
263 
264   for (const auto& server_to_clear : cleared_servers) {
265     ssl_client_auth_cache_.Remove(server_to_clear);
266   }
267 
268   if (ssl_client_session_cache_) {
269     ssl_client_session_cache_->FlushForServers(cleared_servers);
270   }
271 
272   NotifySSLConfigForServersChanged(cleared_servers);
273 }
274 
NotifySSLConfigChanged(SSLConfigChangeType change_type)275 void SSLClientContext::NotifySSLConfigChanged(SSLConfigChangeType change_type) {
276   for (Observer& observer : observers_) {
277     observer.OnSSLConfigChanged(change_type);
278   }
279 }
280 
NotifySSLConfigForServersChanged(const base::flat_set<HostPortPair> & servers)281 void SSLClientContext::NotifySSLConfigForServersChanged(
282     const base::flat_set<HostPortPair>& servers) {
283   for (Observer& observer : observers_) {
284     observer.OnSSLConfigForServersChanged(servers);
285   }
286 }
287 
288 }  // namespace net
289