1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package zip
6
7import (
8	"bytes"
9	"compress/flate"
10	"encoding/binary"
11	"fmt"
12	"hash/crc32"
13	"io"
14	"io/fs"
15	"math/rand"
16	"os"
17	"strings"
18	"testing"
19	"testing/fstest"
20	"time"
21)
22
23// TODO(adg): a more sophisticated test suite
24
25type WriteTest struct {
26	Name   string
27	Data   []byte
28	Method uint16
29	Mode   fs.FileMode
30}
31
32var writeTests = []WriteTest{
33	{
34		Name:   "foo",
35		Data:   []byte("Rabbits, guinea pigs, gophers, marsupial rats, and quolls."),
36		Method: Store,
37		Mode:   0666,
38	},
39	{
40		Name:   "bar",
41		Data:   nil, // large data set in the test
42		Method: Deflate,
43		Mode:   0644,
44	},
45	{
46		Name:   "setuid",
47		Data:   []byte("setuid file"),
48		Method: Deflate,
49		Mode:   0755 | fs.ModeSetuid,
50	},
51	{
52		Name:   "setgid",
53		Data:   []byte("setgid file"),
54		Method: Deflate,
55		Mode:   0755 | fs.ModeSetgid,
56	},
57	{
58		Name:   "symlink",
59		Data:   []byte("../link/target"),
60		Method: Deflate,
61		Mode:   0755 | fs.ModeSymlink,
62	},
63	{
64		Name:   "device",
65		Data:   []byte("device file"),
66		Method: Deflate,
67		Mode:   0755 | fs.ModeDevice,
68	},
69	{
70		Name:   "chardevice",
71		Data:   []byte("char device file"),
72		Method: Deflate,
73		Mode:   0755 | fs.ModeDevice | fs.ModeCharDevice,
74	},
75}
76
77func TestWriter(t *testing.T) {
78	largeData := make([]byte, 1<<17)
79	if _, err := rand.Read(largeData); err != nil {
80		t.Fatal("rand.Read failed:", err)
81	}
82	writeTests[1].Data = largeData
83	defer func() {
84		writeTests[1].Data = nil
85	}()
86
87	// write a zip file
88	buf := new(bytes.Buffer)
89	w := NewWriter(buf)
90
91	for _, wt := range writeTests {
92		testCreate(t, w, &wt)
93	}
94
95	if err := w.Close(); err != nil {
96		t.Fatal(err)
97	}
98
99	// read it back
100	r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
101	if err != nil {
102		t.Fatal(err)
103	}
104	for i, wt := range writeTests {
105		testReadFile(t, r.File[i], &wt)
106	}
107}
108
109// TestWriterComment is test for EOCD comment read/write.
110func TestWriterComment(t *testing.T) {
111	var tests = []struct {
112		comment string
113		ok      bool
114	}{
115		{"hi, hello", true},
116		{"hi, こんにちわ", true},
117		{strings.Repeat("a", uint16max), true},
118		{strings.Repeat("a", uint16max+1), false},
119	}
120
121	for _, test := range tests {
122		// write a zip file
123		buf := new(bytes.Buffer)
124		w := NewWriter(buf)
125		if err := w.SetComment(test.comment); err != nil {
126			if test.ok {
127				t.Fatalf("SetComment: unexpected error %v", err)
128			}
129			continue
130		} else {
131			if !test.ok {
132				t.Fatalf("SetComment: unexpected success, want error")
133			}
134		}
135
136		if err := w.Close(); test.ok == (err != nil) {
137			t.Fatal(err)
138		}
139
140		if w.closed != test.ok {
141			t.Fatalf("Writer.closed: got %v, want %v", w.closed, test.ok)
142		}
143
144		// skip read test in failure cases
145		if !test.ok {
146			continue
147		}
148
149		// read it back
150		r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
151		if err != nil {
152			t.Fatal(err)
153		}
154		if r.Comment != test.comment {
155			t.Fatalf("Reader.Comment: got %v, want %v", r.Comment, test.comment)
156		}
157	}
158}
159
160func TestWriterUTF8(t *testing.T) {
161	var utf8Tests = []struct {
162		name    string
163		comment string
164		nonUTF8 bool
165		flags   uint16
166	}{
167		{
168			name:    "hi, hello",
169			comment: "in the world",
170			flags:   0x8,
171		},
172		{
173			name:    "hi, こんにちわ",
174			comment: "in the world",
175			flags:   0x808,
176		},
177		{
178			name:    "hi, こんにちわ",
179			comment: "in the world",
180			nonUTF8: true,
181			flags:   0x8,
182		},
183		{
184			name:    "hi, hello",
185			comment: "in the 世界",
186			flags:   0x808,
187		},
188		{
189			name:    "hi, こんにちわ",
190			comment: "in the 世界",
191			flags:   0x808,
192		},
193		{
194			name:    "the replacement rune is �",
195			comment: "the replacement rune is �",
196			flags:   0x808,
197		},
198		{
199			// Name is Japanese encoded in Shift JIS.
200			name:    "\x93\xfa\x96{\x8c\xea.txt",
201			comment: "in the 世界",
202			flags:   0x008, // UTF-8 must not be set
203		},
204	}
205
206	// write a zip file
207	buf := new(bytes.Buffer)
208	w := NewWriter(buf)
209
210	for _, test := range utf8Tests {
211		h := &FileHeader{
212			Name:    test.name,
213			Comment: test.comment,
214			NonUTF8: test.nonUTF8,
215			Method:  Deflate,
216		}
217		w, err := w.CreateHeader(h)
218		if err != nil {
219			t.Fatal(err)
220		}
221		w.Write([]byte{})
222	}
223
224	if err := w.Close(); err != nil {
225		t.Fatal(err)
226	}
227
228	// read it back
229	r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
230	if err != nil {
231		t.Fatal(err)
232	}
233	for i, test := range utf8Tests {
234		flags := r.File[i].Flags
235		if flags != test.flags {
236			t.Errorf("CreateHeader(name=%q comment=%q nonUTF8=%v): flags=%#x, want %#x", test.name, test.comment, test.nonUTF8, flags, test.flags)
237		}
238	}
239}
240
241func TestWriterTime(t *testing.T) {
242	var buf bytes.Buffer
243	h := &FileHeader{
244		Name:     "test.txt",
245		Modified: time.Date(2017, 10, 31, 21, 11, 57, 0, timeZone(-7*time.Hour)),
246	}
247	w := NewWriter(&buf)
248	if _, err := w.CreateHeader(h); err != nil {
249		t.Fatalf("unexpected CreateHeader error: %v", err)
250	}
251	if err := w.Close(); err != nil {
252		t.Fatalf("unexpected Close error: %v", err)
253	}
254
255	want, err := os.ReadFile("testdata/time-go.zip")
256	if err != nil {
257		t.Fatalf("unexpected ReadFile error: %v", err)
258	}
259	if got := buf.Bytes(); !bytes.Equal(got, want) {
260		fmt.Printf("%x\n%x\n", got, want)
261		t.Error("contents of time-go.zip differ")
262	}
263}
264
265func TestWriterOffset(t *testing.T) {
266	largeData := make([]byte, 1<<17)
267	if _, err := rand.Read(largeData); err != nil {
268		t.Fatal("rand.Read failed:", err)
269	}
270	writeTests[1].Data = largeData
271	defer func() {
272		writeTests[1].Data = nil
273	}()
274
275	// write a zip file
276	buf := new(bytes.Buffer)
277	existingData := []byte{1, 2, 3, 1, 2, 3, 1, 2, 3}
278	n, _ := buf.Write(existingData)
279	w := NewWriter(buf)
280	w.SetOffset(int64(n))
281
282	for _, wt := range writeTests {
283		testCreate(t, w, &wt)
284	}
285
286	if err := w.Close(); err != nil {
287		t.Fatal(err)
288	}
289
290	// read it back
291	r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
292	if err != nil {
293		t.Fatal(err)
294	}
295	for i, wt := range writeTests {
296		testReadFile(t, r.File[i], &wt)
297	}
298}
299
300func TestWriterFlush(t *testing.T) {
301	var buf bytes.Buffer
302	w := NewWriter(struct{ io.Writer }{&buf})
303	_, err := w.Create("foo")
304	if err != nil {
305		t.Fatal(err)
306	}
307	if buf.Len() > 0 {
308		t.Fatalf("Unexpected %d bytes already in buffer", buf.Len())
309	}
310	if err := w.Flush(); err != nil {
311		t.Fatal(err)
312	}
313	if buf.Len() == 0 {
314		t.Fatal("No bytes written after Flush")
315	}
316}
317
318func TestWriterDir(t *testing.T) {
319	w := NewWriter(io.Discard)
320	dw, err := w.Create("dir/")
321	if err != nil {
322		t.Fatal(err)
323	}
324	if _, err := dw.Write(nil); err != nil {
325		t.Errorf("Write(nil) to directory: got %v, want nil", err)
326	}
327	if _, err := dw.Write([]byte("hello")); err == nil {
328		t.Error(`Write("hello") to directory: got nil error, want non-nil`)
329	}
330}
331
332func TestWriterDirAttributes(t *testing.T) {
333	var buf bytes.Buffer
334	w := NewWriter(&buf)
335	if _, err := w.CreateHeader(&FileHeader{
336		Name:               "dir/",
337		Method:             Deflate,
338		CompressedSize64:   1234,
339		UncompressedSize64: 5678,
340	}); err != nil {
341		t.Fatal(err)
342	}
343	if err := w.Close(); err != nil {
344		t.Fatal(err)
345	}
346	b := buf.Bytes()
347
348	var sig [4]byte
349	binary.LittleEndian.PutUint32(sig[:], uint32(fileHeaderSignature))
350
351	idx := bytes.Index(b, sig[:])
352	if idx == -1 {
353		t.Fatal("file header not found")
354	}
355	b = b[idx:]
356
357	if !bytes.Equal(b[6:10], []byte{0, 0, 0, 0}) { // FileHeader.Flags: 0, FileHeader.Method: 0
358		t.Errorf("unexpected method and flags: %v", b[6:10])
359	}
360
361	if !bytes.Equal(b[14:26], make([]byte, 12)) { // FileHeader.{CRC32,CompressSize,UncompressedSize} all zero.
362		t.Errorf("unexpected crc, compress and uncompressed size to be 0 was: %v", b[14:26])
363	}
364
365	binary.LittleEndian.PutUint32(sig[:], uint32(dataDescriptorSignature))
366	if bytes.Contains(b, sig[:]) {
367		t.Error("there should be no data descriptor")
368	}
369}
370
371func TestWriterCopy(t *testing.T) {
372	// make a zip file
373	buf := new(bytes.Buffer)
374	w := NewWriter(buf)
375	for _, wt := range writeTests {
376		testCreate(t, w, &wt)
377	}
378	if err := w.Close(); err != nil {
379		t.Fatal(err)
380	}
381
382	// read it back
383	src, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
384	if err != nil {
385		t.Fatal(err)
386	}
387	for i, wt := range writeTests {
388		testReadFile(t, src.File[i], &wt)
389	}
390
391	// make a new zip file copying the old compressed data.
392	buf2 := new(bytes.Buffer)
393	dst := NewWriter(buf2)
394	for _, f := range src.File {
395		if err := dst.Copy(f); err != nil {
396			t.Fatal(err)
397		}
398	}
399	if err := dst.Close(); err != nil {
400		t.Fatal(err)
401	}
402
403	// read the new one back
404	r, err := NewReader(bytes.NewReader(buf2.Bytes()), int64(buf2.Len()))
405	if err != nil {
406		t.Fatal(err)
407	}
408	for i, wt := range writeTests {
409		testReadFile(t, r.File[i], &wt)
410	}
411}
412
413func TestWriterCreateRaw(t *testing.T) {
414	files := []struct {
415		name             string
416		content          []byte
417		method           uint16
418		flags            uint16
419		crc32            uint32
420		uncompressedSize uint64
421		compressedSize   uint64
422	}{
423		{
424			name:    "small store w desc",
425			content: []byte("gophers"),
426			method:  Store,
427			flags:   0x8,
428		},
429		{
430			name:    "small deflate wo desc",
431			content: bytes.Repeat([]byte("abcdefg"), 2048),
432			method:  Deflate,
433		},
434	}
435
436	// write a zip file
437	archive := new(bytes.Buffer)
438	w := NewWriter(archive)
439
440	for i := range files {
441		f := &files[i]
442		f.crc32 = crc32.ChecksumIEEE(f.content)
443		size := uint64(len(f.content))
444		f.uncompressedSize = size
445		f.compressedSize = size
446
447		var compressedContent []byte
448		if f.method == Deflate {
449			var buf bytes.Buffer
450			w, err := flate.NewWriter(&buf, flate.BestSpeed)
451			if err != nil {
452				t.Fatalf("flate.NewWriter err = %v", err)
453			}
454			_, err = w.Write(f.content)
455			if err != nil {
456				t.Fatalf("flate Write err = %v", err)
457			}
458			err = w.Close()
459			if err != nil {
460				t.Fatalf("flate Writer.Close err = %v", err)
461			}
462			compressedContent = buf.Bytes()
463			f.compressedSize = uint64(len(compressedContent))
464		}
465
466		h := &FileHeader{
467			Name:               f.name,
468			Method:             f.method,
469			Flags:              f.flags,
470			CRC32:              f.crc32,
471			CompressedSize64:   f.compressedSize,
472			UncompressedSize64: f.uncompressedSize,
473		}
474		w, err := w.CreateRaw(h)
475		if err != nil {
476			t.Fatal(err)
477		}
478		if compressedContent != nil {
479			_, err = w.Write(compressedContent)
480		} else {
481			_, err = w.Write(f.content)
482		}
483		if err != nil {
484			t.Fatalf("%s Write got %v; want nil", f.name, err)
485		}
486	}
487
488	if err := w.Close(); err != nil {
489		t.Fatal(err)
490	}
491
492	// read it back
493	r, err := NewReader(bytes.NewReader(archive.Bytes()), int64(archive.Len()))
494	if err != nil {
495		t.Fatal(err)
496	}
497	for i, want := range files {
498		got := r.File[i]
499		if got.Name != want.name {
500			t.Errorf("got Name %s; want %s", got.Name, want.name)
501		}
502		if got.Method != want.method {
503			t.Errorf("%s: got Method %#x; want %#x", want.name, got.Method, want.method)
504		}
505		if got.Flags != want.flags {
506			t.Errorf("%s: got Flags %#x; want %#x", want.name, got.Flags, want.flags)
507		}
508		if got.CRC32 != want.crc32 {
509			t.Errorf("%s: got CRC32 %#x; want %#x", want.name, got.CRC32, want.crc32)
510		}
511		if got.CompressedSize64 != want.compressedSize {
512			t.Errorf("%s: got CompressedSize64 %d; want %d", want.name, got.CompressedSize64, want.compressedSize)
513		}
514		if got.UncompressedSize64 != want.uncompressedSize {
515			t.Errorf("%s: got UncompressedSize64 %d; want %d", want.name, got.UncompressedSize64, want.uncompressedSize)
516		}
517
518		r, err := got.Open()
519		if err != nil {
520			t.Errorf("%s: Open err = %v", got.Name, err)
521			continue
522		}
523
524		buf, err := io.ReadAll(r)
525		if err != nil {
526			t.Errorf("%s: ReadAll err = %v", got.Name, err)
527			continue
528		}
529
530		if !bytes.Equal(buf, want.content) {
531			t.Errorf("%v: ReadAll returned unexpected bytes", got.Name)
532		}
533	}
534}
535
536func testCreate(t *testing.T, w *Writer, wt *WriteTest) {
537	header := &FileHeader{
538		Name:   wt.Name,
539		Method: wt.Method,
540	}
541	if wt.Mode != 0 {
542		header.SetMode(wt.Mode)
543	}
544	f, err := w.CreateHeader(header)
545	if err != nil {
546		t.Fatal(err)
547	}
548	_, err = f.Write(wt.Data)
549	if err != nil {
550		t.Fatal(err)
551	}
552}
553
554func testReadFile(t *testing.T, f *File, wt *WriteTest) {
555	if f.Name != wt.Name {
556		t.Fatalf("File name: got %q, want %q", f.Name, wt.Name)
557	}
558	testFileMode(t, f, wt.Mode)
559	rc, err := f.Open()
560	if err != nil {
561		t.Fatalf("opening %s: %v", f.Name, err)
562	}
563	b, err := io.ReadAll(rc)
564	if err != nil {
565		t.Fatalf("reading %s: %v", f.Name, err)
566	}
567	err = rc.Close()
568	if err != nil {
569		t.Fatalf("closing %s: %v", f.Name, err)
570	}
571	if !bytes.Equal(b, wt.Data) {
572		t.Errorf("File contents %q, want %q", b, wt.Data)
573	}
574}
575
576func BenchmarkCompressedZipGarbage(b *testing.B) {
577	bigBuf := bytes.Repeat([]byte("a"), 1<<20)
578
579	runOnce := func(buf *bytes.Buffer) {
580		buf.Reset()
581		zw := NewWriter(buf)
582		for j := 0; j < 3; j++ {
583			w, _ := zw.CreateHeader(&FileHeader{
584				Name:   "foo",
585				Method: Deflate,
586			})
587			w.Write(bigBuf)
588		}
589		zw.Close()
590	}
591
592	b.ReportAllocs()
593	// Run once and then reset the timer.
594	// This effectively discards the very large initial flate setup cost,
595	// as well as the initialization of bigBuf.
596	runOnce(&bytes.Buffer{})
597	b.ResetTimer()
598
599	b.RunParallel(func(pb *testing.PB) {
600		var buf bytes.Buffer
601		for pb.Next() {
602			runOnce(&buf)
603		}
604	})
605}
606
607func writeTestsToFS(tests []WriteTest) fs.FS {
608	fsys := fstest.MapFS{}
609	for _, wt := range tests {
610		fsys[wt.Name] = &fstest.MapFile{
611			Data: wt.Data,
612			Mode: wt.Mode,
613		}
614	}
615	return fsys
616}
617
618func TestWriterAddFS(t *testing.T) {
619	buf := new(bytes.Buffer)
620	w := NewWriter(buf)
621	tests := []WriteTest{
622		{
623			Name: "file.go",
624			Data: []byte("hello"),
625			Mode: 0644,
626		},
627		{
628			Name: "subfolder/another.go",
629			Data: []byte("world"),
630			Mode: 0644,
631		},
632	}
633	err := w.AddFS(writeTestsToFS(tests))
634	if err != nil {
635		t.Fatal(err)
636	}
637
638	if err := w.Close(); err != nil {
639		t.Fatal(err)
640	}
641
642	// read it back
643	r, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
644	if err != nil {
645		t.Fatal(err)
646	}
647	for i, wt := range tests {
648		testReadFile(t, r.File[i], &wt)
649	}
650}
651
652func TestIssue61875(t *testing.T) {
653	buf := new(bytes.Buffer)
654	w := NewWriter(buf)
655	tests := []WriteTest{
656		{
657			Name:   "symlink",
658			Data:   []byte("../link/target"),
659			Method: Deflate,
660			Mode:   0755 | fs.ModeSymlink,
661		},
662		{
663			Name:   "device",
664			Data:   []byte(""),
665			Method: Deflate,
666			Mode:   0755 | fs.ModeDevice,
667		},
668	}
669	err := w.AddFS(writeTestsToFS(tests))
670	if err == nil {
671		t.Errorf("expected error, got nil")
672	}
673}
674