xref: /aosp_15_r20/external/grpc-grpc/src/core/lib/promise/observable.h (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 // Copyright 2024 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GRPC_SRC_CORE_LIB_PROMISE_OBSERVABLE_H
16 #define GRPC_SRC_CORE_LIB_PROMISE_OBSERVABLE_H
17 
18 #include <grpc/support/port_platform.h>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/functional/any_invocable.h"
22 
23 #include "src/core/lib/gprpp/sync.h"
24 #include "src/core/lib/promise/activity.h"
25 #include "src/core/lib/promise/poll.h"
26 
27 namespace grpc_core {
28 
29 // Observable allows broadcasting a value to multiple interested observers.
30 template <typename T>
31 class Observable {
32  public:
33   // We need to assign a value initially.
Observable(T initial)34   explicit Observable(T initial)
35       : state_(MakeRefCounted<State>(std::move(initial))) {}
36 
37   // Update the value to something new. Awakes any waiters.
Set(T value)38   void Set(T value) { state_->Set(std::move(value)); }
39 
40   // Returns a promise that resolves to a T when is_acceptable returns true for
41   // that value.
42   // is_acceptable is any invocable that takes a `const T&` and returns a bool.
43   template <typename F>
NextWhen(F is_acceptable)44   auto NextWhen(F is_acceptable) {
45     return ObserverWhen<F>(state_, std::move(is_acceptable));
46   }
47 
48   // Returns a promise that resolves to a T when the value becomes != current.
Next(T current)49   auto Next(T current) {
50     return NextWhen([current = std::move(current)](const T& value) {
51       return value != current;
52     });
53   }
54 
55  private:
56   // Forward declaration so we can form pointers to Observer in State.
57   class Observer;
58 
59   // State keeps track of all observable state.
60   // It's a refcounted object so that promises reading the state are not tied
61   // to the lifetime of the Observable.
62   class State : public RefCounted<State> {
63    public:
State(T value)64     explicit State(T value) : value_(std::move(value)) {}
65 
66     // Update the value and wake all observers.
Set(T value)67     void Set(T value) {
68       MutexLock lock(&mu_);
69       std::swap(value_, value);
70       WakeAll();
71     }
72 
73     // Export our mutex so that Observer can use it.
mu()74     Mutex* mu() ABSL_LOCK_RETURNED(mu_) { return &mu_; }
75 
76     // Fetch a ref to the current value.
current()77     const T& current() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
78       return value_;
79     }
80 
81     // Remove an observer from the set (it no longer needs updates).
Remove(Observer * observer)82     void Remove(Observer* observer) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
83       observers_.erase(observer);
84     }
85 
86     // Add an observer to the set (it needs updates).
Add(Observer * observer)87     GRPC_MUST_USE_RESULT Waker Add(Observer* observer)
88         ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
89       observers_.insert(observer);
90       return GetContext<Activity>()->MakeNonOwningWaker();
91     }
92 
93    private:
94     // Wake all observers.
WakeAll()95     void WakeAll() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
96       for (auto* observer : observers_) {
97         observer->Wakeup();
98       }
99     }
100 
101     Mutex mu_;
102     // All observers that may need an update.
103     absl::flat_hash_set<Observer*> observers_ ABSL_GUARDED_BY(mu_);
104     // The current value.
105     T value_ ABSL_GUARDED_BY(mu_);
106   };
107 
108   // A promise that resolves to a T when ShouldReturn() returns true.
109   // Subclasses must implement ShouldReturn().
110   class Observer {
111    public:
Observer(RefCountedPtr<State> state)112     explicit Observer(RefCountedPtr<State> state) : state_(std::move(state)) {}
113 
~Observer()114     virtual ~Observer() {
115       // If we saw a pending at all then we *may* be in the set of observers.
116       // If not we're definitely not and we can avoid taking the lock at all.
117       if (!saw_pending_) return;
118       MutexLock lock(state_->mu());
119       auto w = std::move(waker_);
120       state_->Remove(this);
121     }
122 
123     Observer(const Observer&) = delete;
124     Observer& operator=(const Observer&) = delete;
Observer(Observer && other)125     Observer(Observer&& other) noexcept : state_(std::move(other.state_)) {
126       GPR_ASSERT(other.waker_.is_unwakeable());
127       GPR_ASSERT(!other.saw_pending_);
128     }
129     Observer& operator=(Observer&& other) noexcept = delete;
130 
Wakeup()131     void Wakeup() { waker_.WakeupAsync(); }
132 
133     virtual bool ShouldReturn(const T& current) = 0;
134 
operator()135     Poll<T> operator()() {
136       MutexLock lock(state_->mu());
137       // Check if the value has changed yet.
138       if (ShouldReturn(state_->current())) {
139         if (saw_pending_ && !waker_.is_unwakeable()) state_->Remove(this);
140         return state_->current();
141       }
142       // Record that we saw at least one pending and then register for wakeup.
143       saw_pending_ = true;
144       if (waker_.is_unwakeable()) waker_ = state_->Add(this);
145       return Pending{};
146     }
147 
148    private:
149     RefCountedPtr<State> state_;
150     Waker waker_;
151     bool saw_pending_ = false;
152   };
153 
154   // A promise that resolves to a T when is_acceptable returns true for
155   // the current value.
156   template <typename F>
157   class ObserverWhen : public Observer {
158    public:
ObserverWhen(RefCountedPtr<State> state,F is_acceptable)159     ObserverWhen(RefCountedPtr<State> state, F is_acceptable)
160         : Observer(std::move(state)),
161           is_acceptable_(std::move(is_acceptable)) {}
162 
ObserverWhen(ObserverWhen && other)163     ObserverWhen(ObserverWhen&& other) noexcept
164         : Observer(std::move(other)),
165           is_acceptable_(std::move(other.is_acceptable_)) {}
166 
ShouldReturn(const T & current)167     bool ShouldReturn(const T& current) override {
168       return is_acceptable_(current);
169     }
170 
171    private:
172     F is_acceptable_;
173   };
174 
175   RefCountedPtr<State> state_;
176 };
177 
178 }  // namespace grpc_core
179 
180 #endif  // GRPC_SRC_CORE_LIB_PROMISE_OBSERVABLE_H
181