1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_bluetooth_sapphire/internal/host/gatt/fake_client.h"
16
17 #include <unordered_set>
18
19 #include "pw_bluetooth_sapphire/internal/host/common/assert.h"
20 #include "pw_bluetooth_sapphire/internal/host/gatt/client.h"
21
22 namespace bt::gatt::testing {
23
FakeClient(pw::async::Dispatcher & pw_dispatcher)24 FakeClient::FakeClient(pw::async::Dispatcher& pw_dispatcher)
25 : heap_dispatcher_(pw_dispatcher), weak_self_(this), weak_fake_(this) {}
26
mtu() const27 uint16_t FakeClient::mtu() const {
28 // TODO(armansito): Return a configurable value.
29 return att::kLEMinMTU;
30 }
31
ExchangeMTU(MTUCallback callback)32 void FakeClient::ExchangeMTU(MTUCallback callback) {
33 (void)heap_dispatcher_.Post(
34 [mtu_status = exchange_mtu_status_,
35 mtu = server_mtu_,
36 callback = std::move(callback)](pw::async::Context /*ctx*/,
37 pw::Status status) mutable {
38 if (!status.ok()) {
39 return;
40 }
41
42 if (mtu_status.is_error()) {
43 callback(fit::error(mtu_status.error_value()));
44 } else {
45 callback(fit::ok(mtu));
46 }
47 });
48 }
49
DiscoverServices(ServiceKind kind,ServiceCallback svc_callback,att::ResultFunction<> status_callback)50 void FakeClient::DiscoverServices(ServiceKind kind,
51 ServiceCallback svc_callback,
52 att::ResultFunction<> status_callback) {
53 DiscoverServicesInRange(kind,
54 /*start=*/att::kHandleMin,
55 /*end=*/att::kHandleMax,
56 std::move(svc_callback),
57 std::move(status_callback));
58 }
59
DiscoverServicesInRange(ServiceKind kind,att::Handle start,att::Handle end,ServiceCallback svc_callback,att::ResultFunction<> status_callback)60 void FakeClient::DiscoverServicesInRange(
61 ServiceKind kind,
62 att::Handle start,
63 att::Handle end,
64 ServiceCallback svc_callback,
65 att::ResultFunction<> status_callback) {
66 DiscoverServicesWithUuidsInRange(kind,
67 start,
68 end,
69 std::move(svc_callback),
70 std::move(status_callback),
71 /*uuids=*/{});
72 }
73
DiscoverServicesWithUuids(ServiceKind kind,ServiceCallback svc_callback,att::ResultFunction<> status_callback,std::vector<UUID> uuids)74 void FakeClient::DiscoverServicesWithUuids(
75 ServiceKind kind,
76 ServiceCallback svc_callback,
77 att::ResultFunction<> status_callback,
78 std::vector<UUID> uuids) {
79 DiscoverServicesWithUuidsInRange(kind,
80 /*start=*/att::kHandleMin,
81 /*end=*/att::kHandleMax,
82 std::move(svc_callback),
83 std::move(status_callback),
84 std::move(uuids));
85 }
86
DiscoverServicesWithUuidsInRange(ServiceKind kind,att::Handle start,att::Handle end,ServiceCallback svc_callback,att::ResultFunction<> status_callback,std::vector<UUID> uuids)87 void FakeClient::DiscoverServicesWithUuidsInRange(
88 ServiceKind kind,
89 att::Handle start,
90 att::Handle end,
91 ServiceCallback svc_callback,
92 att::ResultFunction<> status_callback,
93 std::vector<UUID> uuids) {
94 att::Result<> callback_status = fit::ok();
95 if (discover_services_callback_) {
96 callback_status = discover_services_callback_(kind);
97 }
98
99 std::unordered_set<UUID> uuids_set(uuids.cbegin(), uuids.cend());
100
101 if (callback_status.is_ok()) {
102 for (const ServiceData& svc : services_) {
103 bool uuid_matches =
104 uuids.empty() || uuids_set.find(svc.type) != uuids_set.end();
105 if (svc.kind == kind && uuid_matches && svc.range_start >= start &&
106 svc.range_start <= end) {
107 (void)heap_dispatcher_.Post(
108 [svc, cb = svc_callback.share()](pw::async::Context /*ctx*/,
109 pw::Status status) {
110 if (status.ok()) {
111 cb(svc);
112 }
113 });
114 }
115 }
116 }
117
118 (void)heap_dispatcher_.Post(
119 [callback_status, cb = std::move(status_callback)](
120 pw::async::Context /*ctx*/, pw::Status status) {
121 if (status.ok()) {
122 cb(callback_status);
123 }
124 });
125 }
126
DiscoverCharacteristics(att::Handle range_start,att::Handle range_end,CharacteristicCallback chrc_callback,att::ResultFunction<> status_callback)127 void FakeClient::DiscoverCharacteristics(
128 att::Handle range_start,
129 att::Handle range_end,
130 CharacteristicCallback chrc_callback,
131 att::ResultFunction<> status_callback) {
132 last_chrc_discovery_start_handle_ = range_start;
133 last_chrc_discovery_end_handle_ = range_end;
134 chrc_discovery_count_++;
135
136 (void)heap_dispatcher_.Post(
137 [this,
138 range_start,
139 range_end,
140 chrc_callback = std::move(chrc_callback),
141 status_callback = std::move(status_callback)](pw::async::Context /*ctx*/,
142 pw::Status status) {
143 if (!status.ok()) {
144 return;
145 }
146 for (const auto& chrc : chrcs_) {
147 if (chrc.handle >= range_start && chrc.handle <= range_end) {
148 chrc_callback(chrc);
149 }
150 }
151 status_callback(chrc_discovery_status_);
152 });
153 }
154
DiscoverDescriptors(att::Handle range_start,att::Handle range_end,DescriptorCallback desc_callback,att::ResultFunction<> status_callback)155 void FakeClient::DiscoverDescriptors(att::Handle range_start,
156 att::Handle range_end,
157 DescriptorCallback desc_callback,
158 att::ResultFunction<> status_callback) {
159 last_desc_discovery_start_handle_ = range_start;
160 last_desc_discovery_end_handle_ = range_end;
161 desc_discovery_count_++;
162
163 att::Result<> discovery_status = fit::ok();
164 if (!desc_discovery_status_target_ ||
165 desc_discovery_count_ == desc_discovery_status_target_) {
166 discovery_status = desc_discovery_status_;
167 }
168
169 (void)heap_dispatcher_.Post(
170 [this,
171 discovery_status,
172 range_start,
173 range_end,
174 desc_callback = std::move(desc_callback),
175 status_callback = std::move(status_callback)](pw::async::Context /*ctx*/,
176 pw::Status status) {
177 if (!status.ok()) {
178 return;
179 }
180 for (const auto& desc : descs_) {
181 if (desc.handle >= range_start && desc.handle <= range_end) {
182 desc_callback(desc);
183 }
184 }
185 status_callback(discovery_status);
186 });
187 }
188
ReadRequest(att::Handle handle,ReadCallback callback)189 void FakeClient::ReadRequest(att::Handle handle, ReadCallback callback) {
190 if (read_request_callback_) {
191 read_request_callback_(handle, std::move(callback));
192 }
193 }
194
ReadByTypeRequest(const UUID & type,att::Handle start_handle,att::Handle end_handle,ReadByTypeCallback callback)195 void FakeClient::ReadByTypeRequest(const UUID& type,
196 att::Handle start_handle,
197 att::Handle end_handle,
198 ReadByTypeCallback callback) {
199 if (read_by_type_request_callback_) {
200 read_by_type_request_callback_(
201 type, start_handle, end_handle, std::move(callback));
202 }
203 }
204
ReadBlobRequest(att::Handle handle,uint16_t offset,ReadCallback callback)205 void FakeClient::ReadBlobRequest(att::Handle handle,
206 uint16_t offset,
207 ReadCallback callback) {
208 if (read_blob_request_callback_) {
209 read_blob_request_callback_(handle, offset, std::move(callback));
210 }
211 }
212
WriteRequest(att::Handle handle,const ByteBuffer & value,att::ResultFunction<> callback)213 void FakeClient::WriteRequest(att::Handle handle,
214 const ByteBuffer& value,
215 att::ResultFunction<> callback) {
216 if (write_request_callback_) {
217 write_request_callback_(handle, value, std::move(callback));
218 }
219 }
220
ExecutePrepareWrites(att::PrepareWriteQueue write_queue,ReliableMode reliable_mode,att::ResultFunction<> callback)221 void FakeClient::ExecutePrepareWrites(att::PrepareWriteQueue write_queue,
222 ReliableMode reliable_mode,
223 att::ResultFunction<> callback) {
224 if (execute_prepare_writes_callback_) {
225 execute_prepare_writes_callback_(
226 std::move(write_queue), reliable_mode, std::move(callback));
227 }
228 }
229
PrepareWriteRequest(att::Handle handle,uint16_t offset,const ByteBuffer & part_value,PrepareCallback callback)230 void FakeClient::PrepareWriteRequest(att::Handle handle,
231 uint16_t offset,
232 const ByteBuffer& part_value,
233 PrepareCallback callback) {
234 if (prepare_write_request_callback_) {
235 prepare_write_request_callback_(
236 handle, offset, part_value, std::move(callback));
237 }
238 }
ExecuteWriteRequest(att::ExecuteWriteFlag flag,att::ResultFunction<> callback)239 void FakeClient::ExecuteWriteRequest(att::ExecuteWriteFlag flag,
240 att::ResultFunction<> callback) {
241 if (execute_write_request_callback_) {
242 execute_write_request_callback_(flag, std::move(callback));
243 }
244 }
245
WriteWithoutResponse(att::Handle handle,const ByteBuffer & value,att::ResultFunction<> callback)246 void FakeClient::WriteWithoutResponse(att::Handle handle,
247 const ByteBuffer& value,
248 att::ResultFunction<> callback) {
249 if (write_without_rsp_callback_) {
250 write_without_rsp_callback_(handle, value, std::move(callback));
251 }
252 }
253
SendNotification(bool indicate,att::Handle handle,const ByteBuffer & value,bool maybe_truncated)254 void FakeClient::SendNotification(bool indicate,
255 att::Handle handle,
256 const ByteBuffer& value,
257 bool maybe_truncated) {
258 if (notification_callback_) {
259 notification_callback_(indicate, handle, value, maybe_truncated);
260 }
261 }
262
SetNotificationHandler(NotificationCallback callback)263 void FakeClient::SetNotificationHandler(NotificationCallback callback) {
264 notification_callback_ = std::move(callback);
265 }
266
267 } // namespace bt::gatt::testing
268