1 // Copyright 2023 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GRPC_SRC_CORE_LIB_PROMISE_PARTY_H
16 #define GRPC_SRC_CORE_LIB_PROMISE_PARTY_H
17 
18 #include <grpc/support/port_platform.h>
19 
20 #include <stddef.h>
21 #include <stdint.h>
22 
23 #include <atomic>
24 #include <string>
25 #include <utility>
26 
27 #include "absl/base/thread_annotations.h"
28 #include "absl/strings/string_view.h"
29 
30 #include <grpc/event_engine/event_engine.h>
31 #include <grpc/support/log.h>
32 
33 #include "src/core/lib/debug/trace.h"
34 #include "src/core/lib/gprpp/construct_destruct.h"
35 #include "src/core/lib/gprpp/crash.h"
36 #include "src/core/lib/gprpp/ref_counted.h"
37 #include "src/core/lib/gprpp/ref_counted_ptr.h"
38 #include "src/core/lib/gprpp/sync.h"
39 #include "src/core/lib/promise/activity.h"
40 #include "src/core/lib/promise/context.h"
41 #include "src/core/lib/promise/detail/promise_factory.h"
42 #include "src/core/lib/promise/trace.h"
43 #include "src/core/lib/resource_quota/arena.h"
44 
45 // Two implementations of party synchronization are provided: one using a single
46 // atomic, the other using a mutex and a set of state variables.
47 // Originally the atomic implementation was implemented, but we found some race
48 // conditions on Arm that were not reported by our default TSAN implementation.
49 // The mutex implementation was added to see if it would fix the problem, and
50 // it did. Later we found the race condition, so there's no known reason to use
51 // the mutex version - however we keep it around as a just in case measure.
52 // There's a thought of fuzzing the two implementations against each other as
53 // a correctness check of both, but that's not implemented yet.
54 
55 #define GRPC_PARTY_SYNC_USING_ATOMICS
56 // #define GRPC_PARTY_SYNC_USING_MUTEX
57 
58 #if defined(GRPC_PARTY_SYNC_USING_ATOMICS) +    \
59         defined(GRPC_PARTY_SYNC_USING_MUTEX) != \
60     1
61 #error Must define a party sync mechanism
62 #endif
63 
64 namespace grpc_core {
65 
66 namespace party_detail {
67 
68 // Number of bits reserved for wakeups gives us the maximum number of
69 // participants.
70 static constexpr size_t kMaxParticipants = 16;
71 
72 }  // namespace party_detail
73 
74 class PartySyncUsingAtomics {
75  public:
PartySyncUsingAtomics(size_t initial_refs)76   explicit PartySyncUsingAtomics(size_t initial_refs)
77       : state_(kOneRef * initial_refs) {}
78 
IncrementRefCount()79   void IncrementRefCount() {
80     state_.fetch_add(kOneRef, std::memory_order_relaxed);
81   }
82   GRPC_MUST_USE_RESULT bool RefIfNonZero();
83   // Returns true if the ref count is now zero and the caller should call
84   // PartyOver
Unref()85   GRPC_MUST_USE_RESULT bool Unref() {
86     uint64_t prev_state = state_.fetch_sub(kOneRef, std::memory_order_acq_rel);
87     if ((prev_state & kRefMask) == kOneRef) {
88       return UnreffedLast();
89     }
90     return false;
91   }
ForceImmediateRepoll(WakeupMask mask)92   void ForceImmediateRepoll(WakeupMask mask) {
93     // Or in the bit for the currently polling participant.
94     // Will be grabbed next round to force a repoll of this promise.
95     state_.fetch_or(mask, std::memory_order_relaxed);
96   }
97 
98   // Run the update loop: poll_one_participant is called with an integral index
99   // for the participant that should be polled. It should return true if the
100   // participant completed and should be removed from the allocated set.
101   template <typename F>
RunParty(F poll_one_participant)102   GRPC_MUST_USE_RESULT bool RunParty(F poll_one_participant) {
103     uint64_t prev_state;
104     do {
105       // Grab the current state, and clear the wakeup bits & add flag.
106       prev_state = state_.fetch_and(kRefMask | kLocked | kAllocatedMask,
107                                     std::memory_order_acquire);
108       GPR_ASSERT(prev_state & kLocked);
109       if (prev_state & kDestroying) return true;
110       // From the previous state, extract which participants we're to wakeup.
111       uint64_t wakeups = prev_state & kWakeupMask;
112       // Now update prev_state to be what we want the CAS to see below.
113       prev_state &= kRefMask | kLocked | kAllocatedMask;
114       // For each wakeup bit...
115       for (size_t i = 0; wakeups != 0; i++, wakeups >>= 1) {
116         // If the bit is not set, skip.
117         if ((wakeups & 1) == 0) continue;
118         if (poll_one_participant(i)) {
119           const uint64_t allocated_bit = (1u << i << kAllocatedShift);
120           prev_state &= ~allocated_bit;
121           state_.fetch_and(~allocated_bit, std::memory_order_release);
122         }
123       }
124       // Try to CAS the state we expected to have (with no wakeups or adds)
125       // back to unlocked (by masking in only the ref mask - sans locked bit).
126       // If this succeeds then no wakeups were added, no adds were added, and we
127       // have successfully unlocked.
128       // Otherwise, we need to loop again.
129       // Note that if an owning waker is created or the weak cas spuriously
130       // fails we will also loop again, but in that case see no wakeups or adds
131       // and so will get back here fairly quickly.
132       // TODO(ctiller): consider mitigations for the accidental wakeup on owning
133       // waker creation case -- I currently expect this will be more expensive
134       // than this quick loop.
135     } while (!state_.compare_exchange_weak(
136         prev_state, (prev_state & (kRefMask | kAllocatedMask)),
137         std::memory_order_acq_rel, std::memory_order_acquire));
138     return false;
139   }
140 
141   // Add new participants to the party. Returns true if the caller should run
142   // the party. store is called with an array of indices of the new
143   // participants. Adds a ref that should be dropped by the caller after
144   // RunParty has been called (if that was required).
145   template <typename F>
AddParticipantsAndRef(size_t count,F store)146   GRPC_MUST_USE_RESULT bool AddParticipantsAndRef(size_t count, F store) {
147     uint64_t state = state_.load(std::memory_order_acquire);
148     uint64_t allocated;
149 
150     size_t slots[party_detail::kMaxParticipants];
151 
152     // Find slots for each new participant, ordering them from lowest available
153     // slot upwards to ensure the same poll ordering as presentation ordering to
154     // this function.
155     WakeupMask wakeup_mask;
156     do {
157       wakeup_mask = 0;
158       allocated = (state & kAllocatedMask) >> kAllocatedShift;
159       size_t n = 0;
160       for (size_t bit = 0; n < count && bit < party_detail::kMaxParticipants;
161            bit++) {
162         if (allocated & (1 << bit)) continue;
163         wakeup_mask |= (1 << bit);
164         slots[n++] = bit;
165         allocated |= 1 << bit;
166       }
167       GPR_ASSERT(n == count);
168       // Try to allocate this slot and take a ref (atomically).
169       // Ref needs to be taken because once we store the participant it could be
170       // spuriously woken up and unref the party.
171     } while (!state_.compare_exchange_weak(
172         state, (state | (allocated << kAllocatedShift)) + kOneRef,
173         std::memory_order_acq_rel, std::memory_order_acquire));
174 
175     store(slots);
176 
177     // Now we need to wake up the party.
178     state = state_.fetch_or(wakeup_mask | kLocked, std::memory_order_release);
179 
180     // If the party was already locked, we're done.
181     return ((state & kLocked) == 0);
182   }
183 
184   // Schedule a wakeup for the given participant.
185   // Returns true if the caller should run the party.
186   GRPC_MUST_USE_RESULT bool ScheduleWakeup(WakeupMask mask);
187 
188  private:
189   bool UnreffedLast();
190 
191   // State bits:
192   // The atomic state_ field is composed of the following:
193   //   - 24 bits for ref counts
194   //     1 is owned by the party prior to Orphan()
195   //     All others are owned by owning wakers
196   //   - 1 bit to indicate whether the party is locked
197   //     The first thread to set this owns the party until it is unlocked
198   //     That thread will run the main loop until no further work needs to
199   //     be done.
200   //   - 1 bit to indicate whether there are participants waiting to be
201   //   added
202   //   - 16 bits, one per participant, indicating which participants have
203   //   been
204   //     woken up and should be polled next time the main loop runs.
205 
206   // clang-format off
207   // Bits used to store 16 bits of wakeups
208   static constexpr uint64_t kWakeupMask    = 0x0000'0000'0000'ffff;
209   // Bits used to store 16 bits of allocated participant slots.
210   static constexpr uint64_t kAllocatedMask = 0x0000'0000'ffff'0000;
211   // Bit indicating destruction has begun (refs went to zero)
212   static constexpr uint64_t kDestroying    = 0x0000'0001'0000'0000;
213   // Bit indicating locked or not
214   static constexpr uint64_t kLocked        = 0x0000'0008'0000'0000;
215   // Bits used to store 24 bits of ref counts
216   static constexpr uint64_t kRefMask       = 0xffff'ff00'0000'0000;
217   // clang-format on
218 
219   // Shift to get from a participant mask to an allocated mask.
220   static constexpr size_t kAllocatedShift = 16;
221   // How far to shift to get the refcount
222   static constexpr size_t kRefShift = 40;
223   // One ref count
224   static constexpr uint64_t kOneRef = 1ull << kRefShift;
225 
226   std::atomic<uint64_t> state_;
227 };
228 
229 class PartySyncUsingMutex {
230  public:
PartySyncUsingMutex(size_t initial_refs)231   explicit PartySyncUsingMutex(size_t initial_refs) : refs_(initial_refs) {}
232 
233   void IncrementRefCount() { refs_.Ref(); }
234   GRPC_MUST_USE_RESULT bool RefIfNonZero() { return refs_.RefIfNonZero(); }
235   GRPC_MUST_USE_RESULT bool Unref() { return refs_.Unref(); }
236   void ForceImmediateRepoll(WakeupMask mask) {
237     MutexLock lock(&mu_);
238     wakeups_ |= mask;
239   }
240   template <typename F>
241   GRPC_MUST_USE_RESULT bool RunParty(F poll_one_participant) {
242     WakeupMask freed = 0;
243     while (true) {
244       ReleasableMutexLock lock(&mu_);
245       GPR_ASSERT(locked_);
246       allocated_ &= ~std::exchange(freed, 0);
247       auto wakeup = std::exchange(wakeups_, 0);
248       if (wakeup == 0) {
249         locked_ = false;
250         return false;
251       }
252       lock.Release();
253       for (size_t i = 0; wakeup != 0; i++, wakeup >>= 1) {
254         if ((wakeup & 1) == 0) continue;
255         if (poll_one_participant(i)) freed |= 1 << i;
256       }
257     }
258   }
259 
260   template <typename F>
261   GRPC_MUST_USE_RESULT bool AddParticipantsAndRef(size_t count, F store) {
262     IncrementRefCount();
263     MutexLock lock(&mu_);
264     size_t slots[party_detail::kMaxParticipants];
265     WakeupMask wakeup_mask = 0;
266     size_t n = 0;
267     for (size_t bit = 0; n < count && bit < party_detail::kMaxParticipants;
268          bit++) {
269       if (allocated_ & (1 << bit)) continue;
270       slots[n++] = bit;
271       wakeup_mask |= 1 << bit;
272       allocated_ |= 1 << bit;
273     }
274     GPR_ASSERT(n == count);
275     store(slots);
276     wakeups_ |= wakeup_mask;
277     return !std::exchange(locked_, true);
278   }
279 
280   GRPC_MUST_USE_RESULT bool ScheduleWakeup(WakeupMask mask);
281 
282  private:
283   RefCount refs_;
284   Mutex mu_;
285   WakeupMask allocated_ ABSL_GUARDED_BY(mu_) = 0;
286   WakeupMask wakeups_ ABSL_GUARDED_BY(mu_) = 0;
287   bool locked_ ABSL_GUARDED_BY(mu_) = false;
288 };
289 
290 // A Party is an Activity with multiple participant promises.
291 class Party : public Activity, private Wakeable {
292  private:
293   // Non-owning wakeup handle.
294   class Handle;
295 
296   // One participant in the party.
297   class Participant {
298    public:
299     explicit Participant(absl::string_view name) : name_(name) {}
300     // Poll the participant. Return true if complete.
301     // Participant should take care of its own deallocation in this case.
302     virtual bool Poll() = 0;
303 
304     // Destroy the participant before finishing.
305     virtual void Destroy() = 0;
306 
307     // Return a Handle instance for this participant.
308     Wakeable* MakeNonOwningWakeable(Party* party);
309 
310     absl::string_view name() const { return name_; }
311 
312    protected:
313     ~Participant();
314 
315    private:
316     Handle* handle_ = nullptr;
317     absl::string_view name_;
318   };
319 
320  public:
321   Party(const Party&) = delete;
322   Party& operator=(const Party&) = delete;
323 
324   // Spawn one promise into the party.
325   // The promise will be polled until it is resolved, or until the party is shut
326   // down.
327   // The on_complete callback will be called with the result of the promise if
328   // it completes.
329   // A maximum of sixteen promises can be spawned onto a party.
330   template <typename Factory, typename OnComplete>
331   void Spawn(absl::string_view name, Factory promise_factory,
332              OnComplete on_complete);
333 
334   void Orphan() final { Crash("unused"); }
335 
336   // Activity implementation: not allowed to be overridden by derived types.
337   void ForceImmediateRepoll(WakeupMask mask) final;
338   WakeupMask CurrentParticipant() const final {
339     GPR_DEBUG_ASSERT(currently_polling_ != kNotPolling);
340     return 1u << currently_polling_;
341   }
342   Waker MakeOwningWaker() final;
343   Waker MakeNonOwningWaker() final;
344   std::string ActivityDebugTag(WakeupMask wakeup_mask) const final;
345 
346   void IncrementRefCount() { sync_.IncrementRefCount(); }
347   void Unref() {
348     if (sync_.Unref()) PartyIsOver();
349   }
350   RefCountedPtr<Party> Ref() {
351     IncrementRefCount();
352     return RefCountedPtr<Party>(this);
353   }
354 
355   Arena* arena() const { return arena_; }
356 
357   class BulkSpawner {
358    public:
359     explicit BulkSpawner(Party* party) : party_(party) {}
360     ~BulkSpawner() {
361       party_->AddParticipants(participants_, num_participants_);
362     }
363 
364     template <typename Factory, typename OnComplete>
365     void Spawn(absl::string_view name, Factory promise_factory,
366                OnComplete on_complete);
367 
368    private:
369     Party* const party_;
370     size_t num_participants_ = 0;
371     Participant* participants_[party_detail::kMaxParticipants];
372   };
373 
374  protected:
375   explicit Party(Arena* arena, size_t initial_refs)
376       : sync_(initial_refs), arena_(arena) {}
377   ~Party() override;
378 
379   // Main run loop. Must be locked.
380   // Polls participants and drains the add queue until there is no work left to
381   // be done.
382   // Derived types will likely want to override this to set up their
383   // contexts before polling.
384   // Should not be called by derived types except as a tail call to the base
385   // class RunParty when overriding this method to add custom context.
386   // Returns true if the party is over.
387   virtual bool RunParty() GRPC_MUST_USE_RESULT;
388 
389   bool RefIfNonZero() { return sync_.RefIfNonZero(); }
390 
391   // Destroy any remaining participants.
392   // Should be called by derived types in response to PartyOver.
393   // Needs to have normal context setup before calling.
394   void CancelRemainingParticipants();
395 
396  private:
397   // Concrete implementation of a participant for some promise & oncomplete
398   // type.
399   template <typename SuppliedFactory, typename OnComplete>
400   class ParticipantImpl final : public Participant {
401     using Factory = promise_detail::OncePromiseFactory<void, SuppliedFactory>;
402     using Promise = typename Factory::Promise;
403 
404    public:
405     ParticipantImpl(absl::string_view name, SuppliedFactory promise_factory,
406                     OnComplete on_complete)
407         : Participant(name), on_complete_(std::move(on_complete)) {
408       Construct(&factory_, std::move(promise_factory));
409     }
410     ~ParticipantImpl() {
411       if (!started_) {
412         Destruct(&factory_);
413       } else {
414         Destruct(&promise_);
415       }
416     }
417 
418     bool Poll() override {
419       if (!started_) {
420         auto p = factory_.Make();
421         Destruct(&factory_);
422         Construct(&promise_, std::move(p));
423         started_ = true;
424       }
425       auto p = promise_();
426       if (auto* r = p.value_if_ready()) {
427         on_complete_(std::move(*r));
428         GetContext<Arena>()->DeletePooled(this);
429         return true;
430       }
431       return false;
432     }
433 
434     void Destroy() override { GetContext<Arena>()->DeletePooled(this); }
435 
436    private:
437     union {
438       GPR_NO_UNIQUE_ADDRESS Factory factory_;
439       GPR_NO_UNIQUE_ADDRESS Promise promise_;
440     };
441     GPR_NO_UNIQUE_ADDRESS OnComplete on_complete_;
442     bool started_ = false;
443   };
444 
445   // Notification that the party has finished and this instance can be deleted.
446   // Derived types should arrange to call CancelRemainingParticipants during
447   // this sequence.
448   virtual void PartyOver() = 0;
449 
450   // Run the locked part of the party until it is unlocked.
451   void RunLocked();
452   // Called in response to Unref() hitting zero - ultimately calls PartyOver,
453   // but needs to set some stuff up.
454   // Here so it gets compiled out of line.
455   void PartyIsOver();
456 
457   // Wakeable implementation
458   void Wakeup(WakeupMask wakeup_mask) final;
459   void WakeupAsync(WakeupMask wakeup_mask) final;
460   void Drop(WakeupMask wakeup_mask) final;
461 
462   // Add a participant (backs Spawn, after type erasure to ParticipantFactory).
463   void AddParticipants(Participant** participant, size_t count);
464 
465   virtual grpc_event_engine::experimental::EventEngine* event_engine()
466       const = 0;
467 
468   // Sentinal value for currently_polling_ when no participant is being polled.
469   static constexpr uint8_t kNotPolling = 255;
470 
471 #ifdef GRPC_PARTY_SYNC_USING_ATOMICS
472   PartySyncUsingAtomics sync_;
473 #elif defined(GRPC_PARTY_SYNC_USING_MUTEX)
474   PartySyncUsingMutex sync_;
475 #else
476 #error No synchronization method defined
477 #endif
478 
479   Arena* const arena_;
480   uint8_t currently_polling_ = kNotPolling;
481   // All current participants, using a tagged format.
482   // If the lower bit is unset, then this is a Participant*.
483   // If the lower bit is set, then this is a ParticipantFactory*.
484   std::atomic<Participant*> participants_[party_detail::kMaxParticipants] = {};
485 };
486 
487 template <typename Factory, typename OnComplete>
488 void Party::BulkSpawner::Spawn(absl::string_view name, Factory promise_factory,
489                                OnComplete on_complete) {
490   if (grpc_trace_promise_primitives.enabled()) {
491     gpr_log(GPR_DEBUG, "%s[bulk_spawn] On %p queue %s",
492             party_->DebugTag().c_str(), this, std::string(name).c_str());
493   }
494   participants_[num_participants_++] =
495       party_->arena_->NewPooled<ParticipantImpl<Factory, OnComplete>>(
496           name, std::move(promise_factory), std::move(on_complete));
497 }
498 
499 template <typename Factory, typename OnComplete>
500 void Party::Spawn(absl::string_view name, Factory promise_factory,
501                   OnComplete on_complete) {
502   BulkSpawner(this).Spawn(name, std::move(promise_factory),
503                           std::move(on_complete));
504 }
505 
506 }  // namespace grpc_core
507 
508 #endif  // GRPC_SRC_CORE_LIB_PROMISE_PARTY_H
509