1 /*
2  * Copyright 2021 HIMSA II K/S - www.himsa.com.
3  * Represented by EHIMA - www.ehima.com
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 #include <bluetooth/log.h>
19 #include <stdio.h>
20 
21 #include <algorithm>
22 #include <cstddef>
23 #include <cstdint>
24 #include <cstring>
25 #include <functional>
26 #include <limits>
27 #include <list>
28 #include <map>
29 #include <mutex>
30 #include <ostream>
31 #include <sstream>
32 #include <unordered_set>
33 #include <utility>
34 #include <vector>
35 
36 #include "bta_groups.h"
37 #include "btif/include/btif_profile_storage.h"
38 #include "os/logging/log_adapter.h"
39 #include "stack/include/bt_types.h"
40 #include "types/bluetooth/uuid.h"
41 #include "types/raw_address.h"
42 
43 using bluetooth::Uuid;
44 
45 namespace bluetooth {
46 namespace groups {
47 
48 class DeviceGroupsImpl;
49 DeviceGroupsImpl* instance;
50 std::mutex instance_mutex;
51 static constexpr int kMaxGroupId = 0xEF;
52 
53 class DeviceGroup {
54 public:
DeviceGroup(int group_id,Uuid uuid)55   DeviceGroup(int group_id, Uuid uuid) : group_id_(group_id), group_uuid_(uuid) {}
Add(const RawAddress & addr)56   void Add(const RawAddress& addr) { devices_.insert(addr); }
Remove(const RawAddress & addr)57   void Remove(const RawAddress& addr) { devices_.erase(addr); }
Contains(const RawAddress & addr) const58   bool Contains(const RawAddress& addr) const { return devices_.count(addr) != 0; }
59 
ForEachDevice(std::function<void (const RawAddress &)> cb) const60   void ForEachDevice(std::function<void(const RawAddress&)> cb) const {
61     for (auto const& addr : devices_) {
62       cb(addr);
63     }
64   }
65 
Size(void) const66   int Size(void) const { return devices_.size(); }
GetGroupId(void) const67   int GetGroupId(void) const { return group_id_; }
GetUuid(void) const68   const Uuid& GetUuid(void) const { return group_uuid_; }
69 
70 private:
71   friend std::ostream& operator<<(std::ostream& out, const bluetooth::groups::DeviceGroup& value);
72   int group_id_;
73   Uuid group_uuid_;
74   std::unordered_set<RawAddress> devices_;
75 };
76 
77 class DeviceGroupsImpl : public DeviceGroups {
78   static constexpr uint8_t GROUP_STORAGE_CURRENT_LAYOUT_MAGIC = 0x10;
79   static constexpr size_t GROUP_STORAGE_HEADER_SZ =
80           sizeof(GROUP_STORAGE_CURRENT_LAYOUT_MAGIC) + sizeof(uint8_t); /* num_of_groups */
81   static constexpr size_t GROUP_STORAGE_ENTRY_SZ =
82           sizeof(uint8_t) /* group_id */ + Uuid::kNumBytes128;
83 
84 public:
DeviceGroupsImpl(DeviceGroupsCallbacks * callbacks)85   DeviceGroupsImpl(DeviceGroupsCallbacks* callbacks) {
86     AddCallbacks(callbacks);
87     btif_storage_load_bonded_groups();
88   }
89 
GetGroupId(const RawAddress & addr,Uuid uuid) const90   int GetGroupId(const RawAddress& addr, Uuid uuid) const override {
91     for (const auto& [id, g] : groups_) {
92       if ((g.Contains(addr)) && (uuid == g.GetUuid())) {
93         return id;
94       }
95     }
96     return kGroupUnknown;
97   }
98 
add_to_group(const RawAddress & addr,DeviceGroup * group)99   void add_to_group(const RawAddress& addr, DeviceGroup* group) {
100     group->Add(addr);
101 
102     bool first_device_in_group = (group->Size() == 1);
103 
104     for (auto c : callbacks_) {
105       if (first_device_in_group) {
106         c->OnGroupAdded(addr, group->GetUuid(), group->GetGroupId());
107       } else {
108         c->OnGroupMemberAdded(addr, group->GetGroupId());
109       }
110     }
111   }
112 
AddDevice(const RawAddress & addr,Uuid uuid,int group_id)113   int AddDevice(const RawAddress& addr, Uuid uuid, int group_id) override {
114     DeviceGroup* group = nullptr;
115 
116     if (group_id == kGroupUnknown) {
117       auto gid = GetGroupId(addr, uuid);
118       if (gid != kGroupUnknown) {
119         return gid;
120       }
121       group = create_group(uuid);
122     } else {
123       group = get_or_create_group_with_id(group_id, uuid);
124       if (!group) {
125         return kGroupUnknown;
126       }
127     }
128 
129     log::assert_that(group, "assert failed: group");
130 
131     if (group->Contains(addr)) {
132       log::error("device {} already in the group: {}", addr, group_id);
133       return group->GetGroupId();
134     }
135 
136     add_to_group(addr, group);
137 
138     btif_storage_add_groups(addr);
139     return group->GetGroupId();
140   }
141 
RemoveDevice(const RawAddress & addr,int group_id)142   void RemoveDevice(const RawAddress& addr, int group_id) override {
143     int num_of_groups_dev_belongs = 0;
144 
145     /* Remove from all the groups. Usually happens on unbond */
146     for (auto it = groups_.begin(); it != groups_.end();) {
147       auto& [id, g] = *it;
148       if (!g.Contains(addr)) {
149         ++it;
150         continue;
151       }
152 
153       num_of_groups_dev_belongs++;
154 
155       if ((group_id != bluetooth::groups::kGroupUnknown) && (group_id != id)) {
156         ++it;
157         continue;
158       }
159 
160       num_of_groups_dev_belongs--;
161 
162       g.Remove(addr);
163       for (auto c : callbacks_) {
164         c->OnGroupMemberRemoved(addr, id);
165       }
166 
167       if (g.Size() == 0) {
168         for (auto c : callbacks_) {
169           c->OnGroupRemoved(g.GetUuid(), g.GetGroupId());
170         }
171         it = groups_.erase(it);
172       } else {
173         ++it;
174       }
175     }
176 
177     btif_storage_remove_groups(addr);
178     if (num_of_groups_dev_belongs > 0) {
179       btif_storage_add_groups(addr);
180     }
181   }
182 
SerializeGroups(const RawAddress & addr,std::vector<uint8_t> & out) const183   bool SerializeGroups(const RawAddress& addr, std::vector<uint8_t>& out) const {
184     auto num_groups = std::count_if(groups_.begin(), groups_.end(), [&addr](auto& id_group_pair) {
185       return id_group_pair.second.Contains(addr);
186     });
187     if ((num_groups == 0) || (num_groups > std::numeric_limits<uint8_t>::max())) {
188       return false;
189     }
190 
191     out.resize(GROUP_STORAGE_HEADER_SZ + (num_groups * GROUP_STORAGE_ENTRY_SZ));
192     auto* ptr = out.data();
193 
194     /* header */
195     UINT8_TO_STREAM(ptr, GROUP_STORAGE_CURRENT_LAYOUT_MAGIC);
196     UINT8_TO_STREAM(ptr, num_groups);
197 
198     /* group entries */
199     for (const auto& [id, g] : groups_) {
200       if (g.Contains(addr)) {
201         UINT8_TO_STREAM(ptr, id);
202 
203         Uuid::UUID128Bit uuid128 = g.GetUuid().To128BitLE();
204         memcpy(ptr, uuid128.data(), Uuid::kNumBytes128);
205         ptr += Uuid::kNumBytes128;
206       }
207     }
208 
209     return true;
210   }
211 
DeserializeGroups(const RawAddress & addr,const std::vector<uint8_t> & in)212   void DeserializeGroups(const RawAddress& addr, const std::vector<uint8_t>& in) {
213     if (in.size() < GROUP_STORAGE_HEADER_SZ + GROUP_STORAGE_ENTRY_SZ) {
214       return;
215     }
216 
217     auto* ptr = in.data();
218 
219     uint8_t magic;
220     STREAM_TO_UINT8(magic, ptr);
221 
222     if (magic == GROUP_STORAGE_CURRENT_LAYOUT_MAGIC) {
223       uint8_t num_groups;
224       STREAM_TO_UINT8(num_groups, ptr);
225 
226       if (in.size() < GROUP_STORAGE_HEADER_SZ + (num_groups * GROUP_STORAGE_ENTRY_SZ)) {
227         log::error("Invalid persistent storage data");
228         return;
229       }
230 
231       /* group entries */
232       while (num_groups--) {
233         uint8_t id;
234         STREAM_TO_UINT8(id, ptr);
235 
236         Uuid::UUID128Bit uuid128;
237         STREAM_TO_ARRAY(uuid128.data(), ptr, (int)Uuid::kNumBytes128);
238 
239         auto* group = get_or_create_group_with_id(id, Uuid::From128BitLE(uuid128));
240         if (group) {
241           add_to_group(addr, group);
242         }
243 
244         for (auto c : callbacks_) {
245           c->OnGroupAddFromStorage(addr, Uuid::From128BitLE(uuid128), id);
246         }
247       }
248     }
249   }
250 
AddCallbacks(DeviceGroupsCallbacks * callbacks)251   void AddCallbacks(DeviceGroupsCallbacks* callbacks) {
252     callbacks_.push_back(std::move(callbacks));
253 
254     /* Notify new user about known groups */
255     for (const auto& [id, g] : groups_) {
256       auto group_uuid = g.GetUuid();
257       auto group_id = g.GetGroupId();
258       g.ForEachDevice([&](auto& dev) { callbacks->OnGroupAdded(dev, group_uuid, group_id); });
259     }
260   }
261 
Clear(DeviceGroupsCallbacks * callbacks)262   bool Clear(DeviceGroupsCallbacks* callbacks) {
263     auto it = find_if(callbacks_.begin(), callbacks_.end(),
264                       [callbacks](auto c) { return c == callbacks; });
265 
266     if (it != callbacks_.end()) {
267       callbacks_.erase(it);
268     }
269 
270     if (callbacks_.size() != 0) {
271       return false;
272     }
273     /* When all clients were unregistered */
274     groups_.clear();
275     return true;
276   }
277 
Dump(int fd)278   void Dump(int fd) {
279     std::stringstream stream;
280 
281     stream << "  Num. registered clients: " << callbacks_.size() << std::endl;
282     stream << "  Groups:\n";
283     for (const auto& kv_pair : groups_) {
284       stream << kv_pair.second << std::endl;
285     }
286 
287     dprintf(fd, "%s", stream.str().c_str());
288   }
289 
290 private:
find_device_group(int group_id)291   DeviceGroup* find_device_group(int group_id) {
292     return groups_.count(group_id) ? &groups_.at(group_id) : nullptr;
293   }
294 
get_or_create_group_with_id(int group_id,Uuid uuid)295   DeviceGroup* get_or_create_group_with_id(int group_id, Uuid uuid) {
296     auto group = find_device_group(group_id);
297     if (group) {
298       if (group->GetUuid() != uuid) {
299         log::error("group {} exists but for different uuid: {}, user request uuid: {}", group_id,
300                    group->GetUuid(), uuid);
301         return nullptr;
302       }
303 
304       log::info("group already exists: {}", group_id);
305       return group;
306     }
307 
308     DeviceGroup new_group(group_id, uuid);
309     groups_.insert({group_id, std::move(new_group)});
310 
311     return &groups_.at(group_id);
312   }
313 
create_group(Uuid & uuid)314   DeviceGroup* create_group(Uuid& uuid) {
315     /* Generate new group id and return empty group */
316     /* Find first free id */
317 
318     int group_id = -1;
319     for (int i = 1; i < kMaxGroupId; i++) {
320       if (groups_.count(i) == 0) {
321         group_id = i;
322         break;
323       }
324     }
325 
326     if (group_id < 0) {
327       log::error("too many groups");
328       return nullptr;
329     }
330 
331     DeviceGroup group(group_id, uuid);
332     groups_.insert({group_id, std::move(group)});
333 
334     return &groups_.at(group_id);
335   }
336 
337   std::map<int, DeviceGroup> groups_;
338   std::list<DeviceGroupsCallbacks*> callbacks_;
339 };
340 
Initialize(DeviceGroupsCallbacks * callbacks)341 void DeviceGroups::Initialize(DeviceGroupsCallbacks* callbacks) {
342   std::scoped_lock<std::mutex> lock(instance_mutex);
343   if (instance == nullptr) {
344     instance = new DeviceGroupsImpl(callbacks);
345     return;
346   }
347 
348   instance->AddCallbacks(callbacks);
349 }
350 
AddFromStorage(const RawAddress & addr,const std::vector<uint8_t> & in)351 void DeviceGroups::AddFromStorage(const RawAddress& addr, const std::vector<uint8_t>& in) {
352   if (!instance) {
353     log::error("Not initialized yet");
354     return;
355   }
356 
357   instance->DeserializeGroups(addr, in);
358 }
359 
GetForStorage(const RawAddress & addr,std::vector<uint8_t> & out)360 bool DeviceGroups::GetForStorage(const RawAddress& addr, std::vector<uint8_t>& out) {
361   if (!instance) {
362     log::error("Not initialized yet");
363     return false;
364   }
365 
366   return instance->SerializeGroups(addr, out);
367 }
368 
CleanUp(DeviceGroupsCallbacks * callbacks)369 void DeviceGroups::CleanUp(DeviceGroupsCallbacks* callbacks) {
370   std::scoped_lock<std::mutex> lock(instance_mutex);
371   if (!instance) {
372     return;
373   }
374 
375   if (instance->Clear(callbacks)) {
376     delete (instance);
377     instance = nullptr;
378   }
379 }
380 
operator <<(std::ostream & out,bluetooth::groups::DeviceGroup const & group)381 std::ostream& operator<<(std::ostream& out, bluetooth::groups::DeviceGroup const& group) {
382   out << "    == Group id: " << group.group_id_ << " == \n"
383       << "      Uuid: " << group.group_uuid_ << std::endl;
384   out << "      Devices:\n";
385   for (auto const& addr : group.devices_) {
386     out << "        " << ADDRESS_TO_LOGGABLE_STR(addr) << std::endl;
387   }
388   return out;
389 }
390 
DebugDump(int fd)391 void DeviceGroups::DebugDump(int fd) {
392   std::scoped_lock<std::mutex> lock(instance_mutex);
393   dprintf(fd, "Device Groups Manager:\n");
394   if (instance) {
395     instance->Dump(fd);
396   } else {
397     dprintf(fd, "  Not initialized \n");
398   }
399 }
400 
Get()401 DeviceGroups* DeviceGroups::Get() { return instance; }
402 
403 }  // namespace groups
404 }  // namespace bluetooth
405