1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/tsi/ssl_transport_security.h"
22 
23 #include <limits.h>
24 #include <string.h>
25 
26 // TODO(jboeuf): refactor inet_ntop into a portability header.
27 // Note: for whomever reads this and tries to refactor this, this
28 // can't be in grpc, it has to be in gpr.
29 #ifdef GPR_WINDOWS
30 #include <ws2tcpip.h>
31 #else
32 #include <arpa/inet.h>
33 #include <sys/socket.h>
34 #endif
35 
36 #include <string>
37 
38 #include <openssl/bio.h>
39 #include <openssl/crypto.h>  // For OPENSSL_free
40 #include <openssl/engine.h>
41 #include <openssl/err.h>
42 #include <openssl/ssl.h>
43 #include <openssl/tls1.h>
44 #include <openssl/x509.h>
45 #include <openssl/x509v3.h>
46 
47 #include "absl/strings/match.h"
48 #include "absl/strings/str_cat.h"
49 #include "absl/strings/string_view.h"
50 
51 #include <grpc/grpc_security.h>
52 #include <grpc/support/alloc.h>
53 #include <grpc/support/log.h>
54 #include <grpc/support/string_util.h>
55 #include <grpc/support/sync.h>
56 #include <grpc/support/thd_id.h>
57 
58 #include "src/core/lib/gpr/useful.h"
59 #include "src/core/lib/gprpp/crash.h"
60 #include "src/core/tsi/ssl/key_logging/ssl_key_logging.h"
61 #include "src/core/tsi/ssl/session_cache/ssl_session_cache.h"
62 #include "src/core/tsi/ssl_transport_security_utils.h"
63 #include "src/core/tsi/ssl_types.h"
64 #include "src/core/tsi/transport_security.h"
65 
66 // --- Constants. ---
67 
68 #define TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND 16384
69 #define TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND 1024
70 #define TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE 1024
71 
72 // Putting a macro like this and littering the source file with #if is really
73 // bad practice.
74 // TODO(jboeuf): refactor all the #if / #endif in a separate module.
75 #ifndef TSI_OPENSSL_ALPN_SUPPORT
76 #define TSI_OPENSSL_ALPN_SUPPORT 1
77 #endif
78 
79 // TODO(jboeuf): I have not found a way to get this number dynamically from the
80 // SSL structure. This is what we would ultimately want though...
81 #define TSI_SSL_MAX_PROTECTION_OVERHEAD 100
82 
83 using TlsSessionKeyLogger = tsi::TlsSessionKeyLoggerCache::TlsSessionKeyLogger;
84 
85 // --- Structure definitions. ---
86 
87 struct tsi_ssl_root_certs_store {
88   X509_STORE* store;
89 };
90 
91 struct tsi_ssl_handshaker_factory {
92   const tsi_ssl_handshaker_factory_vtable* vtable;
93   gpr_refcount refcount;
94 };
95 
96 struct tsi_ssl_client_handshaker_factory {
97   tsi_ssl_handshaker_factory base;
98   SSL_CTX* ssl_context;
99   unsigned char* alpn_protocol_list;
100   size_t alpn_protocol_list_length;
101   grpc_core::RefCountedPtr<tsi::SslSessionLRUCache> session_cache;
102   grpc_core::RefCountedPtr<TlsSessionKeyLogger> key_logger;
103 };
104 
105 struct tsi_ssl_server_handshaker_factory {
106   // Several contexts to support SNI.
107   // The tsi_peer array contains the subject names of the server certificates
108   // associated with the contexts at the same index.
109   tsi_ssl_handshaker_factory base;
110   SSL_CTX** ssl_contexts;
111   tsi_peer* ssl_context_x509_subject_names;
112   size_t ssl_context_count;
113   unsigned char* alpn_protocol_list;
114   size_t alpn_protocol_list_length;
115   grpc_core::RefCountedPtr<TlsSessionKeyLogger> key_logger;
116 };
117 
118 struct tsi_ssl_handshaker {
119   tsi_handshaker base;
120   SSL* ssl;
121   BIO* network_io;
122   tsi_result result;
123   unsigned char* outgoing_bytes_buffer;
124   size_t outgoing_bytes_buffer_size;
125   tsi_ssl_handshaker_factory* factory_ref;
126 };
127 struct tsi_ssl_handshaker_result {
128   tsi_handshaker_result base;
129   SSL* ssl;
130   BIO* network_io;
131   unsigned char* unused_bytes;
132   size_t unused_bytes_size;
133 };
134 struct tsi_ssl_frame_protector {
135   tsi_frame_protector base;
136   SSL* ssl;
137   BIO* network_io;
138   unsigned char* buffer;
139   size_t buffer_size;
140   size_t buffer_offset;
141 };
142 // --- Library Initialization. ---
143 
144 static gpr_once g_init_openssl_once = GPR_ONCE_INIT;
145 static int g_ssl_ctx_ex_factory_index = -1;
146 static const unsigned char kSslSessionIdContext[] = {'g', 'r', 'p', 'c'};
147 static int g_ssl_ex_verified_root_cert_index = -1;
148 #if !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
149 static const char kSslEnginePrefix[] = "engine:";
150 #endif
151 
152 #if OPENSSL_VERSION_NUMBER < 0x10100000
153 static gpr_mu* g_openssl_mutexes = nullptr;
154 static void openssl_locking_cb(int mode, int type, const char* file,
155                                int line) GRPC_UNUSED;
156 static unsigned long openssl_thread_id_cb(void) GRPC_UNUSED;
157 
openssl_locking_cb(int mode,int type,const char * file,int line)158 static void openssl_locking_cb(int mode, int type, const char* file, int line) {
159   if (mode & CRYPTO_LOCK) {
160     gpr_mu_lock(&g_openssl_mutexes[type]);
161   } else {
162     gpr_mu_unlock(&g_openssl_mutexes[type]);
163   }
164 }
165 
openssl_thread_id_cb(void)166 static unsigned long openssl_thread_id_cb(void) {
167   return static_cast<unsigned long>(gpr_thd_currentid());
168 }
169 #endif
170 
init_openssl(void)171 static void init_openssl(void) {
172 #if OPENSSL_VERSION_NUMBER >= 0x10100000
173   OPENSSL_init_ssl(0, nullptr);
174 #else
175   SSL_library_init();
176   SSL_load_error_strings();
177   OpenSSL_add_all_algorithms();
178 #endif
179 #if OPENSSL_VERSION_NUMBER < 0x10100000
180   if (!CRYPTO_get_locking_callback()) {
181     int num_locks = CRYPTO_num_locks();
182     GPR_ASSERT(num_locks > 0);
183     g_openssl_mutexes = static_cast<gpr_mu*>(
184         gpr_malloc(static_cast<size_t>(num_locks) * sizeof(gpr_mu)));
185     for (int i = 0; i < num_locks; i++) {
186       gpr_mu_init(&g_openssl_mutexes[i]);
187     }
188     CRYPTO_set_locking_callback(openssl_locking_cb);
189     CRYPTO_set_id_callback(openssl_thread_id_cb);
190   } else {
191     gpr_log(GPR_INFO, "OpenSSL callback has already been set.");
192   }
193 #endif
194   g_ssl_ctx_ex_factory_index =
195       SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
196   GPR_ASSERT(g_ssl_ctx_ex_factory_index != -1);
197 
198   g_ssl_ex_verified_root_cert_index =
199       SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
200   GPR_ASSERT(g_ssl_ex_verified_root_cert_index != -1);
201 }
202 
203 // --- Ssl utils. ---
204 
205 // TODO(jboeuf): Remove when we are past the debugging phase with this code.
ssl_log_where_info(const SSL * ssl,int where,int flag,const char * msg)206 static void ssl_log_where_info(const SSL* ssl, int where, int flag,
207                                const char* msg) {
208   if ((where & flag) && GRPC_TRACE_FLAG_ENABLED(tsi_tracing_enabled)) {
209     gpr_log(GPR_INFO, "%20.20s - %30.30s  - %5.10s", msg,
210             SSL_state_string_long(ssl), SSL_state_string(ssl));
211   }
212 }
213 
214 // Used for debugging. TODO(jboeuf): Remove when code is mature enough.
ssl_info_callback(const SSL * ssl,int where,int ret)215 static void ssl_info_callback(const SSL* ssl, int where, int ret) {
216   if (ret == 0) {
217     gpr_log(GPR_ERROR, "ssl_info_callback: error occurred.\n");
218     return;
219   }
220 
221   ssl_log_where_info(ssl, where, SSL_CB_LOOP, "LOOP");
222   ssl_log_where_info(ssl, where, SSL_CB_HANDSHAKE_START, "HANDSHAKE START");
223   ssl_log_where_info(ssl, where, SSL_CB_HANDSHAKE_DONE, "HANDSHAKE DONE");
224 }
225 
226 // Returns 1 if name looks like an IP address, 0 otherwise.
227 // This is a very rough heuristic, and only handles IPv6 in hexadecimal form.
looks_like_ip_address(absl::string_view name)228 static int looks_like_ip_address(absl::string_view name) {
229   size_t dot_count = 0;
230   size_t num_size = 0;
231   for (size_t i = 0; i < name.size(); ++i) {
232     if (name[i] == ':') {
233       // IPv6 Address in hexadecimal form, : is not allowed in DNS names.
234       return 1;
235     }
236     if (name[i] >= '0' && name[i] <= '9') {
237       if (num_size > 3) return 0;
238       num_size++;
239     } else if (name[i] == '.') {
240       if (dot_count > 3 || num_size == 0) return 0;
241       dot_count++;
242       num_size = 0;
243     } else {
244       return 0;
245     }
246   }
247   if (dot_count < 3 || num_size == 0) return 0;
248   return 1;
249 }
250 
251 // Gets the subject CN from an X509 cert.
ssl_get_x509_common_name(X509 * cert,unsigned char ** utf8,size_t * utf8_size)252 static tsi_result ssl_get_x509_common_name(X509* cert, unsigned char** utf8,
253                                            size_t* utf8_size) {
254   int common_name_index = -1;
255   X509_NAME_ENTRY* common_name_entry = nullptr;
256   ASN1_STRING* common_name_asn1 = nullptr;
257   X509_NAME* subject_name = X509_get_subject_name(cert);
258   int utf8_returned_size = 0;
259   if (subject_name == nullptr) {
260     gpr_log(GPR_INFO, "Could not get subject name from certificate.");
261     return TSI_NOT_FOUND;
262   }
263   common_name_index =
264       X509_NAME_get_index_by_NID(subject_name, NID_commonName, -1);
265   if (common_name_index == -1) {
266     gpr_log(GPR_INFO, "Could not get common name of subject from certificate.");
267     return TSI_NOT_FOUND;
268   }
269   common_name_entry = X509_NAME_get_entry(subject_name, common_name_index);
270   if (common_name_entry == nullptr) {
271     gpr_log(GPR_ERROR, "Could not get common name entry from certificate.");
272     return TSI_INTERNAL_ERROR;
273   }
274   common_name_asn1 = X509_NAME_ENTRY_get_data(common_name_entry);
275   if (common_name_asn1 == nullptr) {
276     gpr_log(GPR_ERROR,
277             "Could not get common name entry asn1 from certificate.");
278     return TSI_INTERNAL_ERROR;
279   }
280   utf8_returned_size = ASN1_STRING_to_UTF8(utf8, common_name_asn1);
281   if (utf8_returned_size < 0) {
282     gpr_log(GPR_ERROR, "Could not extract utf8 from asn1 string.");
283     return TSI_OUT_OF_RESOURCES;
284   }
285   *utf8_size = static_cast<size_t>(utf8_returned_size);
286   return TSI_OK;
287 }
288 
289 // Gets the subject CN of an X509 cert as a tsi_peer_property.
peer_property_from_x509_common_name(X509 * cert,tsi_peer_property * property)290 static tsi_result peer_property_from_x509_common_name(
291     X509* cert, tsi_peer_property* property) {
292   unsigned char* common_name;
293   size_t common_name_size;
294   tsi_result result =
295       ssl_get_x509_common_name(cert, &common_name, &common_name_size);
296   if (result != TSI_OK) {
297     if (result == TSI_NOT_FOUND) {
298       common_name = nullptr;
299       common_name_size = 0;
300     } else {
301       return result;
302     }
303   }
304   result = tsi_construct_string_peer_property(
305       TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY,
306       common_name == nullptr ? "" : reinterpret_cast<const char*>(common_name),
307       common_name_size, property);
308   OPENSSL_free(common_name);
309   return result;
310 }
311 
312 // Gets the subject of an X509 cert as a tsi_peer_property.
peer_property_from_x509_subject(X509 * cert,tsi_peer_property * property,bool is_verified_root_cert)313 static tsi_result peer_property_from_x509_subject(X509* cert,
314                                                   tsi_peer_property* property,
315                                                   bool is_verified_root_cert) {
316   X509_NAME* subject_name = X509_get_subject_name(cert);
317   if (subject_name == nullptr) {
318     gpr_log(GPR_INFO, "Could not get subject name from certificate.");
319     return TSI_NOT_FOUND;
320   }
321   BIO* bio = BIO_new(BIO_s_mem());
322   X509_NAME_print_ex(bio, subject_name, 0, XN_FLAG_RFC2253);
323   char* contents;
324   long len = BIO_get_mem_data(bio, &contents);
325   if (len < 0) {
326     gpr_log(GPR_ERROR, "Could not get subject entry from certificate.");
327     BIO_free(bio);
328     return TSI_INTERNAL_ERROR;
329   }
330   tsi_result result;
331   if (!is_verified_root_cert) {
332     result = tsi_construct_string_peer_property(
333         TSI_X509_SUBJECT_PEER_PROPERTY, contents, static_cast<size_t>(len),
334         property);
335   } else {
336     result = tsi_construct_string_peer_property(
337         TSI_X509_VERIFIED_ROOT_CERT_SUBECT_PEER_PROPERTY, contents,
338         static_cast<size_t>(len), property);
339   }
340   BIO_free(bio);
341   return result;
342 }
343 
344 // Gets the X509 cert in PEM format as a tsi_peer_property.
add_pem_certificate(X509 * cert,tsi_peer_property * property)345 static tsi_result add_pem_certificate(X509* cert, tsi_peer_property* property) {
346   BIO* bio = BIO_new(BIO_s_mem());
347   if (!PEM_write_bio_X509(bio, cert)) {
348     BIO_free(bio);
349     return TSI_INTERNAL_ERROR;
350   }
351   char* contents;
352   long len = BIO_get_mem_data(bio, &contents);
353   if (len <= 0) {
354     BIO_free(bio);
355     return TSI_INTERNAL_ERROR;
356   }
357   tsi_result result = tsi_construct_string_peer_property(
358       TSI_X509_PEM_CERT_PROPERTY, contents, static_cast<size_t>(len), property);
359   BIO_free(bio);
360   return result;
361 }
362 
363 // Gets the subject SANs from an X509 cert as a tsi_peer_property.
add_subject_alt_names_properties_to_peer(tsi_peer * peer,GENERAL_NAMES * subject_alt_names,size_t subject_alt_name_count,int * current_insert_index)364 static tsi_result add_subject_alt_names_properties_to_peer(
365     tsi_peer* peer, GENERAL_NAMES* subject_alt_names,
366     size_t subject_alt_name_count, int* current_insert_index) {
367   size_t i;
368   tsi_result result = TSI_OK;
369 
370   for (i = 0; i < subject_alt_name_count; i++) {
371     GENERAL_NAME* subject_alt_name =
372         sk_GENERAL_NAME_value(subject_alt_names, TSI_SIZE_AS_SIZE(i));
373     if (subject_alt_name->type == GEN_DNS ||
374         subject_alt_name->type == GEN_EMAIL ||
375         subject_alt_name->type == GEN_URI) {
376       unsigned char* name = nullptr;
377       int name_size;
378       std::string property_name;
379       if (subject_alt_name->type == GEN_DNS) {
380         name_size = ASN1_STRING_to_UTF8(&name, subject_alt_name->d.dNSName);
381         property_name = TSI_X509_DNS_PEER_PROPERTY;
382       } else if (subject_alt_name->type == GEN_EMAIL) {
383         name_size = ASN1_STRING_to_UTF8(&name, subject_alt_name->d.rfc822Name);
384         property_name = TSI_X509_EMAIL_PEER_PROPERTY;
385       } else {
386         name_size = ASN1_STRING_to_UTF8(
387             &name, subject_alt_name->d.uniformResourceIdentifier);
388         property_name = TSI_X509_URI_PEER_PROPERTY;
389       }
390       if (name_size < 0) {
391         gpr_log(GPR_ERROR, "Could not get utf8 from asn1 string.");
392         result = TSI_INTERNAL_ERROR;
393         break;
394       }
395       result = tsi_construct_string_peer_property(
396           TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY,
397           reinterpret_cast<const char*>(name), static_cast<size_t>(name_size),
398           &peer->properties[(*current_insert_index)++]);
399       if (result != TSI_OK) {
400         OPENSSL_free(name);
401         break;
402       }
403       result = tsi_construct_string_peer_property(
404           property_name.c_str(), reinterpret_cast<const char*>(name),
405           static_cast<size_t>(name_size),
406           &peer->properties[(*current_insert_index)++]);
407       OPENSSL_free(name);
408     } else if (subject_alt_name->type == GEN_IPADD) {
409       char ntop_buf[INET6_ADDRSTRLEN];
410       int af;
411 
412       if (subject_alt_name->d.iPAddress->length == 4) {
413         af = AF_INET;
414       } else if (subject_alt_name->d.iPAddress->length == 16) {
415         af = AF_INET6;
416       } else {
417         gpr_log(GPR_ERROR, "SAN IP Address contained invalid IP");
418         result = TSI_INTERNAL_ERROR;
419         break;
420       }
421       const char* name = inet_ntop(af, subject_alt_name->d.iPAddress->data,
422                                    ntop_buf, INET6_ADDRSTRLEN);
423       if (name == nullptr) {
424         gpr_log(GPR_ERROR, "Could not get IP string from asn1 octet.");
425         result = TSI_INTERNAL_ERROR;
426         break;
427       }
428 
429       result = tsi_construct_string_peer_property_from_cstring(
430           TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, name,
431           &peer->properties[(*current_insert_index)++]);
432       if (result != TSI_OK) break;
433       result = tsi_construct_string_peer_property_from_cstring(
434           TSI_X509_IP_PEER_PROPERTY, name,
435           &peer->properties[(*current_insert_index)++]);
436     } else {
437       result = tsi_construct_string_peer_property_from_cstring(
438           TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, "other types of SAN",
439           &peer->properties[(*current_insert_index)++]);
440     }
441     if (result != TSI_OK) break;
442   }
443   return result;
444 }
445 
446 // Gets information about the peer's X509 cert as a tsi_peer object.
peer_from_x509(X509 * cert,int include_certificate_type,tsi_peer * peer)447 static tsi_result peer_from_x509(X509* cert, int include_certificate_type,
448                                  tsi_peer* peer) {
449   // TODO(jboeuf): Maybe add more properties.
450   GENERAL_NAMES* subject_alt_names = static_cast<GENERAL_NAMES*>(
451       X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
452   int subject_alt_name_count =
453       (subject_alt_names != nullptr)
454           ? static_cast<int>(sk_GENERAL_NAME_num(subject_alt_names))
455           : 0;
456   size_t property_count;
457   tsi_result result;
458   GPR_ASSERT(subject_alt_name_count >= 0);
459   property_count = (include_certificate_type ? size_t{1} : 0) +
460                    3 /* subject, common name, certificate */ +
461                    static_cast<size_t>(subject_alt_name_count);
462   for (int i = 0; i < subject_alt_name_count; i++) {
463     GENERAL_NAME* subject_alt_name =
464         sk_GENERAL_NAME_value(subject_alt_names, TSI_SIZE_AS_SIZE(i));
465     // TODO(zhenlian): Clean up tsi_peer to avoid duplicate entries.
466     // URI, DNS, email and ip address SAN fields are plumbed to tsi_peer, in
467     // addition to all SAN fields (results in duplicate values). This code
468     // snippet updates property_count accordingly.
469     if (subject_alt_name->type == GEN_URI ||
470         subject_alt_name->type == GEN_DNS ||
471         subject_alt_name->type == GEN_EMAIL ||
472         subject_alt_name->type == GEN_IPADD) {
473       property_count += 1;
474     }
475   }
476   result = tsi_construct_peer(property_count, peer);
477   if (result != TSI_OK) return result;
478   int current_insert_index = 0;
479   do {
480     if (include_certificate_type) {
481       result = tsi_construct_string_peer_property_from_cstring(
482           TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_X509_CERTIFICATE_TYPE,
483           &peer->properties[current_insert_index++]);
484       if (result != TSI_OK) break;
485     }
486 
487     result = peer_property_from_x509_subject(
488         cert, &peer->properties[current_insert_index++],
489         /*is_verified_root_cert=*/false);
490     if (result != TSI_OK) break;
491 
492     result = peer_property_from_x509_common_name(
493         cert, &peer->properties[current_insert_index++]);
494     if (result != TSI_OK) break;
495 
496     result =
497         add_pem_certificate(cert, &peer->properties[current_insert_index++]);
498     if (result != TSI_OK) break;
499 
500     if (subject_alt_name_count != 0) {
501       result = add_subject_alt_names_properties_to_peer(
502           peer, subject_alt_names, static_cast<size_t>(subject_alt_name_count),
503           &current_insert_index);
504       if (result != TSI_OK) break;
505     }
506   } while (false);
507 
508   if (subject_alt_names != nullptr) {
509     sk_GENERAL_NAME_pop_free(subject_alt_names, GENERAL_NAME_free);
510   }
511   if (result != TSI_OK) tsi_peer_destruct(peer);
512 
513   GPR_ASSERT((int)peer->property_count == current_insert_index);
514   return result;
515 }
516 
517 // Loads an in-memory PEM certificate chain into the SSL context.
ssl_ctx_use_certificate_chain(SSL_CTX * context,const char * pem_cert_chain,size_t pem_cert_chain_size)518 static tsi_result ssl_ctx_use_certificate_chain(SSL_CTX* context,
519                                                 const char* pem_cert_chain,
520                                                 size_t pem_cert_chain_size) {
521   tsi_result result = TSI_OK;
522   X509* certificate = nullptr;
523   BIO* pem;
524   GPR_ASSERT(pem_cert_chain_size <= INT_MAX);
525   pem = BIO_new_mem_buf(pem_cert_chain, static_cast<int>(pem_cert_chain_size));
526   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
527 
528   do {
529     certificate =
530         PEM_read_bio_X509_AUX(pem, nullptr, nullptr, const_cast<char*>(""));
531     if (certificate == nullptr) {
532       result = TSI_INVALID_ARGUMENT;
533       break;
534     }
535     if (!SSL_CTX_use_certificate(context, certificate)) {
536       result = TSI_INVALID_ARGUMENT;
537       break;
538     }
539     while (true) {
540       X509* certificate_authority =
541           PEM_read_bio_X509(pem, nullptr, nullptr, const_cast<char*>(""));
542       if (certificate_authority == nullptr) {
543         ERR_clear_error();
544         break;  // Done reading.
545       }
546       if (!SSL_CTX_add_extra_chain_cert(context, certificate_authority)) {
547         X509_free(certificate_authority);
548         result = TSI_INVALID_ARGUMENT;
549         break;
550       }
551       // We don't need to free certificate_authority as its ownership has been
552       // transferred to the context. That is not the case for certificate
553       // though.
554       //
555     }
556   } while (false);
557 
558   if (certificate != nullptr) X509_free(certificate);
559   BIO_free(pem);
560   return result;
561 }
562 
563 #if !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
ssl_ctx_use_engine_private_key(SSL_CTX * context,const char * pem_key,size_t pem_key_size)564 static tsi_result ssl_ctx_use_engine_private_key(SSL_CTX* context,
565                                                  const char* pem_key,
566                                                  size_t pem_key_size) {
567   tsi_result result = TSI_OK;
568   EVP_PKEY* private_key = nullptr;
569   ENGINE* engine = nullptr;
570   char* engine_name = nullptr;
571   // Parse key which is in following format engine:<engine_id>:<key_id>
572   do {
573     char* engine_start = (char*)pem_key + strlen(kSslEnginePrefix);
574     char* engine_end = (char*)strchr(engine_start, ':');
575     if (engine_end == nullptr) {
576       result = TSI_INVALID_ARGUMENT;
577       break;
578     }
579     char* key_id = engine_end + 1;
580     int engine_name_length = engine_end - engine_start;
581     if (engine_name_length == 0) {
582       result = TSI_INVALID_ARGUMENT;
583       break;
584     }
585     engine_name = static_cast<char*>(gpr_zalloc(engine_name_length + 1));
586     memcpy(engine_name, engine_start, engine_name_length);
587     gpr_log(GPR_DEBUG, "ENGINE key: %s", engine_name);
588     ENGINE_load_dynamic();
589     engine = ENGINE_by_id(engine_name);
590     if (engine == nullptr) {
591       // If not available at ENGINE_DIR, use dynamic to load from
592       // current working directory.
593       engine = ENGINE_by_id("dynamic");
594       if (engine == nullptr) {
595         gpr_log(GPR_ERROR, "Cannot load dynamic engine");
596         result = TSI_INVALID_ARGUMENT;
597         break;
598       }
599       if (!ENGINE_ctrl_cmd_string(engine, "ID", engine_name, 0) ||
600           !ENGINE_ctrl_cmd_string(engine, "DIR_LOAD", "2", 0) ||
601           !ENGINE_ctrl_cmd_string(engine, "DIR_ADD", ".", 0) ||
602           !ENGINE_ctrl_cmd_string(engine, "LIST_ADD", "1", 0) ||
603           !ENGINE_ctrl_cmd_string(engine, "LOAD", NULL, 0)) {
604         gpr_log(GPR_ERROR, "Cannot find engine");
605         result = TSI_INVALID_ARGUMENT;
606         break;
607       }
608     }
609     if (!ENGINE_set_default(engine, ENGINE_METHOD_ALL)) {
610       gpr_log(GPR_ERROR, "ENGINE_set_default with ENGINE_METHOD_ALL failed");
611       result = TSI_INVALID_ARGUMENT;
612       break;
613     }
614     if (!ENGINE_init(engine)) {
615       gpr_log(GPR_ERROR, "ENGINE_init failed");
616       result = TSI_INVALID_ARGUMENT;
617       break;
618     }
619     private_key = ENGINE_load_private_key(engine, key_id, 0, 0);
620     if (private_key == nullptr) {
621       gpr_log(GPR_ERROR, "ENGINE_load_private_key failed");
622       result = TSI_INVALID_ARGUMENT;
623       break;
624     }
625     if (!SSL_CTX_use_PrivateKey(context, private_key)) {
626       gpr_log(GPR_ERROR, "SSL_CTX_use_PrivateKey failed");
627       result = TSI_INVALID_ARGUMENT;
628       break;
629     }
630   } while (0);
631   if (engine != nullptr) ENGINE_free(engine);
632   if (private_key != nullptr) EVP_PKEY_free(private_key);
633   if (engine_name != nullptr) gpr_free(engine_name);
634   return result;
635 }
636 #endif  // !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
637 
ssl_ctx_use_pem_private_key(SSL_CTX * context,const char * pem_key,size_t pem_key_size)638 static tsi_result ssl_ctx_use_pem_private_key(SSL_CTX* context,
639                                               const char* pem_key,
640                                               size_t pem_key_size) {
641   tsi_result result = TSI_OK;
642   EVP_PKEY* private_key = nullptr;
643   BIO* pem;
644   GPR_ASSERT(pem_key_size <= INT_MAX);
645   pem = BIO_new_mem_buf(pem_key, static_cast<int>(pem_key_size));
646   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
647   do {
648     private_key =
649         PEM_read_bio_PrivateKey(pem, nullptr, nullptr, const_cast<char*>(""));
650     if (private_key == nullptr) {
651       result = TSI_INVALID_ARGUMENT;
652       break;
653     }
654     if (!SSL_CTX_use_PrivateKey(context, private_key)) {
655       result = TSI_INVALID_ARGUMENT;
656       break;
657     }
658   } while (false);
659   if (private_key != nullptr) EVP_PKEY_free(private_key);
660   BIO_free(pem);
661   return result;
662 }
663 
664 // Loads an in-memory PEM private key into the SSL context.
ssl_ctx_use_private_key(SSL_CTX * context,const char * pem_key,size_t pem_key_size)665 static tsi_result ssl_ctx_use_private_key(SSL_CTX* context, const char* pem_key,
666                                           size_t pem_key_size) {
667 // BoringSSL does not have ENGINE support
668 #if !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
669   if (strncmp(pem_key, kSslEnginePrefix, strlen(kSslEnginePrefix)) == 0) {
670     return ssl_ctx_use_engine_private_key(context, pem_key, pem_key_size);
671   } else
672 #endif  // !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
673   {
674     return ssl_ctx_use_pem_private_key(context, pem_key, pem_key_size);
675   }
676 }
677 
678 // Loads in-memory PEM verification certs into the SSL context and optionally
679 // returns the verification cert names (root_names can be NULL).
x509_store_load_certs(X509_STORE * cert_store,const char * pem_roots,size_t pem_roots_size,STACK_OF (X509_NAME)** root_names)680 static tsi_result x509_store_load_certs(X509_STORE* cert_store,
681                                         const char* pem_roots,
682                                         size_t pem_roots_size,
683                                         STACK_OF(X509_NAME) * *root_names) {
684   tsi_result result = TSI_OK;
685   size_t num_roots = 0;
686   X509* root = nullptr;
687   X509_NAME* root_name = nullptr;
688   BIO* pem;
689   GPR_ASSERT(pem_roots_size <= INT_MAX);
690   pem = BIO_new_mem_buf(pem_roots, static_cast<int>(pem_roots_size));
691   if (cert_store == nullptr) return TSI_INVALID_ARGUMENT;
692   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
693   if (root_names != nullptr) {
694     *root_names = sk_X509_NAME_new_null();
695     if (*root_names == nullptr) return TSI_OUT_OF_RESOURCES;
696   }
697 
698   while (true) {
699     root = PEM_read_bio_X509_AUX(pem, nullptr, nullptr, const_cast<char*>(""));
700     if (root == nullptr) {
701       ERR_clear_error();
702       break;  // We're at the end of stream.
703     }
704     if (root_names != nullptr) {
705       root_name = X509_get_subject_name(root);
706       if (root_name == nullptr) {
707         gpr_log(GPR_ERROR, "Could not get name from root certificate.");
708         result = TSI_INVALID_ARGUMENT;
709         break;
710       }
711       root_name = X509_NAME_dup(root_name);
712       if (root_name == nullptr) {
713         result = TSI_OUT_OF_RESOURCES;
714         break;
715       }
716       sk_X509_NAME_push(*root_names, root_name);
717       root_name = nullptr;
718     }
719     ERR_clear_error();
720     if (!X509_STORE_add_cert(cert_store, root)) {
721       unsigned long error = ERR_get_error();
722       if (ERR_GET_LIB(error) != ERR_LIB_X509 ||
723           ERR_GET_REASON(error) != X509_R_CERT_ALREADY_IN_HASH_TABLE) {
724         gpr_log(GPR_ERROR, "Could not add root certificate to ssl context.");
725         result = TSI_INTERNAL_ERROR;
726         break;
727       }
728     }
729     X509_free(root);
730     num_roots++;
731   }
732   if (num_roots == 0) {
733     gpr_log(GPR_ERROR, "Could not load any root certificate.");
734     result = TSI_INVALID_ARGUMENT;
735   }
736 
737   if (result != TSI_OK) {
738     if (root != nullptr) X509_free(root);
739     if (root_names != nullptr) {
740       sk_X509_NAME_pop_free(*root_names, X509_NAME_free);
741       *root_names = nullptr;
742       if (root_name != nullptr) X509_NAME_free(root_name);
743     }
744   }
745   BIO_free(pem);
746   return result;
747 }
748 
ssl_ctx_load_verification_certs(SSL_CTX * context,const char * pem_roots,size_t pem_roots_size,STACK_OF (X509_NAME)** root_name)749 static tsi_result ssl_ctx_load_verification_certs(SSL_CTX* context,
750                                                   const char* pem_roots,
751                                                   size_t pem_roots_size,
752                                                   STACK_OF(X509_NAME) *
753                                                       *root_name) {
754   X509_STORE* cert_store = SSL_CTX_get_cert_store(context);
755   X509_STORE_set_flags(cert_store,
756                        X509_V_FLAG_PARTIAL_CHAIN | X509_V_FLAG_TRUSTED_FIRST);
757   return x509_store_load_certs(cert_store, pem_roots, pem_roots_size,
758                                root_name);
759 }
760 
761 // Populates the SSL context with a private key and a cert chain, and sets the
762 // cipher list and the ephemeral ECDH key.
populate_ssl_context(SSL_CTX * context,const tsi_ssl_pem_key_cert_pair * key_cert_pair,const char * cipher_list)763 static tsi_result populate_ssl_context(
764     SSL_CTX* context, const tsi_ssl_pem_key_cert_pair* key_cert_pair,
765     const char* cipher_list) {
766   tsi_result result = TSI_OK;
767   if (key_cert_pair != nullptr) {
768     if (key_cert_pair->cert_chain != nullptr) {
769       result = ssl_ctx_use_certificate_chain(context, key_cert_pair->cert_chain,
770                                              strlen(key_cert_pair->cert_chain));
771       if (result != TSI_OK) {
772         gpr_log(GPR_ERROR, "Invalid cert chain file.");
773         return result;
774       }
775     }
776     if (key_cert_pair->private_key != nullptr) {
777       result = ssl_ctx_use_private_key(context, key_cert_pair->private_key,
778                                        strlen(key_cert_pair->private_key));
779       if (result != TSI_OK || !SSL_CTX_check_private_key(context)) {
780         gpr_log(GPR_ERROR, "Invalid private key.");
781         return result != TSI_OK ? result : TSI_INVALID_ARGUMENT;
782       }
783     }
784   }
785   if ((cipher_list != nullptr) &&
786       !SSL_CTX_set_cipher_list(context, cipher_list)) {
787     gpr_log(GPR_ERROR, "Invalid cipher list: %s.", cipher_list);
788     return TSI_INVALID_ARGUMENT;
789   }
790   {
791     EC_KEY* ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
792     if (!SSL_CTX_set_tmp_ecdh(context, ecdh)) {
793       gpr_log(GPR_ERROR, "Could not set ephemeral ECDH key.");
794       EC_KEY_free(ecdh);
795       return TSI_INTERNAL_ERROR;
796     }
797     SSL_CTX_set_options(context, SSL_OP_SINGLE_ECDH_USE);
798     EC_KEY_free(ecdh);
799   }
800   return TSI_OK;
801 }
802 
803 // Extracts the CN and the SANs from an X509 cert as a peer object.
tsi_ssl_extract_x509_subject_names_from_pem_cert(const char * pem_cert,tsi_peer * peer)804 tsi_result tsi_ssl_extract_x509_subject_names_from_pem_cert(
805     const char* pem_cert, tsi_peer* peer) {
806   tsi_result result = TSI_OK;
807   X509* cert = nullptr;
808   BIO* pem;
809   pem = BIO_new_mem_buf(pem_cert, static_cast<int>(strlen(pem_cert)));
810   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
811 
812   cert = PEM_read_bio_X509(pem, nullptr, nullptr, const_cast<char*>(""));
813   if (cert == nullptr) {
814     gpr_log(GPR_ERROR, "Invalid certificate");
815     result = TSI_INVALID_ARGUMENT;
816   } else {
817     result = peer_from_x509(cert, 0, peer);
818   }
819   if (cert != nullptr) X509_free(cert);
820   BIO_free(pem);
821   return result;
822 }
823 
824 // Builds the alpn protocol name list according to rfc 7301.
build_alpn_protocol_name_list(const char ** alpn_protocols,uint16_t num_alpn_protocols,unsigned char ** protocol_name_list,size_t * protocol_name_list_length)825 static tsi_result build_alpn_protocol_name_list(
826     const char** alpn_protocols, uint16_t num_alpn_protocols,
827     unsigned char** protocol_name_list, size_t* protocol_name_list_length) {
828   uint16_t i;
829   unsigned char* current;
830   *protocol_name_list = nullptr;
831   *protocol_name_list_length = 0;
832   if (num_alpn_protocols == 0) return TSI_INVALID_ARGUMENT;
833   for (i = 0; i < num_alpn_protocols; i++) {
834     size_t length =
835         alpn_protocols[i] == nullptr ? 0 : strlen(alpn_protocols[i]);
836     if (length == 0 || length > 255) {
837       gpr_log(GPR_ERROR, "Invalid protocol name length: %d.",
838               static_cast<int>(length));
839       return TSI_INVALID_ARGUMENT;
840     }
841     *protocol_name_list_length += length + 1;
842   }
843   *protocol_name_list =
844       static_cast<unsigned char*>(gpr_malloc(*protocol_name_list_length));
845   if (*protocol_name_list == nullptr) return TSI_OUT_OF_RESOURCES;
846   current = *protocol_name_list;
847   for (i = 0; i < num_alpn_protocols; i++) {
848     size_t length = strlen(alpn_protocols[i]);
849     *(current++) = static_cast<uint8_t>(length);  // max checked above.
850     memcpy(current, alpn_protocols[i], length);
851     current += length;
852   }
853   // Safety check.
854   if ((current < *protocol_name_list) ||
855       (static_cast<uintptr_t>(current - *protocol_name_list) !=
856        *protocol_name_list_length)) {
857     return TSI_INTERNAL_ERROR;
858   }
859   return TSI_OK;
860 }
861 
862 // This callback is invoked when the CRL has been verified and will soft-fail
863 // errors in verification depending on certain error types.
verify_cb(int ok,X509_STORE_CTX * ctx)864 static int verify_cb(int ok, X509_STORE_CTX* ctx) {
865   int cert_error = X509_STORE_CTX_get_error(ctx);
866   if (cert_error == X509_V_ERR_UNABLE_TO_GET_CRL) {
867     gpr_log(
868         GPR_INFO,
869         "Certificate verification failed to get CRL files. Ignoring error.");
870     return 1;
871   }
872   if (cert_error != 0) {
873     gpr_log(GPR_ERROR, "Certificate verify failed with code %d", cert_error);
874   }
875   return ok;
876 }
877 
878 // The verification callback is used for clients that don't really care about
879 // the server's certificate, but we need to pull it anyway, in case a higher
880 // layer wants to look at it. In this case the verification may fail, but
881 // we don't really care.
NullVerifyCallback(int,X509_STORE_CTX *)882 static int NullVerifyCallback(int /*preverify_ok*/, X509_STORE_CTX* /*ctx*/) {
883   return 1;
884 }
885 
RootCertExtractCallback(int preverify_ok,X509_STORE_CTX * ctx)886 static int RootCertExtractCallback(int preverify_ok, X509_STORE_CTX* ctx) {
887   if (ctx == nullptr) {
888     return preverify_ok;
889   }
890 
891   // There's a case where this function is set in SSL_CTX_set_verify and a CRL
892   // related callback is set with X509_STORE_set_verify_cb. They overlap and
893   // this will take precedence, thus we need to ensure the CRL related callback
894   // is still called
895   X509_VERIFY_PARAM* param = X509_STORE_CTX_get0_param(ctx);
896   auto flags = X509_VERIFY_PARAM_get_flags(param);
897   if (flags & X509_V_FLAG_CRL_CHECK) {
898     preverify_ok = verify_cb(preverify_ok, ctx);
899   }
900 
901   // If preverify_ok == 0, verification failed. We shouldn't expect to have a
902   // verified chain, so there is no need to attempt to extract the root cert
903   // from it
904   if (preverify_ok == 0) {
905     return preverify_ok;
906   }
907 
908   // If we're here, verification was successful
909   // Get the verified chain from the X509_STORE_CTX and put it on the SSL object
910   // so that we have access to it when populating the tsi_peer
911 #if OPENSSL_VERSION_NUMBER >= 0x10100000
912   STACK_OF(X509)* chain = X509_STORE_CTX_get0_chain(ctx);
913 #else
914   STACK_OF(X509)* chain = X509_STORE_CTX_get_chain(ctx);
915 #endif
916 
917   if (chain == nullptr) {
918     return preverify_ok;
919   }
920 
921   // The root cert is the last in the chain
922   size_t chain_length = sk_X509_num(chain);
923   if (chain_length == 0) {
924     return preverify_ok;
925   }
926   X509* root_cert = sk_X509_value(chain, chain_length - 1);
927   if (root_cert == nullptr) {
928     return preverify_ok;
929   }
930 
931   SSL* ssl = static_cast<SSL*>(
932       X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
933   if (ssl == nullptr) {
934     return preverify_ok;
935   }
936   int success =
937       SSL_set_ex_data(ssl, g_ssl_ex_verified_root_cert_index, root_cert);
938   if (success == 0) {
939     gpr_log(GPR_INFO, "Could not set verified root cert in SSL's ex_data");
940   }
941   return preverify_ok;
942 }
943 
944 // Sets the min and max TLS version of |ssl_context| to |min_tls_version| and
945 // |max_tls_version|, respectively. Calling this method is a no-op when using
946 // OpenSSL versions < 1.1.
tsi_set_min_and_max_tls_versions(SSL_CTX * ssl_context,tsi_tls_version min_tls_version,tsi_tls_version max_tls_version)947 static tsi_result tsi_set_min_and_max_tls_versions(
948     SSL_CTX* ssl_context, tsi_tls_version min_tls_version,
949     tsi_tls_version max_tls_version) {
950   if (ssl_context == nullptr) {
951     gpr_log(GPR_INFO,
952             "Invalid nullptr argument to |tsi_set_min_and_max_tls_versions|.");
953     return TSI_INVALID_ARGUMENT;
954   }
955 #if OPENSSL_VERSION_NUMBER >= 0x10100000
956   // Set the min TLS version of the SSL context if using OpenSSL version
957   // >= 1.1.0. This OpenSSL version is required because the
958   // |SSL_CTX_set_min_proto_version| and |SSL_CTX_set_max_proto_version| APIs
959   // only exist in this version range.
960   switch (min_tls_version) {
961     case tsi_tls_version::TSI_TLS1_2:
962       SSL_CTX_set_min_proto_version(ssl_context, TLS1_2_VERSION);
963       break;
964 #if defined(TLS1_3_VERSION)
965     // If the library does not support TLS 1.3 and the caller requests a minimum
966     // of TLS 1.3, then return an error because the caller's request cannot be
967     // satisfied.
968     case tsi_tls_version::TSI_TLS1_3:
969       SSL_CTX_set_min_proto_version(ssl_context, TLS1_3_VERSION);
970       break;
971 #endif
972     default:
973       gpr_log(GPR_INFO, "TLS version is not supported.");
974       return TSI_FAILED_PRECONDITION;
975   }
976 
977   // Set the max TLS version of the SSL context.
978   switch (max_tls_version) {
979     case tsi_tls_version::TSI_TLS1_2:
980       SSL_CTX_set_max_proto_version(ssl_context, TLS1_2_VERSION);
981       break;
982     case tsi_tls_version::TSI_TLS1_3:
983 #if defined(TLS1_3_VERSION)
984       SSL_CTX_set_max_proto_version(ssl_context, TLS1_3_VERSION);
985 #else
986       // If the library does not support TLS 1.3, then set the max TLS version
987       // to TLS 1.2 instead.
988       SSL_CTX_set_max_proto_version(ssl_context, TLS1_2_VERSION);
989 #endif
990       break;
991     default:
992       gpr_log(GPR_INFO, "TLS version is not supported.");
993       return TSI_FAILED_PRECONDITION;
994   }
995 #endif
996   return TSI_OK;
997 }
998 
999 // --- tsi_ssl_root_certs_store methods implementation. ---
1000 
tsi_ssl_root_certs_store_create(const char * pem_roots)1001 tsi_ssl_root_certs_store* tsi_ssl_root_certs_store_create(
1002     const char* pem_roots) {
1003   if (pem_roots == nullptr) {
1004     gpr_log(GPR_ERROR, "The root certificates are empty.");
1005     return nullptr;
1006   }
1007   tsi_ssl_root_certs_store* root_store = static_cast<tsi_ssl_root_certs_store*>(
1008       gpr_zalloc(sizeof(tsi_ssl_root_certs_store)));
1009   if (root_store == nullptr) {
1010     gpr_log(GPR_ERROR, "Could not allocate buffer for ssl_root_certs_store.");
1011     return nullptr;
1012   }
1013   root_store->store = X509_STORE_new();
1014   if (root_store->store == nullptr) {
1015     gpr_log(GPR_ERROR, "Could not allocate buffer for X509_STORE.");
1016     gpr_free(root_store);
1017     return nullptr;
1018   }
1019   tsi_result result = x509_store_load_certs(root_store->store, pem_roots,
1020                                             strlen(pem_roots), nullptr);
1021   if (result != TSI_OK) {
1022     gpr_log(GPR_ERROR, "Could not load root certificates.");
1023     X509_STORE_free(root_store->store);
1024     gpr_free(root_store);
1025     return nullptr;
1026   }
1027   return root_store;
1028 }
1029 
tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store * self)1030 void tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store* self) {
1031   if (self == nullptr) return;
1032   X509_STORE_free(self->store);
1033   gpr_free(self);
1034 }
1035 
1036 // --- tsi_ssl_session_cache methods implementation. ---
1037 
tsi_ssl_session_cache_create_lru(size_t capacity)1038 tsi_ssl_session_cache* tsi_ssl_session_cache_create_lru(size_t capacity) {
1039   // Pointer will be dereferenced by unref call.
1040   return tsi::SslSessionLRUCache::Create(capacity).release()->c_ptr();
1041 }
1042 
tsi_ssl_session_cache_ref(tsi_ssl_session_cache * cache)1043 void tsi_ssl_session_cache_ref(tsi_ssl_session_cache* cache) {
1044   // Pointer will be dereferenced by unref call.
1045   tsi::SslSessionLRUCache::FromC(cache)->Ref().release();
1046 }
1047 
tsi_ssl_session_cache_unref(tsi_ssl_session_cache * cache)1048 void tsi_ssl_session_cache_unref(tsi_ssl_session_cache* cache) {
1049   tsi::SslSessionLRUCache::FromC(cache)->Unref();
1050 }
1051 
1052 // --- tsi_frame_protector methods implementation. ---
1053 
ssl_protector_protect(tsi_frame_protector * self,const unsigned char * unprotected_bytes,size_t * unprotected_bytes_size,unsigned char * protected_output_frames,size_t * protected_output_frames_size)1054 static tsi_result ssl_protector_protect(tsi_frame_protector* self,
1055                                         const unsigned char* unprotected_bytes,
1056                                         size_t* unprotected_bytes_size,
1057                                         unsigned char* protected_output_frames,
1058                                         size_t* protected_output_frames_size) {
1059   tsi_ssl_frame_protector* impl =
1060       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1061 
1062   return grpc_core::SslProtectorProtect(
1063       unprotected_bytes, impl->buffer_size, impl->buffer_offset, impl->buffer,
1064       impl->ssl, impl->network_io, unprotected_bytes_size,
1065       protected_output_frames, protected_output_frames_size);
1066 }
1067 
ssl_protector_protect_flush(tsi_frame_protector * self,unsigned char * protected_output_frames,size_t * protected_output_frames_size,size_t * still_pending_size)1068 static tsi_result ssl_protector_protect_flush(
1069     tsi_frame_protector* self, unsigned char* protected_output_frames,
1070     size_t* protected_output_frames_size, size_t* still_pending_size) {
1071   tsi_ssl_frame_protector* impl =
1072       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1073   return grpc_core::SslProtectorProtectFlush(
1074       impl->buffer_offset, impl->buffer, impl->ssl, impl->network_io,
1075       protected_output_frames, protected_output_frames_size,
1076       still_pending_size);
1077 }
1078 
ssl_protector_unprotect(tsi_frame_protector * self,const unsigned char * protected_frames_bytes,size_t * protected_frames_bytes_size,unsigned char * unprotected_bytes,size_t * unprotected_bytes_size)1079 static tsi_result ssl_protector_unprotect(
1080     tsi_frame_protector* self, const unsigned char* protected_frames_bytes,
1081     size_t* protected_frames_bytes_size, unsigned char* unprotected_bytes,
1082     size_t* unprotected_bytes_size) {
1083   tsi_ssl_frame_protector* impl =
1084       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1085   return grpc_core::SslProtectorUnprotect(
1086       protected_frames_bytes, impl->ssl, impl->network_io,
1087       protected_frames_bytes_size, unprotected_bytes, unprotected_bytes_size);
1088 }
1089 
ssl_protector_destroy(tsi_frame_protector * self)1090 static void ssl_protector_destroy(tsi_frame_protector* self) {
1091   tsi_ssl_frame_protector* impl =
1092       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1093   if (impl->buffer != nullptr) gpr_free(impl->buffer);
1094   if (impl->ssl != nullptr) SSL_free(impl->ssl);
1095   if (impl->network_io != nullptr) BIO_free(impl->network_io);
1096   gpr_free(self);
1097 }
1098 
1099 static const tsi_frame_protector_vtable frame_protector_vtable = {
1100     ssl_protector_protect,
1101     ssl_protector_protect_flush,
1102     ssl_protector_unprotect,
1103     ssl_protector_destroy,
1104 };
1105 
1106 // --- tsi_server_handshaker_factory methods implementation. ---
1107 
tsi_ssl_handshaker_factory_destroy(tsi_ssl_handshaker_factory * factory)1108 static void tsi_ssl_handshaker_factory_destroy(
1109     tsi_ssl_handshaker_factory* factory) {
1110   if (factory == nullptr) return;
1111 
1112   if (factory->vtable != nullptr && factory->vtable->destroy != nullptr) {
1113     factory->vtable->destroy(factory);
1114   }
1115   // Note, we don't free(self) here because this object is always directly
1116   // embedded in another object. If tsi_ssl_handshaker_factory_init allocates
1117   // any memory, it should be free'd here.
1118 }
1119 
tsi_ssl_handshaker_factory_ref(tsi_ssl_handshaker_factory * factory)1120 static tsi_ssl_handshaker_factory* tsi_ssl_handshaker_factory_ref(
1121     tsi_ssl_handshaker_factory* factory) {
1122   if (factory == nullptr) return nullptr;
1123   gpr_refn(&factory->refcount, 1);
1124   return factory;
1125 }
1126 
tsi_ssl_handshaker_factory_unref(tsi_ssl_handshaker_factory * factory)1127 static void tsi_ssl_handshaker_factory_unref(
1128     tsi_ssl_handshaker_factory* factory) {
1129   if (factory == nullptr) return;
1130 
1131   if (gpr_unref(&factory->refcount)) {
1132     tsi_ssl_handshaker_factory_destroy(factory);
1133   }
1134 }
1135 
1136 static tsi_ssl_handshaker_factory_vtable handshaker_factory_vtable = {nullptr};
1137 
1138 // Initializes a tsi_ssl_handshaker_factory object. Caller is responsible for
1139 // allocating memory for the factory.
tsi_ssl_handshaker_factory_init(tsi_ssl_handshaker_factory * factory)1140 static void tsi_ssl_handshaker_factory_init(
1141     tsi_ssl_handshaker_factory* factory) {
1142   GPR_ASSERT(factory != nullptr);
1143 
1144   factory->vtable = &handshaker_factory_vtable;
1145   gpr_ref_init(&factory->refcount, 1);
1146 }
1147 
1148 // Gets the X509 cert chain in PEM format as a tsi_peer_property.
tsi_ssl_get_cert_chain_contents(STACK_OF (X509)* peer_chain,tsi_peer_property * property)1149 tsi_result tsi_ssl_get_cert_chain_contents(STACK_OF(X509) * peer_chain,
1150                                            tsi_peer_property* property) {
1151   BIO* bio = BIO_new(BIO_s_mem());
1152   const auto peer_chain_len = sk_X509_num(peer_chain);
1153   for (auto i = decltype(peer_chain_len){0}; i < peer_chain_len; i++) {
1154     if (!PEM_write_bio_X509(bio, sk_X509_value(peer_chain, i))) {
1155       BIO_free(bio);
1156       return TSI_INTERNAL_ERROR;
1157     }
1158   }
1159   char* contents;
1160   long len = BIO_get_mem_data(bio, &contents);
1161   if (len <= 0) {
1162     BIO_free(bio);
1163     return TSI_INTERNAL_ERROR;
1164   }
1165   tsi_result result = tsi_construct_string_peer_property(
1166       TSI_X509_PEM_CERT_CHAIN_PROPERTY, contents, static_cast<size_t>(len),
1167       property);
1168   BIO_free(bio);
1169   return result;
1170 }
1171 
1172 // --- tsi_handshaker_result methods implementation. ---
ssl_handshaker_result_extract_peer(const tsi_handshaker_result * self,tsi_peer * peer)1173 static tsi_result ssl_handshaker_result_extract_peer(
1174     const tsi_handshaker_result* self, tsi_peer* peer) {
1175   tsi_result result = TSI_OK;
1176   const unsigned char* alpn_selected = nullptr;
1177   unsigned int alpn_selected_len;
1178   const tsi_ssl_handshaker_result* impl =
1179       reinterpret_cast<const tsi_ssl_handshaker_result*>(self);
1180   X509* peer_cert = SSL_get_peer_certificate(impl->ssl);
1181   if (peer_cert != nullptr) {
1182     result = peer_from_x509(peer_cert, 1, peer);
1183     X509_free(peer_cert);
1184     if (result != TSI_OK) return result;
1185   }
1186 #if TSI_OPENSSL_ALPN_SUPPORT
1187   SSL_get0_alpn_selected(impl->ssl, &alpn_selected, &alpn_selected_len);
1188 #endif  // TSI_OPENSSL_ALPN_SUPPORT
1189   if (alpn_selected == nullptr) {
1190     // Try npn.
1191     SSL_get0_next_proto_negotiated(impl->ssl, &alpn_selected,
1192                                    &alpn_selected_len);
1193   }
1194   // When called on the client side, the stack also contains the
1195   // peer's certificate; When called on the server side,
1196   // the peer's certificate is not present in the stack
1197   STACK_OF(X509)* peer_chain = SSL_get_peer_cert_chain(impl->ssl);
1198 
1199   X509* verified_root_cert = static_cast<X509*>(
1200       SSL_get_ex_data(impl->ssl, g_ssl_ex_verified_root_cert_index));
1201   // 1 is for session reused property.
1202   size_t new_property_count = peer->property_count + 3;
1203   if (alpn_selected != nullptr) new_property_count++;
1204   if (peer_chain != nullptr) new_property_count++;
1205   if (verified_root_cert != nullptr) new_property_count++;
1206   tsi_peer_property* new_properties = static_cast<tsi_peer_property*>(
1207       gpr_zalloc(sizeof(*new_properties) * new_property_count));
1208   for (size_t i = 0; i < peer->property_count; i++) {
1209     new_properties[i] = peer->properties[i];
1210   }
1211   if (peer->properties != nullptr) gpr_free(peer->properties);
1212   peer->properties = new_properties;
1213   // Add peer chain if available
1214   if (peer_chain != nullptr) {
1215     result = tsi_ssl_get_cert_chain_contents(
1216         peer_chain, &peer->properties[peer->property_count]);
1217     if (result == TSI_OK) peer->property_count++;
1218   }
1219   if (alpn_selected != nullptr) {
1220     result = tsi_construct_string_peer_property(
1221         TSI_SSL_ALPN_SELECTED_PROTOCOL,
1222         reinterpret_cast<const char*>(alpn_selected), alpn_selected_len,
1223         &peer->properties[peer->property_count]);
1224     if (result != TSI_OK) return result;
1225     peer->property_count++;
1226   }
1227   // Add security_level peer property.
1228   result = tsi_construct_string_peer_property_from_cstring(
1229       TSI_SECURITY_LEVEL_PEER_PROPERTY,
1230       tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY),
1231       &peer->properties[peer->property_count]);
1232   if (result != TSI_OK) return result;
1233   peer->property_count++;
1234 
1235   const char* session_reused = SSL_session_reused(impl->ssl) ? "true" : "false";
1236   result = tsi_construct_string_peer_property_from_cstring(
1237       TSI_SSL_SESSION_REUSED_PEER_PROPERTY, session_reused,
1238       &peer->properties[peer->property_count]);
1239   if (result != TSI_OK) return result;
1240   peer->property_count++;
1241 
1242   if (verified_root_cert != nullptr) {
1243     result = peer_property_from_x509_subject(
1244         verified_root_cert, &peer->properties[peer->property_count], true);
1245     if (result != TSI_OK) {
1246       gpr_log(GPR_DEBUG,
1247               "Problem extracting subject from verified_root_cert. result: %d",
1248               static_cast<int>(result));
1249     }
1250     peer->property_count++;
1251   }
1252 
1253   return result;
1254 }
1255 
ssl_handshaker_result_get_frame_protector_type(const tsi_handshaker_result *,tsi_frame_protector_type * frame_protector_type)1256 static tsi_result ssl_handshaker_result_get_frame_protector_type(
1257     const tsi_handshaker_result* /*self*/,
1258     tsi_frame_protector_type* frame_protector_type) {
1259   *frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL;
1260   return TSI_OK;
1261 }
1262 
ssl_handshaker_result_create_frame_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_frame_protector ** protector)1263 static tsi_result ssl_handshaker_result_create_frame_protector(
1264     const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
1265     tsi_frame_protector** protector) {
1266   size_t actual_max_output_protected_frame_size =
1267       TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND;
1268   tsi_ssl_handshaker_result* impl =
1269       reinterpret_cast<tsi_ssl_handshaker_result*>(
1270           const_cast<tsi_handshaker_result*>(self));
1271   tsi_ssl_frame_protector* protector_impl =
1272       static_cast<tsi_ssl_frame_protector*>(
1273           gpr_zalloc(sizeof(*protector_impl)));
1274 
1275   if (max_output_protected_frame_size != nullptr) {
1276     if (*max_output_protected_frame_size >
1277         TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND) {
1278       *max_output_protected_frame_size =
1279           TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND;
1280     } else if (*max_output_protected_frame_size <
1281                TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND) {
1282       *max_output_protected_frame_size =
1283           TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND;
1284     }
1285     actual_max_output_protected_frame_size = *max_output_protected_frame_size;
1286   }
1287   protector_impl->buffer_size =
1288       actual_max_output_protected_frame_size - TSI_SSL_MAX_PROTECTION_OVERHEAD;
1289   protector_impl->buffer =
1290       static_cast<unsigned char*>(gpr_malloc(protector_impl->buffer_size));
1291   if (protector_impl->buffer == nullptr) {
1292     gpr_log(GPR_ERROR,
1293             "Could not allocated buffer for tsi_ssl_frame_protector.");
1294     gpr_free(protector_impl);
1295     return TSI_INTERNAL_ERROR;
1296   }
1297 
1298   // Transfer ownership of ssl and network_io to the frame protector.
1299   protector_impl->ssl = impl->ssl;
1300   impl->ssl = nullptr;
1301   protector_impl->network_io = impl->network_io;
1302   impl->network_io = nullptr;
1303   protector_impl->base.vtable = &frame_protector_vtable;
1304   *protector = &protector_impl->base;
1305   return TSI_OK;
1306 }
1307 
ssl_handshaker_result_get_unused_bytes(const tsi_handshaker_result * self,const unsigned char ** bytes,size_t * bytes_size)1308 static tsi_result ssl_handshaker_result_get_unused_bytes(
1309     const tsi_handshaker_result* self, const unsigned char** bytes,
1310     size_t* bytes_size) {
1311   const tsi_ssl_handshaker_result* impl =
1312       reinterpret_cast<const tsi_ssl_handshaker_result*>(self);
1313   *bytes_size = impl->unused_bytes_size;
1314   *bytes = impl->unused_bytes;
1315   return TSI_OK;
1316 }
1317 
ssl_handshaker_result_destroy(tsi_handshaker_result * self)1318 static void ssl_handshaker_result_destroy(tsi_handshaker_result* self) {
1319   tsi_ssl_handshaker_result* impl =
1320       reinterpret_cast<tsi_ssl_handshaker_result*>(self);
1321   SSL_free(impl->ssl);
1322   BIO_free(impl->network_io);
1323   gpr_free(impl->unused_bytes);
1324   gpr_free(impl);
1325 }
1326 
1327 static const tsi_handshaker_result_vtable handshaker_result_vtable = {
1328     ssl_handshaker_result_extract_peer,
1329     ssl_handshaker_result_get_frame_protector_type,
1330     nullptr,  // create_zero_copy_grpc_protector
1331     ssl_handshaker_result_create_frame_protector,
1332     ssl_handshaker_result_get_unused_bytes,
1333     ssl_handshaker_result_destroy,
1334 };
1335 
ssl_handshaker_result_create(tsi_ssl_handshaker * handshaker,unsigned char * unused_bytes,size_t unused_bytes_size,tsi_handshaker_result ** handshaker_result,std::string * error)1336 static tsi_result ssl_handshaker_result_create(
1337     tsi_ssl_handshaker* handshaker, unsigned char* unused_bytes,
1338     size_t unused_bytes_size, tsi_handshaker_result** handshaker_result,
1339     std::string* error) {
1340   if (handshaker == nullptr || handshaker_result == nullptr ||
1341       (unused_bytes_size > 0 && unused_bytes == nullptr)) {
1342     if (error != nullptr) *error = "invalid argument";
1343     return TSI_INVALID_ARGUMENT;
1344   }
1345   tsi_ssl_handshaker_result* result =
1346       grpc_core::Zalloc<tsi_ssl_handshaker_result>();
1347   result->base.vtable = &handshaker_result_vtable;
1348   // Transfer ownership of ssl and network_io to the handshaker result.
1349   result->ssl = handshaker->ssl;
1350   handshaker->ssl = nullptr;
1351   result->network_io = handshaker->network_io;
1352   handshaker->network_io = nullptr;
1353   // Transfer ownership of |unused_bytes| to the handshaker result.
1354   result->unused_bytes = unused_bytes;
1355   result->unused_bytes_size = unused_bytes_size;
1356   *handshaker_result = &result->base;
1357   return TSI_OK;
1358 }
1359 
1360 // --- tsi_handshaker methods implementation. ---
1361 
ssl_handshaker_get_bytes_to_send_to_peer(tsi_ssl_handshaker * impl,unsigned char * bytes,size_t * bytes_size,std::string * error)1362 static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(
1363     tsi_ssl_handshaker* impl, unsigned char* bytes, size_t* bytes_size,
1364     std::string* error) {
1365   int bytes_read_from_ssl = 0;
1366   if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) {
1367     if (error != nullptr) *error = "invalid argument";
1368     return TSI_INVALID_ARGUMENT;
1369   }
1370   GPR_ASSERT(*bytes_size <= INT_MAX);
1371   bytes_read_from_ssl =
1372       BIO_read(impl->network_io, bytes, static_cast<int>(*bytes_size));
1373   if (bytes_read_from_ssl < 0) {
1374     *bytes_size = 0;
1375     if (!BIO_should_retry(impl->network_io)) {
1376       if (error != nullptr) *error = "error reading from BIO";
1377       impl->result = TSI_INTERNAL_ERROR;
1378       return impl->result;
1379     } else {
1380       return TSI_OK;
1381     }
1382   }
1383   *bytes_size = static_cast<size_t>(bytes_read_from_ssl);
1384   return BIO_pending(impl->network_io) == 0 ? TSI_OK : TSI_INCOMPLETE_DATA;
1385 }
1386 
ssl_handshaker_get_result(tsi_ssl_handshaker * impl)1387 static tsi_result ssl_handshaker_get_result(tsi_ssl_handshaker* impl) {
1388   if ((impl->result == TSI_HANDSHAKE_IN_PROGRESS) &&
1389       SSL_is_init_finished(impl->ssl)) {
1390     impl->result = TSI_OK;
1391   }
1392   return impl->result;
1393 }
1394 
ssl_handshaker_do_handshake(tsi_ssl_handshaker * impl,std::string * error)1395 static tsi_result ssl_handshaker_do_handshake(tsi_ssl_handshaker* impl,
1396                                               std::string* error) {
1397   if (ssl_handshaker_get_result(impl) != TSI_HANDSHAKE_IN_PROGRESS) {
1398     impl->result = TSI_OK;
1399     return impl->result;
1400   } else {
1401     ERR_clear_error();
1402     // Get ready to get some bytes from SSL.
1403     int ssl_result = SSL_do_handshake(impl->ssl);
1404     ssl_result = SSL_get_error(impl->ssl, ssl_result);
1405     switch (ssl_result) {
1406       case SSL_ERROR_WANT_READ:
1407         if (BIO_pending(impl->network_io) == 0) {
1408           // We need more data.
1409           return TSI_INCOMPLETE_DATA;
1410         } else {
1411           return TSI_OK;
1412         }
1413       case SSL_ERROR_NONE:
1414         return TSI_OK;
1415       case SSL_ERROR_WANT_WRITE:
1416         return TSI_DRAIN_BUFFER;
1417       default: {
1418         char err_str[256];
1419         ERR_error_string_n(ERR_get_error(), err_str, sizeof(err_str));
1420         gpr_log(GPR_ERROR, "Handshake failed with fatal error %s: %s.",
1421                 grpc_core::SslErrorString(ssl_result), err_str);
1422         if (error != nullptr) {
1423           *error = absl::StrCat(grpc_core::SslErrorString(ssl_result), ": ",
1424                                 err_str);
1425         }
1426         impl->result = TSI_PROTOCOL_FAILURE;
1427         return impl->result;
1428       }
1429     }
1430   }
1431 }
1432 
ssl_handshaker_process_bytes_from_peer(tsi_ssl_handshaker * impl,const unsigned char * bytes,size_t * bytes_size,std::string * error)1433 static tsi_result ssl_handshaker_process_bytes_from_peer(
1434     tsi_ssl_handshaker* impl, const unsigned char* bytes, size_t* bytes_size,
1435     std::string* error) {
1436   int bytes_written_into_ssl_size = 0;
1437   if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) {
1438     if (error != nullptr) *error = "invalid argument";
1439     return TSI_INVALID_ARGUMENT;
1440   }
1441   GPR_ASSERT(*bytes_size <= INT_MAX);
1442   bytes_written_into_ssl_size =
1443       BIO_write(impl->network_io, bytes, static_cast<int>(*bytes_size));
1444   if (bytes_written_into_ssl_size < 0) {
1445     gpr_log(GPR_ERROR, "Could not write to memory BIO.");
1446     if (error != nullptr) *error = "could not write to memory BIO";
1447     impl->result = TSI_INTERNAL_ERROR;
1448     return impl->result;
1449   }
1450   *bytes_size = static_cast<size_t>(bytes_written_into_ssl_size);
1451   return ssl_handshaker_do_handshake(impl, error);
1452 }
1453 
ssl_handshaker_destroy(tsi_handshaker * self)1454 static void ssl_handshaker_destroy(tsi_handshaker* self) {
1455   tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
1456   SSL_free(impl->ssl);
1457   BIO_free(impl->network_io);
1458   gpr_free(impl->outgoing_bytes_buffer);
1459   tsi_ssl_handshaker_factory_unref(impl->factory_ref);
1460   gpr_free(impl);
1461 }
1462 
1463 // Removes the bytes remaining in |impl->SSL|'s read BIO and writes them to
1464 // |bytes_remaining|.
ssl_bytes_remaining(tsi_ssl_handshaker * impl,unsigned char ** bytes_remaining,size_t * bytes_remaining_size,std::string * error)1465 static tsi_result ssl_bytes_remaining(tsi_ssl_handshaker* impl,
1466                                       unsigned char** bytes_remaining,
1467                                       size_t* bytes_remaining_size,
1468                                       std::string* error) {
1469   if (impl == nullptr || bytes_remaining == nullptr ||
1470       bytes_remaining_size == nullptr) {
1471     if (error != nullptr) *error = "invalid argument";
1472     return TSI_INVALID_ARGUMENT;
1473   }
1474   // Atempt to read all of the bytes in SSL's read BIO. These bytes should
1475   // contain application data records that were appended to a handshake record
1476   // containing the ClientFinished or ServerFinished message.
1477   size_t bytes_in_ssl = BIO_pending(SSL_get_rbio(impl->ssl));
1478   if (bytes_in_ssl == 0) return TSI_OK;
1479   *bytes_remaining = static_cast<uint8_t*>(gpr_malloc(bytes_in_ssl));
1480   int bytes_read = BIO_read(SSL_get_rbio(impl->ssl), *bytes_remaining,
1481                             static_cast<int>(bytes_in_ssl));
1482   // If an unexpected number of bytes were read, return an error status and free
1483   // all of the bytes that were read.
1484   if (bytes_read < 0 || static_cast<size_t>(bytes_read) != bytes_in_ssl) {
1485     gpr_log(GPR_ERROR,
1486             "Failed to read the expected number of bytes from SSL object.");
1487     gpr_free(*bytes_remaining);
1488     *bytes_remaining = nullptr;
1489     if (error != nullptr) {
1490       *error = "Failed to read the expected number of bytes from SSL object.";
1491     }
1492     return TSI_INTERNAL_ERROR;
1493   }
1494   *bytes_remaining_size = static_cast<size_t>(bytes_read);
1495   return TSI_OK;
1496 }
1497 
1498 // Write handshake data received from SSL to an unbound output buffer.
1499 // By doing that, we drain SSL bio buffer used to hold handshake data.
1500 // This API needs to be repeatedly called until all handshake data are
1501 // received from SSL.
ssl_handshaker_write_output_buffer(tsi_handshaker * self,size_t * bytes_written,std::string * error)1502 static tsi_result ssl_handshaker_write_output_buffer(tsi_handshaker* self,
1503                                                      size_t* bytes_written,
1504                                                      std::string* error) {
1505   tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
1506   tsi_result status = TSI_OK;
1507   size_t offset = *bytes_written;
1508   do {
1509     size_t to_send_size = impl->outgoing_bytes_buffer_size - offset;
1510     status = ssl_handshaker_get_bytes_to_send_to_peer(
1511         impl, impl->outgoing_bytes_buffer + offset, &to_send_size, error);
1512     offset += to_send_size;
1513     if (status == TSI_INCOMPLETE_DATA) {
1514       impl->outgoing_bytes_buffer_size *= 2;
1515       impl->outgoing_bytes_buffer = static_cast<unsigned char*>(gpr_realloc(
1516           impl->outgoing_bytes_buffer, impl->outgoing_bytes_buffer_size));
1517     }
1518   } while (status == TSI_INCOMPLETE_DATA);
1519   *bytes_written = offset;
1520   return status;
1521 }
1522 
ssl_handshaker_next(tsi_handshaker * self,const unsigned char * received_bytes,size_t received_bytes_size,const unsigned char ** bytes_to_send,size_t * bytes_to_send_size,tsi_handshaker_result ** handshaker_result,tsi_handshaker_on_next_done_cb,void *,std::string * error)1523 static tsi_result ssl_handshaker_next(tsi_handshaker* self,
1524                                       const unsigned char* received_bytes,
1525                                       size_t received_bytes_size,
1526                                       const unsigned char** bytes_to_send,
1527                                       size_t* bytes_to_send_size,
1528                                       tsi_handshaker_result** handshaker_result,
1529                                       tsi_handshaker_on_next_done_cb /*cb*/,
1530                                       void* /*user_data*/, std::string* error) {
1531   // Input sanity check.
1532   if ((received_bytes_size > 0 && received_bytes == nullptr) ||
1533       bytes_to_send == nullptr || bytes_to_send_size == nullptr ||
1534       handshaker_result == nullptr) {
1535     if (error != nullptr) *error = "invalid argument";
1536     return TSI_INVALID_ARGUMENT;
1537   }
1538   // If there are received bytes, process them first.
1539   tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
1540   tsi_result status = TSI_OK;
1541   size_t bytes_consumed = received_bytes_size;
1542   size_t bytes_written = 0;
1543   if (received_bytes_size > 0) {
1544     status = ssl_handshaker_process_bytes_from_peer(impl, received_bytes,
1545                                                     &bytes_consumed, error);
1546     while (status == TSI_DRAIN_BUFFER) {
1547       status = ssl_handshaker_write_output_buffer(self, &bytes_written, error);
1548       if (status != TSI_OK) return status;
1549       status = ssl_handshaker_do_handshake(impl, error);
1550     }
1551   }
1552   if (status != TSI_OK) return status;
1553   // Get bytes to send to the peer, if available.
1554   status = ssl_handshaker_write_output_buffer(self, &bytes_written, error);
1555   if (status != TSI_OK) return status;
1556   *bytes_to_send = impl->outgoing_bytes_buffer;
1557   *bytes_to_send_size = bytes_written;
1558   // If handshake completes, create tsi_handshaker_result.
1559   if (ssl_handshaker_get_result(impl) == TSI_HANDSHAKE_IN_PROGRESS) {
1560     *handshaker_result = nullptr;
1561   } else {
1562     // Any bytes that remain in |impl->ssl|'s read BIO after the handshake is
1563     // complete must be extracted and set to the unused bytes of the handshaker
1564     // result. This indicates to the gRPC stack that there are bytes from the
1565     // peer that must be processed.
1566     unsigned char* unused_bytes = nullptr;
1567     size_t unused_bytes_size = 0;
1568     status =
1569         ssl_bytes_remaining(impl, &unused_bytes, &unused_bytes_size, error);
1570     if (status != TSI_OK) return status;
1571     if (unused_bytes_size > received_bytes_size) {
1572       gpr_log(GPR_ERROR, "More unused bytes than received bytes.");
1573       gpr_free(unused_bytes);
1574       if (error != nullptr) *error = "More unused bytes than received bytes.";
1575       return TSI_INTERNAL_ERROR;
1576     }
1577     status = ssl_handshaker_result_create(impl, unused_bytes, unused_bytes_size,
1578                                           handshaker_result, error);
1579     if (status == TSI_OK) {
1580       // Indicates that the handshake has completed and that a handshaker_result
1581       // has been created.
1582       self->handshaker_result_created = true;
1583     }
1584   }
1585   return status;
1586 }
1587 
1588 static const tsi_handshaker_vtable handshaker_vtable = {
1589     nullptr,  // get_bytes_to_send_to_peer -- deprecated
1590     nullptr,  // process_bytes_from_peer   -- deprecated
1591     nullptr,  // get_result                -- deprecated
1592     nullptr,  // extract_peer              -- deprecated
1593     nullptr,  // create_frame_protector    -- deprecated
1594     ssl_handshaker_destroy,
1595     ssl_handshaker_next,
1596     nullptr,  // shutdown
1597 };
1598 
1599 // --- tsi_ssl_handshaker_factory common methods. ---
1600 
tsi_ssl_handshaker_resume_session(SSL * ssl,tsi::SslSessionLRUCache * session_cache)1601 static void tsi_ssl_handshaker_resume_session(
1602     SSL* ssl, tsi::SslSessionLRUCache* session_cache) {
1603   const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1604   if (server_name == nullptr) {
1605     return;
1606   }
1607   tsi::SslSessionPtr session = session_cache->Get(server_name);
1608   if (session != nullptr) {
1609     // SSL_set_session internally increments reference counter.
1610     SSL_set_session(ssl, session.get());
1611   }
1612 }
1613 
create_tsi_ssl_handshaker(SSL_CTX * ctx,int is_client,const char * server_name_indication,size_t network_bio_buf_size,size_t ssl_bio_buf_size,tsi_ssl_handshaker_factory * factory,tsi_handshaker ** handshaker)1614 static tsi_result create_tsi_ssl_handshaker(SSL_CTX* ctx, int is_client,
1615                                             const char* server_name_indication,
1616                                             size_t network_bio_buf_size,
1617                                             size_t ssl_bio_buf_size,
1618                                             tsi_ssl_handshaker_factory* factory,
1619                                             tsi_handshaker** handshaker) {
1620   SSL* ssl = SSL_new(ctx);
1621   BIO* network_io = nullptr;
1622   BIO* ssl_io = nullptr;
1623   tsi_ssl_handshaker* impl = nullptr;
1624   *handshaker = nullptr;
1625   if (ctx == nullptr) {
1626     gpr_log(GPR_ERROR, "SSL Context is null. Should never happen.");
1627     return TSI_INTERNAL_ERROR;
1628   }
1629   if (ssl == nullptr) {
1630     return TSI_OUT_OF_RESOURCES;
1631   }
1632   SSL_set_info_callback(ssl, ssl_info_callback);
1633 
1634   if (!BIO_new_bio_pair(&network_io, network_bio_buf_size, &ssl_io,
1635                         ssl_bio_buf_size)) {
1636     gpr_log(GPR_ERROR, "BIO_new_bio_pair failed.");
1637     SSL_free(ssl);
1638     return TSI_OUT_OF_RESOURCES;
1639   }
1640   SSL_set_bio(ssl, ssl_io, ssl_io);
1641   if (is_client) {
1642     int ssl_result;
1643     SSL_set_connect_state(ssl);
1644     if (server_name_indication != nullptr) {
1645       if (!SSL_set_tlsext_host_name(ssl, server_name_indication)) {
1646         gpr_log(GPR_ERROR, "Invalid server name indication %s.",
1647                 server_name_indication);
1648         SSL_free(ssl);
1649         BIO_free(network_io);
1650         return TSI_INTERNAL_ERROR;
1651       }
1652     }
1653     tsi_ssl_client_handshaker_factory* client_factory =
1654         reinterpret_cast<tsi_ssl_client_handshaker_factory*>(factory);
1655     if (client_factory->session_cache != nullptr) {
1656       tsi_ssl_handshaker_resume_session(ssl,
1657                                         client_factory->session_cache.get());
1658     }
1659     ERR_clear_error();
1660     ssl_result = SSL_do_handshake(ssl);
1661     ssl_result = SSL_get_error(ssl, ssl_result);
1662     if (ssl_result != SSL_ERROR_WANT_READ) {
1663       gpr_log(GPR_ERROR,
1664               "Unexpected error received from first SSL_do_handshake call: %s",
1665               grpc_core::SslErrorString(ssl_result));
1666       SSL_free(ssl);
1667       BIO_free(network_io);
1668       return TSI_INTERNAL_ERROR;
1669     }
1670   } else {
1671     SSL_set_accept_state(ssl);
1672   }
1673 
1674   impl = grpc_core::Zalloc<tsi_ssl_handshaker>();
1675   impl->ssl = ssl;
1676   impl->network_io = network_io;
1677   impl->result = TSI_HANDSHAKE_IN_PROGRESS;
1678   impl->outgoing_bytes_buffer_size =
1679       TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE;
1680   impl->outgoing_bytes_buffer =
1681       static_cast<unsigned char*>(gpr_zalloc(impl->outgoing_bytes_buffer_size));
1682   impl->base.vtable = &handshaker_vtable;
1683   impl->factory_ref = tsi_ssl_handshaker_factory_ref(factory);
1684   *handshaker = &impl->base;
1685   return TSI_OK;
1686 }
1687 
select_protocol_list(const unsigned char ** out,unsigned char * outlen,const unsigned char * client_list,size_t client_list_len,const unsigned char * server_list,size_t server_list_len)1688 static int select_protocol_list(const unsigned char** out,
1689                                 unsigned char* outlen,
1690                                 const unsigned char* client_list,
1691                                 size_t client_list_len,
1692                                 const unsigned char* server_list,
1693                                 size_t server_list_len) {
1694   const unsigned char* client_current = client_list;
1695   while (static_cast<unsigned int>(client_current - client_list) <
1696          client_list_len) {
1697     unsigned char client_current_len = *(client_current++);
1698     const unsigned char* server_current = server_list;
1699     while ((server_current >= server_list) &&
1700            static_cast<uintptr_t>(server_current - server_list) <
1701                server_list_len) {
1702       unsigned char server_current_len = *(server_current++);
1703       if ((client_current_len == server_current_len) &&
1704           !memcmp(client_current, server_current, server_current_len)) {
1705         *out = server_current;
1706         *outlen = server_current_len;
1707         return SSL_TLSEXT_ERR_OK;
1708       }
1709       server_current += server_current_len;
1710     }
1711     client_current += client_current_len;
1712   }
1713   return SSL_TLSEXT_ERR_NOACK;
1714 }
1715 
1716 // --- tsi_ssl_client_handshaker_factory methods implementation. ---
1717 
tsi_ssl_client_handshaker_factory_create_handshaker(tsi_ssl_client_handshaker_factory * factory,const char * server_name_indication,size_t network_bio_buf_size,size_t ssl_bio_buf_size,tsi_handshaker ** handshaker)1718 tsi_result tsi_ssl_client_handshaker_factory_create_handshaker(
1719     tsi_ssl_client_handshaker_factory* factory,
1720     const char* server_name_indication, size_t network_bio_buf_size,
1721     size_t ssl_bio_buf_size, tsi_handshaker** handshaker) {
1722   return create_tsi_ssl_handshaker(
1723       factory->ssl_context, 1, server_name_indication, network_bio_buf_size,
1724       ssl_bio_buf_size, &factory->base, handshaker);
1725 }
1726 
tsi_ssl_client_handshaker_factory_unref(tsi_ssl_client_handshaker_factory * factory)1727 void tsi_ssl_client_handshaker_factory_unref(
1728     tsi_ssl_client_handshaker_factory* factory) {
1729   if (factory == nullptr) return;
1730   tsi_ssl_handshaker_factory_unref(&factory->base);
1731 }
1732 
tsi_ssl_client_handshaker_factory_destroy(tsi_ssl_handshaker_factory * factory)1733 static void tsi_ssl_client_handshaker_factory_destroy(
1734     tsi_ssl_handshaker_factory* factory) {
1735   if (factory == nullptr) return;
1736   tsi_ssl_client_handshaker_factory* self =
1737       reinterpret_cast<tsi_ssl_client_handshaker_factory*>(factory);
1738   if (self->ssl_context != nullptr) SSL_CTX_free(self->ssl_context);
1739   if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list);
1740   self->session_cache.reset();
1741   self->key_logger.reset();
1742   gpr_free(self);
1743 }
1744 
client_handshaker_factory_npn_callback(SSL *,unsigned char ** out,unsigned char * outlen,const unsigned char * in,unsigned int inlen,void * arg)1745 static int client_handshaker_factory_npn_callback(
1746     SSL* /*ssl*/, unsigned char** out, unsigned char* outlen,
1747     const unsigned char* in, unsigned int inlen, void* arg) {
1748   tsi_ssl_client_handshaker_factory* factory =
1749       static_cast<tsi_ssl_client_handshaker_factory*>(arg);
1750   return select_protocol_list(const_cast<const unsigned char**>(out), outlen,
1751                               factory->alpn_protocol_list,
1752                               factory->alpn_protocol_list_length, in, inlen);
1753 }
1754 
1755 // --- tsi_ssl_server_handshaker_factory methods implementation. ---
1756 
tsi_ssl_server_handshaker_factory_create_handshaker(tsi_ssl_server_handshaker_factory * factory,size_t network_bio_buf_size,size_t ssl_bio_buf_size,tsi_handshaker ** handshaker)1757 tsi_result tsi_ssl_server_handshaker_factory_create_handshaker(
1758     tsi_ssl_server_handshaker_factory* factory, size_t network_bio_buf_size,
1759     size_t ssl_bio_buf_size, tsi_handshaker** handshaker) {
1760   if (factory->ssl_context_count == 0) return TSI_INVALID_ARGUMENT;
1761   // Create the handshaker with the first context. We will switch if needed
1762   // because of SNI in ssl_server_handshaker_factory_servername_callback.
1763   return create_tsi_ssl_handshaker(factory->ssl_contexts[0], 0, nullptr,
1764                                    network_bio_buf_size, ssl_bio_buf_size,
1765                                    &factory->base, handshaker);
1766 }
1767 
tsi_ssl_server_handshaker_factory_unref(tsi_ssl_server_handshaker_factory * factory)1768 void tsi_ssl_server_handshaker_factory_unref(
1769     tsi_ssl_server_handshaker_factory* factory) {
1770   if (factory == nullptr) return;
1771   tsi_ssl_handshaker_factory_unref(&factory->base);
1772 }
1773 
tsi_ssl_server_handshaker_factory_destroy(tsi_ssl_handshaker_factory * factory)1774 static void tsi_ssl_server_handshaker_factory_destroy(
1775     tsi_ssl_handshaker_factory* factory) {
1776   if (factory == nullptr) return;
1777   tsi_ssl_server_handshaker_factory* self =
1778       reinterpret_cast<tsi_ssl_server_handshaker_factory*>(factory);
1779   size_t i;
1780   for (i = 0; i < self->ssl_context_count; i++) {
1781     if (self->ssl_contexts[i] != nullptr) {
1782       SSL_CTX_free(self->ssl_contexts[i]);
1783       tsi_peer_destruct(&self->ssl_context_x509_subject_names[i]);
1784     }
1785   }
1786   if (self->ssl_contexts != nullptr) gpr_free(self->ssl_contexts);
1787   if (self->ssl_context_x509_subject_names != nullptr) {
1788     gpr_free(self->ssl_context_x509_subject_names);
1789   }
1790   if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list);
1791   self->key_logger.reset();
1792   gpr_free(self);
1793 }
1794 
does_entry_match_name(absl::string_view entry,absl::string_view name)1795 static int does_entry_match_name(absl::string_view entry,
1796                                  absl::string_view name) {
1797   if (entry.empty()) return 0;
1798 
1799   // Take care of '.' terminations.
1800   if (name.back() == '.') {
1801     name.remove_suffix(1);
1802   }
1803   if (entry.back() == '.') {
1804     entry.remove_suffix(1);
1805     if (entry.empty()) return 0;
1806   }
1807 
1808   if (absl::EqualsIgnoreCase(name, entry)) {
1809     return 1;  // Perfect match.
1810   }
1811   if (entry.front() != '*') return 0;
1812 
1813   // Wildchar subdomain matching.
1814   if (entry.size() < 3 || entry[1] != '.') {  // At least *.x
1815     gpr_log(GPR_ERROR, "Invalid wildchar entry.");
1816     return 0;
1817   }
1818   size_t name_subdomain_pos = name.find('.');
1819   if (name_subdomain_pos == absl::string_view::npos) return 0;
1820   if (name_subdomain_pos >= name.size() - 2) return 0;
1821   absl::string_view name_subdomain =
1822       name.substr(name_subdomain_pos + 1);  // Starts after the dot.
1823   entry.remove_prefix(2);                   // Remove *.
1824   size_t dot = name_subdomain.find('.');
1825   if (dot == absl::string_view::npos || dot == name_subdomain.size() - 1) {
1826     gpr_log(GPR_ERROR, "Invalid toplevel subdomain: %s",
1827             std::string(name_subdomain).c_str());
1828     return 0;
1829   }
1830   if (name_subdomain.back() == '.') {
1831     name_subdomain.remove_suffix(1);
1832   }
1833   return !entry.empty() && absl::EqualsIgnoreCase(name_subdomain, entry);
1834 }
1835 
ssl_server_handshaker_factory_servername_callback(SSL * ssl,int *,void * arg)1836 static int ssl_server_handshaker_factory_servername_callback(SSL* ssl,
1837                                                              int* /*ap*/,
1838                                                              void* arg) {
1839   tsi_ssl_server_handshaker_factory* impl =
1840       static_cast<tsi_ssl_server_handshaker_factory*>(arg);
1841   size_t i = 0;
1842   const char* servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1843   if (servername == nullptr || strlen(servername) == 0) {
1844     return SSL_TLSEXT_ERR_NOACK;
1845   }
1846 
1847   for (i = 0; i < impl->ssl_context_count; i++) {
1848     if (tsi_ssl_peer_matches_name(&impl->ssl_context_x509_subject_names[i],
1849                                   servername)) {
1850       SSL_set_SSL_CTX(ssl, impl->ssl_contexts[i]);
1851       return SSL_TLSEXT_ERR_OK;
1852     }
1853   }
1854   gpr_log(GPR_ERROR, "No match found for server name: %s.", servername);
1855   return SSL_TLSEXT_ERR_NOACK;
1856 }
1857 
1858 #if TSI_OPENSSL_ALPN_SUPPORT
server_handshaker_factory_alpn_callback(SSL *,const unsigned char ** out,unsigned char * outlen,const unsigned char * in,unsigned int inlen,void * arg)1859 static int server_handshaker_factory_alpn_callback(
1860     SSL* /*ssl*/, const unsigned char** out, unsigned char* outlen,
1861     const unsigned char* in, unsigned int inlen, void* arg) {
1862   tsi_ssl_server_handshaker_factory* factory =
1863       static_cast<tsi_ssl_server_handshaker_factory*>(arg);
1864   return select_protocol_list(out, outlen, in, inlen,
1865                               factory->alpn_protocol_list,
1866                               factory->alpn_protocol_list_length);
1867 }
1868 #endif  // TSI_OPENSSL_ALPN_SUPPORT
1869 
server_handshaker_factory_npn_advertised_callback(SSL *,const unsigned char ** out,unsigned int * outlen,void * arg)1870 static int server_handshaker_factory_npn_advertised_callback(
1871     SSL* /*ssl*/, const unsigned char** out, unsigned int* outlen, void* arg) {
1872   tsi_ssl_server_handshaker_factory* factory =
1873       static_cast<tsi_ssl_server_handshaker_factory*>(arg);
1874   *out = factory->alpn_protocol_list;
1875   GPR_ASSERT(factory->alpn_protocol_list_length <= UINT_MAX);
1876   *outlen = static_cast<unsigned int>(factory->alpn_protocol_list_length);
1877   return SSL_TLSEXT_ERR_OK;
1878 }
1879 
1880 /// This callback is called when new \a session is established and ready to
1881 /// be cached. This session can be reused for new connections to similar
1882 /// servers at later point of time.
1883 /// It's intended to be used with SSL_CTX_sess_set_new_cb function.
1884 ///
1885 /// It returns 1 if callback takes ownership over \a session and 0 otherwise.
server_handshaker_factory_new_session_callback(SSL * ssl,SSL_SESSION * session)1886 static int server_handshaker_factory_new_session_callback(
1887     SSL* ssl, SSL_SESSION* session) {
1888   SSL_CTX* ssl_context = SSL_get_SSL_CTX(ssl);
1889   if (ssl_context == nullptr) {
1890     return 0;
1891   }
1892   void* arg = SSL_CTX_get_ex_data(ssl_context, g_ssl_ctx_ex_factory_index);
1893   tsi_ssl_client_handshaker_factory* factory =
1894       static_cast<tsi_ssl_client_handshaker_factory*>(arg);
1895   const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1896   if (server_name == nullptr) {
1897     return 0;
1898   }
1899   factory->session_cache->Put(server_name, tsi::SslSessionPtr(session));
1900   // Return 1 to indicate transferred ownership over the given session.
1901   return 1;
1902 }
1903 
1904 /// This callback is invoked at client or server when ssl/tls handshakes
1905 /// complete and keylogging is enabled.
1906 template <typename T>
ssl_keylogging_callback(const SSL * ssl,const char * info)1907 static void ssl_keylogging_callback(const SSL* ssl, const char* info) {
1908   SSL_CTX* ssl_context = SSL_get_SSL_CTX(ssl);
1909   GPR_ASSERT(ssl_context != nullptr);
1910   void* arg = SSL_CTX_get_ex_data(ssl_context, g_ssl_ctx_ex_factory_index);
1911   T* factory = static_cast<T*>(arg);
1912   factory->key_logger->LogSessionKeys(ssl_context, info);
1913 }
1914 
1915 // --- tsi_ssl_handshaker_factory constructors. ---
1916 
1917 static tsi_ssl_handshaker_factory_vtable client_handshaker_factory_vtable = {
1918     tsi_ssl_client_handshaker_factory_destroy};
1919 
tsi_create_ssl_client_handshaker_factory(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pair,const char * pem_root_certs,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_client_handshaker_factory ** factory)1920 tsi_result tsi_create_ssl_client_handshaker_factory(
1921     const tsi_ssl_pem_key_cert_pair* pem_key_cert_pair,
1922     const char* pem_root_certs, const char* cipher_suites,
1923     const char** alpn_protocols, uint16_t num_alpn_protocols,
1924     tsi_ssl_client_handshaker_factory** factory) {
1925   tsi_ssl_client_handshaker_options options;
1926   options.pem_key_cert_pair = pem_key_cert_pair;
1927   options.pem_root_certs = pem_root_certs;
1928   options.cipher_suites = cipher_suites;
1929   options.alpn_protocols = alpn_protocols;
1930   options.num_alpn_protocols = num_alpn_protocols;
1931   return tsi_create_ssl_client_handshaker_factory_with_options(&options,
1932                                                                factory);
1933 }
1934 
tsi_create_ssl_client_handshaker_factory_with_options(const tsi_ssl_client_handshaker_options * options,tsi_ssl_client_handshaker_factory ** factory)1935 tsi_result tsi_create_ssl_client_handshaker_factory_with_options(
1936     const tsi_ssl_client_handshaker_options* options,
1937     tsi_ssl_client_handshaker_factory** factory) {
1938   SSL_CTX* ssl_context = nullptr;
1939   tsi_ssl_client_handshaker_factory* impl = nullptr;
1940   tsi_result result = TSI_OK;
1941 
1942   gpr_once_init(&g_init_openssl_once, init_openssl);
1943 
1944   if (factory == nullptr) return TSI_INVALID_ARGUMENT;
1945   *factory = nullptr;
1946   if (options->pem_root_certs == nullptr && options->root_store == nullptr) {
1947     return TSI_INVALID_ARGUMENT;
1948   }
1949 
1950 #if OPENSSL_VERSION_NUMBER >= 0x10100000
1951   ssl_context = SSL_CTX_new(TLS_method());
1952 #else
1953   ssl_context = SSL_CTX_new(TLSv1_2_method());
1954 #endif
1955   if (ssl_context == nullptr) {
1956     grpc_core::LogSslErrorStack();
1957     gpr_log(GPR_ERROR, "Could not create ssl context.");
1958     return TSI_INVALID_ARGUMENT;
1959   }
1960 
1961   result = tsi_set_min_and_max_tls_versions(
1962       ssl_context, options->min_tls_version, options->max_tls_version);
1963   if (result != TSI_OK) return result;
1964 
1965   impl = static_cast<tsi_ssl_client_handshaker_factory*>(
1966       gpr_zalloc(sizeof(*impl)));
1967   tsi_ssl_handshaker_factory_init(&impl->base);
1968   impl->base.vtable = &client_handshaker_factory_vtable;
1969   impl->ssl_context = ssl_context;
1970   if (options->session_cache != nullptr) {
1971     // Unref is called manually on factory destruction.
1972     impl->session_cache =
1973         reinterpret_cast<tsi::SslSessionLRUCache*>(options->session_cache)
1974             ->Ref();
1975     SSL_CTX_sess_set_new_cb(ssl_context,
1976                             server_handshaker_factory_new_session_callback);
1977     SSL_CTX_set_session_cache_mode(ssl_context, SSL_SESS_CACHE_CLIENT);
1978   }
1979 
1980 #if OPENSSL_VERSION_NUMBER >= 0x10101000 && !defined(LIBRESSL_VERSION_NUMBER)
1981   if (options->key_logger != nullptr) {
1982     impl->key_logger = options->key_logger->Ref();
1983     // SSL_CTX_set_keylog_callback is set here to register callback
1984     // when ssl/tls handshakes complete.
1985     SSL_CTX_set_keylog_callback(
1986         ssl_context,
1987         ssl_keylogging_callback<tsi_ssl_client_handshaker_factory>);
1988   }
1989 #endif
1990 
1991   if (options->session_cache != nullptr || options->key_logger != nullptr) {
1992     // Need to set factory at g_ssl_ctx_ex_factory_index
1993     SSL_CTX_set_ex_data(ssl_context, g_ssl_ctx_ex_factory_index, impl);
1994   }
1995 
1996   do {
1997     result = populate_ssl_context(ssl_context, options->pem_key_cert_pair,
1998                                   options->cipher_suites);
1999     if (result != TSI_OK) break;
2000 
2001 #if OPENSSL_VERSION_NUMBER >= 0x10100000
2002     // X509_STORE_up_ref is only available since OpenSSL 1.1.
2003     if (options->root_store != nullptr) {
2004       X509_STORE_up_ref(options->root_store->store);
2005       SSL_CTX_set_cert_store(ssl_context, options->root_store->store);
2006     }
2007 #endif
2008     if (OPENSSL_VERSION_NUMBER < 0x10100000 || options->root_store == nullptr) {
2009       result = ssl_ctx_load_verification_certs(
2010           ssl_context, options->pem_root_certs, strlen(options->pem_root_certs),
2011           nullptr);
2012       if (result != TSI_OK) {
2013         gpr_log(GPR_ERROR, "Cannot load server root certificates.");
2014         break;
2015       }
2016     }
2017 
2018     if (options->num_alpn_protocols != 0) {
2019       result = build_alpn_protocol_name_list(
2020           options->alpn_protocols, options->num_alpn_protocols,
2021           &impl->alpn_protocol_list, &impl->alpn_protocol_list_length);
2022       if (result != TSI_OK) {
2023         gpr_log(GPR_ERROR, "Building alpn list failed with error %s.",
2024                 tsi_result_to_string(result));
2025         break;
2026       }
2027 #if TSI_OPENSSL_ALPN_SUPPORT
2028       GPR_ASSERT(impl->alpn_protocol_list_length < UINT_MAX);
2029       if (SSL_CTX_set_alpn_protos(
2030               ssl_context, impl->alpn_protocol_list,
2031               static_cast<unsigned int>(impl->alpn_protocol_list_length))) {
2032         gpr_log(GPR_ERROR, "Could not set alpn protocol list to context.");
2033         result = TSI_INVALID_ARGUMENT;
2034         break;
2035       }
2036 #endif  // TSI_OPENSSL_ALPN_SUPPORT
2037       SSL_CTX_set_next_proto_select_cb(
2038           ssl_context, client_handshaker_factory_npn_callback, impl);
2039     }
2040   } while (false);
2041   if (result != TSI_OK) {
2042     tsi_ssl_handshaker_factory_unref(&impl->base);
2043     return result;
2044   }
2045   if (options->skip_server_certificate_verification) {
2046     SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, NullVerifyCallback);
2047   } else {
2048     SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, RootCertExtractCallback);
2049   }
2050 
2051 #if OPENSSL_VERSION_NUMBER >= 0x10100000
2052   if (options->crl_directory != nullptr &&
2053       strcmp(options->crl_directory, "") != 0) {
2054     gpr_log(GPR_INFO, "enabling client CRL checking with path: %s",
2055             options->crl_directory);
2056     X509_STORE* cert_store = SSL_CTX_get_cert_store(ssl_context);
2057     X509_STORE_set_verify_cb(cert_store, verify_cb);
2058     if (!X509_STORE_load_locations(cert_store, nullptr,
2059                                    options->crl_directory)) {
2060       gpr_log(GPR_ERROR, "Failed to load CRL File from directory.");
2061     } else {
2062       X509_VERIFY_PARAM* param = X509_STORE_get0_param(cert_store);
2063       X509_VERIFY_PARAM_set_flags(
2064           param, X509_V_FLAG_CRL_CHECK | X509_V_FLAG_CRL_CHECK_ALL);
2065       gpr_log(GPR_INFO, "enabled client side CRL checking.");
2066     }
2067   }
2068 #endif
2069 
2070   *factory = impl;
2071   return TSI_OK;
2072 }
2073 
2074 static tsi_ssl_handshaker_factory_vtable server_handshaker_factory_vtable = {
2075     tsi_ssl_server_handshaker_factory_destroy};
2076 
tsi_create_ssl_server_handshaker_factory(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pairs,size_t num_key_cert_pairs,const char * pem_client_root_certs,int force_client_auth,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_server_handshaker_factory ** factory)2077 tsi_result tsi_create_ssl_server_handshaker_factory(
2078     const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs,
2079     size_t num_key_cert_pairs, const char* pem_client_root_certs,
2080     int force_client_auth, const char* cipher_suites,
2081     const char** alpn_protocols, uint16_t num_alpn_protocols,
2082     tsi_ssl_server_handshaker_factory** factory) {
2083   return tsi_create_ssl_server_handshaker_factory_ex(
2084       pem_key_cert_pairs, num_key_cert_pairs, pem_client_root_certs,
2085       force_client_auth ? TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY
2086                         : TSI_DONT_REQUEST_CLIENT_CERTIFICATE,
2087       cipher_suites, alpn_protocols, num_alpn_protocols, factory);
2088 }
2089 
tsi_create_ssl_server_handshaker_factory_ex(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pairs,size_t num_key_cert_pairs,const char * pem_client_root_certs,tsi_client_certificate_request_type client_certificate_request,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_server_handshaker_factory ** factory)2090 tsi_result tsi_create_ssl_server_handshaker_factory_ex(
2091     const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs,
2092     size_t num_key_cert_pairs, const char* pem_client_root_certs,
2093     tsi_client_certificate_request_type client_certificate_request,
2094     const char* cipher_suites, const char** alpn_protocols,
2095     uint16_t num_alpn_protocols, tsi_ssl_server_handshaker_factory** factory) {
2096   tsi_ssl_server_handshaker_options options;
2097   options.pem_key_cert_pairs = pem_key_cert_pairs;
2098   options.num_key_cert_pairs = num_key_cert_pairs;
2099   options.pem_client_root_certs = pem_client_root_certs;
2100   options.client_certificate_request = client_certificate_request;
2101   options.cipher_suites = cipher_suites;
2102   options.alpn_protocols = alpn_protocols;
2103   options.num_alpn_protocols = num_alpn_protocols;
2104   return tsi_create_ssl_server_handshaker_factory_with_options(&options,
2105                                                                factory);
2106 }
2107 
tsi_create_ssl_server_handshaker_factory_with_options(const tsi_ssl_server_handshaker_options * options,tsi_ssl_server_handshaker_factory ** factory)2108 tsi_result tsi_create_ssl_server_handshaker_factory_with_options(
2109     const tsi_ssl_server_handshaker_options* options,
2110     tsi_ssl_server_handshaker_factory** factory) {
2111   tsi_ssl_server_handshaker_factory* impl = nullptr;
2112   tsi_result result = TSI_OK;
2113   size_t i = 0;
2114 
2115   gpr_once_init(&g_init_openssl_once, init_openssl);
2116 
2117   if (factory == nullptr) return TSI_INVALID_ARGUMENT;
2118   *factory = nullptr;
2119   if (options->num_key_cert_pairs == 0 ||
2120       options->pem_key_cert_pairs == nullptr) {
2121     return TSI_INVALID_ARGUMENT;
2122   }
2123 
2124   impl = static_cast<tsi_ssl_server_handshaker_factory*>(
2125       gpr_zalloc(sizeof(*impl)));
2126   tsi_ssl_handshaker_factory_init(&impl->base);
2127   impl->base.vtable = &server_handshaker_factory_vtable;
2128 
2129   impl->ssl_contexts = static_cast<SSL_CTX**>(
2130       gpr_zalloc(options->num_key_cert_pairs * sizeof(SSL_CTX*)));
2131   impl->ssl_context_x509_subject_names = static_cast<tsi_peer*>(
2132       gpr_zalloc(options->num_key_cert_pairs * sizeof(tsi_peer)));
2133   if (impl->ssl_contexts == nullptr ||
2134       impl->ssl_context_x509_subject_names == nullptr) {
2135     tsi_ssl_handshaker_factory_unref(&impl->base);
2136     return TSI_OUT_OF_RESOURCES;
2137   }
2138   impl->ssl_context_count = options->num_key_cert_pairs;
2139 
2140   if (options->num_alpn_protocols > 0) {
2141     result = build_alpn_protocol_name_list(
2142         options->alpn_protocols, options->num_alpn_protocols,
2143         &impl->alpn_protocol_list, &impl->alpn_protocol_list_length);
2144     if (result != TSI_OK) {
2145       tsi_ssl_handshaker_factory_unref(&impl->base);
2146       return result;
2147     }
2148   }
2149 
2150   if (options->key_logger != nullptr) {
2151     impl->key_logger = options->key_logger->Ref();
2152   }
2153 
2154   for (i = 0; i < options->num_key_cert_pairs; i++) {
2155     do {
2156 #if OPENSSL_VERSION_NUMBER >= 0x10100000
2157       impl->ssl_contexts[i] = SSL_CTX_new(TLS_method());
2158 #else
2159       impl->ssl_contexts[i] = SSL_CTX_new(TLSv1_2_method());
2160 #endif
2161       if (impl->ssl_contexts[i] == nullptr) {
2162         grpc_core::LogSslErrorStack();
2163         gpr_log(GPR_ERROR, "Could not create ssl context.");
2164         result = TSI_OUT_OF_RESOURCES;
2165         break;
2166       }
2167 
2168       result = tsi_set_min_and_max_tls_versions(impl->ssl_contexts[i],
2169                                                 options->min_tls_version,
2170                                                 options->max_tls_version);
2171       if (result != TSI_OK) return result;
2172 
2173       result = populate_ssl_context(impl->ssl_contexts[i],
2174                                     &options->pem_key_cert_pairs[i],
2175                                     options->cipher_suites);
2176       if (result != TSI_OK) break;
2177 
2178       // TODO(elessar): Provide ability to disable session ticket keys.
2179 
2180       // Allow client cache sessions (it's needed for OpenSSL only).
2181       int set_sid_ctx_result = SSL_CTX_set_session_id_context(
2182           impl->ssl_contexts[i], kSslSessionIdContext,
2183           GPR_ARRAY_SIZE(kSslSessionIdContext));
2184       if (set_sid_ctx_result == 0) {
2185         gpr_log(GPR_ERROR, "Failed to set session id context.");
2186         result = TSI_INTERNAL_ERROR;
2187         break;
2188       }
2189 
2190       if (options->session_ticket_key != nullptr) {
2191         if (SSL_CTX_set_tlsext_ticket_keys(
2192                 impl->ssl_contexts[i],
2193                 const_cast<char*>(options->session_ticket_key),
2194                 options->session_ticket_key_size) == 0) {
2195           gpr_log(GPR_ERROR, "Invalid STEK size.");
2196           result = TSI_INVALID_ARGUMENT;
2197           break;
2198         }
2199       }
2200 
2201       if (options->pem_client_root_certs != nullptr) {
2202         STACK_OF(X509_NAME)* root_names = nullptr;
2203         result = ssl_ctx_load_verification_certs(
2204             impl->ssl_contexts[i], options->pem_client_root_certs,
2205             strlen(options->pem_client_root_certs),
2206             options->send_client_ca_list ? &root_names : nullptr);
2207         if (result != TSI_OK) {
2208           gpr_log(GPR_ERROR, "Invalid verification certs.");
2209           break;
2210         }
2211         if (options->send_client_ca_list) {
2212           SSL_CTX_set_client_CA_list(impl->ssl_contexts[i], root_names);
2213         }
2214       }
2215       switch (options->client_certificate_request) {
2216         case TSI_DONT_REQUEST_CLIENT_CERTIFICATE:
2217           SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_NONE, nullptr);
2218           break;
2219         case TSI_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
2220           SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER,
2221                              NullVerifyCallback);
2222           break;
2223         case TSI_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY:
2224           SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER,
2225                              RootCertExtractCallback);
2226           break;
2227         case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
2228           SSL_CTX_set_verify(impl->ssl_contexts[i],
2229                              SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
2230                              NullVerifyCallback);
2231           break;
2232         case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY:
2233           SSL_CTX_set_verify(impl->ssl_contexts[i],
2234                              SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
2235                              RootCertExtractCallback);
2236           break;
2237       }
2238 
2239 #if OPENSSL_VERSION_NUMBER >= 0x10100000
2240       if (options->crl_directory != nullptr &&
2241           strcmp(options->crl_directory, "") != 0) {
2242         gpr_log(GPR_INFO, "enabling server CRL checking with path %s",
2243                 options->crl_directory);
2244         X509_STORE* cert_store = SSL_CTX_get_cert_store(impl->ssl_contexts[i]);
2245         X509_STORE_set_verify_cb(cert_store, verify_cb);
2246         if (!X509_STORE_load_locations(cert_store, nullptr,
2247                                        options->crl_directory)) {
2248           gpr_log(GPR_ERROR, "Failed to load CRL File from directory.");
2249         } else {
2250           X509_VERIFY_PARAM* param = X509_STORE_get0_param(cert_store);
2251           X509_VERIFY_PARAM_set_flags(
2252               param, X509_V_FLAG_CRL_CHECK | X509_V_FLAG_CRL_CHECK_ALL);
2253           gpr_log(GPR_INFO, "enabled server CRL checking.");
2254         }
2255       }
2256 #endif
2257 
2258       result = tsi_ssl_extract_x509_subject_names_from_pem_cert(
2259           options->pem_key_cert_pairs[i].cert_chain,
2260           &impl->ssl_context_x509_subject_names[i]);
2261       if (result != TSI_OK) break;
2262 
2263       SSL_CTX_set_tlsext_servername_callback(
2264           impl->ssl_contexts[i],
2265           ssl_server_handshaker_factory_servername_callback);
2266       SSL_CTX_set_tlsext_servername_arg(impl->ssl_contexts[i], impl);
2267 #if TSI_OPENSSL_ALPN_SUPPORT
2268       SSL_CTX_set_alpn_select_cb(impl->ssl_contexts[i],
2269                                  server_handshaker_factory_alpn_callback, impl);
2270 #endif  // TSI_OPENSSL_ALPN_SUPPORT
2271       SSL_CTX_set_next_protos_advertised_cb(
2272           impl->ssl_contexts[i],
2273           server_handshaker_factory_npn_advertised_callback, impl);
2274 
2275 #if OPENSSL_VERSION_NUMBER >= 0x10101000 && !defined(LIBRESSL_VERSION_NUMBER)
2276       // Register factory at index
2277       if (options->key_logger != nullptr) {
2278         // Need to set factory at g_ssl_ctx_ex_factory_index
2279         SSL_CTX_set_ex_data(impl->ssl_contexts[i], g_ssl_ctx_ex_factory_index,
2280                             impl);
2281         // SSL_CTX_set_keylog_callback is set here to register callback
2282         // when ssl/tls handshakes complete.
2283         SSL_CTX_set_keylog_callback(
2284             impl->ssl_contexts[i],
2285             ssl_keylogging_callback<tsi_ssl_server_handshaker_factory>);
2286       }
2287 #endif
2288     } while (false);
2289 
2290     if (result != TSI_OK) {
2291       tsi_ssl_handshaker_factory_unref(&impl->base);
2292       return result;
2293     }
2294   }
2295 
2296   *factory = impl;
2297   return TSI_OK;
2298 }
2299 
2300 // --- tsi_ssl utils. ---
2301 
tsi_ssl_peer_matches_name(const tsi_peer * peer,absl::string_view name)2302 int tsi_ssl_peer_matches_name(const tsi_peer* peer, absl::string_view name) {
2303   size_t i = 0;
2304   size_t san_count = 0;
2305   const tsi_peer_property* cn_property = nullptr;
2306   int like_ip = looks_like_ip_address(name);
2307 
2308   // Check the SAN first.
2309   for (i = 0; i < peer->property_count; i++) {
2310     const tsi_peer_property* property = &peer->properties[i];
2311     if (property->name == nullptr) continue;
2312     if (strcmp(property->name,
2313                TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == 0) {
2314       san_count++;
2315 
2316       absl::string_view entry(property->value.data, property->value.length);
2317       if (!like_ip && does_entry_match_name(entry, name)) {
2318         return 1;
2319       } else if (like_ip && name == entry) {
2320         // IP Addresses are exact matches only.
2321         return 1;
2322       }
2323     } else if (strcmp(property->name,
2324                       TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY) == 0) {
2325       cn_property = property;
2326     }
2327   }
2328 
2329   // If there's no SAN, try the CN, but only if its not like an IP Address
2330   if (san_count == 0 && cn_property != nullptr && !like_ip) {
2331     if (does_entry_match_name(absl::string_view(cn_property->value.data,
2332                                                 cn_property->value.length),
2333                               name)) {
2334       return 1;
2335     }
2336   }
2337 
2338   return 0;  // Not found.
2339 }
2340 
2341 // --- Testing support. ---
tsi_ssl_handshaker_factory_swap_vtable(tsi_ssl_handshaker_factory * factory,tsi_ssl_handshaker_factory_vtable * new_vtable)2342 const tsi_ssl_handshaker_factory_vtable* tsi_ssl_handshaker_factory_swap_vtable(
2343     tsi_ssl_handshaker_factory* factory,
2344     tsi_ssl_handshaker_factory_vtable* new_vtable) {
2345   GPR_ASSERT(factory != nullptr);
2346   GPR_ASSERT(factory->vtable != nullptr);
2347 
2348   const tsi_ssl_handshaker_factory_vtable* orig_vtable = factory->vtable;
2349   factory->vtable = new_vtable;
2350   return orig_vtable;
2351 }
2352