1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <grpc/support/port_platform.h>
16
17 #include "src/core/lib/transport/call_filters.h"
18
19 #include "src/core/lib/gprpp/crash.h"
20 #include "src/core/lib/transport/metadata.h"
21
22 namespace grpc_core {
23
24 namespace {
Offset(void * base,size_t amt)25 void* Offset(void* base, size_t amt) { return static_cast<char*>(base) + amt; }
26 } // namespace
27
28 namespace filters_detail {
29
30 template <typename T>
~OperationExecutor()31 OperationExecutor<T>::~OperationExecutor() {
32 if (promise_data_ != nullptr) {
33 ops_->early_destroy(promise_data_);
34 gpr_free_aligned(promise_data_);
35 }
36 }
37
38 template <typename T>
Start(const Layout<FallibleOperator<T>> * layout,T input,void * call_data)39 Poll<ResultOr<T>> OperationExecutor<T>::Start(
40 const Layout<FallibleOperator<T>>* layout, T input, void* call_data) {
41 ops_ = layout->ops.data();
42 end_ops_ = ops_ + layout->ops.size();
43 if (layout->promise_size == 0) {
44 // No call state ==> instantaneously ready
45 auto r = InitStep(std::move(input), call_data);
46 GPR_ASSERT(r.ready());
47 return r;
48 }
49 promise_data_ =
50 gpr_malloc_aligned(layout->promise_size, layout->promise_alignment);
51 return InitStep(std::move(input), call_data);
52 }
53
54 template <typename T>
InitStep(T input,void * call_data)55 Poll<ResultOr<T>> OperationExecutor<T>::InitStep(T input, void* call_data) {
56 while (true) {
57 if (ops_ == end_ops_) {
58 return ResultOr<T>{std::move(input), nullptr};
59 }
60 auto p =
61 ops_->promise_init(promise_data_, Offset(call_data, ops_->call_offset),
62 ops_->channel_data, std::move(input));
63 if (auto* r = p.value_if_ready()) {
64 if (r->ok == nullptr) return std::move(*r);
65 input = std::move(r->ok);
66 ++ops_;
67 continue;
68 }
69 return Pending{};
70 }
71 }
72
73 template <typename T>
Step(void * call_data)74 Poll<ResultOr<T>> OperationExecutor<T>::Step(void* call_data) {
75 GPR_DEBUG_ASSERT(promise_data_ != nullptr);
76 auto p = ContinueStep(call_data);
77 if (p.ready()) {
78 gpr_free_aligned(promise_data_);
79 promise_data_ = nullptr;
80 }
81 return p;
82 }
83
84 template <typename T>
ContinueStep(void * call_data)85 Poll<ResultOr<T>> OperationExecutor<T>::ContinueStep(void* call_data) {
86 auto p = ops_->poll(promise_data_);
87 if (auto* r = p.value_if_ready()) {
88 if (r->ok == nullptr) return std::move(*r);
89 ++ops_;
90 return InitStep(std::move(r->ok), call_data);
91 }
92 return Pending{};
93 }
94
95 template <typename T>
~InfallibleOperationExecutor()96 InfallibleOperationExecutor<T>::~InfallibleOperationExecutor() {
97 if (promise_data_ != nullptr) {
98 ops_->early_destroy(promise_data_);
99 gpr_free_aligned(promise_data_);
100 }
101 }
102
103 template <typename T>
Start(const Layout<InfallibleOperator<T>> * layout,T input,void * call_data)104 Poll<T> InfallibleOperationExecutor<T>::Start(
105 const Layout<InfallibleOperator<T>>* layout, T input, void* call_data) {
106 ops_ = layout->ops.data();
107 end_ops_ = ops_ + layout->ops.size();
108 if (layout->promise_size == 0) {
109 // No call state ==> instantaneously ready
110 auto r = InitStep(std::move(input), call_data);
111 GPR_ASSERT(r.ready());
112 return r;
113 }
114 promise_data_ =
115 gpr_malloc_aligned(layout->promise_size, layout->promise_alignment);
116 return InitStep(std::move(input), call_data);
117 }
118
119 template <typename T>
InitStep(T input,void * call_data)120 Poll<T> InfallibleOperationExecutor<T>::InitStep(T input, void* call_data) {
121 while (true) {
122 if (ops_ == end_ops_) {
123 return input;
124 }
125 auto p =
126 ops_->promise_init(promise_data_, Offset(call_data, ops_->call_offset),
127 ops_->channel_data, std::move(input));
128 if (auto* r = p.value_if_ready()) {
129 input = std::move(*r);
130 ++ops_;
131 continue;
132 }
133 return Pending{};
134 }
135 }
136
137 template <typename T>
Step(void * call_data)138 Poll<T> InfallibleOperationExecutor<T>::Step(void* call_data) {
139 GPR_DEBUG_ASSERT(promise_data_ != nullptr);
140 auto p = ContinueStep(call_data);
141 if (p.ready()) {
142 gpr_free_aligned(promise_data_);
143 promise_data_ = nullptr;
144 }
145 return p;
146 }
147
148 template <typename T>
ContinueStep(void * call_data)149 Poll<T> InfallibleOperationExecutor<T>::ContinueStep(void* call_data) {
150 auto p = ops_->poll(promise_data_);
151 if (auto* r = p.value_if_ready()) {
152 ++ops_;
153 return InitStep(std::move(*r), call_data);
154 }
155 return Pending{};
156 }
157
158 // Explicit instantiations of some types used in filters.h
159 // We'll need to add ServerMetadataHandle to this when it becomes different
160 // to ClientMetadataHandle
161 template class OperationExecutor<ClientMetadataHandle>;
162 template class OperationExecutor<MessageHandle>;
163 template class InfallibleOperationExecutor<ServerMetadataHandle>;
164 } // namespace filters_detail
165
166 ///////////////////////////////////////////////////////////////////////////////
167 // CallFilters
168
CallFilters(ClientMetadataHandle client_initial_metadata)169 CallFilters::CallFilters(ClientMetadataHandle client_initial_metadata)
170 : stack_(nullptr),
171 call_data_(nullptr),
172 client_initial_metadata_(std::move(client_initial_metadata)) {}
173
~CallFilters()174 CallFilters::~CallFilters() {
175 if (call_data_ != nullptr) {
176 for (const auto& destructor : stack_->data_.filter_destructor) {
177 destructor.call_destroy(Offset(call_data_, destructor.call_offset));
178 }
179 gpr_free_aligned(call_data_);
180 }
181 }
182
SetStack(RefCountedPtr<Stack> stack)183 void CallFilters::SetStack(RefCountedPtr<Stack> stack) {
184 GPR_ASSERT(call_data_ == nullptr);
185 stack_ = std::move(stack);
186 call_data_ = gpr_malloc_aligned(stack_->data_.call_data_size,
187 stack_->data_.call_data_alignment);
188 for (const auto& constructor : stack_->data_.filter_constructor) {
189 constructor.call_init(Offset(call_data_, constructor.call_offset),
190 constructor.channel_data);
191 }
192 client_initial_metadata_state_.Start();
193 client_to_server_message_state_.Start();
194 server_initial_metadata_state_.Start();
195 server_to_client_message_state_.Start();
196 }
197
Finalize(const grpc_call_final_info * final_info)198 void CallFilters::Finalize(const grpc_call_final_info* final_info) {
199 for (auto& finalizer : stack_->data_.finalizers) {
200 finalizer.final(Offset(call_data_, finalizer.call_offset),
201 finalizer.channel_data, final_info);
202 }
203 }
204
CancelDueToFailedPipeOperation(SourceLocation but_where)205 void CallFilters::CancelDueToFailedPipeOperation(SourceLocation but_where) {
206 // We expect something cancelled before now
207 if (server_trailing_metadata_ == nullptr) return;
208 if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_promise_primitives)) {
209 gpr_log(but_where.file(), but_where.line(), GPR_LOG_SEVERITY_DEBUG,
210 "Cancelling due to failed pipe operation: %s",
211 DebugString().c_str());
212 }
213 server_trailing_metadata_ =
214 ServerMetadataFromStatus(absl::CancelledError("Failed pipe operation"));
215 server_trailing_metadata_waiter_.Wake();
216 }
217
PushServerTrailingMetadata(ServerMetadataHandle md)218 void CallFilters::PushServerTrailingMetadata(ServerMetadataHandle md) {
219 if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_promise_primitives)) {
220 gpr_log(GPR_DEBUG, "%s Push server trailing metadata: %s into %s",
221 GetContext<Activity>()->DebugTag().c_str(),
222 md->DebugString().c_str(), DebugString().c_str());
223 }
224 GPR_ASSERT(md != nullptr);
225 if (server_trailing_metadata_ != nullptr) return;
226 server_trailing_metadata_ = std::move(md);
227 client_initial_metadata_state_.CloseWithError();
228 server_initial_metadata_state_.CloseSending();
229 client_to_server_message_state_.CloseWithError();
230 server_to_client_message_state_.CloseWithError();
231 server_trailing_metadata_waiter_.Wake();
232 }
233
DebugString() const234 std::string CallFilters::DebugString() const {
235 std::vector<std::string> components = {
236 absl::StrFormat("this:%p", this),
237 absl::StrCat("client_initial_metadata:",
238 client_initial_metadata_state_.DebugString()),
239 ServerInitialMetadataPromises::DebugString("server_initial_metadata",
240 this),
241 ClientToServerMessagePromises::DebugString("client_to_server_message",
242 this),
243 ServerToClientMessagePromises::DebugString("server_to_client_message",
244 this),
245 absl::StrCat("server_trailing_metadata:",
246 server_trailing_metadata_ == nullptr
247 ? "not-set"
248 : server_trailing_metadata_->DebugString())};
249 return absl::StrCat("CallFilters{", absl::StrJoin(components, ", "), "}");
250 };
251
252 ///////////////////////////////////////////////////////////////////////////////
253 // CallFilters::Stack
254
~Stack()255 CallFilters::Stack::~Stack() {
256 for (auto& destructor : data_.channel_data_destructors) {
257 destructor.destroy(destructor.channel_data);
258 }
259 }
260
261 ///////////////////////////////////////////////////////////////////////////////
262 // CallFilters::StackBuilder
263
~StackBuilder()264 CallFilters::StackBuilder::~StackBuilder() {
265 for (auto& destructor : data_.channel_data_destructors) {
266 destructor.destroy(destructor.channel_data);
267 }
268 }
269
Build()270 RefCountedPtr<CallFilters::Stack> CallFilters::StackBuilder::Build() {
271 if (data_.call_data_size % data_.call_data_alignment != 0) {
272 data_.call_data_size += data_.call_data_alignment -
273 data_.call_data_size % data_.call_data_alignment;
274 }
275 // server -> client needs to be reversed so that we can iterate all stacks
276 // in the same order
277 data_.server_initial_metadata.Reverse();
278 data_.server_to_client_messages.Reverse();
279 data_.server_trailing_metadata.Reverse();
280 return RefCountedPtr<Stack>(new Stack(std::move(data_)));
281 }
282
283 ///////////////////////////////////////////////////////////////////////////////
284 // CallFilters::PipeState
285
Start()286 void filters_detail::PipeState::Start() {
287 GPR_DEBUG_ASSERT(!started_);
288 started_ = true;
289 wait_recv_.Wake();
290 }
291
CloseWithError()292 void filters_detail::PipeState::CloseWithError() {
293 if (state_ == ValueState::kClosed) return;
294 state_ = ValueState::kError;
295 wait_recv_.Wake();
296 wait_send_.Wake();
297 }
298
PollClosed()299 Poll<bool> filters_detail::PipeState::PollClosed() {
300 switch (state_) {
301 case ValueState::kIdle:
302 case ValueState::kWaiting:
303 case ValueState::kQueued:
304 case ValueState::kReady:
305 case ValueState::kProcessing:
306 return wait_recv_.pending();
307 case ValueState::kClosed:
308 return false;
309 case ValueState::kError:
310 return true;
311 }
312 GPR_UNREACHABLE_CODE(return Pending{});
313 }
314
CloseSending()315 void filters_detail::PipeState::CloseSending() {
316 switch (state_) {
317 case ValueState::kIdle:
318 state_ = ValueState::kClosed;
319 break;
320 case ValueState::kWaiting:
321 state_ = ValueState::kClosed;
322 wait_recv_.Wake();
323 break;
324 case ValueState::kClosed:
325 case ValueState::kError:
326 break;
327 case ValueState::kQueued:
328 case ValueState::kReady:
329 case ValueState::kProcessing:
330 Crash("Only one push allowed to be outstanding");
331 break;
332 }
333 }
334
BeginPush()335 void filters_detail::PipeState::BeginPush() {
336 switch (state_) {
337 case ValueState::kIdle:
338 state_ = ValueState::kQueued;
339 break;
340 case ValueState::kWaiting:
341 state_ = ValueState::kReady;
342 wait_recv_.Wake();
343 break;
344 case ValueState::kClosed:
345 case ValueState::kError:
346 break;
347 case ValueState::kQueued:
348 case ValueState::kReady:
349 case ValueState::kProcessing:
350 Crash("Only one push allowed to be outstanding");
351 break;
352 }
353 }
354
DropPush()355 void filters_detail::PipeState::DropPush() {
356 switch (state_) {
357 case ValueState::kQueued:
358 case ValueState::kReady:
359 case ValueState::kProcessing:
360 case ValueState::kWaiting:
361 state_ = ValueState::kError;
362 wait_recv_.Wake();
363 break;
364 case ValueState::kIdle:
365 case ValueState::kClosed:
366 case ValueState::kError:
367 break;
368 }
369 }
370
DropPull()371 void filters_detail::PipeState::DropPull() {
372 switch (state_) {
373 case ValueState::kQueued:
374 case ValueState::kReady:
375 case ValueState::kProcessing:
376 case ValueState::kWaiting:
377 state_ = ValueState::kError;
378 wait_send_.Wake();
379 break;
380 case ValueState::kIdle:
381 case ValueState::kClosed:
382 case ValueState::kError:
383 break;
384 }
385 }
386
PollPush()387 Poll<StatusFlag> filters_detail::PipeState::PollPush() {
388 switch (state_) {
389 case ValueState::kIdle:
390 // Read completed and new read started => we see waiting here
391 case ValueState::kWaiting:
392 case ValueState::kClosed:
393 return Success{};
394 case ValueState::kQueued:
395 case ValueState::kReady:
396 case ValueState::kProcessing:
397 return wait_send_.pending();
398 case ValueState::kError:
399 return Failure{};
400 }
401 GPR_UNREACHABLE_CODE(return Pending{});
402 }
403
PollPull()404 Poll<ValueOrFailure<bool>> filters_detail::PipeState::PollPull() {
405 switch (state_) {
406 case ValueState::kWaiting:
407 return wait_recv_.pending();
408 case ValueState::kIdle:
409 state_ = ValueState::kWaiting;
410 return wait_recv_.pending();
411 case ValueState::kReady:
412 case ValueState::kQueued:
413 if (!started_) return wait_recv_.pending();
414 state_ = ValueState::kProcessing;
415 return true;
416 case ValueState::kProcessing:
417 Crash("Only one pull allowed to be outstanding");
418 case ValueState::kClosed:
419 return false;
420 case ValueState::kError:
421 return Failure{};
422 }
423 GPR_UNREACHABLE_CODE(return Pending{});
424 }
425
AckPull()426 void filters_detail::PipeState::AckPull() {
427 switch (state_) {
428 case ValueState::kProcessing:
429 state_ = ValueState::kIdle;
430 wait_send_.Wake();
431 break;
432 case ValueState::kWaiting:
433 case ValueState::kIdle:
434 case ValueState::kQueued:
435 case ValueState::kReady:
436 case ValueState::kClosed:
437 Crash("AckPullValue called in invalid state");
438 case ValueState::kError:
439 break;
440 }
441 }
442
DebugString() const443 std::string filters_detail::PipeState::DebugString() const {
444 const char* state_str = "<<invalid-value>>";
445 switch (state_) {
446 case ValueState::kIdle:
447 state_str = "Idle";
448 break;
449 case ValueState::kWaiting:
450 state_str = "Waiting";
451 break;
452 case ValueState::kQueued:
453 state_str = "Queued";
454 break;
455 case ValueState::kReady:
456 state_str = "Ready";
457 break;
458 case ValueState::kProcessing:
459 state_str = "Processing";
460 break;
461 case ValueState::kClosed:
462 state_str = "Closed";
463 break;
464 case ValueState::kError:
465 state_str = "Error";
466 break;
467 }
468 return absl::StrCat(state_str, started_ ? "" : " (not started)");
469 }
470
471 } // namespace grpc_core
472