1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_bluetooth_sapphire/internal/host/l2cap/fake_signaling_channel.h"
16
17 #include "pw_bluetooth_sapphire/internal/host/testing/test_helpers.h"
18 #include "pw_unit_test/framework.h"
19
20 namespace bt::l2cap::internal::testing {
21 namespace {
22
23 // These classes bind the response that the request handlers are expected to
24 // send back. These also serve as the actual Responder implementation that the
25 // request handler under test will see. These roles may need to be decoupled if
26 // request handlers have to be tested for multiple responses to each request.
27 class Expecter : public SignalingChannel::Responder {
28 public:
Send(const ByteBuffer & rsp_payload)29 void Send(const ByteBuffer& rsp_payload) override {
30 ADD_FAILURE() << "Unexpected local response " << rsp_payload.AsString();
31 }
32
RejectNotUnderstood()33 void RejectNotUnderstood() override {
34 ADD_FAILURE() << "Unexpected local rejection, \"Not Understood\"";
35 }
36
RejectInvalidChannelId(ChannelId local_cid,ChannelId remote_cid)37 void RejectInvalidChannelId(ChannelId local_cid,
38 ChannelId remote_cid) override {
39 ADD_FAILURE() << bt_lib_cpp_string::StringPrintf(
40 "Unexpected local rejection, \"Invalid Channel ID\" local: %#.4x "
41 "remote: %#.4x",
42 local_cid,
43 remote_cid);
44 }
45
called() const46 bool called() const { return called_; }
47
48 protected:
set_called(bool called)49 void set_called(bool called) { called_ = called; }
50
51 private:
52 bool called_ = false;
53 };
54
55 class ResponseExpecter : public Expecter {
56 public:
ResponseExpecter(const ByteBuffer & expected_rsp)57 explicit ResponseExpecter(const ByteBuffer& expected_rsp)
58 : expected_rsp_(expected_rsp) {}
59
Send(const ByteBuffer & rsp_payload)60 void Send(const ByteBuffer& rsp_payload) override {
61 set_called(true);
62 EXPECT_TRUE(ContainersEqual(expected_rsp_, rsp_payload));
63 }
64
65 private:
66 const ByteBuffer& expected_rsp_;
67 };
68
69 class RejectNotUnderstoodExpecter : public Expecter {
70 public:
RejectNotUnderstood()71 void RejectNotUnderstood() override { set_called(true); }
72 };
73
74 class RejectInvalidChannelIdExpecter : public Expecter {
75 public:
RejectInvalidChannelIdExpecter(ChannelId local_cid,ChannelId remote_cid)76 RejectInvalidChannelIdExpecter(ChannelId local_cid, ChannelId remote_cid)
77 : local_cid_(local_cid), remote_cid_(remote_cid) {}
78
RejectInvalidChannelId(ChannelId local_cid,ChannelId remote_cid)79 void RejectInvalidChannelId(ChannelId local_cid,
80 ChannelId remote_cid) override {
81 set_called(true);
82 EXPECT_EQ(local_cid_, local_cid);
83 EXPECT_EQ(remote_cid_, remote_cid);
84 }
85
86 private:
87 const ChannelId local_cid_;
88 const ChannelId remote_cid_;
89 };
90
91 } // namespace
92
FakeSignalingChannel(pw::async::Dispatcher & pw_dispatcher)93 FakeSignalingChannel::FakeSignalingChannel(pw::async::Dispatcher& pw_dispatcher)
94 : heap_dispatcher_(pw_dispatcher) {}
95
~FakeSignalingChannel()96 FakeSignalingChannel::~FakeSignalingChannel() {
97 // Add a test failure for each expected request that wasn't received
98 for (size_t i = expected_transaction_index_; i < transactions_.size(); i++) {
99 ADD_FAILURE_AT(transactions_[i].file, transactions_[i].line)
100 << "Outbound request [" << i << "] expected "
101 << transactions_[i].responses.size() << " responses";
102 }
103 }
104
SendRequest(CommandCode req_code,const ByteBuffer & payload,SignalingChannel::ResponseHandler cb)105 bool FakeSignalingChannel::SendRequest(CommandCode req_code,
106 const ByteBuffer& payload,
107 SignalingChannel::ResponseHandler cb) {
108 if (expected_transaction_index_ >= transactions_.size()) {
109 ADD_FAILURE() << "Received unexpected outbound command after handling "
110 << transactions_.size();
111 return false;
112 }
113
114 Transaction& transaction = transactions_[expected_transaction_index_];
115 ::testing::ScopedTrace trace(
116 transaction.file, transaction.line, "Outbound request expected here");
117 EXPECT_EQ(transaction.request_code, req_code);
118 EXPECT_TRUE(ContainersEqual(transaction.req_payload, payload));
119 EXPECT_TRUE(cb);
120 transaction.response_callback = std::move(cb);
121
122 // Simulate the remote's response(s)
123 (void)heap_dispatcher_.Post(
124 [this, index = expected_transaction_index_](pw::async::Context /*ctx*/,
125 pw::Status status) {
126 if (status.ok()) {
127 Transaction& transaction = transactions_[index];
128 transaction.responses_handled =
129 TriggerResponses(transaction, transaction.responses);
130 }
131 });
132
133 expected_transaction_index_++;
134 return (transaction.request_code == req_code);
135 }
136
ReceiveResponses(TransactionId id,const std::vector<FakeSignalingChannel::Response> & responses)137 void FakeSignalingChannel::ReceiveResponses(
138 TransactionId id,
139 const std::vector<FakeSignalingChannel::Response>& responses) {
140 if (id >= transactions_.size()) {
141 FAIL() << "Can't trigger response to unknown outbound request " << id;
142 }
143
144 const Transaction& transaction = transactions_[id];
145 {
146 ::testing::ScopedTrace trace(
147 transaction.file, transaction.line, "Outbound request expected here");
148 ASSERT_TRUE(transaction.response_callback)
149 << "Can't trigger responses for outbound request that hasn't been sent";
150 EXPECT_EQ(transaction.responses.size(), transaction.responses_handled)
151 << "Not all original simulated responses have been handled";
152 }
153 TriggerResponses(transaction, responses);
154 }
155
ServeRequest(CommandCode req_code,SignalingChannel::RequestDelegate cb)156 void FakeSignalingChannel::ServeRequest(CommandCode req_code,
157 SignalingChannel::RequestDelegate cb) {
158 request_handlers_[req_code] = std::move(cb);
159 }
160
AddOutbound(const char * file,int line,CommandCode req_code,BufferView req_payload,std::vector<FakeSignalingChannel::Response> responses)161 FakeSignalingChannel::TransactionId FakeSignalingChannel::AddOutbound(
162 const char* file,
163 int line,
164 CommandCode req_code,
165 BufferView req_payload,
166 std::vector<FakeSignalingChannel::Response> responses) {
167 transactions_.push_back(Transaction{file,
168 line,
169 req_code,
170 std::move(req_payload),
171 std::move(responses),
172 nullptr});
173 return transactions_.size() - 1;
174 }
175
ReceiveExpect(CommandCode req_code,const ByteBuffer & req_payload,const ByteBuffer & rsp_payload)176 void FakeSignalingChannel::ReceiveExpect(CommandCode req_code,
177 const ByteBuffer& req_payload,
178 const ByteBuffer& rsp_payload) {
179 ResponseExpecter expecter(rsp_payload);
180 ReceiveExpectInternal(req_code, req_payload, &expecter);
181 }
182
ReceiveExpectRejectNotUnderstood(CommandCode req_code,const ByteBuffer & req_payload)183 void FakeSignalingChannel::ReceiveExpectRejectNotUnderstood(
184 CommandCode req_code, const ByteBuffer& req_payload) {
185 RejectNotUnderstoodExpecter expecter;
186 ReceiveExpectInternal(req_code, req_payload, &expecter);
187 }
188
ReceiveExpectRejectInvalidChannelId(CommandCode req_code,const ByteBuffer & req_payload,ChannelId local_cid,ChannelId remote_cid)189 void FakeSignalingChannel::ReceiveExpectRejectInvalidChannelId(
190 CommandCode req_code,
191 const ByteBuffer& req_payload,
192 ChannelId local_cid,
193 ChannelId remote_cid) {
194 RejectInvalidChannelIdExpecter expecter(local_cid, remote_cid);
195 ReceiveExpectInternal(req_code, req_payload, &expecter);
196 }
197
TriggerResponses(const FakeSignalingChannel::Transaction & transaction,const std::vector<FakeSignalingChannel::Response> & responses)198 size_t FakeSignalingChannel::TriggerResponses(
199 const FakeSignalingChannel::Transaction& transaction,
200 const std::vector<FakeSignalingChannel::Response>& responses) {
201 ::testing::ScopedTrace trace(
202 transaction.file, transaction.line, "Outbound request expected here");
203 size_t responses_handled = 0;
204 for (auto& [status, payload] : responses) {
205 responses_handled++;
206 if (transaction.response_callback(status, payload) ==
207 ResponseHandlerAction::kCompleteOutboundTransaction ||
208 ::testing::Test::HasFatalFailure()) {
209 break;
210 }
211 }
212
213 EXPECT_EQ(responses.size(), responses_handled)
214 << bt_lib_cpp_string::StringPrintf(
215 "Outbound command (code %d, at %zu) handled fewer responses than "
216 "expected",
217 transaction.request_code,
218 transaction.responses_handled);
219
220 return responses_handled;
221 }
222
223 // Test evaluator for inbound requests with type-erased, bound expected requests
ReceiveExpectInternal(CommandCode req_code,const ByteBuffer & req_payload,Responder * fake_responder)224 void FakeSignalingChannel::ReceiveExpectInternal(CommandCode req_code,
225 const ByteBuffer& req_payload,
226 Responder* fake_responder) {
227 auto iter = request_handlers_.find(req_code);
228 ASSERT_NE(request_handlers_.end(), iter);
229
230 // Invoke delegate assigned for this request type
231 iter->second(req_payload, fake_responder);
232 EXPECT_TRUE(static_cast<Expecter*>(fake_responder)->called());
233 }
234
235 } // namespace bt::l2cap::internal::testing
236