1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main
6
7import (
8	"fmt"
9	"go/ast"
10	"go/parser"
11	"internal/diff"
12	"internal/testenv"
13	"strings"
14	"testing"
15)
16
17type testCase struct {
18	Name    string
19	Fn      func(*ast.File) bool
20	Version string
21	In      string
22	Out     string
23}
24
25var testCases []testCase
26
27func addTestCases(t []testCase, fn func(*ast.File) bool) {
28	// Fill in fn to avoid repetition in definitions.
29	if fn != nil {
30		for i := range t {
31			if t[i].Fn == nil {
32				t[i].Fn = fn
33			}
34		}
35	}
36	testCases = append(testCases, t...)
37}
38
39func fnop(*ast.File) bool { return false }
40
41func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) {
42	file, err := parser.ParseFile(fset, desc, in, parserMode)
43	if err != nil {
44		t.Errorf("parsing: %v", err)
45		return
46	}
47
48	outb, err := gofmtFile(file)
49	if err != nil {
50		t.Errorf("printing: %v", err)
51		return
52	}
53	if s := string(outb); in != s && mustBeGofmt {
54		t.Errorf("not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
55			desc, in, desc, s)
56		tdiff(t, "want", in, "have", s)
57		return
58	}
59
60	if fn == nil {
61		for _, fix := range fixes {
62			if fix.f(file) {
63				fixed = true
64			}
65		}
66	} else {
67		fixed = fn(file)
68	}
69
70	outb, err = gofmtFile(file)
71	if err != nil {
72		t.Errorf("printing: %v", err)
73		return
74	}
75
76	return string(outb), fixed, true
77}
78
79func TestRewrite(t *testing.T) {
80	// If cgo is enabled, enforce that cgo commands invoked by cmd/fix
81	// do not fail during testing.
82	if testenv.HasCGO() {
83		testenv.MustHaveGoBuild(t) // Really just 'go tool cgo', but close enough.
84
85		// The reportCgoError hook is global, so we can't set it per-test
86		// if we want to be able to run those tests in parallel.
87		// Instead, simply set it to panic on error: the goroutine dump
88		// from the panic should help us determine which test failed.
89		prevReportCgoError := reportCgoError
90		reportCgoError = func(err error) {
91			panic(fmt.Sprintf("unexpected cgo error: %v", err))
92		}
93		t.Cleanup(func() { reportCgoError = prevReportCgoError })
94	}
95
96	for _, tt := range testCases {
97		tt := tt
98		t.Run(tt.Name, func(t *testing.T) {
99			if tt.Version == "" {
100				if testing.Verbose() {
101					// Don't run in parallel: cmd/fix sometimes writes directly to stderr,
102					// and since -v prints which test is currently running we want that
103					// information to accurately correlate with the stderr output.
104				} else {
105					t.Parallel()
106				}
107			} else {
108				old := *goVersion
109				*goVersion = tt.Version
110				defer func() {
111					*goVersion = old
112				}()
113			}
114
115			// Apply fix: should get tt.Out.
116			out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true)
117			if !ok {
118				return
119			}
120
121			// reformat to get printing right
122			out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false)
123			if !ok {
124				return
125			}
126
127			if tt.Out == "" {
128				tt.Out = tt.In
129			}
130			if out != tt.Out {
131				t.Errorf("incorrect output.\n")
132				if !strings.HasPrefix(tt.Name, "testdata/") {
133					t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out)
134				}
135				tdiff(t, "have", out, "want", tt.Out)
136				return
137			}
138
139			if changed := out != tt.In; changed != fixed {
140				t.Errorf("changed=%v != fixed=%v", changed, fixed)
141				return
142			}
143
144			// Should not change if run again.
145			out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true)
146			if !ok {
147				return
148			}
149
150			if fixed2 {
151				t.Errorf("applied fixes during second round")
152				return
153			}
154
155			if out2 != out {
156				t.Errorf("changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s",
157					out, out2)
158				tdiff(t, "first", out, "second", out2)
159			}
160		})
161	}
162}
163
164func tdiff(t *testing.T, aname, a, bname, b string) {
165	t.Errorf("%s", diff.Diff(aname, []byte(a), bname, []byte(b)))
166}
167