1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include <grpc/support/port_platform.h>
20
21 #include "src/core/tsi/alts/handshaker/alts_handshaker_client.h"
22
23 #include <list>
24
25 #include "absl/strings/numbers.h"
26 #include "upb/upb.hpp"
27
28 #include <grpc/byte_buffer.h>
29 #include <grpc/support/alloc.h>
30 #include <grpc/support/log.h>
31
32 #include "src/core/lib/gprpp/crash.h"
33 #include "src/core/lib/gprpp/env.h"
34 #include "src/core/lib/gprpp/sync.h"
35 #include "src/core/lib/slice/slice_internal.h"
36 #include "src/core/lib/surface/call.h"
37 #include "src/core/lib/surface/channel.h"
38 #include "src/core/tsi/alts/handshaker/alts_shared_resource.h"
39 #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h"
40 #include "src/core/tsi/alts/handshaker/alts_tsi_utils.h"
41
42 #define TSI_ALTS_INITIAL_BUFFER_SIZE 256
43
44 const int kHandshakerClientOpNum = 4;
45 const char kMaxConcurrentStreamsEnvironmentVariable[] =
46 "GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES";
47
48 struct alts_handshaker_client {
49 const alts_handshaker_client_vtable* vtable;
50 };
51
52 struct recv_message_result {
53 tsi_result status;
54 const unsigned char* bytes_to_send;
55 size_t bytes_to_send_size;
56 tsi_handshaker_result* result;
57 };
58
59 typedef struct alts_grpc_handshaker_client {
60 alts_handshaker_client base;
61 // One ref is held by the entity that created this handshaker_client, and
62 // another ref is held by the pending RECEIVE_STATUS_ON_CLIENT op.
63 gpr_refcount refs;
64 alts_tsi_handshaker* handshaker;
65 grpc_call* call;
66 // A pointer to a function handling the interaction with handshaker service.
67 // That is, it points to grpc_call_start_batch_and_execute when the handshaker
68 // client is used in a non-testing use case and points to a custom function
69 // that validates the data to be sent to handshaker service in a testing use
70 // case.
71 alts_grpc_caller grpc_caller;
72 // A gRPC closure to be scheduled when the response from handshaker service
73 // is received. It will be initialized with the injected grpc RPC callback.
74 grpc_closure on_handshaker_service_resp_recv;
75 // Buffers containing information to be sent (or received) to (or from) the
76 // handshaker service.
77 grpc_byte_buffer* send_buffer = nullptr;
78 grpc_byte_buffer* recv_buffer = nullptr;
79 // Used to inject a read failure from tests.
80 bool inject_read_failure = false;
81 // Initial metadata to be received from handshaker service.
82 grpc_metadata_array recv_initial_metadata;
83 // A callback function provided by an application to be invoked when response
84 // is received from handshaker service.
85 tsi_handshaker_on_next_done_cb cb;
86 void* user_data;
87 // ALTS credential options passed in from the caller.
88 grpc_alts_credentials_options* options;
89 // target name information to be passed to handshaker service for server
90 // authorization check.
91 grpc_slice target_name;
92 // boolean flag indicating if the handshaker client is used at client
93 // (is_client = true) or server (is_client = false) side.
94 bool is_client;
95 // a temporary store for data received from handshaker service used to extract
96 // unused data.
97 grpc_slice recv_bytes;
98 // a buffer containing data to be sent to the grpc client or server's peer.
99 unsigned char* buffer;
100 size_t buffer_size;
101 /// callback for receiving handshake call status
102 grpc_closure on_status_received;
103 /// gRPC status code of handshake call
104 grpc_status_code handshake_status_code = GRPC_STATUS_OK;
105 /// gRPC status details of handshake call
106 grpc_slice handshake_status_details;
107 // mu synchronizes all fields below including their internal fields.
108 grpc_core::Mutex mu;
109 // indicates if the handshaker call's RECV_STATUS_ON_CLIENT op is done.
110 bool receive_status_finished = false;
111 // if non-null, contains arguments to complete a TSI next callback.
112 recv_message_result* pending_recv_message_result = nullptr;
113 // Maximum frame size used by frame protector.
114 size_t max_frame_size;
115 // If non-null, will be populated with an error string upon error.
116 std::string* error;
117 } alts_grpc_handshaker_client;
118
handshaker_client_send_buffer_destroy(alts_grpc_handshaker_client * client)119 static void handshaker_client_send_buffer_destroy(
120 alts_grpc_handshaker_client* client) {
121 GPR_ASSERT(client != nullptr);
122 grpc_byte_buffer_destroy(client->send_buffer);
123 client->send_buffer = nullptr;
124 }
125
is_handshake_finished_properly(grpc_gcp_HandshakerResp * resp)126 static bool is_handshake_finished_properly(grpc_gcp_HandshakerResp* resp) {
127 GPR_ASSERT(resp != nullptr);
128 return grpc_gcp_HandshakerResp_result(resp) != nullptr;
129 }
130
alts_grpc_handshaker_client_unref(alts_grpc_handshaker_client * client)131 static void alts_grpc_handshaker_client_unref(
132 alts_grpc_handshaker_client* client) {
133 if (gpr_unref(&client->refs)) {
134 if (client->base.vtable != nullptr &&
135 client->base.vtable->destruct != nullptr) {
136 client->base.vtable->destruct(&client->base);
137 }
138 grpc_byte_buffer_destroy(client->send_buffer);
139 grpc_byte_buffer_destroy(client->recv_buffer);
140 client->send_buffer = nullptr;
141 client->recv_buffer = nullptr;
142 grpc_metadata_array_destroy(&client->recv_initial_metadata);
143 grpc_core::CSliceUnref(client->recv_bytes);
144 grpc_core::CSliceUnref(client->target_name);
145 grpc_alts_credentials_options_destroy(client->options);
146 gpr_free(client->buffer);
147 grpc_core::CSliceUnref(client->handshake_status_details);
148 delete client;
149 }
150 }
151
maybe_complete_tsi_next(alts_grpc_handshaker_client * client,bool receive_status_finished,recv_message_result * pending_recv_message_result)152 static void maybe_complete_tsi_next(
153 alts_grpc_handshaker_client* client, bool receive_status_finished,
154 recv_message_result* pending_recv_message_result) {
155 recv_message_result* r;
156 {
157 grpc_core::MutexLock lock(&client->mu);
158 client->receive_status_finished |= receive_status_finished;
159 if (pending_recv_message_result != nullptr) {
160 GPR_ASSERT(client->pending_recv_message_result == nullptr);
161 client->pending_recv_message_result = pending_recv_message_result;
162 }
163 if (client->pending_recv_message_result == nullptr) {
164 return;
165 }
166 const bool have_final_result =
167 client->pending_recv_message_result->result != nullptr ||
168 client->pending_recv_message_result->status != TSI_OK;
169 if (have_final_result && !client->receive_status_finished) {
170 // If we've received the final message from the handshake
171 // server, or we're about to invoke the TSI next callback
172 // with a status other than TSI_OK (which terminates the
173 // handshake), then first wait for the RECV_STATUS op to complete.
174 return;
175 }
176 r = client->pending_recv_message_result;
177 client->pending_recv_message_result = nullptr;
178 }
179 client->cb(r->status, client->user_data, r->bytes_to_send,
180 r->bytes_to_send_size, r->result);
181 gpr_free(r);
182 }
183
handle_response_done(alts_grpc_handshaker_client * client,tsi_result status,std::string error,const unsigned char * bytes_to_send,size_t bytes_to_send_size,tsi_handshaker_result * result)184 static void handle_response_done(alts_grpc_handshaker_client* client,
185 tsi_result status, std::string error,
186 const unsigned char* bytes_to_send,
187 size_t bytes_to_send_size,
188 tsi_handshaker_result* result) {
189 if (client->error != nullptr) *client->error = std::move(error);
190 recv_message_result* p = grpc_core::Zalloc<recv_message_result>();
191 p->status = status;
192 p->bytes_to_send = bytes_to_send;
193 p->bytes_to_send_size = bytes_to_send_size;
194 p->result = result;
195 maybe_complete_tsi_next(client, false /* receive_status_finished */,
196 p /* pending_recv_message_result */);
197 }
198
alts_handshaker_client_handle_response(alts_handshaker_client * c,bool is_ok)199 void alts_handshaker_client_handle_response(alts_handshaker_client* c,
200 bool is_ok) {
201 GPR_ASSERT(c != nullptr);
202 alts_grpc_handshaker_client* client =
203 reinterpret_cast<alts_grpc_handshaker_client*>(c);
204 grpc_byte_buffer* recv_buffer = client->recv_buffer;
205 alts_tsi_handshaker* handshaker = client->handshaker;
206 // Invalid input check.
207 if (client->cb == nullptr) {
208 gpr_log(GPR_ERROR,
209 "client->cb is nullptr in alts_tsi_handshaker_handle_response()");
210 return;
211 }
212 if (handshaker == nullptr) {
213 gpr_log(GPR_ERROR,
214 "handshaker is nullptr in alts_tsi_handshaker_handle_response()");
215 handle_response_done(
216 client, TSI_INTERNAL_ERROR,
217 "handshaker is nullptr in alts_tsi_handshaker_handle_response()",
218 nullptr, 0, nullptr);
219 return;
220 }
221 // TSI handshake has been shutdown.
222 if (alts_tsi_handshaker_has_shutdown(handshaker)) {
223 gpr_log(GPR_INFO, "TSI handshake shutdown");
224 handle_response_done(client, TSI_HANDSHAKE_SHUTDOWN,
225 "TSI handshake shutdown", nullptr, 0, nullptr);
226 return;
227 }
228 // Check for failed grpc read.
229 if (!is_ok || client->inject_read_failure) {
230 gpr_log(GPR_INFO, "read failed on grpc call to handshaker service");
231 handle_response_done(client, TSI_INTERNAL_ERROR,
232 "read failed on grpc call to handshaker service",
233 nullptr, 0, nullptr);
234 return;
235 }
236 if (recv_buffer == nullptr) {
237 gpr_log(GPR_ERROR,
238 "recv_buffer is nullptr in alts_tsi_handshaker_handle_response()");
239 handle_response_done(
240 client, TSI_INTERNAL_ERROR,
241 "recv_buffer is nullptr in alts_tsi_handshaker_handle_response()",
242 nullptr, 0, nullptr);
243 return;
244 }
245 upb::Arena arena;
246 grpc_gcp_HandshakerResp* resp =
247 alts_tsi_utils_deserialize_response(recv_buffer, arena.ptr());
248 grpc_byte_buffer_destroy(client->recv_buffer);
249 client->recv_buffer = nullptr;
250 // Invalid handshaker response check.
251 if (resp == nullptr) {
252 gpr_log(GPR_ERROR, "alts_tsi_utils_deserialize_response() failed");
253 handle_response_done(client, TSI_DATA_CORRUPTED,
254 "alts_tsi_utils_deserialize_response() failed",
255 nullptr, 0, nullptr);
256 return;
257 }
258 const grpc_gcp_HandshakerStatus* resp_status =
259 grpc_gcp_HandshakerResp_status(resp);
260 if (resp_status == nullptr) {
261 gpr_log(GPR_ERROR, "No status in HandshakerResp");
262 handle_response_done(client, TSI_DATA_CORRUPTED,
263 "No status in HandshakerResp", nullptr, 0, nullptr);
264 return;
265 }
266 upb_StringView out_frames = grpc_gcp_HandshakerResp_out_frames(resp);
267 unsigned char* bytes_to_send = nullptr;
268 size_t bytes_to_send_size = 0;
269 if (out_frames.size > 0) {
270 bytes_to_send_size = out_frames.size;
271 while (bytes_to_send_size > client->buffer_size) {
272 client->buffer_size *= 2;
273 client->buffer = static_cast<unsigned char*>(
274 gpr_realloc(client->buffer, client->buffer_size));
275 }
276 memcpy(client->buffer, out_frames.data, bytes_to_send_size);
277 bytes_to_send = client->buffer;
278 }
279 tsi_handshaker_result* result = nullptr;
280 if (is_handshake_finished_properly(resp)) {
281 tsi_result status =
282 alts_tsi_handshaker_result_create(resp, client->is_client, &result);
283 if (status != TSI_OK) {
284 gpr_log(GPR_ERROR, "alts_tsi_handshaker_result_create() failed");
285 handle_response_done(client, status,
286 "alts_tsi_handshaker_result_create() failed",
287 nullptr, 0, nullptr);
288 return;
289 }
290 alts_tsi_handshaker_result_set_unused_bytes(
291 result, &client->recv_bytes,
292 grpc_gcp_HandshakerResp_bytes_consumed(resp));
293 }
294 grpc_status_code code = static_cast<grpc_status_code>(
295 grpc_gcp_HandshakerStatus_code(resp_status));
296 std::string error;
297 if (code != GRPC_STATUS_OK) {
298 upb_StringView details = grpc_gcp_HandshakerStatus_details(resp_status);
299 if (details.size > 0) {
300 error = absl::StrCat("Status ", code, " from handshaker service: ",
301 absl::string_view(details.data, details.size));
302 gpr_log(GPR_ERROR, "%s", error.c_str());
303 }
304 }
305 // TODO(apolcyn): consider short ciruiting handle_response_done and
306 // invoking the TSI callback directly if we aren't done yet, if
307 // handle_response_done's allocation per message received causes
308 // a performance issue.
309 handle_response_done(client, alts_tsi_utils_convert_to_tsi_result(code),
310 std::move(error), bytes_to_send, bytes_to_send_size,
311 result);
312 }
313
continue_make_grpc_call(alts_grpc_handshaker_client * client,bool is_start)314 static tsi_result continue_make_grpc_call(alts_grpc_handshaker_client* client,
315 bool is_start) {
316 GPR_ASSERT(client != nullptr);
317 grpc_op ops[kHandshakerClientOpNum];
318 memset(ops, 0, sizeof(ops));
319 grpc_op* op = ops;
320 if (is_start) {
321 op->op = GRPC_OP_RECV_STATUS_ON_CLIENT;
322 op->data.recv_status_on_client.trailing_metadata = nullptr;
323 op->data.recv_status_on_client.status = &client->handshake_status_code;
324 op->data.recv_status_on_client.status_details =
325 &client->handshake_status_details;
326 op->flags = 0;
327 op->reserved = nullptr;
328 op++;
329 GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
330 gpr_ref(&client->refs);
331 grpc_call_error call_error =
332 client->grpc_caller(client->call, ops, static_cast<size_t>(op - ops),
333 &client->on_status_received);
334 // TODO(apolcyn): return the error here instead, as done for other ops?
335 GPR_ASSERT(call_error == GRPC_CALL_OK);
336 memset(ops, 0, sizeof(ops));
337 op = ops;
338 op->op = GRPC_OP_SEND_INITIAL_METADATA;
339 op->data.send_initial_metadata.count = 0;
340 op++;
341 GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
342 op->op = GRPC_OP_RECV_INITIAL_METADATA;
343 op->data.recv_initial_metadata.recv_initial_metadata =
344 &client->recv_initial_metadata;
345 op++;
346 GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
347 }
348 op->op = GRPC_OP_SEND_MESSAGE;
349 op->data.send_message.send_message = client->send_buffer;
350 op++;
351 GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
352 op->op = GRPC_OP_RECV_MESSAGE;
353 op->data.recv_message.recv_message = &client->recv_buffer;
354 op++;
355 GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
356 GPR_ASSERT(client->grpc_caller != nullptr);
357 if (client->grpc_caller(client->call, ops, static_cast<size_t>(op - ops),
358 &client->on_handshaker_service_resp_recv) !=
359 GRPC_CALL_OK) {
360 gpr_log(GPR_ERROR, "Start batch operation failed");
361 return TSI_INTERNAL_ERROR;
362 }
363 return TSI_OK;
364 }
365
366 // TODO(apolcyn): remove this global queue when we can safely rely
367 // on a MAX_CONCURRENT_STREAMS setting in the ALTS handshake server to
368 // limit the number of concurrent handshakes.
369 namespace {
370
371 class HandshakeQueue {
372 public:
HandshakeQueue(size_t max_outstanding_handshakes)373 explicit HandshakeQueue(size_t max_outstanding_handshakes)
374 : max_outstanding_handshakes_(max_outstanding_handshakes) {}
375
RequestHandshake(alts_grpc_handshaker_client * client)376 void RequestHandshake(alts_grpc_handshaker_client* client) {
377 {
378 grpc_core::MutexLock lock(&mu_);
379 if (outstanding_handshakes_ == max_outstanding_handshakes_) {
380 // Max number already running, add to queue.
381 queued_handshakes_.push_back(client);
382 return;
383 }
384 // Start the handshake immediately.
385 ++outstanding_handshakes_;
386 }
387 continue_make_grpc_call(client, true /* is_start */);
388 }
389
HandshakeDone()390 void HandshakeDone() {
391 alts_grpc_handshaker_client* client = nullptr;
392 {
393 grpc_core::MutexLock lock(&mu_);
394 if (queued_handshakes_.empty()) {
395 // Nothing more in queue. Decrement count and return immediately.
396 --outstanding_handshakes_;
397 return;
398 }
399 // Remove next entry from queue and start the handshake.
400 client = queued_handshakes_.front();
401 queued_handshakes_.pop_front();
402 }
403 continue_make_grpc_call(client, true /* is_start */);
404 }
405
406 private:
407 grpc_core::Mutex mu_;
408 std::list<alts_grpc_handshaker_client*> queued_handshakes_;
409 size_t outstanding_handshakes_ = 0;
410 const size_t max_outstanding_handshakes_;
411 };
412
413 gpr_once g_queued_handshakes_init = GPR_ONCE_INIT;
414 // Using separate queues for client and server handshakes is a
415 // hack that's mainly intended to satisfy the alts_concurrent_connectivity_test,
416 // which runs many concurrent handshakes where both endpoints
417 // are in the same process; this situation is problematic with a
418 // single queue because we have a high chance of using up all outstanding
419 // slots in the queue, such that there aren't any
420 // mutual client/server handshakes outstanding at the same time and
421 // able to make progress.
422 HandshakeQueue* g_client_handshake_queue;
423 HandshakeQueue* g_server_handshake_queue;
424
DoHandshakeQueuesInit(void)425 void DoHandshakeQueuesInit(void) {
426 const size_t per_queue_max_outstanding_handshakes =
427 MaxNumberOfConcurrentHandshakes();
428 g_client_handshake_queue =
429 new HandshakeQueue(per_queue_max_outstanding_handshakes);
430 g_server_handshake_queue =
431 new HandshakeQueue(per_queue_max_outstanding_handshakes);
432 }
433
RequestHandshake(alts_grpc_handshaker_client * client,bool is_client)434 void RequestHandshake(alts_grpc_handshaker_client* client, bool is_client) {
435 gpr_once_init(&g_queued_handshakes_init, DoHandshakeQueuesInit);
436 HandshakeQueue* queue =
437 is_client ? g_client_handshake_queue : g_server_handshake_queue;
438 queue->RequestHandshake(client);
439 }
440
HandshakeDone(bool is_client)441 void HandshakeDone(bool is_client) {
442 HandshakeQueue* queue =
443 is_client ? g_client_handshake_queue : g_server_handshake_queue;
444 queue->HandshakeDone();
445 }
446
447 }; // namespace
448
449 ///
450 /// Populate grpc operation data with the fields of ALTS handshaker client and
451 /// make a grpc call.
452 ///
make_grpc_call(alts_handshaker_client * c,bool is_start)453 static tsi_result make_grpc_call(alts_handshaker_client* c, bool is_start) {
454 GPR_ASSERT(c != nullptr);
455 alts_grpc_handshaker_client* client =
456 reinterpret_cast<alts_grpc_handshaker_client*>(c);
457 if (is_start) {
458 RequestHandshake(client, client->is_client);
459 return TSI_OK;
460 } else {
461 return continue_make_grpc_call(client, is_start);
462 }
463 }
464
on_status_received(void * arg,grpc_error_handle error)465 static void on_status_received(void* arg, grpc_error_handle error) {
466 alts_grpc_handshaker_client* client =
467 static_cast<alts_grpc_handshaker_client*>(arg);
468 if (client->handshake_status_code != GRPC_STATUS_OK) {
469 // TODO(apolcyn): consider overriding the handshake result's
470 // status from the final ALTS message with the status here.
471 char* status_details =
472 grpc_slice_to_c_string(client->handshake_status_details);
473 gpr_log(GPR_INFO,
474 "alts_grpc_handshaker_client:%p on_status_received "
475 "status:%d details:|%s| error:|%s|",
476 client, client->handshake_status_code, status_details,
477 grpc_core::StatusToString(error).c_str());
478 gpr_free(status_details);
479 }
480 maybe_complete_tsi_next(client, true /* receive_status_finished */,
481 nullptr /* pending_recv_message_result */);
482 HandshakeDone(client->is_client);
483 alts_grpc_handshaker_client_unref(client);
484 }
485
486 // Serializes a grpc_gcp_HandshakerReq message into a buffer and returns newly
487 // grpc_byte_buffer holding it.
get_serialized_handshaker_req(grpc_gcp_HandshakerReq * req,upb_Arena * arena)488 static grpc_byte_buffer* get_serialized_handshaker_req(
489 grpc_gcp_HandshakerReq* req, upb_Arena* arena) {
490 size_t buf_length;
491 char* buf = grpc_gcp_HandshakerReq_serialize(req, arena, &buf_length);
492 if (buf == nullptr) {
493 return nullptr;
494 }
495 grpc_slice slice = grpc_slice_from_copied_buffer(buf, buf_length);
496 grpc_byte_buffer* byte_buffer = grpc_raw_byte_buffer_create(&slice, 1);
497 grpc_core::CSliceUnref(slice);
498 return byte_buffer;
499 }
500
501 // Create and populate a client_start handshaker request, then serialize it.
get_serialized_start_client(alts_handshaker_client * c)502 static grpc_byte_buffer* get_serialized_start_client(
503 alts_handshaker_client* c) {
504 GPR_ASSERT(c != nullptr);
505 alts_grpc_handshaker_client* client =
506 reinterpret_cast<alts_grpc_handshaker_client*>(c);
507 upb::Arena arena;
508 grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr());
509 grpc_gcp_StartClientHandshakeReq* start_client =
510 grpc_gcp_HandshakerReq_mutable_client_start(req, arena.ptr());
511 grpc_gcp_StartClientHandshakeReq_set_handshake_security_protocol(
512 start_client, grpc_gcp_ALTS);
513 grpc_gcp_StartClientHandshakeReq_add_application_protocols(
514 start_client, upb_StringView_FromString(ALTS_APPLICATION_PROTOCOL),
515 arena.ptr());
516 grpc_gcp_StartClientHandshakeReq_add_record_protocols(
517 start_client, upb_StringView_FromString(ALTS_RECORD_PROTOCOL),
518 arena.ptr());
519 grpc_gcp_RpcProtocolVersions* client_version =
520 grpc_gcp_StartClientHandshakeReq_mutable_rpc_versions(start_client,
521 arena.ptr());
522 grpc_gcp_RpcProtocolVersions_assign_from_struct(
523 client_version, arena.ptr(), &client->options->rpc_versions);
524 grpc_gcp_StartClientHandshakeReq_set_target_name(
525 start_client, upb_StringView_FromDataAndSize(
526 reinterpret_cast<const char*>(
527 GRPC_SLICE_START_PTR(client->target_name)),
528 GRPC_SLICE_LENGTH(client->target_name)));
529 target_service_account* ptr =
530 (reinterpret_cast<grpc_alts_credentials_client_options*>(client->options))
531 ->target_account_list_head;
532 while (ptr != nullptr) {
533 grpc_gcp_Identity* target_identity =
534 grpc_gcp_StartClientHandshakeReq_add_target_identities(start_client,
535 arena.ptr());
536 grpc_gcp_Identity_set_service_account(target_identity,
537 upb_StringView_FromString(ptr->data));
538 ptr = ptr->next;
539 }
540 grpc_gcp_StartClientHandshakeReq_set_max_frame_size(
541 start_client, static_cast<uint32_t>(client->max_frame_size));
542 return get_serialized_handshaker_req(req, arena.ptr());
543 }
544
handshaker_client_start_client(alts_handshaker_client * c)545 static tsi_result handshaker_client_start_client(alts_handshaker_client* c) {
546 if (c == nullptr) {
547 gpr_log(GPR_ERROR, "client is nullptr in handshaker_client_start_client()");
548 return TSI_INVALID_ARGUMENT;
549 }
550 grpc_byte_buffer* buffer = get_serialized_start_client(c);
551 alts_grpc_handshaker_client* client =
552 reinterpret_cast<alts_grpc_handshaker_client*>(c);
553 if (buffer == nullptr) {
554 gpr_log(GPR_ERROR, "get_serialized_start_client() failed");
555 return TSI_INTERNAL_ERROR;
556 }
557 handshaker_client_send_buffer_destroy(client);
558 client->send_buffer = buffer;
559 tsi_result result = make_grpc_call(&client->base, true /* is_start */);
560 if (result != TSI_OK) {
561 gpr_log(GPR_ERROR, "make_grpc_call() failed");
562 }
563 return result;
564 }
565
566 // Create and populate a start_server handshaker request, then serialize it.
get_serialized_start_server(alts_handshaker_client * c,grpc_slice * bytes_received)567 static grpc_byte_buffer* get_serialized_start_server(
568 alts_handshaker_client* c, grpc_slice* bytes_received) {
569 GPR_ASSERT(c != nullptr);
570 GPR_ASSERT(bytes_received != nullptr);
571 alts_grpc_handshaker_client* client =
572 reinterpret_cast<alts_grpc_handshaker_client*>(c);
573
574 upb::Arena arena;
575 grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr());
576
577 grpc_gcp_StartServerHandshakeReq* start_server =
578 grpc_gcp_HandshakerReq_mutable_server_start(req, arena.ptr());
579 grpc_gcp_StartServerHandshakeReq_add_application_protocols(
580 start_server, upb_StringView_FromString(ALTS_APPLICATION_PROTOCOL),
581 arena.ptr());
582 grpc_gcp_ServerHandshakeParameters* value =
583 grpc_gcp_ServerHandshakeParameters_new(arena.ptr());
584 grpc_gcp_ServerHandshakeParameters_add_record_protocols(
585 value, upb_StringView_FromString(ALTS_RECORD_PROTOCOL), arena.ptr());
586 grpc_gcp_StartServerHandshakeReq_handshake_parameters_set(
587 start_server, grpc_gcp_ALTS, value, arena.ptr());
588 grpc_gcp_StartServerHandshakeReq_set_in_bytes(
589 start_server,
590 upb_StringView_FromDataAndSize(
591 reinterpret_cast<const char*>(GRPC_SLICE_START_PTR(*bytes_received)),
592 GRPC_SLICE_LENGTH(*bytes_received)));
593 grpc_gcp_RpcProtocolVersions* server_version =
594 grpc_gcp_StartServerHandshakeReq_mutable_rpc_versions(start_server,
595 arena.ptr());
596 grpc_gcp_RpcProtocolVersions_assign_from_struct(
597 server_version, arena.ptr(), &client->options->rpc_versions);
598 grpc_gcp_StartServerHandshakeReq_set_max_frame_size(
599 start_server, static_cast<uint32_t>(client->max_frame_size));
600 return get_serialized_handshaker_req(req, arena.ptr());
601 }
602
handshaker_client_start_server(alts_handshaker_client * c,grpc_slice * bytes_received)603 static tsi_result handshaker_client_start_server(alts_handshaker_client* c,
604 grpc_slice* bytes_received) {
605 if (c == nullptr || bytes_received == nullptr) {
606 gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_start_server()");
607 return TSI_INVALID_ARGUMENT;
608 }
609 alts_grpc_handshaker_client* client =
610 reinterpret_cast<alts_grpc_handshaker_client*>(c);
611 grpc_byte_buffer* buffer = get_serialized_start_server(c, bytes_received);
612 if (buffer == nullptr) {
613 gpr_log(GPR_ERROR, "get_serialized_start_server() failed");
614 return TSI_INTERNAL_ERROR;
615 }
616 handshaker_client_send_buffer_destroy(client);
617 client->send_buffer = buffer;
618 tsi_result result = make_grpc_call(&client->base, true /* is_start */);
619 if (result != TSI_OK) {
620 gpr_log(GPR_ERROR, "make_grpc_call() failed");
621 }
622 return result;
623 }
624
625 // Create and populate a next handshaker request, then serialize it.
get_serialized_next(grpc_slice * bytes_received)626 static grpc_byte_buffer* get_serialized_next(grpc_slice* bytes_received) {
627 GPR_ASSERT(bytes_received != nullptr);
628 upb::Arena arena;
629 grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr());
630 grpc_gcp_NextHandshakeMessageReq* next =
631 grpc_gcp_HandshakerReq_mutable_next(req, arena.ptr());
632 grpc_gcp_NextHandshakeMessageReq_set_in_bytes(
633 next,
634 upb_StringView_FromDataAndSize(
635 reinterpret_cast<const char*> GRPC_SLICE_START_PTR(*bytes_received),
636 GRPC_SLICE_LENGTH(*bytes_received)));
637 return get_serialized_handshaker_req(req, arena.ptr());
638 }
639
handshaker_client_next(alts_handshaker_client * c,grpc_slice * bytes_received)640 static tsi_result handshaker_client_next(alts_handshaker_client* c,
641 grpc_slice* bytes_received) {
642 if (c == nullptr || bytes_received == nullptr) {
643 gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_next()");
644 return TSI_INVALID_ARGUMENT;
645 }
646 alts_grpc_handshaker_client* client =
647 reinterpret_cast<alts_grpc_handshaker_client*>(c);
648 grpc_core::CSliceUnref(client->recv_bytes);
649 client->recv_bytes = grpc_core::CSliceRef(*bytes_received);
650 grpc_byte_buffer* buffer = get_serialized_next(bytes_received);
651 if (buffer == nullptr) {
652 gpr_log(GPR_ERROR, "get_serialized_next() failed");
653 return TSI_INTERNAL_ERROR;
654 }
655 handshaker_client_send_buffer_destroy(client);
656 client->send_buffer = buffer;
657 tsi_result result = make_grpc_call(&client->base, false /* is_start */);
658 if (result != TSI_OK) {
659 gpr_log(GPR_ERROR, "make_grpc_call() failed");
660 }
661 return result;
662 }
663
handshaker_client_shutdown(alts_handshaker_client * c)664 static void handshaker_client_shutdown(alts_handshaker_client* c) {
665 GPR_ASSERT(c != nullptr);
666 alts_grpc_handshaker_client* client =
667 reinterpret_cast<alts_grpc_handshaker_client*>(c);
668 if (client->call != nullptr) {
669 grpc_call_cancel_internal(client->call);
670 }
671 }
672
handshaker_call_unref(void * arg,grpc_error_handle)673 static void handshaker_call_unref(void* arg, grpc_error_handle /* error */) {
674 grpc_call* call = static_cast<grpc_call*>(arg);
675 grpc_call_unref(call);
676 }
677
handshaker_client_destruct(alts_handshaker_client * c)678 static void handshaker_client_destruct(alts_handshaker_client* c) {
679 if (c == nullptr) {
680 return;
681 }
682 alts_grpc_handshaker_client* client =
683 reinterpret_cast<alts_grpc_handshaker_client*>(c);
684 if (client->call != nullptr) {
685 // Throw this grpc_call_unref over to the ExecCtx so that
686 // we invoke it at the bottom of the call stack and
687 // prevent lock inversion problems due to nested ExecCtx flushing.
688 // TODO(apolcyn): we could remove this indirection and call
689 // grpc_call_unref inline if there was an internal variant of
690 // grpc_call_unref that didn't need to flush an ExecCtx.
691 if (grpc_core::ExecCtx::Get() == nullptr) {
692 // Unref handshaker call if there is no exec_ctx, e.g., in the case of
693 // Envoy ALTS transport socket.
694 grpc_call_unref(client->call);
695 } else {
696 // Using existing exec_ctx to unref handshaker call.
697 grpc_core::ExecCtx::Run(
698 DEBUG_LOCATION,
699 GRPC_CLOSURE_CREATE(handshaker_call_unref, client->call,
700 grpc_schedule_on_exec_ctx),
701 absl::OkStatus());
702 }
703 }
704 }
705
706 static const alts_handshaker_client_vtable vtable = {
707 handshaker_client_start_client, handshaker_client_start_server,
708 handshaker_client_next, handshaker_client_shutdown,
709 handshaker_client_destruct};
710
alts_grpc_handshaker_client_create(alts_tsi_handshaker * handshaker,grpc_channel * channel,const char * handshaker_service_url,grpc_pollset_set * interested_parties,grpc_alts_credentials_options * options,const grpc_slice & target_name,grpc_iomgr_cb_func grpc_cb,tsi_handshaker_on_next_done_cb cb,void * user_data,alts_handshaker_client_vtable * vtable_for_testing,bool is_client,size_t max_frame_size,std::string * error)711 alts_handshaker_client* alts_grpc_handshaker_client_create(
712 alts_tsi_handshaker* handshaker, grpc_channel* channel,
713 const char* handshaker_service_url, grpc_pollset_set* interested_parties,
714 grpc_alts_credentials_options* options, const grpc_slice& target_name,
715 grpc_iomgr_cb_func grpc_cb, tsi_handshaker_on_next_done_cb cb,
716 void* user_data, alts_handshaker_client_vtable* vtable_for_testing,
717 bool is_client, size_t max_frame_size, std::string* error) {
718 if (channel == nullptr || handshaker_service_url == nullptr) {
719 gpr_log(GPR_ERROR, "Invalid arguments to alts_handshaker_client_create()");
720 return nullptr;
721 }
722 alts_grpc_handshaker_client* client = new alts_grpc_handshaker_client();
723 memset(&client->base, 0, sizeof(client->base));
724 client->base.vtable =
725 vtable_for_testing == nullptr ? &vtable : vtable_for_testing;
726 gpr_ref_init(&client->refs, 1);
727 client->handshaker = handshaker;
728 client->grpc_caller = grpc_call_start_batch_and_execute;
729 grpc_metadata_array_init(&client->recv_initial_metadata);
730 client->cb = cb;
731 client->user_data = user_data;
732 client->options = grpc_alts_credentials_options_copy(options);
733 client->target_name = grpc_slice_copy(target_name);
734 client->is_client = is_client;
735 client->recv_bytes = grpc_empty_slice();
736 client->buffer_size = TSI_ALTS_INITIAL_BUFFER_SIZE;
737 client->buffer = static_cast<unsigned char*>(gpr_zalloc(client->buffer_size));
738 client->handshake_status_details = grpc_empty_slice();
739 client->max_frame_size = max_frame_size;
740 client->error = error;
741 client->call =
742 strcmp(handshaker_service_url, ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING) ==
743 0
744 ? nullptr
745 : grpc_channel_create_pollset_set_call(
746 channel, nullptr, GRPC_PROPAGATE_DEFAULTS, interested_parties,
747 grpc_slice_from_static_string(ALTS_SERVICE_METHOD), nullptr,
748 grpc_core::Timestamp::InfFuture(), nullptr);
749 GRPC_CLOSURE_INIT(&client->on_handshaker_service_resp_recv, grpc_cb, client,
750 grpc_schedule_on_exec_ctx);
751 GRPC_CLOSURE_INIT(&client->on_status_received, on_status_received, client,
752 grpc_schedule_on_exec_ctx);
753 return &client->base;
754 }
755
756 namespace grpc_core {
757 namespace internal {
758
alts_handshaker_client_set_grpc_caller_for_testing(alts_handshaker_client * c,alts_grpc_caller caller)759 void alts_handshaker_client_set_grpc_caller_for_testing(
760 alts_handshaker_client* c, alts_grpc_caller caller) {
761 GPR_ASSERT(c != nullptr && caller != nullptr);
762 alts_grpc_handshaker_client* client =
763 reinterpret_cast<alts_grpc_handshaker_client*>(c);
764 client->grpc_caller = caller;
765 }
766
alts_handshaker_client_get_send_buffer_for_testing(alts_handshaker_client * c)767 grpc_byte_buffer* alts_handshaker_client_get_send_buffer_for_testing(
768 alts_handshaker_client* c) {
769 GPR_ASSERT(c != nullptr);
770 alts_grpc_handshaker_client* client =
771 reinterpret_cast<alts_grpc_handshaker_client*>(c);
772 return client->send_buffer;
773 }
774
alts_handshaker_client_get_recv_buffer_addr_for_testing(alts_handshaker_client * c)775 grpc_byte_buffer** alts_handshaker_client_get_recv_buffer_addr_for_testing(
776 alts_handshaker_client* c) {
777 GPR_ASSERT(c != nullptr);
778 alts_grpc_handshaker_client* client =
779 reinterpret_cast<alts_grpc_handshaker_client*>(c);
780 return &client->recv_buffer;
781 }
782
alts_handshaker_client_get_initial_metadata_for_testing(alts_handshaker_client * c)783 grpc_metadata_array* alts_handshaker_client_get_initial_metadata_for_testing(
784 alts_handshaker_client* c) {
785 GPR_ASSERT(c != nullptr);
786 alts_grpc_handshaker_client* client =
787 reinterpret_cast<alts_grpc_handshaker_client*>(c);
788 return &client->recv_initial_metadata;
789 }
790
alts_handshaker_client_set_recv_bytes_for_testing(alts_handshaker_client * c,grpc_slice * recv_bytes)791 void alts_handshaker_client_set_recv_bytes_for_testing(
792 alts_handshaker_client* c, grpc_slice* recv_bytes) {
793 GPR_ASSERT(c != nullptr);
794 alts_grpc_handshaker_client* client =
795 reinterpret_cast<alts_grpc_handshaker_client*>(c);
796 client->recv_bytes = CSliceRef(*recv_bytes);
797 }
798
alts_handshaker_client_set_fields_for_testing(alts_handshaker_client * c,alts_tsi_handshaker * handshaker,tsi_handshaker_on_next_done_cb cb,void * user_data,grpc_byte_buffer * recv_buffer,bool inject_read_failure)799 void alts_handshaker_client_set_fields_for_testing(
800 alts_handshaker_client* c, alts_tsi_handshaker* handshaker,
801 tsi_handshaker_on_next_done_cb cb, void* user_data,
802 grpc_byte_buffer* recv_buffer, bool inject_read_failure) {
803 GPR_ASSERT(c != nullptr);
804 alts_grpc_handshaker_client* client =
805 reinterpret_cast<alts_grpc_handshaker_client*>(c);
806 client->handshaker = handshaker;
807 client->cb = cb;
808 client->user_data = user_data;
809 client->recv_buffer = recv_buffer;
810 client->inject_read_failure = inject_read_failure;
811 }
812
alts_handshaker_client_check_fields_for_testing(alts_handshaker_client * c,tsi_handshaker_on_next_done_cb cb,void * user_data,bool has_sent_start_message,grpc_slice * recv_bytes)813 void alts_handshaker_client_check_fields_for_testing(
814 alts_handshaker_client* c, tsi_handshaker_on_next_done_cb cb,
815 void* user_data, bool has_sent_start_message, grpc_slice* recv_bytes) {
816 GPR_ASSERT(c != nullptr);
817 alts_grpc_handshaker_client* client =
818 reinterpret_cast<alts_grpc_handshaker_client*>(c);
819 GPR_ASSERT(client->cb == cb);
820 GPR_ASSERT(client->user_data == user_data);
821 if (recv_bytes != nullptr) {
822 GPR_ASSERT(grpc_slice_cmp(client->recv_bytes, *recv_bytes) == 0);
823 }
824 GPR_ASSERT(alts_tsi_handshaker_get_has_sent_start_message_for_testing(
825 client->handshaker) == has_sent_start_message);
826 }
827
alts_handshaker_client_set_vtable_for_testing(alts_handshaker_client * c,alts_handshaker_client_vtable * vtable)828 void alts_handshaker_client_set_vtable_for_testing(
829 alts_handshaker_client* c, alts_handshaker_client_vtable* vtable) {
830 GPR_ASSERT(c != nullptr);
831 GPR_ASSERT(vtable != nullptr);
832 alts_grpc_handshaker_client* client =
833 reinterpret_cast<alts_grpc_handshaker_client*>(c);
834 client->base.vtable = vtable;
835 }
836
alts_handshaker_client_get_handshaker_for_testing(alts_handshaker_client * c)837 alts_tsi_handshaker* alts_handshaker_client_get_handshaker_for_testing(
838 alts_handshaker_client* c) {
839 GPR_ASSERT(c != nullptr);
840 alts_grpc_handshaker_client* client =
841 reinterpret_cast<alts_grpc_handshaker_client*>(c);
842 return client->handshaker;
843 }
844
alts_handshaker_client_set_cb_for_testing(alts_handshaker_client * c,tsi_handshaker_on_next_done_cb cb)845 void alts_handshaker_client_set_cb_for_testing(
846 alts_handshaker_client* c, tsi_handshaker_on_next_done_cb cb) {
847 GPR_ASSERT(c != nullptr);
848 alts_grpc_handshaker_client* client =
849 reinterpret_cast<alts_grpc_handshaker_client*>(c);
850 client->cb = cb;
851 }
852
alts_handshaker_client_get_closure_for_testing(alts_handshaker_client * c)853 grpc_closure* alts_handshaker_client_get_closure_for_testing(
854 alts_handshaker_client* c) {
855 GPR_ASSERT(c != nullptr);
856 alts_grpc_handshaker_client* client =
857 reinterpret_cast<alts_grpc_handshaker_client*>(c);
858 return &client->on_handshaker_service_resp_recv;
859 }
860
alts_handshaker_client_ref_for_testing(alts_handshaker_client * c)861 void alts_handshaker_client_ref_for_testing(alts_handshaker_client* c) {
862 alts_grpc_handshaker_client* client =
863 reinterpret_cast<alts_grpc_handshaker_client*>(c);
864 gpr_ref(&client->refs);
865 }
866
alts_handshaker_client_on_status_received_for_testing(alts_handshaker_client * c,grpc_status_code status,grpc_error_handle error)867 void alts_handshaker_client_on_status_received_for_testing(
868 alts_handshaker_client* c, grpc_status_code status,
869 grpc_error_handle error) {
870 // We first make sure that the handshake queue has been initialized
871 // here because there are tests that use this API that mock out
872 // other parts of the alts_handshaker_client in such a way that the
873 // code path that would normally ensure that the handshake queue
874 // has been initialized isn't taken.
875 gpr_once_init(&g_queued_handshakes_init, DoHandshakeQueuesInit);
876 alts_grpc_handshaker_client* client =
877 reinterpret_cast<alts_grpc_handshaker_client*>(c);
878 client->handshake_status_code = status;
879 client->handshake_status_details = grpc_empty_slice();
880 Closure::Run(DEBUG_LOCATION, &client->on_status_received, error);
881 }
882
883 } // namespace internal
884 } // namespace grpc_core
885
alts_handshaker_client_start_client(alts_handshaker_client * client)886 tsi_result alts_handshaker_client_start_client(alts_handshaker_client* client) {
887 if (client != nullptr && client->vtable != nullptr &&
888 client->vtable->client_start != nullptr) {
889 return client->vtable->client_start(client);
890 }
891 gpr_log(GPR_ERROR,
892 "client or client->vtable has not been initialized properly");
893 return TSI_INVALID_ARGUMENT;
894 }
895
alts_handshaker_client_start_server(alts_handshaker_client * client,grpc_slice * bytes_received)896 tsi_result alts_handshaker_client_start_server(alts_handshaker_client* client,
897 grpc_slice* bytes_received) {
898 if (client != nullptr && client->vtable != nullptr &&
899 client->vtable->server_start != nullptr) {
900 return client->vtable->server_start(client, bytes_received);
901 }
902 gpr_log(GPR_ERROR,
903 "client or client->vtable has not been initialized properly");
904 return TSI_INVALID_ARGUMENT;
905 }
906
alts_handshaker_client_next(alts_handshaker_client * client,grpc_slice * bytes_received)907 tsi_result alts_handshaker_client_next(alts_handshaker_client* client,
908 grpc_slice* bytes_received) {
909 if (client != nullptr && client->vtable != nullptr &&
910 client->vtable->next != nullptr) {
911 return client->vtable->next(client, bytes_received);
912 }
913 gpr_log(GPR_ERROR,
914 "client or client->vtable has not been initialized properly");
915 return TSI_INVALID_ARGUMENT;
916 }
917
alts_handshaker_client_shutdown(alts_handshaker_client * client)918 void alts_handshaker_client_shutdown(alts_handshaker_client* client) {
919 if (client != nullptr && client->vtable != nullptr &&
920 client->vtable->shutdown != nullptr) {
921 client->vtable->shutdown(client);
922 }
923 }
924
alts_handshaker_client_destroy(alts_handshaker_client * c)925 void alts_handshaker_client_destroy(alts_handshaker_client* c) {
926 if (c != nullptr) {
927 alts_grpc_handshaker_client* client =
928 reinterpret_cast<alts_grpc_handshaker_client*>(c);
929 alts_grpc_handshaker_client_unref(client);
930 }
931 }
932
MaxNumberOfConcurrentHandshakes()933 size_t MaxNumberOfConcurrentHandshakes() {
934 size_t max_concurrent_handshakes = 40;
935 absl::optional<std::string> env_var_max_concurrent_handshakes =
936 grpc_core::GetEnv(kMaxConcurrentStreamsEnvironmentVariable);
937 if (env_var_max_concurrent_handshakes.has_value()) {
938 size_t effective_max_concurrent_handshakes = 40;
939 if (absl::SimpleAtoi(*env_var_max_concurrent_handshakes,
940 &effective_max_concurrent_handshakes)) {
941 max_concurrent_handshakes = effective_max_concurrent_handshakes;
942 }
943 }
944 return max_concurrent_handshakes;
945 }
946