1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include <grpc/support/port_platform.h>
20
21 #include "src/core/lib/channel/connected_channel.h"
22
23 #include <inttypes.h>
24
25 #include <functional>
26 #include <initializer_list>
27 #include <memory>
28 #include <string>
29 #include <type_traits>
30 #include <utility>
31
32 #include "absl/status/status.h"
33 #include "absl/status/statusor.h"
34 #include "absl/types/optional.h"
35 #include "absl/types/variant.h"
36
37 #include <grpc/grpc.h>
38 #include <grpc/status.h>
39 #include <grpc/support/alloc.h>
40 #include <grpc/support/log.h>
41
42 #include "src/core/lib/channel/channel_args.h"
43 #include "src/core/lib/channel/channel_fwd.h"
44 #include "src/core/lib/channel/channel_stack.h"
45 #include "src/core/lib/debug/trace.h"
46 #include "src/core/lib/experiments/experiments.h"
47 #include "src/core/lib/gpr/alloc.h"
48 #include "src/core/lib/gprpp/debug_location.h"
49 #include "src/core/lib/gprpp/orphanable.h"
50 #include "src/core/lib/gprpp/ref_counted_ptr.h"
51 #include "src/core/lib/gprpp/time.h"
52 #include "src/core/lib/iomgr/call_combiner.h"
53 #include "src/core/lib/iomgr/closure.h"
54 #include "src/core/lib/iomgr/error.h"
55 #include "src/core/lib/iomgr/polling_entity.h"
56 #include "src/core/lib/promise/activity.h"
57 #include "src/core/lib/promise/arena_promise.h"
58 #include "src/core/lib/promise/context.h"
59 #include "src/core/lib/promise/detail/basic_seq.h"
60 #include "src/core/lib/promise/detail/status.h"
61 #include "src/core/lib/promise/for_each.h"
62 #include "src/core/lib/promise/if.h"
63 #include "src/core/lib/promise/latch.h"
64 #include "src/core/lib/promise/loop.h"
65 #include "src/core/lib/promise/map.h"
66 #include "src/core/lib/promise/party.h"
67 #include "src/core/lib/promise/pipe.h"
68 #include "src/core/lib/promise/poll.h"
69 #include "src/core/lib/promise/promise.h"
70 #include "src/core/lib/promise/race.h"
71 #include "src/core/lib/promise/seq.h"
72 #include "src/core/lib/promise/try_seq.h"
73 #include "src/core/lib/resource_quota/arena.h"
74 #include "src/core/lib/slice/slice.h"
75 #include "src/core/lib/slice/slice_buffer.h"
76 #include "src/core/lib/surface/call.h"
77 #include "src/core/lib/surface/call_trace.h"
78 #include "src/core/lib/surface/channel_stack_type.h"
79 #include "src/core/lib/transport/batch_builder.h"
80 #include "src/core/lib/transport/error_utils.h"
81 #include "src/core/lib/transport/metadata_batch.h"
82 #include "src/core/lib/transport/transport.h"
83 #include "src/core/lib/transport/transport_fwd.h"
84 #include "src/core/lib/transport/transport_impl.h"
85
86 typedef struct connected_channel_channel_data {
87 grpc_transport* transport;
88 } channel_data;
89
90 struct callback_state {
91 grpc_closure closure;
92 grpc_closure* original_closure;
93 grpc_core::CallCombiner* call_combiner;
94 const char* reason;
95 };
96 typedef struct connected_channel_call_data {
97 grpc_core::CallCombiner* call_combiner;
98 // Closures used for returning results on the call combiner.
99 callback_state on_complete[6]; // Max number of pending batches.
100 callback_state recv_initial_metadata_ready;
101 callback_state recv_message_ready;
102 callback_state recv_trailing_metadata_ready;
103 } call_data;
104
run_in_call_combiner(void * arg,grpc_error_handle error)105 static void run_in_call_combiner(void* arg, grpc_error_handle error) {
106 callback_state* state = static_cast<callback_state*>(arg);
107 GRPC_CALL_COMBINER_START(state->call_combiner, state->original_closure, error,
108 state->reason);
109 }
110
run_cancel_in_call_combiner(void * arg,grpc_error_handle error)111 static void run_cancel_in_call_combiner(void* arg, grpc_error_handle error) {
112 run_in_call_combiner(arg, error);
113 gpr_free(arg);
114 }
115
intercept_callback(call_data * calld,callback_state * state,bool free_when_done,const char * reason,grpc_closure ** original_closure)116 static void intercept_callback(call_data* calld, callback_state* state,
117 bool free_when_done, const char* reason,
118 grpc_closure** original_closure) {
119 state->original_closure = *original_closure;
120 state->call_combiner = calld->call_combiner;
121 state->reason = reason;
122 *original_closure = GRPC_CLOSURE_INIT(
123 &state->closure,
124 free_when_done ? run_cancel_in_call_combiner : run_in_call_combiner,
125 state, grpc_schedule_on_exec_ctx);
126 }
127
get_state_for_batch(call_data * calld,grpc_transport_stream_op_batch * batch)128 static callback_state* get_state_for_batch(
129 call_data* calld, grpc_transport_stream_op_batch* batch) {
130 if (batch->send_initial_metadata) return &calld->on_complete[0];
131 if (batch->send_message) return &calld->on_complete[1];
132 if (batch->send_trailing_metadata) return &calld->on_complete[2];
133 if (batch->recv_initial_metadata) return &calld->on_complete[3];
134 if (batch->recv_message) return &calld->on_complete[4];
135 if (batch->recv_trailing_metadata) return &calld->on_complete[5];
136 GPR_UNREACHABLE_CODE(return nullptr);
137 }
138
139 // We perform a small hack to locate transport data alongside the connected
140 // channel data in call allocations, to allow everything to be pulled in minimal
141 // cache line requests
142 #define TRANSPORT_STREAM_FROM_CALL_DATA(calld) \
143 ((grpc_stream*)(((char*)(calld)) + \
144 GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(call_data))))
145 #define CALL_DATA_FROM_TRANSPORT_STREAM(transport_stream) \
146 ((call_data*)(((char*)(transport_stream)) - \
147 GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(call_data))))
148
149 // Intercept a call operation and either push it directly up or translate it
150 // into transport stream operations
connected_channel_start_transport_stream_op_batch(grpc_call_element * elem,grpc_transport_stream_op_batch * batch)151 static void connected_channel_start_transport_stream_op_batch(
152 grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
153 call_data* calld = static_cast<call_data*>(elem->call_data);
154 channel_data* chand = static_cast<channel_data*>(elem->channel_data);
155 if (batch->recv_initial_metadata) {
156 callback_state* state = &calld->recv_initial_metadata_ready;
157 intercept_callback(
158 calld, state, false, "recv_initial_metadata_ready",
159 &batch->payload->recv_initial_metadata.recv_initial_metadata_ready);
160 }
161 if (batch->recv_message) {
162 callback_state* state = &calld->recv_message_ready;
163 intercept_callback(calld, state, false, "recv_message_ready",
164 &batch->payload->recv_message.recv_message_ready);
165 }
166 if (batch->recv_trailing_metadata) {
167 callback_state* state = &calld->recv_trailing_metadata_ready;
168 intercept_callback(
169 calld, state, false, "recv_trailing_metadata_ready",
170 &batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready);
171 }
172 if (batch->cancel_stream) {
173 // There can be more than one cancellation batch in flight at any
174 // given time, so we can't just pick out a fixed index into
175 // calld->on_complete like we can for the other ops. However,
176 // cancellation isn't in the fast path, so we just allocate a new
177 // closure for each one.
178 callback_state* state =
179 static_cast<callback_state*>(gpr_malloc(sizeof(*state)));
180 intercept_callback(calld, state, true, "on_complete (cancel_stream)",
181 &batch->on_complete);
182 } else if (batch->on_complete != nullptr) {
183 callback_state* state = get_state_for_batch(calld, batch);
184 intercept_callback(calld, state, false, "on_complete", &batch->on_complete);
185 }
186 grpc_transport_perform_stream_op(
187 chand->transport, TRANSPORT_STREAM_FROM_CALL_DATA(calld), batch);
188 GRPC_CALL_COMBINER_STOP(calld->call_combiner, "passed batch to transport");
189 }
190
connected_channel_start_transport_op(grpc_channel_element * elem,grpc_transport_op * op)191 static void connected_channel_start_transport_op(grpc_channel_element* elem,
192 grpc_transport_op* op) {
193 channel_data* chand = static_cast<channel_data*>(elem->channel_data);
194 grpc_transport_perform_op(chand->transport, op);
195 }
196
197 // Constructor for call_data
connected_channel_init_call_elem(grpc_call_element * elem,const grpc_call_element_args * args)198 static grpc_error_handle connected_channel_init_call_elem(
199 grpc_call_element* elem, const grpc_call_element_args* args) {
200 call_data* calld = static_cast<call_data*>(elem->call_data);
201 channel_data* chand = static_cast<channel_data*>(elem->channel_data);
202 calld->call_combiner = args->call_combiner;
203 int r = grpc_transport_init_stream(
204 chand->transport, TRANSPORT_STREAM_FROM_CALL_DATA(calld),
205 &args->call_stack->refcount, args->server_transport_data, args->arena);
206 return r == 0 ? absl::OkStatus()
207 : GRPC_ERROR_CREATE("transport stream initialization failed");
208 }
209
set_pollset_or_pollset_set(grpc_call_element * elem,grpc_polling_entity * pollent)210 static void set_pollset_or_pollset_set(grpc_call_element* elem,
211 grpc_polling_entity* pollent) {
212 call_data* calld = static_cast<call_data*>(elem->call_data);
213 channel_data* chand = static_cast<channel_data*>(elem->channel_data);
214 grpc_transport_set_pops(chand->transport,
215 TRANSPORT_STREAM_FROM_CALL_DATA(calld), pollent);
216 }
217
218 // Destructor for call_data
connected_channel_destroy_call_elem(grpc_call_element * elem,const grpc_call_final_info *,grpc_closure * then_schedule_closure)219 static void connected_channel_destroy_call_elem(
220 grpc_call_element* elem, const grpc_call_final_info* /*final_info*/,
221 grpc_closure* then_schedule_closure) {
222 call_data* calld = static_cast<call_data*>(elem->call_data);
223 channel_data* chand = static_cast<channel_data*>(elem->channel_data);
224 grpc_transport_destroy_stream(chand->transport,
225 TRANSPORT_STREAM_FROM_CALL_DATA(calld),
226 then_schedule_closure);
227 }
228
229 // Constructor for channel_data
connected_channel_init_channel_elem(grpc_channel_element * elem,grpc_channel_element_args * args)230 static grpc_error_handle connected_channel_init_channel_elem(
231 grpc_channel_element* elem, grpc_channel_element_args* args) {
232 channel_data* cd = static_cast<channel_data*>(elem->channel_data);
233 GPR_ASSERT(args->is_last);
234 cd->transport = args->channel_args.GetObject<grpc_transport>();
235 return absl::OkStatus();
236 }
237
238 // Destructor for channel_data
connected_channel_destroy_channel_elem(grpc_channel_element * elem)239 static void connected_channel_destroy_channel_elem(grpc_channel_element* elem) {
240 channel_data* cd = static_cast<channel_data*>(elem->channel_data);
241 if (cd->transport) {
242 grpc_transport_destroy(cd->transport);
243 }
244 }
245
246 // No-op.
connected_channel_get_channel_info(grpc_channel_element *,const grpc_channel_info *)247 static void connected_channel_get_channel_info(
248 grpc_channel_element* /*elem*/, const grpc_channel_info* /*channel_info*/) {
249 }
250
251 namespace grpc_core {
252 namespace {
253
254 #if defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL) || \
255 defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL)
256 class ConnectedChannelStream : public Orphanable {
257 public:
ConnectedChannelStream(grpc_transport * transport)258 explicit ConnectedChannelStream(grpc_transport* transport)
259 : transport_(transport), stream_(nullptr, StreamDeleter(this)) {
260 GRPC_STREAM_REF_INIT(
261 &stream_refcount_, 1,
262 [](void* p, grpc_error_handle) {
263 static_cast<ConnectedChannelStream*>(p)->BeginDestroy();
264 },
265 this, "ConnectedChannelStream");
266 }
267
transport()268 grpc_transport* transport() { return transport_; }
stream_destroyed_closure()269 grpc_closure* stream_destroyed_closure() { return &stream_destroyed_; }
270
batch_target()271 BatchBuilder::Target batch_target() {
272 return BatchBuilder::Target{transport_, stream_.get(), &stream_refcount_};
273 }
274
IncrementRefCount(const char * reason="smartptr")275 void IncrementRefCount(const char* reason = "smartptr") {
276 #ifndef NDEBUG
277 grpc_stream_ref(&stream_refcount_, reason);
278 #else
279 (void)reason;
280 grpc_stream_ref(&stream_refcount_);
281 #endif
282 }
283
Unref(const char * reason="smartptr")284 void Unref(const char* reason = "smartptr") {
285 #ifndef NDEBUG
286 grpc_stream_unref(&stream_refcount_, reason);
287 #else
288 (void)reason;
289 grpc_stream_unref(&stream_refcount_);
290 #endif
291 }
292
InternalRef()293 RefCountedPtr<ConnectedChannelStream> InternalRef() {
294 IncrementRefCount("smartptr");
295 return RefCountedPtr<ConnectedChannelStream>(this);
296 }
297
Orphan()298 void Orphan() final {
299 bool finished = finished_.IsSet();
300 if (grpc_call_trace.enabled()) {
301 gpr_log(GPR_DEBUG, "%s[connected] Orphan stream, finished: %d",
302 party_->DebugTag().c_str(), finished);
303 }
304 // If we hadn't already observed the stream to be finished, we need to
305 // cancel it at the transport.
306 if (!finished) {
307 party_->Spawn(
308 "finish",
309 [self = InternalRef()]() {
310 if (!self->finished_.IsSet()) {
311 self->finished_.Set();
312 }
313 return Empty{};
314 },
315 [](Empty) {});
316 GetContext<BatchBuilder>()->Cancel(batch_target(),
317 absl::CancelledError());
318 }
319 Unref("orphan connected stream");
320 }
321
322 // Returns a promise that implements the receive message loop.
323 auto RecvMessages(PipeSender<MessageHandle>* incoming_messages,
324 bool cancel_on_error);
325 // Returns a promise that implements the send message loop.
326 auto SendMessages(PipeReceiver<MessageHandle>* outgoing_messages);
327
SetStream(grpc_stream * stream)328 void SetStream(grpc_stream* stream) { stream_.reset(stream); }
stream()329 grpc_stream* stream() { return stream_.get(); }
stream_refcount()330 grpc_stream_refcount* stream_refcount() { return &stream_refcount_; }
331
set_finished()332 void set_finished() { finished_.Set(); }
WaitFinished()333 auto WaitFinished() { return finished_.Wait(); }
334
335 private:
336 class StreamDeleter {
337 public:
StreamDeleter(ConnectedChannelStream * impl)338 explicit StreamDeleter(ConnectedChannelStream* impl) : impl_(impl) {}
operator ()(grpc_stream * stream) const339 void operator()(grpc_stream* stream) const {
340 if (stream == nullptr) return;
341 grpc_transport_destroy_stream(impl_->transport(), stream,
342 impl_->stream_destroyed_closure());
343 }
344
345 private:
346 ConnectedChannelStream* impl_;
347 };
348 using StreamPtr = std::unique_ptr<grpc_stream, StreamDeleter>;
349
StreamDestroyed()350 void StreamDestroyed() {
351 call_context_->RunInContext([this] { this->~ConnectedChannelStream(); });
352 }
353
BeginDestroy()354 void BeginDestroy() {
355 if (stream_ != nullptr) {
356 stream_.reset();
357 } else {
358 StreamDestroyed();
359 }
360 }
361
362 grpc_transport* const transport_;
363 RefCountedPtr<CallContext> const call_context_{
364 GetContext<CallContext>()->Ref()};
365 grpc_closure stream_destroyed_ =
366 MakeMemberClosure<ConnectedChannelStream,
367 &ConnectedChannelStream::StreamDestroyed>(
368 this, DEBUG_LOCATION);
369 grpc_stream_refcount stream_refcount_;
370 StreamPtr stream_;
371 Arena* arena_ = GetContext<Arena>();
372 Party* const party_ = static_cast<Party*>(Activity::current());
373 ExternallyObservableLatch<void> finished_;
374 };
375
RecvMessages(PipeSender<MessageHandle> * incoming_messages,bool cancel_on_error)376 auto ConnectedChannelStream::RecvMessages(
377 PipeSender<MessageHandle>* incoming_messages, bool cancel_on_error) {
378 return Loop([self = InternalRef(), cancel_on_error,
379 incoming_messages = std::move(*incoming_messages)]() mutable {
380 return Seq(
381 GetContext<BatchBuilder>()->ReceiveMessage(self->batch_target()),
382 [cancel_on_error, &incoming_messages](
383 absl::StatusOr<absl::optional<MessageHandle>> status) mutable {
384 bool has_message = status.ok() && status->has_value();
385 auto publish_message = [&incoming_messages, &status]() {
386 auto pending_message = std::move(**status);
387 if (grpc_call_trace.enabled()) {
388 gpr_log(GPR_INFO,
389 "%s[connected] RecvMessage: received payload of %" PRIdPTR
390 " bytes",
391 Activity::current()->DebugTag().c_str(),
392 pending_message->payload()->Length());
393 }
394 return Map(incoming_messages.Push(std::move(pending_message)),
395 [](bool ok) -> LoopCtl<absl::Status> {
396 if (!ok) {
397 if (grpc_call_trace.enabled()) {
398 gpr_log(GPR_INFO,
399 "%s[connected] RecvMessage: failed to "
400 "push message towards the application",
401 Activity::current()->DebugTag().c_str());
402 }
403 return absl::OkStatus();
404 }
405 return Continue{};
406 });
407 };
408 auto publish_close = [cancel_on_error, &incoming_messages,
409 &status]() mutable {
410 if (grpc_call_trace.enabled()) {
411 gpr_log(GPR_INFO,
412 "%s[connected] RecvMessage: reached end of stream with "
413 "status:%s",
414 Activity::current()->DebugTag().c_str(),
415 status.status().ToString().c_str());
416 }
417 if (cancel_on_error && !status.ok()) {
418 incoming_messages.CloseWithError();
419 }
420 return Immediate(LoopCtl<absl::Status>(status.status()));
421 };
422 return If(has_message, std::move(publish_message),
423 std::move(publish_close));
424 });
425 });
426 }
427
SendMessages(PipeReceiver<MessageHandle> * outgoing_messages)428 auto ConnectedChannelStream::SendMessages(
429 PipeReceiver<MessageHandle>* outgoing_messages) {
430 return ForEach(std::move(*outgoing_messages),
431 [self = InternalRef()](MessageHandle message) {
432 return GetContext<BatchBuilder>()->SendMessage(
433 self->batch_target(), std::move(message));
434 });
435 }
436 #endif // defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL) ||
437 // defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL)
438
439 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL
MakeClientCallPromise(grpc_transport * transport,CallArgs call_args,NextPromiseFactory)440 ArenaPromise<ServerMetadataHandle> MakeClientCallPromise(
441 grpc_transport* transport, CallArgs call_args, NextPromiseFactory) {
442 OrphanablePtr<ConnectedChannelStream> stream(
443 GetContext<Arena>()->New<ConnectedChannelStream>(transport));
444 stream->SetStream(static_cast<grpc_stream*>(
445 GetContext<Arena>()->Alloc(transport->vtable->sizeof_stream)));
446 grpc_transport_init_stream(transport, stream->stream(),
447 stream->stream_refcount(), nullptr,
448 GetContext<Arena>());
449 auto* party = static_cast<Party*>(Activity::current());
450 party->Spawn(
451 "set_polling_entity", call_args.polling_entity->Wait(),
452 [transport,
453 stream = stream->InternalRef()](grpc_polling_entity polling_entity) {
454 grpc_transport_set_pops(transport, stream->stream(), &polling_entity);
455 });
456 // Start a loop to send messages from client_to_server_messages to the
457 // transport. When the pipe closes and the loop completes, send a trailing
458 // metadata batch to close the stream.
459 party->Spawn(
460 "send_messages",
461 TrySeq(stream->SendMessages(call_args.client_to_server_messages),
462 [stream = stream->InternalRef()]() {
463 return GetContext<BatchBuilder>()->SendClientTrailingMetadata(
464 stream->batch_target());
465 }),
466 [](absl::Status) {});
467 // Start a promise to receive server initial metadata and then forward it up
468 // through the receiving pipe.
469 auto server_initial_metadata =
470 GetContext<Arena>()->MakePooled<ServerMetadata>(GetContext<Arena>());
471 party->Spawn(
472 "recv_initial_metadata",
473 TrySeq(GetContext<BatchBuilder>()->ReceiveServerInitialMetadata(
474 stream->batch_target()),
475 [pipe = call_args.server_initial_metadata](
476 ServerMetadataHandle server_initial_metadata) {
477 if (grpc_call_trace.enabled()) {
478 gpr_log(GPR_DEBUG,
479 "%s[connected] Publish client initial metadata: %s",
480 Activity::current()->DebugTag().c_str(),
481 server_initial_metadata->DebugString().c_str());
482 }
483 return Map(pipe->Push(std::move(server_initial_metadata)),
484 [](bool r) {
485 if (r) return absl::OkStatus();
486 return absl::CancelledError();
487 });
488 }),
489 [](absl::Status) {});
490
491 // Build up the rest of the main call promise:
492
493 // Create a promise that will send initial metadata and then signal completion
494 // of that via the token.
495 auto send_initial_metadata = Seq(
496 GetContext<BatchBuilder>()->SendClientInitialMetadata(
497 stream->batch_target(), std::move(call_args.client_initial_metadata)),
498 [sent_initial_metadata_token =
499 std::move(call_args.client_initial_metadata_outstanding)](
500 absl::Status status) mutable {
501 sent_initial_metadata_token.Complete(status.ok());
502 return status;
503 });
504 // Create a promise that will receive server trailing metadata.
505 // If this fails, we massage the error into metadata that we can report
506 // upwards.
507 auto server_trailing_metadata =
508 GetContext<Arena>()->MakePooled<ServerMetadata>(GetContext<Arena>());
509 auto recv_trailing_metadata =
510 Map(GetContext<BatchBuilder>()->ReceiveServerTrailingMetadata(
511 stream->batch_target()),
512 [](absl::StatusOr<ServerMetadataHandle> status) mutable {
513 if (!status.ok()) {
514 auto server_trailing_metadata =
515 GetContext<Arena>()->MakePooled<ServerMetadata>(
516 GetContext<Arena>());
517 grpc_status_code status_code = GRPC_STATUS_UNKNOWN;
518 std::string message;
519 grpc_error_get_status(status.status(), Timestamp::InfFuture(),
520 &status_code, &message, nullptr, nullptr);
521 server_trailing_metadata->Set(GrpcStatusMetadata(), status_code);
522 server_trailing_metadata->Set(GrpcMessageMetadata(),
523 Slice::FromCopiedString(message));
524 return server_trailing_metadata;
525 } else {
526 return std::move(*status);
527 }
528 });
529 // Finally the main call promise.
530 // Concurrently: send initial metadata and receive messages, until BOTH
531 // complete (or one fails).
532 // Next: receive trailing metadata, and return that up the stack.
533 auto recv_messages =
534 stream->RecvMessages(call_args.server_to_client_messages, false);
535 return Map(
536 [send_initial_metadata = std::move(send_initial_metadata),
537 recv_messages = std::move(recv_messages),
538 recv_trailing_metadata = std::move(recv_trailing_metadata),
539 done_send_initial_metadata = false, done_recv_messages = false,
540 done_recv_trailing_metadata =
541 false]() mutable -> Poll<ServerMetadataHandle> {
542 if (!done_send_initial_metadata) {
543 auto p = send_initial_metadata();
544 if (auto* r = p.value_if_ready()) {
545 done_send_initial_metadata = true;
546 if (!r->ok()) return StatusCast<ServerMetadataHandle>(*r);
547 }
548 }
549 if (!done_recv_messages) {
550 auto p = recv_messages();
551 if (auto* r = p.value_if_ready()) {
552 // NOTE: ignore errors here, they'll be collected in the
553 // recv_trailing_metadata.
554 done_recv_messages = true;
555 } else {
556 return Pending{};
557 }
558 }
559 if (!done_recv_trailing_metadata) {
560 auto p = recv_trailing_metadata();
561 if (auto* r = p.value_if_ready()) {
562 done_recv_trailing_metadata = true;
563 return std::move(*r);
564 }
565 }
566 return Pending{};
567 },
568 [stream = std::move(stream)](ServerMetadataHandle result) {
569 stream->set_finished();
570 return result;
571 });
572 }
573 #endif
574
575 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL
MakeServerCallPromise(grpc_transport * transport,CallArgs,NextPromiseFactory next_promise_factory)576 ArenaPromise<ServerMetadataHandle> MakeServerCallPromise(
577 grpc_transport* transport, CallArgs,
578 NextPromiseFactory next_promise_factory) {
579 OrphanablePtr<ConnectedChannelStream> stream(
580 GetContext<Arena>()->New<ConnectedChannelStream>(transport));
581
582 stream->SetStream(static_cast<grpc_stream*>(
583 GetContext<Arena>()->Alloc(transport->vtable->sizeof_stream)));
584 grpc_transport_init_stream(
585 transport, stream->stream(), stream->stream_refcount(),
586 GetContext<CallContext>()->server_call_context()->server_stream_data(),
587 GetContext<Arena>());
588 auto* party = static_cast<Party*>(Activity::current());
589
590 // Arifacts we need for the lifetime of the call.
591 struct CallData {
592 Pipe<MessageHandle> server_to_client;
593 Pipe<MessageHandle> client_to_server;
594 Pipe<ServerMetadataHandle> server_initial_metadata;
595 Latch<ServerMetadataHandle> failure_latch;
596 Latch<grpc_polling_entity> polling_entity_latch;
597 bool sent_initial_metadata = false;
598 bool sent_trailing_metadata = false;
599 };
600 auto* call_data = GetContext<Arena>()->ManagedNew<CallData>();
601
602 party->Spawn(
603 "set_polling_entity", call_data->polling_entity_latch.Wait(),
604 [transport,
605 stream = stream->InternalRef()](grpc_polling_entity polling_entity) {
606 grpc_transport_set_pops(transport, stream->stream(), &polling_entity);
607 });
608
609 auto server_to_client_empty =
610 call_data->server_to_client.receiver.AwaitEmpty();
611
612 // Create a promise that will receive client initial metadata, and then run
613 // the main stem of the call (calling next_promise_factory up through the
614 // filters).
615 // Race the main call with failure_latch, allowing us to forcefully complete
616 // the call in the case of a failure.
617 auto recv_initial_metadata_then_run_promise =
618 TrySeq(GetContext<BatchBuilder>()->ReceiveClientInitialMetadata(
619 stream->batch_target()),
620 [next_promise_factory = std::move(next_promise_factory),
621 server_to_client_empty = std::move(server_to_client_empty),
622 call_data](ClientMetadataHandle client_initial_metadata) {
623 auto call_promise = next_promise_factory(CallArgs{
624 std::move(client_initial_metadata),
625 ClientInitialMetadataOutstandingToken::Empty(),
626 &call_data->polling_entity_latch,
627 &call_data->server_initial_metadata.sender,
628 &call_data->client_to_server.receiver,
629 &call_data->server_to_client.sender,
630 });
631 return Race(call_data->failure_latch.Wait(),
632 [call_promise = std::move(call_promise),
633 server_to_client_empty =
634 std::move(server_to_client_empty)]() mutable
635 -> Poll<ServerMetadataHandle> {
636 // TODO(ctiller): this is deeply weird and we need
637 // to clean this up.
638 //
639 // The following few lines check to ensure that
640 // there's no message currently pending in the
641 // outgoing message queue, and if (and only if)
642 // that's true decides to poll the main promise to
643 // see if there's a result.
644 //
645 // This essentially introduces a polling priority
646 // scheme that makes the current promise structure
647 // work out the way we want when talking to
648 // transports.
649 //
650 // The problem is that transports are going to need
651 // to replicate this structure when they convert to
652 // promises, and that becomes troubling as we'll be
653 // replicating weird throughout the stack.
654 //
655 // Instead we likely need to change the way we're
656 // composing promises through the stack.
657 //
658 // Proposed is to change filters from a promise
659 // that takes ClientInitialMetadata and returns
660 // ServerTrailingMetadata with three pipes for
661 // ServerInitialMetadata and
662 // ClientToServerMessages, ServerToClientMessages.
663 // Instead we'll have five pipes, moving
664 // ClientInitialMetadata and ServerTrailingMetadata
665 // to pipes that can be intercepted.
666 //
667 // The effect of this change will be to cripple the
668 // things that can be done in a filter (but cripple
669 // in line with what most filters actually do).
670 // We'll likely need to add a `CallContext::Cancel`
671 // to allow filters to cancel a request, but this
672 // would also have the advantage of centralizing
673 // our cancellation machinery which seems like an
674 // additional win - with the net effect that the
675 // shape of the call gets made explicit at the top
676 // & bottom of the stack.
677 //
678 // There's a small set of filters (retry, this one,
679 // lame client, clinet channel) that terminate
680 // stacks and need a richer set of semantics, but
681 // that ends up being fine because we can spawn
682 // tasks in parties to handle those edge cases, and
683 // keep the majority of filters simple: they just
684 // call InterceptAndMap on a handful of filters at
685 // call initialization time and then proceed to
686 // actually filter.
687 //
688 // So that's the plan, why isn't it enacted here?
689 //
690 // Well, the plan ends up being easy to implement
691 // in the promise based world (I did a prototype on
692 // a branch in an afternoon). It's heinous to
693 // implement in promise_based_filter, and that code
694 // is load bearing for us at the time of writing.
695 // It's not worth delaying promises for a further N
696 // months (N ~ 6) to make that change.
697 //
698 // Instead, we'll move forward with this, get
699 // promise_based_filter out of the picture, and
700 // then during the mop-up phase for promises tweak
701 // the compute structure to move to the magical
702 // five pipes (I'm reminded of an old Onion
703 // article), and end up in a good happy place.
704 if (server_to_client_empty().pending()) {
705 return Pending{};
706 }
707 return call_promise();
708 });
709 });
710
711 // Promise factory that accepts a ServerMetadataHandle, and sends it as the
712 // trailing metadata for this call.
713 auto send_trailing_metadata = [call_data, stream = stream->InternalRef()](
714 ServerMetadataHandle
715 server_trailing_metadata) {
716 bool is_cancellation =
717 server_trailing_metadata->get(GrpcCallWasCancelled()).value_or(false);
718 return GetContext<BatchBuilder>()->SendServerTrailingMetadata(
719 stream->batch_target(), std::move(server_trailing_metadata),
720 is_cancellation ||
721 !std::exchange(call_data->sent_initial_metadata, true));
722 };
723
724 // Runs the receive message loop, either until all the messages
725 // are received or the server call is complete.
726 party->Spawn(
727 "recv_messages",
728 Race(
729 Map(stream->WaitFinished(), [](Empty) { return absl::OkStatus(); }),
730 Map(stream->RecvMessages(&call_data->client_to_server.sender, true),
731 [failure_latch = &call_data->failure_latch](absl::Status status) {
732 if (!status.ok() && !failure_latch->is_set()) {
733 failure_latch->Set(ServerMetadataFromStatus(status));
734 }
735 return status;
736 })),
737 [](absl::Status) {});
738
739 // Run a promise that will send initial metadata (if that pipe sends some).
740 // And then run the send message loop until that completes.
741
742 auto send_initial_metadata = Seq(
743 Race(Map(stream->WaitFinished(),
744 [](Empty) { return NextResult<ServerMetadataHandle>(true); }),
745 call_data->server_initial_metadata.receiver.Next()),
746 [call_data, stream = stream->InternalRef()](
747 NextResult<ServerMetadataHandle> next_result) mutable {
748 auto md = !call_data->sent_initial_metadata && next_result.has_value()
749 ? std::move(next_result.value())
750 : nullptr;
751 if (md != nullptr) {
752 call_data->sent_initial_metadata = true;
753 auto* party = static_cast<Party*>(Activity::current());
754 party->Spawn("connected/send_initial_metadata",
755 GetContext<BatchBuilder>()->SendServerInitialMetadata(
756 stream->batch_target(), std::move(md)),
757 [](absl::Status) {});
758 return Immediate(absl::OkStatus());
759 }
760 return Immediate(absl::CancelledError());
761 });
762 party->Spawn(
763 "send_initial_metadata_then_messages",
764 Race(Map(stream->WaitFinished(), [](Empty) { return absl::OkStatus(); }),
765 TrySeq(std::move(send_initial_metadata),
766 stream->SendMessages(&call_data->server_to_client.receiver))),
767 [](absl::Status) {});
768
769 // Spawn a job to fetch the "client trailing metadata" - if this is OK then
770 // it's client done, otherwise it's a signal of cancellation from the client
771 // which we'll use failure_latch to signal.
772
773 party->Spawn(
774 "recv_trailing_metadata",
775 Seq(GetContext<BatchBuilder>()->ReceiveClientTrailingMetadata(
776 stream->batch_target()),
777 [failure_latch = &call_data->failure_latch](
778 absl::StatusOr<ClientMetadataHandle> status) mutable {
779 if (grpc_call_trace.enabled()) {
780 gpr_log(
781 GPR_DEBUG,
782 "%s[connected] Got trailing metadata; status=%s metadata=%s",
783 Activity::current()->DebugTag().c_str(),
784 status.status().ToString().c_str(),
785 status.ok() ? (*status)->DebugString().c_str() : "<none>");
786 }
787 ClientMetadataHandle trailing_metadata;
788 if (status.ok()) {
789 trailing_metadata = std::move(*status);
790 } else {
791 trailing_metadata =
792 GetContext<Arena>()->MakePooled<ClientMetadata>(
793 GetContext<Arena>());
794 grpc_status_code status_code = GRPC_STATUS_UNKNOWN;
795 std::string message;
796 grpc_error_get_status(status.status(), Timestamp::InfFuture(),
797 &status_code, &message, nullptr, nullptr);
798 trailing_metadata->Set(GrpcStatusMetadata(), status_code);
799 trailing_metadata->Set(GrpcMessageMetadata(),
800 Slice::FromCopiedString(message));
801 }
802 if (trailing_metadata->get(GrpcStatusMetadata())
803 .value_or(GRPC_STATUS_UNKNOWN) != GRPC_STATUS_OK) {
804 if (!failure_latch->is_set()) {
805 failure_latch->Set(std::move(trailing_metadata));
806 }
807 }
808 return Empty{};
809 }),
810 [](Empty) {});
811
812 // Finally assemble the main call promise:
813 // Receive initial metadata from the client and start the promise up the
814 // filter stack.
815 // Upon completion, send trailing metadata to the client and then return it
816 // (allowing the call code to decide on what signalling to give the
817 // application).
818
819 struct CleanupPollingEntityLatch {
820 void operator()(Latch<grpc_polling_entity>* latch) {
821 if (!latch->is_set()) latch->Set(grpc_polling_entity());
822 }
823 };
824 auto cleanup_polling_entity_latch =
825 std::unique_ptr<Latch<grpc_polling_entity>, CleanupPollingEntityLatch>(
826 &call_data->polling_entity_latch);
827 struct CleanupSendInitialMetadata {
828 void operator()(CallData* call_data) {
829 call_data->server_initial_metadata.receiver.CloseWithError();
830 }
831 };
832 auto cleanup_send_initial_metadata =
833 std::unique_ptr<CallData, CleanupSendInitialMetadata>(call_data);
834
835 return Map(
836 Seq(std::move(recv_initial_metadata_then_run_promise),
837 std::move(send_trailing_metadata)),
838 [cleanup_polling_entity_latch = std::move(cleanup_polling_entity_latch),
839 cleanup_send_initial_metadata = std::move(cleanup_send_initial_metadata),
840 stream = std::move(stream)](ServerMetadataHandle md) {
841 stream->set_finished();
842 return md;
843 });
844 }
845 #endif
846
847 template <ArenaPromise<ServerMetadataHandle> (*make_call_promise)(
848 grpc_transport*, CallArgs, NextPromiseFactory)>
MakeConnectedFilter()849 grpc_channel_filter MakeConnectedFilter() {
850 // Create a vtable that contains both the legacy call methods (for filter
851 // stack based calls) and the new promise based method for creating
852 // promise based calls (the latter iff make_call_promise != nullptr). In
853 // this way the filter can be inserted into either kind of channel stack,
854 // and only if all the filters in the stack are promise based will the
855 // call be promise based.
856 auto make_call_wrapper = +[](grpc_channel_element* elem, CallArgs call_args,
857 NextPromiseFactory next) {
858 grpc_transport* transport =
859 static_cast<channel_data*>(elem->channel_data)->transport;
860 return make_call_promise(transport, std::move(call_args), std::move(next));
861 };
862 return {
863 connected_channel_start_transport_stream_op_batch,
864 make_call_promise != nullptr ? make_call_wrapper : nullptr,
865 connected_channel_start_transport_op,
866 sizeof(call_data),
867 connected_channel_init_call_elem,
868 set_pollset_or_pollset_set,
869 connected_channel_destroy_call_elem,
870 sizeof(channel_data),
871 connected_channel_init_channel_elem,
872 +[](grpc_channel_stack* channel_stack, grpc_channel_element* elem) {
873 // HACK(ctiller): increase call stack size for the channel to make
874 // space for channel data. We need a cleaner (but performant) way to
875 // do this, and I'm not sure what that is yet. This is only "safe"
876 // because call stacks place no additional data after the last call
877 // element, and the last call element MUST be the connected channel.
878 channel_stack->call_stack_size += grpc_transport_stream_size(
879 static_cast<channel_data*>(elem->channel_data)->transport);
880 },
881 connected_channel_destroy_channel_elem,
882 connected_channel_get_channel_info,
883 "connected",
884 };
885 }
886
MakeTransportCallPromise(grpc_transport * transport,CallArgs call_args,NextPromiseFactory)887 ArenaPromise<ServerMetadataHandle> MakeTransportCallPromise(
888 grpc_transport* transport, CallArgs call_args, NextPromiseFactory) {
889 return transport->vtable->make_call_promise(transport, std::move(call_args));
890 }
891
892 const grpc_channel_filter kPromiseBasedTransportFilter =
893 MakeConnectedFilter<MakeTransportCallPromise>();
894
895 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL
896 const grpc_channel_filter kClientEmulatedFilter =
897 MakeConnectedFilter<MakeClientCallPromise>();
898 #else
899 const grpc_channel_filter kClientEmulatedFilter =
900 MakeConnectedFilter<nullptr>();
901 #endif
902
903 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL
904 const grpc_channel_filter kServerEmulatedFilter =
905 MakeConnectedFilter<MakeServerCallPromise>();
906 #else
907 const grpc_channel_filter kServerEmulatedFilter =
908 MakeConnectedFilter<nullptr>();
909 #endif
910
911 } // namespace
912 } // namespace grpc_core
913
grpc_add_connected_filter(grpc_core::ChannelStackBuilder * builder)914 bool grpc_add_connected_filter(grpc_core::ChannelStackBuilder* builder) {
915 grpc_transport* t = builder->transport();
916 GPR_ASSERT(t != nullptr);
917 // Choose the right vtable for the connected filter.
918 // We can't know promise based call or not here (that decision needs the
919 // collaboration of all of the filters on the channel, and we don't want
920 // ordering constraints on when we add filters).
921 // We can know if this results in a promise based call how we'll create
922 // our promise (if indeed we can), and so that is the choice made here.
923 if (t->vtable->make_call_promise != nullptr) {
924 // Option 1, and our ideal: the transport supports promise based calls,
925 // and so we simply use the transport directly.
926 builder->AppendFilter(&grpc_core::kPromiseBasedTransportFilter);
927 } else if (grpc_channel_stack_type_is_client(builder->channel_stack_type())) {
928 // Option 2: the transport does not support promise based calls, but
929 // we're on the client and so we have an implementation that we can use
930 // to convert to batches.
931 builder->AppendFilter(&grpc_core::kClientEmulatedFilter);
932 } else {
933 // Option 3: the transport does not support promise based calls, and
934 // we're on the server so we use the server filter.
935 builder->AppendFilter(&grpc_core::kServerEmulatedFilter);
936 }
937 return true;
938 }
939