1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/tsi/alts/frame_protector/alts_frame_protector.h"
22 
23 #include <stdio.h>
24 #include <stdlib.h>
25 
26 #include <algorithm>
27 
28 #include <grpc/support/alloc.h>
29 #include <grpc/support/log.h>
30 
31 #include "src/core/lib/gpr/useful.h"
32 #include "src/core/lib/gprpp/crash.h"
33 #include "src/core/lib/gprpp/memory.h"
34 #include "src/core/tsi/alts/crypt/gsec.h"
35 #include "src/core/tsi/alts/frame_protector/alts_crypter.h"
36 #include "src/core/tsi/alts/frame_protector/frame_handler.h"
37 #include "src/core/tsi/transport_security.h"
38 
39 constexpr size_t kMinFrameLength = 1024;
40 constexpr size_t kDefaultFrameLength = 16 * 1024;
41 constexpr size_t kMaxFrameLength = 1024 * 1024;
42 
43 // Limit k on number of frames such that at most 2^(8 * k) frames can be sent.
44 constexpr size_t kAltsRecordProtocolRekeyFrameLimit = 8;
45 constexpr size_t kAltsRecordProtocolFrameLimit = 5;
46 
47 // Main struct for alts_frame_protector.
48 struct alts_frame_protector {
49   tsi_frame_protector base;
50   alts_crypter* seal_crypter;
51   alts_crypter* unseal_crypter;
52   alts_frame_writer* writer;
53   alts_frame_reader* reader;
54   unsigned char* in_place_protect_buffer;
55   unsigned char* in_place_unprotect_buffer;
56   size_t in_place_protect_bytes_buffered;
57   size_t in_place_unprotect_bytes_processed;
58   size_t max_protected_frame_size;
59   size_t max_unprotected_frame_size;
60   size_t overhead_length;
61   size_t counter_overflow;
62 };
63 
seal(alts_frame_protector * impl)64 static tsi_result seal(alts_frame_protector* impl) {
65   char* error_details = nullptr;
66   size_t output_size = 0;
67   grpc_status_code status = alts_crypter_process_in_place(
68       impl->seal_crypter, impl->in_place_protect_buffer,
69       impl->max_protected_frame_size, impl->in_place_protect_bytes_buffered,
70       &output_size, &error_details);
71   impl->in_place_protect_bytes_buffered = output_size;
72   if (status != GRPC_STATUS_OK) {
73     gpr_log(GPR_ERROR, "%s", error_details);
74     gpr_free(error_details);
75     return TSI_INTERNAL_ERROR;
76   }
77   return TSI_OK;
78 }
79 
max_encrypted_payload_bytes(alts_frame_protector * impl)80 static size_t max_encrypted_payload_bytes(alts_frame_protector* impl) {
81   return impl->max_protected_frame_size - kFrameHeaderSize;
82 }
83 
alts_protect_flush(tsi_frame_protector * self,unsigned char * protected_output_frames,size_t * protected_output_frames_size,size_t * still_pending_size)84 static tsi_result alts_protect_flush(tsi_frame_protector* self,
85                                      unsigned char* protected_output_frames,
86                                      size_t* protected_output_frames_size,
87                                      size_t* still_pending_size) {
88   if (self == nullptr || protected_output_frames == nullptr ||
89       protected_output_frames_size == nullptr ||
90       still_pending_size == nullptr) {
91     gpr_log(GPR_ERROR, "Invalid nullptr arguments to alts_protect_flush().");
92     return TSI_INVALID_ARGUMENT;
93   }
94   alts_frame_protector* impl = reinterpret_cast<alts_frame_protector*>(self);
95   ///
96   /// If there's nothing to flush (i.e., in_place_protect_buffer is empty),
97   /// we're done.
98   ///
99   if (impl->in_place_protect_bytes_buffered == 0) {
100     *protected_output_frames_size = 0;
101     *still_pending_size = 0;
102     return TSI_OK;
103   }
104   ///
105   /// If a new frame can start being processed, we encrypt the payload and reset
106   /// the frame writer to point to in_place_protect_buffer that holds the newly
107   /// sealed frame.
108   ///
109   if (alts_is_frame_writer_done(impl->writer)) {
110     tsi_result result = seal(impl);
111     if (result != TSI_OK) {
112       return result;
113     }
114     if (!alts_reset_frame_writer(impl->writer, impl->in_place_protect_buffer,
115                                  impl->in_place_protect_bytes_buffered)) {
116       gpr_log(GPR_ERROR, "Couldn't reset frame writer.");
117       return TSI_INTERNAL_ERROR;
118     }
119   }
120   ///
121   /// Write the sealed frame as much as possible to protected_output_frames.
122   /// It's possible a frame will not be written out completely by a single flush
123   ///(i.e., still_pending_size != 0), in which case the flush should be called
124   /// iteratively until a complete frame has been written out.
125   ///
126   size_t written_frame_bytes = *protected_output_frames_size;
127   if (!alts_write_frame_bytes(impl->writer, protected_output_frames,
128                               &written_frame_bytes)) {
129     gpr_log(GPR_ERROR, "Couldn't write frame bytes.");
130     return TSI_INTERNAL_ERROR;
131   }
132   *protected_output_frames_size = written_frame_bytes;
133   *still_pending_size = alts_get_num_writer_bytes_remaining(impl->writer);
134   ///
135   /// If the current frame has been finished processing (i.e., sealed and
136   /// written out completely), we empty in_place_protect_buffer.
137   ///
138   if (alts_is_frame_writer_done(impl->writer)) {
139     impl->in_place_protect_bytes_buffered = 0;
140   }
141   return TSI_OK;
142 }
143 
alts_protect(tsi_frame_protector * self,const unsigned char * unprotected_bytes,size_t * unprotected_bytes_size,unsigned char * protected_output_frames,size_t * protected_output_frames_size)144 static tsi_result alts_protect(tsi_frame_protector* self,
145                                const unsigned char* unprotected_bytes,
146                                size_t* unprotected_bytes_size,
147                                unsigned char* protected_output_frames,
148                                size_t* protected_output_frames_size) {
149   if (self == nullptr || unprotected_bytes == nullptr ||
150       unprotected_bytes_size == nullptr || protected_output_frames == nullptr ||
151       protected_output_frames_size == nullptr) {
152     gpr_log(GPR_ERROR, "Invalid nullptr arguments to alts_protect().");
153     return TSI_INVALID_ARGUMENT;
154   }
155   alts_frame_protector* impl = reinterpret_cast<alts_frame_protector*>(self);
156 
157   ///
158   /// If more payload can be buffered, we buffer it as much as possible to
159   /// in_place_protect_buffer.
160   ///
161   if (impl->in_place_protect_bytes_buffered + impl->overhead_length <
162       max_encrypted_payload_bytes(impl)) {
163     size_t bytes_to_buffer = std::min(
164         *unprotected_bytes_size, max_encrypted_payload_bytes(impl) -
165                                      impl->in_place_protect_bytes_buffered -
166                                      impl->overhead_length);
167     *unprotected_bytes_size = bytes_to_buffer;
168     if (bytes_to_buffer > 0) {
169       memcpy(
170           impl->in_place_protect_buffer + impl->in_place_protect_bytes_buffered,
171           unprotected_bytes, bytes_to_buffer);
172       impl->in_place_protect_bytes_buffered += bytes_to_buffer;
173     }
174   } else {
175     *unprotected_bytes_size = 0;
176   }
177   ///
178   /// If a full frame has been buffered, we output it. If the first condition
179   /// holds, then there exists an unencrypted full frame. If the second
180   /// condition holds, then there exists a full frame that has already been
181   /// encrypted.
182   ///
183   if (max_encrypted_payload_bytes(impl) ==
184           impl->in_place_protect_bytes_buffered + impl->overhead_length ||
185       max_encrypted_payload_bytes(impl) ==
186           impl->in_place_protect_bytes_buffered) {
187     size_t still_pending_size = 0;
188     return alts_protect_flush(self, protected_output_frames,
189                               protected_output_frames_size,
190                               &still_pending_size);
191   } else {
192     *protected_output_frames_size = 0;
193     return TSI_OK;
194   }
195 }
196 
unseal(alts_frame_protector * impl)197 static tsi_result unseal(alts_frame_protector* impl) {
198   char* error_details = nullptr;
199   size_t output_size = 0;
200   grpc_status_code status = alts_crypter_process_in_place(
201       impl->unseal_crypter, impl->in_place_unprotect_buffer,
202       impl->max_unprotected_frame_size,
203       alts_get_output_bytes_read(impl->reader), &output_size, &error_details);
204   if (status != GRPC_STATUS_OK) {
205     gpr_log(GPR_ERROR, "%s", error_details);
206     gpr_free(error_details);
207     return TSI_DATA_CORRUPTED;
208   }
209   return TSI_OK;
210 }
211 
ensure_buffer_size(alts_frame_protector * impl)212 static void ensure_buffer_size(alts_frame_protector* impl) {
213   if (!alts_has_read_frame_length(impl->reader)) {
214     return;
215   }
216   size_t buffer_space_remaining = impl->max_unprotected_frame_size -
217                                   alts_get_output_bytes_read(impl->reader);
218   ///
219   /// Check if we need to resize in_place_unprotect_buffer in order to hold
220   /// remaining bytes of a full frame.
221   ///
222   if (buffer_space_remaining < alts_get_reader_bytes_remaining(impl->reader)) {
223     size_t buffer_len = alts_get_output_bytes_read(impl->reader) +
224                         alts_get_reader_bytes_remaining(impl->reader);
225     unsigned char* buffer = static_cast<unsigned char*>(gpr_malloc(buffer_len));
226     memcpy(buffer, impl->in_place_unprotect_buffer,
227            alts_get_output_bytes_read(impl->reader));
228     impl->max_unprotected_frame_size = buffer_len;
229     gpr_free(impl->in_place_unprotect_buffer);
230     impl->in_place_unprotect_buffer = buffer;
231     alts_reset_reader_output_buffer(
232         impl->reader, buffer + alts_get_output_bytes_read(impl->reader));
233   }
234 }
235 
alts_unprotect(tsi_frame_protector * self,const unsigned char * protected_frames_bytes,size_t * protected_frames_bytes_size,unsigned char * unprotected_bytes,size_t * unprotected_bytes_size)236 static tsi_result alts_unprotect(tsi_frame_protector* self,
237                                  const unsigned char* protected_frames_bytes,
238                                  size_t* protected_frames_bytes_size,
239                                  unsigned char* unprotected_bytes,
240                                  size_t* unprotected_bytes_size) {
241   if (self == nullptr || protected_frames_bytes == nullptr ||
242       protected_frames_bytes_size == nullptr || unprotected_bytes == nullptr ||
243       unprotected_bytes_size == nullptr) {
244     gpr_log(GPR_ERROR, "Invalid nullptr arguments to alts_unprotect().");
245     return TSI_INVALID_ARGUMENT;
246   }
247   alts_frame_protector* impl = reinterpret_cast<alts_frame_protector*>(self);
248   ///
249   /// If a new frame can start being processed, we reset the frame reader to
250   /// point to in_place_unprotect_buffer that will be used to hold deframed
251   /// result.
252   ///
253   if (alts_is_frame_reader_done(impl->reader) &&
254       ((alts_get_output_buffer(impl->reader) == nullptr) ||
255        (alts_get_output_bytes_read(impl->reader) ==
256         impl->in_place_unprotect_bytes_processed + impl->overhead_length))) {
257     if (!alts_reset_frame_reader(impl->reader,
258                                  impl->in_place_unprotect_buffer)) {
259       gpr_log(GPR_ERROR, "Couldn't reset frame reader.");
260       return TSI_INTERNAL_ERROR;
261     }
262     impl->in_place_unprotect_bytes_processed = 0;
263   }
264   ///
265   /// If a full frame has not yet been read, we read more bytes from
266   /// protected_frames_bytes until a full frame has been read. We also need to
267   /// make sure in_place_unprotect_buffer is large enough to hold a complete
268   /// frame.
269   ///
270   if (!alts_is_frame_reader_done(impl->reader)) {
271     ensure_buffer_size(impl);
272     *protected_frames_bytes_size =
273         std::min(impl->max_unprotected_frame_size -
274                      alts_get_output_bytes_read(impl->reader),
275                  *protected_frames_bytes_size);
276     size_t read_frames_bytes_size = *protected_frames_bytes_size;
277     if (!alts_read_frame_bytes(impl->reader, protected_frames_bytes,
278                                &read_frames_bytes_size)) {
279       gpr_log(GPR_ERROR, "Failed to process frame.");
280       return TSI_INTERNAL_ERROR;
281     }
282     *protected_frames_bytes_size = read_frames_bytes_size;
283   } else {
284     *protected_frames_bytes_size = 0;
285   }
286   ///
287   /// If a full frame has been read, we unseal it, and write out the
288   /// deframed result to unprotected_bytes.
289   ///
290   if (alts_is_frame_reader_done(impl->reader)) {
291     if (impl->in_place_unprotect_bytes_processed == 0) {
292       tsi_result result = unseal(impl);
293       if (result != TSI_OK) {
294         return result;
295       }
296     }
297     size_t bytes_to_write = std::min(
298         *unprotected_bytes_size, alts_get_output_bytes_read(impl->reader) -
299                                      impl->in_place_unprotect_bytes_processed -
300                                      impl->overhead_length);
301     if (bytes_to_write > 0) {
302       memcpy(unprotected_bytes,
303              impl->in_place_unprotect_buffer +
304                  impl->in_place_unprotect_bytes_processed,
305              bytes_to_write);
306     }
307     *unprotected_bytes_size = bytes_to_write;
308     impl->in_place_unprotect_bytes_processed += bytes_to_write;
309     return TSI_OK;
310   } else {
311     *unprotected_bytes_size = 0;
312     return TSI_OK;
313   }
314 }
315 
alts_destroy(tsi_frame_protector * self)316 static void alts_destroy(tsi_frame_protector* self) {
317   alts_frame_protector* impl = reinterpret_cast<alts_frame_protector*>(self);
318   if (impl != nullptr) {
319     alts_crypter_destroy(impl->seal_crypter);
320     alts_crypter_destroy(impl->unseal_crypter);
321     gpr_free(impl->in_place_protect_buffer);
322     gpr_free(impl->in_place_unprotect_buffer);
323     alts_destroy_frame_writer(impl->writer);
324     alts_destroy_frame_reader(impl->reader);
325     gpr_free(impl);
326   }
327 }
328 
329 static const tsi_frame_protector_vtable alts_frame_protector_vtable = {
330     alts_protect, alts_protect_flush, alts_unprotect, alts_destroy};
331 
create_alts_crypters(const uint8_t * key,size_t key_size,bool is_client,bool is_rekey,alts_frame_protector * impl,char ** error_details)332 static grpc_status_code create_alts_crypters(const uint8_t* key,
333                                              size_t key_size, bool is_client,
334                                              bool is_rekey,
335                                              alts_frame_protector* impl,
336                                              char** error_details) {
337   grpc_status_code status;
338   gsec_aead_crypter* aead_crypter_seal = nullptr;
339   gsec_aead_crypter* aead_crypter_unseal = nullptr;
340   status = gsec_aes_gcm_aead_crypter_create(key, key_size, kAesGcmNonceLength,
341                                             kAesGcmTagLength, is_rekey,
342                                             &aead_crypter_seal, error_details);
343   if (status != GRPC_STATUS_OK) {
344     return status;
345   }
346   status = gsec_aes_gcm_aead_crypter_create(
347       key, key_size, kAesGcmNonceLength, kAesGcmTagLength, is_rekey,
348       &aead_crypter_unseal, error_details);
349   if (status != GRPC_STATUS_OK) {
350     return status;
351   }
352   size_t overflow_size = is_rekey ? kAltsRecordProtocolRekeyFrameLimit
353                                   : kAltsRecordProtocolFrameLimit;
354   status = alts_seal_crypter_create(aead_crypter_seal, is_client, overflow_size,
355                                     &impl->seal_crypter, error_details);
356   if (status != GRPC_STATUS_OK) {
357     return status;
358   }
359   status =
360       alts_unseal_crypter_create(aead_crypter_unseal, is_client, overflow_size,
361                                  &impl->unseal_crypter, error_details);
362   return status;
363 }
364 
alts_create_frame_protector(const uint8_t * key,size_t key_size,bool is_client,bool is_rekey,size_t * max_protected_frame_size,tsi_frame_protector ** self)365 tsi_result alts_create_frame_protector(const uint8_t* key, size_t key_size,
366                                        bool is_client, bool is_rekey,
367                                        size_t* max_protected_frame_size,
368                                        tsi_frame_protector** self) {
369   if (key == nullptr || self == nullptr) {
370     gpr_log(GPR_ERROR,
371             "Invalid nullptr arguments to alts_create_frame_protector().");
372     return TSI_INTERNAL_ERROR;
373   }
374   char* error_details = nullptr;
375   alts_frame_protector* impl = grpc_core::Zalloc<alts_frame_protector>();
376   grpc_status_code status = create_alts_crypters(
377       key, key_size, is_client, is_rekey, impl, &error_details);
378   if (status != GRPC_STATUS_OK) {
379     gpr_log(GPR_ERROR, "Failed to create ALTS crypters, %s.", error_details);
380     gpr_free(error_details);
381     return TSI_INTERNAL_ERROR;
382   }
383   ///
384   /// Set maximum frame size to be used by a frame protector. If it is nullptr,
385   /// a default frame size will be used. Otherwise, the provided frame size will
386   /// be adjusted (if not falling into a valid frame range) and used.
387   ///
388   size_t max_protected_frame_size_to_set = kDefaultFrameLength;
389   if (max_protected_frame_size != nullptr) {
390     *max_protected_frame_size =
391         std::min(*max_protected_frame_size, kMaxFrameLength);
392     *max_protected_frame_size =
393         std::max(*max_protected_frame_size, kMinFrameLength);
394     max_protected_frame_size_to_set = *max_protected_frame_size;
395   }
396   impl->max_protected_frame_size = max_protected_frame_size_to_set;
397   impl->max_unprotected_frame_size = max_protected_frame_size_to_set;
398   impl->in_place_protect_bytes_buffered = 0;
399   impl->in_place_unprotect_bytes_processed = 0;
400   impl->in_place_protect_buffer = static_cast<unsigned char*>(
401       gpr_malloc(sizeof(unsigned char) * max_protected_frame_size_to_set));
402   impl->in_place_unprotect_buffer = static_cast<unsigned char*>(
403       gpr_malloc(sizeof(unsigned char) * max_protected_frame_size_to_set));
404   impl->overhead_length = alts_crypter_num_overhead_bytes(impl->seal_crypter);
405   impl->writer = alts_create_frame_writer();
406   impl->reader = alts_create_frame_reader();
407   impl->base.vtable = &alts_frame_protector_vtable;
408   *self = &impl->base;
409   return TSI_OK;
410 }
411