xref: /aosp_15_r20/external/sandboxed-api/sandboxed_api/sandbox2/comms.cc (revision ec63e07ab9515d95e79c211197c445ef84cefa6a)
1*ec63e07aSXin Li // Copyright 2019 Google LLC
2*ec63e07aSXin Li //
3*ec63e07aSXin Li // Licensed under the Apache License, Version 2.0 (the "License");
4*ec63e07aSXin Li // you may not use this file except in compliance with the License.
5*ec63e07aSXin Li // You may obtain a copy of the License at
6*ec63e07aSXin Li //
7*ec63e07aSXin Li //     https://www.apache.org/licenses/LICENSE-2.0
8*ec63e07aSXin Li //
9*ec63e07aSXin Li // Unless required by applicable law or agreed to in writing, software
10*ec63e07aSXin Li // distributed under the License is distributed on an "AS IS" BASIS,
11*ec63e07aSXin Li // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*ec63e07aSXin Li // See the License for the specific language governing permissions and
13*ec63e07aSXin Li // limitations under the License.
14*ec63e07aSXin Li 
15*ec63e07aSXin Li // Implementation of sandbox2::Comms class.
16*ec63e07aSXin Li //
17*ec63e07aSXin Li // Warning: This class is not multi-thread safe (for callers). It uses a single
18*ec63e07aSXin Li // communications channel (an AF_UNIX socket), so it requires exactly one sender
19*ec63e07aSXin Li // and one receiver. If you plan to use it from many threads, provide external
20*ec63e07aSXin Li // exclusive locking.
21*ec63e07aSXin Li 
22*ec63e07aSXin Li #include "sandboxed_api/sandbox2/comms.h"
23*ec63e07aSXin Li 
24*ec63e07aSXin Li #include <sys/socket.h>
25*ec63e07aSXin Li #include <sys/uio.h>
26*ec63e07aSXin Li #include <sys/un.h>
27*ec63e07aSXin Li #include <syscall.h>
28*ec63e07aSXin Li #include <unistd.h>
29*ec63e07aSXin Li 
30*ec63e07aSXin Li #include <atomic>
31*ec63e07aSXin Li #include <cerrno>
32*ec63e07aSXin Li #include <cstdint>
33*ec63e07aSXin Li #include <cstdlib>
34*ec63e07aSXin Li #include <cstring>
35*ec63e07aSXin Li #include <functional>
36*ec63e07aSXin Li #include <memory>
37*ec63e07aSXin Li #include <string>
38*ec63e07aSXin Li #include <utility>
39*ec63e07aSXin Li #include <vector>
40*ec63e07aSXin Li 
41*ec63e07aSXin Li #include "absl/base/dynamic_annotations.h"
42*ec63e07aSXin Li #include "absl/status/status.h"
43*ec63e07aSXin Li #include "absl/status/statusor.h"
44*ec63e07aSXin Li #include "absl/strings/numbers.h"
45*ec63e07aSXin Li #include "absl/strings/str_format.h"
46*ec63e07aSXin Li #include "absl/strings/string_view.h"
47*ec63e07aSXin Li #include "google/protobuf/message_lite.h"
48*ec63e07aSXin Li #include "sandboxed_api/sandbox2/util.h"
49*ec63e07aSXin Li #include "sandboxed_api/util/fileops.h"
50*ec63e07aSXin Li #include "sandboxed_api/util/raw_logging.h"
51*ec63e07aSXin Li #include "sandboxed_api/util/status.h"
52*ec63e07aSXin Li #include "sandboxed_api/util/status.pb.h"
53*ec63e07aSXin Li #include "sandboxed_api/util/status_macros.h"
54*ec63e07aSXin Li 
55*ec63e07aSXin Li namespace sandbox2 {
56*ec63e07aSXin Li 
57*ec63e07aSXin Li class PotentiallyBlockingRegion {
58*ec63e07aSXin Li  public:
~PotentiallyBlockingRegion()59*ec63e07aSXin Li   ~PotentiallyBlockingRegion() {
60*ec63e07aSXin Li     // Do nothing. Not defaulted to avoid "unused variable" warnings.
61*ec63e07aSXin Li   }
62*ec63e07aSXin Li };
63*ec63e07aSXin Li 
64*ec63e07aSXin Li namespace {
65*ec63e07aSXin Li 
66*ec63e07aSXin Li using sapi::file_util::fileops::FDCloser;
67*ec63e07aSXin Li 
IsFatalError(int saved_errno)68*ec63e07aSXin Li bool IsFatalError(int saved_errno) {
69*ec63e07aSXin Li   return saved_errno != EAGAIN && saved_errno != EWOULDBLOCK &&
70*ec63e07aSXin Li          saved_errno != EFAULT && saved_errno != EINTR &&
71*ec63e07aSXin Li          saved_errno != EINVAL && saved_errno != ENOMEM;
72*ec63e07aSXin Li }
73*ec63e07aSXin Li 
GetDefaultCommsFd()74*ec63e07aSXin Li int GetDefaultCommsFd() {
75*ec63e07aSXin Li   if (const char* var = getenv(Comms::kSandbox2CommsFDEnvVar); var) {
76*ec63e07aSXin Li     int fd;
77*ec63e07aSXin Li     SAPI_RAW_CHECK(absl::SimpleAtoi(var, &fd), "cannot parse comms fd var");
78*ec63e07aSXin Li     unsetenv(Comms::kSandbox2CommsFDEnvVar);
79*ec63e07aSXin Li     return fd;
80*ec63e07aSXin Li   }
81*ec63e07aSXin Li   return Comms::kSandbox2ClientCommsFD;
82*ec63e07aSXin Li }
83*ec63e07aSXin Li 
CreateSockaddrUn(const std::string & socket_name,bool abstract_uds,sockaddr_un * sun)84*ec63e07aSXin Li socklen_t CreateSockaddrUn(const std::string& socket_name, bool abstract_uds,
85*ec63e07aSXin Li                            sockaddr_un* sun) {
86*ec63e07aSXin Li   sun->sun_family = AF_UNIX;
87*ec63e07aSXin Li   bzero(sun->sun_path, sizeof(sun->sun_path));
88*ec63e07aSXin Li   socklen_t slen = sizeof(sun->sun_family) + strlen(socket_name.c_str());
89*ec63e07aSXin Li   if (abstract_uds) {
90*ec63e07aSXin Li     // Create an 'abstract socket address' by specifying a leading null byte.
91*ec63e07aSXin Li     // The remainder of the path is used as a unique name, but no file is
92*ec63e07aSXin Li     // created on the filesystem. No need to NUL-terminate the string. See `man
93*ec63e07aSXin Li     // 7 unix` for further explanation.
94*ec63e07aSXin Li     strncpy(&sun->sun_path[1], socket_name.c_str(), sizeof(sun->sun_path) - 1);
95*ec63e07aSXin Li     // Len is complicated - it's essentially size of the path, plus initial
96*ec63e07aSXin Li     // NUL-byte, minus size of the sun.sun_family.
97*ec63e07aSXin Li     slen++;
98*ec63e07aSXin Li   } else {
99*ec63e07aSXin Li     // Create the socket address as it was passed from the constructor.
100*ec63e07aSXin Li     strncpy(&sun->sun_path[0], socket_name.c_str(), sizeof(sun->sun_path));
101*ec63e07aSXin Li   }
102*ec63e07aSXin Li 
103*ec63e07aSXin Li   // This takes care of the socket address overflow.
104*ec63e07aSXin Li   if (slen > sizeof(sockaddr_un)) {
105*ec63e07aSXin Li     SAPI_RAW_LOG(ERROR, "Socket address is too long, will be truncated");
106*ec63e07aSXin Li     slen = sizeof(sockaddr_un);
107*ec63e07aSXin Li   }
108*ec63e07aSXin Li   return slen;
109*ec63e07aSXin Li }
110*ec63e07aSXin Li }  // namespace
111*ec63e07aSXin Li 
Comms(int fd,absl::string_view name)112*ec63e07aSXin Li Comms::Comms(int fd, absl::string_view name) : connection_fd_(fd) {
113*ec63e07aSXin Li   // Generate a unique and meaningful socket name for this FD.
114*ec63e07aSXin Li   // Note: getpid()/gettid() are non-blocking syscalls.
115*ec63e07aSXin Li   if (name.empty()) {
116*ec63e07aSXin Li     name_ = absl::StrFormat("sandbox2::Comms:FD=%d/PID=%d/TID=%ld", fd,
117*ec63e07aSXin Li                             getpid(), syscall(__NR_gettid));
118*ec63e07aSXin Li   } else {
119*ec63e07aSXin Li     name_ = std::string(name);
120*ec63e07aSXin Li   }
121*ec63e07aSXin Li 
122*ec63e07aSXin Li   // File descriptor is already connected.
123*ec63e07aSXin Li   state_ = State::kConnected;
124*ec63e07aSXin Li }
125*ec63e07aSXin Li 
Comms(Comms::DefaultConnectionTag)126*ec63e07aSXin Li Comms::Comms(Comms::DefaultConnectionTag) : Comms(GetDefaultCommsFd()) {}
127*ec63e07aSXin Li 
~Comms()128*ec63e07aSXin Li Comms::~Comms() { Terminate(); }
129*ec63e07aSXin Li 
GetConnectionFD() const130*ec63e07aSXin Li int Comms::GetConnectionFD() const {
131*ec63e07aSXin Li   return connection_fd_.get();
132*ec63e07aSXin Li }
133*ec63e07aSXin Li 
Create(absl::string_view socket_name,bool abstract_uds)134*ec63e07aSXin Li absl::StatusOr<ListeningComms> ListeningComms::Create(
135*ec63e07aSXin Li     absl::string_view socket_name, bool abstract_uds) {
136*ec63e07aSXin Li   ListeningComms comms(std::string(socket_name), abstract_uds);
137*ec63e07aSXin Li   SAPI_RETURN_IF_ERROR(comms.Listen());
138*ec63e07aSXin Li   return comms;
139*ec63e07aSXin Li }
140*ec63e07aSXin Li 
Listen()141*ec63e07aSXin Li absl::Status ListeningComms::Listen() {
142*ec63e07aSXin Li   bind_fd_ = FDCloser(socket(AF_UNIX, SOCK_STREAM, 0));  // Non-blocking
143*ec63e07aSXin Li   if (bind_fd_.get() == -1) {
144*ec63e07aSXin Li     return absl::ErrnoToStatus(errno, "socket(AF_UNIX) failed");
145*ec63e07aSXin Li   }
146*ec63e07aSXin Li 
147*ec63e07aSXin Li   sockaddr_un sus;
148*ec63e07aSXin Li   socklen_t slen = CreateSockaddrUn(socket_name_, abstract_uds_, &sus);
149*ec63e07aSXin Li   // bind() is non-blocking.
150*ec63e07aSXin Li   if (bind(bind_fd_.get(), reinterpret_cast<sockaddr*>(&sus), slen) == -1) {
151*ec63e07aSXin Li     return absl::ErrnoToStatus(errno, "bind failed");
152*ec63e07aSXin Li   }
153*ec63e07aSXin Li 
154*ec63e07aSXin Li   // listen() non-blocking.
155*ec63e07aSXin Li   if (listen(bind_fd_.get(), 0) == -1) {
156*ec63e07aSXin Li     return absl::ErrnoToStatus(errno, "listen failed");
157*ec63e07aSXin Li   }
158*ec63e07aSXin Li 
159*ec63e07aSXin Li   SAPI_RAW_VLOG(1, "Listening at: %s", socket_name_.c_str());
160*ec63e07aSXin Li   return absl::OkStatus();
161*ec63e07aSXin Li }
162*ec63e07aSXin Li 
Accept()163*ec63e07aSXin Li absl::StatusOr<Comms> ListeningComms::Accept() {
164*ec63e07aSXin Li   sockaddr_un suc;
165*ec63e07aSXin Li   socklen_t len = sizeof(suc);
166*ec63e07aSXin Li   int connection_fd;
167*ec63e07aSXin Li   {
168*ec63e07aSXin Li     PotentiallyBlockingRegion region;
169*ec63e07aSXin Li     connection_fd = TEMP_FAILURE_RETRY(
170*ec63e07aSXin Li         accept(bind_fd_.get(), reinterpret_cast<sockaddr*>(&suc), &len));
171*ec63e07aSXin Li   }
172*ec63e07aSXin Li   if (connection_fd == -1) {
173*ec63e07aSXin Li     return absl::ErrnoToStatus(errno, "accept failed");
174*ec63e07aSXin Li   }
175*ec63e07aSXin Li   SAPI_RAW_VLOG(1, "Accepted connection at: %s, fd: %d", socket_name_.c_str(),
176*ec63e07aSXin Li                 connection_fd);
177*ec63e07aSXin Li   return Comms(connection_fd, socket_name_);
178*ec63e07aSXin Li }
179*ec63e07aSXin Li 
Connect(const std::string & socket_name,bool abstract_uds)180*ec63e07aSXin Li absl::StatusOr<Comms> Comms::Connect(const std::string& socket_name,
181*ec63e07aSXin Li                                      bool abstract_uds) {
182*ec63e07aSXin Li   FDCloser connection_fd(socket(AF_UNIX, SOCK_STREAM, 0));  // Non-blocking
183*ec63e07aSXin Li   if (connection_fd.get() == -1) {
184*ec63e07aSXin Li     return absl::ErrnoToStatus(errno, "socket(AF_UNIX)");
185*ec63e07aSXin Li   }
186*ec63e07aSXin Li 
187*ec63e07aSXin Li   sockaddr_un suc;
188*ec63e07aSXin Li   socklen_t slen = CreateSockaddrUn(socket_name, abstract_uds, &suc);
189*ec63e07aSXin Li   int ret;
190*ec63e07aSXin Li   {
191*ec63e07aSXin Li     PotentiallyBlockingRegion region;
192*ec63e07aSXin Li     ret = TEMP_FAILURE_RETRY(
193*ec63e07aSXin Li         connect(connection_fd.get(), reinterpret_cast<sockaddr*>(&suc), slen));
194*ec63e07aSXin Li   }
195*ec63e07aSXin Li   if (ret == -1) {
196*ec63e07aSXin Li     return absl::ErrnoToStatus(errno, "connect(connection_fd)");
197*ec63e07aSXin Li   }
198*ec63e07aSXin Li 
199*ec63e07aSXin Li   SAPI_RAW_VLOG(1, "Connected to: %s, fd: %d", socket_name.c_str(),
200*ec63e07aSXin Li                 connection_fd.get());
201*ec63e07aSXin Li   return Comms(connection_fd.Release(), socket_name);
202*ec63e07aSXin Li }
203*ec63e07aSXin Li 
Terminate()204*ec63e07aSXin Li void Comms::Terminate() {
205*ec63e07aSXin Li   state_ = State::kTerminated;
206*ec63e07aSXin Li 
207*ec63e07aSXin Li   connection_fd_.Close();
208*ec63e07aSXin Li   listening_comms_.reset();
209*ec63e07aSXin Li }
210*ec63e07aSXin Li 
SendTLV(uint32_t tag,size_t length,const void * value)211*ec63e07aSXin Li bool Comms::SendTLV(uint32_t tag, size_t length, const void* value) {
212*ec63e07aSXin Li   if (length > GetMaxMsgSize()) {
213*ec63e07aSXin Li     SAPI_RAW_LOG(ERROR, "Maximum TLV message size exceeded: (%zu > %zu)",
214*ec63e07aSXin Li                  length, GetMaxMsgSize());
215*ec63e07aSXin Li     return false;
216*ec63e07aSXin Li   }
217*ec63e07aSXin Li   if (length > kWarnMsgSize) {
218*ec63e07aSXin Li     // TODO(cblichmann): Use LOG_FIRST_N once Abseil logging is released.
219*ec63e07aSXin Li     static std::atomic<int> times_warned = 0;
220*ec63e07aSXin Li     if (times_warned.fetch_add(1, std::memory_order_relaxed) < 10) {
221*ec63e07aSXin Li       SAPI_RAW_LOG(
222*ec63e07aSXin Li           WARNING,
223*ec63e07aSXin Li           "TLV message of size %zu detected. Please consider switching "
224*ec63e07aSXin Li           "to Buffer API instead.",
225*ec63e07aSXin Li           length);
226*ec63e07aSXin Li     }
227*ec63e07aSXin Li   }
228*ec63e07aSXin Li 
229*ec63e07aSXin Li   SAPI_RAW_VLOG(3, "Sending a TLV message, tag: 0x%08x, length: %zu", tag,
230*ec63e07aSXin Li                 length);
231*ec63e07aSXin Li 
232*ec63e07aSXin Li   // To maintain consistency with `RecvTL()`, we wrap `tag` and `length` in a TL
233*ec63e07aSXin Li   // struct.
234*ec63e07aSXin Li   const InternalTLV tl = {
235*ec63e07aSXin Li       .tag = tag,
236*ec63e07aSXin Li       .len = length,
237*ec63e07aSXin Li   };
238*ec63e07aSXin Li 
239*ec63e07aSXin Li   if (length + sizeof(tl) > kSendTLVTempBufferSize) {
240*ec63e07aSXin Li     if (!Send(&tl, sizeof(tl))) {
241*ec63e07aSXin Li       return false;
242*ec63e07aSXin Li     }
243*ec63e07aSXin Li     return Send(value, length);
244*ec63e07aSXin Li   }
245*ec63e07aSXin Li   uint8_t tlv[kSendTLVTempBufferSize];
246*ec63e07aSXin Li   memcpy(tlv, &tl, sizeof(tl));
247*ec63e07aSXin Li   memcpy(reinterpret_cast<uint8_t*>(tlv) + sizeof(tl), value, length);
248*ec63e07aSXin Li 
249*ec63e07aSXin Li   return Send(&tlv, sizeof(tl) + length);
250*ec63e07aSXin Li }
251*ec63e07aSXin Li 
RecvString(std::string * v)252*ec63e07aSXin Li bool Comms::RecvString(std::string* v) {
253*ec63e07aSXin Li   uint32_t tag;
254*ec63e07aSXin Li   if (!RecvTLV(&tag, v)) {
255*ec63e07aSXin Li     return false;
256*ec63e07aSXin Li   }
257*ec63e07aSXin Li 
258*ec63e07aSXin Li   if (tag != kTagString) {
259*ec63e07aSXin Li     SAPI_RAW_LOG(ERROR, "Expected (kTagString == 0x%x), got: 0x%x", kTagString,
260*ec63e07aSXin Li                  tag);
261*ec63e07aSXin Li     return false;
262*ec63e07aSXin Li   }
263*ec63e07aSXin Li   return true;
264*ec63e07aSXin Li }
265*ec63e07aSXin Li 
SendString(const std::string & v)266*ec63e07aSXin Li bool Comms::SendString(const std::string& v) {
267*ec63e07aSXin Li   return SendTLV(kTagString, v.length(), v.c_str());
268*ec63e07aSXin Li }
269*ec63e07aSXin Li 
RecvBytes(std::vector<uint8_t> * buffer)270*ec63e07aSXin Li bool Comms::RecvBytes(std::vector<uint8_t>* buffer) {
271*ec63e07aSXin Li   uint32_t tag;
272*ec63e07aSXin Li   if (!RecvTLV(&tag, buffer)) {
273*ec63e07aSXin Li     return false;
274*ec63e07aSXin Li   }
275*ec63e07aSXin Li   if (tag != kTagBytes) {
276*ec63e07aSXin Li     buffer->clear();
277*ec63e07aSXin Li     SAPI_RAW_LOG(ERROR, "Expected (kTagBytes == 0x%x), got: 0x%u", kTagBytes,
278*ec63e07aSXin Li                  tag);
279*ec63e07aSXin Li     return false;
280*ec63e07aSXin Li   }
281*ec63e07aSXin Li   return true;
282*ec63e07aSXin Li }
283*ec63e07aSXin Li 
SendBytes(const uint8_t * v,size_t len)284*ec63e07aSXin Li bool Comms::SendBytes(const uint8_t* v, size_t len) {
285*ec63e07aSXin Li   return SendTLV(kTagBytes, len, v);
286*ec63e07aSXin Li }
287*ec63e07aSXin Li 
SendBytes(const std::vector<uint8_t> & buffer)288*ec63e07aSXin Li bool Comms::SendBytes(const std::vector<uint8_t>& buffer) {
289*ec63e07aSXin Li   return SendBytes(buffer.data(), buffer.size());
290*ec63e07aSXin Li }
291*ec63e07aSXin Li 
RecvCreds(pid_t * pid,uid_t * uid,gid_t * gid)292*ec63e07aSXin Li bool Comms::RecvCreds(pid_t* pid, uid_t* uid, gid_t* gid) {
293*ec63e07aSXin Li   ucred uc;
294*ec63e07aSXin Li   socklen_t sls = sizeof(uc);
295*ec63e07aSXin Li   int rc;
296*ec63e07aSXin Li   {
297*ec63e07aSXin Li     // Not completely sure if getsockopt() can block on SO_PEERCRED, but let's
298*ec63e07aSXin Li     // play it safe.
299*ec63e07aSXin Li     PotentiallyBlockingRegion region;
300*ec63e07aSXin Li     rc = getsockopt(GetConnectionFD(), SOL_SOCKET, SO_PEERCRED, &uc, &sls);
301*ec63e07aSXin Li   }
302*ec63e07aSXin Li   if (rc == -1) {
303*ec63e07aSXin Li     SAPI_RAW_PLOG(ERROR, "getsockopt(SO_PEERCRED)");
304*ec63e07aSXin Li     return false;
305*ec63e07aSXin Li   }
306*ec63e07aSXin Li   *pid = uc.pid;
307*ec63e07aSXin Li   *uid = uc.uid;
308*ec63e07aSXin Li   *gid = uc.gid;
309*ec63e07aSXin Li 
310*ec63e07aSXin Li   SAPI_RAW_VLOG(2, "Received credentials from PID/UID/GID: %d/%u/%u", *pid,
311*ec63e07aSXin Li                 *uid, *gid);
312*ec63e07aSXin Li   return true;
313*ec63e07aSXin Li }
314*ec63e07aSXin Li 
RecvFD(int * fd)315*ec63e07aSXin Li bool Comms::RecvFD(int* fd) {
316*ec63e07aSXin Li   char fd_msg[8192];
317*ec63e07aSXin Li   cmsghdr* cmsg = reinterpret_cast<cmsghdr*>(fd_msg);
318*ec63e07aSXin Li 
319*ec63e07aSXin Li   InternalTLV tlv;
320*ec63e07aSXin Li   iovec iov = {.iov_base = &tlv, .iov_len = sizeof(tlv)};
321*ec63e07aSXin Li 
322*ec63e07aSXin Li   msghdr msg = {
323*ec63e07aSXin Li       .msg_name = nullptr,
324*ec63e07aSXin Li       .msg_namelen = 0,
325*ec63e07aSXin Li       .msg_iov = &iov,
326*ec63e07aSXin Li       .msg_iovlen = 1,
327*ec63e07aSXin Li       .msg_control = cmsg,
328*ec63e07aSXin Li       .msg_controllen = sizeof(fd_msg),
329*ec63e07aSXin Li       .msg_flags = 0,
330*ec63e07aSXin Li   };
331*ec63e07aSXin Li 
332*ec63e07aSXin Li   const auto op = [&msg](int fd) -> ssize_t {
333*ec63e07aSXin Li     PotentiallyBlockingRegion region;
334*ec63e07aSXin Li     // Use syscall, otherwise we would need to allow socketcall() on PPC.
335*ec63e07aSXin Li     return TEMP_FAILURE_RETRY(
336*ec63e07aSXin Li         util::Syscall(__NR_recvmsg, fd, reinterpret_cast<uintptr_t>(&msg), 0));
337*ec63e07aSXin Li   };
338*ec63e07aSXin Li   ssize_t len;
339*ec63e07aSXin Li   len = op(connection_fd_.get());
340*ec63e07aSXin Li   if (len < 0) {
341*ec63e07aSXin Li     if (IsFatalError(errno)) {
342*ec63e07aSXin Li       Terminate();
343*ec63e07aSXin Li     }
344*ec63e07aSXin Li     SAPI_RAW_PLOG(ERROR, "recvmsg(SCM_RIGHTS)");
345*ec63e07aSXin Li     return false;
346*ec63e07aSXin Li   }
347*ec63e07aSXin Li   if (len == 0) {
348*ec63e07aSXin Li     Terminate();
349*ec63e07aSXin Li     SAPI_RAW_VLOG(1, "RecvFD: end-point terminated the connection.");
350*ec63e07aSXin Li     return false;
351*ec63e07aSXin Li   }
352*ec63e07aSXin Li   if (len != sizeof(tlv)) {
353*ec63e07aSXin Li     SAPI_RAW_LOG(ERROR, "Expected size: %zu, got %zd", sizeof(tlv), len);
354*ec63e07aSXin Li     return false;
355*ec63e07aSXin Li   }
356*ec63e07aSXin Li   // At this point, we know that op() has been called successfully, therefore
357*ec63e07aSXin Li   // msg struct has been fully populated. Apparently MSAN is not aware of
358*ec63e07aSXin Li   // syscall(__NR_recvmsg) semantics so we need to suppress the error (here and
359*ec63e07aSXin Li   // everywhere below).
360*ec63e07aSXin Li   ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&tlv, sizeof(tlv));
361*ec63e07aSXin Li 
362*ec63e07aSXin Li   if (tlv.tag != kTagFd) {
363*ec63e07aSXin Li     SAPI_RAW_LOG(ERROR, "Expected (kTagFD: 0x%x), got: 0x%x", kTagFd, tlv.tag);
364*ec63e07aSXin Li     return false;
365*ec63e07aSXin Li   }
366*ec63e07aSXin Li 
367*ec63e07aSXin Li   cmsg = CMSG_FIRSTHDR(&msg);
368*ec63e07aSXin Li   ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(cmsg, sizeof(cmsghdr));
369*ec63e07aSXin Li   while (cmsg) {
370*ec63e07aSXin Li     if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
371*ec63e07aSXin Li       if (cmsg->cmsg_len != CMSG_LEN(sizeof(int))) {
372*ec63e07aSXin Li         SAPI_RAW_VLOG(1,
373*ec63e07aSXin Li                       "recvmsg(SCM_RIGHTS): cmsg->cmsg_len != "
374*ec63e07aSXin Li                       "CMSG_LEN(sizeof(int)), skipping");
375*ec63e07aSXin Li         continue;
376*ec63e07aSXin Li       }
377*ec63e07aSXin Li       int* fds = reinterpret_cast<int*>(CMSG_DATA(cmsg));
378*ec63e07aSXin Li       *fd = fds[0];
379*ec63e07aSXin Li       ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(fd, sizeof(int));
380*ec63e07aSXin Li       return true;
381*ec63e07aSXin Li     }
382*ec63e07aSXin Li     cmsg = CMSG_NXTHDR(&msg, cmsg);
383*ec63e07aSXin Li   }
384*ec63e07aSXin Li   SAPI_RAW_LOG(ERROR,
385*ec63e07aSXin Li                "Haven't received the SCM_RIGHTS message, process is probably "
386*ec63e07aSXin Li                "out of free file descriptors");
387*ec63e07aSXin Li   return false;
388*ec63e07aSXin Li }
389*ec63e07aSXin Li 
SendFD(int fd)390*ec63e07aSXin Li bool Comms::SendFD(int fd) {
391*ec63e07aSXin Li   char fd_msg[CMSG_SPACE(sizeof(int))] = {0};
392*ec63e07aSXin Li   cmsghdr* cmsg = reinterpret_cast<cmsghdr*>(fd_msg);
393*ec63e07aSXin Li   cmsg->cmsg_level = SOL_SOCKET;
394*ec63e07aSXin Li   cmsg->cmsg_type = SCM_RIGHTS;
395*ec63e07aSXin Li   cmsg->cmsg_len = CMSG_LEN(sizeof(int));
396*ec63e07aSXin Li 
397*ec63e07aSXin Li   int* fds = reinterpret_cast<int*>(CMSG_DATA(cmsg));
398*ec63e07aSXin Li   fds[0] = fd;
399*ec63e07aSXin Li 
400*ec63e07aSXin Li   InternalTLV tlv = {kTagFd, 0};
401*ec63e07aSXin Li 
402*ec63e07aSXin Li   iovec iov;
403*ec63e07aSXin Li   iov.iov_base = &tlv;
404*ec63e07aSXin Li   iov.iov_len = sizeof(tlv);
405*ec63e07aSXin Li 
406*ec63e07aSXin Li   msghdr msg;
407*ec63e07aSXin Li   msg.msg_name = nullptr;
408*ec63e07aSXin Li   msg.msg_namelen = 0;
409*ec63e07aSXin Li   msg.msg_iov = &iov;
410*ec63e07aSXin Li   msg.msg_iovlen = 1;
411*ec63e07aSXin Li   msg.msg_control = cmsg;
412*ec63e07aSXin Li   msg.msg_controllen = sizeof(fd_msg);
413*ec63e07aSXin Li   msg.msg_flags = 0;
414*ec63e07aSXin Li 
415*ec63e07aSXin Li   const auto op = [&msg](int fd) -> ssize_t {
416*ec63e07aSXin Li     PotentiallyBlockingRegion region;
417*ec63e07aSXin Li     // Use syscall, otherwise we would need to whitelist socketcall() on PPC.
418*ec63e07aSXin Li     return TEMP_FAILURE_RETRY(
419*ec63e07aSXin Li         util::Syscall(__NR_sendmsg, fd, reinterpret_cast<uintptr_t>(&msg), 0));
420*ec63e07aSXin Li   };
421*ec63e07aSXin Li   ssize_t len;
422*ec63e07aSXin Li   len = op(connection_fd_.get());
423*ec63e07aSXin Li   if (len == -1 && errno == EPIPE) {
424*ec63e07aSXin Li     Terminate();
425*ec63e07aSXin Li     SAPI_RAW_LOG(ERROR, "sendmsg(SCM_RIGHTS): Peer disconnected");
426*ec63e07aSXin Li     return false;
427*ec63e07aSXin Li   }
428*ec63e07aSXin Li   if (len < 0) {
429*ec63e07aSXin Li     if (IsFatalError(errno)) {
430*ec63e07aSXin Li       Terminate();
431*ec63e07aSXin Li     }
432*ec63e07aSXin Li     SAPI_RAW_PLOG(ERROR, "sendmsg(SCM_RIGHTS)");
433*ec63e07aSXin Li     return false;
434*ec63e07aSXin Li   }
435*ec63e07aSXin Li   if (len != sizeof(tlv)) {
436*ec63e07aSXin Li     SAPI_RAW_LOG(ERROR, "Expected to send %zu bytes, sent %zd", sizeof(tlv),
437*ec63e07aSXin Li                  len);
438*ec63e07aSXin Li     return false;
439*ec63e07aSXin Li   }
440*ec63e07aSXin Li   return true;
441*ec63e07aSXin Li }
442*ec63e07aSXin Li 
RecvProtoBuf(google::protobuf::MessageLite * message)443*ec63e07aSXin Li bool Comms::RecvProtoBuf(google::protobuf::MessageLite* message) {
444*ec63e07aSXin Li   uint32_t tag;
445*ec63e07aSXin Li   std::vector<uint8_t> bytes;
446*ec63e07aSXin Li   if (!RecvTLV(&tag, &bytes)) {
447*ec63e07aSXin Li     if (IsConnected()) {
448*ec63e07aSXin Li       SAPI_RAW_PLOG(ERROR, "RecvProtoBuf failed for (%s)", name_);
449*ec63e07aSXin Li     } else {
450*ec63e07aSXin Li       Terminate();
451*ec63e07aSXin Li       SAPI_RAW_VLOG(2, "Connection terminated (%s)", name_.c_str());
452*ec63e07aSXin Li     }
453*ec63e07aSXin Li     return false;
454*ec63e07aSXin Li   }
455*ec63e07aSXin Li 
456*ec63e07aSXin Li   if (tag != kTagProto2) {
457*ec63e07aSXin Li     SAPI_RAW_LOG(ERROR, "Expected tag: 0x%x, got: 0x%u", kTagProto2, tag);
458*ec63e07aSXin Li     return false;
459*ec63e07aSXin Li   }
460*ec63e07aSXin Li   return message->ParseFromArray(bytes.data(), bytes.size());
461*ec63e07aSXin Li }
462*ec63e07aSXin Li 
SendProtoBuf(const google::protobuf::MessageLite & message)463*ec63e07aSXin Li bool Comms::SendProtoBuf(const google::protobuf::MessageLite& message) {
464*ec63e07aSXin Li   std::string str;
465*ec63e07aSXin Li   if (!message.SerializeToString(&str)) {
466*ec63e07aSXin Li     SAPI_RAW_LOG(ERROR, "Couldn't serialize the ProtoBuf");
467*ec63e07aSXin Li     return false;
468*ec63e07aSXin Li   }
469*ec63e07aSXin Li 
470*ec63e07aSXin Li   return SendTLV(kTagProto2, str.length(),
471*ec63e07aSXin Li                  reinterpret_cast<const uint8_t*>(str.data()));
472*ec63e07aSXin Li }
473*ec63e07aSXin Li 
474*ec63e07aSXin Li // *****************************************************************************
475*ec63e07aSXin Li // All methods below are private, for internal use only.
476*ec63e07aSXin Li // *****************************************************************************
477*ec63e07aSXin Li 
Send(const void * data,size_t len)478*ec63e07aSXin Li bool Comms::Send(const void* data, size_t len) {
479*ec63e07aSXin Li   size_t total_sent = 0;
480*ec63e07aSXin Li   const char* bytes = reinterpret_cast<const char*>(data);
481*ec63e07aSXin Li   const auto op = [bytes, len, &total_sent](int fd) -> ssize_t {
482*ec63e07aSXin Li     PotentiallyBlockingRegion region;
483*ec63e07aSXin Li     return TEMP_FAILURE_RETRY(write(fd, &bytes[total_sent], len - total_sent));
484*ec63e07aSXin Li   };
485*ec63e07aSXin Li   while (total_sent < len) {
486*ec63e07aSXin Li     ssize_t s;
487*ec63e07aSXin Li       s = op(connection_fd_.get());
488*ec63e07aSXin Li     if (s == -1 && errno == EPIPE) {
489*ec63e07aSXin Li       Terminate();
490*ec63e07aSXin Li       // We do not expect the other end to disappear.
491*ec63e07aSXin Li       SAPI_RAW_LOG(ERROR, "Send: end-point terminated the connection");
492*ec63e07aSXin Li       return false;
493*ec63e07aSXin Li     }
494*ec63e07aSXin Li     if (s == -1) {
495*ec63e07aSXin Li       SAPI_RAW_PLOG(ERROR, "write");
496*ec63e07aSXin Li       if (IsFatalError(errno)) {
497*ec63e07aSXin Li         Terminate();
498*ec63e07aSXin Li       }
499*ec63e07aSXin Li       return false;
500*ec63e07aSXin Li     }
501*ec63e07aSXin Li     if (s == 0) {
502*ec63e07aSXin Li       SAPI_RAW_LOG(ERROR,
503*ec63e07aSXin Li                    "Couldn't write more bytes, wrote: %zu, requested: %zu",
504*ec63e07aSXin Li                    total_sent, len);
505*ec63e07aSXin Li       return false;
506*ec63e07aSXin Li     }
507*ec63e07aSXin Li     total_sent += s;
508*ec63e07aSXin Li   }
509*ec63e07aSXin Li   return true;
510*ec63e07aSXin Li }
511*ec63e07aSXin Li 
Recv(void * data,size_t len)512*ec63e07aSXin Li bool Comms::Recv(void* data, size_t len) {
513*ec63e07aSXin Li   size_t total_recv = 0;
514*ec63e07aSXin Li   char* bytes = reinterpret_cast<char*>(data);
515*ec63e07aSXin Li   const auto op = [bytes, len, &total_recv](int fd) -> ssize_t {
516*ec63e07aSXin Li     PotentiallyBlockingRegion region;
517*ec63e07aSXin Li     return TEMP_FAILURE_RETRY(read(fd, &bytes[total_recv], len - total_recv));
518*ec63e07aSXin Li   };
519*ec63e07aSXin Li   while (total_recv < len) {
520*ec63e07aSXin Li     ssize_t s;
521*ec63e07aSXin Li       s = op(connection_fd_.get());
522*ec63e07aSXin Li     if (s == -1) {
523*ec63e07aSXin Li       SAPI_RAW_PLOG(ERROR, "read");
524*ec63e07aSXin Li       if (IsFatalError(errno)) {
525*ec63e07aSXin Li         Terminate();
526*ec63e07aSXin Li       }
527*ec63e07aSXin Li       return false;
528*ec63e07aSXin Li     }
529*ec63e07aSXin Li     if (s == 0) {
530*ec63e07aSXin Li       Terminate();
531*ec63e07aSXin Li       // The other end might have finished its work.
532*ec63e07aSXin Li       SAPI_RAW_VLOG(2, "Recv: end-point terminated the connection.");
533*ec63e07aSXin Li       return false;
534*ec63e07aSXin Li     }
535*ec63e07aSXin Li     total_recv += s;
536*ec63e07aSXin Li   }
537*ec63e07aSXin Li   return true;
538*ec63e07aSXin Li }
539*ec63e07aSXin Li 
540*ec63e07aSXin Li // Internal helper method (low level).
RecvTL(uint32_t * tag,size_t * length)541*ec63e07aSXin Li bool Comms::RecvTL(uint32_t* tag, size_t* length) {
542*ec63e07aSXin Li   InternalTLV tl;
543*ec63e07aSXin Li   if (!Recv(reinterpret_cast<uint8_t*>(&tl), sizeof(tl))) {
544*ec63e07aSXin Li     SAPI_RAW_VLOG(2, "RecvTL: Can't read tag and length");
545*ec63e07aSXin Li     return false;
546*ec63e07aSXin Li   }
547*ec63e07aSXin Li   *tag = tl.tag;
548*ec63e07aSXin Li   *length = tl.len;
549*ec63e07aSXin Li   if (*length > GetMaxMsgSize()) {
550*ec63e07aSXin Li     SAPI_RAW_LOG(ERROR, "Maximum TLV message size exceeded: (%zu > %zd)",
551*ec63e07aSXin Li                  *length, GetMaxMsgSize());
552*ec63e07aSXin Li     return false;
553*ec63e07aSXin Li   }
554*ec63e07aSXin Li   if (*length > kWarnMsgSize) {
555*ec63e07aSXin Li     static std::atomic<int> times_warned = 0;
556*ec63e07aSXin Li     if (times_warned.fetch_add(1, std::memory_order_relaxed) < 10) {
557*ec63e07aSXin Li       SAPI_RAW_LOG(
558*ec63e07aSXin Li           WARNING,
559*ec63e07aSXin Li           "TLV message of size: %zu detected. Please consider switching to "
560*ec63e07aSXin Li           "Buffer API instead.",
561*ec63e07aSXin Li           *length);
562*ec63e07aSXin Li     }
563*ec63e07aSXin Li   }
564*ec63e07aSXin Li   return true;
565*ec63e07aSXin Li }
566*ec63e07aSXin Li 
RecvTLV(uint32_t * tag,std::vector<uint8_t> * value)567*ec63e07aSXin Li bool Comms::RecvTLV(uint32_t* tag, std::vector<uint8_t>* value) {
568*ec63e07aSXin Li   return RecvTLVGeneric(tag, value);
569*ec63e07aSXin Li }
570*ec63e07aSXin Li 
RecvTLV(uint32_t * tag,std::string * value)571*ec63e07aSXin Li bool Comms::RecvTLV(uint32_t* tag, std::string* value) {
572*ec63e07aSXin Li   return RecvTLVGeneric(tag, value);
573*ec63e07aSXin Li }
574*ec63e07aSXin Li 
575*ec63e07aSXin Li template <typename T>
RecvTLVGeneric(uint32_t * tag,T * value)576*ec63e07aSXin Li bool Comms::RecvTLVGeneric(uint32_t* tag, T* value) {
577*ec63e07aSXin Li   size_t length;
578*ec63e07aSXin Li   if (!RecvTL(tag, &length)) {
579*ec63e07aSXin Li     return false;
580*ec63e07aSXin Li   }
581*ec63e07aSXin Li 
582*ec63e07aSXin Li   value->resize(length);
583*ec63e07aSXin Li   return length == 0 || Recv(reinterpret_cast<uint8_t*>(value->data()), length);
584*ec63e07aSXin Li }
585*ec63e07aSXin Li 
RecvTLV(uint32_t * tag,size_t * length,void * buffer,size_t buffer_size)586*ec63e07aSXin Li bool Comms::RecvTLV(uint32_t* tag, size_t* length, void* buffer,
587*ec63e07aSXin Li                     size_t buffer_size) {
588*ec63e07aSXin Li   if (!RecvTL(tag, length)) {
589*ec63e07aSXin Li     return false;
590*ec63e07aSXin Li   }
591*ec63e07aSXin Li 
592*ec63e07aSXin Li   if (*length == 0) {
593*ec63e07aSXin Li     return true;
594*ec63e07aSXin Li   }
595*ec63e07aSXin Li 
596*ec63e07aSXin Li   if (*length > buffer_size) {
597*ec63e07aSXin Li     SAPI_RAW_LOG(ERROR, "Buffer size too small (0x%zx > 0x%zx)", *length,
598*ec63e07aSXin Li                  buffer_size);
599*ec63e07aSXin Li     return false;
600*ec63e07aSXin Li   }
601*ec63e07aSXin Li 
602*ec63e07aSXin Li   return Recv(reinterpret_cast<uint8_t*>(buffer), *length);
603*ec63e07aSXin Li }
604*ec63e07aSXin Li 
RecvInt(void * buffer,size_t len,uint32_t tag)605*ec63e07aSXin Li bool Comms::RecvInt(void* buffer, size_t len, uint32_t tag) {
606*ec63e07aSXin Li   uint32_t received_tag;
607*ec63e07aSXin Li   size_t received_length;
608*ec63e07aSXin Li   if (!RecvTLV(&received_tag, &received_length, buffer, len)) {
609*ec63e07aSXin Li     return false;
610*ec63e07aSXin Li   }
611*ec63e07aSXin Li 
612*ec63e07aSXin Li   if (received_tag != tag) {
613*ec63e07aSXin Li     SAPI_RAW_LOG(ERROR, "Expected tag: 0x%08x, got: 0x%x", tag, received_tag);
614*ec63e07aSXin Li     return false;
615*ec63e07aSXin Li   }
616*ec63e07aSXin Li   if (received_length != len) {
617*ec63e07aSXin Li     SAPI_RAW_LOG(ERROR, "Expected length: %zu, got: %zu", len, received_length);
618*ec63e07aSXin Li     return false;
619*ec63e07aSXin Li   }
620*ec63e07aSXin Li   return true;
621*ec63e07aSXin Li }
622*ec63e07aSXin Li 
RecvStatus(absl::Status * status)623*ec63e07aSXin Li bool Comms::RecvStatus(absl::Status* status) {
624*ec63e07aSXin Li   sapi::StatusProto proto;
625*ec63e07aSXin Li   if (!RecvProtoBuf(&proto)) {
626*ec63e07aSXin Li     return false;
627*ec63e07aSXin Li   }
628*ec63e07aSXin Li   *status = sapi::MakeStatusFromProto(proto);
629*ec63e07aSXin Li   return true;
630*ec63e07aSXin Li }
631*ec63e07aSXin Li 
SendStatus(const absl::Status & status)632*ec63e07aSXin Li bool Comms::SendStatus(const absl::Status& status) {
633*ec63e07aSXin Li   sapi::StatusProto proto;
634*ec63e07aSXin Li   sapi::SaveStatusToProto(status, &proto);
635*ec63e07aSXin Li   return SendProtoBuf(proto);
636*ec63e07aSXin Li }
637*ec63e07aSXin Li 
MoveToAnotherFd()638*ec63e07aSXin Li void Comms::MoveToAnotherFd() {
639*ec63e07aSXin Li   SAPI_RAW_CHECK(connection_fd_.get() != -1,
640*ec63e07aSXin Li                  "Cannot move comms fd as it's not connected");
641*ec63e07aSXin Li   FDCloser new_fd(dup(connection_fd_.get()));
642*ec63e07aSXin Li   SAPI_RAW_CHECK(new_fd.get() != -1, "Failed to move comms to another fd");
643*ec63e07aSXin Li   connection_fd_.Swap(new_fd);
644*ec63e07aSXin Li }
645*ec63e07aSXin Li 
646*ec63e07aSXin Li }  // namespace sandbox2
647