xref: /aosp_15_r20/external/grpc-grpc/test/core/surface/channel_init_test.cc (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 // Copyright 2023 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/core/lib/surface/channel_init.h"
16 
17 #include <map>
18 #include <string>
19 
20 #include "absl/strings/string_view.h"
21 #include "gtest/gtest.h"
22 
23 #include "src/core/lib/channel/channel_stack.h"
24 #include "src/core/lib/channel/channel_stack_builder_impl.h"
25 #include "src/core/lib/surface/channel_stack_type.h"
26 #include "test/core/util/test_config.h"
27 
28 namespace grpc_core {
29 namespace {
30 
FilterNamed(const char * name)31 const grpc_channel_filter* FilterNamed(const char* name) {
32   static auto* filters =
33       new std::map<absl::string_view, const grpc_channel_filter*>;
34   auto it = filters->find(name);
35   if (it != filters->end()) return it->second;
36   return filters
37       ->emplace(name,
38                 new grpc_channel_filter{nullptr, nullptr, nullptr, nullptr, 0,
39                                         nullptr, nullptr, nullptr, 0, nullptr,
40                                         nullptr, nullptr, nullptr, name})
41       .first->second;
42 }
43 
GetFilterNames(const ChannelInit & init,grpc_channel_stack_type type,const ChannelArgs & args)44 std::vector<std::string> GetFilterNames(const ChannelInit& init,
45                                         grpc_channel_stack_type type,
46                                         const ChannelArgs& args) {
47   ChannelStackBuilderImpl b("test", type, args);
48   if (!init.CreateStack(&b)) return {};
49   std::vector<std::string> names;
50   for (auto f : b.stack()) {
51     names.push_back(f->name);
52   }
53   EXPECT_NE(names, std::vector<std::string>());
54   return names;
55 }
56 
TEST(ChannelInitTest,Empty)57 TEST(ChannelInitTest, Empty) {
58   ChannelInit::Builder b;
59   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
60   auto init = b.Build();
61   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
62             std::vector<std::string>({"terminator"}));
63 }
64 
TEST(ChannelInitTest,OneClientFilter)65 TEST(ChannelInitTest, OneClientFilter) {
66   ChannelInit::Builder b;
67   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
68   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
69   b.RegisterFilter(GRPC_SERVER_CHANNEL, FilterNamed("terminator")).Terminal();
70   auto init = b.Build();
71   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
72             std::vector<std::string>({"foo", "terminator"}));
73   EXPECT_EQ(GetFilterNames(init, GRPC_SERVER_CHANNEL, ChannelArgs()),
74             std::vector<std::string>({"terminator"}));
75 }
76 
TEST(ChannelInitTest,DefaultLexicalOrdering)77 TEST(ChannelInitTest, DefaultLexicalOrdering) {
78   // ChannelInit defaults to lexical ordering in the absense of other
79   // constraints, to ensure that a stable ordering is produced between builds.
80   ChannelInit::Builder b;
81   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
82   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"));
83   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
84   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
85   auto init = b.Build();
86   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
87             std::vector<std::string>({"bar", "baz", "foo", "aaa"}));
88 }
89 
TEST(ChannelInitTest,AfterConstraintsApply)90 TEST(ChannelInitTest, AfterConstraintsApply) {
91   ChannelInit::Builder b;
92   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
93   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"))
94       .After({FilterNamed("foo")});
95   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
96   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
97   auto init = b.Build();
98   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
99             std::vector<std::string>({"baz", "foo", "bar", "aaa"}));
100 }
101 
TEST(ChannelInitTest,BeforeConstraintsApply)102 TEST(ChannelInitTest, BeforeConstraintsApply) {
103   ChannelInit::Builder b;
104   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"))
105       .Before({FilterNamed("bar")});
106   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"));
107   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
108   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
109   auto init = b.Build();
110   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
111             std::vector<std::string>({"baz", "foo", "bar", "aaa"}));
112 }
113 
TEST(ChannelInitTest,PredicatesCanFilter)114 TEST(ChannelInitTest, PredicatesCanFilter) {
115   ChannelInit::Builder b;
116   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"))
117       .IfChannelArg("foo", true);
118   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"))
119       .IfChannelArg("bar", false);
120   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
121   auto init = b.Build();
122   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
123             std::vector<std::string>({"foo", "aaa"}));
124   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL,
125                            ChannelArgs().Set("foo", false)),
126             std::vector<std::string>({"aaa"}));
127   EXPECT_EQ(
128       GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs().Set("bar", true)),
129       std::vector<std::string>({"bar", "foo", "aaa"}));
130   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL,
131                            ChannelArgs().Set("bar", true).Set("foo", false)),
132             std::vector<std::string>({"bar", "aaa"}));
133 }
134 
TEST(ChannelInitTest,CanAddTerminalFilter)135 TEST(ChannelInitTest, CanAddTerminalFilter) {
136   ChannelInit::Builder b;
137   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
138   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar")).Terminal();
139   auto init = b.Build();
140   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
141             std::vector<std::string>({"foo", "bar"}));
142 }
143 
TEST(ChannelInitTest,CanAddMultipleTerminalFilters)144 TEST(ChannelInitTest, CanAddMultipleTerminalFilters) {
145   ChannelInit::Builder b;
146   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
147   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"))
148       .Terminal()
149       .IfChannelArg("bar", false);
150   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"))
151       .Terminal()
152       .IfChannelArg("baz", false);
153   auto init = b.Build();
154   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
155             std::vector<std::string>());
156   EXPECT_EQ(
157       GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs().Set("bar", true)),
158       std::vector<std::string>({"foo", "bar"}));
159   EXPECT_EQ(
160       GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs().Set("baz", true)),
161       std::vector<std::string>({"foo", "baz"}));
162   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL,
163                            ChannelArgs().Set("bar", true).Set("baz", true)),
164             std::vector<std::string>());
165 }
166 
TEST(ChannelInitTest,CanAddBeforeAllOnce)167 TEST(ChannelInitTest, CanAddBeforeAllOnce) {
168   ChannelInit::Builder b;
169   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo")).BeforeAll();
170   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"));
171   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
172   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
173   EXPECT_EQ(GetFilterNames(b.Build(), GRPC_CLIENT_CHANNEL, ChannelArgs()),
174             std::vector<std::string>({"foo", "bar", "baz", "aaa"}));
175 }
176 
TEST(ChannelInitDeathTest,CanAddBeforeAllTwice)177 TEST(ChannelInitDeathTest, CanAddBeforeAllTwice) {
178   GTEST_FLAG_SET(death_test_style, "threadsafe");
179   ChannelInit::Builder b;
180   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo")).BeforeAll();
181   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar")).BeforeAll();
182   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
183   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
184   EXPECT_DEATH_IF_SUPPORTED(b.Build(), "Unresolvable graph of channel filters");
185 }
186 
TEST(ChannelInitTest,CanPostProcessFilters)187 TEST(ChannelInitTest, CanPostProcessFilters) {
188   ChannelInit::Builder b;
189   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
190   b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
191   int called_post_processor = 0;
192   b.RegisterPostProcessor(
193       GRPC_CLIENT_CHANNEL,
194       ChannelInit::PostProcessorSlot::kXdsChannelStackModifier,
195       [&called_post_processor](ChannelStackBuilder& b) {
196         ++called_post_processor;
197         b.mutable_stack()->push_back(FilterNamed("bar"));
198       });
199   auto init = b.Build();
200   EXPECT_EQ(called_post_processor, 0);
201   EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
202             std::vector<std::string>({"foo", "aaa", "bar"}));
203 }
204 
205 class TestFilter1 {
206  public:
TestFilter1(int * p)207   explicit TestFilter1(int* p) : p_(p) {}
208 
Create(const ChannelArgs & args,Empty)209   static absl::StatusOr<TestFilter1> Create(const ChannelArgs& args, Empty) {
210     EXPECT_EQ(args.GetInt("foo"), 1);
211     return TestFilter1(args.GetPointer<int>("p"));
212   }
213 
214   static const grpc_channel_filter kFilter;
215 
216   class Call {
217    public:
Call(TestFilter1 * filter)218     explicit Call(TestFilter1* filter) {
219       EXPECT_EQ(*filter->x_, 0);
220       *filter->x_ = 1;
221       ++*filter->p_;
222     }
223     static const NoInterceptor OnClientInitialMetadata;
224     static const NoInterceptor OnServerInitialMetadata;
225     static const NoInterceptor OnServerTrailingMetadata;
226     static const NoInterceptor OnClientToServerMessage;
227     static const NoInterceptor OnServerToClientMessage;
228     static const NoInterceptor OnFinalize;
229   };
230 
231  private:
232   std::unique_ptr<int> x_ = std::make_unique<int>(0);
233   int* const p_;
234 };
235 
236 const grpc_channel_filter TestFilter1::kFilter = {
237     nullptr, nullptr, nullptr, nullptr, 0,       nullptr, nullptr,
238     nullptr, 0,       nullptr, nullptr, nullptr, nullptr, "test_filter1"};
239 const NoInterceptor TestFilter1::Call::OnClientInitialMetadata;
240 const NoInterceptor TestFilter1::Call::OnServerInitialMetadata;
241 const NoInterceptor TestFilter1::Call::OnServerTrailingMetadata;
242 const NoInterceptor TestFilter1::Call::OnClientToServerMessage;
243 const NoInterceptor TestFilter1::Call::OnServerToClientMessage;
244 const NoInterceptor TestFilter1::Call::OnFinalize;
245 
TEST(ChannelInitTest,CanCreateFilterWithCall)246 TEST(ChannelInitTest, CanCreateFilterWithCall) {
247   ChannelInit::Builder b;
248   b.RegisterFilter<TestFilter1>(GRPC_CLIENT_CHANNEL);
249   auto init = b.Build();
250   int p = 0;
251   auto segment = init.CreateStackSegment(
252       GRPC_CLIENT_CHANNEL,
253       ChannelArgs().Set("foo", 1).Set("p", ChannelArgs::UnownedPointer(&p)));
254   ASSERT_TRUE(segment.ok()) << segment.status();
255   CallFilters::StackBuilder stack_builder;
256   segment->AddToCallFilterStack(stack_builder);
257   segment = absl::CancelledError();  // force the segment to be destroyed
258   auto stack = stack_builder.Build();
259   {
260     CallFilters call_filters(Arena::MakePooled<ClientMetadata>());
261     call_filters.SetStack(std::move(stack));
262   }
263   EXPECT_EQ(p, 1);
264 }
265 
266 }  // namespace
267 }  // namespace grpc_core
268 
main(int argc,char ** argv)269 int main(int argc, char** argv) {
270   grpc::testing::TestEnvironment env(&argc, argv);
271   ::testing::InitGoogleTest(&argc, argv);
272   grpc::testing::TestGrpcScope grpc_scope;
273   return RUN_ALL_TESTS();
274 }
275