1 // Copyright 2023 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/core/lib/surface/channel_init.h"
16
17 #include <map>
18 #include <string>
19
20 #include "absl/strings/string_view.h"
21 #include "gtest/gtest.h"
22
23 #include "src/core/lib/channel/channel_stack.h"
24 #include "src/core/lib/channel/channel_stack_builder_impl.h"
25 #include "src/core/lib/surface/channel_stack_type.h"
26 #include "test/core/util/test_config.h"
27
28 namespace grpc_core {
29 namespace {
30
FilterNamed(const char * name)31 const grpc_channel_filter* FilterNamed(const char* name) {
32 static auto* filters =
33 new std::map<absl::string_view, const grpc_channel_filter*>;
34 auto it = filters->find(name);
35 if (it != filters->end()) return it->second;
36 return filters
37 ->emplace(name,
38 new grpc_channel_filter{nullptr, nullptr, nullptr, nullptr, 0,
39 nullptr, nullptr, nullptr, 0, nullptr,
40 nullptr, nullptr, nullptr, name})
41 .first->second;
42 }
43
GetFilterNames(const ChannelInit & init,grpc_channel_stack_type type,const ChannelArgs & args)44 std::vector<std::string> GetFilterNames(const ChannelInit& init,
45 grpc_channel_stack_type type,
46 const ChannelArgs& args) {
47 ChannelStackBuilderImpl b("test", type, args);
48 if (!init.CreateStack(&b)) return {};
49 std::vector<std::string> names;
50 for (auto f : b.stack()) {
51 names.push_back(f->name);
52 }
53 EXPECT_NE(names, std::vector<std::string>());
54 return names;
55 }
56
TEST(ChannelInitTest,Empty)57 TEST(ChannelInitTest, Empty) {
58 ChannelInit::Builder b;
59 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
60 auto init = b.Build();
61 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
62 std::vector<std::string>({"terminator"}));
63 }
64
TEST(ChannelInitTest,OneClientFilter)65 TEST(ChannelInitTest, OneClientFilter) {
66 ChannelInit::Builder b;
67 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
68 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("terminator")).Terminal();
69 b.RegisterFilter(GRPC_SERVER_CHANNEL, FilterNamed("terminator")).Terminal();
70 auto init = b.Build();
71 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
72 std::vector<std::string>({"foo", "terminator"}));
73 EXPECT_EQ(GetFilterNames(init, GRPC_SERVER_CHANNEL, ChannelArgs()),
74 std::vector<std::string>({"terminator"}));
75 }
76
TEST(ChannelInitTest,DefaultLexicalOrdering)77 TEST(ChannelInitTest, DefaultLexicalOrdering) {
78 // ChannelInit defaults to lexical ordering in the absense of other
79 // constraints, to ensure that a stable ordering is produced between builds.
80 ChannelInit::Builder b;
81 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
82 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"));
83 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
84 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
85 auto init = b.Build();
86 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
87 std::vector<std::string>({"bar", "baz", "foo", "aaa"}));
88 }
89
TEST(ChannelInitTest,AfterConstraintsApply)90 TEST(ChannelInitTest, AfterConstraintsApply) {
91 ChannelInit::Builder b;
92 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
93 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"))
94 .After({FilterNamed("foo")});
95 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
96 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
97 auto init = b.Build();
98 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
99 std::vector<std::string>({"baz", "foo", "bar", "aaa"}));
100 }
101
TEST(ChannelInitTest,BeforeConstraintsApply)102 TEST(ChannelInitTest, BeforeConstraintsApply) {
103 ChannelInit::Builder b;
104 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"))
105 .Before({FilterNamed("bar")});
106 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"));
107 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
108 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
109 auto init = b.Build();
110 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
111 std::vector<std::string>({"baz", "foo", "bar", "aaa"}));
112 }
113
TEST(ChannelInitTest,PredicatesCanFilter)114 TEST(ChannelInitTest, PredicatesCanFilter) {
115 ChannelInit::Builder b;
116 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"))
117 .IfChannelArg("foo", true);
118 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"))
119 .IfChannelArg("bar", false);
120 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
121 auto init = b.Build();
122 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
123 std::vector<std::string>({"foo", "aaa"}));
124 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL,
125 ChannelArgs().Set("foo", false)),
126 std::vector<std::string>({"aaa"}));
127 EXPECT_EQ(
128 GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs().Set("bar", true)),
129 std::vector<std::string>({"bar", "foo", "aaa"}));
130 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL,
131 ChannelArgs().Set("bar", true).Set("foo", false)),
132 std::vector<std::string>({"bar", "aaa"}));
133 }
134
TEST(ChannelInitTest,CanAddTerminalFilter)135 TEST(ChannelInitTest, CanAddTerminalFilter) {
136 ChannelInit::Builder b;
137 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
138 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar")).Terminal();
139 auto init = b.Build();
140 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
141 std::vector<std::string>({"foo", "bar"}));
142 }
143
TEST(ChannelInitTest,CanAddMultipleTerminalFilters)144 TEST(ChannelInitTest, CanAddMultipleTerminalFilters) {
145 ChannelInit::Builder b;
146 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
147 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"))
148 .Terminal()
149 .IfChannelArg("bar", false);
150 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"))
151 .Terminal()
152 .IfChannelArg("baz", false);
153 auto init = b.Build();
154 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
155 std::vector<std::string>());
156 EXPECT_EQ(
157 GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs().Set("bar", true)),
158 std::vector<std::string>({"foo", "bar"}));
159 EXPECT_EQ(
160 GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs().Set("baz", true)),
161 std::vector<std::string>({"foo", "baz"}));
162 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL,
163 ChannelArgs().Set("bar", true).Set("baz", true)),
164 std::vector<std::string>());
165 }
166
TEST(ChannelInitTest,CanAddBeforeAllOnce)167 TEST(ChannelInitTest, CanAddBeforeAllOnce) {
168 ChannelInit::Builder b;
169 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo")).BeforeAll();
170 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar"));
171 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
172 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
173 EXPECT_EQ(GetFilterNames(b.Build(), GRPC_CLIENT_CHANNEL, ChannelArgs()),
174 std::vector<std::string>({"foo", "bar", "baz", "aaa"}));
175 }
176
TEST(ChannelInitDeathTest,CanAddBeforeAllTwice)177 TEST(ChannelInitDeathTest, CanAddBeforeAllTwice) {
178 GTEST_FLAG_SET(death_test_style, "threadsafe");
179 ChannelInit::Builder b;
180 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo")).BeforeAll();
181 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("bar")).BeforeAll();
182 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("baz"));
183 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
184 EXPECT_DEATH_IF_SUPPORTED(b.Build(), "Unresolvable graph of channel filters");
185 }
186
TEST(ChannelInitTest,CanPostProcessFilters)187 TEST(ChannelInitTest, CanPostProcessFilters) {
188 ChannelInit::Builder b;
189 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("foo"));
190 b.RegisterFilter(GRPC_CLIENT_CHANNEL, FilterNamed("aaa")).Terminal();
191 int called_post_processor = 0;
192 b.RegisterPostProcessor(
193 GRPC_CLIENT_CHANNEL,
194 ChannelInit::PostProcessorSlot::kXdsChannelStackModifier,
195 [&called_post_processor](ChannelStackBuilder& b) {
196 ++called_post_processor;
197 b.mutable_stack()->push_back(FilterNamed("bar"));
198 });
199 auto init = b.Build();
200 EXPECT_EQ(called_post_processor, 0);
201 EXPECT_EQ(GetFilterNames(init, GRPC_CLIENT_CHANNEL, ChannelArgs()),
202 std::vector<std::string>({"foo", "aaa", "bar"}));
203 }
204
205 class TestFilter1 {
206 public:
TestFilter1(int * p)207 explicit TestFilter1(int* p) : p_(p) {}
208
Create(const ChannelArgs & args,Empty)209 static absl::StatusOr<TestFilter1> Create(const ChannelArgs& args, Empty) {
210 EXPECT_EQ(args.GetInt("foo"), 1);
211 return TestFilter1(args.GetPointer<int>("p"));
212 }
213
214 static const grpc_channel_filter kFilter;
215
216 class Call {
217 public:
Call(TestFilter1 * filter)218 explicit Call(TestFilter1* filter) {
219 EXPECT_EQ(*filter->x_, 0);
220 *filter->x_ = 1;
221 ++*filter->p_;
222 }
223 static const NoInterceptor OnClientInitialMetadata;
224 static const NoInterceptor OnServerInitialMetadata;
225 static const NoInterceptor OnServerTrailingMetadata;
226 static const NoInterceptor OnClientToServerMessage;
227 static const NoInterceptor OnServerToClientMessage;
228 static const NoInterceptor OnFinalize;
229 };
230
231 private:
232 std::unique_ptr<int> x_ = std::make_unique<int>(0);
233 int* const p_;
234 };
235
236 const grpc_channel_filter TestFilter1::kFilter = {
237 nullptr, nullptr, nullptr, nullptr, 0, nullptr, nullptr,
238 nullptr, 0, nullptr, nullptr, nullptr, nullptr, "test_filter1"};
239 const NoInterceptor TestFilter1::Call::OnClientInitialMetadata;
240 const NoInterceptor TestFilter1::Call::OnServerInitialMetadata;
241 const NoInterceptor TestFilter1::Call::OnServerTrailingMetadata;
242 const NoInterceptor TestFilter1::Call::OnClientToServerMessage;
243 const NoInterceptor TestFilter1::Call::OnServerToClientMessage;
244 const NoInterceptor TestFilter1::Call::OnFinalize;
245
TEST(ChannelInitTest,CanCreateFilterWithCall)246 TEST(ChannelInitTest, CanCreateFilterWithCall) {
247 ChannelInit::Builder b;
248 b.RegisterFilter<TestFilter1>(GRPC_CLIENT_CHANNEL);
249 auto init = b.Build();
250 int p = 0;
251 auto segment = init.CreateStackSegment(
252 GRPC_CLIENT_CHANNEL,
253 ChannelArgs().Set("foo", 1).Set("p", ChannelArgs::UnownedPointer(&p)));
254 ASSERT_TRUE(segment.ok()) << segment.status();
255 CallFilters::StackBuilder stack_builder;
256 segment->AddToCallFilterStack(stack_builder);
257 segment = absl::CancelledError(); // force the segment to be destroyed
258 auto stack = stack_builder.Build();
259 {
260 CallFilters call_filters(Arena::MakePooled<ClientMetadata>());
261 call_filters.SetStack(std::move(stack));
262 }
263 EXPECT_EQ(p, 1);
264 }
265
266 } // namespace
267 } // namespace grpc_core
268
main(int argc,char ** argv)269 int main(int argc, char** argv) {
270 grpc::testing::TestEnvironment env(&argc, argv);
271 ::testing::InitGoogleTest(&argc, argv);
272 grpc::testing::TestGrpcScope grpc_scope;
273 return RUN_ALL_TESTS();
274 }
275