1*1a96fba6SXin Li // Copyright 2015 The Chromium OS Authors. All rights reserved.
2*1a96fba6SXin Li // Use of this source code is governed by a BSD-style license that can be
3*1a96fba6SXin Li // found in the LICENSE file.
4*1a96fba6SXin Li
5*1a96fba6SXin Li #include <brillo/streams/tls_stream.h>
6*1a96fba6SXin Li
7*1a96fba6SXin Li #include <algorithm>
8*1a96fba6SXin Li #include <limits>
9*1a96fba6SXin Li #include <string>
10*1a96fba6SXin Li #include <utility>
11*1a96fba6SXin Li #include <vector>
12*1a96fba6SXin Li
13*1a96fba6SXin Li #include <openssl/err.h>
14*1a96fba6SXin Li #include <openssl/ssl.h>
15*1a96fba6SXin Li
16*1a96fba6SXin Li #include <base/bind.h>
17*1a96fba6SXin Li #include <base/memory/weak_ptr.h>
18*1a96fba6SXin Li #include <brillo/message_loops/message_loop.h>
19*1a96fba6SXin Li #include <brillo/secure_blob.h>
20*1a96fba6SXin Li #include <brillo/streams/openssl_stream_bio.h>
21*1a96fba6SXin Li #include <brillo/streams/stream_utils.h>
22*1a96fba6SXin Li #include <brillo/strings/string_utils.h>
23*1a96fba6SXin Li
24*1a96fba6SXin Li namespace {
25*1a96fba6SXin Li
26*1a96fba6SXin Li // SSL info callback which is called by OpenSSL when we enable logging level of
27*1a96fba6SXin Li // at least 3. This logs the information about the internal TLS handshake.
TlsInfoCallback(const SSL *,int where,int ret)28*1a96fba6SXin Li void TlsInfoCallback(const SSL* /* ssl */, int where, int ret) {
29*1a96fba6SXin Li std::string reason;
30*1a96fba6SXin Li std::vector<std::string> info;
31*1a96fba6SXin Li if (where & SSL_CB_LOOP)
32*1a96fba6SXin Li info.push_back("loop");
33*1a96fba6SXin Li if (where & SSL_CB_EXIT)
34*1a96fba6SXin Li info.push_back("exit");
35*1a96fba6SXin Li if (where & SSL_CB_READ)
36*1a96fba6SXin Li info.push_back("read");
37*1a96fba6SXin Li if (where & SSL_CB_WRITE)
38*1a96fba6SXin Li info.push_back("write");
39*1a96fba6SXin Li if (where & SSL_CB_ALERT) {
40*1a96fba6SXin Li info.push_back("alert");
41*1a96fba6SXin Li reason = ", reason: ";
42*1a96fba6SXin Li reason += SSL_alert_type_string_long(ret);
43*1a96fba6SXin Li reason += "/";
44*1a96fba6SXin Li reason += SSL_alert_desc_string_long(ret);
45*1a96fba6SXin Li }
46*1a96fba6SXin Li if (where & SSL_CB_HANDSHAKE_START)
47*1a96fba6SXin Li info.push_back("handshake_start");
48*1a96fba6SXin Li if (where & SSL_CB_HANDSHAKE_DONE)
49*1a96fba6SXin Li info.push_back("handshake_done");
50*1a96fba6SXin Li
51*1a96fba6SXin Li VLOG(3) << "TLS progress info: " << brillo::string_utils::Join(",", info)
52*1a96fba6SXin Li << ", with status: " << ret << reason;
53*1a96fba6SXin Li }
54*1a96fba6SXin Li
55*1a96fba6SXin Li // Static variable to store the index of TlsStream private data in SSL context
56*1a96fba6SXin Li // used to store custom data for OnCertVerifyResults().
57*1a96fba6SXin Li int ssl_ctx_private_data_index = -1;
58*1a96fba6SXin Li
59*1a96fba6SXin Li // Default trusted certificate store location.
60*1a96fba6SXin Li const char kCACertificatePath[] =
61*1a96fba6SXin Li #ifdef __ANDROID__
62*1a96fba6SXin Li "/system/etc/security/cacerts_google";
63*1a96fba6SXin Li #else
64*1a96fba6SXin Li "/usr/share/chromeos-ca-certificates";
65*1a96fba6SXin Li #endif
66*1a96fba6SXin Li
67*1a96fba6SXin Li } // anonymous namespace
68*1a96fba6SXin Li
69*1a96fba6SXin Li namespace brillo {
70*1a96fba6SXin Li
71*1a96fba6SXin Li // TODO(crbug.com/984789): Remove once support for OpenSSL <1.1 is dropped.
72*1a96fba6SXin Li #if OPENSSL_VERSION_NUMBER < 0x10100000L
73*1a96fba6SXin Li #define TLS_client_method() TLSv1_2_client_method()
74*1a96fba6SXin Li #endif
75*1a96fba6SXin Li
76*1a96fba6SXin Li // Helper implementation of TLS stream used to hide most of OpenSSL inner
77*1a96fba6SXin Li // workings from the users of brillo::TlsStream.
78*1a96fba6SXin Li class TlsStream::TlsStreamImpl {
79*1a96fba6SXin Li public:
80*1a96fba6SXin Li TlsStreamImpl();
81*1a96fba6SXin Li ~TlsStreamImpl();
82*1a96fba6SXin Li
83*1a96fba6SXin Li bool Init(StreamPtr socket,
84*1a96fba6SXin Li const std::string& host,
85*1a96fba6SXin Li const base::Closure& success_callback,
86*1a96fba6SXin Li const Stream::ErrorCallback& error_callback,
87*1a96fba6SXin Li ErrorPtr* error);
88*1a96fba6SXin Li
89*1a96fba6SXin Li bool ReadNonBlocking(void* buffer,
90*1a96fba6SXin Li size_t size_to_read,
91*1a96fba6SXin Li size_t* size_read,
92*1a96fba6SXin Li bool* end_of_stream,
93*1a96fba6SXin Li ErrorPtr* error);
94*1a96fba6SXin Li
95*1a96fba6SXin Li bool WriteNonBlocking(const void* buffer,
96*1a96fba6SXin Li size_t size_to_write,
97*1a96fba6SXin Li size_t* size_written,
98*1a96fba6SXin Li ErrorPtr* error);
99*1a96fba6SXin Li
100*1a96fba6SXin Li bool Flush(ErrorPtr* error);
101*1a96fba6SXin Li bool Close(ErrorPtr* error);
102*1a96fba6SXin Li bool WaitForData(AccessMode mode,
103*1a96fba6SXin Li const base::Callback<void(AccessMode)>& callback,
104*1a96fba6SXin Li ErrorPtr* error);
105*1a96fba6SXin Li bool WaitForDataBlocking(AccessMode in_mode,
106*1a96fba6SXin Li base::TimeDelta timeout,
107*1a96fba6SXin Li AccessMode* out_mode,
108*1a96fba6SXin Li ErrorPtr* error);
109*1a96fba6SXin Li void CancelPendingAsyncOperations();
110*1a96fba6SXin Li
111*1a96fba6SXin Li private:
112*1a96fba6SXin Li bool ReportError(ErrorPtr* error,
113*1a96fba6SXin Li const base::Location& location,
114*1a96fba6SXin Li const std::string& message);
115*1a96fba6SXin Li void DoHandshake(const base::Closure& success_callback,
116*1a96fba6SXin Li const Stream::ErrorCallback& error_callback);
117*1a96fba6SXin Li void RetryHandshake(const base::Closure& success_callback,
118*1a96fba6SXin Li const Stream::ErrorCallback& error_callback,
119*1a96fba6SXin Li Stream::AccessMode mode);
120*1a96fba6SXin Li
121*1a96fba6SXin Li int OnCertVerifyResults(int ok, X509_STORE_CTX* ctx);
122*1a96fba6SXin Li static int OnCertVerifyResultsStatic(int ok, X509_STORE_CTX* ctx);
123*1a96fba6SXin Li
124*1a96fba6SXin Li StreamPtr socket_;
125*1a96fba6SXin Li std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free};
126*1a96fba6SXin Li std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free};
127*1a96fba6SXin Li BIO* stream_bio_{nullptr};
128*1a96fba6SXin Li bool need_more_read_{false};
129*1a96fba6SXin Li bool need_more_write_{false};
130*1a96fba6SXin Li
131*1a96fba6SXin Li base::WeakPtrFactory<TlsStreamImpl> weak_ptr_factory_{this};
132*1a96fba6SXin Li DISALLOW_COPY_AND_ASSIGN(TlsStreamImpl);
133*1a96fba6SXin Li };
134*1a96fba6SXin Li
TlsStreamImpl()135*1a96fba6SXin Li TlsStream::TlsStreamImpl::TlsStreamImpl() {
136*1a96fba6SXin Li SSL_load_error_strings();
137*1a96fba6SXin Li SSL_library_init();
138*1a96fba6SXin Li if (ssl_ctx_private_data_index < 0) {
139*1a96fba6SXin Li ssl_ctx_private_data_index =
140*1a96fba6SXin Li SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
141*1a96fba6SXin Li }
142*1a96fba6SXin Li }
143*1a96fba6SXin Li
~TlsStreamImpl()144*1a96fba6SXin Li TlsStream::TlsStreamImpl::~TlsStreamImpl() {
145*1a96fba6SXin Li ssl_.reset();
146*1a96fba6SXin Li ctx_.reset();
147*1a96fba6SXin Li }
148*1a96fba6SXin Li
ReadNonBlocking(void * buffer,size_t size_to_read,size_t * size_read,bool * end_of_stream,ErrorPtr * error)149*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::ReadNonBlocking(void* buffer,
150*1a96fba6SXin Li size_t size_to_read,
151*1a96fba6SXin Li size_t* size_read,
152*1a96fba6SXin Li bool* end_of_stream,
153*1a96fba6SXin Li ErrorPtr* error) {
154*1a96fba6SXin Li const size_t max_int = std::numeric_limits<int>::max();
155*1a96fba6SXin Li int size_int = static_cast<int>(std::min(size_to_read, max_int));
156*1a96fba6SXin Li int ret = SSL_read(ssl_.get(), buffer, size_int);
157*1a96fba6SXin Li if (ret > 0) {
158*1a96fba6SXin Li *size_read = static_cast<size_t>(ret);
159*1a96fba6SXin Li if (end_of_stream)
160*1a96fba6SXin Li *end_of_stream = false;
161*1a96fba6SXin Li return true;
162*1a96fba6SXin Li }
163*1a96fba6SXin Li
164*1a96fba6SXin Li int err = SSL_get_error(ssl_.get(), ret);
165*1a96fba6SXin Li if (err == SSL_ERROR_ZERO_RETURN) {
166*1a96fba6SXin Li *size_read = 0;
167*1a96fba6SXin Li if (end_of_stream)
168*1a96fba6SXin Li *end_of_stream = true;
169*1a96fba6SXin Li return true;
170*1a96fba6SXin Li }
171*1a96fba6SXin Li
172*1a96fba6SXin Li if (err == SSL_ERROR_WANT_READ) {
173*1a96fba6SXin Li need_more_read_ = true;
174*1a96fba6SXin Li } else if (err == SSL_ERROR_WANT_WRITE) {
175*1a96fba6SXin Li // Writes might be required for SSL_read() because of possible TLS
176*1a96fba6SXin Li // re-negotiations which can happen at any time.
177*1a96fba6SXin Li need_more_write_ = true;
178*1a96fba6SXin Li } else {
179*1a96fba6SXin Li return ReportError(error, FROM_HERE, "Error reading from TLS socket");
180*1a96fba6SXin Li }
181*1a96fba6SXin Li *size_read = 0;
182*1a96fba6SXin Li if (end_of_stream)
183*1a96fba6SXin Li *end_of_stream = false;
184*1a96fba6SXin Li return true;
185*1a96fba6SXin Li }
186*1a96fba6SXin Li
WriteNonBlocking(const void * buffer,size_t size_to_write,size_t * size_written,ErrorPtr * error)187*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::WriteNonBlocking(const void* buffer,
188*1a96fba6SXin Li size_t size_to_write,
189*1a96fba6SXin Li size_t* size_written,
190*1a96fba6SXin Li ErrorPtr* error) {
191*1a96fba6SXin Li const size_t max_int = std::numeric_limits<int>::max();
192*1a96fba6SXin Li int size_int = static_cast<int>(std::min(size_to_write, max_int));
193*1a96fba6SXin Li int ret = SSL_write(ssl_.get(), buffer, size_int);
194*1a96fba6SXin Li if (ret > 0) {
195*1a96fba6SXin Li *size_written = static_cast<size_t>(ret);
196*1a96fba6SXin Li return true;
197*1a96fba6SXin Li }
198*1a96fba6SXin Li
199*1a96fba6SXin Li int err = SSL_get_error(ssl_.get(), ret);
200*1a96fba6SXin Li if (err == SSL_ERROR_WANT_READ) {
201*1a96fba6SXin Li // Reads might be required for SSL_write() because of possible TLS
202*1a96fba6SXin Li // re-negotiations which can happen at any time.
203*1a96fba6SXin Li need_more_read_ = true;
204*1a96fba6SXin Li } else if (err == SSL_ERROR_WANT_WRITE) {
205*1a96fba6SXin Li need_more_write_ = true;
206*1a96fba6SXin Li } else {
207*1a96fba6SXin Li return ReportError(error, FROM_HERE, "Error writing to TLS socket");
208*1a96fba6SXin Li }
209*1a96fba6SXin Li *size_written = 0;
210*1a96fba6SXin Li return true;
211*1a96fba6SXin Li }
212*1a96fba6SXin Li
Flush(ErrorPtr * error)213*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::Flush(ErrorPtr* error) {
214*1a96fba6SXin Li return socket_->FlushBlocking(error);
215*1a96fba6SXin Li }
216*1a96fba6SXin Li
Close(ErrorPtr * error)217*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::Close(ErrorPtr* error) {
218*1a96fba6SXin Li // 2 seconds should be plenty here.
219*1a96fba6SXin Li const base::TimeDelta kTimeout = base::TimeDelta::FromSeconds(2);
220*1a96fba6SXin Li // The retry count of 4 below is just arbitrary, to ensure we don't get stuck
221*1a96fba6SXin Li // here forever. We should rarely need to repeat SSL_shutdown anyway.
222*1a96fba6SXin Li for (int retry_count = 0; retry_count < 4; retry_count++) {
223*1a96fba6SXin Li int ret = SSL_shutdown(ssl_.get());
224*1a96fba6SXin Li // We really don't care for bi-directional shutdown here.
225*1a96fba6SXin Li // Just make sure we only send the "close notify" alert to the remote peer.
226*1a96fba6SXin Li if (ret >= 0)
227*1a96fba6SXin Li break;
228*1a96fba6SXin Li
229*1a96fba6SXin Li int err = SSL_get_error(ssl_.get(), ret);
230*1a96fba6SXin Li if (err == SSL_ERROR_WANT_READ) {
231*1a96fba6SXin Li if (!socket_->WaitForDataBlocking(AccessMode::READ, kTimeout, nullptr,
232*1a96fba6SXin Li error)) {
233*1a96fba6SXin Li break;
234*1a96fba6SXin Li }
235*1a96fba6SXin Li } else if (err == SSL_ERROR_WANT_WRITE) {
236*1a96fba6SXin Li if (!socket_->WaitForDataBlocking(AccessMode::WRITE, kTimeout, nullptr,
237*1a96fba6SXin Li error)) {
238*1a96fba6SXin Li break;
239*1a96fba6SXin Li }
240*1a96fba6SXin Li } else {
241*1a96fba6SXin Li LOG(ERROR) << "SSL_shutdown returned error #" << err;
242*1a96fba6SXin Li ReportError(error, FROM_HERE, "Failed to shut down TLS socket");
243*1a96fba6SXin Li break;
244*1a96fba6SXin Li }
245*1a96fba6SXin Li }
246*1a96fba6SXin Li return socket_->CloseBlocking(error);
247*1a96fba6SXin Li }
248*1a96fba6SXin Li
WaitForData(AccessMode mode,const base::Callback<void (AccessMode)> & callback,ErrorPtr * error)249*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::WaitForData(
250*1a96fba6SXin Li AccessMode mode,
251*1a96fba6SXin Li const base::Callback<void(AccessMode)>& callback,
252*1a96fba6SXin Li ErrorPtr* error) {
253*1a96fba6SXin Li bool is_read = stream_utils::IsReadAccessMode(mode);
254*1a96fba6SXin Li bool is_write = stream_utils::IsWriteAccessMode(mode);
255*1a96fba6SXin Li is_read |= need_more_read_;
256*1a96fba6SXin Li is_write |= need_more_write_;
257*1a96fba6SXin Li need_more_read_ = false;
258*1a96fba6SXin Li need_more_write_ = false;
259*1a96fba6SXin Li if (is_read && SSL_pending(ssl_.get()) > 0) {
260*1a96fba6SXin Li callback.Run(AccessMode::READ);
261*1a96fba6SXin Li return true;
262*1a96fba6SXin Li }
263*1a96fba6SXin Li mode = stream_utils::MakeAccessMode(is_read, is_write);
264*1a96fba6SXin Li return socket_->WaitForData(mode, callback, error);
265*1a96fba6SXin Li }
266*1a96fba6SXin Li
WaitForDataBlocking(AccessMode in_mode,base::TimeDelta timeout,AccessMode * out_mode,ErrorPtr * error)267*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::WaitForDataBlocking(AccessMode in_mode,
268*1a96fba6SXin Li base::TimeDelta timeout,
269*1a96fba6SXin Li AccessMode* out_mode,
270*1a96fba6SXin Li ErrorPtr* error) {
271*1a96fba6SXin Li bool is_read = stream_utils::IsReadAccessMode(in_mode);
272*1a96fba6SXin Li bool is_write = stream_utils::IsWriteAccessMode(in_mode);
273*1a96fba6SXin Li is_read |= need_more_read_;
274*1a96fba6SXin Li is_write |= need_more_write_;
275*1a96fba6SXin Li need_more_read_ = need_more_write_ = false;
276*1a96fba6SXin Li if (is_read && SSL_pending(ssl_.get()) > 0) {
277*1a96fba6SXin Li if (out_mode)
278*1a96fba6SXin Li *out_mode = AccessMode::READ;
279*1a96fba6SXin Li return true;
280*1a96fba6SXin Li }
281*1a96fba6SXin Li in_mode = stream_utils::MakeAccessMode(is_read, is_write);
282*1a96fba6SXin Li return socket_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
283*1a96fba6SXin Li }
284*1a96fba6SXin Li
CancelPendingAsyncOperations()285*1a96fba6SXin Li void TlsStream::TlsStreamImpl::CancelPendingAsyncOperations() {
286*1a96fba6SXin Li socket_->CancelPendingAsyncOperations();
287*1a96fba6SXin Li weak_ptr_factory_.InvalidateWeakPtrs();
288*1a96fba6SXin Li }
289*1a96fba6SXin Li
ReportError(ErrorPtr * error,const base::Location & location,const std::string & message)290*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::ReportError(
291*1a96fba6SXin Li ErrorPtr* error,
292*1a96fba6SXin Li const base::Location& location,
293*1a96fba6SXin Li const std::string& message) {
294*1a96fba6SXin Li const char* file = nullptr;
295*1a96fba6SXin Li int line = 0;
296*1a96fba6SXin Li const char* data = 0;
297*1a96fba6SXin Li int flags = 0;
298*1a96fba6SXin Li while (auto errnum = ERR_get_error_line_data(&file, &line, &data, &flags)) {
299*1a96fba6SXin Li char buf[256];
300*1a96fba6SXin Li ERR_error_string_n(errnum, buf, sizeof(buf));
301*1a96fba6SXin Li base::Location ssl_location{"Unknown", file, line, nullptr};
302*1a96fba6SXin Li std::string ssl_message = buf;
303*1a96fba6SXin Li if (flags & ERR_TXT_STRING) {
304*1a96fba6SXin Li ssl_message += ": ";
305*1a96fba6SXin Li ssl_message += data;
306*1a96fba6SXin Li }
307*1a96fba6SXin Li Error::AddTo(error, ssl_location, "openssl", std::to_string(errnum),
308*1a96fba6SXin Li ssl_message);
309*1a96fba6SXin Li }
310*1a96fba6SXin Li Error::AddTo(error, location, "tls_stream", "failed", message);
311*1a96fba6SXin Li return false;
312*1a96fba6SXin Li }
313*1a96fba6SXin Li
OnCertVerifyResults(int ok,X509_STORE_CTX * ctx)314*1a96fba6SXin Li int TlsStream::TlsStreamImpl::OnCertVerifyResults(int ok, X509_STORE_CTX* ctx) {
315*1a96fba6SXin Li // OpenSSL already performs a comprehensive check of the certificate chain
316*1a96fba6SXin Li // (using X509_verify_cert() function) and calls back with the result of its
317*1a96fba6SXin Li // verification.
318*1a96fba6SXin Li // |ok| is set to 1 if the verification passed and 0 if an error was detected.
319*1a96fba6SXin Li // Here we can perform some additional checks if we need to, or simply log
320*1a96fba6SXin Li // the issues found.
321*1a96fba6SXin Li
322*1a96fba6SXin Li // For now, just log an error if it occurred.
323*1a96fba6SXin Li if (!ok) {
324*1a96fba6SXin Li LOG(ERROR) << "Server certificate validation failed: "
325*1a96fba6SXin Li << X509_verify_cert_error_string(X509_STORE_CTX_get_error(ctx));
326*1a96fba6SXin Li }
327*1a96fba6SXin Li return ok;
328*1a96fba6SXin Li }
329*1a96fba6SXin Li
OnCertVerifyResultsStatic(int ok,X509_STORE_CTX * ctx)330*1a96fba6SXin Li int TlsStream::TlsStreamImpl::OnCertVerifyResultsStatic(int ok,
331*1a96fba6SXin Li X509_STORE_CTX* ctx) {
332*1a96fba6SXin Li // Obtain the pointer to the instance of TlsStream::TlsStreamImpl from the
333*1a96fba6SXin Li // SSL CTX object referenced by |ctx|.
334*1a96fba6SXin Li SSL* ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(
335*1a96fba6SXin Li ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
336*1a96fba6SXin Li SSL_CTX* ssl_ctx = ssl ? SSL_get_SSL_CTX(ssl) : nullptr;
337*1a96fba6SXin Li TlsStream::TlsStreamImpl* self = nullptr;
338*1a96fba6SXin Li if (ssl_ctx) {
339*1a96fba6SXin Li self = static_cast<TlsStream::TlsStreamImpl*>(SSL_CTX_get_ex_data(
340*1a96fba6SXin Li ssl_ctx, ssl_ctx_private_data_index));
341*1a96fba6SXin Li }
342*1a96fba6SXin Li return self ? self->OnCertVerifyResults(ok, ctx) : ok;
343*1a96fba6SXin Li }
344*1a96fba6SXin Li
Init(StreamPtr socket,const std::string & host,const base::Closure & success_callback,const Stream::ErrorCallback & error_callback,ErrorPtr * error)345*1a96fba6SXin Li bool TlsStream::TlsStreamImpl::Init(StreamPtr socket,
346*1a96fba6SXin Li const std::string& host,
347*1a96fba6SXin Li const base::Closure& success_callback,
348*1a96fba6SXin Li const Stream::ErrorCallback& error_callback,
349*1a96fba6SXin Li ErrorPtr* error) {
350*1a96fba6SXin Li ctx_.reset(SSL_CTX_new(TLS_client_method()));
351*1a96fba6SXin Li if (!ctx_)
352*1a96fba6SXin Li return ReportError(error, FROM_HERE, "Cannot create SSL_CTX");
353*1a96fba6SXin Li
354*1a96fba6SXin Li // Top cipher suites supported by both Google GFEs and OpenSSL (in server
355*1a96fba6SXin Li // preferred order).
356*1a96fba6SXin Li int res = SSL_CTX_set_cipher_list(ctx_.get(),
357*1a96fba6SXin Li "ECDHE-ECDSA-AES128-GCM-SHA256:"
358*1a96fba6SXin Li "ECDHE-ECDSA-AES256-GCM-SHA384:"
359*1a96fba6SXin Li "ECDHE-RSA-AES128-GCM-SHA256:"
360*1a96fba6SXin Li "ECDHE-RSA-AES256-GCM-SHA384");
361*1a96fba6SXin Li if (res != 1)
362*1a96fba6SXin Li return ReportError(error, FROM_HERE, "Cannot set the cipher list");
363*1a96fba6SXin Li
364*1a96fba6SXin Li res = SSL_CTX_load_verify_locations(ctx_.get(), nullptr, kCACertificatePath);
365*1a96fba6SXin Li if (res != 1) {
366*1a96fba6SXin Li return ReportError(error, FROM_HERE,
367*1a96fba6SXin Li "Failed to specify trusted certificate location");
368*1a96fba6SXin Li }
369*1a96fba6SXin Li
370*1a96fba6SXin Li // Store a pointer to "this" into SSL_CTX instance.
371*1a96fba6SXin Li SSL_CTX_set_ex_data(ctx_.get(), ssl_ctx_private_data_index, this);
372*1a96fba6SXin Li
373*1a96fba6SXin Li // Ask OpenSSL to validate the server host from the certificate to match
374*1a96fba6SXin Li // the expected host name we are given:
375*1a96fba6SXin Li X509_VERIFY_PARAM* param = SSL_CTX_get0_param(ctx_.get());
376*1a96fba6SXin Li X509_VERIFY_PARAM_set1_host(param, host.c_str(), host.size());
377*1a96fba6SXin Li
378*1a96fba6SXin Li SSL_CTX_set_verify(ctx_.get(), SSL_VERIFY_PEER,
379*1a96fba6SXin Li &TlsStreamImpl::OnCertVerifyResultsStatic);
380*1a96fba6SXin Li
381*1a96fba6SXin Li socket_ = std::move(socket);
382*1a96fba6SXin Li ssl_.reset(SSL_new(ctx_.get()));
383*1a96fba6SXin Li
384*1a96fba6SXin Li // Enable TLS progress callback if VLOG level is >=3.
385*1a96fba6SXin Li if (VLOG_IS_ON(3))
386*1a96fba6SXin Li SSL_set_info_callback(ssl_.get(), TlsInfoCallback);
387*1a96fba6SXin Li
388*1a96fba6SXin Li stream_bio_ = BIO_new_stream(socket_.get());
389*1a96fba6SXin Li SSL_set_bio(ssl_.get(), stream_bio_, stream_bio_);
390*1a96fba6SXin Li SSL_set_connect_state(ssl_.get());
391*1a96fba6SXin Li
392*1a96fba6SXin Li // We might have no message loop (e.g. we are in unit tests).
393*1a96fba6SXin Li if (MessageLoop::ThreadHasCurrent()) {
394*1a96fba6SXin Li MessageLoop::current()->PostTask(
395*1a96fba6SXin Li FROM_HERE,
396*1a96fba6SXin Li base::BindOnce(&TlsStreamImpl::DoHandshake,
397*1a96fba6SXin Li weak_ptr_factory_.GetWeakPtr(),
398*1a96fba6SXin Li success_callback,
399*1a96fba6SXin Li error_callback));
400*1a96fba6SXin Li } else {
401*1a96fba6SXin Li DoHandshake(success_callback, error_callback);
402*1a96fba6SXin Li }
403*1a96fba6SXin Li return true;
404*1a96fba6SXin Li }
405*1a96fba6SXin Li
RetryHandshake(const base::Closure & success_callback,const Stream::ErrorCallback & error_callback,Stream::AccessMode)406*1a96fba6SXin Li void TlsStream::TlsStreamImpl::RetryHandshake(
407*1a96fba6SXin Li const base::Closure& success_callback,
408*1a96fba6SXin Li const Stream::ErrorCallback& error_callback,
409*1a96fba6SXin Li Stream::AccessMode /* mode */) {
410*1a96fba6SXin Li VLOG(1) << "Retrying TLS handshake";
411*1a96fba6SXin Li DoHandshake(success_callback, error_callback);
412*1a96fba6SXin Li }
413*1a96fba6SXin Li
DoHandshake(const base::Closure & success_callback,const Stream::ErrorCallback & error_callback)414*1a96fba6SXin Li void TlsStream::TlsStreamImpl::DoHandshake(
415*1a96fba6SXin Li const base::Closure& success_callback,
416*1a96fba6SXin Li const Stream::ErrorCallback& error_callback) {
417*1a96fba6SXin Li VLOG(1) << "Begin TLS handshake";
418*1a96fba6SXin Li int res = SSL_do_handshake(ssl_.get());
419*1a96fba6SXin Li if (res == 1) {
420*1a96fba6SXin Li VLOG(1) << "Handshake successful";
421*1a96fba6SXin Li success_callback.Run();
422*1a96fba6SXin Li return;
423*1a96fba6SXin Li }
424*1a96fba6SXin Li ErrorPtr error;
425*1a96fba6SXin Li int err = SSL_get_error(ssl_.get(), res);
426*1a96fba6SXin Li if (err == SSL_ERROR_WANT_READ) {
427*1a96fba6SXin Li VLOG(1) << "Waiting for read data...";
428*1a96fba6SXin Li bool ok = socket_->WaitForData(
429*1a96fba6SXin Li Stream::AccessMode::READ,
430*1a96fba6SXin Li base::Bind(&TlsStreamImpl::RetryHandshake,
431*1a96fba6SXin Li weak_ptr_factory_.GetWeakPtr(),
432*1a96fba6SXin Li success_callback, error_callback),
433*1a96fba6SXin Li &error);
434*1a96fba6SXin Li if (ok)
435*1a96fba6SXin Li return;
436*1a96fba6SXin Li } else if (err == SSL_ERROR_WANT_WRITE) {
437*1a96fba6SXin Li VLOG(1) << "Waiting for write data...";
438*1a96fba6SXin Li bool ok = socket_->WaitForData(
439*1a96fba6SXin Li Stream::AccessMode::WRITE,
440*1a96fba6SXin Li base::Bind(&TlsStreamImpl::RetryHandshake,
441*1a96fba6SXin Li weak_ptr_factory_.GetWeakPtr(),
442*1a96fba6SXin Li success_callback, error_callback),
443*1a96fba6SXin Li &error);
444*1a96fba6SXin Li if (ok)
445*1a96fba6SXin Li return;
446*1a96fba6SXin Li } else {
447*1a96fba6SXin Li ReportError(&error, FROM_HERE, "TLS handshake failed.");
448*1a96fba6SXin Li }
449*1a96fba6SXin Li error_callback.Run(error.get());
450*1a96fba6SXin Li }
451*1a96fba6SXin Li
452*1a96fba6SXin Li /////////////////////////////////////////////////////////////////////////////
TlsStream(std::unique_ptr<TlsStreamImpl> impl)453*1a96fba6SXin Li TlsStream::TlsStream(std::unique_ptr<TlsStreamImpl> impl)
454*1a96fba6SXin Li : impl_{std::move(impl)} {}
455*1a96fba6SXin Li
~TlsStream()456*1a96fba6SXin Li TlsStream::~TlsStream() {
457*1a96fba6SXin Li if (impl_) {
458*1a96fba6SXin Li impl_->Close(nullptr);
459*1a96fba6SXin Li }
460*1a96fba6SXin Li }
461*1a96fba6SXin Li
Connect(StreamPtr socket,const std::string & host,const base::Callback<void (StreamPtr)> & success_callback,const Stream::ErrorCallback & error_callback)462*1a96fba6SXin Li void TlsStream::Connect(StreamPtr socket,
463*1a96fba6SXin Li const std::string& host,
464*1a96fba6SXin Li const base::Callback<void(StreamPtr)>& success_callback,
465*1a96fba6SXin Li const Stream::ErrorCallback& error_callback) {
466*1a96fba6SXin Li std::unique_ptr<TlsStreamImpl> impl{new TlsStreamImpl};
467*1a96fba6SXin Li std::unique_ptr<TlsStream> stream{new TlsStream{std::move(impl)}};
468*1a96fba6SXin Li
469*1a96fba6SXin Li TlsStreamImpl* pimpl = stream->impl_.get();
470*1a96fba6SXin Li ErrorPtr error;
471*1a96fba6SXin Li bool success = pimpl->Init(std::move(socket), host,
472*1a96fba6SXin Li base::Bind(success_callback,
473*1a96fba6SXin Li base::Passed(std::move(stream))),
474*1a96fba6SXin Li error_callback, &error);
475*1a96fba6SXin Li
476*1a96fba6SXin Li if (!success)
477*1a96fba6SXin Li error_callback.Run(error.get());
478*1a96fba6SXin Li }
479*1a96fba6SXin Li
IsOpen() const480*1a96fba6SXin Li bool TlsStream::IsOpen() const {
481*1a96fba6SXin Li return impl_ ? true : false;
482*1a96fba6SXin Li }
483*1a96fba6SXin Li
SetSizeBlocking(uint64_t,ErrorPtr * error)484*1a96fba6SXin Li bool TlsStream::SetSizeBlocking(uint64_t /* size */, ErrorPtr* error) {
485*1a96fba6SXin Li return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
486*1a96fba6SXin Li }
487*1a96fba6SXin Li
Seek(int64_t,Whence,uint64_t *,ErrorPtr * error)488*1a96fba6SXin Li bool TlsStream::Seek(int64_t /* offset */,
489*1a96fba6SXin Li Whence /* whence */,
490*1a96fba6SXin Li uint64_t* /* new_position*/,
491*1a96fba6SXin Li ErrorPtr* error) {
492*1a96fba6SXin Li return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
493*1a96fba6SXin Li }
494*1a96fba6SXin Li
ReadNonBlocking(void * buffer,size_t size_to_read,size_t * size_read,bool * end_of_stream,ErrorPtr * error)495*1a96fba6SXin Li bool TlsStream::ReadNonBlocking(void* buffer,
496*1a96fba6SXin Li size_t size_to_read,
497*1a96fba6SXin Li size_t* size_read,
498*1a96fba6SXin Li bool* end_of_stream,
499*1a96fba6SXin Li ErrorPtr* error) {
500*1a96fba6SXin Li if (!impl_)
501*1a96fba6SXin Li return stream_utils::ErrorStreamClosed(FROM_HERE, error);
502*1a96fba6SXin Li return impl_->ReadNonBlocking(buffer, size_to_read, size_read, end_of_stream,
503*1a96fba6SXin Li error);
504*1a96fba6SXin Li }
505*1a96fba6SXin Li
WriteNonBlocking(const void * buffer,size_t size_to_write,size_t * size_written,ErrorPtr * error)506*1a96fba6SXin Li bool TlsStream::WriteNonBlocking(const void* buffer,
507*1a96fba6SXin Li size_t size_to_write,
508*1a96fba6SXin Li size_t* size_written,
509*1a96fba6SXin Li ErrorPtr* error) {
510*1a96fba6SXin Li if (!impl_)
511*1a96fba6SXin Li return stream_utils::ErrorStreamClosed(FROM_HERE, error);
512*1a96fba6SXin Li return impl_->WriteNonBlocking(buffer, size_to_write, size_written, error);
513*1a96fba6SXin Li }
514*1a96fba6SXin Li
FlushBlocking(ErrorPtr * error)515*1a96fba6SXin Li bool TlsStream::FlushBlocking(ErrorPtr* error) {
516*1a96fba6SXin Li if (!impl_)
517*1a96fba6SXin Li return stream_utils::ErrorStreamClosed(FROM_HERE, error);
518*1a96fba6SXin Li return impl_->Flush(error);
519*1a96fba6SXin Li }
520*1a96fba6SXin Li
CloseBlocking(ErrorPtr * error)521*1a96fba6SXin Li bool TlsStream::CloseBlocking(ErrorPtr* error) {
522*1a96fba6SXin Li if (impl_ && !impl_->Close(error))
523*1a96fba6SXin Li return false;
524*1a96fba6SXin Li impl_.reset();
525*1a96fba6SXin Li return true;
526*1a96fba6SXin Li }
527*1a96fba6SXin Li
WaitForData(AccessMode mode,const base::Callback<void (AccessMode)> & callback,ErrorPtr * error)528*1a96fba6SXin Li bool TlsStream::WaitForData(AccessMode mode,
529*1a96fba6SXin Li const base::Callback<void(AccessMode)>& callback,
530*1a96fba6SXin Li ErrorPtr* error) {
531*1a96fba6SXin Li if (!impl_)
532*1a96fba6SXin Li return stream_utils::ErrorStreamClosed(FROM_HERE, error);
533*1a96fba6SXin Li return impl_->WaitForData(mode, callback, error);
534*1a96fba6SXin Li }
535*1a96fba6SXin Li
WaitForDataBlocking(AccessMode in_mode,base::TimeDelta timeout,AccessMode * out_mode,ErrorPtr * error)536*1a96fba6SXin Li bool TlsStream::WaitForDataBlocking(AccessMode in_mode,
537*1a96fba6SXin Li base::TimeDelta timeout,
538*1a96fba6SXin Li AccessMode* out_mode,
539*1a96fba6SXin Li ErrorPtr* error) {
540*1a96fba6SXin Li if (!impl_)
541*1a96fba6SXin Li return stream_utils::ErrorStreamClosed(FROM_HERE, error);
542*1a96fba6SXin Li return impl_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
543*1a96fba6SXin Li }
544*1a96fba6SXin Li
CancelPendingAsyncOperations()545*1a96fba6SXin Li void TlsStream::CancelPendingAsyncOperations() {
546*1a96fba6SXin Li if (impl_)
547*1a96fba6SXin Li impl_->CancelPendingAsyncOperations();
548*1a96fba6SXin Li Stream::CancelPendingAsyncOperations();
549*1a96fba6SXin Li }
550*1a96fba6SXin Li
551*1a96fba6SXin Li } // namespace brillo
552