1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include <memory>
20 #include <vector>
21
22 #include <gtest/gtest.h>
23
24 #include "absl/memory/memory.h"
25
26 #include <grpcpp/channel.h>
27 #include <grpcpp/client_context.h>
28 #include <grpcpp/create_channel.h>
29 #include <grpcpp/create_channel_posix.h>
30 #include <grpcpp/generic/generic_stub.h>
31 #include <grpcpp/impl/proto_utils.h>
32 #include <grpcpp/server.h>
33 #include <grpcpp/server_builder.h>
34 #include <grpcpp/server_context.h>
35 #include <grpcpp/server_posix.h>
36 #include <grpcpp/support/client_interceptor.h>
37
38 #include "src/core/lib/iomgr/port.h"
39 #include "src/proto/grpc/testing/echo.grpc.pb.h"
40 #include "test/core/util/port.h"
41 #include "test/core/util/test_config.h"
42 #include "test/cpp/end2end/interceptors_util.h"
43 #include "test/cpp/end2end/test_service_impl.h"
44 #include "test/cpp/util/byte_buffer_proto_helper.h"
45 #include "test/cpp/util/string_ref_helper.h"
46
47 #ifdef GRPC_POSIX_SOCKET
48 #include <fcntl.h>
49
50 #include "src/core/lib/iomgr/socket_utils_posix.h"
51 #endif // GRPC_POSIX_SOCKET
52
53 namespace grpc {
54 namespace testing {
55 namespace {
56
57 enum class RPCType {
58 kSyncUnary,
59 kSyncClientStreaming,
60 kSyncServerStreaming,
61 kSyncBidiStreaming,
62 kAsyncCQUnary,
63 kAsyncCQClientStreaming,
64 kAsyncCQServerStreaming,
65 kAsyncCQBidiStreaming,
66 };
67
68 enum class ChannelType {
69 kHttpChannel,
70 kFdChannel,
71 };
72
73 // Hijacks Echo RPC and fills in the expected values
74 class HijackingInterceptor : public experimental::Interceptor {
75 public:
HijackingInterceptor(experimental::ClientRpcInfo * info)76 explicit HijackingInterceptor(experimental::ClientRpcInfo* info) {
77 info_ = info;
78 // Make sure it is the right method
79 EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
80 EXPECT_EQ(info->suffix_for_stats(), nullptr);
81 EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY);
82 }
83
Intercept(experimental::InterceptorBatchMethods * methods)84 void Intercept(experimental::InterceptorBatchMethods* methods) override {
85 bool hijack = false;
86 if (methods->QueryInterceptionHookPoint(
87 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
88 auto* map = methods->GetSendInitialMetadata();
89 // Check that we can see the test metadata
90 ASSERT_EQ(map->size(), 1);
91 auto iterator = map->begin();
92 EXPECT_EQ("testkey", iterator->first);
93 EXPECT_EQ("testvalue", iterator->second);
94 hijack = true;
95 }
96 if (methods->QueryInterceptionHookPoint(
97 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
98 EchoRequest req;
99 auto* buffer = methods->GetSerializedSendMessage();
100 auto copied_buffer = *buffer;
101 EXPECT_TRUE(
102 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
103 .ok());
104 EXPECT_EQ(req.message(), "Hello");
105 }
106 if (methods->QueryInterceptionHookPoint(
107 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
108 // Got nothing to do here for now
109 }
110 if (methods->QueryInterceptionHookPoint(
111 experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
112 auto* map = methods->GetRecvInitialMetadata();
113 // Got nothing better to do here for now
114 EXPECT_EQ(map->size(), 0);
115 }
116 if (methods->QueryInterceptionHookPoint(
117 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
118 EchoResponse* resp =
119 static_cast<EchoResponse*>(methods->GetRecvMessage());
120 // Check that we got the hijacked message, and re-insert the expected
121 // message
122 EXPECT_EQ(resp->message(), "Hello1");
123 resp->set_message("Hello");
124 }
125 if (methods->QueryInterceptionHookPoint(
126 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
127 auto* map = methods->GetRecvTrailingMetadata();
128 bool found = false;
129 // Check that we received the metadata as an echo
130 for (const auto& pair : *map) {
131 found = pair.first.starts_with("testkey") &&
132 pair.second.starts_with("testvalue");
133 if (found) break;
134 }
135 EXPECT_EQ(found, true);
136 auto* status = methods->GetRecvStatus();
137 EXPECT_EQ(status->ok(), true);
138 }
139 if (methods->QueryInterceptionHookPoint(
140 experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
141 auto* map = methods->GetRecvInitialMetadata();
142 // Got nothing better to do here at the moment
143 EXPECT_EQ(map->size(), 0);
144 }
145 if (methods->QueryInterceptionHookPoint(
146 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
147 // Insert a different message than expected
148 EchoResponse* resp =
149 static_cast<EchoResponse*>(methods->GetRecvMessage());
150 resp->set_message("Hello1");
151 }
152 if (methods->QueryInterceptionHookPoint(
153 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
154 auto* map = methods->GetRecvTrailingMetadata();
155 // insert the metadata that we want
156 EXPECT_EQ(map->size(), 0);
157 map->insert(std::make_pair("testkey", "testvalue"));
158 auto* status = methods->GetRecvStatus();
159 *status = Status(StatusCode::OK, "");
160 }
161 if (hijack) {
162 methods->Hijack();
163 } else {
164 methods->Proceed();
165 }
166 }
167
168 private:
169 experimental::ClientRpcInfo* info_;
170 };
171
172 class HijackingInterceptorFactory
173 : public experimental::ClientInterceptorFactoryInterface {
174 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)175 experimental::Interceptor* CreateClientInterceptor(
176 experimental::ClientRpcInfo* info) override {
177 return new HijackingInterceptor(info);
178 }
179 };
180
181 class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor {
182 public:
HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo * info)183 explicit HijackingInterceptorMakesAnotherCall(
184 experimental::ClientRpcInfo* info) {
185 info_ = info;
186 // Make sure it is the right method
187 EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
188 EXPECT_EQ(strcmp("TestSuffixForStats", info->suffix_for_stats()), 0);
189 }
190
Intercept(experimental::InterceptorBatchMethods * methods)191 void Intercept(experimental::InterceptorBatchMethods* methods) override {
192 if (methods->QueryInterceptionHookPoint(
193 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
194 auto* map = methods->GetSendInitialMetadata();
195 // Check that we can see the test metadata
196 ASSERT_EQ(map->size(), 1);
197 auto iterator = map->begin();
198 EXPECT_EQ("testkey", iterator->first);
199 EXPECT_EQ("testvalue", iterator->second);
200 // Make a copy of the map
201 metadata_map_ = *map;
202 }
203 if (methods->QueryInterceptionHookPoint(
204 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
205 EchoRequest req;
206 auto* buffer = methods->GetSerializedSendMessage();
207 auto copied_buffer = *buffer;
208 EXPECT_TRUE(
209 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
210 .ok());
211 EXPECT_EQ(req.message(), "Hello");
212 req_ = req;
213 stub_ = grpc::testing::EchoTestService::NewStub(
214 methods->GetInterceptedChannel());
215 ctx_.AddMetadata(metadata_map_.begin()->first,
216 metadata_map_.begin()->second);
217 stub_->async()->Echo(&ctx_, &req_, &resp_, [this, methods](Status s) {
218 EXPECT_EQ(s.ok(), true);
219 EXPECT_EQ(resp_.message(), "Hello");
220 methods->Hijack();
221 });
222 // This is a Unary RPC and we have got nothing interesting to do in the
223 // PRE_SEND_CLOSE interception hook point for this interceptor, so let's
224 // return here. (We do not want to call methods->Proceed(). When the new
225 // RPC returns, we will call methods->Hijack() instead.)
226 return;
227 }
228 if (methods->QueryInterceptionHookPoint(
229 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
230 // Got nothing to do here for now
231 }
232 if (methods->QueryInterceptionHookPoint(
233 experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
234 auto* map = methods->GetRecvInitialMetadata();
235 // Got nothing better to do here for now
236 EXPECT_EQ(map->size(), 0);
237 }
238 if (methods->QueryInterceptionHookPoint(
239 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
240 EchoResponse* resp =
241 static_cast<EchoResponse*>(methods->GetRecvMessage());
242 // Check that we got the hijacked message, and re-insert the expected
243 // message
244 EXPECT_EQ(resp->message(), "Hello");
245 }
246 if (methods->QueryInterceptionHookPoint(
247 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
248 auto* map = methods->GetRecvTrailingMetadata();
249 bool found = false;
250 // Check that we received the metadata as an echo
251 for (const auto& pair : *map) {
252 found = pair.first.starts_with("testkey") &&
253 pair.second.starts_with("testvalue");
254 if (found) break;
255 }
256 EXPECT_EQ(found, true);
257 auto* status = methods->GetRecvStatus();
258 EXPECT_EQ(status->ok(), true);
259 }
260 if (methods->QueryInterceptionHookPoint(
261 experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
262 auto* map = methods->GetRecvInitialMetadata();
263 // Got nothing better to do here at the moment
264 EXPECT_EQ(map->size(), 0);
265 }
266 if (methods->QueryInterceptionHookPoint(
267 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
268 // Insert a different message than expected
269 EchoResponse* resp =
270 static_cast<EchoResponse*>(methods->GetRecvMessage());
271 resp->set_message(resp_.message());
272 }
273 if (methods->QueryInterceptionHookPoint(
274 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
275 auto* map = methods->GetRecvTrailingMetadata();
276 // insert the metadata that we want
277 EXPECT_EQ(map->size(), 0);
278 map->insert(std::make_pair("testkey", "testvalue"));
279 auto* status = methods->GetRecvStatus();
280 *status = Status(StatusCode::OK, "");
281 }
282
283 methods->Proceed();
284 }
285
286 private:
287 experimental::ClientRpcInfo* info_;
288 std::multimap<std::string, std::string> metadata_map_;
289 ClientContext ctx_;
290 EchoRequest req_;
291 EchoResponse resp_;
292 std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
293 };
294
295 class HijackingInterceptorMakesAnotherCallFactory
296 : public experimental::ClientInterceptorFactoryInterface {
297 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)298 experimental::Interceptor* CreateClientInterceptor(
299 experimental::ClientRpcInfo* info) override {
300 return new HijackingInterceptorMakesAnotherCall(info);
301 }
302 };
303
304 class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
305 public:
BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)306 explicit BidiStreamingRpcHijackingInterceptor(
307 experimental::ClientRpcInfo* info) {
308 info_ = info;
309 EXPECT_EQ(info->suffix_for_stats(), nullptr);
310 }
311
Intercept(experimental::InterceptorBatchMethods * methods)312 void Intercept(experimental::InterceptorBatchMethods* methods) override {
313 bool hijack = false;
314 if (methods->QueryInterceptionHookPoint(
315 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
316 CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
317 hijack = true;
318 }
319 if (methods->QueryInterceptionHookPoint(
320 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
321 EchoRequest req;
322 auto* buffer = methods->GetSerializedSendMessage();
323 auto copied_buffer = *buffer;
324 EXPECT_TRUE(
325 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
326 .ok());
327 EXPECT_EQ(req.message().find("Hello"), 0u);
328 msg = req.message();
329 }
330 if (methods->QueryInterceptionHookPoint(
331 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
332 // Got nothing to do here for now
333 }
334 if (methods->QueryInterceptionHookPoint(
335 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
336 CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
337 "testvalue");
338 auto* status = methods->GetRecvStatus();
339 EXPECT_EQ(status->ok(), true);
340 }
341 if (methods->QueryInterceptionHookPoint(
342 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
343 EchoResponse* resp =
344 static_cast<EchoResponse*>(methods->GetRecvMessage());
345 resp->set_message(msg);
346 }
347 if (methods->QueryInterceptionHookPoint(
348 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
349 EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
350 ->message()
351 .find("Hello"),
352 0u);
353 }
354 if (methods->QueryInterceptionHookPoint(
355 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
356 auto* map = methods->GetRecvTrailingMetadata();
357 // insert the metadata that we want
358 EXPECT_EQ(map->size(), 0);
359 map->insert(std::make_pair("testkey", "testvalue"));
360 auto* status = methods->GetRecvStatus();
361 *status = Status(StatusCode::OK, "");
362 }
363 if (hijack) {
364 methods->Hijack();
365 } else {
366 methods->Proceed();
367 }
368 }
369
370 private:
371 experimental::ClientRpcInfo* info_;
372 std::string msg;
373 };
374
375 class ClientStreamingRpcHijackingInterceptor
376 : public experimental::Interceptor {
377 public:
ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)378 explicit ClientStreamingRpcHijackingInterceptor(
379 experimental::ClientRpcInfo* info) {
380 info_ = info;
381 EXPECT_EQ(
382 strcmp("/grpc.testing.EchoTestService/RequestStream", info->method()),
383 0);
384 EXPECT_EQ(strcmp("TestSuffixForStats", info->suffix_for_stats()), 0);
385 }
Intercept(experimental::InterceptorBatchMethods * methods)386 void Intercept(experimental::InterceptorBatchMethods* methods) override {
387 bool hijack = false;
388 if (methods->QueryInterceptionHookPoint(
389 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
390 hijack = true;
391 }
392 if (methods->QueryInterceptionHookPoint(
393 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
394 if (++count_ > 10) {
395 methods->FailHijackedSendMessage();
396 }
397 }
398 if (methods->QueryInterceptionHookPoint(
399 experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
400 EXPECT_FALSE(got_failed_send_);
401 got_failed_send_ = !methods->GetSendMessageStatus();
402 }
403 if (methods->QueryInterceptionHookPoint(
404 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
405 auto* status = methods->GetRecvStatus();
406 *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
407 }
408 if (hijack) {
409 methods->Hijack();
410 } else {
411 methods->Proceed();
412 }
413 }
414
GotFailedSend()415 static bool GotFailedSend() { return got_failed_send_; }
416
417 private:
418 experimental::ClientRpcInfo* info_;
419 int count_ = 0;
420 static bool got_failed_send_;
421 };
422
423 bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
424
425 class ClientStreamingRpcHijackingInterceptorFactory
426 : public experimental::ClientInterceptorFactoryInterface {
427 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)428 experimental::Interceptor* CreateClientInterceptor(
429 experimental::ClientRpcInfo* info) override {
430 return new ClientStreamingRpcHijackingInterceptor(info);
431 }
432 };
433
434 class ServerStreamingRpcHijackingInterceptor
435 : public experimental::Interceptor {
436 public:
ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)437 explicit ServerStreamingRpcHijackingInterceptor(
438 experimental::ClientRpcInfo* info) {
439 info_ = info;
440 got_failed_message_ = false;
441 EXPECT_EQ(info->suffix_for_stats(), nullptr);
442 }
443
Intercept(experimental::InterceptorBatchMethods * methods)444 void Intercept(experimental::InterceptorBatchMethods* methods) override {
445 bool hijack = false;
446 if (methods->QueryInterceptionHookPoint(
447 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
448 auto* map = methods->GetSendInitialMetadata();
449 // Check that we can see the test metadata
450 ASSERT_EQ(map->size(), 1);
451 auto iterator = map->begin();
452 EXPECT_EQ("testkey", iterator->first);
453 EXPECT_EQ("testvalue", iterator->second);
454 hijack = true;
455 }
456 if (methods->QueryInterceptionHookPoint(
457 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
458 EchoRequest req;
459 auto* buffer = methods->GetSerializedSendMessage();
460 auto copied_buffer = *buffer;
461 EXPECT_TRUE(
462 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
463 .ok());
464 EXPECT_EQ(req.message(), "Hello");
465 }
466 if (methods->QueryInterceptionHookPoint(
467 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
468 // Got nothing to do here for now
469 }
470 if (methods->QueryInterceptionHookPoint(
471 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
472 auto* map = methods->GetRecvTrailingMetadata();
473 bool found = false;
474 // Check that we received the metadata as an echo
475 for (const auto& pair : *map) {
476 found = pair.first.starts_with("testkey") &&
477 pair.second.starts_with("testvalue");
478 if (found) break;
479 }
480 EXPECT_EQ(found, true);
481 auto* status = methods->GetRecvStatus();
482 EXPECT_EQ(status->ok(), true);
483 }
484 if (methods->QueryInterceptionHookPoint(
485 experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
486 if (++count_ > 10) {
487 methods->FailHijackedRecvMessage();
488 }
489 EchoResponse* resp =
490 static_cast<EchoResponse*>(methods->GetRecvMessage());
491 resp->set_message("Hello");
492 }
493 if (methods->QueryInterceptionHookPoint(
494 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
495 // Only the last message will be a failure
496 EXPECT_FALSE(got_failed_message_);
497 got_failed_message_ = methods->GetRecvMessage() == nullptr;
498 }
499 if (methods->QueryInterceptionHookPoint(
500 experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
501 auto* map = methods->GetRecvTrailingMetadata();
502 // insert the metadata that we want
503 EXPECT_EQ(map->size(), 0);
504 map->insert(std::make_pair("testkey", "testvalue"));
505 auto* status = methods->GetRecvStatus();
506 *status = Status(StatusCode::OK, "");
507 }
508 if (hijack) {
509 methods->Hijack();
510 } else {
511 methods->Proceed();
512 }
513 }
514
GotFailedMessage()515 static bool GotFailedMessage() { return got_failed_message_; }
516
517 private:
518 experimental::ClientRpcInfo* info_;
519 static bool got_failed_message_;
520 int count_ = 0;
521 };
522
523 bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false;
524
525 class ServerStreamingRpcHijackingInterceptorFactory
526 : public experimental::ClientInterceptorFactoryInterface {
527 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)528 experimental::Interceptor* CreateClientInterceptor(
529 experimental::ClientRpcInfo* info) override {
530 return new ServerStreamingRpcHijackingInterceptor(info);
531 }
532 };
533
534 class BidiStreamingRpcHijackingInterceptorFactory
535 : public experimental::ClientInterceptorFactoryInterface {
536 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)537 experimental::Interceptor* CreateClientInterceptor(
538 experimental::ClientRpcInfo* info) override {
539 return new BidiStreamingRpcHijackingInterceptor(info);
540 }
541 };
542
543 // The logging interceptor is for testing purposes only. It is used to verify
544 // that all the appropriate hook points are invoked for an RPC. The counts are
545 // reset each time a new object of LoggingInterceptor is created, so only a
546 // single RPC should be made on the channel before calling the Verify methods.
547 class LoggingInterceptor : public experimental::Interceptor {
548 public:
LoggingInterceptor(experimental::ClientRpcInfo *)549 explicit LoggingInterceptor(experimental::ClientRpcInfo* /*info*/) {
550 pre_send_initial_metadata_ = false;
551 pre_send_message_count_ = 0;
552 pre_send_close_ = false;
553 post_recv_initial_metadata_ = false;
554 post_recv_message_count_ = 0;
555 post_recv_status_ = false;
556 }
557
Intercept(experimental::InterceptorBatchMethods * methods)558 void Intercept(experimental::InterceptorBatchMethods* methods) override {
559 if (methods->QueryInterceptionHookPoint(
560 experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
561 auto* map = methods->GetSendInitialMetadata();
562 // Check that we can see the test metadata
563 ASSERT_EQ(map->size(), 1);
564 auto iterator = map->begin();
565 EXPECT_EQ("testkey", iterator->first);
566 EXPECT_EQ("testvalue", iterator->second);
567 ASSERT_FALSE(pre_send_initial_metadata_);
568 pre_send_initial_metadata_ = true;
569 }
570 if (methods->QueryInterceptionHookPoint(
571 experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
572 EchoRequest req;
573 auto* send_msg = methods->GetSendMessage();
574 if (send_msg == nullptr) {
575 // We did not get the non-serialized form of the message. Get the
576 // serialized form.
577 auto* buffer = methods->GetSerializedSendMessage();
578 auto copied_buffer = *buffer;
579 EchoRequest req;
580 EXPECT_TRUE(
581 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
582 .ok());
583 EXPECT_EQ(req.message(), "Hello");
584 } else {
585 EXPECT_EQ(
586 static_cast<const EchoRequest*>(send_msg)->message().find("Hello"),
587 0u);
588 }
589 auto* buffer = methods->GetSerializedSendMessage();
590 auto copied_buffer = *buffer;
591 EXPECT_TRUE(
592 SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
593 .ok());
594 EXPECT_TRUE(req.message().find("Hello") == 0u);
595 pre_send_message_count_++;
596 }
597 if (methods->QueryInterceptionHookPoint(
598 experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
599 // Got nothing to do here for now
600 pre_send_close_ = true;
601 }
602 if (methods->QueryInterceptionHookPoint(
603 experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
604 auto* map = methods->GetRecvInitialMetadata();
605 // Got nothing better to do here for now
606 EXPECT_EQ(map->size(), 0);
607 post_recv_initial_metadata_ = true;
608 }
609 if (methods->QueryInterceptionHookPoint(
610 experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
611 EchoResponse* resp =
612 static_cast<EchoResponse*>(methods->GetRecvMessage());
613 if (resp != nullptr) {
614 EXPECT_TRUE(resp->message().find("Hello") == 0u);
615 post_recv_message_count_++;
616 }
617 }
618 if (methods->QueryInterceptionHookPoint(
619 experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
620 auto* map = methods->GetRecvTrailingMetadata();
621 bool found = false;
622 // Check that we received the metadata as an echo
623 for (const auto& pair : *map) {
624 found = pair.first.starts_with("testkey") &&
625 pair.second.starts_with("testvalue");
626 if (found) break;
627 }
628 EXPECT_EQ(found, true);
629 auto* status = methods->GetRecvStatus();
630 EXPECT_EQ(status->ok(), true);
631 post_recv_status_ = true;
632 }
633 methods->Proceed();
634 }
635
VerifyCall(RPCType type)636 static void VerifyCall(RPCType type) {
637 switch (type) {
638 case RPCType::kSyncUnary:
639 case RPCType::kAsyncCQUnary:
640 VerifyUnaryCall();
641 break;
642 case RPCType::kSyncClientStreaming:
643 case RPCType::kAsyncCQClientStreaming:
644 VerifyClientStreamingCall();
645 break;
646 case RPCType::kSyncServerStreaming:
647 case RPCType::kAsyncCQServerStreaming:
648 VerifyServerStreamingCall();
649 break;
650 case RPCType::kSyncBidiStreaming:
651 case RPCType::kAsyncCQBidiStreaming:
652 VerifyBidiStreamingCall();
653 break;
654 }
655 }
656
VerifyCallCommon()657 static void VerifyCallCommon() {
658 EXPECT_TRUE(pre_send_initial_metadata_);
659 EXPECT_TRUE(pre_send_close_);
660 EXPECT_TRUE(post_recv_initial_metadata_);
661 EXPECT_TRUE(post_recv_status_);
662 }
663
VerifyUnaryCall()664 static void VerifyUnaryCall() {
665 VerifyCallCommon();
666 EXPECT_EQ(pre_send_message_count_, 1);
667 EXPECT_EQ(post_recv_message_count_, 1);
668 }
669
VerifyClientStreamingCall()670 static void VerifyClientStreamingCall() {
671 VerifyCallCommon();
672 EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
673 EXPECT_EQ(post_recv_message_count_, 1);
674 }
675
VerifyServerStreamingCall()676 static void VerifyServerStreamingCall() {
677 VerifyCallCommon();
678 EXPECT_EQ(pre_send_message_count_, 1);
679 EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
680 }
681
VerifyBidiStreamingCall()682 static void VerifyBidiStreamingCall() {
683 VerifyCallCommon();
684 EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
685 EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
686 }
687
688 private:
689 static bool pre_send_initial_metadata_;
690 static int pre_send_message_count_;
691 static bool pre_send_close_;
692 static bool post_recv_initial_metadata_;
693 static int post_recv_message_count_;
694 static bool post_recv_status_;
695 };
696
697 bool LoggingInterceptor::pre_send_initial_metadata_;
698 int LoggingInterceptor::pre_send_message_count_;
699 bool LoggingInterceptor::pre_send_close_;
700 bool LoggingInterceptor::post_recv_initial_metadata_;
701 int LoggingInterceptor::post_recv_message_count_;
702 bool LoggingInterceptor::post_recv_status_;
703
704 class LoggingInterceptorFactory
705 : public experimental::ClientInterceptorFactoryInterface {
706 public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)707 experimental::Interceptor* CreateClientInterceptor(
708 experimental::ClientRpcInfo* info) override {
709 return new LoggingInterceptor(info);
710 }
711 };
712
713 class TestScenario {
714 public:
TestScenario(const ChannelType & channel_type,const RPCType & rpc_type)715 explicit TestScenario(const ChannelType& channel_type,
716 const RPCType& rpc_type)
717 : channel_type_(channel_type), rpc_type_(rpc_type) {}
718
channel_type() const719 ChannelType channel_type() const { return channel_type_; }
720
rpc_type() const721 RPCType rpc_type() const { return rpc_type_; }
722
723 private:
724 const ChannelType channel_type_;
725 const RPCType rpc_type_;
726 };
727
CreateTestScenarios()728 std::vector<TestScenario> CreateTestScenarios() {
729 std::vector<TestScenario> scenarios;
730 std::vector<RPCType> rpc_types;
731 rpc_types.emplace_back(RPCType::kSyncUnary);
732 rpc_types.emplace_back(RPCType::kSyncClientStreaming);
733 rpc_types.emplace_back(RPCType::kSyncServerStreaming);
734 rpc_types.emplace_back(RPCType::kSyncBidiStreaming);
735 rpc_types.emplace_back(RPCType::kAsyncCQUnary);
736 rpc_types.emplace_back(RPCType::kAsyncCQServerStreaming);
737 for (const auto& rpc_type : rpc_types) {
738 scenarios.emplace_back(ChannelType::kHttpChannel, rpc_type);
739 // TODO(yashykt): Maybe add support for non-posix sockets too
740 #ifdef GRPC_POSIX_SOCKET
741 scenarios.emplace_back(ChannelType::kFdChannel, rpc_type);
742 #endif // GRPC_POSIX_SOCKET
743 }
744 return scenarios;
745 }
746
747 class ParameterizedClientInterceptorsEnd2endTest
748 : public ::testing::TestWithParam<TestScenario> {
749 protected:
ParameterizedClientInterceptorsEnd2endTest()750 ParameterizedClientInterceptorsEnd2endTest() {
751 ServerBuilder builder;
752 builder.RegisterService(&service_);
753 if (GetParam().channel_type() == ChannelType::kHttpChannel) {
754 int port = grpc_pick_unused_port_or_die();
755 server_address_ = "localhost:" + std::to_string(port);
756 builder.AddListeningPort(server_address_, InsecureServerCredentials());
757 server_ = builder.BuildAndStart();
758 }
759 #ifdef GRPC_POSIX_SOCKET
760 else if (GetParam().channel_type() == ChannelType::kFdChannel) {
761 int flags;
762 GPR_ASSERT(socketpair(AF_UNIX, SOCK_STREAM, 0, sv_) == 0);
763 flags = fcntl(sv_[0], F_GETFL, 0);
764 GPR_ASSERT(fcntl(sv_[0], F_SETFL, flags | O_NONBLOCK) == 0);
765 flags = fcntl(sv_[1], F_GETFL, 0);
766 GPR_ASSERT(fcntl(sv_[1], F_SETFL, flags | O_NONBLOCK) == 0);
767 GPR_ASSERT(grpc_set_socket_no_sigpipe_if_possible(sv_[0]) ==
768 absl::OkStatus());
769 GPR_ASSERT(grpc_set_socket_no_sigpipe_if_possible(sv_[1]) ==
770 absl::OkStatus());
771 server_ = builder.BuildAndStart();
772 AddInsecureChannelFromFd(server_.get(), sv_[1]);
773 }
774 #endif // GRPC_POSIX_SOCKET
775 }
776
~ParameterizedClientInterceptorsEnd2endTest()777 ~ParameterizedClientInterceptorsEnd2endTest() override {
778 server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0));
779 }
780
CreateClientChannel(std::vector<std::unique_ptr<grpc::experimental::ClientInterceptorFactoryInterface>> creators)781 std::shared_ptr<grpc::Channel> CreateClientChannel(
782 std::vector<std::unique_ptr<
783 grpc::experimental::ClientInterceptorFactoryInterface>>
784 creators) {
785 if (GetParam().channel_type() == ChannelType::kHttpChannel) {
786 return experimental::CreateCustomChannelWithInterceptors(
787 server_address_, InsecureChannelCredentials(), ChannelArguments(),
788 std::move(creators));
789 }
790 #ifdef GRPC_POSIX_SOCKET
791 else if (GetParam().channel_type() == ChannelType::kFdChannel) {
792 return experimental::CreateCustomInsecureChannelWithInterceptorsFromFd(
793 "", sv_[0], ChannelArguments(), std::move(creators));
794 }
795 #endif // GRPC_POSIX_SOCKET
796 return nullptr;
797 }
798
SendRPC(const std::shared_ptr<Channel> & channel)799 void SendRPC(const std::shared_ptr<Channel>& channel) {
800 switch (GetParam().rpc_type()) {
801 case RPCType::kSyncUnary:
802 MakeCall(channel);
803 break;
804 case RPCType::kSyncClientStreaming:
805 MakeClientStreamingCall(channel);
806 break;
807 case RPCType::kSyncServerStreaming:
808 MakeServerStreamingCall(channel);
809 break;
810 case RPCType::kSyncBidiStreaming:
811 MakeBidiStreamingCall(channel);
812 break;
813 case RPCType::kAsyncCQUnary:
814 MakeAsyncCQCall(channel);
815 break;
816 case RPCType::kAsyncCQClientStreaming:
817 // TODO(yashykt) : Fill this out
818 break;
819 case RPCType::kAsyncCQServerStreaming:
820 MakeAsyncCQServerStreamingCall(channel);
821 break;
822 case RPCType::kAsyncCQBidiStreaming:
823 // TODO(yashykt) : Fill this out
824 break;
825 }
826 }
827
828 std::string server_address_;
829 int sv_[2];
830 EchoTestServiceStreamingImpl service_;
831 std::unique_ptr<Server> server_;
832 };
833
TEST_P(ParameterizedClientInterceptorsEnd2endTest,ClientInterceptorLoggingTest)834 TEST_P(ParameterizedClientInterceptorsEnd2endTest,
835 ClientInterceptorLoggingTest) {
836 ChannelArguments args;
837 PhonyInterceptor::Reset();
838 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
839 creators;
840 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
841 // Add 20 phony interceptors
842 for (auto i = 0; i < 20; i++) {
843 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
844 }
845 auto channel = CreateClientChannel(std::move(creators));
846 SendRPC(channel);
847 LoggingInterceptor::VerifyCall(GetParam().rpc_type());
848 // Make sure all 20 phony interceptors were run
849 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
850 }
851
852 INSTANTIATE_TEST_SUITE_P(ParameterizedClientInterceptorsEnd2end,
853 ParameterizedClientInterceptorsEnd2endTest,
854 ::testing::ValuesIn(CreateTestScenarios()));
855
856 class ClientInterceptorsEnd2endTest
857 : public ::testing::TestWithParam<TestScenario> {
858 protected:
ClientInterceptorsEnd2endTest()859 ClientInterceptorsEnd2endTest() {
860 int port = grpc_pick_unused_port_or_die();
861
862 ServerBuilder builder;
863 server_address_ = "localhost:" + std::to_string(port);
864 builder.AddListeningPort(server_address_, InsecureServerCredentials());
865 builder.RegisterService(&service_);
866 server_ = builder.BuildAndStart();
867 }
868
~ClientInterceptorsEnd2endTest()869 ~ClientInterceptorsEnd2endTest() override { server_->Shutdown(); }
870
871 std::string server_address_;
872 TestServiceImpl service_;
873 std::unique_ptr<Server> server_;
874 };
875
TEST_F(ClientInterceptorsEnd2endTest,LameChannelClientInterceptorHijackingTest)876 TEST_F(ClientInterceptorsEnd2endTest,
877 LameChannelClientInterceptorHijackingTest) {
878 ChannelArguments args;
879 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
880 creators;
881 creators.push_back(std::make_unique<HijackingInterceptorFactory>());
882 auto channel = experimental::CreateCustomChannelWithInterceptors(
883 server_address_, nullptr, args, std::move(creators));
884 MakeCall(channel);
885 }
886
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorHijackingTest)887 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) {
888 ChannelArguments args;
889 PhonyInterceptor::Reset();
890 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
891 creators;
892 // Add 20 phony interceptors before hijacking interceptor
893 creators.reserve(20);
894 for (auto i = 0; i < 20; i++) {
895 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
896 }
897 creators.push_back(std::make_unique<HijackingInterceptorFactory>());
898 // Add 20 phony interceptors after hijacking interceptor
899 for (auto i = 0; i < 20; i++) {
900 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
901 }
902 auto channel = experimental::CreateCustomChannelWithInterceptors(
903 server_address_, InsecureChannelCredentials(), args, std::move(creators));
904 MakeCall(channel);
905 // Make sure only 20 phony interceptors were run
906 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
907 }
908
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorLogThenHijackTest)909 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) {
910 ChannelArguments args;
911 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
912 creators;
913 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
914 creators.push_back(std::make_unique<HijackingInterceptorFactory>());
915 auto channel = experimental::CreateCustomChannelWithInterceptors(
916 server_address_, InsecureChannelCredentials(), args, std::move(creators));
917 MakeCall(channel);
918 LoggingInterceptor::VerifyUnaryCall();
919 }
920
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorHijackingMakesAnotherCallTest)921 TEST_F(ClientInterceptorsEnd2endTest,
922 ClientInterceptorHijackingMakesAnotherCallTest) {
923 ChannelArguments args;
924 PhonyInterceptor::Reset();
925 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
926 creators;
927 // Add 5 phony interceptors before hijacking interceptor
928 creators.reserve(5);
929 for (auto i = 0; i < 5; i++) {
930 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
931 }
932 creators.push_back(
933 std::unique_ptr<experimental::ClientInterceptorFactoryInterface>(
934 new HijackingInterceptorMakesAnotherCallFactory()));
935 // Add 7 phony interceptors after hijacking interceptor
936 for (auto i = 0; i < 7; i++) {
937 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
938 }
939 auto channel = server_->experimental().InProcessChannelWithInterceptors(
940 args, std::move(creators));
941
942 MakeCall(channel, StubOptions("TestSuffixForStats"));
943 // Make sure all interceptors were run once, since the hijacking interceptor
944 // makes an RPC on the intercepted channel
945 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 12);
946 }
947
948 class ClientInterceptorsCallbackEnd2endTest : public ::testing::Test {
949 protected:
ClientInterceptorsCallbackEnd2endTest()950 ClientInterceptorsCallbackEnd2endTest() {
951 int port = grpc_pick_unused_port_or_die();
952
953 ServerBuilder builder;
954 server_address_ = "localhost:" + std::to_string(port);
955 builder.AddListeningPort(server_address_, InsecureServerCredentials());
956 builder.RegisterService(&service_);
957 server_ = builder.BuildAndStart();
958 }
959
~ClientInterceptorsCallbackEnd2endTest()960 ~ClientInterceptorsCallbackEnd2endTest() override { server_->Shutdown(); }
961
962 std::string server_address_;
963 TestServiceImpl service_;
964 std::unique_ptr<Server> server_;
965 };
966
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorLoggingTestWithCallback)967 TEST_F(ClientInterceptorsCallbackEnd2endTest,
968 ClientInterceptorLoggingTestWithCallback) {
969 ChannelArguments args;
970 PhonyInterceptor::Reset();
971 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
972 creators;
973 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
974 // Add 20 phony interceptors
975 for (auto i = 0; i < 20; i++) {
976 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
977 }
978 auto channel = server_->experimental().InProcessChannelWithInterceptors(
979 args, std::move(creators));
980 MakeCallbackCall(channel);
981 LoggingInterceptor::VerifyUnaryCall();
982 // Make sure all 20 phony interceptors were run
983 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
984 }
985
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorHijackingTestWithCallback)986 TEST_F(ClientInterceptorsCallbackEnd2endTest,
987 ClientInterceptorHijackingTestWithCallback) {
988 ChannelArguments args;
989 PhonyInterceptor::Reset();
990 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
991 creators;
992 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
993 // Add 20 phony interceptors
994 for (auto i = 0; i < 20; i++) {
995 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
996 }
997 creators.push_back(std::make_unique<HijackingInterceptorFactory>());
998 auto channel = server_->experimental().InProcessChannelWithInterceptors(
999 args, std::move(creators));
1000 MakeCallbackCall(channel);
1001 LoggingInterceptor::VerifyUnaryCall();
1002 // Make sure all 20 phony interceptors were run
1003 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1004 }
1005
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorFactoryAllowsNullptrReturn)1006 TEST_F(ClientInterceptorsCallbackEnd2endTest,
1007 ClientInterceptorFactoryAllowsNullptrReturn) {
1008 ChannelArguments args;
1009 PhonyInterceptor::Reset();
1010 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1011 creators;
1012 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1013 // Add 20 phony interceptors and 20 null interceptors
1014 for (auto i = 0; i < 20; i++) {
1015 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1016 creators.push_back(std::make_unique<NullInterceptorFactory>());
1017 }
1018 auto channel = server_->experimental().InProcessChannelWithInterceptors(
1019 args, std::move(creators));
1020 MakeCallbackCall(channel);
1021 LoggingInterceptor::VerifyUnaryCall();
1022 // Make sure all 20 phony interceptors were run
1023 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1024 }
1025
1026 class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test {
1027 protected:
ClientInterceptorsStreamingEnd2endTest()1028 ClientInterceptorsStreamingEnd2endTest() {
1029 int port = grpc_pick_unused_port_or_die();
1030
1031 ServerBuilder builder;
1032 server_address_ = "localhost:" + std::to_string(port);
1033 builder.AddListeningPort(server_address_, InsecureServerCredentials());
1034 builder.RegisterService(&service_);
1035 server_ = builder.BuildAndStart();
1036 }
1037
~ClientInterceptorsStreamingEnd2endTest()1038 ~ClientInterceptorsStreamingEnd2endTest() override { server_->Shutdown(); }
1039
1040 std::string server_address_;
1041 EchoTestServiceStreamingImpl service_;
1042 std::unique_ptr<Server> server_;
1043 };
1044
TEST_F(ClientInterceptorsStreamingEnd2endTest,ClientStreamingTest)1045 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) {
1046 ChannelArguments args;
1047 PhonyInterceptor::Reset();
1048 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1049 creators;
1050 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1051 // Add 20 phony interceptors
1052 for (auto i = 0; i < 20; i++) {
1053 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1054 }
1055 auto channel = experimental::CreateCustomChannelWithInterceptors(
1056 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1057 MakeClientStreamingCall(channel);
1058 LoggingInterceptor::VerifyClientStreamingCall();
1059 // Make sure all 20 phony interceptors were run
1060 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1061 }
1062
TEST_F(ClientInterceptorsStreamingEnd2endTest,ServerStreamingTest)1063 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
1064 ChannelArguments args;
1065 PhonyInterceptor::Reset();
1066 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1067 creators;
1068 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1069 // Add 20 phony interceptors
1070 for (auto i = 0; i < 20; i++) {
1071 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1072 }
1073 auto channel = experimental::CreateCustomChannelWithInterceptors(
1074 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1075 MakeServerStreamingCall(channel);
1076 LoggingInterceptor::VerifyServerStreamingCall();
1077 // Make sure all 20 phony interceptors were run
1078 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1079 }
1080
TEST_F(ClientInterceptorsStreamingEnd2endTest,ClientStreamingHijackingTest)1081 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
1082 ChannelArguments args;
1083 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1084 creators;
1085 creators.push_back(
1086 std::make_unique<ClientStreamingRpcHijackingInterceptorFactory>());
1087 auto channel = experimental::CreateCustomChannelWithInterceptors(
1088 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1089
1090 auto stub = grpc::testing::EchoTestService::NewStub(
1091 channel, StubOptions("TestSuffixForStats"));
1092 ClientContext ctx;
1093 EchoRequest req;
1094 EchoResponse resp;
1095 req.mutable_param()->set_echo_metadata(true);
1096 req.set_message("Hello");
1097 string expected_resp;
1098 auto writer = stub->RequestStream(&ctx, &resp);
1099 for (int i = 0; i < 10; i++) {
1100 EXPECT_TRUE(writer->Write(req));
1101 expected_resp += "Hello";
1102 }
1103 // The interceptor will reject the 11th message
1104 writer->Write(req);
1105 Status s = writer->Finish();
1106 EXPECT_EQ(s.ok(), false);
1107 EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
1108 }
1109
TEST_F(ClientInterceptorsStreamingEnd2endTest,ServerStreamingHijackingTest)1110 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
1111 ChannelArguments args;
1112 PhonyInterceptor::Reset();
1113 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1114 creators;
1115 creators.push_back(
1116 std::make_unique<ServerStreamingRpcHijackingInterceptorFactory>());
1117 auto channel = experimental::CreateCustomChannelWithInterceptors(
1118 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1119 MakeServerStreamingCall(channel);
1120 EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
1121 }
1122
TEST_F(ClientInterceptorsStreamingEnd2endTest,AsyncCQServerStreamingHijackingTest)1123 TEST_F(ClientInterceptorsStreamingEnd2endTest,
1124 AsyncCQServerStreamingHijackingTest) {
1125 ChannelArguments args;
1126 PhonyInterceptor::Reset();
1127 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1128 creators;
1129 creators.push_back(
1130 std::make_unique<ServerStreamingRpcHijackingInterceptorFactory>());
1131 auto channel = experimental::CreateCustomChannelWithInterceptors(
1132 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1133 MakeAsyncCQServerStreamingCall(channel);
1134 EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
1135 }
1136
TEST_F(ClientInterceptorsStreamingEnd2endTest,BidiStreamingHijackingTest)1137 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
1138 ChannelArguments args;
1139 PhonyInterceptor::Reset();
1140 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1141 creators;
1142 creators.push_back(
1143 std::make_unique<BidiStreamingRpcHijackingInterceptorFactory>());
1144 auto channel = experimental::CreateCustomChannelWithInterceptors(
1145 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1146 MakeBidiStreamingCall(channel);
1147 }
1148
TEST_F(ClientInterceptorsStreamingEnd2endTest,BidiStreamingTest)1149 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
1150 ChannelArguments args;
1151 PhonyInterceptor::Reset();
1152 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1153 creators;
1154 creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1155 // Add 20 phony interceptors
1156 for (auto i = 0; i < 20; i++) {
1157 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1158 }
1159 auto channel = experimental::CreateCustomChannelWithInterceptors(
1160 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1161 MakeBidiStreamingCall(channel);
1162 LoggingInterceptor::VerifyBidiStreamingCall();
1163 // Make sure all 20 phony interceptors were run
1164 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1165 }
1166
1167 class ClientGlobalInterceptorEnd2endTest : public ::testing::Test {
1168 protected:
ClientGlobalInterceptorEnd2endTest()1169 ClientGlobalInterceptorEnd2endTest() {
1170 int port = grpc_pick_unused_port_or_die();
1171
1172 ServerBuilder builder;
1173 server_address_ = "localhost:" + std::to_string(port);
1174 builder.AddListeningPort(server_address_, InsecureServerCredentials());
1175 builder.RegisterService(&service_);
1176 server_ = builder.BuildAndStart();
1177 }
1178
~ClientGlobalInterceptorEnd2endTest()1179 ~ClientGlobalInterceptorEnd2endTest() override { server_->Shutdown(); }
1180
1181 std::string server_address_;
1182 TestServiceImpl service_;
1183 std::unique_ptr<Server> server_;
1184 };
1185
TEST_F(ClientGlobalInterceptorEnd2endTest,PhonyGlobalInterceptor)1186 TEST_F(ClientGlobalInterceptorEnd2endTest, PhonyGlobalInterceptor) {
1187 // We should ideally be registering a global interceptor only once per
1188 // process, but for the purposes of testing, it should be fine to modify the
1189 // registered global interceptor when there are no ongoing gRPC operations
1190 PhonyInterceptorFactory global_factory;
1191 experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1192 ChannelArguments args;
1193 PhonyInterceptor::Reset();
1194 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1195 creators;
1196 // Add 20 phony interceptors
1197 creators.reserve(20);
1198 for (auto i = 0; i < 20; i++) {
1199 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1200 }
1201 auto channel = experimental::CreateCustomChannelWithInterceptors(
1202 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1203 MakeCall(channel);
1204 // Make sure all 20 phony interceptors were run with the global interceptor
1205 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 21);
1206 experimental::TestOnlyResetGlobalClientInterceptorFactory();
1207 }
1208
TEST_F(ClientGlobalInterceptorEnd2endTest,LoggingGlobalInterceptor)1209 TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) {
1210 // We should ideally be registering a global interceptor only once per
1211 // process, but for the purposes of testing, it should be fine to modify the
1212 // registered global interceptor when there are no ongoing gRPC operations
1213 LoggingInterceptorFactory global_factory;
1214 experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1215 ChannelArguments args;
1216 PhonyInterceptor::Reset();
1217 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1218 creators;
1219 // Add 20 phony interceptors
1220 creators.reserve(20);
1221 for (auto i = 0; i < 20; i++) {
1222 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1223 }
1224 auto channel = experimental::CreateCustomChannelWithInterceptors(
1225 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1226 MakeCall(channel);
1227 LoggingInterceptor::VerifyUnaryCall();
1228 // Make sure all 20 phony interceptors were run
1229 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1230 experimental::TestOnlyResetGlobalClientInterceptorFactory();
1231 }
1232
TEST_F(ClientGlobalInterceptorEnd2endTest,HijackingGlobalInterceptor)1233 TEST_F(ClientGlobalInterceptorEnd2endTest, HijackingGlobalInterceptor) {
1234 // We should ideally be registering a global interceptor only once per
1235 // process, but for the purposes of testing, it should be fine to modify the
1236 // registered global interceptor when there are no ongoing gRPC operations
1237 HijackingInterceptorFactory global_factory;
1238 experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1239 ChannelArguments args;
1240 PhonyInterceptor::Reset();
1241 std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1242 creators;
1243 // Add 20 phony interceptors
1244 creators.reserve(20);
1245 for (auto i = 0; i < 20; i++) {
1246 creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1247 }
1248 auto channel = experimental::CreateCustomChannelWithInterceptors(
1249 server_address_, InsecureChannelCredentials(), args, std::move(creators));
1250 MakeCall(channel);
1251 // Make sure all 20 phony interceptors were run
1252 EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1253 experimental::TestOnlyResetGlobalClientInterceptorFactory();
1254 }
1255
1256 } // namespace
1257 } // namespace testing
1258 } // namespace grpc
1259
main(int argc,char ** argv)1260 int main(int argc, char** argv) {
1261 grpc::testing::TestEnvironment env(&argc, argv);
1262 ::testing::InitGoogleTest(&argc, argv);
1263 int ret = RUN_ALL_TESTS();
1264 // Make sure that gRPC shuts down cleanly
1265 GPR_ASSERT(grpc_wait_until_shutdown(10));
1266 return ret;
1267 }
1268