1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Tests for transport.go.
6//
7// More tests are in clientserver_test.go (for things testing both client & server for both
8// HTTP/1 and HTTP/2). This
9
10package http_test
11
12import (
13	"bufio"
14	"bytes"
15	"compress/gzip"
16	"context"
17	"crypto/rand"
18	"crypto/tls"
19	"crypto/x509"
20	"encoding/binary"
21	"errors"
22	"fmt"
23	"go/token"
24	"internal/nettrace"
25	"io"
26	"log"
27	mrand "math/rand"
28	"net"
29	. "net/http"
30	"net/http/httptest"
31	"net/http/httptrace"
32	"net/http/httputil"
33	"net/http/internal/testcert"
34	"net/textproto"
35	"net/url"
36	"os"
37	"reflect"
38	"runtime"
39	"strconv"
40	"strings"
41	"sync"
42	"sync/atomic"
43	"testing"
44	"testing/iotest"
45	"time"
46
47	"golang.org/x/net/http/httpguts"
48)
49
50// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close
51// and then verify that the final 2 responses get errors back.
52
53// hostPortHandler writes back the client's "host:port".
54var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
55	if r.FormValue("close") == "true" {
56		w.Header().Set("Connection", "close")
57	}
58	w.Header().Set("X-Saw-Close", fmt.Sprint(r.Close))
59	w.Write([]byte(r.RemoteAddr))
60
61	// Include the address of the net.Conn in addition to the RemoteAddr,
62	// in case kernels reuse source ports quickly (see Issue 52450)
63	if c, ok := ResponseWriterConnForTesting(w); ok {
64		fmt.Fprintf(w, ", %T %p", c, c)
65	}
66})
67
68// testCloseConn is a net.Conn tracked by a testConnSet.
69type testCloseConn struct {
70	net.Conn
71	set *testConnSet
72}
73
74func (c *testCloseConn) Close() error {
75	c.set.remove(c)
76	return c.Conn.Close()
77}
78
79// testConnSet tracks a set of TCP connections and whether they've
80// been closed.
81type testConnSet struct {
82	t      *testing.T
83	mu     sync.Mutex // guards closed and list
84	closed map[net.Conn]bool
85	list   []net.Conn // in order created
86}
87
88func (tcs *testConnSet) insert(c net.Conn) {
89	tcs.mu.Lock()
90	defer tcs.mu.Unlock()
91	tcs.closed[c] = false
92	tcs.list = append(tcs.list, c)
93}
94
95func (tcs *testConnSet) remove(c net.Conn) {
96	tcs.mu.Lock()
97	defer tcs.mu.Unlock()
98	tcs.closed[c] = true
99}
100
101// some tests use this to manage raw tcp connections for later inspection
102func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) {
103	connSet := &testConnSet{
104		t:      t,
105		closed: make(map[net.Conn]bool),
106	}
107	dial := func(n, addr string) (net.Conn, error) {
108		c, err := net.Dial(n, addr)
109		if err != nil {
110			return nil, err
111		}
112		tc := &testCloseConn{c, connSet}
113		connSet.insert(tc)
114		return tc, nil
115	}
116	return connSet, dial
117}
118
119func (tcs *testConnSet) check(t *testing.T) {
120	tcs.mu.Lock()
121	defer tcs.mu.Unlock()
122	for i := 4; i >= 0; i-- {
123		for i, c := range tcs.list {
124			if tcs.closed[c] {
125				continue
126			}
127			if i != 0 {
128				// TODO(bcmills): What is the Sleep here doing, and why is this
129				// Unlock/Sleep/Lock cycle needed at all?
130				tcs.mu.Unlock()
131				time.Sleep(50 * time.Millisecond)
132				tcs.mu.Lock()
133				continue
134			}
135			t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
136		}
137	}
138}
139
140func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) }
141func testReuseRequest(t *testing.T, mode testMode) {
142	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
143		w.Write([]byte("{}"))
144	})).ts
145
146	c := ts.Client()
147	req, _ := NewRequest("GET", ts.URL, nil)
148	res, err := c.Do(req)
149	if err != nil {
150		t.Fatal(err)
151	}
152	err = res.Body.Close()
153	if err != nil {
154		t.Fatal(err)
155	}
156
157	res, err = c.Do(req)
158	if err != nil {
159		t.Fatal(err)
160	}
161	err = res.Body.Close()
162	if err != nil {
163		t.Fatal(err)
164	}
165}
166
167// Two subsequent requests and verify their response is the same.
168// The response from the server is our own IP:port
169func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) }
170func testTransportKeepAlives(t *testing.T, mode testMode) {
171	ts := newClientServerTest(t, mode, hostPortHandler).ts
172
173	c := ts.Client()
174	for _, disableKeepAlive := range []bool{false, true} {
175		c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive
176		fetch := func(n int) string {
177			res, err := c.Get(ts.URL)
178			if err != nil {
179				t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
180			}
181			body, err := io.ReadAll(res.Body)
182			if err != nil {
183				t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
184			}
185			return string(body)
186		}
187
188		body1 := fetch(1)
189		body2 := fetch(2)
190
191		bodiesDiffer := body1 != body2
192		if bodiesDiffer != disableKeepAlive {
193			t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
194				disableKeepAlive, bodiesDiffer, body1, body2)
195		}
196	}
197}
198
199func TestTransportConnectionCloseOnResponse(t *testing.T) {
200	run(t, testTransportConnectionCloseOnResponse)
201}
202func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) {
203	ts := newClientServerTest(t, mode, hostPortHandler).ts
204
205	connSet, testDial := makeTestDial(t)
206
207	c := ts.Client()
208	tr := c.Transport.(*Transport)
209	tr.Dial = testDial
210
211	for _, connectionClose := range []bool{false, true} {
212		fetch := func(n int) string {
213			req := new(Request)
214			var err error
215			req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose))
216			if err != nil {
217				t.Fatalf("URL parse error: %v", err)
218			}
219			req.Method = "GET"
220			req.Proto = "HTTP/1.1"
221			req.ProtoMajor = 1
222			req.ProtoMinor = 1
223
224			res, err := c.Do(req)
225			if err != nil {
226				t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
227			}
228			defer res.Body.Close()
229			body, err := io.ReadAll(res.Body)
230			if err != nil {
231				t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
232			}
233			return string(body)
234		}
235
236		body1 := fetch(1)
237		body2 := fetch(2)
238		bodiesDiffer := body1 != body2
239		if bodiesDiffer != connectionClose {
240			t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
241				connectionClose, bodiesDiffer, body1, body2)
242		}
243
244		tr.CloseIdleConnections()
245	}
246
247	connSet.check(t)
248}
249
250// TestTransportConnectionCloseOnRequest tests that the Transport's doesn't reuse
251// an underlying TCP connection after making an http.Request with Request.Close set.
252//
253// It tests the behavior by making an HTTP request to a server which
254// describes the source connection it got (remote port number +
255// address of its net.Conn).
256func TestTransportConnectionCloseOnRequest(t *testing.T) {
257	run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode})
258}
259func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) {
260	ts := newClientServerTest(t, mode, hostPortHandler).ts
261
262	connSet, testDial := makeTestDial(t)
263
264	c := ts.Client()
265	tr := c.Transport.(*Transport)
266	tr.Dial = testDial
267	for _, reqClose := range []bool{false, true} {
268		fetch := func(n int) string {
269			req := new(Request)
270			var err error
271			req.URL, err = url.Parse(ts.URL)
272			if err != nil {
273				t.Fatalf("URL parse error: %v", err)
274			}
275			req.Method = "GET"
276			req.Proto = "HTTP/1.1"
277			req.ProtoMajor = 1
278			req.ProtoMinor = 1
279			req.Close = reqClose
280
281			res, err := c.Do(req)
282			if err != nil {
283				t.Fatalf("error in Request.Close=%v, req #%d, Do: %v", reqClose, n, err)
284			}
285			if got, want := res.Header.Get("X-Saw-Close"), fmt.Sprint(reqClose); got != want {
286				t.Errorf("for Request.Close = %v; handler's X-Saw-Close was %v; want %v",
287					reqClose, got, !reqClose)
288			}
289			body, err := io.ReadAll(res.Body)
290			if err != nil {
291				t.Fatalf("for Request.Close=%v, on request %v/2: ReadAll: %v", reqClose, n, err)
292			}
293			return string(body)
294		}
295
296		body1 := fetch(1)
297		body2 := fetch(2)
298
299		got := 1
300		if body1 != body2 {
301			got++
302		}
303		want := 1
304		if reqClose {
305			want = 2
306		}
307		if got != want {
308			t.Errorf("for Request.Close=%v: server saw %v unique connections, wanted %v\n\nbodies were: %q and %q",
309				reqClose, got, want, body1, body2)
310		}
311
312		tr.CloseIdleConnections()
313	}
314
315	connSet.check(t)
316}
317
318// if the Transport's DisableKeepAlives is set, all requests should
319// send Connection: close.
320// HTTP/1-only (Connection: close doesn't exist in h2)
321func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
322	run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode})
323}
324func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) {
325	ts := newClientServerTest(t, mode, hostPortHandler).ts
326
327	c := ts.Client()
328	c.Transport.(*Transport).DisableKeepAlives = true
329
330	res, err := c.Get(ts.URL)
331	if err != nil {
332		t.Fatal(err)
333	}
334	res.Body.Close()
335	if res.Header.Get("X-Saw-Close") != "true" {
336		t.Errorf("handler didn't see Connection: close ")
337	}
338}
339
340// Test that Transport only sends one "Connection: close", regardless of
341// how "close" was indicated.
342func TestTransportRespectRequestWantsClose(t *testing.T) {
343	run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode})
344}
345func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) {
346	tests := []struct {
347		disableKeepAlives bool
348		close             bool
349	}{
350		{disableKeepAlives: false, close: false},
351		{disableKeepAlives: false, close: true},
352		{disableKeepAlives: true, close: false},
353		{disableKeepAlives: true, close: true},
354	}
355
356	for _, tc := range tests {
357		t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close),
358			func(t *testing.T) {
359				ts := newClientServerTest(t, mode, hostPortHandler).ts
360
361				c := ts.Client()
362				c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives
363				req, err := NewRequest("GET", ts.URL, nil)
364				if err != nil {
365					t.Fatal(err)
366				}
367				count := 0
368				trace := &httptrace.ClientTrace{
369					WroteHeaderField: func(key string, field []string) {
370						if key != "Connection" {
371							return
372						}
373						if httpguts.HeaderValuesContainsToken(field, "close") {
374							count += 1
375						}
376					},
377				}
378				req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
379				req.Close = tc.close
380				res, err := c.Do(req)
381				if err != nil {
382					t.Fatal(err)
383				}
384				defer res.Body.Close()
385				if want := tc.disableKeepAlives || tc.close; count > 1 || (count == 1) != want {
386					t.Errorf("expecting want:%v, got 'Connection: close':%d", want, count)
387				}
388			})
389	}
390
391}
392
393func TestTransportIdleCacheKeys(t *testing.T) {
394	run(t, testTransportIdleCacheKeys, []testMode{http1Mode})
395}
396func testTransportIdleCacheKeys(t *testing.T, mode testMode) {
397	ts := newClientServerTest(t, mode, hostPortHandler).ts
398	c := ts.Client()
399	tr := c.Transport.(*Transport)
400
401	if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
402		t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
403	}
404
405	resp, err := c.Get(ts.URL)
406	if err != nil {
407		t.Error(err)
408	}
409	io.ReadAll(resp.Body)
410
411	keys := tr.IdleConnKeysForTesting()
412	if e, g := 1, len(keys); e != g {
413		t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
414	}
415
416	if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
417		t.Errorf("Expected idle cache key %q; got %q", e, keys[0])
418	}
419
420	tr.CloseIdleConnections()
421	if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
422		t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
423	}
424}
425
426// Tests that the HTTP transport re-uses connections when a client
427// reads to the end of a response Body without closing it.
428func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) }
429func testTransportReadToEndReusesConn(t *testing.T, mode testMode) {
430	const msg = "foobar"
431
432	var addrSeen map[string]int
433	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
434		addrSeen[r.RemoteAddr]++
435		if r.URL.Path == "/chunked/" {
436			w.WriteHeader(200)
437			w.(Flusher).Flush()
438		} else {
439			w.Header().Set("Content-Length", strconv.Itoa(len(msg)))
440			w.WriteHeader(200)
441		}
442		w.Write([]byte(msg))
443	})).ts
444
445	for pi, path := range []string{"/content-length/", "/chunked/"} {
446		wantLen := []int{len(msg), -1}[pi]
447		addrSeen = make(map[string]int)
448		for i := 0; i < 3; i++ {
449			res, err := ts.Client().Get(ts.URL + path)
450			if err != nil {
451				t.Errorf("Get %s: %v", path, err)
452				continue
453			}
454			// We want to close this body eventually (before the
455			// defer afterTest at top runs), but not before the
456			// len(addrSeen) check at the bottom of this test,
457			// since Closing this early in the loop would risk
458			// making connections be re-used for the wrong reason.
459			defer res.Body.Close()
460
461			if res.ContentLength != int64(wantLen) {
462				t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen)
463			}
464			got, err := io.ReadAll(res.Body)
465			if string(got) != msg || err != nil {
466				t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg)
467			}
468		}
469		if len(addrSeen) != 1 {
470			t.Errorf("for %s, server saw %d distinct client addresses; want 1", path, len(addrSeen))
471		}
472	}
473}
474
475func TestTransportMaxPerHostIdleConns(t *testing.T) {
476	run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode})
477}
478func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) {
479	stop := make(chan struct{}) // stop marks the exit of main Test goroutine
480	defer close(stop)
481
482	resch := make(chan string)
483	gotReq := make(chan bool)
484	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
485		gotReq <- true
486		var msg string
487		select {
488		case <-stop:
489			return
490		case msg = <-resch:
491		}
492		_, err := w.Write([]byte(msg))
493		if err != nil {
494			t.Errorf("Write: %v", err)
495			return
496		}
497	})).ts
498
499	c := ts.Client()
500	tr := c.Transport.(*Transport)
501	maxIdleConnsPerHost := 2
502	tr.MaxIdleConnsPerHost = maxIdleConnsPerHost
503
504	// Start 3 outstanding requests and wait for the server to get them.
505	// Their responses will hang until we write to resch, though.
506	donech := make(chan bool)
507	doReq := func() {
508		defer func() {
509			select {
510			case <-stop:
511				return
512			case donech <- t.Failed():
513			}
514		}()
515		resp, err := c.Get(ts.URL)
516		if err != nil {
517			t.Error(err)
518			return
519		}
520		if _, err := io.ReadAll(resp.Body); err != nil {
521			t.Errorf("ReadAll: %v", err)
522			return
523		}
524	}
525	go doReq()
526	<-gotReq
527	go doReq()
528	<-gotReq
529	go doReq()
530	<-gotReq
531
532	if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
533		t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g)
534	}
535
536	resch <- "res1"
537	<-donech
538	keys := tr.IdleConnKeysForTesting()
539	if e, g := 1, len(keys); e != g {
540		t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
541	}
542	addr := ts.Listener.Addr().String()
543	cacheKey := "|http|" + addr
544	if keys[0] != cacheKey {
545		t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
546	}
547	if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
548		t.Errorf("after first response, expected %d idle conns; got %d", e, g)
549	}
550
551	resch <- "res2"
552	<-donech
553	if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
554		t.Errorf("after second response, idle conns = %d; want %d", g, w)
555	}
556
557	resch <- "res3"
558	<-donech
559	if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
560		t.Errorf("after third response, idle conns = %d; want %d", g, w)
561	}
562}
563
564func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
565	run(t, testTransportMaxConnsPerHostIncludeDialInProgress)
566}
567func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) {
568	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
569		_, err := w.Write([]byte("foo"))
570		if err != nil {
571			t.Fatalf("Write: %v", err)
572		}
573	})).ts
574	c := ts.Client()
575	tr := c.Transport.(*Transport)
576	dialStarted := make(chan struct{})
577	stallDial := make(chan struct{})
578	tr.Dial = func(network, addr string) (net.Conn, error) {
579		dialStarted <- struct{}{}
580		<-stallDial
581		return net.Dial(network, addr)
582	}
583
584	tr.DisableKeepAlives = true
585	tr.MaxConnsPerHost = 1
586
587	preDial := make(chan struct{})
588	reqComplete := make(chan struct{})
589	doReq := func(reqId string) {
590		req, _ := NewRequest("GET", ts.URL, nil)
591		trace := &httptrace.ClientTrace{
592			GetConn: func(hostPort string) {
593				preDial <- struct{}{}
594			},
595		}
596		req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
597		resp, err := tr.RoundTrip(req)
598		if err != nil {
599			t.Errorf("unexpected error for request %s: %v", reqId, err)
600		}
601		_, err = io.ReadAll(resp.Body)
602		if err != nil {
603			t.Errorf("unexpected error for request %s: %v", reqId, err)
604		}
605		reqComplete <- struct{}{}
606	}
607	// get req1 to dial-in-progress
608	go doReq("req1")
609	<-preDial
610	<-dialStarted
611
612	// get req2 to waiting on conns per host to go down below max
613	go doReq("req2")
614	<-preDial
615	select {
616	case <-dialStarted:
617		t.Error("req2 dial started while req1 dial in progress")
618		return
619	default:
620	}
621
622	// let req1 complete
623	stallDial <- struct{}{}
624	<-reqComplete
625
626	// let req2 complete
627	<-dialStarted
628	stallDial <- struct{}{}
629	<-reqComplete
630}
631
632func TestTransportMaxConnsPerHost(t *testing.T) {
633	run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode})
634}
635func testTransportMaxConnsPerHost(t *testing.T, mode testMode) {
636	CondSkipHTTP2(t)
637
638	h := HandlerFunc(func(w ResponseWriter, r *Request) {
639		_, err := w.Write([]byte("foo"))
640		if err != nil {
641			t.Fatalf("Write: %v", err)
642		}
643	})
644
645	ts := newClientServerTest(t, mode, h).ts
646	c := ts.Client()
647	tr := c.Transport.(*Transport)
648	tr.MaxConnsPerHost = 1
649
650	mu := sync.Mutex{}
651	var conns []net.Conn
652	var dialCnt, gotConnCnt, tlsHandshakeCnt int32
653	tr.Dial = func(network, addr string) (net.Conn, error) {
654		atomic.AddInt32(&dialCnt, 1)
655		c, err := net.Dial(network, addr)
656		mu.Lock()
657		defer mu.Unlock()
658		conns = append(conns, c)
659		return c, err
660	}
661
662	doReq := func() {
663		trace := &httptrace.ClientTrace{
664			GotConn: func(connInfo httptrace.GotConnInfo) {
665				if !connInfo.Reused {
666					atomic.AddInt32(&gotConnCnt, 1)
667				}
668			},
669			TLSHandshakeStart: func() {
670				atomic.AddInt32(&tlsHandshakeCnt, 1)
671			},
672		}
673		req, _ := NewRequest("GET", ts.URL, nil)
674		req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
675
676		resp, err := c.Do(req)
677		if err != nil {
678			t.Fatalf("request failed: %v", err)
679		}
680		defer resp.Body.Close()
681		_, err = io.ReadAll(resp.Body)
682		if err != nil {
683			t.Fatalf("read body failed: %v", err)
684		}
685	}
686
687	wg := sync.WaitGroup{}
688	for i := 0; i < 10; i++ {
689		wg.Add(1)
690		go func() {
691			defer wg.Done()
692			doReq()
693		}()
694	}
695	wg.Wait()
696
697	expected := int32(tr.MaxConnsPerHost)
698	if dialCnt != expected {
699		t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected)
700	}
701	if gotConnCnt != expected {
702		t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected)
703	}
704	if ts.TLS != nil && tlsHandshakeCnt != expected {
705		t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
706	}
707
708	if t.Failed() {
709		t.FailNow()
710	}
711
712	mu.Lock()
713	for _, c := range conns {
714		c.Close()
715	}
716	conns = nil
717	mu.Unlock()
718	tr.CloseIdleConnections()
719
720	doReq()
721	expected++
722	if dialCnt != expected {
723		t.Errorf("round 2: too many dials: %d", dialCnt)
724	}
725	if gotConnCnt != expected {
726		t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected)
727	}
728	if ts.TLS != nil && tlsHandshakeCnt != expected {
729		t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected)
730	}
731}
732
733func TestTransportMaxConnsPerHostDialCancellation(t *testing.T) {
734	run(t, testTransportMaxConnsPerHostDialCancellation,
735		testNotParallel, // because test uses SetPendingDialHooks
736		[]testMode{http1Mode, https1Mode, http2Mode},
737	)
738}
739
740func testTransportMaxConnsPerHostDialCancellation(t *testing.T, mode testMode) {
741	CondSkipHTTP2(t)
742
743	h := HandlerFunc(func(w ResponseWriter, r *Request) {
744		_, err := w.Write([]byte("foo"))
745		if err != nil {
746			t.Fatalf("Write: %v", err)
747		}
748	})
749
750	cst := newClientServerTest(t, mode, h)
751	defer cst.close()
752	ts := cst.ts
753	c := ts.Client()
754	tr := c.Transport.(*Transport)
755	tr.MaxConnsPerHost = 1
756
757	// This request is canceled when dial is queued, which preempts dialing.
758	ctx, cancel := context.WithCancel(context.Background())
759	defer cancel()
760	SetPendingDialHooks(cancel, nil)
761	defer SetPendingDialHooks(nil, nil)
762
763	req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
764	_, err := c.Do(req)
765	if !errors.Is(err, context.Canceled) {
766		t.Errorf("expected error %v, got %v", context.Canceled, err)
767	}
768
769	// This request should succeed.
770	SetPendingDialHooks(nil, nil)
771	req, _ = NewRequest("GET", ts.URL, nil)
772	resp, err := c.Do(req)
773	if err != nil {
774		t.Fatalf("request failed: %v", err)
775	}
776	defer resp.Body.Close()
777	_, err = io.ReadAll(resp.Body)
778	if err != nil {
779		t.Fatalf("read body failed: %v", err)
780	}
781}
782
783func TestTransportRemovesDeadIdleConnections(t *testing.T) {
784	run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode})
785}
786func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) {
787	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
788		io.WriteString(w, r.RemoteAddr)
789	})).ts
790
791	c := ts.Client()
792	tr := c.Transport.(*Transport)
793
794	doReq := func(name string) {
795		// Do a POST instead of a GET to prevent the Transport's
796		// idempotent request retry logic from kicking in...
797		res, err := c.Post(ts.URL, "", nil)
798		if err != nil {
799			t.Fatalf("%s: %v", name, err)
800		}
801		if res.StatusCode != 200 {
802			t.Fatalf("%s: %v", name, res.Status)
803		}
804		defer res.Body.Close()
805		slurp, err := io.ReadAll(res.Body)
806		if err != nil {
807			t.Fatalf("%s: %v", name, err)
808		}
809		t.Logf("%s: ok (%q)", name, slurp)
810	}
811
812	doReq("first")
813	keys1 := tr.IdleConnKeysForTesting()
814
815	ts.CloseClientConnections()
816
817	var keys2 []string
818	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
819		keys2 = tr.IdleConnKeysForTesting()
820		if len(keys2) != 0 {
821			if d > 0 {
822				t.Logf("Transport hasn't noticed idle connection's death in %v.\nbefore: %q\n after: %q\n", d, keys1, keys2)
823			}
824			return false
825		}
826		return true
827	})
828
829	doReq("second")
830}
831
832// Test that the Transport notices when a server hangs up on its
833// unexpectedly (a keep-alive connection is closed).
834func TestTransportServerClosingUnexpectedly(t *testing.T) {
835	run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode})
836}
837func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) {
838	ts := newClientServerTest(t, mode, hostPortHandler).ts
839	c := ts.Client()
840
841	fetch := func(n, retries int) string {
842		condFatalf := func(format string, arg ...any) {
843			if retries <= 0 {
844				t.Fatalf(format, arg...)
845			}
846			t.Logf("retrying shortly after expected error: "+format, arg...)
847			time.Sleep(time.Second / time.Duration(retries))
848		}
849		for retries >= 0 {
850			retries--
851			res, err := c.Get(ts.URL)
852			if err != nil {
853				condFatalf("error in req #%d, GET: %v", n, err)
854				continue
855			}
856			body, err := io.ReadAll(res.Body)
857			if err != nil {
858				condFatalf("error in req #%d, ReadAll: %v", n, err)
859				continue
860			}
861			res.Body.Close()
862			return string(body)
863		}
864		panic("unreachable")
865	}
866
867	body1 := fetch(1, 0)
868	body2 := fetch(2, 0)
869
870	// Close all the idle connections in a way that's similar to
871	// the server hanging up on us. We don't use
872	// httptest.Server.CloseClientConnections because it's
873	// best-effort and stops blocking after 5 seconds. On a loaded
874	// machine running many tests concurrently it's possible for
875	// that method to be async and cause the body3 fetch below to
876	// run on an old connection. This function is synchronous.
877	ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
878
879	body3 := fetch(3, 5)
880
881	if body1 != body2 {
882		t.Errorf("expected body1 and body2 to be equal")
883	}
884	if body2 == body3 {
885		t.Errorf("expected body2 and body3 to be different")
886	}
887}
888
889// Test for https://golang.org/issue/2616 (appropriate issue number)
890// This fails pretty reliably with GOMAXPROCS=100 or something high.
891func TestStressSurpriseServerCloses(t *testing.T) {
892	run(t, testStressSurpriseServerCloses, []testMode{http1Mode})
893}
894func testStressSurpriseServerCloses(t *testing.T, mode testMode) {
895	if testing.Short() {
896		t.Skip("skipping test in short mode")
897	}
898	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
899		w.Header().Set("Content-Length", "5")
900		w.Header().Set("Content-Type", "text/plain")
901		w.Write([]byte("Hello"))
902		w.(Flusher).Flush()
903		conn, buf, _ := w.(Hijacker).Hijack()
904		buf.Flush()
905		conn.Close()
906	})).ts
907	c := ts.Client()
908
909	// Do a bunch of traffic from different goroutines. Send to activityc
910	// after each request completes, regardless of whether it failed.
911	// If these are too high, OS X exhausts its ephemeral ports
912	// and hangs waiting for them to transition TCP states. That's
913	// not what we want to test. TODO(bradfitz): use an io.Pipe
914	// dialer for this test instead?
915	const (
916		numClients    = 20
917		reqsPerClient = 25
918	)
919	var wg sync.WaitGroup
920	wg.Add(numClients * reqsPerClient)
921	for i := 0; i < numClients; i++ {
922		go func() {
923			for i := 0; i < reqsPerClient; i++ {
924				res, err := c.Get(ts.URL)
925				if err == nil {
926					// We expect errors since the server is
927					// hanging up on us after telling us to
928					// send more requests, so we don't
929					// actually care what the error is.
930					// But we want to close the body in cases
931					// where we won the race.
932					res.Body.Close()
933				}
934				wg.Done()
935			}
936		}()
937	}
938
939	// Make sure all the request come back, one way or another.
940	wg.Wait()
941}
942
943// TestTransportHeadResponses verifies that we deal with Content-Lengths
944// with no bodies properly
945func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) }
946func testTransportHeadResponses(t *testing.T, mode testMode) {
947	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
948		if r.Method != "HEAD" {
949			panic("expected HEAD; got " + r.Method)
950		}
951		w.Header().Set("Content-Length", "123")
952		w.WriteHeader(200)
953	})).ts
954	c := ts.Client()
955
956	for i := 0; i < 2; i++ {
957		res, err := c.Head(ts.URL)
958		if err != nil {
959			t.Errorf("error on loop %d: %v", i, err)
960			continue
961		}
962		if e, g := "123", res.Header.Get("Content-Length"); e != g {
963			t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g)
964		}
965		if e, g := int64(123), res.ContentLength; e != g {
966			t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g)
967		}
968		if all, err := io.ReadAll(res.Body); err != nil {
969			t.Errorf("loop %d: Body ReadAll: %v", i, err)
970		} else if len(all) != 0 {
971			t.Errorf("Bogus body %q", all)
972		}
973	}
974}
975
976// TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding
977// on responses to HEAD requests.
978func TestTransportHeadChunkedResponse(t *testing.T) {
979	run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel)
980}
981func testTransportHeadChunkedResponse(t *testing.T, mode testMode) {
982	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
983		if r.Method != "HEAD" {
984			panic("expected HEAD; got " + r.Method)
985		}
986		w.Header().Set("Transfer-Encoding", "chunked") // client should ignore
987		w.Header().Set("x-client-ipport", r.RemoteAddr)
988		w.WriteHeader(200)
989	})).ts
990	c := ts.Client()
991
992	// Ensure that we wait for the readLoop to complete before
993	// calling Head again
994	didRead := make(chan bool)
995	SetReadLoopBeforeNextReadHook(func() { didRead <- true })
996	defer SetReadLoopBeforeNextReadHook(nil)
997
998	res1, err := c.Head(ts.URL)
999	<-didRead
1000
1001	if err != nil {
1002		t.Fatalf("request 1 error: %v", err)
1003	}
1004
1005	res2, err := c.Head(ts.URL)
1006	<-didRead
1007
1008	if err != nil {
1009		t.Fatalf("request 2 error: %v", err)
1010	}
1011	if v1, v2 := res1.Header.Get("x-client-ipport"), res2.Header.Get("x-client-ipport"); v1 != v2 {
1012		t.Errorf("ip/ports differed between head requests: %q vs %q", v1, v2)
1013	}
1014}
1015
1016var roundTripTests = []struct {
1017	accept       string
1018	expectAccept string
1019	compressed   bool
1020}{
1021	// Requests with no accept-encoding header use transparent compression
1022	{"", "gzip", false},
1023	// Requests with other accept-encoding should pass through unmodified
1024	{"foo", "foo", false},
1025	// Requests with accept-encoding == gzip should be passed through
1026	{"gzip", "gzip", true},
1027}
1028
1029// Test that the modification made to the Request by the RoundTripper is cleaned up
1030func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) }
1031func testRoundTripGzip(t *testing.T, mode testMode) {
1032	const responseBody = "test response body"
1033	ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1034		accept := req.Header.Get("Accept-Encoding")
1035		if expect := req.FormValue("expect_accept"); accept != expect {
1036			t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
1037				req.FormValue("testnum"), accept, expect)
1038		}
1039		if accept == "gzip" {
1040			rw.Header().Set("Content-Encoding", "gzip")
1041			gz := gzip.NewWriter(rw)
1042			gz.Write([]byte(responseBody))
1043			gz.Close()
1044		} else {
1045			rw.Header().Set("Content-Encoding", accept)
1046			rw.Write([]byte(responseBody))
1047		}
1048	})).ts
1049	tr := ts.Client().Transport.(*Transport)
1050
1051	for i, test := range roundTripTests {
1052		// Test basic request (no accept-encoding)
1053		req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
1054		if test.accept != "" {
1055			req.Header.Set("Accept-Encoding", test.accept)
1056		}
1057		res, err := tr.RoundTrip(req)
1058		if err != nil {
1059			t.Errorf("%d. RoundTrip: %v", i, err)
1060			continue
1061		}
1062		var body []byte
1063		if test.compressed {
1064			var r *gzip.Reader
1065			r, err = gzip.NewReader(res.Body)
1066			if err != nil {
1067				t.Errorf("%d. gzip NewReader: %v", i, err)
1068				continue
1069			}
1070			body, err = io.ReadAll(r)
1071			res.Body.Close()
1072		} else {
1073			body, err = io.ReadAll(res.Body)
1074		}
1075		if err != nil {
1076			t.Errorf("%d. Error: %q", i, err)
1077			continue
1078		}
1079		if g, e := string(body), responseBody; g != e {
1080			t.Errorf("%d. body = %q; want %q", i, g, e)
1081		}
1082		if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
1083			t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
1084		}
1085		if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
1086			t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
1087		}
1088	}
1089
1090}
1091
1092func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) }
1093func testTransportGzip(t *testing.T, mode testMode) {
1094	if mode == http2Mode {
1095		t.Skip("https://go.dev/issue/56020")
1096	}
1097	const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
1098	const nRandBytes = 1024 * 1024
1099	ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1100		if req.Method == "HEAD" {
1101			if g := req.Header.Get("Accept-Encoding"); g != "" {
1102				t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g)
1103			}
1104			return
1105		}
1106		if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e {
1107			t.Errorf("Accept-Encoding = %q, want %q", g, e)
1108		}
1109		rw.Header().Set("Content-Encoding", "gzip")
1110
1111		var w io.Writer = rw
1112		var buf bytes.Buffer
1113		if req.FormValue("chunked") == "0" {
1114			w = &buf
1115			defer io.Copy(rw, &buf)
1116			defer func() {
1117				rw.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
1118			}()
1119		}
1120		gz := gzip.NewWriter(w)
1121		gz.Write([]byte(testString))
1122		if req.FormValue("body") == "large" {
1123			io.CopyN(gz, rand.Reader, nRandBytes)
1124		}
1125		gz.Close()
1126	})).ts
1127	c := ts.Client()
1128
1129	for _, chunked := range []string{"1", "0"} {
1130		// First fetch something large, but only read some of it.
1131		res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked)
1132		if err != nil {
1133			t.Fatalf("large get: %v", err)
1134		}
1135		buf := make([]byte, len(testString))
1136		n, err := io.ReadFull(res.Body, buf)
1137		if err != nil {
1138			t.Fatalf("partial read of large response: size=%d, %v", n, err)
1139		}
1140		if e, g := testString, string(buf); e != g {
1141			t.Errorf("partial read got %q, expected %q", g, e)
1142		}
1143		res.Body.Close()
1144		// Read on the body, even though it's closed
1145		n, err = res.Body.Read(buf)
1146		if n != 0 || err == nil {
1147			t.Errorf("expected error post-closed large Read; got = %d, %v", n, err)
1148		}
1149
1150		// Then something small.
1151		res, err = c.Get(ts.URL + "/?chunked=" + chunked)
1152		if err != nil {
1153			t.Fatal(err)
1154		}
1155		body, err := io.ReadAll(res.Body)
1156		if err != nil {
1157			t.Fatal(err)
1158		}
1159		if g, e := string(body), testString; g != e {
1160			t.Fatalf("body = %q; want %q", g, e)
1161		}
1162		if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1163			t.Fatalf("Content-Encoding = %q; want %q", g, e)
1164		}
1165
1166		// Read on the body after it's been fully read:
1167		n, err = res.Body.Read(buf)
1168		if n != 0 || err == nil {
1169			t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err)
1170		}
1171		res.Body.Close()
1172		n, err = res.Body.Read(buf)
1173		if n != 0 || err == nil {
1174			t.Errorf("expected Read error after Close; got %d, %v", n, err)
1175		}
1176	}
1177
1178	// And a HEAD request too, because they're always weird.
1179	res, err := c.Head(ts.URL)
1180	if err != nil {
1181		t.Fatalf("Head: %v", err)
1182	}
1183	if res.StatusCode != 200 {
1184		t.Errorf("Head status=%d; want=200", res.StatusCode)
1185	}
1186}
1187
1188// A transport100Continue test exercises Transport behaviors when sending a
1189// request with an Expect: 100-continue header.
1190type transport100ContinueTest struct {
1191	t *testing.T
1192
1193	reqdone chan struct{}
1194	resp    *Response
1195	respErr error
1196
1197	conn   net.Conn
1198	reader *bufio.Reader
1199}
1200
1201const transport100ContinueTestBody = "request body"
1202
1203// newTransport100ContinueTest creates a Transport and sends an Expect: 100-continue
1204// request on it.
1205func newTransport100ContinueTest(t *testing.T, timeout time.Duration) *transport100ContinueTest {
1206	ln := newLocalListener(t)
1207	defer ln.Close()
1208
1209	test := &transport100ContinueTest{
1210		t:       t,
1211		reqdone: make(chan struct{}),
1212	}
1213
1214	tr := &Transport{
1215		ExpectContinueTimeout: timeout,
1216	}
1217	go func() {
1218		defer close(test.reqdone)
1219		body := strings.NewReader(transport100ContinueTestBody)
1220		req, _ := NewRequest("PUT", "http://"+ln.Addr().String(), body)
1221		req.Header.Set("Expect", "100-continue")
1222		req.ContentLength = int64(len(transport100ContinueTestBody))
1223		test.resp, test.respErr = tr.RoundTrip(req)
1224		test.resp.Body.Close()
1225	}()
1226
1227	c, err := ln.Accept()
1228	if err != nil {
1229		t.Fatalf("Accept: %v", err)
1230	}
1231	t.Cleanup(func() {
1232		c.Close()
1233	})
1234	br := bufio.NewReader(c)
1235	_, err = ReadRequest(br)
1236	if err != nil {
1237		t.Fatalf("ReadRequest: %v", err)
1238	}
1239	test.conn = c
1240	test.reader = br
1241	t.Cleanup(func() {
1242		<-test.reqdone
1243		tr.CloseIdleConnections()
1244		got, _ := io.ReadAll(test.reader)
1245		if len(got) > 0 {
1246			t.Fatalf("Transport sent unexpected bytes: %q", got)
1247		}
1248	})
1249
1250	return test
1251}
1252
1253// respond sends response lines from the server to the transport.
1254func (test *transport100ContinueTest) respond(lines ...string) {
1255	for _, line := range lines {
1256		if _, err := test.conn.Write([]byte(line + "\r\n")); err != nil {
1257			test.t.Fatalf("Write: %v", err)
1258		}
1259	}
1260	if _, err := test.conn.Write([]byte("\r\n")); err != nil {
1261		test.t.Fatalf("Write: %v", err)
1262	}
1263}
1264
1265// wantBodySent ensures the transport has sent the request body to the server.
1266func (test *transport100ContinueTest) wantBodySent() {
1267	got, err := io.ReadAll(io.LimitReader(test.reader, int64(len(transport100ContinueTestBody))))
1268	if err != nil {
1269		test.t.Fatalf("unexpected error reading body: %v", err)
1270	}
1271	if got, want := string(got), transport100ContinueTestBody; got != want {
1272		test.t.Fatalf("unexpected body: got %q, want %q", got, want)
1273	}
1274}
1275
1276// wantRequestDone ensures the Transport.RoundTrip has completed with the expected status.
1277func (test *transport100ContinueTest) wantRequestDone(want int) {
1278	<-test.reqdone
1279	if test.respErr != nil {
1280		test.t.Fatalf("unexpected RoundTrip error: %v", test.respErr)
1281	}
1282	if got := test.resp.StatusCode; got != want {
1283		test.t.Fatalf("unexpected response code: got %v, want %v", got, want)
1284	}
1285}
1286
1287func TestTransportExpect100ContinueSent(t *testing.T) {
1288	test := newTransport100ContinueTest(t, 1*time.Hour)
1289	// Server sends a 100 Continue response, and the client sends the request body.
1290	test.respond("HTTP/1.1 100 Continue")
1291	test.wantBodySent()
1292	test.respond("HTTP/1.1 200", "Content-Length: 0")
1293	test.wantRequestDone(200)
1294}
1295
1296func TestTransportExpect100Continue200ResponseNoConnClose(t *testing.T) {
1297	test := newTransport100ContinueTest(t, 1*time.Hour)
1298	// No 100 Continue response, no Connection: close header.
1299	test.respond("HTTP/1.1 200", "Content-Length: 0")
1300	test.wantBodySent()
1301	test.wantRequestDone(200)
1302}
1303
1304func TestTransportExpect100Continue200ResponseWithConnClose(t *testing.T) {
1305	test := newTransport100ContinueTest(t, 1*time.Hour)
1306	// No 100 Continue response, Connection: close header set.
1307	test.respond("HTTP/1.1 200", "Connection: close", "Content-Length: 0")
1308	test.wantRequestDone(200)
1309}
1310
1311func TestTransportExpect100Continue500ResponseNoConnClose(t *testing.T) {
1312	test := newTransport100ContinueTest(t, 1*time.Hour)
1313	// No 100 Continue response, no Connection: close header.
1314	test.respond("HTTP/1.1 500", "Content-Length: 0")
1315	test.wantBodySent()
1316	test.wantRequestDone(500)
1317}
1318
1319func TestTransportExpect100Continue500ResponseTimeout(t *testing.T) {
1320	test := newTransport100ContinueTest(t, 5*time.Millisecond) // short timeout
1321	test.wantBodySent()                                        // after timeout
1322	test.respond("HTTP/1.1 200", "Content-Length: 0")
1323	test.wantRequestDone(200)
1324}
1325
1326func TestSOCKS5Proxy(t *testing.T) {
1327	run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode})
1328}
1329func testSOCKS5Proxy(t *testing.T, mode testMode) {
1330	ch := make(chan string, 1)
1331	l := newLocalListener(t)
1332	defer l.Close()
1333	defer close(ch)
1334	proxy := func(t *testing.T) {
1335		s, err := l.Accept()
1336		if err != nil {
1337			t.Errorf("socks5 proxy Accept(): %v", err)
1338			return
1339		}
1340		defer s.Close()
1341		var buf [22]byte
1342		if _, err := io.ReadFull(s, buf[:3]); err != nil {
1343			t.Errorf("socks5 proxy initial read: %v", err)
1344			return
1345		}
1346		if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1347			t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want)
1348			return
1349		}
1350		if _, err := s.Write([]byte{5, 0}); err != nil {
1351			t.Errorf("socks5 proxy initial write: %v", err)
1352			return
1353		}
1354		if _, err := io.ReadFull(s, buf[:4]); err != nil {
1355			t.Errorf("socks5 proxy second read: %v", err)
1356			return
1357		}
1358		if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) {
1359			t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want)
1360			return
1361		}
1362		var ipLen int
1363		switch buf[3] {
1364		case 1:
1365			ipLen = net.IPv4len
1366		case 4:
1367			ipLen = net.IPv6len
1368		default:
1369			t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
1370			return
1371		}
1372		if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
1373			t.Errorf("socks5 proxy address read: %v", err)
1374			return
1375		}
1376		ip := net.IP(buf[4 : ipLen+4])
1377		port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6])
1378		copy(buf[:3], []byte{5, 0, 0})
1379		if _, err := s.Write(buf[:ipLen+6]); err != nil {
1380			t.Errorf("socks5 proxy connect write: %v", err)
1381			return
1382		}
1383		ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
1384
1385		// Implement proxying.
1386		targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
1387		targetConn, err := net.Dial("tcp", targetHost)
1388		if err != nil {
1389			t.Errorf("net.Dial failed")
1390			return
1391		}
1392		go io.Copy(targetConn, s)
1393		io.Copy(s, targetConn) // Wait for the client to close the socket.
1394		targetConn.Close()
1395	}
1396
1397	pu, err := url.Parse("socks5://" + l.Addr().String())
1398	if err != nil {
1399		t.Fatal(err)
1400	}
1401
1402	sentinelHeader := "X-Sentinel"
1403	sentinelValue := "12345"
1404	h := HandlerFunc(func(w ResponseWriter, r *Request) {
1405		w.Header().Set(sentinelHeader, sentinelValue)
1406	})
1407	for _, useTLS := range []bool{false, true} {
1408		t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
1409			ts := newClientServerTest(t, mode, h).ts
1410			go proxy(t)
1411			c := ts.Client()
1412			c.Transport.(*Transport).Proxy = ProxyURL(pu)
1413			r, err := c.Head(ts.URL)
1414			if err != nil {
1415				t.Fatal(err)
1416			}
1417			if r.Header.Get(sentinelHeader) != sentinelValue {
1418				t.Errorf("Failed to retrieve sentinel value")
1419			}
1420			got := <-ch
1421			ts.Close()
1422			tsu, err := url.Parse(ts.URL)
1423			if err != nil {
1424				t.Fatal(err)
1425			}
1426			want := "proxy for " + tsu.Host
1427			if got != want {
1428				t.Errorf("got %q, want %q", got, want)
1429			}
1430		})
1431	}
1432}
1433
1434func TestTransportProxy(t *testing.T) {
1435	defer afterTest(t)
1436	testCases := []struct{ siteMode, proxyMode testMode }{
1437		{http1Mode, http1Mode},
1438		{http1Mode, https1Mode},
1439		{https1Mode, http1Mode},
1440		{https1Mode, https1Mode},
1441	}
1442	for _, testCase := range testCases {
1443		siteMode := testCase.siteMode
1444		proxyMode := testCase.proxyMode
1445		t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) {
1446			siteCh := make(chan *Request, 1)
1447			h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1448				siteCh <- r
1449			})
1450			proxyCh := make(chan *Request, 1)
1451			h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1452				proxyCh <- r
1453				// Implement an entire CONNECT proxy
1454				if r.Method == "CONNECT" {
1455					hijacker, ok := w.(Hijacker)
1456					if !ok {
1457						t.Errorf("hijack not allowed")
1458						return
1459					}
1460					clientConn, _, err := hijacker.Hijack()
1461					if err != nil {
1462						t.Errorf("hijacking failed")
1463						return
1464					}
1465					res := &Response{
1466						StatusCode: StatusOK,
1467						Proto:      "HTTP/1.1",
1468						ProtoMajor: 1,
1469						ProtoMinor: 1,
1470						Header:     make(Header),
1471					}
1472
1473					targetConn, err := net.Dial("tcp", r.URL.Host)
1474					if err != nil {
1475						t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1476						return
1477					}
1478
1479					if err := res.Write(clientConn); err != nil {
1480						t.Errorf("Writing 200 OK failed: %v", err)
1481						return
1482					}
1483
1484					go io.Copy(targetConn, clientConn)
1485					go func() {
1486						io.Copy(clientConn, targetConn)
1487						targetConn.Close()
1488					}()
1489				}
1490			})
1491			ts := newClientServerTest(t, siteMode, h1).ts
1492			proxy := newClientServerTest(t, proxyMode, h2).ts
1493
1494			pu, err := url.Parse(proxy.URL)
1495			if err != nil {
1496				t.Fatal(err)
1497			}
1498
1499			// If neither server is HTTPS or both are, then c may be derived from either.
1500			// If only one server is HTTPS, c must be derived from that server in order
1501			// to ensure that it is configured to use the fake root CA from testcert.go.
1502			c := proxy.Client()
1503			if siteMode == https1Mode {
1504				c = ts.Client()
1505			}
1506
1507			c.Transport.(*Transport).Proxy = ProxyURL(pu)
1508			if _, err := c.Head(ts.URL); err != nil {
1509				t.Error(err)
1510			}
1511			got := <-proxyCh
1512			c.Transport.(*Transport).CloseIdleConnections()
1513			ts.Close()
1514			proxy.Close()
1515			if siteMode == https1Mode {
1516				// First message should be a CONNECT, asking for a socket to the real server,
1517				if got.Method != "CONNECT" {
1518					t.Errorf("Wrong method for secure proxying: %q", got.Method)
1519				}
1520				gotHost := got.URL.Host
1521				pu, err := url.Parse(ts.URL)
1522				if err != nil {
1523					t.Fatal("Invalid site URL")
1524				}
1525				if wantHost := pu.Host; gotHost != wantHost {
1526					t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
1527				}
1528
1529				// The next message on the channel should be from the site's server.
1530				next := <-siteCh
1531				if next.Method != "HEAD" {
1532					t.Errorf("Wrong method at destination: %s", next.Method)
1533				}
1534				if nextURL := next.URL.String(); nextURL != "/" {
1535					t.Errorf("Wrong URL at destination: %s", nextURL)
1536				}
1537			} else {
1538				if got.Method != "HEAD" {
1539					t.Errorf("Wrong method for destination: %q", got.Method)
1540				}
1541				gotURL := got.URL.String()
1542				wantURL := ts.URL + "/"
1543				if gotURL != wantURL {
1544					t.Errorf("Got URL %q, want %q", gotURL, wantURL)
1545				}
1546			}
1547		})
1548	}
1549}
1550
1551func TestOnProxyConnectResponse(t *testing.T) {
1552
1553	var tcases = []struct {
1554		proxyStatusCode int
1555		err             error
1556	}{
1557		{
1558			StatusOK,
1559			nil,
1560		},
1561		{
1562			StatusForbidden,
1563			errors.New("403"),
1564		},
1565	}
1566	for _, tcase := range tcases {
1567		h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
1568
1569		})
1570
1571		h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
1572			// Implement an entire CONNECT proxy
1573			if r.Method == "CONNECT" {
1574				if tcase.proxyStatusCode != StatusOK {
1575					w.WriteHeader(tcase.proxyStatusCode)
1576					return
1577				}
1578				hijacker, ok := w.(Hijacker)
1579				if !ok {
1580					t.Errorf("hijack not allowed")
1581					return
1582				}
1583				clientConn, _, err := hijacker.Hijack()
1584				if err != nil {
1585					t.Errorf("hijacking failed")
1586					return
1587				}
1588				res := &Response{
1589					StatusCode: StatusOK,
1590					Proto:      "HTTP/1.1",
1591					ProtoMajor: 1,
1592					ProtoMinor: 1,
1593					Header:     make(Header),
1594				}
1595
1596				targetConn, err := net.Dial("tcp", r.URL.Host)
1597				if err != nil {
1598					t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
1599					return
1600				}
1601
1602				if err := res.Write(clientConn); err != nil {
1603					t.Errorf("Writing 200 OK failed: %v", err)
1604					return
1605				}
1606
1607				go io.Copy(targetConn, clientConn)
1608				go func() {
1609					io.Copy(clientConn, targetConn)
1610					targetConn.Close()
1611				}()
1612			}
1613		})
1614		ts := newClientServerTest(t, https1Mode, h1).ts
1615		proxy := newClientServerTest(t, https1Mode, h2).ts
1616
1617		pu, err := url.Parse(proxy.URL)
1618		if err != nil {
1619			t.Fatal(err)
1620		}
1621
1622		c := proxy.Client()
1623
1624		var (
1625			dials  atomic.Int32
1626			closes atomic.Int32
1627		)
1628		c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
1629			conn, err := net.Dial(network, addr)
1630			if err != nil {
1631				return nil, err
1632			}
1633			dials.Add(1)
1634			return noteCloseConn{
1635				Conn: conn,
1636				closeFunc: func() {
1637					closes.Add(1)
1638				},
1639			}, nil
1640		}
1641
1642		c.Transport.(*Transport).Proxy = ProxyURL(pu)
1643		c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
1644			if proxyURL.String() != pu.String() {
1645				t.Errorf("proxy url got %s, want %s", proxyURL, pu)
1646			}
1647
1648			if "https://"+connectReq.URL.String() != ts.URL {
1649				t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
1650			}
1651			return tcase.err
1652		}
1653		wantCloses := int32(0)
1654		if _, err := c.Head(ts.URL); err != nil {
1655			wantCloses = 1
1656			if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
1657				t.Errorf("got %v, want %v", err, tcase.err)
1658			}
1659		} else {
1660			if tcase.err != nil {
1661				t.Errorf("got %v, want nil", err)
1662			}
1663		}
1664		if got, want := dials.Load(), int32(1); got != want {
1665			t.Errorf("got %v dials, want %v", got, want)
1666		}
1667		// #64804: If OnProxyConnectResponse returns an error, we should close the conn.
1668		if got, want := closes.Load(), wantCloses; got != want {
1669			t.Errorf("got %v closes, want %v", got, want)
1670		}
1671	}
1672}
1673
1674// Issue 28012: verify that the Transport closes its TCP connection to http proxies
1675// when they're slow to reply to HTTPS CONNECT responses.
1676func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
1677	cancelc := make(chan struct{})
1678	SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
1679		ctx, cancel := context.WithCancel(ctx)
1680		go func() {
1681			select {
1682			case <-cancelc:
1683			case <-ctx.Done():
1684			}
1685			cancel()
1686		}()
1687		return ctx, cancel
1688	})
1689
1690	defer afterTest(t)
1691
1692	ln := newLocalListener(t)
1693	defer ln.Close()
1694	listenerDone := make(chan struct{})
1695	go func() {
1696		defer close(listenerDone)
1697		c, err := ln.Accept()
1698		if err != nil {
1699			t.Errorf("Accept: %v", err)
1700			return
1701		}
1702		defer c.Close()
1703		// Read the CONNECT request
1704		br := bufio.NewReader(c)
1705		cr, err := ReadRequest(br)
1706		if err != nil {
1707			t.Errorf("proxy server failed to read CONNECT request")
1708			return
1709		}
1710		if cr.Method != "CONNECT" {
1711			t.Errorf("unexpected method %q", cr.Method)
1712			return
1713		}
1714
1715		// Now hang and never write a response; instead, cancel the request and wait
1716		// for the client to close.
1717		// (Prior to Issue 28012 being fixed, we never closed.)
1718		close(cancelc)
1719		var buf [1]byte
1720		_, err = br.Read(buf[:])
1721		if err != io.EOF {
1722			t.Errorf("proxy server Read err = %v; want EOF", err)
1723		}
1724		return
1725	}()
1726
1727	c := &Client{
1728		Transport: &Transport{
1729			Proxy: func(*Request) (*url.URL, error) {
1730				return url.Parse("http://" + ln.Addr().String())
1731			},
1732		},
1733	}
1734	req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1735	if err != nil {
1736		t.Fatal(err)
1737	}
1738	_, err = c.Do(req)
1739	if err == nil {
1740		t.Errorf("unexpected Get success")
1741	}
1742
1743	// Wait unconditionally for the listener goroutine to exit: this should never
1744	// hang, so if it does we want a full goroutine dump — and that's exactly what
1745	// the testing package will give us when the test run times out.
1746	<-listenerDone
1747}
1748
1749// Issue 16997: test transport dial preserves typed errors
1750func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
1751	defer afterTest(t)
1752
1753	var errDial = errors.New("some dial error")
1754
1755	tr := &Transport{
1756		Proxy: func(*Request) (*url.URL, error) {
1757			return url.Parse("http://proxy.fake.tld/")
1758		},
1759		Dial: func(string, string) (net.Conn, error) {
1760			return nil, errDial
1761		},
1762	}
1763	defer tr.CloseIdleConnections()
1764
1765	c := &Client{Transport: tr}
1766	req, _ := NewRequest("GET", "http://fake.tld", nil)
1767	res, err := c.Do(req)
1768	if err == nil {
1769		res.Body.Close()
1770		t.Fatal("wanted a non-nil error")
1771	}
1772
1773	uerr, ok := err.(*url.Error)
1774	if !ok {
1775		t.Fatalf("got %T, want *url.Error", err)
1776	}
1777	oe, ok := uerr.Err.(*net.OpError)
1778	if !ok {
1779		t.Fatalf("url.Error.Err =  %T; want *net.OpError", uerr.Err)
1780	}
1781	want := &net.OpError{
1782		Op:  "proxyconnect",
1783		Net: "tcp",
1784		Err: errDial, // original error, unwrapped.
1785	}
1786	if !reflect.DeepEqual(oe, want) {
1787		t.Errorf("Got error %#v; want %#v", oe, want)
1788	}
1789}
1790
1791// Issue 36431: calls to RoundTrip should not mutate t.ProxyConnectHeader.
1792//
1793// (A bug caused dialConn to instead write the per-request Proxy-Authorization
1794// header through to the shared Header instance, introducing a data race.)
1795func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) {
1796	run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader)
1797}
1798func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) {
1799	proxy := newClientServerTest(t, mode, NotFoundHandler()).ts
1800	defer proxy.Close()
1801	c := proxy.Client()
1802
1803	tr := c.Transport.(*Transport)
1804	tr.Proxy = func(*Request) (*url.URL, error) {
1805		u, _ := url.Parse(proxy.URL)
1806		u.User = url.UserPassword("aladdin", "opensesame")
1807		return u, nil
1808	}
1809	h := tr.ProxyConnectHeader
1810	if h == nil {
1811		h = make(Header)
1812	}
1813	tr.ProxyConnectHeader = h.Clone()
1814
1815	req, err := NewRequest("GET", "https://golang.fake.tld/", nil)
1816	if err != nil {
1817		t.Fatal(err)
1818	}
1819	_, err = c.Do(req)
1820	if err == nil {
1821		t.Errorf("unexpected Get success")
1822	}
1823
1824	if !reflect.DeepEqual(tr.ProxyConnectHeader, h) {
1825		t.Errorf("tr.ProxyConnectHeader = %v; want %v", tr.ProxyConnectHeader, h)
1826	}
1827}
1828
1829// TestTransportGzipRecursive sends a gzip quine and checks that the
1830// client gets the same value back. This is more cute than anything,
1831// but checks that we don't recurse forever, and checks that
1832// Content-Encoding is removed.
1833func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) }
1834func testTransportGzipRecursive(t *testing.T, mode testMode) {
1835	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1836		w.Header().Set("Content-Encoding", "gzip")
1837		w.Write(rgz)
1838	})).ts
1839
1840	c := ts.Client()
1841	res, err := c.Get(ts.URL)
1842	if err != nil {
1843		t.Fatal(err)
1844	}
1845	body, err := io.ReadAll(res.Body)
1846	if err != nil {
1847		t.Fatal(err)
1848	}
1849	if !bytes.Equal(body, rgz) {
1850		t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x",
1851			body, rgz)
1852	}
1853	if g, e := res.Header.Get("Content-Encoding"), ""; g != e {
1854		t.Fatalf("Content-Encoding = %q; want %q", g, e)
1855	}
1856}
1857
1858// golang.org/issue/7750: request fails when server replies with
1859// a short gzip body
1860func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) }
1861func testTransportGzipShort(t *testing.T, mode testMode) {
1862	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1863		w.Header().Set("Content-Encoding", "gzip")
1864		w.Write([]byte{0x1f, 0x8b})
1865	})).ts
1866
1867	c := ts.Client()
1868	res, err := c.Get(ts.URL)
1869	if err != nil {
1870		t.Fatal(err)
1871	}
1872	defer res.Body.Close()
1873	_, err = io.ReadAll(res.Body)
1874	if err == nil {
1875		t.Fatal("Expect an error from reading a body.")
1876	}
1877	if err != io.ErrUnexpectedEOF {
1878		t.Errorf("ReadAll error = %v; want io.ErrUnexpectedEOF", err)
1879	}
1880}
1881
1882// Wait until number of goroutines is no greater than nmax, or time out.
1883func waitNumGoroutine(nmax int) int {
1884	nfinal := runtime.NumGoroutine()
1885	for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- {
1886		time.Sleep(50 * time.Millisecond)
1887		runtime.GC()
1888		nfinal = runtime.NumGoroutine()
1889	}
1890	return nfinal
1891}
1892
1893// tests that persistent goroutine connections shut down when no longer desired.
1894func TestTransportPersistConnLeak(t *testing.T) {
1895	run(t, testTransportPersistConnLeak, testNotParallel)
1896}
1897func testTransportPersistConnLeak(t *testing.T, mode testMode) {
1898	if mode == http2Mode {
1899		t.Skip("flaky in HTTP/2")
1900	}
1901	// Not parallel: counts goroutines
1902
1903	const numReq = 25
1904	gotReqCh := make(chan bool, numReq)
1905	unblockCh := make(chan bool, numReq)
1906	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1907		gotReqCh <- true
1908		<-unblockCh
1909		w.Header().Set("Content-Length", "0")
1910		w.WriteHeader(204)
1911	})).ts
1912	c := ts.Client()
1913	tr := c.Transport.(*Transport)
1914
1915	n0 := runtime.NumGoroutine()
1916
1917	didReqCh := make(chan bool, numReq)
1918	failed := make(chan bool, numReq)
1919	for i := 0; i < numReq; i++ {
1920		go func() {
1921			res, err := c.Get(ts.URL)
1922			didReqCh <- true
1923			if err != nil {
1924				t.Logf("client fetch error: %v", err)
1925				failed <- true
1926				return
1927			}
1928			res.Body.Close()
1929		}()
1930	}
1931
1932	// Wait for all goroutines to be stuck in the Handler.
1933	for i := 0; i < numReq; i++ {
1934		select {
1935		case <-gotReqCh:
1936			// ok
1937		case <-failed:
1938			// Not great but not what we are testing:
1939			// sometimes an overloaded system will fail to make all the connections.
1940		}
1941	}
1942
1943	nhigh := runtime.NumGoroutine()
1944
1945	// Tell all handlers to unblock and reply.
1946	close(unblockCh)
1947
1948	// Wait for all HTTP clients to be done.
1949	for i := 0; i < numReq; i++ {
1950		<-didReqCh
1951	}
1952
1953	tr.CloseIdleConnections()
1954	nfinal := waitNumGoroutine(n0 + 5)
1955
1956	growth := nfinal - n0
1957
1958	// We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
1959	// Previously we were leaking one per numReq.
1960	if int(growth) > 5 {
1961		t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
1962		t.Error("too many new goroutines")
1963	}
1964}
1965
1966// golang.org/issue/4531: Transport leaks goroutines when
1967// request.ContentLength is explicitly short
1968func TestTransportPersistConnLeakShortBody(t *testing.T) {
1969	run(t, testTransportPersistConnLeakShortBody, testNotParallel)
1970}
1971func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) {
1972	if mode == http2Mode {
1973		t.Skip("flaky in HTTP/2")
1974	}
1975
1976	// Not parallel: measures goroutines.
1977	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1978	})).ts
1979	c := ts.Client()
1980	tr := c.Transport.(*Transport)
1981
1982	n0 := runtime.NumGoroutine()
1983	body := []byte("Hello")
1984	for i := 0; i < 20; i++ {
1985		req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
1986		if err != nil {
1987			t.Fatal(err)
1988		}
1989		req.ContentLength = int64(len(body) - 2) // explicitly short
1990		_, err = c.Do(req)
1991		if err == nil {
1992			t.Fatal("Expect an error from writing too long of a body.")
1993		}
1994	}
1995	nhigh := runtime.NumGoroutine()
1996	tr.CloseIdleConnections()
1997	nfinal := waitNumGoroutine(n0 + 5)
1998
1999	growth := nfinal - n0
2000
2001	// We expect 0 or 1 extra goroutine, empirically. Allow up to 5.
2002	// Previously we were leaking one per numReq.
2003	t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
2004	if int(growth) > 5 {
2005		t.Error("too many new goroutines")
2006	}
2007}
2008
2009// A countedConn is a net.Conn that decrements an atomic counter when finalized.
2010type countedConn struct {
2011	net.Conn
2012}
2013
2014// A countingDialer dials connections and counts the number that remain reachable.
2015type countingDialer struct {
2016	dialer      net.Dialer
2017	mu          sync.Mutex
2018	total, live int64
2019}
2020
2021func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
2022	conn, err := d.dialer.DialContext(ctx, network, address)
2023	if err != nil {
2024		return nil, err
2025	}
2026
2027	counted := new(countedConn)
2028	counted.Conn = conn
2029
2030	d.mu.Lock()
2031	defer d.mu.Unlock()
2032	d.total++
2033	d.live++
2034
2035	runtime.SetFinalizer(counted, d.decrement)
2036	return counted, nil
2037}
2038
2039func (d *countingDialer) decrement(*countedConn) {
2040	d.mu.Lock()
2041	defer d.mu.Unlock()
2042	d.live--
2043}
2044
2045func (d *countingDialer) Read() (total, live int64) {
2046	d.mu.Lock()
2047	defer d.mu.Unlock()
2048	return d.total, d.live
2049}
2050
2051func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
2052	run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode})
2053}
2054func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) {
2055	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2056		// Close every connection so that it cannot be kept alive.
2057		conn, _, err := w.(Hijacker).Hijack()
2058		if err != nil {
2059			t.Errorf("Hijack failed unexpectedly: %v", err)
2060			return
2061		}
2062		conn.Close()
2063	})).ts
2064
2065	var d countingDialer
2066	c := ts.Client()
2067	c.Transport.(*Transport).DialContext = d.DialContext
2068
2069	body := []byte("Hello")
2070	for i := 0; ; i++ {
2071		total, live := d.Read()
2072		if live < total {
2073			break
2074		}
2075		if i >= 1<<12 {
2076			t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
2077		}
2078
2079		req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2080		if err != nil {
2081			t.Fatal(err)
2082		}
2083		_, err = c.Do(req)
2084		if err == nil {
2085			t.Fatal("expected broken connection")
2086		}
2087
2088		runtime.GC()
2089	}
2090}
2091
2092type countedContext struct {
2093	context.Context
2094}
2095
2096type contextCounter struct {
2097	mu   sync.Mutex
2098	live int64
2099}
2100
2101func (cc *contextCounter) Track(ctx context.Context) context.Context {
2102	counted := new(countedContext)
2103	counted.Context = ctx
2104	cc.mu.Lock()
2105	defer cc.mu.Unlock()
2106	cc.live++
2107	runtime.SetFinalizer(counted, cc.decrement)
2108	return counted
2109}
2110
2111func (cc *contextCounter) decrement(*countedContext) {
2112	cc.mu.Lock()
2113	defer cc.mu.Unlock()
2114	cc.live--
2115}
2116
2117func (cc *contextCounter) Read() (live int64) {
2118	cc.mu.Lock()
2119	defer cc.mu.Unlock()
2120	return cc.live
2121}
2122
2123func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
2124	run(t, testTransportPersistConnContextLeakMaxConnsPerHost)
2125}
2126func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) {
2127	if mode == http2Mode {
2128		t.Skip("https://go.dev/issue/56021")
2129	}
2130
2131	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2132		runtime.Gosched()
2133		w.WriteHeader(StatusOK)
2134	})).ts
2135
2136	c := ts.Client()
2137	c.Transport.(*Transport).MaxConnsPerHost = 1
2138
2139	ctx := context.Background()
2140	body := []byte("Hello")
2141	doPosts := func(cc *contextCounter) {
2142		var wg sync.WaitGroup
2143		for n := 64; n > 0; n-- {
2144			wg.Add(1)
2145			go func() {
2146				defer wg.Done()
2147
2148				ctx := cc.Track(ctx)
2149				req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
2150				if err != nil {
2151					t.Error(err)
2152				}
2153
2154				_, err = c.Do(req.WithContext(ctx))
2155				if err != nil {
2156					t.Errorf("Do failed with error: %v", err)
2157				}
2158			}()
2159		}
2160		wg.Wait()
2161	}
2162
2163	var initialCC contextCounter
2164	doPosts(&initialCC)
2165
2166	// flushCC exists only to put pressure on the GC to finalize the initialCC
2167	// contexts: the flushCC allocations should eventually displace the initialCC
2168	// allocations.
2169	var flushCC contextCounter
2170	for i := 0; ; i++ {
2171		live := initialCC.Read()
2172		if live == 0 {
2173			break
2174		}
2175		if i >= 100 {
2176			t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
2177		}
2178		doPosts(&flushCC)
2179		runtime.GC()
2180	}
2181}
2182
2183// This used to crash; https://golang.org/issue/3266
2184func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) }
2185func testTransportIdleConnCrash(t *testing.T, mode testMode) {
2186	var tr *Transport
2187
2188	unblockCh := make(chan bool, 1)
2189	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2190		<-unblockCh
2191		tr.CloseIdleConnections()
2192	})).ts
2193	c := ts.Client()
2194	tr = c.Transport.(*Transport)
2195
2196	didreq := make(chan bool)
2197	go func() {
2198		res, err := c.Get(ts.URL)
2199		if err != nil {
2200			t.Error(err)
2201		} else {
2202			res.Body.Close() // returns idle conn
2203		}
2204		didreq <- true
2205	}()
2206	unblockCh <- true
2207	<-didreq
2208}
2209
2210// Test that the transport doesn't close the TCP connection early,
2211// before the response body has been read. This was a regression
2212// which sadly lacked a triggering test. The large response body made
2213// the old race easier to trigger.
2214func TestIssue3644(t *testing.T) { run(t, testIssue3644) }
2215func testIssue3644(t *testing.T, mode testMode) {
2216	const numFoos = 5000
2217	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2218		w.Header().Set("Connection", "close")
2219		for i := 0; i < numFoos; i++ {
2220			w.Write([]byte("foo "))
2221		}
2222	})).ts
2223	c := ts.Client()
2224	res, err := c.Get(ts.URL)
2225	if err != nil {
2226		t.Fatal(err)
2227	}
2228	defer res.Body.Close()
2229	bs, err := io.ReadAll(res.Body)
2230	if err != nil {
2231		t.Fatal(err)
2232	}
2233	if len(bs) != numFoos*len("foo ") {
2234		t.Errorf("unexpected response length")
2235	}
2236}
2237
2238// Test that a client receives a server's reply, even if the server doesn't read
2239// the entire request body.
2240func TestIssue3595(t *testing.T) {
2241	// Not parallel: modifies the global rstAvoidanceDelay.
2242	run(t, testIssue3595, testNotParallel)
2243}
2244func testIssue3595(t *testing.T, mode testMode) {
2245	runTimeSensitiveTest(t, []time.Duration{
2246		1 * time.Millisecond,
2247		5 * time.Millisecond,
2248		10 * time.Millisecond,
2249		50 * time.Millisecond,
2250		100 * time.Millisecond,
2251		500 * time.Millisecond,
2252		time.Second,
2253		5 * time.Second,
2254	}, func(t *testing.T, timeout time.Duration) error {
2255		SetRSTAvoidanceDelay(t, timeout)
2256		t.Logf("set RST avoidance delay to %v", timeout)
2257
2258		const deniedMsg = "sorry, denied."
2259		cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2260			Error(w, deniedMsg, StatusUnauthorized)
2261		}))
2262		// We need to close cst explicitly here so that in-flight server
2263		// requests don't race with the call to SetRSTAvoidanceDelay for a retry.
2264		defer cst.close()
2265		ts := cst.ts
2266		c := ts.Client()
2267
2268		res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
2269		if err != nil {
2270			return fmt.Errorf("Post: %v", err)
2271		}
2272		got, err := io.ReadAll(res.Body)
2273		if err != nil {
2274			return fmt.Errorf("Body ReadAll: %v", err)
2275		}
2276		t.Logf("server response:\n%s", got)
2277		if !strings.Contains(string(got), deniedMsg) {
2278			// If we got an RST packet too early, we should have seen an error
2279			// from io.ReadAll, not a silently-truncated body.
2280			t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
2281		}
2282		return nil
2283	})
2284}
2285
2286// From https://golang.org/issue/4454 ,
2287// "client fails to handle requests with no body and chunked encoding"
2288func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) }
2289func testChunkedNoContent(t *testing.T, mode testMode) {
2290	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2291		w.WriteHeader(StatusNoContent)
2292	})).ts
2293
2294	c := ts.Client()
2295	for _, closeBody := range []bool{true, false} {
2296		const n = 4
2297		for i := 1; i <= n; i++ {
2298			res, err := c.Get(ts.URL)
2299			if err != nil {
2300				t.Errorf("closingBody=%v, req %d/%d: %v", closeBody, i, n, err)
2301			} else {
2302				if closeBody {
2303					res.Body.Close()
2304				}
2305			}
2306		}
2307	}
2308}
2309
2310func TestTransportConcurrency(t *testing.T) {
2311	run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode})
2312}
2313func testTransportConcurrency(t *testing.T, mode testMode) {
2314	// Not parallel: uses global test hooks.
2315	maxProcs, numReqs := 16, 500
2316	if testing.Short() {
2317		maxProcs, numReqs = 4, 50
2318	}
2319	defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
2320	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2321		fmt.Fprintf(w, "%v", r.FormValue("echo"))
2322	})).ts
2323
2324	var wg sync.WaitGroup
2325	wg.Add(numReqs)
2326
2327	// Due to the Transport's "socket late binding" (see
2328	// idleConnCh in transport.go), the numReqs HTTP requests
2329	// below can finish with a dial still outstanding. To keep
2330	// the leak checker happy, keep track of pending dials and
2331	// wait for them to finish (and be closed or returned to the
2332	// idle pool) before we close idle connections.
2333	SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
2334	defer SetPendingDialHooks(nil, nil)
2335
2336	c := ts.Client()
2337	reqs := make(chan string)
2338	defer close(reqs)
2339
2340	for i := 0; i < maxProcs*2; i++ {
2341		go func() {
2342			for req := range reqs {
2343				res, err := c.Get(ts.URL + "/?echo=" + req)
2344				if err != nil {
2345					if runtime.GOOS == "netbsd" && strings.HasSuffix(err.Error(), ": connection reset by peer") {
2346						// https://go.dev/issue/52168: this test was observed to fail with
2347						// ECONNRESET errors in Dial on various netbsd builders.
2348						t.Logf("error on req %s: %v", req, err)
2349						t.Logf("(see https://go.dev/issue/52168)")
2350					} else {
2351						t.Errorf("error on req %s: %v", req, err)
2352					}
2353					wg.Done()
2354					continue
2355				}
2356				all, err := io.ReadAll(res.Body)
2357				if err != nil {
2358					t.Errorf("read error on req %s: %v", req, err)
2359				} else if string(all) != req {
2360					t.Errorf("body of req %s = %q; want %q", req, all, req)
2361				}
2362				res.Body.Close()
2363				wg.Done()
2364			}
2365		}()
2366	}
2367	for i := 0; i < numReqs; i++ {
2368		reqs <- fmt.Sprintf("request-%d", i)
2369	}
2370	wg.Wait()
2371}
2372
2373func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) }
2374func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) {
2375	mux := NewServeMux()
2376	mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2377		io.Copy(w, neverEnding('a'))
2378	})
2379	ts := newClientServerTest(t, mode, mux).ts
2380
2381	connc := make(chan net.Conn, 1)
2382	c := ts.Client()
2383	c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2384		conn, err := net.Dial(n, addr)
2385		if err != nil {
2386			return nil, err
2387		}
2388		select {
2389		case connc <- conn:
2390		default:
2391		}
2392		return conn, nil
2393	}
2394
2395	res, err := c.Get(ts.URL + "/get")
2396	if err != nil {
2397		t.Fatalf("Error issuing GET: %v", err)
2398	}
2399	defer res.Body.Close()
2400
2401	conn := <-connc
2402	conn.SetDeadline(time.Now().Add(1 * time.Millisecond))
2403	_, err = io.Copy(io.Discard, res.Body)
2404	if err == nil {
2405		t.Errorf("Unexpected successful copy")
2406	}
2407}
2408
2409func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
2410	run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode})
2411}
2412func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) {
2413	const debug = false
2414	mux := NewServeMux()
2415	mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
2416		io.Copy(w, neverEnding('a'))
2417	})
2418	mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
2419		defer r.Body.Close()
2420		io.Copy(io.Discard, r.Body)
2421	})
2422	ts := newClientServerTest(t, mode, mux).ts
2423	timeout := 100 * time.Millisecond
2424
2425	c := ts.Client()
2426	c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
2427		conn, err := net.Dial(n, addr)
2428		if err != nil {
2429			return nil, err
2430		}
2431		conn.SetDeadline(time.Now().Add(timeout))
2432		if debug {
2433			conn = NewLoggingConn("client", conn)
2434		}
2435		return conn, nil
2436	}
2437
2438	getFailed := false
2439	nRuns := 5
2440	if testing.Short() {
2441		nRuns = 1
2442	}
2443	for i := 0; i < nRuns; i++ {
2444		if debug {
2445			println("run", i+1, "of", nRuns)
2446		}
2447		sres, err := c.Get(ts.URL + "/get")
2448		if err != nil {
2449			if !getFailed {
2450				// Make the timeout longer, once.
2451				getFailed = true
2452				t.Logf("increasing timeout")
2453				i--
2454				timeout *= 10
2455				continue
2456			}
2457			t.Errorf("Error issuing GET: %v", err)
2458			break
2459		}
2460		req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
2461		_, err = c.Do(req)
2462		if err == nil {
2463			sres.Body.Close()
2464			t.Errorf("Unexpected successful PUT")
2465			break
2466		}
2467		sres.Body.Close()
2468	}
2469	if debug {
2470		println("tests complete; waiting for handlers to finish")
2471	}
2472	ts.Close()
2473}
2474
2475func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) }
2476func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
2477	if testing.Short() {
2478		t.Skip("skipping timeout test in -short mode")
2479	}
2480
2481	timeout := 2 * time.Millisecond
2482	retry := true
2483	for retry && !t.Failed() {
2484		var srvWG sync.WaitGroup
2485		inHandler := make(chan bool, 1)
2486		mux := NewServeMux()
2487		mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {
2488			inHandler <- true
2489			srvWG.Done()
2490		})
2491		mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
2492			inHandler <- true
2493			<-r.Context().Done()
2494			srvWG.Done()
2495		})
2496		ts := newClientServerTest(t, mode, mux).ts
2497
2498		c := ts.Client()
2499		c.Transport.(*Transport).ResponseHeaderTimeout = timeout
2500
2501		retry = false
2502		srvWG.Add(3)
2503		tests := []struct {
2504			path        string
2505			wantTimeout bool
2506		}{
2507			{path: "/fast"},
2508			{path: "/slow", wantTimeout: true},
2509			{path: "/fast"},
2510		}
2511		for i, tt := range tests {
2512			req, _ := NewRequest("GET", ts.URL+tt.path, nil)
2513			req = req.WithT(t)
2514			res, err := c.Do(req)
2515			<-inHandler
2516			if err != nil {
2517				uerr, ok := err.(*url.Error)
2518				if !ok {
2519					t.Errorf("error is not a url.Error; got: %#v", err)
2520					continue
2521				}
2522				nerr, ok := uerr.Err.(net.Error)
2523				if !ok {
2524					t.Errorf("error does not satisfy net.Error interface; got: %#v", err)
2525					continue
2526				}
2527				if !nerr.Timeout() {
2528					t.Errorf("want timeout error; got: %q", nerr)
2529					continue
2530				}
2531				if !tt.wantTimeout {
2532					if !retry {
2533						// The timeout may be set too short. Retry with a longer one.
2534						t.Logf("unexpected timeout for path %q after %v; retrying with longer timeout", tt.path, timeout)
2535						timeout *= 2
2536						retry = true
2537					}
2538				}
2539				if !strings.Contains(err.Error(), "timeout awaiting response headers") {
2540					t.Errorf("%d. unexpected error: %v", i, err)
2541				}
2542				continue
2543			}
2544			if tt.wantTimeout {
2545				t.Errorf(`no error for path %q; expected "timeout awaiting response headers"`, tt.path)
2546				continue
2547			}
2548			if res.StatusCode != 200 {
2549				t.Errorf("%d for path %q status = %d; want 200", i, tt.path, res.StatusCode)
2550			}
2551		}
2552
2553		srvWG.Wait()
2554		ts.Close()
2555	}
2556}
2557
2558// A cancelTest is a test of request cancellation.
2559type cancelTest struct {
2560	mode     testMode
2561	newReq   func(req *Request) *Request       // prepare the request to cancel
2562	cancel   func(tr *Transport, req *Request) // cancel the request
2563	checkErr func(when string, err error)      // verify the expected error
2564}
2565
2566// runCancelTestTransport uses Transport.CancelRequest.
2567func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2568	t.Run("TransportCancel", func(t *testing.T) {
2569		f(t, cancelTest{
2570			mode: mode,
2571			newReq: func(req *Request) *Request {
2572				return req
2573			},
2574			cancel: func(tr *Transport, req *Request) {
2575				tr.CancelRequest(req)
2576			},
2577			checkErr: func(when string, err error) {
2578				if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2579					t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2580				}
2581			},
2582		})
2583	})
2584}
2585
2586// runCancelTestChannel uses Request.Cancel.
2587func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2588	var cancelOnce sync.Once
2589	cancelc := make(chan struct{})
2590	f(t, cancelTest{
2591		mode: mode,
2592		newReq: func(req *Request) *Request {
2593			req.Cancel = cancelc
2594			return req
2595		},
2596		cancel: func(tr *Transport, req *Request) {
2597			cancelOnce.Do(func() {
2598				close(cancelc)
2599			})
2600		},
2601		checkErr: func(when string, err error) {
2602			if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
2603				t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
2604			}
2605		},
2606	})
2607}
2608
2609// runCancelTestContext uses a request context.
2610func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
2611	ctx, cancel := context.WithCancel(context.Background())
2612	f(t, cancelTest{
2613		mode: mode,
2614		newReq: func(req *Request) *Request {
2615			return req.WithContext(ctx)
2616		},
2617		cancel: func(tr *Transport, req *Request) {
2618			cancel()
2619		},
2620		checkErr: func(when string, err error) {
2621			if !errors.Is(err, context.Canceled) {
2622				t.Errorf("%v error = %v, want context.Canceled", when, err)
2623			}
2624		},
2625	})
2626}
2627
2628func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
2629	run(t, func(t *testing.T, mode testMode) {
2630		if mode == http1Mode {
2631			t.Run("TransportCancel", func(t *testing.T) {
2632				runCancelTestTransport(t, mode, f)
2633			})
2634		}
2635		t.Run("RequestCancel", func(t *testing.T) {
2636			runCancelTestChannel(t, mode, f)
2637		})
2638		t.Run("ContextCancel", func(t *testing.T) {
2639			runCancelTestContext(t, mode, f)
2640		})
2641	}, opts...)
2642}
2643
2644func TestTransportCancelRequest(t *testing.T) {
2645	runCancelTest(t, testTransportCancelRequest)
2646}
2647func testTransportCancelRequest(t *testing.T, test cancelTest) {
2648	if testing.Short() {
2649		t.Skip("skipping test in -short mode")
2650	}
2651
2652	const msg = "Hello"
2653	unblockc := make(chan bool)
2654	ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2655		io.WriteString(w, msg)
2656		w.(Flusher).Flush() // send headers and some body
2657		<-unblockc
2658	})).ts
2659	defer close(unblockc)
2660
2661	c := ts.Client()
2662	tr := c.Transport.(*Transport)
2663
2664	req, _ := NewRequest("GET", ts.URL, nil)
2665	req = test.newReq(req)
2666	res, err := c.Do(req)
2667	if err != nil {
2668		t.Fatal(err)
2669	}
2670	body := make([]byte, len(msg))
2671	n, _ := io.ReadFull(res.Body, body)
2672	if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2673		t.Errorf("Body = %q; want %q", body[:n], msg)
2674	}
2675	test.cancel(tr, req)
2676
2677	tail, err := io.ReadAll(res.Body)
2678	res.Body.Close()
2679	test.checkErr("Body.Read", err)
2680	if len(tail) > 0 {
2681		t.Errorf("Spurious bytes from Body.Read: %q", tail)
2682	}
2683
2684	// Verify no outstanding requests after readLoop/writeLoop
2685	// goroutines shut down.
2686	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2687		n := tr.NumPendingRequestsForTesting()
2688		if n > 0 {
2689			if d > 0 {
2690				t.Logf("pending requests = %d after %v (want 0)", n, d)
2691			}
2692			return false
2693		}
2694		return true
2695	})
2696}
2697
2698func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
2699	if testing.Short() {
2700		t.Skip("skipping test in -short mode")
2701	}
2702	unblockc := make(chan bool)
2703	ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2704		<-unblockc
2705	})).ts
2706	defer close(unblockc)
2707
2708	c := ts.Client()
2709	tr := c.Transport.(*Transport)
2710
2711	donec := make(chan bool)
2712	req, _ := NewRequest("GET", ts.URL, body)
2713	req = test.newReq(req)
2714	go func() {
2715		defer close(donec)
2716		c.Do(req)
2717	}()
2718
2719	unblockc <- true
2720	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2721		test.cancel(tr, req)
2722		select {
2723		case <-donec:
2724			return true
2725		default:
2726			if d > 0 {
2727				t.Logf("Do of canceled request has not returned after %v", d)
2728			}
2729			return false
2730		}
2731	})
2732}
2733
2734func TestTransportCancelRequestInDo(t *testing.T) {
2735	runCancelTest(t, func(t *testing.T, test cancelTest) {
2736		testTransportCancelRequestInDo(t, test, nil)
2737	})
2738}
2739
2740func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
2741	runCancelTest(t, func(t *testing.T, test cancelTest) {
2742		testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
2743	})
2744}
2745
2746func TestTransportCancelRequestInDial(t *testing.T) {
2747	runCancelTest(t, testTransportCancelRequestInDial)
2748}
2749func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
2750	defer afterTest(t)
2751	if testing.Short() {
2752		t.Skip("skipping test in -short mode")
2753	}
2754	var logbuf strings.Builder
2755	eventLog := log.New(&logbuf, "", 0)
2756
2757	unblockDial := make(chan bool)
2758	defer close(unblockDial)
2759
2760	inDial := make(chan bool)
2761	tr := &Transport{
2762		Dial: func(network, addr string) (net.Conn, error) {
2763			eventLog.Println("dial: blocking")
2764			if !<-inDial {
2765				return nil, errors.New("main Test goroutine exited")
2766			}
2767			<-unblockDial
2768			return nil, errors.New("nope")
2769		},
2770	}
2771	cl := &Client{Transport: tr}
2772	gotres := make(chan bool)
2773	req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
2774	req = test.newReq(req)
2775	go func() {
2776		_, err := cl.Do(req)
2777		eventLog.Printf("Get error = %v", err != nil)
2778		test.checkErr("Get", err)
2779		gotres <- true
2780	}()
2781
2782	inDial <- true
2783
2784	eventLog.Printf("canceling")
2785	test.cancel(tr, req)
2786	test.cancel(tr, req) // used to panic on second call to Transport.Cancel
2787
2788	if d, ok := t.Deadline(); ok {
2789		// When the test's deadline is about to expire, log the pending events for
2790		// better debugging.
2791		timeout := time.Until(d) * 19 / 20 // Allow 5% for cleanup.
2792		timer := time.AfterFunc(timeout, func() {
2793			panic(fmt.Sprintf("hang in %s. events are: %s", t.Name(), logbuf.String()))
2794		})
2795		defer timer.Stop()
2796	}
2797	<-gotres
2798
2799	got := logbuf.String()
2800	want := `dial: blocking
2801canceling
2802Get error = true
2803`
2804	if got != want {
2805		t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
2806	}
2807}
2808
2809// Issue 51354
2810func TestTransportCancelRequestWithBody(t *testing.T) {
2811	runCancelTest(t, testTransportCancelRequestWithBody)
2812}
2813func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
2814	if testing.Short() {
2815		t.Skip("skipping test in -short mode")
2816	}
2817
2818	const msg = "Hello"
2819	unblockc := make(chan struct{})
2820	ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2821		io.WriteString(w, msg)
2822		w.(Flusher).Flush() // send headers and some body
2823		<-unblockc
2824	})).ts
2825	defer close(unblockc)
2826
2827	c := ts.Client()
2828	tr := c.Transport.(*Transport)
2829
2830	req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
2831	req = test.newReq(req)
2832
2833	res, err := c.Do(req)
2834	if err != nil {
2835		t.Fatal(err)
2836	}
2837	body := make([]byte, len(msg))
2838	n, _ := io.ReadFull(res.Body, body)
2839	if n != len(body) || !bytes.Equal(body, []byte(msg)) {
2840		t.Errorf("Body = %q; want %q", body[:n], msg)
2841	}
2842	test.cancel(tr, req)
2843
2844	tail, err := io.ReadAll(res.Body)
2845	res.Body.Close()
2846	test.checkErr("Body.Read", err)
2847	if len(tail) > 0 {
2848		t.Errorf("Spurious bytes from Body.Read: %q", tail)
2849	}
2850
2851	// Verify no outstanding requests after readLoop/writeLoop
2852	// goroutines shut down.
2853	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
2854		n := tr.NumPendingRequestsForTesting()
2855		if n > 0 {
2856			if d > 0 {
2857				t.Logf("pending requests = %d after %v (want 0)", n, d)
2858			}
2859			return false
2860		}
2861		return true
2862	})
2863}
2864
2865func TestTransportCancelRequestBeforeDo(t *testing.T) {
2866	// We can't cancel a request that hasn't started using Transport.CancelRequest.
2867	run(t, func(t *testing.T, mode testMode) {
2868		t.Run("RequestCancel", func(t *testing.T) {
2869			runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
2870		})
2871		t.Run("ContextCancel", func(t *testing.T) {
2872			runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
2873		})
2874	})
2875}
2876func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
2877	unblockc := make(chan bool)
2878	cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2879		<-unblockc
2880	}))
2881	defer close(unblockc)
2882
2883	c := cst.ts.Client()
2884
2885	req, _ := NewRequest("GET", cst.ts.URL, nil)
2886	req = test.newReq(req)
2887	test.cancel(cst.tr, req)
2888
2889	_, err := c.Do(req)
2890	test.checkErr("Do", err)
2891}
2892
2893// Issue 11020. The returned error message should be errRequestCanceled
2894func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
2895	runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
2896}
2897func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
2898	defer afterTest(t)
2899
2900	serverConnCh := make(chan net.Conn, 1)
2901	tr := &Transport{
2902		Dial: func(network, addr string) (net.Conn, error) {
2903			cc, sc := net.Pipe()
2904			serverConnCh <- sc
2905			return cc, nil
2906		},
2907	}
2908	defer tr.CloseIdleConnections()
2909	errc := make(chan error, 1)
2910	req, _ := NewRequest("GET", "http://example.com/", nil)
2911	req = test.newReq(req)
2912	go func() {
2913		_, err := tr.RoundTrip(req)
2914		errc <- err
2915	}()
2916
2917	sc := <-serverConnCh
2918	verb := make([]byte, 3)
2919	if _, err := io.ReadFull(sc, verb); err != nil {
2920		t.Errorf("Error reading HTTP verb from server: %v", err)
2921	}
2922	if string(verb) != "GET" {
2923		t.Errorf("server received %q; want GET", verb)
2924	}
2925	defer sc.Close()
2926
2927	test.cancel(tr, req)
2928
2929	err := <-errc
2930	if err == nil {
2931		t.Fatalf("unexpected success from RoundTrip")
2932	}
2933	test.checkErr("RoundTrip", err)
2934}
2935
2936// golang.org/issue/3672 -- Client can't close HTTP stream
2937// Calling Close on a Response.Body used to just read until EOF.
2938// Now it actually closes the TCP connection.
2939func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) }
2940func testTransportCloseResponseBody(t *testing.T, mode testMode) {
2941	writeErr := make(chan error, 1)
2942	msg := []byte("young\n")
2943	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2944		for {
2945			_, err := w.Write(msg)
2946			if err != nil {
2947				writeErr <- err
2948				return
2949			}
2950			w.(Flusher).Flush()
2951		}
2952	})).ts
2953
2954	c := ts.Client()
2955	tr := c.Transport.(*Transport)
2956
2957	req, _ := NewRequest("GET", ts.URL, nil)
2958	defer tr.CancelRequest(req)
2959
2960	res, err := c.Do(req)
2961	if err != nil {
2962		t.Fatal(err)
2963	}
2964
2965	const repeats = 3
2966	buf := make([]byte, len(msg)*repeats)
2967	want := bytes.Repeat(msg, repeats)
2968
2969	_, err = io.ReadFull(res.Body, buf)
2970	if err != nil {
2971		t.Fatal(err)
2972	}
2973	if !bytes.Equal(buf, want) {
2974		t.Fatalf("read %q; want %q", buf, want)
2975	}
2976
2977	if err := res.Body.Close(); err != nil {
2978		t.Errorf("Close = %v", err)
2979	}
2980
2981	if err := <-writeErr; err == nil {
2982		t.Errorf("expected non-nil write error")
2983	}
2984}
2985
2986type fooProto struct{}
2987
2988func (fooProto) RoundTrip(req *Request) (*Response, error) {
2989	res := &Response{
2990		Status:     "200 OK",
2991		StatusCode: 200,
2992		Header:     make(Header),
2993		Body:       io.NopCloser(strings.NewReader("You wanted " + req.URL.String())),
2994	}
2995	return res, nil
2996}
2997
2998func TestTransportAltProto(t *testing.T) {
2999	defer afterTest(t)
3000	tr := &Transport{}
3001	c := &Client{Transport: tr}
3002	tr.RegisterProtocol("foo", fooProto{})
3003	res, err := c.Get("foo://bar.com/path")
3004	if err != nil {
3005		t.Fatal(err)
3006	}
3007	bodyb, err := io.ReadAll(res.Body)
3008	if err != nil {
3009		t.Fatal(err)
3010	}
3011	body := string(bodyb)
3012	if e := "You wanted foo://bar.com/path"; body != e {
3013		t.Errorf("got response %q, want %q", body, e)
3014	}
3015}
3016
3017func TestTransportNoHost(t *testing.T) {
3018	defer afterTest(t)
3019	tr := &Transport{}
3020	_, err := tr.RoundTrip(&Request{
3021		Header: make(Header),
3022		URL: &url.URL{
3023			Scheme: "http",
3024		},
3025	})
3026	want := "http: no Host in request URL"
3027	if got := fmt.Sprint(err); got != want {
3028		t.Errorf("error = %v; want %q", err, want)
3029	}
3030}
3031
3032// Issue 13311
3033func TestTransportEmptyMethod(t *testing.T) {
3034	req, _ := NewRequest("GET", "http://foo.com/", nil)
3035	req.Method = ""                                 // docs say "For client requests an empty string means GET"
3036	got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport
3037	if err != nil {
3038		t.Fatal(err)
3039	}
3040	if !strings.Contains(string(got), "GET ") {
3041		t.Fatalf("expected substring 'GET '; got: %s", got)
3042	}
3043}
3044
3045func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) }
3046func testTransportSocketLateBinding(t *testing.T, mode testMode) {
3047	mux := NewServeMux()
3048	fooGate := make(chan bool, 1)
3049	mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) {
3050		w.Header().Set("foo-ipport", r.RemoteAddr)
3051		w.(Flusher).Flush()
3052		<-fooGate
3053	})
3054	mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) {
3055		w.Header().Set("bar-ipport", r.RemoteAddr)
3056	})
3057	ts := newClientServerTest(t, mode, mux).ts
3058
3059	dialGate := make(chan bool, 1)
3060	dialing := make(chan bool)
3061	c := ts.Client()
3062	c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) {
3063		for {
3064			select {
3065			case ok := <-dialGate:
3066				if !ok {
3067					return nil, errors.New("manually closed")
3068				}
3069				return net.Dial(n, addr)
3070			case dialing <- true:
3071			}
3072		}
3073	}
3074	defer close(dialGate)
3075
3076	dialGate <- true // only allow one dial
3077	fooRes, err := c.Get(ts.URL + "/foo")
3078	if err != nil {
3079		t.Fatal(err)
3080	}
3081	fooAddr := fooRes.Header.Get("foo-ipport")
3082	if fooAddr == "" {
3083		t.Fatal("No addr on /foo request")
3084	}
3085
3086	fooDone := make(chan struct{})
3087	go func() {
3088		// We know that the foo Dial completed and reached the handler because we
3089		// read its header. Wait for the bar request to block in Dial, then
3090		// let the foo response finish so we can use its connection for /bar.
3091
3092		if mode == http2Mode {
3093			// In HTTP/2 mode, the second Dial won't happen because the protocol
3094			// multiplexes the streams by default. Just sleep for an arbitrary time;
3095			// the test should pass regardless of how far the bar request gets by this
3096			// point.
3097			select {
3098			case <-dialing:
3099				t.Errorf("unexpected second Dial in HTTP/2 mode")
3100			case <-time.After(10 * time.Millisecond):
3101			}
3102		} else {
3103			<-dialing
3104		}
3105		fooGate <- true
3106		io.Copy(io.Discard, fooRes.Body)
3107		fooRes.Body.Close()
3108		close(fooDone)
3109	}()
3110	defer func() {
3111		<-fooDone
3112	}()
3113
3114	barRes, err := c.Get(ts.URL + "/bar")
3115	if err != nil {
3116		t.Fatal(err)
3117	}
3118	barAddr := barRes.Header.Get("bar-ipport")
3119	if barAddr != fooAddr {
3120		t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr)
3121	}
3122	barRes.Body.Close()
3123}
3124
3125// Issue 2184
3126func TestTransportReading100Continue(t *testing.T) {
3127	defer afterTest(t)
3128
3129	const numReqs = 5
3130	reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) }
3131	reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) }
3132
3133	send100Response := func(w *io.PipeWriter, r *io.PipeReader) {
3134		defer w.Close()
3135		defer r.Close()
3136		br := bufio.NewReader(r)
3137		n := 0
3138		for {
3139			n++
3140			req, err := ReadRequest(br)
3141			if err == io.EOF {
3142				return
3143			}
3144			if err != nil {
3145				t.Error(err)
3146				return
3147			}
3148			slurp, err := io.ReadAll(req.Body)
3149			if err != nil {
3150				t.Errorf("Server request body slurp: %v", err)
3151				return
3152			}
3153			id := req.Header.Get("Request-Id")
3154			resCode := req.Header.Get("X-Want-Response-Code")
3155			if resCode == "" {
3156				resCode = "100 Continue"
3157				if string(slurp) != reqBody(n) {
3158					t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n))
3159				}
3160			}
3161			body := fmt.Sprintf("Response number %d", n)
3162			v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s
3163Date: Thu, 28 Feb 2013 17:55:41 GMT
3164
3165HTTP/1.1 200 OK
3166Content-Type: text/html
3167Echo-Request-Id: %s
3168Content-Length: %d
3169
3170%s`, resCode, id, len(body), body), "\n", "\r\n", -1))
3171			w.Write(v)
3172			if id == reqID(numReqs) {
3173				return
3174			}
3175		}
3176
3177	}
3178
3179	tr := &Transport{
3180		Dial: func(n, addr string) (net.Conn, error) {
3181			sr, sw := io.Pipe() // server read/write
3182			cr, cw := io.Pipe() // client read/write
3183			conn := &rwTestConn{
3184				Reader: cr,
3185				Writer: sw,
3186				closeFunc: func() error {
3187					sw.Close()
3188					cw.Close()
3189					return nil
3190				},
3191			}
3192			go send100Response(cw, sr)
3193			return conn, nil
3194		},
3195		DisableKeepAlives: false,
3196	}
3197	defer tr.CloseIdleConnections()
3198	c := &Client{Transport: tr}
3199
3200	testResponse := func(req *Request, name string, wantCode int) {
3201		t.Helper()
3202		res, err := c.Do(req)
3203		if err != nil {
3204			t.Fatalf("%s: Do: %v", name, err)
3205		}
3206		if res.StatusCode != wantCode {
3207			t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode)
3208		}
3209		if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack {
3210			t.Errorf("%s: response id %q != request id %q", name, idBack, id)
3211		}
3212		_, err = io.ReadAll(res.Body)
3213		if err != nil {
3214			t.Fatalf("%s: Slurp error: %v", name, err)
3215		}
3216	}
3217
3218	// Few 100 responses, making sure we're not off-by-one.
3219	for i := 1; i <= numReqs; i++ {
3220		req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i)))
3221		req.Header.Set("Request-Id", reqID(i))
3222		testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
3223	}
3224}
3225
3226// Issue 17739: the HTTP client must ignore any unknown 1xx
3227// informational responses before the actual response.
3228func TestTransportIgnore1xxResponses(t *testing.T) {
3229	run(t, testTransportIgnore1xxResponses, []testMode{http1Mode})
3230}
3231func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
3232	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3233		conn, buf, _ := w.(Hijacker).Hijack()
3234		buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
3235		buf.Flush()
3236		conn.Close()
3237	}))
3238	cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
3239
3240	var got strings.Builder
3241
3242	req, _ := NewRequest("GET", cst.ts.URL, nil)
3243	req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
3244		Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
3245			fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
3246			return nil
3247		},
3248	}))
3249	res, err := cst.c.Do(req)
3250	if err != nil {
3251		t.Fatal(err)
3252	}
3253	defer res.Body.Close()
3254
3255	res.Write(&got)
3256	want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
3257	if got.String() != want {
3258		t.Errorf(" got: %q\nwant: %q\n", got.String(), want)
3259	}
3260}
3261
3262func TestTransportLimits1xxResponses(t *testing.T) {
3263	run(t, testTransportLimits1xxResponses, []testMode{http1Mode})
3264}
3265func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
3266	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3267		conn, buf, _ := w.(Hijacker).Hijack()
3268		for i := 0; i < 10; i++ {
3269			buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
3270		}
3271		buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3272		buf.Flush()
3273		conn.Close()
3274	}))
3275	cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
3276
3277	res, err := cst.c.Get(cst.ts.URL)
3278	if res != nil {
3279		defer res.Body.Close()
3280	}
3281	got := fmt.Sprint(err)
3282	wantSub := "too many 1xx informational responses"
3283	if !strings.Contains(got, wantSub) {
3284		t.Errorf("Get error = %v; want substring %q", err, wantSub)
3285	}
3286}
3287
3288// Issue 26161: the HTTP client must treat 101 responses
3289// as the final response.
3290func TestTransportTreat101Terminal(t *testing.T) {
3291	run(t, testTransportTreat101Terminal, []testMode{http1Mode})
3292}
3293func testTransportTreat101Terminal(t *testing.T, mode testMode) {
3294	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3295		conn, buf, _ := w.(Hijacker).Hijack()
3296		buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
3297		buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
3298		buf.Flush()
3299		conn.Close()
3300	}))
3301	res, err := cst.c.Get(cst.ts.URL)
3302	if err != nil {
3303		t.Fatal(err)
3304	}
3305	defer res.Body.Close()
3306	if res.StatusCode != StatusSwitchingProtocols {
3307		t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
3308	}
3309}
3310
3311type proxyFromEnvTest struct {
3312	req string // URL to fetch; blank means "http://example.com"
3313
3314	env      string // HTTP_PROXY
3315	httpsenv string // HTTPS_PROXY
3316	noenv    string // NO_PROXY
3317	reqmeth  string // REQUEST_METHOD
3318
3319	want    string
3320	wanterr error
3321}
3322
3323func (t proxyFromEnvTest) String() string {
3324	var buf strings.Builder
3325	space := func() {
3326		if buf.Len() > 0 {
3327			buf.WriteByte(' ')
3328		}
3329	}
3330	if t.env != "" {
3331		fmt.Fprintf(&buf, "http_proxy=%q", t.env)
3332	}
3333	if t.httpsenv != "" {
3334		space()
3335		fmt.Fprintf(&buf, "https_proxy=%q", t.httpsenv)
3336	}
3337	if t.noenv != "" {
3338		space()
3339		fmt.Fprintf(&buf, "no_proxy=%q", t.noenv)
3340	}
3341	if t.reqmeth != "" {
3342		space()
3343		fmt.Fprintf(&buf, "request_method=%q", t.reqmeth)
3344	}
3345	req := "http://example.com"
3346	if t.req != "" {
3347		req = t.req
3348	}
3349	space()
3350	fmt.Fprintf(&buf, "req=%q", req)
3351	return strings.TrimSpace(buf.String())
3352}
3353
3354var proxyFromEnvTests = []proxyFromEnvTest{
3355	{env: "127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3356	{env: "cache.corp.example.com:1234", want: "http://cache.corp.example.com:1234"},
3357	{env: "cache.corp.example.com", want: "http://cache.corp.example.com"},
3358	{env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"},
3359	{env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"},
3360	{env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"},
3361	{env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"},
3362	{env: "socks5h://127.0.0.1", want: "socks5h://127.0.0.1"},
3363
3364	// Don't use secure for http
3365	{req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"},
3366	// Use secure for https.
3367	{req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"},
3368	{req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"},
3369
3370	// Issue 16405: don't use HTTP_PROXY in a CGI environment,
3371	// where HTTP_PROXY can be attacker-controlled.
3372	{env: "http://10.1.2.3:8080", reqmeth: "POST",
3373		want:    "<nil>",
3374		wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
3375
3376	{want: "<nil>"},
3377
3378	{noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
3379	{noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3380	{noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3381	{noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
3382	{noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
3383}
3384
3385func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
3386	t.Helper()
3387	reqURL := tt.req
3388	if reqURL == "" {
3389		reqURL = "http://example.com"
3390	}
3391	req, _ := NewRequest("GET", reqURL, nil)
3392	url, err := proxyForRequest(req)
3393	if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
3394		t.Errorf("%v: got error = %q, want %q", tt, g, e)
3395		return
3396	}
3397	if got := fmt.Sprintf("%s", url); got != tt.want {
3398		t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
3399	}
3400}
3401
3402func TestProxyFromEnvironment(t *testing.T) {
3403	ResetProxyEnv()
3404	defer ResetProxyEnv()
3405	for _, tt := range proxyFromEnvTests {
3406		testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3407			os.Setenv("HTTP_PROXY", tt.env)
3408			os.Setenv("HTTPS_PROXY", tt.httpsenv)
3409			os.Setenv("NO_PROXY", tt.noenv)
3410			os.Setenv("REQUEST_METHOD", tt.reqmeth)
3411			ResetCachedEnvironment()
3412			return ProxyFromEnvironment(req)
3413		})
3414	}
3415}
3416
3417func TestProxyFromEnvironmentLowerCase(t *testing.T) {
3418	ResetProxyEnv()
3419	defer ResetProxyEnv()
3420	for _, tt := range proxyFromEnvTests {
3421		testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
3422			os.Setenv("http_proxy", tt.env)
3423			os.Setenv("https_proxy", tt.httpsenv)
3424			os.Setenv("no_proxy", tt.noenv)
3425			os.Setenv("REQUEST_METHOD", tt.reqmeth)
3426			ResetCachedEnvironment()
3427			return ProxyFromEnvironment(req)
3428		})
3429	}
3430}
3431
3432func TestIdleConnChannelLeak(t *testing.T) {
3433	run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel)
3434}
3435func testIdleConnChannelLeak(t *testing.T, mode testMode) {
3436	// Not parallel: uses global test hooks.
3437	var mu sync.Mutex
3438	var n int
3439
3440	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3441		mu.Lock()
3442		n++
3443		mu.Unlock()
3444	})).ts
3445
3446	const nReqs = 5
3447	didRead := make(chan bool, nReqs)
3448	SetReadLoopBeforeNextReadHook(func() { didRead <- true })
3449	defer SetReadLoopBeforeNextReadHook(nil)
3450
3451	c := ts.Client()
3452	tr := c.Transport.(*Transport)
3453	tr.Dial = func(netw, addr string) (net.Conn, error) {
3454		return net.Dial(netw, ts.Listener.Addr().String())
3455	}
3456
3457	// First, without keep-alives.
3458	for _, disableKeep := range []bool{true, false} {
3459		tr.DisableKeepAlives = disableKeep
3460		for i := 0; i < nReqs; i++ {
3461			_, err := c.Get(fmt.Sprintf("http://foo-host-%d.tld/", i))
3462			if err != nil {
3463				t.Fatal(err)
3464			}
3465			// Note: no res.Body.Close is needed here, since the
3466			// response Content-Length is zero. Perhaps the test
3467			// should be more explicit and use a HEAD, but tests
3468			// elsewhere guarantee that zero byte responses generate
3469			// a "Content-Length: 0" instead of chunking.
3470		}
3471
3472		// At this point, each of the 5 Transport.readLoop goroutines
3473		// are scheduling noting that there are no response bodies (see
3474		// earlier comment), and are then calling putIdleConn, which
3475		// decrements this count. Usually that happens quickly, which is
3476		// why this test has seemed to work for ages. But it's still
3477		// racey: we have wait for them to finish first. See Issue 10427
3478		for i := 0; i < nReqs; i++ {
3479			<-didRead
3480		}
3481
3482		if got := tr.IdleConnWaitMapSizeForTesting(); got != 0 {
3483			t.Fatalf("for DisableKeepAlives = %v, map size = %d; want 0", disableKeep, got)
3484		}
3485	}
3486}
3487
3488// Verify the status quo: that the Client.Post function coerces its
3489// body into a ReadCloser if it's a Closer, and that the Transport
3490// then closes it.
3491func TestTransportClosesRequestBody(t *testing.T) {
3492	run(t, testTransportClosesRequestBody, []testMode{http1Mode})
3493}
3494func testTransportClosesRequestBody(t *testing.T, mode testMode) {
3495	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3496		io.Copy(io.Discard, r.Body)
3497	})).ts
3498
3499	c := ts.Client()
3500
3501	closes := 0
3502
3503	res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
3504	if err != nil {
3505		t.Fatal(err)
3506	}
3507	res.Body.Close()
3508	if closes != 1 {
3509		t.Errorf("closes = %d; want 1", closes)
3510	}
3511}
3512
3513func TestTransportTLSHandshakeTimeout(t *testing.T) {
3514	defer afterTest(t)
3515	if testing.Short() {
3516		t.Skip("skipping in short mode")
3517	}
3518	ln := newLocalListener(t)
3519	defer ln.Close()
3520	testdonec := make(chan struct{})
3521	defer close(testdonec)
3522
3523	go func() {
3524		c, err := ln.Accept()
3525		if err != nil {
3526			t.Error(err)
3527			return
3528		}
3529		<-testdonec
3530		c.Close()
3531	}()
3532
3533	tr := &Transport{
3534		Dial: func(_, _ string) (net.Conn, error) {
3535			return net.Dial("tcp", ln.Addr().String())
3536		},
3537		TLSHandshakeTimeout: 250 * time.Millisecond,
3538	}
3539	cl := &Client{Transport: tr}
3540	_, err := cl.Get("https://dummy.tld/")
3541	if err == nil {
3542		t.Error("expected error")
3543		return
3544	}
3545	ue, ok := err.(*url.Error)
3546	if !ok {
3547		t.Errorf("expected url.Error; got %#v", err)
3548		return
3549	}
3550	ne, ok := ue.Err.(net.Error)
3551	if !ok {
3552		t.Errorf("expected net.Error; got %#v", err)
3553		return
3554	}
3555	if !ne.Timeout() {
3556		t.Errorf("expected timeout error; got %v", err)
3557	}
3558	if !strings.Contains(err.Error(), "handshake timeout") {
3559		t.Errorf("expected 'handshake timeout' in error; got %v", err)
3560	}
3561}
3562
3563// Trying to repro golang.org/issue/3514
3564func TestTLSServerClosesConnection(t *testing.T) {
3565	run(t, testTLSServerClosesConnection, []testMode{https1Mode})
3566}
3567func testTLSServerClosesConnection(t *testing.T, mode testMode) {
3568	closedc := make(chan bool, 1)
3569	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3570		if strings.Contains(r.URL.Path, "/keep-alive-then-die") {
3571			conn, _, _ := w.(Hijacker).Hijack()
3572			conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo"))
3573			conn.Close()
3574			closedc <- true
3575			return
3576		}
3577		fmt.Fprintf(w, "hello")
3578	})).ts
3579
3580	c := ts.Client()
3581	tr := c.Transport.(*Transport)
3582
3583	var nSuccess = 0
3584	var errs []error
3585	const trials = 20
3586	for i := 0; i < trials; i++ {
3587		tr.CloseIdleConnections()
3588		res, err := c.Get(ts.URL + "/keep-alive-then-die")
3589		if err != nil {
3590			t.Fatal(err)
3591		}
3592		<-closedc
3593		slurp, err := io.ReadAll(res.Body)
3594		if err != nil {
3595			t.Fatal(err)
3596		}
3597		if string(slurp) != "foo" {
3598			t.Errorf("Got %q, want foo", slurp)
3599		}
3600
3601		// Now try again and see if we successfully
3602		// pick a new connection.
3603		res, err = c.Get(ts.URL + "/")
3604		if err != nil {
3605			errs = append(errs, err)
3606			continue
3607		}
3608		slurp, err = io.ReadAll(res.Body)
3609		if err != nil {
3610			errs = append(errs, err)
3611			continue
3612		}
3613		nSuccess++
3614	}
3615	if nSuccess > 0 {
3616		t.Logf("successes = %d of %d", nSuccess, trials)
3617	} else {
3618		t.Errorf("All runs failed:")
3619	}
3620	for _, err := range errs {
3621		t.Logf("  err: %v", err)
3622	}
3623}
3624
3625// byteFromChanReader is an io.Reader that reads a single byte at a
3626// time from the channel. When the channel is closed, the reader
3627// returns io.EOF.
3628type byteFromChanReader chan byte
3629
3630func (c byteFromChanReader) Read(p []byte) (n int, err error) {
3631	if len(p) == 0 {
3632		return
3633	}
3634	b, ok := <-c
3635	if !ok {
3636		return 0, io.EOF
3637	}
3638	p[0] = b
3639	return 1, nil
3640}
3641
3642// Verifies that the Transport doesn't reuse a connection in the case
3643// where the server replies before the request has been fully
3644// written. We still honor that reply (see TestIssue3595), but don't
3645// send future requests on the connection because it's then in a
3646// questionable state.
3647// golang.org/issue/7569
3648func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
3649	run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}, testNotParallel)
3650}
3651func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) {
3652	defer func(d time.Duration) {
3653		*MaxWriteWaitBeforeConnReuse = d
3654	}(*MaxWriteWaitBeforeConnReuse)
3655	*MaxWriteWaitBeforeConnReuse = 10 * time.Millisecond
3656	var sconn struct {
3657		sync.Mutex
3658		c net.Conn
3659	}
3660	var getOkay bool
3661	var copying sync.WaitGroup
3662	closeConn := func() {
3663		sconn.Lock()
3664		defer sconn.Unlock()
3665		if sconn.c != nil {
3666			sconn.c.Close()
3667			sconn.c = nil
3668			if !getOkay {
3669				t.Logf("Closed server connection")
3670			}
3671		}
3672	}
3673	defer func() {
3674		closeConn()
3675		copying.Wait()
3676	}()
3677
3678	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3679		if r.Method == "GET" {
3680			io.WriteString(w, "bar")
3681			return
3682		}
3683		conn, _, _ := w.(Hijacker).Hijack()
3684		sconn.Lock()
3685		sconn.c = conn
3686		sconn.Unlock()
3687		conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive
3688
3689		copying.Add(1)
3690		go func() {
3691			io.Copy(io.Discard, conn)
3692			copying.Done()
3693		}()
3694	})).ts
3695	c := ts.Client()
3696
3697	const bodySize = 256 << 10
3698	finalBit := make(byteFromChanReader, 1)
3699	req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
3700	req.ContentLength = bodySize
3701	res, err := c.Do(req)
3702	if err := wantBody(res, err, "foo"); err != nil {
3703		t.Errorf("POST response: %v", err)
3704	}
3705
3706	res, err = c.Get(ts.URL)
3707	if err := wantBody(res, err, "bar"); err != nil {
3708		t.Errorf("GET response: %v", err)
3709		return
3710	}
3711	getOkay = true  // suppress test noise
3712	finalBit <- 'x' // unblock the writeloop of the first Post
3713	close(finalBit)
3714}
3715
3716// Tests that we don't leak Transport persistConn.readLoop goroutines
3717// when a server hangs up immediately after saying it would keep-alive.
3718func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) }
3719func testTransportIssue10457(t *testing.T, mode testMode) {
3720	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3721		// Send a response with no body, keep-alive
3722		// (implicit), and then lie and immediately close the
3723		// connection. This forces the Transport's readLoop to
3724		// immediately Peek an io.EOF and get to the point
3725		// that used to hang.
3726		conn, _, _ := w.(Hijacker).Hijack()
3727		conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive
3728		conn.Close()
3729	})).ts
3730	c := ts.Client()
3731
3732	res, err := c.Get(ts.URL)
3733	if err != nil {
3734		t.Fatalf("Get: %v", err)
3735	}
3736	defer res.Body.Close()
3737
3738	// Just a sanity check that we at least get the response. The real
3739	// test here is that the "defer afterTest" above doesn't find any
3740	// leaked goroutines.
3741	if got, want := res.Header.Get("Foo"), "Bar"; got != want {
3742		t.Errorf("Foo header = %q; want %q", got, want)
3743	}
3744}
3745
3746type closerFunc func() error
3747
3748func (f closerFunc) Close() error { return f() }
3749
3750type writerFuncConn struct {
3751	net.Conn
3752	write func(p []byte) (n int, err error)
3753}
3754
3755func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) }
3756
3757// Issues 4677, 18241, and 17844. If we try to reuse a connection that the
3758// server is in the process of closing, we may end up successfully writing out
3759// our request (or a portion of our request) only to find a connection error
3760// when we try to read from (or finish writing to) the socket.
3761//
3762// NOTE: we resend a request only if:
3763//   - we reused a keep-alive connection
3764//   - we haven't yet received any header data
3765//   - either we wrote no bytes to the server, or the request is idempotent
3766//
3767// This automatically prevents an infinite resend loop because we'll run out of
3768// the cached keep-alive connections eventually.
3769func TestRetryRequestsOnError(t *testing.T) {
3770	run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode})
3771}
3772func testRetryRequestsOnError(t *testing.T, mode testMode) {
3773	newRequest := func(method, urlStr string, body io.Reader) *Request {
3774		req, err := NewRequest(method, urlStr, body)
3775		if err != nil {
3776			t.Fatal(err)
3777		}
3778		return req
3779	}
3780
3781	testCases := []struct {
3782		name       string
3783		failureN   int
3784		failureErr error
3785		// Note that we can't just re-use the Request object across calls to c.Do
3786		// because we need to rewind Body between calls.  (GetBody is only used to
3787		// rewind Body on failure and redirects, not just because it's done.)
3788		req       func() *Request
3789		reqString string
3790	}{
3791		{
3792			name: "IdempotentNoBodySomeWritten",
3793			// Believe that we've written some bytes to the server, so we know we're
3794			// not just in the "retry when no bytes sent" case".
3795			failureN: 1,
3796			// Use the specific error that shouldRetryRequest looks for with idempotent requests.
3797			failureErr: ExportErrServerClosedIdle,
3798			req: func() *Request {
3799				return newRequest("GET", "http://fake.golang", nil)
3800			},
3801			reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3802		},
3803		{
3804			name: "IdempotentGetBodySomeWritten",
3805			// Believe that we've written some bytes to the server, so we know we're
3806			// not just in the "retry when no bytes sent" case".
3807			failureN: 1,
3808			// Use the specific error that shouldRetryRequest looks for with idempotent requests.
3809			failureErr: ExportErrServerClosedIdle,
3810			req: func() *Request {
3811				return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n"))
3812			},
3813			reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3814		},
3815		{
3816			name: "NothingWrittenNoBody",
3817			// It's key that we return 0 here -- that's what enables Transport to know
3818			// that nothing was written, even though this is a non-idempotent request.
3819			failureN:   0,
3820			failureErr: errors.New("second write fails"),
3821			req: func() *Request {
3822				return newRequest("DELETE", "http://fake.golang", nil)
3823			},
3824			reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`,
3825		},
3826		{
3827			name: "NothingWrittenGetBody",
3828			// It's key that we return 0 here -- that's what enables Transport to know
3829			// that nothing was written, even though this is a non-idempotent request.
3830			failureN:   0,
3831			failureErr: errors.New("second write fails"),
3832			// Note that NewRequest will set up GetBody for strings.Reader, which is
3833			// required for the retry to occur
3834			req: func() *Request {
3835				return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n"))
3836			},
3837			reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`,
3838		},
3839	}
3840
3841	for _, tc := range testCases {
3842		t.Run(tc.name, func(t *testing.T) {
3843			var (
3844				mu     sync.Mutex
3845				logbuf strings.Builder
3846			)
3847			logf := func(format string, args ...any) {
3848				mu.Lock()
3849				defer mu.Unlock()
3850				fmt.Fprintf(&logbuf, format, args...)
3851				logbuf.WriteByte('\n')
3852			}
3853
3854			ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3855				logf("Handler")
3856				w.Header().Set("X-Status", "ok")
3857			})).ts
3858
3859			var writeNumAtomic int32
3860			c := ts.Client()
3861			c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) {
3862				logf("Dial")
3863				c, err := net.Dial(network, ts.Listener.Addr().String())
3864				if err != nil {
3865					logf("Dial error: %v", err)
3866					return nil, err
3867				}
3868				return &writerFuncConn{
3869					Conn: c,
3870					write: func(p []byte) (n int, err error) {
3871						if atomic.AddInt32(&writeNumAtomic, 1) == 2 {
3872							logf("intentional write failure")
3873							return tc.failureN, tc.failureErr
3874						}
3875						logf("Write(%q)", p)
3876						return c.Write(p)
3877					},
3878				}, nil
3879			}
3880
3881			SetRoundTripRetried(func() {
3882				logf("Retried.")
3883			})
3884			defer SetRoundTripRetried(nil)
3885
3886			for i := 0; i < 3; i++ {
3887				t0 := time.Now()
3888				req := tc.req()
3889				res, err := c.Do(req)
3890				if err != nil {
3891					if time.Since(t0) < *MaxWriteWaitBeforeConnReuse/2 {
3892						mu.Lock()
3893						got := logbuf.String()
3894						mu.Unlock()
3895						t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
3896					}
3897					t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", *MaxWriteWaitBeforeConnReuse)
3898				}
3899				res.Body.Close()
3900				if res.Request != req {
3901					t.Errorf("Response.Request != original request; want identical Request")
3902				}
3903			}
3904
3905			mu.Lock()
3906			got := logbuf.String()
3907			mu.Unlock()
3908			want := fmt.Sprintf(`Dial
3909Write("%s")
3910Handler
3911intentional write failure
3912Retried.
3913Dial
3914Write("%s")
3915Handler
3916Write("%s")
3917Handler
3918`, tc.reqString, tc.reqString, tc.reqString)
3919			if got != want {
3920				t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want)
3921			}
3922		})
3923	}
3924}
3925
3926// Issue 6981
3927func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) }
3928func testTransportClosesBodyOnError(t *testing.T, mode testMode) {
3929	readBody := make(chan error, 1)
3930	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3931		_, err := io.ReadAll(r.Body)
3932		readBody <- err
3933	})).ts
3934	c := ts.Client()
3935	fakeErr := errors.New("fake error")
3936	didClose := make(chan bool, 1)
3937	req, _ := NewRequest("POST", ts.URL, struct {
3938		io.Reader
3939		io.Closer
3940	}{
3941		io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)),
3942		closerFunc(func() error {
3943			select {
3944			case didClose <- true:
3945			default:
3946			}
3947			return nil
3948		}),
3949	})
3950	res, err := c.Do(req)
3951	if res != nil {
3952		defer res.Body.Close()
3953	}
3954	if err == nil || !strings.Contains(err.Error(), fakeErr.Error()) {
3955		t.Fatalf("Do error = %v; want something containing %q", err, fakeErr.Error())
3956	}
3957	if err := <-readBody; err == nil {
3958		t.Errorf("Unexpected success reading request body from handler; want 'unexpected EOF reading trailer'")
3959	}
3960	select {
3961	case <-didClose:
3962	default:
3963		t.Errorf("didn't see Body.Close")
3964	}
3965}
3966
3967func TestTransportDialTLS(t *testing.T) {
3968	run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode})
3969}
3970func testTransportDialTLS(t *testing.T, mode testMode) {
3971	var mu sync.Mutex // guards following
3972	var gotReq, didDial bool
3973
3974	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3975		mu.Lock()
3976		gotReq = true
3977		mu.Unlock()
3978	})).ts
3979	c := ts.Client()
3980	c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) {
3981		mu.Lock()
3982		didDial = true
3983		mu.Unlock()
3984		c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
3985		if err != nil {
3986			return nil, err
3987		}
3988		return c, c.Handshake()
3989	}
3990
3991	res, err := c.Get(ts.URL)
3992	if err != nil {
3993		t.Fatal(err)
3994	}
3995	res.Body.Close()
3996	mu.Lock()
3997	if !gotReq {
3998		t.Error("didn't get request")
3999	}
4000	if !didDial {
4001		t.Error("didn't use dial hook")
4002	}
4003}
4004
4005func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) }
4006func testTransportDialContext(t *testing.T, mode testMode) {
4007	ctxKey := "some-key"
4008	ctxValue := "some-value"
4009	var (
4010		mu          sync.Mutex // guards following
4011		gotReq      bool
4012		gotCtxValue any
4013	)
4014
4015	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4016		mu.Lock()
4017		gotReq = true
4018		mu.Unlock()
4019	})).ts
4020	c := ts.Client()
4021	c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4022		mu.Lock()
4023		gotCtxValue = ctx.Value(ctxKey)
4024		mu.Unlock()
4025		return net.Dial(netw, addr)
4026	}
4027
4028	req, err := NewRequest("GET", ts.URL, nil)
4029	if err != nil {
4030		t.Fatal(err)
4031	}
4032	ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4033	res, err := c.Do(req.WithContext(ctx))
4034	if err != nil {
4035		t.Fatal(err)
4036	}
4037	res.Body.Close()
4038	mu.Lock()
4039	if !gotReq {
4040		t.Error("didn't get request")
4041	}
4042	if got, want := gotCtxValue, ctxValue; got != want {
4043		t.Errorf("got context with value %v, want %v", got, want)
4044	}
4045}
4046
4047func TestTransportDialTLSContext(t *testing.T) {
4048	run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode})
4049}
4050func testTransportDialTLSContext(t *testing.T, mode testMode) {
4051	ctxKey := "some-key"
4052	ctxValue := "some-value"
4053	var (
4054		mu          sync.Mutex // guards following
4055		gotReq      bool
4056		gotCtxValue any
4057	)
4058
4059	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4060		mu.Lock()
4061		gotReq = true
4062		mu.Unlock()
4063	})).ts
4064	c := ts.Client()
4065	c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
4066		mu.Lock()
4067		gotCtxValue = ctx.Value(ctxKey)
4068		mu.Unlock()
4069		c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
4070		if err != nil {
4071			return nil, err
4072		}
4073		return c, c.HandshakeContext(ctx)
4074	}
4075
4076	req, err := NewRequest("GET", ts.URL, nil)
4077	if err != nil {
4078		t.Fatal(err)
4079	}
4080	ctx := context.WithValue(context.Background(), ctxKey, ctxValue)
4081	res, err := c.Do(req.WithContext(ctx))
4082	if err != nil {
4083		t.Fatal(err)
4084	}
4085	res.Body.Close()
4086	mu.Lock()
4087	if !gotReq {
4088		t.Error("didn't get request")
4089	}
4090	if got, want := gotCtxValue, ctxValue; got != want {
4091		t.Errorf("got context with value %v, want %v", got, want)
4092	}
4093}
4094
4095// Test for issue 8755
4096// Ensure that if a proxy returns an error, it is exposed by RoundTrip
4097func TestRoundTripReturnsProxyError(t *testing.T) {
4098	badProxy := func(*Request) (*url.URL, error) {
4099		return nil, errors.New("errorMessage")
4100	}
4101
4102	tr := &Transport{Proxy: badProxy}
4103
4104	req, _ := NewRequest("GET", "http://example.com", nil)
4105
4106	_, err := tr.RoundTrip(req)
4107
4108	if err == nil {
4109		t.Error("Expected proxy error to be returned by RoundTrip")
4110	}
4111}
4112
4113// tests that putting an idle conn after a call to CloseIdleConns does return it
4114func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
4115	tr := &Transport{}
4116	wantIdle := func(when string, n int) bool {
4117		got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn
4118		if got == n {
4119			return true
4120		}
4121		t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4122		return false
4123	}
4124	wantIdle("start", 0)
4125	if !tr.PutIdleTestConn("http", "example.com") {
4126		t.Fatal("put failed")
4127	}
4128	if !tr.PutIdleTestConn("http", "example.com") {
4129		t.Fatal("second put failed")
4130	}
4131	wantIdle("after put", 2)
4132	tr.CloseIdleConnections()
4133	if !tr.IsIdleForTesting() {
4134		t.Error("should be idle after CloseIdleConnections")
4135	}
4136	wantIdle("after close idle", 0)
4137	if tr.PutIdleTestConn("http", "example.com") {
4138		t.Fatal("put didn't fail")
4139	}
4140	wantIdle("after second put", 0)
4141
4142	tr.QueueForIdleConnForTesting() // should toggle the transport out of idle mode
4143	if tr.IsIdleForTesting() {
4144		t.Error("shouldn't be idle after QueueForIdleConnForTesting")
4145	}
4146	if !tr.PutIdleTestConn("http", "example.com") {
4147		t.Fatal("after re-activation")
4148	}
4149	wantIdle("after final put", 1)
4150}
4151
4152// Test for issue 34282
4153// Ensure that getConn doesn't call the GotConn trace hook on an HTTP/2 idle conn
4154func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
4155	tr := &Transport{}
4156	wantIdle := func(when string, n int) bool {
4157		got := tr.IdleConnCountForTesting("https", "example.com:443") // key used by PutIdleTestConnH2
4158		if got == n {
4159			return true
4160		}
4161		t.Errorf("%s: idle conns = %d; want %d", when, got, n)
4162		return false
4163	}
4164	wantIdle("start", 0)
4165	alt := funcRoundTripper(func() {})
4166	if !tr.PutIdleTestConnH2("https", "example.com:443", alt) {
4167		t.Fatal("put failed")
4168	}
4169	wantIdle("after put", 1)
4170	ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
4171		GotConn: func(httptrace.GotConnInfo) {
4172			// tr.getConn should leave it for the HTTP/2 alt to call GotConn.
4173			t.Error("GotConn called")
4174		},
4175	})
4176	req, _ := NewRequestWithContext(ctx, MethodGet, "https://example.com", nil)
4177	_, err := tr.RoundTrip(req)
4178	if err != errFakeRoundTrip {
4179		t.Errorf("got error: %v; want %q", err, errFakeRoundTrip)
4180	}
4181	wantIdle("after round trip", 1)
4182}
4183
4184func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
4185	run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode})
4186}
4187func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) {
4188	if testing.Short() {
4189		t.Skip("skipping in short mode")
4190	}
4191
4192	timeout := 1 * time.Millisecond
4193	retry := true
4194	for retry {
4195		trFunc := func(tr *Transport) {
4196			tr.MaxConnsPerHost = 1
4197			tr.MaxIdleConnsPerHost = 1
4198			tr.IdleConnTimeout = timeout
4199		}
4200		cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
4201
4202		retry = false
4203		tooShort := func(err error) bool {
4204			if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
4205				return false
4206			}
4207			if !retry {
4208				t.Helper()
4209				t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout)
4210				timeout *= 2
4211				retry = true
4212				cst.close()
4213			}
4214			return true
4215		}
4216
4217		if _, err := cst.c.Get(cst.ts.URL); err != nil {
4218			if tooShort(err) {
4219				continue
4220			}
4221			t.Fatalf("got error: %s", err)
4222		}
4223
4224		time.Sleep(10 * timeout)
4225		if _, err := cst.c.Get(cst.ts.URL); err != nil {
4226			if tooShort(err) {
4227				continue
4228			}
4229			t.Fatalf("got error: %s", err)
4230		}
4231	}
4232}
4233
4234// This tests that a client requesting a content range won't also
4235// implicitly ask for gzip support. If they want that, they need to do it
4236// on their own.
4237// golang.org/issue/8923
4238func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) }
4239func testTransportRangeAndGzip(t *testing.T, mode testMode) {
4240	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4241		if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
4242			t.Error("Transport advertised gzip support in the Accept header")
4243		}
4244		if r.Header.Get("Range") == "" {
4245			t.Error("no Range in request")
4246		}
4247	})).ts
4248	c := ts.Client()
4249
4250	req, _ := NewRequest("GET", ts.URL, nil)
4251	req.Header.Set("Range", "bytes=7-11")
4252	res, err := c.Do(req)
4253	if err != nil {
4254		t.Fatal(err)
4255	}
4256	res.Body.Close()
4257}
4258
4259// Test for issue 10474
4260func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) }
4261func testTransportResponseCancelRace(t *testing.T, mode testMode) {
4262	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4263		// important that this response has a body.
4264		var b [1024]byte
4265		w.Write(b[:])
4266	})).ts
4267	tr := ts.Client().Transport.(*Transport)
4268
4269	req, err := NewRequest("GET", ts.URL, nil)
4270	if err != nil {
4271		t.Fatal(err)
4272	}
4273	res, err := tr.RoundTrip(req)
4274	if err != nil {
4275		t.Fatal(err)
4276	}
4277	// If we do an early close, Transport just throws the connection away and
4278	// doesn't reuse it. In order to trigger the bug, it has to reuse the connection
4279	// so read the body
4280	if _, err := io.Copy(io.Discard, res.Body); err != nil {
4281		t.Fatal(err)
4282	}
4283
4284	req2, err := NewRequest("GET", ts.URL, nil)
4285	if err != nil {
4286		t.Fatal(err)
4287	}
4288	tr.CancelRequest(req)
4289	res, err = tr.RoundTrip(req2)
4290	if err != nil {
4291		t.Fatal(err)
4292	}
4293	res.Body.Close()
4294}
4295
4296// Test for issue 19248: Content-Encoding's value is case insensitive.
4297func TestTransportContentEncodingCaseInsensitive(t *testing.T) {
4298	run(t, testTransportContentEncodingCaseInsensitive)
4299}
4300func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) {
4301	for _, ce := range []string{"gzip", "GZIP"} {
4302		ce := ce
4303		t.Run(ce, func(t *testing.T) {
4304			const encodedString = "Hello Gopher"
4305			ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4306				w.Header().Set("Content-Encoding", ce)
4307				gz := gzip.NewWriter(w)
4308				gz.Write([]byte(encodedString))
4309				gz.Close()
4310			})).ts
4311
4312			res, err := ts.Client().Get(ts.URL)
4313			if err != nil {
4314				t.Fatal(err)
4315			}
4316
4317			body, err := io.ReadAll(res.Body)
4318			res.Body.Close()
4319			if err != nil {
4320				t.Fatal(err)
4321			}
4322
4323			if string(body) != encodedString {
4324				t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body))
4325			}
4326		})
4327	}
4328}
4329
4330// https://go.dev/issue/49621
4331func TestConnClosedBeforeRequestIsWritten(t *testing.T) {
4332	run(t, testConnClosedBeforeRequestIsWritten, testNotParallel, []testMode{http1Mode})
4333}
4334func testConnClosedBeforeRequestIsWritten(t *testing.T, mode testMode) {
4335	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
4336		func(tr *Transport) {
4337			tr.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
4338				// Connection immediately returns errors.
4339				return &funcConn{
4340					read: func([]byte) (int, error) {
4341						return 0, errors.New("error")
4342					},
4343					write: func([]byte) (int, error) {
4344						return 0, errors.New("error")
4345					},
4346				}, nil
4347			}
4348		},
4349	).ts
4350	// Set a short delay in RoundTrip to give the persistConn time to notice
4351	// the connection is broken. We want to exercise the path where writeLoop exits
4352	// before it reads the request to send. If this delay is too short, we may instead
4353	// exercise the path where writeLoop accepts the request and then fails to write it.
4354	// That's fine, so long as we get the desired path often enough.
4355	SetEnterRoundTripHook(func() {
4356		time.Sleep(1 * time.Millisecond)
4357	})
4358	defer SetEnterRoundTripHook(nil)
4359	var closes int
4360	_, err := ts.Client().Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")})
4361	if err == nil {
4362		t.Fatalf("expected request to fail, but it did not")
4363	}
4364	if closes != 1 {
4365		t.Errorf("after RoundTrip, request body was closed %v times; want 1", closes)
4366	}
4367}
4368
4369// logWritesConn is a net.Conn that logs each Write call to writes
4370// and then proxies to w.
4371// It proxies Read calls to a reader it receives from rch.
4372type logWritesConn struct {
4373	net.Conn // nil. crash on use.
4374
4375	w io.Writer
4376
4377	rch <-chan io.Reader
4378	r   io.Reader // nil until received by rch
4379
4380	mu     sync.Mutex
4381	writes []string
4382}
4383
4384func (c *logWritesConn) Write(p []byte) (n int, err error) {
4385	c.mu.Lock()
4386	defer c.mu.Unlock()
4387	c.writes = append(c.writes, string(p))
4388	return c.w.Write(p)
4389}
4390
4391func (c *logWritesConn) Read(p []byte) (n int, err error) {
4392	if c.r == nil {
4393		c.r = <-c.rch
4394	}
4395	return c.r.Read(p)
4396}
4397
4398func (c *logWritesConn) Close() error { return nil }
4399
4400// Issue 6574
4401func TestTransportFlushesBodyChunks(t *testing.T) {
4402	defer afterTest(t)
4403	resBody := make(chan io.Reader, 1)
4404	connr, connw := io.Pipe() // connection pipe pair
4405	lw := &logWritesConn{
4406		rch: resBody,
4407		w:   connw,
4408	}
4409	tr := &Transport{
4410		Dial: func(network, addr string) (net.Conn, error) {
4411			return lw, nil
4412		},
4413	}
4414	bodyr, bodyw := io.Pipe() // body pipe pair
4415	go func() {
4416		defer bodyw.Close()
4417		for i := 0; i < 3; i++ {
4418			fmt.Fprintf(bodyw, "num%d\n", i)
4419		}
4420	}()
4421	resc := make(chan *Response)
4422	go func() {
4423		req, _ := NewRequest("POST", "http://localhost:8080", bodyr)
4424		req.Header.Set("User-Agent", "x") // known value for test
4425		res, err := tr.RoundTrip(req)
4426		if err != nil {
4427			t.Errorf("RoundTrip: %v", err)
4428			close(resc)
4429			return
4430		}
4431		resc <- res
4432
4433	}()
4434	// Fully consume the request before checking the Write log vs. want.
4435	req, err := ReadRequest(bufio.NewReader(connr))
4436	if err != nil {
4437		t.Fatal(err)
4438	}
4439	io.Copy(io.Discard, req.Body)
4440
4441	// Unblock the transport's roundTrip goroutine.
4442	resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
4443	res, ok := <-resc
4444	if !ok {
4445		return
4446	}
4447	defer res.Body.Close()
4448
4449	want := []string{
4450		"POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
4451		"5\r\nnum0\n\r\n",
4452		"5\r\nnum1\n\r\n",
4453		"5\r\nnum2\n\r\n",
4454		"0\r\n\r\n",
4455	}
4456	if !reflect.DeepEqual(lw.writes, want) {
4457		t.Errorf("Writes differed.\n Got: %q\nWant: %q\n", lw.writes, want)
4458	}
4459}
4460
4461// Issue 22088: flush Transport request headers if we're not sure the body won't block on read.
4462func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) }
4463func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
4464	gotReq := make(chan struct{})
4465	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4466		close(gotReq)
4467	}))
4468
4469	pr, pw := io.Pipe()
4470	req, err := NewRequest("POST", cst.ts.URL, pr)
4471	if err != nil {
4472		t.Fatal(err)
4473	}
4474	gotRes := make(chan struct{})
4475	go func() {
4476		defer close(gotRes)
4477		res, err := cst.tr.RoundTrip(req)
4478		if err != nil {
4479			t.Error(err)
4480			return
4481		}
4482		res.Body.Close()
4483	}()
4484
4485	<-gotReq
4486	pw.Close()
4487	<-gotRes
4488}
4489
4490type wgReadCloser struct {
4491	io.Reader
4492	wg     *sync.WaitGroup
4493	closed bool
4494}
4495
4496func (c *wgReadCloser) Close() error {
4497	if c.closed {
4498		return net.ErrClosed
4499	}
4500	c.closed = true
4501	c.wg.Done()
4502	return nil
4503}
4504
4505// Issue 11745.
4506func TestTransportPrefersResponseOverWriteError(t *testing.T) {
4507	// Not parallel: modifies the global rstAvoidanceDelay.
4508	run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
4509}
4510func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
4511	if testing.Short() {
4512		t.Skip("skipping in short mode")
4513	}
4514
4515	runTimeSensitiveTest(t, []time.Duration{
4516		1 * time.Millisecond,
4517		5 * time.Millisecond,
4518		10 * time.Millisecond,
4519		50 * time.Millisecond,
4520		100 * time.Millisecond,
4521		500 * time.Millisecond,
4522		time.Second,
4523		5 * time.Second,
4524	}, func(t *testing.T, timeout time.Duration) error {
4525		SetRSTAvoidanceDelay(t, timeout)
4526		t.Logf("set RST avoidance delay to %v", timeout)
4527
4528		const contentLengthLimit = 1024 * 1024 // 1MB
4529		cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4530			if r.ContentLength >= contentLengthLimit {
4531				w.WriteHeader(StatusBadRequest)
4532				r.Body.Close()
4533				return
4534			}
4535			w.WriteHeader(StatusOK)
4536		}))
4537		// We need to close cst explicitly here so that in-flight server
4538		// requests don't race with the call to SetRSTAvoidanceDelay for a retry.
4539		defer cst.close()
4540		ts := cst.ts
4541		c := ts.Client()
4542
4543		count := 100
4544
4545		bigBody := strings.Repeat("a", contentLengthLimit*2)
4546		var wg sync.WaitGroup
4547		defer wg.Wait()
4548		getBody := func() (io.ReadCloser, error) {
4549			wg.Add(1)
4550			body := &wgReadCloser{
4551				Reader: strings.NewReader(bigBody),
4552				wg:     &wg,
4553			}
4554			return body, nil
4555		}
4556
4557		for i := 0; i < count; i++ {
4558			reqBody, _ := getBody()
4559			req, err := NewRequest("PUT", ts.URL, reqBody)
4560			if err != nil {
4561				reqBody.Close()
4562				t.Fatal(err)
4563			}
4564			req.ContentLength = int64(len(bigBody))
4565			req.GetBody = getBody
4566
4567			resp, err := c.Do(req)
4568			if err != nil {
4569				return fmt.Errorf("Do %d: %v", i, err)
4570			} else {
4571				resp.Body.Close()
4572				if resp.StatusCode != 400 {
4573					t.Errorf("Expected status code 400, got %v", resp.Status)
4574				}
4575			}
4576		}
4577		return nil
4578	})
4579}
4580
4581func TestTransportAutomaticHTTP2(t *testing.T) {
4582	testTransportAutoHTTP(t, &Transport{}, true)
4583}
4584
4585func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
4586	testTransportAutoHTTP(t, &Transport{
4587		ForceAttemptHTTP2: true,
4588		TLSClientConfig:   new(tls.Config),
4589	}, true)
4590}
4591
4592// golang.org/issue/14391: also check DefaultTransport
4593func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
4594	testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
4595}
4596
4597func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) {
4598	testTransportAutoHTTP(t, &Transport{
4599		TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper),
4600	}, false)
4601}
4602
4603func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) {
4604	testTransportAutoHTTP(t, &Transport{
4605		TLSClientConfig: new(tls.Config),
4606	}, false)
4607}
4608
4609func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) {
4610	testTransportAutoHTTP(t, &Transport{
4611		ExpectContinueTimeout: 1 * time.Second,
4612	}, true)
4613}
4614
4615func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
4616	var d net.Dialer
4617	testTransportAutoHTTP(t, &Transport{
4618		Dial: d.Dial,
4619	}, false)
4620}
4621
4622func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
4623	var d net.Dialer
4624	testTransportAutoHTTP(t, &Transport{
4625		DialContext: d.DialContext,
4626	}, false)
4627}
4628
4629func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
4630	testTransportAutoHTTP(t, &Transport{
4631		DialTLS: func(network, addr string) (net.Conn, error) {
4632			panic("unused")
4633		},
4634	}, false)
4635}
4636
4637func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) {
4638	CondSkipHTTP2(t)
4639	_, err := tr.RoundTrip(new(Request))
4640	if err == nil {
4641		t.Error("expected error from RoundTrip")
4642	}
4643	if reg := tr.TLSNextProto["h2"] != nil; reg != wantH2 {
4644		t.Errorf("HTTP/2 registered = %v; want %v", reg, wantH2)
4645	}
4646}
4647
4648// Issue 13633: there was a race where we returned bodyless responses
4649// to callers before recycling the persistent connection, which meant
4650// a client doing two subsequent requests could end up on different
4651// connections. It's somewhat harmless but enough tests assume it's
4652// not true in order to test other things that it's worth fixing.
4653// Plus it's nice to be consistent and not have timing-dependent
4654// behavior.
4655func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
4656	run(t, testTransportReuseConnEmptyResponseBody)
4657}
4658func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) {
4659	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4660		w.Header().Set("X-Addr", r.RemoteAddr)
4661		// Empty response body.
4662	}))
4663	n := 100
4664	if testing.Short() {
4665		n = 10
4666	}
4667	var firstAddr string
4668	for i := 0; i < n; i++ {
4669		res, err := cst.c.Get(cst.ts.URL)
4670		if err != nil {
4671			log.Fatal(err)
4672		}
4673		addr := res.Header.Get("X-Addr")
4674		if i == 0 {
4675			firstAddr = addr
4676		} else if addr != firstAddr {
4677			t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
4678		}
4679		res.Body.Close()
4680	}
4681}
4682
4683// Issue 13839
4684func TestNoCrashReturningTransportAltConn(t *testing.T) {
4685	cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
4686	if err != nil {
4687		t.Fatal(err)
4688	}
4689	ln := newLocalListener(t)
4690	defer ln.Close()
4691
4692	var wg sync.WaitGroup
4693	SetPendingDialHooks(func() { wg.Add(1) }, wg.Done)
4694	defer SetPendingDialHooks(nil, nil)
4695
4696	testDone := make(chan struct{})
4697	defer close(testDone)
4698	go func() {
4699		tln := tls.NewListener(ln, &tls.Config{
4700			NextProtos:   []string{"foo"},
4701			Certificates: []tls.Certificate{cert},
4702		})
4703		sc, err := tln.Accept()
4704		if err != nil {
4705			t.Error(err)
4706			return
4707		}
4708		if err := sc.(*tls.Conn).Handshake(); err != nil {
4709			t.Error(err)
4710			return
4711		}
4712		<-testDone
4713		sc.Close()
4714	}()
4715
4716	addr := ln.Addr().String()
4717
4718	req, _ := NewRequest("GET", "https://fake.tld/", nil)
4719	cancel := make(chan struct{})
4720	req.Cancel = cancel
4721
4722	doReturned := make(chan bool, 1)
4723	madeRoundTripper := make(chan bool, 1)
4724
4725	tr := &Transport{
4726		DisableKeepAlives: true,
4727		TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
4728			"foo": func(authority string, c *tls.Conn) RoundTripper {
4729				madeRoundTripper <- true
4730				return funcRoundTripper(func() {
4731					t.Error("foo RoundTripper should not be called")
4732				})
4733			},
4734		},
4735		Dial: func(_, _ string) (net.Conn, error) {
4736			panic("shouldn't be called")
4737		},
4738		DialTLS: func(_, _ string) (net.Conn, error) {
4739			tc, err := tls.Dial("tcp", addr, &tls.Config{
4740				InsecureSkipVerify: true,
4741				NextProtos:         []string{"foo"},
4742			})
4743			if err != nil {
4744				return nil, err
4745			}
4746			if err := tc.Handshake(); err != nil {
4747				return nil, err
4748			}
4749			close(cancel)
4750			<-doReturned
4751			return tc, nil
4752		},
4753	}
4754	c := &Client{Transport: tr}
4755
4756	_, err = c.Do(req)
4757	if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
4758		t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
4759	}
4760
4761	doReturned <- true
4762	<-madeRoundTripper
4763	wg.Wait()
4764}
4765
4766func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) {
4767	run(t, func(t *testing.T, mode testMode) {
4768		testTransportReuseConnection_Gzip(t, mode, true)
4769	})
4770}
4771
4772func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
4773	run(t, func(t *testing.T, mode testMode) {
4774		testTransportReuseConnection_Gzip(t, mode, false)
4775	})
4776}
4777
4778// Make sure we re-use underlying TCP connection for gzipped responses too.
4779func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) {
4780	addr := make(chan string, 2)
4781	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4782		addr <- r.RemoteAddr
4783		w.Header().Set("Content-Encoding", "gzip")
4784		if chunked {
4785			w.(Flusher).Flush()
4786		}
4787		w.Write(rgz) // arbitrary gzip response
4788	})).ts
4789	c := ts.Client()
4790
4791	trace := &httptrace.ClientTrace{
4792		GetConn:      func(hostPort string) { t.Logf("GetConn(%q)", hostPort) },
4793		GotConn:      func(ci httptrace.GotConnInfo) { t.Logf("GotConn(%+v)", ci) },
4794		PutIdleConn:  func(err error) { t.Logf("PutIdleConn(%v)", err) },
4795		ConnectStart: func(network, addr string) { t.Logf("ConnectStart(%q, %q)", network, addr) },
4796		ConnectDone:  func(network, addr string, err error) { t.Logf("ConnectDone(%q, %q, %v)", network, addr, err) },
4797	}
4798	ctx := httptrace.WithClientTrace(context.Background(), trace)
4799
4800	for i := 0; i < 2; i++ {
4801		req, _ := NewRequest("GET", ts.URL, nil)
4802		req = req.WithContext(ctx)
4803		res, err := c.Do(req)
4804		if err != nil {
4805			t.Fatal(err)
4806		}
4807		buf := make([]byte, len(rgz))
4808		if n, err := io.ReadFull(res.Body, buf); err != nil {
4809			t.Errorf("%d. ReadFull = %v, %v", i, n, err)
4810		}
4811		// Note: no res.Body.Close call. It should work without it,
4812		// since the flate.Reader's internal buffering will hit EOF
4813		// and that should be sufficient.
4814	}
4815	a1, a2 := <-addr, <-addr
4816	if a1 != a2 {
4817		t.Fatalf("didn't reuse connection")
4818	}
4819}
4820
4821func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) }
4822func testTransportResponseHeaderLength(t *testing.T, mode testMode) {
4823	if mode == http2Mode {
4824		t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes")
4825	}
4826	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4827		if r.URL.Path == "/long" {
4828			w.Header().Set("Long", strings.Repeat("a", 1<<20))
4829		}
4830	})).ts
4831	c := ts.Client()
4832	c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10
4833
4834	if res, err := c.Get(ts.URL); err != nil {
4835		t.Fatal(err)
4836	} else {
4837		res.Body.Close()
4838	}
4839
4840	res, err := c.Get(ts.URL + "/long")
4841	if err == nil {
4842		defer res.Body.Close()
4843		var n int64
4844		for k, vv := range res.Header {
4845			for _, v := range vv {
4846				n += int64(len(k)) + int64(len(v))
4847			}
4848		}
4849		t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
4850	}
4851	if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
4852		t.Errorf("got error: %v; want %q", err, want)
4853	}
4854}
4855
4856func TestTransportEventTrace(t *testing.T) {
4857	run(t, func(t *testing.T, mode testMode) {
4858		testTransportEventTrace(t, mode, false)
4859	}, testNotParallel)
4860}
4861
4862// test a non-nil httptrace.ClientTrace but with all hooks set to zero.
4863func TestTransportEventTrace_NoHooks(t *testing.T) {
4864	run(t, func(t *testing.T, mode testMode) {
4865		testTransportEventTrace(t, mode, true)
4866	}, testNotParallel)
4867}
4868
4869func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) {
4870	const resBody = "some body"
4871	gotWroteReqEvent := make(chan struct{}, 500)
4872	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4873		if r.Method == "GET" {
4874			// Do nothing for the second request.
4875			return
4876		}
4877		if _, err := io.ReadAll(r.Body); err != nil {
4878			t.Error(err)
4879		}
4880		if !noHooks {
4881			<-gotWroteReqEvent
4882		}
4883		io.WriteString(w, resBody)
4884	}), func(tr *Transport) {
4885		if tr.TLSClientConfig != nil {
4886			tr.TLSClientConfig.InsecureSkipVerify = true
4887		}
4888	})
4889	defer cst.close()
4890
4891	cst.tr.ExpectContinueTimeout = 1 * time.Second
4892
4893	var mu sync.Mutex // guards buf
4894	var buf strings.Builder
4895	logf := func(format string, args ...any) {
4896		mu.Lock()
4897		defer mu.Unlock()
4898		fmt.Fprintf(&buf, format, args...)
4899		buf.WriteByte('\n')
4900	}
4901
4902	addrStr := cst.ts.Listener.Addr().String()
4903	ip, port, err := net.SplitHostPort(addrStr)
4904	if err != nil {
4905		t.Fatal(err)
4906	}
4907
4908	// Install a fake DNS server.
4909	ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
4910		if host != "dns-is-faked.golang" {
4911			t.Errorf("unexpected DNS host lookup for %q/%q", network, host)
4912			return nil, nil
4913		}
4914		return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
4915	})
4916
4917	body := "some body"
4918	req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
4919	req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
4920	trace := &httptrace.ClientTrace{
4921		GetConn:              func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
4922		GotConn:              func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
4923		GotFirstResponseByte: func() { logf("first response byte") },
4924		PutIdleConn:          func(err error) { logf("PutIdleConn = %v", err) },
4925		DNSStart:             func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
4926		DNSDone:              func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
4927		ConnectStart:         func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
4928		ConnectDone: func(network, addr string, err error) {
4929			if err != nil {
4930				t.Errorf("ConnectDone: %v", err)
4931			}
4932			logf("ConnectDone: connected to %s %s = %v", network, addr, err)
4933		},
4934		WroteHeaderField: func(key string, value []string) {
4935			logf("WroteHeaderField: %s: %v", key, value)
4936		},
4937		WroteHeaders: func() {
4938			logf("WroteHeaders")
4939		},
4940		Wait100Continue: func() { logf("Wait100Continue") },
4941		Got100Continue:  func() { logf("Got100Continue") },
4942		WroteRequest: func(e httptrace.WroteRequestInfo) {
4943			logf("WroteRequest: %+v", e)
4944			gotWroteReqEvent <- struct{}{}
4945		},
4946	}
4947	if mode == http2Mode {
4948		trace.TLSHandshakeStart = func() { logf("tls handshake start") }
4949		trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
4950			logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
4951		}
4952	}
4953	if noHooks {
4954		// zero out all func pointers, trying to get some path to crash
4955		*trace = httptrace.ClientTrace{}
4956	}
4957	req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
4958
4959	req.Header.Set("Expect", "100-continue")
4960	res, err := cst.c.Do(req)
4961	if err != nil {
4962		t.Fatal(err)
4963	}
4964	logf("got roundtrip.response")
4965	slurp, err := io.ReadAll(res.Body)
4966	if err != nil {
4967		t.Fatal(err)
4968	}
4969	logf("consumed body")
4970	if string(slurp) != resBody || res.StatusCode != 200 {
4971		t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
4972	}
4973	res.Body.Close()
4974
4975	if noHooks {
4976		// Done at this point. Just testing a full HTTP
4977		// requests can happen with a trace pointing to a zero
4978		// ClientTrace, full of nil func pointers.
4979		return
4980	}
4981
4982	mu.Lock()
4983	got := buf.String()
4984	mu.Unlock()
4985
4986	wantOnce := func(sub string) {
4987		if strings.Count(got, sub) != 1 {
4988			t.Errorf("expected substring %q exactly once in output.", sub)
4989		}
4990	}
4991	wantOnceOrMore := func(sub string) {
4992		if strings.Count(got, sub) == 0 {
4993			t.Errorf("expected substring %q at least once in output.", sub)
4994		}
4995	}
4996	wantOnce("Getting conn for dns-is-faked.golang:" + port)
4997	wantOnce("DNS start: {Host:dns-is-faked.golang}")
4998	wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
4999	wantOnce("got conn: {")
5000	wantOnceOrMore("Connecting to tcp " + addrStr)
5001	wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
5002	wantOnce("Reused:false WasIdle:false IdleTime:0s")
5003	wantOnce("first response byte")
5004	if mode == http2Mode {
5005		wantOnce("tls handshake start")
5006		wantOnce("tls handshake done")
5007	} else {
5008		wantOnce("PutIdleConn = <nil>")
5009		wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
5010		// TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the
5011		// WroteHeaderField hook is not yet implemented in h2.)
5012		wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
5013		wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
5014		wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
5015		wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
5016	}
5017	wantOnce("WroteHeaders")
5018	wantOnce("Wait100Continue")
5019	wantOnce("Got100Continue")
5020	wantOnce("WroteRequest: {Err:<nil>}")
5021	if strings.Contains(got, " to udp ") {
5022		t.Errorf("should not see UDP (DNS) connections")
5023	}
5024	if t.Failed() {
5025		t.Errorf("Output:\n%s", got)
5026	}
5027
5028	// And do a second request:
5029	req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
5030	req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5031	res, err = cst.c.Do(req)
5032	if err != nil {
5033		t.Fatal(err)
5034	}
5035	if res.StatusCode != 200 {
5036		t.Fatal(res.Status)
5037	}
5038	res.Body.Close()
5039
5040	mu.Lock()
5041	got = buf.String()
5042	mu.Unlock()
5043
5044	sub := "Getting conn for dns-is-faked.golang:"
5045	if gotn, want := strings.Count(got, sub), 2; gotn != want {
5046		t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
5047	}
5048
5049}
5050
5051func TestTransportEventTraceTLSVerify(t *testing.T) {
5052	run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode})
5053}
5054func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) {
5055	var mu sync.Mutex
5056	var buf strings.Builder
5057	logf := func(format string, args ...any) {
5058		mu.Lock()
5059		defer mu.Unlock()
5060		fmt.Fprintf(&buf, format, args...)
5061		buf.WriteByte('\n')
5062	}
5063
5064	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5065		t.Error("Unexpected request")
5066	}), func(ts *httptest.Server) {
5067		ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
5068			logf("%s", p)
5069			return len(p), nil
5070		}), "", 0)
5071	}).ts
5072
5073	certpool := x509.NewCertPool()
5074	certpool.AddCert(ts.Certificate())
5075
5076	c := &Client{Transport: &Transport{
5077		TLSClientConfig: &tls.Config{
5078			ServerName: "dns-is-faked.golang",
5079			RootCAs:    certpool,
5080		},
5081	}}
5082
5083	trace := &httptrace.ClientTrace{
5084		TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
5085		TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5086			logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
5087		},
5088	}
5089
5090	req, _ := NewRequest("GET", ts.URL, nil)
5091	req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5092	_, err := c.Do(req)
5093	if err == nil {
5094		t.Error("Expected request to fail TLS verification")
5095	}
5096
5097	mu.Lock()
5098	got := buf.String()
5099	mu.Unlock()
5100
5101	wantOnce := func(sub string) {
5102		if strings.Count(got, sub) != 1 {
5103			t.Errorf("expected substring %q exactly once in output.", sub)
5104		}
5105	}
5106
5107	wantOnce("TLSHandshakeStart")
5108	wantOnce("TLSHandshakeDone")
5109	wantOnce("err = tls: failed to verify certificate: x509: certificate is valid for example.com")
5110
5111	if t.Failed() {
5112		t.Errorf("Output:\n%s", got)
5113	}
5114}
5115
5116var (
5117	isDNSHijackedOnce sync.Once
5118	isDNSHijacked     bool
5119)
5120
5121func skipIfDNSHijacked(t *testing.T) {
5122	// Skip this test if the user is using a shady/ISP
5123	// DNS server hijacking queries.
5124	// See issues 16732, 16716.
5125	isDNSHijackedOnce.Do(func() {
5126		addrs, _ := net.LookupHost("dns-should-not-resolve.golang")
5127		isDNSHijacked = len(addrs) != 0
5128	})
5129	if isDNSHijacked {
5130		t.Skip("skipping; test requires non-hijacking DNS server")
5131	}
5132}
5133
5134func TestTransportEventTraceRealDNS(t *testing.T) {
5135	skipIfDNSHijacked(t)
5136	defer afterTest(t)
5137	tr := &Transport{}
5138	defer tr.CloseIdleConnections()
5139	c := &Client{Transport: tr}
5140
5141	var mu sync.Mutex // guards buf
5142	var buf strings.Builder
5143	logf := func(format string, args ...any) {
5144		mu.Lock()
5145		defer mu.Unlock()
5146		fmt.Fprintf(&buf, format, args...)
5147		buf.WriteByte('\n')
5148	}
5149
5150	req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
5151	trace := &httptrace.ClientTrace{
5152		DNSStart:     func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
5153		DNSDone:      func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
5154		ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
5155		ConnectDone:  func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
5156	}
5157	req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
5158
5159	resp, err := c.Do(req)
5160	if err == nil {
5161		resp.Body.Close()
5162		t.Fatal("expected error during DNS lookup")
5163	}
5164
5165	mu.Lock()
5166	got := buf.String()
5167	mu.Unlock()
5168
5169	wantSub := func(sub string) {
5170		if !strings.Contains(got, sub) {
5171			t.Errorf("expected substring %q in output.", sub)
5172		}
5173	}
5174	wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
5175	wantSub("DNSDone: {Addrs:[] Err:")
5176	if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
5177		t.Errorf("should not see Connect events")
5178	}
5179	if t.Failed() {
5180		t.Errorf("Output:\n%s", got)
5181	}
5182}
5183
5184// Issue 14353: port can only contain digits.
5185func TestTransportRejectsAlphaPort(t *testing.T) {
5186	res, err := Get("http://dummy.tld:123foo/bar")
5187	if err == nil {
5188		res.Body.Close()
5189		t.Fatal("unexpected success")
5190	}
5191	ue, ok := err.(*url.Error)
5192	if !ok {
5193		t.Fatalf("got %#v; want *url.Error", err)
5194	}
5195	got := ue.Err.Error()
5196	want := `invalid port ":123foo" after host`
5197	if got != want {
5198		t.Errorf("got error %q; want %q", got, want)
5199	}
5200}
5201
5202// Test the httptrace.TLSHandshake{Start,Done} hooks with an https http1
5203// connections. The http2 test is done in TestTransportEventTrace_h2
5204func TestTLSHandshakeTrace(t *testing.T) {
5205	run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode})
5206}
5207func testTLSHandshakeTrace(t *testing.T, mode testMode) {
5208	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
5209
5210	var mu sync.Mutex
5211	var start, done bool
5212	trace := &httptrace.ClientTrace{
5213		TLSHandshakeStart: func() {
5214			mu.Lock()
5215			defer mu.Unlock()
5216			start = true
5217		},
5218		TLSHandshakeDone: func(s tls.ConnectionState, err error) {
5219			mu.Lock()
5220			defer mu.Unlock()
5221			done = true
5222			if err != nil {
5223				t.Fatal("Expected error to be nil but was:", err)
5224			}
5225		},
5226	}
5227
5228	c := ts.Client()
5229	req, err := NewRequest("GET", ts.URL, nil)
5230	if err != nil {
5231		t.Fatal("Unable to construct test request:", err)
5232	}
5233	req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
5234
5235	r, err := c.Do(req)
5236	if err != nil {
5237		t.Fatal("Unexpected error making request:", err)
5238	}
5239	r.Body.Close()
5240	mu.Lock()
5241	defer mu.Unlock()
5242	if !start {
5243		t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
5244	}
5245	if !done {
5246		t.Fatal("Expected TLSHandshakeDone to be called, but wasn't")
5247	}
5248}
5249
5250func TestTransportMaxIdleConns(t *testing.T) {
5251	run(t, testTransportMaxIdleConns, []testMode{http1Mode})
5252}
5253func testTransportMaxIdleConns(t *testing.T, mode testMode) {
5254	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5255		// No body for convenience.
5256	})).ts
5257	c := ts.Client()
5258	tr := c.Transport.(*Transport)
5259	tr.MaxIdleConns = 4
5260
5261	ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
5262	if err != nil {
5263		t.Fatal(err)
5264	}
5265	ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, _, host string) ([]net.IPAddr, error) {
5266		return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5267	})
5268
5269	hitHost := func(n int) {
5270		req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil)
5271		req = req.WithContext(ctx)
5272		res, err := c.Do(req)
5273		if err != nil {
5274			t.Fatal(err)
5275		}
5276		res.Body.Close()
5277	}
5278	for i := 0; i < 4; i++ {
5279		hitHost(i)
5280	}
5281	want := []string{
5282		"|http|host-0.dns-is-faked.golang:" + port,
5283		"|http|host-1.dns-is-faked.golang:" + port,
5284		"|http|host-2.dns-is-faked.golang:" + port,
5285		"|http|host-3.dns-is-faked.golang:" + port,
5286	}
5287	if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
5288		t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want)
5289	}
5290
5291	// Now hitting the 5th host should kick out the first host:
5292	hitHost(4)
5293	want = []string{
5294		"|http|host-1.dns-is-faked.golang:" + port,
5295		"|http|host-2.dns-is-faked.golang:" + port,
5296		"|http|host-3.dns-is-faked.golang:" + port,
5297		"|http|host-4.dns-is-faked.golang:" + port,
5298	}
5299	if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) {
5300		t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want)
5301	}
5302}
5303
5304func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) }
5305func testTransportIdleConnTimeout(t *testing.T, mode testMode) {
5306	if testing.Short() {
5307		t.Skip("skipping in short mode")
5308	}
5309
5310	timeout := 1 * time.Millisecond
5311timeoutLoop:
5312	for {
5313		cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5314			// No body for convenience.
5315		}))
5316		tr := cst.tr
5317		tr.IdleConnTimeout = timeout
5318		defer tr.CloseIdleConnections()
5319		c := &Client{Transport: tr}
5320
5321		idleConns := func() []string {
5322			if mode == http2Mode {
5323				return tr.IdleConnStrsForTesting_h2()
5324			} else {
5325				return tr.IdleConnStrsForTesting()
5326			}
5327		}
5328
5329		var conn string
5330		doReq := func(n int) (timeoutOk bool) {
5331			req, _ := NewRequest("GET", cst.ts.URL, nil)
5332			req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5333				PutIdleConn: func(err error) {
5334					if err != nil {
5335						t.Errorf("failed to keep idle conn: %v", err)
5336					}
5337				},
5338			}))
5339			res, err := c.Do(req)
5340			if err != nil {
5341				if strings.Contains(err.Error(), "use of closed network connection") {
5342					t.Logf("req %v: connection closed prematurely", n)
5343					return false
5344				}
5345			}
5346			res.Body.Close()
5347			conns := idleConns()
5348			if len(conns) != 1 {
5349				if len(conns) == 0 {
5350					t.Logf("req %v: no idle conns", n)
5351					return false
5352				}
5353				t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
5354			}
5355			if conn == "" {
5356				conn = conns[0]
5357			}
5358			if conn != conns[0] {
5359				t.Logf("req %v: cached connection changed; expected the same one throughout the test", n)
5360				return false
5361			}
5362			return true
5363		}
5364		for i := 0; i < 3; i++ {
5365			if !doReq(i) {
5366				t.Logf("idle conn timeout %v appears to be too short; retrying with longer", timeout)
5367				timeout *= 2
5368				cst.close()
5369				continue timeoutLoop
5370			}
5371			time.Sleep(timeout / 2)
5372		}
5373
5374		waitCondition(t, timeout/2, func(d time.Duration) bool {
5375			if got := idleConns(); len(got) != 0 {
5376				if d >= timeout*3/2 {
5377					t.Logf("after %v, idle conns = %q", d, got)
5378				}
5379				return false
5380			}
5381			return true
5382		})
5383		break
5384	}
5385}
5386
5387// Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an
5388// HTTP/2 connection was established but its caller no longer
5389// wanted it. (Assuming the connection cache was enabled, which it is
5390// by default)
5391//
5392// This test reproduced the crash by setting the IdleConnTimeout low
5393// (to make the test reasonable) and then making a request which is
5394// canceled by the DialTLS hook, which then also waits to return the
5395// real connection until after the RoundTrip saw the error.  Then we
5396// know the successful tls.Dial from DialTLS will need to go into the
5397// idle pool. Then we give it a of time to explode.
5398func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) }
5399func testIdleConnH2Crash(t *testing.T, mode testMode) {
5400	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5401		// nothing
5402	}))
5403
5404	ctx, cancel := context.WithCancel(context.Background())
5405	defer cancel()
5406
5407	sawDoErr := make(chan bool, 1)
5408	testDone := make(chan struct{})
5409	defer close(testDone)
5410
5411	cst.tr.IdleConnTimeout = 5 * time.Millisecond
5412	cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
5413		c, err := tls.Dial(network, addr, &tls.Config{
5414			InsecureSkipVerify: true,
5415			NextProtos:         []string{"h2"},
5416		})
5417		if err != nil {
5418			t.Error(err)
5419			return nil, err
5420		}
5421		if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" {
5422			t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2")
5423			c.Close()
5424			return nil, errors.New("bogus")
5425		}
5426
5427		cancel()
5428
5429		select {
5430		case <-sawDoErr:
5431		case <-testDone:
5432		}
5433		return c, nil
5434	}
5435
5436	req, _ := NewRequest("GET", cst.ts.URL, nil)
5437	req = req.WithContext(ctx)
5438	res, err := cst.c.Do(req)
5439	if err == nil {
5440		res.Body.Close()
5441		t.Fatal("unexpected success")
5442	}
5443	sawDoErr <- true
5444
5445	// Wait for the explosion.
5446	time.Sleep(cst.tr.IdleConnTimeout * 10)
5447}
5448
5449type funcConn struct {
5450	net.Conn
5451	read  func([]byte) (int, error)
5452	write func([]byte) (int, error)
5453}
5454
5455func (c funcConn) Read(p []byte) (int, error)  { return c.read(p) }
5456func (c funcConn) Write(p []byte) (int, error) { return c.write(p) }
5457func (c funcConn) Close() error                { return nil }
5458
5459// Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek
5460// back to the caller.
5461func TestTransportReturnsPeekError(t *testing.T) {
5462	errValue := errors.New("specific error value")
5463
5464	wrote := make(chan struct{})
5465	var wroteOnce sync.Once
5466
5467	tr := &Transport{
5468		Dial: func(network, addr string) (net.Conn, error) {
5469			c := funcConn{
5470				read: func([]byte) (int, error) {
5471					<-wrote
5472					return 0, errValue
5473				},
5474				write: func(p []byte) (int, error) {
5475					wroteOnce.Do(func() { close(wrote) })
5476					return len(p), nil
5477				},
5478			}
5479			return c, nil
5480		},
5481	}
5482	_, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil))
5483	if err != errValue {
5484		t.Errorf("error = %#v; want %v", err, errValue)
5485	}
5486}
5487
5488// Issue 13835: international domain names should work
5489func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) }
5490func testTransportIDNA(t *testing.T, mode testMode) {
5491	const uniDomain = "гофер.го"
5492	const punyDomain = "xn--c1ae0ajs.xn--c1aw"
5493
5494	var port string
5495	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5496		want := punyDomain + ":" + port
5497		if r.Host != want {
5498			t.Errorf("Host header = %q; want %q", r.Host, want)
5499		}
5500		if mode == http2Mode {
5501			if r.TLS == nil {
5502				t.Errorf("r.TLS == nil")
5503			} else if r.TLS.ServerName != punyDomain {
5504				t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
5505			}
5506		}
5507		w.Header().Set("Hit-Handler", "1")
5508	}), func(tr *Transport) {
5509		if tr.TLSClientConfig != nil {
5510			tr.TLSClientConfig.InsecureSkipVerify = true
5511		}
5512	})
5513
5514	ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
5515	if err != nil {
5516		t.Fatal(err)
5517	}
5518
5519	// Install a fake DNS server.
5520	ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, network, host string) ([]net.IPAddr, error) {
5521		if host != punyDomain {
5522			t.Errorf("got DNS host lookup for %q/%q; want %q", network, host, punyDomain)
5523			return nil, nil
5524		}
5525		return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
5526	})
5527
5528	req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
5529	trace := &httptrace.ClientTrace{
5530		GetConn: func(hostPort string) {
5531			want := net.JoinHostPort(punyDomain, port)
5532			if hostPort != want {
5533				t.Errorf("getting conn for %q; want %q", hostPort, want)
5534			}
5535		},
5536		DNSStart: func(e httptrace.DNSStartInfo) {
5537			if e.Host != punyDomain {
5538				t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
5539			}
5540		},
5541	}
5542	req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
5543
5544	res, err := cst.tr.RoundTrip(req)
5545	if err != nil {
5546		t.Fatal(err)
5547	}
5548	defer res.Body.Close()
5549	if res.Header.Get("Hit-Handler") != "1" {
5550		out, err := httputil.DumpResponse(res, true)
5551		if err != nil {
5552			t.Fatal(err)
5553		}
5554		t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
5555	}
5556}
5557
5558// Issue 13290: send User-Agent in proxy CONNECT
5559func TestTransportProxyConnectHeader(t *testing.T) {
5560	run(t, testTransportProxyConnectHeader, []testMode{http1Mode})
5561}
5562func testTransportProxyConnectHeader(t *testing.T, mode testMode) {
5563	reqc := make(chan *Request, 1)
5564	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5565		if r.Method != "CONNECT" {
5566			t.Errorf("method = %q; want CONNECT", r.Method)
5567		}
5568		reqc <- r
5569		c, _, err := w.(Hijacker).Hijack()
5570		if err != nil {
5571			t.Errorf("Hijack: %v", err)
5572			return
5573		}
5574		c.Close()
5575	})).ts
5576
5577	c := ts.Client()
5578	c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5579		return url.Parse(ts.URL)
5580	}
5581	c.Transport.(*Transport).ProxyConnectHeader = Header{
5582		"User-Agent": {"foo"},
5583		"Other":      {"bar"},
5584	}
5585
5586	res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
5587	if err == nil {
5588		res.Body.Close()
5589		t.Errorf("unexpected success")
5590	}
5591
5592	r := <-reqc
5593	if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
5594		t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5595	}
5596	if got, want := r.Header.Get("Other"), "bar"; got != want {
5597		t.Errorf("CONNECT request Other = %q; want %q", got, want)
5598	}
5599}
5600
5601func TestTransportProxyGetConnectHeader(t *testing.T) {
5602	run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode})
5603}
5604func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) {
5605	reqc := make(chan *Request, 1)
5606	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5607		if r.Method != "CONNECT" {
5608			t.Errorf("method = %q; want CONNECT", r.Method)
5609		}
5610		reqc <- r
5611		c, _, err := w.(Hijacker).Hijack()
5612		if err != nil {
5613			t.Errorf("Hijack: %v", err)
5614			return
5615		}
5616		c.Close()
5617	})).ts
5618
5619	c := ts.Client()
5620	c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) {
5621		return url.Parse(ts.URL)
5622	}
5623	// These should be ignored:
5624	c.Transport.(*Transport).ProxyConnectHeader = Header{
5625		"User-Agent": {"foo"},
5626		"Other":      {"bar"},
5627	}
5628	c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) {
5629		return Header{
5630			"User-Agent": {"foo2"},
5631			"Other":      {"bar2"},
5632		}, nil
5633	}
5634
5635	res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
5636	if err == nil {
5637		res.Body.Close()
5638		t.Errorf("unexpected success")
5639	}
5640
5641	r := <-reqc
5642	if got, want := r.Header.Get("User-Agent"), "foo2"; got != want {
5643		t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
5644	}
5645	if got, want := r.Header.Get("Other"), "bar2"; got != want {
5646		t.Errorf("CONNECT request Other = %q; want %q", got, want)
5647	}
5648}
5649
5650var errFakeRoundTrip = errors.New("fake roundtrip")
5651
5652type funcRoundTripper func()
5653
5654func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
5655	fn()
5656	return nil, errFakeRoundTrip
5657}
5658
5659func wantBody(res *Response, err error, want string) error {
5660	if err != nil {
5661		return err
5662	}
5663	slurp, err := io.ReadAll(res.Body)
5664	if err != nil {
5665		return fmt.Errorf("error reading body: %v", err)
5666	}
5667	if string(slurp) != want {
5668		return fmt.Errorf("body = %q; want %q", slurp, want)
5669	}
5670	if err := res.Body.Close(); err != nil {
5671		return fmt.Errorf("body Close = %v", err)
5672	}
5673	return nil
5674}
5675
5676func newLocalListener(t *testing.T) net.Listener {
5677	ln, err := net.Listen("tcp", "127.0.0.1:0")
5678	if err != nil {
5679		ln, err = net.Listen("tcp6", "[::1]:0")
5680	}
5681	if err != nil {
5682		t.Fatal(err)
5683	}
5684	return ln
5685}
5686
5687type countCloseReader struct {
5688	n *int
5689	io.Reader
5690}
5691
5692func (cr countCloseReader) Close() error {
5693	(*cr.n)++
5694	return nil
5695}
5696
5697// rgz is a gzip quine that uncompresses to itself.
5698var rgz = []byte{
5699	0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00,
5700	0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73,
5701	0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0,
5702	0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2,
5703	0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17,
5704	0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60,
5705	0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2,
5706	0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00,
5707	0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00,
5708	0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16,
5709	0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05,
5710	0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff,
5711	0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00,
5712	0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00,
5713	0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5714	0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88,
5715	0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff,
5716	0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00,
5717	0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00,
5718	0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4,
5719	0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00,
5720	0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff,
5721	0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00,
5722	0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00,
5723	0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16,
5724	0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08,
5725	0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa,
5726	0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06,
5727	0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00,
5728	0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5729	0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00,
5730	0x00, 0x00,
5731}
5732
5733// Ensure that a missing status doesn't make the server panic
5734// See Issue https://golang.org/issues/21701
5735func TestMissingStatusNoPanic(t *testing.T) {
5736	t.Parallel()
5737
5738	const want = "unknown status code"
5739
5740	ln := newLocalListener(t)
5741	addr := ln.Addr().String()
5742	done := make(chan bool)
5743	fullAddrURL := fmt.Sprintf("http://%s", addr)
5744	raw := "HTTP/1.1 400\r\n" +
5745		"Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
5746		"Content-Type: text/html; charset=utf-8\r\n" +
5747		"Content-Length: 10\r\n" +
5748		"Last-Modified: Wed, 30 Aug 2017 19:02:02 GMT\r\n" +
5749		"Vary: Accept-Encoding\r\n\r\n" +
5750		"Aloha Olaa"
5751
5752	go func() {
5753		defer close(done)
5754
5755		conn, _ := ln.Accept()
5756		if conn != nil {
5757			io.WriteString(conn, raw)
5758			io.ReadAll(conn)
5759			conn.Close()
5760		}
5761	}()
5762
5763	proxyURL, err := url.Parse(fullAddrURL)
5764	if err != nil {
5765		t.Fatalf("proxyURL: %v", err)
5766	}
5767
5768	tr := &Transport{Proxy: ProxyURL(proxyURL)}
5769
5770	req, _ := NewRequest("GET", "https://golang.org/", nil)
5771	res, err, panicked := doFetchCheckPanic(tr, req)
5772	if panicked {
5773		t.Error("panicked, expecting an error")
5774	}
5775	if res != nil && res.Body != nil {
5776		io.Copy(io.Discard, res.Body)
5777		res.Body.Close()
5778	}
5779
5780	if err == nil || !strings.Contains(err.Error(), want) {
5781		t.Errorf("got=%v want=%q", err, want)
5782	}
5783
5784	ln.Close()
5785	<-done
5786}
5787
5788func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, panicked bool) {
5789	defer func() {
5790		if r := recover(); r != nil {
5791			panicked = true
5792		}
5793	}()
5794	res, err = tr.RoundTrip(req)
5795	return
5796}
5797
5798// Issue 22330: do not allow the response body to be read when the status code
5799// forbids a response body.
5800func TestNoBodyOnChunked304Response(t *testing.T) {
5801	run(t, testNoBodyOnChunked304Response, []testMode{http1Mode})
5802}
5803func testNoBodyOnChunked304Response(t *testing.T, mode testMode) {
5804	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5805		conn, buf, _ := w.(Hijacker).Hijack()
5806		buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"))
5807		buf.Flush()
5808		conn.Close()
5809	}))
5810
5811	// Our test server above is sending back bogus data after the
5812	// response (the "0\r\n\r\n" part), which causes the Transport
5813	// code to log spam. Disable keep-alives so we never even try
5814	// to reuse the connection.
5815	cst.tr.DisableKeepAlives = true
5816
5817	res, err := cst.c.Get(cst.ts.URL)
5818	if err != nil {
5819		t.Fatal(err)
5820	}
5821
5822	if res.Body != NoBody {
5823		t.Errorf("Unexpected body on 304 response")
5824	}
5825}
5826
5827type funcWriter func([]byte) (int, error)
5828
5829func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
5830
5831type doneContext struct {
5832	context.Context
5833	err error
5834}
5835
5836func (doneContext) Done() <-chan struct{} {
5837	c := make(chan struct{})
5838	close(c)
5839	return c
5840}
5841
5842func (d doneContext) Err() error { return d.err }
5843
5844// Issue 25852: Transport should check whether Context is done early.
5845func TestTransportCheckContextDoneEarly(t *testing.T) {
5846	tr := &Transport{}
5847	req, _ := NewRequest("GET", "http://fake.example/", nil)
5848	wantErr := errors.New("some error")
5849	req = req.WithContext(doneContext{context.Background(), wantErr})
5850	_, err := tr.RoundTrip(req)
5851	if err != wantErr {
5852		t.Errorf("error = %v; want %v", err, wantErr)
5853	}
5854}
5855
5856// Issue 23399: verify that if a client request times out, the Transport's
5857// conn is closed so that it's not reused.
5858//
5859// This is the test variant that times out before the server replies with
5860// any response headers.
5861func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
5862	run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode})
5863}
5864func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
5865	timeout := 1 * time.Millisecond
5866	for {
5867		inHandler := make(chan bool)
5868		cancelHandler := make(chan struct{})
5869		handlerDone := make(chan bool)
5870		cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5871			<-r.Context().Done()
5872
5873			select {
5874			case <-cancelHandler:
5875				return
5876			case inHandler <- true:
5877			}
5878			defer func() { handlerDone <- true }()
5879
5880			// Read from the conn until EOF to verify that it was correctly closed.
5881			conn, _, err := w.(Hijacker).Hijack()
5882			if err != nil {
5883				t.Error(err)
5884				return
5885			}
5886			n, err := conn.Read([]byte{0})
5887			if n != 0 || err != io.EOF {
5888				t.Errorf("unexpected Read result: %v, %v", n, err)
5889			}
5890			conn.Close()
5891		}))
5892
5893		cst.c.Timeout = timeout
5894
5895		_, err := cst.c.Get(cst.ts.URL)
5896		if err == nil {
5897			close(cancelHandler)
5898			t.Fatal("unexpected Get success")
5899		}
5900
5901		tooSlow := time.NewTimer(timeout * 10)
5902		select {
5903		case <-tooSlow.C:
5904			// If we didn't get into the Handler, that probably means the builder was
5905			// just slow and the Get failed in that time but never made it to the
5906			// server. That's fine; we'll try again with a longer timeout.
5907			t.Logf("no handler seen in %v; retrying with longer timeout", timeout)
5908			close(cancelHandler)
5909			cst.close()
5910			timeout *= 2
5911			continue
5912		case <-inHandler:
5913			tooSlow.Stop()
5914			<-handlerDone
5915		}
5916		break
5917	}
5918}
5919
5920// Issue 23399: verify that if a client request times out, the Transport's
5921// conn is closed so that it's not reused.
5922//
5923// This is the test variant that has the server send response headers
5924// first, and time out during the write of the response body.
5925func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
5926	run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
5927}
5928func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
5929	inHandler := make(chan bool)
5930	cancelHandler := make(chan struct{})
5931	handlerDone := make(chan bool)
5932	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5933		w.Header().Set("Content-Length", "100")
5934		w.(Flusher).Flush()
5935
5936		select {
5937		case <-cancelHandler:
5938			return
5939		case inHandler <- true:
5940		}
5941		defer func() { handlerDone <- true }()
5942
5943		conn, _, err := w.(Hijacker).Hijack()
5944		if err != nil {
5945			t.Error(err)
5946			return
5947		}
5948		conn.Write([]byte("foo"))
5949
5950		n, err := conn.Read([]byte{0})
5951		// The error should be io.EOF or "read tcp
5952		// 127.0.0.1:35827->127.0.0.1:40290: read: connection
5953		// reset by peer" depending on timing. Really we just
5954		// care that it returns at all. But if it returns with
5955		// data, that's weird.
5956		if n != 0 || err == nil {
5957			t.Errorf("unexpected Read result: %v, %v", n, err)
5958		}
5959		conn.Close()
5960	}))
5961
5962	// Set Timeout to something very long but non-zero to exercise
5963	// the codepaths that check for it. But rather than wait for it to fire
5964	// (which would make the test slow), we send on the req.Cancel channel instead,
5965	// which happens to exercise the same code paths.
5966	cst.c.Timeout = 24 * time.Hour // just to be non-zero, not to hit it.
5967	req, _ := NewRequest("GET", cst.ts.URL, nil)
5968	cancelReq := make(chan struct{})
5969	req.Cancel = cancelReq
5970
5971	res, err := cst.c.Do(req)
5972	if err != nil {
5973		close(cancelHandler)
5974		t.Fatalf("Get error: %v", err)
5975	}
5976
5977	// Cancel the request while the handler is still blocked on sending to the
5978	// inHandler channel. Then read it until it fails, to verify that the
5979	// connection is broken before the handler itself closes it.
5980	close(cancelReq)
5981	got, err := io.ReadAll(res.Body)
5982	if err == nil {
5983		t.Errorf("unexpected success; read %q, nil", got)
5984	}
5985
5986	// Now unblock the handler and wait for it to complete.
5987	<-inHandler
5988	<-handlerDone
5989}
5990
5991func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
5992	run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode})
5993}
5994func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) {
5995	done := make(chan struct{})
5996	defer close(done)
5997	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5998		conn, _, err := w.(Hijacker).Hijack()
5999		if err != nil {
6000			t.Error(err)
6001			return
6002		}
6003		defer conn.Close()
6004		io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
6005		bs := bufio.NewScanner(conn)
6006		bs.Scan()
6007		fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
6008		<-done
6009	}))
6010
6011	req, _ := NewRequest("GET", cst.ts.URL, nil)
6012	req.Header.Set("Upgrade", "foo")
6013	req.Header.Set("Connection", "upgrade")
6014	res, err := cst.c.Do(req)
6015	if err != nil {
6016		t.Fatal(err)
6017	}
6018	if res.StatusCode != 101 {
6019		t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
6020	}
6021	rwc, ok := res.Body.(io.ReadWriteCloser)
6022	if !ok {
6023		t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
6024	}
6025	defer rwc.Close()
6026	bs := bufio.NewScanner(rwc)
6027	if !bs.Scan() {
6028		t.Fatalf("expected readable input")
6029	}
6030	if got, want := bs.Text(), "Some buffered data"; got != want {
6031		t.Errorf("read %q; want %q", got, want)
6032	}
6033	io.WriteString(rwc, "echo\n")
6034	if !bs.Scan() {
6035		t.Fatalf("expected another line")
6036	}
6037	if got, want := bs.Text(), "ECHO"; got != want {
6038		t.Errorf("read %q; want %q", got, want)
6039	}
6040}
6041
6042func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) }
6043func testTransportCONNECTBidi(t *testing.T, mode testMode) {
6044	const target = "backend:443"
6045	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6046		if r.Method != "CONNECT" {
6047			t.Errorf("unexpected method %q", r.Method)
6048			w.WriteHeader(500)
6049			return
6050		}
6051		if r.RequestURI != target {
6052			t.Errorf("unexpected CONNECT target %q", r.RequestURI)
6053			w.WriteHeader(500)
6054			return
6055		}
6056		nc, brw, err := w.(Hijacker).Hijack()
6057		if err != nil {
6058			t.Error(err)
6059			return
6060		}
6061		defer nc.Close()
6062		nc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
6063		// Switch to a little protocol that capitalize its input lines:
6064		for {
6065			line, err := brw.ReadString('\n')
6066			if err != nil {
6067				if err != io.EOF {
6068					t.Error(err)
6069				}
6070				return
6071			}
6072			io.WriteString(brw, strings.ToUpper(line))
6073			brw.Flush()
6074		}
6075	}))
6076	pr, pw := io.Pipe()
6077	defer pw.Close()
6078	req, err := NewRequest("CONNECT", cst.ts.URL, pr)
6079	if err != nil {
6080		t.Fatal(err)
6081	}
6082	req.URL.Opaque = target
6083	res, err := cst.c.Do(req)
6084	if err != nil {
6085		t.Fatal(err)
6086	}
6087	defer res.Body.Close()
6088	if res.StatusCode != 200 {
6089		t.Fatalf("status code = %d; want 200", res.StatusCode)
6090	}
6091	br := bufio.NewReader(res.Body)
6092	for _, str := range []string{"foo", "bar", "baz"} {
6093		fmt.Fprintf(pw, "%s\n", str)
6094		got, err := br.ReadString('\n')
6095		if err != nil {
6096			t.Fatal(err)
6097		}
6098		got = strings.TrimSpace(got)
6099		want := strings.ToUpper(str)
6100		if got != want {
6101			t.Fatalf("got %q; want %q", got, want)
6102		}
6103	}
6104}
6105
6106func TestTransportRequestReplayable(t *testing.T) {
6107	someBody := io.NopCloser(strings.NewReader(""))
6108	tests := []struct {
6109		name string
6110		req  *Request
6111		want bool
6112	}{
6113		{
6114			name: "GET",
6115			req:  &Request{Method: "GET"},
6116			want: true,
6117		},
6118		{
6119			name: "GET_http.NoBody",
6120			req:  &Request{Method: "GET", Body: NoBody},
6121			want: true,
6122		},
6123		{
6124			name: "GET_body",
6125			req:  &Request{Method: "GET", Body: someBody},
6126			want: false,
6127		},
6128		{
6129			name: "POST",
6130			req:  &Request{Method: "POST"},
6131			want: false,
6132		},
6133		{
6134			name: "POST_idempotency-key",
6135			req:  &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}},
6136			want: true,
6137		},
6138		{
6139			name: "POST_x-idempotency-key",
6140			req:  &Request{Method: "POST", Header: Header{"X-Idempotency-Key": {"x"}}},
6141			want: true,
6142		},
6143		{
6144			name: "POST_body",
6145			req:  &Request{Method: "POST", Header: Header{"Idempotency-Key": {"x"}}, Body: someBody},
6146			want: false,
6147		},
6148	}
6149	for _, tt := range tests {
6150		t.Run(tt.name, func(t *testing.T) {
6151			got := tt.req.ExportIsReplayable()
6152			if got != tt.want {
6153				t.Errorf("replyable = %v; want %v", got, tt.want)
6154			}
6155		})
6156	}
6157}
6158
6159// testMockTCPConn is a mock TCP connection used to test that
6160// ReadFrom is called when sending the request body.
6161type testMockTCPConn struct {
6162	*net.TCPConn
6163
6164	ReadFromCalled bool
6165}
6166
6167func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
6168	c.ReadFromCalled = true
6169	return c.TCPConn.ReadFrom(r)
6170}
6171
6172func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) }
6173func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
6174	nBytes := int64(1 << 10)
6175	newFileFunc := func() (r io.Reader, done func(), err error) {
6176		f, err := os.CreateTemp("", "net-http-newfilefunc")
6177		if err != nil {
6178			return nil, nil, err
6179		}
6180
6181		// Write some bytes to the file to enable reading.
6182		if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
6183			return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
6184		}
6185		if _, err := f.Seek(0, 0); err != nil {
6186			return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
6187		}
6188
6189		done = func() {
6190			f.Close()
6191			os.Remove(f.Name())
6192		}
6193
6194		return f, done, nil
6195	}
6196
6197	newBufferFunc := func() (io.Reader, func(), error) {
6198		return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
6199	}
6200
6201	cases := []struct {
6202		name             string
6203		readerFunc       func() (io.Reader, func(), error)
6204		contentLength    int64
6205		expectedReadFrom bool
6206	}{
6207		{
6208			name:             "file, length",
6209			readerFunc:       newFileFunc,
6210			contentLength:    nBytes,
6211			expectedReadFrom: true,
6212		},
6213		{
6214			name:       "file, no length",
6215			readerFunc: newFileFunc,
6216		},
6217		{
6218			name:          "file, negative length",
6219			readerFunc:    newFileFunc,
6220			contentLength: -1,
6221		},
6222		{
6223			name:          "buffer",
6224			contentLength: nBytes,
6225			readerFunc:    newBufferFunc,
6226		},
6227		{
6228			name:       "buffer, no length",
6229			readerFunc: newBufferFunc,
6230		},
6231		{
6232			name:          "buffer, length -1",
6233			contentLength: -1,
6234			readerFunc:    newBufferFunc,
6235		},
6236	}
6237
6238	for _, tc := range cases {
6239		t.Run(tc.name, func(t *testing.T) {
6240			r, cleanup, err := tc.readerFunc()
6241			if err != nil {
6242				t.Fatal(err)
6243			}
6244			defer cleanup()
6245
6246			tConn := &testMockTCPConn{}
6247			trFunc := func(tr *Transport) {
6248				tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
6249					var d net.Dialer
6250					conn, err := d.DialContext(ctx, network, addr)
6251					if err != nil {
6252						return nil, err
6253					}
6254
6255					tcpConn, ok := conn.(*net.TCPConn)
6256					if !ok {
6257						return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
6258					}
6259
6260					tConn.TCPConn = tcpConn
6261					return tConn, nil
6262				}
6263			}
6264
6265			cst := newClientServerTest(
6266				t,
6267				mode,
6268				HandlerFunc(func(w ResponseWriter, r *Request) {
6269					io.Copy(io.Discard, r.Body)
6270					r.Body.Close()
6271					w.WriteHeader(200)
6272				}),
6273				trFunc,
6274			)
6275
6276			req, err := NewRequest("PUT", cst.ts.URL, r)
6277			if err != nil {
6278				t.Fatal(err)
6279			}
6280			req.ContentLength = tc.contentLength
6281			req.Header.Set("Content-Type", "application/octet-stream")
6282			resp, err := cst.c.Do(req)
6283			if err != nil {
6284				t.Fatal(err)
6285			}
6286			defer resp.Body.Close()
6287			if resp.StatusCode != 200 {
6288				t.Fatalf("status code = %d; want 200", resp.StatusCode)
6289			}
6290
6291			expectedReadFrom := tc.expectedReadFrom
6292			if mode != http1Mode {
6293				expectedReadFrom = false
6294			}
6295			if !tConn.ReadFromCalled && expectedReadFrom {
6296				t.Fatalf("did not call ReadFrom")
6297			}
6298
6299			if tConn.ReadFromCalled && !expectedReadFrom {
6300				t.Fatalf("ReadFrom was unexpectedly invoked")
6301			}
6302		})
6303	}
6304}
6305
6306func TestTransportClone(t *testing.T) {
6307	tr := &Transport{
6308		Proxy: func(*Request) (*url.URL, error) { panic("") },
6309		OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
6310			return nil
6311		},
6312		DialContext:            func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6313		Dial:                   func(network, addr string) (net.Conn, error) { panic("") },
6314		DialTLS:                func(network, addr string) (net.Conn, error) { panic("") },
6315		DialTLSContext:         func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
6316		TLSClientConfig:        new(tls.Config),
6317		TLSHandshakeTimeout:    time.Second,
6318		DisableKeepAlives:      true,
6319		DisableCompression:     true,
6320		MaxIdleConns:           1,
6321		MaxIdleConnsPerHost:    1,
6322		MaxConnsPerHost:        1,
6323		IdleConnTimeout:        time.Second,
6324		ResponseHeaderTimeout:  time.Second,
6325		ExpectContinueTimeout:  time.Second,
6326		ProxyConnectHeader:     Header{},
6327		GetProxyConnectHeader:  func(context.Context, *url.URL, string) (Header, error) { return nil, nil },
6328		MaxResponseHeaderBytes: 1,
6329		ForceAttemptHTTP2:      true,
6330		TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
6331			"foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
6332		},
6333		ReadBufferSize:  1,
6334		WriteBufferSize: 1,
6335	}
6336	tr2 := tr.Clone()
6337	rv := reflect.ValueOf(tr2).Elem()
6338	rt := rv.Type()
6339	for i := 0; i < rt.NumField(); i++ {
6340		sf := rt.Field(i)
6341		if !token.IsExported(sf.Name) {
6342			continue
6343		}
6344		if rv.Field(i).IsZero() {
6345			t.Errorf("cloned field t2.%s is zero", sf.Name)
6346		}
6347	}
6348
6349	if _, ok := tr2.TLSNextProto["foo"]; !ok {
6350		t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
6351	}
6352
6353	// But test that a nil TLSNextProto is kept nil:
6354	tr = new(Transport)
6355	tr2 = tr.Clone()
6356	if tr2.TLSNextProto != nil {
6357		t.Errorf("Transport.TLSNextProto unexpected non-nil")
6358	}
6359}
6360
6361func TestIs408(t *testing.T) {
6362	tests := []struct {
6363		in   string
6364		want bool
6365	}{
6366		{"HTTP/1.0 408", true},
6367		{"HTTP/1.1 408", true},
6368		{"HTTP/1.8 408", true},
6369		{"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now.
6370		{"HTTP/1.1 408 ", true},
6371		{"HTTP/1.1 40", false},
6372		{"http/1.0 408", false},
6373		{"HTTP/1-1 408", false},
6374	}
6375	for _, tt := range tests {
6376		if got := Export_is408Message([]byte(tt.in)); got != tt.want {
6377			t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
6378		}
6379	}
6380}
6381
6382func TestTransportIgnores408(t *testing.T) {
6383	run(t, testTransportIgnores408, []testMode{http1Mode}, testNotParallel)
6384}
6385func testTransportIgnores408(t *testing.T, mode testMode) {
6386	// Not parallel. Relies on mutating the log package's global Output.
6387	defer log.SetOutput(log.Writer())
6388
6389	var logout strings.Builder
6390	log.SetOutput(&logout)
6391
6392	const target = "backend:443"
6393
6394	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6395		nc, _, err := w.(Hijacker).Hijack()
6396		if err != nil {
6397			t.Error(err)
6398			return
6399		}
6400		defer nc.Close()
6401		nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
6402		nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail
6403	}))
6404	req, err := NewRequest("GET", cst.ts.URL, nil)
6405	if err != nil {
6406		t.Fatal(err)
6407	}
6408	res, err := cst.c.Do(req)
6409	if err != nil {
6410		t.Fatal(err)
6411	}
6412	slurp, err := io.ReadAll(res.Body)
6413	if err != nil {
6414		t.Fatal(err)
6415	}
6416	if err != nil {
6417		t.Fatal(err)
6418	}
6419	if string(slurp) != "ok" {
6420		t.Fatalf("got %q; want ok", slurp)
6421	}
6422
6423	waitCondition(t, 1*time.Millisecond, func(d time.Duration) bool {
6424		if n := cst.tr.IdleConnKeyCountForTesting(); n != 0 {
6425			if d > 0 {
6426				t.Logf("%v idle conns still present after %v", n, d)
6427			}
6428			return false
6429		}
6430		return true
6431	})
6432	if got := logout.String(); got != "" {
6433		t.Fatalf("expected no log output; got: %s", got)
6434	}
6435}
6436
6437func TestInvalidHeaderResponse(t *testing.T) {
6438	run(t, testInvalidHeaderResponse, []testMode{http1Mode})
6439}
6440func testInvalidHeaderResponse(t *testing.T, mode testMode) {
6441	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6442		conn, buf, _ := w.(Hijacker).Hijack()
6443		buf.Write([]byte("HTTP/1.1 200 OK\r\n" +
6444			"Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" +
6445			"Content-Type: text/html; charset=utf-8\r\n" +
6446			"Content-Length: 0\r\n" +
6447			"Foo : bar\r\n\r\n"))
6448		buf.Flush()
6449		conn.Close()
6450	}))
6451	res, err := cst.c.Get(cst.ts.URL)
6452	if err != nil {
6453		t.Fatal(err)
6454	}
6455	defer res.Body.Close()
6456	if v := res.Header.Get("Foo"); v != "" {
6457		t.Errorf(`unexpected "Foo" header: %q`, v)
6458	}
6459	if v := res.Header.Get("Foo "); v != "bar" {
6460		t.Errorf(`bad "Foo " header value: %q, want %q`, v, "bar")
6461	}
6462}
6463
6464type bodyCloser bool
6465
6466func (bc *bodyCloser) Close() error {
6467	*bc = true
6468	return nil
6469}
6470func (bc *bodyCloser) Read(b []byte) (n int, err error) {
6471	return 0, io.EOF
6472}
6473
6474// Issue 35015: ensure that Transport closes the body on any error
6475// with an invalid request, as promised by Client.Do docs.
6476func TestTransportClosesBodyOnInvalidRequests(t *testing.T) {
6477	run(t, testTransportClosesBodyOnInvalidRequests)
6478}
6479func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) {
6480	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6481		t.Errorf("Should not have been invoked")
6482	})).ts
6483
6484	u, _ := url.Parse(cst.URL)
6485
6486	tests := []struct {
6487		name    string
6488		req     *Request
6489		wantErr string
6490	}{
6491		{
6492			name: "invalid method",
6493			req: &Request{
6494				Method: " ",
6495				URL:    u,
6496			},
6497			wantErr: `invalid method " "`,
6498		},
6499		{
6500			name: "nil URL",
6501			req: &Request{
6502				Method: "GET",
6503			},
6504			wantErr: `nil Request.URL`,
6505		},
6506		{
6507			name: "invalid header key",
6508			req: &Request{
6509				Method: "GET",
6510				Header: Header{"��": {"emoji"}},
6511				URL:    u,
6512			},
6513			wantErr: `invalid header field name "��"`,
6514		},
6515		{
6516			name: "invalid header value",
6517			req: &Request{
6518				Method: "POST",
6519				Header: Header{"key": {"\x19"}},
6520				URL:    u,
6521			},
6522			wantErr: `invalid header field value for "key"`,
6523		},
6524		{
6525			name: "non HTTP(s) scheme",
6526			req: &Request{
6527				Method: "POST",
6528				URL:    &url.URL{Scheme: "faux"},
6529			},
6530			wantErr: `unsupported protocol scheme "faux"`,
6531		},
6532		{
6533			name: "no Host in URL",
6534			req: &Request{
6535				Method: "POST",
6536				URL:    &url.URL{Scheme: "http"},
6537			},
6538			wantErr: `no Host in request URL`,
6539		},
6540	}
6541
6542	for _, tt := range tests {
6543		t.Run(tt.name, func(t *testing.T) {
6544			var bc bodyCloser
6545			req := tt.req
6546			req.Body = &bc
6547			_, err := cst.Client().Do(tt.req)
6548			if err == nil {
6549				t.Fatal("Expected an error")
6550			}
6551			if !bc {
6552				t.Fatal("Expected body to have been closed")
6553			}
6554			if g, w := err.Error(), tt.wantErr; !strings.HasSuffix(g, w) {
6555				t.Fatalf("Error mismatch: %q does not end with %q", g, w)
6556			}
6557		})
6558	}
6559}
6560
6561// breakableConn is a net.Conn wrapper with a Write method
6562// that will fail when its brokenState is true.
6563type breakableConn struct {
6564	net.Conn
6565	*brokenState
6566}
6567
6568type brokenState struct {
6569	sync.Mutex
6570	broken bool
6571}
6572
6573func (w *breakableConn) Write(b []byte) (n int, err error) {
6574	w.Lock()
6575	defer w.Unlock()
6576	if w.broken {
6577		return 0, errors.New("some write error")
6578	}
6579	return w.Conn.Write(b)
6580}
6581
6582// Issue 34978: don't cache a broken HTTP/2 connection
6583func TestDontCacheBrokenHTTP2Conn(t *testing.T) {
6584	run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode})
6585}
6586func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) {
6587	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog)
6588
6589	var brokenState brokenState
6590
6591	const numReqs = 5
6592	var numDials, gotConns uint32 // atomic
6593
6594	cst.tr.Dial = func(netw, addr string) (net.Conn, error) {
6595		atomic.AddUint32(&numDials, 1)
6596		c, err := net.Dial(netw, addr)
6597		if err != nil {
6598			t.Errorf("unexpected Dial error: %v", err)
6599			return nil, err
6600		}
6601		return &breakableConn{c, &brokenState}, err
6602	}
6603
6604	for i := 1; i <= numReqs; i++ {
6605		brokenState.Lock()
6606		brokenState.broken = false
6607		brokenState.Unlock()
6608
6609		// doBreak controls whether we break the TCP connection after the TLS
6610		// handshake (before the HTTP/2 handshake). We test a few failures
6611		// in a row followed by a final success.
6612		doBreak := i != numReqs
6613
6614		ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6615			GotConn: func(info httptrace.GotConnInfo) {
6616				t.Logf("got conn: %v, reused=%v, wasIdle=%v, idleTime=%v", info.Conn.LocalAddr(), info.Reused, info.WasIdle, info.IdleTime)
6617				atomic.AddUint32(&gotConns, 1)
6618			},
6619			TLSHandshakeDone: func(cfg tls.ConnectionState, err error) {
6620				brokenState.Lock()
6621				defer brokenState.Unlock()
6622				if doBreak {
6623					brokenState.broken = true
6624				}
6625			},
6626		})
6627		req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6628		if err != nil {
6629			t.Fatal(err)
6630		}
6631		_, err = cst.c.Do(req)
6632		if doBreak != (err != nil) {
6633			t.Errorf("for iteration %d, doBreak=%v; unexpected error %v", i, doBreak, err)
6634		}
6635	}
6636	if got, want := atomic.LoadUint32(&gotConns), 1; int(got) != want {
6637		t.Errorf("GotConn calls = %v; want %v", got, want)
6638	}
6639	if got, want := atomic.LoadUint32(&numDials), numReqs; int(got) != want {
6640		t.Errorf("Dials = %v; want %v", got, want)
6641	}
6642}
6643
6644// Issue 34941
6645// When the client has too many concurrent requests on a single connection,
6646// http.http2noCachedConnError is reported on multiple requests. There should
6647// only be one decrement regardless of the number of failures.
6648func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) {
6649	run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode})
6650}
6651func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) {
6652	CondSkipHTTP2(t)
6653
6654	h := HandlerFunc(func(w ResponseWriter, r *Request) {
6655		_, err := w.Write([]byte("foo"))
6656		if err != nil {
6657			t.Fatalf("Write: %v", err)
6658		}
6659	})
6660
6661	ts := newClientServerTest(t, mode, h).ts
6662
6663	c := ts.Client()
6664	tr := c.Transport.(*Transport)
6665	tr.MaxConnsPerHost = 1
6666
6667	errCh := make(chan error, 300)
6668	doReq := func() {
6669		resp, err := c.Get(ts.URL)
6670		if err != nil {
6671			errCh <- fmt.Errorf("request failed: %v", err)
6672			return
6673		}
6674		defer resp.Body.Close()
6675		_, err = io.ReadAll(resp.Body)
6676		if err != nil {
6677			errCh <- fmt.Errorf("read body failed: %v", err)
6678		}
6679	}
6680
6681	var wg sync.WaitGroup
6682	for i := 0; i < 300; i++ {
6683		wg.Add(1)
6684		go func() {
6685			defer wg.Done()
6686			doReq()
6687		}()
6688	}
6689	wg.Wait()
6690	close(errCh)
6691
6692	for err := range errCh {
6693		t.Errorf("error occurred: %v", err)
6694	}
6695}
6696
6697// Issue 36820
6698// Test that we use the older backward compatible cancellation protocol
6699// when a RoundTripper is registered via RegisterProtocol.
6700func TestAltProtoCancellation(t *testing.T) {
6701	defer afterTest(t)
6702	tr := &Transport{}
6703	c := &Client{
6704		Transport: tr,
6705		Timeout:   time.Millisecond,
6706	}
6707	tr.RegisterProtocol("cancel", cancelProto{})
6708	_, err := c.Get("cancel://bar.com/path")
6709	if err == nil {
6710		t.Error("request unexpectedly succeeded")
6711	} else if !strings.Contains(err.Error(), errCancelProto.Error()) {
6712		t.Errorf("got error %q, does not contain expected string %q", err, errCancelProto)
6713	}
6714}
6715
6716var errCancelProto = errors.New("canceled as expected")
6717
6718type cancelProto struct{}
6719
6720func (cancelProto) RoundTrip(req *Request) (*Response, error) {
6721	<-req.Cancel
6722	return nil, errCancelProto
6723}
6724
6725type roundTripFunc func(r *Request) (*Response, error)
6726
6727func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) }
6728
6729// Issue 32441: body is not reset after ErrSkipAltProtocol
6730func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) }
6731func testIssue32441(t *testing.T, mode testMode) {
6732	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6733		if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6734			t.Error("body length is zero")
6735		}
6736	})).ts
6737	c := ts.Client()
6738	c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) {
6739		// Draining body to trigger failure condition on actual request to server.
6740		if n, _ := io.Copy(io.Discard, r.Body); n == 0 {
6741			t.Error("body length is zero during round trip")
6742		}
6743		return nil, ErrSkipAltProtocol
6744	}))
6745	if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil {
6746		t.Error(err)
6747	}
6748}
6749
6750// Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers
6751// that contain a sign (eg. "+3"), per RFC 2616, Section 14.13.
6752func TestTransportRejectsSignInContentLength(t *testing.T) {
6753	run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode})
6754}
6755func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) {
6756	cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6757		w.Header().Set("Content-Length", "+3")
6758		w.Write([]byte("abc"))
6759	})).ts
6760
6761	c := cst.Client()
6762	res, err := c.Get(cst.URL)
6763	if err == nil || res != nil {
6764		t.Fatal("Expected a non-nil error and a nil http.Response")
6765	}
6766	if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) {
6767		t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
6768	}
6769}
6770
6771// dumpConn is a net.Conn which writes to Writer and reads from Reader
6772type dumpConn struct {
6773	io.Writer
6774	io.Reader
6775}
6776
6777func (c *dumpConn) Close() error                       { return nil }
6778func (c *dumpConn) LocalAddr() net.Addr                { return nil }
6779func (c *dumpConn) RemoteAddr() net.Addr               { return nil }
6780func (c *dumpConn) SetDeadline(t time.Time) error      { return nil }
6781func (c *dumpConn) SetReadDeadline(t time.Time) error  { return nil }
6782func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
6783
6784// delegateReader is a reader that delegates to another reader,
6785// once it arrives on a channel.
6786type delegateReader struct {
6787	c chan io.Reader
6788	r io.Reader // nil until received from c
6789}
6790
6791func (r *delegateReader) Read(p []byte) (int, error) {
6792	if r.r == nil {
6793		var ok bool
6794		if r.r, ok = <-r.c; !ok {
6795			return 0, errors.New("delegate closed")
6796		}
6797	}
6798	return r.r.Read(p)
6799}
6800
6801func testTransportRace(req *Request) {
6802	save := req.Body
6803	pr, pw := io.Pipe()
6804	defer pr.Close()
6805	defer pw.Close()
6806	dr := &delegateReader{c: make(chan io.Reader)}
6807
6808	t := &Transport{
6809		Dial: func(net, addr string) (net.Conn, error) {
6810			return &dumpConn{pw, dr}, nil
6811		},
6812	}
6813	defer t.CloseIdleConnections()
6814
6815	quitReadCh := make(chan struct{})
6816	// Wait for the request before replying with a dummy response:
6817	go func() {
6818		defer close(quitReadCh)
6819
6820		req, err := ReadRequest(bufio.NewReader(pr))
6821		if err == nil {
6822			// Ensure all the body is read; otherwise
6823			// we'll get a partial dump.
6824			io.Copy(io.Discard, req.Body)
6825			req.Body.Close()
6826		}
6827		select {
6828		case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
6829		case quitReadCh <- struct{}{}:
6830			// Ensure delegate is closed so Read doesn't block forever.
6831			close(dr.c)
6832		}
6833	}()
6834
6835	t.RoundTrip(req)
6836
6837	// Ensure the reader returns before we reset req.Body to prevent
6838	// a data race on req.Body.
6839	pw.Close()
6840	<-quitReadCh
6841
6842	req.Body = save
6843}
6844
6845// Issue 37669
6846// Test that a cancellation doesn't result in a data race due to the writeLoop
6847// goroutine being left running, if the caller mutates the processed Request
6848// upon completion.
6849func TestErrorWriteLoopRace(t *testing.T) {
6850	if testing.Short() {
6851		return
6852	}
6853	t.Parallel()
6854	for i := 0; i < 1000; i++ {
6855		delay := time.Duration(mrand.Intn(5)) * time.Millisecond
6856		ctx, cancel := context.WithTimeout(context.Background(), delay)
6857		defer cancel()
6858
6859		r := bytes.NewBuffer(make([]byte, 10000))
6860		req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
6861		if err != nil {
6862			t.Fatal(err)
6863		}
6864
6865		testTransportRace(req)
6866	}
6867}
6868
6869// Issue 41600
6870// Test that a new request which uses the connection of an active request
6871// cannot cause it to be canceled as well.
6872func TestCancelRequestWhenSharingConnection(t *testing.T) {
6873	run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode})
6874}
6875func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) {
6876	reqc := make(chan chan struct{}, 2)
6877	ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
6878		ch := make(chan struct{}, 1)
6879		reqc <- ch
6880		<-ch
6881		w.Header().Add("Content-Length", "0")
6882	})).ts
6883
6884	client := ts.Client()
6885	transport := client.Transport.(*Transport)
6886	transport.MaxIdleConns = 1
6887	transport.MaxConnsPerHost = 1
6888
6889	var wg sync.WaitGroup
6890
6891	wg.Add(1)
6892	putidlec := make(chan chan struct{}, 1)
6893	reqerrc := make(chan error, 1)
6894	go func() {
6895		defer wg.Done()
6896		ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6897			PutIdleConn: func(error) {
6898				// Signal that the idle conn has been returned to the pool,
6899				// and wait for the order to proceed.
6900				ch := make(chan struct{})
6901				putidlec <- ch
6902				close(putidlec) // panic if PutIdleConn runs twice for some reason
6903				<-ch
6904			},
6905		})
6906		req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
6907		res, err := client.Do(req)
6908		if err != nil {
6909			reqerrc <- err
6910		} else {
6911			res.Body.Close()
6912		}
6913	}()
6914
6915	// Wait for the first request to receive a response and return the
6916	// connection to the idle pool.
6917	select {
6918	case err := <-reqerrc:
6919		t.Fatalf("request 1: got err %v, want nil", err)
6920	case r1c := <-reqc:
6921		close(r1c)
6922	}
6923	var idlec chan struct{}
6924	select {
6925	case err := <-reqerrc:
6926		t.Fatalf("request 1: got err %v, want nil", err)
6927	case idlec = <-putidlec:
6928	}
6929
6930	wg.Add(1)
6931	cancelctx, cancel := context.WithCancel(context.Background())
6932	go func() {
6933		defer wg.Done()
6934		req, _ := NewRequestWithContext(cancelctx, "GET", ts.URL, nil)
6935		res, err := client.Do(req)
6936		if err == nil {
6937			res.Body.Close()
6938		}
6939		if !errors.Is(err, context.Canceled) {
6940			t.Errorf("request 2: got err %v, want Canceled", err)
6941		}
6942
6943		// Unblock the first request.
6944		close(idlec)
6945	}()
6946
6947	// Wait for the second request to arrive at the server, and then cancel
6948	// the request context.
6949	r2c := <-reqc
6950	cancel()
6951
6952	<-idlec
6953
6954	close(r2c)
6955	wg.Wait()
6956}
6957
6958func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) }
6959func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) {
6960	ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6961		go io.Copy(io.Discard, req.Body)
6962		panic(ErrAbortHandler)
6963	})).ts
6964
6965	var wg sync.WaitGroup
6966	for i := 0; i < 2; i++ {
6967		wg.Add(1)
6968		go func() {
6969			defer wg.Done()
6970			for j := 0; j < 10; j++ {
6971				const reqLen = 6 * 1024 * 1024
6972				req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen})
6973				req.ContentLength = reqLen
6974				resp, _ := ts.Client().Transport.RoundTrip(req)
6975				if resp != nil {
6976					resp.Body.Close()
6977				}
6978			}
6979		}()
6980	}
6981	wg.Wait()
6982}
6983
6984func TestRequestSanitization(t *testing.T) { run(t, testRequestSanitization) }
6985func testRequestSanitization(t *testing.T, mode testMode) {
6986	if mode == http2Mode {
6987		// Remove this after updating x/net.
6988		t.Skip("https://go.dev/issue/60374 test fails when run with HTTP/2")
6989	}
6990	ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
6991		if h, ok := req.Header["X-Evil"]; ok {
6992			t.Errorf("request has X-Evil header: %q", h)
6993		}
6994	})).ts
6995	req, _ := NewRequest("GET", ts.URL, nil)
6996	req.Host = "go.dev\r\nX-Evil:evil"
6997	resp, _ := ts.Client().Do(req)
6998	if resp != nil {
6999		resp.Body.Close()
7000	}
7001}
7002
7003func TestProxyAuthHeader(t *testing.T) {
7004	// Not parallel: Sets an environment variable.
7005	run(t, testProxyAuthHeader, []testMode{http1Mode}, testNotParallel)
7006}
7007func testProxyAuthHeader(t *testing.T, mode testMode) {
7008	const username = "u"
7009	const password = "@/?!"
7010	cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7011		// Copy the Proxy-Authorization header to a new Request,
7012		// since Request.BasicAuth only parses the Authorization header.
7013		var r2 Request
7014		r2.Header = Header{
7015			"Authorization": req.Header["Proxy-Authorization"],
7016		}
7017		gotuser, gotpass, ok := r2.BasicAuth()
7018		if !ok || gotuser != username || gotpass != password {
7019			t.Errorf("req.BasicAuth() = %q, %q, %v; want %q, %q, true", gotuser, gotpass, ok, username, password)
7020		}
7021	}))
7022	u, err := url.Parse(cst.ts.URL)
7023	if err != nil {
7024		t.Fatal(err)
7025	}
7026	u.User = url.UserPassword(username, password)
7027	t.Setenv("HTTP_PROXY", u.String())
7028	cst.tr.Proxy = ProxyURL(u)
7029	resp, err := cst.c.Get("http://_/")
7030	if err != nil {
7031		t.Fatal(err)
7032	}
7033	resp.Body.Close()
7034}
7035
7036// Issue 61708
7037func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) {
7038	ln := newLocalListener(t)
7039	addr := ln.Addr().String()
7040
7041	done := make(chan struct{})
7042	go func() {
7043		conn, err := ln.Accept()
7044		if err != nil {
7045			t.Errorf("ln.Accept: %v", err)
7046			return
7047		}
7048		// Start reading request before sending response to avoid
7049		// "Unsolicited response received on idle HTTP channel" RoundTrip error.
7050		if _, err := io.ReadFull(conn, make([]byte, 1)); err != nil {
7051			t.Errorf("conn.Read: %v", err)
7052			return
7053		}
7054		io.WriteString(conn, "HTTP/1.1 200\r\nContent-Length: 3\r\n\r\nfoo")
7055		<-done
7056		conn.Close()
7057	}()
7058
7059	didRead := make(chan bool)
7060	SetReadLoopBeforeNextReadHook(func() { didRead <- true })
7061	defer SetReadLoopBeforeNextReadHook(nil)
7062
7063	tr := &Transport{}
7064
7065	// Send a request with a body guaranteed to fail on write.
7066	req, err := NewRequest("POST", "http://"+addr, io.LimitReader(neverEnding('x'), 1<<30))
7067	if err != nil {
7068		t.Fatalf("NewRequest: %v", err)
7069	}
7070
7071	resp, err := tr.RoundTrip(req)
7072	if err != nil {
7073		t.Fatalf("tr.RoundTrip: %v", err)
7074	}
7075
7076	close(done)
7077
7078	// Before closing response body wait for readLoopDone goroutine
7079	// to complete due to closed connection by writeLoop.
7080	<-didRead
7081
7082	resp.Body.Close()
7083
7084	// Verify no outstanding requests after readLoop/writeLoop
7085	// goroutines shut down.
7086	waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
7087		n := tr.NumPendingRequestsForTesting()
7088		if n > 0 {
7089			if d > 0 {
7090				t.Logf("pending requests = %d after %v (want 0)", n, d)
7091			}
7092			return false
7093		}
7094		return true
7095	})
7096}
7097
7098func TestValidateClientRequestTrailers(t *testing.T) {
7099	run(t, testValidateClientRequestTrailers)
7100}
7101
7102func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
7103	cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
7104		rw.Write([]byte("Hello"))
7105	})).ts
7106
7107	cases := []struct {
7108		trailer Header
7109		wantErr string
7110	}{
7111		{Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
7112		{Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
7113	}
7114
7115	for i, tt := range cases {
7116		testName := fmt.Sprintf("%s%d", mode, i)
7117		t.Run(testName, func(t *testing.T) {
7118			req, err := NewRequest("GET", cst.URL, nil)
7119			if err != nil {
7120				t.Fatal(err)
7121			}
7122			req.Trailer = tt.trailer
7123			res, err := cst.Client().Do(req)
7124			if err == nil {
7125				t.Fatal("Expected an error")
7126			}
7127			if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
7128				t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
7129			}
7130			if res != nil {
7131				t.Fatal("Unexpected non-nil response")
7132			}
7133		})
7134	}
7135}
7136