1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/lib/security/transport/security_handshaker.h"
22 
23 #include <limits.h>
24 #include <stdint.h>
25 #include <string.h>
26 
27 #include <algorithm>
28 #include <memory>
29 #include <string>
30 
31 #include "absl/base/attributes.h"
32 #include "absl/status/status.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/string_view.h"
35 #include "absl/types/optional.h"
36 
37 #include <grpc/grpc_security.h>
38 #include <grpc/grpc_security_constants.h>
39 #include <grpc/slice.h>
40 #include <grpc/slice_buffer.h>
41 #include <grpc/support/alloc.h>
42 #include <grpc/support/log.h>
43 
44 #include "src/core/lib/channel/channel_args.h"
45 #include "src/core/lib/channel/channelz.h"
46 #include "src/core/lib/config/core_configuration.h"
47 #include "src/core/lib/debug/stats.h"
48 #include "src/core/lib/debug/stats_data.h"
49 #include "src/core/lib/gprpp/debug_location.h"
50 #include "src/core/lib/gprpp/ref_counted_ptr.h"
51 #include "src/core/lib/gprpp/status_helper.h"
52 #include "src/core/lib/gprpp/sync.h"
53 #include "src/core/lib/gprpp/unique_type_name.h"
54 #include "src/core/lib/iomgr/closure.h"
55 #include "src/core/lib/iomgr/endpoint.h"
56 #include "src/core/lib/iomgr/error.h"
57 #include "src/core/lib/iomgr/exec_ctx.h"
58 #include "src/core/lib/iomgr/iomgr_fwd.h"
59 #include "src/core/lib/iomgr/tcp_server.h"
60 #include "src/core/lib/security/context/security_context.h"
61 #include "src/core/lib/security/transport/secure_endpoint.h"
62 #include "src/core/lib/security/transport/tsi_error.h"
63 #include "src/core/lib/slice/slice.h"
64 #include "src/core/lib/slice/slice_internal.h"
65 #include "src/core/lib/transport/handshaker.h"
66 #include "src/core/lib/transport/handshaker_factory.h"
67 #include "src/core/lib/transport/handshaker_registry.h"
68 #include "src/core/tsi/transport_security_grpc.h"
69 
70 #define GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE 256
71 
72 namespace grpc_core {
73 
74 namespace {
75 
76 class SecurityHandshaker : public Handshaker {
77  public:
78   SecurityHandshaker(tsi_handshaker* handshaker,
79                      grpc_security_connector* connector,
80                      const ChannelArgs& args);
81   ~SecurityHandshaker() override;
82   void Shutdown(grpc_error_handle why) override;
83   void DoHandshake(grpc_tcp_server_acceptor* acceptor,
84                    grpc_closure* on_handshake_done,
85                    HandshakerArgs* args) override;
name() const86   const char* name() const override { return "security"; }
87 
88  private:
89   grpc_error_handle DoHandshakerNextLocked(const unsigned char* bytes_received,
90                                            size_t bytes_received_size);
91 
92   grpc_error_handle OnHandshakeNextDoneLocked(
93       tsi_result result, const unsigned char* bytes_to_send,
94       size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result);
95   void HandshakeFailedLocked(grpc_error_handle error);
96   void CleanupArgsForFailureLocked();
97 
98   static void OnHandshakeDataReceivedFromPeerFn(void* arg,
99                                                 grpc_error_handle error);
100   static void OnHandshakeDataSentToPeerFn(void* arg, grpc_error_handle error);
101   static void OnHandshakeDataReceivedFromPeerFnScheduler(
102       void* arg, grpc_error_handle error);
103   static void OnHandshakeDataSentToPeerFnScheduler(void* arg,
104                                                    grpc_error_handle error);
105   static void OnHandshakeNextDoneGrpcWrapper(
106       tsi_result result, void* user_data, const unsigned char* bytes_to_send,
107       size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result);
108   static void OnPeerCheckedFn(void* arg, grpc_error_handle error);
109   void OnPeerCheckedInner(grpc_error_handle error);
110   size_t MoveReadBufferIntoHandshakeBuffer();
111   grpc_error_handle CheckPeerLocked();
112 
113   // State set at creation time.
114   tsi_handshaker* handshaker_;
115   RefCountedPtr<grpc_security_connector> connector_;
116 
117   Mutex mu_;
118 
119   bool is_shutdown_ = false;
120   // Endpoint and read buffer to destroy after a shutdown.
121   grpc_endpoint* endpoint_to_destroy_ = nullptr;
122   grpc_slice_buffer* read_buffer_to_destroy_ = nullptr;
123 
124   // State saved while performing the handshake.
125   HandshakerArgs* args_ = nullptr;
126   grpc_closure* on_handshake_done_ = nullptr;
127 
128   size_t handshake_buffer_size_;
129   unsigned char* handshake_buffer_;
130   grpc_slice_buffer outgoing_;
131   grpc_closure on_handshake_data_sent_to_peer_;
132   grpc_closure on_handshake_data_received_from_peer_;
133   grpc_closure on_peer_checked_;
134   RefCountedPtr<grpc_auth_context> auth_context_;
135   tsi_handshaker_result* handshaker_result_ = nullptr;
136   size_t max_frame_size_ = 0;
137   std::string tsi_handshake_error_;
138 };
139 
SecurityHandshaker(tsi_handshaker * handshaker,grpc_security_connector * connector,const ChannelArgs & args)140 SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker,
141                                        grpc_security_connector* connector,
142                                        const ChannelArgs& args)
143     : handshaker_(handshaker),
144       connector_(connector->Ref(DEBUG_LOCATION, "handshake")),
145       handshake_buffer_size_(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE),
146       handshake_buffer_(
147           static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size_))),
148       max_frame_size_(
149           std::max(0, args.GetInt(GRPC_ARG_TSI_MAX_FRAME_SIZE).value_or(0))) {
150   grpc_slice_buffer_init(&outgoing_);
151   GRPC_CLOSURE_INIT(&on_peer_checked_, &SecurityHandshaker::OnPeerCheckedFn,
152                     this, grpc_schedule_on_exec_ctx);
153 }
154 
~SecurityHandshaker()155 SecurityHandshaker::~SecurityHandshaker() {
156   tsi_handshaker_destroy(handshaker_);
157   tsi_handshaker_result_destroy(handshaker_result_);
158   if (endpoint_to_destroy_ != nullptr) {
159     grpc_endpoint_destroy(endpoint_to_destroy_);
160   }
161   if (read_buffer_to_destroy_ != nullptr) {
162     grpc_slice_buffer_destroy(read_buffer_to_destroy_);
163     gpr_free(read_buffer_to_destroy_);
164   }
165   gpr_free(handshake_buffer_);
166   grpc_slice_buffer_destroy(&outgoing_);
167   auth_context_.reset(DEBUG_LOCATION, "handshake");
168   connector_.reset(DEBUG_LOCATION, "handshake");
169 }
170 
MoveReadBufferIntoHandshakeBuffer()171 size_t SecurityHandshaker::MoveReadBufferIntoHandshakeBuffer() {
172   size_t bytes_in_read_buffer = args_->read_buffer->length;
173   if (handshake_buffer_size_ < bytes_in_read_buffer) {
174     handshake_buffer_ = static_cast<uint8_t*>(
175         gpr_realloc(handshake_buffer_, bytes_in_read_buffer));
176     handshake_buffer_size_ = bytes_in_read_buffer;
177   }
178   size_t offset = 0;
179   while (args_->read_buffer->count > 0) {
180     grpc_slice* next_slice = grpc_slice_buffer_peek_first(args_->read_buffer);
181     memcpy(handshake_buffer_ + offset, GRPC_SLICE_START_PTR(*next_slice),
182            GRPC_SLICE_LENGTH(*next_slice));
183     offset += GRPC_SLICE_LENGTH(*next_slice);
184     grpc_slice_buffer_remove_first(args_->read_buffer);
185   }
186   return bytes_in_read_buffer;
187 }
188 
189 // Set args_ fields to NULL, saving the endpoint and read buffer for
190 // later destruction.
CleanupArgsForFailureLocked()191 void SecurityHandshaker::CleanupArgsForFailureLocked() {
192   endpoint_to_destroy_ = args_->endpoint;
193   args_->endpoint = nullptr;
194   read_buffer_to_destroy_ = args_->read_buffer;
195   args_->read_buffer = nullptr;
196   args_->args = ChannelArgs();
197 }
198 
199 // If the handshake failed or we're shutting down, clean up and invoke the
200 // callback with the error.
HandshakeFailedLocked(grpc_error_handle error)201 void SecurityHandshaker::HandshakeFailedLocked(grpc_error_handle error) {
202   if (error.ok()) {
203     // If we were shut down after the handshake succeeded but before an
204     // endpoint callback was invoked, we need to generate our own error.
205     error = GRPC_ERROR_CREATE("Handshaker shutdown");
206   }
207   gpr_log(GPR_DEBUG, "Security handshake failed: %s",
208           StatusToString(error).c_str());
209   if (!is_shutdown_) {
210     tsi_handshaker_shutdown(handshaker_);
211     // TODO(ctiller): It is currently necessary to shutdown endpoints
212     // before destroying them, even if we know that there are no
213     // pending read/write callbacks.  This should be fixed, at which
214     // point this can be removed.
215     grpc_endpoint_shutdown(args_->endpoint, error);
216     // Not shutting down, so the write failed.  Clean up before
217     // invoking the callback.
218     CleanupArgsForFailureLocked();
219     // Set shutdown to true so that subsequent calls to
220     // security_handshaker_shutdown() do nothing.
221     is_shutdown_ = true;
222   }
223   // Invoke callback.
224   ExecCtx::Run(DEBUG_LOCATION, on_handshake_done_, error);
225 }
226 
227 namespace {
228 
229 RefCountedPtr<channelz::SocketNode::Security>
MakeChannelzSecurityFromAuthContext(grpc_auth_context * auth_context)230 MakeChannelzSecurityFromAuthContext(grpc_auth_context* auth_context) {
231   RefCountedPtr<channelz::SocketNode::Security> security =
232       MakeRefCounted<channelz::SocketNode::Security>();
233   // TODO(yashykt): Currently, we are assuming TLS by default and are only able
234   // to fill in the remote certificate but we should ideally be able to fill in
235   // other fields in
236   // https://github.com/grpc/grpc/blob/fcd43e90304862a823316b224ee733d17a8cfd90/src/proto/grpc/channelz/channelz.proto#L326
237   // from grpc_auth_context.
238   security->type = channelz::SocketNode::Security::ModelType::kTls;
239   security->tls = absl::make_optional<channelz::SocketNode::Security::Tls>();
240   grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name(
241       auth_context, GRPC_X509_PEM_CERT_PROPERTY_NAME);
242   const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it);
243   if (prop != nullptr) {
244     security->tls->remote_certificate =
245         std::string(prop->value, prop->value_length);
246   }
247   return security;
248 }
249 
250 }  // namespace
251 
OnPeerCheckedInner(grpc_error_handle error)252 void SecurityHandshaker::OnPeerCheckedInner(grpc_error_handle error) {
253   MutexLock lock(&mu_);
254   if (!error.ok() || is_shutdown_) {
255     HandshakeFailedLocked(error);
256     return;
257   }
258   // Get unused bytes.
259   const unsigned char* unused_bytes = nullptr;
260   size_t unused_bytes_size = 0;
261   tsi_result result = tsi_handshaker_result_get_unused_bytes(
262       handshaker_result_, &unused_bytes, &unused_bytes_size);
263   if (result != TSI_OK) {
264     HandshakeFailedLocked(grpc_set_tsi_error_result(
265         GRPC_ERROR_CREATE(
266             "TSI handshaker result does not provide unused bytes"),
267         result));
268     return;
269   }
270   // Check whether we need to wrap the endpoint.
271   tsi_frame_protector_type frame_protector_type;
272   result = tsi_handshaker_result_get_frame_protector_type(
273       handshaker_result_, &frame_protector_type);
274   if (result != TSI_OK) {
275     HandshakeFailedLocked(grpc_set_tsi_error_result(
276         GRPC_ERROR_CREATE("TSI handshaker result does not implement "
277                           "get_frame_protector_type"),
278         result));
279     return;
280   }
281   tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr;
282   tsi_frame_protector* protector = nullptr;
283   switch (frame_protector_type) {
284     case TSI_FRAME_PROTECTOR_ZERO_COPY:
285       ABSL_FALLTHROUGH_INTENDED;
286     case TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY:
287       // Create zero-copy frame protector.
288       result = tsi_handshaker_result_create_zero_copy_grpc_protector(
289           handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_,
290           &zero_copy_protector);
291       if (result != TSI_OK) {
292         HandshakeFailedLocked(grpc_set_tsi_error_result(
293             GRPC_ERROR_CREATE("Zero-copy frame protector creation failed"),
294             result));
295         return;
296       }
297       break;
298     case TSI_FRAME_PROTECTOR_NORMAL:
299       // Create normal frame protector.
300       result = tsi_handshaker_result_create_frame_protector(
301           handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_,
302           &protector);
303       if (result != TSI_OK) {
304         HandshakeFailedLocked(grpc_set_tsi_error_result(
305             GRPC_ERROR_CREATE("Frame protector creation failed"), result));
306         return;
307       }
308       break;
309     case TSI_FRAME_PROTECTOR_NONE:
310       break;
311   }
312   bool has_frame_protector =
313       zero_copy_protector != nullptr || protector != nullptr;
314   // If we have a frame protector, create a secure endpoint.
315   if (has_frame_protector) {
316     if (unused_bytes_size > 0) {
317       grpc_slice slice = grpc_slice_from_copied_buffer(
318           reinterpret_cast<const char*>(unused_bytes), unused_bytes_size);
319       args_->endpoint = grpc_secure_endpoint_create(
320           protector, zero_copy_protector, args_->endpoint, &slice,
321           args_->args.ToC().get(), 1);
322       CSliceUnref(slice);
323     } else {
324       args_->endpoint = grpc_secure_endpoint_create(
325           protector, zero_copy_protector, args_->endpoint, nullptr,
326           args_->args.ToC().get(), 0);
327     }
328   } else if (unused_bytes_size > 0) {
329     // Not wrapping the endpoint, so just pass along unused bytes.
330     grpc_slice slice = grpc_slice_from_copied_buffer(
331         reinterpret_cast<const char*>(unused_bytes), unused_bytes_size);
332     grpc_slice_buffer_add(args_->read_buffer, slice);
333   }
334   // Done with handshaker result.
335   tsi_handshaker_result_destroy(handshaker_result_);
336   handshaker_result_ = nullptr;
337   args_->args = args_->args.SetObject(auth_context_);
338   // Add channelz channel args only if frame protector is created.
339   if (has_frame_protector) {
340     args_->args = args_->args.SetObject(
341         MakeChannelzSecurityFromAuthContext(auth_context_.get()));
342   }
343   // Invoke callback.
344   ExecCtx::Run(DEBUG_LOCATION, on_handshake_done_, absl::OkStatus());
345   // Set shutdown to true so that subsequent calls to
346   // security_handshaker_shutdown() do nothing.
347   is_shutdown_ = true;
348 }
349 
OnPeerCheckedFn(void * arg,grpc_error_handle error)350 void SecurityHandshaker::OnPeerCheckedFn(void* arg, grpc_error_handle error) {
351   RefCountedPtr<SecurityHandshaker>(static_cast<SecurityHandshaker*>(arg))
352       ->OnPeerCheckedInner(error);
353 }
354 
CheckPeerLocked()355 grpc_error_handle SecurityHandshaker::CheckPeerLocked() {
356   tsi_peer peer;
357   tsi_result result =
358       tsi_handshaker_result_extract_peer(handshaker_result_, &peer);
359   if (result != TSI_OK) {
360     return grpc_set_tsi_error_result(
361         GRPC_ERROR_CREATE("Peer extraction failed"), result);
362   }
363   connector_->check_peer(peer, args_->endpoint, args_->args, &auth_context_,
364                          &on_peer_checked_);
365   grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name(
366       auth_context_.get(), GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME);
367   const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it);
368   if (!prop ||
369       !strcmp(tsi_security_level_to_string(TSI_SECURITY_NONE), prop->value)) {
370     global_stats().IncrementInsecureConnectionsCreated();
371   }
372   return absl::OkStatus();
373 }
374 
OnHandshakeNextDoneLocked(tsi_result result,const unsigned char * bytes_to_send,size_t bytes_to_send_size,tsi_handshaker_result * handshaker_result)375 grpc_error_handle SecurityHandshaker::OnHandshakeNextDoneLocked(
376     tsi_result result, const unsigned char* bytes_to_send,
377     size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) {
378   grpc_error_handle error;
379   // Handshaker was shutdown.
380   if (is_shutdown_) {
381     return GRPC_ERROR_CREATE("Handshaker shutdown");
382   }
383   // Read more if we need to.
384   if (result == TSI_INCOMPLETE_DATA) {
385     GPR_ASSERT(bytes_to_send_size == 0);
386     grpc_endpoint_read(
387         args_->endpoint, args_->read_buffer,
388         GRPC_CLOSURE_INIT(
389             &on_handshake_data_received_from_peer_,
390             &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler,
391             this, grpc_schedule_on_exec_ctx),
392         /*urgent=*/true, /*min_progress_size=*/1);
393     return error;
394   }
395   if (result != TSI_OK) {
396     auto* security_connector = args_->args.GetObject<grpc_security_connector>();
397     absl::string_view connector_type = "<unknown>";
398     if (security_connector != nullptr) {
399       connector_type = security_connector->type().name();
400     }
401     return grpc_set_tsi_error_result(
402         GRPC_ERROR_CREATE(absl::StrCat(
403             connector_type, " handshake failed",
404             (tsi_handshake_error_.empty() ? "" : ": "), tsi_handshake_error_)),
405         result);
406   }
407   // Update handshaker result.
408   if (handshaker_result != nullptr) {
409     GPR_ASSERT(handshaker_result_ == nullptr);
410     handshaker_result_ = handshaker_result;
411   }
412   if (bytes_to_send_size > 0) {
413     // Send data to peer, if needed.
414     grpc_slice to_send = grpc_slice_from_copied_buffer(
415         reinterpret_cast<const char*>(bytes_to_send), bytes_to_send_size);
416     grpc_slice_buffer_reset_and_unref(&outgoing_);
417     grpc_slice_buffer_add(&outgoing_, to_send);
418     grpc_endpoint_write(
419         args_->endpoint, &outgoing_,
420         GRPC_CLOSURE_INIT(
421             &on_handshake_data_sent_to_peer_,
422             &SecurityHandshaker::OnHandshakeDataSentToPeerFnScheduler, this,
423             grpc_schedule_on_exec_ctx),
424         nullptr, /*max_frame_size=*/INT_MAX);
425   } else if (handshaker_result == nullptr) {
426     // There is nothing to send, but need to read from peer.
427     grpc_endpoint_read(
428         args_->endpoint, args_->read_buffer,
429         GRPC_CLOSURE_INIT(
430             &on_handshake_data_received_from_peer_,
431             &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler,
432             this, grpc_schedule_on_exec_ctx),
433         /*urgent=*/true, /*min_progress_size=*/1);
434   } else {
435     // Handshake has finished, check peer and so on.
436     error = CheckPeerLocked();
437   }
438   return error;
439 }
440 
OnHandshakeNextDoneGrpcWrapper(tsi_result result,void * user_data,const unsigned char * bytes_to_send,size_t bytes_to_send_size,tsi_handshaker_result * handshaker_result)441 void SecurityHandshaker::OnHandshakeNextDoneGrpcWrapper(
442     tsi_result result, void* user_data, const unsigned char* bytes_to_send,
443     size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) {
444   RefCountedPtr<SecurityHandshaker> h(
445       static_cast<SecurityHandshaker*>(user_data));
446   MutexLock lock(&h->mu_);
447   grpc_error_handle error = h->OnHandshakeNextDoneLocked(
448       result, bytes_to_send, bytes_to_send_size, handshaker_result);
449   if (!error.ok()) {
450     h->HandshakeFailedLocked(error);
451   } else {
452     h.release();  // Avoid unref
453   }
454 }
455 
DoHandshakerNextLocked(const unsigned char * bytes_received,size_t bytes_received_size)456 grpc_error_handle SecurityHandshaker::DoHandshakerNextLocked(
457     const unsigned char* bytes_received, size_t bytes_received_size) {
458   // Invoke TSI handshaker.
459   const unsigned char* bytes_to_send = nullptr;
460   size_t bytes_to_send_size = 0;
461   tsi_handshaker_result* hs_result = nullptr;
462   tsi_result result = tsi_handshaker_next(
463       handshaker_, bytes_received, bytes_received_size, &bytes_to_send,
464       &bytes_to_send_size, &hs_result, &OnHandshakeNextDoneGrpcWrapper, this,
465       &tsi_handshake_error_);
466   if (result == TSI_ASYNC) {
467     // Handshaker operating asynchronously. Nothing else to do here;
468     // callback will be invoked in a TSI thread.
469     return absl::OkStatus();
470   }
471   // Handshaker returned synchronously. Invoke callback directly in
472   // this thread with our existing exec_ctx.
473   return OnHandshakeNextDoneLocked(result, bytes_to_send, bytes_to_send_size,
474                                    hs_result);
475 }
476 
477 // This callback might be run inline while we are still holding on to the mutex,
478 // so schedule OnHandshakeDataReceivedFromPeerFn on ExecCtx to avoid a deadlock.
OnHandshakeDataReceivedFromPeerFnScheduler(void * arg,grpc_error_handle error)479 void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler(
480     void* arg, grpc_error_handle error) {
481   SecurityHandshaker* h = static_cast<SecurityHandshaker*>(arg);
482   ExecCtx::Run(
483       DEBUG_LOCATION,
484       GRPC_CLOSURE_INIT(&h->on_handshake_data_received_from_peer_,
485                         &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn,
486                         h, grpc_schedule_on_exec_ctx),
487       error);
488 }
489 
OnHandshakeDataReceivedFromPeerFn(void * arg,grpc_error_handle error)490 void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(
491     void* arg, grpc_error_handle error) {
492   RefCountedPtr<SecurityHandshaker> h(static_cast<SecurityHandshaker*>(arg));
493   MutexLock lock(&h->mu_);
494   if (!error.ok() || h->is_shutdown_) {
495     h->HandshakeFailedLocked(
496         GRPC_ERROR_CREATE_REFERENCING("Handshake read failed", &error, 1));
497     return;
498   }
499   // Copy all slices received.
500   size_t bytes_received_size = h->MoveReadBufferIntoHandshakeBuffer();
501   // Call TSI handshaker.
502   error = h->DoHandshakerNextLocked(h->handshake_buffer_, bytes_received_size);
503   if (!error.ok()) {
504     h->HandshakeFailedLocked(error);
505   } else {
506     h.release();  // Avoid unref
507   }
508 }
509 
510 // This callback might be run inline while we are still holding on to the mutex,
511 // so schedule OnHandshakeDataSentToPeerFn on ExecCtx to avoid a deadlock.
OnHandshakeDataSentToPeerFnScheduler(void * arg,grpc_error_handle error)512 void SecurityHandshaker::OnHandshakeDataSentToPeerFnScheduler(
513     void* arg, grpc_error_handle error) {
514   SecurityHandshaker* h = static_cast<SecurityHandshaker*>(arg);
515   ExecCtx::Run(
516       DEBUG_LOCATION,
517       GRPC_CLOSURE_INIT(&h->on_handshake_data_sent_to_peer_,
518                         &SecurityHandshaker::OnHandshakeDataSentToPeerFn, h,
519                         grpc_schedule_on_exec_ctx),
520       error);
521 }
522 
OnHandshakeDataSentToPeerFn(void * arg,grpc_error_handle error)523 void SecurityHandshaker::OnHandshakeDataSentToPeerFn(void* arg,
524                                                      grpc_error_handle error) {
525   RefCountedPtr<SecurityHandshaker> h(static_cast<SecurityHandshaker*>(arg));
526   MutexLock lock(&h->mu_);
527   if (!error.ok() || h->is_shutdown_) {
528     h->HandshakeFailedLocked(
529         GRPC_ERROR_CREATE_REFERENCING("Handshake write failed", &error, 1));
530     return;
531   }
532   // We may be done.
533   if (h->handshaker_result_ == nullptr) {
534     grpc_endpoint_read(
535         h->args_->endpoint, h->args_->read_buffer,
536         GRPC_CLOSURE_INIT(
537             &h->on_handshake_data_received_from_peer_,
538             &SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler,
539             h.get(), grpc_schedule_on_exec_ctx),
540         /*urgent=*/true, /*min_progress_size=*/1);
541   } else {
542     error = h->CheckPeerLocked();
543     if (!error.ok()) {
544       h->HandshakeFailedLocked(error);
545       return;
546     }
547   }
548   h.release();  // Avoid unref
549 }
550 
551 //
552 // public handshaker API
553 //
554 
Shutdown(grpc_error_handle why)555 void SecurityHandshaker::Shutdown(grpc_error_handle why) {
556   MutexLock lock(&mu_);
557   if (!is_shutdown_) {
558     is_shutdown_ = true;
559     connector_->cancel_check_peer(&on_peer_checked_, why);
560     tsi_handshaker_shutdown(handshaker_);
561     grpc_endpoint_shutdown(args_->endpoint, why);
562     CleanupArgsForFailureLocked();
563   }
564 }
565 
DoHandshake(grpc_tcp_server_acceptor *,grpc_closure * on_handshake_done,HandshakerArgs * args)566 void SecurityHandshaker::DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/,
567                                      grpc_closure* on_handshake_done,
568                                      HandshakerArgs* args) {
569   auto ref = Ref();
570   MutexLock lock(&mu_);
571   args_ = args;
572   on_handshake_done_ = on_handshake_done;
573   size_t bytes_received_size = MoveReadBufferIntoHandshakeBuffer();
574   grpc_error_handle error =
575       DoHandshakerNextLocked(handshake_buffer_, bytes_received_size);
576   if (!error.ok()) {
577     HandshakeFailedLocked(error);
578   } else {
579     ref.release();  // Avoid unref
580   }
581 }
582 
583 //
584 // FailHandshaker
585 //
586 
587 class FailHandshaker : public Handshaker {
588  public:
name() const589   const char* name() const override { return "security_fail"; }
Shutdown(grpc_error_handle)590   void Shutdown(grpc_error_handle /*why*/) override {}
DoHandshake(grpc_tcp_server_acceptor *,grpc_closure * on_handshake_done,HandshakerArgs * args)591   void DoHandshake(grpc_tcp_server_acceptor* /*acceptor*/,
592                    grpc_closure* on_handshake_done,
593                    HandshakerArgs* args) override {
594     grpc_error_handle error =
595         GRPC_ERROR_CREATE("Failed to create security handshaker");
596     grpc_endpoint_shutdown(args->endpoint, error);
597     grpc_endpoint_destroy(args->endpoint);
598     args->endpoint = nullptr;
599     args->args = ChannelArgs();
600     grpc_slice_buffer_destroy(args->read_buffer);
601     gpr_free(args->read_buffer);
602     args->read_buffer = nullptr;
603     ExecCtx::Run(DEBUG_LOCATION, on_handshake_done, error);
604   }
605 
606  private:
607   ~FailHandshaker() override = default;
608 };
609 
610 //
611 // handshaker factories
612 //
613 
614 class ClientSecurityHandshakerFactory : public HandshakerFactory {
615  public:
AddHandshakers(const ChannelArgs & args,grpc_pollset_set * interested_parties,HandshakeManager * handshake_mgr)616   void AddHandshakers(const ChannelArgs& args,
617                       grpc_pollset_set* interested_parties,
618                       HandshakeManager* handshake_mgr) override {
619     auto* security_connector =
620         args.GetObject<grpc_channel_security_connector>();
621     if (security_connector) {
622       security_connector->add_handshakers(args, interested_parties,
623                                           handshake_mgr);
624     }
625   }
Priority()626   HandshakerPriority Priority() override {
627     return HandshakerPriority::kSecurityHandshakers;
628   }
629   ~ClientSecurityHandshakerFactory() override = default;
630 };
631 
632 class ServerSecurityHandshakerFactory : public HandshakerFactory {
633  public:
AddHandshakers(const ChannelArgs & args,grpc_pollset_set * interested_parties,HandshakeManager * handshake_mgr)634   void AddHandshakers(const ChannelArgs& args,
635                       grpc_pollset_set* interested_parties,
636                       HandshakeManager* handshake_mgr) override {
637     auto* security_connector = args.GetObject<grpc_server_security_connector>();
638     if (security_connector) {
639       security_connector->add_handshakers(args, interested_parties,
640                                           handshake_mgr);
641     }
642   }
Priority()643   HandshakerPriority Priority() override {
644     return HandshakerPriority::kSecurityHandshakers;
645   }
646   ~ServerSecurityHandshakerFactory() override = default;
647 };
648 
649 }  // namespace
650 
651 //
652 // exported functions
653 //
654 
SecurityHandshakerCreate(tsi_handshaker * handshaker,grpc_security_connector * connector,const ChannelArgs & args)655 RefCountedPtr<Handshaker> SecurityHandshakerCreate(
656     tsi_handshaker* handshaker, grpc_security_connector* connector,
657     const ChannelArgs& args) {
658   // If no TSI handshaker was created, return a handshaker that always fails.
659   // Otherwise, return a real security handshaker.
660   if (handshaker == nullptr) {
661     return MakeRefCounted<FailHandshaker>();
662   } else {
663     return MakeRefCounted<SecurityHandshaker>(handshaker, connector, args);
664   }
665 }
666 
SecurityRegisterHandshakerFactories(CoreConfiguration::Builder * builder)667 void SecurityRegisterHandshakerFactories(CoreConfiguration::Builder* builder) {
668   builder->handshaker_registry()->RegisterHandshakerFactory(
669       HANDSHAKER_CLIENT, std::make_unique<ClientSecurityHandshakerFactory>());
670   builder->handshaker_registry()->RegisterHandshakerFactory(
671       HANDSHAKER_SERVER, std::make_unique<ServerSecurityHandshakerFactory>());
672 }
673 
674 }  // namespace grpc_core
675 
grpc_security_handshaker_create(tsi_handshaker * handshaker,grpc_security_connector * connector,const grpc_channel_args * args)676 grpc_handshaker* grpc_security_handshaker_create(
677     tsi_handshaker* handshaker, grpc_security_connector* connector,
678     const grpc_channel_args* args) {
679   return SecurityHandshakerCreate(handshaker, connector,
680                                   grpc_core::ChannelArgs::FromC(args))
681       .release();
682 }
683