1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include <grpc/support/port_platform.h>
20
21 #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h"
22
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26
27 #include "upb/upb.hpp"
28
29 #include <grpc/grpc_security.h>
30 #include <grpc/support/alloc.h>
31 #include <grpc/support/log.h>
32 #include <grpc/support/string_util.h>
33 #include <grpc/support/sync.h>
34 #include <grpc/support/thd_id.h>
35
36 #include "src/core/lib/gprpp/crash.h"
37 #include "src/core/lib/gprpp/memory.h"
38 #include "src/core/lib/gprpp/sync.h"
39 #include "src/core/lib/gprpp/thd.h"
40 #include "src/core/lib/iomgr/closure.h"
41 #include "src/core/lib/slice/slice_internal.h"
42 #include "src/core/lib/surface/channel.h"
43 #include "src/core/tsi/alts/frame_protector/alts_frame_protector.h"
44 #include "src/core/tsi/alts/handshaker/alts_handshaker_client.h"
45 #include "src/core/tsi/alts/handshaker/alts_shared_resource.h"
46 #include "src/core/tsi/alts/handshaker/alts_tsi_utils.h"
47 #include "src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.h"
48
49 // Main struct for ALTS TSI handshaker.
50 struct alts_tsi_handshaker {
51 tsi_handshaker base;
52 grpc_slice target_name;
53 bool is_client;
54 bool has_sent_start_message = false;
55 bool has_created_handshaker_client = false;
56 char* handshaker_service_url;
57 grpc_pollset_set* interested_parties;
58 grpc_alts_credentials_options* options;
59 alts_handshaker_client_vtable* client_vtable_for_testing = nullptr;
60 grpc_channel* channel = nullptr;
61 bool use_dedicated_cq;
62 // mu synchronizes all fields below. Note these are the
63 // only fields that can be concurrently accessed (due to
64 // potential concurrency of tsi_handshaker_shutdown and
65 // tsi_handshaker_next).
66 grpc_core::Mutex mu;
67 alts_handshaker_client* client = nullptr;
68 // shutdown effectively follows base.handshake_shutdown,
69 // but is synchronized by the mutex of this object.
70 bool shutdown = false;
71 // Maximum frame size used by frame protector.
72 size_t max_frame_size;
73 };
74
75 // Main struct for ALTS TSI handshaker result.
76 typedef struct alts_tsi_handshaker_result {
77 tsi_handshaker_result base;
78 char* peer_identity;
79 char* key_data;
80 unsigned char* unused_bytes;
81 size_t unused_bytes_size;
82 grpc_slice rpc_versions;
83 bool is_client;
84 grpc_slice serialized_context;
85 // Peer's maximum frame size.
86 size_t max_frame_size;
87 } alts_tsi_handshaker_result;
88
handshaker_result_extract_peer(const tsi_handshaker_result * self,tsi_peer * peer)89 static tsi_result handshaker_result_extract_peer(
90 const tsi_handshaker_result* self, tsi_peer* peer) {
91 if (self == nullptr || peer == nullptr) {
92 gpr_log(GPR_ERROR, "Invalid argument to handshaker_result_extract_peer()");
93 return TSI_INVALID_ARGUMENT;
94 }
95 alts_tsi_handshaker_result* result =
96 reinterpret_cast<alts_tsi_handshaker_result*>(
97 const_cast<tsi_handshaker_result*>(self));
98 GPR_ASSERT(kTsiAltsNumOfPeerProperties == 5);
99 tsi_result ok = tsi_construct_peer(kTsiAltsNumOfPeerProperties, peer);
100 int index = 0;
101 if (ok != TSI_OK) {
102 gpr_log(GPR_ERROR, "Failed to construct tsi peer");
103 return ok;
104 }
105 GPR_ASSERT(&peer->properties[index] != nullptr);
106 ok = tsi_construct_string_peer_property_from_cstring(
107 TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE,
108 &peer->properties[index]);
109 if (ok != TSI_OK) {
110 tsi_peer_destruct(peer);
111 gpr_log(GPR_ERROR, "Failed to set tsi peer property");
112 return ok;
113 }
114 index++;
115 GPR_ASSERT(&peer->properties[index] != nullptr);
116 ok = tsi_construct_string_peer_property_from_cstring(
117 TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, result->peer_identity,
118 &peer->properties[index]);
119 if (ok != TSI_OK) {
120 tsi_peer_destruct(peer);
121 gpr_log(GPR_ERROR, "Failed to set tsi peer property");
122 }
123 index++;
124 GPR_ASSERT(&peer->properties[index] != nullptr);
125 ok = tsi_construct_string_peer_property(
126 TSI_ALTS_RPC_VERSIONS,
127 reinterpret_cast<char*>(GRPC_SLICE_START_PTR(result->rpc_versions)),
128 GRPC_SLICE_LENGTH(result->rpc_versions), &peer->properties[index]);
129 if (ok != TSI_OK) {
130 tsi_peer_destruct(peer);
131 gpr_log(GPR_ERROR, "Failed to set tsi peer property");
132 }
133 index++;
134 GPR_ASSERT(&peer->properties[index] != nullptr);
135 ok = tsi_construct_string_peer_property(
136 TSI_ALTS_CONTEXT,
137 reinterpret_cast<char*>(GRPC_SLICE_START_PTR(result->serialized_context)),
138 GRPC_SLICE_LENGTH(result->serialized_context), &peer->properties[index]);
139 if (ok != TSI_OK) {
140 tsi_peer_destruct(peer);
141 gpr_log(GPR_ERROR, "Failed to set tsi peer property");
142 }
143 index++;
144 GPR_ASSERT(&peer->properties[index] != nullptr);
145 ok = tsi_construct_string_peer_property_from_cstring(
146 TSI_SECURITY_LEVEL_PEER_PROPERTY,
147 tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY),
148 &peer->properties[index]);
149 if (ok != TSI_OK) {
150 tsi_peer_destruct(peer);
151 gpr_log(GPR_ERROR, "Failed to set tsi peer property");
152 }
153 GPR_ASSERT(++index == kTsiAltsNumOfPeerProperties);
154 return ok;
155 }
156
handshaker_result_get_frame_protector_type(const tsi_handshaker_result *,tsi_frame_protector_type * frame_protector_type)157 static tsi_result handshaker_result_get_frame_protector_type(
158 const tsi_handshaker_result* /*self*/,
159 tsi_frame_protector_type* frame_protector_type) {
160 *frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY;
161 return TSI_OK;
162 }
163
handshaker_result_create_zero_copy_grpc_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_zero_copy_grpc_protector ** protector)164 static tsi_result handshaker_result_create_zero_copy_grpc_protector(
165 const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
166 tsi_zero_copy_grpc_protector** protector) {
167 if (self == nullptr || protector == nullptr) {
168 gpr_log(GPR_ERROR,
169 "Invalid arguments to create_zero_copy_grpc_protector()");
170 return TSI_INVALID_ARGUMENT;
171 }
172 alts_tsi_handshaker_result* result =
173 reinterpret_cast<alts_tsi_handshaker_result*>(
174 const_cast<tsi_handshaker_result*>(self));
175
176 // In case the peer does not send max frame size (e.g. peer is gRPC Go or
177 // peer uses an old binary), the negotiated frame size is set to
178 // kTsiAltsMinFrameSize (ignoring max_output_protected_frame_size value if
179 // present). Otherwise, it is based on peer and user specified max frame
180 // size (if present).
181 size_t max_frame_size = kTsiAltsMinFrameSize;
182 if (result->max_frame_size) {
183 size_t peer_max_frame_size = result->max_frame_size;
184 max_frame_size = std::min<size_t>(peer_max_frame_size,
185 max_output_protected_frame_size == nullptr
186 ? kTsiAltsMaxFrameSize
187 : *max_output_protected_frame_size);
188 max_frame_size = std::max<size_t>(max_frame_size, kTsiAltsMinFrameSize);
189 }
190 max_output_protected_frame_size = &max_frame_size;
191 gpr_log(GPR_DEBUG,
192 "After Frame Size Negotiation, maximum frame size used by frame "
193 "protector equals %zu",
194 *max_output_protected_frame_size);
195 tsi_result ok = alts_zero_copy_grpc_protector_create(
196 reinterpret_cast<const uint8_t*>(result->key_data),
197 kAltsAes128GcmRekeyKeyLength, /*is_rekey=*/true, result->is_client,
198 /*is_integrity_only=*/false, /*enable_extra_copy=*/false,
199 max_output_protected_frame_size, protector);
200 if (ok != TSI_OK) {
201 gpr_log(GPR_ERROR, "Failed to create zero-copy grpc protector");
202 }
203 return ok;
204 }
205
handshaker_result_create_frame_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_frame_protector ** protector)206 static tsi_result handshaker_result_create_frame_protector(
207 const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
208 tsi_frame_protector** protector) {
209 if (self == nullptr || protector == nullptr) {
210 gpr_log(GPR_ERROR,
211 "Invalid arguments to handshaker_result_create_frame_protector()");
212 return TSI_INVALID_ARGUMENT;
213 }
214 alts_tsi_handshaker_result* result =
215 reinterpret_cast<alts_tsi_handshaker_result*>(
216 const_cast<tsi_handshaker_result*>(self));
217 tsi_result ok = alts_create_frame_protector(
218 reinterpret_cast<const uint8_t*>(result->key_data),
219 kAltsAes128GcmRekeyKeyLength, result->is_client, /*is_rekey=*/true,
220 max_output_protected_frame_size, protector);
221 if (ok != TSI_OK) {
222 gpr_log(GPR_ERROR, "Failed to create frame protector");
223 }
224 return ok;
225 }
226
handshaker_result_get_unused_bytes(const tsi_handshaker_result * self,const unsigned char ** bytes,size_t * bytes_size)227 static tsi_result handshaker_result_get_unused_bytes(
228 const tsi_handshaker_result* self, const unsigned char** bytes,
229 size_t* bytes_size) {
230 if (self == nullptr || bytes == nullptr || bytes_size == nullptr) {
231 gpr_log(GPR_ERROR,
232 "Invalid arguments to handshaker_result_get_unused_bytes()");
233 return TSI_INVALID_ARGUMENT;
234 }
235 alts_tsi_handshaker_result* result =
236 reinterpret_cast<alts_tsi_handshaker_result*>(
237 const_cast<tsi_handshaker_result*>(self));
238 *bytes = result->unused_bytes;
239 *bytes_size = result->unused_bytes_size;
240 return TSI_OK;
241 }
242
handshaker_result_destroy(tsi_handshaker_result * self)243 static void handshaker_result_destroy(tsi_handshaker_result* self) {
244 if (self == nullptr) {
245 return;
246 }
247 alts_tsi_handshaker_result* result =
248 reinterpret_cast<alts_tsi_handshaker_result*>(
249 const_cast<tsi_handshaker_result*>(self));
250 gpr_free(result->peer_identity);
251 gpr_free(result->key_data);
252 gpr_free(result->unused_bytes);
253 grpc_core::CSliceUnref(result->rpc_versions);
254 grpc_core::CSliceUnref(result->serialized_context);
255 gpr_free(result);
256 }
257
258 static const tsi_handshaker_result_vtable result_vtable = {
259 handshaker_result_extract_peer,
260 handshaker_result_get_frame_protector_type,
261 handshaker_result_create_zero_copy_grpc_protector,
262 handshaker_result_create_frame_protector,
263 handshaker_result_get_unused_bytes,
264 handshaker_result_destroy};
265
alts_tsi_handshaker_result_create(grpc_gcp_HandshakerResp * resp,bool is_client,tsi_handshaker_result ** result)266 tsi_result alts_tsi_handshaker_result_create(grpc_gcp_HandshakerResp* resp,
267 bool is_client,
268 tsi_handshaker_result** result) {
269 if (result == nullptr || resp == nullptr) {
270 gpr_log(GPR_ERROR, "Invalid arguments to create_handshaker_result()");
271 return TSI_INVALID_ARGUMENT;
272 }
273 const grpc_gcp_HandshakerResult* hresult =
274 grpc_gcp_HandshakerResp_result(resp);
275 const grpc_gcp_Identity* identity =
276 grpc_gcp_HandshakerResult_peer_identity(hresult);
277 if (identity == nullptr) {
278 gpr_log(GPR_ERROR, "Invalid identity");
279 return TSI_FAILED_PRECONDITION;
280 }
281 upb_StringView peer_service_account =
282 grpc_gcp_Identity_service_account(identity);
283 if (peer_service_account.size == 0) {
284 gpr_log(GPR_ERROR, "Invalid peer service account");
285 return TSI_FAILED_PRECONDITION;
286 }
287 upb_StringView key_data = grpc_gcp_HandshakerResult_key_data(hresult);
288 if (key_data.size < kAltsAes128GcmRekeyKeyLength) {
289 gpr_log(GPR_ERROR, "Bad key length");
290 return TSI_FAILED_PRECONDITION;
291 }
292 const grpc_gcp_RpcProtocolVersions* peer_rpc_version =
293 grpc_gcp_HandshakerResult_peer_rpc_versions(hresult);
294 if (peer_rpc_version == nullptr) {
295 gpr_log(GPR_ERROR, "Peer does not set RPC protocol versions.");
296 return TSI_FAILED_PRECONDITION;
297 }
298 upb_StringView application_protocol =
299 grpc_gcp_HandshakerResult_application_protocol(hresult);
300 if (application_protocol.size == 0) {
301 gpr_log(GPR_ERROR, "Invalid application protocol");
302 return TSI_FAILED_PRECONDITION;
303 }
304 upb_StringView record_protocol =
305 grpc_gcp_HandshakerResult_record_protocol(hresult);
306 if (record_protocol.size == 0) {
307 gpr_log(GPR_ERROR, "Invalid record protocol");
308 return TSI_FAILED_PRECONDITION;
309 }
310 const grpc_gcp_Identity* local_identity =
311 grpc_gcp_HandshakerResult_local_identity(hresult);
312 if (local_identity == nullptr) {
313 gpr_log(GPR_ERROR, "Invalid local identity");
314 return TSI_FAILED_PRECONDITION;
315 }
316 upb_StringView local_service_account =
317 grpc_gcp_Identity_service_account(local_identity);
318 // We don't check if local service account is empty here
319 // because local identity could be empty in certain situations.
320 alts_tsi_handshaker_result* sresult =
321 grpc_core::Zalloc<alts_tsi_handshaker_result>();
322 sresult->key_data =
323 static_cast<char*>(gpr_zalloc(kAltsAes128GcmRekeyKeyLength));
324 memcpy(sresult->key_data, key_data.data, kAltsAes128GcmRekeyKeyLength);
325 sresult->peer_identity =
326 static_cast<char*>(gpr_zalloc(peer_service_account.size + 1));
327 memcpy(sresult->peer_identity, peer_service_account.data,
328 peer_service_account.size);
329 sresult->max_frame_size = grpc_gcp_HandshakerResult_max_frame_size(hresult);
330 upb::Arena rpc_versions_arena;
331 bool serialized = grpc_gcp_rpc_protocol_versions_encode(
332 peer_rpc_version, rpc_versions_arena.ptr(), &sresult->rpc_versions);
333 if (!serialized) {
334 gpr_log(GPR_ERROR, "Failed to serialize peer's RPC protocol versions.");
335 return TSI_FAILED_PRECONDITION;
336 }
337 upb::Arena context_arena;
338 grpc_gcp_AltsContext* context = grpc_gcp_AltsContext_new(context_arena.ptr());
339 grpc_gcp_AltsContext_set_application_protocol(context, application_protocol);
340 grpc_gcp_AltsContext_set_record_protocol(context, record_protocol);
341 // ALTS currently only supports the security level of 2,
342 // which is "grpc_gcp_INTEGRITY_AND_PRIVACY".
343 grpc_gcp_AltsContext_set_security_level(context, 2);
344 grpc_gcp_AltsContext_set_peer_service_account(context, peer_service_account);
345 grpc_gcp_AltsContext_set_local_service_account(context,
346 local_service_account);
347 grpc_gcp_AltsContext_set_peer_rpc_versions(
348 context, const_cast<grpc_gcp_RpcProtocolVersions*>(peer_rpc_version));
349 grpc_gcp_Identity* peer_identity = const_cast<grpc_gcp_Identity*>(identity);
350 if (peer_identity == nullptr) {
351 gpr_log(GPR_ERROR, "Null peer identity in ALTS context.");
352 return TSI_FAILED_PRECONDITION;
353 }
354 if (grpc_gcp_Identity_attributes_size(identity) != 0) {
355 size_t iter = kUpb_Map_Begin;
356 grpc_gcp_Identity_AttributesEntry* peer_attributes_entry =
357 grpc_gcp_Identity_attributes_nextmutable(peer_identity, &iter);
358 while (peer_attributes_entry != nullptr) {
359 upb_StringView key = grpc_gcp_Identity_AttributesEntry_key(
360 const_cast<grpc_gcp_Identity_AttributesEntry*>(
361 peer_attributes_entry));
362 upb_StringView val = grpc_gcp_Identity_AttributesEntry_value(
363 const_cast<grpc_gcp_Identity_AttributesEntry*>(
364 peer_attributes_entry));
365 grpc_gcp_AltsContext_peer_attributes_set(context, key, val,
366 context_arena.ptr());
367 peer_attributes_entry =
368 grpc_gcp_Identity_attributes_nextmutable(peer_identity, &iter);
369 }
370 }
371 size_t serialized_ctx_length;
372 char* serialized_ctx = grpc_gcp_AltsContext_serialize(
373 context, context_arena.ptr(), &serialized_ctx_length);
374 if (serialized_ctx == nullptr) {
375 gpr_log(GPR_ERROR, "Failed to serialize peer's ALTS context.");
376 return TSI_FAILED_PRECONDITION;
377 }
378 sresult->serialized_context =
379 grpc_slice_from_copied_buffer(serialized_ctx, serialized_ctx_length);
380 sresult->is_client = is_client;
381 sresult->base.vtable = &result_vtable;
382 *result = &sresult->base;
383 return TSI_OK;
384 }
385
386 // gRPC provided callback used when gRPC thread model is applied.
on_handshaker_service_resp_recv(void * arg,grpc_error_handle error)387 static void on_handshaker_service_resp_recv(void* arg,
388 grpc_error_handle error) {
389 alts_handshaker_client* client = static_cast<alts_handshaker_client*>(arg);
390 if (client == nullptr) {
391 gpr_log(GPR_ERROR, "ALTS handshaker client is nullptr");
392 return;
393 }
394 bool success = true;
395 if (!error.ok()) {
396 gpr_log(GPR_INFO,
397 "ALTS handshaker on_handshaker_service_resp_recv error: %s",
398 grpc_core::StatusToString(error).c_str());
399 success = false;
400 }
401 alts_handshaker_client_handle_response(client, success);
402 }
403
404 // gRPC provided callback used when dedicatd CQ and thread are used.
405 // It serves to safely bring the control back to application.
on_handshaker_service_resp_recv_dedicated(void * arg,grpc_error_handle)406 static void on_handshaker_service_resp_recv_dedicated(
407 void* arg, grpc_error_handle /*error*/) {
408 alts_shared_resource_dedicated* resource =
409 grpc_alts_get_shared_resource_dedicated();
410 grpc_cq_end_op(
411 resource->cq, arg, absl::OkStatus(),
412 [](void* /*done_arg*/, grpc_cq_completion* /*storage*/) {}, nullptr,
413 &resource->storage);
414 }
415
416 // Returns TSI_OK if and only if no error is encountered.
alts_tsi_handshaker_continue_handshaker_next(alts_tsi_handshaker * handshaker,const unsigned char * received_bytes,size_t received_bytes_size,tsi_handshaker_on_next_done_cb cb,void * user_data,std::string * error)417 static tsi_result alts_tsi_handshaker_continue_handshaker_next(
418 alts_tsi_handshaker* handshaker, const unsigned char* received_bytes,
419 size_t received_bytes_size, tsi_handshaker_on_next_done_cb cb,
420 void* user_data, std::string* error) {
421 if (!handshaker->has_created_handshaker_client) {
422 if (handshaker->channel == nullptr) {
423 grpc_alts_shared_resource_dedicated_start(
424 handshaker->handshaker_service_url);
425 handshaker->interested_parties =
426 grpc_alts_get_shared_resource_dedicated()->interested_parties;
427 GPR_ASSERT(handshaker->interested_parties != nullptr);
428 }
429 grpc_iomgr_cb_func grpc_cb = handshaker->channel == nullptr
430 ? on_handshaker_service_resp_recv_dedicated
431 : on_handshaker_service_resp_recv;
432 grpc_channel* channel =
433 handshaker->channel == nullptr
434 ? grpc_alts_get_shared_resource_dedicated()->channel
435 : handshaker->channel;
436 alts_handshaker_client* client = alts_grpc_handshaker_client_create(
437 handshaker, channel, handshaker->handshaker_service_url,
438 handshaker->interested_parties, handshaker->options,
439 handshaker->target_name, grpc_cb, cb, user_data,
440 handshaker->client_vtable_for_testing, handshaker->is_client,
441 handshaker->max_frame_size, error);
442 if (client == nullptr) {
443 gpr_log(GPR_ERROR, "Failed to create ALTS handshaker client");
444 if (error != nullptr) *error = "Failed to create ALTS handshaker client";
445 return TSI_FAILED_PRECONDITION;
446 }
447 {
448 grpc_core::MutexLock lock(&handshaker->mu);
449 GPR_ASSERT(handshaker->client == nullptr);
450 handshaker->client = client;
451 if (handshaker->shutdown) {
452 gpr_log(GPR_INFO, "TSI handshake shutdown");
453 if (error != nullptr) *error = "TSI handshaker shutdown";
454 return TSI_HANDSHAKE_SHUTDOWN;
455 }
456 }
457 handshaker->has_created_handshaker_client = true;
458 }
459 if (handshaker->channel == nullptr &&
460 handshaker->client_vtable_for_testing == nullptr) {
461 GPR_ASSERT(grpc_cq_begin_op(grpc_alts_get_shared_resource_dedicated()->cq,
462 handshaker->client));
463 }
464 grpc_slice slice = (received_bytes == nullptr || received_bytes_size == 0)
465 ? grpc_empty_slice()
466 : grpc_slice_from_copied_buffer(
467 reinterpret_cast<const char*>(received_bytes),
468 received_bytes_size);
469 tsi_result ok = TSI_OK;
470 if (!handshaker->has_sent_start_message) {
471 handshaker->has_sent_start_message = true;
472 ok = handshaker->is_client
473 ? alts_handshaker_client_start_client(handshaker->client)
474 : alts_handshaker_client_start_server(handshaker->client, &slice);
475 // It's unsafe for the current thread to access any state in handshaker
476 // at this point, since alts_handshaker_client_start_client/server
477 // have potentially just started an op batch on the handshake call.
478 // The completion callback for that batch is unsynchronized and so
479 // can invoke the TSI next API callback from any thread, at which point
480 // there is nothing taking ownership of this handshaker to prevent it
481 // from being destroyed.
482 } else {
483 ok = alts_handshaker_client_next(handshaker->client, &slice);
484 }
485 grpc_core::CSliceUnref(slice);
486 return ok;
487 }
488
489 struct alts_tsi_handshaker_continue_handshaker_next_args {
490 alts_tsi_handshaker* handshaker;
491 std::unique_ptr<unsigned char> received_bytes;
492 size_t received_bytes_size;
493 tsi_handshaker_on_next_done_cb cb;
494 void* user_data;
495 grpc_closure closure;
496 std::string* error = nullptr;
497 };
498
alts_tsi_handshaker_create_channel(void * arg,grpc_error_handle)499 static void alts_tsi_handshaker_create_channel(
500 void* arg, grpc_error_handle /* unused_error */) {
501 alts_tsi_handshaker_continue_handshaker_next_args* next_args =
502 static_cast<alts_tsi_handshaker_continue_handshaker_next_args*>(arg);
503 alts_tsi_handshaker* handshaker = next_args->handshaker;
504 GPR_ASSERT(handshaker->channel == nullptr);
505 grpc_channel_credentials* creds = grpc_insecure_credentials_create();
506 // Disable retries so that we quickly get a signal when the
507 // handshake server is not reachable.
508 grpc_arg disable_retries_arg = grpc_channel_arg_integer_create(
509 const_cast<char*>(GRPC_ARG_ENABLE_RETRIES), 0);
510 grpc_channel_args args = {1, &disable_retries_arg};
511 handshaker->channel = grpc_channel_create(
512 next_args->handshaker->handshaker_service_url, creds, &args);
513 grpc_channel_credentials_release(creds);
514 tsi_result continue_next_result =
515 alts_tsi_handshaker_continue_handshaker_next(
516 handshaker, next_args->received_bytes.get(),
517 next_args->received_bytes_size, next_args->cb, next_args->user_data,
518 next_args->error);
519 if (continue_next_result != TSI_OK) {
520 next_args->cb(continue_next_result, next_args->user_data, nullptr, 0,
521 nullptr);
522 }
523 delete next_args;
524 }
525
handshaker_next(tsi_handshaker * self,const unsigned char * received_bytes,size_t received_bytes_size,const unsigned char **,size_t *,tsi_handshaker_result **,tsi_handshaker_on_next_done_cb cb,void * user_data,std::string * error)526 static tsi_result handshaker_next(
527 tsi_handshaker* self, const unsigned char* received_bytes,
528 size_t received_bytes_size, const unsigned char** /*bytes_to_send*/,
529 size_t* /*bytes_to_send_size*/, tsi_handshaker_result** /*result*/,
530 tsi_handshaker_on_next_done_cb cb, void* user_data, std::string* error) {
531 if (self == nullptr || cb == nullptr) {
532 gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()");
533 if (error != nullptr) *error = "invalid argument";
534 return TSI_INVALID_ARGUMENT;
535 }
536 alts_tsi_handshaker* handshaker =
537 reinterpret_cast<alts_tsi_handshaker*>(self);
538 {
539 grpc_core::MutexLock lock(&handshaker->mu);
540 if (handshaker->shutdown) {
541 gpr_log(GPR_INFO, "TSI handshake shutdown");
542 if (error != nullptr) *error = "handshake shutdown";
543 return TSI_HANDSHAKE_SHUTDOWN;
544 }
545 }
546 if (handshaker->channel == nullptr && !handshaker->use_dedicated_cq) {
547 alts_tsi_handshaker_continue_handshaker_next_args* args =
548 new alts_tsi_handshaker_continue_handshaker_next_args();
549 args->handshaker = handshaker;
550 args->received_bytes = nullptr;
551 args->received_bytes_size = received_bytes_size;
552 args->error = error;
553 if (received_bytes_size > 0) {
554 args->received_bytes = std::unique_ptr<unsigned char>(
555 static_cast<unsigned char*>(gpr_zalloc(received_bytes_size)));
556 memcpy(args->received_bytes.get(), received_bytes, received_bytes_size);
557 }
558 args->cb = cb;
559 args->user_data = user_data;
560 GRPC_CLOSURE_INIT(&args->closure, alts_tsi_handshaker_create_channel, args,
561 grpc_schedule_on_exec_ctx);
562 // We continue this handshaker_next call at the bottom of the ExecCtx just
563 // so that we can invoke grpc_channel_create at the bottom of the call
564 // stack. Doing so avoids potential lock cycles between g_init_mu and other
565 // mutexes within core that might be held on the current call stack
566 // (note that g_init_mu gets acquired during channel creation).
567 grpc_core::ExecCtx::Run(DEBUG_LOCATION, &args->closure, absl::OkStatus());
568 } else {
569 tsi_result ok = alts_tsi_handshaker_continue_handshaker_next(
570 handshaker, received_bytes, received_bytes_size, cb, user_data, error);
571 if (ok != TSI_OK) {
572 gpr_log(GPR_ERROR, "Failed to schedule ALTS handshaker requests");
573 return ok;
574 }
575 }
576 return TSI_ASYNC;
577 }
578
579 //
580 // This API will be invoked by a non-gRPC application, and an ExecCtx needs
581 // to be explicitly created in order to invoke ALTS handshaker client API's
582 // that assumes the caller is inside gRPC core.
583 //
handshaker_next_dedicated(tsi_handshaker * self,const unsigned char * received_bytes,size_t received_bytes_size,const unsigned char ** bytes_to_send,size_t * bytes_to_send_size,tsi_handshaker_result ** result,tsi_handshaker_on_next_done_cb cb,void * user_data,std::string * error)584 static tsi_result handshaker_next_dedicated(
585 tsi_handshaker* self, const unsigned char* received_bytes,
586 size_t received_bytes_size, const unsigned char** bytes_to_send,
587 size_t* bytes_to_send_size, tsi_handshaker_result** result,
588 tsi_handshaker_on_next_done_cb cb, void* user_data, std::string* error) {
589 grpc_core::ExecCtx exec_ctx;
590 return handshaker_next(self, received_bytes, received_bytes_size,
591 bytes_to_send, bytes_to_send_size, result, cb,
592 user_data, error);
593 }
594
handshaker_shutdown(tsi_handshaker * self)595 static void handshaker_shutdown(tsi_handshaker* self) {
596 GPR_ASSERT(self != nullptr);
597 alts_tsi_handshaker* handshaker =
598 reinterpret_cast<alts_tsi_handshaker*>(self);
599 grpc_core::MutexLock lock(&handshaker->mu);
600 if (handshaker->shutdown) {
601 return;
602 }
603 if (handshaker->client != nullptr) {
604 alts_handshaker_client_shutdown(handshaker->client);
605 }
606 handshaker->shutdown = true;
607 }
608
handshaker_destroy(tsi_handshaker * self)609 static void handshaker_destroy(tsi_handshaker* self) {
610 if (self == nullptr) {
611 return;
612 }
613 alts_tsi_handshaker* handshaker =
614 reinterpret_cast<alts_tsi_handshaker*>(self);
615 alts_handshaker_client_destroy(handshaker->client);
616 grpc_core::CSliceUnref(handshaker->target_name);
617 grpc_alts_credentials_options_destroy(handshaker->options);
618 if (handshaker->channel != nullptr) {
619 grpc_channel_destroy_internal(handshaker->channel);
620 }
621 gpr_free(handshaker->handshaker_service_url);
622 delete handshaker;
623 }
624
625 static const tsi_handshaker_vtable handshaker_vtable = {
626 nullptr, nullptr,
627 nullptr, nullptr,
628 nullptr, handshaker_destroy,
629 handshaker_next, handshaker_shutdown};
630
631 static const tsi_handshaker_vtable handshaker_vtable_dedicated = {
632 nullptr,
633 nullptr,
634 nullptr,
635 nullptr,
636 nullptr,
637 handshaker_destroy,
638 handshaker_next_dedicated,
639 handshaker_shutdown};
640
alts_tsi_handshaker_has_shutdown(alts_tsi_handshaker * handshaker)641 bool alts_tsi_handshaker_has_shutdown(alts_tsi_handshaker* handshaker) {
642 GPR_ASSERT(handshaker != nullptr);
643 grpc_core::MutexLock lock(&handshaker->mu);
644 return handshaker->shutdown;
645 }
646
alts_tsi_handshaker_create(const grpc_alts_credentials_options * options,const char * target_name,const char * handshaker_service_url,bool is_client,grpc_pollset_set * interested_parties,tsi_handshaker ** self,size_t user_specified_max_frame_size)647 tsi_result alts_tsi_handshaker_create(
648 const grpc_alts_credentials_options* options, const char* target_name,
649 const char* handshaker_service_url, bool is_client,
650 grpc_pollset_set* interested_parties, tsi_handshaker** self,
651 size_t user_specified_max_frame_size) {
652 if (handshaker_service_url == nullptr || self == nullptr ||
653 options == nullptr || (is_client && target_name == nullptr)) {
654 gpr_log(GPR_ERROR, "Invalid arguments to alts_tsi_handshaker_create()");
655 return TSI_INVALID_ARGUMENT;
656 }
657 bool use_dedicated_cq = interested_parties == nullptr;
658 alts_tsi_handshaker* handshaker = new alts_tsi_handshaker();
659 memset(&handshaker->base, 0, sizeof(handshaker->base));
660 handshaker->base.vtable =
661 use_dedicated_cq ? &handshaker_vtable_dedicated : &handshaker_vtable;
662 handshaker->target_name = target_name == nullptr
663 ? grpc_empty_slice()
664 : grpc_slice_from_static_string(target_name);
665 handshaker->is_client = is_client;
666 handshaker->handshaker_service_url = gpr_strdup(handshaker_service_url);
667 handshaker->interested_parties = interested_parties;
668 handshaker->options = grpc_alts_credentials_options_copy(options);
669 handshaker->use_dedicated_cq = use_dedicated_cq;
670 handshaker->max_frame_size = user_specified_max_frame_size != 0
671 ? user_specified_max_frame_size
672 : kTsiAltsMaxFrameSize;
673 *self = &handshaker->base;
674 return TSI_OK;
675 }
676
alts_tsi_handshaker_result_set_unused_bytes(tsi_handshaker_result * result,grpc_slice * recv_bytes,size_t bytes_consumed)677 void alts_tsi_handshaker_result_set_unused_bytes(tsi_handshaker_result* result,
678 grpc_slice* recv_bytes,
679 size_t bytes_consumed) {
680 GPR_ASSERT(recv_bytes != nullptr && result != nullptr);
681 if (GRPC_SLICE_LENGTH(*recv_bytes) == bytes_consumed) {
682 return;
683 }
684 alts_tsi_handshaker_result* sresult =
685 reinterpret_cast<alts_tsi_handshaker_result*>(result);
686 sresult->unused_bytes_size = GRPC_SLICE_LENGTH(*recv_bytes) - bytes_consumed;
687 sresult->unused_bytes =
688 static_cast<unsigned char*>(gpr_zalloc(sresult->unused_bytes_size));
689 memcpy(sresult->unused_bytes,
690 GRPC_SLICE_START_PTR(*recv_bytes) + bytes_consumed,
691 sresult->unused_bytes_size);
692 }
693
694 namespace grpc_core {
695 namespace internal {
696
alts_tsi_handshaker_get_has_sent_start_message_for_testing(alts_tsi_handshaker * handshaker)697 bool alts_tsi_handshaker_get_has_sent_start_message_for_testing(
698 alts_tsi_handshaker* handshaker) {
699 GPR_ASSERT(handshaker != nullptr);
700 return handshaker->has_sent_start_message;
701 }
702
alts_tsi_handshaker_set_client_vtable_for_testing(alts_tsi_handshaker * handshaker,alts_handshaker_client_vtable * vtable)703 void alts_tsi_handshaker_set_client_vtable_for_testing(
704 alts_tsi_handshaker* handshaker, alts_handshaker_client_vtable* vtable) {
705 GPR_ASSERT(handshaker != nullptr);
706 handshaker->client_vtable_for_testing = vtable;
707 }
708
alts_tsi_handshaker_get_is_client_for_testing(alts_tsi_handshaker * handshaker)709 bool alts_tsi_handshaker_get_is_client_for_testing(
710 alts_tsi_handshaker* handshaker) {
711 GPR_ASSERT(handshaker != nullptr);
712 return handshaker->is_client;
713 }
714
alts_tsi_handshaker_get_client_for_testing(alts_tsi_handshaker * handshaker)715 alts_handshaker_client* alts_tsi_handshaker_get_client_for_testing(
716 alts_tsi_handshaker* handshaker) {
717 return handshaker->client;
718 }
719
720 } // namespace internal
721 } // namespace grpc_core
722