xref: /aosp_15_r20/external/liburing/test/recv-msgall-stream.c (revision 25da2bea747f3a93b4c30fd9708b0618ef55a0e6)
1 /* SPDX-License-Identifier: MIT */
2 /*
3  * Test MSG_WAITALL for recv/recvmsg and include normal sync versions just
4  * for comparison.
5  */
6 #include <errno.h>
7 #include <stdio.h>
8 #include <stdlib.h>
9 #include <string.h>
10 #include <unistd.h>
11 #include <fcntl.h>
12 #include <arpa/inet.h>
13 #include <sys/types.h>
14 #include <sys/socket.h>
15 #include <pthread.h>
16 
17 #include "liburing.h"
18 #include "helpers.h"
19 
20 #define MAX_MSG	128
21 
22 static int port = 31200;
23 
24 struct recv_data {
25 	pthread_mutex_t mutex;
26 	int use_recvmsg;
27 	int use_sync;
28 	int port;
29 };
30 
get_conn_sock(struct recv_data * rd,int * sockout)31 static int get_conn_sock(struct recv_data *rd, int *sockout)
32 {
33 	struct sockaddr_in saddr;
34 	int sockfd, ret, val;
35 
36 	memset(&saddr, 0, sizeof(saddr));
37 	saddr.sin_family = AF_INET;
38 	saddr.sin_addr.s_addr = htonl(INADDR_ANY);
39 	saddr.sin_port = htons(rd->port);
40 
41 	sockfd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
42 	if (sockfd < 0) {
43 		perror("socket");
44 		goto err;
45 	}
46 
47 	val = 1;
48 	setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));
49 	setsockopt(sockfd, SOL_SOCKET, SO_REUSEPORT, &val, sizeof(val));
50 
51 	ret = bind(sockfd, (struct sockaddr *)&saddr, sizeof(saddr));
52 	if (ret < 0) {
53 		perror("bind");
54 		goto err;
55 	}
56 
57 	ret = listen(sockfd, 16);
58 	if (ret < 0) {
59 		perror("listen");
60 		goto err;
61 	}
62 
63 	pthread_mutex_unlock(&rd->mutex);
64 
65 	ret = accept(sockfd, NULL, NULL);
66 	if (ret < 0) {
67 		perror("accept");
68 		return -1;
69 	}
70 
71 	*sockout = sockfd;
72 	return ret;
73 err:
74 	pthread_mutex_unlock(&rd->mutex);
75 	return -1;
76 }
77 
recv_prep(struct io_uring * ring,struct iovec * iov,int * sock,struct recv_data * rd)78 static int recv_prep(struct io_uring *ring, struct iovec *iov, int *sock,
79 		     struct recv_data *rd)
80 {
81 	struct io_uring_sqe *sqe;
82 	struct msghdr msg = { };
83 	int sockfd, sockout = -1, ret;
84 
85 	sockfd = get_conn_sock(rd, &sockout);
86 	if (sockfd < 0)
87 		goto err;
88 
89 	sqe = io_uring_get_sqe(ring);
90 	if (!rd->use_recvmsg) {
91 		io_uring_prep_recv(sqe, sockfd, iov->iov_base, iov->iov_len,
92 					MSG_WAITALL);
93 	} else {
94 		msg.msg_namelen = sizeof(struct sockaddr_in);
95 		msg.msg_iov = iov;
96 		msg.msg_iovlen = 1;
97 		io_uring_prep_recvmsg(sqe, sockfd, &msg, MSG_WAITALL);
98 	}
99 
100 	sqe->user_data = 2;
101 
102 	ret = io_uring_submit(ring);
103 	if (ret <= 0) {
104 		fprintf(stderr, "submit failed: %d\n", ret);
105 		goto err;
106 	}
107 
108 	*sock = sockfd;
109 	return 0;
110 err:
111 	if (sockout != -1) {
112 		shutdown(sockout, SHUT_RDWR);
113 		close(sockout);
114 	}
115 	if (sockfd != -1) {
116 		shutdown(sockfd, SHUT_RDWR);
117 		close(sockfd);
118 	}
119 	return 1;
120 }
121 
do_recv(struct io_uring * ring)122 static int do_recv(struct io_uring *ring)
123 {
124 	struct io_uring_cqe *cqe;
125 	int ret;
126 
127 	ret = io_uring_wait_cqe(ring, &cqe);
128 	if (ret) {
129 		fprintf(stdout, "wait_cqe: %d\n", ret);
130 		goto err;
131 	}
132 	if (cqe->res == -EINVAL) {
133 		fprintf(stdout, "recv not supported, skipping\n");
134 		return 0;
135 	}
136 	if (cqe->res < 0) {
137 		fprintf(stderr, "failed cqe: %d\n", cqe->res);
138 		goto err;
139 	}
140 	if (cqe->res != MAX_MSG * sizeof(int)) {
141 		fprintf(stderr, "got wrong length: %d\n", cqe->res);
142 		goto err;
143 	}
144 
145 	io_uring_cqe_seen(ring, cqe);
146 	return 0;
147 err:
148 	return 1;
149 }
150 
recv_sync(struct recv_data * rd)151 static int recv_sync(struct recv_data *rd)
152 {
153 	int buf[MAX_MSG];
154 	struct iovec iov = {
155 		.iov_base = buf,
156 		.iov_len = sizeof(buf),
157 	};
158 	int i, ret, sockfd, sockout = -1;
159 
160 	sockfd = get_conn_sock(rd, &sockout);
161 
162 	if (rd->use_recvmsg) {
163 		struct msghdr msg = { };
164 
165 		msg.msg_namelen = sizeof(struct sockaddr_in);
166 		msg.msg_iov = &iov;
167 		msg.msg_iovlen = 1;
168 		ret = recvmsg(sockfd, &msg, MSG_WAITALL);
169 	} else {
170 		ret = recv(sockfd, buf, sizeof(buf), MSG_WAITALL);
171 	}
172 
173 	if (ret < 0) {
174 		perror("receive");
175 		goto err;
176 	}
177 
178 	if (ret != sizeof(buf)) {
179 		ret = -1;
180 		goto err;
181 	}
182 
183 	for (i = 0; i < MAX_MSG; i++) {
184 		if (buf[i] != i)
185 			goto err;
186 	}
187 	ret = 0;
188 err:
189 	shutdown(sockout, SHUT_RDWR);
190 	shutdown(sockfd, SHUT_RDWR);
191 	close(sockout);
192 	close(sockfd);
193 	return ret;
194 }
195 
recv_uring(struct recv_data * rd)196 static int recv_uring(struct recv_data *rd)
197 {
198 	int buf[MAX_MSG];
199 	struct iovec iov = {
200 		.iov_base = buf,
201 		.iov_len = sizeof(buf),
202 	};
203 	struct io_uring_params p = { };
204 	struct io_uring ring;
205 	int ret, sock = -1, sockout = -1;
206 
207 	ret = t_create_ring_params(1, &ring, &p);
208 	if (ret == T_SETUP_SKIP) {
209 		pthread_mutex_unlock(&rd->mutex);
210 		ret = 0;
211 		goto err;
212 	} else if (ret < 0) {
213 		pthread_mutex_unlock(&rd->mutex);
214 		goto err;
215 	}
216 
217 	sock = recv_prep(&ring, &iov, &sockout, rd);
218 	if (ret) {
219 		fprintf(stderr, "recv_prep failed: %d\n", ret);
220 		goto err;
221 	}
222 	ret = do_recv(&ring);
223 	if (!ret) {
224 		int i;
225 
226 		for (i = 0; i < MAX_MSG; i++) {
227 			if (buf[i] != i) {
228 				fprintf(stderr, "found %d at %d\n", buf[i], i);
229 				ret = 1;
230 				break;
231 			}
232 		}
233 	}
234 
235 	shutdown(sockout, SHUT_RDWR);
236 	shutdown(sock, SHUT_RDWR);
237 	close(sock);
238 	close(sockout);
239 	io_uring_queue_exit(&ring);
240 err:
241 	if (sock != -1) {
242 		shutdown(sock, SHUT_RDWR);
243 		close(sock);
244 	}
245 	if (sockout != -1) {
246 		shutdown(sockout, SHUT_RDWR);
247 		close(sockout);
248 	}
249 	return ret;
250 }
251 
recv_fn(void * data)252 static void *recv_fn(void *data)
253 {
254 	struct recv_data *rd = data;
255 
256 	if (rd->use_sync)
257 		return (void *) (uintptr_t) recv_sync(rd);
258 
259 	return (void *) (uintptr_t) recv_uring(rd);
260 }
261 
do_send(struct recv_data * rd)262 static int do_send(struct recv_data *rd)
263 {
264 	struct sockaddr_in saddr;
265 	struct io_uring ring;
266 	struct io_uring_cqe *cqe;
267 	struct io_uring_sqe *sqe;
268 	int sockfd, ret, i;
269 	struct iovec iov;
270 	int *buf;
271 
272 	ret = io_uring_queue_init(2, &ring, 0);
273 	if (ret) {
274 		fprintf(stderr, "queue init failed: %d\n", ret);
275 		return 1;
276 	}
277 
278 	buf = malloc(MAX_MSG * sizeof(int));
279 	for (i = 0; i < MAX_MSG; i++)
280 		buf[i] = i;
281 
282 	memset(&saddr, 0, sizeof(saddr));
283 	saddr.sin_family = AF_INET;
284 	saddr.sin_port = htons(rd->port);
285 	inet_pton(AF_INET, "127.0.0.1", &saddr.sin_addr);
286 
287 	sockfd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, IPPROTO_TCP);
288 	if (sockfd < 0) {
289 		perror("socket");
290 		return 1;
291 	}
292 
293 	pthread_mutex_lock(&rd->mutex);
294 
295 	ret = connect(sockfd, (struct sockaddr *)&saddr, sizeof(saddr));
296 	if (ret < 0) {
297 		perror("connect");
298 		return 1;
299 	}
300 
301 	iov.iov_base = buf;
302 	iov.iov_len = MAX_MSG * sizeof(int) / 2;
303 	for (i = 0; i < 2; i++) {
304 		sqe = io_uring_get_sqe(&ring);
305 		io_uring_prep_send(sqe, sockfd, iov.iov_base, iov.iov_len, 0);
306 		sqe->user_data = 1;
307 
308 		ret = io_uring_submit(&ring);
309 		if (ret <= 0) {
310 			fprintf(stderr, "submit failed: %d\n", ret);
311 			goto err;
312 		}
313 		usleep(10000);
314 		iov.iov_base += iov.iov_len;
315 	}
316 
317 	for (i = 0; i < 2; i++) {
318 		ret = io_uring_wait_cqe(&ring, &cqe);
319 		if (cqe->res == -EINVAL) {
320 			fprintf(stdout, "send not supported, skipping\n");
321 			close(sockfd);
322 			return 0;
323 		}
324 		if (cqe->res != iov.iov_len) {
325 			fprintf(stderr, "failed cqe: %d\n", cqe->res);
326 			goto err;
327 		}
328 		io_uring_cqe_seen(&ring, cqe);
329 	}
330 
331 	shutdown(sockfd, SHUT_RDWR);
332 	close(sockfd);
333 	return 0;
334 err:
335 	shutdown(sockfd, SHUT_RDWR);
336 	close(sockfd);
337 	return 1;
338 }
339 
test(int use_recvmsg,int use_sync)340 static int test(int use_recvmsg, int use_sync)
341 {
342 	pthread_mutexattr_t attr;
343 	pthread_t recv_thread;
344 	struct recv_data rd;
345 	int ret;
346 	void *retval;
347 
348 	pthread_mutexattr_init(&attr);
349 	pthread_mutexattr_setpshared(&attr, 1);
350 	pthread_mutex_init(&rd.mutex, &attr);
351 	pthread_mutex_lock(&rd.mutex);
352 	rd.use_recvmsg = use_recvmsg;
353 	rd.use_sync = use_sync;
354 	rd.port = port++;
355 
356 	ret = pthread_create(&recv_thread, NULL, recv_fn, &rd);
357 	if (ret) {
358 		fprintf(stderr, "Thread create failed: %d\n", ret);
359 		pthread_mutex_unlock(&rd.mutex);
360 		return 1;
361 	}
362 
363 	do_send(&rd);
364 	pthread_join(recv_thread, &retval);
365 	return (intptr_t)retval;
366 }
367 
main(int argc,char * argv[])368 int main(int argc, char *argv[])
369 {
370 	int ret;
371 
372 	if (argc > 1)
373 		return 0;
374 
375 	ret = test(0, 0);
376 	if (ret) {
377 		fprintf(stderr, "test recv failed\n");
378 		return ret;
379 	}
380 
381 	ret = test(1, 0);
382 	if (ret) {
383 		fprintf(stderr, "test recvmsg failed\n");
384 		return ret;
385 	}
386 
387 	ret = test(0, 1);
388 	if (ret) {
389 		fprintf(stderr, "test sync recv failed\n");
390 		return ret;
391 	}
392 
393 	ret = test(1, 1);
394 	if (ret) {
395 		fprintf(stderr, "test sync recvmsg failed\n");
396 		return ret;
397 	}
398 
399 	return 0;
400 }
401