1// Copyright 2011 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package main 6 7import ( 8 "bytes" 9 "flag" 10 "internal/diff" 11 "os" 12 "path/filepath" 13 "strings" 14 "testing" 15 "text/scanner" 16) 17 18var update = flag.Bool("update", false, "update .golden files") 19 20// gofmtFlags looks for a comment of the form 21// 22// //gofmt flags 23// 24// within the first maxLines lines of the given file, 25// and returns the flags string, if any. Otherwise it 26// returns the empty string. 27func gofmtFlags(filename string, maxLines int) string { 28 f, err := os.Open(filename) 29 if err != nil { 30 return "" // ignore errors - they will be found later 31 } 32 defer f.Close() 33 34 // initialize scanner 35 var s scanner.Scanner 36 s.Init(f) 37 s.Error = func(*scanner.Scanner, string) {} // ignore errors 38 s.Mode = scanner.GoTokens &^ scanner.SkipComments // want comments 39 40 // look for //gofmt comment 41 for s.Line <= maxLines { 42 switch s.Scan() { 43 case scanner.Comment: 44 const prefix = "//gofmt " 45 if t := s.TokenText(); strings.HasPrefix(t, prefix) { 46 return strings.TrimSpace(t[len(prefix):]) 47 } 48 case scanner.EOF: 49 return "" 50 } 51 } 52 53 return "" 54} 55 56func runTest(t *testing.T, in, out string) { 57 // process flags 58 *simplifyAST = false 59 *rewriteRule = "" 60 info, err := os.Lstat(in) 61 if err != nil { 62 t.Error(err) 63 return 64 } 65 for _, flag := range strings.Split(gofmtFlags(in, 20), " ") { 66 elts := strings.SplitN(flag, "=", 2) 67 name := elts[0] 68 value := "" 69 if len(elts) == 2 { 70 value = elts[1] 71 } 72 switch name { 73 case "": 74 // no flags 75 case "-r": 76 *rewriteRule = value 77 case "-s": 78 *simplifyAST = true 79 case "-stdin": 80 // fake flag - pretend input is from stdin 81 info = nil 82 default: 83 t.Errorf("unrecognized flag name: %s", name) 84 } 85 } 86 87 initParserMode() 88 initRewrite() 89 90 const maxWeight = 2 << 20 91 var buf, errBuf bytes.Buffer 92 s := newSequencer(maxWeight, &buf, &errBuf) 93 s.Add(fileWeight(in, info), func(r *reporter) error { 94 return processFile(in, info, nil, r) 95 }) 96 if errBuf.Len() > 0 { 97 t.Logf("%q", errBuf.Bytes()) 98 } 99 if s.GetExitCode() != 0 { 100 t.Fail() 101 } 102 103 expected, err := os.ReadFile(out) 104 if err != nil { 105 t.Error(err) 106 return 107 } 108 109 if got := buf.Bytes(); !bytes.Equal(got, expected) { 110 if *update { 111 if in != out { 112 if err := os.WriteFile(out, got, 0666); err != nil { 113 t.Error(err) 114 } 115 return 116 } 117 // in == out: don't accidentally destroy input 118 t.Errorf("WARNING: -update did not rewrite input file %s", in) 119 } 120 121 t.Errorf("(gofmt %s) != %s (see %s.gofmt)\n%s", in, out, in, 122 diff.Diff("expected", expected, "got", got)) 123 if err := os.WriteFile(in+".gofmt", got, 0666); err != nil { 124 t.Error(err) 125 } 126 } 127} 128 129// TestRewrite processes testdata/*.input files and compares them to the 130// corresponding testdata/*.golden files. The gofmt flags used to process 131// a file must be provided via a comment of the form 132// 133// //gofmt flags 134// 135// in the processed file within the first 20 lines, if any. 136func TestRewrite(t *testing.T) { 137 // determine input files 138 match, err := filepath.Glob("testdata/*.input") 139 if err != nil { 140 t.Fatal(err) 141 } 142 143 // add larger examples 144 match = append(match, "gofmt.go", "gofmt_test.go") 145 146 for _, in := range match { 147 name := filepath.Base(in) 148 t.Run(name, func(t *testing.T) { 149 out := in // for files where input and output are identical 150 if strings.HasSuffix(in, ".input") { 151 out = in[:len(in)-len(".input")] + ".golden" 152 } 153 runTest(t, in, out) 154 if in != out && !t.Failed() { 155 // Check idempotence. 156 runTest(t, out, out) 157 } 158 }) 159 } 160} 161 162// Test case for issue 3961. 163func TestCRLF(t *testing.T) { 164 const input = "testdata/crlf.input" // must contain CR/LF's 165 const golden = "testdata/crlf.golden" // must not contain any CR's 166 167 data, err := os.ReadFile(input) 168 if err != nil { 169 t.Error(err) 170 } 171 if !bytes.Contains(data, []byte("\r\n")) { 172 t.Errorf("%s contains no CR/LF's", input) 173 } 174 175 data, err = os.ReadFile(golden) 176 if err != nil { 177 t.Error(err) 178 } 179 if bytes.Contains(data, []byte("\r")) { 180 t.Errorf("%s contains CR's", golden) 181 } 182} 183 184func TestBackupFile(t *testing.T) { 185 dir, err := os.MkdirTemp("", "gofmt_test") 186 if err != nil { 187 t.Fatal(err) 188 } 189 defer os.RemoveAll(dir) 190 name, err := backupFile(filepath.Join(dir, "foo.go"), []byte(" package main"), 0644) 191 if err != nil { 192 t.Fatal(err) 193 } 194 t.Logf("Created: %s", name) 195} 196