xref: /aosp_15_r20/system/chre/host/common/socket_server.cc (revision 84e339476a462649f82315436d70fd732297a399)
1*84e33947SAndroid Build Coastguard Worker /*
2*84e33947SAndroid Build Coastguard Worker  * Copyright (C) 2017 The Android Open Source Project
3*84e33947SAndroid Build Coastguard Worker  *
4*84e33947SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*84e33947SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*84e33947SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*84e33947SAndroid Build Coastguard Worker  *
8*84e33947SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*84e33947SAndroid Build Coastguard Worker  *
10*84e33947SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*84e33947SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*84e33947SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*84e33947SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*84e33947SAndroid Build Coastguard Worker  * limitations under the License.
15*84e33947SAndroid Build Coastguard Worker  */
16*84e33947SAndroid Build Coastguard Worker 
17*84e33947SAndroid Build Coastguard Worker #include "chre_host/socket_server.h"
18*84e33947SAndroid Build Coastguard Worker 
19*84e33947SAndroid Build Coastguard Worker #include <poll.h>
20*84e33947SAndroid Build Coastguard Worker 
21*84e33947SAndroid Build Coastguard Worker #include <cassert>
22*84e33947SAndroid Build Coastguard Worker #include <cerrno>
23*84e33947SAndroid Build Coastguard Worker #include <cinttypes>
24*84e33947SAndroid Build Coastguard Worker #include <csignal>
25*84e33947SAndroid Build Coastguard Worker #include <cstdlib>
26*84e33947SAndroid Build Coastguard Worker #include <map>
27*84e33947SAndroid Build Coastguard Worker #include <mutex>
28*84e33947SAndroid Build Coastguard Worker 
29*84e33947SAndroid Build Coastguard Worker #include <cutils/sockets.h>
30*84e33947SAndroid Build Coastguard Worker 
31*84e33947SAndroid Build Coastguard Worker #include "chre_host/log.h"
32*84e33947SAndroid Build Coastguard Worker 
33*84e33947SAndroid Build Coastguard Worker namespace android {
34*84e33947SAndroid Build Coastguard Worker namespace chre {
35*84e33947SAndroid Build Coastguard Worker 
36*84e33947SAndroid Build Coastguard Worker std::atomic<bool> SocketServer::sSignalReceived(false);
37*84e33947SAndroid Build Coastguard Worker 
SocketServer()38*84e33947SAndroid Build Coastguard Worker SocketServer::SocketServer() {
39*84e33947SAndroid Build Coastguard Worker   // Initialize the socket fds field for all inactive client slots to -1, so
40*84e33947SAndroid Build Coastguard Worker   // poll skips over it, and we don't attempt to send on it
41*84e33947SAndroid Build Coastguard Worker   for (size_t i = 1; i <= kMaxActiveClients; i++) {
42*84e33947SAndroid Build Coastguard Worker     mPollFds[i].fd = -1;
43*84e33947SAndroid Build Coastguard Worker     mPollFds[i].events = POLLIN;
44*84e33947SAndroid Build Coastguard Worker   }
45*84e33947SAndroid Build Coastguard Worker }
46*84e33947SAndroid Build Coastguard Worker 
run(const char * socketName,bool allowSocketCreation,ClientMessageCallback clientMessageCallback)47*84e33947SAndroid Build Coastguard Worker void SocketServer::run(const char *socketName, bool allowSocketCreation,
48*84e33947SAndroid Build Coastguard Worker                        ClientMessageCallback clientMessageCallback) {
49*84e33947SAndroid Build Coastguard Worker   mClientMessageCallback = clientMessageCallback;
50*84e33947SAndroid Build Coastguard Worker 
51*84e33947SAndroid Build Coastguard Worker   mSockFd = android_get_control_socket(socketName);
52*84e33947SAndroid Build Coastguard Worker   if (mSockFd == INVALID_SOCKET && allowSocketCreation) {
53*84e33947SAndroid Build Coastguard Worker     LOGI("Didn't inherit socket, creating...");
54*84e33947SAndroid Build Coastguard Worker     mSockFd = socket_local_server(socketName, ANDROID_SOCKET_NAMESPACE_RESERVED,
55*84e33947SAndroid Build Coastguard Worker                                   SOCK_SEQPACKET);
56*84e33947SAndroid Build Coastguard Worker   }
57*84e33947SAndroid Build Coastguard Worker 
58*84e33947SAndroid Build Coastguard Worker   if (mSockFd == INVALID_SOCKET) {
59*84e33947SAndroid Build Coastguard Worker     LOGE("Couldn't get/create socket");
60*84e33947SAndroid Build Coastguard Worker   } else {
61*84e33947SAndroid Build Coastguard Worker     int ret = listen(mSockFd, kMaxPendingConnectionRequests);
62*84e33947SAndroid Build Coastguard Worker     if (ret < 0) {
63*84e33947SAndroid Build Coastguard Worker       LOG_ERROR("Couldn't listen on socket", errno);
64*84e33947SAndroid Build Coastguard Worker     } else {
65*84e33947SAndroid Build Coastguard Worker       serviceSocket();
66*84e33947SAndroid Build Coastguard Worker     }
67*84e33947SAndroid Build Coastguard Worker 
68*84e33947SAndroid Build Coastguard Worker     {
69*84e33947SAndroid Build Coastguard Worker       std::lock_guard<std::mutex> lock(mClientsMutex);
70*84e33947SAndroid Build Coastguard Worker       for (const auto &pair : mClients) {
71*84e33947SAndroid Build Coastguard Worker         int clientSocket = pair.first;
72*84e33947SAndroid Build Coastguard Worker         if (close(clientSocket) != 0) {
73*84e33947SAndroid Build Coastguard Worker           LOGI("Couldn't close client %" PRIu16 "'s socket: %s",
74*84e33947SAndroid Build Coastguard Worker                pair.second.clientId, strerror(errno));
75*84e33947SAndroid Build Coastguard Worker         }
76*84e33947SAndroid Build Coastguard Worker       }
77*84e33947SAndroid Build Coastguard Worker       mClients.clear();
78*84e33947SAndroid Build Coastguard Worker     }
79*84e33947SAndroid Build Coastguard Worker     close(mSockFd);
80*84e33947SAndroid Build Coastguard Worker   }
81*84e33947SAndroid Build Coastguard Worker }
82*84e33947SAndroid Build Coastguard Worker 
sendToAllClients(const void * data,size_t length)83*84e33947SAndroid Build Coastguard Worker void SocketServer::sendToAllClients(const void *data, size_t length) {
84*84e33947SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> lock(mClientsMutex);
85*84e33947SAndroid Build Coastguard Worker 
86*84e33947SAndroid Build Coastguard Worker   int deliveredCount = 0;
87*84e33947SAndroid Build Coastguard Worker   for (const auto &pair : mClients) {
88*84e33947SAndroid Build Coastguard Worker     int clientSocket = pair.first;
89*84e33947SAndroid Build Coastguard Worker     uint16_t clientId = pair.second.clientId;
90*84e33947SAndroid Build Coastguard Worker     if (sendToClientSocket(data, length, clientSocket, clientId)) {
91*84e33947SAndroid Build Coastguard Worker       deliveredCount++;
92*84e33947SAndroid Build Coastguard Worker     } else if (errno == EINTR) {
93*84e33947SAndroid Build Coastguard Worker       // Exit early if we were interrupted - we should only get this for
94*84e33947SAndroid Build Coastguard Worker       // SIGINT/SIGTERM, so we should exit quickly
95*84e33947SAndroid Build Coastguard Worker       break;
96*84e33947SAndroid Build Coastguard Worker     }
97*84e33947SAndroid Build Coastguard Worker   }
98*84e33947SAndroid Build Coastguard Worker 
99*84e33947SAndroid Build Coastguard Worker   if (deliveredCount == 0) {
100*84e33947SAndroid Build Coastguard Worker     LOGW("Got message but didn't deliver to any clients");
101*84e33947SAndroid Build Coastguard Worker   }
102*84e33947SAndroid Build Coastguard Worker }
103*84e33947SAndroid Build Coastguard Worker 
sendToClientById(const void * data,size_t length,uint16_t clientId)104*84e33947SAndroid Build Coastguard Worker bool SocketServer::sendToClientById(const void *data, size_t length,
105*84e33947SAndroid Build Coastguard Worker                                     uint16_t clientId) {
106*84e33947SAndroid Build Coastguard Worker   std::lock_guard<std::mutex> lock(mClientsMutex);
107*84e33947SAndroid Build Coastguard Worker 
108*84e33947SAndroid Build Coastguard Worker   bool sent = false;
109*84e33947SAndroid Build Coastguard Worker   for (const auto &pair : mClients) {
110*84e33947SAndroid Build Coastguard Worker     uint16_t thisClientId = pair.second.clientId;
111*84e33947SAndroid Build Coastguard Worker     if (thisClientId == clientId) {
112*84e33947SAndroid Build Coastguard Worker       int clientSocket = pair.first;
113*84e33947SAndroid Build Coastguard Worker       sent = sendToClientSocket(data, length, clientSocket, thisClientId);
114*84e33947SAndroid Build Coastguard Worker       break;
115*84e33947SAndroid Build Coastguard Worker     }
116*84e33947SAndroid Build Coastguard Worker   }
117*84e33947SAndroid Build Coastguard Worker 
118*84e33947SAndroid Build Coastguard Worker   return sent;
119*84e33947SAndroid Build Coastguard Worker }
120*84e33947SAndroid Build Coastguard Worker 
acceptClientConnection()121*84e33947SAndroid Build Coastguard Worker void SocketServer::acceptClientConnection() {
122*84e33947SAndroid Build Coastguard Worker   int clientSocket = accept(mSockFd, NULL, NULL);
123*84e33947SAndroid Build Coastguard Worker   if (clientSocket < 0) {
124*84e33947SAndroid Build Coastguard Worker     LOG_ERROR("Couldn't accept client connection", errno);
125*84e33947SAndroid Build Coastguard Worker   } else if (mClients.size() >= kMaxActiveClients) {
126*84e33947SAndroid Build Coastguard Worker     LOGW("Rejecting client request - maximum number of clients reached");
127*84e33947SAndroid Build Coastguard Worker     close(clientSocket);
128*84e33947SAndroid Build Coastguard Worker   } else {
129*84e33947SAndroid Build Coastguard Worker     ClientData clientData;
130*84e33947SAndroid Build Coastguard Worker     clientData.clientId = mNextClientId++;
131*84e33947SAndroid Build Coastguard Worker 
132*84e33947SAndroid Build Coastguard Worker     // We currently don't handle wraparound - if we're getting this many
133*84e33947SAndroid Build Coastguard Worker     // connects/disconnects, then something is wrong.
134*84e33947SAndroid Build Coastguard Worker     // TODO: can handle this properly by iterating over the existing clients to
135*84e33947SAndroid Build Coastguard Worker     // avoid a conflict.
136*84e33947SAndroid Build Coastguard Worker     if (clientData.clientId == 0) {
137*84e33947SAndroid Build Coastguard Worker       LOGE("Couldn't allocate client ID");
138*84e33947SAndroid Build Coastguard Worker       std::exit(-1);
139*84e33947SAndroid Build Coastguard Worker     }
140*84e33947SAndroid Build Coastguard Worker 
141*84e33947SAndroid Build Coastguard Worker     bool slotFound = false;
142*84e33947SAndroid Build Coastguard Worker     for (size_t i = 1; i <= kMaxActiveClients; i++) {
143*84e33947SAndroid Build Coastguard Worker       if (mPollFds[i].fd < 0) {
144*84e33947SAndroid Build Coastguard Worker         mPollFds[i].fd = clientSocket;
145*84e33947SAndroid Build Coastguard Worker         slotFound = true;
146*84e33947SAndroid Build Coastguard Worker         break;
147*84e33947SAndroid Build Coastguard Worker       }
148*84e33947SAndroid Build Coastguard Worker     }
149*84e33947SAndroid Build Coastguard Worker 
150*84e33947SAndroid Build Coastguard Worker     if (!slotFound) {
151*84e33947SAndroid Build Coastguard Worker       LOGE("Couldn't find slot for client!");
152*84e33947SAndroid Build Coastguard Worker       assert(slotFound);
153*84e33947SAndroid Build Coastguard Worker       close(clientSocket);
154*84e33947SAndroid Build Coastguard Worker     } else {
155*84e33947SAndroid Build Coastguard Worker       {
156*84e33947SAndroid Build Coastguard Worker         std::lock_guard<std::mutex> lock(mClientsMutex);
157*84e33947SAndroid Build Coastguard Worker         mClients[clientSocket] = clientData;
158*84e33947SAndroid Build Coastguard Worker       }
159*84e33947SAndroid Build Coastguard Worker       LOGI(
160*84e33947SAndroid Build Coastguard Worker           "Accepted new client connection (count %zu), assigned client ID "
161*84e33947SAndroid Build Coastguard Worker           "%" PRIu16,
162*84e33947SAndroid Build Coastguard Worker           mClients.size(), clientData.clientId);
163*84e33947SAndroid Build Coastguard Worker     }
164*84e33947SAndroid Build Coastguard Worker   }
165*84e33947SAndroid Build Coastguard Worker }
166*84e33947SAndroid Build Coastguard Worker 
handleClientData(int clientSocket)167*84e33947SAndroid Build Coastguard Worker void SocketServer::handleClientData(int clientSocket) {
168*84e33947SAndroid Build Coastguard Worker   const ClientData &clientData = mClients[clientSocket];
169*84e33947SAndroid Build Coastguard Worker   uint16_t clientId = clientData.clientId;
170*84e33947SAndroid Build Coastguard Worker 
171*84e33947SAndroid Build Coastguard Worker   ssize_t packetSize =
172*84e33947SAndroid Build Coastguard Worker       recv(clientSocket, mRecvBuffer.data(), mRecvBuffer.size(), MSG_DONTWAIT);
173*84e33947SAndroid Build Coastguard Worker   if (packetSize < 0) {
174*84e33947SAndroid Build Coastguard Worker     LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId,
175*84e33947SAndroid Build Coastguard Worker          strerror(errno));
176*84e33947SAndroid Build Coastguard Worker     if (ENOTCONN == errno) {
177*84e33947SAndroid Build Coastguard Worker       disconnectClient(clientSocket);
178*84e33947SAndroid Build Coastguard Worker     }
179*84e33947SAndroid Build Coastguard Worker   } else if (packetSize == 0) {
180*84e33947SAndroid Build Coastguard Worker     LOGI("Client %" PRIu16 " disconnected", clientId);
181*84e33947SAndroid Build Coastguard Worker     disconnectClient(clientSocket);
182*84e33947SAndroid Build Coastguard Worker   } else {
183*84e33947SAndroid Build Coastguard Worker     LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId);
184*84e33947SAndroid Build Coastguard Worker     mClientMessageCallback(clientId, mRecvBuffer.data(), packetSize);
185*84e33947SAndroid Build Coastguard Worker   }
186*84e33947SAndroid Build Coastguard Worker }
187*84e33947SAndroid Build Coastguard Worker 
disconnectClient(int clientSocket)188*84e33947SAndroid Build Coastguard Worker void SocketServer::disconnectClient(int clientSocket) {
189*84e33947SAndroid Build Coastguard Worker   {
190*84e33947SAndroid Build Coastguard Worker     std::lock_guard<std::mutex> lock(mClientsMutex);
191*84e33947SAndroid Build Coastguard Worker     mClients.erase(clientSocket);
192*84e33947SAndroid Build Coastguard Worker   }
193*84e33947SAndroid Build Coastguard Worker   close(clientSocket);
194*84e33947SAndroid Build Coastguard Worker 
195*84e33947SAndroid Build Coastguard Worker   bool removed = false;
196*84e33947SAndroid Build Coastguard Worker   for (size_t i = 1; i <= kMaxActiveClients; i++) {
197*84e33947SAndroid Build Coastguard Worker     if (mPollFds[i].fd == clientSocket) {
198*84e33947SAndroid Build Coastguard Worker       mPollFds[i].fd = -1;
199*84e33947SAndroid Build Coastguard Worker       removed = true;
200*84e33947SAndroid Build Coastguard Worker       break;
201*84e33947SAndroid Build Coastguard Worker     }
202*84e33947SAndroid Build Coastguard Worker   }
203*84e33947SAndroid Build Coastguard Worker 
204*84e33947SAndroid Build Coastguard Worker   if (!removed) {
205*84e33947SAndroid Build Coastguard Worker     LOGE("Out of sync");
206*84e33947SAndroid Build Coastguard Worker     assert(removed);
207*84e33947SAndroid Build Coastguard Worker   }
208*84e33947SAndroid Build Coastguard Worker }
209*84e33947SAndroid Build Coastguard Worker 
sendToClientSocket(const void * data,size_t length,int clientSocket,uint16_t clientId)210*84e33947SAndroid Build Coastguard Worker bool SocketServer::sendToClientSocket(const void *data, size_t length,
211*84e33947SAndroid Build Coastguard Worker                                       int clientSocket, uint16_t clientId) {
212*84e33947SAndroid Build Coastguard Worker   errno = 0;
213*84e33947SAndroid Build Coastguard Worker   ssize_t bytesSent = send(clientSocket, data, length, 0);
214*84e33947SAndroid Build Coastguard Worker   if (bytesSent < 0) {
215*84e33947SAndroid Build Coastguard Worker     LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s", length,
216*84e33947SAndroid Build Coastguard Worker          clientId, strerror(errno));
217*84e33947SAndroid Build Coastguard Worker   } else if (bytesSent == 0) {
218*84e33947SAndroid Build Coastguard Worker     LOGW("Client %" PRIu16 " disconnected before message could be delivered",
219*84e33947SAndroid Build Coastguard Worker          clientId);
220*84e33947SAndroid Build Coastguard Worker   } else {
221*84e33947SAndroid Build Coastguard Worker     LOGV("Delivered message of size %zu bytes to client %" PRIu16, length,
222*84e33947SAndroid Build Coastguard Worker          clientId);
223*84e33947SAndroid Build Coastguard Worker   }
224*84e33947SAndroid Build Coastguard Worker 
225*84e33947SAndroid Build Coastguard Worker   return (bytesSent > 0);
226*84e33947SAndroid Build Coastguard Worker }
227*84e33947SAndroid Build Coastguard Worker 
serviceSocket()228*84e33947SAndroid Build Coastguard Worker void SocketServer::serviceSocket() {
229*84e33947SAndroid Build Coastguard Worker   constexpr size_t kListenIndex = 0;
230*84e33947SAndroid Build Coastguard Worker   static_assert(kListenIndex == 0,
231*84e33947SAndroid Build Coastguard Worker                 "Code assumes that the first index is always the listen "
232*84e33947SAndroid Build Coastguard Worker                 "socket");
233*84e33947SAndroid Build Coastguard Worker 
234*84e33947SAndroid Build Coastguard Worker   mPollFds[kListenIndex].fd = mSockFd;
235*84e33947SAndroid Build Coastguard Worker   mPollFds[kListenIndex].events = POLLIN;
236*84e33947SAndroid Build Coastguard Worker 
237*84e33947SAndroid Build Coastguard Worker   // Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM,
238*84e33947SAndroid Build Coastguard Worker   // and ignore other signals
239*84e33947SAndroid Build Coastguard Worker   sigset_t signalMask;
240*84e33947SAndroid Build Coastguard Worker   sigfillset(&signalMask);
241*84e33947SAndroid Build Coastguard Worker   sigdelset(&signalMask, SIGINT);
242*84e33947SAndroid Build Coastguard Worker   sigdelset(&signalMask, SIGTERM);
243*84e33947SAndroid Build Coastguard Worker 
244*84e33947SAndroid Build Coastguard Worker   LOGI("Ready to accept connections");
245*84e33947SAndroid Build Coastguard Worker   while (!sSignalReceived) {
246*84e33947SAndroid Build Coastguard Worker     int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask);
247*84e33947SAndroid Build Coastguard Worker     if (ret == -1) {
248*84e33947SAndroid Build Coastguard Worker       // Don't use TEMP_FAILURE_RETRY since our logic needs to check
249*84e33947SAndroid Build Coastguard Worker       // sSignalReceived to see if it should exit where as TEMP_FAILURE_RETRY
250*84e33947SAndroid Build Coastguard Worker       // is a tight retry loop around ppoll.
251*84e33947SAndroid Build Coastguard Worker       if (errno == EINTR) {
252*84e33947SAndroid Build Coastguard Worker         continue;
253*84e33947SAndroid Build Coastguard Worker       }
254*84e33947SAndroid Build Coastguard Worker       LOGI("Exiting poll loop: %s", strerror(errno));
255*84e33947SAndroid Build Coastguard Worker       break;
256*84e33947SAndroid Build Coastguard Worker     }
257*84e33947SAndroid Build Coastguard Worker 
258*84e33947SAndroid Build Coastguard Worker     if (mPollFds[kListenIndex].revents & POLLIN) {
259*84e33947SAndroid Build Coastguard Worker       acceptClientConnection();
260*84e33947SAndroid Build Coastguard Worker     }
261*84e33947SAndroid Build Coastguard Worker 
262*84e33947SAndroid Build Coastguard Worker     for (size_t i = 1; i <= kMaxActiveClients; i++) {
263*84e33947SAndroid Build Coastguard Worker       if (mPollFds[i].fd < 0) {
264*84e33947SAndroid Build Coastguard Worker         continue;
265*84e33947SAndroid Build Coastguard Worker       }
266*84e33947SAndroid Build Coastguard Worker 
267*84e33947SAndroid Build Coastguard Worker       if (mPollFds[i].revents & POLLIN) {
268*84e33947SAndroid Build Coastguard Worker         handleClientData(mPollFds[i].fd);
269*84e33947SAndroid Build Coastguard Worker       }
270*84e33947SAndroid Build Coastguard Worker     }
271*84e33947SAndroid Build Coastguard Worker   }
272*84e33947SAndroid Build Coastguard Worker }
273*84e33947SAndroid Build Coastguard Worker 
274*84e33947SAndroid Build Coastguard Worker }  // namespace chre
275*84e33947SAndroid Build Coastguard Worker }  // namespace android
276