1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main
6
7import (
8	"bytes"
9	"encoding/json"
10	"internal/trace"
11	"io"
12	"net/http/httptest"
13	"os"
14	"path/filepath"
15	"slices"
16	"strconv"
17	"strings"
18	"testing"
19	"time"
20
21	"internal/trace/raw"
22	"internal/trace/traceviewer/format"
23)
24
25func TestJSONTraceHandler(t *testing.T) {
26	testPaths, err := filepath.Glob("./testdata/*.test")
27	if err != nil {
28		t.Fatalf("discovering tests: %v", err)
29	}
30	for _, testPath := range testPaths {
31		t.Run(filepath.Base(testPath), func(t *testing.T) {
32			parsed := getTestTrace(t, testPath)
33			data := recordJSONTraceHandlerResponse(t, parsed)
34			// TODO(mknyszek): Check that there's one at most goroutine per proc at any given time.
35			checkExecutionTimes(t, data)
36			checkPlausibleHeapMetrics(t, data)
37			// TODO(mknyszek): Check for plausible thread and goroutine metrics.
38			checkMetaNamesEmitted(t, data, "process_name", []string{"STATS", "PROCS"})
39			checkMetaNamesEmitted(t, data, "thread_name", []string{"GC", "Network", "Timers", "Syscalls", "Proc 0"})
40			checkProcStartStop(t, data)
41			checkSyscalls(t, data)
42			checkNetworkUnblock(t, data)
43			// TODO(mknyszek): Check for flow events.
44		})
45	}
46}
47
48func checkSyscalls(t *testing.T, data format.Data) {
49	data = filterViewerTrace(data,
50		filterEventName("syscall"),
51		filterStackRootFunc("main.blockingSyscall"))
52	if len(data.Events) <= 1 {
53		t.Errorf("got %d events, want > 1", len(data.Events))
54	}
55	data = filterViewerTrace(data, filterBlocked("yes"))
56	if len(data.Events) != 1 {
57		t.Errorf("got %d events, want 1", len(data.Events))
58	}
59}
60
61type eventFilterFn func(*format.Event, *format.Data) bool
62
63func filterEventName(name string) eventFilterFn {
64	return func(e *format.Event, _ *format.Data) bool {
65		return e.Name == name
66	}
67}
68
69// filterGoRoutineName returns an event filter that returns true if the event's
70// goroutine name is equal to name.
71func filterGoRoutineName(name string) eventFilterFn {
72	return func(e *format.Event, _ *format.Data) bool {
73		return parseGoroutineName(e) == name
74	}
75}
76
77// parseGoroutineName returns the goroutine name from the event's name field.
78// E.g. if e.Name is "G42 main.cpu10", this returns "main.cpu10".
79func parseGoroutineName(e *format.Event) string {
80	parts := strings.SplitN(e.Name, " ", 2)
81	if len(parts) != 2 || !strings.HasPrefix(parts[0], "G") {
82		return ""
83	}
84	return parts[1]
85}
86
87// filterBlocked returns an event filter that returns true if the event's
88// "blocked" argument is equal to blocked.
89func filterBlocked(blocked string) eventFilterFn {
90	return func(e *format.Event, _ *format.Data) bool {
91		m, ok := e.Arg.(map[string]any)
92		if !ok {
93			return false
94		}
95		return m["blocked"] == blocked
96	}
97}
98
99// filterStackRootFunc returns an event filter that returns true if the function
100// at the root of the stack trace is named name.
101func filterStackRootFunc(name string) eventFilterFn {
102	return func(e *format.Event, data *format.Data) bool {
103		frames := stackFrames(data, e.Stack)
104		rootFrame := frames[len(frames)-1]
105		return strings.HasPrefix(rootFrame, name+":")
106	}
107}
108
109// filterViewerTrace returns a copy of data with only the events that pass all
110// of the given filters.
111func filterViewerTrace(data format.Data, fns ...eventFilterFn) (filtered format.Data) {
112	filtered = data
113	filtered.Events = nil
114	for _, e := range data.Events {
115		keep := true
116		for _, fn := range fns {
117			keep = keep && fn(e, &filtered)
118		}
119		if keep {
120			filtered.Events = append(filtered.Events, e)
121		}
122	}
123	return
124}
125
126func stackFrames(data *format.Data, stackID int) (frames []string) {
127	for {
128		frame, ok := data.Frames[strconv.Itoa(stackID)]
129		if !ok {
130			return
131		}
132		frames = append(frames, frame.Name)
133		stackID = frame.Parent
134	}
135}
136
137func checkProcStartStop(t *testing.T, data format.Data) {
138	procStarted := map[uint64]bool{}
139	for _, e := range data.Events {
140		if e.Name == "proc start" {
141			if procStarted[e.TID] == true {
142				t.Errorf("proc started twice: %d", e.TID)
143			}
144			procStarted[e.TID] = true
145		}
146		if e.Name == "proc stop" {
147			if procStarted[e.TID] == false {
148				t.Errorf("proc stopped twice: %d", e.TID)
149			}
150			procStarted[e.TID] = false
151		}
152	}
153	if got, want := len(procStarted), 8; got != want {
154		t.Errorf("wrong number of procs started/stopped got=%d want=%d", got, want)
155	}
156}
157
158func checkNetworkUnblock(t *testing.T, data format.Data) {
159	count := 0
160	var netBlockEv *format.Event
161	for _, e := range data.Events {
162		if e.TID == trace.NetpollP && e.Name == "unblock (network)" && e.Phase == "I" && e.Scope == "t" {
163			count++
164			netBlockEv = e
165		}
166	}
167	if netBlockEv == nil {
168		t.Error("failed to find a network unblock")
169	}
170	if count == 0 {
171		t.Errorf("found zero network block events, want at least one")
172	}
173	// TODO(mknyszek): Check for the flow of this event to some slice event of a goroutine running.
174}
175
176func checkExecutionTimes(t *testing.T, data format.Data) {
177	cpu10 := sumExecutionTime(filterViewerTrace(data, filterGoRoutineName("main.cpu10")))
178	cpu20 := sumExecutionTime(filterViewerTrace(data, filterGoRoutineName("main.cpu20")))
179	if cpu10 <= 0 || cpu20 <= 0 || cpu10 >= cpu20 {
180		t.Errorf("bad execution times: cpu10=%v, cpu20=%v", cpu10, cpu20)
181	}
182}
183
184func checkMetaNamesEmitted(t *testing.T, data format.Data, category string, want []string) {
185	t.Helper()
186	names := metaEventNameArgs(category, data)
187	for _, wantName := range want {
188		if !slices.Contains(names, wantName) {
189			t.Errorf("%s: names=%v, want %q", category, names, wantName)
190		}
191	}
192}
193
194func metaEventNameArgs(category string, data format.Data) (names []string) {
195	for _, e := range data.Events {
196		if e.Name == category && e.Phase == "M" {
197			names = append(names, e.Arg.(map[string]any)["name"].(string))
198		}
199	}
200	return
201}
202
203func checkPlausibleHeapMetrics(t *testing.T, data format.Data) {
204	hms := heapMetrics(data)
205	var nonZeroAllocated, nonZeroNextGC bool
206	for _, hm := range hms {
207		if hm.Allocated > 0 {
208			nonZeroAllocated = true
209		}
210		if hm.NextGC > 0 {
211			nonZeroNextGC = true
212		}
213	}
214
215	if !nonZeroAllocated {
216		t.Errorf("nonZeroAllocated=%v, want true", nonZeroAllocated)
217	}
218	if !nonZeroNextGC {
219		t.Errorf("nonZeroNextGC=%v, want true", nonZeroNextGC)
220	}
221}
222
223func heapMetrics(data format.Data) (metrics []format.HeapCountersArg) {
224	for _, e := range data.Events {
225		if e.Phase == "C" && e.Name == "Heap" {
226			j, _ := json.Marshal(e.Arg)
227			var metric format.HeapCountersArg
228			json.Unmarshal(j, &metric)
229			metrics = append(metrics, metric)
230		}
231	}
232	return
233}
234
235func recordJSONTraceHandlerResponse(t *testing.T, parsed *parsedTrace) format.Data {
236	h := JSONTraceHandler(parsed)
237	recorder := httptest.NewRecorder()
238	r := httptest.NewRequest("GET", "/jsontrace", nil)
239	h.ServeHTTP(recorder, r)
240
241	var data format.Data
242	if err := json.Unmarshal(recorder.Body.Bytes(), &data); err != nil {
243		t.Fatal(err)
244	}
245	return data
246}
247
248func sumExecutionTime(data format.Data) (sum time.Duration) {
249	for _, e := range data.Events {
250		sum += time.Duration(e.Dur) * time.Microsecond
251	}
252	return
253}
254
255func getTestTrace(t *testing.T, testPath string) *parsedTrace {
256	t.Helper()
257
258	// First read in the text trace and write it out as bytes.
259	f, err := os.Open(testPath)
260	if err != nil {
261		t.Fatalf("failed to open test %s: %v", testPath, err)
262	}
263	r, err := raw.NewTextReader(f)
264	if err != nil {
265		t.Fatalf("failed to read test %s: %v", testPath, err)
266	}
267	var trace bytes.Buffer
268	w, err := raw.NewWriter(&trace, r.Version())
269	if err != nil {
270		t.Fatalf("failed to write out test %s: %v", testPath, err)
271	}
272	for {
273		ev, err := r.ReadEvent()
274		if err == io.EOF {
275			break
276		}
277		if err != nil {
278			t.Fatalf("failed to read test %s: %v", testPath, err)
279		}
280		if err := w.WriteEvent(ev); err != nil {
281			t.Fatalf("failed to write out test %s: %v", testPath, err)
282		}
283	}
284
285	// Parse the test trace.
286	parsed, err := parseTrace(&trace, int64(trace.Len()))
287	if err != nil {
288		t.Fatalf("failed to parse trace: %v", err)
289	}
290	return parsed
291}
292