1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/tsi/alts/frame_protector/frame_handler.h"
22 
23 #include <limits.h>
24 #include <stdint.h>
25 #include <string.h>
26 
27 #include <algorithm>
28 
29 #include <grpc/support/alloc.h>
30 #include <grpc/support/log.h>
31 
32 #include "src/core/lib/gprpp/crash.h"
33 #include "src/core/lib/gprpp/memory.h"
34 
35 // Use little endian to interpret a string of bytes as uint32_t.
load_32_le(const unsigned char * buffer)36 static uint32_t load_32_le(const unsigned char* buffer) {
37   return (static_cast<uint32_t>(buffer[3]) << 24) |
38          (static_cast<uint32_t>(buffer[2]) << 16) |
39          (static_cast<uint32_t>(buffer[1]) << 8) |
40          static_cast<uint32_t>(buffer[0]);
41 }
42 
43 // Store uint32_t as a string of little endian bytes.
store_32_le(uint32_t value,unsigned char * buffer)44 static void store_32_le(uint32_t value, unsigned char* buffer) {
45   buffer[3] = static_cast<unsigned char>(value >> 24) & 0xFF;
46   buffer[2] = static_cast<unsigned char>(value >> 16) & 0xFF;
47   buffer[1] = static_cast<unsigned char>(value >> 8) & 0xFF;
48   buffer[0] = static_cast<unsigned char>(value) & 0xFF;
49 }
50 
51 // Frame writer implementation.
alts_create_frame_writer()52 alts_frame_writer* alts_create_frame_writer() {
53   return grpc_core::Zalloc<alts_frame_writer>();
54 }
55 
alts_reset_frame_writer(alts_frame_writer * writer,const unsigned char * buffer,size_t length)56 bool alts_reset_frame_writer(alts_frame_writer* writer,
57                              const unsigned char* buffer, size_t length) {
58   if (buffer == nullptr) return false;
59   size_t max_input_size = SIZE_MAX - kFrameLengthFieldSize;
60   if (length > max_input_size) {
61     gpr_log(GPR_ERROR, "length must be at most %zu", max_input_size);
62     return false;
63   }
64   writer->input_buffer = buffer;
65   writer->input_size = length;
66   writer->input_bytes_written = 0;
67   writer->header_bytes_written = 0;
68   store_32_le(
69       static_cast<uint32_t>(writer->input_size + kFrameMessageTypeFieldSize),
70       writer->header_buffer);
71   store_32_le(kFrameMessageType, writer->header_buffer + kFrameLengthFieldSize);
72   return true;
73 }
74 
alts_write_frame_bytes(alts_frame_writer * writer,unsigned char * output,size_t * bytes_size)75 bool alts_write_frame_bytes(alts_frame_writer* writer, unsigned char* output,
76                             size_t* bytes_size) {
77   if (bytes_size == nullptr || output == nullptr) return false;
78   if (alts_is_frame_writer_done(writer)) {
79     *bytes_size = 0;
80     return true;
81   }
82   size_t bytes_written = 0;
83   // Write some header bytes, if needed.
84   if (writer->header_bytes_written != sizeof(writer->header_buffer)) {
85     size_t bytes_to_write =
86         std::min(*bytes_size,
87                  sizeof(writer->header_buffer) - writer->header_bytes_written);
88     memcpy(output, writer->header_buffer + writer->header_bytes_written,
89            bytes_to_write);
90     bytes_written += bytes_to_write;
91     *bytes_size -= bytes_to_write;
92     writer->header_bytes_written += bytes_to_write;
93     output += bytes_to_write;
94     if (writer->header_bytes_written != sizeof(writer->header_buffer)) {
95       *bytes_size = bytes_written;
96       return true;
97     }
98   }
99   // Write some non-header bytes.
100   size_t bytes_to_write =
101       std::min(writer->input_size - writer->input_bytes_written, *bytes_size);
102   memcpy(output, writer->input_buffer, bytes_to_write);
103   writer->input_buffer += bytes_to_write;
104   bytes_written += bytes_to_write;
105   writer->input_bytes_written += bytes_to_write;
106   *bytes_size = bytes_written;
107   return true;
108 }
109 
alts_is_frame_writer_done(alts_frame_writer * writer)110 bool alts_is_frame_writer_done(alts_frame_writer* writer) {
111   return writer->input_buffer == nullptr ||
112          writer->input_size == writer->input_bytes_written;
113 }
114 
alts_get_num_writer_bytes_remaining(alts_frame_writer * writer)115 size_t alts_get_num_writer_bytes_remaining(alts_frame_writer* writer) {
116   return (sizeof(writer->header_buffer) - writer->header_bytes_written) +
117          (writer->input_size - writer->input_bytes_written);
118 }
119 
alts_destroy_frame_writer(alts_frame_writer * writer)120 void alts_destroy_frame_writer(alts_frame_writer* writer) { gpr_free(writer); }
121 
122 // Frame reader implementation.
alts_create_frame_reader()123 alts_frame_reader* alts_create_frame_reader() {
124   alts_frame_reader* reader = grpc_core::Zalloc<alts_frame_reader>();
125   return reader;
126 }
127 
alts_is_frame_reader_done(alts_frame_reader * reader)128 bool alts_is_frame_reader_done(alts_frame_reader* reader) {
129   return reader->output_buffer == nullptr ||
130          (reader->header_bytes_read == sizeof(reader->header_buffer) &&
131           reader->bytes_remaining == 0);
132 }
133 
alts_has_read_frame_length(alts_frame_reader * reader)134 bool alts_has_read_frame_length(alts_frame_reader* reader) {
135   return sizeof(reader->header_buffer) == reader->header_bytes_read;
136 }
137 
alts_get_reader_bytes_remaining(alts_frame_reader * reader)138 size_t alts_get_reader_bytes_remaining(alts_frame_reader* reader) {
139   return alts_has_read_frame_length(reader) ? reader->bytes_remaining : 0;
140 }
141 
alts_reset_reader_output_buffer(alts_frame_reader * reader,unsigned char * buffer)142 void alts_reset_reader_output_buffer(alts_frame_reader* reader,
143                                      unsigned char* buffer) {
144   reader->output_buffer = buffer;
145 }
146 
alts_reset_frame_reader(alts_frame_reader * reader,unsigned char * buffer)147 bool alts_reset_frame_reader(alts_frame_reader* reader, unsigned char* buffer) {
148   if (buffer == nullptr) return false;
149   reader->output_buffer = buffer;
150   reader->bytes_remaining = 0;
151   reader->header_bytes_read = 0;
152   reader->output_bytes_read = 0;
153   return true;
154 }
155 
alts_read_frame_bytes(alts_frame_reader * reader,const unsigned char * bytes,size_t * bytes_size)156 bool alts_read_frame_bytes(alts_frame_reader* reader,
157                            const unsigned char* bytes, size_t* bytes_size) {
158   if (bytes_size == nullptr) return false;
159   if (bytes == nullptr) {
160     *bytes_size = 0;
161     return false;
162   }
163   if (alts_is_frame_reader_done(reader)) {
164     *bytes_size = 0;
165     return true;
166   }
167   size_t bytes_processed = 0;
168   // Process the header, if needed.
169   if (reader->header_bytes_read != sizeof(reader->header_buffer)) {
170     size_t bytes_to_write = std::min(
171         *bytes_size, sizeof(reader->header_buffer) - reader->header_bytes_read);
172     memcpy(reader->header_buffer + reader->header_bytes_read, bytes,
173            bytes_to_write);
174     reader->header_bytes_read += bytes_to_write;
175     bytes_processed += bytes_to_write;
176     bytes += bytes_to_write;
177     *bytes_size -= bytes_to_write;
178     if (reader->header_bytes_read != sizeof(reader->header_buffer)) {
179       *bytes_size = bytes_processed;
180       return true;
181     }
182     size_t frame_length = load_32_le(reader->header_buffer);
183     if (frame_length < kFrameMessageTypeFieldSize ||
184         frame_length > kFrameMaxSize) {
185       gpr_log(GPR_ERROR,
186               "Bad frame length (should be at least %zu, and at most %zu)",
187               kFrameMessageTypeFieldSize, kFrameMaxSize);
188       *bytes_size = 0;
189       return false;
190     }
191     size_t message_type =
192         load_32_le(reader->header_buffer + kFrameLengthFieldSize);
193     if (message_type != kFrameMessageType) {
194       gpr_log(GPR_ERROR, "Unsupported message type %zu (should be %zu)",
195               message_type, kFrameMessageType);
196       *bytes_size = 0;
197       return false;
198     }
199     reader->bytes_remaining = frame_length - kFrameMessageTypeFieldSize;
200   }
201   // Process the non-header bytes.
202   size_t bytes_to_write = std::min(*bytes_size, reader->bytes_remaining);
203   memcpy(reader->output_buffer, bytes, bytes_to_write);
204   reader->output_buffer += bytes_to_write;
205   bytes_processed += bytes_to_write;
206   reader->bytes_remaining -= bytes_to_write;
207   reader->output_bytes_read += bytes_to_write;
208   *bytes_size = bytes_processed;
209   return true;
210 }
211 
alts_get_output_bytes_read(alts_frame_reader * reader)212 size_t alts_get_output_bytes_read(alts_frame_reader* reader) {
213   return reader->output_bytes_read;
214 }
215 
alts_get_output_buffer(alts_frame_reader * reader)216 unsigned char* alts_get_output_buffer(alts_frame_reader* reader) {
217   return reader->output_buffer;
218 }
219 
alts_destroy_frame_reader(alts_frame_reader * reader)220 void alts_destroy_frame_reader(alts_frame_reader* reader) { gpr_free(reader); }
221