xref: /aosp_15_r20/frameworks/native/libs/binder/RpcServer.cpp (revision 38e8c45f13ce32b0dcecb25141ffecaf386fa17f)
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "RpcServer"
18 
19 #include <inttypes.h>
20 #include <netinet/tcp.h>
21 #include <poll.h>
22 #include <sys/socket.h>
23 #include <sys/un.h>
24 
25 #include <thread>
26 #include <vector>
27 
28 #include <binder/Functional.h>
29 #include <binder/Parcel.h>
30 #include <binder/RpcServer.h>
31 #include <binder/RpcTransportRaw.h>
32 #include <log/log.h>
33 
34 #include "BuildFlags.h"
35 #include "FdTrigger.h"
36 #include "OS.h"
37 #include "RpcSocketAddress.h"
38 #include "RpcState.h"
39 #include "RpcTransportUtils.h"
40 #include "RpcWireFormat.h"
41 #include "Utils.h"
42 
43 namespace android {
44 
45 constexpr size_t kSessionIdBytes = 32;
46 
47 using namespace android::binder::impl;
48 using android::binder::borrowed_fd;
49 using android::binder::unique_fd;
50 
RpcServer(std::unique_ptr<RpcTransportCtx> ctx)51 RpcServer::RpcServer(std::unique_ptr<RpcTransportCtx> ctx) : mCtx(std::move(ctx)) {}
~RpcServer()52 RpcServer::~RpcServer() {
53     RpcMutexUniqueLock _l(mLock);
54     LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr, "Must call shutdown() before destructor");
55 }
56 
make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory)57 sp<RpcServer> RpcServer::make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory) {
58     // Default is without TLS.
59     if (rpcTransportCtxFactory == nullptr)
60         rpcTransportCtxFactory = binder::os::makeDefaultRpcTransportCtxFactory();
61     auto ctx = rpcTransportCtxFactory->newServerCtx();
62     if (ctx == nullptr) return nullptr;
63     return sp<RpcServer>::make(std::move(ctx));
64 }
65 
setupUnixDomainSocketBootstrapServer(unique_fd bootstrapFd)66 status_t RpcServer::setupUnixDomainSocketBootstrapServer(unique_fd bootstrapFd) {
67     return setupExternalServer(std::move(bootstrapFd), &RpcServer::recvmsgSocketConnection);
68 }
69 
setupUnixDomainServer(const char * path)70 status_t RpcServer::setupUnixDomainServer(const char* path) {
71     return setupSocketServer(UnixSocketAddress(path));
72 }
73 
setupVsockServer(unsigned bindCid,unsigned port,unsigned * assignedPort)74 status_t RpcServer::setupVsockServer(unsigned bindCid, unsigned port, unsigned* assignedPort) {
75     auto status = setupSocketServer(VsockSocketAddress(bindCid, port));
76     if (status != OK) return status;
77 
78     if (assignedPort == nullptr) return OK;
79     sockaddr_vm addr;
80     socklen_t len = sizeof(addr);
81     if (0 != getsockname(mServer.fd.get(), reinterpret_cast<sockaddr*>(&addr), &len)) {
82         status = -errno;
83         ALOGE("setupVsockServer: Failed to getsockname: %s", strerror(-status));
84         return status;
85     }
86 
87     LOG_ALWAYS_FATAL_IF(len != sizeof(addr), "Wrong socket type: len %zu vs len %zu",
88                         static_cast<size_t>(len), sizeof(addr));
89     *assignedPort = addr.svm_port;
90     return OK;
91 }
92 
setupInetServer(const char * address,unsigned int port,unsigned int * assignedPort)93 status_t RpcServer::setupInetServer(const char* address, unsigned int port,
94                                     unsigned int* assignedPort) {
95     if (assignedPort != nullptr) *assignedPort = 0;
96     auto aiStart = InetSocketAddress::getAddrInfo(address, port);
97     if (aiStart == nullptr) return UNKNOWN_ERROR;
98     for (auto ai = aiStart.get(); ai != nullptr; ai = ai->ai_next) {
99         if (ai->ai_addr == nullptr) continue;
100         InetSocketAddress socketAddress(ai->ai_addr, ai->ai_addrlen, address, port);
101         if (status_t status = setupSocketServer(socketAddress); status != OK) {
102             continue;
103         }
104 
105         LOG_ALWAYS_FATAL_IF(socketAddress.addr()->sa_family != AF_INET, "expecting inet");
106         sockaddr_in addr{};
107         socklen_t len = sizeof(addr);
108         if (0 != getsockname(mServer.fd.get(), reinterpret_cast<sockaddr*>(&addr), &len)) {
109             int savedErrno = errno;
110             ALOGE("Could not getsockname at %s: %s", socketAddress.toString().c_str(),
111                   strerror(savedErrno));
112             return -savedErrno;
113         }
114         LOG_ALWAYS_FATAL_IF(len != sizeof(addr), "Wrong socket type: len %zu vs len %zu",
115                             static_cast<size_t>(len), sizeof(addr));
116         unsigned int realPort = ntohs(addr.sin_port);
117         LOG_ALWAYS_FATAL_IF(port != 0 && realPort != port,
118                             "Requesting inet server on %s but it is set up on %u.",
119                             socketAddress.toString().c_str(), realPort);
120 
121         if (assignedPort != nullptr) {
122             *assignedPort = realPort;
123         }
124 
125         return OK;
126     }
127     ALOGE("None of the socket address resolved for %s:%u can be set up as inet server.", address,
128           port);
129     return UNKNOWN_ERROR;
130 }
131 
setMaxThreads(size_t threads)132 void RpcServer::setMaxThreads(size_t threads) {
133     LOG_ALWAYS_FATAL_IF(threads <= 0, "RpcServer is useless without threads");
134     LOG_ALWAYS_FATAL_IF(mJoinThreadRunning, "Cannot set max threads while running");
135     mMaxThreads = threads;
136 }
137 
getMaxThreads()138 size_t RpcServer::getMaxThreads() {
139     return mMaxThreads;
140 }
141 
setProtocolVersion(uint32_t version)142 bool RpcServer::setProtocolVersion(uint32_t version) {
143     if (!RpcState::validateProtocolVersion(version)) {
144         return false;
145     }
146 
147     mProtocolVersion = version;
148     return true;
149 }
150 
setSupportedFileDescriptorTransportModes(const std::vector<RpcSession::FileDescriptorTransportMode> & modes)151 void RpcServer::setSupportedFileDescriptorTransportModes(
152         const std::vector<RpcSession::FileDescriptorTransportMode>& modes) {
153     mSupportedFileDescriptorTransportModes.reset();
154     for (RpcSession::FileDescriptorTransportMode mode : modes) {
155         mSupportedFileDescriptorTransportModes.set(static_cast<size_t>(mode));
156     }
157 }
158 
setRootObject(const sp<IBinder> & binder)159 void RpcServer::setRootObject(const sp<IBinder>& binder) {
160     RpcMutexLockGuard _l(mLock);
161     mRootObjectFactory = nullptr;
162     mRootObjectWeak = mRootObject = binder;
163 }
164 
setRootObjectWeak(const wp<IBinder> & binder)165 void RpcServer::setRootObjectWeak(const wp<IBinder>& binder) {
166     RpcMutexLockGuard _l(mLock);
167     mRootObject.clear();
168     mRootObjectFactory = nullptr;
169     mRootObjectWeak = binder;
170 }
setPerSessionRootObject(std::function<sp<IBinder> (wp<RpcSession> session,const void *,size_t)> && makeObject)171 void RpcServer::setPerSessionRootObject(
172         std::function<sp<IBinder>(wp<RpcSession> session, const void*, size_t)>&& makeObject) {
173     RpcMutexLockGuard _l(mLock);
174     mRootObject.clear();
175     mRootObjectWeak.clear();
176     mRootObjectFactory = std::move(makeObject);
177 }
178 
setConnectionFilter(std::function<bool (const void *,size_t)> && filter)179 void RpcServer::setConnectionFilter(std::function<bool(const void*, size_t)>&& filter) {
180     RpcMutexLockGuard _l(mLock);
181     LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr, "Already joined");
182     mConnectionFilter = std::move(filter);
183 }
184 
setServerSocketModifier(std::function<void (borrowed_fd)> && modifier)185 void RpcServer::setServerSocketModifier(std::function<void(borrowed_fd)>&& modifier) {
186     RpcMutexLockGuard _l(mLock);
187     LOG_ALWAYS_FATAL_IF(mServer.fd.ok(), "Already started");
188     mServerSocketModifier = std::move(modifier);
189 }
190 
getRootObject()191 sp<IBinder> RpcServer::getRootObject() {
192     RpcMutexLockGuard _l(mLock);
193     bool hasWeak = mRootObjectWeak.unsafe_get();
194     sp<IBinder> ret = mRootObjectWeak.promote();
195     ALOGW_IF(hasWeak && ret == nullptr, "RpcServer root object is freed, returning nullptr");
196     return ret;
197 }
198 
getCertificate(RpcCertificateFormat format)199 std::vector<uint8_t> RpcServer::getCertificate(RpcCertificateFormat format) {
200     RpcMutexLockGuard _l(mLock);
201     return mCtx->getCertificate(format);
202 }
203 
joinRpcServer(sp<RpcServer> && thiz)204 static void joinRpcServer(sp<RpcServer>&& thiz) {
205     thiz->join();
206 }
207 
start()208 void RpcServer::start() {
209     RpcMutexLockGuard _l(mLock);
210     LOG_ALWAYS_FATAL_IF(mJoinThread.get(), "Already started!");
211     mJoinThread =
212             std::make_unique<RpcMaybeThread>(&joinRpcServer, sp<RpcServer>::fromExisting(this));
213     rpcJoinIfSingleThreaded(*mJoinThread);
214 }
215 
acceptSocketConnection(const RpcServer & server,RpcTransportFd * out)216 status_t RpcServer::acceptSocketConnection(const RpcServer& server, RpcTransportFd* out) {
217     RpcTransportFd clientSocket(unique_fd(TEMP_FAILURE_RETRY(
218             accept4(server.mServer.fd.get(), nullptr, nullptr, SOCK_CLOEXEC | SOCK_NONBLOCK))));
219     if (!clientSocket.fd.ok()) {
220         int savedErrno = errno;
221         ALOGE("Could not accept4 socket: %s", strerror(savedErrno));
222         return -savedErrno;
223     }
224 
225     *out = std::move(clientSocket);
226     return OK;
227 }
228 
recvmsgSocketConnection(const RpcServer & server,RpcTransportFd * out)229 status_t RpcServer::recvmsgSocketConnection(const RpcServer& server, RpcTransportFd* out) {
230     int zero = 0;
231     iovec iov{&zero, sizeof(zero)};
232     std::vector<std::variant<unique_fd, borrowed_fd>> fds;
233 
234     ssize_t num_bytes = binder::os::receiveMessageFromSocket(server.mServer, &iov, 1, &fds);
235     if (num_bytes < 0) {
236         int savedErrno = errno;
237         ALOGE("Failed recvmsg: %s", strerror(savedErrno));
238         return -savedErrno;
239     }
240     if (num_bytes == 0) {
241         return DEAD_OBJECT;
242     }
243     if (fds.size() != 1) {
244         ALOGE("Expected exactly one fd from recvmsg, got %zu", fds.size());
245         return -EINVAL;
246     }
247 
248     unique_fd fd(std::move(std::get<unique_fd>(fds.back())));
249     if (status_t res = binder::os::setNonBlocking(fd); res != OK) return res;
250 
251     *out = RpcTransportFd(std::move(fd));
252     return OK;
253 }
254 
join()255 void RpcServer::join() {
256 
257     {
258         RpcMutexLockGuard _l(mLock);
259         LOG_ALWAYS_FATAL_IF(!mServer.fd.ok(), "RpcServer must be setup to join.");
260         LOG_ALWAYS_FATAL_IF(mAcceptFn == nullptr, "RpcServer must have an accept() function");
261         LOG_ALWAYS_FATAL_IF(mShutdownTrigger != nullptr, "Already joined");
262         mJoinThreadRunning = true;
263         mShutdownTrigger = FdTrigger::make();
264         LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr, "Cannot create join signaler");
265     }
266 
267     status_t status;
268     while ((status = mShutdownTrigger->triggerablePoll(mServer, POLLIN)) == OK) {
269         std::array<uint8_t, kRpcAddressSize> addr;
270         static_assert(addr.size() >= sizeof(sockaddr_storage), "kRpcAddressSize is too small");
271         socklen_t addrLen = addr.size();
272 
273         RpcTransportFd clientSocket;
274         if ((status = mAcceptFn(*this, &clientSocket)) != OK) {
275             if (status == DEAD_OBJECT)
276                 break;
277             else
278                 continue;
279         }
280 
281         LOG_RPC_DETAIL("accept on fd %d yields fd %d", mServer.fd.get(), clientSocket.fd.get());
282 
283         if (getpeername(clientSocket.fd.get(), reinterpret_cast<sockaddr*>(addr.data()),
284                         &addrLen)) {
285             ALOGE("Could not getpeername socket: %s", strerror(errno));
286             continue;
287         }
288 
289         if (mConnectionFilter != nullptr && !mConnectionFilter(addr.data(), addrLen)) {
290             ALOGE("Dropped client connection fd %d", clientSocket.fd.get());
291             continue;
292         }
293 
294         {
295             RpcMutexLockGuard _l(mLock);
296             RpcMaybeThread thread =
297                     RpcMaybeThread(&RpcServer::establishConnection,
298                                    sp<RpcServer>::fromExisting(this), std::move(clientSocket), addr,
299                                    addrLen, RpcSession::join);
300 
301             auto& threadRef = mConnectingThreads[thread.get_id()];
302             threadRef = std::move(thread);
303             rpcJoinIfSingleThreaded(threadRef);
304         }
305     }
306     LOG_RPC_DETAIL("RpcServer::join exiting with %s", statusToString(status).c_str());
307 
308     if constexpr (kEnableRpcThreads) {
309         RpcMutexLockGuard _l(mLock);
310         mJoinThreadRunning = false;
311     } else {
312         // Multi-threaded builds clear this in shutdown(), but we need it valid
313         // so the loop above exits cleanly
314         mShutdownTrigger = nullptr;
315     }
316     mShutdownCv.notify_all();
317 }
318 
shutdown()319 bool RpcServer::shutdown() {
320     RpcMutexUniqueLock _l(mLock);
321     if (mShutdownTrigger == nullptr) {
322         LOG_RPC_DETAIL("Cannot shutdown. No shutdown trigger installed (already shutdown, or not "
323                        "joined yet?)");
324         return false;
325     }
326 
327     mShutdownTrigger->trigger();
328 
329     for (auto& [id, session] : mSessions) {
330         (void)id;
331         // server lock is a more general lock
332         RpcMutexLockGuard _lSession(session->mMutex);
333         session->mShutdownTrigger->trigger();
334     }
335 
336     if constexpr (!kEnableRpcThreads) {
337         // In single-threaded mode we're done here, everything else that
338         // needs to happen should be at the end of RpcServer::join()
339         return true;
340     }
341 
342     while (mJoinThreadRunning || !mConnectingThreads.empty() || !mSessions.empty()) {
343         if (std::cv_status::timeout == mShutdownCv.wait_for(_l, std::chrono::seconds(1))) {
344             ALOGE("Waiting for RpcServer to shut down (1s w/o progress). Join thread running: %d, "
345                   "Connecting threads: "
346                   "%zu, Sessions: %zu. Is your server deadlocked?",
347                   mJoinThreadRunning, mConnectingThreads.size(), mSessions.size());
348         }
349     }
350 
351     // At this point, we know join() is about to exit, but the thread that calls
352     // join() may not have exited yet.
353     // If RpcServer owns the join thread (aka start() is called), make sure the thread exits;
354     // otherwise ~thread() may call std::terminate(), which may crash the process.
355     // If RpcServer does not own the join thread (aka join() is called directly),
356     // then the owner of RpcServer is responsible for cleaning up that thread.
357     if (mJoinThread.get()) {
358         mJoinThread->join();
359         mJoinThread.reset();
360     }
361 
362     mServer = RpcTransportFd();
363 
364     LOG_RPC_DETAIL("Finished waiting on shutdown.");
365 
366     mShutdownTrigger = nullptr;
367     return true;
368 }
369 
listSessions()370 std::vector<sp<RpcSession>> RpcServer::listSessions() {
371     RpcMutexLockGuard _l(mLock);
372     std::vector<sp<RpcSession>> sessions;
373     for (auto& [id, session] : mSessions) {
374         (void)id;
375         sessions.push_back(session);
376     }
377     return sessions;
378 }
379 
numUninitializedSessions()380 size_t RpcServer::numUninitializedSessions() {
381     RpcMutexLockGuard _l(mLock);
382     return mConnectingThreads.size();
383 }
384 
establishConnection(sp<RpcServer> && server,RpcTransportFd clientFd,std::array<uint8_t,kRpcAddressSize> addr,size_t addrLen,std::function<void (sp<RpcSession> &&,RpcSession::PreJoinSetupResult &&)> && joinFn)385 void RpcServer::establishConnection(
386         sp<RpcServer>&& server, RpcTransportFd clientFd, std::array<uint8_t, kRpcAddressSize> addr,
387         size_t addrLen,
388         std::function<void(sp<RpcSession>&&, RpcSession::PreJoinSetupResult&&)>&& joinFn) {
389     // mShutdownTrigger can only be cleared once connection threads have joined.
390     // It must be set before this thread is started
391     LOG_ALWAYS_FATAL_IF(server->mShutdownTrigger == nullptr);
392     LOG_ALWAYS_FATAL_IF(server->mCtx == nullptr);
393 
394     status_t status = OK;
395 
396     int clientFdForLog = clientFd.fd.get();
397     auto client = server->mCtx->newTransport(std::move(clientFd), server->mShutdownTrigger.get());
398     if (client == nullptr) {
399         ALOGE("Dropping accept4()-ed socket because sslAccept fails");
400         status = DEAD_OBJECT;
401         // still need to cleanup before we can return
402     } else {
403         LOG_RPC_DETAIL("Created RpcTransport %p for client fd %d", client.get(), clientFdForLog);
404     }
405 
406     RpcConnectionHeader header;
407     if (status == OK) {
408         iovec iov{&header, sizeof(header)};
409         status = client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1,
410                                                 std::nullopt, /*ancillaryFds=*/nullptr);
411         if (status != OK) {
412             ALOGE("Failed to read ID for client connecting to RPC server: %s",
413                   statusToString(status).c_str());
414             // still need to cleanup before we can return
415         }
416     }
417 
418     std::vector<uint8_t> sessionId;
419     if (status == OK) {
420         if (header.sessionIdSize > 0) {
421             if (header.sessionIdSize == kSessionIdBytes) {
422                 sessionId.resize(header.sessionIdSize);
423                 iovec iov{sessionId.data(), sessionId.size()};
424                 status = client->interruptableReadFully(server->mShutdownTrigger.get(), &iov, 1,
425                                                         std::nullopt, /*ancillaryFds=*/nullptr);
426                 if (status != OK) {
427                     ALOGE("Failed to read session ID for client connecting to RPC server: %s",
428                           statusToString(status).c_str());
429                     // still need to cleanup before we can return
430                 }
431             } else {
432                 ALOGE("Malformed session ID. Expecting session ID of size %zu but got %" PRIu16,
433                       kSessionIdBytes, header.sessionIdSize);
434                 status = BAD_VALUE;
435             }
436         }
437     }
438 
439     bool incoming = false;
440     uint32_t protocolVersion = 0;
441     bool requestingNewSession = false;
442 
443     if (status == OK) {
444         incoming = header.options & RPC_CONNECTION_OPTION_INCOMING;
445         protocolVersion = std::min(header.version,
446                                    server->mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION));
447         requestingNewSession = sessionId.empty();
448 
449         if (requestingNewSession) {
450             RpcNewSessionResponse response{
451                     .version = protocolVersion,
452             };
453 
454             iovec iov{&response, sizeof(response)};
455             status = client->interruptableWriteFully(server->mShutdownTrigger.get(), &iov, 1,
456                                                      std::nullopt, nullptr);
457             if (status != OK) {
458                 ALOGE("Failed to send new session response: %s", statusToString(status).c_str());
459                 // still need to cleanup before we can return
460             }
461         }
462     }
463 
464     RpcMaybeThread thisThread;
465     sp<RpcSession> session;
466     {
467         RpcMutexUniqueLock _l(server->mLock);
468 
469         auto threadId = server->mConnectingThreads.find(rpc_this_thread::get_id());
470         LOG_ALWAYS_FATAL_IF(threadId == server->mConnectingThreads.end(),
471                             "Must establish connection on owned thread");
472         thisThread = std::move(threadId->second);
473         auto detachGuardLambda = [&]() {
474             thisThread.detach();
475             _l.unlock();
476             server->mShutdownCv.notify_all();
477         };
478         auto detachGuard = make_scope_guard(std::ref(detachGuardLambda));
479         server->mConnectingThreads.erase(threadId);
480 
481         if (status != OK || server->mShutdownTrigger->isTriggered()) {
482             return;
483         }
484 
485         if (requestingNewSession) {
486             if (incoming) {
487                 ALOGE("Cannot create a new session with an incoming connection, would leak");
488                 return;
489             }
490 
491             // Uniquely identify session at the application layer. Even if a
492             // client/server use the same certificates, if they create multiple
493             // sessions, we still want to distinguish between them.
494             sessionId.resize(kSessionIdBytes);
495             size_t tries = 0;
496             do {
497                 // don't block if there is some entropy issue
498                 if (tries++ > 5) {
499                     ALOGE("Cannot find new address: %s",
500                           HexString(sessionId.data(), sessionId.size()).c_str());
501                     return;
502                 }
503 
504                 auto status = binder::os::getRandomBytes(sessionId.data(), sessionId.size());
505                 if (status != OK) {
506                     ALOGE("Failed to read random session ID: %s", statusToString(status).c_str());
507                     return;
508                 }
509             } while (server->mSessions.end() != server->mSessions.find(sessionId));
510 
511             session = sp<RpcSession>::make(nullptr);
512             session->setMaxIncomingThreads(server->mMaxThreads);
513             if (!session->setProtocolVersion(protocolVersion)) return;
514 
515             if (header.fileDescriptorTransportMode <
516                         server->mSupportedFileDescriptorTransportModes.size() &&
517                 server->mSupportedFileDescriptorTransportModes.test(
518                         header.fileDescriptorTransportMode)) {
519                 session->setFileDescriptorTransportMode(
520                         static_cast<RpcSession::FileDescriptorTransportMode>(
521                                 header.fileDescriptorTransportMode));
522             } else {
523                 ALOGE("Rejecting connection: FileDescriptorTransportMode is not supported: %hhu",
524                       header.fileDescriptorTransportMode);
525                 return;
526             }
527 
528             // if null, falls back to server root
529             sp<IBinder> sessionSpecificRoot;
530             if (server->mRootObjectFactory != nullptr) {
531                 sessionSpecificRoot =
532                         server->mRootObjectFactory(wp<RpcSession>(session), addr.data(), addrLen);
533                 if (sessionSpecificRoot == nullptr) {
534                     ALOGE("Warning: server returned null from root object factory");
535                 }
536             }
537 
538             if (!session->setForServer(server,
539                                        sp<RpcServer::EventListener>::fromExisting(
540                                                static_cast<RpcServer::EventListener*>(
541                                                        server.get())),
542                                        sessionId, sessionSpecificRoot)) {
543                 ALOGE("Failed to attach server to session");
544                 return;
545             }
546 
547             server->mSessions[sessionId] = session;
548         } else {
549             auto it = server->mSessions.find(sessionId);
550             if (it == server->mSessions.end()) {
551                 ALOGE("Cannot add thread, no record of session with ID %s",
552                       HexString(sessionId.data(), sessionId.size()).c_str());
553                 return;
554             }
555             session = it->second;
556         }
557 
558         if (incoming) {
559             LOG_ALWAYS_FATAL_IF(OK != session->addOutgoingConnection(std::move(client), true),
560                                 "server state must already be initialized");
561             return;
562         }
563 
564         detachGuard.release();
565         session->preJoinThreadOwnership(std::move(thisThread));
566     }
567 
568     auto setupResult = session->preJoinSetup(std::move(client));
569 
570     // avoid strong cycle
571     server = nullptr;
572 
573     joinFn(std::move(session), std::move(setupResult));
574 }
575 
setupSocketServer(const RpcSocketAddress & addr)576 status_t RpcServer::setupSocketServer(const RpcSocketAddress& addr) {
577     LOG_RPC_DETAIL("Setting up socket server %s", addr.toString().c_str());
578     LOG_ALWAYS_FATAL_IF(hasServer(), "Each RpcServer can only have one server.");
579 
580     unique_fd socket_fd(TEMP_FAILURE_RETRY(
581             socket(addr.addr()->sa_family, SOCK_STREAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0)));
582     if (!socket_fd.ok()) {
583         int savedErrno = errno;
584         ALOGE("Could not create socket at %s: %s", addr.toString().c_str(), strerror(savedErrno));
585         return -savedErrno;
586     }
587 
588     if (addr.addr()->sa_family == AF_INET || addr.addr()->sa_family == AF_INET6) {
589         int noDelay = 1;
590         int result =
591                 setsockopt(socket_fd.get(), IPPROTO_TCP, TCP_NODELAY, &noDelay, sizeof(noDelay));
592         if (result < 0) {
593             int savedErrno = errno;
594             ALOGE("Could not set TCP_NODELAY on  %s", strerror(savedErrno));
595             return -savedErrno;
596         }
597     }
598 
599     {
600         RpcMutexLockGuard _l(mLock);
601         if (mServerSocketModifier != nullptr) {
602             mServerSocketModifier(socket_fd);
603         }
604     }
605 
606     if (0 != TEMP_FAILURE_RETRY(bind(socket_fd.get(), addr.addr(), addr.addrSize()))) {
607         int savedErrno = errno;
608         ALOGE("Could not bind socket at %s: %s", addr.toString().c_str(), strerror(savedErrno));
609         return -savedErrno;
610     }
611 
612     return setupRawSocketServer(std::move(socket_fd));
613 }
614 
setupRawSocketServer(unique_fd socket_fd)615 status_t RpcServer::setupRawSocketServer(unique_fd socket_fd) {
616     LOG_ALWAYS_FATAL_IF(!socket_fd.ok(), "Socket must be setup to listen.");
617 
618     // Right now, we create all threads at once, making accept4 slow. To avoid hanging the client,
619     // the backlog is increased to a large number.
620     // TODO(b/189955605): Once we create threads dynamically & lazily, the backlog can be reduced
621     //  to 1.
622     if (0 != TEMP_FAILURE_RETRY(listen(socket_fd.get(), 50 /*backlog*/))) {
623         int savedErrno = errno;
624         ALOGE("Could not listen initialized Unix socket: %s", strerror(savedErrno));
625         return -savedErrno;
626     }
627     if (status_t status = setupExternalServer(std::move(socket_fd)); status != OK) {
628         ALOGE("Another thread has set up server while calling setupSocketServer. Race?");
629         return status;
630     }
631     return OK;
632 }
633 
onSessionAllIncomingThreadsEnded(const sp<RpcSession> & session)634 void RpcServer::onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) {
635     const std::vector<uint8_t>& id = session->mId;
636     LOG_ALWAYS_FATAL_IF(id.empty(), "Server sessions must be initialized with ID");
637     LOG_RPC_DETAIL("Dropping session with address %s", HexString(id.data(), id.size()).c_str());
638 
639     RpcMutexLockGuard _l(mLock);
640     auto it = mSessions.find(id);
641     LOG_ALWAYS_FATAL_IF(it == mSessions.end(), "Bad state, unknown session id %s",
642                         HexString(id.data(), id.size()).c_str());
643     LOG_ALWAYS_FATAL_IF(it->second != session, "Bad state, session has id mismatch %s",
644                         HexString(id.data(), id.size()).c_str());
645     (void)mSessions.erase(it);
646 }
647 
onSessionIncomingThreadEnded()648 void RpcServer::onSessionIncomingThreadEnded() {
649     mShutdownCv.notify_all();
650 }
651 
hasServer()652 bool RpcServer::hasServer() {
653     RpcMutexLockGuard _l(mLock);
654     return mServer.fd.ok();
655 }
656 
releaseServer()657 unique_fd RpcServer::releaseServer() {
658     RpcMutexLockGuard _l(mLock);
659     return std::move(mServer.fd);
660 }
661 
setupExternalServer(unique_fd serverFd,std::function<status_t (const RpcServer &,RpcTransportFd *)> && acceptFn)662 status_t RpcServer::setupExternalServer(
663         unique_fd serverFd, std::function<status_t(const RpcServer&, RpcTransportFd*)>&& acceptFn) {
664     RpcMutexLockGuard _l(mLock);
665     if (mServer.fd.ok()) {
666         ALOGE("Each RpcServer can only have one server.");
667         return INVALID_OPERATION;
668     }
669     mServer = std::move(serverFd);
670     mAcceptFn = std::move(acceptFn);
671     return OK;
672 }
673 
setupExternalServer(unique_fd serverFd)674 status_t RpcServer::setupExternalServer(unique_fd serverFd) {
675     return setupExternalServer(std::move(serverFd), &RpcServer::acceptSocketConnection);
676 }
677 
hasActiveRequests()678 bool RpcServer::hasActiveRequests() {
679     RpcMutexLockGuard _l(mLock);
680     for (const auto& [_, session] : mSessions) {
681         if (session->hasActiveRequests()) {
682             return true;
683         }
684     }
685     return !mServer.isInPollingState();
686 }
687 
688 } // namespace android
689