1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include "src/core/lib/security/transport/secure_endpoint.h"
20
21 #include <fcntl.h>
22 #include <sys/types.h>
23
24 #include <gtest/gtest.h>
25
26 #include <grpc/grpc.h>
27 #include <grpc/support/alloc.h>
28 #include <grpc/support/log.h>
29
30 #include "src/core/lib/gpr/useful.h"
31 #include "src/core/lib/gprpp/crash.h"
32 #include "src/core/lib/iomgr/endpoint_pair.h"
33 #include "src/core/lib/iomgr/iomgr.h"
34 #include "src/core/lib/slice/slice_internal.h"
35 #include "src/core/tsi/fake_transport_security.h"
36 #include "test/core/iomgr/endpoint_tests.h"
37 #include "test/core/util/test_config.h"
38
39 static gpr_mu* g_mu;
40 static grpc_pollset* g_pollset;
41
42 #define TSI_FAKE_FRAME_HEADER_SIZE 4
43
44 typedef struct intercept_endpoint {
45 grpc_endpoint base;
46 grpc_endpoint* wrapped_ep;
47 grpc_slice_buffer staging_buffer;
48 } intercept_endpoint;
49
me_read(grpc_endpoint * ep,grpc_slice_buffer * slices,grpc_closure * cb,bool urgent,int min_progress_size)50 static void me_read(grpc_endpoint* ep, grpc_slice_buffer* slices,
51 grpc_closure* cb, bool urgent, int min_progress_size) {
52 intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
53 grpc_endpoint_read(m->wrapped_ep, slices, cb, urgent, min_progress_size);
54 }
55
me_write(grpc_endpoint * ep,grpc_slice_buffer * slices,grpc_closure * cb,void * arg,int max_frame_size)56 static void me_write(grpc_endpoint* ep, grpc_slice_buffer* slices,
57 grpc_closure* cb, void* arg, int max_frame_size) {
58 intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
59 int remaining = slices->length;
60 while (remaining > 0) {
61 // Estimate the frame size of the next frame.
62 int next_frame_size =
63 tsi_fake_zero_copy_grpc_protector_next_frame_size(slices);
64 ASSERT_GT(next_frame_size, TSI_FAKE_FRAME_HEADER_SIZE);
65 // Ensure the protected data size does not exceed the max_frame_size.
66 ASSERT_LE(next_frame_size - TSI_FAKE_FRAME_HEADER_SIZE, max_frame_size);
67 // Move this frame into a staging buffer and repeat.
68 grpc_slice_buffer_move_first(slices, next_frame_size, &m->staging_buffer);
69 remaining -= next_frame_size;
70 }
71 grpc_slice_buffer_swap(&m->staging_buffer, slices);
72 grpc_endpoint_write(m->wrapped_ep, slices, cb, arg, max_frame_size);
73 }
74
me_add_to_pollset(grpc_endpoint *,grpc_pollset *)75 static void me_add_to_pollset(grpc_endpoint* /*ep*/,
76 grpc_pollset* /*pollset*/) {}
77
me_add_to_pollset_set(grpc_endpoint *,grpc_pollset_set *)78 static void me_add_to_pollset_set(grpc_endpoint* /*ep*/,
79 grpc_pollset_set* /*pollset*/) {}
80
me_delete_from_pollset_set(grpc_endpoint *,grpc_pollset_set *)81 static void me_delete_from_pollset_set(grpc_endpoint* /*ep*/,
82 grpc_pollset_set* /*pollset*/) {}
83
me_shutdown(grpc_endpoint * ep,grpc_error_handle why)84 static void me_shutdown(grpc_endpoint* ep, grpc_error_handle why) {
85 intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
86 grpc_endpoint_shutdown(m->wrapped_ep, why);
87 }
88
me_destroy(grpc_endpoint * ep)89 static void me_destroy(grpc_endpoint* ep) {
90 intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
91 grpc_endpoint_destroy(m->wrapped_ep);
92 grpc_slice_buffer_destroy(&m->staging_buffer);
93 gpr_free(m);
94 }
95
me_get_peer(grpc_endpoint *)96 static absl::string_view me_get_peer(grpc_endpoint* /*ep*/) {
97 return "fake:intercept-endpoint";
98 }
99
me_get_local_address(grpc_endpoint *)100 static absl::string_view me_get_local_address(grpc_endpoint* /*ep*/) {
101 return "fake:intercept-endpoint";
102 }
103
me_get_fd(grpc_endpoint *)104 static int me_get_fd(grpc_endpoint* /*ep*/) { return -1; }
105
me_can_track_err(grpc_endpoint *)106 static bool me_can_track_err(grpc_endpoint* /*ep*/) { return false; }
107
108 static const grpc_endpoint_vtable vtable = {me_read,
109 me_write,
110 me_add_to_pollset,
111 me_add_to_pollset_set,
112 me_delete_from_pollset_set,
113 me_shutdown,
114 me_destroy,
115 me_get_peer,
116 me_get_local_address,
117 me_get_fd,
118 me_can_track_err};
119
wrap_with_intercept_endpoint(grpc_endpoint * wrapped_ep)120 grpc_endpoint* wrap_with_intercept_endpoint(grpc_endpoint* wrapped_ep) {
121 intercept_endpoint* m =
122 static_cast<intercept_endpoint*>(gpr_malloc(sizeof(*m)));
123 m->base.vtable = &vtable;
124 m->wrapped_ep = wrapped_ep;
125 grpc_slice_buffer_init(&m->staging_buffer);
126 return &m->base;
127 }
128
secure_endpoint_create_fixture_tcp_socketpair(size_t slice_size,grpc_slice * leftover_slices,size_t leftover_nslices,bool use_zero_copy_protector)129 static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair(
130 size_t slice_size, grpc_slice* leftover_slices, size_t leftover_nslices,
131 bool use_zero_copy_protector) {
132 grpc_core::ExecCtx exec_ctx;
133 tsi_frame_protector* fake_read_protector =
134 tsi_create_fake_frame_protector(nullptr);
135 tsi_frame_protector* fake_write_protector =
136 tsi_create_fake_frame_protector(nullptr);
137 tsi_zero_copy_grpc_protector* fake_read_zero_copy_protector =
138 use_zero_copy_protector
139 ? tsi_create_fake_zero_copy_grpc_protector(nullptr)
140 : nullptr;
141 tsi_zero_copy_grpc_protector* fake_write_zero_copy_protector =
142 use_zero_copy_protector
143 ? tsi_create_fake_zero_copy_grpc_protector(nullptr)
144 : nullptr;
145 grpc_endpoint_test_fixture f;
146 grpc_endpoint_pair tcp;
147
148 grpc_arg a[2];
149 a[0].key = const_cast<char*>(GRPC_ARG_TCP_READ_CHUNK_SIZE);
150 a[0].type = GRPC_ARG_INTEGER;
151 a[0].value.integer = static_cast<int>(slice_size);
152 a[1].key = const_cast<char*>(GRPC_ARG_RESOURCE_QUOTA);
153 a[1].type = GRPC_ARG_POINTER;
154 a[1].value.pointer.p = grpc_resource_quota_create("test");
155 a[1].value.pointer.vtable = grpc_resource_quota_arg_vtable();
156 grpc_channel_args args = {GPR_ARRAY_SIZE(a), a};
157 tcp = grpc_iomgr_create_endpoint_pair("fixture", &args);
158 grpc_endpoint_add_to_pollset(tcp.client, g_pollset);
159 grpc_endpoint_add_to_pollset(tcp.server, g_pollset);
160
161 // TODO(vigneshbabu): Extend the intercept endpoint logic to cover non-zero
162 // copy based frame protectors as well.
163 if (use_zero_copy_protector && leftover_nslices == 0) {
164 tcp.client = wrap_with_intercept_endpoint(tcp.client);
165 tcp.server = wrap_with_intercept_endpoint(tcp.server);
166 }
167
168 if (leftover_nslices == 0) {
169 f.client_ep = grpc_secure_endpoint_create(fake_read_protector,
170 fake_read_zero_copy_protector,
171 tcp.client, nullptr, &args, 0);
172 } else {
173 unsigned i;
174 tsi_result result;
175 size_t still_pending_size;
176 size_t total_buffer_size = 8192;
177 size_t buffer_size = total_buffer_size;
178 uint8_t* encrypted_buffer = static_cast<uint8_t*>(gpr_malloc(buffer_size));
179 uint8_t* cur = encrypted_buffer;
180 grpc_slice encrypted_leftover;
181 for (i = 0; i < leftover_nslices; i++) {
182 grpc_slice plain = leftover_slices[i];
183 uint8_t* message_bytes = GRPC_SLICE_START_PTR(plain);
184 size_t message_size = GRPC_SLICE_LENGTH(plain);
185 while (message_size > 0) {
186 size_t protected_buffer_size_to_send = buffer_size;
187 size_t processed_message_size = message_size;
188 result = tsi_frame_protector_protect(
189 fake_write_protector, message_bytes, &processed_message_size, cur,
190 &protected_buffer_size_to_send);
191 EXPECT_EQ(result, TSI_OK);
192 message_bytes += processed_message_size;
193 message_size -= processed_message_size;
194 cur += protected_buffer_size_to_send;
195 EXPECT_GE(buffer_size, protected_buffer_size_to_send);
196 buffer_size -= protected_buffer_size_to_send;
197 }
198 grpc_slice_unref(plain);
199 }
200 do {
201 size_t protected_buffer_size_to_send = buffer_size;
202 result = tsi_frame_protector_protect_flush(fake_write_protector, cur,
203 &protected_buffer_size_to_send,
204 &still_pending_size);
205 EXPECT_EQ(result, TSI_OK);
206 cur += protected_buffer_size_to_send;
207 EXPECT_GE(buffer_size, protected_buffer_size_to_send);
208 buffer_size -= protected_buffer_size_to_send;
209 } while (still_pending_size > 0);
210 encrypted_leftover = grpc_slice_from_copied_buffer(
211 reinterpret_cast<const char*>(encrypted_buffer),
212 total_buffer_size - buffer_size);
213 f.client_ep = grpc_secure_endpoint_create(
214 fake_read_protector, fake_read_zero_copy_protector, tcp.client,
215 &encrypted_leftover, &args, 1);
216 grpc_slice_unref(encrypted_leftover);
217 gpr_free(encrypted_buffer);
218 }
219
220 f.server_ep = grpc_secure_endpoint_create(fake_write_protector,
221 fake_write_zero_copy_protector,
222 tcp.server, nullptr, &args, 0);
223 grpc_resource_quota_unref(
224 static_cast<grpc_resource_quota*>(a[1].value.pointer.p));
225 return f;
226 }
227
228 static grpc_endpoint_test_fixture
secure_endpoint_create_fixture_tcp_socketpair_noleftover(size_t slice_size)229 secure_endpoint_create_fixture_tcp_socketpair_noleftover(size_t slice_size) {
230 return secure_endpoint_create_fixture_tcp_socketpair(slice_size, nullptr, 0,
231 false);
232 }
233
234 static grpc_endpoint_test_fixture
secure_endpoint_create_fixture_tcp_socketpair_noleftover_zero_copy(size_t slice_size)235 secure_endpoint_create_fixture_tcp_socketpair_noleftover_zero_copy(
236 size_t slice_size) {
237 return secure_endpoint_create_fixture_tcp_socketpair(slice_size, nullptr, 0,
238 true);
239 }
240
241 static grpc_endpoint_test_fixture
secure_endpoint_create_fixture_tcp_socketpair_leftover(size_t slice_size)242 secure_endpoint_create_fixture_tcp_socketpair_leftover(size_t slice_size) {
243 grpc_slice s =
244 grpc_slice_from_copied_string("hello world 12345678900987654321");
245 return secure_endpoint_create_fixture_tcp_socketpair(slice_size, &s, 1,
246 false);
247 }
248
249 static grpc_endpoint_test_fixture
secure_endpoint_create_fixture_tcp_socketpair_leftover_zero_copy(size_t slice_size)250 secure_endpoint_create_fixture_tcp_socketpair_leftover_zero_copy(
251 size_t slice_size) {
252 grpc_slice s =
253 grpc_slice_from_copied_string("hello world 12345678900987654321");
254 return secure_endpoint_create_fixture_tcp_socketpair(slice_size, &s, 1, true);
255 }
256
clean_up(void)257 static void clean_up(void) {}
258
259 static grpc_endpoint_test_config configs[] = {
260 {"secure_ep/tcp_socketpair",
261 secure_endpoint_create_fixture_tcp_socketpair_noleftover, clean_up},
262 {"secure_ep/tcp_socketpair_zero_copy",
263 secure_endpoint_create_fixture_tcp_socketpair_noleftover_zero_copy,
264 clean_up},
265 {"secure_ep/tcp_socketpair_leftover",
266 secure_endpoint_create_fixture_tcp_socketpair_leftover, clean_up},
267 {"secure_ep/tcp_socketpair_leftover_zero_copy",
268 secure_endpoint_create_fixture_tcp_socketpair_leftover_zero_copy,
269 clean_up},
270 };
271
inc_call_ctr(void * arg,grpc_error_handle)272 static void inc_call_ctr(void* arg, grpc_error_handle /*error*/) {
273 ++*static_cast<int*>(arg);
274 }
275
test_leftover(grpc_endpoint_test_config config,size_t slice_size)276 static void test_leftover(grpc_endpoint_test_config config, size_t slice_size) {
277 grpc_endpoint_test_fixture f = config.create_fixture(slice_size);
278 grpc_slice_buffer incoming;
279 grpc_slice s =
280 grpc_slice_from_copied_string("hello world 12345678900987654321");
281 grpc_core::ExecCtx exec_ctx;
282 int n = 0;
283 grpc_closure done_closure;
284 gpr_log(GPR_INFO, "Start test left over");
285
286 grpc_slice_buffer_init(&incoming);
287 GRPC_CLOSURE_INIT(&done_closure, inc_call_ctr, &n, grpc_schedule_on_exec_ctx);
288 grpc_endpoint_read(f.client_ep, &incoming, &done_closure, /*urgent=*/false,
289 /*min_progress_size=*/1);
290
291 grpc_core::ExecCtx::Get()->Flush();
292 ASSERT_EQ(n, 1);
293 ASSERT_EQ(incoming.count, 1);
294 ASSERT_TRUE(grpc_slice_eq(s, incoming.slices[0]));
295
296 grpc_endpoint_shutdown(f.client_ep, GRPC_ERROR_CREATE("test_leftover end"));
297 grpc_endpoint_shutdown(f.server_ep, GRPC_ERROR_CREATE("test_leftover end"));
298 grpc_endpoint_destroy(f.client_ep);
299 grpc_endpoint_destroy(f.server_ep);
300
301 grpc_slice_unref(s);
302 grpc_slice_buffer_destroy(&incoming);
303
304 clean_up();
305 }
306
destroy_pollset(void * p,grpc_error_handle)307 static void destroy_pollset(void* p, grpc_error_handle /*error*/) {
308 grpc_pollset_destroy(static_cast<grpc_pollset*>(p));
309 }
310
TEST(SecureEndpointTest,MainTest)311 TEST(SecureEndpointTest, MainTest) {
312 grpc_closure destroyed;
313 grpc_init();
314
315 {
316 grpc_core::ExecCtx exec_ctx;
317 g_pollset = static_cast<grpc_pollset*>(gpr_zalloc(grpc_pollset_size()));
318 grpc_pollset_init(g_pollset, &g_mu);
319 grpc_endpoint_tests(configs[0], g_pollset, g_mu);
320 grpc_endpoint_tests(configs[1], g_pollset, g_mu);
321 test_leftover(configs[2], 1);
322 test_leftover(configs[3], 1);
323 GRPC_CLOSURE_INIT(&destroyed, destroy_pollset, g_pollset,
324 grpc_schedule_on_exec_ctx);
325 grpc_pollset_shutdown(g_pollset, &destroyed);
326 }
327
328 grpc_shutdown();
329
330 gpr_free(g_pollset);
331 }
332
main(int argc,char ** argv)333 int main(int argc, char** argv) {
334 grpc::testing::TestEnvironment env(&argc, argv);
335 ::testing::InitGoogleTest(&argc, argv);
336 return RUN_ALL_TESTS();
337 }
338