1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/lib/channel/connected_channel.h"
22 
23 #include <inttypes.h>
24 
25 #include <functional>
26 #include <initializer_list>
27 #include <memory>
28 #include <string>
29 #include <type_traits>
30 #include <utility>
31 
32 #include "absl/status/status.h"
33 #include "absl/status/statusor.h"
34 #include "absl/types/optional.h"
35 #include "absl/types/variant.h"
36 
37 #include <grpc/grpc.h>
38 #include <grpc/status.h>
39 #include <grpc/support/alloc.h>
40 #include <grpc/support/log.h>
41 
42 #include "src/core/lib/channel/channel_args.h"
43 #include "src/core/lib/channel/channel_fwd.h"
44 #include "src/core/lib/channel/channel_stack.h"
45 #include "src/core/lib/debug/trace.h"
46 #include "src/core/lib/experiments/experiments.h"
47 #include "src/core/lib/gpr/alloc.h"
48 #include "src/core/lib/gprpp/debug_location.h"
49 #include "src/core/lib/gprpp/orphanable.h"
50 #include "src/core/lib/gprpp/ref_counted_ptr.h"
51 #include "src/core/lib/gprpp/time.h"
52 #include "src/core/lib/iomgr/call_combiner.h"
53 #include "src/core/lib/iomgr/closure.h"
54 #include "src/core/lib/iomgr/error.h"
55 #include "src/core/lib/iomgr/polling_entity.h"
56 #include "src/core/lib/promise/activity.h"
57 #include "src/core/lib/promise/arena_promise.h"
58 #include "src/core/lib/promise/context.h"
59 #include "src/core/lib/promise/detail/basic_seq.h"
60 #include "src/core/lib/promise/detail/status.h"
61 #include "src/core/lib/promise/for_each.h"
62 #include "src/core/lib/promise/if.h"
63 #include "src/core/lib/promise/latch.h"
64 #include "src/core/lib/promise/loop.h"
65 #include "src/core/lib/promise/map.h"
66 #include "src/core/lib/promise/party.h"
67 #include "src/core/lib/promise/pipe.h"
68 #include "src/core/lib/promise/poll.h"
69 #include "src/core/lib/promise/promise.h"
70 #include "src/core/lib/promise/race.h"
71 #include "src/core/lib/promise/seq.h"
72 #include "src/core/lib/promise/try_seq.h"
73 #include "src/core/lib/resource_quota/arena.h"
74 #include "src/core/lib/slice/slice.h"
75 #include "src/core/lib/slice/slice_buffer.h"
76 #include "src/core/lib/surface/call.h"
77 #include "src/core/lib/surface/call_trace.h"
78 #include "src/core/lib/surface/channel_stack_type.h"
79 #include "src/core/lib/transport/batch_builder.h"
80 #include "src/core/lib/transport/error_utils.h"
81 #include "src/core/lib/transport/metadata_batch.h"
82 #include "src/core/lib/transport/transport.h"
83 #include "src/core/lib/transport/transport_fwd.h"
84 #include "src/core/lib/transport/transport_impl.h"
85 
86 typedef struct connected_channel_channel_data {
87   grpc_transport* transport;
88 } channel_data;
89 
90 struct callback_state {
91   grpc_closure closure;
92   grpc_closure* original_closure;
93   grpc_core::CallCombiner* call_combiner;
94   const char* reason;
95 };
96 typedef struct connected_channel_call_data {
97   grpc_core::CallCombiner* call_combiner;
98   // Closures used for returning results on the call combiner.
99   callback_state on_complete[6];  // Max number of pending batches.
100   callback_state recv_initial_metadata_ready;
101   callback_state recv_message_ready;
102   callback_state recv_trailing_metadata_ready;
103 } call_data;
104 
run_in_call_combiner(void * arg,grpc_error_handle error)105 static void run_in_call_combiner(void* arg, grpc_error_handle error) {
106   callback_state* state = static_cast<callback_state*>(arg);
107   GRPC_CALL_COMBINER_START(state->call_combiner, state->original_closure, error,
108                            state->reason);
109 }
110 
run_cancel_in_call_combiner(void * arg,grpc_error_handle error)111 static void run_cancel_in_call_combiner(void* arg, grpc_error_handle error) {
112   run_in_call_combiner(arg, error);
113   gpr_free(arg);
114 }
115 
intercept_callback(call_data * calld,callback_state * state,bool free_when_done,const char * reason,grpc_closure ** original_closure)116 static void intercept_callback(call_data* calld, callback_state* state,
117                                bool free_when_done, const char* reason,
118                                grpc_closure** original_closure) {
119   state->original_closure = *original_closure;
120   state->call_combiner = calld->call_combiner;
121   state->reason = reason;
122   *original_closure = GRPC_CLOSURE_INIT(
123       &state->closure,
124       free_when_done ? run_cancel_in_call_combiner : run_in_call_combiner,
125       state, grpc_schedule_on_exec_ctx);
126 }
127 
get_state_for_batch(call_data * calld,grpc_transport_stream_op_batch * batch)128 static callback_state* get_state_for_batch(
129     call_data* calld, grpc_transport_stream_op_batch* batch) {
130   if (batch->send_initial_metadata) return &calld->on_complete[0];
131   if (batch->send_message) return &calld->on_complete[1];
132   if (batch->send_trailing_metadata) return &calld->on_complete[2];
133   if (batch->recv_initial_metadata) return &calld->on_complete[3];
134   if (batch->recv_message) return &calld->on_complete[4];
135   if (batch->recv_trailing_metadata) return &calld->on_complete[5];
136   GPR_UNREACHABLE_CODE(return nullptr);
137 }
138 
139 // We perform a small hack to locate transport data alongside the connected
140 // channel data in call allocations, to allow everything to be pulled in minimal
141 // cache line requests
142 #define TRANSPORT_STREAM_FROM_CALL_DATA(calld) \
143   ((grpc_stream*)(((char*)(calld)) +           \
144                   GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(call_data))))
145 #define CALL_DATA_FROM_TRANSPORT_STREAM(transport_stream) \
146   ((call_data*)(((char*)(transport_stream)) -             \
147                 GPR_ROUND_UP_TO_ALIGNMENT_SIZE(sizeof(call_data))))
148 
149 // Intercept a call operation and either push it directly up or translate it
150 // into transport stream operations
connected_channel_start_transport_stream_op_batch(grpc_call_element * elem,grpc_transport_stream_op_batch * batch)151 static void connected_channel_start_transport_stream_op_batch(
152     grpc_call_element* elem, grpc_transport_stream_op_batch* batch) {
153   call_data* calld = static_cast<call_data*>(elem->call_data);
154   channel_data* chand = static_cast<channel_data*>(elem->channel_data);
155   if (batch->recv_initial_metadata) {
156     callback_state* state = &calld->recv_initial_metadata_ready;
157     intercept_callback(
158         calld, state, false, "recv_initial_metadata_ready",
159         &batch->payload->recv_initial_metadata.recv_initial_metadata_ready);
160   }
161   if (batch->recv_message) {
162     callback_state* state = &calld->recv_message_ready;
163     intercept_callback(calld, state, false, "recv_message_ready",
164                        &batch->payload->recv_message.recv_message_ready);
165   }
166   if (batch->recv_trailing_metadata) {
167     callback_state* state = &calld->recv_trailing_metadata_ready;
168     intercept_callback(
169         calld, state, false, "recv_trailing_metadata_ready",
170         &batch->payload->recv_trailing_metadata.recv_trailing_metadata_ready);
171   }
172   if (batch->cancel_stream) {
173     // There can be more than one cancellation batch in flight at any
174     // given time, so we can't just pick out a fixed index into
175     // calld->on_complete like we can for the other ops.  However,
176     // cancellation isn't in the fast path, so we just allocate a new
177     // closure for each one.
178     callback_state* state =
179         static_cast<callback_state*>(gpr_malloc(sizeof(*state)));
180     intercept_callback(calld, state, true, "on_complete (cancel_stream)",
181                        &batch->on_complete);
182   } else if (batch->on_complete != nullptr) {
183     callback_state* state = get_state_for_batch(calld, batch);
184     intercept_callback(calld, state, false, "on_complete", &batch->on_complete);
185   }
186   grpc_transport_perform_stream_op(
187       chand->transport, TRANSPORT_STREAM_FROM_CALL_DATA(calld), batch);
188   GRPC_CALL_COMBINER_STOP(calld->call_combiner, "passed batch to transport");
189 }
190 
connected_channel_start_transport_op(grpc_channel_element * elem,grpc_transport_op * op)191 static void connected_channel_start_transport_op(grpc_channel_element* elem,
192                                                  grpc_transport_op* op) {
193   channel_data* chand = static_cast<channel_data*>(elem->channel_data);
194   grpc_transport_perform_op(chand->transport, op);
195 }
196 
197 // Constructor for call_data
connected_channel_init_call_elem(grpc_call_element * elem,const grpc_call_element_args * args)198 static grpc_error_handle connected_channel_init_call_elem(
199     grpc_call_element* elem, const grpc_call_element_args* args) {
200   call_data* calld = static_cast<call_data*>(elem->call_data);
201   channel_data* chand = static_cast<channel_data*>(elem->channel_data);
202   calld->call_combiner = args->call_combiner;
203   int r = grpc_transport_init_stream(
204       chand->transport, TRANSPORT_STREAM_FROM_CALL_DATA(calld),
205       &args->call_stack->refcount, args->server_transport_data, args->arena);
206   return r == 0 ? absl::OkStatus()
207                 : GRPC_ERROR_CREATE("transport stream initialization failed");
208 }
209 
set_pollset_or_pollset_set(grpc_call_element * elem,grpc_polling_entity * pollent)210 static void set_pollset_or_pollset_set(grpc_call_element* elem,
211                                        grpc_polling_entity* pollent) {
212   call_data* calld = static_cast<call_data*>(elem->call_data);
213   channel_data* chand = static_cast<channel_data*>(elem->channel_data);
214   grpc_transport_set_pops(chand->transport,
215                           TRANSPORT_STREAM_FROM_CALL_DATA(calld), pollent);
216 }
217 
218 // Destructor for call_data
connected_channel_destroy_call_elem(grpc_call_element * elem,const grpc_call_final_info *,grpc_closure * then_schedule_closure)219 static void connected_channel_destroy_call_elem(
220     grpc_call_element* elem, const grpc_call_final_info* /*final_info*/,
221     grpc_closure* then_schedule_closure) {
222   call_data* calld = static_cast<call_data*>(elem->call_data);
223   channel_data* chand = static_cast<channel_data*>(elem->channel_data);
224   grpc_transport_destroy_stream(chand->transport,
225                                 TRANSPORT_STREAM_FROM_CALL_DATA(calld),
226                                 then_schedule_closure);
227 }
228 
229 // Constructor for channel_data
connected_channel_init_channel_elem(grpc_channel_element * elem,grpc_channel_element_args * args)230 static grpc_error_handle connected_channel_init_channel_elem(
231     grpc_channel_element* elem, grpc_channel_element_args* args) {
232   channel_data* cd = static_cast<channel_data*>(elem->channel_data);
233   GPR_ASSERT(args->is_last);
234   cd->transport = args->channel_args.GetObject<grpc_transport>();
235   return absl::OkStatus();
236 }
237 
238 // Destructor for channel_data
connected_channel_destroy_channel_elem(grpc_channel_element * elem)239 static void connected_channel_destroy_channel_elem(grpc_channel_element* elem) {
240   channel_data* cd = static_cast<channel_data*>(elem->channel_data);
241   if (cd->transport) {
242     grpc_transport_destroy(cd->transport);
243   }
244 }
245 
246 // No-op.
connected_channel_get_channel_info(grpc_channel_element *,const grpc_channel_info *)247 static void connected_channel_get_channel_info(
248     grpc_channel_element* /*elem*/, const grpc_channel_info* /*channel_info*/) {
249 }
250 
251 namespace grpc_core {
252 namespace {
253 
254 #if defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL) || \
255     defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL)
256 class ConnectedChannelStream : public Orphanable {
257  public:
ConnectedChannelStream(grpc_transport * transport)258   explicit ConnectedChannelStream(grpc_transport* transport)
259       : transport_(transport), stream_(nullptr, StreamDeleter(this)) {
260     GRPC_STREAM_REF_INIT(
261         &stream_refcount_, 1,
262         [](void* p, grpc_error_handle) {
263           static_cast<ConnectedChannelStream*>(p)->BeginDestroy();
264         },
265         this, "ConnectedChannelStream");
266   }
267 
transport()268   grpc_transport* transport() { return transport_; }
stream_destroyed_closure()269   grpc_closure* stream_destroyed_closure() { return &stream_destroyed_; }
270 
batch_target()271   BatchBuilder::Target batch_target() {
272     return BatchBuilder::Target{transport_, stream_.get(), &stream_refcount_};
273   }
274 
IncrementRefCount(const char * reason="smartptr")275   void IncrementRefCount(const char* reason = "smartptr") {
276 #ifndef NDEBUG
277     grpc_stream_ref(&stream_refcount_, reason);
278 #else
279     (void)reason;
280     grpc_stream_ref(&stream_refcount_);
281 #endif
282   }
283 
Unref(const char * reason="smartptr")284   void Unref(const char* reason = "smartptr") {
285 #ifndef NDEBUG
286     grpc_stream_unref(&stream_refcount_, reason);
287 #else
288     (void)reason;
289     grpc_stream_unref(&stream_refcount_);
290 #endif
291   }
292 
InternalRef()293   RefCountedPtr<ConnectedChannelStream> InternalRef() {
294     IncrementRefCount("smartptr");
295     return RefCountedPtr<ConnectedChannelStream>(this);
296   }
297 
Orphan()298   void Orphan() final {
299     bool finished = finished_.IsSet();
300     if (grpc_call_trace.enabled()) {
301       gpr_log(GPR_DEBUG, "%s[connected] Orphan stream, finished: %d",
302               party_->DebugTag().c_str(), finished);
303     }
304     // If we hadn't already observed the stream to be finished, we need to
305     // cancel it at the transport.
306     if (!finished) {
307       party_->Spawn(
308           "finish",
309           [self = InternalRef()]() {
310             if (!self->finished_.IsSet()) {
311               self->finished_.Set();
312             }
313             return Empty{};
314           },
315           [](Empty) {});
316       GetContext<BatchBuilder>()->Cancel(batch_target(),
317                                          absl::CancelledError());
318     }
319     Unref("orphan connected stream");
320   }
321 
322   // Returns a promise that implements the receive message loop.
323   auto RecvMessages(PipeSender<MessageHandle>* incoming_messages,
324                     bool cancel_on_error);
325   // Returns a promise that implements the send message loop.
326   auto SendMessages(PipeReceiver<MessageHandle>* outgoing_messages);
327 
SetStream(grpc_stream * stream)328   void SetStream(grpc_stream* stream) { stream_.reset(stream); }
stream()329   grpc_stream* stream() { return stream_.get(); }
stream_refcount()330   grpc_stream_refcount* stream_refcount() { return &stream_refcount_; }
331 
set_finished()332   void set_finished() { finished_.Set(); }
WaitFinished()333   auto WaitFinished() { return finished_.Wait(); }
334 
335  private:
336   class StreamDeleter {
337    public:
StreamDeleter(ConnectedChannelStream * impl)338     explicit StreamDeleter(ConnectedChannelStream* impl) : impl_(impl) {}
operator ()(grpc_stream * stream) const339     void operator()(grpc_stream* stream) const {
340       if (stream == nullptr) return;
341       grpc_transport_destroy_stream(impl_->transport(), stream,
342                                     impl_->stream_destroyed_closure());
343     }
344 
345    private:
346     ConnectedChannelStream* impl_;
347   };
348   using StreamPtr = std::unique_ptr<grpc_stream, StreamDeleter>;
349 
StreamDestroyed()350   void StreamDestroyed() {
351     call_context_->RunInContext([this] { this->~ConnectedChannelStream(); });
352   }
353 
BeginDestroy()354   void BeginDestroy() {
355     if (stream_ != nullptr) {
356       stream_.reset();
357     } else {
358       StreamDestroyed();
359     }
360   }
361 
362   grpc_transport* const transport_;
363   RefCountedPtr<CallContext> const call_context_{
364       GetContext<CallContext>()->Ref()};
365   grpc_closure stream_destroyed_ =
366       MakeMemberClosure<ConnectedChannelStream,
367                         &ConnectedChannelStream::StreamDestroyed>(
368           this, DEBUG_LOCATION);
369   grpc_stream_refcount stream_refcount_;
370   StreamPtr stream_;
371   Arena* arena_ = GetContext<Arena>();
372   Party* const party_ = static_cast<Party*>(Activity::current());
373   ExternallyObservableLatch<void> finished_;
374 };
375 
RecvMessages(PipeSender<MessageHandle> * incoming_messages,bool cancel_on_error)376 auto ConnectedChannelStream::RecvMessages(
377     PipeSender<MessageHandle>* incoming_messages, bool cancel_on_error) {
378   return Loop([self = InternalRef(), cancel_on_error,
379                incoming_messages = std::move(*incoming_messages)]() mutable {
380     return Seq(
381         GetContext<BatchBuilder>()->ReceiveMessage(self->batch_target()),
382         [cancel_on_error, &incoming_messages](
383             absl::StatusOr<absl::optional<MessageHandle>> status) mutable {
384           bool has_message = status.ok() && status->has_value();
385           auto publish_message = [&incoming_messages, &status]() {
386             auto pending_message = std::move(**status);
387             if (grpc_call_trace.enabled()) {
388               gpr_log(GPR_INFO,
389                       "%s[connected] RecvMessage: received payload of %" PRIdPTR
390                       " bytes",
391                       Activity::current()->DebugTag().c_str(),
392                       pending_message->payload()->Length());
393             }
394             return Map(incoming_messages.Push(std::move(pending_message)),
395                        [](bool ok) -> LoopCtl<absl::Status> {
396                          if (!ok) {
397                            if (grpc_call_trace.enabled()) {
398                              gpr_log(GPR_INFO,
399                                      "%s[connected] RecvMessage: failed to "
400                                      "push message towards the application",
401                                      Activity::current()->DebugTag().c_str());
402                            }
403                            return absl::OkStatus();
404                          }
405                          return Continue{};
406                        });
407           };
408           auto publish_close = [cancel_on_error, &incoming_messages,
409                                 &status]() mutable {
410             if (grpc_call_trace.enabled()) {
411               gpr_log(GPR_INFO,
412                       "%s[connected] RecvMessage: reached end of stream with "
413                       "status:%s",
414                       Activity::current()->DebugTag().c_str(),
415                       status.status().ToString().c_str());
416             }
417             if (cancel_on_error && !status.ok()) {
418               incoming_messages.CloseWithError();
419             }
420             return Immediate(LoopCtl<absl::Status>(status.status()));
421           };
422           return If(has_message, std::move(publish_message),
423                     std::move(publish_close));
424         });
425   });
426 }
427 
SendMessages(PipeReceiver<MessageHandle> * outgoing_messages)428 auto ConnectedChannelStream::SendMessages(
429     PipeReceiver<MessageHandle>* outgoing_messages) {
430   return ForEach(std::move(*outgoing_messages),
431                  [self = InternalRef()](MessageHandle message) {
432                    return GetContext<BatchBuilder>()->SendMessage(
433                        self->batch_target(), std::move(message));
434                  });
435 }
436 #endif  // defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL) ||
437         // defined(GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL)
438 
439 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL
MakeClientCallPromise(grpc_transport * transport,CallArgs call_args,NextPromiseFactory)440 ArenaPromise<ServerMetadataHandle> MakeClientCallPromise(
441     grpc_transport* transport, CallArgs call_args, NextPromiseFactory) {
442   OrphanablePtr<ConnectedChannelStream> stream(
443       GetContext<Arena>()->New<ConnectedChannelStream>(transport));
444   stream->SetStream(static_cast<grpc_stream*>(
445       GetContext<Arena>()->Alloc(transport->vtable->sizeof_stream)));
446   grpc_transport_init_stream(transport, stream->stream(),
447                              stream->stream_refcount(), nullptr,
448                              GetContext<Arena>());
449   auto* party = static_cast<Party*>(Activity::current());
450   party->Spawn(
451       "set_polling_entity", call_args.polling_entity->Wait(),
452       [transport,
453        stream = stream->InternalRef()](grpc_polling_entity polling_entity) {
454         grpc_transport_set_pops(transport, stream->stream(), &polling_entity);
455       });
456   // Start a loop to send messages from client_to_server_messages to the
457   // transport. When the pipe closes and the loop completes, send a trailing
458   // metadata batch to close the stream.
459   party->Spawn(
460       "send_messages",
461       TrySeq(stream->SendMessages(call_args.client_to_server_messages),
462              [stream = stream->InternalRef()]() {
463                return GetContext<BatchBuilder>()->SendClientTrailingMetadata(
464                    stream->batch_target());
465              }),
466       [](absl::Status) {});
467   // Start a promise to receive server initial metadata and then forward it up
468   // through the receiving pipe.
469   auto server_initial_metadata =
470       GetContext<Arena>()->MakePooled<ServerMetadata>(GetContext<Arena>());
471   party->Spawn(
472       "recv_initial_metadata",
473       TrySeq(GetContext<BatchBuilder>()->ReceiveServerInitialMetadata(
474                  stream->batch_target()),
475              [pipe = call_args.server_initial_metadata](
476                  ServerMetadataHandle server_initial_metadata) {
477                if (grpc_call_trace.enabled()) {
478                  gpr_log(GPR_DEBUG,
479                          "%s[connected] Publish client initial metadata: %s",
480                          Activity::current()->DebugTag().c_str(),
481                          server_initial_metadata->DebugString().c_str());
482                }
483                return Map(pipe->Push(std::move(server_initial_metadata)),
484                           [](bool r) {
485                             if (r) return absl::OkStatus();
486                             return absl::CancelledError();
487                           });
488              }),
489       [](absl::Status) {});
490 
491   // Build up the rest of the main call promise:
492 
493   // Create a promise that will send initial metadata and then signal completion
494   // of that via the token.
495   auto send_initial_metadata = Seq(
496       GetContext<BatchBuilder>()->SendClientInitialMetadata(
497           stream->batch_target(), std::move(call_args.client_initial_metadata)),
498       [sent_initial_metadata_token =
499            std::move(call_args.client_initial_metadata_outstanding)](
500           absl::Status status) mutable {
501         sent_initial_metadata_token.Complete(status.ok());
502         return status;
503       });
504   // Create a promise that will receive server trailing metadata.
505   // If this fails, we massage the error into metadata that we can report
506   // upwards.
507   auto server_trailing_metadata =
508       GetContext<Arena>()->MakePooled<ServerMetadata>(GetContext<Arena>());
509   auto recv_trailing_metadata =
510       Map(GetContext<BatchBuilder>()->ReceiveServerTrailingMetadata(
511               stream->batch_target()),
512           [](absl::StatusOr<ServerMetadataHandle> status) mutable {
513             if (!status.ok()) {
514               auto server_trailing_metadata =
515                   GetContext<Arena>()->MakePooled<ServerMetadata>(
516                       GetContext<Arena>());
517               grpc_status_code status_code = GRPC_STATUS_UNKNOWN;
518               std::string message;
519               grpc_error_get_status(status.status(), Timestamp::InfFuture(),
520                                     &status_code, &message, nullptr, nullptr);
521               server_trailing_metadata->Set(GrpcStatusMetadata(), status_code);
522               server_trailing_metadata->Set(GrpcMessageMetadata(),
523                                             Slice::FromCopiedString(message));
524               return server_trailing_metadata;
525             } else {
526               return std::move(*status);
527             }
528           });
529   // Finally the main call promise.
530   // Concurrently: send initial metadata and receive messages, until BOTH
531   // complete (or one fails).
532   // Next: receive trailing metadata, and return that up the stack.
533   auto recv_messages =
534       stream->RecvMessages(call_args.server_to_client_messages, false);
535   return Map(
536       [send_initial_metadata = std::move(send_initial_metadata),
537        recv_messages = std::move(recv_messages),
538        recv_trailing_metadata = std::move(recv_trailing_metadata),
539        done_send_initial_metadata = false, done_recv_messages = false,
540        done_recv_trailing_metadata =
541            false]() mutable -> Poll<ServerMetadataHandle> {
542         if (!done_send_initial_metadata) {
543           auto p = send_initial_metadata();
544           if (auto* r = p.value_if_ready()) {
545             done_send_initial_metadata = true;
546             if (!r->ok()) return StatusCast<ServerMetadataHandle>(*r);
547           }
548         }
549         if (!done_recv_messages) {
550           auto p = recv_messages();
551           if (auto* r = p.value_if_ready()) {
552             // NOTE: ignore errors here, they'll be collected in the
553             // recv_trailing_metadata.
554             done_recv_messages = true;
555           } else {
556             return Pending{};
557           }
558         }
559         if (!done_recv_trailing_metadata) {
560           auto p = recv_trailing_metadata();
561           if (auto* r = p.value_if_ready()) {
562             done_recv_trailing_metadata = true;
563             return std::move(*r);
564           }
565         }
566         return Pending{};
567       },
568       [stream = std::move(stream)](ServerMetadataHandle result) {
569         stream->set_finished();
570         return result;
571       });
572 }
573 #endif
574 
575 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL
MakeServerCallPromise(grpc_transport * transport,CallArgs,NextPromiseFactory next_promise_factory)576 ArenaPromise<ServerMetadataHandle> MakeServerCallPromise(
577     grpc_transport* transport, CallArgs,
578     NextPromiseFactory next_promise_factory) {
579   OrphanablePtr<ConnectedChannelStream> stream(
580       GetContext<Arena>()->New<ConnectedChannelStream>(transport));
581 
582   stream->SetStream(static_cast<grpc_stream*>(
583       GetContext<Arena>()->Alloc(transport->vtable->sizeof_stream)));
584   grpc_transport_init_stream(
585       transport, stream->stream(), stream->stream_refcount(),
586       GetContext<CallContext>()->server_call_context()->server_stream_data(),
587       GetContext<Arena>());
588   auto* party = static_cast<Party*>(Activity::current());
589 
590   // Arifacts we need for the lifetime of the call.
591   struct CallData {
592     Pipe<MessageHandle> server_to_client;
593     Pipe<MessageHandle> client_to_server;
594     Pipe<ServerMetadataHandle> server_initial_metadata;
595     Latch<ServerMetadataHandle> failure_latch;
596     Latch<grpc_polling_entity> polling_entity_latch;
597     bool sent_initial_metadata = false;
598     bool sent_trailing_metadata = false;
599   };
600   auto* call_data = GetContext<Arena>()->ManagedNew<CallData>();
601 
602   party->Spawn(
603       "set_polling_entity", call_data->polling_entity_latch.Wait(),
604       [transport,
605        stream = stream->InternalRef()](grpc_polling_entity polling_entity) {
606         grpc_transport_set_pops(transport, stream->stream(), &polling_entity);
607       });
608 
609   auto server_to_client_empty =
610       call_data->server_to_client.receiver.AwaitEmpty();
611 
612   // Create a promise that will receive client initial metadata, and then run
613   // the main stem of the call (calling next_promise_factory up through the
614   // filters).
615   // Race the main call with failure_latch, allowing us to forcefully complete
616   // the call in the case of a failure.
617   auto recv_initial_metadata_then_run_promise =
618       TrySeq(GetContext<BatchBuilder>()->ReceiveClientInitialMetadata(
619                  stream->batch_target()),
620              [next_promise_factory = std::move(next_promise_factory),
621               server_to_client_empty = std::move(server_to_client_empty),
622               call_data](ClientMetadataHandle client_initial_metadata) {
623                auto call_promise = next_promise_factory(CallArgs{
624                    std::move(client_initial_metadata),
625                    ClientInitialMetadataOutstandingToken::Empty(),
626                    &call_data->polling_entity_latch,
627                    &call_data->server_initial_metadata.sender,
628                    &call_data->client_to_server.receiver,
629                    &call_data->server_to_client.sender,
630                });
631                return Race(call_data->failure_latch.Wait(),
632                            [call_promise = std::move(call_promise),
633                             server_to_client_empty =
634                                 std::move(server_to_client_empty)]() mutable
635                            -> Poll<ServerMetadataHandle> {
636                              // TODO(ctiller): this is deeply weird and we need
637                              // to clean this up.
638                              //
639                              // The following few lines check to ensure that
640                              // there's no message currently pending in the
641                              // outgoing message queue, and if (and only if)
642                              // that's true decides to poll the main promise to
643                              // see if there's a result.
644                              //
645                              // This essentially introduces a polling priority
646                              // scheme that makes the current promise structure
647                              // work out the way we want when talking to
648                              // transports.
649                              //
650                              // The problem is that transports are going to need
651                              // to replicate this structure when they convert to
652                              // promises, and that becomes troubling as we'll be
653                              // replicating weird throughout the stack.
654                              //
655                              // Instead we likely need to change the way we're
656                              // composing promises through the stack.
657                              //
658                              // Proposed is to change filters from a promise
659                              // that takes ClientInitialMetadata and returns
660                              // ServerTrailingMetadata with three pipes for
661                              // ServerInitialMetadata and
662                              // ClientToServerMessages, ServerToClientMessages.
663                              // Instead we'll have five pipes, moving
664                              // ClientInitialMetadata and ServerTrailingMetadata
665                              // to pipes that can be intercepted.
666                              //
667                              // The effect of this change will be to cripple the
668                              // things that can be done in a filter (but cripple
669                              // in line with what most filters actually do).
670                              // We'll likely need to add a `CallContext::Cancel`
671                              // to allow filters to cancel a request, but this
672                              // would also have the advantage of centralizing
673                              // our cancellation machinery which seems like an
674                              // additional win - with the net effect that the
675                              // shape of the call gets made explicit at the top
676                              // & bottom of the stack.
677                              //
678                              // There's a small set of filters (retry, this one,
679                              // lame client, clinet channel) that terminate
680                              // stacks and need a richer set of semantics, but
681                              // that ends up being fine because we can spawn
682                              // tasks in parties to handle those edge cases, and
683                              // keep the majority of filters simple: they just
684                              // call InterceptAndMap on a handful of filters at
685                              // call initialization time and then proceed to
686                              // actually filter.
687                              //
688                              // So that's the plan, why isn't it enacted here?
689                              //
690                              // Well, the plan ends up being easy to implement
691                              // in the promise based world (I did a prototype on
692                              // a branch in an afternoon). It's heinous to
693                              // implement in promise_based_filter, and that code
694                              // is load bearing for us at the time of writing.
695                              // It's not worth delaying promises for a further N
696                              // months (N ~ 6) to make that change.
697                              //
698                              // Instead, we'll move forward with this, get
699                              // promise_based_filter out of the picture, and
700                              // then during the mop-up phase for promises tweak
701                              // the compute structure to move to the magical
702                              // five pipes (I'm reminded of an old Onion
703                              // article), and end up in a good happy place.
704                              if (server_to_client_empty().pending()) {
705                                return Pending{};
706                              }
707                              return call_promise();
708                            });
709              });
710 
711   // Promise factory that accepts a ServerMetadataHandle, and sends it as the
712   // trailing metadata for this call.
713   auto send_trailing_metadata = [call_data, stream = stream->InternalRef()](
714                                     ServerMetadataHandle
715                                         server_trailing_metadata) {
716     bool is_cancellation =
717         server_trailing_metadata->get(GrpcCallWasCancelled()).value_or(false);
718     return GetContext<BatchBuilder>()->SendServerTrailingMetadata(
719         stream->batch_target(), std::move(server_trailing_metadata),
720         is_cancellation ||
721             !std::exchange(call_data->sent_initial_metadata, true));
722   };
723 
724   // Runs the receive message loop, either until all the messages
725   // are received or the server call is complete.
726   party->Spawn(
727       "recv_messages",
728       Race(
729           Map(stream->WaitFinished(), [](Empty) { return absl::OkStatus(); }),
730           Map(stream->RecvMessages(&call_data->client_to_server.sender, true),
731               [failure_latch = &call_data->failure_latch](absl::Status status) {
732                 if (!status.ok() && !failure_latch->is_set()) {
733                   failure_latch->Set(ServerMetadataFromStatus(status));
734                 }
735                 return status;
736               })),
737       [](absl::Status) {});
738 
739   // Run a promise that will send initial metadata (if that pipe sends some).
740   // And then run the send message loop until that completes.
741 
742   auto send_initial_metadata = Seq(
743       Race(Map(stream->WaitFinished(),
744                [](Empty) { return NextResult<ServerMetadataHandle>(true); }),
745            call_data->server_initial_metadata.receiver.Next()),
746       [call_data, stream = stream->InternalRef()](
747           NextResult<ServerMetadataHandle> next_result) mutable {
748         auto md = !call_data->sent_initial_metadata && next_result.has_value()
749                       ? std::move(next_result.value())
750                       : nullptr;
751         if (md != nullptr) {
752           call_data->sent_initial_metadata = true;
753           auto* party = static_cast<Party*>(Activity::current());
754           party->Spawn("connected/send_initial_metadata",
755                        GetContext<BatchBuilder>()->SendServerInitialMetadata(
756                            stream->batch_target(), std::move(md)),
757                        [](absl::Status) {});
758           return Immediate(absl::OkStatus());
759         }
760         return Immediate(absl::CancelledError());
761       });
762   party->Spawn(
763       "send_initial_metadata_then_messages",
764       Race(Map(stream->WaitFinished(), [](Empty) { return absl::OkStatus(); }),
765            TrySeq(std::move(send_initial_metadata),
766                   stream->SendMessages(&call_data->server_to_client.receiver))),
767       [](absl::Status) {});
768 
769   // Spawn a job to fetch the "client trailing metadata" - if this is OK then
770   // it's client done, otherwise it's a signal of cancellation from the client
771   // which we'll use failure_latch to signal.
772 
773   party->Spawn(
774       "recv_trailing_metadata",
775       Seq(GetContext<BatchBuilder>()->ReceiveClientTrailingMetadata(
776               stream->batch_target()),
777           [failure_latch = &call_data->failure_latch](
778               absl::StatusOr<ClientMetadataHandle> status) mutable {
779             if (grpc_call_trace.enabled()) {
780               gpr_log(
781                   GPR_DEBUG,
782                   "%s[connected] Got trailing metadata; status=%s metadata=%s",
783                   Activity::current()->DebugTag().c_str(),
784                   status.status().ToString().c_str(),
785                   status.ok() ? (*status)->DebugString().c_str() : "<none>");
786             }
787             ClientMetadataHandle trailing_metadata;
788             if (status.ok()) {
789               trailing_metadata = std::move(*status);
790             } else {
791               trailing_metadata =
792                   GetContext<Arena>()->MakePooled<ClientMetadata>(
793                       GetContext<Arena>());
794               grpc_status_code status_code = GRPC_STATUS_UNKNOWN;
795               std::string message;
796               grpc_error_get_status(status.status(), Timestamp::InfFuture(),
797                                     &status_code, &message, nullptr, nullptr);
798               trailing_metadata->Set(GrpcStatusMetadata(), status_code);
799               trailing_metadata->Set(GrpcMessageMetadata(),
800                                      Slice::FromCopiedString(message));
801             }
802             if (trailing_metadata->get(GrpcStatusMetadata())
803                     .value_or(GRPC_STATUS_UNKNOWN) != GRPC_STATUS_OK) {
804               if (!failure_latch->is_set()) {
805                 failure_latch->Set(std::move(trailing_metadata));
806               }
807             }
808             return Empty{};
809           }),
810       [](Empty) {});
811 
812   // Finally assemble the main call promise:
813   // Receive initial metadata from the client and start the promise up the
814   // filter stack.
815   // Upon completion, send trailing metadata to the client and then return it
816   // (allowing the call code to decide on what signalling to give the
817   // application).
818 
819   struct CleanupPollingEntityLatch {
820     void operator()(Latch<grpc_polling_entity>* latch) {
821       if (!latch->is_set()) latch->Set(grpc_polling_entity());
822     }
823   };
824   auto cleanup_polling_entity_latch =
825       std::unique_ptr<Latch<grpc_polling_entity>, CleanupPollingEntityLatch>(
826           &call_data->polling_entity_latch);
827   struct CleanupSendInitialMetadata {
828     void operator()(CallData* call_data) {
829       call_data->server_initial_metadata.receiver.CloseWithError();
830     }
831   };
832   auto cleanup_send_initial_metadata =
833       std::unique_ptr<CallData, CleanupSendInitialMetadata>(call_data);
834 
835   return Map(
836       Seq(std::move(recv_initial_metadata_then_run_promise),
837           std::move(send_trailing_metadata)),
838       [cleanup_polling_entity_latch = std::move(cleanup_polling_entity_latch),
839        cleanup_send_initial_metadata = std::move(cleanup_send_initial_metadata),
840        stream = std::move(stream)](ServerMetadataHandle md) {
841         stream->set_finished();
842         return md;
843       });
844 }
845 #endif
846 
847 template <ArenaPromise<ServerMetadataHandle> (*make_call_promise)(
848     grpc_transport*, CallArgs, NextPromiseFactory)>
MakeConnectedFilter()849 grpc_channel_filter MakeConnectedFilter() {
850   // Create a vtable that contains both the legacy call methods (for filter
851   // stack based calls) and the new promise based method for creating
852   // promise based calls (the latter iff make_call_promise != nullptr). In
853   // this way the filter can be inserted into either kind of channel stack,
854   // and only if all the filters in the stack are promise based will the
855   // call be promise based.
856   auto make_call_wrapper = +[](grpc_channel_element* elem, CallArgs call_args,
857                                NextPromiseFactory next) {
858     grpc_transport* transport =
859         static_cast<channel_data*>(elem->channel_data)->transport;
860     return make_call_promise(transport, std::move(call_args), std::move(next));
861   };
862   return {
863       connected_channel_start_transport_stream_op_batch,
864       make_call_promise != nullptr ? make_call_wrapper : nullptr,
865       connected_channel_start_transport_op,
866       sizeof(call_data),
867       connected_channel_init_call_elem,
868       set_pollset_or_pollset_set,
869       connected_channel_destroy_call_elem,
870       sizeof(channel_data),
871       connected_channel_init_channel_elem,
872       +[](grpc_channel_stack* channel_stack, grpc_channel_element* elem) {
873         // HACK(ctiller): increase call stack size for the channel to make
874         // space for channel data. We need a cleaner (but performant) way to
875         // do this, and I'm not sure what that is yet. This is only "safe"
876         // because call stacks place no additional data after the last call
877         // element, and the last call element MUST be the connected channel.
878         channel_stack->call_stack_size += grpc_transport_stream_size(
879             static_cast<channel_data*>(elem->channel_data)->transport);
880       },
881       connected_channel_destroy_channel_elem,
882       connected_channel_get_channel_info,
883       "connected",
884   };
885 }
886 
MakeTransportCallPromise(grpc_transport * transport,CallArgs call_args,NextPromiseFactory)887 ArenaPromise<ServerMetadataHandle> MakeTransportCallPromise(
888     grpc_transport* transport, CallArgs call_args, NextPromiseFactory) {
889   return transport->vtable->make_call_promise(transport, std::move(call_args));
890 }
891 
892 const grpc_channel_filter kPromiseBasedTransportFilter =
893     MakeConnectedFilter<MakeTransportCallPromise>();
894 
895 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_CLIENT_CALL
896 const grpc_channel_filter kClientEmulatedFilter =
897     MakeConnectedFilter<MakeClientCallPromise>();
898 #else
899 const grpc_channel_filter kClientEmulatedFilter =
900     MakeConnectedFilter<nullptr>();
901 #endif
902 
903 #ifdef GRPC_EXPERIMENT_IS_INCLUDED_PROMISE_BASED_SERVER_CALL
904 const grpc_channel_filter kServerEmulatedFilter =
905     MakeConnectedFilter<MakeServerCallPromise>();
906 #else
907 const grpc_channel_filter kServerEmulatedFilter =
908     MakeConnectedFilter<nullptr>();
909 #endif
910 
911 }  // namespace
912 }  // namespace grpc_core
913 
grpc_add_connected_filter(grpc_core::ChannelStackBuilder * builder)914 bool grpc_add_connected_filter(grpc_core::ChannelStackBuilder* builder) {
915   grpc_transport* t = builder->transport();
916   GPR_ASSERT(t != nullptr);
917   // Choose the right vtable for the connected filter.
918   // We can't know promise based call or not here (that decision needs the
919   // collaboration of all of the filters on the channel, and we don't want
920   // ordering constraints on when we add filters).
921   // We can know if this results in a promise based call how we'll create
922   // our promise (if indeed we can), and so that is the choice made here.
923   if (t->vtable->make_call_promise != nullptr) {
924     // Option 1, and our ideal: the transport supports promise based calls,
925     // and so we simply use the transport directly.
926     builder->AppendFilter(&grpc_core::kPromiseBasedTransportFilter);
927   } else if (grpc_channel_stack_type_is_client(builder->channel_stack_type())) {
928     // Option 2: the transport does not support promise based calls, but
929     // we're on the client and so we have an implementation that we can use
930     // to convert to batches.
931     builder->AppendFilter(&grpc_core::kClientEmulatedFilter);
932   } else {
933     // Option 3: the transport does not support promise based calls, and
934     // we're on the server so we use the server filter.
935     builder->AppendFilter(&grpc_core::kServerEmulatedFilter);
936   }
937   return true;
938 }
939