xref: /aosp_15_r20/external/pigweed/pw_async2/public/pw_async2/coro.h (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2024 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15 
16 #include <concepts>
17 #include <coroutine>
18 #include <variant>
19 
20 #include "pw_allocator/allocator.h"
21 #include "pw_allocator/layout.h"
22 #include "pw_async2/dispatcher.h"
23 #include "pw_function/function.h"
24 #include "pw_log/log.h"
25 #include "pw_status/status.h"
26 #include "pw_status/try.h"
27 
28 namespace pw::async2 {
29 
30 // Forward-declare `Coro` so that it can be referenced by the promise type APIs.
31 template <std::constructible_from<pw::Status> T>
32 class Coro;
33 
34 /// Context required for creating and executing coroutines.
35 class CoroContext {
36  public:
37   /// Creates a `CoroContext` which will allocate coroutine state using
38   /// `alloc`.
CoroContext(pw::allocator::Allocator & alloc)39   explicit CoroContext(pw::allocator::Allocator& alloc) : alloc_(alloc) {}
alloc()40   pw::allocator::Allocator& alloc() const { return alloc_; }
41 
42  private:
43   pw::allocator::Allocator& alloc_;
44 };
45 
46 // The internal coroutine API implementation details enabling `Coro<T>`.
47 //
48 // Users of `Coro<T>` need not concern themselves with these details, unless
49 // they think it sounds like fun ;)
50 namespace internal {
51 
52 void LogCoroAllocationFailure(size_t requested_size);
53 
54 template <typename T>
55 class OptionalWrapper final {
56  public:
57   // Create an empty container for a to-be-provided value.
OptionalWrapper()58   OptionalWrapper() : value_() {}
59 
60   // Assign a value.
61   template <typename U>
62   OptionalWrapper& operator=(U&& value) {
63     value_ = std::forward<U>(value);
64     return *this;
65   }
66 
67   // Retrieve the inner value.
68   //
69   // This operation will fail if no value was assigned.
T()70   operator T() {
71     PW_ASSERT(value_.has_value());
72     return *value_;
73   }
74 
75  private:
76   std::optional<T> value_;
77 };
78 
79 // A container for a to-be-produced value of type `T`.
80 //
81 // This is designed to allow avoiding the overhead of `std::optional` when
82 // `T` is default-initializable.
83 //
84 // Values of this type begin as either:
85 // - a default-initialized `T` if `T` is default-initializable or
86 // - `std::nullopt`
87 template <typename T>
88 using OptionalOrDefault =
89     std::conditional<std::is_default_constructible<T>::value,
90                      T,
91                      OptionalWrapper<T>>::type;
92 
93 // A wrapper for `std::coroutine_handle` that assumes unique ownership of the
94 // underlying `PromiseType`.
95 //
96 // This type will `destroy()` the underlying promise in its destructor, or
97 // when `Release()` is called.
98 template <typename PromiseType>
99 class OwningCoroutineHandle final {
100  public:
101   // Construct a null (`!IsValid()`) handle.
OwningCoroutineHandle(std::nullptr_t)102   OwningCoroutineHandle(std::nullptr_t) : promise_handle_(nullptr) {}
103 
104   /// Take ownership of `promise_handle`.
OwningCoroutineHandle(std::coroutine_handle<PromiseType> && promise_handle)105   OwningCoroutineHandle(std::coroutine_handle<PromiseType>&& promise_handle)
106       : promise_handle_(std::move(promise_handle)) {}
107 
108   // Empty out `other` and transfers ownership of its `promise_handle`
109   // to `this`.
OwningCoroutineHandle(OwningCoroutineHandle && other)110   OwningCoroutineHandle(OwningCoroutineHandle&& other)
111       : promise_handle_(std::move(other.promise_handle_)) {
112     other.promise_handle_ = nullptr;
113   }
114 
115   // Empty out `other` and transfers ownership of its `promise_handle`
116   // to `this`.
117   OwningCoroutineHandle& operator=(OwningCoroutineHandle&& other) {
118     Release();
119     promise_handle_ = std::move(other.promise_handle_);
120     other.promise_handle_ = nullptr;
121     return *this;
122   }
123 
124   // `destroy()`s the underlying `promise_handle` if valid.
~OwningCoroutineHandle()125   ~OwningCoroutineHandle() { Release(); }
126 
127   // Return whether or not this value contains a `promise_handle`.
128   //
129   // This will return `false` if this `OwningCoroutineHandle` was
130   // `nullptr`-initialized, moved from, or if `Release` was invoked.
IsValid()131   [[nodiscard]] bool IsValid() const {
132     return promise_handle_.address() != nullptr;
133   }
134 
135   // Return a reference to the underlying `PromiseType`.
136   //
137   // Precondition: `IsValid()` must be `true`.
promise()138   [[nodiscard]] PromiseType& promise() const {
139     return promise_handle_.promise();
140   }
141 
142   // Whether or not the underlying coroutine has completed.
143   //
144   // Precondition: `IsValid()` must be `true`.
done()145   [[nodiscard]] bool done() const { return promise_handle_.done(); }
146 
147   // Resume the underlying coroutine.
148   //
149   // Precondition: `IsValid()` must be `true`, and `done()` must be
150   // `false`.
resume()151   void resume() { promise_handle_.resume(); }
152 
153   // Invokes `destroy()` on the underlying promise and deallocates its
154   // associated storage.
Release()155   void Release() {
156     // DOCSTAG: [pw_async2-coro-release]
157     void* address = promise_handle_.address();
158     if (address != nullptr) {
159       pw::allocator::Deallocator& dealloc = promise_handle_.promise().dealloc_;
160       promise_handle_.destroy();
161       promise_handle_ = nullptr;
162       dealloc.Deallocate(address);
163     }
164     // DOCSTAG: [pw_async2-coro-release]
165   }
166 
167  private:
168   std::coroutine_handle<PromiseType> promise_handle_;
169 };
170 
171 // Forward-declare the wrapper type for values passed to `co_await`.
172 template <typename Pendable, typename PromiseType>
173 class Awaitable;
174 
175 // A container for values passed in and out of the promise.
176 //
177 // The C++20 coroutine `resume()` function cannot accept arguments no return
178 // values, so instead coroutine inputs and outputs are funneled through this
179 // type. A pointer to the `InOut` object is stored in the `CoroPromiseType`
180 // so that the coroutine object can access it.
181 template <typename T>
182 struct InOut final {
183   // The `Context` passed into the coroutine via `Pend`.
184   Context* input_cx;
185 
186   // The output assigned to by the coroutine if the coroutine is `done()`.
187   OptionalOrDefault<T>* output;
188 };
189 
190 // Attempt to complete the current pendable value passed to `co_await`,
191 // storing its return value inside the `Awaitable` object so that it can
192 // be retrieved by the coroutine.
193 //
194 // Each `co_await` statement creates an `Awaitable` object whose `Pend`
195 // method must be completed before the coroutine's `resume()` function can
196 // be invoked.
197 //
198 // `sizeof(void*)` is used as the size since only one pointer capture is
199 // required in all cases.
200 using PendFillReturnValueFn = pw::Function<Poll<>(Context&), sizeof(void*)>;
201 
202 // The `promise_type` of `Coro<T>`.
203 //
204 // To understand this type, it may be necessary to refer to the reference
205 // documentation for the C++20 coroutine API.
206 template <typename T>
207 class CoroPromiseType final {
208  public:
209   // Construct the `CoroPromiseType` using the arguments passed to a
210   // function returning `Coro<T>`.
211   //
212   // The first argument *must* be a `CoroContext`. The other
213   // arguments are unused, but must be accepted in order for this to compile.
214   template <typename... Args>
CoroPromiseType(CoroContext & cx,const Args &...)215   CoroPromiseType(CoroContext& cx, const Args&...)
216       : dealloc_(cx.alloc()), currently_pending_(nullptr), in_out_(nullptr) {}
217 
218   // Method-receiver version.
219   template <typename MethodReceiver, typename... Args>
CoroPromiseType(const MethodReceiver &,CoroContext & cx,const Args &...)220   CoroPromiseType(const MethodReceiver&, CoroContext& cx, const Args&...)
221       : dealloc_(cx.alloc()), currently_pending_(nullptr), in_out_(nullptr) {}
222 
223   // Get the `Coro<T>` after successfully allocating the coroutine space
224   // and constructing `this`.
225   Coro<T> get_return_object();
226 
227   // Do not begin executing the `Coro<T>` until `resume()` has been invoked
228   // for the first time.
initial_suspend()229   std::suspend_always initial_suspend() { return {}; }
230 
231   // Unconditionally suspend to prevent `destroy()` being invoked.
232   //
233   // The caller of `resume()` needs to first observe `done()` before the
234   // state can be destroyed.
235   //
236   // Setting this to suspend means that the caller is responsible for invoking
237   // `destroy()`.
final_suspend()238   std::suspend_always final_suspend() noexcept { return {}; }
239 
240   // Store the `co_return` argument in the `InOut<T>` object provided by
241   // the `Pend` wrapper.
242   template <std::convertible_to<T> From>
return_value(From && value)243   void return_value(From&& value) {
244     *in_out_->output = std::forward<From>(value);
245   }
246 
247   // Ignore exceptions in coroutines.
248   //
249   // Pigweed is not designed to be used with exceptions: `Result` or a
250   // similar type should be used to propagate errors.
unhandled_exception()251   void unhandled_exception() { PW_ASSERT(false); }
252 
253   // Create an invalid (nullptr) `Coro<T>` if `operator new` below fails.
254   static Coro<T> get_return_object_on_allocation_failure();
255 
256   // Allocate the space for both this `CoroPromiseType<T>` and the coroutine
257   // state.
258   //
259   // This override does not accept alignment.
260   template <typename... Args>
new(std::size_t size,CoroContext & coro_cx,const Args &...)261   static void* operator new(std::size_t size,
262                             CoroContext& coro_cx,
263                             const Args&...) noexcept {
264     return SharedNew(coro_cx, size, alignof(std::max_align_t));
265   }
266 
267   // Allocate the space for both this `CoroPromiseType<T>` and the coroutine
268   // state.
269   //
270   // This override accepts alignment.
271   template <typename... Args>
new(std::size_t size,std::align_val_t align,CoroContext & coro_cx,const Args &...)272   static void* operator new(std::size_t size,
273                             std::align_val_t align,
274                             CoroContext& coro_cx,
275                             const Args&...) noexcept {
276     return SharedNew(coro_cx, size, static_cast<size_t>(align));
277   }
278 
279   // Method-receiver form.
280   //
281   // This override does not accept alignment.
282   template <typename MethodReceiver, typename... Args>
new(std::size_t size,const MethodReceiver &,CoroContext & coro_cx,const Args &...)283   static void* operator new(std::size_t size,
284                             const MethodReceiver&,
285                             CoroContext& coro_cx,
286                             const Args&...) noexcept {
287     return SharedNew(coro_cx, size, alignof(std::max_align_t));
288   }
289 
290   // Method-receiver form.
291   //
292   // This accepts alignment.
293   template <typename MethodReceiver, typename... Args>
new(std::size_t size,std::align_val_t align,const MethodReceiver &,CoroContext & coro_cx,const Args &...)294   static void* operator new(std::size_t size,
295                             std::align_val_t align,
296                             const MethodReceiver&,
297                             CoroContext& coro_cx,
298                             const Args&...) noexcept {
299     return SharedNew(coro_cx, size, static_cast<size_t>(align));
300   }
301 
SharedNew(CoroContext & coro_cx,std::size_t size,std::size_t align)302   static void* SharedNew(CoroContext& coro_cx,
303                          std::size_t size,
304                          std::size_t align) noexcept {
305     auto ptr = coro_cx.alloc().Allocate(pw::allocator::Layout(size, align));
306     if (ptr == nullptr) {
307       internal::LogCoroAllocationFailure(size);
308     }
309     return ptr;
310   }
311 
312   // Deallocate the space for both this `CoroPromiseType<T>` and the
313   // coroutine state.
314   //
315   // In reality, we do nothing here!!!
316   //
317   // Coroutines do not support `destroying_delete`, so we can't access
318   // `dealloc_` here, and therefore have no way to deallocate.
319   // Instead, deallocation is handled by `OwningCoroutineHandle<T>::Release`.
delete(void *)320   static void operator delete(void*) {}
321 
322   // Handle a `co_await` call by accepting a type with a
323   // `Poll<U> Pend(Context&)` method, returning an `Awaitable` which will
324   // yield a `U` once complete.
325   template <typename Pendable>
326     requires(!std::is_reference_v<Pendable>)
await_transform(Pendable && pendable)327   Awaitable<Pendable, CoroPromiseType> await_transform(Pendable&& pendable) {
328     return pendable;
329   }
330 
331   template <typename Pendable>
await_transform(Pendable & pendable)332   Awaitable<Pendable*, CoroPromiseType> await_transform(Pendable& pendable) {
333     return &pendable;
334   }
335 
336   // Returns a reference to the `Context` passed in.
cx()337   Context& cx() { return *in_out_->input_cx; }
338 
339   pw::allocator::Deallocator& dealloc_;
340   PendFillReturnValueFn currently_pending_;
341   InOut<T>* in_out_;
342 };
343 
344 // The object created by invoking `co_await` in a `Coro<T>` function.
345 //
346 // This wraps a `Pendable` type and implements the awaitable interface
347 // expected by the standard coroutine API.
348 template <typename Pendable, typename PromiseType>
349 class Awaitable final {
350  public:
351   // The `OutputType` in `Poll<OutputType> Pendable::Pend(Context&)`.
352   using OutputType = std::remove_cvref_t<
353       decltype(std::declval<std::remove_pointer_t<Pendable>>()
354                    .Pend(std::declval<Context&>())
355                    .value())>;
356 
Awaitable(Pendable && pendable)357   Awaitable(Pendable&& pendable) : state_(std::forward<Pendable>(pendable)) {}
358 
359   // Confirms that `await_suspend` must be invoked.
await_ready()360   bool await_ready() { return false; }
361 
362   // Returns whether or not the current coroutine should be suspended.
363   //
364   // This is invoked once as part of every `co_await` call after
365   // `await_ready` returns `false`.
366   //
367   // In the process, this method attempts to complete the inner `Pendable`
368   // before suspending this coroutine.
await_suspend(const std::coroutine_handle<PromiseType> & promise)369   bool await_suspend(const std::coroutine_handle<PromiseType>& promise) {
370     Context& cx = promise.promise().cx();
371     if (PendFillReturnValue(cx).IsPending()) {
372       /// The coroutine should suspend since the await-ed thing is pending.
373       promise.promise().currently_pending_ = [this](Context& lambda_cx) {
374         return PendFillReturnValue(lambda_cx);
375       };
376       return true;
377     }
378     return false;
379   }
380 
381   // Returns `return_value`.
382   //
383   // This is automatically invoked by the language runtime when the promise's
384   // `resume()` method is called.
await_resume()385   OutputType&& await_resume() {
386     return std::move(std::get<OutputType>(state_));
387   }
388 
PendableNoPtr()389   auto& PendableNoPtr() {
390     if constexpr (std::is_pointer_v<Pendable>) {
391       return *std::get<Pendable>(state_);
392     } else {
393       return std::get<Pendable>(state_);
394     }
395   }
396 
397   // Attempts to complete the `Pendable` value, storing its return value
398   // upon completion.
399   //
400   // This method must return `Ready()` before the coroutine can be safely
401   // resumed, as otherwise the return value will not be available when
402   // `await_resume` is called to produce the result of `co_await`.
PendFillReturnValue(Context & cx)403   Poll<> PendFillReturnValue(Context& cx) {
404     Poll<OutputType> poll_res(PendableNoPtr().Pend(cx));
405     if (poll_res.IsPending()) {
406       return Pending();
407     }
408     state_ = std::move(*poll_res);
409     return Ready();
410   }
411 
412  private:
413   std::variant<Pendable, OutputType> state_;
414 };
415 
416 }  // namespace internal
417 
418 /// An asynchronous coroutine which implements the C++20 coroutine API.
419 ///
420 /// # Why coroutines?
421 /// Coroutines allow a series of asynchronous operations to be written as
422 /// straight line code. Rather than manually writing a state machine, users can
423 /// `co_await` any Pigweed asynchronous value (types with a
424 /// `Poll<T> Pend(Context&)` method).
425 ///
426 /// # Allocation
427 /// Pigweed's `Coro<T>` API supports checked, fallible, heap-free allocation.
428 /// The first argument to any coroutine function must be a
429 /// `CoroContext` (or a reference to one). This allows the
430 /// coroutine to allocate space for asynchronously-held stack variables using
431 /// the allocator member of the `CoroContext`.
432 ///
433 /// Failure to allocate coroutine "stack" space will result in the `Coro<T>`
434 /// returning `Status::Invalid()`.
435 ///
436 /// # Creating a coroutine function
437 /// To create a coroutine, a function must:
438 /// - Have an annotated return type of `Coro<T>` where `T` is some type
439 ///   constructible from `pw::Status`, such as `pw::Status` or
440 ///   `pw::Result<U>`.
441 /// - Use `co_return <value>` rather than `return <value>` for any
442 ///   `return` statements. This also requires the use of `PW_CO_TRY` and
443 ///   `PW_CO_TRY_ASSIGN` rather than `PW_TRY` and `PW_TRY_ASSIGN`.
444 /// - Accept a value convertible to `pw::allocator::Allocator&` as its first
445 ///   argument. This allocator will be used to allocate storage for coroutine
446 ///   stack variables held across a `co_await` point.
447 ///
448 /// # Using co_await
449 /// Inside a coroutine function, `co_await <expr>` can be used on any type
450 /// with a `Poll<T> Pend(Context&)` method. The result will be a value of
451 /// type `T`.
452 ///
453 /// # Example
454 /// @rst
455 /// .. literalinclude:: examples/basic.cc
456 ///    :language: cpp
457 ///    :linenos:
458 ///    :start-after: [pw_async2-examples-basic-coro]
459 ///    :end-before: [pw_async2-examples-basic-coro]
460 /// @endrst
461 template <std::constructible_from<pw::Status> T>
462 class Coro final {
463  public:
464   /// Creates an empty, invalid coroutine object.
Empty()465   static Coro Empty() {
466     return Coro(internal::OwningCoroutineHandle<promise_type>(nullptr));
467   }
468 
469   /// Whether or not this `Coro<T>` is a valid coroutine.
470   ///
471   /// This will return `false` if coroutine state allocation failed or if
472   /// this `Coro<T>::Pend` method previously returned a `Ready` value.
IsValid()473   [[nodiscard]] bool IsValid() const { return promise_handle_.IsValid(); }
474 
475   /// Attempt to complete this coroutine, returning the result if complete.
476   ///
477   /// Returns `Status::Internal()` if `!IsValid()`, which may occur if
478   /// coroutine state allocation fails.
Pend(Context & cx)479   Poll<T> Pend(Context& cx) {
480     if (!IsValid()) {
481       // This coroutine failed to allocate its internal state.
482       // (Or `Pend` is being erroniously invoked after previously completing.)
483       return Ready(Status::Internal());
484     }
485 
486     // If an `Awaitable` value is currently being processed, it must be
487     // allowed to complete and store its return value before we can resume
488     // the coroutine.
489     if (promise_handle_.promise().currently_pending_ != nullptr &&
490         promise_handle_.promise().currently_pending_(cx).IsPending()) {
491       return Pending();
492     }
493     // DOCSTAG: [pw_async2-coro-resume]
494     // Create the arguments (and output storage) for the coroutine.
495     internal::InOut<T> in_out;
496     internal::OptionalOrDefault<T> return_value;
497     in_out.input_cx = &cx;
498     in_out.output = &return_value;
499     promise_handle_.promise().in_out_ = &in_out;
500 
501     // Resume the coroutine, triggering `Awaitable::await_resume()` and the
502     // returning of the resulting value from `co_await`.
503     promise_handle_.resume();
504     if (!promise_handle_.done()) {
505       return Pending();
506     }
507 
508     // Destroy the coroutine state: it has completed, and further calls to
509     // `resume` would result in undefined behavior.
510     promise_handle_.Release();
511 
512     // When the coroutine completed in `resume()` above, it stored its
513     // `co_return` value into `return_value`. This retrieves that value.
514     return return_value;
515     // DOCSTAG: [pw_async2-coro-resume]
516   }
517 
518   /// Used by the compiler in order to create a `Coro<T>` from a coroutine
519   /// function.
520   using promise_type = ::pw::async2::internal::CoroPromiseType<T>;
521 
522  private:
523   // Allow `CoroPromiseType<T>::get_return_object()` and
524   // `CoroPromiseType<T>::get_retunr_object_on_allocation_failure()` to
525   // use the private constructor below.
526   friend promise_type;
527 
528   /// Create a new `Coro<T>` using a (possibly null) handle.
Coro(internal::OwningCoroutineHandle<promise_type> && promise_handle)529   Coro(internal::OwningCoroutineHandle<promise_type>&& promise_handle)
530       : promise_handle_(std::move(promise_handle)) {}
531 
532   internal::OwningCoroutineHandle<promise_type> promise_handle_;
533 };
534 
535 // Implement the remaining internal pieces that require a definition of
536 // `Coro<T>`.
537 namespace internal {
538 
539 template <typename T>
get_return_object()540 Coro<T> CoroPromiseType<T>::get_return_object() {
541   return internal::OwningCoroutineHandle<CoroPromiseType<T>>(
542       std::coroutine_handle<CoroPromiseType<T>>::from_promise(*this));
543 }
544 
545 template <typename T>
get_return_object_on_allocation_failure()546 Coro<T> CoroPromiseType<T>::get_return_object_on_allocation_failure() {
547   return internal::OwningCoroutineHandle<CoroPromiseType<T>>(nullptr);
548 }
549 
550 }  // namespace internal
551 }  // namespace pw::async2
552