1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #pragma once
16 #include "pw_bluetooth_sapphire/internal/host/common/device_address.h"
17 #include "pw_bluetooth_sapphire/internal/host/gap/adapter.h"
18 #include "pw_bluetooth_sapphire/internal/host/gap/gap.h"
19 #include "pw_bluetooth_sapphire/internal/host/hci/fake_local_address_delegate.h"
20 #include "pw_bluetooth_sapphire/internal/host/l2cap/fake_channel.h"
21 #include "pw_bluetooth_sapphire/internal/host/l2cap/types.h"
22 
23 namespace bt::gap::testing {
24 
25 // FakeAdapter is a fake implementation of Adapter that can be used in higher
26 // layer unit tests (e.g. FIDL tests).
27 class FakeAdapter final : public Adapter {
28  public:
29   explicit FakeAdapter(pw::async::Dispatcher& pw_dispatcher);
30   ~FakeAdapter() override = default;
31 
mutable_state()32   AdapterState& mutable_state() { return state_; }
33 
34   // Adapter overrides:
35 
identifier()36   AdapterId identifier() const override { return AdapterId(0); }
37 
38   bool Initialize(InitializeCallback callback,
39                   fit::closure transport_closed_callback) override;
40 
41   void ShutDown() override;
42 
IsInitializing()43   bool IsInitializing() const override {
44     return init_state_ == InitState::kInitializing;
45   }
46 
IsInitialized()47   bool IsInitialized() const override {
48     return init_state_ == InitState::kInitialized;
49   }
50 
state()51   const AdapterState& state() const override { return state_; }
52 
53   class FakeLowEnergy final : public LowEnergy {
54    public:
55     struct RegisteredAdvertisement {
56       AdvertisingData data;
57       AdvertisingData scan_response;
58       bool include_tx_power_level;
59       DeviceAddress::Type addr_type;
60       bool extended_pdu;
61       bool anonymous;
62       std::optional<ConnectableAdvertisingParameters> connectable;
63     };
64 
65     struct Connection {
66       PeerId peer_id;
67       LowEnergyConnectionOptions options;
68     };
69 
FakeLowEnergy(FakeAdapter * adapter)70     explicit FakeLowEnergy(FakeAdapter* adapter)
71         : adapter_(adapter), fake_address_delegate_(adapter->pw_dispatcher_) {}
72     ~FakeLowEnergy() override = default;
73 
74     const std::unordered_map<AdvertisementId, RegisteredAdvertisement>&
registered_advertisements()75     registered_advertisements() {
76       return advertisements_;
77     }
78 
connections()79     const std::unordered_map<PeerId, Connection>& connections() const {
80       return connections_;
81     }
82 
83     // Update the LE random address of the adapter.
84     void UpdateRandomAddress(DeviceAddress& address);
85 
86     // Overrides the result returned to StartAdvertising() callback.
87     void set_advertising_result(hci::Result<> result);
88 
89     // LowEnergy overrides:
90 
91     // If Connect is called multiple times, only the connection options of the
92     // last call will be reported in connections().
93     void Connect(PeerId peer_id,
94                  ConnectionResultCallback callback,
95                  LowEnergyConnectionOptions connection_options) override;
96 
97     bool Disconnect(PeerId peer_id) override;
98 
99     void OpenL2capChannel(PeerId peer_id,
100                           l2cap::Psm,
101                           l2cap::ChannelParameters,
102                           sm::SecurityLevel security_level,
103                           l2cap::ChannelCallback) override;
104 
Pair(PeerId,sm::SecurityLevel,sm::BondableMode,sm::ResultFunction<>)105     void Pair(PeerId,
106               sm::SecurityLevel,
107               sm::BondableMode,
108               sm::ResultFunction<>) override {}
109 
SetLESecurityMode(LESecurityMode)110     void SetLESecurityMode(LESecurityMode) override {}
111 
security_mode()112     LESecurityMode security_mode() const override {
113       return adapter_->le_security_mode_;
114     }
115 
116     void StartAdvertising(
117         AdvertisingData data,
118         AdvertisingData scan_rsp,
119         AdvertisingInterval interval,
120         bool extended_pdu,
121         bool anonymous,
122         bool include_tx_power_level,
123         std::optional<ConnectableAdvertisingParameters> connectable,
124         std::optional<DeviceAddress::Type> address_type,
125         AdvertisingStatusCallback status_callback) override;
126 
StartDiscovery(bool,SessionCallback)127     void StartDiscovery(bool, SessionCallback) override {}
128 
129     void EnablePrivacy(bool enabled) override;
130 
131     // Returns true if the privacy feature is currently enabled.
PrivacyEnabled()132     bool PrivacyEnabled() const override {
133       return fake_address_delegate_.privacy_enabled();
134     }
135     // Returns the current LE address.
CurrentAddress()136     const DeviceAddress CurrentAddress() const override {
137       return fake_address_delegate_.current_address();
138     }
139 
register_address_changed_callback(fit::closure callback)140     void register_address_changed_callback(fit::closure callback) override {
141       fake_address_delegate_.register_address_changed_callback(
142           std::move(callback));
143     }
144 
set_irk(const std::optional<UInt128> &)145     void set_irk(const std::optional<UInt128>&) override {}
146 
irk()147     std::optional<UInt128> irk() const override { return std::nullopt; }
148 
set_request_timeout_for_testing(pw::chrono::SystemClock::duration)149     void set_request_timeout_for_testing(
150         pw::chrono::SystemClock::duration) override {}
151 
set_scan_period_for_testing(pw::chrono::SystemClock::duration)152     void set_scan_period_for_testing(
153         pw::chrono::SystemClock::duration) override {}
154 
155    private:
156     FakeAdapter* adapter_;
157     AdvertisementId next_advertisement_id_ = AdvertisementId(1);
158     std::unordered_map<AdvertisementId, RegisteredAdvertisement>
159         advertisements_;
160     std::unordered_map<PeerId, Connection> connections_;
161     hci::FakeLocalAddressDelegate fake_address_delegate_;
162     l2cap::ChannelId next_channel_id_ = l2cap::kFirstDynamicChannelId;
163     std::unordered_map<l2cap::ChannelId,
164                        std::unique_ptr<l2cap::testing::FakeChannel>>
165         channels_;
166     std::optional<hci::Result<>> advertising_result_override_;
167   };
168 
le()169   LowEnergy* le() const override { return fake_le_.get(); }
fake_le()170   FakeLowEnergy* fake_le() const { return fake_le_.get(); }
171 
172   class FakeBrEdr final : public BrEdr {
173    public:
174     struct RegisteredService {
175       std::vector<sdp::ServiceRecord> records;
176       l2cap::ChannelParameters channel_params;
177       ServiceConnectCallback connect_callback;
178     };
179 
180     struct RegisteredSearch {
181       UUID uuid;
182       std::unordered_set<sdp::AttributeId> attributes;
183       SearchCallback callback;
184     };
185 
186     FakeBrEdr() = default;
187     ~FakeBrEdr() override;
188 
189     // Called with a reference to the l2cap::FakeChannel created when a channel
190     // is connected with Connect().
191     using ChannelCallback =
192         fit::function<void(l2cap::testing::FakeChannel::WeakPtr)>;
set_l2cap_channel_callback(ChannelCallback cb)193     void set_l2cap_channel_callback(ChannelCallback cb) {
194       channel_cb_ = std::move(cb);
195     }
196 
197     // Destroys the channel, invaliding all weak pointers. Returns true if the
198     // channel was successfully destroyed.
DestroyChannel(l2cap::ChannelId channel_id)199     bool DestroyChannel(l2cap::ChannelId channel_id) {
200       return channels_.erase(channel_id);
201     }
202 
203     // Notifies all registered searches associated with the provided |uuid| with
204     // the peer's service |attributes|.
205     void TriggerServiceFound(
206         PeerId peer_id,
207         UUID uuid,
208         std::map<sdp::AttributeId, sdp::DataElement> attributes);
209 
registered_services()210     const std::map<RegistrationHandle, RegisteredService>& registered_services()
211         const {
212       return registered_services_;
213     }
214 
registered_searches()215     const std::map<RegistrationHandle, RegisteredSearch>& registered_searches()
216         const {
217       return registered_searches_;
218     }
219 
220     // BrEdr overrides:
Connect(PeerId,ConnectResultCallback)221     [[nodiscard]] bool Connect(PeerId, ConnectResultCallback) override {
222       return false;
223     }
224 
Disconnect(PeerId,DisconnectReason)225     bool Disconnect(PeerId, DisconnectReason) override { return false; }
226 
227     void OpenL2capChannel(PeerId peer_id,
228                           l2cap::Psm psm,
229                           BrEdrSecurityRequirements security_requirements,
230                           l2cap::ChannelParameters params,
231                           l2cap::ChannelCallback cb) override;
232 
GetPeerId(hci_spec::ConnectionHandle)233     PeerId GetPeerId(hci_spec::ConnectionHandle) const override {
234       return PeerId();
235     }
236 
237     SearchId AddServiceSearch(const UUID& uuid,
238                               std::unordered_set<sdp::AttributeId> attributes,
239                               SearchCallback callback) override;
240 
RemoveServiceSearch(SearchId)241     bool RemoveServiceSearch(SearchId) override { return false; }
242 
Pair(PeerId,BrEdrSecurityRequirements,hci::ResultFunction<>)243     void Pair(PeerId,
244               BrEdrSecurityRequirements,
245               hci::ResultFunction<>) override {}
246 
SetBrEdrSecurityMode(BrEdrSecurityMode)247     void SetBrEdrSecurityMode(BrEdrSecurityMode) override {}
248 
security_mode()249     BrEdrSecurityMode security_mode() const override {
250       return BrEdrSecurityMode::Mode4;
251     }
252 
SetConnectable(bool,hci::ResultFunction<>)253     void SetConnectable(bool, hci::ResultFunction<>) override {}
254 
RequestDiscovery(DiscoveryCallback)255     void RequestDiscovery(DiscoveryCallback) override {}
256 
RequestDiscoverable(DiscoverableCallback)257     void RequestDiscoverable(DiscoverableCallback) override {}
258 
259     RegistrationHandle RegisterService(std::vector<sdp::ServiceRecord> records,
260                                        l2cap::ChannelParameters chan_params,
261                                        ServiceConnectCallback conn_cb) override;
262 
263     bool UnregisterService(RegistrationHandle handle) override;
264 
GetRegisteredServices(RegistrationHandle)265     std::vector<sdp::ServiceRecord> GetRegisteredServices(
266         RegistrationHandle) const override {
267       return {};
268     }
269 
OpenScoConnection(PeerId,const bt::StaticPacket<pw::bluetooth::emboss::SynchronousConnectionParametersWriter> &,sco::ScoConnectionManager::OpenConnectionCallback)270     std::optional<ScoRequestHandle> OpenScoConnection(
271         PeerId,
272         const bt::StaticPacket<
273             pw::bluetooth::emboss::SynchronousConnectionParametersWriter>&,
274         sco::ScoConnectionManager::OpenConnectionCallback) override {
275       return std::nullopt;
276     }
277 
AcceptScoConnection(PeerId,std::vector<bt::StaticPacket<pw::bluetooth::emboss::SynchronousConnectionParametersWriter>>,sco::ScoConnectionManager::AcceptConnectionCallback)278     std::optional<ScoRequestHandle> AcceptScoConnection(
279         PeerId,
280         std::vector<bt::StaticPacket<
281             pw::bluetooth::emboss::SynchronousConnectionParametersWriter>>,
282         sco::ScoConnectionManager::AcceptConnectionCallback) override {
283       return std::nullopt;
284     }
285 
286    private:
287     // Callback used by tests to get new channel refs.
288     ChannelCallback channel_cb_;
289     RegistrationHandle next_service_handle_ = 1;
290     RegistrationHandle next_search_handle_ = 1;
291     std::map<RegistrationHandle, RegisteredService> registered_services_;
292     std::map<RegistrationHandle, RegisteredSearch> registered_searches_;
293 
294     l2cap::ChannelId next_channel_id_ = l2cap::kFirstDynamicChannelId;
295     std::unordered_map<l2cap::ChannelId,
296                        std::unique_ptr<l2cap::testing::FakeChannel>>
297         channels_;
298   };
299 
bredr()300   BrEdr* bredr() const override { return fake_bredr_.get(); }
fake_bredr()301   FakeBrEdr* fake_bredr() const { return fake_bredr_.get(); }
302 
peer_cache()303   PeerCache* peer_cache() override { return &peer_cache_; }
304 
AddBondedPeer(BondingData)305   bool AddBondedPeer(BondingData) override { return true; }
306 
SetPairingDelegate(PairingDelegate::WeakPtr)307   void SetPairingDelegate(PairingDelegate::WeakPtr) override {}
308 
IsDiscoverable()309   bool IsDiscoverable() const override { return is_discoverable_; }
310 
IsDiscovering()311   bool IsDiscovering() const override { return is_discovering_; }
312 
313   void SetLocalName(std::string name, hci::ResultFunction<> callback) override;
314 
local_name()315   std::string local_name() const override { return local_name_; }
316 
317   void SetDeviceClass(DeviceClass dev_class,
318                       hci::ResultFunction<> callback) override;
319 
320   void GetSupportedDelayRange(
321       const bt::StaticPacket<pw::bluetooth::emboss::CodecIdWriter>& codec_id,
322       pw::bluetooth::emboss::LogicalTransportType logical_transport_type,
323       pw::bluetooth::emboss::DataPathDirection direction,
324       const std::optional<std::vector<uint8_t>>& codec_configuration,
325       GetSupportedDelayRangeCallback cb) override;
326 
set_auto_connect_callback(AutoConnectCallback)327   void set_auto_connect_callback(AutoConnectCallback) override {}
328 
AttachInspect(inspect::Node &,std::string)329   void AttachInspect(inspect::Node&, std::string) override {}
330 
AsWeakPtr()331   Adapter::WeakPtr AsWeakPtr() override { return weak_self_.GetWeakPtr(); }
332 
333  private:
334   enum InitState {
335     kNotInitialized = 0,
336     kInitializing,
337     kInitialized,
338   };
339 
340   InitState init_state_;
341   AdapterState state_;
342   std::unique_ptr<FakeLowEnergy> fake_le_;
343   std::unique_ptr<FakeBrEdr> fake_bredr_;
344   bool is_discoverable_ = true;
345   bool is_discovering_ = true;
346   std::string local_name_;
347   DeviceClass device_class_;
348   LESecurityMode le_security_mode_;
349 
350   pw::async::Dispatcher& pw_dispatcher_;
351   pw::async::HeapDispatcher heap_dispatcher_;
352   PeerCache peer_cache_;
353   WeakSelf<Adapter> weak_self_;
354 };
355 
356 }  // namespace bt::gap::testing
357