xref: /aosp_15_r20/external/grpc-grpc/src/core/lib/transport/call_filters.cc (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <grpc/support/port_platform.h>
16 
17 #include "src/core/lib/transport/call_filters.h"
18 
19 #include "src/core/lib/gprpp/crash.h"
20 #include "src/core/lib/transport/metadata.h"
21 
22 namespace grpc_core {
23 
24 namespace {
Offset(void * base,size_t amt)25 void* Offset(void* base, size_t amt) { return static_cast<char*>(base) + amt; }
26 }  // namespace
27 
28 namespace filters_detail {
29 
30 template <typename T>
~OperationExecutor()31 OperationExecutor<T>::~OperationExecutor() {
32   if (promise_data_ != nullptr) {
33     ops_->early_destroy(promise_data_);
34     gpr_free_aligned(promise_data_);
35   }
36 }
37 
38 template <typename T>
Start(const Layout<FallibleOperator<T>> * layout,T input,void * call_data)39 Poll<ResultOr<T>> OperationExecutor<T>::Start(
40     const Layout<FallibleOperator<T>>* layout, T input, void* call_data) {
41   ops_ = layout->ops.data();
42   end_ops_ = ops_ + layout->ops.size();
43   if (layout->promise_size == 0) {
44     // No call state ==> instantaneously ready
45     auto r = InitStep(std::move(input), call_data);
46     GPR_ASSERT(r.ready());
47     return r;
48   }
49   promise_data_ =
50       gpr_malloc_aligned(layout->promise_size, layout->promise_alignment);
51   return InitStep(std::move(input), call_data);
52 }
53 
54 template <typename T>
InitStep(T input,void * call_data)55 Poll<ResultOr<T>> OperationExecutor<T>::InitStep(T input, void* call_data) {
56   while (true) {
57     if (ops_ == end_ops_) {
58       return ResultOr<T>{std::move(input), nullptr};
59     }
60     auto p =
61         ops_->promise_init(promise_data_, Offset(call_data, ops_->call_offset),
62                            ops_->channel_data, std::move(input));
63     if (auto* r = p.value_if_ready()) {
64       if (r->ok == nullptr) return std::move(*r);
65       input = std::move(r->ok);
66       ++ops_;
67       continue;
68     }
69     return Pending{};
70   }
71 }
72 
73 template <typename T>
Step(void * call_data)74 Poll<ResultOr<T>> OperationExecutor<T>::Step(void* call_data) {
75   GPR_DEBUG_ASSERT(promise_data_ != nullptr);
76   auto p = ContinueStep(call_data);
77   if (p.ready()) {
78     gpr_free_aligned(promise_data_);
79     promise_data_ = nullptr;
80   }
81   return p;
82 }
83 
84 template <typename T>
ContinueStep(void * call_data)85 Poll<ResultOr<T>> OperationExecutor<T>::ContinueStep(void* call_data) {
86   auto p = ops_->poll(promise_data_);
87   if (auto* r = p.value_if_ready()) {
88     if (r->ok == nullptr) return std::move(*r);
89     ++ops_;
90     return InitStep(std::move(r->ok), call_data);
91   }
92   return Pending{};
93 }
94 
95 template <typename T>
~InfallibleOperationExecutor()96 InfallibleOperationExecutor<T>::~InfallibleOperationExecutor() {
97   if (promise_data_ != nullptr) {
98     ops_->early_destroy(promise_data_);
99     gpr_free_aligned(promise_data_);
100   }
101 }
102 
103 template <typename T>
Start(const Layout<InfallibleOperator<T>> * layout,T input,void * call_data)104 Poll<T> InfallibleOperationExecutor<T>::Start(
105     const Layout<InfallibleOperator<T>>* layout, T input, void* call_data) {
106   ops_ = layout->ops.data();
107   end_ops_ = ops_ + layout->ops.size();
108   if (layout->promise_size == 0) {
109     // No call state ==> instantaneously ready
110     auto r = InitStep(std::move(input), call_data);
111     GPR_ASSERT(r.ready());
112     return r;
113   }
114   promise_data_ =
115       gpr_malloc_aligned(layout->promise_size, layout->promise_alignment);
116   return InitStep(std::move(input), call_data);
117 }
118 
119 template <typename T>
InitStep(T input,void * call_data)120 Poll<T> InfallibleOperationExecutor<T>::InitStep(T input, void* call_data) {
121   while (true) {
122     if (ops_ == end_ops_) {
123       return input;
124     }
125     auto p =
126         ops_->promise_init(promise_data_, Offset(call_data, ops_->call_offset),
127                            ops_->channel_data, std::move(input));
128     if (auto* r = p.value_if_ready()) {
129       input = std::move(*r);
130       ++ops_;
131       continue;
132     }
133     return Pending{};
134   }
135 }
136 
137 template <typename T>
Step(void * call_data)138 Poll<T> InfallibleOperationExecutor<T>::Step(void* call_data) {
139   GPR_DEBUG_ASSERT(promise_data_ != nullptr);
140   auto p = ContinueStep(call_data);
141   if (p.ready()) {
142     gpr_free_aligned(promise_data_);
143     promise_data_ = nullptr;
144   }
145   return p;
146 }
147 
148 template <typename T>
ContinueStep(void * call_data)149 Poll<T> InfallibleOperationExecutor<T>::ContinueStep(void* call_data) {
150   auto p = ops_->poll(promise_data_);
151   if (auto* r = p.value_if_ready()) {
152     ++ops_;
153     return InitStep(std::move(*r), call_data);
154   }
155   return Pending{};
156 }
157 
158 // Explicit instantiations of some types used in filters.h
159 // We'll need to add ServerMetadataHandle to this when it becomes different
160 // to ClientMetadataHandle
161 template class OperationExecutor<ClientMetadataHandle>;
162 template class OperationExecutor<MessageHandle>;
163 template class InfallibleOperationExecutor<ServerMetadataHandle>;
164 }  // namespace filters_detail
165 
166 ///////////////////////////////////////////////////////////////////////////////
167 // CallFilters
168 
CallFilters(ClientMetadataHandle client_initial_metadata)169 CallFilters::CallFilters(ClientMetadataHandle client_initial_metadata)
170     : stack_(nullptr),
171       call_data_(nullptr),
172       client_initial_metadata_(std::move(client_initial_metadata)) {}
173 
~CallFilters()174 CallFilters::~CallFilters() {
175   if (call_data_ != nullptr) {
176     for (const auto& destructor : stack_->data_.filter_destructor) {
177       destructor.call_destroy(Offset(call_data_, destructor.call_offset));
178     }
179     gpr_free_aligned(call_data_);
180   }
181 }
182 
SetStack(RefCountedPtr<Stack> stack)183 void CallFilters::SetStack(RefCountedPtr<Stack> stack) {
184   GPR_ASSERT(call_data_ == nullptr);
185   stack_ = std::move(stack);
186   call_data_ = gpr_malloc_aligned(stack_->data_.call_data_size,
187                                   stack_->data_.call_data_alignment);
188   for (const auto& constructor : stack_->data_.filter_constructor) {
189     constructor.call_init(Offset(call_data_, constructor.call_offset),
190                           constructor.channel_data);
191   }
192   client_initial_metadata_state_.Start();
193   client_to_server_message_state_.Start();
194   server_initial_metadata_state_.Start();
195   server_to_client_message_state_.Start();
196 }
197 
Finalize(const grpc_call_final_info * final_info)198 void CallFilters::Finalize(const grpc_call_final_info* final_info) {
199   for (auto& finalizer : stack_->data_.finalizers) {
200     finalizer.final(Offset(call_data_, finalizer.call_offset),
201                     finalizer.channel_data, final_info);
202   }
203 }
204 
CancelDueToFailedPipeOperation(SourceLocation but_where)205 void CallFilters::CancelDueToFailedPipeOperation(SourceLocation but_where) {
206   // We expect something cancelled before now
207   if (server_trailing_metadata_ == nullptr) return;
208   if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_promise_primitives)) {
209     gpr_log(but_where.file(), but_where.line(), GPR_LOG_SEVERITY_DEBUG,
210             "Cancelling due to failed pipe operation: %s",
211             DebugString().c_str());
212   }
213   server_trailing_metadata_ =
214       ServerMetadataFromStatus(absl::CancelledError("Failed pipe operation"));
215   server_trailing_metadata_waiter_.Wake();
216 }
217 
PushServerTrailingMetadata(ServerMetadataHandle md)218 void CallFilters::PushServerTrailingMetadata(ServerMetadataHandle md) {
219   if (GRPC_TRACE_FLAG_ENABLED(grpc_trace_promise_primitives)) {
220     gpr_log(GPR_DEBUG, "%s Push server trailing metadata: %s into %s",
221             GetContext<Activity>()->DebugTag().c_str(),
222             md->DebugString().c_str(), DebugString().c_str());
223   }
224   GPR_ASSERT(md != nullptr);
225   if (server_trailing_metadata_ != nullptr) return;
226   server_trailing_metadata_ = std::move(md);
227   client_initial_metadata_state_.CloseWithError();
228   server_initial_metadata_state_.CloseSending();
229   client_to_server_message_state_.CloseWithError();
230   server_to_client_message_state_.CloseWithError();
231   server_trailing_metadata_waiter_.Wake();
232 }
233 
DebugString() const234 std::string CallFilters::DebugString() const {
235   std::vector<std::string> components = {
236       absl::StrFormat("this:%p", this),
237       absl::StrCat("client_initial_metadata:",
238                    client_initial_metadata_state_.DebugString()),
239       ServerInitialMetadataPromises::DebugString("server_initial_metadata",
240                                                  this),
241       ClientToServerMessagePromises::DebugString("client_to_server_message",
242                                                  this),
243       ServerToClientMessagePromises::DebugString("server_to_client_message",
244                                                  this),
245       absl::StrCat("server_trailing_metadata:",
246                    server_trailing_metadata_ == nullptr
247                        ? "not-set"
248                        : server_trailing_metadata_->DebugString())};
249   return absl::StrCat("CallFilters{", absl::StrJoin(components, ", "), "}");
250 };
251 
252 ///////////////////////////////////////////////////////////////////////////////
253 // CallFilters::Stack
254 
~Stack()255 CallFilters::Stack::~Stack() {
256   for (auto& destructor : data_.channel_data_destructors) {
257     destructor.destroy(destructor.channel_data);
258   }
259 }
260 
261 ///////////////////////////////////////////////////////////////////////////////
262 // CallFilters::StackBuilder
263 
~StackBuilder()264 CallFilters::StackBuilder::~StackBuilder() {
265   for (auto& destructor : data_.channel_data_destructors) {
266     destructor.destroy(destructor.channel_data);
267   }
268 }
269 
Build()270 RefCountedPtr<CallFilters::Stack> CallFilters::StackBuilder::Build() {
271   if (data_.call_data_size % data_.call_data_alignment != 0) {
272     data_.call_data_size += data_.call_data_alignment -
273                             data_.call_data_size % data_.call_data_alignment;
274   }
275   // server -> client needs to be reversed so that we can iterate all stacks
276   // in the same order
277   data_.server_initial_metadata.Reverse();
278   data_.server_to_client_messages.Reverse();
279   data_.server_trailing_metadata.Reverse();
280   return RefCountedPtr<Stack>(new Stack(std::move(data_)));
281 }
282 
283 ///////////////////////////////////////////////////////////////////////////////
284 // CallFilters::PipeState
285 
Start()286 void filters_detail::PipeState::Start() {
287   GPR_DEBUG_ASSERT(!started_);
288   started_ = true;
289   wait_recv_.Wake();
290 }
291 
CloseWithError()292 void filters_detail::PipeState::CloseWithError() {
293   if (state_ == ValueState::kClosed) return;
294   state_ = ValueState::kError;
295   wait_recv_.Wake();
296   wait_send_.Wake();
297 }
298 
PollClosed()299 Poll<bool> filters_detail::PipeState::PollClosed() {
300   switch (state_) {
301     case ValueState::kIdle:
302     case ValueState::kWaiting:
303     case ValueState::kQueued:
304     case ValueState::kReady:
305     case ValueState::kProcessing:
306       return wait_recv_.pending();
307     case ValueState::kClosed:
308       return false;
309     case ValueState::kError:
310       return true;
311   }
312   GPR_UNREACHABLE_CODE(return Pending{});
313 }
314 
CloseSending()315 void filters_detail::PipeState::CloseSending() {
316   switch (state_) {
317     case ValueState::kIdle:
318       state_ = ValueState::kClosed;
319       break;
320     case ValueState::kWaiting:
321       state_ = ValueState::kClosed;
322       wait_recv_.Wake();
323       break;
324     case ValueState::kClosed:
325     case ValueState::kError:
326       break;
327     case ValueState::kQueued:
328     case ValueState::kReady:
329     case ValueState::kProcessing:
330       Crash("Only one push allowed to be outstanding");
331       break;
332   }
333 }
334 
BeginPush()335 void filters_detail::PipeState::BeginPush() {
336   switch (state_) {
337     case ValueState::kIdle:
338       state_ = ValueState::kQueued;
339       break;
340     case ValueState::kWaiting:
341       state_ = ValueState::kReady;
342       wait_recv_.Wake();
343       break;
344     case ValueState::kClosed:
345     case ValueState::kError:
346       break;
347     case ValueState::kQueued:
348     case ValueState::kReady:
349     case ValueState::kProcessing:
350       Crash("Only one push allowed to be outstanding");
351       break;
352   }
353 }
354 
DropPush()355 void filters_detail::PipeState::DropPush() {
356   switch (state_) {
357     case ValueState::kQueued:
358     case ValueState::kReady:
359     case ValueState::kProcessing:
360     case ValueState::kWaiting:
361       state_ = ValueState::kError;
362       wait_recv_.Wake();
363       break;
364     case ValueState::kIdle:
365     case ValueState::kClosed:
366     case ValueState::kError:
367       break;
368   }
369 }
370 
DropPull()371 void filters_detail::PipeState::DropPull() {
372   switch (state_) {
373     case ValueState::kQueued:
374     case ValueState::kReady:
375     case ValueState::kProcessing:
376     case ValueState::kWaiting:
377       state_ = ValueState::kError;
378       wait_send_.Wake();
379       break;
380     case ValueState::kIdle:
381     case ValueState::kClosed:
382     case ValueState::kError:
383       break;
384   }
385 }
386 
PollPush()387 Poll<StatusFlag> filters_detail::PipeState::PollPush() {
388   switch (state_) {
389     case ValueState::kIdle:
390     // Read completed and new read started => we see waiting here
391     case ValueState::kWaiting:
392     case ValueState::kClosed:
393       return Success{};
394     case ValueState::kQueued:
395     case ValueState::kReady:
396     case ValueState::kProcessing:
397       return wait_send_.pending();
398     case ValueState::kError:
399       return Failure{};
400   }
401   GPR_UNREACHABLE_CODE(return Pending{});
402 }
403 
PollPull()404 Poll<ValueOrFailure<bool>> filters_detail::PipeState::PollPull() {
405   switch (state_) {
406     case ValueState::kWaiting:
407       return wait_recv_.pending();
408     case ValueState::kIdle:
409       state_ = ValueState::kWaiting;
410       return wait_recv_.pending();
411     case ValueState::kReady:
412     case ValueState::kQueued:
413       if (!started_) return wait_recv_.pending();
414       state_ = ValueState::kProcessing;
415       return true;
416     case ValueState::kProcessing:
417       Crash("Only one pull allowed to be outstanding");
418     case ValueState::kClosed:
419       return false;
420     case ValueState::kError:
421       return Failure{};
422   }
423   GPR_UNREACHABLE_CODE(return Pending{});
424 }
425 
AckPull()426 void filters_detail::PipeState::AckPull() {
427   switch (state_) {
428     case ValueState::kProcessing:
429       state_ = ValueState::kIdle;
430       wait_send_.Wake();
431       break;
432     case ValueState::kWaiting:
433     case ValueState::kIdle:
434     case ValueState::kQueued:
435     case ValueState::kReady:
436     case ValueState::kClosed:
437       Crash("AckPullValue called in invalid state");
438     case ValueState::kError:
439       break;
440   }
441 }
442 
DebugString() const443 std::string filters_detail::PipeState::DebugString() const {
444   const char* state_str = "<<invalid-value>>";
445   switch (state_) {
446     case ValueState::kIdle:
447       state_str = "Idle";
448       break;
449     case ValueState::kWaiting:
450       state_str = "Waiting";
451       break;
452     case ValueState::kQueued:
453       state_str = "Queued";
454       break;
455     case ValueState::kReady:
456       state_str = "Ready";
457       break;
458     case ValueState::kProcessing:
459       state_str = "Processing";
460       break;
461     case ValueState::kClosed:
462       state_str = "Closed";
463       break;
464     case ValueState::kError:
465       state_str = "Error";
466       break;
467   }
468   return absl::StrCat(state_str, started_ ? "" : " (not started)");
469 }
470 
471 }  // namespace grpc_core
472