1 // Copyright 2015 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include <brillo/streams/tls_stream.h>
6
7 #include <algorithm>
8 #include <limits>
9 #include <string>
10 #include <utility>
11 #include <vector>
12
13 #include <openssl/err.h>
14 #include <openssl/ssl.h>
15
16 #include <base/bind.h>
17 #include <base/memory/weak_ptr.h>
18 #include <brillo/message_loops/message_loop.h>
19 #include <brillo/secure_blob.h>
20 #include <brillo/streams/openssl_stream_bio.h>
21 #include <brillo/streams/stream_utils.h>
22 #include <brillo/strings/string_utils.h>
23
24 namespace {
25
26 // SSL info callback which is called by OpenSSL when we enable logging level of
27 // at least 3. This logs the information about the internal TLS handshake.
TlsInfoCallback(const SSL *,int where,int ret)28 void TlsInfoCallback(const SSL* /* ssl */, int where, int ret) {
29 std::string reason;
30 std::vector<std::string> info;
31 if (where & SSL_CB_LOOP)
32 info.push_back("loop");
33 if (where & SSL_CB_EXIT)
34 info.push_back("exit");
35 if (where & SSL_CB_READ)
36 info.push_back("read");
37 if (where & SSL_CB_WRITE)
38 info.push_back("write");
39 if (where & SSL_CB_ALERT) {
40 info.push_back("alert");
41 reason = ", reason: ";
42 reason += SSL_alert_type_string_long(ret);
43 reason += "/";
44 reason += SSL_alert_desc_string_long(ret);
45 }
46 if (where & SSL_CB_HANDSHAKE_START)
47 info.push_back("handshake_start");
48 if (where & SSL_CB_HANDSHAKE_DONE)
49 info.push_back("handshake_done");
50
51 VLOG(3) << "TLS progress info: " << brillo::string_utils::Join(",", info)
52 << ", with status: " << ret << reason;
53 }
54
55 // Static variable to store the index of TlsStream private data in SSL context
56 // used to store custom data for OnCertVerifyResults().
57 int ssl_ctx_private_data_index = -1;
58
59 // Default trusted certificate store location.
60 const char kCACertificatePath[] =
61 #ifdef __ANDROID__
62 "/system/etc/security/cacerts_google";
63 #else
64 "/usr/share/chromeos-ca-certificates";
65 #endif
66
67 } // anonymous namespace
68
69 namespace brillo {
70
71 // TODO(crbug.com/984789): Remove once support for OpenSSL <1.1 is dropped.
72 #if OPENSSL_VERSION_NUMBER < 0x10100000L
73 #define TLS_client_method() TLSv1_2_client_method()
74 #endif
75
76 // Helper implementation of TLS stream used to hide most of OpenSSL inner
77 // workings from the users of brillo::TlsStream.
78 class TlsStream::TlsStreamImpl {
79 public:
80 TlsStreamImpl();
81 ~TlsStreamImpl();
82
83 bool Init(StreamPtr socket,
84 const std::string& host,
85 const base::Closure& success_callback,
86 const Stream::ErrorCallback& error_callback,
87 ErrorPtr* error);
88
89 bool ReadNonBlocking(void* buffer,
90 size_t size_to_read,
91 size_t* size_read,
92 bool* end_of_stream,
93 ErrorPtr* error);
94
95 bool WriteNonBlocking(const void* buffer,
96 size_t size_to_write,
97 size_t* size_written,
98 ErrorPtr* error);
99
100 bool Flush(ErrorPtr* error);
101 bool Close(ErrorPtr* error);
102 bool WaitForData(AccessMode mode,
103 const base::Callback<void(AccessMode)>& callback,
104 ErrorPtr* error);
105 bool WaitForDataBlocking(AccessMode in_mode,
106 base::TimeDelta timeout,
107 AccessMode* out_mode,
108 ErrorPtr* error);
109 void CancelPendingAsyncOperations();
110
111 private:
112 bool ReportError(ErrorPtr* error,
113 const base::Location& location,
114 const std::string& message);
115 void DoHandshake(const base::Closure& success_callback,
116 const Stream::ErrorCallback& error_callback);
117 void RetryHandshake(const base::Closure& success_callback,
118 const Stream::ErrorCallback& error_callback,
119 Stream::AccessMode mode);
120
121 int OnCertVerifyResults(int ok, X509_STORE_CTX* ctx);
122 static int OnCertVerifyResultsStatic(int ok, X509_STORE_CTX* ctx);
123
124 StreamPtr socket_;
125 std::unique_ptr<SSL_CTX, decltype(&SSL_CTX_free)> ctx_{nullptr, SSL_CTX_free};
126 std::unique_ptr<SSL, decltype(&SSL_free)> ssl_{nullptr, SSL_free};
127 BIO* stream_bio_{nullptr};
128 bool need_more_read_{false};
129 bool need_more_write_{false};
130
131 base::WeakPtrFactory<TlsStreamImpl> weak_ptr_factory_{this};
132 DISALLOW_COPY_AND_ASSIGN(TlsStreamImpl);
133 };
134
TlsStreamImpl()135 TlsStream::TlsStreamImpl::TlsStreamImpl() {
136 SSL_load_error_strings();
137 SSL_library_init();
138 if (ssl_ctx_private_data_index < 0) {
139 ssl_ctx_private_data_index =
140 SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
141 }
142 }
143
~TlsStreamImpl()144 TlsStream::TlsStreamImpl::~TlsStreamImpl() {
145 ssl_.reset();
146 ctx_.reset();
147 }
148
ReadNonBlocking(void * buffer,size_t size_to_read,size_t * size_read,bool * end_of_stream,ErrorPtr * error)149 bool TlsStream::TlsStreamImpl::ReadNonBlocking(void* buffer,
150 size_t size_to_read,
151 size_t* size_read,
152 bool* end_of_stream,
153 ErrorPtr* error) {
154 const size_t max_int = std::numeric_limits<int>::max();
155 int size_int = static_cast<int>(std::min(size_to_read, max_int));
156 int ret = SSL_read(ssl_.get(), buffer, size_int);
157 if (ret > 0) {
158 *size_read = static_cast<size_t>(ret);
159 if (end_of_stream)
160 *end_of_stream = false;
161 return true;
162 }
163
164 int err = SSL_get_error(ssl_.get(), ret);
165 if (err == SSL_ERROR_ZERO_RETURN) {
166 *size_read = 0;
167 if (end_of_stream)
168 *end_of_stream = true;
169 return true;
170 }
171
172 if (err == SSL_ERROR_WANT_READ) {
173 need_more_read_ = true;
174 } else if (err == SSL_ERROR_WANT_WRITE) {
175 // Writes might be required for SSL_read() because of possible TLS
176 // re-negotiations which can happen at any time.
177 need_more_write_ = true;
178 } else {
179 return ReportError(error, FROM_HERE, "Error reading from TLS socket");
180 }
181 *size_read = 0;
182 if (end_of_stream)
183 *end_of_stream = false;
184 return true;
185 }
186
WriteNonBlocking(const void * buffer,size_t size_to_write,size_t * size_written,ErrorPtr * error)187 bool TlsStream::TlsStreamImpl::WriteNonBlocking(const void* buffer,
188 size_t size_to_write,
189 size_t* size_written,
190 ErrorPtr* error) {
191 const size_t max_int = std::numeric_limits<int>::max();
192 int size_int = static_cast<int>(std::min(size_to_write, max_int));
193 int ret = SSL_write(ssl_.get(), buffer, size_int);
194 if (ret > 0) {
195 *size_written = static_cast<size_t>(ret);
196 return true;
197 }
198
199 int err = SSL_get_error(ssl_.get(), ret);
200 if (err == SSL_ERROR_WANT_READ) {
201 // Reads might be required for SSL_write() because of possible TLS
202 // re-negotiations which can happen at any time.
203 need_more_read_ = true;
204 } else if (err == SSL_ERROR_WANT_WRITE) {
205 need_more_write_ = true;
206 } else {
207 return ReportError(error, FROM_HERE, "Error writing to TLS socket");
208 }
209 *size_written = 0;
210 return true;
211 }
212
Flush(ErrorPtr * error)213 bool TlsStream::TlsStreamImpl::Flush(ErrorPtr* error) {
214 return socket_->FlushBlocking(error);
215 }
216
Close(ErrorPtr * error)217 bool TlsStream::TlsStreamImpl::Close(ErrorPtr* error) {
218 // 2 seconds should be plenty here.
219 const base::TimeDelta kTimeout = base::TimeDelta::FromSeconds(2);
220 // The retry count of 4 below is just arbitrary, to ensure we don't get stuck
221 // here forever. We should rarely need to repeat SSL_shutdown anyway.
222 for (int retry_count = 0; retry_count < 4; retry_count++) {
223 int ret = SSL_shutdown(ssl_.get());
224 // We really don't care for bi-directional shutdown here.
225 // Just make sure we only send the "close notify" alert to the remote peer.
226 if (ret >= 0)
227 break;
228
229 int err = SSL_get_error(ssl_.get(), ret);
230 if (err == SSL_ERROR_WANT_READ) {
231 if (!socket_->WaitForDataBlocking(AccessMode::READ, kTimeout, nullptr,
232 error)) {
233 break;
234 }
235 } else if (err == SSL_ERROR_WANT_WRITE) {
236 if (!socket_->WaitForDataBlocking(AccessMode::WRITE, kTimeout, nullptr,
237 error)) {
238 break;
239 }
240 } else {
241 LOG(ERROR) << "SSL_shutdown returned error #" << err;
242 ReportError(error, FROM_HERE, "Failed to shut down TLS socket");
243 break;
244 }
245 }
246 return socket_->CloseBlocking(error);
247 }
248
WaitForData(AccessMode mode,const base::Callback<void (AccessMode)> & callback,ErrorPtr * error)249 bool TlsStream::TlsStreamImpl::WaitForData(
250 AccessMode mode,
251 const base::Callback<void(AccessMode)>& callback,
252 ErrorPtr* error) {
253 bool is_read = stream_utils::IsReadAccessMode(mode);
254 bool is_write = stream_utils::IsWriteAccessMode(mode);
255 is_read |= need_more_read_;
256 is_write |= need_more_write_;
257 need_more_read_ = false;
258 need_more_write_ = false;
259 if (is_read && SSL_pending(ssl_.get()) > 0) {
260 callback.Run(AccessMode::READ);
261 return true;
262 }
263 mode = stream_utils::MakeAccessMode(is_read, is_write);
264 return socket_->WaitForData(mode, callback, error);
265 }
266
WaitForDataBlocking(AccessMode in_mode,base::TimeDelta timeout,AccessMode * out_mode,ErrorPtr * error)267 bool TlsStream::TlsStreamImpl::WaitForDataBlocking(AccessMode in_mode,
268 base::TimeDelta timeout,
269 AccessMode* out_mode,
270 ErrorPtr* error) {
271 bool is_read = stream_utils::IsReadAccessMode(in_mode);
272 bool is_write = stream_utils::IsWriteAccessMode(in_mode);
273 is_read |= need_more_read_;
274 is_write |= need_more_write_;
275 need_more_read_ = need_more_write_ = false;
276 if (is_read && SSL_pending(ssl_.get()) > 0) {
277 if (out_mode)
278 *out_mode = AccessMode::READ;
279 return true;
280 }
281 in_mode = stream_utils::MakeAccessMode(is_read, is_write);
282 return socket_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
283 }
284
CancelPendingAsyncOperations()285 void TlsStream::TlsStreamImpl::CancelPendingAsyncOperations() {
286 socket_->CancelPendingAsyncOperations();
287 weak_ptr_factory_.InvalidateWeakPtrs();
288 }
289
ReportError(ErrorPtr * error,const base::Location & location,const std::string & message)290 bool TlsStream::TlsStreamImpl::ReportError(
291 ErrorPtr* error,
292 const base::Location& location,
293 const std::string& message) {
294 const char* file = nullptr;
295 int line = 0;
296 const char* data = 0;
297 int flags = 0;
298 while (auto errnum = ERR_get_error_line_data(&file, &line, &data, &flags)) {
299 char buf[256];
300 ERR_error_string_n(errnum, buf, sizeof(buf));
301 base::Location ssl_location{"Unknown", file, line, nullptr};
302 std::string ssl_message = buf;
303 if (flags & ERR_TXT_STRING) {
304 ssl_message += ": ";
305 ssl_message += data;
306 }
307 Error::AddTo(error, ssl_location, "openssl", std::to_string(errnum),
308 ssl_message);
309 }
310 Error::AddTo(error, location, "tls_stream", "failed", message);
311 return false;
312 }
313
OnCertVerifyResults(int ok,X509_STORE_CTX * ctx)314 int TlsStream::TlsStreamImpl::OnCertVerifyResults(int ok, X509_STORE_CTX* ctx) {
315 // OpenSSL already performs a comprehensive check of the certificate chain
316 // (using X509_verify_cert() function) and calls back with the result of its
317 // verification.
318 // |ok| is set to 1 if the verification passed and 0 if an error was detected.
319 // Here we can perform some additional checks if we need to, or simply log
320 // the issues found.
321
322 // For now, just log an error if it occurred.
323 if (!ok) {
324 LOG(ERROR) << "Server certificate validation failed: "
325 << X509_verify_cert_error_string(X509_STORE_CTX_get_error(ctx));
326 }
327 return ok;
328 }
329
OnCertVerifyResultsStatic(int ok,X509_STORE_CTX * ctx)330 int TlsStream::TlsStreamImpl::OnCertVerifyResultsStatic(int ok,
331 X509_STORE_CTX* ctx) {
332 // Obtain the pointer to the instance of TlsStream::TlsStreamImpl from the
333 // SSL CTX object referenced by |ctx|.
334 SSL* ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(
335 ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
336 SSL_CTX* ssl_ctx = ssl ? SSL_get_SSL_CTX(ssl) : nullptr;
337 TlsStream::TlsStreamImpl* self = nullptr;
338 if (ssl_ctx) {
339 self = static_cast<TlsStream::TlsStreamImpl*>(SSL_CTX_get_ex_data(
340 ssl_ctx, ssl_ctx_private_data_index));
341 }
342 return self ? self->OnCertVerifyResults(ok, ctx) : ok;
343 }
344
Init(StreamPtr socket,const std::string & host,const base::Closure & success_callback,const Stream::ErrorCallback & error_callback,ErrorPtr * error)345 bool TlsStream::TlsStreamImpl::Init(StreamPtr socket,
346 const std::string& host,
347 const base::Closure& success_callback,
348 const Stream::ErrorCallback& error_callback,
349 ErrorPtr* error) {
350 ctx_.reset(SSL_CTX_new(TLS_client_method()));
351 if (!ctx_)
352 return ReportError(error, FROM_HERE, "Cannot create SSL_CTX");
353
354 // Top cipher suites supported by both Google GFEs and OpenSSL (in server
355 // preferred order).
356 int res = SSL_CTX_set_cipher_list(ctx_.get(),
357 "ECDHE-ECDSA-AES128-GCM-SHA256:"
358 "ECDHE-ECDSA-AES256-GCM-SHA384:"
359 "ECDHE-RSA-AES128-GCM-SHA256:"
360 "ECDHE-RSA-AES256-GCM-SHA384");
361 if (res != 1)
362 return ReportError(error, FROM_HERE, "Cannot set the cipher list");
363
364 res = SSL_CTX_load_verify_locations(ctx_.get(), nullptr, kCACertificatePath);
365 if (res != 1) {
366 return ReportError(error, FROM_HERE,
367 "Failed to specify trusted certificate location");
368 }
369
370 // Store a pointer to "this" into SSL_CTX instance.
371 SSL_CTX_set_ex_data(ctx_.get(), ssl_ctx_private_data_index, this);
372
373 // Ask OpenSSL to validate the server host from the certificate to match
374 // the expected host name we are given:
375 X509_VERIFY_PARAM* param = SSL_CTX_get0_param(ctx_.get());
376 X509_VERIFY_PARAM_set1_host(param, host.c_str(), host.size());
377
378 SSL_CTX_set_verify(ctx_.get(), SSL_VERIFY_PEER,
379 &TlsStreamImpl::OnCertVerifyResultsStatic);
380
381 socket_ = std::move(socket);
382 ssl_.reset(SSL_new(ctx_.get()));
383
384 // Enable TLS progress callback if VLOG level is >=3.
385 if (VLOG_IS_ON(3))
386 SSL_set_info_callback(ssl_.get(), TlsInfoCallback);
387
388 stream_bio_ = BIO_new_stream(socket_.get());
389 SSL_set_bio(ssl_.get(), stream_bio_, stream_bio_);
390 SSL_set_connect_state(ssl_.get());
391
392 // We might have no message loop (e.g. we are in unit tests).
393 if (MessageLoop::ThreadHasCurrent()) {
394 MessageLoop::current()->PostTask(
395 FROM_HERE,
396 base::BindOnce(&TlsStreamImpl::DoHandshake,
397 weak_ptr_factory_.GetWeakPtr(),
398 success_callback,
399 error_callback));
400 } else {
401 DoHandshake(success_callback, error_callback);
402 }
403 return true;
404 }
405
RetryHandshake(const base::Closure & success_callback,const Stream::ErrorCallback & error_callback,Stream::AccessMode)406 void TlsStream::TlsStreamImpl::RetryHandshake(
407 const base::Closure& success_callback,
408 const Stream::ErrorCallback& error_callback,
409 Stream::AccessMode /* mode */) {
410 VLOG(1) << "Retrying TLS handshake";
411 DoHandshake(success_callback, error_callback);
412 }
413
DoHandshake(const base::Closure & success_callback,const Stream::ErrorCallback & error_callback)414 void TlsStream::TlsStreamImpl::DoHandshake(
415 const base::Closure& success_callback,
416 const Stream::ErrorCallback& error_callback) {
417 VLOG(1) << "Begin TLS handshake";
418 int res = SSL_do_handshake(ssl_.get());
419 if (res == 1) {
420 VLOG(1) << "Handshake successful";
421 success_callback.Run();
422 return;
423 }
424 ErrorPtr error;
425 int err = SSL_get_error(ssl_.get(), res);
426 if (err == SSL_ERROR_WANT_READ) {
427 VLOG(1) << "Waiting for read data...";
428 bool ok = socket_->WaitForData(
429 Stream::AccessMode::READ,
430 base::Bind(&TlsStreamImpl::RetryHandshake,
431 weak_ptr_factory_.GetWeakPtr(),
432 success_callback, error_callback),
433 &error);
434 if (ok)
435 return;
436 } else if (err == SSL_ERROR_WANT_WRITE) {
437 VLOG(1) << "Waiting for write data...";
438 bool ok = socket_->WaitForData(
439 Stream::AccessMode::WRITE,
440 base::Bind(&TlsStreamImpl::RetryHandshake,
441 weak_ptr_factory_.GetWeakPtr(),
442 success_callback, error_callback),
443 &error);
444 if (ok)
445 return;
446 } else {
447 ReportError(&error, FROM_HERE, "TLS handshake failed.");
448 }
449 error_callback.Run(error.get());
450 }
451
452 /////////////////////////////////////////////////////////////////////////////
TlsStream(std::unique_ptr<TlsStreamImpl> impl)453 TlsStream::TlsStream(std::unique_ptr<TlsStreamImpl> impl)
454 : impl_{std::move(impl)} {}
455
~TlsStream()456 TlsStream::~TlsStream() {
457 if (impl_) {
458 impl_->Close(nullptr);
459 }
460 }
461
Connect(StreamPtr socket,const std::string & host,const base::Callback<void (StreamPtr)> & success_callback,const Stream::ErrorCallback & error_callback)462 void TlsStream::Connect(StreamPtr socket,
463 const std::string& host,
464 const base::Callback<void(StreamPtr)>& success_callback,
465 const Stream::ErrorCallback& error_callback) {
466 std::unique_ptr<TlsStreamImpl> impl{new TlsStreamImpl};
467 std::unique_ptr<TlsStream> stream{new TlsStream{std::move(impl)}};
468
469 TlsStreamImpl* pimpl = stream->impl_.get();
470 ErrorPtr error;
471 bool success = pimpl->Init(std::move(socket), host,
472 base::Bind(success_callback,
473 base::Passed(std::move(stream))),
474 error_callback, &error);
475
476 if (!success)
477 error_callback.Run(error.get());
478 }
479
IsOpen() const480 bool TlsStream::IsOpen() const {
481 return impl_ ? true : false;
482 }
483
SetSizeBlocking(uint64_t,ErrorPtr * error)484 bool TlsStream::SetSizeBlocking(uint64_t /* size */, ErrorPtr* error) {
485 return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
486 }
487
Seek(int64_t,Whence,uint64_t *,ErrorPtr * error)488 bool TlsStream::Seek(int64_t /* offset */,
489 Whence /* whence */,
490 uint64_t* /* new_position*/,
491 ErrorPtr* error) {
492 return stream_utils::ErrorOperationNotSupported(FROM_HERE, error);
493 }
494
ReadNonBlocking(void * buffer,size_t size_to_read,size_t * size_read,bool * end_of_stream,ErrorPtr * error)495 bool TlsStream::ReadNonBlocking(void* buffer,
496 size_t size_to_read,
497 size_t* size_read,
498 bool* end_of_stream,
499 ErrorPtr* error) {
500 if (!impl_)
501 return stream_utils::ErrorStreamClosed(FROM_HERE, error);
502 return impl_->ReadNonBlocking(buffer, size_to_read, size_read, end_of_stream,
503 error);
504 }
505
WriteNonBlocking(const void * buffer,size_t size_to_write,size_t * size_written,ErrorPtr * error)506 bool TlsStream::WriteNonBlocking(const void* buffer,
507 size_t size_to_write,
508 size_t* size_written,
509 ErrorPtr* error) {
510 if (!impl_)
511 return stream_utils::ErrorStreamClosed(FROM_HERE, error);
512 return impl_->WriteNonBlocking(buffer, size_to_write, size_written, error);
513 }
514
FlushBlocking(ErrorPtr * error)515 bool TlsStream::FlushBlocking(ErrorPtr* error) {
516 if (!impl_)
517 return stream_utils::ErrorStreamClosed(FROM_HERE, error);
518 return impl_->Flush(error);
519 }
520
CloseBlocking(ErrorPtr * error)521 bool TlsStream::CloseBlocking(ErrorPtr* error) {
522 if (impl_ && !impl_->Close(error))
523 return false;
524 impl_.reset();
525 return true;
526 }
527
WaitForData(AccessMode mode,const base::Callback<void (AccessMode)> & callback,ErrorPtr * error)528 bool TlsStream::WaitForData(AccessMode mode,
529 const base::Callback<void(AccessMode)>& callback,
530 ErrorPtr* error) {
531 if (!impl_)
532 return stream_utils::ErrorStreamClosed(FROM_HERE, error);
533 return impl_->WaitForData(mode, callback, error);
534 }
535
WaitForDataBlocking(AccessMode in_mode,base::TimeDelta timeout,AccessMode * out_mode,ErrorPtr * error)536 bool TlsStream::WaitForDataBlocking(AccessMode in_mode,
537 base::TimeDelta timeout,
538 AccessMode* out_mode,
539 ErrorPtr* error) {
540 if (!impl_)
541 return stream_utils::ErrorStreamClosed(FROM_HERE, error);
542 return impl_->WaitForDataBlocking(in_mode, timeout, out_mode, error);
543 }
544
CancelPendingAsyncOperations()545 void TlsStream::CancelPendingAsyncOperations() {
546 if (impl_)
547 impl_->CancelPendingAsyncOperations();
548 Stream::CancelPendingAsyncOperations();
549 }
550
551 } // namespace brillo
552