xref: /aosp_15_r20/external/virglrenderer/server/render_socket.c (revision bbecb9d118dfdb95f99bd754f8fa9be01f189df3)
1 /*
2  * Copyright 2021 Google LLC
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "render_socket.h"
7 
8 #include <errno.h>
9 #include <sys/socket.h>
10 #include <sys/types.h>
11 #include <unistd.h>
12 
13 #define RENDER_SOCKET_MAX_FD_COUNT 8
14 
15 /* The socket pair between the server process and the client process is set up
16  * by the client process (or yet another process).  Because render_server_run
17  * does not poll yet, the fd is expected to be blocking.
18  *
19  * We also expect the fd to be always valid.  If the client process dies, the
20  * fd becomes invalid and is considered a fatal error.
21  *
22  * There is also a socket pair between each context worker and the client
23  * process.  The pair is set up by render_socket_pair here.
24  *
25  * The fd is also expected to be blocking.  When the client process closes its
26  * end of the socket pair, the context worker terminates.
27  */
28 bool
render_socket_pair(int out_fds[static2])29 render_socket_pair(int out_fds[static 2])
30 {
31    int ret = socketpair(AF_UNIX, SOCK_SEQPACKET | SOCK_CLOEXEC, 0, out_fds);
32    if (ret) {
33       render_log("failed to create socket pair");
34       return false;
35    }
36 
37    return true;
38 }
39 
40 bool
render_socket_is_seqpacket(int fd)41 render_socket_is_seqpacket(int fd)
42 {
43    int type;
44    socklen_t len = sizeof(type);
45    if (getsockopt(fd, SOL_SOCKET, SO_TYPE, &type, &len))
46       return false;
47    return type == SOCK_SEQPACKET;
48 }
49 
50 void
render_socket_init(struct render_socket * socket,int fd)51 render_socket_init(struct render_socket *socket, int fd)
52 {
53    assert(fd >= 0);
54    *socket = (struct render_socket){
55       .fd = fd,
56    };
57 }
58 
59 void
render_socket_fini(struct render_socket * socket)60 render_socket_fini(struct render_socket *socket)
61 {
62    close(socket->fd);
63 }
64 
65 static const int *
get_received_fds(const struct msghdr * msg,int * out_count)66 get_received_fds(const struct msghdr *msg, int *out_count)
67 {
68    const struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg);
69    if (unlikely(!cmsg || cmsg->cmsg_level != SOL_SOCKET ||
70                 cmsg->cmsg_type != SCM_RIGHTS || cmsg->cmsg_len < CMSG_LEN(0))) {
71       *out_count = 0;
72       return NULL;
73    }
74 
75    *out_count = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
76    return (const int *)CMSG_DATA(cmsg);
77 }
78 
79 static bool
render_socket_recvmsg(struct render_socket * socket,struct msghdr * msg,size_t * out_size)80 render_socket_recvmsg(struct render_socket *socket, struct msghdr *msg, size_t *out_size)
81 {
82    do {
83       const ssize_t s = recvmsg(socket->fd, msg, MSG_CMSG_CLOEXEC);
84       if (unlikely(s <= 0)) {
85          if (!s)
86             return false;
87 
88          if (errno == EAGAIN || errno == EINTR)
89             continue;
90 
91          render_log("failed to receive message: %s", strerror(errno));
92          return false;
93       }
94 
95       if (unlikely(msg->msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
96          render_log("failed to receive message: truncated");
97 
98          int fd_count;
99          const int *fds = get_received_fds(msg, &fd_count);
100          for (int i = 0; i < fd_count; i++)
101             close(fds[i]);
102 
103          return false;
104       }
105 
106       *out_size = s;
107       return true;
108    } while (true);
109 }
110 
111 static bool
render_socket_receive_request_internal(struct render_socket * socket,void * data,size_t max_size,size_t * out_size,int * fds,int max_fd_count,int * out_fd_count)112 render_socket_receive_request_internal(struct render_socket *socket,
113                                        void *data,
114                                        size_t max_size,
115                                        size_t *out_size,
116                                        int *fds,
117                                        int max_fd_count,
118                                        int *out_fd_count)
119 {
120    assert(data && max_size);
121    struct msghdr msg = {
122       .msg_iov =
123          &(struct iovec){
124             .iov_base = data,
125             .iov_len = max_size,
126          },
127       .msg_iovlen = 1,
128    };
129 
130    char cmsg_buf[CMSG_SPACE(sizeof(*fds) * RENDER_SOCKET_MAX_FD_COUNT)];
131    if (max_fd_count) {
132       assert(fds && max_fd_count <= RENDER_SOCKET_MAX_FD_COUNT);
133       msg.msg_control = cmsg_buf;
134       msg.msg_controllen = CMSG_SPACE(sizeof(*fds) * max_fd_count);
135 
136       struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
137       memset(cmsg, 0, sizeof(*cmsg));
138    }
139 
140    if (!render_socket_recvmsg(socket, &msg, out_size))
141       return false;
142 
143    if (max_fd_count) {
144       int received_fd_count;
145       const int *received_fds = get_received_fds(&msg, &received_fd_count);
146       assert(received_fd_count <= max_fd_count);
147 
148       memcpy(fds, received_fds, sizeof(*fds) * received_fd_count);
149       *out_fd_count = received_fd_count;
150    } else if (out_fd_count) {
151       *out_fd_count = 0;
152    }
153 
154    return true;
155 }
156 
157 bool
render_socket_receive_request(struct render_socket * socket,void * data,size_t max_size,size_t * out_size)158 render_socket_receive_request(struct render_socket *socket,
159                               void *data,
160                               size_t max_size,
161                               size_t *out_size)
162 {
163    return render_socket_receive_request_internal(socket, data, max_size, out_size, NULL,
164                                                  0, NULL);
165 }
166 
167 bool
render_socket_receive_request_with_fds(struct render_socket * socket,void * data,size_t max_size,size_t * out_size,int * fds,int max_fd_count,int * out_fd_count)168 render_socket_receive_request_with_fds(struct render_socket *socket,
169                                        void *data,
170                                        size_t max_size,
171                                        size_t *out_size,
172                                        int *fds,
173                                        int max_fd_count,
174                                        int *out_fd_count)
175 {
176    return render_socket_receive_request_internal(socket, data, max_size, out_size, fds,
177                                                  max_fd_count, out_fd_count);
178 }
179 
180 bool
render_socket_receive_data(struct render_socket * socket,void * data,size_t size)181 render_socket_receive_data(struct render_socket *socket, void *data, size_t size)
182 {
183    size_t received_size;
184    if (!render_socket_receive_request(socket, data, size, &received_size))
185       return false;
186 
187    if (size != received_size) {
188       render_log("failed to receive data: expected %zu but received %zu", size,
189                  received_size);
190       return false;
191    }
192 
193    return true;
194 }
195 
196 static bool
render_socket_sendmsg(struct render_socket * socket,const struct msghdr * msg)197 render_socket_sendmsg(struct render_socket *socket, const struct msghdr *msg)
198 {
199    do {
200       const ssize_t s = sendmsg(socket->fd, msg, MSG_NOSIGNAL);
201       if (unlikely(s < 0)) {
202          if (errno == EAGAIN || errno == EINTR)
203             continue;
204 
205          render_log("failed to send message: %s", strerror(errno));
206          return false;
207       }
208 
209       /* no partial send since the socket type is SOCK_SEQPACKET */
210       assert(msg->msg_iovlen == 1 && msg->msg_iov[0].iov_len == (size_t)s);
211       return true;
212    } while (true);
213 }
214 
215 static inline bool
render_socket_send_reply_internal(struct render_socket * socket,const void * data,size_t size,const int * fds,int fd_count)216 render_socket_send_reply_internal(struct render_socket *socket,
217                                   const void *data,
218                                   size_t size,
219                                   const int *fds,
220                                   int fd_count)
221 {
222    assert(data && size);
223    struct msghdr msg = {
224       .msg_iov =
225          &(struct iovec){
226             .iov_base = (void *)data,
227             .iov_len = size,
228          },
229       .msg_iovlen = 1,
230    };
231 
232    char cmsg_buf[CMSG_SPACE(sizeof(*fds) * RENDER_SOCKET_MAX_FD_COUNT)];
233    if (fd_count) {
234       assert(fds && fd_count <= RENDER_SOCKET_MAX_FD_COUNT);
235       msg.msg_control = cmsg_buf;
236       msg.msg_controllen = CMSG_SPACE(sizeof(*fds) * fd_count);
237 
238       struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
239       cmsg->cmsg_level = SOL_SOCKET;
240       cmsg->cmsg_type = SCM_RIGHTS;
241       cmsg->cmsg_len = CMSG_LEN(sizeof(*fds) * fd_count);
242       memcpy(CMSG_DATA(cmsg), fds, sizeof(*fds) * fd_count);
243    }
244 
245    return render_socket_sendmsg(socket, &msg);
246 }
247 
248 bool
render_socket_send_reply(struct render_socket * socket,const void * data,size_t size)249 render_socket_send_reply(struct render_socket *socket, const void *data, size_t size)
250 {
251    return render_socket_send_reply_internal(socket, data, size, NULL, 0);
252 }
253 
254 bool
render_socket_send_reply_with_fds(struct render_socket * socket,const void * data,size_t size,const int * fds,int fd_count)255 render_socket_send_reply_with_fds(struct render_socket *socket,
256                                   const void *data,
257                                   size_t size,
258                                   const int *fds,
259                                   int fd_count)
260 {
261    return render_socket_send_reply_internal(socket, data, size, fds, fd_count);
262 }
263