1 /*
2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <inttypes.h>
18 #include <cstring>
19 #include <optional>
20 #include <utility>
21
22 #include "chre/platform/log.h"
23 #include "chre/util/dynamic_vector.h"
24 #include "chre/util/lock_guard.h"
25 #include "chre/util/system/message_common.h"
26 #include "chre/util/system/message_router.h"
27
28 namespace chre::message {
29
MessageHub()30 MessageRouter::MessageHub::MessageHub()
31 : mRouter(nullptr), mHubId(MESSAGE_HUB_ID_INVALID) {}
32
MessageHub(MessageRouter & router,MessageHubId id)33 MessageRouter::MessageHub::MessageHub(MessageRouter &router, MessageHubId id)
34 : mRouter(&router), mHubId(id) {}
35
MessageHub(MessageHub && other)36 MessageRouter::MessageHub::MessageHub(MessageHub &&other)
37 : mRouter(other.mRouter), mHubId(other.mHubId) {
38 other.mRouter = nullptr;
39 other.mHubId = MESSAGE_HUB_ID_INVALID;
40 }
41
operator =(MessageHub && other)42 MessageRouter::MessageHub &MessageRouter::MessageHub::operator=(
43 MessageHub &&other) {
44 mRouter = other.mRouter;
45 mHubId = other.mHubId;
46 other.mRouter = nullptr;
47 other.mHubId = MESSAGE_HUB_ID_INVALID;
48 return *this;
49 }
50
openSession(EndpointId fromEndpointId,MessageHubId toMessageHubId,EndpointId toEndpointId)51 SessionId MessageRouter::MessageHub::openSession(EndpointId fromEndpointId,
52 MessageHubId toMessageHubId,
53 EndpointId toEndpointId) {
54 return mRouter == nullptr
55 ? SESSION_ID_INVALID
56 : mRouter->openSession(mHubId, fromEndpointId, toMessageHubId,
57 toEndpointId);
58 }
59
closeSession(SessionId sessionId)60 bool MessageRouter::MessageHub::closeSession(SessionId sessionId) {
61 return mRouter == nullptr ? false : mRouter->closeSession(mHubId, sessionId);
62 }
63
getSessionWithId(SessionId sessionId)64 std::optional<Session> MessageRouter::MessageHub::getSessionWithId(
65 SessionId sessionId) {
66 return mRouter == nullptr ? std::nullopt
67 : mRouter->getSessionWithId(mHubId, sessionId);
68 }
69
sendMessage(pw::UniquePtr<std::byte[]> && data,size_t length,uint32_t messageType,uint32_t messagePermissions,SessionId sessionId)70 bool MessageRouter::MessageHub::sendMessage(pw::UniquePtr<std::byte[]> &&data,
71 size_t length, uint32_t messageType,
72 uint32_t messagePermissions,
73 SessionId sessionId) {
74 return mRouter == nullptr
75 ? false
76 : mRouter->sendMessage(std::move(data), length, messageType,
77 messagePermissions, sessionId, mHubId);
78 }
79
getId()80 MessageHubId MessageRouter::MessageHub::getId() {
81 return mHubId;
82 }
83
84 std::optional<typename MessageRouter::MessageHub>
registerMessageHub(const char * name,MessageHubId id,MessageRouter::MessageRouter::MessageHubCallback & callback)85 MessageRouter::registerMessageHub(
86 const char *name, MessageHubId id,
87 MessageRouter::MessageRouter::MessageHubCallback &callback) {
88 LockGuard<Mutex> lock(mMutex);
89 if (mMessageHubs.full()) {
90 LOGE(
91 "Message hub '%s' not registered: maximum number of message hubs "
92 "reached",
93 name);
94 return std::nullopt;
95 }
96
97 for (MessageHubRecord &messageHub : mMessageHubs) {
98 if (std::strcmp(messageHub.info.name, name) == 0 ||
99 messageHub.info.id == id) {
100 LOGE(
101 "Message hub '%s' not registered: hub with same name or ID already "
102 "exists",
103 name);
104 return std::nullopt;
105 }
106 }
107
108 MessageHubRecord messageHubRecord = {
109 .info = {.id = id, .name = name},
110 .callback = &callback,
111 };
112 mMessageHubs.push_back(std::move(messageHubRecord));
113 return MessageHub(*this, id);
114 }
115
forEachEndpointOfHub(MessageHubId messageHubId,const pw::Function<bool (const EndpointInfo &)> & function)116 bool MessageRouter::forEachEndpointOfHub(
117 MessageHubId messageHubId,
118 const pw::Function<bool(const EndpointInfo &)> &function) {
119 MessageRouter::MessageHubCallback *callback =
120 getCallbackFromMessageHubId(messageHubId);
121 if (callback == nullptr) {
122 LOGE("Failed to find message hub with ID %" PRIu64, messageHubId);
123 return false;
124 }
125
126 callback->forEachEndpoint(function);
127 return true;
128 }
129
forEachEndpoint(const pw::Function<void (const MessageHubInfo &,const EndpointInfo &)> & function)130 void MessageRouter::forEachEndpoint(
131 const pw::Function<void(const MessageHubInfo &, const EndpointInfo &)>
132 &function) {
133 LockGuard<Mutex> lock(mMutex);
134
135 struct Context {
136 decltype(function) function;
137 MessageHubInfo &messageHubInfo;
138 };
139 for (MessageHubRecord &messageHubRecord : mMessageHubs) {
140 Context context = {
141 .function = function,
142 .messageHubInfo = messageHubRecord.info,
143 };
144
145 messageHubRecord.callback->forEachEndpoint(
146 [&context](const EndpointInfo &endpointInfo) {
147 context.function(context.messageHubInfo, endpointInfo);
148 return false;
149 });
150 }
151 }
152
getEndpointInfo(MessageHubId messageHubId,EndpointId endpointId)153 std::optional<EndpointInfo> MessageRouter::getEndpointInfo(
154 MessageHubId messageHubId, EndpointId endpointId) {
155 MessageRouter::MessageHubCallback *callback =
156 getCallbackFromMessageHubId(messageHubId);
157 if (callback == nullptr) {
158 LOGE("Failed to get endpoint info for message hub with ID %" PRIu64
159 " and endpoint ID %" PRIu64 ": hub not found",
160 messageHubId, endpointId);
161 return std::nullopt;
162 }
163
164 return callback->getEndpointInfo(endpointId);
165 }
166
forEachMessageHub(const pw::Function<bool (const MessageHubInfo &)> & function)167 void MessageRouter::forEachMessageHub(
168 const pw::Function<bool(const MessageHubInfo &)> &function) {
169 LockGuard<Mutex> lock(mMutex);
170 for (MessageHubRecord &messageHubRecord : mMessageHubs) {
171 function(messageHubRecord.info);
172 }
173 }
174
unregisterMessageHub(MessageHubId fromMessageHubId)175 bool MessageRouter::unregisterMessageHub(MessageHubId fromMessageHubId) {
176 DynamicVector<std::pair<MessageHubCallback *, Session>> sessionsToDestroy;
177
178 {
179 LockGuard<Mutex> lock(mMutex);
180
181 bool success = false;
182 for (MessageHubRecord &messageHubRecord : mMessageHubs) {
183 if (messageHubRecord.info.id == fromMessageHubId) {
184 mMessageHubs.erase(&messageHubRecord);
185 success = true;
186 break;
187 }
188 }
189 if (!success) {
190 return false;
191 }
192
193 for (size_t i = 0; i < mSessions.size();) {
194 Session &session = mSessions[i];
195 bool initiatorIsFromHub =
196 session.initiator.messageHubId == fromMessageHubId;
197 bool peerIsFromHub = session.peer.messageHubId == fromMessageHubId;
198
199 if (initiatorIsFromHub || peerIsFromHub) {
200 MessageHubCallback *callback = getCallbackFromMessageHubIdLocked(
201 initiatorIsFromHub ? session.peer.messageHubId
202 : session.initiator.messageHubId);
203 sessionsToDestroy.push_back(std::make_pair(callback, session));
204 mSessions.erase(&mSessions[i]);
205 } else {
206 ++i;
207 }
208 }
209 }
210
211 for (auto [callback, session] : sessionsToDestroy) {
212 if (callback != nullptr) {
213 callback->onSessionClosed(session);
214 }
215 }
216 return true;
217 }
218
openSession(MessageHubId fromMessageHubId,EndpointId fromEndpointId,MessageHubId toMessageHubId,EndpointId toEndpointId)219 SessionId MessageRouter::openSession(MessageHubId fromMessageHubId,
220 EndpointId fromEndpointId,
221 MessageHubId toMessageHubId,
222 EndpointId toEndpointId) {
223 if (fromMessageHubId == toMessageHubId) {
224 LOGE(
225 "Failed to open session: initiator and peer message hubs are the "
226 "same");
227 return SESSION_ID_INVALID;
228 }
229
230 MessageRouter::MessageHubCallback *initiatorCallback =
231 getCallbackFromMessageHubId(fromMessageHubId);
232 MessageRouter::MessageHubCallback *peerCallback =
233 getCallbackFromMessageHubId(toMessageHubId);
234 if (initiatorCallback == nullptr || peerCallback == nullptr) {
235 LOGE("Failed to open session: initiator or peer message hub not found");
236 return SESSION_ID_INVALID;
237 }
238
239 if (!checkIfEndpointExists(initiatorCallback, fromEndpointId)) {
240 LOGE("Failed to open session: endpoint with ID %" PRIu64
241 " not found in message hub with ID %" PRIu64,
242 fromEndpointId, fromMessageHubId);
243 return SESSION_ID_INVALID;
244 }
245
246 if (!checkIfEndpointExists(peerCallback, toEndpointId)) {
247 LOGE("Failed to open session: endpoint with ID %" PRIu64
248 " not found in message hub with ID %" PRIu64,
249 toEndpointId, toMessageHubId);
250 return SESSION_ID_INVALID;
251 }
252
253 {
254 LockGuard<Mutex> lock(mMutex);
255 if (mSessions.full()) {
256 LOGE("Failed to open session: maximum number of sessions reached");
257 return SESSION_ID_INVALID;
258 }
259
260 Session insertSession = {
261 .sessionId = mNextSessionId,
262 .initiator = {.messageHubId = fromMessageHubId,
263 .endpointId = fromEndpointId},
264 .peer = {.messageHubId = toMessageHubId, .endpointId = toEndpointId},
265 };
266
267 for (Session &session : mSessions) {
268 if (session.isEquivalent(insertSession)) {
269 LOGD("Session with ID %" PRIu16 " already exists", session.sessionId);
270 return session.sessionId;
271 }
272 }
273
274 mSessions.push_back(std::move(insertSession));
275 return mNextSessionId++;
276 }
277 }
278
closeSession(MessageHubId fromMessageHubId,SessionId sessionId)279 bool MessageRouter::closeSession(MessageHubId fromMessageHubId,
280 SessionId sessionId) {
281 Session session;
282 MessageRouter::MessageHubCallback *initiatorCallback = nullptr;
283 MessageRouter::MessageHubCallback *peerCallback = nullptr;
284 {
285 LockGuard<Mutex> lock(mMutex);
286
287 std::optional<size_t> index =
288 findSessionIndexLocked(fromMessageHubId, sessionId);
289 if (!index.has_value()) {
290 LOGE("Failed to close session with ID %" PRIu16 ": session not found",
291 sessionId);
292 return false;
293 }
294
295 session = mSessions[*index];
296 initiatorCallback =
297 getCallbackFromMessageHubIdLocked(session.initiator.messageHubId);
298 peerCallback = getCallbackFromMessageHubIdLocked(session.peer.messageHubId);
299 mSessions.erase(&mSessions[*index]);
300 }
301
302 if (initiatorCallback != nullptr) {
303 initiatorCallback->onSessionClosed(session);
304 }
305 if (peerCallback != nullptr) {
306 peerCallback->onSessionClosed(session);
307 }
308 return true;
309 }
310
getSessionWithId(MessageHubId fromMessageHubId,SessionId sessionId)311 std::optional<Session> MessageRouter::getSessionWithId(
312 MessageHubId fromMessageHubId, SessionId sessionId) {
313 LockGuard<Mutex> lock(mMutex);
314
315 std::optional<size_t> index =
316 findSessionIndexLocked(fromMessageHubId, sessionId);
317 return index.has_value() ? std::optional<Session>(mSessions[*index])
318 : std::nullopt;
319 }
320
sendMessage(pw::UniquePtr<std::byte[]> && data,size_t length,uint32_t messageType,uint32_t messagePermissions,SessionId sessionId,MessageHubId fromMessageHubId)321 bool MessageRouter::sendMessage(pw::UniquePtr<std::byte[]> &&data,
322 size_t length, uint32_t messageType,
323 uint32_t messagePermissions,
324 SessionId sessionId,
325 MessageHubId fromMessageHubId) {
326 MessageRouter::MessageHubCallback *receiverCallback = nullptr;
327 Session session;
328 {
329 LockGuard<Mutex> lock(mMutex);
330
331 std::optional<size_t> index =
332 findSessionIndexLocked(fromMessageHubId, sessionId);
333 if (!index.has_value()) {
334 LOGE("Failed to send message: session with ID %" PRIu16 " not found",
335 sessionId);
336 return false;
337 }
338
339 session = mSessions[*index];
340 receiverCallback = getCallbackFromMessageHubIdLocked(
341 session.initiator.messageHubId == fromMessageHubId
342 ? session.peer.messageHubId
343 : session.initiator.messageHubId);
344 }
345
346 bool success = false;
347 if (receiverCallback != nullptr) {
348 success = receiverCallback->onMessageReceived(
349 std::move(data), length, messageType, messagePermissions, session,
350 session.initiator.messageHubId == fromMessageHubId);
351 }
352
353 if (!success) {
354 closeSession(fromMessageHubId, sessionId);
355 }
356 return success;
357 }
358
getMessageHubRecordLocked(MessageHubId messageHubId)359 const MessageRouter::MessageHubRecord *MessageRouter::getMessageHubRecordLocked(
360 MessageHubId messageHubId) {
361 for (MessageHubRecord &messageHubRecord : mMessageHubs) {
362 if (messageHubRecord.info.id == messageHubId) {
363 return &messageHubRecord;
364 }
365 }
366 return nullptr;
367 }
368
findSessionIndexLocked(MessageHubId fromMessageHubId,SessionId sessionId)369 std::optional<size_t> MessageRouter::findSessionIndexLocked(
370 MessageHubId fromMessageHubId, SessionId sessionId) {
371 for (size_t i = 0; i < mSessions.size(); ++i) {
372 if (mSessions[i].sessionId == sessionId) {
373 if (mSessions[i].initiator.messageHubId == fromMessageHubId ||
374 mSessions[i].peer.messageHubId == fromMessageHubId) {
375 return i;
376 }
377
378 LOGE("Hub mismatch for session with ID %" PRIu16
379 ": requesting hub ID %" PRIu64
380 " but session is between hubs %" PRIu64 " and %" PRIu64,
381 sessionId, fromMessageHubId, mSessions[i].initiator.messageHubId,
382 mSessions[i].peer.messageHubId);
383 break;
384 }
385 }
386 return std::nullopt;
387 }
388
getCallbackFromMessageHubId(MessageHubId messageHubId)389 MessageRouter::MessageHubCallback *MessageRouter::getCallbackFromMessageHubId(
390 MessageHubId messageHubId) {
391 LockGuard<Mutex> lock(mMutex);
392 return getCallbackFromMessageHubIdLocked(messageHubId);
393 }
394
395 MessageRouter::MessageHubCallback *
getCallbackFromMessageHubIdLocked(MessageHubId messageHubId)396 MessageRouter::getCallbackFromMessageHubIdLocked(MessageHubId messageHubId) {
397 const MessageHubRecord *messageHubRecord =
398 getMessageHubRecordLocked(messageHubId);
399 return messageHubRecord == nullptr ? nullptr : messageHubRecord->callback;
400 }
401
checkIfEndpointExists(MessageRouter::MessageHubCallback * callback,EndpointId endpointId)402 bool MessageRouter::checkIfEndpointExists(
403 MessageRouter::MessageHubCallback *callback, EndpointId endpointId) {
404 struct EndpointContext {
405 EndpointId endpointId;
406 bool foundEndpoint = false;
407 };
408 EndpointContext context = {
409 .endpointId = endpointId,
410 };
411
412 callback->forEachEndpoint([&context](const EndpointInfo &endpointInfo) {
413 if (context.endpointId == endpointInfo.id) {
414 context.foundEndpoint = true;
415 return true;
416 }
417 return false;
418 });
419 return context.foundEndpoint;
420 }
421
422 } // namespace chre::message
423