xref: /aosp_15_r20/external/grpc-grpc/test/cpp/end2end/client_interceptors_end2end_test.cc (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 //
2 //
3 // Copyright 2018 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <memory>
20 #include <vector>
21 
22 #include <gtest/gtest.h>
23 
24 #include "absl/memory/memory.h"
25 
26 #include <grpcpp/channel.h>
27 #include <grpcpp/client_context.h>
28 #include <grpcpp/create_channel.h>
29 #include <grpcpp/create_channel_posix.h>
30 #include <grpcpp/generic/generic_stub.h>
31 #include <grpcpp/impl/proto_utils.h>
32 #include <grpcpp/server.h>
33 #include <grpcpp/server_builder.h>
34 #include <grpcpp/server_context.h>
35 #include <grpcpp/server_posix.h>
36 #include <grpcpp/support/client_interceptor.h>
37 
38 #include "src/core/lib/iomgr/port.h"
39 #include "src/proto/grpc/testing/echo.grpc.pb.h"
40 #include "test/core/util/port.h"
41 #include "test/core/util/test_config.h"
42 #include "test/cpp/end2end/interceptors_util.h"
43 #include "test/cpp/end2end/test_service_impl.h"
44 #include "test/cpp/util/byte_buffer_proto_helper.h"
45 #include "test/cpp/util/string_ref_helper.h"
46 
47 #ifdef GRPC_POSIX_SOCKET
48 #include <fcntl.h>
49 
50 #include "src/core/lib/iomgr/socket_utils_posix.h"
51 #endif  // GRPC_POSIX_SOCKET
52 
53 namespace grpc {
54 namespace testing {
55 namespace {
56 
57 enum class RPCType {
58   kSyncUnary,
59   kSyncClientStreaming,
60   kSyncServerStreaming,
61   kSyncBidiStreaming,
62   kAsyncCQUnary,
63   kAsyncCQClientStreaming,
64   kAsyncCQServerStreaming,
65   kAsyncCQBidiStreaming,
66 };
67 
68 enum class ChannelType {
69   kHttpChannel,
70   kFdChannel,
71 };
72 
73 // Hijacks Echo RPC and fills in the expected values
74 class HijackingInterceptor : public experimental::Interceptor {
75  public:
HijackingInterceptor(experimental::ClientRpcInfo * info)76   explicit HijackingInterceptor(experimental::ClientRpcInfo* info) {
77     info_ = info;
78     // Make sure it is the right method
79     EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
80     EXPECT_EQ(info->suffix_for_stats(), nullptr);
81     EXPECT_EQ(info->type(), experimental::ClientRpcInfo::Type::UNARY);
82   }
83 
Intercept(experimental::InterceptorBatchMethods * methods)84   void Intercept(experimental::InterceptorBatchMethods* methods) override {
85     bool hijack = false;
86     if (methods->QueryInterceptionHookPoint(
87             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
88       auto* map = methods->GetSendInitialMetadata();
89       // Check that we can see the test metadata
90       ASSERT_EQ(map->size(), 1);
91       auto iterator = map->begin();
92       EXPECT_EQ("testkey", iterator->first);
93       EXPECT_EQ("testvalue", iterator->second);
94       hijack = true;
95     }
96     if (methods->QueryInterceptionHookPoint(
97             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
98       EchoRequest req;
99       auto* buffer = methods->GetSerializedSendMessage();
100       auto copied_buffer = *buffer;
101       EXPECT_TRUE(
102           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
103               .ok());
104       EXPECT_EQ(req.message(), "Hello");
105     }
106     if (methods->QueryInterceptionHookPoint(
107             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
108       // Got nothing to do here for now
109     }
110     if (methods->QueryInterceptionHookPoint(
111             experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
112       auto* map = methods->GetRecvInitialMetadata();
113       // Got nothing better to do here for now
114       EXPECT_EQ(map->size(), 0);
115     }
116     if (methods->QueryInterceptionHookPoint(
117             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
118       EchoResponse* resp =
119           static_cast<EchoResponse*>(methods->GetRecvMessage());
120       // Check that we got the hijacked message, and re-insert the expected
121       // message
122       EXPECT_EQ(resp->message(), "Hello1");
123       resp->set_message("Hello");
124     }
125     if (methods->QueryInterceptionHookPoint(
126             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
127       auto* map = methods->GetRecvTrailingMetadata();
128       bool found = false;
129       // Check that we received the metadata as an echo
130       for (const auto& pair : *map) {
131         found = pair.first.starts_with("testkey") &&
132                 pair.second.starts_with("testvalue");
133         if (found) break;
134       }
135       EXPECT_EQ(found, true);
136       auto* status = methods->GetRecvStatus();
137       EXPECT_EQ(status->ok(), true);
138     }
139     if (methods->QueryInterceptionHookPoint(
140             experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
141       auto* map = methods->GetRecvInitialMetadata();
142       // Got nothing better to do here at the moment
143       EXPECT_EQ(map->size(), 0);
144     }
145     if (methods->QueryInterceptionHookPoint(
146             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
147       // Insert a different message than expected
148       EchoResponse* resp =
149           static_cast<EchoResponse*>(methods->GetRecvMessage());
150       resp->set_message("Hello1");
151     }
152     if (methods->QueryInterceptionHookPoint(
153             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
154       auto* map = methods->GetRecvTrailingMetadata();
155       // insert the metadata that we want
156       EXPECT_EQ(map->size(), 0);
157       map->insert(std::make_pair("testkey", "testvalue"));
158       auto* status = methods->GetRecvStatus();
159       *status = Status(StatusCode::OK, "");
160     }
161     if (hijack) {
162       methods->Hijack();
163     } else {
164       methods->Proceed();
165     }
166   }
167 
168  private:
169   experimental::ClientRpcInfo* info_;
170 };
171 
172 class HijackingInterceptorFactory
173     : public experimental::ClientInterceptorFactoryInterface {
174  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)175   experimental::Interceptor* CreateClientInterceptor(
176       experimental::ClientRpcInfo* info) override {
177     return new HijackingInterceptor(info);
178   }
179 };
180 
181 class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor {
182  public:
HijackingInterceptorMakesAnotherCall(experimental::ClientRpcInfo * info)183   explicit HijackingInterceptorMakesAnotherCall(
184       experimental::ClientRpcInfo* info) {
185     info_ = info;
186     // Make sure it is the right method
187     EXPECT_EQ(strcmp("/grpc.testing.EchoTestService/Echo", info->method()), 0);
188     EXPECT_EQ(strcmp("TestSuffixForStats", info->suffix_for_stats()), 0);
189   }
190 
Intercept(experimental::InterceptorBatchMethods * methods)191   void Intercept(experimental::InterceptorBatchMethods* methods) override {
192     if (methods->QueryInterceptionHookPoint(
193             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
194       auto* map = methods->GetSendInitialMetadata();
195       // Check that we can see the test metadata
196       ASSERT_EQ(map->size(), 1);
197       auto iterator = map->begin();
198       EXPECT_EQ("testkey", iterator->first);
199       EXPECT_EQ("testvalue", iterator->second);
200       // Make a copy of the map
201       metadata_map_ = *map;
202     }
203     if (methods->QueryInterceptionHookPoint(
204             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
205       EchoRequest req;
206       auto* buffer = methods->GetSerializedSendMessage();
207       auto copied_buffer = *buffer;
208       EXPECT_TRUE(
209           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
210               .ok());
211       EXPECT_EQ(req.message(), "Hello");
212       req_ = req;
213       stub_ = grpc::testing::EchoTestService::NewStub(
214           methods->GetInterceptedChannel());
215       ctx_.AddMetadata(metadata_map_.begin()->first,
216                        metadata_map_.begin()->second);
217       stub_->async()->Echo(&ctx_, &req_, &resp_, [this, methods](Status s) {
218         EXPECT_EQ(s.ok(), true);
219         EXPECT_EQ(resp_.message(), "Hello");
220         methods->Hijack();
221       });
222       // This is a Unary RPC and we have got nothing interesting to do in the
223       // PRE_SEND_CLOSE interception hook point for this interceptor, so let's
224       // return here. (We do not want to call methods->Proceed(). When the new
225       // RPC returns, we will call methods->Hijack() instead.)
226       return;
227     }
228     if (methods->QueryInterceptionHookPoint(
229             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
230       // Got nothing to do here for now
231     }
232     if (methods->QueryInterceptionHookPoint(
233             experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
234       auto* map = methods->GetRecvInitialMetadata();
235       // Got nothing better to do here for now
236       EXPECT_EQ(map->size(), 0);
237     }
238     if (methods->QueryInterceptionHookPoint(
239             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
240       EchoResponse* resp =
241           static_cast<EchoResponse*>(methods->GetRecvMessage());
242       // Check that we got the hijacked message, and re-insert the expected
243       // message
244       EXPECT_EQ(resp->message(), "Hello");
245     }
246     if (methods->QueryInterceptionHookPoint(
247             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
248       auto* map = methods->GetRecvTrailingMetadata();
249       bool found = false;
250       // Check that we received the metadata as an echo
251       for (const auto& pair : *map) {
252         found = pair.first.starts_with("testkey") &&
253                 pair.second.starts_with("testvalue");
254         if (found) break;
255       }
256       EXPECT_EQ(found, true);
257       auto* status = methods->GetRecvStatus();
258       EXPECT_EQ(status->ok(), true);
259     }
260     if (methods->QueryInterceptionHookPoint(
261             experimental::InterceptionHookPoints::PRE_RECV_INITIAL_METADATA)) {
262       auto* map = methods->GetRecvInitialMetadata();
263       // Got nothing better to do here at the moment
264       EXPECT_EQ(map->size(), 0);
265     }
266     if (methods->QueryInterceptionHookPoint(
267             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
268       // Insert a different message than expected
269       EchoResponse* resp =
270           static_cast<EchoResponse*>(methods->GetRecvMessage());
271       resp->set_message(resp_.message());
272     }
273     if (methods->QueryInterceptionHookPoint(
274             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
275       auto* map = methods->GetRecvTrailingMetadata();
276       // insert the metadata that we want
277       EXPECT_EQ(map->size(), 0);
278       map->insert(std::make_pair("testkey", "testvalue"));
279       auto* status = methods->GetRecvStatus();
280       *status = Status(StatusCode::OK, "");
281     }
282 
283     methods->Proceed();
284   }
285 
286  private:
287   experimental::ClientRpcInfo* info_;
288   std::multimap<std::string, std::string> metadata_map_;
289   ClientContext ctx_;
290   EchoRequest req_;
291   EchoResponse resp_;
292   std::unique_ptr<grpc::testing::EchoTestService::Stub> stub_;
293 };
294 
295 class HijackingInterceptorMakesAnotherCallFactory
296     : public experimental::ClientInterceptorFactoryInterface {
297  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)298   experimental::Interceptor* CreateClientInterceptor(
299       experimental::ClientRpcInfo* info) override {
300     return new HijackingInterceptorMakesAnotherCall(info);
301   }
302 };
303 
304 class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
305  public:
BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)306   explicit BidiStreamingRpcHijackingInterceptor(
307       experimental::ClientRpcInfo* info) {
308     info_ = info;
309     EXPECT_EQ(info->suffix_for_stats(), nullptr);
310   }
311 
Intercept(experimental::InterceptorBatchMethods * methods)312   void Intercept(experimental::InterceptorBatchMethods* methods) override {
313     bool hijack = false;
314     if (methods->QueryInterceptionHookPoint(
315             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
316       CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
317       hijack = true;
318     }
319     if (methods->QueryInterceptionHookPoint(
320             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
321       EchoRequest req;
322       auto* buffer = methods->GetSerializedSendMessage();
323       auto copied_buffer = *buffer;
324       EXPECT_TRUE(
325           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
326               .ok());
327       EXPECT_EQ(req.message().find("Hello"), 0u);
328       msg = req.message();
329     }
330     if (methods->QueryInterceptionHookPoint(
331             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
332       // Got nothing to do here for now
333     }
334     if (methods->QueryInterceptionHookPoint(
335             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
336       CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
337                     "testvalue");
338       auto* status = methods->GetRecvStatus();
339       EXPECT_EQ(status->ok(), true);
340     }
341     if (methods->QueryInterceptionHookPoint(
342             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
343       EchoResponse* resp =
344           static_cast<EchoResponse*>(methods->GetRecvMessage());
345       resp->set_message(msg);
346     }
347     if (methods->QueryInterceptionHookPoint(
348             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
349       EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
350                     ->message()
351                     .find("Hello"),
352                 0u);
353     }
354     if (methods->QueryInterceptionHookPoint(
355             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
356       auto* map = methods->GetRecvTrailingMetadata();
357       // insert the metadata that we want
358       EXPECT_EQ(map->size(), 0);
359       map->insert(std::make_pair("testkey", "testvalue"));
360       auto* status = methods->GetRecvStatus();
361       *status = Status(StatusCode::OK, "");
362     }
363     if (hijack) {
364       methods->Hijack();
365     } else {
366       methods->Proceed();
367     }
368   }
369 
370  private:
371   experimental::ClientRpcInfo* info_;
372   std::string msg;
373 };
374 
375 class ClientStreamingRpcHijackingInterceptor
376     : public experimental::Interceptor {
377  public:
ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)378   explicit ClientStreamingRpcHijackingInterceptor(
379       experimental::ClientRpcInfo* info) {
380     info_ = info;
381     EXPECT_EQ(
382         strcmp("/grpc.testing.EchoTestService/RequestStream", info->method()),
383         0);
384     EXPECT_EQ(strcmp("TestSuffixForStats", info->suffix_for_stats()), 0);
385   }
Intercept(experimental::InterceptorBatchMethods * methods)386   void Intercept(experimental::InterceptorBatchMethods* methods) override {
387     bool hijack = false;
388     if (methods->QueryInterceptionHookPoint(
389             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
390       hijack = true;
391     }
392     if (methods->QueryInterceptionHookPoint(
393             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
394       if (++count_ > 10) {
395         methods->FailHijackedSendMessage();
396       }
397     }
398     if (methods->QueryInterceptionHookPoint(
399             experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
400       EXPECT_FALSE(got_failed_send_);
401       got_failed_send_ = !methods->GetSendMessageStatus();
402     }
403     if (methods->QueryInterceptionHookPoint(
404             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
405       auto* status = methods->GetRecvStatus();
406       *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
407     }
408     if (hijack) {
409       methods->Hijack();
410     } else {
411       methods->Proceed();
412     }
413   }
414 
GotFailedSend()415   static bool GotFailedSend() { return got_failed_send_; }
416 
417  private:
418   experimental::ClientRpcInfo* info_;
419   int count_ = 0;
420   static bool got_failed_send_;
421 };
422 
423 bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
424 
425 class ClientStreamingRpcHijackingInterceptorFactory
426     : public experimental::ClientInterceptorFactoryInterface {
427  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)428   experimental::Interceptor* CreateClientInterceptor(
429       experimental::ClientRpcInfo* info) override {
430     return new ClientStreamingRpcHijackingInterceptor(info);
431   }
432 };
433 
434 class ServerStreamingRpcHijackingInterceptor
435     : public experimental::Interceptor {
436  public:
ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo * info)437   explicit ServerStreamingRpcHijackingInterceptor(
438       experimental::ClientRpcInfo* info) {
439     info_ = info;
440     got_failed_message_ = false;
441     EXPECT_EQ(info->suffix_for_stats(), nullptr);
442   }
443 
Intercept(experimental::InterceptorBatchMethods * methods)444   void Intercept(experimental::InterceptorBatchMethods* methods) override {
445     bool hijack = false;
446     if (methods->QueryInterceptionHookPoint(
447             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
448       auto* map = methods->GetSendInitialMetadata();
449       // Check that we can see the test metadata
450       ASSERT_EQ(map->size(), 1);
451       auto iterator = map->begin();
452       EXPECT_EQ("testkey", iterator->first);
453       EXPECT_EQ("testvalue", iterator->second);
454       hijack = true;
455     }
456     if (methods->QueryInterceptionHookPoint(
457             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
458       EchoRequest req;
459       auto* buffer = methods->GetSerializedSendMessage();
460       auto copied_buffer = *buffer;
461       EXPECT_TRUE(
462           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
463               .ok());
464       EXPECT_EQ(req.message(), "Hello");
465     }
466     if (methods->QueryInterceptionHookPoint(
467             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
468       // Got nothing to do here for now
469     }
470     if (methods->QueryInterceptionHookPoint(
471             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
472       auto* map = methods->GetRecvTrailingMetadata();
473       bool found = false;
474       // Check that we received the metadata as an echo
475       for (const auto& pair : *map) {
476         found = pair.first.starts_with("testkey") &&
477                 pair.second.starts_with("testvalue");
478         if (found) break;
479       }
480       EXPECT_EQ(found, true);
481       auto* status = methods->GetRecvStatus();
482       EXPECT_EQ(status->ok(), true);
483     }
484     if (methods->QueryInterceptionHookPoint(
485             experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
486       if (++count_ > 10) {
487         methods->FailHijackedRecvMessage();
488       }
489       EchoResponse* resp =
490           static_cast<EchoResponse*>(methods->GetRecvMessage());
491       resp->set_message("Hello");
492     }
493     if (methods->QueryInterceptionHookPoint(
494             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
495       // Only the last message will be a failure
496       EXPECT_FALSE(got_failed_message_);
497       got_failed_message_ = methods->GetRecvMessage() == nullptr;
498     }
499     if (methods->QueryInterceptionHookPoint(
500             experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
501       auto* map = methods->GetRecvTrailingMetadata();
502       // insert the metadata that we want
503       EXPECT_EQ(map->size(), 0);
504       map->insert(std::make_pair("testkey", "testvalue"));
505       auto* status = methods->GetRecvStatus();
506       *status = Status(StatusCode::OK, "");
507     }
508     if (hijack) {
509       methods->Hijack();
510     } else {
511       methods->Proceed();
512     }
513   }
514 
GotFailedMessage()515   static bool GotFailedMessage() { return got_failed_message_; }
516 
517  private:
518   experimental::ClientRpcInfo* info_;
519   static bool got_failed_message_;
520   int count_ = 0;
521 };
522 
523 bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false;
524 
525 class ServerStreamingRpcHijackingInterceptorFactory
526     : public experimental::ClientInterceptorFactoryInterface {
527  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)528   experimental::Interceptor* CreateClientInterceptor(
529       experimental::ClientRpcInfo* info) override {
530     return new ServerStreamingRpcHijackingInterceptor(info);
531   }
532 };
533 
534 class BidiStreamingRpcHijackingInterceptorFactory
535     : public experimental::ClientInterceptorFactoryInterface {
536  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)537   experimental::Interceptor* CreateClientInterceptor(
538       experimental::ClientRpcInfo* info) override {
539     return new BidiStreamingRpcHijackingInterceptor(info);
540   }
541 };
542 
543 // The logging interceptor is for testing purposes only. It is used to verify
544 // that all the appropriate hook points are invoked for an RPC. The counts are
545 // reset each time a new object of LoggingInterceptor is created, so only a
546 // single RPC should be made on the channel before calling the Verify methods.
547 class LoggingInterceptor : public experimental::Interceptor {
548  public:
LoggingInterceptor(experimental::ClientRpcInfo *)549   explicit LoggingInterceptor(experimental::ClientRpcInfo* /*info*/) {
550     pre_send_initial_metadata_ = false;
551     pre_send_message_count_ = 0;
552     pre_send_close_ = false;
553     post_recv_initial_metadata_ = false;
554     post_recv_message_count_ = 0;
555     post_recv_status_ = false;
556   }
557 
Intercept(experimental::InterceptorBatchMethods * methods)558   void Intercept(experimental::InterceptorBatchMethods* methods) override {
559     if (methods->QueryInterceptionHookPoint(
560             experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
561       auto* map = methods->GetSendInitialMetadata();
562       // Check that we can see the test metadata
563       ASSERT_EQ(map->size(), 1);
564       auto iterator = map->begin();
565       EXPECT_EQ("testkey", iterator->first);
566       EXPECT_EQ("testvalue", iterator->second);
567       ASSERT_FALSE(pre_send_initial_metadata_);
568       pre_send_initial_metadata_ = true;
569     }
570     if (methods->QueryInterceptionHookPoint(
571             experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
572       EchoRequest req;
573       auto* send_msg = methods->GetSendMessage();
574       if (send_msg == nullptr) {
575         // We did not get the non-serialized form of the message. Get the
576         // serialized form.
577         auto* buffer = methods->GetSerializedSendMessage();
578         auto copied_buffer = *buffer;
579         EchoRequest req;
580         EXPECT_TRUE(
581             SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
582                 .ok());
583         EXPECT_EQ(req.message(), "Hello");
584       } else {
585         EXPECT_EQ(
586             static_cast<const EchoRequest*>(send_msg)->message().find("Hello"),
587             0u);
588       }
589       auto* buffer = methods->GetSerializedSendMessage();
590       auto copied_buffer = *buffer;
591       EXPECT_TRUE(
592           SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
593               .ok());
594       EXPECT_TRUE(req.message().find("Hello") == 0u);
595       pre_send_message_count_++;
596     }
597     if (methods->QueryInterceptionHookPoint(
598             experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
599       // Got nothing to do here for now
600       pre_send_close_ = true;
601     }
602     if (methods->QueryInterceptionHookPoint(
603             experimental::InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) {
604       auto* map = methods->GetRecvInitialMetadata();
605       // Got nothing better to do here for now
606       EXPECT_EQ(map->size(), 0);
607       post_recv_initial_metadata_ = true;
608     }
609     if (methods->QueryInterceptionHookPoint(
610             experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
611       EchoResponse* resp =
612           static_cast<EchoResponse*>(methods->GetRecvMessage());
613       if (resp != nullptr) {
614         EXPECT_TRUE(resp->message().find("Hello") == 0u);
615         post_recv_message_count_++;
616       }
617     }
618     if (methods->QueryInterceptionHookPoint(
619             experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
620       auto* map = methods->GetRecvTrailingMetadata();
621       bool found = false;
622       // Check that we received the metadata as an echo
623       for (const auto& pair : *map) {
624         found = pair.first.starts_with("testkey") &&
625                 pair.second.starts_with("testvalue");
626         if (found) break;
627       }
628       EXPECT_EQ(found, true);
629       auto* status = methods->GetRecvStatus();
630       EXPECT_EQ(status->ok(), true);
631       post_recv_status_ = true;
632     }
633     methods->Proceed();
634   }
635 
VerifyCall(RPCType type)636   static void VerifyCall(RPCType type) {
637     switch (type) {
638       case RPCType::kSyncUnary:
639       case RPCType::kAsyncCQUnary:
640         VerifyUnaryCall();
641         break;
642       case RPCType::kSyncClientStreaming:
643       case RPCType::kAsyncCQClientStreaming:
644         VerifyClientStreamingCall();
645         break;
646       case RPCType::kSyncServerStreaming:
647       case RPCType::kAsyncCQServerStreaming:
648         VerifyServerStreamingCall();
649         break;
650       case RPCType::kSyncBidiStreaming:
651       case RPCType::kAsyncCQBidiStreaming:
652         VerifyBidiStreamingCall();
653         break;
654     }
655   }
656 
VerifyCallCommon()657   static void VerifyCallCommon() {
658     EXPECT_TRUE(pre_send_initial_metadata_);
659     EXPECT_TRUE(pre_send_close_);
660     EXPECT_TRUE(post_recv_initial_metadata_);
661     EXPECT_TRUE(post_recv_status_);
662   }
663 
VerifyUnaryCall()664   static void VerifyUnaryCall() {
665     VerifyCallCommon();
666     EXPECT_EQ(pre_send_message_count_, 1);
667     EXPECT_EQ(post_recv_message_count_, 1);
668   }
669 
VerifyClientStreamingCall()670   static void VerifyClientStreamingCall() {
671     VerifyCallCommon();
672     EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
673     EXPECT_EQ(post_recv_message_count_, 1);
674   }
675 
VerifyServerStreamingCall()676   static void VerifyServerStreamingCall() {
677     VerifyCallCommon();
678     EXPECT_EQ(pre_send_message_count_, 1);
679     EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
680   }
681 
VerifyBidiStreamingCall()682   static void VerifyBidiStreamingCall() {
683     VerifyCallCommon();
684     EXPECT_EQ(pre_send_message_count_, kNumStreamingMessages);
685     EXPECT_EQ(post_recv_message_count_, kNumStreamingMessages);
686   }
687 
688  private:
689   static bool pre_send_initial_metadata_;
690   static int pre_send_message_count_;
691   static bool pre_send_close_;
692   static bool post_recv_initial_metadata_;
693   static int post_recv_message_count_;
694   static bool post_recv_status_;
695 };
696 
697 bool LoggingInterceptor::pre_send_initial_metadata_;
698 int LoggingInterceptor::pre_send_message_count_;
699 bool LoggingInterceptor::pre_send_close_;
700 bool LoggingInterceptor::post_recv_initial_metadata_;
701 int LoggingInterceptor::post_recv_message_count_;
702 bool LoggingInterceptor::post_recv_status_;
703 
704 class LoggingInterceptorFactory
705     : public experimental::ClientInterceptorFactoryInterface {
706  public:
CreateClientInterceptor(experimental::ClientRpcInfo * info)707   experimental::Interceptor* CreateClientInterceptor(
708       experimental::ClientRpcInfo* info) override {
709     return new LoggingInterceptor(info);
710   }
711 };
712 
713 class TestScenario {
714  public:
TestScenario(const ChannelType & channel_type,const RPCType & rpc_type)715   explicit TestScenario(const ChannelType& channel_type,
716                         const RPCType& rpc_type)
717       : channel_type_(channel_type), rpc_type_(rpc_type) {}
718 
channel_type() const719   ChannelType channel_type() const { return channel_type_; }
720 
rpc_type() const721   RPCType rpc_type() const { return rpc_type_; }
722 
723  private:
724   const ChannelType channel_type_;
725   const RPCType rpc_type_;
726 };
727 
CreateTestScenarios()728 std::vector<TestScenario> CreateTestScenarios() {
729   std::vector<TestScenario> scenarios;
730   std::vector<RPCType> rpc_types;
731   rpc_types.emplace_back(RPCType::kSyncUnary);
732   rpc_types.emplace_back(RPCType::kSyncClientStreaming);
733   rpc_types.emplace_back(RPCType::kSyncServerStreaming);
734   rpc_types.emplace_back(RPCType::kSyncBidiStreaming);
735   rpc_types.emplace_back(RPCType::kAsyncCQUnary);
736   rpc_types.emplace_back(RPCType::kAsyncCQServerStreaming);
737   for (const auto& rpc_type : rpc_types) {
738     scenarios.emplace_back(ChannelType::kHttpChannel, rpc_type);
739 // TODO(yashykt): Maybe add support for non-posix sockets too
740 #ifdef GRPC_POSIX_SOCKET
741     scenarios.emplace_back(ChannelType::kFdChannel, rpc_type);
742 #endif  // GRPC_POSIX_SOCKET
743   }
744   return scenarios;
745 }
746 
747 class ParameterizedClientInterceptorsEnd2endTest
748     : public ::testing::TestWithParam<TestScenario> {
749  protected:
ParameterizedClientInterceptorsEnd2endTest()750   ParameterizedClientInterceptorsEnd2endTest() {
751     ServerBuilder builder;
752     builder.RegisterService(&service_);
753     if (GetParam().channel_type() == ChannelType::kHttpChannel) {
754       int port = grpc_pick_unused_port_or_die();
755       server_address_ = "localhost:" + std::to_string(port);
756       builder.AddListeningPort(server_address_, InsecureServerCredentials());
757       server_ = builder.BuildAndStart();
758     }
759 #ifdef GRPC_POSIX_SOCKET
760     else if (GetParam().channel_type() == ChannelType::kFdChannel) {
761       int flags;
762       GPR_ASSERT(socketpair(AF_UNIX, SOCK_STREAM, 0, sv_) == 0);
763       flags = fcntl(sv_[0], F_GETFL, 0);
764       GPR_ASSERT(fcntl(sv_[0], F_SETFL, flags | O_NONBLOCK) == 0);
765       flags = fcntl(sv_[1], F_GETFL, 0);
766       GPR_ASSERT(fcntl(sv_[1], F_SETFL, flags | O_NONBLOCK) == 0);
767       GPR_ASSERT(grpc_set_socket_no_sigpipe_if_possible(sv_[0]) ==
768                  absl::OkStatus());
769       GPR_ASSERT(grpc_set_socket_no_sigpipe_if_possible(sv_[1]) ==
770                  absl::OkStatus());
771       server_ = builder.BuildAndStart();
772       AddInsecureChannelFromFd(server_.get(), sv_[1]);
773     }
774 #endif  // GRPC_POSIX_SOCKET
775   }
776 
~ParameterizedClientInterceptorsEnd2endTest()777   ~ParameterizedClientInterceptorsEnd2endTest() override {
778     server_->Shutdown(grpc_timeout_milliseconds_to_deadline(0));
779   }
780 
CreateClientChannel(std::vector<std::unique_ptr<grpc::experimental::ClientInterceptorFactoryInterface>> creators)781   std::shared_ptr<grpc::Channel> CreateClientChannel(
782       std::vector<std::unique_ptr<
783           grpc::experimental::ClientInterceptorFactoryInterface>>
784           creators) {
785     if (GetParam().channel_type() == ChannelType::kHttpChannel) {
786       return experimental::CreateCustomChannelWithInterceptors(
787           server_address_, InsecureChannelCredentials(), ChannelArguments(),
788           std::move(creators));
789     }
790 #ifdef GRPC_POSIX_SOCKET
791     else if (GetParam().channel_type() == ChannelType::kFdChannel) {
792       return experimental::CreateCustomInsecureChannelWithInterceptorsFromFd(
793           "", sv_[0], ChannelArguments(), std::move(creators));
794     }
795 #endif  // GRPC_POSIX_SOCKET
796     return nullptr;
797   }
798 
SendRPC(const std::shared_ptr<Channel> & channel)799   void SendRPC(const std::shared_ptr<Channel>& channel) {
800     switch (GetParam().rpc_type()) {
801       case RPCType::kSyncUnary:
802         MakeCall(channel);
803         break;
804       case RPCType::kSyncClientStreaming:
805         MakeClientStreamingCall(channel);
806         break;
807       case RPCType::kSyncServerStreaming:
808         MakeServerStreamingCall(channel);
809         break;
810       case RPCType::kSyncBidiStreaming:
811         MakeBidiStreamingCall(channel);
812         break;
813       case RPCType::kAsyncCQUnary:
814         MakeAsyncCQCall(channel);
815         break;
816       case RPCType::kAsyncCQClientStreaming:
817         // TODO(yashykt) : Fill this out
818         break;
819       case RPCType::kAsyncCQServerStreaming:
820         MakeAsyncCQServerStreamingCall(channel);
821         break;
822       case RPCType::kAsyncCQBidiStreaming:
823         // TODO(yashykt) : Fill this out
824         break;
825     }
826   }
827 
828   std::string server_address_;
829   int sv_[2];
830   EchoTestServiceStreamingImpl service_;
831   std::unique_ptr<Server> server_;
832 };
833 
TEST_P(ParameterizedClientInterceptorsEnd2endTest,ClientInterceptorLoggingTest)834 TEST_P(ParameterizedClientInterceptorsEnd2endTest,
835        ClientInterceptorLoggingTest) {
836   ChannelArguments args;
837   PhonyInterceptor::Reset();
838   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
839       creators;
840   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
841   // Add 20 phony interceptors
842   for (auto i = 0; i < 20; i++) {
843     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
844   }
845   auto channel = CreateClientChannel(std::move(creators));
846   SendRPC(channel);
847   LoggingInterceptor::VerifyCall(GetParam().rpc_type());
848   // Make sure all 20 phony interceptors were run
849   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
850 }
851 
852 INSTANTIATE_TEST_SUITE_P(ParameterizedClientInterceptorsEnd2end,
853                          ParameterizedClientInterceptorsEnd2endTest,
854                          ::testing::ValuesIn(CreateTestScenarios()));
855 
856 class ClientInterceptorsEnd2endTest
857     : public ::testing::TestWithParam<TestScenario> {
858  protected:
ClientInterceptorsEnd2endTest()859   ClientInterceptorsEnd2endTest() {
860     int port = grpc_pick_unused_port_or_die();
861 
862     ServerBuilder builder;
863     server_address_ = "localhost:" + std::to_string(port);
864     builder.AddListeningPort(server_address_, InsecureServerCredentials());
865     builder.RegisterService(&service_);
866     server_ = builder.BuildAndStart();
867   }
868 
~ClientInterceptorsEnd2endTest()869   ~ClientInterceptorsEnd2endTest() override { server_->Shutdown(); }
870 
871   std::string server_address_;
872   TestServiceImpl service_;
873   std::unique_ptr<Server> server_;
874 };
875 
TEST_F(ClientInterceptorsEnd2endTest,LameChannelClientInterceptorHijackingTest)876 TEST_F(ClientInterceptorsEnd2endTest,
877        LameChannelClientInterceptorHijackingTest) {
878   ChannelArguments args;
879   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
880       creators;
881   creators.push_back(std::make_unique<HijackingInterceptorFactory>());
882   auto channel = experimental::CreateCustomChannelWithInterceptors(
883       server_address_, nullptr, args, std::move(creators));
884   MakeCall(channel);
885 }
886 
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorHijackingTest)887 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorHijackingTest) {
888   ChannelArguments args;
889   PhonyInterceptor::Reset();
890   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
891       creators;
892   // Add 20 phony interceptors before hijacking interceptor
893   creators.reserve(20);
894   for (auto i = 0; i < 20; i++) {
895     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
896   }
897   creators.push_back(std::make_unique<HijackingInterceptorFactory>());
898   // Add 20 phony interceptors after hijacking interceptor
899   for (auto i = 0; i < 20; i++) {
900     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
901   }
902   auto channel = experimental::CreateCustomChannelWithInterceptors(
903       server_address_, InsecureChannelCredentials(), args, std::move(creators));
904   MakeCall(channel);
905   // Make sure only 20 phony interceptors were run
906   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
907 }
908 
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorLogThenHijackTest)909 TEST_F(ClientInterceptorsEnd2endTest, ClientInterceptorLogThenHijackTest) {
910   ChannelArguments args;
911   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
912       creators;
913   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
914   creators.push_back(std::make_unique<HijackingInterceptorFactory>());
915   auto channel = experimental::CreateCustomChannelWithInterceptors(
916       server_address_, InsecureChannelCredentials(), args, std::move(creators));
917   MakeCall(channel);
918   LoggingInterceptor::VerifyUnaryCall();
919 }
920 
TEST_F(ClientInterceptorsEnd2endTest,ClientInterceptorHijackingMakesAnotherCallTest)921 TEST_F(ClientInterceptorsEnd2endTest,
922        ClientInterceptorHijackingMakesAnotherCallTest) {
923   ChannelArguments args;
924   PhonyInterceptor::Reset();
925   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
926       creators;
927   // Add 5 phony interceptors before hijacking interceptor
928   creators.reserve(5);
929   for (auto i = 0; i < 5; i++) {
930     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
931   }
932   creators.push_back(
933       std::unique_ptr<experimental::ClientInterceptorFactoryInterface>(
934           new HijackingInterceptorMakesAnotherCallFactory()));
935   // Add 7 phony interceptors after hijacking interceptor
936   for (auto i = 0; i < 7; i++) {
937     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
938   }
939   auto channel = server_->experimental().InProcessChannelWithInterceptors(
940       args, std::move(creators));
941 
942   MakeCall(channel, StubOptions("TestSuffixForStats"));
943   // Make sure all interceptors were run once, since the hijacking interceptor
944   // makes an RPC on the intercepted channel
945   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 12);
946 }
947 
948 class ClientInterceptorsCallbackEnd2endTest : public ::testing::Test {
949  protected:
ClientInterceptorsCallbackEnd2endTest()950   ClientInterceptorsCallbackEnd2endTest() {
951     int port = grpc_pick_unused_port_or_die();
952 
953     ServerBuilder builder;
954     server_address_ = "localhost:" + std::to_string(port);
955     builder.AddListeningPort(server_address_, InsecureServerCredentials());
956     builder.RegisterService(&service_);
957     server_ = builder.BuildAndStart();
958   }
959 
~ClientInterceptorsCallbackEnd2endTest()960   ~ClientInterceptorsCallbackEnd2endTest() override { server_->Shutdown(); }
961 
962   std::string server_address_;
963   TestServiceImpl service_;
964   std::unique_ptr<Server> server_;
965 };
966 
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorLoggingTestWithCallback)967 TEST_F(ClientInterceptorsCallbackEnd2endTest,
968        ClientInterceptorLoggingTestWithCallback) {
969   ChannelArguments args;
970   PhonyInterceptor::Reset();
971   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
972       creators;
973   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
974   // Add 20 phony interceptors
975   for (auto i = 0; i < 20; i++) {
976     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
977   }
978   auto channel = server_->experimental().InProcessChannelWithInterceptors(
979       args, std::move(creators));
980   MakeCallbackCall(channel);
981   LoggingInterceptor::VerifyUnaryCall();
982   // Make sure all 20 phony interceptors were run
983   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
984 }
985 
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorHijackingTestWithCallback)986 TEST_F(ClientInterceptorsCallbackEnd2endTest,
987        ClientInterceptorHijackingTestWithCallback) {
988   ChannelArguments args;
989   PhonyInterceptor::Reset();
990   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
991       creators;
992   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
993   // Add 20 phony interceptors
994   for (auto i = 0; i < 20; i++) {
995     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
996   }
997   creators.push_back(std::make_unique<HijackingInterceptorFactory>());
998   auto channel = server_->experimental().InProcessChannelWithInterceptors(
999       args, std::move(creators));
1000   MakeCallbackCall(channel);
1001   LoggingInterceptor::VerifyUnaryCall();
1002   // Make sure all 20 phony interceptors were run
1003   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1004 }
1005 
TEST_F(ClientInterceptorsCallbackEnd2endTest,ClientInterceptorFactoryAllowsNullptrReturn)1006 TEST_F(ClientInterceptorsCallbackEnd2endTest,
1007        ClientInterceptorFactoryAllowsNullptrReturn) {
1008   ChannelArguments args;
1009   PhonyInterceptor::Reset();
1010   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1011       creators;
1012   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1013   // Add 20 phony interceptors and 20 null interceptors
1014   for (auto i = 0; i < 20; i++) {
1015     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1016     creators.push_back(std::make_unique<NullInterceptorFactory>());
1017   }
1018   auto channel = server_->experimental().InProcessChannelWithInterceptors(
1019       args, std::move(creators));
1020   MakeCallbackCall(channel);
1021   LoggingInterceptor::VerifyUnaryCall();
1022   // Make sure all 20 phony interceptors were run
1023   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1024 }
1025 
1026 class ClientInterceptorsStreamingEnd2endTest : public ::testing::Test {
1027  protected:
ClientInterceptorsStreamingEnd2endTest()1028   ClientInterceptorsStreamingEnd2endTest() {
1029     int port = grpc_pick_unused_port_or_die();
1030 
1031     ServerBuilder builder;
1032     server_address_ = "localhost:" + std::to_string(port);
1033     builder.AddListeningPort(server_address_, InsecureServerCredentials());
1034     builder.RegisterService(&service_);
1035     server_ = builder.BuildAndStart();
1036   }
1037 
~ClientInterceptorsStreamingEnd2endTest()1038   ~ClientInterceptorsStreamingEnd2endTest() override { server_->Shutdown(); }
1039 
1040   std::string server_address_;
1041   EchoTestServiceStreamingImpl service_;
1042   std::unique_ptr<Server> server_;
1043 };
1044 
TEST_F(ClientInterceptorsStreamingEnd2endTest,ClientStreamingTest)1045 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingTest) {
1046   ChannelArguments args;
1047   PhonyInterceptor::Reset();
1048   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1049       creators;
1050   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1051   // Add 20 phony interceptors
1052   for (auto i = 0; i < 20; i++) {
1053     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1054   }
1055   auto channel = experimental::CreateCustomChannelWithInterceptors(
1056       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1057   MakeClientStreamingCall(channel);
1058   LoggingInterceptor::VerifyClientStreamingCall();
1059   // Make sure all 20 phony interceptors were run
1060   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1061 }
1062 
TEST_F(ClientInterceptorsStreamingEnd2endTest,ServerStreamingTest)1063 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
1064   ChannelArguments args;
1065   PhonyInterceptor::Reset();
1066   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1067       creators;
1068   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1069   // Add 20 phony interceptors
1070   for (auto i = 0; i < 20; i++) {
1071     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1072   }
1073   auto channel = experimental::CreateCustomChannelWithInterceptors(
1074       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1075   MakeServerStreamingCall(channel);
1076   LoggingInterceptor::VerifyServerStreamingCall();
1077   // Make sure all 20 phony interceptors were run
1078   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1079 }
1080 
TEST_F(ClientInterceptorsStreamingEnd2endTest,ClientStreamingHijackingTest)1081 TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
1082   ChannelArguments args;
1083   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1084       creators;
1085   creators.push_back(
1086       std::make_unique<ClientStreamingRpcHijackingInterceptorFactory>());
1087   auto channel = experimental::CreateCustomChannelWithInterceptors(
1088       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1089 
1090   auto stub = grpc::testing::EchoTestService::NewStub(
1091       channel, StubOptions("TestSuffixForStats"));
1092   ClientContext ctx;
1093   EchoRequest req;
1094   EchoResponse resp;
1095   req.mutable_param()->set_echo_metadata(true);
1096   req.set_message("Hello");
1097   string expected_resp;
1098   auto writer = stub->RequestStream(&ctx, &resp);
1099   for (int i = 0; i < 10; i++) {
1100     EXPECT_TRUE(writer->Write(req));
1101     expected_resp += "Hello";
1102   }
1103   // The interceptor will reject the 11th message
1104   writer->Write(req);
1105   Status s = writer->Finish();
1106   EXPECT_EQ(s.ok(), false);
1107   EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
1108 }
1109 
TEST_F(ClientInterceptorsStreamingEnd2endTest,ServerStreamingHijackingTest)1110 TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
1111   ChannelArguments args;
1112   PhonyInterceptor::Reset();
1113   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1114       creators;
1115   creators.push_back(
1116       std::make_unique<ServerStreamingRpcHijackingInterceptorFactory>());
1117   auto channel = experimental::CreateCustomChannelWithInterceptors(
1118       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1119   MakeServerStreamingCall(channel);
1120   EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
1121 }
1122 
TEST_F(ClientInterceptorsStreamingEnd2endTest,AsyncCQServerStreamingHijackingTest)1123 TEST_F(ClientInterceptorsStreamingEnd2endTest,
1124        AsyncCQServerStreamingHijackingTest) {
1125   ChannelArguments args;
1126   PhonyInterceptor::Reset();
1127   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1128       creators;
1129   creators.push_back(
1130       std::make_unique<ServerStreamingRpcHijackingInterceptorFactory>());
1131   auto channel = experimental::CreateCustomChannelWithInterceptors(
1132       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1133   MakeAsyncCQServerStreamingCall(channel);
1134   EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
1135 }
1136 
TEST_F(ClientInterceptorsStreamingEnd2endTest,BidiStreamingHijackingTest)1137 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
1138   ChannelArguments args;
1139   PhonyInterceptor::Reset();
1140   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1141       creators;
1142   creators.push_back(
1143       std::make_unique<BidiStreamingRpcHijackingInterceptorFactory>());
1144   auto channel = experimental::CreateCustomChannelWithInterceptors(
1145       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1146   MakeBidiStreamingCall(channel);
1147 }
1148 
TEST_F(ClientInterceptorsStreamingEnd2endTest,BidiStreamingTest)1149 TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
1150   ChannelArguments args;
1151   PhonyInterceptor::Reset();
1152   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1153       creators;
1154   creators.push_back(std::make_unique<LoggingInterceptorFactory>());
1155   // Add 20 phony interceptors
1156   for (auto i = 0; i < 20; i++) {
1157     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1158   }
1159   auto channel = experimental::CreateCustomChannelWithInterceptors(
1160       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1161   MakeBidiStreamingCall(channel);
1162   LoggingInterceptor::VerifyBidiStreamingCall();
1163   // Make sure all 20 phony interceptors were run
1164   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1165 }
1166 
1167 class ClientGlobalInterceptorEnd2endTest : public ::testing::Test {
1168  protected:
ClientGlobalInterceptorEnd2endTest()1169   ClientGlobalInterceptorEnd2endTest() {
1170     int port = grpc_pick_unused_port_or_die();
1171 
1172     ServerBuilder builder;
1173     server_address_ = "localhost:" + std::to_string(port);
1174     builder.AddListeningPort(server_address_, InsecureServerCredentials());
1175     builder.RegisterService(&service_);
1176     server_ = builder.BuildAndStart();
1177   }
1178 
~ClientGlobalInterceptorEnd2endTest()1179   ~ClientGlobalInterceptorEnd2endTest() override { server_->Shutdown(); }
1180 
1181   std::string server_address_;
1182   TestServiceImpl service_;
1183   std::unique_ptr<Server> server_;
1184 };
1185 
TEST_F(ClientGlobalInterceptorEnd2endTest,PhonyGlobalInterceptor)1186 TEST_F(ClientGlobalInterceptorEnd2endTest, PhonyGlobalInterceptor) {
1187   // We should ideally be registering a global interceptor only once per
1188   // process, but for the purposes of testing, it should be fine to modify the
1189   // registered global interceptor when there are no ongoing gRPC operations
1190   PhonyInterceptorFactory global_factory;
1191   experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1192   ChannelArguments args;
1193   PhonyInterceptor::Reset();
1194   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1195       creators;
1196   // Add 20 phony interceptors
1197   creators.reserve(20);
1198   for (auto i = 0; i < 20; i++) {
1199     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1200   }
1201   auto channel = experimental::CreateCustomChannelWithInterceptors(
1202       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1203   MakeCall(channel);
1204   // Make sure all 20 phony interceptors were run with the global interceptor
1205   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 21);
1206   experimental::TestOnlyResetGlobalClientInterceptorFactory();
1207 }
1208 
TEST_F(ClientGlobalInterceptorEnd2endTest,LoggingGlobalInterceptor)1209 TEST_F(ClientGlobalInterceptorEnd2endTest, LoggingGlobalInterceptor) {
1210   // We should ideally be registering a global interceptor only once per
1211   // process, but for the purposes of testing, it should be fine to modify the
1212   // registered global interceptor when there are no ongoing gRPC operations
1213   LoggingInterceptorFactory global_factory;
1214   experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1215   ChannelArguments args;
1216   PhonyInterceptor::Reset();
1217   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1218       creators;
1219   // Add 20 phony interceptors
1220   creators.reserve(20);
1221   for (auto i = 0; i < 20; i++) {
1222     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1223   }
1224   auto channel = experimental::CreateCustomChannelWithInterceptors(
1225       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1226   MakeCall(channel);
1227   LoggingInterceptor::VerifyUnaryCall();
1228   // Make sure all 20 phony interceptors were run
1229   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1230   experimental::TestOnlyResetGlobalClientInterceptorFactory();
1231 }
1232 
TEST_F(ClientGlobalInterceptorEnd2endTest,HijackingGlobalInterceptor)1233 TEST_F(ClientGlobalInterceptorEnd2endTest, HijackingGlobalInterceptor) {
1234   // We should ideally be registering a global interceptor only once per
1235   // process, but for the purposes of testing, it should be fine to modify the
1236   // registered global interceptor when there are no ongoing gRPC operations
1237   HijackingInterceptorFactory global_factory;
1238   experimental::RegisterGlobalClientInterceptorFactory(&global_factory);
1239   ChannelArguments args;
1240   PhonyInterceptor::Reset();
1241   std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
1242       creators;
1243   // Add 20 phony interceptors
1244   creators.reserve(20);
1245   for (auto i = 0; i < 20; i++) {
1246     creators.push_back(std::make_unique<PhonyInterceptorFactory>());
1247   }
1248   auto channel = experimental::CreateCustomChannelWithInterceptors(
1249       server_address_, InsecureChannelCredentials(), args, std::move(creators));
1250   MakeCall(channel);
1251   // Make sure all 20 phony interceptors were run
1252   EXPECT_EQ(PhonyInterceptor::GetNumTimesRun(), 20);
1253   experimental::TestOnlyResetGlobalClientInterceptorFactory();
1254 }
1255 
1256 }  // namespace
1257 }  // namespace testing
1258 }  // namespace grpc
1259 
main(int argc,char ** argv)1260 int main(int argc, char** argv) {
1261   grpc::testing::TestEnvironment env(&argc, argv);
1262   ::testing::InitGoogleTest(&argc, argv);
1263   int ret = RUN_ALL_TESTS();
1264   // Make sure that gRPC shuts down cleanly
1265   GPR_ASSERT(grpc_wait_until_shutdown(10));
1266   return ret;
1267 }
1268