1 /* SPDX-License-Identifier: MIT */
2 /*
3 * Test MSG_WAITALL for recv/recvmsg and include normal sync versions just
4 * for comparison.
5 */
6 #include <errno.h>
7 #include <stdio.h>
8 #include <stdlib.h>
9 #include <string.h>
10 #include <unistd.h>
11 #include <fcntl.h>
12 #include <arpa/inet.h>
13 #include <sys/types.h>
14 #include <sys/socket.h>
15 #include <pthread.h>
16
17 #include "liburing.h"
18 #include "helpers.h"
19
20 #define MAX_MSG 128
21
22 static int port = 31200;
23
24 struct recv_data {
25 pthread_mutex_t mutex;
26 int use_recvmsg;
27 int use_sync;
28 int port;
29 };
30
get_conn_sock(struct recv_data * rd,int * sockout)31 static int get_conn_sock(struct recv_data *rd, int *sockout)
32 {
33 struct sockaddr_in saddr;
34 int sockfd, ret, val;
35
36 memset(&saddr, 0, sizeof(saddr));
37 saddr.sin_family = AF_INET;
38 saddr.sin_addr.s_addr = htonl(INADDR_ANY);
39 saddr.sin_port = htons(rd->port);
40
41 sockfd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
42 if (sockfd < 0) {
43 perror("socket");
44 goto err;
45 }
46
47 val = 1;
48 setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));
49 setsockopt(sockfd, SOL_SOCKET, SO_REUSEPORT, &val, sizeof(val));
50
51 ret = bind(sockfd, (struct sockaddr *)&saddr, sizeof(saddr));
52 if (ret < 0) {
53 perror("bind");
54 goto err;
55 }
56
57 ret = listen(sockfd, 16);
58 if (ret < 0) {
59 perror("listen");
60 goto err;
61 }
62
63 pthread_mutex_unlock(&rd->mutex);
64
65 ret = accept(sockfd, NULL, NULL);
66 if (ret < 0) {
67 perror("accept");
68 return -1;
69 }
70
71 *sockout = sockfd;
72 return ret;
73 err:
74 pthread_mutex_unlock(&rd->mutex);
75 return -1;
76 }
77
recv_prep(struct io_uring * ring,struct iovec * iov,int * sock,struct recv_data * rd)78 static int recv_prep(struct io_uring *ring, struct iovec *iov, int *sock,
79 struct recv_data *rd)
80 {
81 struct io_uring_sqe *sqe;
82 struct msghdr msg = { };
83 int sockfd, sockout = -1, ret;
84
85 sockfd = get_conn_sock(rd, &sockout);
86 if (sockfd < 0)
87 goto err;
88
89 sqe = io_uring_get_sqe(ring);
90 if (!rd->use_recvmsg) {
91 io_uring_prep_recv(sqe, sockfd, iov->iov_base, iov->iov_len,
92 MSG_WAITALL);
93 } else {
94 msg.msg_namelen = sizeof(struct sockaddr_in);
95 msg.msg_iov = iov;
96 msg.msg_iovlen = 1;
97 io_uring_prep_recvmsg(sqe, sockfd, &msg, MSG_WAITALL);
98 }
99
100 sqe->user_data = 2;
101
102 ret = io_uring_submit(ring);
103 if (ret <= 0) {
104 fprintf(stderr, "submit failed: %d\n", ret);
105 goto err;
106 }
107
108 *sock = sockfd;
109 return 0;
110 err:
111 if (sockout != -1) {
112 shutdown(sockout, SHUT_RDWR);
113 close(sockout);
114 }
115 if (sockfd != -1) {
116 shutdown(sockfd, SHUT_RDWR);
117 close(sockfd);
118 }
119 return 1;
120 }
121
do_recv(struct io_uring * ring)122 static int do_recv(struct io_uring *ring)
123 {
124 struct io_uring_cqe *cqe;
125 int ret;
126
127 ret = io_uring_wait_cqe(ring, &cqe);
128 if (ret) {
129 fprintf(stdout, "wait_cqe: %d\n", ret);
130 goto err;
131 }
132 if (cqe->res == -EINVAL) {
133 fprintf(stdout, "recv not supported, skipping\n");
134 return 0;
135 }
136 if (cqe->res < 0) {
137 fprintf(stderr, "failed cqe: %d\n", cqe->res);
138 goto err;
139 }
140 if (cqe->res != MAX_MSG * sizeof(int)) {
141 fprintf(stderr, "got wrong length: %d\n", cqe->res);
142 goto err;
143 }
144
145 io_uring_cqe_seen(ring, cqe);
146 return 0;
147 err:
148 return 1;
149 }
150
recv_sync(struct recv_data * rd)151 static int recv_sync(struct recv_data *rd)
152 {
153 int buf[MAX_MSG];
154 struct iovec iov = {
155 .iov_base = buf,
156 .iov_len = sizeof(buf),
157 };
158 int i, ret, sockfd, sockout = -1;
159
160 sockfd = get_conn_sock(rd, &sockout);
161
162 if (rd->use_recvmsg) {
163 struct msghdr msg = { };
164
165 msg.msg_namelen = sizeof(struct sockaddr_in);
166 msg.msg_iov = &iov;
167 msg.msg_iovlen = 1;
168 ret = recvmsg(sockfd, &msg, MSG_WAITALL);
169 } else {
170 ret = recv(sockfd, buf, sizeof(buf), MSG_WAITALL);
171 }
172
173 if (ret < 0) {
174 perror("receive");
175 goto err;
176 }
177
178 if (ret != sizeof(buf)) {
179 ret = -1;
180 goto err;
181 }
182
183 for (i = 0; i < MAX_MSG; i++) {
184 if (buf[i] != i)
185 goto err;
186 }
187 ret = 0;
188 err:
189 shutdown(sockout, SHUT_RDWR);
190 shutdown(sockfd, SHUT_RDWR);
191 close(sockout);
192 close(sockfd);
193 return ret;
194 }
195
recv_uring(struct recv_data * rd)196 static int recv_uring(struct recv_data *rd)
197 {
198 int buf[MAX_MSG];
199 struct iovec iov = {
200 .iov_base = buf,
201 .iov_len = sizeof(buf),
202 };
203 struct io_uring_params p = { };
204 struct io_uring ring;
205 int ret, sock = -1, sockout = -1;
206
207 ret = t_create_ring_params(1, &ring, &p);
208 if (ret == T_SETUP_SKIP) {
209 pthread_mutex_unlock(&rd->mutex);
210 ret = 0;
211 goto err;
212 } else if (ret < 0) {
213 pthread_mutex_unlock(&rd->mutex);
214 goto err;
215 }
216
217 sock = recv_prep(&ring, &iov, &sockout, rd);
218 if (ret) {
219 fprintf(stderr, "recv_prep failed: %d\n", ret);
220 goto err;
221 }
222 ret = do_recv(&ring);
223 if (!ret) {
224 int i;
225
226 for (i = 0; i < MAX_MSG; i++) {
227 if (buf[i] != i) {
228 fprintf(stderr, "found %d at %d\n", buf[i], i);
229 ret = 1;
230 break;
231 }
232 }
233 }
234
235 shutdown(sockout, SHUT_RDWR);
236 shutdown(sock, SHUT_RDWR);
237 close(sock);
238 close(sockout);
239 io_uring_queue_exit(&ring);
240 err:
241 if (sock != -1) {
242 shutdown(sock, SHUT_RDWR);
243 close(sock);
244 }
245 if (sockout != -1) {
246 shutdown(sockout, SHUT_RDWR);
247 close(sockout);
248 }
249 return ret;
250 }
251
recv_fn(void * data)252 static void *recv_fn(void *data)
253 {
254 struct recv_data *rd = data;
255
256 if (rd->use_sync)
257 return (void *) (uintptr_t) recv_sync(rd);
258
259 return (void *) (uintptr_t) recv_uring(rd);
260 }
261
do_send(struct recv_data * rd)262 static int do_send(struct recv_data *rd)
263 {
264 struct sockaddr_in saddr;
265 struct io_uring ring;
266 struct io_uring_cqe *cqe;
267 struct io_uring_sqe *sqe;
268 int sockfd, ret, i;
269 struct iovec iov;
270 int *buf;
271
272 ret = io_uring_queue_init(2, &ring, 0);
273 if (ret) {
274 fprintf(stderr, "queue init failed: %d\n", ret);
275 return 1;
276 }
277
278 buf = malloc(MAX_MSG * sizeof(int));
279 for (i = 0; i < MAX_MSG; i++)
280 buf[i] = i;
281
282 memset(&saddr, 0, sizeof(saddr));
283 saddr.sin_family = AF_INET;
284 saddr.sin_port = htons(rd->port);
285 inet_pton(AF_INET, "127.0.0.1", &saddr.sin_addr);
286
287 sockfd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
288 if (sockfd < 0) {
289 perror("socket");
290 return 1;
291 }
292
293 pthread_mutex_lock(&rd->mutex);
294
295 ret = connect(sockfd, (struct sockaddr *)&saddr, sizeof(saddr));
296 if (ret < 0) {
297 perror("connect");
298 return 1;
299 }
300
301 iov.iov_base = buf;
302 iov.iov_len = MAX_MSG * sizeof(int) / 2;
303 for (i = 0; i < 2; i++) {
304 sqe = io_uring_get_sqe(&ring);
305 io_uring_prep_send(sqe, sockfd, iov.iov_base, iov.iov_len, 0);
306 sqe->user_data = 1;
307
308 ret = io_uring_submit(&ring);
309 if (ret <= 0) {
310 fprintf(stderr, "submit failed: %d\n", ret);
311 goto err;
312 }
313 usleep(10000);
314 iov.iov_base += iov.iov_len;
315 }
316
317 for (i = 0; i < 2; i++) {
318 ret = io_uring_wait_cqe(&ring, &cqe);
319 if (cqe->res == -EINVAL) {
320 fprintf(stdout, "send not supported, skipping\n");
321 close(sockfd);
322 return 0;
323 }
324 if (cqe->res != iov.iov_len) {
325 fprintf(stderr, "failed cqe: %d\n", cqe->res);
326 goto err;
327 }
328 io_uring_cqe_seen(&ring, cqe);
329 }
330
331 shutdown(sockfd, SHUT_RDWR);
332 close(sockfd);
333 return 0;
334 err:
335 shutdown(sockfd, SHUT_RDWR);
336 close(sockfd);
337 return 1;
338 }
339
test(int use_recvmsg,int use_sync)340 static int test(int use_recvmsg, int use_sync)
341 {
342 pthread_mutexattr_t attr;
343 pthread_t recv_thread;
344 struct recv_data rd;
345 int ret;
346 void *retval;
347
348 pthread_mutexattr_init(&attr);
349 pthread_mutexattr_setpshared(&attr, 1);
350 pthread_mutex_init(&rd.mutex, &attr);
351 pthread_mutex_lock(&rd.mutex);
352 rd.use_recvmsg = use_recvmsg;
353 rd.use_sync = use_sync;
354 rd.port = port++;
355
356 ret = pthread_create(&recv_thread, NULL, recv_fn, &rd);
357 if (ret) {
358 fprintf(stderr, "Thread create failed: %d\n", ret);
359 pthread_mutex_unlock(&rd.mutex);
360 return 1;
361 }
362
363 do_send(&rd);
364 pthread_join(recv_thread, &retval);
365 return (intptr_t)retval;
366 }
367
main(int argc,char * argv[])368 int main(int argc, char *argv[])
369 {
370 int ret;
371
372 if (argc > 1)
373 return 0;
374
375 ret = test(0, 0);
376 if (ret) {
377 fprintf(stderr, "test recv failed\n");
378 return ret;
379 }
380
381 ret = test(1, 0);
382 if (ret) {
383 fprintf(stderr, "test recvmsg failed\n");
384 return ret;
385 }
386
387 ret = test(0, 1);
388 if (ret) {
389 fprintf(stderr, "test sync recv failed\n");
390 return ret;
391 }
392
393 ret = test(1, 1);
394 if (ret) {
395 fprintf(stderr, "test sync recvmsg failed\n");
396 return ret;
397 }
398
399 return 0;
400 }
401