1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/tsi/fake_transport_security.h"
22 
23 #include <stdlib.h>
24 #include <string.h>
25 
26 #include <grpc/support/alloc.h>
27 #include <grpc/support/log.h>
28 
29 #include "src/core/lib/gpr/useful.h"
30 #include "src/core/lib/gprpp/crash.h"
31 #include "src/core/lib/gprpp/memory.h"
32 #include "src/core/lib/slice/slice_internal.h"
33 #include "src/core/tsi/transport_security_grpc.h"
34 
35 // --- Constants. ---
36 #define TSI_FAKE_FRAME_HEADER_SIZE 4
37 #define TSI_FAKE_FRAME_INITIAL_ALLOCATED_SIZE 64
38 #define TSI_FAKE_DEFAULT_FRAME_SIZE 16384
39 #define TSI_FAKE_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE 256
40 
41 // --- Structure definitions. ---
42 
43 // a frame is encoded like this:
44 // | size |     data    |
45 // where the size field value is the size of the size field plus the size of
46 // the data encoded in little endian on 4 bytes.
47 struct tsi_fake_frame {
48   unsigned char* data;
49   size_t size;
50   size_t allocated_size;
51   size_t offset;
52   int needs_draining;
53 };
54 typedef enum {
55   TSI_FAKE_CLIENT_INIT = 0,
56   TSI_FAKE_SERVER_INIT = 1,
57   TSI_FAKE_CLIENT_FINISHED = 2,
58   TSI_FAKE_SERVER_FINISHED = 3,
59   TSI_FAKE_HANDSHAKE_MESSAGE_MAX = 4
60 } tsi_fake_handshake_message;
61 
62 struct tsi_fake_handshaker {
63   tsi_handshaker base;
64   int is_client;
65   tsi_fake_handshake_message next_message_to_send;
66   int needs_incoming_message;
67   tsi_fake_frame incoming_frame;
68   tsi_fake_frame outgoing_frame;
69   unsigned char* outgoing_bytes_buffer;
70   size_t outgoing_bytes_buffer_size;
71   tsi_result result;
72 };
73 struct tsi_fake_frame_protector {
74   tsi_frame_protector base;
75   tsi_fake_frame protect_frame;
76   tsi_fake_frame unprotect_frame;
77   size_t max_frame_size;
78 };
79 struct tsi_fake_zero_copy_grpc_protector {
80   tsi_zero_copy_grpc_protector base;
81   grpc_slice_buffer header_sb;
82   grpc_slice_buffer protected_sb;
83   size_t max_frame_size;
84   size_t parsed_frame_size;
85 };
86 // --- Utils. ---
87 
88 static const char* tsi_fake_handshake_message_strings[] = {
89     "CLIENT_INIT", "SERVER_INIT", "CLIENT_FINISHED", "SERVER_FINISHED"};
90 
tsi_fake_handshake_message_to_string(int msg)91 static const char* tsi_fake_handshake_message_to_string(int msg) {
92   if (msg < 0 || msg >= TSI_FAKE_HANDSHAKE_MESSAGE_MAX) {
93     gpr_log(GPR_ERROR, "Invalid message %d", msg);
94     return "UNKNOWN";
95   }
96   return tsi_fake_handshake_message_strings[msg];
97 }
98 
tsi_fake_handshake_message_from_string(const char * msg_string,tsi_fake_handshake_message * msg,std::string * error)99 static tsi_result tsi_fake_handshake_message_from_string(
100     const char* msg_string, tsi_fake_handshake_message* msg,
101     std::string* error) {
102   for (int i = 0; i < TSI_FAKE_HANDSHAKE_MESSAGE_MAX; i++) {
103     if (strncmp(msg_string, tsi_fake_handshake_message_strings[i],
104                 strlen(tsi_fake_handshake_message_strings[i])) == 0) {
105       *msg = static_cast<tsi_fake_handshake_message>(i);
106       return TSI_OK;
107     }
108   }
109   gpr_log(GPR_ERROR, "Invalid handshake message.");
110   if (error != nullptr) *error = "invalid handshake message";
111   return TSI_DATA_CORRUPTED;
112 }
113 
load32_little_endian(const unsigned char * buf)114 static uint32_t load32_little_endian(const unsigned char* buf) {
115   return (static_cast<uint32_t>(buf[0]) | static_cast<uint32_t>(buf[1] << 8) |
116           static_cast<uint32_t>(buf[2] << 16) |
117           static_cast<uint32_t>(buf[3] << 24));
118 }
119 
store32_little_endian(uint32_t value,unsigned char * buf)120 static void store32_little_endian(uint32_t value, unsigned char* buf) {
121   buf[3] = static_cast<unsigned char>((value >> 24) & 0xFF);
122   buf[2] = static_cast<unsigned char>((value >> 16) & 0xFF);
123   buf[1] = static_cast<unsigned char>((value >> 8) & 0xFF);
124   buf[0] = static_cast<unsigned char>((value)&0xFF);
125 }
126 
read_frame_size(const grpc_slice_buffer * sb)127 static uint32_t read_frame_size(const grpc_slice_buffer* sb) {
128   GPR_ASSERT(sb != nullptr && sb->length >= TSI_FAKE_FRAME_HEADER_SIZE);
129   uint8_t frame_size_buffer[TSI_FAKE_FRAME_HEADER_SIZE];
130   uint8_t* buf = frame_size_buffer;
131   // Copies the first 4 bytes to a temporary buffer.
132   size_t remaining = TSI_FAKE_FRAME_HEADER_SIZE;
133   for (size_t i = 0; i < sb->count; i++) {
134     size_t slice_length = GRPC_SLICE_LENGTH(sb->slices[i]);
135     if (remaining <= slice_length) {
136       memcpy(buf, GRPC_SLICE_START_PTR(sb->slices[i]), remaining);
137       remaining = 0;
138       break;
139     } else {
140       memcpy(buf, GRPC_SLICE_START_PTR(sb->slices[i]), slice_length);
141       buf += slice_length;
142       remaining -= slice_length;
143     }
144   }
145   GPR_ASSERT(remaining == 0);
146   return load32_little_endian(frame_size_buffer);
147 }
148 
tsi_fake_zero_copy_grpc_protector_next_frame_size(const grpc_slice_buffer * protected_slices)149 uint32_t tsi_fake_zero_copy_grpc_protector_next_frame_size(
150     const grpc_slice_buffer* protected_slices) {
151   return read_frame_size(protected_slices);
152 }
153 
tsi_fake_frame_reset(tsi_fake_frame * frame,int needs_draining)154 static void tsi_fake_frame_reset(tsi_fake_frame* frame, int needs_draining) {
155   frame->offset = 0;
156   frame->needs_draining = needs_draining;
157   if (!needs_draining) frame->size = 0;
158 }
159 
160 // Checks if the frame's allocated size is at least frame->size, and reallocs
161 // more memory if necessary.
tsi_fake_frame_ensure_size(tsi_fake_frame * frame)162 static void tsi_fake_frame_ensure_size(tsi_fake_frame* frame) {
163   if (frame->data == nullptr) {
164     frame->allocated_size = frame->size;
165     frame->data =
166         static_cast<unsigned char*>(gpr_malloc(frame->allocated_size));
167   } else if (frame->size > frame->allocated_size) {
168     unsigned char* new_data =
169         static_cast<unsigned char*>(gpr_realloc(frame->data, frame->size));
170     frame->data = new_data;
171     frame->allocated_size = frame->size;
172   }
173 }
174 
175 // Decodes the serialized fake frame contained in incoming_bytes, and fills
176 // frame with the contents of the decoded frame.
177 // This method should not be called if frame->needs_framing is not 0.
tsi_fake_frame_decode(const unsigned char * incoming_bytes,size_t * incoming_bytes_size,tsi_fake_frame * frame,std::string * error)178 static tsi_result tsi_fake_frame_decode(const unsigned char* incoming_bytes,
179                                         size_t* incoming_bytes_size,
180                                         tsi_fake_frame* frame,
181                                         std::string* error) {
182   size_t available_size = *incoming_bytes_size;
183   size_t to_read_size = 0;
184   const unsigned char* bytes_cursor = incoming_bytes;
185 
186   if (frame->needs_draining) {
187     if (error != nullptr) *error = "fake handshaker frame needs draining";
188     return TSI_INTERNAL_ERROR;
189   }
190   if (frame->data == nullptr) {
191     frame->allocated_size = TSI_FAKE_FRAME_INITIAL_ALLOCATED_SIZE;
192     frame->data =
193         static_cast<unsigned char*>(gpr_malloc(frame->allocated_size));
194   }
195 
196   if (frame->offset < TSI_FAKE_FRAME_HEADER_SIZE) {
197     to_read_size = TSI_FAKE_FRAME_HEADER_SIZE - frame->offset;
198     if (to_read_size > available_size) {
199       // Just fill what we can and exit.
200       memcpy(frame->data + frame->offset, bytes_cursor, available_size);
201       bytes_cursor += available_size;
202       frame->offset += available_size;
203       *incoming_bytes_size = static_cast<size_t>(bytes_cursor - incoming_bytes);
204       return TSI_INCOMPLETE_DATA;
205     }
206     memcpy(frame->data + frame->offset, bytes_cursor, to_read_size);
207     bytes_cursor += to_read_size;
208     frame->offset += to_read_size;
209     available_size -= to_read_size;
210     frame->size = load32_little_endian(frame->data);
211     tsi_fake_frame_ensure_size(frame);
212   }
213 
214   to_read_size = frame->size - frame->offset;
215   if (to_read_size > available_size) {
216     memcpy(frame->data + frame->offset, bytes_cursor, available_size);
217     frame->offset += available_size;
218     bytes_cursor += available_size;
219     *incoming_bytes_size = static_cast<size_t>(bytes_cursor - incoming_bytes);
220     return TSI_INCOMPLETE_DATA;
221   }
222   memcpy(frame->data + frame->offset, bytes_cursor, to_read_size);
223   bytes_cursor += to_read_size;
224   *incoming_bytes_size = static_cast<size_t>(bytes_cursor - incoming_bytes);
225   tsi_fake_frame_reset(frame, 1 /* needs_draining */);
226   return TSI_OK;
227 }
228 
229 // Encodes a fake frame into its wire format and places the result in
230 // outgoing_bytes. outgoing_bytes_size indicates the size of the encoded frame.
231 // This method should not be called if frame->needs_framing is 0.
tsi_fake_frame_encode(unsigned char * outgoing_bytes,size_t * outgoing_bytes_size,tsi_fake_frame * frame,std::string * error)232 static tsi_result tsi_fake_frame_encode(unsigned char* outgoing_bytes,
233                                         size_t* outgoing_bytes_size,
234                                         tsi_fake_frame* frame,
235                                         std::string* error) {
236   size_t to_write_size = frame->size - frame->offset;
237   if (!frame->needs_draining) {
238     if (error != nullptr) *error = "fake frame needs draining";
239     return TSI_INTERNAL_ERROR;
240   }
241   if (*outgoing_bytes_size < to_write_size) {
242     memcpy(outgoing_bytes, frame->data + frame->offset, *outgoing_bytes_size);
243     frame->offset += *outgoing_bytes_size;
244     return TSI_INCOMPLETE_DATA;
245   }
246   memcpy(outgoing_bytes, frame->data + frame->offset, to_write_size);
247   *outgoing_bytes_size = to_write_size;
248   tsi_fake_frame_reset(frame, 0 /* needs_draining */);
249   return TSI_OK;
250 }
251 
252 // Sets the payload of a fake frame to contain the given data blob, where
253 // data_size indicates the size of data.
tsi_fake_frame_set_data(unsigned char * data,size_t data_size,tsi_fake_frame * frame)254 static void tsi_fake_frame_set_data(unsigned char* data, size_t data_size,
255                                     tsi_fake_frame* frame) {
256   frame->offset = 0;
257   frame->size = data_size + TSI_FAKE_FRAME_HEADER_SIZE;
258   tsi_fake_frame_ensure_size(frame);
259   store32_little_endian(static_cast<uint32_t>(frame->size), frame->data);
260   memcpy(frame->data + TSI_FAKE_FRAME_HEADER_SIZE, data, data_size);
261   tsi_fake_frame_reset(frame, 1 /* needs draining */);
262 }
263 
264 // Destroys the contents of a fake frame.
tsi_fake_frame_destruct(tsi_fake_frame * frame)265 static void tsi_fake_frame_destruct(tsi_fake_frame* frame) {
266   if (frame->data != nullptr) gpr_free(frame->data);
267 }
268 
269 // --- tsi_frame_protector methods implementation. ---
270 
fake_protector_protect(tsi_frame_protector * self,const unsigned char * unprotected_bytes,size_t * unprotected_bytes_size,unsigned char * protected_output_frames,size_t * protected_output_frames_size)271 static tsi_result fake_protector_protect(tsi_frame_protector* self,
272                                          const unsigned char* unprotected_bytes,
273                                          size_t* unprotected_bytes_size,
274                                          unsigned char* protected_output_frames,
275                                          size_t* protected_output_frames_size) {
276   tsi_result result = TSI_OK;
277   tsi_fake_frame_protector* impl =
278       reinterpret_cast<tsi_fake_frame_protector*>(self);
279   unsigned char frame_header[TSI_FAKE_FRAME_HEADER_SIZE];
280   tsi_fake_frame* frame = &impl->protect_frame;
281   size_t saved_output_size = *protected_output_frames_size;
282   size_t drained_size = 0;
283   size_t* num_bytes_written = protected_output_frames_size;
284   *num_bytes_written = 0;
285 
286   // Try to drain first.
287   if (frame->needs_draining) {
288     drained_size = saved_output_size - *num_bytes_written;
289     result = tsi_fake_frame_encode(protected_output_frames, &drained_size,
290                                    frame, /*error=*/nullptr);
291     *num_bytes_written += drained_size;
292     protected_output_frames += drained_size;
293     if (result != TSI_OK) {
294       if (result == TSI_INCOMPLETE_DATA) {
295         *unprotected_bytes_size = 0;
296         result = TSI_OK;
297       }
298       return result;
299     }
300   }
301 
302   // Now process the unprotected_bytes.
303   if (frame->needs_draining) return TSI_INTERNAL_ERROR;
304   if (frame->size == 0) {
305     // New frame, create a header.
306     size_t written_in_frame_size = 0;
307     store32_little_endian(static_cast<uint32_t>(impl->max_frame_size),
308                           frame_header);
309     written_in_frame_size = TSI_FAKE_FRAME_HEADER_SIZE;
310     result = tsi_fake_frame_decode(frame_header, &written_in_frame_size, frame,
311                                    /*error=*/nullptr);
312     if (result != TSI_INCOMPLETE_DATA) {
313       gpr_log(GPR_ERROR, "tsi_fake_frame_decode returned %s",
314               tsi_result_to_string(result));
315       return result;
316     }
317   }
318   result =
319       tsi_fake_frame_decode(unprotected_bytes, unprotected_bytes_size, frame,
320                             /*error=*/nullptr);
321   if (result != TSI_OK) {
322     if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
323     return result;
324   }
325 
326   // Try to drain again.
327   if (!frame->needs_draining) return TSI_INTERNAL_ERROR;
328   if (frame->offset != 0) return TSI_INTERNAL_ERROR;
329   drained_size = saved_output_size - *num_bytes_written;
330   result = tsi_fake_frame_encode(protected_output_frames, &drained_size, frame,
331                                  /*error=*/nullptr);
332   *num_bytes_written += drained_size;
333   if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
334   return result;
335 }
336 
fake_protector_protect_flush(tsi_frame_protector * self,unsigned char * protected_output_frames,size_t * protected_output_frames_size,size_t * still_pending_size)337 static tsi_result fake_protector_protect_flush(
338     tsi_frame_protector* self, unsigned char* protected_output_frames,
339     size_t* protected_output_frames_size, size_t* still_pending_size) {
340   tsi_result result = TSI_OK;
341   tsi_fake_frame_protector* impl =
342       reinterpret_cast<tsi_fake_frame_protector*>(self);
343   tsi_fake_frame* frame = &impl->protect_frame;
344   if (!frame->needs_draining) {
345     // Create a short frame.
346     frame->size = frame->offset;
347     frame->offset = 0;
348     frame->needs_draining = 1;
349     store32_little_endian(static_cast<uint32_t>(frame->size),
350                           frame->data);  // Overwrite header.
351   }
352   result = tsi_fake_frame_encode(protected_output_frames,
353                                  protected_output_frames_size, frame,
354                                  /*error=*/nullptr);
355   if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
356   *still_pending_size = frame->size - frame->offset;
357   return result;
358 }
359 
fake_protector_unprotect(tsi_frame_protector * self,const unsigned char * protected_frames_bytes,size_t * protected_frames_bytes_size,unsigned char * unprotected_bytes,size_t * unprotected_bytes_size)360 static tsi_result fake_protector_unprotect(
361     tsi_frame_protector* self, const unsigned char* protected_frames_bytes,
362     size_t* protected_frames_bytes_size, unsigned char* unprotected_bytes,
363     size_t* unprotected_bytes_size) {
364   tsi_result result = TSI_OK;
365   tsi_fake_frame_protector* impl =
366       reinterpret_cast<tsi_fake_frame_protector*>(self);
367   tsi_fake_frame* frame = &impl->unprotect_frame;
368   size_t saved_output_size = *unprotected_bytes_size;
369   size_t drained_size = 0;
370   size_t* num_bytes_written = unprotected_bytes_size;
371   *num_bytes_written = 0;
372 
373   // Try to drain first.
374   if (frame->needs_draining) {
375     // Go past the header if needed.
376     if (frame->offset == 0) frame->offset = TSI_FAKE_FRAME_HEADER_SIZE;
377     drained_size = saved_output_size - *num_bytes_written;
378     result = tsi_fake_frame_encode(unprotected_bytes, &drained_size, frame,
379                                    /*error=*/nullptr);
380     unprotected_bytes += drained_size;
381     *num_bytes_written += drained_size;
382     if (result != TSI_OK) {
383       if (result == TSI_INCOMPLETE_DATA) {
384         *protected_frames_bytes_size = 0;
385         result = TSI_OK;
386       }
387       return result;
388     }
389   }
390 
391   // Now process the protected_bytes.
392   if (frame->needs_draining) return TSI_INTERNAL_ERROR;
393   result = tsi_fake_frame_decode(protected_frames_bytes,
394                                  protected_frames_bytes_size, frame,
395                                  /*error=*/nullptr);
396   if (result != TSI_OK) {
397     if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
398     return result;
399   }
400 
401   // Try to drain again.
402   if (!frame->needs_draining) return TSI_INTERNAL_ERROR;
403   if (frame->offset != 0) return TSI_INTERNAL_ERROR;
404   frame->offset = TSI_FAKE_FRAME_HEADER_SIZE;  // Go past the header.
405   drained_size = saved_output_size - *num_bytes_written;
406   result = tsi_fake_frame_encode(unprotected_bytes, &drained_size, frame,
407                                  /*error=*/nullptr);
408   *num_bytes_written += drained_size;
409   if (result == TSI_INCOMPLETE_DATA) result = TSI_OK;
410   return result;
411 }
412 
fake_protector_destroy(tsi_frame_protector * self)413 static void fake_protector_destroy(tsi_frame_protector* self) {
414   tsi_fake_frame_protector* impl =
415       reinterpret_cast<tsi_fake_frame_protector*>(self);
416   tsi_fake_frame_destruct(&impl->protect_frame);
417   tsi_fake_frame_destruct(&impl->unprotect_frame);
418   gpr_free(self);
419 }
420 
421 static const tsi_frame_protector_vtable frame_protector_vtable = {
422     fake_protector_protect,
423     fake_protector_protect_flush,
424     fake_protector_unprotect,
425     fake_protector_destroy,
426 };
427 
428 // --- tsi_zero_copy_grpc_protector methods implementation. ---
429 
fake_zero_copy_grpc_protector_protect(tsi_zero_copy_grpc_protector * self,grpc_slice_buffer * unprotected_slices,grpc_slice_buffer * protected_slices)430 static tsi_result fake_zero_copy_grpc_protector_protect(
431     tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* unprotected_slices,
432     grpc_slice_buffer* protected_slices) {
433   if (self == nullptr || unprotected_slices == nullptr ||
434       protected_slices == nullptr) {
435     return TSI_INVALID_ARGUMENT;
436   }
437   tsi_fake_zero_copy_grpc_protector* impl =
438       reinterpret_cast<tsi_fake_zero_copy_grpc_protector*>(self);
439   // Protects each frame.
440   while (unprotected_slices->length > 0) {
441     size_t frame_length =
442         std::min(impl->max_frame_size,
443                  unprotected_slices->length + TSI_FAKE_FRAME_HEADER_SIZE);
444     grpc_slice slice = GRPC_SLICE_MALLOC(TSI_FAKE_FRAME_HEADER_SIZE);
445     store32_little_endian(static_cast<uint32_t>(frame_length),
446                           GRPC_SLICE_START_PTR(slice));
447     grpc_slice_buffer_add(protected_slices, slice);
448     size_t data_length = frame_length - TSI_FAKE_FRAME_HEADER_SIZE;
449     grpc_slice_buffer_move_first(unprotected_slices, data_length,
450                                  protected_slices);
451   }
452   return TSI_OK;
453 }
454 
fake_zero_copy_grpc_protector_unprotect(tsi_zero_copy_grpc_protector * self,grpc_slice_buffer * protected_slices,grpc_slice_buffer * unprotected_slices,int * min_progress_size)455 static tsi_result fake_zero_copy_grpc_protector_unprotect(
456     tsi_zero_copy_grpc_protector* self, grpc_slice_buffer* protected_slices,
457     grpc_slice_buffer* unprotected_slices, int* min_progress_size) {
458   if (self == nullptr || unprotected_slices == nullptr ||
459       protected_slices == nullptr) {
460     return TSI_INVALID_ARGUMENT;
461   }
462   tsi_fake_zero_copy_grpc_protector* impl =
463       reinterpret_cast<tsi_fake_zero_copy_grpc_protector*>(self);
464   grpc_slice_buffer_move_into(protected_slices, &impl->protected_sb);
465   // Unprotect each frame, if we get a full frame.
466   while (impl->protected_sb.length >= TSI_FAKE_FRAME_HEADER_SIZE) {
467     if (impl->parsed_frame_size == 0) {
468       impl->parsed_frame_size = read_frame_size(&impl->protected_sb);
469       if (impl->parsed_frame_size <= 4) {
470         gpr_log(GPR_ERROR, "Invalid frame size.");
471         return TSI_DATA_CORRUPTED;
472       }
473     }
474     // If we do not have a full frame, return with OK status.
475     if (impl->protected_sb.length < impl->parsed_frame_size) break;
476     // Strips header bytes.
477     grpc_slice_buffer_move_first(&impl->protected_sb,
478                                  TSI_FAKE_FRAME_HEADER_SIZE, &impl->header_sb);
479     // Moves data to unprotected slices.
480     grpc_slice_buffer_move_first(
481         &impl->protected_sb,
482         impl->parsed_frame_size - TSI_FAKE_FRAME_HEADER_SIZE,
483         unprotected_slices);
484     impl->parsed_frame_size = 0;
485     grpc_slice_buffer_reset_and_unref(&impl->header_sb);
486   }
487   if (min_progress_size != nullptr) {
488     if (impl->parsed_frame_size > TSI_FAKE_FRAME_HEADER_SIZE) {
489       *min_progress_size = impl->parsed_frame_size - impl->protected_sb.length;
490     } else {
491       *min_progress_size = 1;
492     }
493   }
494   return TSI_OK;
495 }
496 
fake_zero_copy_grpc_protector_destroy(tsi_zero_copy_grpc_protector * self)497 static void fake_zero_copy_grpc_protector_destroy(
498     tsi_zero_copy_grpc_protector* self) {
499   if (self == nullptr) return;
500   tsi_fake_zero_copy_grpc_protector* impl =
501       reinterpret_cast<tsi_fake_zero_copy_grpc_protector*>(self);
502   grpc_slice_buffer_destroy(&impl->header_sb);
503   grpc_slice_buffer_destroy(&impl->protected_sb);
504   gpr_free(impl);
505 }
506 
fake_zero_copy_grpc_protector_max_frame_size(tsi_zero_copy_grpc_protector * self,size_t * max_frame_size)507 static tsi_result fake_zero_copy_grpc_protector_max_frame_size(
508     tsi_zero_copy_grpc_protector* self, size_t* max_frame_size) {
509   if (self == nullptr || max_frame_size == nullptr) return TSI_INVALID_ARGUMENT;
510   tsi_fake_zero_copy_grpc_protector* impl =
511       reinterpret_cast<tsi_fake_zero_copy_grpc_protector*>(self);
512   *max_frame_size = impl->max_frame_size;
513   return TSI_OK;
514 }
515 
516 static const tsi_zero_copy_grpc_protector_vtable
517     zero_copy_grpc_protector_vtable = {
518         fake_zero_copy_grpc_protector_protect,
519         fake_zero_copy_grpc_protector_unprotect,
520         fake_zero_copy_grpc_protector_destroy,
521         fake_zero_copy_grpc_protector_max_frame_size,
522 };
523 
524 // --- tsi_handshaker_result methods implementation. ---
525 
526 struct fake_handshaker_result {
527   tsi_handshaker_result base;
528   unsigned char* unused_bytes;
529   size_t unused_bytes_size;
530 };
531 
fake_handshaker_result_extract_peer(const tsi_handshaker_result *,tsi_peer * peer)532 static tsi_result fake_handshaker_result_extract_peer(
533     const tsi_handshaker_result* /*self*/, tsi_peer* peer) {
534   // Construct a tsi_peer with 1 property: certificate type, security_level.
535   tsi_result result = tsi_construct_peer(2, peer);
536   if (result != TSI_OK) return result;
537   result = tsi_construct_string_peer_property_from_cstring(
538       TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_FAKE_CERTIFICATE_TYPE,
539       &peer->properties[0]);
540   if (result != TSI_OK) tsi_peer_destruct(peer);
541   result = tsi_construct_string_peer_property_from_cstring(
542       TSI_SECURITY_LEVEL_PEER_PROPERTY,
543       tsi_security_level_to_string(TSI_SECURITY_NONE), &peer->properties[1]);
544   if (result != TSI_OK) tsi_peer_destruct(peer);
545   return result;
546 }
547 
fake_handshaker_result_get_frame_protector_type(const tsi_handshaker_result *,tsi_frame_protector_type * frame_protector_type)548 static tsi_result fake_handshaker_result_get_frame_protector_type(
549     const tsi_handshaker_result* /*self*/,
550     tsi_frame_protector_type* frame_protector_type) {
551   *frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY;
552   return TSI_OK;
553 }
554 
fake_handshaker_result_create_zero_copy_grpc_protector(const tsi_handshaker_result *,size_t * max_output_protected_frame_size,tsi_zero_copy_grpc_protector ** protector)555 static tsi_result fake_handshaker_result_create_zero_copy_grpc_protector(
556     const tsi_handshaker_result* /*self*/,
557     size_t* max_output_protected_frame_size,
558     tsi_zero_copy_grpc_protector** protector) {
559   *protector =
560       tsi_create_fake_zero_copy_grpc_protector(max_output_protected_frame_size);
561   return TSI_OK;
562 }
563 
fake_handshaker_result_create_frame_protector(const tsi_handshaker_result *,size_t * max_output_protected_frame_size,tsi_frame_protector ** protector)564 static tsi_result fake_handshaker_result_create_frame_protector(
565     const tsi_handshaker_result* /*self*/,
566     size_t* max_output_protected_frame_size, tsi_frame_protector** protector) {
567   *protector = tsi_create_fake_frame_protector(max_output_protected_frame_size);
568   return TSI_OK;
569 }
570 
fake_handshaker_result_get_unused_bytes(const tsi_handshaker_result * self,const unsigned char ** bytes,size_t * bytes_size)571 static tsi_result fake_handshaker_result_get_unused_bytes(
572     const tsi_handshaker_result* self, const unsigned char** bytes,
573     size_t* bytes_size) {
574   fake_handshaker_result* result = reinterpret_cast<fake_handshaker_result*>(
575       const_cast<tsi_handshaker_result*>(self));
576   *bytes_size = result->unused_bytes_size;
577   *bytes = result->unused_bytes;
578   return TSI_OK;
579 }
580 
fake_handshaker_result_destroy(tsi_handshaker_result * self)581 static void fake_handshaker_result_destroy(tsi_handshaker_result* self) {
582   fake_handshaker_result* result =
583       reinterpret_cast<fake_handshaker_result*>(self);
584   gpr_free(result->unused_bytes);
585   gpr_free(self);
586 }
587 
588 static const tsi_handshaker_result_vtable handshaker_result_vtable = {
589     fake_handshaker_result_extract_peer,
590     fake_handshaker_result_get_frame_protector_type,
591     fake_handshaker_result_create_zero_copy_grpc_protector,
592     fake_handshaker_result_create_frame_protector,
593     fake_handshaker_result_get_unused_bytes,
594     fake_handshaker_result_destroy,
595 };
596 
fake_handshaker_result_create(const unsigned char * unused_bytes,size_t unused_bytes_size,tsi_handshaker_result ** handshaker_result,std::string * error)597 static tsi_result fake_handshaker_result_create(
598     const unsigned char* unused_bytes, size_t unused_bytes_size,
599     tsi_handshaker_result** handshaker_result, std::string* error) {
600   if ((unused_bytes_size > 0 && unused_bytes == nullptr) ||
601       handshaker_result == nullptr) {
602     if (error != nullptr) *error = "invalid argument";
603     return TSI_INVALID_ARGUMENT;
604   }
605   fake_handshaker_result* result = grpc_core::Zalloc<fake_handshaker_result>();
606   result->base.vtable = &handshaker_result_vtable;
607   if (unused_bytes_size > 0) {
608     result->unused_bytes =
609         static_cast<unsigned char*>(gpr_malloc(unused_bytes_size));
610     memcpy(result->unused_bytes, unused_bytes, unused_bytes_size);
611   }
612   result->unused_bytes_size = unused_bytes_size;
613   *handshaker_result = &result->base;
614   return TSI_OK;
615 }
616 
617 // --- tsi_handshaker methods implementation. ---
618 
fake_handshaker_get_bytes_to_send_to_peer(tsi_handshaker * self,unsigned char * bytes,size_t * bytes_size,std::string * error)619 static tsi_result fake_handshaker_get_bytes_to_send_to_peer(
620     tsi_handshaker* self, unsigned char* bytes, size_t* bytes_size,
621     std::string* error) {
622   tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self);
623   tsi_result result = TSI_OK;
624   if (impl->needs_incoming_message || impl->result == TSI_OK) {
625     *bytes_size = 0;
626     return TSI_OK;
627   }
628   if (!impl->outgoing_frame.needs_draining) {
629     tsi_fake_handshake_message next_message_to_send =
630         // NOLINTNEXTLINE(bugprone-misplaced-widening-cast)
631         static_cast<tsi_fake_handshake_message>(impl->next_message_to_send + 2);
632     const char* msg_string =
633         tsi_fake_handshake_message_to_string(impl->next_message_to_send);
634     tsi_fake_frame_set_data(
635         reinterpret_cast<unsigned char*>(const_cast<char*>(msg_string)),
636         strlen(msg_string), &impl->outgoing_frame);
637     if (next_message_to_send > TSI_FAKE_HANDSHAKE_MESSAGE_MAX) {
638       next_message_to_send = TSI_FAKE_HANDSHAKE_MESSAGE_MAX;
639     }
640     if (GRPC_TRACE_FLAG_ENABLED(tsi_tracing_enabled)) {
641       gpr_log(GPR_INFO, "%s prepared %s.",
642               impl->is_client ? "Client" : "Server",
643               tsi_fake_handshake_message_to_string(impl->next_message_to_send));
644     }
645     impl->next_message_to_send = next_message_to_send;
646   }
647   result =
648       tsi_fake_frame_encode(bytes, bytes_size, &impl->outgoing_frame, error);
649   if (result != TSI_OK) return result;
650   if (!impl->is_client &&
651       impl->next_message_to_send == TSI_FAKE_HANDSHAKE_MESSAGE_MAX) {
652     // We're done.
653     if (GRPC_TRACE_FLAG_ENABLED(tsi_tracing_enabled)) {
654       gpr_log(GPR_INFO, "Server is done.");
655     }
656     impl->result = TSI_OK;
657   } else {
658     impl->needs_incoming_message = 1;
659   }
660   return TSI_OK;
661 }
662 
fake_handshaker_process_bytes_from_peer(tsi_handshaker * self,const unsigned char * bytes,size_t * bytes_size,std::string * error)663 static tsi_result fake_handshaker_process_bytes_from_peer(
664     tsi_handshaker* self, const unsigned char* bytes, size_t* bytes_size,
665     std::string* error) {
666   tsi_result result = TSI_OK;
667   tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self);
668   tsi_fake_handshake_message expected_msg =
669       static_cast<tsi_fake_handshake_message>(impl->next_message_to_send - 1);
670   tsi_fake_handshake_message received_msg;
671 
672   if (!impl->needs_incoming_message || impl->result == TSI_OK) {
673     *bytes_size = 0;
674     return TSI_OK;
675   }
676   result =
677       tsi_fake_frame_decode(bytes, bytes_size, &impl->incoming_frame, error);
678   if (result != TSI_OK) return result;
679 
680   // We now have a complete frame.
681   result = tsi_fake_handshake_message_from_string(
682       reinterpret_cast<const char*>(impl->incoming_frame.data) +
683           TSI_FAKE_FRAME_HEADER_SIZE,
684       &received_msg, error);
685   if (result != TSI_OK) {
686     impl->result = result;
687     return result;
688   }
689   if (received_msg != expected_msg) {
690     gpr_log(GPR_ERROR, "Invalid received message (%s instead of %s)",
691             tsi_fake_handshake_message_to_string(received_msg),
692             tsi_fake_handshake_message_to_string(expected_msg));
693   }
694   if (GRPC_TRACE_FLAG_ENABLED(tsi_tracing_enabled)) {
695     gpr_log(GPR_INFO, "%s received %s.", impl->is_client ? "Client" : "Server",
696             tsi_fake_handshake_message_to_string(received_msg));
697   }
698   tsi_fake_frame_reset(&impl->incoming_frame, 0 /* needs_draining */);
699   impl->needs_incoming_message = 0;
700   if (impl->next_message_to_send == TSI_FAKE_HANDSHAKE_MESSAGE_MAX) {
701     // We're done.
702     if (GRPC_TRACE_FLAG_ENABLED(tsi_tracing_enabled)) {
703       gpr_log(GPR_INFO, "%s is done.", impl->is_client ? "Client" : "Server");
704     }
705     impl->result = TSI_OK;
706   }
707   return TSI_OK;
708 }
709 
fake_handshaker_get_result(tsi_handshaker * self)710 static tsi_result fake_handshaker_get_result(tsi_handshaker* self) {
711   tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self);
712   return impl->result;
713 }
714 
fake_handshaker_destroy(tsi_handshaker * self)715 static void fake_handshaker_destroy(tsi_handshaker* self) {
716   tsi_fake_handshaker* impl = reinterpret_cast<tsi_fake_handshaker*>(self);
717   tsi_fake_frame_destruct(&impl->incoming_frame);
718   tsi_fake_frame_destruct(&impl->outgoing_frame);
719   gpr_free(impl->outgoing_bytes_buffer);
720   gpr_free(self);
721 }
722 
fake_handshaker_next(tsi_handshaker * self,const unsigned char * received_bytes,size_t received_bytes_size,const unsigned char ** bytes_to_send,size_t * bytes_to_send_size,tsi_handshaker_result ** handshaker_result,tsi_handshaker_on_next_done_cb,void *,std::string * error)723 static tsi_result fake_handshaker_next(
724     tsi_handshaker* self, const unsigned char* received_bytes,
725     size_t received_bytes_size, const unsigned char** bytes_to_send,
726     size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result,
727     tsi_handshaker_on_next_done_cb /*cb*/, void* /*user_data*/,
728     std::string* error) {
729   // Sanity check the arguments.
730   if ((received_bytes_size > 0 && received_bytes == nullptr) ||
731       bytes_to_send == nullptr || bytes_to_send_size == nullptr ||
732       handshaker_result == nullptr) {
733     if (error != nullptr) *error = "invalid argument";
734     return TSI_INVALID_ARGUMENT;
735   }
736   tsi_fake_handshaker* handshaker =
737       reinterpret_cast<tsi_fake_handshaker*>(self);
738   tsi_result result = TSI_OK;
739 
740   // Decode and process a handshake frame from the peer.
741   size_t consumed_bytes_size = received_bytes_size;
742   if (received_bytes_size > 0) {
743     result = fake_handshaker_process_bytes_from_peer(
744         self, received_bytes, &consumed_bytes_size, error);
745     if (result != TSI_OK) return result;
746   }
747 
748   // Create a handshake message to send to the peer and encode it as a fake
749   // frame.
750   size_t offset = 0;
751   do {
752     size_t sent_bytes_size = handshaker->outgoing_bytes_buffer_size - offset;
753     result = fake_handshaker_get_bytes_to_send_to_peer(
754         self, handshaker->outgoing_bytes_buffer + offset, &sent_bytes_size,
755         error);
756     offset += sent_bytes_size;
757     if (result == TSI_INCOMPLETE_DATA) {
758       handshaker->outgoing_bytes_buffer_size *= 2;
759       handshaker->outgoing_bytes_buffer = static_cast<unsigned char*>(
760           gpr_realloc(handshaker->outgoing_bytes_buffer,
761                       handshaker->outgoing_bytes_buffer_size));
762     }
763   } while (result == TSI_INCOMPLETE_DATA);
764   if (result != TSI_OK) return result;
765   *bytes_to_send = handshaker->outgoing_bytes_buffer;
766   *bytes_to_send_size = offset;
767 
768   // Check if the handshake was completed.
769   if (fake_handshaker_get_result(self) == TSI_HANDSHAKE_IN_PROGRESS) {
770     *handshaker_result = nullptr;
771   } else {
772     // Calculate the unused bytes.
773     const unsigned char* unused_bytes = nullptr;
774     size_t unused_bytes_size = received_bytes_size - consumed_bytes_size;
775     if (unused_bytes_size > 0) {
776       unused_bytes = received_bytes + consumed_bytes_size;
777     }
778 
779     // Create a handshaker_result containing the unused bytes.
780     result = fake_handshaker_result_create(unused_bytes, unused_bytes_size,
781                                            handshaker_result, error);
782     if (result == TSI_OK) {
783       // Indicate that the handshake has completed and that a handshaker_result
784       // has been created.
785       self->handshaker_result_created = true;
786     }
787   }
788   return result;
789 }
790 
791 static const tsi_handshaker_vtable handshaker_vtable = {
792     nullptr,  // get_bytes_to_send_to_peer -- deprecated
793     nullptr,  // process_bytes_from_peer   -- deprecated
794     nullptr,  // get_result                -- deprecated
795     nullptr,  // extract_peer              -- deprecated
796     nullptr,  // create_frame_protector    -- deprecated
797     fake_handshaker_destroy,
798     fake_handshaker_next,
799     nullptr,  // shutdown
800 };
801 
tsi_create_fake_handshaker(int is_client)802 tsi_handshaker* tsi_create_fake_handshaker(int is_client) {
803   tsi_fake_handshaker* impl = grpc_core::Zalloc<tsi_fake_handshaker>();
804   impl->base.vtable = &handshaker_vtable;
805   impl->is_client = is_client;
806   impl->result = TSI_HANDSHAKE_IN_PROGRESS;
807   impl->outgoing_bytes_buffer_size =
808       TSI_FAKE_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE;
809   impl->outgoing_bytes_buffer =
810       static_cast<unsigned char*>(gpr_malloc(impl->outgoing_bytes_buffer_size));
811   if (is_client) {
812     impl->needs_incoming_message = 0;
813     impl->next_message_to_send = TSI_FAKE_CLIENT_INIT;
814   } else {
815     impl->needs_incoming_message = 1;
816     impl->next_message_to_send = TSI_FAKE_SERVER_INIT;
817   }
818   return &impl->base;
819 }
820 
tsi_create_fake_frame_protector(size_t * max_protected_frame_size)821 tsi_frame_protector* tsi_create_fake_frame_protector(
822     size_t* max_protected_frame_size) {
823   tsi_fake_frame_protector* impl =
824       grpc_core::Zalloc<tsi_fake_frame_protector>();
825   impl->max_frame_size = (max_protected_frame_size == nullptr)
826                              ? TSI_FAKE_DEFAULT_FRAME_SIZE
827                              : *max_protected_frame_size;
828   impl->base.vtable = &frame_protector_vtable;
829   return &impl->base;
830 }
831 
tsi_create_fake_zero_copy_grpc_protector(size_t * max_protected_frame_size)832 tsi_zero_copy_grpc_protector* tsi_create_fake_zero_copy_grpc_protector(
833     size_t* max_protected_frame_size) {
834   tsi_fake_zero_copy_grpc_protector* impl =
835       static_cast<tsi_fake_zero_copy_grpc_protector*>(
836           gpr_zalloc(sizeof(*impl)));
837   grpc_slice_buffer_init(&impl->header_sb);
838   grpc_slice_buffer_init(&impl->protected_sb);
839   impl->max_frame_size = (max_protected_frame_size == nullptr)
840                              ? TSI_FAKE_DEFAULT_FRAME_SIZE
841                              : *max_protected_frame_size;
842   impl->parsed_frame_size = 0;
843   impl->base.vtable = &zero_copy_grpc_protector_vtable;
844   return &impl->base;
845 }
846