1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main
6
7import (
8	"bytes"
9	"context"
10	"flag"
11	"fmt"
12	"go/ast"
13	"go/parser"
14	"go/printer"
15	"go/scanner"
16	"go/token"
17	"internal/diff"
18	"io"
19	"io/fs"
20	"math/rand"
21	"os"
22	"path/filepath"
23	"runtime"
24	"runtime/pprof"
25	"strconv"
26	"strings"
27
28	"cmd/internal/telemetry/counter"
29
30	"golang.org/x/sync/semaphore"
31)
32
33var (
34	// main operation modes
35	list        = flag.Bool("l", false, "list files whose formatting differs from gofmt's")
36	write       = flag.Bool("w", false, "write result to (source) file instead of stdout")
37	rewriteRule = flag.String("r", "", "rewrite rule (e.g., 'a[b:len(a)] -> a[b:]')")
38	simplifyAST = flag.Bool("s", false, "simplify code")
39	doDiff      = flag.Bool("d", false, "display diffs instead of rewriting files")
40	allErrors   = flag.Bool("e", false, "report all errors (not just the first 10 on different lines)")
41
42	// debugging
43	cpuprofile = flag.String("cpuprofile", "", "write cpu profile to this file")
44)
45
46// Keep these in sync with go/format/format.go.
47const (
48	tabWidth    = 8
49	printerMode = printer.UseSpaces | printer.TabIndent | printerNormalizeNumbers
50
51	// printerNormalizeNumbers means to canonicalize number literal prefixes
52	// and exponents while printing. See https://golang.org/doc/go1.13#gofmt.
53	//
54	// This value is defined in go/printer specifically for go/format and cmd/gofmt.
55	printerNormalizeNumbers = 1 << 30
56)
57
58// fdSem guards the number of concurrently-open file descriptors.
59//
60// For now, this is arbitrarily set to 200, based on the observation that many
61// platforms default to a kernel limit of 256. Ideally, perhaps we should derive
62// it from rlimit on platforms that support that system call.
63//
64// File descriptors opened from outside of this package are not tracked,
65// so this limit may be approximate.
66var fdSem = make(chan bool, 200)
67
68var (
69	rewrite    func(*token.FileSet, *ast.File) *ast.File
70	parserMode parser.Mode
71)
72
73func usage() {
74	fmt.Fprintf(os.Stderr, "usage: gofmt [flags] [path ...]\n")
75	flag.PrintDefaults()
76}
77
78func initParserMode() {
79	parserMode = parser.ParseComments
80	if *allErrors {
81		parserMode |= parser.AllErrors
82	}
83	// It's only -r that makes use of go/ast's object resolution,
84	// so avoid the unnecessary work if the flag isn't used.
85	if *rewriteRule == "" {
86		parserMode |= parser.SkipObjectResolution
87	}
88}
89
90func isGoFile(f fs.DirEntry) bool {
91	// ignore non-Go files
92	name := f.Name()
93	return !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") && !f.IsDir()
94}
95
96// A sequencer performs concurrent tasks that may write output, but emits that
97// output in a deterministic order.
98type sequencer struct {
99	maxWeight int64
100	sem       *semaphore.Weighted   // weighted by input bytes (an approximate proxy for memory overhead)
101	prev      <-chan *reporterState // 1-buffered
102}
103
104// newSequencer returns a sequencer that allows concurrent tasks up to maxWeight
105// and writes tasks' output to out and err.
106func newSequencer(maxWeight int64, out, err io.Writer) *sequencer {
107	sem := semaphore.NewWeighted(maxWeight)
108	prev := make(chan *reporterState, 1)
109	prev <- &reporterState{out: out, err: err}
110	return &sequencer{
111		maxWeight: maxWeight,
112		sem:       sem,
113		prev:      prev,
114	}
115}
116
117// exclusive is a weight that can be passed to a sequencer to cause
118// a task to be executed without any other concurrent tasks.
119const exclusive = -1
120
121// Add blocks until the sequencer has enough weight to spare, then adds f as a
122// task to be executed concurrently.
123//
124// If the weight is either negative or larger than the sequencer's maximum
125// weight, Add blocks until all other tasks have completed, then the task
126// executes exclusively (blocking all other calls to Add until it completes).
127//
128// f may run concurrently in a goroutine, but its output to the passed-in
129// reporter will be sequential relative to the other tasks in the sequencer.
130//
131// If f invokes a method on the reporter, execution of that method may block
132// until the previous task has finished. (To maximize concurrency, f should
133// avoid invoking the reporter until it has finished any parallelizable work.)
134//
135// If f returns a non-nil error, that error will be reported after f's output
136// (if any) and will cause a nonzero final exit code.
137func (s *sequencer) Add(weight int64, f func(*reporter) error) {
138	if weight < 0 || weight > s.maxWeight {
139		weight = s.maxWeight
140	}
141	if err := s.sem.Acquire(context.TODO(), weight); err != nil {
142		// Change the task from "execute f" to "report err".
143		weight = 0
144		f = func(*reporter) error { return err }
145	}
146
147	r := &reporter{prev: s.prev}
148	next := make(chan *reporterState, 1)
149	s.prev = next
150
151	// Start f in parallel: it can run until it invokes a method on r, at which
152	// point it will block until the previous task releases the output state.
153	go func() {
154		if err := f(r); err != nil {
155			r.Report(err)
156		}
157		next <- r.getState() // Release the next task.
158		s.sem.Release(weight)
159	}()
160}
161
162// AddReport prints an error to s after the output of any previously-added
163// tasks, causing the final exit code to be nonzero.
164func (s *sequencer) AddReport(err error) {
165	s.Add(0, func(*reporter) error { return err })
166}
167
168// GetExitCode waits for all previously-added tasks to complete, then returns an
169// exit code for the sequence suitable for passing to os.Exit.
170func (s *sequencer) GetExitCode() int {
171	c := make(chan int, 1)
172	s.Add(0, func(r *reporter) error {
173		c <- r.ExitCode()
174		return nil
175	})
176	return <-c
177}
178
179// A reporter reports output, warnings, and errors.
180type reporter struct {
181	prev  <-chan *reporterState
182	state *reporterState
183}
184
185// reporterState carries the state of a reporter instance.
186//
187// Only one reporter at a time may have access to a reporterState.
188type reporterState struct {
189	out, err io.Writer
190	exitCode int
191}
192
193// getState blocks until any prior reporters are finished with the reporter
194// state, then returns the state for manipulation.
195func (r *reporter) getState() *reporterState {
196	if r.state == nil {
197		r.state = <-r.prev
198	}
199	return r.state
200}
201
202// Warnf emits a warning message to the reporter's error stream,
203// without changing its exit code.
204func (r *reporter) Warnf(format string, args ...any) {
205	fmt.Fprintf(r.getState().err, format, args...)
206}
207
208// Write emits a slice to the reporter's output stream.
209//
210// Any error is returned to the caller, and does not otherwise affect the
211// reporter's exit code.
212func (r *reporter) Write(p []byte) (int, error) {
213	return r.getState().out.Write(p)
214}
215
216// Report emits a non-nil error to the reporter's error stream,
217// changing its exit code to a nonzero value.
218func (r *reporter) Report(err error) {
219	if err == nil {
220		panic("Report with nil error")
221	}
222	st := r.getState()
223	scanner.PrintError(st.err, err)
224	st.exitCode = 2
225}
226
227func (r *reporter) ExitCode() int {
228	return r.getState().exitCode
229}
230
231// If info == nil, we are formatting stdin instead of a file.
232// If in == nil, the source is the contents of the file with the given filename.
233func processFile(filename string, info fs.FileInfo, in io.Reader, r *reporter) error {
234	src, err := readFile(filename, info, in)
235	if err != nil {
236		return err
237	}
238
239	fileSet := token.NewFileSet()
240	// If we are formatting stdin, we accept a program fragment in lieu of a
241	// complete source file.
242	fragmentOk := info == nil
243	file, sourceAdj, indentAdj, err := parse(fileSet, filename, src, fragmentOk)
244	if err != nil {
245		return err
246	}
247
248	if rewrite != nil {
249		if sourceAdj == nil {
250			file = rewrite(fileSet, file)
251		} else {
252			r.Warnf("warning: rewrite ignored for incomplete programs\n")
253		}
254	}
255
256	ast.SortImports(fileSet, file)
257
258	if *simplifyAST {
259		simplify(file)
260	}
261
262	res, err := format(fileSet, file, sourceAdj, indentAdj, src, printer.Config{Mode: printerMode, Tabwidth: tabWidth})
263	if err != nil {
264		return err
265	}
266
267	if !bytes.Equal(src, res) {
268		// formatting has changed
269		if *list {
270			fmt.Fprintln(r, filename)
271		}
272		if *write {
273			if info == nil {
274				panic("-w should not have been allowed with stdin")
275			}
276
277			perm := info.Mode().Perm()
278			if err := writeFile(filename, src, res, perm, info.Size()); err != nil {
279				return err
280			}
281		}
282		if *doDiff {
283			newName := filepath.ToSlash(filename)
284			oldName := newName + ".orig"
285			r.Write(diff.Diff(oldName, src, newName, res))
286		}
287	}
288
289	if !*list && !*write && !*doDiff {
290		_, err = r.Write(res)
291	}
292
293	return err
294}
295
296// readFile reads the contents of filename, described by info.
297// If in is non-nil, readFile reads directly from it.
298// Otherwise, readFile opens and reads the file itself,
299// with the number of concurrently-open files limited by fdSem.
300func readFile(filename string, info fs.FileInfo, in io.Reader) ([]byte, error) {
301	if in == nil {
302		fdSem <- true
303		var err error
304		f, err := os.Open(filename)
305		if err != nil {
306			return nil, err
307		}
308		in = f
309		defer func() {
310			f.Close()
311			<-fdSem
312		}()
313	}
314
315	// Compute the file's size and read its contents with minimal allocations.
316	//
317	// If we have the FileInfo from filepath.WalkDir, use it to make
318	// a buffer of the right size and avoid ReadAll's reallocations.
319	//
320	// If the size is unknown (or bogus, or overflows an int), fall back to
321	// a size-independent ReadAll.
322	size := -1
323	if info != nil && info.Mode().IsRegular() && int64(int(info.Size())) == info.Size() {
324		size = int(info.Size())
325	}
326	if size+1 <= 0 {
327		// The file is not known to be regular, so we don't have a reliable size for it.
328		var err error
329		src, err := io.ReadAll(in)
330		if err != nil {
331			return nil, err
332		}
333		return src, nil
334	}
335
336	// We try to read size+1 bytes so that we can detect modifications: if we
337	// read more than size bytes, then the file was modified concurrently.
338	// (If that happens, we could, say, append to src to finish the read, or
339	// proceed with a truncated buffer — but the fact that it changed at all
340	// indicates a possible race with someone editing the file, so we prefer to
341	// stop to avoid corrupting it.)
342	src := make([]byte, size+1)
343	n, err := io.ReadFull(in, src)
344	switch err {
345	case nil, io.EOF, io.ErrUnexpectedEOF:
346		// io.ReadFull returns io.EOF (for an empty file) or io.ErrUnexpectedEOF
347		// (for a non-empty file) if the file was changed unexpectedly. Continue
348		// with comparing file sizes in those cases.
349	default:
350		return nil, err
351	}
352	if n < size {
353		return nil, fmt.Errorf("error: size of %s changed during reading (from %d to %d bytes)", filename, size, n)
354	} else if n > size {
355		return nil, fmt.Errorf("error: size of %s changed during reading (from %d to >=%d bytes)", filename, size, len(src))
356	}
357	return src[:n], nil
358}
359
360func main() {
361	// Arbitrarily limit in-flight work to 2MiB times the number of threads.
362	//
363	// The actual overhead for the parse tree and output will depend on the
364	// specifics of the file, but this at least keeps the footprint of the process
365	// roughly proportional to GOMAXPROCS.
366	maxWeight := (2 << 20) * int64(runtime.GOMAXPROCS(0))
367	s := newSequencer(maxWeight, os.Stdout, os.Stderr)
368
369	// call gofmtMain in a separate function
370	// so that it can use defer and have them
371	// run before the exit.
372	gofmtMain(s)
373	os.Exit(s.GetExitCode())
374}
375
376func gofmtMain(s *sequencer) {
377	counter.Open()
378	flag.Usage = usage
379	flag.Parse()
380	counter.Inc("gofmt/invocations")
381	counter.CountFlags("gofmt/flag:", *flag.CommandLine)
382
383	if *cpuprofile != "" {
384		fdSem <- true
385		f, err := os.Create(*cpuprofile)
386		if err != nil {
387			s.AddReport(fmt.Errorf("creating cpu profile: %s", err))
388			return
389		}
390		defer func() {
391			f.Close()
392			<-fdSem
393		}()
394		pprof.StartCPUProfile(f)
395		defer pprof.StopCPUProfile()
396	}
397
398	initParserMode()
399	initRewrite()
400
401	args := flag.Args()
402	if len(args) == 0 {
403		if *write {
404			s.AddReport(fmt.Errorf("error: cannot use -w with standard input"))
405			return
406		}
407		s.Add(0, func(r *reporter) error {
408			return processFile("<standard input>", nil, os.Stdin, r)
409		})
410		return
411	}
412
413	for _, arg := range args {
414		switch info, err := os.Stat(arg); {
415		case err != nil:
416			s.AddReport(err)
417		case !info.IsDir():
418			// Non-directory arguments are always formatted.
419			arg := arg
420			s.Add(fileWeight(arg, info), func(r *reporter) error {
421				return processFile(arg, info, nil, r)
422			})
423		default:
424			// Directories are walked, ignoring non-Go files.
425			err := filepath.WalkDir(arg, func(path string, f fs.DirEntry, err error) error {
426				if err != nil || !isGoFile(f) {
427					return err
428				}
429				info, err := f.Info()
430				if err != nil {
431					s.AddReport(err)
432					return nil
433				}
434				s.Add(fileWeight(path, info), func(r *reporter) error {
435					return processFile(path, info, nil, r)
436				})
437				return nil
438			})
439			if err != nil {
440				s.AddReport(err)
441			}
442		}
443	}
444}
445
446func fileWeight(path string, info fs.FileInfo) int64 {
447	if info == nil {
448		return exclusive
449	}
450	if info.Mode().Type() == fs.ModeSymlink {
451		var err error
452		info, err = os.Stat(path)
453		if err != nil {
454			return exclusive
455		}
456	}
457	if !info.Mode().IsRegular() {
458		// For non-regular files, FileInfo.Size is system-dependent and thus not a
459		// reliable indicator of weight.
460		return exclusive
461	}
462	return info.Size()
463}
464
465// writeFile updates a file with the new formatted data.
466func writeFile(filename string, orig, formatted []byte, perm fs.FileMode, size int64) error {
467	// Make a temporary backup file before rewriting the original file.
468	bakname, err := backupFile(filename, orig, perm)
469	if err != nil {
470		return err
471	}
472
473	fdSem <- true
474	defer func() { <-fdSem }()
475
476	fout, err := os.OpenFile(filename, os.O_WRONLY, perm)
477	if err != nil {
478		// We couldn't even open the file, so it should
479		// not have changed.
480		os.Remove(bakname)
481		return err
482	}
483	defer fout.Close() // for error paths
484
485	restoreFail := func(err error) {
486		fmt.Fprintf(os.Stderr, "gofmt: %s: error restoring file to original: %v; backup in %s\n", filename, err, bakname)
487	}
488
489	n, err := fout.Write(formatted)
490	if err == nil && int64(n) < size {
491		err = fout.Truncate(int64(n))
492	}
493
494	if err != nil {
495		// Rewriting the file failed.
496
497		if n == 0 {
498			// Original file unchanged.
499			os.Remove(bakname)
500			return err
501		}
502
503		// Try to restore the original contents.
504
505		no, erro := fout.WriteAt(orig, 0)
506		if erro != nil {
507			// That failed too.
508			restoreFail(erro)
509			return err
510		}
511
512		if no < n {
513			// Original file is shorter. Truncate.
514			if erro = fout.Truncate(int64(no)); erro != nil {
515				restoreFail(erro)
516				return err
517			}
518		}
519
520		if erro := fout.Close(); erro != nil {
521			restoreFail(erro)
522			return err
523		}
524
525		// Original contents restored.
526		os.Remove(bakname)
527		return err
528	}
529
530	if err := fout.Close(); err != nil {
531		restoreFail(err)
532		return err
533	}
534
535	// File updated.
536	os.Remove(bakname)
537	return nil
538}
539
540// backupFile writes data to a new file named filename<number> with permissions perm,
541// with <number> randomly chosen such that the file name is unique. backupFile returns
542// the chosen file name.
543func backupFile(filename string, data []byte, perm fs.FileMode) (string, error) {
544	fdSem <- true
545	defer func() { <-fdSem }()
546
547	nextRandom := func() string {
548		return strconv.Itoa(rand.Int())
549	}
550
551	dir, base := filepath.Split(filename)
552	var (
553		bakname string
554		f       *os.File
555	)
556	for {
557		bakname = filepath.Join(dir, base+"."+nextRandom())
558		var err error
559		f, err = os.OpenFile(bakname, os.O_RDWR|os.O_CREATE|os.O_EXCL, perm)
560		if err == nil {
561			break
562		}
563		if !os.IsExist(err) {
564			return "", err
565		}
566	}
567
568	// write data to backup file
569	_, err := f.Write(data)
570	if err1 := f.Close(); err == nil {
571		err = err1
572	}
573
574	return bakname, err
575}
576