1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/tsi/alts/handshaker/alts_handshaker_client.h"
22 
23 #include <list>
24 
25 #include "absl/strings/numbers.h"
26 #include "upb/upb.hpp"
27 
28 #include <grpc/byte_buffer.h>
29 #include <grpc/support/alloc.h>
30 #include <grpc/support/log.h>
31 
32 #include "src/core/lib/gprpp/crash.h"
33 #include "src/core/lib/gprpp/env.h"
34 #include "src/core/lib/gprpp/sync.h"
35 #include "src/core/lib/slice/slice_internal.h"
36 #include "src/core/lib/surface/call.h"
37 #include "src/core/lib/surface/channel.h"
38 #include "src/core/tsi/alts/handshaker/alts_shared_resource.h"
39 #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h"
40 #include "src/core/tsi/alts/handshaker/alts_tsi_utils.h"
41 
42 #define TSI_ALTS_INITIAL_BUFFER_SIZE 256
43 
44 const int kHandshakerClientOpNum = 4;
45 const char kMaxConcurrentStreamsEnvironmentVariable[] =
46     "GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES";
47 
48 struct alts_handshaker_client {
49   const alts_handshaker_client_vtable* vtable;
50 };
51 
52 struct recv_message_result {
53   tsi_result status;
54   const unsigned char* bytes_to_send;
55   size_t bytes_to_send_size;
56   tsi_handshaker_result* result;
57 };
58 
59 typedef struct alts_grpc_handshaker_client {
60   alts_handshaker_client base;
61   // One ref is held by the entity that created this handshaker_client, and
62   // another ref is held by the pending RECEIVE_STATUS_ON_CLIENT op.
63   gpr_refcount refs;
64   alts_tsi_handshaker* handshaker;
65   grpc_call* call;
66   // A pointer to a function handling the interaction with handshaker service.
67   // That is, it points to grpc_call_start_batch_and_execute when the handshaker
68   // client is used in a non-testing use case and points to a custom function
69   // that validates the data to be sent to handshaker service in a testing use
70   // case.
71   alts_grpc_caller grpc_caller;
72   // A gRPC closure to be scheduled when the response from handshaker service
73   // is received. It will be initialized with the injected grpc RPC callback.
74   grpc_closure on_handshaker_service_resp_recv;
75   // Buffers containing information to be sent (or received) to (or from) the
76   // handshaker service.
77   grpc_byte_buffer* send_buffer = nullptr;
78   grpc_byte_buffer* recv_buffer = nullptr;
79   // Used to inject a read failure from tests.
80   bool inject_read_failure = false;
81   // Initial metadata to be received from handshaker service.
82   grpc_metadata_array recv_initial_metadata;
83   // A callback function provided by an application to be invoked when response
84   // is received from handshaker service.
85   tsi_handshaker_on_next_done_cb cb;
86   void* user_data;
87   // ALTS credential options passed in from the caller.
88   grpc_alts_credentials_options* options;
89   // target name information to be passed to handshaker service for server
90   // authorization check.
91   grpc_slice target_name;
92   // boolean flag indicating if the handshaker client is used at client
93   // (is_client = true) or server (is_client = false) side.
94   bool is_client;
95   // a temporary store for data received from handshaker service used to extract
96   // unused data.
97   grpc_slice recv_bytes;
98   // a buffer containing data to be sent to the grpc client or server's peer.
99   unsigned char* buffer;
100   size_t buffer_size;
101   /// callback for receiving handshake call status
102   grpc_closure on_status_received;
103   /// gRPC status code of handshake call
104   grpc_status_code handshake_status_code = GRPC_STATUS_OK;
105   /// gRPC status details of handshake call
106   grpc_slice handshake_status_details;
107   // mu synchronizes all fields below including their internal fields.
108   grpc_core::Mutex mu;
109   // indicates if the handshaker call's RECV_STATUS_ON_CLIENT op is done.
110   bool receive_status_finished = false;
111   // if non-null, contains arguments to complete a TSI next callback.
112   recv_message_result* pending_recv_message_result = nullptr;
113   // Maximum frame size used by frame protector.
114   size_t max_frame_size;
115   // If non-null, will be populated with an error string upon error.
116   std::string* error;
117 } alts_grpc_handshaker_client;
118 
handshaker_client_send_buffer_destroy(alts_grpc_handshaker_client * client)119 static void handshaker_client_send_buffer_destroy(
120     alts_grpc_handshaker_client* client) {
121   GPR_ASSERT(client != nullptr);
122   grpc_byte_buffer_destroy(client->send_buffer);
123   client->send_buffer = nullptr;
124 }
125 
is_handshake_finished_properly(grpc_gcp_HandshakerResp * resp)126 static bool is_handshake_finished_properly(grpc_gcp_HandshakerResp* resp) {
127   GPR_ASSERT(resp != nullptr);
128   return grpc_gcp_HandshakerResp_result(resp) != nullptr;
129 }
130 
alts_grpc_handshaker_client_unref(alts_grpc_handshaker_client * client)131 static void alts_grpc_handshaker_client_unref(
132     alts_grpc_handshaker_client* client) {
133   if (gpr_unref(&client->refs)) {
134     if (client->base.vtable != nullptr &&
135         client->base.vtable->destruct != nullptr) {
136       client->base.vtable->destruct(&client->base);
137     }
138     grpc_byte_buffer_destroy(client->send_buffer);
139     grpc_byte_buffer_destroy(client->recv_buffer);
140     client->send_buffer = nullptr;
141     client->recv_buffer = nullptr;
142     grpc_metadata_array_destroy(&client->recv_initial_metadata);
143     grpc_core::CSliceUnref(client->recv_bytes);
144     grpc_core::CSliceUnref(client->target_name);
145     grpc_alts_credentials_options_destroy(client->options);
146     gpr_free(client->buffer);
147     grpc_core::CSliceUnref(client->handshake_status_details);
148     delete client;
149   }
150 }
151 
maybe_complete_tsi_next(alts_grpc_handshaker_client * client,bool receive_status_finished,recv_message_result * pending_recv_message_result)152 static void maybe_complete_tsi_next(
153     alts_grpc_handshaker_client* client, bool receive_status_finished,
154     recv_message_result* pending_recv_message_result) {
155   recv_message_result* r;
156   {
157     grpc_core::MutexLock lock(&client->mu);
158     client->receive_status_finished |= receive_status_finished;
159     if (pending_recv_message_result != nullptr) {
160       GPR_ASSERT(client->pending_recv_message_result == nullptr);
161       client->pending_recv_message_result = pending_recv_message_result;
162     }
163     if (client->pending_recv_message_result == nullptr) {
164       return;
165     }
166     const bool have_final_result =
167         client->pending_recv_message_result->result != nullptr ||
168         client->pending_recv_message_result->status != TSI_OK;
169     if (have_final_result && !client->receive_status_finished) {
170       // If we've received the final message from the handshake
171       // server, or we're about to invoke the TSI next callback
172       // with a status other than TSI_OK (which terminates the
173       // handshake), then first wait for the RECV_STATUS op to complete.
174       return;
175     }
176     r = client->pending_recv_message_result;
177     client->pending_recv_message_result = nullptr;
178   }
179   client->cb(r->status, client->user_data, r->bytes_to_send,
180              r->bytes_to_send_size, r->result);
181   gpr_free(r);
182 }
183 
handle_response_done(alts_grpc_handshaker_client * client,tsi_result status,std::string error,const unsigned char * bytes_to_send,size_t bytes_to_send_size,tsi_handshaker_result * result)184 static void handle_response_done(alts_grpc_handshaker_client* client,
185                                  tsi_result status, std::string error,
186                                  const unsigned char* bytes_to_send,
187                                  size_t bytes_to_send_size,
188                                  tsi_handshaker_result* result) {
189   if (client->error != nullptr) *client->error = std::move(error);
190   recv_message_result* p = grpc_core::Zalloc<recv_message_result>();
191   p->status = status;
192   p->bytes_to_send = bytes_to_send;
193   p->bytes_to_send_size = bytes_to_send_size;
194   p->result = result;
195   maybe_complete_tsi_next(client, false /* receive_status_finished */,
196                           p /* pending_recv_message_result */);
197 }
198 
alts_handshaker_client_handle_response(alts_handshaker_client * c,bool is_ok)199 void alts_handshaker_client_handle_response(alts_handshaker_client* c,
200                                             bool is_ok) {
201   GPR_ASSERT(c != nullptr);
202   alts_grpc_handshaker_client* client =
203       reinterpret_cast<alts_grpc_handshaker_client*>(c);
204   grpc_byte_buffer* recv_buffer = client->recv_buffer;
205   alts_tsi_handshaker* handshaker = client->handshaker;
206   // Invalid input check.
207   if (client->cb == nullptr) {
208     gpr_log(GPR_ERROR,
209             "client->cb is nullptr in alts_tsi_handshaker_handle_response()");
210     return;
211   }
212   if (handshaker == nullptr) {
213     gpr_log(GPR_ERROR,
214             "handshaker is nullptr in alts_tsi_handshaker_handle_response()");
215     handle_response_done(
216         client, TSI_INTERNAL_ERROR,
217         "handshaker is nullptr in alts_tsi_handshaker_handle_response()",
218         nullptr, 0, nullptr);
219     return;
220   }
221   // TSI handshake has been shutdown.
222   if (alts_tsi_handshaker_has_shutdown(handshaker)) {
223     gpr_log(GPR_INFO, "TSI handshake shutdown");
224     handle_response_done(client, TSI_HANDSHAKE_SHUTDOWN,
225                          "TSI handshake shutdown", nullptr, 0, nullptr);
226     return;
227   }
228   // Check for failed grpc read.
229   if (!is_ok || client->inject_read_failure) {
230     gpr_log(GPR_INFO, "read failed on grpc call to handshaker service");
231     handle_response_done(client, TSI_INTERNAL_ERROR,
232                          "read failed on grpc call to handshaker service",
233                          nullptr, 0, nullptr);
234     return;
235   }
236   if (recv_buffer == nullptr) {
237     gpr_log(GPR_ERROR,
238             "recv_buffer is nullptr in alts_tsi_handshaker_handle_response()");
239     handle_response_done(
240         client, TSI_INTERNAL_ERROR,
241         "recv_buffer is nullptr in alts_tsi_handshaker_handle_response()",
242         nullptr, 0, nullptr);
243     return;
244   }
245   upb::Arena arena;
246   grpc_gcp_HandshakerResp* resp =
247       alts_tsi_utils_deserialize_response(recv_buffer, arena.ptr());
248   grpc_byte_buffer_destroy(client->recv_buffer);
249   client->recv_buffer = nullptr;
250   // Invalid handshaker response check.
251   if (resp == nullptr) {
252     gpr_log(GPR_ERROR, "alts_tsi_utils_deserialize_response() failed");
253     handle_response_done(client, TSI_DATA_CORRUPTED,
254                          "alts_tsi_utils_deserialize_response() failed",
255                          nullptr, 0, nullptr);
256     return;
257   }
258   const grpc_gcp_HandshakerStatus* resp_status =
259       grpc_gcp_HandshakerResp_status(resp);
260   if (resp_status == nullptr) {
261     gpr_log(GPR_ERROR, "No status in HandshakerResp");
262     handle_response_done(client, TSI_DATA_CORRUPTED,
263                          "No status in HandshakerResp", nullptr, 0, nullptr);
264     return;
265   }
266   upb_StringView out_frames = grpc_gcp_HandshakerResp_out_frames(resp);
267   unsigned char* bytes_to_send = nullptr;
268   size_t bytes_to_send_size = 0;
269   if (out_frames.size > 0) {
270     bytes_to_send_size = out_frames.size;
271     while (bytes_to_send_size > client->buffer_size) {
272       client->buffer_size *= 2;
273       client->buffer = static_cast<unsigned char*>(
274           gpr_realloc(client->buffer, client->buffer_size));
275     }
276     memcpy(client->buffer, out_frames.data, bytes_to_send_size);
277     bytes_to_send = client->buffer;
278   }
279   tsi_handshaker_result* result = nullptr;
280   if (is_handshake_finished_properly(resp)) {
281     tsi_result status =
282         alts_tsi_handshaker_result_create(resp, client->is_client, &result);
283     if (status != TSI_OK) {
284       gpr_log(GPR_ERROR, "alts_tsi_handshaker_result_create() failed");
285       handle_response_done(client, status,
286                            "alts_tsi_handshaker_result_create() failed",
287                            nullptr, 0, nullptr);
288       return;
289     }
290     alts_tsi_handshaker_result_set_unused_bytes(
291         result, &client->recv_bytes,
292         grpc_gcp_HandshakerResp_bytes_consumed(resp));
293   }
294   grpc_status_code code = static_cast<grpc_status_code>(
295       grpc_gcp_HandshakerStatus_code(resp_status));
296   std::string error;
297   if (code != GRPC_STATUS_OK) {
298     upb_StringView details = grpc_gcp_HandshakerStatus_details(resp_status);
299     if (details.size > 0) {
300       error = absl::StrCat("Status ", code, " from handshaker service: ",
301                            absl::string_view(details.data, details.size));
302       gpr_log(GPR_ERROR, "%s", error.c_str());
303     }
304   }
305   // TODO(apolcyn): consider short ciruiting handle_response_done and
306   // invoking the TSI callback directly if we aren't done yet, if
307   // handle_response_done's allocation per message received causes
308   // a performance issue.
309   handle_response_done(client, alts_tsi_utils_convert_to_tsi_result(code),
310                        std::move(error), bytes_to_send, bytes_to_send_size,
311                        result);
312 }
313 
continue_make_grpc_call(alts_grpc_handshaker_client * client,bool is_start)314 static tsi_result continue_make_grpc_call(alts_grpc_handshaker_client* client,
315                                           bool is_start) {
316   GPR_ASSERT(client != nullptr);
317   grpc_op ops[kHandshakerClientOpNum];
318   memset(ops, 0, sizeof(ops));
319   grpc_op* op = ops;
320   if (is_start) {
321     op->op = GRPC_OP_RECV_STATUS_ON_CLIENT;
322     op->data.recv_status_on_client.trailing_metadata = nullptr;
323     op->data.recv_status_on_client.status = &client->handshake_status_code;
324     op->data.recv_status_on_client.status_details =
325         &client->handshake_status_details;
326     op->flags = 0;
327     op->reserved = nullptr;
328     op++;
329     GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
330     gpr_ref(&client->refs);
331     grpc_call_error call_error =
332         client->grpc_caller(client->call, ops, static_cast<size_t>(op - ops),
333                             &client->on_status_received);
334     // TODO(apolcyn): return the error here instead, as done for other ops?
335     GPR_ASSERT(call_error == GRPC_CALL_OK);
336     memset(ops, 0, sizeof(ops));
337     op = ops;
338     op->op = GRPC_OP_SEND_INITIAL_METADATA;
339     op->data.send_initial_metadata.count = 0;
340     op++;
341     GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
342     op->op = GRPC_OP_RECV_INITIAL_METADATA;
343     op->data.recv_initial_metadata.recv_initial_metadata =
344         &client->recv_initial_metadata;
345     op++;
346     GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
347   }
348   op->op = GRPC_OP_SEND_MESSAGE;
349   op->data.send_message.send_message = client->send_buffer;
350   op++;
351   GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
352   op->op = GRPC_OP_RECV_MESSAGE;
353   op->data.recv_message.recv_message = &client->recv_buffer;
354   op++;
355   GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
356   GPR_ASSERT(client->grpc_caller != nullptr);
357   if (client->grpc_caller(client->call, ops, static_cast<size_t>(op - ops),
358                           &client->on_handshaker_service_resp_recv) !=
359       GRPC_CALL_OK) {
360     gpr_log(GPR_ERROR, "Start batch operation failed");
361     return TSI_INTERNAL_ERROR;
362   }
363   return TSI_OK;
364 }
365 
366 // TODO(apolcyn): remove this global queue when we can safely rely
367 // on a MAX_CONCURRENT_STREAMS setting in the ALTS handshake server to
368 // limit the number of concurrent handshakes.
369 namespace {
370 
371 class HandshakeQueue {
372  public:
HandshakeQueue(size_t max_outstanding_handshakes)373   explicit HandshakeQueue(size_t max_outstanding_handshakes)
374       : max_outstanding_handshakes_(max_outstanding_handshakes) {}
375 
RequestHandshake(alts_grpc_handshaker_client * client)376   void RequestHandshake(alts_grpc_handshaker_client* client) {
377     {
378       grpc_core::MutexLock lock(&mu_);
379       if (outstanding_handshakes_ == max_outstanding_handshakes_) {
380         // Max number already running, add to queue.
381         queued_handshakes_.push_back(client);
382         return;
383       }
384       // Start the handshake immediately.
385       ++outstanding_handshakes_;
386     }
387     continue_make_grpc_call(client, true /* is_start */);
388   }
389 
HandshakeDone()390   void HandshakeDone() {
391     alts_grpc_handshaker_client* client = nullptr;
392     {
393       grpc_core::MutexLock lock(&mu_);
394       if (queued_handshakes_.empty()) {
395         // Nothing more in queue.  Decrement count and return immediately.
396         --outstanding_handshakes_;
397         return;
398       }
399       // Remove next entry from queue and start the handshake.
400       client = queued_handshakes_.front();
401       queued_handshakes_.pop_front();
402     }
403     continue_make_grpc_call(client, true /* is_start */);
404   }
405 
406  private:
407   grpc_core::Mutex mu_;
408   std::list<alts_grpc_handshaker_client*> queued_handshakes_;
409   size_t outstanding_handshakes_ = 0;
410   const size_t max_outstanding_handshakes_;
411 };
412 
413 gpr_once g_queued_handshakes_init = GPR_ONCE_INIT;
414 // Using separate queues for client and server handshakes is a
415 // hack that's mainly intended to satisfy the alts_concurrent_connectivity_test,
416 // which runs many concurrent handshakes where both endpoints
417 // are in the same process; this situation is problematic with a
418 // single queue because we have a high chance of using up all outstanding
419 // slots in the queue, such that there aren't any
420 // mutual client/server handshakes outstanding at the same time and
421 // able to make progress.
422 HandshakeQueue* g_client_handshake_queue;
423 HandshakeQueue* g_server_handshake_queue;
424 
DoHandshakeQueuesInit(void)425 void DoHandshakeQueuesInit(void) {
426   const size_t per_queue_max_outstanding_handshakes =
427       MaxNumberOfConcurrentHandshakes();
428   g_client_handshake_queue =
429       new HandshakeQueue(per_queue_max_outstanding_handshakes);
430   g_server_handshake_queue =
431       new HandshakeQueue(per_queue_max_outstanding_handshakes);
432 }
433 
RequestHandshake(alts_grpc_handshaker_client * client,bool is_client)434 void RequestHandshake(alts_grpc_handshaker_client* client, bool is_client) {
435   gpr_once_init(&g_queued_handshakes_init, DoHandshakeQueuesInit);
436   HandshakeQueue* queue =
437       is_client ? g_client_handshake_queue : g_server_handshake_queue;
438   queue->RequestHandshake(client);
439 }
440 
HandshakeDone(bool is_client)441 void HandshakeDone(bool is_client) {
442   HandshakeQueue* queue =
443       is_client ? g_client_handshake_queue : g_server_handshake_queue;
444   queue->HandshakeDone();
445 }
446 
447 };  // namespace
448 
449 ///
450 /// Populate grpc operation data with the fields of ALTS handshaker client and
451 /// make a grpc call.
452 ///
make_grpc_call(alts_handshaker_client * c,bool is_start)453 static tsi_result make_grpc_call(alts_handshaker_client* c, bool is_start) {
454   GPR_ASSERT(c != nullptr);
455   alts_grpc_handshaker_client* client =
456       reinterpret_cast<alts_grpc_handshaker_client*>(c);
457   if (is_start) {
458     RequestHandshake(client, client->is_client);
459     return TSI_OK;
460   } else {
461     return continue_make_grpc_call(client, is_start);
462   }
463 }
464 
on_status_received(void * arg,grpc_error_handle error)465 static void on_status_received(void* arg, grpc_error_handle error) {
466   alts_grpc_handshaker_client* client =
467       static_cast<alts_grpc_handshaker_client*>(arg);
468   if (client->handshake_status_code != GRPC_STATUS_OK) {
469     // TODO(apolcyn): consider overriding the handshake result's
470     // status from the final ALTS message with the status here.
471     char* status_details =
472         grpc_slice_to_c_string(client->handshake_status_details);
473     gpr_log(GPR_INFO,
474             "alts_grpc_handshaker_client:%p on_status_received "
475             "status:%d details:|%s| error:|%s|",
476             client, client->handshake_status_code, status_details,
477             grpc_core::StatusToString(error).c_str());
478     gpr_free(status_details);
479   }
480   maybe_complete_tsi_next(client, true /* receive_status_finished */,
481                           nullptr /* pending_recv_message_result */);
482   HandshakeDone(client->is_client);
483   alts_grpc_handshaker_client_unref(client);
484 }
485 
486 // Serializes a grpc_gcp_HandshakerReq message into a buffer and returns newly
487 // grpc_byte_buffer holding it.
get_serialized_handshaker_req(grpc_gcp_HandshakerReq * req,upb_Arena * arena)488 static grpc_byte_buffer* get_serialized_handshaker_req(
489     grpc_gcp_HandshakerReq* req, upb_Arena* arena) {
490   size_t buf_length;
491   char* buf = grpc_gcp_HandshakerReq_serialize(req, arena, &buf_length);
492   if (buf == nullptr) {
493     return nullptr;
494   }
495   grpc_slice slice = grpc_slice_from_copied_buffer(buf, buf_length);
496   grpc_byte_buffer* byte_buffer = grpc_raw_byte_buffer_create(&slice, 1);
497   grpc_core::CSliceUnref(slice);
498   return byte_buffer;
499 }
500 
501 // Create and populate a client_start handshaker request, then serialize it.
get_serialized_start_client(alts_handshaker_client * c)502 static grpc_byte_buffer* get_serialized_start_client(
503     alts_handshaker_client* c) {
504   GPR_ASSERT(c != nullptr);
505   alts_grpc_handshaker_client* client =
506       reinterpret_cast<alts_grpc_handshaker_client*>(c);
507   upb::Arena arena;
508   grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr());
509   grpc_gcp_StartClientHandshakeReq* start_client =
510       grpc_gcp_HandshakerReq_mutable_client_start(req, arena.ptr());
511   grpc_gcp_StartClientHandshakeReq_set_handshake_security_protocol(
512       start_client, grpc_gcp_ALTS);
513   grpc_gcp_StartClientHandshakeReq_add_application_protocols(
514       start_client, upb_StringView_FromString(ALTS_APPLICATION_PROTOCOL),
515       arena.ptr());
516   grpc_gcp_StartClientHandshakeReq_add_record_protocols(
517       start_client, upb_StringView_FromString(ALTS_RECORD_PROTOCOL),
518       arena.ptr());
519   grpc_gcp_RpcProtocolVersions* client_version =
520       grpc_gcp_StartClientHandshakeReq_mutable_rpc_versions(start_client,
521                                                             arena.ptr());
522   grpc_gcp_RpcProtocolVersions_assign_from_struct(
523       client_version, arena.ptr(), &client->options->rpc_versions);
524   grpc_gcp_StartClientHandshakeReq_set_target_name(
525       start_client, upb_StringView_FromDataAndSize(
526                         reinterpret_cast<const char*>(
527                             GRPC_SLICE_START_PTR(client->target_name)),
528                         GRPC_SLICE_LENGTH(client->target_name)));
529   target_service_account* ptr =
530       (reinterpret_cast<grpc_alts_credentials_client_options*>(client->options))
531           ->target_account_list_head;
532   while (ptr != nullptr) {
533     grpc_gcp_Identity* target_identity =
534         grpc_gcp_StartClientHandshakeReq_add_target_identities(start_client,
535                                                                arena.ptr());
536     grpc_gcp_Identity_set_service_account(target_identity,
537                                           upb_StringView_FromString(ptr->data));
538     ptr = ptr->next;
539   }
540   grpc_gcp_StartClientHandshakeReq_set_max_frame_size(
541       start_client, static_cast<uint32_t>(client->max_frame_size));
542   return get_serialized_handshaker_req(req, arena.ptr());
543 }
544 
handshaker_client_start_client(alts_handshaker_client * c)545 static tsi_result handshaker_client_start_client(alts_handshaker_client* c) {
546   if (c == nullptr) {
547     gpr_log(GPR_ERROR, "client is nullptr in handshaker_client_start_client()");
548     return TSI_INVALID_ARGUMENT;
549   }
550   grpc_byte_buffer* buffer = get_serialized_start_client(c);
551   alts_grpc_handshaker_client* client =
552       reinterpret_cast<alts_grpc_handshaker_client*>(c);
553   if (buffer == nullptr) {
554     gpr_log(GPR_ERROR, "get_serialized_start_client() failed");
555     return TSI_INTERNAL_ERROR;
556   }
557   handshaker_client_send_buffer_destroy(client);
558   client->send_buffer = buffer;
559   tsi_result result = make_grpc_call(&client->base, true /* is_start */);
560   if (result != TSI_OK) {
561     gpr_log(GPR_ERROR, "make_grpc_call() failed");
562   }
563   return result;
564 }
565 
566 // Create and populate a start_server handshaker request, then serialize it.
get_serialized_start_server(alts_handshaker_client * c,grpc_slice * bytes_received)567 static grpc_byte_buffer* get_serialized_start_server(
568     alts_handshaker_client* c, grpc_slice* bytes_received) {
569   GPR_ASSERT(c != nullptr);
570   GPR_ASSERT(bytes_received != nullptr);
571   alts_grpc_handshaker_client* client =
572       reinterpret_cast<alts_grpc_handshaker_client*>(c);
573 
574   upb::Arena arena;
575   grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr());
576 
577   grpc_gcp_StartServerHandshakeReq* start_server =
578       grpc_gcp_HandshakerReq_mutable_server_start(req, arena.ptr());
579   grpc_gcp_StartServerHandshakeReq_add_application_protocols(
580       start_server, upb_StringView_FromString(ALTS_APPLICATION_PROTOCOL),
581       arena.ptr());
582   grpc_gcp_ServerHandshakeParameters* value =
583       grpc_gcp_ServerHandshakeParameters_new(arena.ptr());
584   grpc_gcp_ServerHandshakeParameters_add_record_protocols(
585       value, upb_StringView_FromString(ALTS_RECORD_PROTOCOL), arena.ptr());
586   grpc_gcp_StartServerHandshakeReq_handshake_parameters_set(
587       start_server, grpc_gcp_ALTS, value, arena.ptr());
588   grpc_gcp_StartServerHandshakeReq_set_in_bytes(
589       start_server,
590       upb_StringView_FromDataAndSize(
591           reinterpret_cast<const char*>(GRPC_SLICE_START_PTR(*bytes_received)),
592           GRPC_SLICE_LENGTH(*bytes_received)));
593   grpc_gcp_RpcProtocolVersions* server_version =
594       grpc_gcp_StartServerHandshakeReq_mutable_rpc_versions(start_server,
595                                                             arena.ptr());
596   grpc_gcp_RpcProtocolVersions_assign_from_struct(
597       server_version, arena.ptr(), &client->options->rpc_versions);
598   grpc_gcp_StartServerHandshakeReq_set_max_frame_size(
599       start_server, static_cast<uint32_t>(client->max_frame_size));
600   return get_serialized_handshaker_req(req, arena.ptr());
601 }
602 
handshaker_client_start_server(alts_handshaker_client * c,grpc_slice * bytes_received)603 static tsi_result handshaker_client_start_server(alts_handshaker_client* c,
604                                                  grpc_slice* bytes_received) {
605   if (c == nullptr || bytes_received == nullptr) {
606     gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_start_server()");
607     return TSI_INVALID_ARGUMENT;
608   }
609   alts_grpc_handshaker_client* client =
610       reinterpret_cast<alts_grpc_handshaker_client*>(c);
611   grpc_byte_buffer* buffer = get_serialized_start_server(c, bytes_received);
612   if (buffer == nullptr) {
613     gpr_log(GPR_ERROR, "get_serialized_start_server() failed");
614     return TSI_INTERNAL_ERROR;
615   }
616   handshaker_client_send_buffer_destroy(client);
617   client->send_buffer = buffer;
618   tsi_result result = make_grpc_call(&client->base, true /* is_start */);
619   if (result != TSI_OK) {
620     gpr_log(GPR_ERROR, "make_grpc_call() failed");
621   }
622   return result;
623 }
624 
625 // Create and populate a next handshaker request, then serialize it.
get_serialized_next(grpc_slice * bytes_received)626 static grpc_byte_buffer* get_serialized_next(grpc_slice* bytes_received) {
627   GPR_ASSERT(bytes_received != nullptr);
628   upb::Arena arena;
629   grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr());
630   grpc_gcp_NextHandshakeMessageReq* next =
631       grpc_gcp_HandshakerReq_mutable_next(req, arena.ptr());
632   grpc_gcp_NextHandshakeMessageReq_set_in_bytes(
633       next,
634       upb_StringView_FromDataAndSize(
635           reinterpret_cast<const char*> GRPC_SLICE_START_PTR(*bytes_received),
636           GRPC_SLICE_LENGTH(*bytes_received)));
637   return get_serialized_handshaker_req(req, arena.ptr());
638 }
639 
handshaker_client_next(alts_handshaker_client * c,grpc_slice * bytes_received)640 static tsi_result handshaker_client_next(alts_handshaker_client* c,
641                                          grpc_slice* bytes_received) {
642   if (c == nullptr || bytes_received == nullptr) {
643     gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_next()");
644     return TSI_INVALID_ARGUMENT;
645   }
646   alts_grpc_handshaker_client* client =
647       reinterpret_cast<alts_grpc_handshaker_client*>(c);
648   grpc_core::CSliceUnref(client->recv_bytes);
649   client->recv_bytes = grpc_core::CSliceRef(*bytes_received);
650   grpc_byte_buffer* buffer = get_serialized_next(bytes_received);
651   if (buffer == nullptr) {
652     gpr_log(GPR_ERROR, "get_serialized_next() failed");
653     return TSI_INTERNAL_ERROR;
654   }
655   handshaker_client_send_buffer_destroy(client);
656   client->send_buffer = buffer;
657   tsi_result result = make_grpc_call(&client->base, false /* is_start */);
658   if (result != TSI_OK) {
659     gpr_log(GPR_ERROR, "make_grpc_call() failed");
660   }
661   return result;
662 }
663 
handshaker_client_shutdown(alts_handshaker_client * c)664 static void handshaker_client_shutdown(alts_handshaker_client* c) {
665   GPR_ASSERT(c != nullptr);
666   alts_grpc_handshaker_client* client =
667       reinterpret_cast<alts_grpc_handshaker_client*>(c);
668   if (client->call != nullptr) {
669     grpc_call_cancel_internal(client->call);
670   }
671 }
672 
handshaker_call_unref(void * arg,grpc_error_handle)673 static void handshaker_call_unref(void* arg, grpc_error_handle /* error */) {
674   grpc_call* call = static_cast<grpc_call*>(arg);
675   grpc_call_unref(call);
676 }
677 
handshaker_client_destruct(alts_handshaker_client * c)678 static void handshaker_client_destruct(alts_handshaker_client* c) {
679   if (c == nullptr) {
680     return;
681   }
682   alts_grpc_handshaker_client* client =
683       reinterpret_cast<alts_grpc_handshaker_client*>(c);
684   if (client->call != nullptr) {
685     // Throw this grpc_call_unref over to the ExecCtx so that
686     // we invoke it at the bottom of the call stack and
687     // prevent lock inversion problems due to nested ExecCtx flushing.
688     // TODO(apolcyn): we could remove this indirection and call
689     // grpc_call_unref inline if there was an internal variant of
690     // grpc_call_unref that didn't need to flush an ExecCtx.
691     if (grpc_core::ExecCtx::Get() == nullptr) {
692       // Unref handshaker call if there is no exec_ctx, e.g., in the case of
693       // Envoy ALTS transport socket.
694       grpc_call_unref(client->call);
695     } else {
696       // Using existing exec_ctx to unref handshaker call.
697       grpc_core::ExecCtx::Run(
698           DEBUG_LOCATION,
699           GRPC_CLOSURE_CREATE(handshaker_call_unref, client->call,
700                               grpc_schedule_on_exec_ctx),
701           absl::OkStatus());
702     }
703   }
704 }
705 
706 static const alts_handshaker_client_vtable vtable = {
707     handshaker_client_start_client, handshaker_client_start_server,
708     handshaker_client_next, handshaker_client_shutdown,
709     handshaker_client_destruct};
710 
alts_grpc_handshaker_client_create(alts_tsi_handshaker * handshaker,grpc_channel * channel,const char * handshaker_service_url,grpc_pollset_set * interested_parties,grpc_alts_credentials_options * options,const grpc_slice & target_name,grpc_iomgr_cb_func grpc_cb,tsi_handshaker_on_next_done_cb cb,void * user_data,alts_handshaker_client_vtable * vtable_for_testing,bool is_client,size_t max_frame_size,std::string * error)711 alts_handshaker_client* alts_grpc_handshaker_client_create(
712     alts_tsi_handshaker* handshaker, grpc_channel* channel,
713     const char* handshaker_service_url, grpc_pollset_set* interested_parties,
714     grpc_alts_credentials_options* options, const grpc_slice& target_name,
715     grpc_iomgr_cb_func grpc_cb, tsi_handshaker_on_next_done_cb cb,
716     void* user_data, alts_handshaker_client_vtable* vtable_for_testing,
717     bool is_client, size_t max_frame_size, std::string* error) {
718   if (channel == nullptr || handshaker_service_url == nullptr) {
719     gpr_log(GPR_ERROR, "Invalid arguments to alts_handshaker_client_create()");
720     return nullptr;
721   }
722   alts_grpc_handshaker_client* client = new alts_grpc_handshaker_client();
723   memset(&client->base, 0, sizeof(client->base));
724   client->base.vtable =
725       vtable_for_testing == nullptr ? &vtable : vtable_for_testing;
726   gpr_ref_init(&client->refs, 1);
727   client->handshaker = handshaker;
728   client->grpc_caller = grpc_call_start_batch_and_execute;
729   grpc_metadata_array_init(&client->recv_initial_metadata);
730   client->cb = cb;
731   client->user_data = user_data;
732   client->options = grpc_alts_credentials_options_copy(options);
733   client->target_name = grpc_slice_copy(target_name);
734   client->is_client = is_client;
735   client->recv_bytes = grpc_empty_slice();
736   client->buffer_size = TSI_ALTS_INITIAL_BUFFER_SIZE;
737   client->buffer = static_cast<unsigned char*>(gpr_zalloc(client->buffer_size));
738   client->handshake_status_details = grpc_empty_slice();
739   client->max_frame_size = max_frame_size;
740   client->error = error;
741   client->call =
742       strcmp(handshaker_service_url, ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING) ==
743               0
744           ? nullptr
745           : grpc_channel_create_pollset_set_call(
746                 channel, nullptr, GRPC_PROPAGATE_DEFAULTS, interested_parties,
747                 grpc_slice_from_static_string(ALTS_SERVICE_METHOD), nullptr,
748                 grpc_core::Timestamp::InfFuture(), nullptr);
749   GRPC_CLOSURE_INIT(&client->on_handshaker_service_resp_recv, grpc_cb, client,
750                     grpc_schedule_on_exec_ctx);
751   GRPC_CLOSURE_INIT(&client->on_status_received, on_status_received, client,
752                     grpc_schedule_on_exec_ctx);
753   return &client->base;
754 }
755 
756 namespace grpc_core {
757 namespace internal {
758 
alts_handshaker_client_set_grpc_caller_for_testing(alts_handshaker_client * c,alts_grpc_caller caller)759 void alts_handshaker_client_set_grpc_caller_for_testing(
760     alts_handshaker_client* c, alts_grpc_caller caller) {
761   GPR_ASSERT(c != nullptr && caller != nullptr);
762   alts_grpc_handshaker_client* client =
763       reinterpret_cast<alts_grpc_handshaker_client*>(c);
764   client->grpc_caller = caller;
765 }
766 
alts_handshaker_client_get_send_buffer_for_testing(alts_handshaker_client * c)767 grpc_byte_buffer* alts_handshaker_client_get_send_buffer_for_testing(
768     alts_handshaker_client* c) {
769   GPR_ASSERT(c != nullptr);
770   alts_grpc_handshaker_client* client =
771       reinterpret_cast<alts_grpc_handshaker_client*>(c);
772   return client->send_buffer;
773 }
774 
alts_handshaker_client_get_recv_buffer_addr_for_testing(alts_handshaker_client * c)775 grpc_byte_buffer** alts_handshaker_client_get_recv_buffer_addr_for_testing(
776     alts_handshaker_client* c) {
777   GPR_ASSERT(c != nullptr);
778   alts_grpc_handshaker_client* client =
779       reinterpret_cast<alts_grpc_handshaker_client*>(c);
780   return &client->recv_buffer;
781 }
782 
alts_handshaker_client_get_initial_metadata_for_testing(alts_handshaker_client * c)783 grpc_metadata_array* alts_handshaker_client_get_initial_metadata_for_testing(
784     alts_handshaker_client* c) {
785   GPR_ASSERT(c != nullptr);
786   alts_grpc_handshaker_client* client =
787       reinterpret_cast<alts_grpc_handshaker_client*>(c);
788   return &client->recv_initial_metadata;
789 }
790 
alts_handshaker_client_set_recv_bytes_for_testing(alts_handshaker_client * c,grpc_slice * recv_bytes)791 void alts_handshaker_client_set_recv_bytes_for_testing(
792     alts_handshaker_client* c, grpc_slice* recv_bytes) {
793   GPR_ASSERT(c != nullptr);
794   alts_grpc_handshaker_client* client =
795       reinterpret_cast<alts_grpc_handshaker_client*>(c);
796   client->recv_bytes = CSliceRef(*recv_bytes);
797 }
798 
alts_handshaker_client_set_fields_for_testing(alts_handshaker_client * c,alts_tsi_handshaker * handshaker,tsi_handshaker_on_next_done_cb cb,void * user_data,grpc_byte_buffer * recv_buffer,bool inject_read_failure)799 void alts_handshaker_client_set_fields_for_testing(
800     alts_handshaker_client* c, alts_tsi_handshaker* handshaker,
801     tsi_handshaker_on_next_done_cb cb, void* user_data,
802     grpc_byte_buffer* recv_buffer, bool inject_read_failure) {
803   GPR_ASSERT(c != nullptr);
804   alts_grpc_handshaker_client* client =
805       reinterpret_cast<alts_grpc_handshaker_client*>(c);
806   client->handshaker = handshaker;
807   client->cb = cb;
808   client->user_data = user_data;
809   client->recv_buffer = recv_buffer;
810   client->inject_read_failure = inject_read_failure;
811 }
812 
alts_handshaker_client_check_fields_for_testing(alts_handshaker_client * c,tsi_handshaker_on_next_done_cb cb,void * user_data,bool has_sent_start_message,grpc_slice * recv_bytes)813 void alts_handshaker_client_check_fields_for_testing(
814     alts_handshaker_client* c, tsi_handshaker_on_next_done_cb cb,
815     void* user_data, bool has_sent_start_message, grpc_slice* recv_bytes) {
816   GPR_ASSERT(c != nullptr);
817   alts_grpc_handshaker_client* client =
818       reinterpret_cast<alts_grpc_handshaker_client*>(c);
819   GPR_ASSERT(client->cb == cb);
820   GPR_ASSERT(client->user_data == user_data);
821   if (recv_bytes != nullptr) {
822     GPR_ASSERT(grpc_slice_cmp(client->recv_bytes, *recv_bytes) == 0);
823   }
824   GPR_ASSERT(alts_tsi_handshaker_get_has_sent_start_message_for_testing(
825                  client->handshaker) == has_sent_start_message);
826 }
827 
alts_handshaker_client_set_vtable_for_testing(alts_handshaker_client * c,alts_handshaker_client_vtable * vtable)828 void alts_handshaker_client_set_vtable_for_testing(
829     alts_handshaker_client* c, alts_handshaker_client_vtable* vtable) {
830   GPR_ASSERT(c != nullptr);
831   GPR_ASSERT(vtable != nullptr);
832   alts_grpc_handshaker_client* client =
833       reinterpret_cast<alts_grpc_handshaker_client*>(c);
834   client->base.vtable = vtable;
835 }
836 
alts_handshaker_client_get_handshaker_for_testing(alts_handshaker_client * c)837 alts_tsi_handshaker* alts_handshaker_client_get_handshaker_for_testing(
838     alts_handshaker_client* c) {
839   GPR_ASSERT(c != nullptr);
840   alts_grpc_handshaker_client* client =
841       reinterpret_cast<alts_grpc_handshaker_client*>(c);
842   return client->handshaker;
843 }
844 
alts_handshaker_client_set_cb_for_testing(alts_handshaker_client * c,tsi_handshaker_on_next_done_cb cb)845 void alts_handshaker_client_set_cb_for_testing(
846     alts_handshaker_client* c, tsi_handshaker_on_next_done_cb cb) {
847   GPR_ASSERT(c != nullptr);
848   alts_grpc_handshaker_client* client =
849       reinterpret_cast<alts_grpc_handshaker_client*>(c);
850   client->cb = cb;
851 }
852 
alts_handshaker_client_get_closure_for_testing(alts_handshaker_client * c)853 grpc_closure* alts_handshaker_client_get_closure_for_testing(
854     alts_handshaker_client* c) {
855   GPR_ASSERT(c != nullptr);
856   alts_grpc_handshaker_client* client =
857       reinterpret_cast<alts_grpc_handshaker_client*>(c);
858   return &client->on_handshaker_service_resp_recv;
859 }
860 
alts_handshaker_client_ref_for_testing(alts_handshaker_client * c)861 void alts_handshaker_client_ref_for_testing(alts_handshaker_client* c) {
862   alts_grpc_handshaker_client* client =
863       reinterpret_cast<alts_grpc_handshaker_client*>(c);
864   gpr_ref(&client->refs);
865 }
866 
alts_handshaker_client_on_status_received_for_testing(alts_handshaker_client * c,grpc_status_code status,grpc_error_handle error)867 void alts_handshaker_client_on_status_received_for_testing(
868     alts_handshaker_client* c, grpc_status_code status,
869     grpc_error_handle error) {
870   // We first make sure that the handshake queue has been initialized
871   // here because there are tests that use this API that mock out
872   // other parts of the alts_handshaker_client in such a way that the
873   // code path that would normally ensure that the handshake queue
874   // has been initialized isn't taken.
875   gpr_once_init(&g_queued_handshakes_init, DoHandshakeQueuesInit);
876   alts_grpc_handshaker_client* client =
877       reinterpret_cast<alts_grpc_handshaker_client*>(c);
878   client->handshake_status_code = status;
879   client->handshake_status_details = grpc_empty_slice();
880   Closure::Run(DEBUG_LOCATION, &client->on_status_received, error);
881 }
882 
883 }  // namespace internal
884 }  // namespace grpc_core
885 
alts_handshaker_client_start_client(alts_handshaker_client * client)886 tsi_result alts_handshaker_client_start_client(alts_handshaker_client* client) {
887   if (client != nullptr && client->vtable != nullptr &&
888       client->vtable->client_start != nullptr) {
889     return client->vtable->client_start(client);
890   }
891   gpr_log(GPR_ERROR,
892           "client or client->vtable has not been initialized properly");
893   return TSI_INVALID_ARGUMENT;
894 }
895 
alts_handshaker_client_start_server(alts_handshaker_client * client,grpc_slice * bytes_received)896 tsi_result alts_handshaker_client_start_server(alts_handshaker_client* client,
897                                                grpc_slice* bytes_received) {
898   if (client != nullptr && client->vtable != nullptr &&
899       client->vtable->server_start != nullptr) {
900     return client->vtable->server_start(client, bytes_received);
901   }
902   gpr_log(GPR_ERROR,
903           "client or client->vtable has not been initialized properly");
904   return TSI_INVALID_ARGUMENT;
905 }
906 
alts_handshaker_client_next(alts_handshaker_client * client,grpc_slice * bytes_received)907 tsi_result alts_handshaker_client_next(alts_handshaker_client* client,
908                                        grpc_slice* bytes_received) {
909   if (client != nullptr && client->vtable != nullptr &&
910       client->vtable->next != nullptr) {
911     return client->vtable->next(client, bytes_received);
912   }
913   gpr_log(GPR_ERROR,
914           "client or client->vtable has not been initialized properly");
915   return TSI_INVALID_ARGUMENT;
916 }
917 
alts_handshaker_client_shutdown(alts_handshaker_client * client)918 void alts_handshaker_client_shutdown(alts_handshaker_client* client) {
919   if (client != nullptr && client->vtable != nullptr &&
920       client->vtable->shutdown != nullptr) {
921     client->vtable->shutdown(client);
922   }
923 }
924 
alts_handshaker_client_destroy(alts_handshaker_client * c)925 void alts_handshaker_client_destroy(alts_handshaker_client* c) {
926   if (c != nullptr) {
927     alts_grpc_handshaker_client* client =
928         reinterpret_cast<alts_grpc_handshaker_client*>(c);
929     alts_grpc_handshaker_client_unref(client);
930   }
931 }
932 
MaxNumberOfConcurrentHandshakes()933 size_t MaxNumberOfConcurrentHandshakes() {
934   size_t max_concurrent_handshakes = 40;
935   absl::optional<std::string> env_var_max_concurrent_handshakes =
936       grpc_core::GetEnv(kMaxConcurrentStreamsEnvironmentVariable);
937   if (env_var_max_concurrent_handshakes.has_value()) {
938     size_t effective_max_concurrent_handshakes = 40;
939     if (absl::SimpleAtoi(*env_var_max_concurrent_handshakes,
940                          &effective_max_concurrent_handshakes)) {
941       max_concurrent_handshakes = effective_max_concurrent_handshakes;
942     }
943   }
944   return max_concurrent_handshakes;
945 }
946