1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package net
6
7import (
8	"bytes"
9	"context"
10	"crypto/sha256"
11	"encoding/hex"
12	"errors"
13	"fmt"
14	"internal/poll"
15	"io"
16	"os"
17	"runtime"
18	"strconv"
19	"sync"
20	"testing"
21	"time"
22)
23
24const (
25	newton       = "../testdata/Isaac.Newton-Opticks.txt"
26	newtonLen    = 567198
27	newtonSHA256 = "d4a9ac22462b35e7821a4f2706c211093da678620a8f9997989ee7cf8d507bbd"
28)
29
30// expectSendfile runs f, and verifies that internal/poll.SendFile successfully handles
31// a write to wantConn during f's execution.
32//
33// On platforms where supportsSendfile is false, expectSendfile runs f but does not
34// expect a call to SendFile.
35func expectSendfile(t *testing.T, wantConn Conn, f func()) {
36	t.Helper()
37	if !supportsSendfile {
38		f()
39		return
40	}
41	orig := poll.TestHookDidSendFile
42	defer func() {
43		poll.TestHookDidSendFile = orig
44	}()
45	var (
46		called     bool
47		gotHandled bool
48		gotFD      *poll.FD
49	)
50	poll.TestHookDidSendFile = func(dstFD *poll.FD, src int, written int64, err error, handled bool) {
51		if called {
52			t.Error("internal/poll.SendFile called multiple times, want one call")
53		}
54		called = true
55		gotHandled = handled
56		gotFD = dstFD
57	}
58	f()
59	if !called {
60		t.Error("internal/poll.SendFile was not called, want it to be")
61		return
62	}
63	if !gotHandled {
64		t.Error("internal/poll.SendFile did not handle the write, want it to")
65		return
66	}
67	if &wantConn.(*TCPConn).fd.pfd != gotFD {
68		t.Error("internal.poll.SendFile called with unexpected FD")
69	}
70}
71
72func TestSendfile(t *testing.T) {
73	ln := newLocalListener(t, "tcp")
74	defer ln.Close()
75
76	errc := make(chan error, 1)
77	go func(ln Listener) {
78		// Wait for a connection.
79		conn, err := ln.Accept()
80		if err != nil {
81			errc <- err
82			close(errc)
83			return
84		}
85
86		go func() {
87			defer close(errc)
88			defer conn.Close()
89
90			f, err := os.Open(newton)
91			if err != nil {
92				errc <- err
93				return
94			}
95			defer f.Close()
96
97			// Return file data using io.Copy, which should use
98			// sendFile if available.
99			var sbytes int64
100			switch runtime.GOOS {
101			case "windows":
102				// Windows is not using sendfile for some reason:
103				// https://go.dev/issue/67042
104				sbytes, err = io.Copy(conn, f)
105			default:
106				expectSendfile(t, conn, func() {
107					sbytes, err = io.Copy(conn, f)
108				})
109			}
110			if err != nil {
111				errc <- err
112				return
113			}
114
115			if sbytes != newtonLen {
116				errc <- fmt.Errorf("sent %d bytes; expected %d", sbytes, newtonLen)
117				return
118			}
119		}()
120	}(ln)
121
122	// Connect to listener to retrieve file and verify digest matches
123	// expected.
124	c, err := Dial("tcp", ln.Addr().String())
125	if err != nil {
126		t.Fatal(err)
127	}
128	defer c.Close()
129
130	h := sha256.New()
131	rbytes, err := io.Copy(h, c)
132	if err != nil {
133		t.Error(err)
134	}
135
136	if rbytes != newtonLen {
137		t.Errorf("received %d bytes; expected %d", rbytes, newtonLen)
138	}
139
140	if res := hex.EncodeToString(h.Sum(nil)); res != newtonSHA256 {
141		t.Error("retrieved data hash did not match")
142	}
143
144	for err := range errc {
145		t.Error(err)
146	}
147}
148
149func TestSendfileParts(t *testing.T) {
150	ln := newLocalListener(t, "tcp")
151	defer ln.Close()
152
153	errc := make(chan error, 1)
154	go func(ln Listener) {
155		// Wait for a connection.
156		conn, err := ln.Accept()
157		if err != nil {
158			errc <- err
159			close(errc)
160			return
161		}
162
163		go func() {
164			defer close(errc)
165			defer conn.Close()
166
167			f, err := os.Open(newton)
168			if err != nil {
169				errc <- err
170				return
171			}
172			defer f.Close()
173
174			for i := 0; i < 3; i++ {
175				// Return file data using io.CopyN, which should use
176				// sendFile if available.
177				expectSendfile(t, conn, func() {
178					_, err = io.CopyN(conn, f, 3)
179				})
180				if err != nil {
181					errc <- err
182					return
183				}
184			}
185		}()
186	}(ln)
187
188	c, err := Dial("tcp", ln.Addr().String())
189	if err != nil {
190		t.Fatal(err)
191	}
192	defer c.Close()
193
194	buf := new(bytes.Buffer)
195	buf.ReadFrom(c)
196
197	if want, have := "Produced ", buf.String(); have != want {
198		t.Errorf("unexpected server reply %q, want %q", have, want)
199	}
200
201	for err := range errc {
202		t.Error(err)
203	}
204}
205
206func TestSendfileSeeked(t *testing.T) {
207	ln := newLocalListener(t, "tcp")
208	defer ln.Close()
209
210	const seekTo = 65 << 10
211	const sendSize = 10 << 10
212
213	errc := make(chan error, 1)
214	go func(ln Listener) {
215		// Wait for a connection.
216		conn, err := ln.Accept()
217		if err != nil {
218			errc <- err
219			close(errc)
220			return
221		}
222
223		go func() {
224			defer close(errc)
225			defer conn.Close()
226
227			f, err := os.Open(newton)
228			if err != nil {
229				errc <- err
230				return
231			}
232			defer f.Close()
233			if _, err := f.Seek(seekTo, io.SeekStart); err != nil {
234				errc <- err
235				return
236			}
237
238			expectSendfile(t, conn, func() {
239				_, err = io.CopyN(conn, f, sendSize)
240			})
241			if err != nil {
242				errc <- err
243				return
244			}
245		}()
246	}(ln)
247
248	c, err := Dial("tcp", ln.Addr().String())
249	if err != nil {
250		t.Fatal(err)
251	}
252	defer c.Close()
253
254	buf := new(bytes.Buffer)
255	buf.ReadFrom(c)
256
257	if buf.Len() != sendSize {
258		t.Errorf("Got %d bytes; want %d", buf.Len(), sendSize)
259	}
260
261	for err := range errc {
262		t.Error(err)
263	}
264}
265
266// Test that sendfile doesn't put a pipe into blocking mode.
267func TestSendfilePipe(t *testing.T) {
268	switch runtime.GOOS {
269	case "plan9", "windows", "js", "wasip1":
270		// These systems don't support deadlines on pipes.
271		t.Skipf("skipping on %s", runtime.GOOS)
272	}
273
274	t.Parallel()
275
276	ln := newLocalListener(t, "tcp")
277	defer ln.Close()
278
279	r, w, err := os.Pipe()
280	if err != nil {
281		t.Fatal(err)
282	}
283	defer w.Close()
284	defer r.Close()
285
286	copied := make(chan bool)
287
288	var wg sync.WaitGroup
289	wg.Add(1)
290	go func() {
291		// Accept a connection and copy 1 byte from the read end of
292		// the pipe to the connection. This will call into sendfile.
293		defer wg.Done()
294		conn, err := ln.Accept()
295		if err != nil {
296			t.Error(err)
297			return
298		}
299		defer conn.Close()
300		// The comment above states that this should call into sendfile,
301		// but empirically it doesn't seem to do so at this time.
302		// If it does, or does on some platforms, this CopyN should be wrapped
303		// in expectSendfile.
304		_, err = io.CopyN(conn, r, 1)
305		if err != nil {
306			t.Error(err)
307			return
308		}
309		// Signal the main goroutine that we've copied the byte.
310		close(copied)
311	}()
312
313	wg.Add(1)
314	go func() {
315		// Write 1 byte to the write end of the pipe.
316		defer wg.Done()
317		_, err := w.Write([]byte{'a'})
318		if err != nil {
319			t.Error(err)
320		}
321	}()
322
323	wg.Add(1)
324	go func() {
325		// Connect to the server started two goroutines up and
326		// discard any data that it writes.
327		defer wg.Done()
328		conn, err := Dial("tcp", ln.Addr().String())
329		if err != nil {
330			t.Error(err)
331			return
332		}
333		defer conn.Close()
334		io.Copy(io.Discard, conn)
335	}()
336
337	// Wait for the byte to be copied, meaning that sendfile has
338	// been called on the pipe.
339	<-copied
340
341	// Set a very short deadline on the read end of the pipe.
342	if err := r.SetDeadline(time.Now().Add(time.Microsecond)); err != nil {
343		t.Fatal(err)
344	}
345
346	wg.Add(1)
347	go func() {
348		// Wait for much longer than the deadline and write a byte
349		// to the pipe.
350		defer wg.Done()
351		time.Sleep(50 * time.Millisecond)
352		w.Write([]byte{'b'})
353	}()
354
355	// If this read does not time out, the pipe was incorrectly
356	// put into blocking mode.
357	_, err = r.Read(make([]byte, 1))
358	if err == nil {
359		t.Error("Read did not time out")
360	} else if !os.IsTimeout(err) {
361		t.Errorf("got error %v, expected a time out", err)
362	}
363
364	wg.Wait()
365}
366
367// Issue 43822: tests that returns EOF when conn write timeout.
368func TestSendfileOnWriteTimeoutExceeded(t *testing.T) {
369	ln := newLocalListener(t, "tcp")
370	defer ln.Close()
371
372	errc := make(chan error, 1)
373	go func(ln Listener) (retErr error) {
374		defer func() {
375			errc <- retErr
376			close(errc)
377		}()
378
379		conn, err := ln.Accept()
380		if err != nil {
381			return err
382		}
383		defer conn.Close()
384
385		// Set the write deadline in the past(1h ago). It makes
386		// sure that it is always write timeout.
387		if err := conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)); err != nil {
388			return err
389		}
390
391		f, err := os.Open(newton)
392		if err != nil {
393			return err
394		}
395		defer f.Close()
396
397		// We expect this to use sendfile, but as of the time this comment was written
398		// poll.SendFile on an FD past its timeout can return an error indicating that
399		// it didn't handle the operation, resulting in a non-sendfile retry.
400		// So don't use expectSendfile here.
401		_, err = io.Copy(conn, f)
402		if errors.Is(err, os.ErrDeadlineExceeded) {
403			return nil
404		}
405
406		if err == nil {
407			err = fmt.Errorf("expected ErrDeadlineExceeded, but got nil")
408		}
409		return err
410	}(ln)
411
412	conn, err := Dial("tcp", ln.Addr().String())
413	if err != nil {
414		t.Fatal(err)
415	}
416	defer conn.Close()
417
418	n, err := io.Copy(io.Discard, conn)
419	if err != nil {
420		t.Fatalf("expected nil error, but got %v", err)
421	}
422	if n != 0 {
423		t.Fatalf("expected receive zero, but got %d byte(s)", n)
424	}
425
426	if err := <-errc; err != nil {
427		t.Fatal(err)
428	}
429}
430
431func BenchmarkSendfileZeroBytes(b *testing.B) {
432	var (
433		wg          sync.WaitGroup
434		ctx, cancel = context.WithCancel(context.Background())
435	)
436
437	defer wg.Wait()
438
439	ln := newLocalListener(b, "tcp")
440	defer ln.Close()
441
442	tempFile, err := os.CreateTemp(b.TempDir(), "test.txt")
443	if err != nil {
444		b.Fatalf("failed to create temp file: %v", err)
445	}
446	defer tempFile.Close()
447
448	fileName := tempFile.Name()
449
450	dataSize := b.N
451	wg.Add(1)
452	go func(f *os.File) {
453		defer wg.Done()
454
455		for i := 0; i < dataSize; i++ {
456			if _, err := f.Write([]byte{1}); err != nil {
457				b.Errorf("failed to write: %v", err)
458				return
459			}
460			if i%1000 == 0 {
461				f.Sync()
462			}
463		}
464	}(tempFile)
465
466	b.ResetTimer()
467	b.ReportAllocs()
468
469	wg.Add(1)
470	go func(ln Listener, fileName string) {
471		defer wg.Done()
472
473		conn, err := ln.Accept()
474		if err != nil {
475			b.Errorf("failed to accept: %v", err)
476			return
477		}
478		defer conn.Close()
479
480		f, err := os.OpenFile(fileName, os.O_RDONLY, 0660)
481		if err != nil {
482			b.Errorf("failed to open file: %v", err)
483			return
484		}
485		defer f.Close()
486
487		for {
488			if ctx.Err() != nil {
489				return
490			}
491
492			if _, err := io.Copy(conn, f); err != nil {
493				b.Errorf("failed to copy: %v", err)
494				return
495			}
496		}
497	}(ln, fileName)
498
499	conn, err := Dial("tcp", ln.Addr().String())
500	if err != nil {
501		b.Fatalf("failed to dial: %v", err)
502	}
503	defer conn.Close()
504
505	n, err := io.CopyN(io.Discard, conn, int64(dataSize))
506	if err != nil {
507		b.Fatalf("failed to copy: %v", err)
508	}
509	if n != int64(dataSize) {
510		b.Fatalf("expected %d copied bytes, but got %d", dataSize, n)
511	}
512
513	cancel()
514}
515
516func BenchmarkSendFile(b *testing.B) {
517	if runtime.GOOS == "windows" {
518		// TODO(panjf2000): Windows has not yet implemented FileConn,
519		//		remove this when it's implemented in https://go.dev/issues/9503.
520		b.Skipf("skipping on %s", runtime.GOOS)
521	}
522
523	b.Run("file-to-tcp", func(b *testing.B) { benchmarkSendFile(b, "tcp") })
524	b.Run("file-to-unix", func(b *testing.B) { benchmarkSendFile(b, "unix") })
525}
526
527func benchmarkSendFile(b *testing.B, proto string) {
528	for i := 0; i <= 10; i++ {
529		size := 1 << (i + 10)
530		bench := sendFileBench{
531			proto:     proto,
532			chunkSize: size,
533		}
534		b.Run(strconv.Itoa(size), bench.benchSendFile)
535	}
536}
537
538type sendFileBench struct {
539	proto     string
540	chunkSize int
541}
542
543func (bench sendFileBench) benchSendFile(b *testing.B) {
544	fileSize := b.N * bench.chunkSize
545	f := createTempFile(b, fileSize)
546
547	client, server := spawnTestSocketPair(b, bench.proto)
548	defer server.Close()
549
550	cleanUp, err := startTestSocketPeer(b, client, "r", bench.chunkSize, fileSize)
551	if err != nil {
552		client.Close()
553		b.Fatal(err)
554	}
555	defer cleanUp(b)
556
557	b.ReportAllocs()
558	b.SetBytes(int64(bench.chunkSize))
559	b.ResetTimer()
560
561	// Data go from file to socket via sendfile(2).
562	sent, err := io.Copy(server, f)
563	if err != nil {
564		b.Fatalf("failed to copy data with sendfile, error: %v", err)
565	}
566	if sent != int64(fileSize) {
567		b.Fatalf("bytes sent mismatch, got: %d, want: %d", sent, fileSize)
568	}
569}
570
571func createTempFile(b *testing.B, size int) *os.File {
572	f, err := os.CreateTemp(b.TempDir(), "sendfile-bench")
573	if err != nil {
574		b.Fatalf("failed to create temporary file: %v", err)
575	}
576	b.Cleanup(func() {
577		f.Close()
578	})
579
580	data := make([]byte, size)
581	if _, err := f.Write(data); err != nil {
582		b.Fatalf("failed to create and feed the file: %v", err)
583	}
584	if err := f.Sync(); err != nil {
585		b.Fatalf("failed to save the file: %v", err)
586	}
587	if _, err := f.Seek(0, io.SeekStart); err != nil {
588		b.Fatalf("failed to rewind the file: %v", err)
589	}
590
591	return f
592}
593