1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_bluetooth_sapphire/internal/host/gatt/gatt.h"
16
17 #include <lib/fit/defer.h>
18
19 #include <unordered_map>
20
21 #include "pw_bluetooth_sapphire/internal/host/common/assert.h"
22 #include "pw_bluetooth_sapphire/internal/host/common/log.h"
23 #include "pw_bluetooth_sapphire/internal/host/gatt/client.h"
24 #include "pw_bluetooth_sapphire/internal/host/gatt/connection.h"
25 #include "pw_bluetooth_sapphire/internal/host/gatt/generic_attribute_service.h"
26 #include "pw_bluetooth_sapphire/internal/host/gatt/remote_service.h"
27 #include "pw_bluetooth_sapphire/internal/host/gatt/server.h"
28
29 namespace bt::gatt {
30
GATT()31 GATT::GATT() : WeakSelf(this) {}
32
33 namespace {
34
35 class Impl final : public GATT {
36 public:
Impl()37 explicit Impl() {
38 local_services_ = std::make_unique<LocalServiceManager>();
39
40 // Forwards Service Changed payloads to clients.
41 auto send_indication_callback = [this](IdType service_id,
42 IdType chrc_id,
43 PeerId peer_id,
44 BufferView value) {
45 auto iter = connections_.find(peer_id);
46 if (iter == connections_.end()) {
47 bt_log(WARN, "gatt", "peer not registered: %s", bt_str(peer_id));
48 return;
49 }
50 auto indication_cb = [](att::Result<> result) {
51 bt_log(TRACE,
52 "gatt",
53 "service changed indication complete: %s",
54 bt_str(result));
55 };
56 iter->second.server()->SendUpdate(
57 service_id, chrc_id, value.view(), std::move(indication_cb));
58 };
59
60 // Spin up Generic Attribute as the first service.
61 gatt_service_ = std::make_unique<GenericAttributeService>(
62 local_services_->GetWeakPtr(), std::move(send_indication_callback));
63
64 bt_log(DEBUG, "gatt", "initialized");
65 }
66
~Impl()67 ~Impl() override {
68 bt_log(DEBUG, "gatt", "shutting down");
69
70 connections_.clear();
71 gatt_service_ = nullptr;
72 local_services_ = nullptr;
73 }
74
75 // GATT overrides:
76
AddConnection(PeerId peer_id,std::unique_ptr<Client> client,Server::FactoryFunction server_factory)77 void AddConnection(PeerId peer_id,
78 std::unique_ptr<Client> client,
79 Server::FactoryFunction server_factory) override {
80 bt_log(DEBUG, "gatt", "add connection %s", bt_str(peer_id));
81
82 auto iter = connections_.find(peer_id);
83 if (iter != connections_.end()) {
84 bt_log(WARN, "gatt", "peer is already registered: %s", bt_str(peer_id));
85 return;
86 }
87
88 RemoteServiceWatcher service_watcher =
89 [this, peer_id](std::vector<att::Handle> removed,
90 std::vector<RemoteService::WeakPtr> added,
91 std::vector<RemoteService::WeakPtr> modified) {
92 OnServicesChanged(peer_id, removed, added, modified);
93 };
94 std::unique_ptr<Server> server =
95 server_factory(peer_id, local_services_->GetWeakPtr());
96 connections_.try_emplace(peer_id,
97 std::move(client),
98 std::move(server),
99 std::move(service_watcher));
100
101 if (retrieve_service_changed_ccc_callback_) {
102 auto optional_service_changed_ccc_data =
103 retrieve_service_changed_ccc_callback_(peer_id);
104 if (optional_service_changed_ccc_data && gatt_service_) {
105 gatt_service_->SetServiceChangedIndicationSubscription(
106 peer_id, optional_service_changed_ccc_data->indicate);
107 }
108 } else {
109 bt_log(WARN,
110 "gatt",
111 "Unable to retrieve service changed CCC: callback not set.");
112 }
113 }
114
RemoveConnection(PeerId peer_id)115 void RemoveConnection(PeerId peer_id) override {
116 bt_log(DEBUG, "gatt", "remove connection: %s", bt_str(peer_id));
117 local_services_->DisconnectClient(peer_id);
118 connections_.erase(peer_id);
119 }
120
RegisterPeerMtuListener(PeerMtuListener listener)121 PeerMtuListenerId RegisterPeerMtuListener(PeerMtuListener listener) override {
122 peer_mtu_listeners_.insert({next_mtu_listener_id_, std::move(listener)});
123 return next_mtu_listener_id_++;
124 }
125
UnregisterPeerMtuListener(PeerMtuListenerId listener_id)126 bool UnregisterPeerMtuListener(PeerMtuListenerId listener_id) override {
127 return peer_mtu_listeners_.erase(listener_id) == 1;
128 }
129
RegisterService(ServicePtr service,ServiceIdCallback callback,ReadHandler read_handler,WriteHandler write_handler,ClientConfigCallback ccc_callback)130 void RegisterService(ServicePtr service,
131 ServiceIdCallback callback,
132 ReadHandler read_handler,
133 WriteHandler write_handler,
134 ClientConfigCallback ccc_callback) override {
135 IdType id = local_services_->RegisterService(std::move(service),
136 std::move(read_handler),
137 std::move(write_handler),
138 std::move(ccc_callback));
139 callback(id);
140 }
141
UnregisterService(IdType service_id)142 void UnregisterService(IdType service_id) override {
143 local_services_->UnregisterService(service_id);
144 }
145
SendUpdate(IdType service_id,IdType chrc_id,PeerId peer_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)146 void SendUpdate(IdType service_id,
147 IdType chrc_id,
148 PeerId peer_id,
149 ::std::vector<uint8_t> value,
150 IndicationCallback indicate_cb) override {
151 // There is nothing to do if the requested peer is not connected.
152 auto iter = connections_.find(peer_id);
153 if (iter == connections_.end()) {
154 bt_log(TRACE,
155 "gatt",
156 "cannot notify disconnected peer: %s",
157 bt_str(peer_id));
158 if (indicate_cb) {
159 indicate_cb(ToResult(HostError::kNotFound));
160 }
161 return;
162 }
163 iter->second.server()->SendUpdate(service_id,
164 chrc_id,
165 BufferView(value.data(), value.size()),
166 std::move(indicate_cb));
167 }
168
UpdateConnectedPeers(IdType service_id,IdType chrc_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)169 void UpdateConnectedPeers(IdType service_id,
170 IdType chrc_id,
171 ::std::vector<uint8_t> value,
172 IndicationCallback indicate_cb) override {
173 att::ResultFunction<> shared_peer_results_cb(nullptr);
174 if (indicate_cb) {
175 // This notifies indicate_cb with success when destroyed (if indicate_cb
176 // has not been invoked)
177 auto deferred_success =
178 fit::defer([outer_cb = indicate_cb.share()]() mutable {
179 if (outer_cb) {
180 outer_cb(fit::ok());
181 }
182 });
183 // This captures, but doesn't use, deferred_success. Because this is later
184 // |share|d for each peer's SendUpdate callback, deferred_success is
185 // stored in this refcounted memory. If any of the SendUpdate callbacks
186 // fail, the outer callback is notified of failure. But if all of the
187 // callbacks succeed, shared_peer_results_cb's captures will be destroyed,
188 // and deferred_success will then notify indicate_cb of success.
189 shared_peer_results_cb =
190 [deferred = std::move(deferred_success),
191 outer_cb = std::move(indicate_cb)](att::Result<> res) mutable {
192 if (outer_cb && res.is_error()) {
193 outer_cb(res);
194 }
195 };
196 }
197 for (auto& iter : connections_) {
198 // The `shared_peer_results_cb.share()` *does* propagate indication vs.
199 // notification-ness correctly - `fit::function(nullptr).share` just
200 // creates another null fit::function.
201 iter.second.server()->SendUpdate(service_id,
202 chrc_id,
203 BufferView(value.data(), value.size()),
204 shared_peer_results_cb.share());
205 }
206 }
207
SetPersistServiceChangedCCCCallback(PersistServiceChangedCCCCallback callback)208 void SetPersistServiceChangedCCCCallback(
209 PersistServiceChangedCCCCallback callback) override {
210 gatt_service_->SetPersistServiceChangedCCCCallback(std::move(callback));
211 }
212
SetRetrieveServiceChangedCCCCallback(RetrieveServiceChangedCCCCallback callback)213 void SetRetrieveServiceChangedCCCCallback(
214 RetrieveServiceChangedCCCCallback callback) override {
215 retrieve_service_changed_ccc_callback_ = std::move(callback);
216 }
217
InitializeClient(PeerId peer_id,std::vector<UUID> services_to_discover)218 void InitializeClient(PeerId peer_id,
219 std::vector<UUID> services_to_discover) override {
220 bt_log(TRACE, "gatt", "initialize client: %s", bt_str(peer_id));
221
222 auto iter = connections_.find(peer_id);
223 if (iter == connections_.end()) {
224 bt_log(WARN, "gatt", "unknown peer: %s", bt_str(peer_id));
225 return;
226 }
227 auto mtu_cb = [this, peer_id](uint16_t mtu) {
228 for (auto& [_id, listener] : peer_mtu_listeners_) {
229 listener(peer_id, mtu);
230 }
231 };
232 iter->second.Initialize(std::move(services_to_discover), std::move(mtu_cb));
233 }
234
RegisterRemoteServiceWatcherForPeer(PeerId peer_id,RemoteServiceWatcher watcher)235 RemoteServiceWatcherId RegisterRemoteServiceWatcherForPeer(
236 PeerId peer_id, RemoteServiceWatcher watcher) override {
237 PW_CHECK(watcher);
238
239 RemoteServiceWatcherId id = next_watcher_id_++;
240 peer_remote_service_watchers_.emplace(
241 peer_id, std::make_pair(id, std::move(watcher)));
242 return id;
243 }
244
UnregisterRemoteServiceWatcher(RemoteServiceWatcherId watcher_id)245 bool UnregisterRemoteServiceWatcher(
246 RemoteServiceWatcherId watcher_id) override {
247 for (auto it = peer_remote_service_watchers_.begin();
248 it != peer_remote_service_watchers_.end();) {
249 if (watcher_id == it->second.first) {
250 it = peer_remote_service_watchers_.erase(it);
251 return true;
252 }
253 it++;
254 }
255 return false;
256 }
257
ListServices(PeerId peer_id,std::vector<UUID> uuids,ServiceListCallback callback)258 void ListServices(PeerId peer_id,
259 std::vector<UUID> uuids,
260 ServiceListCallback callback) override {
261 PW_CHECK(callback);
262 auto iter = connections_.find(peer_id);
263 if (iter == connections_.end()) {
264 // Connection not found.
265 callback(ToResult(HostError::kNotFound), ServiceList());
266 return;
267 }
268 iter->second.remote_service_manager()->ListServices(uuids,
269 std::move(callback));
270 }
271
FindService(PeerId peer_id,IdType service_id)272 RemoteService::WeakPtr FindService(PeerId peer_id,
273 IdType service_id) override {
274 auto iter = connections_.find(peer_id);
275 if (iter == connections_.end()) {
276 // Connection not found.
277 return RemoteService::WeakPtr();
278 }
279 return iter->second.remote_service_manager()->FindService(
280 static_cast<att::Handle>(service_id));
281 }
282
283 private:
OnServicesChanged(PeerId peer_id,const std::vector<att::Handle> & removed,const std::vector<RemoteService::WeakPtr> & added,const std::vector<RemoteService::WeakPtr> & modified)284 void OnServicesChanged(PeerId peer_id,
285 const std::vector<att::Handle>& removed,
286 const std::vector<RemoteService::WeakPtr>& added,
287 const std::vector<RemoteService::WeakPtr>& modified) {
288 auto peer_watcher_range =
289 peer_remote_service_watchers_.equal_range(peer_id);
290 for (auto it = peer_watcher_range.first; it != peer_watcher_range.second;
291 it++) {
292 TRACE_DURATION("bluetooth", "GATT::OnServiceChanged notify watcher");
293 it->second.second(removed, added, modified);
294 }
295 }
296
297 // The registry containing all local GATT services. This represents a single
298 // ATT database.
299 std::unique_ptr<LocalServiceManager> local_services_;
300
301 // Local GATT service (first in database) for clients to subscribe to service
302 // registration and removal.
303 std::unique_ptr<GenericAttributeService> gatt_service_;
304
305 // Contains the state of all GATT profile connections and their services.
306 std::unordered_map<PeerId, internal::Connection> connections_;
307
308 // Callback to fetch CCC for Service Changed indications from upper layers.
309 RetrieveServiceChangedCCCCallback retrieve_service_changed_ccc_callback_;
310
311 RemoteServiceWatcherId next_watcher_id_ = 0u;
312 std::unordered_multimap<
313 PeerId,
314 std::pair<RemoteServiceWatcherId, RemoteServiceWatcher>>
315 peer_remote_service_watchers_;
316 PeerMtuListenerId next_mtu_listener_id_ = 0u;
317 std::unordered_map<PeerMtuListenerId, PeerMtuListener> peer_mtu_listeners_;
318
319 BT_DISALLOW_COPY_AND_ASSIGN_ALLOW_MOVE(Impl);
320 };
321 } // namespace
322
323 // static
Create()324 std::unique_ptr<GATT> GATT::Create() { return std::make_unique<Impl>(); }
325
326 } // namespace bt::gatt
327