xref: /aosp_15_r20/external/pigweed/pw_bluetooth_sapphire/host/gatt/gatt.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_bluetooth_sapphire/internal/host/gatt/gatt.h"
16 
17 #include <lib/fit/defer.h>
18 
19 #include <unordered_map>
20 
21 #include "pw_bluetooth_sapphire/internal/host/common/assert.h"
22 #include "pw_bluetooth_sapphire/internal/host/common/log.h"
23 #include "pw_bluetooth_sapphire/internal/host/gatt/client.h"
24 #include "pw_bluetooth_sapphire/internal/host/gatt/connection.h"
25 #include "pw_bluetooth_sapphire/internal/host/gatt/generic_attribute_service.h"
26 #include "pw_bluetooth_sapphire/internal/host/gatt/remote_service.h"
27 #include "pw_bluetooth_sapphire/internal/host/gatt/server.h"
28 
29 namespace bt::gatt {
30 
GATT()31 GATT::GATT() : WeakSelf(this) {}
32 
33 namespace {
34 
35 class Impl final : public GATT {
36  public:
Impl()37   explicit Impl() {
38     local_services_ = std::make_unique<LocalServiceManager>();
39 
40     // Forwards Service Changed payloads to clients.
41     auto send_indication_callback = [this](IdType service_id,
42                                            IdType chrc_id,
43                                            PeerId peer_id,
44                                            BufferView value) {
45       auto iter = connections_.find(peer_id);
46       if (iter == connections_.end()) {
47         bt_log(WARN, "gatt", "peer not registered: %s", bt_str(peer_id));
48         return;
49       }
50       auto indication_cb = [](att::Result<> result) {
51         bt_log(TRACE,
52                "gatt",
53                "service changed indication complete: %s",
54                bt_str(result));
55       };
56       iter->second.server()->SendUpdate(
57           service_id, chrc_id, value.view(), std::move(indication_cb));
58     };
59 
60     // Spin up Generic Attribute as the first service.
61     gatt_service_ = std::make_unique<GenericAttributeService>(
62         local_services_->GetWeakPtr(), std::move(send_indication_callback));
63 
64     bt_log(DEBUG, "gatt", "initialized");
65   }
66 
~Impl()67   ~Impl() override {
68     bt_log(DEBUG, "gatt", "shutting down");
69 
70     connections_.clear();
71     gatt_service_ = nullptr;
72     local_services_ = nullptr;
73   }
74 
75   // GATT overrides:
76 
AddConnection(PeerId peer_id,std::unique_ptr<Client> client,Server::FactoryFunction server_factory)77   void AddConnection(PeerId peer_id,
78                      std::unique_ptr<Client> client,
79                      Server::FactoryFunction server_factory) override {
80     bt_log(DEBUG, "gatt", "add connection %s", bt_str(peer_id));
81 
82     auto iter = connections_.find(peer_id);
83     if (iter != connections_.end()) {
84       bt_log(WARN, "gatt", "peer is already registered: %s", bt_str(peer_id));
85       return;
86     }
87 
88     RemoteServiceWatcher service_watcher =
89         [this, peer_id](std::vector<att::Handle> removed,
90                         std::vector<RemoteService::WeakPtr> added,
91                         std::vector<RemoteService::WeakPtr> modified) {
92           OnServicesChanged(peer_id, removed, added, modified);
93         };
94     std::unique_ptr<Server> server =
95         server_factory(peer_id, local_services_->GetWeakPtr());
96     connections_.try_emplace(peer_id,
97                              std::move(client),
98                              std::move(server),
99                              std::move(service_watcher));
100 
101     if (retrieve_service_changed_ccc_callback_) {
102       auto optional_service_changed_ccc_data =
103           retrieve_service_changed_ccc_callback_(peer_id);
104       if (optional_service_changed_ccc_data && gatt_service_) {
105         gatt_service_->SetServiceChangedIndicationSubscription(
106             peer_id, optional_service_changed_ccc_data->indicate);
107       }
108     } else {
109       bt_log(WARN,
110              "gatt",
111              "Unable to retrieve service changed CCC: callback not set.");
112     }
113   }
114 
RemoveConnection(PeerId peer_id)115   void RemoveConnection(PeerId peer_id) override {
116     bt_log(DEBUG, "gatt", "remove connection: %s", bt_str(peer_id));
117     local_services_->DisconnectClient(peer_id);
118     connections_.erase(peer_id);
119   }
120 
RegisterPeerMtuListener(PeerMtuListener listener)121   PeerMtuListenerId RegisterPeerMtuListener(PeerMtuListener listener) override {
122     peer_mtu_listeners_.insert({next_mtu_listener_id_, std::move(listener)});
123     return next_mtu_listener_id_++;
124   }
125 
UnregisterPeerMtuListener(PeerMtuListenerId listener_id)126   bool UnregisterPeerMtuListener(PeerMtuListenerId listener_id) override {
127     return peer_mtu_listeners_.erase(listener_id) == 1;
128   }
129 
RegisterService(ServicePtr service,ServiceIdCallback callback,ReadHandler read_handler,WriteHandler write_handler,ClientConfigCallback ccc_callback)130   void RegisterService(ServicePtr service,
131                        ServiceIdCallback callback,
132                        ReadHandler read_handler,
133                        WriteHandler write_handler,
134                        ClientConfigCallback ccc_callback) override {
135     IdType id = local_services_->RegisterService(std::move(service),
136                                                  std::move(read_handler),
137                                                  std::move(write_handler),
138                                                  std::move(ccc_callback));
139     callback(id);
140   }
141 
UnregisterService(IdType service_id)142   void UnregisterService(IdType service_id) override {
143     local_services_->UnregisterService(service_id);
144   }
145 
SendUpdate(IdType service_id,IdType chrc_id,PeerId peer_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)146   void SendUpdate(IdType service_id,
147                   IdType chrc_id,
148                   PeerId peer_id,
149                   ::std::vector<uint8_t> value,
150                   IndicationCallback indicate_cb) override {
151     // There is nothing to do if the requested peer is not connected.
152     auto iter = connections_.find(peer_id);
153     if (iter == connections_.end()) {
154       bt_log(TRACE,
155              "gatt",
156              "cannot notify disconnected peer: %s",
157              bt_str(peer_id));
158       if (indicate_cb) {
159         indicate_cb(ToResult(HostError::kNotFound));
160       }
161       return;
162     }
163     iter->second.server()->SendUpdate(service_id,
164                                       chrc_id,
165                                       BufferView(value.data(), value.size()),
166                                       std::move(indicate_cb));
167   }
168 
UpdateConnectedPeers(IdType service_id,IdType chrc_id,::std::vector<uint8_t> value,IndicationCallback indicate_cb)169   void UpdateConnectedPeers(IdType service_id,
170                             IdType chrc_id,
171                             ::std::vector<uint8_t> value,
172                             IndicationCallback indicate_cb) override {
173     att::ResultFunction<> shared_peer_results_cb(nullptr);
174     if (indicate_cb) {
175       // This notifies indicate_cb with success when destroyed (if indicate_cb
176       // has not been invoked)
177       auto deferred_success =
178           fit::defer([outer_cb = indicate_cb.share()]() mutable {
179             if (outer_cb) {
180               outer_cb(fit::ok());
181             }
182           });
183       // This captures, but doesn't use, deferred_success. Because this is later
184       // |share|d for each peer's SendUpdate callback, deferred_success is
185       // stored in this refcounted memory. If any of the SendUpdate callbacks
186       // fail, the outer callback is notified of failure. But if all of the
187       // callbacks succeed, shared_peer_results_cb's captures will be destroyed,
188       // and deferred_success will then notify indicate_cb of success.
189       shared_peer_results_cb =
190           [deferred = std::move(deferred_success),
191            outer_cb = std::move(indicate_cb)](att::Result<> res) mutable {
192             if (outer_cb && res.is_error()) {
193               outer_cb(res);
194             }
195           };
196     }
197     for (auto& iter : connections_) {
198       // The `shared_peer_results_cb.share()` *does* propagate indication vs.
199       // notification-ness correctly - `fit::function(nullptr).share` just
200       // creates another null fit::function.
201       iter.second.server()->SendUpdate(service_id,
202                                        chrc_id,
203                                        BufferView(value.data(), value.size()),
204                                        shared_peer_results_cb.share());
205     }
206   }
207 
SetPersistServiceChangedCCCCallback(PersistServiceChangedCCCCallback callback)208   void SetPersistServiceChangedCCCCallback(
209       PersistServiceChangedCCCCallback callback) override {
210     gatt_service_->SetPersistServiceChangedCCCCallback(std::move(callback));
211   }
212 
SetRetrieveServiceChangedCCCCallback(RetrieveServiceChangedCCCCallback callback)213   void SetRetrieveServiceChangedCCCCallback(
214       RetrieveServiceChangedCCCCallback callback) override {
215     retrieve_service_changed_ccc_callback_ = std::move(callback);
216   }
217 
InitializeClient(PeerId peer_id,std::vector<UUID> services_to_discover)218   void InitializeClient(PeerId peer_id,
219                         std::vector<UUID> services_to_discover) override {
220     bt_log(TRACE, "gatt", "initialize client: %s", bt_str(peer_id));
221 
222     auto iter = connections_.find(peer_id);
223     if (iter == connections_.end()) {
224       bt_log(WARN, "gatt", "unknown peer: %s", bt_str(peer_id));
225       return;
226     }
227     auto mtu_cb = [this, peer_id](uint16_t mtu) {
228       for (auto& [_id, listener] : peer_mtu_listeners_) {
229         listener(peer_id, mtu);
230       }
231     };
232     iter->second.Initialize(std::move(services_to_discover), std::move(mtu_cb));
233   }
234 
RegisterRemoteServiceWatcherForPeer(PeerId peer_id,RemoteServiceWatcher watcher)235   RemoteServiceWatcherId RegisterRemoteServiceWatcherForPeer(
236       PeerId peer_id, RemoteServiceWatcher watcher) override {
237     PW_CHECK(watcher);
238 
239     RemoteServiceWatcherId id = next_watcher_id_++;
240     peer_remote_service_watchers_.emplace(
241         peer_id, std::make_pair(id, std::move(watcher)));
242     return id;
243   }
244 
UnregisterRemoteServiceWatcher(RemoteServiceWatcherId watcher_id)245   bool UnregisterRemoteServiceWatcher(
246       RemoteServiceWatcherId watcher_id) override {
247     for (auto it = peer_remote_service_watchers_.begin();
248          it != peer_remote_service_watchers_.end();) {
249       if (watcher_id == it->second.first) {
250         it = peer_remote_service_watchers_.erase(it);
251         return true;
252       }
253       it++;
254     }
255     return false;
256   }
257 
ListServices(PeerId peer_id,std::vector<UUID> uuids,ServiceListCallback callback)258   void ListServices(PeerId peer_id,
259                     std::vector<UUID> uuids,
260                     ServiceListCallback callback) override {
261     PW_CHECK(callback);
262     auto iter = connections_.find(peer_id);
263     if (iter == connections_.end()) {
264       // Connection not found.
265       callback(ToResult(HostError::kNotFound), ServiceList());
266       return;
267     }
268     iter->second.remote_service_manager()->ListServices(uuids,
269                                                         std::move(callback));
270   }
271 
FindService(PeerId peer_id,IdType service_id)272   RemoteService::WeakPtr FindService(PeerId peer_id,
273                                      IdType service_id) override {
274     auto iter = connections_.find(peer_id);
275     if (iter == connections_.end()) {
276       // Connection not found.
277       return RemoteService::WeakPtr();
278     }
279     return iter->second.remote_service_manager()->FindService(
280         static_cast<att::Handle>(service_id));
281   }
282 
283  private:
OnServicesChanged(PeerId peer_id,const std::vector<att::Handle> & removed,const std::vector<RemoteService::WeakPtr> & added,const std::vector<RemoteService::WeakPtr> & modified)284   void OnServicesChanged(PeerId peer_id,
285                          const std::vector<att::Handle>& removed,
286                          const std::vector<RemoteService::WeakPtr>& added,
287                          const std::vector<RemoteService::WeakPtr>& modified) {
288     auto peer_watcher_range =
289         peer_remote_service_watchers_.equal_range(peer_id);
290     for (auto it = peer_watcher_range.first; it != peer_watcher_range.second;
291          it++) {
292       TRACE_DURATION("bluetooth", "GATT::OnServiceChanged notify watcher");
293       it->second.second(removed, added, modified);
294     }
295   }
296 
297   // The registry containing all local GATT services. This represents a single
298   // ATT database.
299   std::unique_ptr<LocalServiceManager> local_services_;
300 
301   // Local GATT service (first in database) for clients to subscribe to service
302   // registration and removal.
303   std::unique_ptr<GenericAttributeService> gatt_service_;
304 
305   // Contains the state of all GATT profile connections and their services.
306   std::unordered_map<PeerId, internal::Connection> connections_;
307 
308   // Callback to fetch CCC for Service Changed indications from upper layers.
309   RetrieveServiceChangedCCCCallback retrieve_service_changed_ccc_callback_;
310 
311   RemoteServiceWatcherId next_watcher_id_ = 0u;
312   std::unordered_multimap<
313       PeerId,
314       std::pair<RemoteServiceWatcherId, RemoteServiceWatcher>>
315       peer_remote_service_watchers_;
316   PeerMtuListenerId next_mtu_listener_id_ = 0u;
317   std::unordered_map<PeerMtuListenerId, PeerMtuListener> peer_mtu_listeners_;
318 
319   BT_DISALLOW_COPY_AND_ASSIGN_ALLOW_MOVE(Impl);
320 };
321 }  // namespace
322 
323 // static
Create()324 std::unique_ptr<GATT> GATT::Create() { return std::make_unique<Impl>(); }
325 
326 }  // namespace bt::gatt
327