1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package http
6
7import (
8	"bufio"
9	"bytes"
10	"crypto/rand"
11	"fmt"
12	"io"
13	"os"
14	"reflect"
15	"strings"
16	"testing"
17)
18
19func TestBodyReadBadTrailer(t *testing.T) {
20	b := &body{
21		src: strings.NewReader("foobar"),
22		hdr: true, // force reading the trailer
23		r:   bufio.NewReader(strings.NewReader("")),
24	}
25	buf := make([]byte, 7)
26	n, err := b.Read(buf[:3])
27	got := string(buf[:n])
28	if got != "foo" || err != nil {
29		t.Fatalf(`first Read = %d (%q), %v; want 3 ("foo")`, n, got, err)
30	}
31
32	n, err = b.Read(buf[:])
33	got = string(buf[:n])
34	if got != "bar" || err != nil {
35		t.Fatalf(`second Read = %d (%q), %v; want 3 ("bar")`, n, got, err)
36	}
37
38	n, err = b.Read(buf[:])
39	got = string(buf[:n])
40	if err == nil {
41		t.Errorf("final Read was successful (%q), expected error from trailer read", got)
42	}
43}
44
45func TestFinalChunkedBodyReadEOF(t *testing.T) {
46	res, err := ReadResponse(bufio.NewReader(strings.NewReader(
47		"HTTP/1.1 200 OK\r\n"+
48			"Transfer-Encoding: chunked\r\n"+
49			"\r\n"+
50			"0a\r\n"+
51			"Body here\n\r\n"+
52			"09\r\n"+
53			"continued\r\n"+
54			"0\r\n"+
55			"\r\n")), nil)
56	if err != nil {
57		t.Fatal(err)
58	}
59	want := "Body here\ncontinued"
60	buf := make([]byte, len(want))
61	n, err := res.Body.Read(buf)
62	if n != len(want) || err != io.EOF {
63		t.Logf("body = %#v", res.Body)
64		t.Errorf("Read = %v, %v; want %d, EOF", n, err, len(want))
65	}
66	if string(buf) != want {
67		t.Errorf("buf = %q; want %q", buf, want)
68	}
69}
70
71func TestDetectInMemoryReaders(t *testing.T) {
72	pr, _ := io.Pipe()
73	tests := []struct {
74		r    io.Reader
75		want bool
76	}{
77		{pr, false},
78
79		{bytes.NewReader(nil), true},
80		{bytes.NewBuffer(nil), true},
81		{strings.NewReader(""), true},
82
83		{io.NopCloser(pr), false},
84
85		{io.NopCloser(bytes.NewReader(nil)), true},
86		{io.NopCloser(bytes.NewBuffer(nil)), true},
87		{io.NopCloser(strings.NewReader("")), true},
88	}
89	for i, tt := range tests {
90		got := isKnownInMemoryReader(tt.r)
91		if got != tt.want {
92			t.Errorf("%d: got = %v; want %v", i, got, tt.want)
93		}
94	}
95}
96
97type mockTransferWriter struct {
98	CalledReader io.Reader
99	WriteCalled  bool
100}
101
102var _ io.ReaderFrom = (*mockTransferWriter)(nil)
103
104func (w *mockTransferWriter) ReadFrom(r io.Reader) (int64, error) {
105	w.CalledReader = r
106	return io.Copy(io.Discard, r)
107}
108
109func (w *mockTransferWriter) Write(p []byte) (int, error) {
110	w.WriteCalled = true
111	return io.Discard.Write(p)
112}
113
114func TestTransferWriterWriteBodyReaderTypes(t *testing.T) {
115	fileType := reflect.TypeFor[*os.File]()
116	bufferType := reflect.TypeFor[*bytes.Buffer]()
117
118	nBytes := int64(1 << 10)
119	newFileFunc := func() (r io.Reader, done func(), err error) {
120		f, err := os.CreateTemp("", "net-http-newfilefunc")
121		if err != nil {
122			return nil, nil, err
123		}
124
125		// Write some bytes to the file to enable reading.
126		if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
127			return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
128		}
129		if _, err := f.Seek(0, 0); err != nil {
130			return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
131		}
132
133		done = func() {
134			f.Close()
135			os.Remove(f.Name())
136		}
137
138		return f, done, nil
139	}
140
141	newBufferFunc := func() (io.Reader, func(), error) {
142		return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
143	}
144
145	cases := []struct {
146		name             string
147		bodyFunc         func() (io.Reader, func(), error)
148		method           string
149		contentLength    int64
150		transferEncoding []string
151		limitedReader    bool
152		expectedReader   reflect.Type
153		expectedWrite    bool
154	}{
155		{
156			name:           "file, non-chunked, size set",
157			bodyFunc:       newFileFunc,
158			method:         "PUT",
159			contentLength:  nBytes,
160			limitedReader:  true,
161			expectedReader: fileType,
162		},
163		{
164			name:   "file, non-chunked, size set, nopCloser wrapped",
165			method: "PUT",
166			bodyFunc: func() (io.Reader, func(), error) {
167				r, cleanup, err := newFileFunc()
168				return io.NopCloser(r), cleanup, err
169			},
170			contentLength:  nBytes,
171			limitedReader:  true,
172			expectedReader: fileType,
173		},
174		{
175			name:           "file, non-chunked, negative size",
176			method:         "PUT",
177			bodyFunc:       newFileFunc,
178			contentLength:  -1,
179			expectedReader: fileType,
180		},
181		{
182			name:           "file, non-chunked, CONNECT, negative size",
183			method:         "CONNECT",
184			bodyFunc:       newFileFunc,
185			contentLength:  -1,
186			expectedReader: fileType,
187		},
188		{
189			name:             "file, chunked",
190			method:           "PUT",
191			bodyFunc:         newFileFunc,
192			transferEncoding: []string{"chunked"},
193			expectedWrite:    true,
194		},
195		{
196			name:           "buffer, non-chunked, size set",
197			bodyFunc:       newBufferFunc,
198			method:         "PUT",
199			contentLength:  nBytes,
200			limitedReader:  true,
201			expectedReader: bufferType,
202		},
203		{
204			name:   "buffer, non-chunked, size set, nopCloser wrapped",
205			method: "PUT",
206			bodyFunc: func() (io.Reader, func(), error) {
207				r, cleanup, err := newBufferFunc()
208				return io.NopCloser(r), cleanup, err
209			},
210			contentLength:  nBytes,
211			limitedReader:  true,
212			expectedReader: bufferType,
213		},
214		{
215			name:          "buffer, non-chunked, negative size",
216			method:        "PUT",
217			bodyFunc:      newBufferFunc,
218			contentLength: -1,
219			expectedWrite: true,
220		},
221		{
222			name:          "buffer, non-chunked, CONNECT, negative size",
223			method:        "CONNECT",
224			bodyFunc:      newBufferFunc,
225			contentLength: -1,
226			expectedWrite: true,
227		},
228		{
229			name:             "buffer, chunked",
230			method:           "PUT",
231			bodyFunc:         newBufferFunc,
232			transferEncoding: []string{"chunked"},
233			expectedWrite:    true,
234		},
235	}
236
237	for _, tc := range cases {
238		t.Run(tc.name, func(t *testing.T) {
239			body, cleanup, err := tc.bodyFunc()
240			if err != nil {
241				t.Fatal(err)
242			}
243			defer cleanup()
244
245			mw := &mockTransferWriter{}
246			tw := &transferWriter{
247				Body:             body,
248				ContentLength:    tc.contentLength,
249				TransferEncoding: tc.transferEncoding,
250			}
251
252			if err := tw.writeBody(mw); err != nil {
253				t.Fatal(err)
254			}
255
256			if tc.expectedReader != nil {
257				if mw.CalledReader == nil {
258					t.Fatal("did not call ReadFrom")
259				}
260
261				var actualReader reflect.Type
262				lr, ok := mw.CalledReader.(*io.LimitedReader)
263				if ok && tc.limitedReader {
264					actualReader = reflect.TypeOf(lr.R)
265				} else {
266					actualReader = reflect.TypeOf(mw.CalledReader)
267					// We have to handle this special case for genericWriteTo in os,
268					// this struct is introduced to support a zero-copy optimization,
269					// check out https://go.dev/issue/58808 for details.
270					if actualReader.Kind() == reflect.Struct && actualReader.PkgPath() == "os" && actualReader.Name() == "fileWithoutWriteTo" {
271						actualReader = actualReader.Field(1).Type
272					}
273				}
274
275				if tc.expectedReader != actualReader {
276					t.Fatalf("got reader %s want %s", actualReader, tc.expectedReader)
277				}
278			}
279
280			if tc.expectedWrite && !mw.WriteCalled {
281				t.Fatal("did not invoke Write")
282			}
283		})
284	}
285}
286
287func TestParseTransferEncoding(t *testing.T) {
288	tests := []struct {
289		hdr     Header
290		wantErr error
291	}{
292		{
293			hdr:     Header{"Transfer-Encoding": {"fugazi"}},
294			wantErr: &unsupportedTEError{`unsupported transfer encoding: "fugazi"`},
295		},
296		{
297			hdr:     Header{"Transfer-Encoding": {"chunked, chunked", "identity", "chunked"}},
298			wantErr: &unsupportedTEError{`too many transfer encodings: ["chunked, chunked" "identity" "chunked"]`},
299		},
300		{
301			hdr:     Header{"Transfer-Encoding": {""}},
302			wantErr: &unsupportedTEError{`unsupported transfer encoding: ""`},
303		},
304		{
305			hdr:     Header{"Transfer-Encoding": {"chunked, identity"}},
306			wantErr: &unsupportedTEError{`unsupported transfer encoding: "chunked, identity"`},
307		},
308		{
309			hdr:     Header{"Transfer-Encoding": {"chunked", "identity"}},
310			wantErr: &unsupportedTEError{`too many transfer encodings: ["chunked" "identity"]`},
311		},
312		{
313			hdr:     Header{"Transfer-Encoding": {"\x0bchunked"}},
314			wantErr: &unsupportedTEError{`unsupported transfer encoding: "\vchunked"`},
315		},
316		{
317			hdr:     Header{"Transfer-Encoding": {"chunked"}},
318			wantErr: nil,
319		},
320	}
321
322	for i, tt := range tests {
323		tr := &transferReader{
324			Header:     tt.hdr,
325			ProtoMajor: 1,
326			ProtoMinor: 1,
327		}
328		gotErr := tr.parseTransferEncoding()
329		if !reflect.DeepEqual(gotErr, tt.wantErr) {
330			t.Errorf("%d.\ngot error:\n%v\nwant error:\n%v\n\n", i, gotErr, tt.wantErr)
331		}
332	}
333}
334
335// issue 39017 - disallow Content-Length values such as "+3"
336func TestParseContentLength(t *testing.T) {
337	tests := []struct {
338		cl      string
339		wantErr error
340	}{
341		{
342			cl:      "",
343			wantErr: badStringError("invalid empty Content-Length", ""),
344		},
345		{
346			cl:      "3",
347			wantErr: nil,
348		},
349		{
350			cl:      "+3",
351			wantErr: badStringError("bad Content-Length", "+3"),
352		},
353		{
354			cl:      "-3",
355			wantErr: badStringError("bad Content-Length", "-3"),
356		},
357		{
358			// max int64, for safe conversion before returning
359			cl:      "9223372036854775807",
360			wantErr: nil,
361		},
362		{
363			cl:      "9223372036854775808",
364			wantErr: badStringError("bad Content-Length", "9223372036854775808"),
365		},
366	}
367
368	for _, tt := range tests {
369		if _, gotErr := parseContentLength([]string{tt.cl}); !reflect.DeepEqual(gotErr, tt.wantErr) {
370			t.Errorf("%q:\n\tgot=%v\n\twant=%v", tt.cl, gotErr, tt.wantErr)
371		}
372	}
373}
374