1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package inlheur
6
7import (
8	"cmd/compile/internal/ir"
9	"fmt"
10	"os"
11)
12
13// paramsAnalyzer holds state information for the phase that computes
14// flags for a Go functions parameters, for use in inline heuristics.
15// Note that the params slice below includes entries for blanks.
16type paramsAnalyzer struct {
17	fname  string
18	values []ParamPropBits
19	params []*ir.Name
20	top    []bool
21	*condLevelTracker
22	*nameFinder
23}
24
25// getParams returns an *ir.Name slice containing all params for the
26// function (plus rcvr as well if applicable).
27func getParams(fn *ir.Func) []*ir.Name {
28	sig := fn.Type()
29	numParams := sig.NumRecvs() + sig.NumParams()
30	return fn.Dcl[:numParams]
31}
32
33// addParamsAnalyzer creates a new paramsAnalyzer helper object for
34// the function fn, appends it to the analyzers list, and returns the
35// new list. If the function in question doesn't have any interesting
36// parameters then the analyzer list is returned unchanged, and the
37// params flags in "fp" are updated accordingly.
38func addParamsAnalyzer(fn *ir.Func, analyzers []propAnalyzer, fp *FuncProps, nf *nameFinder) []propAnalyzer {
39	pa, props := makeParamsAnalyzer(fn, nf)
40	if pa != nil {
41		analyzers = append(analyzers, pa)
42	} else {
43		fp.ParamFlags = props
44	}
45	return analyzers
46}
47
48// makeParamsAnalyzer creates a new helper object to analyze parameters
49// of function fn. If the function doesn't have any interesting
50// params, a nil helper is returned along with a set of default param
51// flags for the func.
52func makeParamsAnalyzer(fn *ir.Func, nf *nameFinder) (*paramsAnalyzer, []ParamPropBits) {
53	params := getParams(fn) // includes receiver if applicable
54	if len(params) == 0 {
55		return nil, nil
56	}
57	vals := make([]ParamPropBits, len(params))
58	if fn.Inl == nil {
59		return nil, vals
60	}
61	top := make([]bool, len(params))
62	interestingToAnalyze := false
63	for i, pn := range params {
64		if pn == nil {
65			continue
66		}
67		pt := pn.Type()
68		if !pt.IsScalar() && !pt.HasNil() {
69			// existing properties not applicable here (for things
70			// like structs, arrays, slices, etc).
71			continue
72		}
73		// If param is reassigned, skip it.
74		if ir.Reassigned(pn) {
75			continue
76		}
77		top[i] = true
78		interestingToAnalyze = true
79	}
80	if !interestingToAnalyze {
81		return nil, vals
82	}
83
84	if debugTrace&debugTraceParams != 0 {
85		fmt.Fprintf(os.Stderr, "=-= param analysis of func %v:\n",
86			fn.Sym().Name)
87		for i := range vals {
88			n := "_"
89			if params[i] != nil {
90				n = params[i].Sym().String()
91			}
92			fmt.Fprintf(os.Stderr, "=-=  %d: %q %s top=%v\n",
93				i, n, vals[i].String(), top[i])
94		}
95	}
96	pa := &paramsAnalyzer{
97		fname:            fn.Sym().Name,
98		values:           vals,
99		params:           params,
100		top:              top,
101		condLevelTracker: new(condLevelTracker),
102		nameFinder:       nf,
103	}
104	return pa, nil
105}
106
107func (pa *paramsAnalyzer) setResults(funcProps *FuncProps) {
108	funcProps.ParamFlags = pa.values
109}
110
111func (pa *paramsAnalyzer) findParamIdx(n *ir.Name) int {
112	if n == nil {
113		panic("bad")
114	}
115	for i := range pa.params {
116		if pa.params[i] == n {
117			return i
118		}
119	}
120	return -1
121}
122
123type testfType func(x ir.Node, param *ir.Name, idx int) (bool, bool)
124
125// paramsAnalyzer invokes function 'testf' on the specified expression
126// 'x' for each parameter, and if the result is TRUE, or's 'flag' into
127// the flags for that param.
128func (pa *paramsAnalyzer) checkParams(x ir.Node, flag ParamPropBits, mayflag ParamPropBits, testf testfType) {
129	for idx, p := range pa.params {
130		if !pa.top[idx] && pa.values[idx] == ParamNoInfo {
131			continue
132		}
133		result, may := testf(x, p, idx)
134		if debugTrace&debugTraceParams != 0 {
135			fmt.Fprintf(os.Stderr, "=-= test expr %v param %s result=%v flag=%s\n", x, p.Sym().Name, result, flag.String())
136		}
137		if result {
138			v := flag
139			if pa.condLevel != 0 || may {
140				v = mayflag
141			}
142			pa.values[idx] |= v
143			pa.top[idx] = false
144		}
145	}
146}
147
148// foldCheckParams checks expression 'x' (an 'if' condition or
149// 'switch' stmt expr) to see if the expr would fold away if a
150// specific parameter had a constant value.
151func (pa *paramsAnalyzer) foldCheckParams(x ir.Node) {
152	pa.checkParams(x, ParamFeedsIfOrSwitch, ParamMayFeedIfOrSwitch,
153		func(x ir.Node, p *ir.Name, idx int) (bool, bool) {
154			return ShouldFoldIfNameConstant(x, []*ir.Name{p}), false
155		})
156}
157
158// callCheckParams examines the target of call expression 'ce' to see
159// if it is making a call to the value passed in for some parameter.
160func (pa *paramsAnalyzer) callCheckParams(ce *ir.CallExpr) {
161	switch ce.Op() {
162	case ir.OCALLINTER:
163		if ce.Op() != ir.OCALLINTER {
164			return
165		}
166		sel := ce.Fun.(*ir.SelectorExpr)
167		r := pa.staticValue(sel.X)
168		if r.Op() != ir.ONAME {
169			return
170		}
171		name := r.(*ir.Name)
172		if name.Class != ir.PPARAM {
173			return
174		}
175		pa.checkParams(r, ParamFeedsInterfaceMethodCall,
176			ParamMayFeedInterfaceMethodCall,
177			func(x ir.Node, p *ir.Name, idx int) (bool, bool) {
178				name := x.(*ir.Name)
179				return name == p, false
180			})
181	case ir.OCALLFUNC:
182		if ce.Fun.Op() != ir.ONAME {
183			return
184		}
185		called := ir.StaticValue(ce.Fun)
186		if called.Op() != ir.ONAME {
187			return
188		}
189		name := called.(*ir.Name)
190		if name.Class == ir.PPARAM {
191			pa.checkParams(called, ParamFeedsIndirectCall,
192				ParamMayFeedIndirectCall,
193				func(x ir.Node, p *ir.Name, idx int) (bool, bool) {
194					name := x.(*ir.Name)
195					return name == p, false
196				})
197		} else {
198			cname := pa.funcName(called)
199			if cname != nil {
200				pa.deriveFlagsFromCallee(ce, cname.Func)
201			}
202		}
203	}
204}
205
206// deriveFlagsFromCallee tries to derive flags for the current
207// function based on a call this function makes to some other
208// function. Example:
209//
210//	/* Simple */                /* Derived from callee */
211//	func foo(f func(int)) {     func foo(f func(int)) {
212//	  f(2)                        bar(32, f)
213//	}                           }
214//	                            func bar(x int, f func()) {
215//	                              f(x)
216//	                            }
217//
218// Here we can set the "param feeds indirect call" flag for
219// foo's param 'f' since we know that bar has that flag set for
220// its second param, and we're passing that param a function.
221func (pa *paramsAnalyzer) deriveFlagsFromCallee(ce *ir.CallExpr, callee *ir.Func) {
222	calleeProps := propsForFunc(callee)
223	if calleeProps == nil {
224		return
225	}
226	if debugTrace&debugTraceParams != 0 {
227		fmt.Fprintf(os.Stderr, "=-= callee props for %v:\n%s",
228			callee.Sym().Name, calleeProps.String())
229	}
230
231	must := []ParamPropBits{ParamFeedsInterfaceMethodCall, ParamFeedsIndirectCall, ParamFeedsIfOrSwitch}
232	may := []ParamPropBits{ParamMayFeedInterfaceMethodCall, ParamMayFeedIndirectCall, ParamMayFeedIfOrSwitch}
233
234	for pidx, arg := range ce.Args {
235		// Does the callee param have any interesting properties?
236		// If not we can skip this one.
237		pflag := calleeProps.ParamFlags[pidx]
238		if pflag == 0 {
239			continue
240		}
241		// See if one of the caller's parameters is flowing unmodified
242		// into this actual expression.
243		r := pa.staticValue(arg)
244		if r.Op() != ir.ONAME {
245			return
246		}
247		name := r.(*ir.Name)
248		if name.Class != ir.PPARAM {
249			return
250		}
251		callerParamIdx := pa.findParamIdx(name)
252		// note that callerParamIdx may return -1 in the case where
253		// the param belongs not to the current closure func we're
254		// analyzing but to an outer enclosing func.
255		if callerParamIdx == -1 {
256			return
257		}
258		if pa.params[callerParamIdx] == nil {
259			panic("something went wrong")
260		}
261		if !pa.top[callerParamIdx] &&
262			pa.values[callerParamIdx] == ParamNoInfo {
263			continue
264		}
265		if debugTrace&debugTraceParams != 0 {
266			fmt.Fprintf(os.Stderr, "=-= pflag for arg %d is %s\n",
267				pidx, pflag.String())
268		}
269		for i := range must {
270			mayv := may[i]
271			mustv := must[i]
272			if pflag&mustv != 0 && pa.condLevel == 0 {
273				pa.values[callerParamIdx] |= mustv
274			} else if pflag&(mustv|mayv) != 0 {
275				pa.values[callerParamIdx] |= mayv
276			}
277		}
278		pa.top[callerParamIdx] = false
279	}
280}
281
282func (pa *paramsAnalyzer) nodeVisitPost(n ir.Node) {
283	if len(pa.values) == 0 {
284		return
285	}
286	pa.condLevelTracker.post(n)
287	switch n.Op() {
288	case ir.OCALLFUNC:
289		ce := n.(*ir.CallExpr)
290		pa.callCheckParams(ce)
291	case ir.OCALLINTER:
292		ce := n.(*ir.CallExpr)
293		pa.callCheckParams(ce)
294	case ir.OIF:
295		ifst := n.(*ir.IfStmt)
296		pa.foldCheckParams(ifst.Cond)
297	case ir.OSWITCH:
298		swst := n.(*ir.SwitchStmt)
299		if swst.Tag != nil {
300			pa.foldCheckParams(swst.Tag)
301		}
302	}
303}
304
305func (pa *paramsAnalyzer) nodeVisitPre(n ir.Node) {
306	if len(pa.values) == 0 {
307		return
308	}
309	pa.condLevelTracker.pre(n)
310}
311
312// condLevelTracker helps keeps track very roughly of "level of conditional
313// nesting", e.g. how many "if" statements you have to go through to
314// get to the point where a given stmt executes. Example:
315//
316//	                      cond nesting level
317//	func foo() {
318//	 G = 1                   0
319//	 if x < 10 {             0
320//	  if y < 10 {            1
321//	   G = 0                 2
322//	  }
323//	 }
324//	}
325//
326// The intent here is to provide some sort of very abstract relative
327// hotness metric, e.g. "G = 1" above is expected to be executed more
328// often than "G = 0" (in the aggregate, across large numbers of
329// functions).
330type condLevelTracker struct {
331	condLevel int
332}
333
334func (c *condLevelTracker) pre(n ir.Node) {
335	// Increment level of "conditional testing" if we see
336	// an "if" or switch statement, and decrement if in
337	// a loop.
338	switch n.Op() {
339	case ir.OIF, ir.OSWITCH:
340		c.condLevel++
341	case ir.OFOR, ir.ORANGE:
342		c.condLevel--
343	}
344}
345
346func (c *condLevelTracker) post(n ir.Node) {
347	switch n.Op() {
348	case ir.OFOR, ir.ORANGE:
349		c.condLevel++
350	case ir.OIF:
351		c.condLevel--
352	case ir.OSWITCH:
353		c.condLevel--
354	}
355}
356