xref: /aosp_15_r20/external/libbrillo/brillo/streams/tls_stream.cc (revision 1a96fba65179ea7d3f56207137718607415c5953)
1 // Copyright 2015 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <brillo/streams/tls_stream.h>
6 
7 #include <algorithm>
8 #include <limits>
9 #include <string>
10 #include <utility>
11 #include <vector>
12 
13 #include <openssl/err.h>
14 #include <openssl/ssl.h>
15 
16 #include <base/bind.h>
17 #include <base/memory/weak_ptr.h>
18 #include <brillo/message_loops/message_loop.h>
19 #include <brillo/secure_blob.h>
20 #include <brillo/streams/openssl_stream_bio.h>
21 #include <brillo/streams/stream_utils.h>
22 #include <brillo/strings/string_utils.h>
23 
24 namespace {
25 
26 // SSL info callback which is called by OpenSSL when we enable logging level of
27 // at least 3. This logs the information about the internal TLS handshake.
TlsInfoCallback(const SSL *,int where,int ret)28 void TlsInfoCallback(const SSL* /* ssl */, int where, int ret) {
29   std::string reason;
30   std::vector<std::string> info;
31   if (where & SSL_CB_LOOP)
32     info.push_back("loop");
33   if (where & SSL_CB_EXIT)
34     info.push_back("exit");
35   if (where & SSL_CB_READ)
36     info.push_back("read");
37   if (where & SSL_CB_WRITE)
38     info.push_back("write");
39   if (where & SSL_CB_ALERT) {
40     info.push_back("alert");
41     reason = ", reason: ";
42     reason += SSL_alert_type_string_long(ret);
43     reason += "/";
44     reason += SSL_alert_desc_string_long(ret);
45   }
46   if (where & SSL_CB_HANDSHAKE_START)
47     info.push_back("handshake_start");
48   if (where & SSL_CB_HANDSHAKE_DONE)
49     info.push_back("handshake_done");
50 
51   VLOG(3) << "TLS progress info: " << brillo::string_utils::Join(",", info)
52           << ", with status: " << ret << reason;
53 }
54 
55 // Static variable to store the index of TlsStream private data in SSL context
56 // used to store custom data for OnCertVerifyResults().
57 int ssl_ctx_private_data_index = -1;
58 
59 // Default trusted certificate store location.
60 const char kCACertificatePath[] =
61 #ifdef __ANDROID__
62     "/system/etc/security/cacerts_google";
63 #else
64     "/usr/share/chromeos-ca-certificates";
65 #endif
66 
67 }  // anonymous namespace
68 
69 namespace brillo {
70 
71 // TODO(crbug.com/984789): Remove once support for OpenSSL <1.1 is dropped.
72 #if OPENSSL_VERSION_NUMBER < 0x10100000L
73 #define TLS_client_method() TLSv1_2_client_method()
74 #endif
75 
76 // Helper implementation of TLS stream used to hide most of OpenSSL inner
77 // workings from the users of brillo::TlsStream.
78 class TlsStream::TlsStreamImpl {
79  public:
80   TlsStreamImpl();
81   ~TlsStreamImpl();
82 
83   bool Init(StreamPtr socket,
84             const std::string& host,
85             const base::Closure& success_callback,
86             const Stream::ErrorCallback& error_callback,
87             ErrorPtr* error);
88 
89   bool ReadNonBlocking(void* buffer,
90                        size_t size_to_read,
91                        size_t* size_read,
92                        bool* end_of_stream,
93                        ErrorPtr* error);
94 
95   bool WriteNonBlocking(const void* buffer,
96                         size_t size_to_write,
97                         size_t* size_written,
98                         ErrorPtr* error);
99 
100   bool Flush(ErrorPtr* error);
101   bool Close(ErrorPtr* error);
102   bool WaitForData(AccessMode mode,
103                    const base::Callback<void(AccessMode)>& callback,
104                    ErrorPtr* error);
105   bool WaitForDataBlocking(AccessMode in_mode,
106                            base::TimeDelta timeout,
107                            AccessMode* out_mode,
108                            ErrorPtr* error);
109   void CancelPendingAsyncOperations();
110 
111  private:
112   bool ReportError(ErrorPtr* error,
113                    const base::Location& location,
114                    const std::string& message);
115   void DoHandshake(const base::Closure& success_callback,
116                    const Stream::ErrorCallback& error_callback);
117   void RetryHandshake(const base::Closure& success_callback,
118                       const Stream::ErrorCallback& error_callback,
119                       Stream::AccessMode mode);
120 
121   int OnCertVerifyResults(int ok, X509_STORE_CTX* ctx);
122   static int OnCertVerifyResultsStatic(int ok, X509_STORE_CTX* ctx);
123 
124   StreamPtr socket_;
125   std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free};
126   std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free};
127   BIO* stream_bio_{nullptr};
128   bool need_more_read_{false};
129   bool need_more_write_{false};
130 
131   base::WeakPtrFactory<TlsStreamImpl> weak_ptr_factory_{this};
132   DISALLOW_COPY_AND_ASSIGN(TlsStreamImpl);
133 };
134 
TlsStreamImpl()135 TlsStream::TlsStreamImpl::TlsStreamImpl() {
136   SSL_load_error_strings();
137   SSL_library_init();
138   if (ssl_ctx_private_data_index < 0) {
139     ssl_ctx_private_data_index =
140         SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
141   }
142 }
143 
~TlsStreamImpl()144 TlsStream::TlsStreamImpl::~TlsStreamImpl() {
145   ssl_.reset();
146   ctx_.reset();
147 }
148 
ReadNonBlocking(void * buffer,size_t size_to_read,size_t * size_read,bool * end_of_stream,ErrorPtr * error)149 bool TlsStream::TlsStreamImpl::ReadNonBlocking(void* buffer,
150                                                size_t size_to_read,
151                                                size_t* size_read,
152                                                bool* end_of_stream,
153                                                ErrorPtr* error) {
154   const size_t max_int = std::numeric_limits<int>::max();
155   int size_int = static_cast<int>(std::min(size_to_read, max_int));
156   int ret = SSL_read(ssl_.get(), buffer, size_int);
157   if (ret > 0) {
158     *size_read = static_cast<size_t>(ret);
159     if (end_of_stream)
160       *end_of_stream = false;
161     return true;
162   }
163 
164   int err = SSL_get_error(ssl_.get(), ret);
165   if (err == SSL_ERROR_ZERO_RETURN) {
166     *size_read = 0;
167     if (end_of_stream)
168       *end_of_stream = true;
169     return true;
170   }
171 
172   if (err == SSL_ERROR_WANT_READ) {
173     need_more_read_ = true;
174   } else if (err == SSL_ERROR_WANT_WRITE) {
175     // Writes might be required for SSL_read() because of possible TLS
176     // re-negotiations which can happen at any time.
177     need_more_write_ = true;
178   } else {
179     return ReportError(error, FROM_HERE, "Error reading from TLS socket");
180   }
181   *size_read = 0;
182   if (end_of_stream)
183     *end_of_stream = false;
184   return true;
185 }
186 
WriteNonBlocking(const void * buffer,size_t size_to_write,size_t * size_written,ErrorPtr * error)187 bool TlsStream::TlsStreamImpl::WriteNonBlocking(const void* buffer,
188                                                 size_t size_to_write,
189                                                 size_t* size_written,
190                                                 ErrorPtr* error) {
191   const size_t max_int = std::numeric_limits<int>::max();
192   int size_int = static_cast<int>(std::min(size_to_write, max_int));
193   int ret = SSL_write(ssl_.get(), buffer, size_int);
194   if (ret > 0) {
195     *size_written = static_cast<size_t>(ret);
196     return true;
197   }
198 
199   int err = SSL_get_error(ssl_.get(), ret);
200   if (err == SSL_ERROR_WANT_READ) {
201     // Reads might be required for SSL_write() because of possible TLS
202     // re-negotiations which can happen at any time.
203     need_more_read_ = true;
204   } else if (err == SSL_ERROR_WANT_WRITE) {
205     need_more_write_ = true;
206   } else {
207     return ReportError(error, FROM_HERE, "Error writing to TLS socket");
208   }
209   *size_written = 0;
210   return true;
211 }
212 
Flush(ErrorPtr * error)213 bool TlsStream::TlsStreamImpl::Flush(ErrorPtr* error) {
214   return socket_->FlushBlocking(error);
215 }
216 
Close(ErrorPtr * error)217 bool TlsStream::TlsStreamImpl::Close(ErrorPtr* error) {
218   // 2 seconds should be plenty here.
219   const base::TimeDelta kTimeout = base::TimeDelta::FromSeconds(2);
220   // The retry count of 4 below is just arbitrary, to ensure we don't get stuck
221   // here forever. We should rarely need to repeat SSL_shutdown anyway.
222   for (int retry_count = 0; retry_count < 4; retry_count++) {
223     int ret = SSL_shutdown(ssl_.get());
224     // We really don't care for bi-directional shutdown here.
225     // Just make sure we only send the "close notify" alert to the remote peer.
226     if (ret >= 0)
227       break;
228 
229     int err = SSL_get_error(ssl_.get(), ret);
230     if (err == SSL_ERROR_WANT_READ) {
231       if (!socket_->WaitForDataBlocking(AccessMode::READ, kTimeout, nullptr,
232                                         error)) {
233         break;
234       }
235     } else if (err == SSL_ERROR_WANT_WRITE) {
236       if (!socket_->WaitForDataBlocking(AccessMode::WRITE, kTimeout, nullptr,
237                                         error)) {
238         break;
239       }
240     } else {
241       LOG(ERROR) << "SSL_shutdown returned error #" << err;
242       ReportError(error, FROM_HERE, "Failed to shut down TLS socket");
243       break;
244     }
245   }
246   return socket_->CloseBlocking(error);
247 }
248 
WaitForData(AccessMode mode,const base::Callback<void (AccessMode)> & callback,ErrorPtr * error)249 bool TlsStream::TlsStreamImpl::WaitForData(
250     AccessMode mode,
251     const base::Callback<void(AccessMode)>& callback,
252     ErrorPtr* error) {
253   bool is_read = stream_utils::IsReadAccessMode(mode);
254   bool is_write = stream_utils::IsWriteAccessMode(mode);
255   is_read |= need_more_read_;
256   is_write |= need_more_write_;
257   need_more_read_ = false;
258   need_more_write_ = false;
259   if (is_read && SSL_pending(ssl_.get()) > 0) {
260     callback.Run(AccessMode::READ);
261     return true;
262   }
263   mode = stream_utils::MakeAccessMode(is_read, is_write);
264   return socket_->WaitForData(mode, callback, error);
265 }
266 
WaitForDataBlocking(AccessMode in_mode,base::TimeDelta timeout,AccessMode * out_mode,ErrorPtr * error)267 bool TlsStream::TlsStreamImpl::WaitForDataBlocking(AccessMode in_mode,
268                                                    base::TimeDelta timeout,
269                                                    AccessMode* out_mode,
270                                                    ErrorPtr* error) {
271   bool is_read = stream_utils::IsReadAccessMode(in_mode);
272   bool is_write = stream_utils::IsWriteAccessMode(in_mode);
273   is_read |= need_more_read_;
274   is_write |= need_more_write_;
275   need_more_read_ = need_more_write_ = false;
276   if (is_read && SSL_pending(ssl_.get()) > 0) {
277     if (out_mode)
278       *out_mode = AccessMode::READ;
279     return true;
280   }
281   in_mode = stream_utils::MakeAccessMode(is_read, is_write);
282   return socket_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
283 }
284 
CancelPendingAsyncOperations()285 void TlsStream::TlsStreamImpl::CancelPendingAsyncOperations() {
286   socket_->CancelPendingAsyncOperations();
287   weak_ptr_factory_.InvalidateWeakPtrs();
288 }
289 
ReportError(ErrorPtr * error,const base::Location & location,const std::string & message)290 bool TlsStream::TlsStreamImpl::ReportError(
291     ErrorPtr* error,
292     const base::Location& location,
293     const std::string& message) {
294   const char* file = nullptr;
295   int line = 0;
296   const char* data = 0;
297   int flags = 0;
298   while (auto errnum = ERR_get_error_line_data(&file, &line, &data, &flags)) {
299     char buf[256];
300     ERR_error_string_n(errnum, buf, sizeof(buf));
301     base::Location ssl_location{"Unknown", file, line, nullptr};
302     std::string ssl_message = buf;
303     if (flags & ERR_TXT_STRING) {
304       ssl_message += ": ";
305       ssl_message += data;
306     }
307     Error::AddTo(error, ssl_location, "openssl", std::to_string(errnum),
308                  ssl_message);
309   }
310   Error::AddTo(error, location, "tls_stream", "failed", message);
311   return false;
312 }
313 
OnCertVerifyResults(int ok,X509_STORE_CTX * ctx)314 int TlsStream::TlsStreamImpl::OnCertVerifyResults(int ok, X509_STORE_CTX* ctx) {
315   // OpenSSL already performs a comprehensive check of the certificate chain
316   // (using X509_verify_cert() function) and calls back with the result of its
317   // verification.
318   // |ok| is set to 1 if the verification passed and 0 if an error was detected.
319   // Here we can perform some additional checks if we need to, or simply log
320   // the issues found.
321 
322   // For now, just log an error if it occurred.
323   if (!ok) {
324     LOG(ERROR) << "Server certificate validation failed: "
325                << X509_verify_cert_error_string(X509_STORE_CTX_get_error(ctx));
326   }
327   return ok;
328 }
329 
OnCertVerifyResultsStatic(int ok,X509_STORE_CTX * ctx)330 int TlsStream::TlsStreamImpl::OnCertVerifyResultsStatic(int ok,
331                                                         X509_STORE_CTX* ctx) {
332   // Obtain the pointer to the instance of TlsStream::TlsStreamImpl from the
333   // SSL CTX object referenced by |ctx|.
334   SSL* ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(
335       ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
336   SSL_CTX* ssl_ctx = ssl ? SSL_get_SSL_CTX(ssl) : nullptr;
337   TlsStream::TlsStreamImpl* self = nullptr;
338   if (ssl_ctx) {
339     self = static_cast<TlsStream::TlsStreamImpl*>(SSL_CTX_get_ex_data(
340         ssl_ctx, ssl_ctx_private_data_index));
341   }
342   return self ? self->OnCertVerifyResults(ok, ctx) : ok;
343 }
344 
Init(StreamPtr socket,const std::string & host,const base::Closure & success_callback,const Stream::ErrorCallback & error_callback,ErrorPtr * error)345 bool TlsStream::TlsStreamImpl::Init(StreamPtr socket,
346                                     const std::string& host,
347                                     const base::Closure& success_callback,
348                                     const Stream::ErrorCallback& error_callback,
349                                     ErrorPtr* error) {
350   ctx_.reset(SSL_CTX_new(TLS_client_method()));
351   if (!ctx_)
352     return ReportError(error, FROM_HERE, "Cannot create SSL_CTX");
353 
354   // Top cipher suites supported by both Google GFEs and OpenSSL (in server
355   // preferred order).
356   int res = SSL_CTX_set_cipher_list(ctx_.get(),
357                                     "ECDHE-ECDSA-AES128-GCM-SHA256:"
358                                     "ECDHE-ECDSA-AES256-GCM-SHA384:"
359                                     "ECDHE-RSA-AES128-GCM-SHA256:"
360                                     "ECDHE-RSA-AES256-GCM-SHA384");
361   if (res != 1)
362     return ReportError(error, FROM_HERE, "Cannot set the cipher list");
363 
364   res = SSL_CTX_load_verify_locations(ctx_.get(), nullptr, kCACertificatePath);
365   if (res != 1) {
366     return ReportError(error, FROM_HERE,
367                        "Failed to specify trusted certificate location");
368   }
369 
370   // Store a pointer to "this" into SSL_CTX instance.
371   SSL_CTX_set_ex_data(ctx_.get(), ssl_ctx_private_data_index, this);
372 
373   // Ask OpenSSL to validate the server host from the certificate to match
374   // the expected host name we are given:
375   X509_VERIFY_PARAM* param = SSL_CTX_get0_param(ctx_.get());
376   X509_VERIFY_PARAM_set1_host(param, host.c_str(), host.size());
377 
378   SSL_CTX_set_verify(ctx_.get(), SSL_VERIFY_PEER,
379                      &TlsStreamImpl::OnCertVerifyResultsStatic);
380 
381   socket_ = std::move(socket);
382   ssl_.reset(SSL_new(ctx_.get()));
383 
384   // Enable TLS progress callback if VLOG level is >=3.
385   if (VLOG_IS_ON(3))
386     SSL_set_info_callback(ssl_.get(), TlsInfoCallback);
387 
388   stream_bio_ = BIO_new_stream(socket_.get());
389   SSL_set_bio(ssl_.get(), stream_bio_, stream_bio_);
390   SSL_set_connect_state(ssl_.get());
391 
392   // We might have no message loop (e.g. we are in unit tests).
393   if (MessageLoop::ThreadHasCurrent()) {
394     MessageLoop::current()->PostTask(
395         FROM_HERE,
396         base::BindOnce(&TlsStreamImpl::DoHandshake,
397                        weak_ptr_factory_.GetWeakPtr(),
398                        success_callback,
399                        error_callback));
400   } else {
401     DoHandshake(success_callback, error_callback);
402   }
403   return true;
404 }
405 
RetryHandshake(const base::Closure & success_callback,const Stream::ErrorCallback & error_callback,Stream::AccessMode)406 void TlsStream::TlsStreamImpl::RetryHandshake(
407     const base::Closure& success_callback,
408     const Stream::ErrorCallback& error_callback,
409     Stream::AccessMode /* mode */) {
410   VLOG(1) << "Retrying TLS handshake";
411   DoHandshake(success_callback, error_callback);
412 }
413 
DoHandshake(const base::Closure & success_callback,const Stream::ErrorCallback & error_callback)414 void TlsStream::TlsStreamImpl::DoHandshake(
415     const base::Closure& success_callback,
416     const Stream::ErrorCallback& error_callback) {
417   VLOG(1) << "Begin TLS handshake";
418   int res = SSL_do_handshake(ssl_.get());
419   if (res == 1) {
420     VLOG(1) << "Handshake successful";
421     success_callback.Run();
422     return;
423   }
424   ErrorPtr error;
425   int err = SSL_get_error(ssl_.get(), res);
426   if (err == SSL_ERROR_WANT_READ) {
427     VLOG(1) << "Waiting for read data...";
428     bool ok = socket_->WaitForData(
429         Stream::AccessMode::READ,
430         base::Bind(&TlsStreamImpl::RetryHandshake,
431                    weak_ptr_factory_.GetWeakPtr(),
432                    success_callback, error_callback),
433         &error);
434     if (ok)
435       return;
436   } else if (err == SSL_ERROR_WANT_WRITE) {
437     VLOG(1) << "Waiting for write data...";
438     bool ok = socket_->WaitForData(
439         Stream::AccessMode::WRITE,
440         base::Bind(&TlsStreamImpl::RetryHandshake,
441                    weak_ptr_factory_.GetWeakPtr(),
442                    success_callback, error_callback),
443         &error);
444     if (ok)
445       return;
446   } else {
447     ReportError(&error, FROM_HERE, "TLS handshake failed.");
448   }
449   error_callback.Run(error.get());
450 }
451 
452 /////////////////////////////////////////////////////////////////////////////
TlsStream(std::unique_ptr<TlsStreamImpl> impl)453 TlsStream::TlsStream(std::unique_ptr<TlsStreamImpl> impl)
454     : impl_{std::move(impl)} {}
455 
~TlsStream()456 TlsStream::~TlsStream() {
457   if (impl_) {
458     impl_->Close(nullptr);
459   }
460 }
461 
Connect(StreamPtr socket,const std::string & host,const base::Callback<void (StreamPtr)> & success_callback,const Stream::ErrorCallback & error_callback)462 void TlsStream::Connect(StreamPtr socket,
463                         const std::string& host,
464                         const base::Callback<void(StreamPtr)>& success_callback,
465                         const Stream::ErrorCallback& error_callback) {
466   std::unique_ptr<TlsStreamImpl> impl{new TlsStreamImpl};
467   std::unique_ptr<TlsStream> stream{new TlsStream{std::move(impl)}};
468 
469   TlsStreamImpl* pimpl = stream->impl_.get();
470   ErrorPtr error;
471   bool success = pimpl->Init(std::move(socket), host,
472                              base::Bind(success_callback,
473                                         base::Passed(std::move(stream))),
474                              error_callback, &error);
475 
476   if (!success)
477     error_callback.Run(error.get());
478 }
479 
IsOpen() const480 bool TlsStream::IsOpen() const {
481   return impl_ ? true : false;
482 }
483 
SetSizeBlocking(uint64_t,ErrorPtr * error)484 bool TlsStream::SetSizeBlocking(uint64_t /* size */, ErrorPtr* error) {
485   return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
486 }
487 
Seek(int64_t,Whence,uint64_t *,ErrorPtr * error)488 bool TlsStream::Seek(int64_t /* offset */,
489                      Whence /* whence */,
490                      uint64_t* /* new_position*/,
491                      ErrorPtr* error) {
492   return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
493 }
494 
ReadNonBlocking(void * buffer,size_t size_to_read,size_t * size_read,bool * end_of_stream,ErrorPtr * error)495 bool TlsStream::ReadNonBlocking(void* buffer,
496                                 size_t size_to_read,
497                                 size_t* size_read,
498                                 bool* end_of_stream,
499                                 ErrorPtr* error) {
500   if (!impl_)
501     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
502   return impl_->ReadNonBlocking(buffer, size_to_read, size_read, end_of_stream,
503                                 error);
504 }
505 
WriteNonBlocking(const void * buffer,size_t size_to_write,size_t * size_written,ErrorPtr * error)506 bool TlsStream::WriteNonBlocking(const void* buffer,
507                                  size_t size_to_write,
508                                  size_t* size_written,
509                                  ErrorPtr* error) {
510   if (!impl_)
511     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
512   return impl_->WriteNonBlocking(buffer, size_to_write, size_written, error);
513 }
514 
FlushBlocking(ErrorPtr * error)515 bool TlsStream::FlushBlocking(ErrorPtr* error) {
516   if (!impl_)
517     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
518   return impl_->Flush(error);
519 }
520 
CloseBlocking(ErrorPtr * error)521 bool TlsStream::CloseBlocking(ErrorPtr* error) {
522   if (impl_ && !impl_->Close(error))
523     return false;
524   impl_.reset();
525   return true;
526 }
527 
WaitForData(AccessMode mode,const base::Callback<void (AccessMode)> & callback,ErrorPtr * error)528 bool TlsStream::WaitForData(AccessMode mode,
529                             const base::Callback<void(AccessMode)>& callback,
530                             ErrorPtr* error) {
531   if (!impl_)
532     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
533   return impl_->WaitForData(mode, callback, error);
534 }
535 
WaitForDataBlocking(AccessMode in_mode,base::TimeDelta timeout,AccessMode * out_mode,ErrorPtr * error)536 bool TlsStream::WaitForDataBlocking(AccessMode in_mode,
537                                     base::TimeDelta timeout,
538                                     AccessMode* out_mode,
539                                     ErrorPtr* error) {
540   if (!impl_)
541     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
542   return impl_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
543 }
544 
CancelPendingAsyncOperations()545 void TlsStream::CancelPendingAsyncOperations() {
546   if (impl_)
547     impl_->CancelPendingAsyncOperations();
548   Stream::CancelPendingAsyncOperations();
549 }
550 
551 }  // namespace brillo
552