1 // Copyright 2024 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15
16 #include <concepts>
17 #include <coroutine>
18 #include <variant>
19
20 #include "pw_allocator/allocator.h"
21 #include "pw_allocator/layout.h"
22 #include "pw_async2/dispatcher.h"
23 #include "pw_function/function.h"
24 #include "pw_log/log.h"
25 #include "pw_status/status.h"
26 #include "pw_status/try.h"
27
28 namespace pw::async2 {
29
30 // Forward-declare `Coro` so that it can be referenced by the promise type APIs.
31 template <std::constructible_from<pw::Status> T>
32 class Coro;
33
34 /// Context required for creating and executing coroutines.
35 class CoroContext {
36 public:
37 /// Creates a `CoroContext` which will allocate coroutine state using
38 /// `alloc`.
CoroContext(pw::allocator::Allocator & alloc)39 explicit CoroContext(pw::allocator::Allocator& alloc) : alloc_(alloc) {}
alloc()40 pw::allocator::Allocator& alloc() const { return alloc_; }
41
42 private:
43 pw::allocator::Allocator& alloc_;
44 };
45
46 // The internal coroutine API implementation details enabling `Coro<T>`.
47 //
48 // Users of `Coro<T>` need not concern themselves with these details, unless
49 // they think it sounds like fun ;)
50 namespace internal {
51
52 void LogCoroAllocationFailure(size_t requested_size);
53
54 template <typename T>
55 class OptionalWrapper final {
56 public:
57 // Create an empty container for a to-be-provided value.
OptionalWrapper()58 OptionalWrapper() : value_() {}
59
60 // Assign a value.
61 template <typename U>
62 OptionalWrapper& operator=(U&& value) {
63 value_ = std::forward<U>(value);
64 return *this;
65 }
66
67 // Retrieve the inner value.
68 //
69 // This operation will fail if no value was assigned.
T()70 operator T() {
71 PW_ASSERT(value_.has_value());
72 return *value_;
73 }
74
75 private:
76 std::optional<T> value_;
77 };
78
79 // A container for a to-be-produced value of type `T`.
80 //
81 // This is designed to allow avoiding the overhead of `std::optional` when
82 // `T` is default-initializable.
83 //
84 // Values of this type begin as either:
85 // - a default-initialized `T` if `T` is default-initializable or
86 // - `std::nullopt`
87 template <typename T>
88 using OptionalOrDefault =
89 std::conditional<std::is_default_constructible<T>::value,
90 T,
91 OptionalWrapper<T>>::type;
92
93 // A wrapper for `std::coroutine_handle` that assumes unique ownership of the
94 // underlying `PromiseType`.
95 //
96 // This type will `destroy()` the underlying promise in its destructor, or
97 // when `Release()` is called.
98 template <typename PromiseType>
99 class OwningCoroutineHandle final {
100 public:
101 // Construct a null (`!IsValid()`) handle.
OwningCoroutineHandle(std::nullptr_t)102 OwningCoroutineHandle(std::nullptr_t) : promise_handle_(nullptr) {}
103
104 /// Take ownership of `promise_handle`.
OwningCoroutineHandle(std::coroutine_handle<PromiseType> && promise_handle)105 OwningCoroutineHandle(std::coroutine_handle<PromiseType>&& promise_handle)
106 : promise_handle_(std::move(promise_handle)) {}
107
108 // Empty out `other` and transfers ownership of its `promise_handle`
109 // to `this`.
OwningCoroutineHandle(OwningCoroutineHandle && other)110 OwningCoroutineHandle(OwningCoroutineHandle&& other)
111 : promise_handle_(std::move(other.promise_handle_)) {
112 other.promise_handle_ = nullptr;
113 }
114
115 // Empty out `other` and transfers ownership of its `promise_handle`
116 // to `this`.
117 OwningCoroutineHandle& operator=(OwningCoroutineHandle&& other) {
118 Release();
119 promise_handle_ = std::move(other.promise_handle_);
120 other.promise_handle_ = nullptr;
121 return *this;
122 }
123
124 // `destroy()`s the underlying `promise_handle` if valid.
~OwningCoroutineHandle()125 ~OwningCoroutineHandle() { Release(); }
126
127 // Return whether or not this value contains a `promise_handle`.
128 //
129 // This will return `false` if this `OwningCoroutineHandle` was
130 // `nullptr`-initialized, moved from, or if `Release` was invoked.
IsValid()131 [[nodiscard]] bool IsValid() const {
132 return promise_handle_.address() != nullptr;
133 }
134
135 // Return a reference to the underlying `PromiseType`.
136 //
137 // Precondition: `IsValid()` must be `true`.
promise()138 [[nodiscard]] PromiseType& promise() const {
139 return promise_handle_.promise();
140 }
141
142 // Whether or not the underlying coroutine has completed.
143 //
144 // Precondition: `IsValid()` must be `true`.
done()145 [[nodiscard]] bool done() const { return promise_handle_.done(); }
146
147 // Resume the underlying coroutine.
148 //
149 // Precondition: `IsValid()` must be `true`, and `done()` must be
150 // `false`.
resume()151 void resume() { promise_handle_.resume(); }
152
153 // Invokes `destroy()` on the underlying promise and deallocates its
154 // associated storage.
Release()155 void Release() {
156 // DOCSTAG: [pw_async2-coro-release]
157 void* address = promise_handle_.address();
158 if (address != nullptr) {
159 pw::allocator::Deallocator& dealloc = promise_handle_.promise().dealloc_;
160 promise_handle_.destroy();
161 promise_handle_ = nullptr;
162 dealloc.Deallocate(address);
163 }
164 // DOCSTAG: [pw_async2-coro-release]
165 }
166
167 private:
168 std::coroutine_handle<PromiseType> promise_handle_;
169 };
170
171 // Forward-declare the wrapper type for values passed to `co_await`.
172 template <typename Pendable, typename PromiseType>
173 class Awaitable;
174
175 // A container for values passed in and out of the promise.
176 //
177 // The C++20 coroutine `resume()` function cannot accept arguments no return
178 // values, so instead coroutine inputs and outputs are funneled through this
179 // type. A pointer to the `InOut` object is stored in the `CoroPromiseType`
180 // so that the coroutine object can access it.
181 template <typename T>
182 struct InOut final {
183 // The `Context` passed into the coroutine via `Pend`.
184 Context* input_cx;
185
186 // The output assigned to by the coroutine if the coroutine is `done()`.
187 OptionalOrDefault<T>* output;
188 };
189
190 // Attempt to complete the current pendable value passed to `co_await`,
191 // storing its return value inside the `Awaitable` object so that it can
192 // be retrieved by the coroutine.
193 //
194 // Each `co_await` statement creates an `Awaitable` object whose `Pend`
195 // method must be completed before the coroutine's `resume()` function can
196 // be invoked.
197 //
198 // `sizeof(void*)` is used as the size since only one pointer capture is
199 // required in all cases.
200 using PendFillReturnValueFn = pw::Function<Poll<>(Context&), sizeof(void*)>;
201
202 // The `promise_type` of `Coro<T>`.
203 //
204 // To understand this type, it may be necessary to refer to the reference
205 // documentation for the C++20 coroutine API.
206 template <typename T>
207 class CoroPromiseType final {
208 public:
209 // Construct the `CoroPromiseType` using the arguments passed to a
210 // function returning `Coro<T>`.
211 //
212 // The first argument *must* be a `CoroContext`. The other
213 // arguments are unused, but must be accepted in order for this to compile.
214 template <typename... Args>
CoroPromiseType(CoroContext & cx,const Args &...)215 CoroPromiseType(CoroContext& cx, const Args&...)
216 : dealloc_(cx.alloc()), currently_pending_(nullptr), in_out_(nullptr) {}
217
218 // Method-receiver version.
219 template <typename MethodReceiver, typename... Args>
CoroPromiseType(const MethodReceiver &,CoroContext & cx,const Args &...)220 CoroPromiseType(const MethodReceiver&, CoroContext& cx, const Args&...)
221 : dealloc_(cx.alloc()), currently_pending_(nullptr), in_out_(nullptr) {}
222
223 // Get the `Coro<T>` after successfully allocating the coroutine space
224 // and constructing `this`.
225 Coro<T> get_return_object();
226
227 // Do not begin executing the `Coro<T>` until `resume()` has been invoked
228 // for the first time.
initial_suspend()229 std::suspend_always initial_suspend() { return {}; }
230
231 // Unconditionally suspend to prevent `destroy()` being invoked.
232 //
233 // The caller of `resume()` needs to first observe `done()` before the
234 // state can be destroyed.
235 //
236 // Setting this to suspend means that the caller is responsible for invoking
237 // `destroy()`.
final_suspend()238 std::suspend_always final_suspend() noexcept { return {}; }
239
240 // Store the `co_return` argument in the `InOut<T>` object provided by
241 // the `Pend` wrapper.
242 template <std::convertible_to<T> From>
return_value(From && value)243 void return_value(From&& value) {
244 *in_out_->output = std::forward<From>(value);
245 }
246
247 // Ignore exceptions in coroutines.
248 //
249 // Pigweed is not designed to be used with exceptions: `Result` or a
250 // similar type should be used to propagate errors.
unhandled_exception()251 void unhandled_exception() { PW_ASSERT(false); }
252
253 // Create an invalid (nullptr) `Coro<T>` if `operator new` below fails.
254 static Coro<T> get_return_object_on_allocation_failure();
255
256 // Allocate the space for both this `CoroPromiseType<T>` and the coroutine
257 // state.
258 //
259 // This override does not accept alignment.
260 template <typename... Args>
new(std::size_t size,CoroContext & coro_cx,const Args &...)261 static void* operator new(std::size_t size,
262 CoroContext& coro_cx,
263 const Args&...) noexcept {
264 return SharedNew(coro_cx, size, alignof(std::max_align_t));
265 }
266
267 // Allocate the space for both this `CoroPromiseType<T>` and the coroutine
268 // state.
269 //
270 // This override accepts alignment.
271 template <typename... Args>
new(std::size_t size,std::align_val_t align,CoroContext & coro_cx,const Args &...)272 static void* operator new(std::size_t size,
273 std::align_val_t align,
274 CoroContext& coro_cx,
275 const Args&...) noexcept {
276 return SharedNew(coro_cx, size, static_cast<size_t>(align));
277 }
278
279 // Method-receiver form.
280 //
281 // This override does not accept alignment.
282 template <typename MethodReceiver, typename... Args>
new(std::size_t size,const MethodReceiver &,CoroContext & coro_cx,const Args &...)283 static void* operator new(std::size_t size,
284 const MethodReceiver&,
285 CoroContext& coro_cx,
286 const Args&...) noexcept {
287 return SharedNew(coro_cx, size, alignof(std::max_align_t));
288 }
289
290 // Method-receiver form.
291 //
292 // This accepts alignment.
293 template <typename MethodReceiver, typename... Args>
new(std::size_t size,std::align_val_t align,const MethodReceiver &,CoroContext & coro_cx,const Args &...)294 static void* operator new(std::size_t size,
295 std::align_val_t align,
296 const MethodReceiver&,
297 CoroContext& coro_cx,
298 const Args&...) noexcept {
299 return SharedNew(coro_cx, size, static_cast<size_t>(align));
300 }
301
SharedNew(CoroContext & coro_cx,std::size_t size,std::size_t align)302 static void* SharedNew(CoroContext& coro_cx,
303 std::size_t size,
304 std::size_t align) noexcept {
305 auto ptr = coro_cx.alloc().Allocate(pw::allocator::Layout(size, align));
306 if (ptr == nullptr) {
307 internal::LogCoroAllocationFailure(size);
308 }
309 return ptr;
310 }
311
312 // Deallocate the space for both this `CoroPromiseType<T>` and the
313 // coroutine state.
314 //
315 // In reality, we do nothing here!!!
316 //
317 // Coroutines do not support `destroying_delete`, so we can't access
318 // `dealloc_` here, and therefore have no way to deallocate.
319 // Instead, deallocation is handled by `OwningCoroutineHandle<T>::Release`.
delete(void *)320 static void operator delete(void*) {}
321
322 // Handle a `co_await` call by accepting a type with a
323 // `Poll<U> Pend(Context&)` method, returning an `Awaitable` which will
324 // yield a `U` once complete.
325 template <typename Pendable>
326 requires(!std::is_reference_v<Pendable>)
await_transform(Pendable && pendable)327 Awaitable<Pendable, CoroPromiseType> await_transform(Pendable&& pendable) {
328 return pendable;
329 }
330
331 template <typename Pendable>
await_transform(Pendable & pendable)332 Awaitable<Pendable*, CoroPromiseType> await_transform(Pendable& pendable) {
333 return &pendable;
334 }
335
336 // Returns a reference to the `Context` passed in.
cx()337 Context& cx() { return *in_out_->input_cx; }
338
339 pw::allocator::Deallocator& dealloc_;
340 PendFillReturnValueFn currently_pending_;
341 InOut<T>* in_out_;
342 };
343
344 // The object created by invoking `co_await` in a `Coro<T>` function.
345 //
346 // This wraps a `Pendable` type and implements the awaitable interface
347 // expected by the standard coroutine API.
348 template <typename Pendable, typename PromiseType>
349 class Awaitable final {
350 public:
351 // The `OutputType` in `Poll<OutputType> Pendable::Pend(Context&)`.
352 using OutputType = std::remove_cvref_t<
353 decltype(std::declval<std::remove_pointer_t<Pendable>>()
354 .Pend(std::declval<Context&>())
355 .value())>;
356
Awaitable(Pendable && pendable)357 Awaitable(Pendable&& pendable) : state_(std::forward<Pendable>(pendable)) {}
358
359 // Confirms that `await_suspend` must be invoked.
await_ready()360 bool await_ready() { return false; }
361
362 // Returns whether or not the current coroutine should be suspended.
363 //
364 // This is invoked once as part of every `co_await` call after
365 // `await_ready` returns `false`.
366 //
367 // In the process, this method attempts to complete the inner `Pendable`
368 // before suspending this coroutine.
await_suspend(const std::coroutine_handle<PromiseType> & promise)369 bool await_suspend(const std::coroutine_handle<PromiseType>& promise) {
370 Context& cx = promise.promise().cx();
371 if (PendFillReturnValue(cx).IsPending()) {
372 /// The coroutine should suspend since the await-ed thing is pending.
373 promise.promise().currently_pending_ = [this](Context& lambda_cx) {
374 return PendFillReturnValue(lambda_cx);
375 };
376 return true;
377 }
378 return false;
379 }
380
381 // Returns `return_value`.
382 //
383 // This is automatically invoked by the language runtime when the promise's
384 // `resume()` method is called.
await_resume()385 OutputType&& await_resume() {
386 return std::move(std::get<OutputType>(state_));
387 }
388
PendableNoPtr()389 auto& PendableNoPtr() {
390 if constexpr (std::is_pointer_v<Pendable>) {
391 return *std::get<Pendable>(state_);
392 } else {
393 return std::get<Pendable>(state_);
394 }
395 }
396
397 // Attempts to complete the `Pendable` value, storing its return value
398 // upon completion.
399 //
400 // This method must return `Ready()` before the coroutine can be safely
401 // resumed, as otherwise the return value will not be available when
402 // `await_resume` is called to produce the result of `co_await`.
PendFillReturnValue(Context & cx)403 Poll<> PendFillReturnValue(Context& cx) {
404 Poll<OutputType> poll_res(PendableNoPtr().Pend(cx));
405 if (poll_res.IsPending()) {
406 return Pending();
407 }
408 state_ = std::move(*poll_res);
409 return Ready();
410 }
411
412 private:
413 std::variant<Pendable, OutputType> state_;
414 };
415
416 } // namespace internal
417
418 /// An asynchronous coroutine which implements the C++20 coroutine API.
419 ///
420 /// # Why coroutines?
421 /// Coroutines allow a series of asynchronous operations to be written as
422 /// straight line code. Rather than manually writing a state machine, users can
423 /// `co_await` any Pigweed asynchronous value (types with a
424 /// `Poll<T> Pend(Context&)` method).
425 ///
426 /// # Allocation
427 /// Pigweed's `Coro<T>` API supports checked, fallible, heap-free allocation.
428 /// The first argument to any coroutine function must be a
429 /// `CoroContext` (or a reference to one). This allows the
430 /// coroutine to allocate space for asynchronously-held stack variables using
431 /// the allocator member of the `CoroContext`.
432 ///
433 /// Failure to allocate coroutine "stack" space will result in the `Coro<T>`
434 /// returning `Status::Invalid()`.
435 ///
436 /// # Creating a coroutine function
437 /// To create a coroutine, a function must:
438 /// - Have an annotated return type of `Coro<T>` where `T` is some type
439 /// constructible from `pw::Status`, such as `pw::Status` or
440 /// `pw::Result<U>`.
441 /// - Use `co_return <value>` rather than `return <value>` for any
442 /// `return` statements. This also requires the use of `PW_CO_TRY` and
443 /// `PW_CO_TRY_ASSIGN` rather than `PW_TRY` and `PW_TRY_ASSIGN`.
444 /// - Accept a value convertible to `pw::allocator::Allocator&` as its first
445 /// argument. This allocator will be used to allocate storage for coroutine
446 /// stack variables held across a `co_await` point.
447 ///
448 /// # Using co_await
449 /// Inside a coroutine function, `co_await <expr>` can be used on any type
450 /// with a `Poll<T> Pend(Context&)` method. The result will be a value of
451 /// type `T`.
452 ///
453 /// # Example
454 /// @rst
455 /// .. literalinclude:: examples/basic.cc
456 /// :language: cpp
457 /// :linenos:
458 /// :start-after: [pw_async2-examples-basic-coro]
459 /// :end-before: [pw_async2-examples-basic-coro]
460 /// @endrst
461 template <std::constructible_from<pw::Status> T>
462 class Coro final {
463 public:
464 /// Creates an empty, invalid coroutine object.
Empty()465 static Coro Empty() {
466 return Coro(internal::OwningCoroutineHandle<promise_type>(nullptr));
467 }
468
469 /// Whether or not this `Coro<T>` is a valid coroutine.
470 ///
471 /// This will return `false` if coroutine state allocation failed or if
472 /// this `Coro<T>::Pend` method previously returned a `Ready` value.
IsValid()473 [[nodiscard]] bool IsValid() const { return promise_handle_.IsValid(); }
474
475 /// Attempt to complete this coroutine, returning the result if complete.
476 ///
477 /// Returns `Status::Internal()` if `!IsValid()`, which may occur if
478 /// coroutine state allocation fails.
Pend(Context & cx)479 Poll<T> Pend(Context& cx) {
480 if (!IsValid()) {
481 // This coroutine failed to allocate its internal state.
482 // (Or `Pend` is being erroniously invoked after previously completing.)
483 return Ready(Status::Internal());
484 }
485
486 // If an `Awaitable` value is currently being processed, it must be
487 // allowed to complete and store its return value before we can resume
488 // the coroutine.
489 if (promise_handle_.promise().currently_pending_ != nullptr &&
490 promise_handle_.promise().currently_pending_(cx).IsPending()) {
491 return Pending();
492 }
493 // DOCSTAG: [pw_async2-coro-resume]
494 // Create the arguments (and output storage) for the coroutine.
495 internal::InOut<T> in_out;
496 internal::OptionalOrDefault<T> return_value;
497 in_out.input_cx = &cx;
498 in_out.output = &return_value;
499 promise_handle_.promise().in_out_ = &in_out;
500
501 // Resume the coroutine, triggering `Awaitable::await_resume()` and the
502 // returning of the resulting value from `co_await`.
503 promise_handle_.resume();
504 if (!promise_handle_.done()) {
505 return Pending();
506 }
507
508 // Destroy the coroutine state: it has completed, and further calls to
509 // `resume` would result in undefined behavior.
510 promise_handle_.Release();
511
512 // When the coroutine completed in `resume()` above, it stored its
513 // `co_return` value into `return_value`. This retrieves that value.
514 return return_value;
515 // DOCSTAG: [pw_async2-coro-resume]
516 }
517
518 /// Used by the compiler in order to create a `Coro<T>` from a coroutine
519 /// function.
520 using promise_type = ::pw::async2::internal::CoroPromiseType<T>;
521
522 private:
523 // Allow `CoroPromiseType<T>::get_return_object()` and
524 // `CoroPromiseType<T>::get_retunr_object_on_allocation_failure()` to
525 // use the private constructor below.
526 friend promise_type;
527
528 /// Create a new `Coro<T>` using a (possibly null) handle.
Coro(internal::OwningCoroutineHandle<promise_type> && promise_handle)529 Coro(internal::OwningCoroutineHandle<promise_type>&& promise_handle)
530 : promise_handle_(std::move(promise_handle)) {}
531
532 internal::OwningCoroutineHandle<promise_type> promise_handle_;
533 };
534
535 // Implement the remaining internal pieces that require a definition of
536 // `Coro<T>`.
537 namespace internal {
538
539 template <typename T>
get_return_object()540 Coro<T> CoroPromiseType<T>::get_return_object() {
541 return internal::OwningCoroutineHandle<CoroPromiseType<T>>(
542 std::coroutine_handle<CoroPromiseType<T>>::from_promise(*this));
543 }
544
545 template <typename T>
get_return_object_on_allocation_failure()546 Coro<T> CoroPromiseType<T>::get_return_object_on_allocation_failure() {
547 return internal::OwningCoroutineHandle<CoroPromiseType<T>>(nullptr);
548 }
549
550 } // namespace internal
551 } // namespace pw::async2
552