1// Copyright 2011 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package main 6 7import ( 8 "fmt" 9 "go/ast" 10 "go/parser" 11 "internal/diff" 12 "internal/testenv" 13 "strings" 14 "testing" 15) 16 17type testCase struct { 18 Name string 19 Fn func(*ast.File) bool 20 Version string 21 In string 22 Out string 23} 24 25var testCases []testCase 26 27func addTestCases(t []testCase, fn func(*ast.File) bool) { 28 // Fill in fn to avoid repetition in definitions. 29 if fn != nil { 30 for i := range t { 31 if t[i].Fn == nil { 32 t[i].Fn = fn 33 } 34 } 35 } 36 testCases = append(testCases, t...) 37} 38 39func fnop(*ast.File) bool { return false } 40 41func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) { 42 file, err := parser.ParseFile(fset, desc, in, parserMode) 43 if err != nil { 44 t.Errorf("parsing: %v", err) 45 return 46 } 47 48 outb, err := gofmtFile(file) 49 if err != nil { 50 t.Errorf("printing: %v", err) 51 return 52 } 53 if s := string(outb); in != s && mustBeGofmt { 54 t.Errorf("not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s", 55 desc, in, desc, s) 56 tdiff(t, "want", in, "have", s) 57 return 58 } 59 60 if fn == nil { 61 for _, fix := range fixes { 62 if fix.f(file) { 63 fixed = true 64 } 65 } 66 } else { 67 fixed = fn(file) 68 } 69 70 outb, err = gofmtFile(file) 71 if err != nil { 72 t.Errorf("printing: %v", err) 73 return 74 } 75 76 return string(outb), fixed, true 77} 78 79func TestRewrite(t *testing.T) { 80 // If cgo is enabled, enforce that cgo commands invoked by cmd/fix 81 // do not fail during testing. 82 if testenv.HasCGO() { 83 testenv.MustHaveGoBuild(t) // Really just 'go tool cgo', but close enough. 84 85 // The reportCgoError hook is global, so we can't set it per-test 86 // if we want to be able to run those tests in parallel. 87 // Instead, simply set it to panic on error: the goroutine dump 88 // from the panic should help us determine which test failed. 89 prevReportCgoError := reportCgoError 90 reportCgoError = func(err error) { 91 panic(fmt.Sprintf("unexpected cgo error: %v", err)) 92 } 93 t.Cleanup(func() { reportCgoError = prevReportCgoError }) 94 } 95 96 for _, tt := range testCases { 97 tt := tt 98 t.Run(tt.Name, func(t *testing.T) { 99 if tt.Version == "" { 100 if testing.Verbose() { 101 // Don't run in parallel: cmd/fix sometimes writes directly to stderr, 102 // and since -v prints which test is currently running we want that 103 // information to accurately correlate with the stderr output. 104 } else { 105 t.Parallel() 106 } 107 } else { 108 old := *goVersion 109 *goVersion = tt.Version 110 defer func() { 111 *goVersion = old 112 }() 113 } 114 115 // Apply fix: should get tt.Out. 116 out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true) 117 if !ok { 118 return 119 } 120 121 // reformat to get printing right 122 out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false) 123 if !ok { 124 return 125 } 126 127 if tt.Out == "" { 128 tt.Out = tt.In 129 } 130 if out != tt.Out { 131 t.Errorf("incorrect output.\n") 132 if !strings.HasPrefix(tt.Name, "testdata/") { 133 t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out) 134 } 135 tdiff(t, "have", out, "want", tt.Out) 136 return 137 } 138 139 if changed := out != tt.In; changed != fixed { 140 t.Errorf("changed=%v != fixed=%v", changed, fixed) 141 return 142 } 143 144 // Should not change if run again. 145 out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true) 146 if !ok { 147 return 148 } 149 150 if fixed2 { 151 t.Errorf("applied fixes during second round") 152 return 153 } 154 155 if out2 != out { 156 t.Errorf("changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s", 157 out, out2) 158 tdiff(t, "first", out, "second", out2) 159 } 160 }) 161 } 162} 163 164func tdiff(t *testing.T, aname, a, bname, b string) { 165 t.Errorf("%s", diff.Diff(aname, []byte(a), bname, []byte(b))) 166} 167