1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package os_test
6
7import (
8	"bytes"
9	"internal/poll"
10	"io"
11	"math/rand"
12	"net"
13	. "os"
14	"strconv"
15	"syscall"
16	"testing"
17	"time"
18)
19
20func TestSendFile(t *testing.T) {
21	sizes := []int{
22		1,
23		42,
24		1025,
25		syscall.Getpagesize() + 1,
26		32769,
27	}
28	t.Run("sendfile-to-unix", func(t *testing.T) {
29		for _, size := range sizes {
30			t.Run(strconv.Itoa(size), func(t *testing.T) {
31				testSendFile(t, "unix", int64(size))
32			})
33		}
34	})
35	t.Run("sendfile-to-tcp", func(t *testing.T) {
36		for _, size := range sizes {
37			t.Run(strconv.Itoa(size), func(t *testing.T) {
38				testSendFile(t, "tcp", int64(size))
39			})
40		}
41	})
42}
43
44func testSendFile(t *testing.T, proto string, size int64) {
45	dst, src, recv, data, hook := newSendFileTest(t, proto, size)
46
47	// Now call WriteTo (through io.Copy), which will hopefully call poll.SendFile
48	n, err := io.Copy(dst, src)
49	if err != nil {
50		t.Fatalf("io.Copy error: %v", err)
51	}
52
53	// We should have called poll.Splice with the right file descriptor arguments.
54	if n > 0 && !hook.called {
55		t.Fatal("expected to called poll.SendFile")
56	}
57	if hook.called && hook.srcfd != int(src.Fd()) {
58		t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
59	}
60	sc, ok := dst.(syscall.Conn)
61	if !ok {
62		t.Fatalf("destination is not a syscall.Conn")
63	}
64	rc, err := sc.SyscallConn()
65	if err != nil {
66		t.Fatalf("destination SyscallConn error: %v", err)
67	}
68	if err = rc.Control(func(fd uintptr) {
69		if hook.called && hook.dstfd != int(fd) {
70			t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd))
71		}
72	}); err != nil {
73		t.Fatalf("destination Conn Control error: %v", err)
74	}
75
76	// Verify the data size and content.
77	dataSize := len(data)
78	dstData := make([]byte, dataSize)
79	m, err := io.ReadFull(recv, dstData)
80	if err != nil {
81		t.Fatalf("server Conn Read error: %v", err)
82	}
83	if n != int64(dataSize) {
84		t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize)
85	}
86	if m != dataSize {
87		t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize)
88	}
89	if !bytes.Equal(dstData, data) {
90		t.Errorf("data mismatch, got %s, want %s", dstData, data)
91	}
92}
93
94// newSendFileTest initializes a new test for sendfile.
95//
96// It creates source file and destination sockets, and populates the source file
97// with random data of the specified size. It also hooks package os' call
98// to poll.Sendfile and returns the hook so it can be inspected.
99func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) {
100	t.Helper()
101
102	hook := hookSendFile(t)
103
104	client, server := createSocketPair(t, proto)
105	tempFile, data := createTempFile(t, size)
106
107	return client, tempFile, server, data, hook
108}
109
110func hookSendFile(t *testing.T) *sendFileHook {
111	h := new(sendFileHook)
112	orig := poll.TestHookDidSendFile
113	t.Cleanup(func() {
114		poll.TestHookDidSendFile = orig
115	})
116	poll.TestHookDidSendFile = func(dstFD *poll.FD, src int, written int64, err error, handled bool) {
117		h.called = true
118		h.dstfd = dstFD.Sysfd
119		h.srcfd = src
120		h.written = written
121		h.err = err
122		h.handled = handled
123	}
124	return h
125}
126
127type sendFileHook struct {
128	called bool
129	dstfd  int
130	srcfd  int
131
132	written int64
133	handled bool
134	err     error
135}
136
137func createTempFile(t *testing.T, size int64) (*File, []byte) {
138	f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket")
139	if err != nil {
140		t.Fatalf("failed to create temporary file: %v", err)
141	}
142	t.Cleanup(func() {
143		f.Close()
144	})
145
146	randSeed := time.Now().Unix()
147	t.Logf("random data seed: %d\n", randSeed)
148	prng := rand.New(rand.NewSource(randSeed))
149	data := make([]byte, size)
150	prng.Read(data)
151	if _, err := f.Write(data); err != nil {
152		t.Fatalf("failed to create and feed the file: %v", err)
153	}
154	if err := f.Sync(); err != nil {
155		t.Fatalf("failed to save the file: %v", err)
156	}
157	if _, err := f.Seek(0, io.SeekStart); err != nil {
158		t.Fatalf("failed to rewind the file: %v", err)
159	}
160
161	return f, data
162}
163