1 /*
2 * Copyright 2021 Google LLC
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "render_socket.h"
7
8 #include <errno.h>
9 #include <sys/socket.h>
10 #include <sys/types.h>
11 #include <unistd.h>
12
13 #define RENDER_SOCKET_MAX_FD_COUNT 8
14
15 /* The socket pair between the server process and the client process is set up
16 * by the client process (or yet another process). Because render_server_run
17 * does not poll yet, the fd is expected to be blocking.
18 *
19 * We also expect the fd to be always valid. If the client process dies, the
20 * fd becomes invalid and is considered a fatal error.
21 *
22 * There is also a socket pair between each context worker and the client
23 * process. The pair is set up by render_socket_pair here.
24 *
25 * The fd is also expected to be blocking. When the client process closes its
26 * end of the socket pair, the context worker terminates.
27 */
28 bool
render_socket_pair(int out_fds[static2])29 render_socket_pair(int out_fds[static 2])
30 {
31 int ret = socketpair(AF_UNIX, SOCK_SEQPACKET | SOCK_CLOEXEC, 0, out_fds);
32 if (ret) {
33 render_log("failed to create socket pair");
34 return false;
35 }
36
37 return true;
38 }
39
40 bool
render_socket_is_seqpacket(int fd)41 render_socket_is_seqpacket(int fd)
42 {
43 int type;
44 socklen_t len = sizeof(type);
45 if (getsockopt(fd, SOL_SOCKET, SO_TYPE, &type, &len))
46 return false;
47 return type == SOCK_SEQPACKET;
48 }
49
50 void
render_socket_init(struct render_socket * socket,int fd)51 render_socket_init(struct render_socket *socket, int fd)
52 {
53 assert(fd >= 0);
54 *socket = (struct render_socket){
55 .fd = fd,
56 };
57 }
58
59 void
render_socket_fini(struct render_socket * socket)60 render_socket_fini(struct render_socket *socket)
61 {
62 close(socket->fd);
63 }
64
65 static const int *
get_received_fds(const struct msghdr * msg,int * out_count)66 get_received_fds(const struct msghdr *msg, int *out_count)
67 {
68 const struct cmsghdr *cmsg = CMSG_FIRSTHDR(msg);
69 if (unlikely(!cmsg || cmsg->cmsg_level != SOL_SOCKET ||
70 cmsg->cmsg_type != SCM_RIGHTS || cmsg->cmsg_len < CMSG_LEN(0))) {
71 *out_count = 0;
72 return NULL;
73 }
74
75 *out_count = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
76 return (const int *)CMSG_DATA(cmsg);
77 }
78
79 static bool
render_socket_recvmsg(struct render_socket * socket,struct msghdr * msg,size_t * out_size)80 render_socket_recvmsg(struct render_socket *socket, struct msghdr *msg, size_t *out_size)
81 {
82 do {
83 const ssize_t s = recvmsg(socket->fd, msg, MSG_CMSG_CLOEXEC);
84 if (unlikely(s <= 0)) {
85 if (!s)
86 return false;
87
88 if (errno == EAGAIN || errno == EINTR)
89 continue;
90
91 render_log("failed to receive message: %s", strerror(errno));
92 return false;
93 }
94
95 if (unlikely(msg->msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
96 render_log("failed to receive message: truncated");
97
98 int fd_count;
99 const int *fds = get_received_fds(msg, &fd_count);
100 for (int i = 0; i < fd_count; i++)
101 close(fds[i]);
102
103 return false;
104 }
105
106 *out_size = s;
107 return true;
108 } while (true);
109 }
110
111 static bool
render_socket_receive_request_internal(struct render_socket * socket,void * data,size_t max_size,size_t * out_size,int * fds,int max_fd_count,int * out_fd_count)112 render_socket_receive_request_internal(struct render_socket *socket,
113 void *data,
114 size_t max_size,
115 size_t *out_size,
116 int *fds,
117 int max_fd_count,
118 int *out_fd_count)
119 {
120 assert(data && max_size);
121 struct msghdr msg = {
122 .msg_iov =
123 &(struct iovec){
124 .iov_base = data,
125 .iov_len = max_size,
126 },
127 .msg_iovlen = 1,
128 };
129
130 char cmsg_buf[CMSG_SPACE(sizeof(*fds) * RENDER_SOCKET_MAX_FD_COUNT)];
131 if (max_fd_count) {
132 assert(fds && max_fd_count <= RENDER_SOCKET_MAX_FD_COUNT);
133 msg.msg_control = cmsg_buf;
134 msg.msg_controllen = CMSG_SPACE(sizeof(*fds) * max_fd_count);
135
136 struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
137 memset(cmsg, 0, sizeof(*cmsg));
138 }
139
140 if (!render_socket_recvmsg(socket, &msg, out_size))
141 return false;
142
143 if (max_fd_count) {
144 int received_fd_count;
145 const int *received_fds = get_received_fds(&msg, &received_fd_count);
146 assert(received_fd_count <= max_fd_count);
147
148 memcpy(fds, received_fds, sizeof(*fds) * received_fd_count);
149 *out_fd_count = received_fd_count;
150 } else if (out_fd_count) {
151 *out_fd_count = 0;
152 }
153
154 return true;
155 }
156
157 bool
render_socket_receive_request(struct render_socket * socket,void * data,size_t max_size,size_t * out_size)158 render_socket_receive_request(struct render_socket *socket,
159 void *data,
160 size_t max_size,
161 size_t *out_size)
162 {
163 return render_socket_receive_request_internal(socket, data, max_size, out_size, NULL,
164 0, NULL);
165 }
166
167 bool
render_socket_receive_request_with_fds(struct render_socket * socket,void * data,size_t max_size,size_t * out_size,int * fds,int max_fd_count,int * out_fd_count)168 render_socket_receive_request_with_fds(struct render_socket *socket,
169 void *data,
170 size_t max_size,
171 size_t *out_size,
172 int *fds,
173 int max_fd_count,
174 int *out_fd_count)
175 {
176 return render_socket_receive_request_internal(socket, data, max_size, out_size, fds,
177 max_fd_count, out_fd_count);
178 }
179
180 bool
render_socket_receive_data(struct render_socket * socket,void * data,size_t size)181 render_socket_receive_data(struct render_socket *socket, void *data, size_t size)
182 {
183 size_t received_size;
184 if (!render_socket_receive_request(socket, data, size, &received_size))
185 return false;
186
187 if (size != received_size) {
188 render_log("failed to receive data: expected %zu but received %zu", size,
189 received_size);
190 return false;
191 }
192
193 return true;
194 }
195
196 static bool
render_socket_sendmsg(struct render_socket * socket,const struct msghdr * msg)197 render_socket_sendmsg(struct render_socket *socket, const struct msghdr *msg)
198 {
199 do {
200 const ssize_t s = sendmsg(socket->fd, msg, MSG_NOSIGNAL);
201 if (unlikely(s < 0)) {
202 if (errno == EAGAIN || errno == EINTR)
203 continue;
204
205 render_log("failed to send message: %s", strerror(errno));
206 return false;
207 }
208
209 /* no partial send since the socket type is SOCK_SEQPACKET */
210 assert(msg->msg_iovlen == 1 && msg->msg_iov[0].iov_len == (size_t)s);
211 return true;
212 } while (true);
213 }
214
215 static inline bool
render_socket_send_reply_internal(struct render_socket * socket,const void * data,size_t size,const int * fds,int fd_count)216 render_socket_send_reply_internal(struct render_socket *socket,
217 const void *data,
218 size_t size,
219 const int *fds,
220 int fd_count)
221 {
222 assert(data && size);
223 struct msghdr msg = {
224 .msg_iov =
225 &(struct iovec){
226 .iov_base = (void *)data,
227 .iov_len = size,
228 },
229 .msg_iovlen = 1,
230 };
231
232 char cmsg_buf[CMSG_SPACE(sizeof(*fds) * RENDER_SOCKET_MAX_FD_COUNT)];
233 if (fd_count) {
234 assert(fds && fd_count <= RENDER_SOCKET_MAX_FD_COUNT);
235 msg.msg_control = cmsg_buf;
236 msg.msg_controllen = CMSG_SPACE(sizeof(*fds) * fd_count);
237
238 struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
239 cmsg->cmsg_level = SOL_SOCKET;
240 cmsg->cmsg_type = SCM_RIGHTS;
241 cmsg->cmsg_len = CMSG_LEN(sizeof(*fds) * fd_count);
242 memcpy(CMSG_DATA(cmsg), fds, sizeof(*fds) * fd_count);
243 }
244
245 return render_socket_sendmsg(socket, &msg);
246 }
247
248 bool
render_socket_send_reply(struct render_socket * socket,const void * data,size_t size)249 render_socket_send_reply(struct render_socket *socket, const void *data, size_t size)
250 {
251 return render_socket_send_reply_internal(socket, data, size, NULL, 0);
252 }
253
254 bool
render_socket_send_reply_with_fds(struct render_socket * socket,const void * data,size_t size,const int * fds,int fd_count)255 render_socket_send_reply_with_fds(struct render_socket *socket,
256 const void *data,
257 size_t size,
258 const int *fds,
259 int fd_count)
260 {
261 return render_socket_send_reply_internal(socket, data, size, fds, fd_count);
262 }
263