1 //
2 // Copyright 2020 gRPC authors.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #ifndef GRPC_SRC_CORE_LIB_GPRPP_DUAL_REF_COUNTED_H
18 #define GRPC_SRC_CORE_LIB_GPRPP_DUAL_REF_COUNTED_H
19 
20 #include <grpc/support/port_platform.h>
21 
22 #include <atomic>
23 #include <cstdint>
24 
25 #include <grpc/support/log.h>
26 
27 #include "src/core/lib/gprpp/debug_location.h"
28 #include "src/core/lib/gprpp/orphanable.h"
29 #include "src/core/lib/gprpp/ref_counted_ptr.h"
30 
31 namespace grpc_core {
32 
33 // DualRefCounted is an interface for reference-counted objects with two
34 // classes of refs: strong refs (usually just called "refs") and weak refs.
35 // This supports cases where an object needs to start shutting down when
36 // all external callers are done with it (represented by strong refs) but
37 // cannot be destroyed until all internal callbacks are complete
38 // (represented by weak refs).
39 //
40 // Each class of refs can be incremented and decremented independently.
41 // Objects start with 1 strong ref and 0 weak refs at instantiation.
42 // When the strong refcount reaches 0, the object's Orphan() method is called.
43 // When the weak refcount reaches 0, the object is destroyed.
44 //
45 // This will be used by CRTP (curiously-recurring template pattern), e.g.:
46 //   class MyClass : public RefCounted<MyClass> { ... };
47 template <typename Child>
48 class DualRefCounted : public Orphanable {
49  public:
50   ~DualRefCounted() override = default;
51 
Ref()52   RefCountedPtr<Child> Ref() GRPC_MUST_USE_RESULT {
53     IncrementRefCount();
54     return RefCountedPtr<Child>(static_cast<Child*>(this));
55   }
56 
Ref(const DebugLocation & location,const char * reason)57   RefCountedPtr<Child> Ref(const DebugLocation& location,
58                            const char* reason) GRPC_MUST_USE_RESULT {
59     IncrementRefCount(location, reason);
60     return RefCountedPtr<Child>(static_cast<Child*>(this));
61   }
62 
Unref()63   void Unref() {
64     // Convert strong ref to weak ref.
65     const uint64_t prev_ref_pair =
66         refs_.fetch_add(MakeRefPair(-1, 1), std::memory_order_acq_rel);
67     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
68 #ifndef NDEBUG
69     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
70     if (trace_ != nullptr) {
71       gpr_log(GPR_INFO, "%s:%p unref %d -> %d, weak_ref %d -> %d", trace_, this,
72               strong_refs, strong_refs - 1, weak_refs, weak_refs + 1);
73     }
74     GPR_ASSERT(strong_refs > 0);
75 #endif
76     if (GPR_UNLIKELY(strong_refs == 1)) {
77       Orphan();
78     }
79     // Now drop the weak ref.
80     WeakUnref();
81   }
Unref(const DebugLocation & location,const char * reason)82   void Unref(const DebugLocation& location, const char* reason) {
83     const uint64_t prev_ref_pair =
84         refs_.fetch_add(MakeRefPair(-1, 1), std::memory_order_acq_rel);
85     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
86 #ifndef NDEBUG
87     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
88     if (trace_ != nullptr) {
89       gpr_log(GPR_INFO, "%s:%p %s:%d unref %d -> %d, weak_ref %d -> %d) %s",
90               trace_, this, location.file(), location.line(), strong_refs,
91               strong_refs - 1, weak_refs, weak_refs + 1, reason);
92     }
93     GPR_ASSERT(strong_refs > 0);
94 #else
95     // Avoid unused-parameter warnings for debug-only parameters
96     (void)location;
97     (void)reason;
98 #endif
99     if (GPR_UNLIKELY(strong_refs == 1)) {
100       Orphan();
101     }
102     // Now drop the weak ref.
103     WeakUnref(location, reason);
104   }
105 
RefIfNonZero()106   RefCountedPtr<Child> RefIfNonZero() GRPC_MUST_USE_RESULT {
107     uint64_t prev_ref_pair = refs_.load(std::memory_order_acquire);
108     do {
109       const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
110 #ifndef NDEBUG
111       const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
112       if (trace_ != nullptr) {
113         gpr_log(GPR_INFO, "%s:%p ref_if_non_zero %d -> %d (weak_refs=%d)",
114                 trace_, this, strong_refs, strong_refs + 1, weak_refs);
115       }
116 #endif
117       if (strong_refs == 0) return nullptr;
118     } while (!refs_.compare_exchange_weak(
119         prev_ref_pair, prev_ref_pair + MakeRefPair(1, 0),
120         std::memory_order_acq_rel, std::memory_order_acquire));
121     return RefCountedPtr<Child>(static_cast<Child*>(this));
122   }
123 
RefIfNonZero(const DebugLocation & location,const char * reason)124   RefCountedPtr<Child> RefIfNonZero(const DebugLocation& location,
125                                     const char* reason) GRPC_MUST_USE_RESULT {
126     uint64_t prev_ref_pair = refs_.load(std::memory_order_acquire);
127     do {
128       const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
129 #ifndef NDEBUG
130       const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
131       if (trace_ != nullptr) {
132         gpr_log(GPR_INFO,
133                 "%s:%p %s:%d ref_if_non_zero %d -> %d (weak_refs=%d) %s",
134                 trace_, this, location.file(), location.line(), strong_refs,
135                 strong_refs + 1, weak_refs, reason);
136       }
137 #else
138       // Avoid unused-parameter warnings for debug-only parameters
139       (void)location;
140       (void)reason;
141 #endif
142       if (strong_refs == 0) return nullptr;
143     } while (!refs_.compare_exchange_weak(
144         prev_ref_pair, prev_ref_pair + MakeRefPair(1, 0),
145         std::memory_order_acq_rel, std::memory_order_acquire));
146     return RefCountedPtr<Child>(static_cast<Child*>(this));
147   }
148 
WeakRef()149   WeakRefCountedPtr<Child> WeakRef() GRPC_MUST_USE_RESULT {
150     IncrementWeakRefCount();
151     return WeakRefCountedPtr<Child>(static_cast<Child*>(this));
152   }
153 
WeakRef(const DebugLocation & location,const char * reason)154   WeakRefCountedPtr<Child> WeakRef(const DebugLocation& location,
155                                    const char* reason) GRPC_MUST_USE_RESULT {
156     IncrementWeakRefCount(location, reason);
157     return WeakRefCountedPtr<Child>(static_cast<Child*>(this));
158   }
159 
WeakUnref()160   void WeakUnref() {
161 #ifndef NDEBUG
162     // Grab a copy of the trace flag before the atomic change, since we
163     // will no longer be holding a ref afterwards and therefore can't
164     // safely access it, since another thread might free us in the interim.
165     const char* trace = trace_;
166 #endif
167     const uint64_t prev_ref_pair =
168         refs_.fetch_sub(MakeRefPair(0, 1), std::memory_order_acq_rel);
169 #ifndef NDEBUG
170     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
171     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
172     if (trace != nullptr) {
173       gpr_log(GPR_INFO, "%s:%p weak_unref %d -> %d (refs=%d)", trace, this,
174               weak_refs, weak_refs - 1, strong_refs);
175     }
176     GPR_ASSERT(weak_refs > 0);
177 #endif
178     if (GPR_UNLIKELY(prev_ref_pair == MakeRefPair(0, 1))) {
179       delete static_cast<Child*>(this);
180     }
181   }
WeakUnref(const DebugLocation & location,const char * reason)182   void WeakUnref(const DebugLocation& location, const char* reason) {
183 #ifndef NDEBUG
184     // Grab a copy of the trace flag before the atomic change, since we
185     // will no longer be holding a ref afterwards and therefore can't
186     // safely access it, since another thread might free us in the interim.
187     const char* trace = trace_;
188 #endif
189     const uint64_t prev_ref_pair =
190         refs_.fetch_sub(MakeRefPair(0, 1), std::memory_order_acq_rel);
191 #ifndef NDEBUG
192     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
193     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
194     if (trace != nullptr) {
195       gpr_log(GPR_INFO, "%s:%p %s:%d weak_unref %d -> %d (refs=%d) %s", trace,
196               this, location.file(), location.line(), weak_refs, weak_refs - 1,
197               strong_refs, reason);
198     }
199     GPR_ASSERT(weak_refs > 0);
200 #else
201     // Avoid unused-parameter warnings for debug-only parameters
202     (void)location;
203     (void)reason;
204 #endif
205     if (GPR_UNLIKELY(prev_ref_pair == MakeRefPair(0, 1))) {
206       delete static_cast<Child*>(this);
207     }
208   }
209 
210   // Not copyable nor movable.
211   DualRefCounted(const DualRefCounted&) = delete;
212   DualRefCounted& operator=(const DualRefCounted&) = delete;
213 
214  protected:
215   // Note: Tracing is a no-op in non-debug builds.
216   explicit DualRefCounted(
217       const char*
218 #ifndef NDEBUG
219           // Leave unnamed if NDEBUG to avoid unused parameter warning
220           trace
221 #endif
222       = nullptr,
223       int32_t initial_refcount = 1)
224       :
225 #ifndef NDEBUG
trace_(trace)226         trace_(trace),
227 #endif
228         refs_(MakeRefPair(initial_refcount, 0)) {
229   }
230 
231  private:
232   // Allow RefCountedPtr<> to access IncrementRefCount().
233   template <typename T>
234   friend class RefCountedPtr;
235   // Allow WeakRefCountedPtr<> to access IncrementWeakRefCount().
236   template <typename T>
237   friend class WeakRefCountedPtr;
238 
239   // First 32 bits are strong refs, next 32 bits are weak refs.
MakeRefPair(uint32_t strong,uint32_t weak)240   static uint64_t MakeRefPair(uint32_t strong, uint32_t weak) {
241     return (static_cast<uint64_t>(strong) << 32) + static_cast<int64_t>(weak);
242   }
GetStrongRefs(uint64_t ref_pair)243   static uint32_t GetStrongRefs(uint64_t ref_pair) {
244     return static_cast<uint32_t>(ref_pair >> 32);
245   }
GetWeakRefs(uint64_t ref_pair)246   static uint32_t GetWeakRefs(uint64_t ref_pair) {
247     return static_cast<uint32_t>(ref_pair & 0xffffffffu);
248   }
249 
IncrementRefCount()250   void IncrementRefCount() {
251 #ifndef NDEBUG
252     const uint64_t prev_ref_pair =
253         refs_.fetch_add(MakeRefPair(1, 0), std::memory_order_relaxed);
254     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
255     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
256     GPR_ASSERT(strong_refs != 0);
257     if (trace_ != nullptr) {
258       gpr_log(GPR_INFO, "%s:%p ref %d -> %d; (weak_refs=%d)", trace_, this,
259               strong_refs, strong_refs + 1, weak_refs);
260     }
261 #else
262     refs_.fetch_add(MakeRefPair(1, 0), std::memory_order_relaxed);
263 #endif
264   }
IncrementRefCount(const DebugLocation & location,const char * reason)265   void IncrementRefCount(const DebugLocation& location, const char* reason) {
266 #ifndef NDEBUG
267     const uint64_t prev_ref_pair =
268         refs_.fetch_add(MakeRefPair(1, 0), std::memory_order_relaxed);
269     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
270     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
271     GPR_ASSERT(strong_refs != 0);
272     if (trace_ != nullptr) {
273       gpr_log(GPR_INFO, "%s:%p %s:%d ref %d -> %d (weak_refs=%d) %s", trace_,
274               this, location.file(), location.line(), strong_refs,
275               strong_refs + 1, weak_refs, reason);
276     }
277 #else
278     // Use conditionally-important parameters
279     (void)location;
280     (void)reason;
281     refs_.fetch_add(MakeRefPair(1, 0), std::memory_order_relaxed);
282 #endif
283   }
284 
IncrementWeakRefCount()285   void IncrementWeakRefCount() {
286 #ifndef NDEBUG
287     const uint64_t prev_ref_pair =
288         refs_.fetch_add(MakeRefPair(0, 1), std::memory_order_relaxed);
289     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
290     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
291     if (trace_ != nullptr) {
292       gpr_log(GPR_INFO, "%s:%p weak_ref %d -> %d; (refs=%d)", trace_, this,
293               weak_refs, weak_refs + 1, strong_refs);
294     }
295 #else
296     refs_.fetch_add(MakeRefPair(0, 1), std::memory_order_relaxed);
297 #endif
298   }
IncrementWeakRefCount(const DebugLocation & location,const char * reason)299   void IncrementWeakRefCount(const DebugLocation& location,
300                              const char* reason) {
301 #ifndef NDEBUG
302     const uint64_t prev_ref_pair =
303         refs_.fetch_add(MakeRefPair(0, 1), std::memory_order_relaxed);
304     const uint32_t strong_refs = GetStrongRefs(prev_ref_pair);
305     const uint32_t weak_refs = GetWeakRefs(prev_ref_pair);
306     if (trace_ != nullptr) {
307       gpr_log(GPR_INFO, "%s:%p %s:%d weak_ref %d -> %d (refs=%d) %s", trace_,
308               this, location.file(), location.line(), weak_refs, weak_refs + 1,
309               strong_refs, reason);
310     }
311 #else
312     // Use conditionally-important parameters
313     (void)location;
314     (void)reason;
315     refs_.fetch_add(MakeRefPair(0, 1), std::memory_order_relaxed);
316 #endif
317   }
318 
319 #ifndef NDEBUG
320   const char* trace_;
321 #endif
322   std::atomic<uint64_t> refs_{0};
323 };
324 
325 }  // namespace grpc_core
326 
327 #endif  // GRPC_SRC_CORE_LIB_GPRPP_DUAL_REF_COUNTED_H
328