1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main
6
7import (
8	"bytes"
9	"flag"
10	"internal/diff"
11	"os"
12	"path/filepath"
13	"strings"
14	"testing"
15	"text/scanner"
16)
17
18var update = flag.Bool("update", false, "update .golden files")
19
20// gofmtFlags looks for a comment of the form
21//
22//	//gofmt flags
23//
24// within the first maxLines lines of the given file,
25// and returns the flags string, if any. Otherwise it
26// returns the empty string.
27func gofmtFlags(filename string, maxLines int) string {
28	f, err := os.Open(filename)
29	if err != nil {
30		return "" // ignore errors - they will be found later
31	}
32	defer f.Close()
33
34	// initialize scanner
35	var s scanner.Scanner
36	s.Init(f)
37	s.Error = func(*scanner.Scanner, string) {}       // ignore errors
38	s.Mode = scanner.GoTokens &^ scanner.SkipComments // want comments
39
40	// look for //gofmt comment
41	for s.Line <= maxLines {
42		switch s.Scan() {
43		case scanner.Comment:
44			const prefix = "//gofmt "
45			if t := s.TokenText(); strings.HasPrefix(t, prefix) {
46				return strings.TrimSpace(t[len(prefix):])
47			}
48		case scanner.EOF:
49			return ""
50		}
51	}
52
53	return ""
54}
55
56func runTest(t *testing.T, in, out string) {
57	// process flags
58	*simplifyAST = false
59	*rewriteRule = ""
60	info, err := os.Lstat(in)
61	if err != nil {
62		t.Error(err)
63		return
64	}
65	for _, flag := range strings.Split(gofmtFlags(in, 20), " ") {
66		elts := strings.SplitN(flag, "=", 2)
67		name := elts[0]
68		value := ""
69		if len(elts) == 2 {
70			value = elts[1]
71		}
72		switch name {
73		case "":
74			// no flags
75		case "-r":
76			*rewriteRule = value
77		case "-s":
78			*simplifyAST = true
79		case "-stdin":
80			// fake flag - pretend input is from stdin
81			info = nil
82		default:
83			t.Errorf("unrecognized flag name: %s", name)
84		}
85	}
86
87	initParserMode()
88	initRewrite()
89
90	const maxWeight = 2 << 20
91	var buf, errBuf bytes.Buffer
92	s := newSequencer(maxWeight, &buf, &errBuf)
93	s.Add(fileWeight(in, info), func(r *reporter) error {
94		return processFile(in, info, nil, r)
95	})
96	if errBuf.Len() > 0 {
97		t.Logf("%q", errBuf.Bytes())
98	}
99	if s.GetExitCode() != 0 {
100		t.Fail()
101	}
102
103	expected, err := os.ReadFile(out)
104	if err != nil {
105		t.Error(err)
106		return
107	}
108
109	if got := buf.Bytes(); !bytes.Equal(got, expected) {
110		if *update {
111			if in != out {
112				if err := os.WriteFile(out, got, 0666); err != nil {
113					t.Error(err)
114				}
115				return
116			}
117			// in == out: don't accidentally destroy input
118			t.Errorf("WARNING: -update did not rewrite input file %s", in)
119		}
120
121		t.Errorf("(gofmt %s) != %s (see %s.gofmt)\n%s", in, out, in,
122			diff.Diff("expected", expected, "got", got))
123		if err := os.WriteFile(in+".gofmt", got, 0666); err != nil {
124			t.Error(err)
125		}
126	}
127}
128
129// TestRewrite processes testdata/*.input files and compares them to the
130// corresponding testdata/*.golden files. The gofmt flags used to process
131// a file must be provided via a comment of the form
132//
133//	//gofmt flags
134//
135// in the processed file within the first 20 lines, if any.
136func TestRewrite(t *testing.T) {
137	// determine input files
138	match, err := filepath.Glob("testdata/*.input")
139	if err != nil {
140		t.Fatal(err)
141	}
142
143	// add larger examples
144	match = append(match, "gofmt.go", "gofmt_test.go")
145
146	for _, in := range match {
147		name := filepath.Base(in)
148		t.Run(name, func(t *testing.T) {
149			out := in // for files where input and output are identical
150			if strings.HasSuffix(in, ".input") {
151				out = in[:len(in)-len(".input")] + ".golden"
152			}
153			runTest(t, in, out)
154			if in != out && !t.Failed() {
155				// Check idempotence.
156				runTest(t, out, out)
157			}
158		})
159	}
160}
161
162// Test case for issue 3961.
163func TestCRLF(t *testing.T) {
164	const input = "testdata/crlf.input"   // must contain CR/LF's
165	const golden = "testdata/crlf.golden" // must not contain any CR's
166
167	data, err := os.ReadFile(input)
168	if err != nil {
169		t.Error(err)
170	}
171	if !bytes.Contains(data, []byte("\r\n")) {
172		t.Errorf("%s contains no CR/LF's", input)
173	}
174
175	data, err = os.ReadFile(golden)
176	if err != nil {
177		t.Error(err)
178	}
179	if bytes.Contains(data, []byte("\r")) {
180		t.Errorf("%s contains CR's", golden)
181	}
182}
183
184func TestBackupFile(t *testing.T) {
185	dir, err := os.MkdirTemp("", "gofmt_test")
186	if err != nil {
187		t.Fatal(err)
188	}
189	defer os.RemoveAll(dir)
190	name, err := backupFile(filepath.Join(dir, "foo.go"), []byte("  package main"), 0644)
191	if err != nil {
192		t.Fatal(err)
193	}
194	t.Logf("Created: %s", name)
195}
196