xref: /aosp_15_r20/system/chre/util/system/message_router.cc (revision 84e339476a462649f82315436d70fd732297a399)
1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <inttypes.h>
18 #include <cstring>
19 #include <optional>
20 #include <utility>
21 
22 #include "chre/platform/log.h"
23 #include "chre/util/dynamic_vector.h"
24 #include "chre/util/lock_guard.h"
25 #include "chre/util/system/message_common.h"
26 #include "chre/util/system/message_router.h"
27 
28 namespace chre::message {
29 
MessageHub()30 MessageRouter::MessageHub::MessageHub()
31     : mRouter(nullptr), mHubId(MESSAGE_HUB_ID_INVALID) {}
32 
MessageHub(MessageRouter & router,MessageHubId id)33 MessageRouter::MessageHub::MessageHub(MessageRouter &router, MessageHubId id)
34     : mRouter(&router), mHubId(id) {}
35 
MessageHub(MessageHub && other)36 MessageRouter::MessageHub::MessageHub(MessageHub &&other)
37     : mRouter(other.mRouter), mHubId(other.mHubId) {
38   other.mRouter = nullptr;
39   other.mHubId = MESSAGE_HUB_ID_INVALID;
40 }
41 
operator =(MessageHub && other)42 MessageRouter::MessageHub &MessageRouter::MessageHub::operator=(
43     MessageHub &&other) {
44   mRouter = other.mRouter;
45   mHubId = other.mHubId;
46   other.mRouter = nullptr;
47   other.mHubId = MESSAGE_HUB_ID_INVALID;
48   return *this;
49 }
50 
openSession(EndpointId fromEndpointId,MessageHubId toMessageHubId,EndpointId toEndpointId)51 SessionId MessageRouter::MessageHub::openSession(EndpointId fromEndpointId,
52                                                  MessageHubId toMessageHubId,
53                                                  EndpointId toEndpointId) {
54   return mRouter == nullptr
55              ? SESSION_ID_INVALID
56              : mRouter->openSession(mHubId, fromEndpointId, toMessageHubId,
57                                     toEndpointId);
58 }
59 
closeSession(SessionId sessionId)60 bool MessageRouter::MessageHub::closeSession(SessionId sessionId) {
61   return mRouter == nullptr ? false : mRouter->closeSession(mHubId, sessionId);
62 }
63 
getSessionWithId(SessionId sessionId)64 std::optional<Session> MessageRouter::MessageHub::getSessionWithId(
65     SessionId sessionId) {
66   return mRouter == nullptr ? std::nullopt
67                             : mRouter->getSessionWithId(mHubId, sessionId);
68 }
69 
sendMessage(pw::UniquePtr<std::byte[]> && data,size_t length,uint32_t messageType,uint32_t messagePermissions,SessionId sessionId)70 bool MessageRouter::MessageHub::sendMessage(pw::UniquePtr<std::byte[]> &&data,
71                                             size_t length, uint32_t messageType,
72                                             uint32_t messagePermissions,
73                                             SessionId sessionId) {
74   return mRouter == nullptr
75              ? false
76              : mRouter->sendMessage(std::move(data), length, messageType,
77                                     messagePermissions, sessionId, mHubId);
78 }
79 
getId()80 MessageHubId MessageRouter::MessageHub::getId() {
81   return mHubId;
82 }
83 
84 std::optional<typename MessageRouter::MessageHub>
registerMessageHub(const char * name,MessageHubId id,MessageRouter::MessageRouter::MessageHubCallback & callback)85 MessageRouter::registerMessageHub(
86     const char *name, MessageHubId id,
87     MessageRouter::MessageRouter::MessageHubCallback &callback) {
88   LockGuard<Mutex> lock(mMutex);
89   if (mMessageHubs.full()) {
90     LOGE(
91         "Message hub '%s' not registered: maximum number of message hubs "
92         "reached",
93         name);
94     return std::nullopt;
95   }
96 
97   for (MessageHubRecord &messageHub : mMessageHubs) {
98     if (std::strcmp(messageHub.info.name, name) == 0 ||
99         messageHub.info.id == id) {
100       LOGE(
101           "Message hub '%s' not registered: hub with same name or ID already "
102           "exists",
103           name);
104       return std::nullopt;
105     }
106   }
107 
108   MessageHubRecord messageHubRecord = {
109       .info = {.id = id, .name = name},
110       .callback = &callback,
111   };
112   mMessageHubs.push_back(std::move(messageHubRecord));
113   return MessageHub(*this, id);
114 }
115 
forEachEndpointOfHub(MessageHubId messageHubId,const pw::Function<bool (const EndpointInfo &)> & function)116 bool MessageRouter::forEachEndpointOfHub(
117     MessageHubId messageHubId,
118     const pw::Function<bool(const EndpointInfo &)> &function) {
119   MessageRouter::MessageHubCallback *callback =
120       getCallbackFromMessageHubId(messageHubId);
121   if (callback == nullptr) {
122     LOGE("Failed to find message hub with ID %" PRIu64, messageHubId);
123     return false;
124   }
125 
126   callback->forEachEndpoint(function);
127   return true;
128 }
129 
forEachEndpoint(const pw::Function<void (const MessageHubInfo &,const EndpointInfo &)> & function)130 void MessageRouter::forEachEndpoint(
131     const pw::Function<void(const MessageHubInfo &, const EndpointInfo &)>
132         &function) {
133   LockGuard<Mutex> lock(mMutex);
134 
135   struct Context {
136     decltype(function) function;
137     MessageHubInfo &messageHubInfo;
138   };
139   for (MessageHubRecord &messageHubRecord : mMessageHubs) {
140     Context context = {
141         .function = function,
142         .messageHubInfo = messageHubRecord.info,
143     };
144 
145     messageHubRecord.callback->forEachEndpoint(
146         [&context](const EndpointInfo &endpointInfo) {
147           context.function(context.messageHubInfo, endpointInfo);
148           return false;
149         });
150   }
151 }
152 
getEndpointInfo(MessageHubId messageHubId,EndpointId endpointId)153 std::optional<EndpointInfo> MessageRouter::getEndpointInfo(
154     MessageHubId messageHubId, EndpointId endpointId) {
155   MessageRouter::MessageHubCallback *callback =
156       getCallbackFromMessageHubId(messageHubId);
157   if (callback == nullptr) {
158     LOGE("Failed to get endpoint info for message hub with ID %" PRIu64
159          " and endpoint ID %" PRIu64 ": hub not found",
160          messageHubId, endpointId);
161     return std::nullopt;
162   }
163 
164   return callback->getEndpointInfo(endpointId);
165 }
166 
forEachMessageHub(const pw::Function<bool (const MessageHubInfo &)> & function)167 void MessageRouter::forEachMessageHub(
168     const pw::Function<bool(const MessageHubInfo &)> &function) {
169   LockGuard<Mutex> lock(mMutex);
170   for (MessageHubRecord &messageHubRecord : mMessageHubs) {
171     function(messageHubRecord.info);
172   }
173 }
174 
unregisterMessageHub(MessageHubId fromMessageHubId)175 bool MessageRouter::unregisterMessageHub(MessageHubId fromMessageHubId) {
176   DynamicVector<std::pair<MessageHubCallback *, Session>> sessionsToDestroy;
177 
178   {
179     LockGuard<Mutex> lock(mMutex);
180 
181     bool success = false;
182     for (MessageHubRecord &messageHubRecord : mMessageHubs) {
183       if (messageHubRecord.info.id == fromMessageHubId) {
184         mMessageHubs.erase(&messageHubRecord);
185         success = true;
186         break;
187       }
188     }
189     if (!success) {
190       return false;
191     }
192 
193     for (size_t i = 0; i < mSessions.size();) {
194       Session &session = mSessions[i];
195       bool initiatorIsFromHub =
196           session.initiator.messageHubId == fromMessageHubId;
197       bool peerIsFromHub = session.peer.messageHubId == fromMessageHubId;
198 
199       if (initiatorIsFromHub || peerIsFromHub) {
200         MessageHubCallback *callback = getCallbackFromMessageHubIdLocked(
201             initiatorIsFromHub ? session.peer.messageHubId
202                                : session.initiator.messageHubId);
203         sessionsToDestroy.push_back(std::make_pair(callback, session));
204         mSessions.erase(&mSessions[i]);
205       } else {
206         ++i;
207       }
208     }
209   }
210 
211   for (auto [callback, session] : sessionsToDestroy) {
212     if (callback != nullptr) {
213       callback->onSessionClosed(session);
214     }
215   }
216   return true;
217 }
218 
openSession(MessageHubId fromMessageHubId,EndpointId fromEndpointId,MessageHubId toMessageHubId,EndpointId toEndpointId)219 SessionId MessageRouter::openSession(MessageHubId fromMessageHubId,
220                                      EndpointId fromEndpointId,
221                                      MessageHubId toMessageHubId,
222                                      EndpointId toEndpointId) {
223   if (fromMessageHubId == toMessageHubId) {
224     LOGE(
225         "Failed to open session: initiator and peer message hubs are the "
226         "same");
227     return SESSION_ID_INVALID;
228   }
229 
230   MessageRouter::MessageHubCallback *initiatorCallback =
231       getCallbackFromMessageHubId(fromMessageHubId);
232   MessageRouter::MessageHubCallback *peerCallback =
233       getCallbackFromMessageHubId(toMessageHubId);
234   if (initiatorCallback == nullptr || peerCallback == nullptr) {
235     LOGE("Failed to open session: initiator or peer message hub not found");
236     return SESSION_ID_INVALID;
237   }
238 
239   if (!checkIfEndpointExists(initiatorCallback, fromEndpointId)) {
240     LOGE("Failed to open session: endpoint with ID %" PRIu64
241          " not found in message hub with ID %" PRIu64,
242          fromEndpointId, fromMessageHubId);
243     return SESSION_ID_INVALID;
244   }
245 
246   if (!checkIfEndpointExists(peerCallback, toEndpointId)) {
247     LOGE("Failed to open session: endpoint with ID %" PRIu64
248          " not found in message hub with ID %" PRIu64,
249          toEndpointId, toMessageHubId);
250     return SESSION_ID_INVALID;
251   }
252 
253   {
254     LockGuard<Mutex> lock(mMutex);
255     if (mSessions.full()) {
256       LOGE("Failed to open session: maximum number of sessions reached");
257       return SESSION_ID_INVALID;
258     }
259 
260     Session insertSession = {
261         .sessionId = mNextSessionId,
262         .initiator = {.messageHubId = fromMessageHubId,
263                       .endpointId = fromEndpointId},
264         .peer = {.messageHubId = toMessageHubId, .endpointId = toEndpointId},
265     };
266 
267     for (Session &session : mSessions) {
268       if (session.isEquivalent(insertSession)) {
269         LOGD("Session with ID %" PRIu16 " already exists", session.sessionId);
270         return session.sessionId;
271       }
272     }
273 
274     mSessions.push_back(std::move(insertSession));
275     return mNextSessionId++;
276   }
277 }
278 
closeSession(MessageHubId fromMessageHubId,SessionId sessionId)279 bool MessageRouter::closeSession(MessageHubId fromMessageHubId,
280                                  SessionId sessionId) {
281   Session session;
282   MessageRouter::MessageHubCallback *initiatorCallback = nullptr;
283   MessageRouter::MessageHubCallback *peerCallback = nullptr;
284   {
285     LockGuard<Mutex> lock(mMutex);
286 
287     std::optional<size_t> index =
288         findSessionIndexLocked(fromMessageHubId, sessionId);
289     if (!index.has_value()) {
290       LOGE("Failed to close session with ID %" PRIu16 ": session not found",
291            sessionId);
292       return false;
293     }
294 
295     session = mSessions[*index];
296     initiatorCallback =
297         getCallbackFromMessageHubIdLocked(session.initiator.messageHubId);
298     peerCallback = getCallbackFromMessageHubIdLocked(session.peer.messageHubId);
299     mSessions.erase(&mSessions[*index]);
300   }
301 
302   if (initiatorCallback != nullptr) {
303     initiatorCallback->onSessionClosed(session);
304   }
305   if (peerCallback != nullptr) {
306     peerCallback->onSessionClosed(session);
307   }
308   return true;
309 }
310 
getSessionWithId(MessageHubId fromMessageHubId,SessionId sessionId)311 std::optional<Session> MessageRouter::getSessionWithId(
312     MessageHubId fromMessageHubId, SessionId sessionId) {
313   LockGuard<Mutex> lock(mMutex);
314 
315   std::optional<size_t> index =
316       findSessionIndexLocked(fromMessageHubId, sessionId);
317   return index.has_value() ? std::optional<Session>(mSessions[*index])
318                            : std::nullopt;
319 }
320 
sendMessage(pw::UniquePtr<std::byte[]> && data,size_t length,uint32_t messageType,uint32_t messagePermissions,SessionId sessionId,MessageHubId fromMessageHubId)321 bool MessageRouter::sendMessage(pw::UniquePtr<std::byte[]> &&data,
322                                 size_t length, uint32_t messageType,
323                                 uint32_t messagePermissions,
324                                 SessionId sessionId,
325                                 MessageHubId fromMessageHubId) {
326   MessageRouter::MessageHubCallback *receiverCallback = nullptr;
327   Session session;
328   {
329     LockGuard<Mutex> lock(mMutex);
330 
331     std::optional<size_t> index =
332         findSessionIndexLocked(fromMessageHubId, sessionId);
333     if (!index.has_value()) {
334       LOGE("Failed to send message: session with ID %" PRIu16 " not found",
335            sessionId);
336       return false;
337     }
338 
339     session = mSessions[*index];
340     receiverCallback = getCallbackFromMessageHubIdLocked(
341         session.initiator.messageHubId == fromMessageHubId
342             ? session.peer.messageHubId
343             : session.initiator.messageHubId);
344   }
345 
346   bool success = false;
347   if (receiverCallback != nullptr) {
348     success = receiverCallback->onMessageReceived(
349         std::move(data), length, messageType, messagePermissions, session,
350         session.initiator.messageHubId == fromMessageHubId);
351   }
352 
353   if (!success) {
354     closeSession(fromMessageHubId, sessionId);
355   }
356   return success;
357 }
358 
getMessageHubRecordLocked(MessageHubId messageHubId)359 const MessageRouter::MessageHubRecord *MessageRouter::getMessageHubRecordLocked(
360     MessageHubId messageHubId) {
361   for (MessageHubRecord &messageHubRecord : mMessageHubs) {
362     if (messageHubRecord.info.id == messageHubId) {
363       return &messageHubRecord;
364     }
365   }
366   return nullptr;
367 }
368 
findSessionIndexLocked(MessageHubId fromMessageHubId,SessionId sessionId)369 std::optional<size_t> MessageRouter::findSessionIndexLocked(
370     MessageHubId fromMessageHubId, SessionId sessionId) {
371   for (size_t i = 0; i < mSessions.size(); ++i) {
372     if (mSessions[i].sessionId == sessionId) {
373       if (mSessions[i].initiator.messageHubId == fromMessageHubId ||
374           mSessions[i].peer.messageHubId == fromMessageHubId) {
375         return i;
376       }
377 
378       LOGE("Hub mismatch for session with ID %" PRIu16
379            ": requesting hub ID %" PRIu64
380            " but session is between hubs %" PRIu64 " and %" PRIu64,
381            sessionId, fromMessageHubId, mSessions[i].initiator.messageHubId,
382            mSessions[i].peer.messageHubId);
383       break;
384     }
385   }
386   return std::nullopt;
387 }
388 
getCallbackFromMessageHubId(MessageHubId messageHubId)389 MessageRouter::MessageHubCallback *MessageRouter::getCallbackFromMessageHubId(
390     MessageHubId messageHubId) {
391   LockGuard<Mutex> lock(mMutex);
392   return getCallbackFromMessageHubIdLocked(messageHubId);
393 }
394 
395 MessageRouter::MessageHubCallback *
getCallbackFromMessageHubIdLocked(MessageHubId messageHubId)396 MessageRouter::getCallbackFromMessageHubIdLocked(MessageHubId messageHubId) {
397   const MessageHubRecord *messageHubRecord =
398       getMessageHubRecordLocked(messageHubId);
399   return messageHubRecord == nullptr ? nullptr : messageHubRecord->callback;
400 }
401 
checkIfEndpointExists(MessageRouter::MessageHubCallback * callback,EndpointId endpointId)402 bool MessageRouter::checkIfEndpointExists(
403     MessageRouter::MessageHubCallback *callback, EndpointId endpointId) {
404   struct EndpointContext {
405     EndpointId endpointId;
406     bool foundEndpoint = false;
407   };
408   EndpointContext context = {
409       .endpointId = endpointId,
410   };
411 
412   callback->forEachEndpoint([&context](const EndpointInfo &endpointInfo) {
413     if (context.endpointId == endpointInfo.id) {
414       context.foundEndpoint = true;
415       return true;
416     }
417     return false;
418   });
419   return context.foundEndpoint;
420 }
421 
422 }  // namespace chre::message
423