1// Copyright 2023 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package os_test 6 7import ( 8 "bytes" 9 "internal/poll" 10 "io" 11 "math/rand" 12 "net" 13 . "os" 14 "strconv" 15 "syscall" 16 "testing" 17 "time" 18) 19 20func TestSendFile(t *testing.T) { 21 sizes := []int{ 22 1, 23 42, 24 1025, 25 syscall.Getpagesize() + 1, 26 32769, 27 } 28 t.Run("sendfile-to-unix", func(t *testing.T) { 29 for _, size := range sizes { 30 t.Run(strconv.Itoa(size), func(t *testing.T) { 31 testSendFile(t, "unix", int64(size)) 32 }) 33 } 34 }) 35 t.Run("sendfile-to-tcp", func(t *testing.T) { 36 for _, size := range sizes { 37 t.Run(strconv.Itoa(size), func(t *testing.T) { 38 testSendFile(t, "tcp", int64(size)) 39 }) 40 } 41 }) 42} 43 44func testSendFile(t *testing.T, proto string, size int64) { 45 dst, src, recv, data, hook := newSendFileTest(t, proto, size) 46 47 // Now call WriteTo (through io.Copy), which will hopefully call poll.SendFile 48 n, err := io.Copy(dst, src) 49 if err != nil { 50 t.Fatalf("io.Copy error: %v", err) 51 } 52 53 // We should have called poll.Splice with the right file descriptor arguments. 54 if n > 0 && !hook.called { 55 t.Fatal("expected to called poll.SendFile") 56 } 57 if hook.called && hook.srcfd != int(src.Fd()) { 58 t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd()) 59 } 60 sc, ok := dst.(syscall.Conn) 61 if !ok { 62 t.Fatalf("destination is not a syscall.Conn") 63 } 64 rc, err := sc.SyscallConn() 65 if err != nil { 66 t.Fatalf("destination SyscallConn error: %v", err) 67 } 68 if err = rc.Control(func(fd uintptr) { 69 if hook.called && hook.dstfd != int(fd) { 70 t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd)) 71 } 72 }); err != nil { 73 t.Fatalf("destination Conn Control error: %v", err) 74 } 75 76 // Verify the data size and content. 77 dataSize := len(data) 78 dstData := make([]byte, dataSize) 79 m, err := io.ReadFull(recv, dstData) 80 if err != nil { 81 t.Fatalf("server Conn Read error: %v", err) 82 } 83 if n != int64(dataSize) { 84 t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize) 85 } 86 if m != dataSize { 87 t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize) 88 } 89 if !bytes.Equal(dstData, data) { 90 t.Errorf("data mismatch, got %s, want %s", dstData, data) 91 } 92} 93 94// newSendFileTest initializes a new test for sendfile. 95// 96// It creates source file and destination sockets, and populates the source file 97// with random data of the specified size. It also hooks package os' call 98// to poll.Sendfile and returns the hook so it can be inspected. 99func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) { 100 t.Helper() 101 102 hook := hookSendFile(t) 103 104 client, server := createSocketPair(t, proto) 105 tempFile, data := createTempFile(t, size) 106 107 return client, tempFile, server, data, hook 108} 109 110func hookSendFile(t *testing.T) *sendFileHook { 111 h := new(sendFileHook) 112 orig := poll.TestHookDidSendFile 113 t.Cleanup(func() { 114 poll.TestHookDidSendFile = orig 115 }) 116 poll.TestHookDidSendFile = func(dstFD *poll.FD, src int, written int64, err error, handled bool) { 117 h.called = true 118 h.dstfd = dstFD.Sysfd 119 h.srcfd = src 120 h.written = written 121 h.err = err 122 h.handled = handled 123 } 124 return h 125} 126 127type sendFileHook struct { 128 called bool 129 dstfd int 130 srcfd int 131 132 written int64 133 handled bool 134 err error 135} 136 137func createTempFile(t *testing.T, size int64) (*File, []byte) { 138 f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket") 139 if err != nil { 140 t.Fatalf("failed to create temporary file: %v", err) 141 } 142 t.Cleanup(func() { 143 f.Close() 144 }) 145 146 randSeed := time.Now().Unix() 147 t.Logf("random data seed: %d\n", randSeed) 148 prng := rand.New(rand.NewSource(randSeed)) 149 data := make([]byte, size) 150 prng.Read(data) 151 if _, err := f.Write(data); err != nil { 152 t.Fatalf("failed to create and feed the file: %v", err) 153 } 154 if err := f.Sync(); err != nil { 155 t.Fatalf("failed to save the file: %v", err) 156 } 157 if _, err := f.Seek(0, io.SeekStart); err != nil { 158 t.Fatalf("failed to rewind the file: %v", err) 159 } 160 161 return f, data 162} 163