1// Copyright 2022 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ld
6
7import (
8	"cmd/internal/obj"
9	"cmd/internal/objabi"
10	"cmd/link/internal/loader"
11	"fmt"
12	"internal/buildcfg"
13	"sort"
14	"strings"
15)
16
17type stackCheck struct {
18	ctxt      *Link
19	ldr       *loader.Loader
20	morestack loader.Sym
21	callSize  int // The number of bytes added by a CALL
22
23	// height records the maximum number of bytes a function and
24	// its callees can add to the stack without a split check.
25	height map[loader.Sym]int16
26
27	// graph records the out-edges from each symbol. This is only
28	// populated on a second pass if the first pass reveals an
29	// over-limit function.
30	graph map[loader.Sym][]stackCheckEdge
31}
32
33type stackCheckEdge struct {
34	growth int        // Stack growth in bytes at call to target
35	target loader.Sym // 0 for stack growth without a call
36}
37
38// stackCheckCycle is a sentinel stored in the height map to detect if
39// we've found a cycle. This is effectively an "infinite" stack
40// height, so we use the closest value to infinity that we can.
41const stackCheckCycle int16 = 1<<15 - 1
42
43// stackCheckIndirect is a sentinel Sym value used to represent the
44// target of an indirect/closure call.
45const stackCheckIndirect loader.Sym = ^loader.Sym(0)
46
47// doStackCheck walks the call tree to check that there is always
48// enough stack space for call frames, especially for a chain of
49// nosplit functions.
50//
51// It walks all functions to accumulate the number of bytes they can
52// grow the stack by without a split check and checks this against the
53// limit.
54func (ctxt *Link) doStackCheck() {
55	sc := newStackCheck(ctxt, false)
56
57	// limit is number of bytes a splittable function ensures are
58	// available on the stack. If any call chain exceeds this
59	// depth, the stack check test fails.
60	//
61	// The call to morestack in every splittable function ensures
62	// that there are at least StackLimit bytes available below SP
63	// when morestack returns.
64	limit := objabi.StackNosplit(*flagRace) - sc.callSize
65	if buildcfg.GOARCH == "arm64" {
66		// Need an extra 8 bytes below SP to save FP.
67		limit -= 8
68	}
69
70	// Compute stack heights without any back-tracking information.
71	// This will almost certainly succeed and we can simply
72	// return. If it fails, we do a second pass with back-tracking
73	// to produce a good error message.
74	//
75	// This accumulates stack heights bottom-up so it only has to
76	// visit every function once.
77	var failed []loader.Sym
78	for _, s := range ctxt.Textp {
79		if sc.check(s) > limit {
80			failed = append(failed, s)
81		}
82	}
83
84	if len(failed) > 0 {
85		// Something was over-limit, so now we do the more
86		// expensive work to report a good error. First, for
87		// the over-limit functions, redo the stack check but
88		// record the graph this time.
89		sc = newStackCheck(ctxt, true)
90		for _, s := range failed {
91			sc.check(s)
92		}
93
94		// Find the roots of the graph (functions that are not
95		// called by any other function).
96		roots := sc.findRoots()
97
98		// Find and report all paths that go over the limit.
99		// This accumulates stack depths top-down. This is
100		// much less efficient because we may have to visit
101		// the same function multiple times at different
102		// depths, but lets us find all paths.
103		for _, root := range roots {
104			ctxt.Errorf(root, "nosplit stack over %d byte limit", limit)
105			chain := []stackCheckChain{{stackCheckEdge{0, root}, false}}
106			sc.report(root, limit, &chain)
107		}
108	}
109}
110
111func newStackCheck(ctxt *Link, graph bool) *stackCheck {
112	sc := &stackCheck{
113		ctxt:      ctxt,
114		ldr:       ctxt.loader,
115		morestack: ctxt.loader.Lookup("runtime.morestack", 0),
116		height:    make(map[loader.Sym]int16, len(ctxt.Textp)),
117	}
118	// Compute stack effect of a CALL operation. 0 on LR machines.
119	// 1 register pushed on non-LR machines.
120	if !ctxt.Arch.HasLR {
121		sc.callSize = ctxt.Arch.RegSize
122	}
123
124	if graph {
125		// We're going to record the call graph.
126		sc.graph = make(map[loader.Sym][]stackCheckEdge)
127	}
128
129	return sc
130}
131
132func (sc *stackCheck) symName(sym loader.Sym) string {
133	switch sym {
134	case stackCheckIndirect:
135		return "indirect"
136	case 0:
137		return "leaf"
138	}
139	return fmt.Sprintf("%s<%d>", sc.ldr.SymName(sym), sc.ldr.SymVersion(sym))
140}
141
142// check returns the stack height of sym. It populates sc.height and
143// sc.graph for sym and every function in its call tree.
144func (sc *stackCheck) check(sym loader.Sym) int {
145	if h, ok := sc.height[sym]; ok {
146		// We've already visited this symbol or we're in a cycle.
147		return int(h)
148	}
149	// Store the sentinel so we can detect cycles.
150	sc.height[sym] = stackCheckCycle
151	// Compute and record the height and optionally edges.
152	h, edges := sc.computeHeight(sym, *flagDebugNosplit || sc.graph != nil)
153	if h > int(stackCheckCycle) { // Prevent integer overflow
154		h = int(stackCheckCycle)
155	}
156	sc.height[sym] = int16(h)
157	if sc.graph != nil {
158		sc.graph[sym] = edges
159	}
160
161	if *flagDebugNosplit {
162		for _, edge := range edges {
163			fmt.Printf("nosplit: %s +%d", sc.symName(sym), edge.growth)
164			if edge.target == 0 {
165				// Local stack growth or leaf function.
166				fmt.Printf("\n")
167			} else {
168				fmt.Printf(" -> %s\n", sc.symName(edge.target))
169			}
170		}
171	}
172
173	return h
174}
175
176// computeHeight returns the stack height of sym. If graph is true, it
177// also returns the out-edges of sym.
178//
179// Caching is applied to this in check. Call check instead of calling
180// this directly.
181func (sc *stackCheck) computeHeight(sym loader.Sym, graph bool) (int, []stackCheckEdge) {
182	ldr := sc.ldr
183
184	// Check special cases.
185	if sym == sc.morestack {
186		// morestack looks like it calls functions, but they
187		// either happen only when already on the system stack
188		// (where there is ~infinite space), or after
189		// switching to the system stack. Hence, its stack
190		// height on the user stack is 0.
191		return 0, nil
192	}
193	if sym == stackCheckIndirect {
194		// Assume that indirect/closure calls are always to
195		// splittable functions, so they just need enough room
196		// to call morestack.
197		return sc.callSize, []stackCheckEdge{{sc.callSize, sc.morestack}}
198	}
199
200	// Ignore calls to external functions. Assume that these calls
201	// are only ever happening on the system stack, where there's
202	// plenty of room.
203	if ldr.AttrExternal(sym) {
204		return 0, nil
205	}
206	if info := ldr.FuncInfo(sym); !info.Valid() { // also external
207		return 0, nil
208	}
209
210	// Track the maximum height of this function and, if we're
211	// recording the graph, its out-edges.
212	var edges []stackCheckEdge
213	maxHeight := 0
214	ctxt := sc.ctxt
215	// addEdge adds a stack growth out of this function to
216	// function "target" or, if target == 0, a local stack growth
217	// within the function.
218	addEdge := func(growth int, target loader.Sym) {
219		if graph {
220			edges = append(edges, stackCheckEdge{growth, target})
221		}
222		height := growth
223		if target != 0 { // Don't walk into the leaf "edge"
224			height += sc.check(target)
225		}
226		if height > maxHeight {
227			maxHeight = height
228		}
229	}
230
231	if !ldr.IsNoSplit(sym) {
232		// Splittable functions start with a call to
233		// morestack, after which their height is 0. Account
234		// for the height of the call to morestack.
235		addEdge(sc.callSize, sc.morestack)
236		return maxHeight, edges
237	}
238
239	// This function is nosplit, so it adjusts SP without a split
240	// check.
241	//
242	// Walk through SP adjustments in function, consuming relocs
243	// and following calls.
244	maxLocalHeight := 0
245	relocs, ri := ldr.Relocs(sym), 0
246	pcsp := obj.NewPCIter(uint32(ctxt.Arch.MinLC))
247	for pcsp.Init(ldr.Data(ldr.Pcsp(sym))); !pcsp.Done; pcsp.Next() {
248		// pcsp.value is in effect for [pcsp.pc, pcsp.nextpc).
249		height := int(pcsp.Value)
250		if height > maxLocalHeight {
251			maxLocalHeight = height
252		}
253
254		// Process calls in this span.
255		for ; ri < relocs.Count(); ri++ {
256			r := relocs.At(ri)
257			if uint32(r.Off()) >= pcsp.NextPC {
258				break
259			}
260			t := r.Type()
261			if t.IsDirectCall() || t == objabi.R_CALLIND {
262				growth := height + sc.callSize
263				var target loader.Sym
264				if t == objabi.R_CALLIND {
265					target = stackCheckIndirect
266				} else {
267					target = r.Sym()
268				}
269				addEdge(growth, target)
270			}
271		}
272	}
273	if maxLocalHeight > maxHeight {
274		// This is either a leaf function, or the function
275		// grew its stack to larger than the maximum call
276		// height between calls. Either way, record that local
277		// stack growth.
278		addEdge(maxLocalHeight, 0)
279	}
280
281	return maxHeight, edges
282}
283
284func (sc *stackCheck) findRoots() []loader.Sym {
285	// Collect all nodes.
286	nodes := make(map[loader.Sym]struct{})
287	for k := range sc.graph {
288		nodes[k] = struct{}{}
289	}
290
291	// Start a DFS from each node and delete all reachable
292	// children. If we encounter an unrooted cycle, this will
293	// delete everything in that cycle, so we detect this case and
294	// track the lowest-numbered node encountered in the cycle and
295	// put that node back as a root.
296	var walk func(origin, sym loader.Sym) (cycle bool, lowest loader.Sym)
297	walk = func(origin, sym loader.Sym) (cycle bool, lowest loader.Sym) {
298		if _, ok := nodes[sym]; !ok {
299			// We already deleted this node.
300			return false, 0
301		}
302		delete(nodes, sym)
303
304		if origin == sym {
305			// We found an unrooted cycle. We already
306			// deleted all children of this node. Walk
307			// back up, tracking the lowest numbered
308			// symbol in this cycle.
309			return true, sym
310		}
311
312		// Delete children of this node.
313		for _, out := range sc.graph[sym] {
314			if c, l := walk(origin, out.target); c {
315				cycle = true
316				if lowest == 0 {
317					// On first cycle detection,
318					// add sym to the set of
319					// lowest-numbered candidates.
320					lowest = sym
321				}
322				if l < lowest {
323					lowest = l
324				}
325			}
326		}
327		return
328	}
329	for k := range nodes {
330		// Delete all children of k.
331		for _, out := range sc.graph[k] {
332			if cycle, lowest := walk(k, out.target); cycle {
333				// This is an unrooted cycle so we
334				// just deleted everything. Put back
335				// the lowest-numbered symbol.
336				nodes[lowest] = struct{}{}
337			}
338		}
339	}
340
341	// Sort roots by height. This makes the result deterministic
342	// and also improves the error reporting.
343	var roots []loader.Sym
344	for k := range nodes {
345		roots = append(roots, k)
346	}
347	sort.Slice(roots, func(i, j int) bool {
348		h1, h2 := sc.height[roots[i]], sc.height[roots[j]]
349		if h1 != h2 {
350			return h1 > h2
351		}
352		// Secondary sort by Sym.
353		return roots[i] < roots[j]
354	})
355	return roots
356}
357
358type stackCheckChain struct {
359	stackCheckEdge
360	printed bool
361}
362
363func (sc *stackCheck) report(sym loader.Sym, depth int, chain *[]stackCheckChain) {
364	// Walk the out-edges of sym. We temporarily pull the edges
365	// out of the graph to detect cycles and prevent infinite
366	// recursion.
367	edges, ok := sc.graph[sym]
368	isCycle := !(ok || sym == 0)
369	delete(sc.graph, sym)
370	for _, out := range edges {
371		*chain = append(*chain, stackCheckChain{out, false})
372		sc.report(out.target, depth-out.growth, chain)
373		*chain = (*chain)[:len(*chain)-1]
374	}
375	sc.graph[sym] = edges
376
377	// If we've reached the end of a chain and it went over the
378	// stack limit or was a cycle that would eventually go over,
379	// print the whole chain.
380	//
381	// We should either be in morestack (which has no out-edges)
382	// or the sentinel 0 Sym "called" from a leaf function (which
383	// has no out-edges), or we came back around a cycle (possibly
384	// to ourselves) and edges was temporarily nil'd.
385	if len(edges) == 0 && (depth < 0 || isCycle) {
386		var indent string
387		for i := range *chain {
388			ent := &(*chain)[i]
389			if ent.printed {
390				// Already printed on an earlier part
391				// of this call tree.
392				continue
393			}
394			ent.printed = true
395
396			if i == 0 {
397				// chain[0] is just the root function,
398				// not a stack growth.
399				fmt.Printf("%s\n", sc.symName(ent.target))
400				continue
401			}
402
403			indent = strings.Repeat("    ", i)
404			fmt.Print(indent)
405			// Grows the stack X bytes and (maybe) calls Y.
406			fmt.Printf("grows %d bytes", ent.growth)
407			if ent.target == 0 {
408				// Not a call, just a leaf. Print nothing.
409			} else {
410				fmt.Printf(", calls %s", sc.symName(ent.target))
411			}
412			fmt.Printf("\n")
413		}
414		// Print how far over this chain went.
415		if isCycle {
416			fmt.Printf("%sinfinite cycle\n", indent)
417		} else {
418			fmt.Printf("%s%d bytes over limit\n", indent, -depth)
419		}
420	}
421}
422