xref: /aosp_15_r20/external/libbrillo/brillo/streams/tls_stream.cc (revision 1a96fba65179ea7d3f56207137718607415c5953)
1*1a96fba6SXin Li // Copyright 2015 The Chromium OS Authors. All rights reserved.
2*1a96fba6SXin Li // Use of this source code is governed by a BSD-style license that can be
3*1a96fba6SXin Li // found in the LICENSE file.
4*1a96fba6SXin Li 
5*1a96fba6SXin Li #include <brillo/streams/tls_stream.h>
6*1a96fba6SXin Li 
7*1a96fba6SXin Li #include <algorithm>
8*1a96fba6SXin Li #include <limits>
9*1a96fba6SXin Li #include <string>
10*1a96fba6SXin Li #include <utility>
11*1a96fba6SXin Li #include <vector>
12*1a96fba6SXin Li 
13*1a96fba6SXin Li #include <openssl/err.h>
14*1a96fba6SXin Li #include <openssl/ssl.h>
15*1a96fba6SXin Li 
16*1a96fba6SXin Li #include <base/bind.h>
17*1a96fba6SXin Li #include <base/memory/weak_ptr.h>
18*1a96fba6SXin Li #include <brillo/message_loops/message_loop.h>
19*1a96fba6SXin Li #include <brillo/secure_blob.h>
20*1a96fba6SXin Li #include <brillo/streams/openssl_stream_bio.h>
21*1a96fba6SXin Li #include <brillo/streams/stream_utils.h>
22*1a96fba6SXin Li #include <brillo/strings/string_utils.h>
23*1a96fba6SXin Li 
24*1a96fba6SXin Li namespace {
25*1a96fba6SXin Li 
26*1a96fba6SXin Li // SSL info callback which is called by OpenSSL when we enable logging level of
27*1a96fba6SXin Li // at least 3. This logs the information about the internal TLS handshake.
TlsInfoCallback(const SSL *,int where,int ret)28*1a96fba6SXin Li void TlsInfoCallback(const SSL* /* ssl */, int where, int ret) {
29*1a96fba6SXin Li   std::string reason;
30*1a96fba6SXin Li   std::vector<std::string> info;
31*1a96fba6SXin Li   if (where & SSL_CB_LOOP)
32*1a96fba6SXin Li     info.push_back("loop");
33*1a96fba6SXin Li   if (where & SSL_CB_EXIT)
34*1a96fba6SXin Li     info.push_back("exit");
35*1a96fba6SXin Li   if (where & SSL_CB_READ)
36*1a96fba6SXin Li     info.push_back("read");
37*1a96fba6SXin Li   if (where & SSL_CB_WRITE)
38*1a96fba6SXin Li     info.push_back("write");
39*1a96fba6SXin Li   if (where & SSL_CB_ALERT) {
40*1a96fba6SXin Li     info.push_back("alert");
41*1a96fba6SXin Li     reason = ", reason: ";
42*1a96fba6SXin Li     reason += SSL_alert_type_string_long(ret);
43*1a96fba6SXin Li     reason += "/";
44*1a96fba6SXin Li     reason += SSL_alert_desc_string_long(ret);
45*1a96fba6SXin Li   }
46*1a96fba6SXin Li   if (where & SSL_CB_HANDSHAKE_START)
47*1a96fba6SXin Li     info.push_back("handshake_start");
48*1a96fba6SXin Li   if (where & SSL_CB_HANDSHAKE_DONE)
49*1a96fba6SXin Li     info.push_back("handshake_done");
50*1a96fba6SXin Li 
51*1a96fba6SXin Li   VLOG(3) << "TLS progress info: " << brillo::string_utils::Join(",", info)
52*1a96fba6SXin Li           << ", with status: " << ret << reason;
53*1a96fba6SXin Li }
54*1a96fba6SXin Li 
55*1a96fba6SXin Li // Static variable to store the index of TlsStream private data in SSL context
56*1a96fba6SXin Li // used to store custom data for OnCertVerifyResults().
57*1a96fba6SXin Li int ssl_ctx_private_data_index = -1;
58*1a96fba6SXin Li 
59*1a96fba6SXin Li // Default trusted certificate store location.
60*1a96fba6SXin Li const char kCACertificatePath[] =
61*1a96fba6SXin Li #ifdef __ANDROID__
62*1a96fba6SXin Li     "/system/etc/security/cacerts_google";
63*1a96fba6SXin Li #else
64*1a96fba6SXin Li     "/usr/share/chromeos-ca-certificates";
65*1a96fba6SXin Li #endif
66*1a96fba6SXin Li 
67*1a96fba6SXin Li }  // anonymous namespace
68*1a96fba6SXin Li 
69*1a96fba6SXin Li namespace brillo {
70*1a96fba6SXin Li 
71*1a96fba6SXin Li // TODO(crbug.com/984789): Remove once support for OpenSSL <1.1 is dropped.
72*1a96fba6SXin Li #if OPENSSL_VERSION_NUMBER < 0x10100000L
73*1a96fba6SXin Li #define TLS_client_method() TLSv1_2_client_method()
74*1a96fba6SXin Li #endif
75*1a96fba6SXin Li 
76*1a96fba6SXin Li // Helper implementation of TLS stream used to hide most of OpenSSL inner
77*1a96fba6SXin Li // workings from the users of brillo::TlsStream.
78*1a96fba6SXin Li class TlsStream::TlsStreamImpl {
79*1a96fba6SXin Li  public:
80*1a96fba6SXin Li   TlsStreamImpl();
81*1a96fba6SXin Li   ~TlsStreamImpl();
82*1a96fba6SXin Li 
83*1a96fba6SXin Li   bool Init(StreamPtr socket,
84*1a96fba6SXin Li             const std::string& host,
85*1a96fba6SXin Li             const base::Closure& success_callback,
86*1a96fba6SXin Li             const Stream::ErrorCallback& error_callback,
87*1a96fba6SXin Li             ErrorPtr* error);
88*1a96fba6SXin Li 
89*1a96fba6SXin Li   bool ReadNonBlocking(void* buffer,
90*1a96fba6SXin Li                        size_t size_to_read,
91*1a96fba6SXin Li                        size_t* size_read,
92*1a96fba6SXin Li                        bool* end_of_stream,
93*1a96fba6SXin Li                        ErrorPtr* error);
94*1a96fba6SXin Li 
95*1a96fba6SXin Li   bool WriteNonBlocking(const void* buffer,
96*1a96fba6SXin Li                         size_t size_to_write,
97*1a96fba6SXin Li                         size_t* size_written,
98*1a96fba6SXin Li                         ErrorPtr* error);
99*1a96fba6SXin Li 
100*1a96fba6SXin Li   bool Flush(ErrorPtr* error);
101*1a96fba6SXin Li   bool Close(ErrorPtr* error);
102*1a96fba6SXin Li   bool WaitForData(AccessMode mode,
103*1a96fba6SXin Li                    const base::Callback<void(AccessMode)>& callback,
104*1a96fba6SXin Li                    ErrorPtr* error);
105*1a96fba6SXin Li   bool WaitForDataBlocking(AccessMode in_mode,
106*1a96fba6SXin Li                            base::TimeDelta timeout,
107*1a96fba6SXin Li                            AccessMode* out_mode,
108*1a96fba6SXin Li                            ErrorPtr* error);
109*1a96fba6SXin Li   void CancelPendingAsyncOperations();
110*1a96fba6SXin Li 
111*1a96fba6SXin Li  private:
112*1a96fba6SXin Li   bool ReportError(ErrorPtr* error,
113*1a96fba6SXin Li                    const base::Location& location,
114*1a96fba6SXin Li                    const std::string& message);
115*1a96fba6SXin Li   void DoHandshake(const base::Closure& success_callback,
116*1a96fba6SXin Li                    const Stream::ErrorCallback& error_callback);
117*1a96fba6SXin Li   void RetryHandshake(const base::Closure& success_callback,
118*1a96fba6SXin Li                       const Stream::ErrorCallback& error_callback,
119*1a96fba6SXin Li                       Stream::AccessMode mode);
120*1a96fba6SXin Li 
121*1a96fba6SXin Li   int OnCertVerifyResults(int ok, X509_STORE_CTX* ctx);
122*1a96fba6SXin Li   static int OnCertVerifyResultsStatic(int ok, X509_STORE_CTX* ctx);
123*1a96fba6SXin Li 
124*1a96fba6SXin Li   StreamPtr socket_;
125*1a96fba6SXin Li   std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free};
126*1a96fba6SXin Li   std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free};
127*1a96fba6SXin Li   BIO* stream_bio_{nullptr};
128*1a96fba6SXin Li   bool need_more_read_{false};
129*1a96fba6SXin Li   bool need_more_write_{false};
130*1a96fba6SXin Li 
131*1a96fba6SXin Li   base::WeakPtrFactory<TlsStreamImpl> weak_ptr_factory_{this};
132*1a96fba6SXin Li   DISALLOW_COPY_AND_ASSIGN(TlsStreamImpl);
133*1a96fba6SXin Li };
134*1a96fba6SXin Li 
TlsStreamImpl()135*1a96fba6SXin Li TlsStream::TlsStreamImpl::TlsStreamImpl() {
136*1a96fba6SXin Li   SSL_load_error_strings();
137*1a96fba6SXin Li   SSL_library_init();
138*1a96fba6SXin Li   if (ssl_ctx_private_data_index < 0) {
139*1a96fba6SXin Li     ssl_ctx_private_data_index =
140*1a96fba6SXin Li         SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
141*1a96fba6SXin Li   }
142*1a96fba6SXin Li }
143*1a96fba6SXin Li 
~TlsStreamImpl()144*1a96fba6SXin Li TlsStream::TlsStreamImpl::~TlsStreamImpl() {
145*1a96fba6SXin Li   ssl_.reset();
146*1a96fba6SXin Li   ctx_.reset();
147*1a96fba6SXin Li }
148*1a96fba6SXin Li 
ReadNonBlocking(void * buffer,size_t size_to_read,size_t * size_read,bool * end_of_stream,ErrorPtr * error)149*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::ReadNonBlocking(void* buffer,
150*1a96fba6SXin Li                                                size_t size_to_read,
151*1a96fba6SXin Li                                                size_t* size_read,
152*1a96fba6SXin Li                                                bool* end_of_stream,
153*1a96fba6SXin Li                                                ErrorPtr* error) {
154*1a96fba6SXin Li   const size_t max_int = std::numeric_limits<int>::max();
155*1a96fba6SXin Li   int size_int = static_cast<int>(std::min(size_to_read, max_int));
156*1a96fba6SXin Li   int ret = SSL_read(ssl_.get(), buffer, size_int);
157*1a96fba6SXin Li   if (ret > 0) {
158*1a96fba6SXin Li     *size_read = static_cast<size_t>(ret);
159*1a96fba6SXin Li     if (end_of_stream)
160*1a96fba6SXin Li       *end_of_stream = false;
161*1a96fba6SXin Li     return true;
162*1a96fba6SXin Li   }
163*1a96fba6SXin Li 
164*1a96fba6SXin Li   int err = SSL_get_error(ssl_.get(), ret);
165*1a96fba6SXin Li   if (err == SSL_ERROR_ZERO_RETURN) {
166*1a96fba6SXin Li     *size_read = 0;
167*1a96fba6SXin Li     if (end_of_stream)
168*1a96fba6SXin Li       *end_of_stream = true;
169*1a96fba6SXin Li     return true;
170*1a96fba6SXin Li   }
171*1a96fba6SXin Li 
172*1a96fba6SXin Li   if (err == SSL_ERROR_WANT_READ) {
173*1a96fba6SXin Li     need_more_read_ = true;
174*1a96fba6SXin Li   } else if (err == SSL_ERROR_WANT_WRITE) {
175*1a96fba6SXin Li     // Writes might be required for SSL_read() because of possible TLS
176*1a96fba6SXin Li     // re-negotiations which can happen at any time.
177*1a96fba6SXin Li     need_more_write_ = true;
178*1a96fba6SXin Li   } else {
179*1a96fba6SXin Li     return ReportError(error, FROM_HERE, "Error reading from TLS socket");
180*1a96fba6SXin Li   }
181*1a96fba6SXin Li   *size_read = 0;
182*1a96fba6SXin Li   if (end_of_stream)
183*1a96fba6SXin Li     *end_of_stream = false;
184*1a96fba6SXin Li   return true;
185*1a96fba6SXin Li }
186*1a96fba6SXin Li 
WriteNonBlocking(const void * buffer,size_t size_to_write,size_t * size_written,ErrorPtr * error)187*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::WriteNonBlocking(const void* buffer,
188*1a96fba6SXin Li                                                 size_t size_to_write,
189*1a96fba6SXin Li                                                 size_t* size_written,
190*1a96fba6SXin Li                                                 ErrorPtr* error) {
191*1a96fba6SXin Li   const size_t max_int = std::numeric_limits<int>::max();
192*1a96fba6SXin Li   int size_int = static_cast<int>(std::min(size_to_write, max_int));
193*1a96fba6SXin Li   int ret = SSL_write(ssl_.get(), buffer, size_int);
194*1a96fba6SXin Li   if (ret > 0) {
195*1a96fba6SXin Li     *size_written = static_cast<size_t>(ret);
196*1a96fba6SXin Li     return true;
197*1a96fba6SXin Li   }
198*1a96fba6SXin Li 
199*1a96fba6SXin Li   int err = SSL_get_error(ssl_.get(), ret);
200*1a96fba6SXin Li   if (err == SSL_ERROR_WANT_READ) {
201*1a96fba6SXin Li     // Reads might be required for SSL_write() because of possible TLS
202*1a96fba6SXin Li     // re-negotiations which can happen at any time.
203*1a96fba6SXin Li     need_more_read_ = true;
204*1a96fba6SXin Li   } else if (err == SSL_ERROR_WANT_WRITE) {
205*1a96fba6SXin Li     need_more_write_ = true;
206*1a96fba6SXin Li   } else {
207*1a96fba6SXin Li     return ReportError(error, FROM_HERE, "Error writing to TLS socket");
208*1a96fba6SXin Li   }
209*1a96fba6SXin Li   *size_written = 0;
210*1a96fba6SXin Li   return true;
211*1a96fba6SXin Li }
212*1a96fba6SXin Li 
Flush(ErrorPtr * error)213*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::Flush(ErrorPtr* error) {
214*1a96fba6SXin Li   return socket_->FlushBlocking(error);
215*1a96fba6SXin Li }
216*1a96fba6SXin Li 
Close(ErrorPtr * error)217*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::Close(ErrorPtr* error) {
218*1a96fba6SXin Li   // 2 seconds should be plenty here.
219*1a96fba6SXin Li   const base::TimeDelta kTimeout = base::TimeDelta::FromSeconds(2);
220*1a96fba6SXin Li   // The retry count of 4 below is just arbitrary, to ensure we don't get stuck
221*1a96fba6SXin Li   // here forever. We should rarely need to repeat SSL_shutdown anyway.
222*1a96fba6SXin Li   for (int retry_count = 0; retry_count < 4; retry_count++) {
223*1a96fba6SXin Li     int ret = SSL_shutdown(ssl_.get());
224*1a96fba6SXin Li     // We really don't care for bi-directional shutdown here.
225*1a96fba6SXin Li     // Just make sure we only send the "close notify" alert to the remote peer.
226*1a96fba6SXin Li     if (ret >= 0)
227*1a96fba6SXin Li       break;
228*1a96fba6SXin Li 
229*1a96fba6SXin Li     int err = SSL_get_error(ssl_.get(), ret);
230*1a96fba6SXin Li     if (err == SSL_ERROR_WANT_READ) {
231*1a96fba6SXin Li       if (!socket_->WaitForDataBlocking(AccessMode::READ, kTimeout, nullptr,
232*1a96fba6SXin Li                                         error)) {
233*1a96fba6SXin Li         break;
234*1a96fba6SXin Li       }
235*1a96fba6SXin Li     } else if (err == SSL_ERROR_WANT_WRITE) {
236*1a96fba6SXin Li       if (!socket_->WaitForDataBlocking(AccessMode::WRITE, kTimeout, nullptr,
237*1a96fba6SXin Li                                         error)) {
238*1a96fba6SXin Li         break;
239*1a96fba6SXin Li       }
240*1a96fba6SXin Li     } else {
241*1a96fba6SXin Li       LOG(ERROR) << "SSL_shutdown returned error #" << err;
242*1a96fba6SXin Li       ReportError(error, FROM_HERE, "Failed to shut down TLS socket");
243*1a96fba6SXin Li       break;
244*1a96fba6SXin Li     }
245*1a96fba6SXin Li   }
246*1a96fba6SXin Li   return socket_->CloseBlocking(error);
247*1a96fba6SXin Li }
248*1a96fba6SXin Li 
WaitForData(AccessMode mode,const base::Callback<void (AccessMode)> & callback,ErrorPtr * error)249*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::WaitForData(
250*1a96fba6SXin Li     AccessMode mode,
251*1a96fba6SXin Li     const base::Callback<void(AccessMode)>& callback,
252*1a96fba6SXin Li     ErrorPtr* error) {
253*1a96fba6SXin Li   bool is_read = stream_utils::IsReadAccessMode(mode);
254*1a96fba6SXin Li   bool is_write = stream_utils::IsWriteAccessMode(mode);
255*1a96fba6SXin Li   is_read |= need_more_read_;
256*1a96fba6SXin Li   is_write |= need_more_write_;
257*1a96fba6SXin Li   need_more_read_ = false;
258*1a96fba6SXin Li   need_more_write_ = false;
259*1a96fba6SXin Li   if (is_read && SSL_pending(ssl_.get()) > 0) {
260*1a96fba6SXin Li     callback.Run(AccessMode::READ);
261*1a96fba6SXin Li     return true;
262*1a96fba6SXin Li   }
263*1a96fba6SXin Li   mode = stream_utils::MakeAccessMode(is_read, is_write);
264*1a96fba6SXin Li   return socket_->WaitForData(mode, callback, error);
265*1a96fba6SXin Li }
266*1a96fba6SXin Li 
WaitForDataBlocking(AccessMode in_mode,base::TimeDelta timeout,AccessMode * out_mode,ErrorPtr * error)267*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::WaitForDataBlocking(AccessMode in_mode,
268*1a96fba6SXin Li                                                    base::TimeDelta timeout,
269*1a96fba6SXin Li                                                    AccessMode* out_mode,
270*1a96fba6SXin Li                                                    ErrorPtr* error) {
271*1a96fba6SXin Li   bool is_read = stream_utils::IsReadAccessMode(in_mode);
272*1a96fba6SXin Li   bool is_write = stream_utils::IsWriteAccessMode(in_mode);
273*1a96fba6SXin Li   is_read |= need_more_read_;
274*1a96fba6SXin Li   is_write |= need_more_write_;
275*1a96fba6SXin Li   need_more_read_ = need_more_write_ = false;
276*1a96fba6SXin Li   if (is_read && SSL_pending(ssl_.get()) > 0) {
277*1a96fba6SXin Li     if (out_mode)
278*1a96fba6SXin Li       *out_mode = AccessMode::READ;
279*1a96fba6SXin Li     return true;
280*1a96fba6SXin Li   }
281*1a96fba6SXin Li   in_mode = stream_utils::MakeAccessMode(is_read, is_write);
282*1a96fba6SXin Li   return socket_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
283*1a96fba6SXin Li }
284*1a96fba6SXin Li 
CancelPendingAsyncOperations()285*1a96fba6SXin Li void TlsStream::TlsStreamImpl::CancelPendingAsyncOperations() {
286*1a96fba6SXin Li   socket_->CancelPendingAsyncOperations();
287*1a96fba6SXin Li   weak_ptr_factory_.InvalidateWeakPtrs();
288*1a96fba6SXin Li }
289*1a96fba6SXin Li 
ReportError(ErrorPtr * error,const base::Location & location,const std::string & message)290*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::ReportError(
291*1a96fba6SXin Li     ErrorPtr* error,
292*1a96fba6SXin Li     const base::Location& location,
293*1a96fba6SXin Li     const std::string& message) {
294*1a96fba6SXin Li   const char* file = nullptr;
295*1a96fba6SXin Li   int line = 0;
296*1a96fba6SXin Li   const char* data = 0;
297*1a96fba6SXin Li   int flags = 0;
298*1a96fba6SXin Li   while (auto errnum = ERR_get_error_line_data(&file, &line, &data, &flags)) {
299*1a96fba6SXin Li     char buf[256];
300*1a96fba6SXin Li     ERR_error_string_n(errnum, buf, sizeof(buf));
301*1a96fba6SXin Li     base::Location ssl_location{"Unknown", file, line, nullptr};
302*1a96fba6SXin Li     std::string ssl_message = buf;
303*1a96fba6SXin Li     if (flags & ERR_TXT_STRING) {
304*1a96fba6SXin Li       ssl_message += ": ";
305*1a96fba6SXin Li       ssl_message += data;
306*1a96fba6SXin Li     }
307*1a96fba6SXin Li     Error::AddTo(error, ssl_location, "openssl", std::to_string(errnum),
308*1a96fba6SXin Li                  ssl_message);
309*1a96fba6SXin Li   }
310*1a96fba6SXin Li   Error::AddTo(error, location, "tls_stream", "failed", message);
311*1a96fba6SXin Li   return false;
312*1a96fba6SXin Li }
313*1a96fba6SXin Li 
OnCertVerifyResults(int ok,X509_STORE_CTX * ctx)314*1a96fba6SXin Li int TlsStream::TlsStreamImpl::OnCertVerifyResults(int ok, X509_STORE_CTX* ctx) {
315*1a96fba6SXin Li   // OpenSSL already performs a comprehensive check of the certificate chain
316*1a96fba6SXin Li   // (using X509_verify_cert() function) and calls back with the result of its
317*1a96fba6SXin Li   // verification.
318*1a96fba6SXin Li   // |ok| is set to 1 if the verification passed and 0 if an error was detected.
319*1a96fba6SXin Li   // Here we can perform some additional checks if we need to, or simply log
320*1a96fba6SXin Li   // the issues found.
321*1a96fba6SXin Li 
322*1a96fba6SXin Li   // For now, just log an error if it occurred.
323*1a96fba6SXin Li   if (!ok) {
324*1a96fba6SXin Li     LOG(ERROR) << "Server certificate validation failed: "
325*1a96fba6SXin Li                << X509_verify_cert_error_string(X509_STORE_CTX_get_error(ctx));
326*1a96fba6SXin Li   }
327*1a96fba6SXin Li   return ok;
328*1a96fba6SXin Li }
329*1a96fba6SXin Li 
OnCertVerifyResultsStatic(int ok,X509_STORE_CTX * ctx)330*1a96fba6SXin Li int TlsStream::TlsStreamImpl::OnCertVerifyResultsStatic(int ok,
331*1a96fba6SXin Li                                                         X509_STORE_CTX* ctx) {
332*1a96fba6SXin Li   // Obtain the pointer to the instance of TlsStream::TlsStreamImpl from the
333*1a96fba6SXin Li   // SSL CTX object referenced by |ctx|.
334*1a96fba6SXin Li   SSL* ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(
335*1a96fba6SXin Li       ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
336*1a96fba6SXin Li   SSL_CTX* ssl_ctx = ssl ? SSL_get_SSL_CTX(ssl) : nullptr;
337*1a96fba6SXin Li   TlsStream::TlsStreamImpl* self = nullptr;
338*1a96fba6SXin Li   if (ssl_ctx) {
339*1a96fba6SXin Li     self = static_cast<TlsStream::TlsStreamImpl*>(SSL_CTX_get_ex_data(
340*1a96fba6SXin Li         ssl_ctx, ssl_ctx_private_data_index));
341*1a96fba6SXin Li   }
342*1a96fba6SXin Li   return self ? self->OnCertVerifyResults(ok, ctx) : ok;
343*1a96fba6SXin Li }
344*1a96fba6SXin Li 
Init(StreamPtr socket,const std::string & host,const base::Closure & success_callback,const Stream::ErrorCallback & error_callback,ErrorPtr * error)345*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::Init(StreamPtr socket,
346*1a96fba6SXin Li                                     const std::string& host,
347*1a96fba6SXin Li                                     const base::Closure& success_callback,
348*1a96fba6SXin Li                                     const Stream::ErrorCallback& error_callback,
349*1a96fba6SXin Li                                     ErrorPtr* error) {
350*1a96fba6SXin Li   ctx_.reset(SSL_CTX_new(TLS_client_method()));
351*1a96fba6SXin Li   if (!ctx_)
352*1a96fba6SXin Li     return ReportError(error, FROM_HERE, "Cannot create SSL_CTX");
353*1a96fba6SXin Li 
354*1a96fba6SXin Li   // Top cipher suites supported by both Google GFEs and OpenSSL (in server
355*1a96fba6SXin Li   // preferred order).
356*1a96fba6SXin Li   int res = SSL_CTX_set_cipher_list(ctx_.get(),
357*1a96fba6SXin Li                                     "ECDHE-ECDSA-AES128-GCM-SHA256:"
358*1a96fba6SXin Li                                     "ECDHE-ECDSA-AES256-GCM-SHA384:"
359*1a96fba6SXin Li                                     "ECDHE-RSA-AES128-GCM-SHA256:"
360*1a96fba6SXin Li                                     "ECDHE-RSA-AES256-GCM-SHA384");
361*1a96fba6SXin Li   if (res != 1)
362*1a96fba6SXin Li     return ReportError(error, FROM_HERE, "Cannot set the cipher list");
363*1a96fba6SXin Li 
364*1a96fba6SXin Li   res = SSL_CTX_load_verify_locations(ctx_.get(), nullptr, kCACertificatePath);
365*1a96fba6SXin Li   if (res != 1) {
366*1a96fba6SXin Li     return ReportError(error, FROM_HERE,
367*1a96fba6SXin Li                        "Failed to specify trusted certificate location");
368*1a96fba6SXin Li   }
369*1a96fba6SXin Li 
370*1a96fba6SXin Li   // Store a pointer to "this" into SSL_CTX instance.
371*1a96fba6SXin Li   SSL_CTX_set_ex_data(ctx_.get(), ssl_ctx_private_data_index, this);
372*1a96fba6SXin Li 
373*1a96fba6SXin Li   // Ask OpenSSL to validate the server host from the certificate to match
374*1a96fba6SXin Li   // the expected host name we are given:
375*1a96fba6SXin Li   X509_VERIFY_PARAM* param = SSL_CTX_get0_param(ctx_.get());
376*1a96fba6SXin Li   X509_VERIFY_PARAM_set1_host(param, host.c_str(), host.size());
377*1a96fba6SXin Li 
378*1a96fba6SXin Li   SSL_CTX_set_verify(ctx_.get(), SSL_VERIFY_PEER,
379*1a96fba6SXin Li                      &TlsStreamImpl::OnCertVerifyResultsStatic);
380*1a96fba6SXin Li 
381*1a96fba6SXin Li   socket_ = std::move(socket);
382*1a96fba6SXin Li   ssl_.reset(SSL_new(ctx_.get()));
383*1a96fba6SXin Li 
384*1a96fba6SXin Li   // Enable TLS progress callback if VLOG level is >=3.
385*1a96fba6SXin Li   if (VLOG_IS_ON(3))
386*1a96fba6SXin Li     SSL_set_info_callback(ssl_.get(), TlsInfoCallback);
387*1a96fba6SXin Li 
388*1a96fba6SXin Li   stream_bio_ = BIO_new_stream(socket_.get());
389*1a96fba6SXin Li   SSL_set_bio(ssl_.get(), stream_bio_, stream_bio_);
390*1a96fba6SXin Li   SSL_set_connect_state(ssl_.get());
391*1a96fba6SXin Li 
392*1a96fba6SXin Li   // We might have no message loop (e.g. we are in unit tests).
393*1a96fba6SXin Li   if (MessageLoop::ThreadHasCurrent()) {
394*1a96fba6SXin Li     MessageLoop::current()->PostTask(
395*1a96fba6SXin Li         FROM_HERE,
396*1a96fba6SXin Li         base::BindOnce(&TlsStreamImpl::DoHandshake,
397*1a96fba6SXin Li                        weak_ptr_factory_.GetWeakPtr(),
398*1a96fba6SXin Li                        success_callback,
399*1a96fba6SXin Li                        error_callback));
400*1a96fba6SXin Li   } else {
401*1a96fba6SXin Li     DoHandshake(success_callback, error_callback);
402*1a96fba6SXin Li   }
403*1a96fba6SXin Li   return true;
404*1a96fba6SXin Li }
405*1a96fba6SXin Li 
RetryHandshake(const base::Closure & success_callback,const Stream::ErrorCallback & error_callback,Stream::AccessMode)406*1a96fba6SXin Li void TlsStream::TlsStreamImpl::RetryHandshake(
407*1a96fba6SXin Li     const base::Closure& success_callback,
408*1a96fba6SXin Li     const Stream::ErrorCallback& error_callback,
409*1a96fba6SXin Li     Stream::AccessMode /* mode */) {
410*1a96fba6SXin Li   VLOG(1) << "Retrying TLS handshake";
411*1a96fba6SXin Li   DoHandshake(success_callback, error_callback);
412*1a96fba6SXin Li }
413*1a96fba6SXin Li 
DoHandshake(const base::Closure & success_callback,const Stream::ErrorCallback & error_callback)414*1a96fba6SXin Li void TlsStream::TlsStreamImpl::DoHandshake(
415*1a96fba6SXin Li     const base::Closure& success_callback,
416*1a96fba6SXin Li     const Stream::ErrorCallback& error_callback) {
417*1a96fba6SXin Li   VLOG(1) << "Begin TLS handshake";
418*1a96fba6SXin Li   int res = SSL_do_handshake(ssl_.get());
419*1a96fba6SXin Li   if (res == 1) {
420*1a96fba6SXin Li     VLOG(1) << "Handshake successful";
421*1a96fba6SXin Li     success_callback.Run();
422*1a96fba6SXin Li     return;
423*1a96fba6SXin Li   }
424*1a96fba6SXin Li   ErrorPtr error;
425*1a96fba6SXin Li   int err = SSL_get_error(ssl_.get(), res);
426*1a96fba6SXin Li   if (err == SSL_ERROR_WANT_READ) {
427*1a96fba6SXin Li     VLOG(1) << "Waiting for read data...";
428*1a96fba6SXin Li     bool ok = socket_->WaitForData(
429*1a96fba6SXin Li         Stream::AccessMode::READ,
430*1a96fba6SXin Li         base::Bind(&TlsStreamImpl::RetryHandshake,
431*1a96fba6SXin Li                    weak_ptr_factory_.GetWeakPtr(),
432*1a96fba6SXin Li                    success_callback, error_callback),
433*1a96fba6SXin Li         &error);
434*1a96fba6SXin Li     if (ok)
435*1a96fba6SXin Li       return;
436*1a96fba6SXin Li   } else if (err == SSL_ERROR_WANT_WRITE) {
437*1a96fba6SXin Li     VLOG(1) << "Waiting for write data...";
438*1a96fba6SXin Li     bool ok = socket_->WaitForData(
439*1a96fba6SXin Li         Stream::AccessMode::WRITE,
440*1a96fba6SXin Li         base::Bind(&TlsStreamImpl::RetryHandshake,
441*1a96fba6SXin Li                    weak_ptr_factory_.GetWeakPtr(),
442*1a96fba6SXin Li                    success_callback, error_callback),
443*1a96fba6SXin Li         &error);
444*1a96fba6SXin Li     if (ok)
445*1a96fba6SXin Li       return;
446*1a96fba6SXin Li   } else {
447*1a96fba6SXin Li     ReportError(&error, FROM_HERE, "TLS handshake failed.");
448*1a96fba6SXin Li   }
449*1a96fba6SXin Li   error_callback.Run(error.get());
450*1a96fba6SXin Li }
451*1a96fba6SXin Li 
452*1a96fba6SXin Li /////////////////////////////////////////////////////////////////////////////
TlsStream(std::unique_ptr<TlsStreamImpl> impl)453*1a96fba6SXin Li TlsStream::TlsStream(std::unique_ptr<TlsStreamImpl> impl)
454*1a96fba6SXin Li     : impl_{std::move(impl)} {}
455*1a96fba6SXin Li 
~TlsStream()456*1a96fba6SXin Li TlsStream::~TlsStream() {
457*1a96fba6SXin Li   if (impl_) {
458*1a96fba6SXin Li     impl_->Close(nullptr);
459*1a96fba6SXin Li   }
460*1a96fba6SXin Li }
461*1a96fba6SXin Li 
Connect(StreamPtr socket,const std::string & host,const base::Callback<void (StreamPtr)> & success_callback,const Stream::ErrorCallback & error_callback)462*1a96fba6SXin Li void TlsStream::Connect(StreamPtr socket,
463*1a96fba6SXin Li                         const std::string& host,
464*1a96fba6SXin Li                         const base::Callback<void(StreamPtr)>& success_callback,
465*1a96fba6SXin Li                         const Stream::ErrorCallback& error_callback) {
466*1a96fba6SXin Li   std::unique_ptr<TlsStreamImpl> impl{new TlsStreamImpl};
467*1a96fba6SXin Li   std::unique_ptr<TlsStream> stream{new TlsStream{std::move(impl)}};
468*1a96fba6SXin Li 
469*1a96fba6SXin Li   TlsStreamImpl* pimpl = stream->impl_.get();
470*1a96fba6SXin Li   ErrorPtr error;
471*1a96fba6SXin Li   bool success = pimpl->Init(std::move(socket), host,
472*1a96fba6SXin Li                              base::Bind(success_callback,
473*1a96fba6SXin Li                                         base::Passed(std::move(stream))),
474*1a96fba6SXin Li                              error_callback, &error);
475*1a96fba6SXin Li 
476*1a96fba6SXin Li   if (!success)
477*1a96fba6SXin Li     error_callback.Run(error.get());
478*1a96fba6SXin Li }
479*1a96fba6SXin Li 
IsOpen() const480*1a96fba6SXin Li bool TlsStream::IsOpen() const {
481*1a96fba6SXin Li   return impl_ ? true : false;
482*1a96fba6SXin Li }
483*1a96fba6SXin Li 
SetSizeBlocking(uint64_t,ErrorPtr * error)484*1a96fba6SXin Li bool TlsStream::SetSizeBlocking(uint64_t /* size */, ErrorPtr* error) {
485*1a96fba6SXin Li   return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
486*1a96fba6SXin Li }
487*1a96fba6SXin Li 
Seek(int64_t,Whence,uint64_t *,ErrorPtr * error)488*1a96fba6SXin Li bool TlsStream::Seek(int64_t /* offset */,
489*1a96fba6SXin Li                      Whence /* whence */,
490*1a96fba6SXin Li                      uint64_t* /* new_position*/,
491*1a96fba6SXin Li                      ErrorPtr* error) {
492*1a96fba6SXin Li   return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
493*1a96fba6SXin Li }
494*1a96fba6SXin Li 
ReadNonBlocking(void * buffer,size_t size_to_read,size_t * size_read,bool * end_of_stream,ErrorPtr * error)495*1a96fba6SXin Li bool TlsStream::ReadNonBlocking(void* buffer,
496*1a96fba6SXin Li                                 size_t size_to_read,
497*1a96fba6SXin Li                                 size_t* size_read,
498*1a96fba6SXin Li                                 bool* end_of_stream,
499*1a96fba6SXin Li                                 ErrorPtr* error) {
500*1a96fba6SXin Li   if (!impl_)
501*1a96fba6SXin Li     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
502*1a96fba6SXin Li   return impl_->ReadNonBlocking(buffer, size_to_read, size_read, end_of_stream,
503*1a96fba6SXin Li                                 error);
504*1a96fba6SXin Li }
505*1a96fba6SXin Li 
WriteNonBlocking(const void * buffer,size_t size_to_write,size_t * size_written,ErrorPtr * error)506*1a96fba6SXin Li bool TlsStream::WriteNonBlocking(const void* buffer,
507*1a96fba6SXin Li                                  size_t size_to_write,
508*1a96fba6SXin Li                                  size_t* size_written,
509*1a96fba6SXin Li                                  ErrorPtr* error) {
510*1a96fba6SXin Li   if (!impl_)
511*1a96fba6SXin Li     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
512*1a96fba6SXin Li   return impl_->WriteNonBlocking(buffer, size_to_write, size_written, error);
513*1a96fba6SXin Li }
514*1a96fba6SXin Li 
FlushBlocking(ErrorPtr * error)515*1a96fba6SXin Li bool TlsStream::FlushBlocking(ErrorPtr* error) {
516*1a96fba6SXin Li   if (!impl_)
517*1a96fba6SXin Li     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
518*1a96fba6SXin Li   return impl_->Flush(error);
519*1a96fba6SXin Li }
520*1a96fba6SXin Li 
CloseBlocking(ErrorPtr * error)521*1a96fba6SXin Li bool TlsStream::CloseBlocking(ErrorPtr* error) {
522*1a96fba6SXin Li   if (impl_ && !impl_->Close(error))
523*1a96fba6SXin Li     return false;
524*1a96fba6SXin Li   impl_.reset();
525*1a96fba6SXin Li   return true;
526*1a96fba6SXin Li }
527*1a96fba6SXin Li 
WaitForData(AccessMode mode,const base::Callback<void (AccessMode)> & callback,ErrorPtr * error)528*1a96fba6SXin Li bool TlsStream::WaitForData(AccessMode mode,
529*1a96fba6SXin Li                             const base::Callback<void(AccessMode)>& callback,
530*1a96fba6SXin Li                             ErrorPtr* error) {
531*1a96fba6SXin Li   if (!impl_)
532*1a96fba6SXin Li     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
533*1a96fba6SXin Li   return impl_->WaitForData(mode, callback, error);
534*1a96fba6SXin Li }
535*1a96fba6SXin Li 
WaitForDataBlocking(AccessMode in_mode,base::TimeDelta timeout,AccessMode * out_mode,ErrorPtr * error)536*1a96fba6SXin Li bool TlsStream::WaitForDataBlocking(AccessMode in_mode,
537*1a96fba6SXin Li                                     base::TimeDelta timeout,
538*1a96fba6SXin Li                                     AccessMode* out_mode,
539*1a96fba6SXin Li                                     ErrorPtr* error) {
540*1a96fba6SXin Li   if (!impl_)
541*1a96fba6SXin Li     return stream_utils::ErrorStreamClosed(FROM_HERE, error);
542*1a96fba6SXin Li   return impl_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
543*1a96fba6SXin Li }
544*1a96fba6SXin Li 
CancelPendingAsyncOperations()545*1a96fba6SXin Li void TlsStream::CancelPendingAsyncOperations() {
546*1a96fba6SXin Li   if (impl_)
547*1a96fba6SXin Li     impl_->CancelPendingAsyncOperations();
548*1a96fba6SXin Li   Stream::CancelPendingAsyncOperations();
549*1a96fba6SXin Li }
550*1a96fba6SXin Li 
551*1a96fba6SXin Li }  // namespace brillo
552