1 // Copyright 2023 gRPC authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef GRPC_SRC_CORE_LIB_PROMISE_PARTY_H 16 #define GRPC_SRC_CORE_LIB_PROMISE_PARTY_H 17 18 #include <grpc/support/port_platform.h> 19 20 #include <stddef.h> 21 #include <stdint.h> 22 23 #include <atomic> 24 #include <string> 25 #include <utility> 26 27 #include "absl/base/thread_annotations.h" 28 #include "absl/strings/string_view.h" 29 30 #include <grpc/event_engine/event_engine.h> 31 #include <grpc/support/log.h> 32 33 #include "src/core/lib/debug/trace.h" 34 #include "src/core/lib/gprpp/construct_destruct.h" 35 #include "src/core/lib/gprpp/crash.h" 36 #include "src/core/lib/gprpp/ref_counted.h" 37 #include "src/core/lib/gprpp/ref_counted_ptr.h" 38 #include "src/core/lib/gprpp/sync.h" 39 #include "src/core/lib/promise/activity.h" 40 #include "src/core/lib/promise/context.h" 41 #include "src/core/lib/promise/detail/promise_factory.h" 42 #include "src/core/lib/promise/trace.h" 43 #include "src/core/lib/resource_quota/arena.h" 44 45 // Two implementations of party synchronization are provided: one using a single 46 // atomic, the other using a mutex and a set of state variables. 47 // Originally the atomic implementation was implemented, but we found some race 48 // conditions on Arm that were not reported by our default TSAN implementation. 49 // The mutex implementation was added to see if it would fix the problem, and 50 // it did. Later we found the race condition, so there's no known reason to use 51 // the mutex version - however we keep it around as a just in case measure. 52 // There's a thought of fuzzing the two implementations against each other as 53 // a correctness check of both, but that's not implemented yet. 54 55 #define GRPC_PARTY_SYNC_USING_ATOMICS 56 // #define GRPC_PARTY_SYNC_USING_MUTEX 57 58 #if defined(GRPC_PARTY_SYNC_USING_ATOMICS) + \ 59 defined(GRPC_PARTY_SYNC_USING_MUTEX) != \ 60 1 61 #error Must define a party sync mechanism 62 #endif 63 64 namespace grpc_core { 65 66 namespace party_detail { 67 68 // Number of bits reserved for wakeups gives us the maximum number of 69 // participants. 70 static constexpr size_t kMaxParticipants = 16; 71 72 } // namespace party_detail 73 74 class PartySyncUsingAtomics { 75 public: PartySyncUsingAtomics(size_t initial_refs)76 explicit PartySyncUsingAtomics(size_t initial_refs) 77 : state_(kOneRef * initial_refs) {} 78 IncrementRefCount()79 void IncrementRefCount() { 80 state_.fetch_add(kOneRef, std::memory_order_relaxed); 81 } 82 GRPC_MUST_USE_RESULT bool RefIfNonZero(); 83 // Returns true if the ref count is now zero and the caller should call 84 // PartyOver Unref()85 GRPC_MUST_USE_RESULT bool Unref() { 86 uint64_t prev_state = state_.fetch_sub(kOneRef, std::memory_order_acq_rel); 87 if ((prev_state & kRefMask) == kOneRef) { 88 return UnreffedLast(); 89 } 90 return false; 91 } ForceImmediateRepoll(WakeupMask mask)92 void ForceImmediateRepoll(WakeupMask mask) { 93 // Or in the bit for the currently polling participant. 94 // Will be grabbed next round to force a repoll of this promise. 95 state_.fetch_or(mask, std::memory_order_relaxed); 96 } 97 98 // Run the update loop: poll_one_participant is called with an integral index 99 // for the participant that should be polled. It should return true if the 100 // participant completed and should be removed from the allocated set. 101 template <typename F> RunParty(F poll_one_participant)102 GRPC_MUST_USE_RESULT bool RunParty(F poll_one_participant) { 103 uint64_t prev_state; 104 do { 105 // Grab the current state, and clear the wakeup bits & add flag. 106 prev_state = state_.fetch_and(kRefMask | kLocked | kAllocatedMask, 107 std::memory_order_acquire); 108 GPR_ASSERT(prev_state & kLocked); 109 if (prev_state & kDestroying) return true; 110 // From the previous state, extract which participants we're to wakeup. 111 uint64_t wakeups = prev_state & kWakeupMask; 112 // Now update prev_state to be what we want the CAS to see below. 113 prev_state &= kRefMask | kLocked | kAllocatedMask; 114 // For each wakeup bit... 115 for (size_t i = 0; wakeups != 0; i++, wakeups >>= 1) { 116 // If the bit is not set, skip. 117 if ((wakeups & 1) == 0) continue; 118 if (poll_one_participant(i)) { 119 const uint64_t allocated_bit = (1u << i << kAllocatedShift); 120 prev_state &= ~allocated_bit; 121 state_.fetch_and(~allocated_bit, std::memory_order_release); 122 } 123 } 124 // Try to CAS the state we expected to have (with no wakeups or adds) 125 // back to unlocked (by masking in only the ref mask - sans locked bit). 126 // If this succeeds then no wakeups were added, no adds were added, and we 127 // have successfully unlocked. 128 // Otherwise, we need to loop again. 129 // Note that if an owning waker is created or the weak cas spuriously 130 // fails we will also loop again, but in that case see no wakeups or adds 131 // and so will get back here fairly quickly. 132 // TODO(ctiller): consider mitigations for the accidental wakeup on owning 133 // waker creation case -- I currently expect this will be more expensive 134 // than this quick loop. 135 } while (!state_.compare_exchange_weak( 136 prev_state, (prev_state & (kRefMask | kAllocatedMask)), 137 std::memory_order_acq_rel, std::memory_order_acquire)); 138 return false; 139 } 140 141 // Add new participants to the party. Returns true if the caller should run 142 // the party. store is called with an array of indices of the new 143 // participants. Adds a ref that should be dropped by the caller after 144 // RunParty has been called (if that was required). 145 template <typename F> AddParticipantsAndRef(size_t count,F store)146 GRPC_MUST_USE_RESULT bool AddParticipantsAndRef(size_t count, F store) { 147 uint64_t state = state_.load(std::memory_order_acquire); 148 uint64_t allocated; 149 150 size_t slots[party_detail::kMaxParticipants]; 151 152 // Find slots for each new participant, ordering them from lowest available 153 // slot upwards to ensure the same poll ordering as presentation ordering to 154 // this function. 155 WakeupMask wakeup_mask; 156 do { 157 wakeup_mask = 0; 158 allocated = (state & kAllocatedMask) >> kAllocatedShift; 159 size_t n = 0; 160 for (size_t bit = 0; n < count && bit < party_detail::kMaxParticipants; 161 bit++) { 162 if (allocated & (1 << bit)) continue; 163 wakeup_mask |= (1 << bit); 164 slots[n++] = bit; 165 allocated |= 1 << bit; 166 } 167 GPR_ASSERT(n == count); 168 // Try to allocate this slot and take a ref (atomically). 169 // Ref needs to be taken because once we store the participant it could be 170 // spuriously woken up and unref the party. 171 } while (!state_.compare_exchange_weak( 172 state, (state | (allocated << kAllocatedShift)) + kOneRef, 173 std::memory_order_acq_rel, std::memory_order_acquire)); 174 175 store(slots); 176 177 // Now we need to wake up the party. 178 state = state_.fetch_or(wakeup_mask | kLocked, std::memory_order_release); 179 180 // If the party was already locked, we're done. 181 return ((state & kLocked) == 0); 182 } 183 184 // Schedule a wakeup for the given participant. 185 // Returns true if the caller should run the party. 186 GRPC_MUST_USE_RESULT bool ScheduleWakeup(WakeupMask mask); 187 188 private: 189 bool UnreffedLast(); 190 191 // State bits: 192 // The atomic state_ field is composed of the following: 193 // - 24 bits for ref counts 194 // 1 is owned by the party prior to Orphan() 195 // All others are owned by owning wakers 196 // - 1 bit to indicate whether the party is locked 197 // The first thread to set this owns the party until it is unlocked 198 // That thread will run the main loop until no further work needs to 199 // be done. 200 // - 1 bit to indicate whether there are participants waiting to be 201 // added 202 // - 16 bits, one per participant, indicating which participants have 203 // been 204 // woken up and should be polled next time the main loop runs. 205 206 // clang-format off 207 // Bits used to store 16 bits of wakeups 208 static constexpr uint64_t kWakeupMask = 0x0000'0000'0000'ffff; 209 // Bits used to store 16 bits of allocated participant slots. 210 static constexpr uint64_t kAllocatedMask = 0x0000'0000'ffff'0000; 211 // Bit indicating destruction has begun (refs went to zero) 212 static constexpr uint64_t kDestroying = 0x0000'0001'0000'0000; 213 // Bit indicating locked or not 214 static constexpr uint64_t kLocked = 0x0000'0008'0000'0000; 215 // Bits used to store 24 bits of ref counts 216 static constexpr uint64_t kRefMask = 0xffff'ff00'0000'0000; 217 // clang-format on 218 219 // Shift to get from a participant mask to an allocated mask. 220 static constexpr size_t kAllocatedShift = 16; 221 // How far to shift to get the refcount 222 static constexpr size_t kRefShift = 40; 223 // One ref count 224 static constexpr uint64_t kOneRef = 1ull << kRefShift; 225 226 std::atomic<uint64_t> state_; 227 }; 228 229 class PartySyncUsingMutex { 230 public: PartySyncUsingMutex(size_t initial_refs)231 explicit PartySyncUsingMutex(size_t initial_refs) : refs_(initial_refs) {} 232 233 void IncrementRefCount() { refs_.Ref(); } 234 GRPC_MUST_USE_RESULT bool RefIfNonZero() { return refs_.RefIfNonZero(); } 235 GRPC_MUST_USE_RESULT bool Unref() { return refs_.Unref(); } 236 void ForceImmediateRepoll(WakeupMask mask) { 237 MutexLock lock(&mu_); 238 wakeups_ |= mask; 239 } 240 template <typename F> 241 GRPC_MUST_USE_RESULT bool RunParty(F poll_one_participant) { 242 WakeupMask freed = 0; 243 while (true) { 244 ReleasableMutexLock lock(&mu_); 245 GPR_ASSERT(locked_); 246 allocated_ &= ~std::exchange(freed, 0); 247 auto wakeup = std::exchange(wakeups_, 0); 248 if (wakeup == 0) { 249 locked_ = false; 250 return false; 251 } 252 lock.Release(); 253 for (size_t i = 0; wakeup != 0; i++, wakeup >>= 1) { 254 if ((wakeup & 1) == 0) continue; 255 if (poll_one_participant(i)) freed |= 1 << i; 256 } 257 } 258 } 259 260 template <typename F> 261 GRPC_MUST_USE_RESULT bool AddParticipantsAndRef(size_t count, F store) { 262 IncrementRefCount(); 263 MutexLock lock(&mu_); 264 size_t slots[party_detail::kMaxParticipants]; 265 WakeupMask wakeup_mask = 0; 266 size_t n = 0; 267 for (size_t bit = 0; n < count && bit < party_detail::kMaxParticipants; 268 bit++) { 269 if (allocated_ & (1 << bit)) continue; 270 slots[n++] = bit; 271 wakeup_mask |= 1 << bit; 272 allocated_ |= 1 << bit; 273 } 274 GPR_ASSERT(n == count); 275 store(slots); 276 wakeups_ |= wakeup_mask; 277 return !std::exchange(locked_, true); 278 } 279 280 GRPC_MUST_USE_RESULT bool ScheduleWakeup(WakeupMask mask); 281 282 private: 283 RefCount refs_; 284 Mutex mu_; 285 WakeupMask allocated_ ABSL_GUARDED_BY(mu_) = 0; 286 WakeupMask wakeups_ ABSL_GUARDED_BY(mu_) = 0; 287 bool locked_ ABSL_GUARDED_BY(mu_) = false; 288 }; 289 290 // A Party is an Activity with multiple participant promises. 291 class Party : public Activity, private Wakeable { 292 private: 293 // Non-owning wakeup handle. 294 class Handle; 295 296 // One participant in the party. 297 class Participant { 298 public: 299 explicit Participant(absl::string_view name) : name_(name) {} 300 // Poll the participant. Return true if complete. 301 // Participant should take care of its own deallocation in this case. 302 virtual bool Poll() = 0; 303 304 // Destroy the participant before finishing. 305 virtual void Destroy() = 0; 306 307 // Return a Handle instance for this participant. 308 Wakeable* MakeNonOwningWakeable(Party* party); 309 310 absl::string_view name() const { return name_; } 311 312 protected: 313 ~Participant(); 314 315 private: 316 Handle* handle_ = nullptr; 317 absl::string_view name_; 318 }; 319 320 public: 321 Party(const Party&) = delete; 322 Party& operator=(const Party&) = delete; 323 324 // Spawn one promise into the party. 325 // The promise will be polled until it is resolved, or until the party is shut 326 // down. 327 // The on_complete callback will be called with the result of the promise if 328 // it completes. 329 // A maximum of sixteen promises can be spawned onto a party. 330 template <typename Factory, typename OnComplete> 331 void Spawn(absl::string_view name, Factory promise_factory, 332 OnComplete on_complete); 333 334 void Orphan() final { Crash("unused"); } 335 336 // Activity implementation: not allowed to be overridden by derived types. 337 void ForceImmediateRepoll(WakeupMask mask) final; 338 WakeupMask CurrentParticipant() const final { 339 GPR_DEBUG_ASSERT(currently_polling_ != kNotPolling); 340 return 1u << currently_polling_; 341 } 342 Waker MakeOwningWaker() final; 343 Waker MakeNonOwningWaker() final; 344 std::string ActivityDebugTag(WakeupMask wakeup_mask) const final; 345 346 void IncrementRefCount() { sync_.IncrementRefCount(); } 347 void Unref() { 348 if (sync_.Unref()) PartyIsOver(); 349 } 350 RefCountedPtr<Party> Ref() { 351 IncrementRefCount(); 352 return RefCountedPtr<Party>(this); 353 } 354 355 Arena* arena() const { return arena_; } 356 357 class BulkSpawner { 358 public: 359 explicit BulkSpawner(Party* party) : party_(party) {} 360 ~BulkSpawner() { 361 party_->AddParticipants(participants_, num_participants_); 362 } 363 364 template <typename Factory, typename OnComplete> 365 void Spawn(absl::string_view name, Factory promise_factory, 366 OnComplete on_complete); 367 368 private: 369 Party* const party_; 370 size_t num_participants_ = 0; 371 Participant* participants_[party_detail::kMaxParticipants]; 372 }; 373 374 protected: 375 explicit Party(Arena* arena, size_t initial_refs) 376 : sync_(initial_refs), arena_(arena) {} 377 ~Party() override; 378 379 // Main run loop. Must be locked. 380 // Polls participants and drains the add queue until there is no work left to 381 // be done. 382 // Derived types will likely want to override this to set up their 383 // contexts before polling. 384 // Should not be called by derived types except as a tail call to the base 385 // class RunParty when overriding this method to add custom context. 386 // Returns true if the party is over. 387 virtual bool RunParty() GRPC_MUST_USE_RESULT; 388 389 bool RefIfNonZero() { return sync_.RefIfNonZero(); } 390 391 // Destroy any remaining participants. 392 // Should be called by derived types in response to PartyOver. 393 // Needs to have normal context setup before calling. 394 void CancelRemainingParticipants(); 395 396 private: 397 // Concrete implementation of a participant for some promise & oncomplete 398 // type. 399 template <typename SuppliedFactory, typename OnComplete> 400 class ParticipantImpl final : public Participant { 401 using Factory = promise_detail::OncePromiseFactory<void, SuppliedFactory>; 402 using Promise = typename Factory::Promise; 403 404 public: 405 ParticipantImpl(absl::string_view name, SuppliedFactory promise_factory, 406 OnComplete on_complete) 407 : Participant(name), on_complete_(std::move(on_complete)) { 408 Construct(&factory_, std::move(promise_factory)); 409 } 410 ~ParticipantImpl() { 411 if (!started_) { 412 Destruct(&factory_); 413 } else { 414 Destruct(&promise_); 415 } 416 } 417 418 bool Poll() override { 419 if (!started_) { 420 auto p = factory_.Make(); 421 Destruct(&factory_); 422 Construct(&promise_, std::move(p)); 423 started_ = true; 424 } 425 auto p = promise_(); 426 if (auto* r = p.value_if_ready()) { 427 on_complete_(std::move(*r)); 428 GetContext<Arena>()->DeletePooled(this); 429 return true; 430 } 431 return false; 432 } 433 434 void Destroy() override { GetContext<Arena>()->DeletePooled(this); } 435 436 private: 437 union { 438 GPR_NO_UNIQUE_ADDRESS Factory factory_; 439 GPR_NO_UNIQUE_ADDRESS Promise promise_; 440 }; 441 GPR_NO_UNIQUE_ADDRESS OnComplete on_complete_; 442 bool started_ = false; 443 }; 444 445 // Notification that the party has finished and this instance can be deleted. 446 // Derived types should arrange to call CancelRemainingParticipants during 447 // this sequence. 448 virtual void PartyOver() = 0; 449 450 // Run the locked part of the party until it is unlocked. 451 void RunLocked(); 452 // Called in response to Unref() hitting zero - ultimately calls PartyOver, 453 // but needs to set some stuff up. 454 // Here so it gets compiled out of line. 455 void PartyIsOver(); 456 457 // Wakeable implementation 458 void Wakeup(WakeupMask wakeup_mask) final; 459 void WakeupAsync(WakeupMask wakeup_mask) final; 460 void Drop(WakeupMask wakeup_mask) final; 461 462 // Add a participant (backs Spawn, after type erasure to ParticipantFactory). 463 void AddParticipants(Participant** participant, size_t count); 464 465 virtual grpc_event_engine::experimental::EventEngine* event_engine() 466 const = 0; 467 468 // Sentinal value for currently_polling_ when no participant is being polled. 469 static constexpr uint8_t kNotPolling = 255; 470 471 #ifdef GRPC_PARTY_SYNC_USING_ATOMICS 472 PartySyncUsingAtomics sync_; 473 #elif defined(GRPC_PARTY_SYNC_USING_MUTEX) 474 PartySyncUsingMutex sync_; 475 #else 476 #error No synchronization method defined 477 #endif 478 479 Arena* const arena_; 480 uint8_t currently_polling_ = kNotPolling; 481 // All current participants, using a tagged format. 482 // If the lower bit is unset, then this is a Participant*. 483 // If the lower bit is set, then this is a ParticipantFactory*. 484 std::atomic<Participant*> participants_[party_detail::kMaxParticipants] = {}; 485 }; 486 487 template <typename Factory, typename OnComplete> 488 void Party::BulkSpawner::Spawn(absl::string_view name, Factory promise_factory, 489 OnComplete on_complete) { 490 if (grpc_trace_promise_primitives.enabled()) { 491 gpr_log(GPR_DEBUG, "%s[bulk_spawn] On %p queue %s", 492 party_->DebugTag().c_str(), this, std::string(name).c_str()); 493 } 494 participants_[num_participants_++] = 495 party_->arena_->NewPooled<ParticipantImpl<Factory, OnComplete>>( 496 name, std::move(promise_factory), std::move(on_complete)); 497 } 498 499 template <typename Factory, typename OnComplete> 500 void Party::Spawn(absl::string_view name, Factory promise_factory, 501 OnComplete on_complete) { 502 BulkSpawner(this).Spawn(name, std::move(promise_factory), 503 std::move(on_complete)); 504 } 505 506 } // namespace grpc_core 507 508 #endif // GRPC_SRC_CORE_LIB_PROMISE_PARTY_H 509