xref: /aosp_15_r20/external/cronet/net/socket/ssl_server_socket_impl.cc (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2012 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "net/socket/ssl_server_socket_impl.h"
6 
7 #include <memory>
8 #include <optional>
9 #include <string_view>
10 #include <utility>
11 
12 #include "base/functional/bind.h"
13 #include "base/functional/callback_helpers.h"
14 #include "base/logging.h"
15 #include "base/memory/raw_ptr.h"
16 #include "base/memory/weak_ptr.h"
17 #include "base/strings/string_util.h"
18 #include "crypto/openssl_util.h"
19 #include "crypto/rsa_private_key.h"
20 #include "net/base/completion_once_callback.h"
21 #include "net/base/net_errors.h"
22 #include "net/cert/cert_verify_result.h"
23 #include "net/cert/client_cert_verifier.h"
24 #include "net/cert/x509_util.h"
25 #include "net/log/net_log_event_type.h"
26 #include "net/log/net_log_with_source.h"
27 #include "net/socket/socket_bio_adapter.h"
28 #include "net/ssl/openssl_ssl_util.h"
29 #include "net/ssl/ssl_connection_status_flags.h"
30 #include "net/ssl/ssl_info.h"
31 #include "net/ssl/ssl_private_key.h"
32 #include "net/traffic_annotation/network_traffic_annotation.h"
33 #include "third_party/boringssl/src/include/openssl/bytestring.h"
34 #include "third_party/boringssl/src/include/openssl/err.h"
35 #include "third_party/boringssl/src/include/openssl/pool.h"
36 #include "third_party/boringssl/src/include/openssl/ssl.h"
37 
38 #define GotoState(s) next_handshake_state_ = s
39 
40 namespace net {
41 
42 namespace {
43 
44 // This constant can be any non-negative/non-zero value (eg: it does not
45 // overlap with any value of the net::Error range, including net::OK).
46 const int kSSLServerSocketNoPendingResult = 1;
47 
48 }  // namespace
49 
50 class SSLServerContextImpl::SocketImpl : public SSLServerSocket,
51                                          public SocketBIOAdapter::Delegate {
52  public:
53   SocketImpl(SSLServerContextImpl* context,
54              std::unique_ptr<StreamSocket> socket);
55 
56   SocketImpl(const SocketImpl&) = delete;
57   SocketImpl& operator=(const SocketImpl&) = delete;
58 
59   ~SocketImpl() override;
60 
61   // SSLServerSocket interface.
62   int Handshake(CompletionOnceCallback callback) override;
63 
64   // SSLSocket interface.
65   int ExportKeyingMaterial(std::string_view label,
66                            bool has_context,
67                            std::string_view context,
68                            unsigned char* out,
69                            unsigned int outlen) override;
70 
71   // Socket interface (via StreamSocket).
72   int Read(IOBuffer* buf,
73            int buf_len,
74            CompletionOnceCallback callback) override;
75   int ReadIfReady(IOBuffer* buf,
76                   int buf_len,
77                   CompletionOnceCallback callback) override;
78   int CancelReadIfReady() override;
79   int Write(IOBuffer* buf,
80             int buf_len,
81             CompletionOnceCallback callback,
82             const NetworkTrafficAnnotationTag& traffic_annotation) override;
83   int SetReceiveBufferSize(int32_t size) override;
84   int SetSendBufferSize(int32_t size) override;
85 
86   // StreamSocket implementation.
87   int Connect(CompletionOnceCallback callback) override;
88   void Disconnect() override;
89   bool IsConnected() const override;
90   bool IsConnectedAndIdle() const override;
91   int GetPeerAddress(IPEndPoint* address) const override;
92   int GetLocalAddress(IPEndPoint* address) const override;
93   const NetLogWithSource& NetLog() const override;
94   bool WasEverUsed() const override;
95   NextProto GetNegotiatedProtocol() const override;
96   std::optional<std::string_view> GetPeerApplicationSettings() const override;
97   bool GetSSLInfo(SSLInfo* ssl_info) override;
98   int64_t GetTotalReceivedBytes() const override;
99   void ApplySocketTag(const SocketTag& tag) override;
100 
101   static SocketImpl* FromSSL(SSL* ssl);
102 
103   static ssl_verify_result_t CertVerifyCallback(SSL* ssl, uint8_t* out_alert);
104   ssl_verify_result_t CertVerifyCallbackImpl(uint8_t* out_alert);
105 
106   static const SSL_PRIVATE_KEY_METHOD kPrivateKeyMethod;
107   static ssl_private_key_result_t PrivateKeySignCallback(SSL* ssl,
108                                                          uint8_t* out,
109                                                          size_t* out_len,
110                                                          size_t max_out,
111                                                          uint16_t algorithm,
112                                                          const uint8_t* in,
113                                                          size_t in_len);
114   static ssl_private_key_result_t PrivateKeyDecryptCallback(SSL* ssl,
115                                                             uint8_t* out,
116                                                             size_t* out_len,
117                                                             size_t max_out,
118                                                             const uint8_t* in,
119                                                             size_t in_len);
120   static ssl_private_key_result_t PrivateKeyCompleteCallback(SSL* ssl,
121                                                              uint8_t* out,
122                                                              size_t* out_len,
123                                                              size_t max_out);
124 
125   ssl_private_key_result_t PrivateKeySignCallback(uint8_t* out,
126                                                   size_t* out_len,
127                                                   size_t max_out,
128                                                   uint16_t algorithm,
129                                                   const uint8_t* in,
130                                                   size_t in_len);
131   ssl_private_key_result_t PrivateKeyCompleteCallback(uint8_t* out,
132                                                       size_t* out_len,
133                                                       size_t max_out);
134   void OnPrivateKeyComplete(Error error, const std::vector<uint8_t>& signature);
135 
136   static int ALPNSelectCallback(SSL* ssl,
137                                 const uint8_t** out,
138                                 uint8_t* out_len,
139                                 const uint8_t* in,
140                                 unsigned in_len,
141                                 void* arg);
142 
143   static ssl_select_cert_result_t SelectCertificateCallback(
144       const SSL_CLIENT_HELLO* client_hello);
145 
146   // SocketBIOAdapter::Delegate implementation.
147   void OnReadReady() override;
148   void OnWriteReady() override;
149 
150  private:
151   enum State {
152     STATE_NONE,
153     STATE_HANDSHAKE,
154   };
155 
156   void OnHandshakeIOComplete(int result);
157 
158   [[nodiscard]] int DoPayloadRead(IOBuffer* buf, int buf_len);
159   [[nodiscard]] int DoPayloadWrite();
160 
161   [[nodiscard]] int DoHandshakeLoop(int last_io_result);
162   [[nodiscard]] int DoHandshake();
163   void DoHandshakeCallback(int result);
164   void DoReadCallback(int result);
165   void DoWriteCallback(int result);
166 
167   [[nodiscard]] int Init();
168   void ExtractClientCert();
169 
170   raw_ptr<SSLServerContextImpl, DanglingUntriaged> context_;
171 
172   NetLogWithSource net_log_;
173 
174   CompletionOnceCallback user_handshake_callback_;
175   CompletionOnceCallback user_read_callback_;
176   CompletionOnceCallback user_write_callback_;
177 
178   // SSLPrivateKey signature.
179   int signature_result_;
180   std::vector<uint8_t> signature_;
181 
182   // Used by Read function.
183   scoped_refptr<IOBuffer> user_read_buf_;
184   int user_read_buf_len_ = 0;
185 
186   // Used by Write function.
187   scoped_refptr<IOBuffer> user_write_buf_;
188   int user_write_buf_len_ = 0;
189 
190   // OpenSSL stuff
191   bssl::UniquePtr<SSL> ssl_;
192 
193   // Whether we received any data in early data.
194   bool early_data_received_ = false;
195 
196   // StreamSocket for sending and receiving data.
197   std::unique_ptr<StreamSocket> transport_socket_;
198   std::unique_ptr<SocketBIOAdapter> transport_adapter_;
199 
200   // Certificate for the client.
201   scoped_refptr<X509Certificate> client_cert_;
202 
203   State next_handshake_state_ = STATE_NONE;
204   bool completed_handshake_ = false;
205 
206   NextProto negotiated_protocol_ = kProtoUnknown;
207 
208   base::WeakPtrFactory<SocketImpl> weak_factory_{this};
209 };
210 
SocketImpl(SSLServerContextImpl * context,std::unique_ptr<StreamSocket> transport_socket)211 SSLServerContextImpl::SocketImpl::SocketImpl(
212     SSLServerContextImpl* context,
213     std::unique_ptr<StreamSocket> transport_socket)
214     : context_(context),
215       signature_result_(kSSLServerSocketNoPendingResult),
216       transport_socket_(std::move(transport_socket)) {}
217 
~SocketImpl()218 SSLServerContextImpl::SocketImpl::~SocketImpl() {
219   if (ssl_) {
220     // Calling SSL_shutdown prevents the session from being marked as
221     // unresumable.
222     SSL_shutdown(ssl_.get());
223     ssl_.reset();
224   }
225 }
226 
227 // static
228 const SSL_PRIVATE_KEY_METHOD
229     SSLServerContextImpl::SocketImpl::kPrivateKeyMethod = {
230         &SSLServerContextImpl::SocketImpl::PrivateKeySignCallback,
231         &SSLServerContextImpl::SocketImpl::PrivateKeyDecryptCallback,
232         &SSLServerContextImpl::SocketImpl::PrivateKeyCompleteCallback,
233 };
234 
235 // static
236 ssl_private_key_result_t
PrivateKeySignCallback(SSL * ssl,uint8_t * out,size_t * out_len,size_t max_out,uint16_t algorithm,const uint8_t * in,size_t in_len)237 SSLServerContextImpl::SocketImpl::PrivateKeySignCallback(SSL* ssl,
238                                                          uint8_t* out,
239                                                          size_t* out_len,
240                                                          size_t max_out,
241                                                          uint16_t algorithm,
242                                                          const uint8_t* in,
243                                                          size_t in_len) {
244   return FromSSL(ssl)->PrivateKeySignCallback(out, out_len, max_out, algorithm,
245                                               in, in_len);
246 }
247 
248 // static
249 ssl_private_key_result_t
PrivateKeyDecryptCallback(SSL * ssl,uint8_t * out,size_t * out_len,size_t max_out,const uint8_t * in,size_t in_len)250 SSLServerContextImpl::SocketImpl::PrivateKeyDecryptCallback(SSL* ssl,
251                                                             uint8_t* out,
252                                                             size_t* out_len,
253                                                             size_t max_out,
254                                                             const uint8_t* in,
255                                                             size_t in_len) {
256   // Decrypt is not supported.
257   return ssl_private_key_failure;
258 }
259 
260 // static
261 ssl_private_key_result_t
PrivateKeyCompleteCallback(SSL * ssl,uint8_t * out,size_t * out_len,size_t max_out)262 SSLServerContextImpl::SocketImpl::PrivateKeyCompleteCallback(SSL* ssl,
263                                                              uint8_t* out,
264                                                              size_t* out_len,
265                                                              size_t max_out) {
266   return FromSSL(ssl)->PrivateKeyCompleteCallback(out, out_len, max_out);
267 }
268 
269 ssl_private_key_result_t
PrivateKeySignCallback(uint8_t * out,size_t * out_len,size_t max_out,uint16_t algorithm,const uint8_t * in,size_t in_len)270 SSLServerContextImpl::SocketImpl::PrivateKeySignCallback(uint8_t* out,
271                                                          size_t* out_len,
272                                                          size_t max_out,
273                                                          uint16_t algorithm,
274                                                          const uint8_t* in,
275                                                          size_t in_len) {
276   DCHECK(context_);
277   DCHECK(context_->private_key_);
278   signature_result_ = ERR_IO_PENDING;
279   context_->private_key_->Sign(
280       algorithm, base::make_span(in, in_len),
281       base::BindOnce(&SSLServerContextImpl::SocketImpl::OnPrivateKeyComplete,
282                      weak_factory_.GetWeakPtr()));
283   return ssl_private_key_retry;
284 }
285 
286 ssl_private_key_result_t
PrivateKeyCompleteCallback(uint8_t * out,size_t * out_len,size_t max_out)287 SSLServerContextImpl::SocketImpl::PrivateKeyCompleteCallback(uint8_t* out,
288                                                              size_t* out_len,
289                                                              size_t max_out) {
290   if (signature_result_ == ERR_IO_PENDING)
291     return ssl_private_key_retry;
292   if (signature_result_ != OK) {
293     OpenSSLPutNetError(FROM_HERE, signature_result_);
294     return ssl_private_key_failure;
295   }
296   if (signature_.size() > max_out) {
297     OpenSSLPutNetError(FROM_HERE, ERR_SSL_CLIENT_AUTH_SIGNATURE_FAILED);
298     return ssl_private_key_failure;
299   }
300   memcpy(out, signature_.data(), signature_.size());
301   *out_len = signature_.size();
302   signature_.clear();
303   return ssl_private_key_success;
304 }
305 
OnPrivateKeyComplete(Error error,const std::vector<uint8_t> & signature)306 void SSLServerContextImpl::SocketImpl::OnPrivateKeyComplete(
307     Error error,
308     const std::vector<uint8_t>& signature) {
309   DCHECK_EQ(ERR_IO_PENDING, signature_result_);
310   DCHECK(signature_.empty());
311 
312   signature_result_ = error;
313   if (signature_result_ == OK)
314     signature_ = signature;
315   OnHandshakeIOComplete(ERR_IO_PENDING);
316 }
317 
318 // static
ALPNSelectCallback(SSL * ssl,const uint8_t ** out,uint8_t * out_len,const uint8_t * in,unsigned in_len,void * arg)319 int SSLServerContextImpl::SocketImpl::ALPNSelectCallback(SSL* ssl,
320                                                          const uint8_t** out,
321                                                          uint8_t* out_len,
322                                                          const uint8_t* in,
323                                                          unsigned in_len,
324                                                          void* arg) {
325   SSLServerContextImpl::SocketImpl* socket = FromSSL(ssl);
326 
327   // Iterate over the server protocols in preference order.
328   for (NextProto server_proto :
329        socket->context_->ssl_server_config_.alpn_protos) {
330     const char* server_proto_str = NextProtoToString(server_proto);
331 
332     // See if the client advertised the corresponding protocol.
333     CBS cbs;
334     CBS_init(&cbs, in, in_len);
335     while (CBS_len(&cbs) != 0) {
336       CBS client_proto;
337       if (!CBS_get_u8_length_prefixed(&cbs, &client_proto)) {
338         return SSL_TLSEXT_ERR_NOACK;
339       }
340       if (std::string_view(
341               reinterpret_cast<const char*>(CBS_data(&client_proto)),
342               CBS_len(&client_proto)) == server_proto_str) {
343         *out = CBS_data(&client_proto);
344         *out_len = CBS_len(&client_proto);
345 
346         const auto& application_settings =
347             socket->context_->ssl_server_config_.application_settings;
348         auto it = application_settings.find(server_proto);
349         if (it != application_settings.end()) {
350           const std::vector<uint8_t>& data = it->second;
351           SSL_add_application_settings(ssl, CBS_data(&client_proto),
352                                        CBS_len(&client_proto), data.data(),
353                                        data.size());
354         }
355         return SSL_TLSEXT_ERR_OK;
356       }
357     }
358   }
359   return SSL_TLSEXT_ERR_NOACK;
360 }
361 
362 ssl_select_cert_result_t
SelectCertificateCallback(const SSL_CLIENT_HELLO * client_hello)363 SSLServerContextImpl::SocketImpl::SelectCertificateCallback(
364     const SSL_CLIENT_HELLO* client_hello) {
365   SSLServerContextImpl::SocketImpl* socket = FromSSL(client_hello->ssl);
366   const SSLServerConfig& config = socket->context_->ssl_server_config_;
367   if (!config.client_hello_callback_for_testing.is_null() &&
368       !config.client_hello_callback_for_testing.Run(client_hello)) {
369     return ssl_select_cert_error;
370   }
371   return ssl_select_cert_success;
372 }
373 
Handshake(CompletionOnceCallback callback)374 int SSLServerContextImpl::SocketImpl::Handshake(
375     CompletionOnceCallback callback) {
376   net_log_.BeginEvent(NetLogEventType::SSL_SERVER_HANDSHAKE);
377 
378   // Set up new ssl object.
379   int rv = Init();
380   if (rv != OK) {
381     LOG(ERROR) << "Failed to initialize OpenSSL: rv=" << rv;
382     net_log_.EndEventWithNetErrorCode(NetLogEventType::SSL_SERVER_HANDSHAKE,
383                                       rv);
384     return rv;
385   }
386 
387   // Set SSL to server mode. Handshake happens in the loop below.
388   SSL_set_accept_state(ssl_.get());
389 
390   GotoState(STATE_HANDSHAKE);
391   rv = DoHandshakeLoop(OK);
392   if (rv == ERR_IO_PENDING) {
393     user_handshake_callback_ = std::move(callback);
394   } else {
395     net_log_.EndEventWithNetErrorCode(NetLogEventType::SSL_SERVER_HANDSHAKE,
396                                       rv);
397   }
398 
399   return rv > OK ? OK : rv;
400 }
401 
ExportKeyingMaterial(std::string_view label,bool has_context,std::string_view context,unsigned char * out,unsigned int outlen)402 int SSLServerContextImpl::SocketImpl::ExportKeyingMaterial(
403     std::string_view label,
404     bool has_context,
405     std::string_view context,
406     unsigned char* out,
407     unsigned int outlen) {
408   if (!IsConnected())
409     return ERR_SOCKET_NOT_CONNECTED;
410 
411   crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
412 
413   int rv = SSL_export_keying_material(
414       ssl_.get(), out, outlen, label.data(), label.size(),
415       reinterpret_cast<const unsigned char*>(context.data()), context.length(),
416       context.length() > 0);
417 
418   if (rv != 1) {
419     int ssl_error = SSL_get_error(ssl_.get(), rv);
420     LOG(ERROR) << "Failed to export keying material;"
421                << " returned " << rv << ", SSL error code " << ssl_error;
422     return MapOpenSSLError(ssl_error, err_tracer);
423   }
424   return OK;
425 }
426 
Read(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)427 int SSLServerContextImpl::SocketImpl::Read(IOBuffer* buf,
428                                            int buf_len,
429                                            CompletionOnceCallback callback) {
430   int rv = ReadIfReady(buf, buf_len, std::move(callback));
431   if (rv == ERR_IO_PENDING) {
432     user_read_buf_ = buf;
433     user_read_buf_len_ = buf_len;
434   }
435   return rv;
436 }
437 
ReadIfReady(IOBuffer * buf,int buf_len,CompletionOnceCallback callback)438 int SSLServerContextImpl::SocketImpl::ReadIfReady(
439     IOBuffer* buf,
440     int buf_len,
441     CompletionOnceCallback callback) {
442   DCHECK(user_read_callback_.is_null());
443   DCHECK(user_handshake_callback_.is_null());
444   DCHECK(!user_read_buf_);
445   DCHECK(!callback.is_null());
446   DCHECK(completed_handshake_);
447 
448   int rv = DoPayloadRead(buf, buf_len);
449 
450   if (rv == ERR_IO_PENDING) {
451     user_read_callback_ = std::move(callback);
452   }
453 
454   return rv;
455 }
456 
CancelReadIfReady()457 int SSLServerContextImpl::SocketImpl::CancelReadIfReady() {
458   DCHECK(user_read_callback_);
459   DCHECK(!user_read_buf_);
460 
461   // Cancel |user_read_callback_|, because caller does not expect the callback
462   // to be invoked after they have canceled the ReadIfReady.
463   //
464   // We do not pass the signal on to |stream_socket_| or |transport_adapter_|.
465   // When it completes, it will signal OnReadReady(), which will notice there is
466   // no read operation to progress and skip it. Unlike with SSLClientSocket,
467   // SSL and transport reads are more aligned, but this avoids making
468   // assumptions or breaking the SocketBIOAdapter's state.
469   user_read_callback_.Reset();
470   return OK;
471 }
472 
Write(IOBuffer * buf,int buf_len,CompletionOnceCallback callback,const NetworkTrafficAnnotationTag & traffic_annotation)473 int SSLServerContextImpl::SocketImpl::Write(
474     IOBuffer* buf,
475     int buf_len,
476     CompletionOnceCallback callback,
477     const NetworkTrafficAnnotationTag& traffic_annotation) {
478   DCHECK(user_write_callback_.is_null());
479   DCHECK(!user_write_buf_);
480   DCHECK(!callback.is_null());
481 
482   user_write_buf_ = buf;
483   user_write_buf_len_ = buf_len;
484 
485   int rv = DoPayloadWrite();
486 
487   if (rv == ERR_IO_PENDING) {
488     user_write_callback_ = std::move(callback);
489   } else {
490     user_write_buf_ = nullptr;
491     user_write_buf_len_ = 0;
492   }
493   return rv;
494 }
495 
SetReceiveBufferSize(int32_t size)496 int SSLServerContextImpl::SocketImpl::SetReceiveBufferSize(int32_t size) {
497   return transport_socket_->SetReceiveBufferSize(size);
498 }
499 
SetSendBufferSize(int32_t size)500 int SSLServerContextImpl::SocketImpl::SetSendBufferSize(int32_t size) {
501   return transport_socket_->SetSendBufferSize(size);
502 }
503 
Connect(CompletionOnceCallback callback)504 int SSLServerContextImpl::SocketImpl::Connect(CompletionOnceCallback callback) {
505   NOTIMPLEMENTED();
506   return ERR_NOT_IMPLEMENTED;
507 }
508 
Disconnect()509 void SSLServerContextImpl::SocketImpl::Disconnect() {
510   transport_socket_->Disconnect();
511 }
512 
IsConnected() const513 bool SSLServerContextImpl::SocketImpl::IsConnected() const {
514   // TODO(wtc): Find out if we should check transport_socket_->IsConnected()
515   // as well.
516   return completed_handshake_;
517 }
518 
IsConnectedAndIdle() const519 bool SSLServerContextImpl::SocketImpl::IsConnectedAndIdle() const {
520   return completed_handshake_ && transport_socket_->IsConnectedAndIdle();
521 }
522 
GetPeerAddress(IPEndPoint * address) const523 int SSLServerContextImpl::SocketImpl::GetPeerAddress(
524     IPEndPoint* address) const {
525   if (!IsConnected())
526     return ERR_SOCKET_NOT_CONNECTED;
527   return transport_socket_->GetPeerAddress(address);
528 }
529 
GetLocalAddress(IPEndPoint * address) const530 int SSLServerContextImpl::SocketImpl::GetLocalAddress(
531     IPEndPoint* address) const {
532   if (!IsConnected())
533     return ERR_SOCKET_NOT_CONNECTED;
534   return transport_socket_->GetLocalAddress(address);
535 }
536 
NetLog() const537 const NetLogWithSource& SSLServerContextImpl::SocketImpl::NetLog() const {
538   return net_log_;
539 }
540 
WasEverUsed() const541 bool SSLServerContextImpl::SocketImpl::WasEverUsed() const {
542   return transport_socket_->WasEverUsed();
543 }
544 
GetNegotiatedProtocol() const545 NextProto SSLServerContextImpl::SocketImpl::GetNegotiatedProtocol() const {
546   return negotiated_protocol_;
547 }
548 
549 std::optional<std::string_view>
GetPeerApplicationSettings() const550 SSLServerContextImpl::SocketImpl::GetPeerApplicationSettings() const {
551   if (!SSL_has_application_settings(ssl_.get())) {
552     return std::nullopt;
553   }
554 
555   const uint8_t* out_data;
556   size_t out_len;
557   SSL_get0_peer_application_settings(ssl_.get(), &out_data, &out_len);
558   return std::string_view{reinterpret_cast<const char*>(out_data), out_len};
559 }
560 
GetSSLInfo(SSLInfo * ssl_info)561 bool SSLServerContextImpl::SocketImpl::GetSSLInfo(SSLInfo* ssl_info) {
562   ssl_info->Reset();
563   if (!completed_handshake_)
564     return false;
565 
566   ssl_info->cert = client_cert_;
567 
568   const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl_.get());
569   CHECK(cipher);
570 
571   SSLConnectionStatusSetCipherSuite(SSL_CIPHER_get_protocol_id(cipher),
572                                     &ssl_info->connection_status);
573   SSLConnectionStatusSetVersion(GetNetSSLVersion(ssl_.get()),
574                                 &ssl_info->connection_status);
575 
576   ssl_info->early_data_received = early_data_received_;
577   ssl_info->encrypted_client_hello = SSL_ech_accepted(ssl_.get());
578   ssl_info->handshake_type = SSL_session_reused(ssl_.get())
579                                  ? SSLInfo::HANDSHAKE_RESUME
580                                  : SSLInfo::HANDSHAKE_FULL;
581 
582   return true;
583 }
584 
GetTotalReceivedBytes() const585 int64_t SSLServerContextImpl::SocketImpl::GetTotalReceivedBytes() const {
586   return transport_socket_->GetTotalReceivedBytes();
587 }
588 
ApplySocketTag(const SocketTag & tag)589 void SSLServerContextImpl::SocketImpl::ApplySocketTag(const SocketTag& tag) {
590   NOTIMPLEMENTED();
591 }
592 
OnReadReady()593 void SSLServerContextImpl::SocketImpl::OnReadReady() {
594   if (next_handshake_state_ == STATE_HANDSHAKE) {
595     // In handshake phase. The parameter to OnHandshakeIOComplete is unused.
596     OnHandshakeIOComplete(OK);
597     return;
598   }
599 
600   // BoringSSL does not support renegotiation as a server, so the only other
601   // operation blocked on Read is DoPayloadRead.
602   if (!user_read_buf_) {
603     if (!user_read_callback_.is_null()) {
604       DoReadCallback(OK);
605     }
606     return;
607   }
608 
609   int rv = DoPayloadRead(user_read_buf_.get(), user_read_buf_len_);
610   if (rv != ERR_IO_PENDING)
611     DoReadCallback(rv);
612 }
613 
OnWriteReady()614 void SSLServerContextImpl::SocketImpl::OnWriteReady() {
615   if (next_handshake_state_ == STATE_HANDSHAKE) {
616     // In handshake phase. The parameter to OnHandshakeIOComplete is unused.
617     OnHandshakeIOComplete(OK);
618     return;
619   }
620 
621   // BoringSSL does not support renegotiation as a server, so the only other
622   // operation blocked on Read is DoPayloadWrite.
623   if (!user_write_buf_)
624     return;
625 
626   int rv = DoPayloadWrite();
627   if (rv != ERR_IO_PENDING)
628     DoWriteCallback(rv);
629 }
630 
OnHandshakeIOComplete(int result)631 void SSLServerContextImpl::SocketImpl::OnHandshakeIOComplete(int result) {
632   int rv = DoHandshakeLoop(result);
633   if (rv == ERR_IO_PENDING)
634     return;
635 
636   net_log_.EndEventWithNetErrorCode(NetLogEventType::SSL_SERVER_HANDSHAKE, rv);
637   if (!user_handshake_callback_.is_null())
638     DoHandshakeCallback(rv);
639 }
640 
DoPayloadRead(IOBuffer * buf,int buf_len)641 int SSLServerContextImpl::SocketImpl::DoPayloadRead(IOBuffer* buf,
642                                                     int buf_len) {
643   DCHECK(completed_handshake_);
644   DCHECK_EQ(STATE_NONE, next_handshake_state_);
645   DCHECK(buf);
646   DCHECK_GT(buf_len, 0);
647 
648   crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
649   int rv = SSL_read(ssl_.get(), buf->data(), buf_len);
650   if (rv >= 0) {
651     if (SSL_in_early_data(ssl_.get()))
652       early_data_received_ = true;
653     return rv;
654   }
655   int ssl_error = SSL_get_error(ssl_.get(), rv);
656   OpenSSLErrorInfo error_info;
657   int net_error =
658       MapOpenSSLErrorWithDetails(ssl_error, err_tracer, &error_info);
659   if (net_error != ERR_IO_PENDING) {
660     NetLogOpenSSLError(net_log_, NetLogEventType::SSL_READ_ERROR, net_error,
661                        ssl_error, error_info);
662   }
663   return net_error;
664 }
665 
DoPayloadWrite()666 int SSLServerContextImpl::SocketImpl::DoPayloadWrite() {
667   DCHECK(completed_handshake_);
668   DCHECK_EQ(STATE_NONE, next_handshake_state_);
669   DCHECK(user_write_buf_);
670 
671   crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
672   int rv = SSL_write(ssl_.get(), user_write_buf_->data(), user_write_buf_len_);
673   if (rv >= 0)
674     return rv;
675   int ssl_error = SSL_get_error(ssl_.get(), rv);
676   OpenSSLErrorInfo error_info;
677   int net_error =
678       MapOpenSSLErrorWithDetails(ssl_error, err_tracer, &error_info);
679   if (net_error != ERR_IO_PENDING) {
680     NetLogOpenSSLError(net_log_, NetLogEventType::SSL_WRITE_ERROR, net_error,
681                        ssl_error, error_info);
682   }
683   return net_error;
684 }
685 
DoHandshakeLoop(int last_io_result)686 int SSLServerContextImpl::SocketImpl::DoHandshakeLoop(int last_io_result) {
687   int rv = last_io_result;
688   do {
689     // Default to STATE_NONE for next state.
690     // (This is a quirk carried over from the windows
691     // implementation.  It makes reading the logs a bit harder.)
692     // State handlers can and often do call GotoState just
693     // to stay in the current state.
694     State state = next_handshake_state_;
695     GotoState(STATE_NONE);
696     switch (state) {
697       case STATE_HANDSHAKE:
698         rv = DoHandshake();
699         break;
700       case STATE_NONE:
701       default:
702         rv = ERR_UNEXPECTED;
703         LOG(DFATAL) << "unexpected state " << state;
704         break;
705     }
706   } while (rv != ERR_IO_PENDING && next_handshake_state_ != STATE_NONE);
707   return rv;
708 }
709 
DoHandshake()710 int SSLServerContextImpl::SocketImpl::DoHandshake() {
711   crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
712   int net_error = OK;
713   int rv = SSL_do_handshake(ssl_.get());
714   if (rv == 1) {
715     const STACK_OF(CRYPTO_BUFFER)* certs =
716         SSL_get0_peer_certificates(ssl_.get());
717     if (certs) {
718       client_cert_ = x509_util::CreateX509CertificateFromBuffers(certs);
719       if (!client_cert_)
720         return ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT;
721     }
722 
723     const uint8_t* alpn_proto = nullptr;
724     unsigned alpn_len = 0;
725     SSL_get0_alpn_selected(ssl_.get(), &alpn_proto, &alpn_len);
726     if (alpn_len > 0) {
727       std::string_view proto(reinterpret_cast<const char*>(alpn_proto),
728                              alpn_len);
729       negotiated_protocol_ = NextProtoFromString(proto);
730     }
731 
732     if (context_->ssl_server_config_.alert_after_handshake_for_testing) {
733       SSL_send_fatal_alert(ssl_.get(),
734                            context_->ssl_server_config_
735                                .alert_after_handshake_for_testing.value());
736       return ERR_FAILED;
737     }
738 
739     completed_handshake_ = true;
740   } else {
741     int ssl_error = SSL_get_error(ssl_.get(), rv);
742 
743     if (ssl_error == SSL_ERROR_WANT_PRIVATE_KEY_OPERATION) {
744       DCHECK(context_->private_key_);
745       GotoState(STATE_HANDSHAKE);
746       return ERR_IO_PENDING;
747     }
748 
749     OpenSSLErrorInfo error_info;
750     net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer, &error_info);
751 
752     // SSL_R_CERTIFICATE_VERIFY_FAILED's mapping is different between client and
753     // server.
754     if (ERR_GET_LIB(error_info.error_code) == ERR_LIB_SSL &&
755         ERR_GET_REASON(error_info.error_code) ==
756             SSL_R_CERTIFICATE_VERIFY_FAILED) {
757       net_error = ERR_BAD_SSL_CLIENT_AUTH_CERT;
758     }
759 
760     // If not done, stay in this state
761     if (net_error == ERR_IO_PENDING) {
762       GotoState(STATE_HANDSHAKE);
763     } else {
764       LOG(ERROR) << "handshake failed; returned " << rv << ", SSL error code "
765                  << ssl_error << ", net_error " << net_error;
766       NetLogOpenSSLError(net_log_, NetLogEventType::SSL_HANDSHAKE_ERROR,
767                          net_error, ssl_error, error_info);
768     }
769   }
770   return net_error;
771 }
772 
DoHandshakeCallback(int rv)773 void SSLServerContextImpl::SocketImpl::DoHandshakeCallback(int rv) {
774   DCHECK_NE(rv, ERR_IO_PENDING);
775   std::move(user_handshake_callback_).Run(rv > OK ? OK : rv);
776 }
777 
DoReadCallback(int rv)778 void SSLServerContextImpl::SocketImpl::DoReadCallback(int rv) {
779   DCHECK(rv != ERR_IO_PENDING);
780   DCHECK(!user_read_callback_.is_null());
781 
782   user_read_buf_ = nullptr;
783   user_read_buf_len_ = 0;
784   std::move(user_read_callback_).Run(rv);
785 }
786 
DoWriteCallback(int rv)787 void SSLServerContextImpl::SocketImpl::DoWriteCallback(int rv) {
788   DCHECK(rv != ERR_IO_PENDING);
789   DCHECK(!user_write_callback_.is_null());
790 
791   user_write_buf_ = nullptr;
792   user_write_buf_len_ = 0;
793   std::move(user_write_callback_).Run(rv);
794 }
795 
Init()796 int SSLServerContextImpl::SocketImpl::Init() {
797   static const int kBufferSize = 17 * 1024;
798 
799   crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
800 
801   ssl_.reset(SSL_new(context_->ssl_ctx_.get()));
802   if (!ssl_ || !SSL_set_app_data(ssl_.get(), this)) {
803     return ERR_UNEXPECTED;
804   }
805 
806   SSL_set_shed_handshake_config(ssl_.get(), 1);
807 
808   // Set certificate and private key.
809   if (context_->pkey_) {
810     DCHECK(context_->cert_->cert_buffer());
811     if (!SetSSLChainAndKey(ssl_.get(), context_->cert_.get(),
812                            context_->pkey_.get(), nullptr)) {
813       return ERR_UNEXPECTED;
814     }
815   } else {
816     DCHECK(context_->private_key_);
817     if (!SetSSLChainAndKey(ssl_.get(), context_->cert_.get(), nullptr,
818                            &kPrivateKeyMethod)) {
819       return ERR_UNEXPECTED;
820     }
821     std::vector<uint16_t> preferences =
822         context_->private_key_->GetAlgorithmPreferences();
823     SSL_set_signing_algorithm_prefs(ssl_.get(), preferences.data(),
824                                     preferences.size());
825   }
826 
827   if (context_->ssl_server_config_.signature_algorithm_for_testing
828           .has_value()) {
829     uint16_t id = *context_->ssl_server_config_.signature_algorithm_for_testing;
830     CHECK(SSL_set_signing_algorithm_prefs(ssl_.get(), &id, 1));
831   }
832 
833   const std::vector<int>& curves =
834       context_->ssl_server_config_.curves_for_testing;
835   if (!curves.empty()) {
836     CHECK(SSL_set1_curves(ssl_.get(), curves.data(), curves.size()));
837   }
838 
839   transport_adapter_ = std::make_unique<SocketBIOAdapter>(
840       transport_socket_.get(), kBufferSize, kBufferSize, this);
841   BIO* transport_bio = transport_adapter_->bio();
842 
843   BIO_up_ref(transport_bio);  // SSL_set0_rbio takes ownership.
844   SSL_set0_rbio(ssl_.get(), transport_bio);
845 
846   BIO_up_ref(transport_bio);  // SSL_set0_wbio takes ownership.
847   SSL_set0_wbio(ssl_.get(), transport_bio);
848 
849   return OK;
850 }
851 
FromSSL(SSL * ssl)852 SSLServerContextImpl::SocketImpl* SSLServerContextImpl::SocketImpl::FromSSL(
853     SSL* ssl) {
854   SocketImpl* socket = reinterpret_cast<SocketImpl*>(SSL_get_app_data(ssl));
855   DCHECK(socket);
856   return socket;
857 }
858 
859 // static
CertVerifyCallback(SSL * ssl,uint8_t * out_alert)860 ssl_verify_result_t SSLServerContextImpl::SocketImpl::CertVerifyCallback(
861     SSL* ssl,
862     uint8_t* out_alert) {
863   return FromSSL(ssl)->CertVerifyCallbackImpl(out_alert);
864 }
865 
CertVerifyCallbackImpl(uint8_t * out_alert)866 ssl_verify_result_t SSLServerContextImpl::SocketImpl::CertVerifyCallbackImpl(
867     uint8_t* out_alert) {
868   ClientCertVerifier* verifier =
869       context_->ssl_server_config_.client_cert_verifier;
870   // If a verifier was not supplied, all certificates are accepted.
871   if (!verifier)
872     return ssl_verify_ok;
873 
874   scoped_refptr<X509Certificate> client_cert =
875       x509_util::CreateX509CertificateFromBuffers(
876           SSL_get0_peer_certificates(ssl_.get()));
877   if (!client_cert) {
878     *out_alert = SSL_AD_BAD_CERTIFICATE;
879     return ssl_verify_invalid;
880   }
881 
882   // TODO(davidben): Support asynchronous verifiers. http://crbug.com/347402
883   std::unique_ptr<ClientCertVerifier::Request> ignore_async;
884   int res = verifier->Verify(client_cert.get(), CompletionOnceCallback(),
885                              &ignore_async);
886   DCHECK_NE(res, ERR_IO_PENDING);
887 
888   if (res != OK) {
889     // TODO(davidben): Map from certificate verification failure to alert.
890     *out_alert = SSL_AD_CERTIFICATE_UNKNOWN;
891     return ssl_verify_invalid;
892   }
893   return ssl_verify_ok;
894 }
895 
CreateSSLServerContext(X509Certificate * certificate,EVP_PKEY * pkey,const SSLServerConfig & ssl_server_config)896 std::unique_ptr<SSLServerContext> CreateSSLServerContext(
897     X509Certificate* certificate,
898     EVP_PKEY* pkey,
899     const SSLServerConfig& ssl_server_config) {
900   return std::make_unique<SSLServerContextImpl>(certificate, pkey,
901                                                 ssl_server_config);
902 }
903 
CreateSSLServerContext(X509Certificate * certificate,const crypto::RSAPrivateKey & key,const SSLServerConfig & ssl_server_config)904 std::unique_ptr<SSLServerContext> CreateSSLServerContext(
905     X509Certificate* certificate,
906     const crypto::RSAPrivateKey& key,
907     const SSLServerConfig& ssl_server_config) {
908   return std::make_unique<SSLServerContextImpl>(certificate, key.key(),
909                                                 ssl_server_config);
910 }
911 
CreateSSLServerContext(X509Certificate * certificate,scoped_refptr<SSLPrivateKey> key,const SSLServerConfig & ssl_config)912 std::unique_ptr<SSLServerContext> CreateSSLServerContext(
913     X509Certificate* certificate,
914     scoped_refptr<SSLPrivateKey> key,
915     const SSLServerConfig& ssl_config) {
916   return std::make_unique<SSLServerContextImpl>(certificate, key, ssl_config);
917 }
918 
SSLServerContextImpl(X509Certificate * certificate,scoped_refptr<net::SSLPrivateKey> key,const SSLServerConfig & ssl_server_config)919 SSLServerContextImpl::SSLServerContextImpl(
920     X509Certificate* certificate,
921     scoped_refptr<net::SSLPrivateKey> key,
922     const SSLServerConfig& ssl_server_config)
923     : ssl_server_config_(ssl_server_config),
924       cert_(certificate),
925       private_key_(key) {
926   CHECK(private_key_);
927   Init();
928 }
929 
SSLServerContextImpl(X509Certificate * certificate,EVP_PKEY * pkey,const SSLServerConfig & ssl_server_config)930 SSLServerContextImpl::SSLServerContextImpl(
931     X509Certificate* certificate,
932     EVP_PKEY* pkey,
933     const SSLServerConfig& ssl_server_config)
934     : ssl_server_config_(ssl_server_config), cert_(certificate) {
935   CHECK(pkey);
936   pkey_ = bssl::UpRef(pkey);
937   Init();
938 }
939 
Init()940 void SSLServerContextImpl::Init() {
941   crypto::EnsureOpenSSLInit();
942   ssl_ctx_.reset(SSL_CTX_new(TLS_with_buffers_method()));
943   SSL_CTX_set_session_cache_mode(ssl_ctx_.get(), SSL_SESS_CACHE_SERVER);
944   uint8_t session_ctx_id = 0;
945   SSL_CTX_set_session_id_context(ssl_ctx_.get(), &session_ctx_id,
946                                  sizeof(session_ctx_id));
947   // Deduplicate all certificates minted from the SSL_CTX in memory.
948   SSL_CTX_set0_buffer_pool(ssl_ctx_.get(), x509_util::GetBufferPool());
949 
950   int verify_mode = 0;
951   switch (ssl_server_config_.client_cert_type) {
952     case SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT:
953       verify_mode |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
954       [[fallthrough]];
955     case SSLServerConfig::ClientCertType::OPTIONAL_CLIENT_CERT:
956       verify_mode |= SSL_VERIFY_PEER;
957       SSL_CTX_set_custom_verify(ssl_ctx_.get(), verify_mode,
958                                 SocketImpl::CertVerifyCallback);
959       break;
960     case SSLServerConfig::ClientCertType::NO_CLIENT_CERT:
961       break;
962   }
963 
964   SSL_CTX_set_early_data_enabled(ssl_ctx_.get(),
965                                  ssl_server_config_.early_data_enabled);
966   // TLS versions before TLS 1.2 are no longer supported.
967   CHECK_LE(TLS1_2_VERSION, ssl_server_config_.version_min);
968   CHECK_LE(TLS1_2_VERSION, ssl_server_config_.version_max);
969   CHECK(SSL_CTX_set_min_proto_version(ssl_ctx_.get(),
970                                       ssl_server_config_.version_min));
971   CHECK(SSL_CTX_set_max_proto_version(ssl_ctx_.get(),
972                                       ssl_server_config_.version_max));
973 
974   // OpenSSL defaults some options to on, others to off. To avoid ambiguity,
975   // set everything we care about to an absolute value.
976   SslSetClearMask options;
977   options.ConfigureFlag(SSL_OP_NO_COMPRESSION, true);
978 
979   SSL_CTX_set_options(ssl_ctx_.get(), options.set_mask);
980   SSL_CTX_clear_options(ssl_ctx_.get(), options.clear_mask);
981 
982   // Same as above, this time for the SSL mode.
983   SslSetClearMask mode;
984 
985   mode.ConfigureFlag(SSL_MODE_RELEASE_BUFFERS, true);
986 
987   SSL_CTX_set_mode(ssl_ctx_.get(), mode.set_mask);
988   SSL_CTX_clear_mode(ssl_ctx_.get(), mode.clear_mask);
989 
990   if (ssl_server_config_.cipher_suite_for_testing.has_value()) {
991     const SSL_CIPHER* cipher =
992         SSL_get_cipher_by_value(*ssl_server_config_.cipher_suite_for_testing);
993     CHECK(cipher);
994     CHECK(SSL_CTX_set_strict_cipher_list(ssl_ctx_.get(),
995                                          SSL_CIPHER_get_name(cipher)));
996   } else {
997     // Use BoringSSL defaults, but disable 3DES and HMAC-SHA1 ciphers in ECDSA.
998     // These are the remaining CBC-mode ECDSA ciphers.
999     std::string command("ALL:!aPSK:!ECDSA+SHA1:!3DES");
1000 
1001     // SSLPrivateKey only supports ECDHE-based ciphers because it lacks decrypt.
1002     if (ssl_server_config_.require_ecdhe || (!pkey_ && private_key_))
1003       command.append(":!kRSA");
1004 
1005     // Remove any disabled ciphers.
1006     for (uint16_t id : ssl_server_config_.disabled_cipher_suites) {
1007       const SSL_CIPHER* cipher = SSL_get_cipher_by_value(id);
1008       if (cipher) {
1009         command.append(":!");
1010         command.append(SSL_CIPHER_get_name(cipher));
1011       }
1012     }
1013 
1014     CHECK(SSL_CTX_set_strict_cipher_list(ssl_ctx_.get(), command.c_str()));
1015   }
1016 
1017   if (ssl_server_config_.client_cert_type !=
1018           SSLServerConfig::ClientCertType::NO_CLIENT_CERT &&
1019       !ssl_server_config_.cert_authorities.empty()) {
1020     bssl::UniquePtr<STACK_OF(CRYPTO_BUFFER)> stack(sk_CRYPTO_BUFFER_new_null());
1021     for (const auto& authority : ssl_server_config_.cert_authorities) {
1022       sk_CRYPTO_BUFFER_push(stack.get(),
1023                             x509_util::CreateCryptoBuffer(authority).release());
1024     }
1025     SSL_CTX_set0_client_CAs(ssl_ctx_.get(), stack.release());
1026   }
1027 
1028   SSL_CTX_set_alpn_select_cb(ssl_ctx_.get(), &SocketImpl::ALPNSelectCallback,
1029                              nullptr);
1030 
1031   if (!ssl_server_config_.ocsp_response.empty()) {
1032     SSL_CTX_set_ocsp_response(ssl_ctx_.get(),
1033                               ssl_server_config_.ocsp_response.data(),
1034                               ssl_server_config_.ocsp_response.size());
1035   }
1036 
1037   if (!ssl_server_config_.signed_cert_timestamp_list.empty()) {
1038     SSL_CTX_set_signed_cert_timestamp_list(
1039         ssl_ctx_.get(), ssl_server_config_.signed_cert_timestamp_list.data(),
1040         ssl_server_config_.signed_cert_timestamp_list.size());
1041   }
1042 
1043   if (ssl_server_config_.ech_keys) {
1044     CHECK(SSL_CTX_set1_ech_keys(ssl_ctx_.get(),
1045                                 ssl_server_config_.ech_keys.get()));
1046   }
1047 
1048   SSL_CTX_set_select_certificate_cb(ssl_ctx_.get(),
1049                                     &SocketImpl::SelectCertificateCallback);
1050 }
1051 
1052 SSLServerContextImpl::~SSLServerContextImpl() = default;
1053 
CreateSSLServerSocket(std::unique_ptr<StreamSocket> socket)1054 std::unique_ptr<SSLServerSocket> SSLServerContextImpl::CreateSSLServerSocket(
1055     std::unique_ptr<StreamSocket> socket) {
1056   return std::make_unique<SocketImpl>(this, std::move(socket));
1057 }
1058 
1059 }  // namespace net
1060