1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h"
22 
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 
27 #include "upb/upb.hpp"
28 
29 #include <grpc/grpc_security.h>
30 #include <grpc/support/alloc.h>
31 #include <grpc/support/log.h>
32 #include <grpc/support/string_util.h>
33 #include <grpc/support/sync.h>
34 #include <grpc/support/thd_id.h>
35 
36 #include "src/core/lib/gprpp/crash.h"
37 #include "src/core/lib/gprpp/memory.h"
38 #include "src/core/lib/gprpp/sync.h"
39 #include "src/core/lib/gprpp/thd.h"
40 #include "src/core/lib/iomgr/closure.h"
41 #include "src/core/lib/slice/slice_internal.h"
42 #include "src/core/lib/surface/channel.h"
43 #include "src/core/tsi/alts/frame_protector/alts_frame_protector.h"
44 #include "src/core/tsi/alts/handshaker/alts_handshaker_client.h"
45 #include "src/core/tsi/alts/handshaker/alts_shared_resource.h"
46 #include "src/core/tsi/alts/handshaker/alts_tsi_utils.h"
47 #include "src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.h"
48 
49 // Main struct for ALTS TSI handshaker.
50 struct alts_tsi_handshaker {
51   tsi_handshaker base;
52   grpc_slice target_name;
53   bool is_client;
54   bool has_sent_start_message = false;
55   bool has_created_handshaker_client = false;
56   char* handshaker_service_url;
57   grpc_pollset_set* interested_parties;
58   grpc_alts_credentials_options* options;
59   alts_handshaker_client_vtable* client_vtable_for_testing = nullptr;
60   grpc_channel* channel = nullptr;
61   bool use_dedicated_cq;
62   // mu synchronizes all fields below. Note these are the
63   // only fields that can be concurrently accessed (due to
64   // potential concurrency of tsi_handshaker_shutdown and
65   // tsi_handshaker_next).
66   grpc_core::Mutex mu;
67   alts_handshaker_client* client = nullptr;
68   // shutdown effectively follows base.handshake_shutdown,
69   // but is synchronized by the mutex of this object.
70   bool shutdown = false;
71   // Maximum frame size used by frame protector.
72   size_t max_frame_size;
73 };
74 
75 // Main struct for ALTS TSI handshaker result.
76 typedef struct alts_tsi_handshaker_result {
77   tsi_handshaker_result base;
78   char* peer_identity;
79   char* key_data;
80   unsigned char* unused_bytes;
81   size_t unused_bytes_size;
82   grpc_slice rpc_versions;
83   bool is_client;
84   grpc_slice serialized_context;
85   // Peer's maximum frame size.
86   size_t max_frame_size;
87 } alts_tsi_handshaker_result;
88 
handshaker_result_extract_peer(const tsi_handshaker_result * self,tsi_peer * peer)89 static tsi_result handshaker_result_extract_peer(
90     const tsi_handshaker_result* self, tsi_peer* peer) {
91   if (self == nullptr || peer == nullptr) {
92     gpr_log(GPR_ERROR, "Invalid argument to handshaker_result_extract_peer()");
93     return TSI_INVALID_ARGUMENT;
94   }
95   alts_tsi_handshaker_result* result =
96       reinterpret_cast<alts_tsi_handshaker_result*>(
97           const_cast<tsi_handshaker_result*>(self));
98   GPR_ASSERT(kTsiAltsNumOfPeerProperties == 5);
99   tsi_result ok = tsi_construct_peer(kTsiAltsNumOfPeerProperties, peer);
100   int index = 0;
101   if (ok != TSI_OK) {
102     gpr_log(GPR_ERROR, "Failed to construct tsi peer");
103     return ok;
104   }
105   GPR_ASSERT(&peer->properties[index] != nullptr);
106   ok = tsi_construct_string_peer_property_from_cstring(
107       TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE,
108       &peer->properties[index]);
109   if (ok != TSI_OK) {
110     tsi_peer_destruct(peer);
111     gpr_log(GPR_ERROR, "Failed to set tsi peer property");
112     return ok;
113   }
114   index++;
115   GPR_ASSERT(&peer->properties[index] != nullptr);
116   ok = tsi_construct_string_peer_property_from_cstring(
117       TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, result->peer_identity,
118       &peer->properties[index]);
119   if (ok != TSI_OK) {
120     tsi_peer_destruct(peer);
121     gpr_log(GPR_ERROR, "Failed to set tsi peer property");
122   }
123   index++;
124   GPR_ASSERT(&peer->properties[index] != nullptr);
125   ok = tsi_construct_string_peer_property(
126       TSI_ALTS_RPC_VERSIONS,
127       reinterpret_cast<char*>(GRPC_SLICE_START_PTR(result->rpc_versions)),
128       GRPC_SLICE_LENGTH(result->rpc_versions), &peer->properties[index]);
129   if (ok != TSI_OK) {
130     tsi_peer_destruct(peer);
131     gpr_log(GPR_ERROR, "Failed to set tsi peer property");
132   }
133   index++;
134   GPR_ASSERT(&peer->properties[index] != nullptr);
135   ok = tsi_construct_string_peer_property(
136       TSI_ALTS_CONTEXT,
137       reinterpret_cast<char*>(GRPC_SLICE_START_PTR(result->serialized_context)),
138       GRPC_SLICE_LENGTH(result->serialized_context), &peer->properties[index]);
139   if (ok != TSI_OK) {
140     tsi_peer_destruct(peer);
141     gpr_log(GPR_ERROR, "Failed to set tsi peer property");
142   }
143   index++;
144   GPR_ASSERT(&peer->properties[index] != nullptr);
145   ok = tsi_construct_string_peer_property_from_cstring(
146       TSI_SECURITY_LEVEL_PEER_PROPERTY,
147       tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY),
148       &peer->properties[index]);
149   if (ok != TSI_OK) {
150     tsi_peer_destruct(peer);
151     gpr_log(GPR_ERROR, "Failed to set tsi peer property");
152   }
153   GPR_ASSERT(++index == kTsiAltsNumOfPeerProperties);
154   return ok;
155 }
156 
handshaker_result_get_frame_protector_type(const tsi_handshaker_result *,tsi_frame_protector_type * frame_protector_type)157 static tsi_result handshaker_result_get_frame_protector_type(
158     const tsi_handshaker_result* /*self*/,
159     tsi_frame_protector_type* frame_protector_type) {
160   *frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY;
161   return TSI_OK;
162 }
163 
handshaker_result_create_zero_copy_grpc_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_zero_copy_grpc_protector ** protector)164 static tsi_result handshaker_result_create_zero_copy_grpc_protector(
165     const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
166     tsi_zero_copy_grpc_protector** protector) {
167   if (self == nullptr || protector == nullptr) {
168     gpr_log(GPR_ERROR,
169             "Invalid arguments to create_zero_copy_grpc_protector()");
170     return TSI_INVALID_ARGUMENT;
171   }
172   alts_tsi_handshaker_result* result =
173       reinterpret_cast<alts_tsi_handshaker_result*>(
174           const_cast<tsi_handshaker_result*>(self));
175 
176   // In case the peer does not send max frame size (e.g. peer is gRPC Go or
177   // peer uses an old binary), the negotiated frame size is set to
178   // kTsiAltsMinFrameSize (ignoring max_output_protected_frame_size value if
179   // present). Otherwise, it is based on peer and user specified max frame
180   // size (if present).
181   size_t max_frame_size = kTsiAltsMinFrameSize;
182   if (result->max_frame_size) {
183     size_t peer_max_frame_size = result->max_frame_size;
184     max_frame_size = std::min<size_t>(peer_max_frame_size,
185                                       max_output_protected_frame_size == nullptr
186                                           ? kTsiAltsMaxFrameSize
187                                           : *max_output_protected_frame_size);
188     max_frame_size = std::max<size_t>(max_frame_size, kTsiAltsMinFrameSize);
189   }
190   max_output_protected_frame_size = &max_frame_size;
191   gpr_log(GPR_DEBUG,
192           "After Frame Size Negotiation, maximum frame size used by frame "
193           "protector equals %zu",
194           *max_output_protected_frame_size);
195   tsi_result ok = alts_zero_copy_grpc_protector_create(
196       reinterpret_cast<const uint8_t*>(result->key_data),
197       kAltsAes128GcmRekeyKeyLength, /*is_rekey=*/true, result->is_client,
198       /*is_integrity_only=*/false, /*enable_extra_copy=*/false,
199       max_output_protected_frame_size, protector);
200   if (ok != TSI_OK) {
201     gpr_log(GPR_ERROR, "Failed to create zero-copy grpc protector");
202   }
203   return ok;
204 }
205 
handshaker_result_create_frame_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_frame_protector ** protector)206 static tsi_result handshaker_result_create_frame_protector(
207     const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
208     tsi_frame_protector** protector) {
209   if (self == nullptr || protector == nullptr) {
210     gpr_log(GPR_ERROR,
211             "Invalid arguments to handshaker_result_create_frame_protector()");
212     return TSI_INVALID_ARGUMENT;
213   }
214   alts_tsi_handshaker_result* result =
215       reinterpret_cast<alts_tsi_handshaker_result*>(
216           const_cast<tsi_handshaker_result*>(self));
217   tsi_result ok = alts_create_frame_protector(
218       reinterpret_cast<const uint8_t*>(result->key_data),
219       kAltsAes128GcmRekeyKeyLength, result->is_client, /*is_rekey=*/true,
220       max_output_protected_frame_size, protector);
221   if (ok != TSI_OK) {
222     gpr_log(GPR_ERROR, "Failed to create frame protector");
223   }
224   return ok;
225 }
226 
handshaker_result_get_unused_bytes(const tsi_handshaker_result * self,const unsigned char ** bytes,size_t * bytes_size)227 static tsi_result handshaker_result_get_unused_bytes(
228     const tsi_handshaker_result* self, const unsigned char** bytes,
229     size_t* bytes_size) {
230   if (self == nullptr || bytes == nullptr || bytes_size == nullptr) {
231     gpr_log(GPR_ERROR,
232             "Invalid arguments to handshaker_result_get_unused_bytes()");
233     return TSI_INVALID_ARGUMENT;
234   }
235   alts_tsi_handshaker_result* result =
236       reinterpret_cast<alts_tsi_handshaker_result*>(
237           const_cast<tsi_handshaker_result*>(self));
238   *bytes = result->unused_bytes;
239   *bytes_size = result->unused_bytes_size;
240   return TSI_OK;
241 }
242 
handshaker_result_destroy(tsi_handshaker_result * self)243 static void handshaker_result_destroy(tsi_handshaker_result* self) {
244   if (self == nullptr) {
245     return;
246   }
247   alts_tsi_handshaker_result* result =
248       reinterpret_cast<alts_tsi_handshaker_result*>(
249           const_cast<tsi_handshaker_result*>(self));
250   gpr_free(result->peer_identity);
251   gpr_free(result->key_data);
252   gpr_free(result->unused_bytes);
253   grpc_core::CSliceUnref(result->rpc_versions);
254   grpc_core::CSliceUnref(result->serialized_context);
255   gpr_free(result);
256 }
257 
258 static const tsi_handshaker_result_vtable result_vtable = {
259     handshaker_result_extract_peer,
260     handshaker_result_get_frame_protector_type,
261     handshaker_result_create_zero_copy_grpc_protector,
262     handshaker_result_create_frame_protector,
263     handshaker_result_get_unused_bytes,
264     handshaker_result_destroy};
265 
alts_tsi_handshaker_result_create(grpc_gcp_HandshakerResp * resp,bool is_client,tsi_handshaker_result ** result)266 tsi_result alts_tsi_handshaker_result_create(grpc_gcp_HandshakerResp* resp,
267                                              bool is_client,
268                                              tsi_handshaker_result** result) {
269   if (result == nullptr || resp == nullptr) {
270     gpr_log(GPR_ERROR, "Invalid arguments to create_handshaker_result()");
271     return TSI_INVALID_ARGUMENT;
272   }
273   const grpc_gcp_HandshakerResult* hresult =
274       grpc_gcp_HandshakerResp_result(resp);
275   const grpc_gcp_Identity* identity =
276       grpc_gcp_HandshakerResult_peer_identity(hresult);
277   if (identity == nullptr) {
278     gpr_log(GPR_ERROR, "Invalid identity");
279     return TSI_FAILED_PRECONDITION;
280   }
281   upb_StringView peer_service_account =
282       grpc_gcp_Identity_service_account(identity);
283   if (peer_service_account.size == 0) {
284     gpr_log(GPR_ERROR, "Invalid peer service account");
285     return TSI_FAILED_PRECONDITION;
286   }
287   upb_StringView key_data = grpc_gcp_HandshakerResult_key_data(hresult);
288   if (key_data.size < kAltsAes128GcmRekeyKeyLength) {
289     gpr_log(GPR_ERROR, "Bad key length");
290     return TSI_FAILED_PRECONDITION;
291   }
292   const grpc_gcp_RpcProtocolVersions* peer_rpc_version =
293       grpc_gcp_HandshakerResult_peer_rpc_versions(hresult);
294   if (peer_rpc_version == nullptr) {
295     gpr_log(GPR_ERROR, "Peer does not set RPC protocol versions.");
296     return TSI_FAILED_PRECONDITION;
297   }
298   upb_StringView application_protocol =
299       grpc_gcp_HandshakerResult_application_protocol(hresult);
300   if (application_protocol.size == 0) {
301     gpr_log(GPR_ERROR, "Invalid application protocol");
302     return TSI_FAILED_PRECONDITION;
303   }
304   upb_StringView record_protocol =
305       grpc_gcp_HandshakerResult_record_protocol(hresult);
306   if (record_protocol.size == 0) {
307     gpr_log(GPR_ERROR, "Invalid record protocol");
308     return TSI_FAILED_PRECONDITION;
309   }
310   const grpc_gcp_Identity* local_identity =
311       grpc_gcp_HandshakerResult_local_identity(hresult);
312   if (local_identity == nullptr) {
313     gpr_log(GPR_ERROR, "Invalid local identity");
314     return TSI_FAILED_PRECONDITION;
315   }
316   upb_StringView local_service_account =
317       grpc_gcp_Identity_service_account(local_identity);
318   // We don't check if local service account is empty here
319   // because local identity could be empty in certain situations.
320   alts_tsi_handshaker_result* sresult =
321       grpc_core::Zalloc<alts_tsi_handshaker_result>();
322   sresult->key_data =
323       static_cast<char*>(gpr_zalloc(kAltsAes128GcmRekeyKeyLength));
324   memcpy(sresult->key_data, key_data.data, kAltsAes128GcmRekeyKeyLength);
325   sresult->peer_identity =
326       static_cast<char*>(gpr_zalloc(peer_service_account.size + 1));
327   memcpy(sresult->peer_identity, peer_service_account.data,
328          peer_service_account.size);
329   sresult->max_frame_size = grpc_gcp_HandshakerResult_max_frame_size(hresult);
330   upb::Arena rpc_versions_arena;
331   bool serialized = grpc_gcp_rpc_protocol_versions_encode(
332       peer_rpc_version, rpc_versions_arena.ptr(), &sresult->rpc_versions);
333   if (!serialized) {
334     gpr_log(GPR_ERROR, "Failed to serialize peer's RPC protocol versions.");
335     return TSI_FAILED_PRECONDITION;
336   }
337   upb::Arena context_arena;
338   grpc_gcp_AltsContext* context = grpc_gcp_AltsContext_new(context_arena.ptr());
339   grpc_gcp_AltsContext_set_application_protocol(context, application_protocol);
340   grpc_gcp_AltsContext_set_record_protocol(context, record_protocol);
341   // ALTS currently only supports the security level of 2,
342   // which is "grpc_gcp_INTEGRITY_AND_PRIVACY".
343   grpc_gcp_AltsContext_set_security_level(context, 2);
344   grpc_gcp_AltsContext_set_peer_service_account(context, peer_service_account);
345   grpc_gcp_AltsContext_set_local_service_account(context,
346                                                  local_service_account);
347   grpc_gcp_AltsContext_set_peer_rpc_versions(
348       context, const_cast<grpc_gcp_RpcProtocolVersions*>(peer_rpc_version));
349   grpc_gcp_Identity* peer_identity = const_cast<grpc_gcp_Identity*>(identity);
350   if (peer_identity == nullptr) {
351     gpr_log(GPR_ERROR, "Null peer identity in ALTS context.");
352     return TSI_FAILED_PRECONDITION;
353   }
354   if (grpc_gcp_Identity_attributes_size(identity) != 0) {
355     size_t iter = kUpb_Map_Begin;
356     grpc_gcp_Identity_AttributesEntry* peer_attributes_entry =
357         grpc_gcp_Identity_attributes_nextmutable(peer_identity, &iter);
358     while (peer_attributes_entry != nullptr) {
359       upb_StringView key = grpc_gcp_Identity_AttributesEntry_key(
360           const_cast<grpc_gcp_Identity_AttributesEntry*>(
361               peer_attributes_entry));
362       upb_StringView val = grpc_gcp_Identity_AttributesEntry_value(
363           const_cast<grpc_gcp_Identity_AttributesEntry*>(
364               peer_attributes_entry));
365       grpc_gcp_AltsContext_peer_attributes_set(context, key, val,
366                                                context_arena.ptr());
367       peer_attributes_entry =
368           grpc_gcp_Identity_attributes_nextmutable(peer_identity, &iter);
369     }
370   }
371   size_t serialized_ctx_length;
372   char* serialized_ctx = grpc_gcp_AltsContext_serialize(
373       context, context_arena.ptr(), &serialized_ctx_length);
374   if (serialized_ctx == nullptr) {
375     gpr_log(GPR_ERROR, "Failed to serialize peer's ALTS context.");
376     return TSI_FAILED_PRECONDITION;
377   }
378   sresult->serialized_context =
379       grpc_slice_from_copied_buffer(serialized_ctx, serialized_ctx_length);
380   sresult->is_client = is_client;
381   sresult->base.vtable = &result_vtable;
382   *result = &sresult->base;
383   return TSI_OK;
384 }
385 
386 // gRPC provided callback used when gRPC thread model is applied.
on_handshaker_service_resp_recv(void * arg,grpc_error_handle error)387 static void on_handshaker_service_resp_recv(void* arg,
388                                             grpc_error_handle error) {
389   alts_handshaker_client* client = static_cast<alts_handshaker_client*>(arg);
390   if (client == nullptr) {
391     gpr_log(GPR_ERROR, "ALTS handshaker client is nullptr");
392     return;
393   }
394   bool success = true;
395   if (!error.ok()) {
396     gpr_log(GPR_INFO,
397             "ALTS handshaker on_handshaker_service_resp_recv error: %s",
398             grpc_core::StatusToString(error).c_str());
399     success = false;
400   }
401   alts_handshaker_client_handle_response(client, success);
402 }
403 
404 // gRPC provided callback used when dedicatd CQ and thread are used.
405 // It serves to safely bring the control back to application.
on_handshaker_service_resp_recv_dedicated(void * arg,grpc_error_handle)406 static void on_handshaker_service_resp_recv_dedicated(
407     void* arg, grpc_error_handle /*error*/) {
408   alts_shared_resource_dedicated* resource =
409       grpc_alts_get_shared_resource_dedicated();
410   grpc_cq_end_op(
411       resource->cq, arg, absl::OkStatus(),
412       [](void* /*done_arg*/, grpc_cq_completion* /*storage*/) {}, nullptr,
413       &resource->storage);
414 }
415 
416 // Returns TSI_OK if and only if no error is encountered.
alts_tsi_handshaker_continue_handshaker_next(alts_tsi_handshaker * handshaker,const unsigned char * received_bytes,size_t received_bytes_size,tsi_handshaker_on_next_done_cb cb,void * user_data,std::string * error)417 static tsi_result alts_tsi_handshaker_continue_handshaker_next(
418     alts_tsi_handshaker* handshaker, const unsigned char* received_bytes,
419     size_t received_bytes_size, tsi_handshaker_on_next_done_cb cb,
420     void* user_data, std::string* error) {
421   if (!handshaker->has_created_handshaker_client) {
422     if (handshaker->channel == nullptr) {
423       grpc_alts_shared_resource_dedicated_start(
424           handshaker->handshaker_service_url);
425       handshaker->interested_parties =
426           grpc_alts_get_shared_resource_dedicated()->interested_parties;
427       GPR_ASSERT(handshaker->interested_parties != nullptr);
428     }
429     grpc_iomgr_cb_func grpc_cb = handshaker->channel == nullptr
430                                      ? on_handshaker_service_resp_recv_dedicated
431                                      : on_handshaker_service_resp_recv;
432     grpc_channel* channel =
433         handshaker->channel == nullptr
434             ? grpc_alts_get_shared_resource_dedicated()->channel
435             : handshaker->channel;
436     alts_handshaker_client* client = alts_grpc_handshaker_client_create(
437         handshaker, channel, handshaker->handshaker_service_url,
438         handshaker->interested_parties, handshaker->options,
439         handshaker->target_name, grpc_cb, cb, user_data,
440         handshaker->client_vtable_for_testing, handshaker->is_client,
441         handshaker->max_frame_size, error);
442     if (client == nullptr) {
443       gpr_log(GPR_ERROR, "Failed to create ALTS handshaker client");
444       if (error != nullptr) *error = "Failed to create ALTS handshaker client";
445       return TSI_FAILED_PRECONDITION;
446     }
447     {
448       grpc_core::MutexLock lock(&handshaker->mu);
449       GPR_ASSERT(handshaker->client == nullptr);
450       handshaker->client = client;
451       if (handshaker->shutdown) {
452         gpr_log(GPR_INFO, "TSI handshake shutdown");
453         if (error != nullptr) *error = "TSI handshaker shutdown";
454         return TSI_HANDSHAKE_SHUTDOWN;
455       }
456     }
457     handshaker->has_created_handshaker_client = true;
458   }
459   if (handshaker->channel == nullptr &&
460       handshaker->client_vtable_for_testing == nullptr) {
461     GPR_ASSERT(grpc_cq_begin_op(grpc_alts_get_shared_resource_dedicated()->cq,
462                                 handshaker->client));
463   }
464   grpc_slice slice = (received_bytes == nullptr || received_bytes_size == 0)
465                          ? grpc_empty_slice()
466                          : grpc_slice_from_copied_buffer(
467                                reinterpret_cast<const char*>(received_bytes),
468                                received_bytes_size);
469   tsi_result ok = TSI_OK;
470   if (!handshaker->has_sent_start_message) {
471     handshaker->has_sent_start_message = true;
472     ok = handshaker->is_client
473              ? alts_handshaker_client_start_client(handshaker->client)
474              : alts_handshaker_client_start_server(handshaker->client, &slice);
475     // It's unsafe for the current thread to access any state in handshaker
476     // at this point, since alts_handshaker_client_start_client/server
477     // have potentially just started an op batch on the handshake call.
478     // The completion callback for that batch is unsynchronized and so
479     // can invoke the TSI next API callback from any thread, at which point
480     // there is nothing taking ownership of this handshaker to prevent it
481     // from being destroyed.
482   } else {
483     ok = alts_handshaker_client_next(handshaker->client, &slice);
484   }
485   grpc_core::CSliceUnref(slice);
486   return ok;
487 }
488 
489 struct alts_tsi_handshaker_continue_handshaker_next_args {
490   alts_tsi_handshaker* handshaker;
491   std::unique_ptr<unsigned char> received_bytes;
492   size_t received_bytes_size;
493   tsi_handshaker_on_next_done_cb cb;
494   void* user_data;
495   grpc_closure closure;
496   std::string* error = nullptr;
497 };
498 
alts_tsi_handshaker_create_channel(void * arg,grpc_error_handle)499 static void alts_tsi_handshaker_create_channel(
500     void* arg, grpc_error_handle /* unused_error */) {
501   alts_tsi_handshaker_continue_handshaker_next_args* next_args =
502       static_cast<alts_tsi_handshaker_continue_handshaker_next_args*>(arg);
503   alts_tsi_handshaker* handshaker = next_args->handshaker;
504   GPR_ASSERT(handshaker->channel == nullptr);
505   grpc_channel_credentials* creds = grpc_insecure_credentials_create();
506   // Disable retries so that we quickly get a signal when the
507   // handshake server is not reachable.
508   grpc_arg disable_retries_arg = grpc_channel_arg_integer_create(
509       const_cast<char*>(GRPC_ARG_ENABLE_RETRIES), 0);
510   grpc_channel_args args = {1, &disable_retries_arg};
511   handshaker->channel = grpc_channel_create(
512       next_args->handshaker->handshaker_service_url, creds, &args);
513   grpc_channel_credentials_release(creds);
514   tsi_result continue_next_result =
515       alts_tsi_handshaker_continue_handshaker_next(
516           handshaker, next_args->received_bytes.get(),
517           next_args->received_bytes_size, next_args->cb, next_args->user_data,
518           next_args->error);
519   if (continue_next_result != TSI_OK) {
520     next_args->cb(continue_next_result, next_args->user_data, nullptr, 0,
521                   nullptr);
522   }
523   delete next_args;
524 }
525 
handshaker_next(tsi_handshaker * self,const unsigned char * received_bytes,size_t received_bytes_size,const unsigned char **,size_t *,tsi_handshaker_result **,tsi_handshaker_on_next_done_cb cb,void * user_data,std::string * error)526 static tsi_result handshaker_next(
527     tsi_handshaker* self, const unsigned char* received_bytes,
528     size_t received_bytes_size, const unsigned char** /*bytes_to_send*/,
529     size_t* /*bytes_to_send_size*/, tsi_handshaker_result** /*result*/,
530     tsi_handshaker_on_next_done_cb cb, void* user_data, std::string* error) {
531   if (self == nullptr || cb == nullptr) {
532     gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()");
533     if (error != nullptr) *error = "invalid argument";
534     return TSI_INVALID_ARGUMENT;
535   }
536   alts_tsi_handshaker* handshaker =
537       reinterpret_cast<alts_tsi_handshaker*>(self);
538   {
539     grpc_core::MutexLock lock(&handshaker->mu);
540     if (handshaker->shutdown) {
541       gpr_log(GPR_INFO, "TSI handshake shutdown");
542       if (error != nullptr) *error = "handshake shutdown";
543       return TSI_HANDSHAKE_SHUTDOWN;
544     }
545   }
546   if (handshaker->channel == nullptr && !handshaker->use_dedicated_cq) {
547     alts_tsi_handshaker_continue_handshaker_next_args* args =
548         new alts_tsi_handshaker_continue_handshaker_next_args();
549     args->handshaker = handshaker;
550     args->received_bytes = nullptr;
551     args->received_bytes_size = received_bytes_size;
552     args->error = error;
553     if (received_bytes_size > 0) {
554       args->received_bytes = std::unique_ptr<unsigned char>(
555           static_cast<unsigned char*>(gpr_zalloc(received_bytes_size)));
556       memcpy(args->received_bytes.get(), received_bytes, received_bytes_size);
557     }
558     args->cb = cb;
559     args->user_data = user_data;
560     GRPC_CLOSURE_INIT(&args->closure, alts_tsi_handshaker_create_channel, args,
561                       grpc_schedule_on_exec_ctx);
562     // We continue this handshaker_next call at the bottom of the ExecCtx just
563     // so that we can invoke grpc_channel_create at the bottom of the call
564     // stack. Doing so avoids potential lock cycles between g_init_mu and other
565     // mutexes within core that might be held on the current call stack
566     // (note that g_init_mu gets acquired during channel creation).
567     grpc_core::ExecCtx::Run(DEBUG_LOCATION, &args->closure, absl::OkStatus());
568   } else {
569     tsi_result ok = alts_tsi_handshaker_continue_handshaker_next(
570         handshaker, received_bytes, received_bytes_size, cb, user_data, error);
571     if (ok != TSI_OK) {
572       gpr_log(GPR_ERROR, "Failed to schedule ALTS handshaker requests");
573       return ok;
574     }
575   }
576   return TSI_ASYNC;
577 }
578 
579 //
580 // This API will be invoked by a non-gRPC application, and an ExecCtx needs
581 // to be explicitly created in order to invoke ALTS handshaker client API's
582 // that assumes the caller is inside gRPC core.
583 //
handshaker_next_dedicated(tsi_handshaker * self,const unsigned char * received_bytes,size_t received_bytes_size,const unsigned char ** bytes_to_send,size_t * bytes_to_send_size,tsi_handshaker_result ** result,tsi_handshaker_on_next_done_cb cb,void * user_data,std::string * error)584 static tsi_result handshaker_next_dedicated(
585     tsi_handshaker* self, const unsigned char* received_bytes,
586     size_t received_bytes_size, const unsigned char** bytes_to_send,
587     size_t* bytes_to_send_size, tsi_handshaker_result** result,
588     tsi_handshaker_on_next_done_cb cb, void* user_data, std::string* error) {
589   grpc_core::ExecCtx exec_ctx;
590   return handshaker_next(self, received_bytes, received_bytes_size,
591                          bytes_to_send, bytes_to_send_size, result, cb,
592                          user_data, error);
593 }
594 
handshaker_shutdown(tsi_handshaker * self)595 static void handshaker_shutdown(tsi_handshaker* self) {
596   GPR_ASSERT(self != nullptr);
597   alts_tsi_handshaker* handshaker =
598       reinterpret_cast<alts_tsi_handshaker*>(self);
599   grpc_core::MutexLock lock(&handshaker->mu);
600   if (handshaker->shutdown) {
601     return;
602   }
603   if (handshaker->client != nullptr) {
604     alts_handshaker_client_shutdown(handshaker->client);
605   }
606   handshaker->shutdown = true;
607 }
608 
handshaker_destroy(tsi_handshaker * self)609 static void handshaker_destroy(tsi_handshaker* self) {
610   if (self == nullptr) {
611     return;
612   }
613   alts_tsi_handshaker* handshaker =
614       reinterpret_cast<alts_tsi_handshaker*>(self);
615   alts_handshaker_client_destroy(handshaker->client);
616   grpc_core::CSliceUnref(handshaker->target_name);
617   grpc_alts_credentials_options_destroy(handshaker->options);
618   if (handshaker->channel != nullptr) {
619     grpc_channel_destroy_internal(handshaker->channel);
620   }
621   gpr_free(handshaker->handshaker_service_url);
622   delete handshaker;
623 }
624 
625 static const tsi_handshaker_vtable handshaker_vtable = {
626     nullptr,         nullptr,
627     nullptr,         nullptr,
628     nullptr,         handshaker_destroy,
629     handshaker_next, handshaker_shutdown};
630 
631 static const tsi_handshaker_vtable handshaker_vtable_dedicated = {
632     nullptr,
633     nullptr,
634     nullptr,
635     nullptr,
636     nullptr,
637     handshaker_destroy,
638     handshaker_next_dedicated,
639     handshaker_shutdown};
640 
alts_tsi_handshaker_has_shutdown(alts_tsi_handshaker * handshaker)641 bool alts_tsi_handshaker_has_shutdown(alts_tsi_handshaker* handshaker) {
642   GPR_ASSERT(handshaker != nullptr);
643   grpc_core::MutexLock lock(&handshaker->mu);
644   return handshaker->shutdown;
645 }
646 
alts_tsi_handshaker_create(const grpc_alts_credentials_options * options,const char * target_name,const char * handshaker_service_url,bool is_client,grpc_pollset_set * interested_parties,tsi_handshaker ** self,size_t user_specified_max_frame_size)647 tsi_result alts_tsi_handshaker_create(
648     const grpc_alts_credentials_options* options, const char* target_name,
649     const char* handshaker_service_url, bool is_client,
650     grpc_pollset_set* interested_parties, tsi_handshaker** self,
651     size_t user_specified_max_frame_size) {
652   if (handshaker_service_url == nullptr || self == nullptr ||
653       options == nullptr || (is_client && target_name == nullptr)) {
654     gpr_log(GPR_ERROR, "Invalid arguments to alts_tsi_handshaker_create()");
655     return TSI_INVALID_ARGUMENT;
656   }
657   bool use_dedicated_cq = interested_parties == nullptr;
658   alts_tsi_handshaker* handshaker = new alts_tsi_handshaker();
659   memset(&handshaker->base, 0, sizeof(handshaker->base));
660   handshaker->base.vtable =
661       use_dedicated_cq ? &handshaker_vtable_dedicated : &handshaker_vtable;
662   handshaker->target_name = target_name == nullptr
663                                 ? grpc_empty_slice()
664                                 : grpc_slice_from_static_string(target_name);
665   handshaker->is_client = is_client;
666   handshaker->handshaker_service_url = gpr_strdup(handshaker_service_url);
667   handshaker->interested_parties = interested_parties;
668   handshaker->options = grpc_alts_credentials_options_copy(options);
669   handshaker->use_dedicated_cq = use_dedicated_cq;
670   handshaker->max_frame_size = user_specified_max_frame_size != 0
671                                    ? user_specified_max_frame_size
672                                    : kTsiAltsMaxFrameSize;
673   *self = &handshaker->base;
674   return TSI_OK;
675 }
676 
alts_tsi_handshaker_result_set_unused_bytes(tsi_handshaker_result * result,grpc_slice * recv_bytes,size_t bytes_consumed)677 void alts_tsi_handshaker_result_set_unused_bytes(tsi_handshaker_result* result,
678                                                  grpc_slice* recv_bytes,
679                                                  size_t bytes_consumed) {
680   GPR_ASSERT(recv_bytes != nullptr && result != nullptr);
681   if (GRPC_SLICE_LENGTH(*recv_bytes) == bytes_consumed) {
682     return;
683   }
684   alts_tsi_handshaker_result* sresult =
685       reinterpret_cast<alts_tsi_handshaker_result*>(result);
686   sresult->unused_bytes_size = GRPC_SLICE_LENGTH(*recv_bytes) - bytes_consumed;
687   sresult->unused_bytes =
688       static_cast<unsigned char*>(gpr_zalloc(sresult->unused_bytes_size));
689   memcpy(sresult->unused_bytes,
690          GRPC_SLICE_START_PTR(*recv_bytes) + bytes_consumed,
691          sresult->unused_bytes_size);
692 }
693 
694 namespace grpc_core {
695 namespace internal {
696 
alts_tsi_handshaker_get_has_sent_start_message_for_testing(alts_tsi_handshaker * handshaker)697 bool alts_tsi_handshaker_get_has_sent_start_message_for_testing(
698     alts_tsi_handshaker* handshaker) {
699   GPR_ASSERT(handshaker != nullptr);
700   return handshaker->has_sent_start_message;
701 }
702 
alts_tsi_handshaker_set_client_vtable_for_testing(alts_tsi_handshaker * handshaker,alts_handshaker_client_vtable * vtable)703 void alts_tsi_handshaker_set_client_vtable_for_testing(
704     alts_tsi_handshaker* handshaker, alts_handshaker_client_vtable* vtable) {
705   GPR_ASSERT(handshaker != nullptr);
706   handshaker->client_vtable_for_testing = vtable;
707 }
708 
alts_tsi_handshaker_get_is_client_for_testing(alts_tsi_handshaker * handshaker)709 bool alts_tsi_handshaker_get_is_client_for_testing(
710     alts_tsi_handshaker* handshaker) {
711   GPR_ASSERT(handshaker != nullptr);
712   return handshaker->is_client;
713 }
714 
alts_tsi_handshaker_get_client_for_testing(alts_tsi_handshaker * handshaker)715 alts_handshaker_client* alts_tsi_handshaker_get_client_for_testing(
716     alts_tsi_handshaker* handshaker) {
717   return handshaker->client;
718 }
719 
720 }  // namespace internal
721 }  // namespace grpc_core
722