xref: /aosp_15_r20/external/grpc-grpc/test/core/promise/party_test.cc (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 // Copyright 2023 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/core/lib/promise/party.h"
16 
17 #include <stdio.h>
18 
19 #include <algorithm>
20 #include <atomic>
21 #include <memory>
22 #include <thread>
23 #include <vector>
24 
25 #include "absl/base/thread_annotations.h"
26 #include "gtest/gtest.h"
27 
28 #include <grpc/event_engine/event_engine.h>
29 #include <grpc/event_engine/memory_allocator.h>
30 #include <grpc/grpc.h>
31 
32 #include "src/core/lib/event_engine/default_event_engine.h"
33 #include "src/core/lib/gprpp/notification.h"
34 #include "src/core/lib/gprpp/ref_counted_ptr.h"
35 #include "src/core/lib/gprpp/sync.h"
36 #include "src/core/lib/gprpp/time.h"
37 #include "src/core/lib/iomgr/exec_ctx.h"
38 #include "src/core/lib/promise/context.h"
39 #include "src/core/lib/promise/inter_activity_latch.h"
40 #include "src/core/lib/promise/poll.h"
41 #include "src/core/lib/promise/seq.h"
42 #include "src/core/lib/promise/sleep.h"
43 #include "src/core/lib/resource_quota/memory_quota.h"
44 #include "src/core/lib/resource_quota/resource_quota.h"
45 
46 namespace grpc_core {
47 
48 ///////////////////////////////////////////////////////////////////////////////
49 // PartySyncTest
50 
51 template <typename T>
52 class PartySyncTest : public ::testing::Test {};
53 
54 // PartySyncUsingMutex isn't working on Mac, but we don't use it for anything
55 // right now so that's fine.
56 #ifdef GPR_APPLE
57 using PartySyncTypes = ::testing::Types<PartySyncUsingAtomics>;
58 #else
59 using PartySyncTypes =
60     ::testing::Types<PartySyncUsingAtomics, PartySyncUsingMutex>;
61 #endif
62 TYPED_TEST_SUITE(PartySyncTest, PartySyncTypes);
63 
TYPED_TEST(PartySyncTest,NoOp)64 TYPED_TEST(PartySyncTest, NoOp) { TypeParam sync(1); }
65 
TYPED_TEST(PartySyncTest,RefAndUnref)66 TYPED_TEST(PartySyncTest, RefAndUnref) {
67   Notification half_way;
68   TypeParam sync(1);
69   std::thread thread1([&] {
70     for (int i = 0; i < 1000000; i++) {
71       sync.IncrementRefCount();
72     }
73     half_way.Notify();
74     for (int i = 0; i < 1000000; i++) {
75       sync.IncrementRefCount();
76     }
77     for (int i = 0; i < 2000000; i++) {
78       EXPECT_FALSE(sync.Unref());
79     }
80   });
81   half_way.WaitForNotification();
82   for (int i = 0; i < 2000000; i++) {
83     sync.IncrementRefCount();
84   }
85   for (int i = 0; i < 2000000; i++) {
86     EXPECT_FALSE(sync.Unref());
87   }
88   thread1.join();
89   EXPECT_TRUE(sync.Unref());
90 }
91 
TYPED_TEST(PartySyncTest,AddAndRemoveParticipant)92 TYPED_TEST(PartySyncTest, AddAndRemoveParticipant) {
93   TypeParam sync(1);
94   std::vector<std::thread> threads;
95   std::atomic<std::atomic<bool>*> participants[party_detail::kMaxParticipants] =
96       {};
97   threads.reserve(8);
98   for (int i = 0; i < 8; i++) {
99     threads.emplace_back([&] {
100       for (int i = 0; i < 100000; i++) {
101         auto done = std::make_unique<std::atomic<bool>>(false);
102         int slot = -1;
103         bool run = sync.AddParticipantsAndRef(1, [&](size_t* idxs) {
104           slot = idxs[0];
105           participants[slot].store(done.get(), std::memory_order_release);
106         });
107         EXPECT_NE(slot, -1);
108         if (run) {
109           bool run_any = false;
110           bool run_me = false;
111           EXPECT_FALSE(sync.RunParty([&](int slot) {
112             run_any = true;
113             std::atomic<bool>* participant =
114                 participants[slot].exchange(nullptr, std::memory_order_acquire);
115             if (participant == done.get()) run_me = true;
116             if (participant == nullptr) {
117               gpr_log(GPR_ERROR,
118                       "Participant was null (spurious wakeup observed)");
119               return false;
120             }
121             participant->store(true, std::memory_order_release);
122             return true;
123           }));
124           EXPECT_TRUE(run_any);
125           EXPECT_TRUE(run_me);
126         }
127         EXPECT_FALSE(sync.Unref());
128         while (!done->load(std::memory_order_acquire)) {
129         }
130       }
131     });
132   }
133   for (auto& thread : threads) {
134     thread.join();
135   }
136   EXPECT_TRUE(sync.Unref());
137 }
138 
TYPED_TEST(PartySyncTest,AddAndRemoveTwoParticipants)139 TYPED_TEST(PartySyncTest, AddAndRemoveTwoParticipants) {
140   TypeParam sync(1);
141   std::vector<std::thread> threads;
142   std::atomic<std::atomic<int>*> participants[party_detail::kMaxParticipants] =
143       {};
144   threads.reserve(8);
145   for (int i = 0; i < 4; i++) {
146     threads.emplace_back([&] {
147       for (int i = 0; i < 100000; i++) {
148         auto done = std::make_unique<std::atomic<int>>(2);
149         int slots[2] = {-1, -1};
150         bool run = sync.AddParticipantsAndRef(2, [&](size_t* idxs) {
151           for (int i = 0; i < 2; i++) {
152             slots[i] = idxs[i];
153             participants[slots[i]].store(done.get(), std::memory_order_release);
154           }
155         });
156         EXPECT_NE(slots[0], -1);
157         EXPECT_NE(slots[1], -1);
158         EXPECT_GT(slots[1], slots[0]);
159         if (run) {
160           bool run_any = false;
161           int run_me = 0;
162           EXPECT_FALSE(sync.RunParty([&](int slot) {
163             run_any = true;
164             std::atomic<int>* participant =
165                 participants[slot].exchange(nullptr, std::memory_order_acquire);
166             if (participant == done.get()) run_me++;
167             if (participant == nullptr) {
168               gpr_log(GPR_ERROR,
169                       "Participant was null (spurious wakeup observed)");
170               return false;
171             }
172             participant->fetch_sub(1, std::memory_order_release);
173             return true;
174           }));
175           EXPECT_TRUE(run_any);
176           EXPECT_EQ(run_me, 2);
177         }
178         EXPECT_FALSE(sync.Unref());
179         while (done->load(std::memory_order_acquire) != 0) {
180         }
181       }
182     });
183   }
184   for (auto& thread : threads) {
185     thread.join();
186   }
187   EXPECT_TRUE(sync.Unref());
188 }
189 
TYPED_TEST(PartySyncTest,UnrefWhileRunning)190 TYPED_TEST(PartySyncTest, UnrefWhileRunning) {
191   std::vector<std::thread> trials;
192   std::atomic<int> delete_paths_taken[3] = {{0}, {0}, {0}};
193   trials.reserve(100);
194   for (int i = 0; i < 100; i++) {
195     trials.emplace_back([&delete_paths_taken] {
196       TypeParam sync(1);
197       int delete_path = -1;
198       EXPECT_TRUE(sync.AddParticipantsAndRef(
199           1, [](size_t* slots) { EXPECT_EQ(slots[0], 0); }));
200       std::thread run_party([&] {
201         if (sync.RunParty([&sync, n = 0](int slot) mutable {
202               EXPECT_EQ(slot, 0);
203               ++n;
204               if (n < 10) {
205                 sync.ForceImmediateRepoll(1);
206                 return false;
207               }
208               return true;
209             })) {
210           delete_path = 0;
211         }
212       });
213       std::thread unref([&] {
214         if (sync.Unref()) delete_path = 1;
215       });
216       if (sync.Unref()) delete_path = 2;
217       run_party.join();
218       unref.join();
219       EXPECT_GE(delete_path, 0);
220       delete_paths_taken[delete_path].fetch_add(1, std::memory_order_relaxed);
221     });
222   }
223   for (auto& trial : trials) {
224     trial.join();
225   }
226   fprintf(stderr, "DELETE_PATHS: RunParty:%d AsyncUnref:%d SyncUnref:%d\n",
227           delete_paths_taken[0].load(), delete_paths_taken[1].load(),
228           delete_paths_taken[2].load());
229 }
230 
231 ///////////////////////////////////////////////////////////////////////////////
232 // PartyTest
233 
234 class TestParty final : public Party {
235  public:
TestParty()236   TestParty() : Party(1) {}
~TestParty()237   ~TestParty() override {}
DebugTag() const238   std::string DebugTag() const override { return "TestParty"; }
239 
240   using Party::IncrementRefCount;
241   using Party::Unref;
242 
RunParty()243   bool RunParty() override {
244     promise_detail::Context<grpc_event_engine::experimental::EventEngine>
245         ee_ctx(ee_.get());
246     return Party::RunParty();
247   }
248 
PartyOver()249   void PartyOver() override {
250     {
251       promise_detail::Context<grpc_event_engine::experimental::EventEngine>
252           ee_ctx(ee_.get());
253       CancelRemainingParticipants();
254     }
255     delete this;
256   }
257 
258  private:
event_engine() const259   grpc_event_engine::experimental::EventEngine* event_engine() const final {
260     return ee_.get();
261   }
262 
263   std::shared_ptr<grpc_event_engine::experimental::EventEngine> ee_ =
264       grpc_event_engine::experimental::GetDefaultEventEngine();
265 };
266 
267 class PartyTest : public ::testing::Test {
268  protected:
269 };
270 
TEST_F(PartyTest,Noop)271 TEST_F(PartyTest, Noop) { auto party = MakeRefCounted<TestParty>(); }
272 
TEST_F(PartyTest,CanSpawnAndRun)273 TEST_F(PartyTest, CanSpawnAndRun) {
274   auto party = MakeRefCounted<TestParty>();
275   Notification n;
276   party->Spawn(
277       "TestSpawn",
278       [i = 10]() mutable -> Poll<int> {
279         EXPECT_EQ(GetContext<Activity>()->DebugTag(), "TestParty");
280         EXPECT_GT(i, 0);
281         GetContext<Activity>()->ForceImmediateRepoll();
282         --i;
283         if (i == 0) return 42;
284         return Pending{};
285       },
286       [&n](int x) {
287         EXPECT_EQ(x, 42);
288         n.Notify();
289       });
290   n.WaitForNotification();
291 }
292 
TEST_F(PartyTest,CanSpawnWaitableAndRun)293 TEST_F(PartyTest, CanSpawnWaitableAndRun) {
294   auto party1 = MakeRefCounted<TestParty>();
295   auto party2 = MakeRefCounted<TestParty>();
296   Notification n;
297   InterActivityLatch<void> done;
298   // Spawn a task on party1 that will wait for a task on party2.
299   // The party2 task will wait on the latch `done`.
300   party1->Spawn(
301       "party1_main",
302       [&party2, &done]() {
303         return party2->SpawnWaitable("party2_main",
304                                      [&done]() { return done.Wait(); });
305       },
306       [&n](Empty) { n.Notify(); });
307   ASSERT_FALSE(n.HasBeenNotified());
308   party1->Spawn(
309       "party1_notify_latch",
310       [&done]() {
311         done.Set();
312         return Empty{};
313       },
314       [](Empty) {});
315   n.WaitForNotification();
316 }
317 
TEST_F(PartyTest,CanSpawnFromSpawn)318 TEST_F(PartyTest, CanSpawnFromSpawn) {
319   auto party = MakeRefCounted<TestParty>();
320   Notification n1;
321   Notification n2;
322   party->Spawn(
323       "TestSpawn",
324       [party, &n2]() -> Poll<int> {
325         EXPECT_EQ(GetContext<Activity>()->DebugTag(), "TestParty");
326         party->Spawn(
327             "TestSpawnInner",
328             [i = 10]() mutable -> Poll<int> {
329               EXPECT_EQ(GetContext<Activity>()->DebugTag(), "TestParty");
330               GetContext<Activity>()->ForceImmediateRepoll();
331               --i;
332               if (i == 0) return 42;
333               return Pending{};
334             },
335             [&n2](int x) {
336               EXPECT_EQ(x, 42);
337               n2.Notify();
338             });
339         return 1234;
340       },
341       [&n1](int x) {
342         EXPECT_EQ(x, 1234);
343         n1.Notify();
344       });
345   n1.WaitForNotification();
346   n2.WaitForNotification();
347 }
348 
TEST_F(PartyTest,CanWakeupWithOwningWaker)349 TEST_F(PartyTest, CanWakeupWithOwningWaker) {
350   auto party = MakeRefCounted<TestParty>();
351   Notification n[10];
352   Notification complete;
353   Waker waker;
354   party->Spawn(
355       "TestSpawn",
356       [i = 0, &waker, &n]() mutable -> Poll<int> {
357         EXPECT_EQ(GetContext<Activity>()->DebugTag(), "TestParty");
358         waker = GetContext<Activity>()->MakeOwningWaker();
359         n[i].Notify();
360         i++;
361         if (i == 10) return 42;
362         return Pending{};
363       },
364       [&complete](int x) {
365         EXPECT_EQ(x, 42);
366         complete.Notify();
367       });
368   for (int i = 0; i < 10; i++) {
369     n[i].WaitForNotification();
370     waker.Wakeup();
371   }
372   complete.WaitForNotification();
373 }
374 
TEST_F(PartyTest,CanWakeupWithNonOwningWaker)375 TEST_F(PartyTest, CanWakeupWithNonOwningWaker) {
376   auto party = MakeRefCounted<TestParty>();
377   Notification n[10];
378   Notification complete;
379   Waker waker;
380   party->Spawn(
381       "TestSpawn",
382       [i = 10, &waker, &n]() mutable -> Poll<int> {
383         EXPECT_EQ(GetContext<Activity>()->DebugTag(), "TestParty");
384         waker = GetContext<Activity>()->MakeNonOwningWaker();
385         --i;
386         n[9 - i].Notify();
387         if (i == 0) return 42;
388         return Pending{};
389       },
390       [&complete](int x) {
391         EXPECT_EQ(x, 42);
392         complete.Notify();
393       });
394   for (int i = 0; i < 9; i++) {
395     n[i].WaitForNotification();
396     EXPECT_FALSE(n[i + 1].HasBeenNotified());
397     waker.Wakeup();
398   }
399   complete.WaitForNotification();
400 }
401 
TEST_F(PartyTest,CanWakeupWithNonOwningWakerAfterOrphaning)402 TEST_F(PartyTest, CanWakeupWithNonOwningWakerAfterOrphaning) {
403   auto party = MakeRefCounted<TestParty>();
404   Notification set_waker;
405   Waker waker;
406   party->Spawn(
407       "TestSpawn",
408       [&waker, &set_waker]() mutable -> Poll<int> {
409         EXPECT_FALSE(set_waker.HasBeenNotified());
410         EXPECT_EQ(GetContext<Activity>()->DebugTag(), "TestParty");
411         waker = GetContext<Activity>()->MakeNonOwningWaker();
412         set_waker.Notify();
413         return Pending{};
414       },
415       [](int) { Crash("unreachable"); });
416   set_waker.WaitForNotification();
417   party.reset();
418   EXPECT_FALSE(waker.is_unwakeable());
419   waker.Wakeup();
420   EXPECT_TRUE(waker.is_unwakeable());
421 }
422 
TEST_F(PartyTest,CanDropNonOwningWakeAfterOrphaning)423 TEST_F(PartyTest, CanDropNonOwningWakeAfterOrphaning) {
424   auto party = MakeRefCounted<TestParty>();
425   Notification set_waker;
426   std::unique_ptr<Waker> waker;
427   party->Spawn(
428       "TestSpawn",
429       [&waker, &set_waker]() mutable -> Poll<int> {
430         EXPECT_FALSE(set_waker.HasBeenNotified());
431         EXPECT_EQ(GetContext<Activity>()->DebugTag(), "TestParty");
432         waker = std::make_unique<Waker>(
433             GetContext<Activity>()->MakeNonOwningWaker());
434         set_waker.Notify();
435         return Pending{};
436       },
437       [](int) { Crash("unreachable"); });
438   set_waker.WaitForNotification();
439   party.reset();
440   EXPECT_NE(waker, nullptr);
441   waker.reset();
442 }
443 
TEST_F(PartyTest,CanWakeupNonOwningOrphanedWakerWithNoEffect)444 TEST_F(PartyTest, CanWakeupNonOwningOrphanedWakerWithNoEffect) {
445   auto party = MakeRefCounted<TestParty>();
446   Notification set_waker;
447   Waker waker;
448   party->Spawn(
449       "TestSpawn",
450       [&waker, &set_waker]() mutable -> Poll<int> {
451         EXPECT_FALSE(set_waker.HasBeenNotified());
452         EXPECT_EQ(GetContext<Activity>()->DebugTag(), "TestParty");
453         waker = GetContext<Activity>()->MakeNonOwningWaker();
454         set_waker.Notify();
455         return Pending{};
456       },
457       [](int) { Crash("unreachable"); });
458   set_waker.WaitForNotification();
459   EXPECT_FALSE(waker.is_unwakeable());
460   party.reset();
461   waker.Wakeup();
462   EXPECT_TRUE(waker.is_unwakeable());
463 }
464 
TEST_F(PartyTest,CanBulkSpawn)465 TEST_F(PartyTest, CanBulkSpawn) {
466   auto party = MakeRefCounted<TestParty>();
467   Notification n1;
468   Notification n2;
469   {
470     Party::BulkSpawner spawner(party.get());
471     spawner.Spawn(
472         "spawn1", []() { return Empty{}; }, [&n1](Empty) { n1.Notify(); });
473     spawner.Spawn(
474         "spawn2", []() { return Empty{}; }, [&n2](Empty) { n2.Notify(); });
475     for (int i = 0; i < 5000; i++) {
476       EXPECT_FALSE(n1.HasBeenNotified());
477       EXPECT_FALSE(n2.HasBeenNotified());
478     }
479   }
480   n1.WaitForNotification();
481   n2.WaitForNotification();
482 }
483 
TEST_F(PartyTest,AfterCurrentPollWorks)484 TEST_F(PartyTest, AfterCurrentPollWorks) {
485   auto party = MakeRefCounted<TestParty>();
486   Notification n;
487   int state = 0;
488   {
489     Party::BulkSpawner spawner(party.get());
490     // BulkSpawner will schedule and poll this promise first, but the
491     // `AfterCurrentPoll` will pause it.
492     // Then spawn1, spawn2, and spawn3 will run in order (with EXPECT_EQ checks
493     // demonstrating this), at which point the poll will complete, causing
494     // spawn_final to be awoken and scheduled and see the final state.
495     spawner.Spawn(
496         "spawn_final",
497         [&state, &party]() {
498           return Seq(party->AfterCurrentPoll(), [&state]() {
499             EXPECT_EQ(state, 3);
500             return Empty{};
501           });
502         },
503         [&n](Empty) { n.Notify(); });
504     spawner.Spawn(
505         "spawn1",
506         [&state]() {
507           EXPECT_EQ(state, 0);
508           state = 1;
509           return Empty{};
510         },
511         [](Empty) {});
512     spawner.Spawn(
513         "spawn2",
514         [&state]() {
515           EXPECT_EQ(state, 1);
516           state = 2;
517           return Empty{};
518         },
519         [](Empty) {});
520     spawner.Spawn(
521         "spawn3",
522         [&state]() {
523           EXPECT_EQ(state, 2);
524           state = 3;
525           return Empty{};
526         },
527         [](Empty) {});
528   }
529   n.WaitForNotification();
530 }
531 
TEST_F(PartyTest,ThreadStressTest)532 TEST_F(PartyTest, ThreadStressTest) {
533   auto party = MakeRefCounted<TestParty>();
534   std::vector<std::thread> threads;
535   threads.reserve(8);
536   for (int i = 0; i < 8; i++) {
537     threads.emplace_back([party]() {
538       for (int i = 0; i < 100; i++) {
539         ExecCtx ctx;  // needed for Sleep
540         Notification promise_complete;
541         party->Spawn("TestSpawn",
542                      Seq(Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
543                          []() -> Poll<int> { return 42; }),
544                      [&promise_complete](int i) {
545                        EXPECT_EQ(i, 42);
546                        promise_complete.Notify();
547                      });
548         promise_complete.WaitForNotification();
549       }
550     });
551   }
552   for (auto& thread : threads) {
553     thread.join();
554   }
555 }
556 
557 class PromiseNotification {
558  public:
PromiseNotification(bool owning_waker)559   explicit PromiseNotification(bool owning_waker)
560       : owning_waker_(owning_waker) {}
561 
Wait()562   auto Wait() {
563     return [this]() -> Poll<int> {
564       MutexLock lock(&mu_);
565       if (done_) return 42;
566       if (!polled_) {
567         if (owning_waker_) {
568           waker_ = GetContext<Activity>()->MakeOwningWaker();
569         } else {
570           waker_ = GetContext<Activity>()->MakeNonOwningWaker();
571         }
572         polled_ = true;
573       }
574       return Pending{};
575     };
576   }
577 
Notify()578   void Notify() {
579     Waker waker;
580     {
581       MutexLock lock(&mu_);
582       done_ = true;
583       waker = std::move(waker_);
584     }
585     waker.Wakeup();
586   }
587 
NotifyUnderLock()588   void NotifyUnderLock() {
589     MutexLock lock(&mu_);
590     done_ = true;
591     waker_.WakeupAsync();
592   }
593 
594  private:
595   Mutex mu_;
596   const bool owning_waker_;
597   bool done_ ABSL_GUARDED_BY(mu_) = false;
598   bool polled_ ABSL_GUARDED_BY(mu_) = false;
599   Waker waker_ ABSL_GUARDED_BY(mu_);
600 };
601 
TEST_F(PartyTest,ThreadStressTestWithOwningWaker)602 TEST_F(PartyTest, ThreadStressTestWithOwningWaker) {
603   auto party = MakeRefCounted<TestParty>();
604   std::vector<std::thread> threads;
605   threads.reserve(8);
606   for (int i = 0; i < 8; i++) {
607     threads.emplace_back([party]() {
608       for (int i = 0; i < 100; i++) {
609         ExecCtx ctx;  // needed for Sleep
610         PromiseNotification promise_start(true);
611         Notification promise_complete;
612         party->Spawn("TestSpawn",
613                      Seq(promise_start.Wait(),
614                          Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
615                          []() -> Poll<int> { return 42; }),
616                      [&promise_complete](int i) {
617                        EXPECT_EQ(i, 42);
618                        promise_complete.Notify();
619                      });
620         promise_start.Notify();
621         promise_complete.WaitForNotification();
622       }
623     });
624   }
625   for (auto& thread : threads) {
626     thread.join();
627   }
628 }
629 
TEST_F(PartyTest,ThreadStressTestWithOwningWakerHoldingLock)630 TEST_F(PartyTest, ThreadStressTestWithOwningWakerHoldingLock) {
631   auto party = MakeRefCounted<TestParty>();
632   std::vector<std::thread> threads;
633   threads.reserve(8);
634   for (int i = 0; i < 8; i++) {
635     threads.emplace_back([party]() {
636       for (int i = 0; i < 100; i++) {
637         ExecCtx ctx;  // needed for Sleep
638         PromiseNotification promise_start(true);
639         Notification promise_complete;
640         party->Spawn("TestSpawn",
641                      Seq(promise_start.Wait(),
642                          Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
643                          []() -> Poll<int> { return 42; }),
644                      [&promise_complete](int i) {
645                        EXPECT_EQ(i, 42);
646                        promise_complete.Notify();
647                      });
648         promise_start.NotifyUnderLock();
649         promise_complete.WaitForNotification();
650       }
651     });
652   }
653   for (auto& thread : threads) {
654     thread.join();
655   }
656 }
657 
TEST_F(PartyTest,ThreadStressTestWithNonOwningWaker)658 TEST_F(PartyTest, ThreadStressTestWithNonOwningWaker) {
659   auto party = MakeRefCounted<TestParty>();
660   std::vector<std::thread> threads;
661   threads.reserve(8);
662   for (int i = 0; i < 8; i++) {
663     threads.emplace_back([party]() {
664       for (int i = 0; i < 100; i++) {
665         ExecCtx ctx;  // needed for Sleep
666         PromiseNotification promise_start(false);
667         Notification promise_complete;
668         party->Spawn("TestSpawn",
669                      Seq(promise_start.Wait(),
670                          Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
671                          []() -> Poll<int> { return 42; }),
672                      [&promise_complete](int i) {
673                        EXPECT_EQ(i, 42);
674                        promise_complete.Notify();
675                      });
676         promise_start.Notify();
677         promise_complete.WaitForNotification();
678       }
679     });
680   }
681   for (auto& thread : threads) {
682     thread.join();
683   }
684 }
685 
TEST_F(PartyTest,ThreadStressTestWithOwningWakerNoSleep)686 TEST_F(PartyTest, ThreadStressTestWithOwningWakerNoSleep) {
687   auto party = MakeRefCounted<TestParty>();
688   std::vector<std::thread> threads;
689   threads.reserve(8);
690   for (int i = 0; i < 8; i++) {
691     threads.emplace_back([party]() {
692       for (int i = 0; i < 10000; i++) {
693         PromiseNotification promise_start(true);
694         Notification promise_complete;
695         party->Spawn(
696             "TestSpawn",
697             Seq(promise_start.Wait(), []() -> Poll<int> { return 42; }),
698             [&promise_complete](int i) {
699               EXPECT_EQ(i, 42);
700               promise_complete.Notify();
701             });
702         promise_start.Notify();
703         promise_complete.WaitForNotification();
704       }
705     });
706   }
707   for (auto& thread : threads) {
708     thread.join();
709   }
710 }
711 
TEST_F(PartyTest,ThreadStressTestWithNonOwningWakerNoSleep)712 TEST_F(PartyTest, ThreadStressTestWithNonOwningWakerNoSleep) {
713   auto party = MakeRefCounted<TestParty>();
714   std::vector<std::thread> threads;
715   threads.reserve(8);
716   for (int i = 0; i < 8; i++) {
717     threads.emplace_back([party]() {
718       for (int i = 0; i < 10000; i++) {
719         PromiseNotification promise_start(false);
720         Notification promise_complete;
721         party->Spawn(
722             "TestSpawn",
723             Seq(promise_start.Wait(), []() -> Poll<int> { return 42; }),
724             [&promise_complete](int i) {
725               EXPECT_EQ(i, 42);
726               promise_complete.Notify();
727             });
728         promise_start.Notify();
729         promise_complete.WaitForNotification();
730       }
731     });
732   }
733   for (auto& thread : threads) {
734     thread.join();
735   }
736 }
737 
TEST_F(PartyTest,ThreadStressTestWithInnerSpawn)738 TEST_F(PartyTest, ThreadStressTestWithInnerSpawn) {
739   auto party = MakeRefCounted<TestParty>();
740   std::vector<std::thread> threads;
741   threads.reserve(8);
742   for (int i = 0; i < 8; i++) {
743     threads.emplace_back([party]() {
744       for (int i = 0; i < 100; i++) {
745         ExecCtx ctx;  // needed for Sleep
746         PromiseNotification inner_start(true);
747         PromiseNotification inner_complete(false);
748         Notification promise_complete;
749         party->Spawn(
750             "TestSpawn",
751             Seq(
752                 [party, &inner_start, &inner_complete]() -> Poll<int> {
753                   party->Spawn("TestSpawnInner",
754                                Seq(inner_start.Wait(), []() { return 0; }),
755                                [&inner_complete](int i) {
756                                  EXPECT_EQ(i, 0);
757                                  inner_complete.Notify();
758                                });
759                   return 0;
760                 },
761                 Sleep(Timestamp::Now() + Duration::Milliseconds(10)),
762                 [&inner_start]() {
763                   inner_start.Notify();
764                   return 0;
765                 },
766                 inner_complete.Wait(), []() -> Poll<int> { return 42; }),
767             [&promise_complete](int i) {
768               EXPECT_EQ(i, 42);
769               promise_complete.Notify();
770             });
771         promise_complete.WaitForNotification();
772       }
773     });
774   }
775   for (auto& thread : threads) {
776     thread.join();
777   }
778 }
779 
TEST_F(PartyTest,NestedWakeup)780 TEST_F(PartyTest, NestedWakeup) {
781   auto party1 = MakeRefCounted<TestParty>();
782   auto party2 = MakeRefCounted<TestParty>();
783   auto party3 = MakeRefCounted<TestParty>();
784   int whats_going_on = 0;
785   Notification started2;
786   Notification done2;
787   Notification started3;
788   Notification notify_done;
789   party1->Spawn(
790       "p1",
791       [&]() {
792         EXPECT_EQ(whats_going_on, 0);
793         whats_going_on = 1;
794         party2->Spawn(
795             "p2",
796             [&]() {
797               started2.Notify();
798               started3.WaitForNotification();
799               EXPECT_EQ(whats_going_on, 3);
800               whats_going_on = 4;
801               return Empty{};
802             },
803             [&](Empty) {
804               EXPECT_EQ(whats_going_on, 4);
805               whats_going_on = 5;
806               done2.Notify();
807             });
808         party3->Spawn(
809             "p3",
810             [&]() {
811               started2.WaitForNotification();
812               started3.Notify();
813               done2.WaitForNotification();
814               EXPECT_EQ(whats_going_on, 5);
815               whats_going_on = 6;
816               return Empty{};
817             },
818             [&](Empty) {
819               EXPECT_EQ(whats_going_on, 6);
820               whats_going_on = 7;
821               notify_done.Notify();
822             });
823         EXPECT_EQ(whats_going_on, 1);
824         whats_going_on = 2;
825         return Empty{};
826       },
827       [&](Empty) {
828         EXPECT_EQ(whats_going_on, 2);
829         whats_going_on = 3;
830       });
831   notify_done.WaitForNotification();
832 }
833 
834 }  // namespace grpc_core
835 
main(int argc,char ** argv)836 int main(int argc, char** argv) {
837   ::testing::InitGoogleTest(&argc, argv);
838   grpc_init();
839   int r = RUN_ALL_TESTS();
840   grpc_shutdown();
841   return r;
842 }
843