xref: /aosp_15_r20/build/blueprint/gotestmain/gotestmain.go (revision 1fa6dee971e1612fa5cc0aa5ca2d35a22e2c34a3)
1// Copyright 2015 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package main
16
17import (
18	"bytes"
19	"flag"
20	"fmt"
21	"go/ast"
22	"go/doc"
23	"go/parser"
24	"go/token"
25	"io/ioutil"
26	"os"
27	"reflect"
28	"sort"
29	"strings"
30	"testing"
31	"text/template"
32)
33
34var (
35	output   = flag.String("o", "", "output filename")
36	pkg      = flag.String("pkg", "", "test package")
37	exitCode = 0
38)
39
40type data struct {
41	Package               string
42	Tests                 []string
43	Examples              []*doc.Example
44	HasMain               bool
45	MainStartTakesFuzzers bool
46}
47
48func findTests(srcs []string) (tests []string, examples []*doc.Example, hasMain bool) {
49	for _, src := range srcs {
50		f, err := parser.ParseFile(token.NewFileSet(), src, nil, parser.ParseComments)
51		if err != nil {
52			panic(err)
53		}
54		for _, obj := range f.Scope.Objects {
55			if obj.Kind != ast.Fun || !strings.HasPrefix(obj.Name, "Test") {
56				continue
57			}
58			if obj.Name == "TestMain" {
59				hasMain = true
60			} else {
61				tests = append(tests, obj.Name)
62			}
63		}
64
65		examples = append(examples, doc.Examples(f)...)
66	}
67	sort.Strings(tests)
68	return
69}
70
71// Returns true for go1.18+, where testing.MainStart takes an extra slice of fuzzers.
72func mainStartTakesFuzzers() bool {
73	return reflect.TypeOf(testing.MainStart).NumIn() > 4
74}
75
76func main() {
77	flag.Parse()
78
79	if flag.NArg() == 0 {
80		fmt.Fprintln(os.Stderr, "error: must pass at least one input")
81		exitCode = 1
82		return
83	}
84
85	buf := &bytes.Buffer{}
86
87	tests, examples, hasMain := findTests(flag.Args())
88
89	d := data{
90		Package:               *pkg,
91		Tests:                 tests,
92		Examples:              examples,
93		HasMain:               hasMain,
94		MainStartTakesFuzzers: mainStartTakesFuzzers(),
95	}
96
97	err := testMainTmpl.Execute(buf, d)
98	if err != nil {
99		panic(err)
100	}
101
102	err = ioutil.WriteFile(*output, buf.Bytes(), 0666)
103	if err != nil {
104		panic(err)
105	}
106}
107
108var testMainTmpl = template.Must(template.New("testMain").Parse(`
109package main
110
111import (
112	"io"
113{{if not .HasMain}}
114	"os"
115{{end}}
116	"reflect"
117	"regexp"
118	"testing"
119	"time"
120
121	pkg "{{.Package}}"
122)
123
124var t = []testing.InternalTest{
125{{range .Tests}}
126	{"{{.}}", pkg.{{.}}},
127{{end}}
128}
129
130var e = []testing.InternalExample{
131{{range .Examples}}
132	{{if or .Output .EmptyOutput}}
133		{"{{.Name}}", pkg.Example{{.Name}}, {{.Output | printf "%q" }}, {{.Unordered}}},
134	{{end}}
135{{end}}
136}
137
138var matchPat string
139var matchRe *regexp.Regexp
140
141type matchString struct{}
142
143func MatchString(pat, str string) (result bool, err error) {
144	if matchRe == nil || matchPat != pat {
145		matchPat = pat
146		matchRe, err = regexp.Compile(matchPat)
147		if err != nil {
148			return
149		}
150	}
151	return matchRe.MatchString(str), nil
152}
153
154func (matchString) MatchString(pat, str string) (bool, error) {
155	return MatchString(pat, str)
156}
157
158func (matchString) StartCPUProfile(w io.Writer) error {
159	panic("shouldn't get here")
160}
161
162func (matchString) StopCPUProfile() {
163}
164
165func (matchString) WriteHeapProfile(w io.Writer) error {
166    panic("shouldn't get here")
167}
168
169func (matchString) WriteProfileTo(string, io.Writer, int) error {
170    panic("shouldn't get here")
171}
172
173func (matchString) ImportPath() string {
174	return "{{.Package}}"
175}
176
177func (matchString) StartTestLog(io.Writer) {
178	panic("shouldn't get here")
179}
180
181func (matchString) StopTestLog() error {
182	panic("shouldn't get here")
183}
184
185func (matchString) SetPanicOnExit0(bool) {
186	panic("shouldn't get here")
187}
188
189func (matchString) CoordinateFuzzing(time.Duration, int64, time.Duration, int64, int, []corpusEntry, []reflect.Type, string, string) error {
190	panic("shouldn't get here")
191}
192
193func (matchString) RunFuzzWorker(func(corpusEntry) error) error {
194	panic("shouldn't get here")
195}
196
197func (matchString) ReadCorpus(string, []reflect.Type) ([]corpusEntry, error) {
198	panic("shouldn't get here")
199}
200
201func (matchString) CheckCorpus([]interface{}, []reflect.Type) error {
202	panic("shouldn't get here")
203}
204
205func (matchString) ResetCoverage() {
206	panic("shouldn't get here")
207}
208
209func (matchString) SnapshotCoverage() {
210	panic("shouldn't get here")
211}
212
213func (f matchString) InitRuntimeCoverage() (mode string, tearDown func(string, string) (string, error), snapcov func() float64) {
214	return
215}
216
217type corpusEntry = struct {
218	Parent     string
219	Path       string
220	Data       []byte
221	Values     []interface{}
222	Generation int
223	IsSeed     bool
224}
225
226func main() {
227{{if .MainStartTakesFuzzers }}
228	m := testing.MainStart(matchString{}, t, nil, nil, e)
229{{else}}
230	m := testing.MainStart(matchString{}, t, nil, e)
231{{end}}
232{{if .HasMain}}
233	pkg.TestMain(m)
234{{else}}
235	os.Exit(m.Run())
236{{end}}
237}
238`))
239