xref: /aosp_15_r20/external/grpc-grpc/test/core/security/secure_endpoint_test.cc (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include "src/core/lib/security/transport/secure_endpoint.h"
20 
21 #include <fcntl.h>
22 #include <sys/types.h>
23 
24 #include <gtest/gtest.h>
25 
26 #include <grpc/grpc.h>
27 #include <grpc/support/alloc.h>
28 #include <grpc/support/log.h>
29 
30 #include "src/core/lib/gpr/useful.h"
31 #include "src/core/lib/gprpp/crash.h"
32 #include "src/core/lib/iomgr/endpoint_pair.h"
33 #include "src/core/lib/iomgr/iomgr.h"
34 #include "src/core/lib/slice/slice_internal.h"
35 #include "src/core/tsi/fake_transport_security.h"
36 #include "test/core/iomgr/endpoint_tests.h"
37 #include "test/core/util/test_config.h"
38 
39 static gpr_mu* g_mu;
40 static grpc_pollset* g_pollset;
41 
42 #define TSI_FAKE_FRAME_HEADER_SIZE 4
43 
44 typedef struct intercept_endpoint {
45   grpc_endpoint base;
46   grpc_endpoint* wrapped_ep;
47   grpc_slice_buffer staging_buffer;
48 } intercept_endpoint;
49 
me_read(grpc_endpoint * ep,grpc_slice_buffer * slices,grpc_closure * cb,bool urgent,int min_progress_size)50 static void me_read(grpc_endpoint* ep, grpc_slice_buffer* slices,
51                     grpc_closure* cb, bool urgent, int min_progress_size) {
52   intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
53   grpc_endpoint_read(m->wrapped_ep, slices, cb, urgent, min_progress_size);
54 }
55 
me_write(grpc_endpoint * ep,grpc_slice_buffer * slices,grpc_closure * cb,void * arg,int max_frame_size)56 static void me_write(grpc_endpoint* ep, grpc_slice_buffer* slices,
57                      grpc_closure* cb, void* arg, int max_frame_size) {
58   intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
59   int remaining = slices->length;
60   while (remaining > 0) {
61     // Estimate the frame size of the next frame.
62     int next_frame_size =
63         tsi_fake_zero_copy_grpc_protector_next_frame_size(slices);
64     ASSERT_GT(next_frame_size, TSI_FAKE_FRAME_HEADER_SIZE);
65     // Ensure the protected data size does not exceed the max_frame_size.
66     ASSERT_LE(next_frame_size - TSI_FAKE_FRAME_HEADER_SIZE, max_frame_size);
67     // Move this frame into a staging buffer and repeat.
68     grpc_slice_buffer_move_first(slices, next_frame_size, &m->staging_buffer);
69     remaining -= next_frame_size;
70   }
71   grpc_slice_buffer_swap(&m->staging_buffer, slices);
72   grpc_endpoint_write(m->wrapped_ep, slices, cb, arg, max_frame_size);
73 }
74 
me_add_to_pollset(grpc_endpoint *,grpc_pollset *)75 static void me_add_to_pollset(grpc_endpoint* /*ep*/,
76                               grpc_pollset* /*pollset*/) {}
77 
me_add_to_pollset_set(grpc_endpoint *,grpc_pollset_set *)78 static void me_add_to_pollset_set(grpc_endpoint* /*ep*/,
79                                   grpc_pollset_set* /*pollset*/) {}
80 
me_delete_from_pollset_set(grpc_endpoint *,grpc_pollset_set *)81 static void me_delete_from_pollset_set(grpc_endpoint* /*ep*/,
82                                        grpc_pollset_set* /*pollset*/) {}
83 
me_shutdown(grpc_endpoint * ep,grpc_error_handle why)84 static void me_shutdown(grpc_endpoint* ep, grpc_error_handle why) {
85   intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
86   grpc_endpoint_shutdown(m->wrapped_ep, why);
87 }
88 
me_destroy(grpc_endpoint * ep)89 static void me_destroy(grpc_endpoint* ep) {
90   intercept_endpoint* m = reinterpret_cast<intercept_endpoint*>(ep);
91   grpc_endpoint_destroy(m->wrapped_ep);
92   grpc_slice_buffer_destroy(&m->staging_buffer);
93   gpr_free(m);
94 }
95 
me_get_peer(grpc_endpoint *)96 static absl::string_view me_get_peer(grpc_endpoint* /*ep*/) {
97   return "fake:intercept-endpoint";
98 }
99 
me_get_local_address(grpc_endpoint *)100 static absl::string_view me_get_local_address(grpc_endpoint* /*ep*/) {
101   return "fake:intercept-endpoint";
102 }
103 
me_get_fd(grpc_endpoint *)104 static int me_get_fd(grpc_endpoint* /*ep*/) { return -1; }
105 
me_can_track_err(grpc_endpoint *)106 static bool me_can_track_err(grpc_endpoint* /*ep*/) { return false; }
107 
108 static const grpc_endpoint_vtable vtable = {me_read,
109                                             me_write,
110                                             me_add_to_pollset,
111                                             me_add_to_pollset_set,
112                                             me_delete_from_pollset_set,
113                                             me_shutdown,
114                                             me_destroy,
115                                             me_get_peer,
116                                             me_get_local_address,
117                                             me_get_fd,
118                                             me_can_track_err};
119 
wrap_with_intercept_endpoint(grpc_endpoint * wrapped_ep)120 grpc_endpoint* wrap_with_intercept_endpoint(grpc_endpoint* wrapped_ep) {
121   intercept_endpoint* m =
122       static_cast<intercept_endpoint*>(gpr_malloc(sizeof(*m)));
123   m->base.vtable = &vtable;
124   m->wrapped_ep = wrapped_ep;
125   grpc_slice_buffer_init(&m->staging_buffer);
126   return &m->base;
127 }
128 
secure_endpoint_create_fixture_tcp_socketpair(size_t slice_size,grpc_slice * leftover_slices,size_t leftover_nslices,bool use_zero_copy_protector)129 static grpc_endpoint_test_fixture secure_endpoint_create_fixture_tcp_socketpair(
130     size_t slice_size, grpc_slice* leftover_slices, size_t leftover_nslices,
131     bool use_zero_copy_protector) {
132   grpc_core::ExecCtx exec_ctx;
133   tsi_frame_protector* fake_read_protector =
134       tsi_create_fake_frame_protector(nullptr);
135   tsi_frame_protector* fake_write_protector =
136       tsi_create_fake_frame_protector(nullptr);
137   tsi_zero_copy_grpc_protector* fake_read_zero_copy_protector =
138       use_zero_copy_protector
139           ? tsi_create_fake_zero_copy_grpc_protector(nullptr)
140           : nullptr;
141   tsi_zero_copy_grpc_protector* fake_write_zero_copy_protector =
142       use_zero_copy_protector
143           ? tsi_create_fake_zero_copy_grpc_protector(nullptr)
144           : nullptr;
145   grpc_endpoint_test_fixture f;
146   grpc_endpoint_pair tcp;
147 
148   grpc_arg a[2];
149   a[0].key = const_cast<char*>(GRPC_ARG_TCP_READ_CHUNK_SIZE);
150   a[0].type = GRPC_ARG_INTEGER;
151   a[0].value.integer = static_cast<int>(slice_size);
152   a[1].key = const_cast<char*>(GRPC_ARG_RESOURCE_QUOTA);
153   a[1].type = GRPC_ARG_POINTER;
154   a[1].value.pointer.p = grpc_resource_quota_create("test");
155   a[1].value.pointer.vtable = grpc_resource_quota_arg_vtable();
156   grpc_channel_args args = {GPR_ARRAY_SIZE(a), a};
157   tcp = grpc_iomgr_create_endpoint_pair("fixture", &args);
158   grpc_endpoint_add_to_pollset(tcp.client, g_pollset);
159   grpc_endpoint_add_to_pollset(tcp.server, g_pollset);
160 
161   // TODO(vigneshbabu): Extend the intercept endpoint logic to cover non-zero
162   // copy based frame protectors as well.
163   if (use_zero_copy_protector && leftover_nslices == 0) {
164     tcp.client = wrap_with_intercept_endpoint(tcp.client);
165     tcp.server = wrap_with_intercept_endpoint(tcp.server);
166   }
167 
168   if (leftover_nslices == 0) {
169     f.client_ep = grpc_secure_endpoint_create(fake_read_protector,
170                                               fake_read_zero_copy_protector,
171                                               tcp.client, nullptr, &args, 0);
172   } else {
173     unsigned i;
174     tsi_result result;
175     size_t still_pending_size;
176     size_t total_buffer_size = 8192;
177     size_t buffer_size = total_buffer_size;
178     uint8_t* encrypted_buffer = static_cast<uint8_t*>(gpr_malloc(buffer_size));
179     uint8_t* cur = encrypted_buffer;
180     grpc_slice encrypted_leftover;
181     for (i = 0; i < leftover_nslices; i++) {
182       grpc_slice plain = leftover_slices[i];
183       uint8_t* message_bytes = GRPC_SLICE_START_PTR(plain);
184       size_t message_size = GRPC_SLICE_LENGTH(plain);
185       while (message_size > 0) {
186         size_t protected_buffer_size_to_send = buffer_size;
187         size_t processed_message_size = message_size;
188         result = tsi_frame_protector_protect(
189             fake_write_protector, message_bytes, &processed_message_size, cur,
190             &protected_buffer_size_to_send);
191         EXPECT_EQ(result, TSI_OK);
192         message_bytes += processed_message_size;
193         message_size -= processed_message_size;
194         cur += protected_buffer_size_to_send;
195         EXPECT_GE(buffer_size, protected_buffer_size_to_send);
196         buffer_size -= protected_buffer_size_to_send;
197       }
198       grpc_slice_unref(plain);
199     }
200     do {
201       size_t protected_buffer_size_to_send = buffer_size;
202       result = tsi_frame_protector_protect_flush(fake_write_protector, cur,
203                                                  &protected_buffer_size_to_send,
204                                                  &still_pending_size);
205       EXPECT_EQ(result, TSI_OK);
206       cur += protected_buffer_size_to_send;
207       EXPECT_GE(buffer_size, protected_buffer_size_to_send);
208       buffer_size -= protected_buffer_size_to_send;
209     } while (still_pending_size > 0);
210     encrypted_leftover = grpc_slice_from_copied_buffer(
211         reinterpret_cast<const char*>(encrypted_buffer),
212         total_buffer_size - buffer_size);
213     f.client_ep = grpc_secure_endpoint_create(
214         fake_read_protector, fake_read_zero_copy_protector, tcp.client,
215         &encrypted_leftover, &args, 1);
216     grpc_slice_unref(encrypted_leftover);
217     gpr_free(encrypted_buffer);
218   }
219 
220   f.server_ep = grpc_secure_endpoint_create(fake_write_protector,
221                                             fake_write_zero_copy_protector,
222                                             tcp.server, nullptr, &args, 0);
223   grpc_resource_quota_unref(
224       static_cast<grpc_resource_quota*>(a[1].value.pointer.p));
225   return f;
226 }
227 
228 static grpc_endpoint_test_fixture
secure_endpoint_create_fixture_tcp_socketpair_noleftover(size_t slice_size)229 secure_endpoint_create_fixture_tcp_socketpair_noleftover(size_t slice_size) {
230   return secure_endpoint_create_fixture_tcp_socketpair(slice_size, nullptr, 0,
231                                                        false);
232 }
233 
234 static grpc_endpoint_test_fixture
secure_endpoint_create_fixture_tcp_socketpair_noleftover_zero_copy(size_t slice_size)235 secure_endpoint_create_fixture_tcp_socketpair_noleftover_zero_copy(
236     size_t slice_size) {
237   return secure_endpoint_create_fixture_tcp_socketpair(slice_size, nullptr, 0,
238                                                        true);
239 }
240 
241 static grpc_endpoint_test_fixture
secure_endpoint_create_fixture_tcp_socketpair_leftover(size_t slice_size)242 secure_endpoint_create_fixture_tcp_socketpair_leftover(size_t slice_size) {
243   grpc_slice s =
244       grpc_slice_from_copied_string("hello world 12345678900987654321");
245   return secure_endpoint_create_fixture_tcp_socketpair(slice_size, &s, 1,
246                                                        false);
247 }
248 
249 static grpc_endpoint_test_fixture
secure_endpoint_create_fixture_tcp_socketpair_leftover_zero_copy(size_t slice_size)250 secure_endpoint_create_fixture_tcp_socketpair_leftover_zero_copy(
251     size_t slice_size) {
252   grpc_slice s =
253       grpc_slice_from_copied_string("hello world 12345678900987654321");
254   return secure_endpoint_create_fixture_tcp_socketpair(slice_size, &s, 1, true);
255 }
256 
clean_up(void)257 static void clean_up(void) {}
258 
259 static grpc_endpoint_test_config configs[] = {
260     {"secure_ep/tcp_socketpair",
261      secure_endpoint_create_fixture_tcp_socketpair_noleftover, clean_up},
262     {"secure_ep/tcp_socketpair_zero_copy",
263      secure_endpoint_create_fixture_tcp_socketpair_noleftover_zero_copy,
264      clean_up},
265     {"secure_ep/tcp_socketpair_leftover",
266      secure_endpoint_create_fixture_tcp_socketpair_leftover, clean_up},
267     {"secure_ep/tcp_socketpair_leftover_zero_copy",
268      secure_endpoint_create_fixture_tcp_socketpair_leftover_zero_copy,
269      clean_up},
270 };
271 
inc_call_ctr(void * arg,grpc_error_handle)272 static void inc_call_ctr(void* arg, grpc_error_handle /*error*/) {
273   ++*static_cast<int*>(arg);
274 }
275 
test_leftover(grpc_endpoint_test_config config,size_t slice_size)276 static void test_leftover(grpc_endpoint_test_config config, size_t slice_size) {
277   grpc_endpoint_test_fixture f = config.create_fixture(slice_size);
278   grpc_slice_buffer incoming;
279   grpc_slice s =
280       grpc_slice_from_copied_string("hello world 12345678900987654321");
281   grpc_core::ExecCtx exec_ctx;
282   int n = 0;
283   grpc_closure done_closure;
284   gpr_log(GPR_INFO, "Start test left over");
285 
286   grpc_slice_buffer_init(&incoming);
287   GRPC_CLOSURE_INIT(&done_closure, inc_call_ctr, &n, grpc_schedule_on_exec_ctx);
288   grpc_endpoint_read(f.client_ep, &incoming, &done_closure, /*urgent=*/false,
289                      /*min_progress_size=*/1);
290 
291   grpc_core::ExecCtx::Get()->Flush();
292   ASSERT_EQ(n, 1);
293   ASSERT_EQ(incoming.count, 1);
294   ASSERT_TRUE(grpc_slice_eq(s, incoming.slices[0]));
295 
296   grpc_endpoint_shutdown(f.client_ep, GRPC_ERROR_CREATE("test_leftover end"));
297   grpc_endpoint_shutdown(f.server_ep, GRPC_ERROR_CREATE("test_leftover end"));
298   grpc_endpoint_destroy(f.client_ep);
299   grpc_endpoint_destroy(f.server_ep);
300 
301   grpc_slice_unref(s);
302   grpc_slice_buffer_destroy(&incoming);
303 
304   clean_up();
305 }
306 
destroy_pollset(void * p,grpc_error_handle)307 static void destroy_pollset(void* p, grpc_error_handle /*error*/) {
308   grpc_pollset_destroy(static_cast<grpc_pollset*>(p));
309 }
310 
TEST(SecureEndpointTest,MainTest)311 TEST(SecureEndpointTest, MainTest) {
312   grpc_closure destroyed;
313   grpc_init();
314 
315   {
316     grpc_core::ExecCtx exec_ctx;
317     g_pollset = static_cast<grpc_pollset*>(gpr_zalloc(grpc_pollset_size()));
318     grpc_pollset_init(g_pollset, &g_mu);
319     grpc_endpoint_tests(configs[0], g_pollset, g_mu);
320     grpc_endpoint_tests(configs[1], g_pollset, g_mu);
321     test_leftover(configs[2], 1);
322     test_leftover(configs[3], 1);
323     GRPC_CLOSURE_INIT(&destroyed, destroy_pollset, g_pollset,
324                       grpc_schedule_on_exec_ctx);
325     grpc_pollset_shutdown(g_pollset, &destroyed);
326   }
327 
328   grpc_shutdown();
329 
330   gpr_free(g_pollset);
331 }
332 
main(int argc,char ** argv)333 int main(int argc, char** argv) {
334   grpc::testing::TestEnvironment env(&argc, argv);
335   ::testing::InitGoogleTest(&argc, argv);
336   return RUN_ALL_TESTS();
337 }
338