1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_bluetooth_sapphire/internal/host/common/weak_self.h"
16
17 #include <pw_async/fake_dispatcher_fixture.h>
18 #include <pw_async/heap_dispatcher.h>
19
20 #include "pw_bluetooth_sapphire/internal/host/common/assert.h"
21 #include "pw_unit_test/framework.h"
22
23 namespace bt {
24 namespace {
25
26 using WeakSelfTest = pw::async::test::FakeDispatcherFixture;
27
28 class FunctionTester : public WeakSelf<FunctionTester> {
29 public:
FunctionTester(uint8_t testval,pw::async::Dispatcher & pw_dispatcher)30 explicit FunctionTester(uint8_t testval, pw::async::Dispatcher& pw_dispatcher)
31 : WeakSelf(this), value_(testval), heap_dispatcher_(pw_dispatcher) {}
32
callback_later_with_weak(fit::function<void (FunctionTester::WeakPtr)> cb)33 void callback_later_with_weak(
34 fit::function<void(FunctionTester::WeakPtr)> cb) {
35 auto weak = GetWeakPtr();
36 (void)heap_dispatcher_.Post(
37 [self = std::move(weak), cb = std::move(cb)](pw::async::Context /*ctx*/,
38 pw::Status status) {
39 if (status.ok()) {
40 cb(self);
41 }
42 });
43 }
44
value() const45 uint8_t value() const { return value_; }
46
47 private:
48 uint8_t value_;
49 pw::async::HeapDispatcher heap_dispatcher_;
50 };
51
TEST_F(WeakSelfTest,InvalidatingSelf)52 TEST_F(WeakSelfTest, InvalidatingSelf) {
53 bool called = false;
54 FunctionTester::WeakPtr ptr;
55
56 // Default-constructed weak pointers are not alive.
57 EXPECT_FALSE(ptr.is_alive());
58
59 auto cb = [&ptr, &called](auto weakptr) {
60 called = true;
61 ptr = weakptr;
62 };
63
64 {
65 FunctionTester test(0xBA, dispatcher());
66
67 test.callback_later_with_weak(cb);
68
69 // Run the loop until we're called back.
70 RunUntilIdle();
71
72 EXPECT_TRUE(called);
73 EXPECT_TRUE(ptr.is_alive());
74 EXPECT_EQ(&test, &ptr.get());
75 EXPECT_EQ(0xBA, ptr->value());
76
77 called = false;
78 test.callback_later_with_weak(cb);
79
80 // Now out of scope.
81 }
82
83 // Run the loop until we're called back.
84 RunUntilIdle();
85
86 EXPECT_TRUE(called);
87 EXPECT_FALSE(ptr.is_alive());
88 EXPECT_DEATH_IF_SUPPORTED(ptr.get(), "destroyed");
89 }
90
TEST_F(WeakSelfTest,InvalidatePtrs)91 TEST_F(WeakSelfTest, InvalidatePtrs) {
92 bool called = false;
93 FunctionTester::WeakPtr ptr;
94
95 // Default-constructed weak pointers are not alive.
96 EXPECT_FALSE(ptr.is_alive());
97
98 auto cb = [&ptr, &called](auto weakptr) {
99 called = true;
100 ptr = weakptr;
101 };
102
103 FunctionTester test(0xBA, dispatcher());
104
105 test.callback_later_with_weak(cb);
106
107 // Run the loop until we're called back.
108 RunUntilIdle();
109
110 EXPECT_TRUE(called);
111 EXPECT_TRUE(ptr.is_alive());
112 EXPECT_EQ(&test, &ptr.get());
113 EXPECT_EQ(0xBA, ptr->value());
114
115 called = false;
116 test.callback_later_with_weak(cb);
117
118 // Now invalidate the pointers.
119 test.InvalidatePtrs();
120
121 // Run the loop until we're called back.
122 RunUntilIdle();
123
124 EXPECT_TRUE(called);
125 EXPECT_FALSE(ptr.is_alive());
126 EXPECT_DEATH_IF_SUPPORTED(ptr.get(), "destroyed");
127 }
128
129 class StaticTester;
130
131 class OnlyTwoStaticManager {
132 public:
OnlyTwoStaticManager(StaticTester * self_ptr)133 explicit OnlyTwoStaticManager(StaticTester* self_ptr) : obj_ptr_(self_ptr) {}
~OnlyTwoStaticManager()134 ~OnlyTwoStaticManager() { InvalidateAll(); }
135
136 using RefType = RecyclingWeakRef;
137
GetWeakRef()138 std::optional<pw::IntrusivePtr<RefType>> GetWeakRef() {
139 for (auto& ptr : OnlyTwoStaticManager::pointers_) {
140 if (ptr.is_alive() && ptr.get() == obj_ptr_) {
141 // Already adopted, add another refptr pointing to it.
142 return pw::IntrusivePtr(&ptr);
143 }
144 }
145 for (auto& ptr : OnlyTwoStaticManager::pointers_) {
146 if (!ptr.is_in_use()) {
147 return ptr.alloc(obj_ptr_);
148 }
149 }
150 return std::nullopt;
151 }
152
InvalidateAll()153 void InvalidateAll() {
154 OnlyTwoStaticManager::pointers_[0].maybe_unset(obj_ptr_);
155 OnlyTwoStaticManager::pointers_[1].maybe_unset(obj_ptr_);
156 }
157
158 private:
159 StaticTester* obj_ptr_;
160 inline static RecyclingWeakRef pointers_[2];
161 };
162
163 class StaticTester : public WeakSelf<StaticTester, OnlyTwoStaticManager> {
164 public:
StaticTester(uint8_t testval)165 explicit StaticTester(uint8_t testval) : WeakSelf(this), value_(testval) {}
166
value() const167 uint8_t value() const { return value_; }
168
169 private:
170 uint8_t value_;
171 };
172
TEST_F(WeakSelfTest,StaticRecyclingPointers)173 TEST_F(WeakSelfTest, StaticRecyclingPointers) {
174 // We can create more objects than we have weak space for.
175 StaticTester test1(1);
176 StaticTester test2(2);
177 StaticTester test3(3);
178
179 // And create as many weak pointers of one of them as we want.
180 auto ptr = test1.GetWeakPtr();
181 auto ptr2 = test1.GetWeakPtr();
182 auto ptr3 = test1.GetWeakPtr();
183 auto ptr4 = ptr;
184
185 // Make the second one have some ptrs too.
186 {
187 {
188 StaticTester test4(4);
189 auto second_ptr = test4.GetWeakPtr();
190 auto second_ptr2 = test4.GetWeakPtr();
191 EXPECT_TRUE(ptr4.is_alive());
192 StaticTester* ptr4_old = &ptr4.get();
193 ptr4 = second_ptr;
194 EXPECT_TRUE(ptr4.is_alive());
195 // It's moved to the new one though.
196 EXPECT_NE(&ptr4.get(), ptr4_old);
197 EXPECT_EQ(&ptr4.get(), &test4);
198 }
199 // ptr4 outlived it's target.
200 EXPECT_FALSE(ptr4.is_alive());
201 // Now let's make the second weak pointer unused, recycling it.
202 ptr4 = ptr3;
203 }
204
205 // Now I can get a second weak ptr still, from our third object.
206 auto still_okay = test3.GetWeakPtr();
207 auto still_copy = still_okay;
208 EXPECT_TRUE(still_copy.is_alive());
209 }
210
TEST_F(WeakSelfTest,StaticDeathWhenExhausted)211 TEST_F(WeakSelfTest, StaticDeathWhenExhausted) {
212 StaticTester test1(1);
213 StaticTester test3(3);
214
215 auto ptr1 = test1.GetWeakPtr();
216 auto ptr2 = ptr1;
217 {
218 StaticTester test2(2);
219
220 ptr2 = test2.GetWeakPtr();
221
222 EXPECT_TRUE(ptr2.is_alive());
223 EXPECT_TRUE(ptr1.is_alive());
224 }
225
226 EXPECT_FALSE(ptr2.is_alive());
227
228 EXPECT_DEATH_IF_SUPPORTED(test3.GetWeakPtr(), ".*");
229 }
230
231 class GetWeakRefTester;
232
233 class CountingWeakManager {
234 public:
CountingWeakManager(GetWeakRefTester * self_ptr)235 explicit CountingWeakManager(GetWeakRefTester* self_ptr)
236 : manager_(self_ptr) {}
237
238 using RefType = DynamicWeakManager<GetWeakRefTester>::RefType;
239
240 ~CountingWeakManager() = default;
241
GetWeakRef()242 std::optional<pw::IntrusivePtr<RefType>> GetWeakRef() {
243 // Make sure the weak ref doesn't accidentally get cleared after it's set.
244 if (count_get_weak_ref_ == 0) {
245 PW_CHECK(!manager_.HasWeakRef());
246 } else {
247 PW_CHECK(manager_.HasWeakRef());
248 }
249 count_get_weak_ref_++;
250 return manager_.GetWeakRef();
251 }
252
InvalidateAll()253 void InvalidateAll() { return manager_.InvalidateAll(); }
254
255 private:
256 size_t count_get_weak_ref_{0};
257 DynamicWeakManager<GetWeakRefTester> manager_;
258 };
259
260 class GetWeakRefTester
261 : public WeakSelf<GetWeakRefTester, CountingWeakManager> {
262 public:
GetWeakRefTester(uint8_t testval)263 explicit GetWeakRefTester(uint8_t testval)
264 : WeakSelf(this), value_(testval) {}
265
value() const266 uint8_t value() const { return value_; }
267
268 private:
269 uint8_t value_;
270 };
271
TEST_F(WeakSelfTest,GetWeakRefNotMoved)272 TEST_F(WeakSelfTest, GetWeakRefNotMoved) {
273 GetWeakRefTester test_val{1};
274 {
275 // This is the main test, just make sure there are no assertions in
276 // `GetWeakPtr`.
277 auto ptr1 = test_val.GetWeakPtr();
278 auto ptr2 = test_val.GetWeakPtr();
279
280 EXPECT_TRUE(ptr1.is_alive());
281 EXPECT_TRUE(ptr2.is_alive());
282 EXPECT_EQ(&ptr1.get(), &ptr2.get());
283 }
284
285 auto ptr1 = test_val.GetWeakPtr();
286 auto ptr2 = test_val.GetWeakPtr();
287
288 EXPECT_TRUE(ptr1.is_alive());
289 EXPECT_TRUE(ptr2.is_alive());
290 EXPECT_EQ(&ptr1.get(), &ptr2.get());
291 }
292
293 class BaseClass {
294 public:
295 BaseClass() = default;
296 virtual ~BaseClass() = default;
297
set_value(int value)298 void set_value(int value) { value_ = value; }
299
value() const300 int value() const { return value_; }
301
302 private:
303 int value_ = 0;
304 };
305
306 class ChildClass : public BaseClass, public WeakSelf<ChildClass> {
307 public:
ChildClass()308 ChildClass() : BaseClass(), WeakSelf<ChildClass>(this) {}
309 };
310
TEST_F(WeakSelfTest,Upcast)311 TEST_F(WeakSelfTest, Upcast) {
312 ChildClass obj;
313
314 WeakPtr<ChildClass> child_weak = obj.GetWeakPtr();
315 child_weak->set_value(1);
316 EXPECT_EQ(child_weak->value(), 1);
317
318 WeakPtr<BaseClass> base_weak_copy(child_weak);
319 EXPECT_TRUE(child_weak.is_alive());
320 base_weak_copy->set_value(2);
321 EXPECT_EQ(base_weak_copy->value(), 2);
322
323 WeakPtr<BaseClass> base_weak_move(std::move(child_weak));
324 EXPECT_FALSE(child_weak.is_alive());
325 base_weak_move->set_value(3);
326 EXPECT_EQ(base_weak_move->value(), 3);
327 }
328
329 } // namespace
330 } // namespace bt
331