1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package loopvar applies the proper variable capture, according
6// to experiment, flags, language version, etc.
7package loopvar
8
9import (
10	"cmd/compile/internal/base"
11	"cmd/compile/internal/ir"
12	"cmd/compile/internal/logopt"
13	"cmd/compile/internal/typecheck"
14	"cmd/compile/internal/types"
15	"cmd/internal/src"
16	"fmt"
17)
18
19type VarAndLoop struct {
20	Name    *ir.Name
21	Loop    ir.Node  // the *ir.RangeStmt or *ir.ForStmt. Used for identity and position
22	LastPos src.XPos // the last position observed within Loop
23}
24
25// ForCapture transforms for and range loops that declare variables that might be
26// captured by a closure or escaped to the heap, using a syntactic check that
27// conservatively overestimates the loops where capture occurs, but still avoids
28// transforming the (large) majority of loops. It returns the list of names
29// subject to this change, that may (once transformed) be heap allocated in the
30// process. (This allows checking after escape analysis to call out any such
31// variables, in case it causes allocation/performance problems).
32//
33// The decision to transform loops is normally encoded in the For/Range loop node
34// field DistinctVars but is also dependent on base.LoopVarHash, and some values
35// of base.Debug.LoopVar (which is set per-package).  Decisions encoded in DistinctVars
36// are preserved across inlining, so if package a calls b.F and loops in b.F are
37// transformed, then they are always transformed, whether b.F is inlined or not.
38//
39// Per-package, the debug flag settings that affect this transformer:
40//
41// base.LoopVarHash != nil => use hash setting to govern transformation.
42// note that LoopVarHash != nil sets base.Debug.LoopVar to 1 (unless it is >= 11, for testing/debugging).
43//
44// base.Debug.LoopVar == 11 => transform ALL loops ignoring syntactic/potential escape. Do not log, can be in addition to GOEXPERIMENT.
45//
46// The effect of GOEXPERIMENT=loopvar is to change the default value (0) of base.Debug.LoopVar to 1 for all packages.
47func ForCapture(fn *ir.Func) []VarAndLoop {
48	// if a loop variable is transformed it is appended to this slice for later logging
49	var transformed []VarAndLoop
50
51	describe := func(n *ir.Name) string {
52		pos := n.Pos()
53		inner := base.Ctxt.InnermostPos(pos)
54		outer := base.Ctxt.OutermostPos(pos)
55		if inner == outer {
56			return fmt.Sprintf("loop variable %v now per-iteration", n)
57		}
58		return fmt.Sprintf("loop variable %v now per-iteration (loop inlined into %s:%d)", n, outer.Filename(), outer.Line())
59	}
60
61	forCapture := func() {
62		seq := 1
63
64		dclFixups := make(map[*ir.Name]ir.Stmt)
65
66		// possibly leaked includes names of declared loop variables that may be leaked;
67		// the mapped value is true if the name is *syntactically* leaked, and those loops
68		// will be transformed.
69		possiblyLeaked := make(map[*ir.Name]bool)
70
71		// these enable an optimization of "escape" under return statements
72		loopDepth := 0
73		returnInLoopDepth := 0
74
75		// noteMayLeak is called for candidate variables in for range/3-clause, and
76		// adds them (mapped to false) to possiblyLeaked.
77		noteMayLeak := func(x ir.Node) {
78			if n, ok := x.(*ir.Name); ok {
79				if n.Type().Kind() == types.TBLANK {
80					return
81				}
82				// default is false (leak candidate, not yet known to leak), but flag can make all variables "leak"
83				possiblyLeaked[n] = base.Debug.LoopVar >= 11
84			}
85		}
86
87		// For reporting, keep track of the last position within any loop.
88		// Loops nest, also need to be sensitive to inlining.
89		var lastPos src.XPos
90
91		updateLastPos := func(p src.XPos) {
92			pl, ll := p.Line(), lastPos.Line()
93			if p.SameFile(lastPos) &&
94				(pl > ll || pl == ll && p.Col() > lastPos.Col()) {
95				lastPos = p
96			}
97		}
98
99		// maybeReplaceVar unshares an iteration variable for a range loop,
100		// if that variable was actually (syntactically) leaked,
101		// subject to hash-variable debugging.
102		maybeReplaceVar := func(k ir.Node, x *ir.RangeStmt) ir.Node {
103			if n, ok := k.(*ir.Name); ok && possiblyLeaked[n] {
104				desc := func() string {
105					return describe(n)
106				}
107				if base.LoopVarHash.MatchPos(n.Pos(), desc) {
108					// Rename the loop key, prefix body with assignment from loop key
109					transformed = append(transformed, VarAndLoop{n, x, lastPos})
110					tk := typecheck.TempAt(base.Pos, fn, n.Type())
111					tk.SetTypecheck(1)
112					as := ir.NewAssignStmt(x.Pos(), n, tk)
113					as.Def = true
114					as.SetTypecheck(1)
115					x.Body.Prepend(as)
116					dclFixups[n] = as
117					return tk
118				}
119			}
120			return k
121		}
122
123		// scanChildrenThenTransform processes node x to:
124		//  1. if x is a for/range w/ DistinctVars, note declared iteration variables possiblyLeaked (PL)
125		//  2. search all of x's children for syntactically escaping references to v in PL,
126		//     meaning either address-of-v or v-captured-by-a-closure
127		//  3. for all v in PL that had a syntactically escaping reference, transform the declaration
128		//     and (in case of 3-clause loop) the loop to the unshared loop semantics.
129		//  This is all much simpler for range loops; 3-clause loops can have an arbitrary number
130		//  of iteration variables and the transformation is more involved, range loops have at most 2.
131		var scanChildrenThenTransform func(x ir.Node) bool
132		scanChildrenThenTransform = func(n ir.Node) bool {
133
134			if loopDepth > 0 {
135				updateLastPos(n.Pos())
136			}
137
138			switch x := n.(type) {
139			case *ir.ClosureExpr:
140				if returnInLoopDepth >= loopDepth {
141					// This expression is a child of a return, which escapes all loops above
142					// the return, but not those between this expression and the return.
143					break
144				}
145				for _, cv := range x.Func.ClosureVars {
146					v := cv.Canonical()
147					if _, ok := possiblyLeaked[v]; ok {
148						possiblyLeaked[v] = true
149					}
150				}
151
152			case *ir.AddrExpr:
153				if returnInLoopDepth >= loopDepth {
154					// This expression is a child of a return, which escapes all loops above
155					// the return, but not those between this expression and the return.
156					break
157				}
158				// Explicitly note address-taken so that return-statements can be excluded
159				y := ir.OuterValue(x.X)
160				if y.Op() != ir.ONAME {
161					break
162				}
163				z, ok := y.(*ir.Name)
164				if !ok {
165					break
166				}
167				switch z.Class {
168				case ir.PAUTO, ir.PPARAM, ir.PPARAMOUT, ir.PAUTOHEAP:
169					if _, ok := possiblyLeaked[z]; ok {
170						possiblyLeaked[z] = true
171					}
172				}
173
174			case *ir.ReturnStmt:
175				savedRILD := returnInLoopDepth
176				returnInLoopDepth = loopDepth
177				defer func() { returnInLoopDepth = savedRILD }()
178
179			case *ir.RangeStmt:
180				if !(x.Def && x.DistinctVars) {
181					// range loop must define its iteration variables AND have distinctVars.
182					x.DistinctVars = false
183					break
184				}
185				noteMayLeak(x.Key)
186				noteMayLeak(x.Value)
187				loopDepth++
188				savedLastPos := lastPos
189				lastPos = x.Pos() // this sets the file.
190				ir.DoChildren(n, scanChildrenThenTransform)
191				loopDepth--
192				x.Key = maybeReplaceVar(x.Key, x)
193				x.Value = maybeReplaceVar(x.Value, x)
194				thisLastPos := lastPos
195				lastPos = savedLastPos
196				updateLastPos(thisLastPos) // this will propagate lastPos if in the same file.
197				x.DistinctVars = false
198				return false
199
200			case *ir.ForStmt:
201				if !x.DistinctVars {
202					break
203				}
204				forAllDefInInit(x, noteMayLeak)
205				loopDepth++
206				savedLastPos := lastPos
207				lastPos = x.Pos() // this sets the file.
208				ir.DoChildren(n, scanChildrenThenTransform)
209				loopDepth--
210				var leaked []*ir.Name
211				// Collect the leaking variables for the much-more-complex transformation.
212				forAllDefInInit(x, func(z ir.Node) {
213					if n, ok := z.(*ir.Name); ok && possiblyLeaked[n] {
214						desc := func() string {
215							return describe(n)
216						}
217						// Hash on n.Pos() for most precise failure location.
218						if base.LoopVarHash.MatchPos(n.Pos(), desc) {
219							leaked = append(leaked, n)
220						}
221					}
222				})
223
224				if len(leaked) > 0 {
225					// need to transform the for loop just so.
226
227					/* Contrived example, w/ numbered comments from the transformation:
228									BEFORE:
229										var escape []*int
230										for z := 0; z < n; z++ {
231											if reason() {
232												escape = append(escape, &z)
233												continue
234											}
235											z = z + z
236											stuff
237										}
238									AFTER:
239										for z', tmp_first := 0, true; ; { // (4)
240											                              // (5) body' follows:
241											z := z'                       // (1)
242											if tmp_first {tmp_first = false} else {z++} // (6)
243											if ! (z < n) { break }        // (7)
244											                              // (3, 8) body_continue
245											if reason() {
246					                            escape = append(escape, &z)
247												goto next                 // rewritten continue
248											}
249											z = z + z
250											stuff
251										next:                             // (9)
252											z' = z                       // (2)
253										}
254
255										In the case that the loop contains no increment (z++),
256										there is no need for step 6,
257										and thus no need to test, update, or declare tmp_first (part of step 4).
258										Similarly if the loop contains no exit test (z < n),
259										then there is no need for step 7.
260					*/
261
262					// Expressed in terms of the input ForStmt
263					//
264					// 	type ForStmt struct {
265					// 	init     Nodes
266					// 	Label    *types.Sym
267					// 	Cond     Node  // empty if OFORUNTIL
268					// 	Post     Node
269					// 	Body     Nodes
270					// 	HasBreak bool
271					// }
272
273					// OFOR: init; loop: if !Cond {break}; Body; Post; goto loop
274
275					// (1) prebody = {z := z' for z in leaked}
276					// (2) postbody = {z' = z for z in leaked}
277					// (3) body_continue = {body : s/continue/goto next}
278					// (4) init' = (init : s/z/z' for z in leaked) + tmp_first := true
279					// (5) body' = prebody +        // appears out of order below
280					// (6)         if tmp_first {tmp_first = false} else {Post} +
281					// (7)         if !cond {break} +
282					// (8)         body_continue (3) +
283					// (9)         next: postbody (2)
284					// (10) cond' = {}
285					// (11) post' = {}
286
287					// minor optimizations:
288					//   if Post is empty, tmp_first and step 6 can be skipped.
289					//   if Cond is empty, that code can also be skipped.
290
291					var preBody, postBody ir.Nodes
292
293					// Given original iteration variable z, what is the corresponding z'
294					// that carries the value from iteration to iteration?
295					zPrimeForZ := make(map[*ir.Name]*ir.Name)
296
297					// (1,2) initialize preBody and postBody
298					for _, z := range leaked {
299						transformed = append(transformed, VarAndLoop{z, x, lastPos})
300
301						tz := typecheck.TempAt(base.Pos, fn, z.Type())
302						tz.SetTypecheck(1)
303						zPrimeForZ[z] = tz
304
305						as := ir.NewAssignStmt(x.Pos(), z, tz)
306						as.Def = true
307						as.SetTypecheck(1)
308						preBody.Append(as)
309						dclFixups[z] = as
310
311						as = ir.NewAssignStmt(x.Pos(), tz, z)
312						as.SetTypecheck(1)
313						postBody.Append(as)
314
315					}
316
317					// (3) rewrite continues in body -- rewrite is inplace, so works for top level visit, too.
318					label := typecheck.Lookup(fmt.Sprintf(".3clNext_%d", seq))
319					seq++
320					labelStmt := ir.NewLabelStmt(x.Pos(), label)
321					labelStmt.SetTypecheck(1)
322
323					loopLabel := x.Label
324					loopDepth := 0
325					var editContinues func(x ir.Node) bool
326					editContinues = func(x ir.Node) bool {
327
328						switch c := x.(type) {
329						case *ir.BranchStmt:
330							// If this is a continue targeting the loop currently being rewritten, transform it to an appropriate GOTO
331							if c.Op() == ir.OCONTINUE && (loopDepth == 0 && c.Label == nil || loopLabel != nil && c.Label == loopLabel) {
332								c.Label = label
333								c.SetOp(ir.OGOTO)
334							}
335						case *ir.RangeStmt, *ir.ForStmt:
336							loopDepth++
337							ir.DoChildren(x, editContinues)
338							loopDepth--
339							return false
340						}
341						ir.DoChildren(x, editContinues)
342						return false
343					}
344					for _, y := range x.Body {
345						editContinues(y)
346					}
347					bodyContinue := x.Body
348
349					// (4) rewrite init
350					forAllDefInInitUpdate(x, func(z ir.Node, pz *ir.Node) {
351						// note tempFor[n] can be nil if hash searching.
352						if n, ok := z.(*ir.Name); ok && possiblyLeaked[n] && zPrimeForZ[n] != nil {
353							*pz = zPrimeForZ[n]
354						}
355					})
356
357					postNotNil := x.Post != nil
358					var tmpFirstDcl ir.Node
359					if postNotNil {
360						// body' = prebody +
361						// (6)     if tmp_first {tmp_first = false} else {Post} +
362						//         if !cond {break} + ...
363						tmpFirst := typecheck.TempAt(base.Pos, fn, types.Types[types.TBOOL])
364						tmpFirstDcl = typecheck.Stmt(ir.NewAssignStmt(x.Pos(), tmpFirst, ir.NewBool(base.Pos, true)))
365						tmpFirstSetFalse := typecheck.Stmt(ir.NewAssignStmt(x.Pos(), tmpFirst, ir.NewBool(base.Pos, false)))
366						ifTmpFirst := ir.NewIfStmt(x.Pos(), tmpFirst, ir.Nodes{tmpFirstSetFalse}, ir.Nodes{x.Post})
367						ifTmpFirst.PtrInit().Append(typecheck.Stmt(ir.NewDecl(base.Pos, ir.ODCL, tmpFirst))) // declares tmpFirst
368						preBody.Append(typecheck.Stmt(ifTmpFirst))
369					}
370
371					// body' = prebody +
372					//         if tmp_first {tmp_first = false} else {Post} +
373					// (7)     if !cond {break} + ...
374					if x.Cond != nil {
375						notCond := ir.NewUnaryExpr(x.Cond.Pos(), ir.ONOT, x.Cond)
376						notCond.SetType(x.Cond.Type())
377						notCond.SetTypecheck(1)
378						newBreak := ir.NewBranchStmt(x.Pos(), ir.OBREAK, nil)
379						newBreak.SetTypecheck(1)
380						ifNotCond := ir.NewIfStmt(x.Pos(), notCond, ir.Nodes{newBreak}, nil)
381						ifNotCond.SetTypecheck(1)
382						preBody.Append(ifNotCond)
383					}
384
385					if postNotNil {
386						x.PtrInit().Append(tmpFirstDcl)
387					}
388
389					// (8)
390					preBody.Append(bodyContinue...)
391					// (9)
392					preBody.Append(labelStmt)
393					preBody.Append(postBody...)
394
395					// (5) body' = prebody + ...
396					x.Body = preBody
397
398					// (10) cond' = {}
399					x.Cond = nil
400
401					// (11) post' = {}
402					x.Post = nil
403				}
404				thisLastPos := lastPos
405				lastPos = savedLastPos
406				updateLastPos(thisLastPos) // this will propagate lastPos if in the same file.
407				x.DistinctVars = false
408
409				return false
410			}
411
412			ir.DoChildren(n, scanChildrenThenTransform)
413
414			return false
415		}
416		scanChildrenThenTransform(fn)
417		if len(transformed) > 0 {
418			// editNodes scans a slice C of ir.Node, looking for declarations that
419			// appear in dclFixups.  Any declaration D whose "fixup" is an assignmnt
420			// statement A is removed from the C and relocated to the Init
421			// of A.  editNodes returns the modified slice of ir.Node.
422			editNodes := func(c ir.Nodes) ir.Nodes {
423				j := 0
424				for _, n := range c {
425					if d, ok := n.(*ir.Decl); ok {
426						if s := dclFixups[d.X]; s != nil {
427							switch a := s.(type) {
428							case *ir.AssignStmt:
429								a.PtrInit().Prepend(d)
430								delete(dclFixups, d.X) // can't be sure of visit order, wouldn't want to visit twice.
431							default:
432								base.Fatalf("not implemented yet for node type %v", s.Op())
433							}
434							continue // do not copy this node, and do not increment j
435						}
436					}
437					c[j] = n
438					j++
439				}
440				for k := j; k < len(c); k++ {
441					c[k] = nil
442				}
443				return c[:j]
444			}
445			// fixup all tagged declarations in all the statements lists in fn.
446			rewriteNodes(fn, editNodes)
447		}
448	}
449	ir.WithFunc(fn, forCapture)
450	return transformed
451}
452
453// forAllDefInInitUpdate applies "do" to all the defining assignments in the Init clause of a ForStmt.
454// This abstracts away some of the boilerplate from the already complex and verbose for-3-clause case.
455func forAllDefInInitUpdate(x *ir.ForStmt, do func(z ir.Node, update *ir.Node)) {
456	for _, s := range x.Init() {
457		switch y := s.(type) {
458		case *ir.AssignListStmt:
459			if !y.Def {
460				continue
461			}
462			for i, z := range y.Lhs {
463				do(z, &y.Lhs[i])
464			}
465		case *ir.AssignStmt:
466			if !y.Def {
467				continue
468			}
469			do(y.X, &y.X)
470		}
471	}
472}
473
474// forAllDefInInit is forAllDefInInitUpdate without the update option.
475func forAllDefInInit(x *ir.ForStmt, do func(z ir.Node)) {
476	forAllDefInInitUpdate(x, func(z ir.Node, _ *ir.Node) { do(z) })
477}
478
479// rewriteNodes applies editNodes to all statement lists in fn.
480func rewriteNodes(fn *ir.Func, editNodes func(c ir.Nodes) ir.Nodes) {
481	var forNodes func(x ir.Node) bool
482	forNodes = func(n ir.Node) bool {
483		if stmt, ok := n.(ir.InitNode); ok {
484			// process init list
485			stmt.SetInit(editNodes(stmt.Init()))
486		}
487		switch x := n.(type) {
488		case *ir.Func:
489			x.Body = editNodes(x.Body)
490		case *ir.InlinedCallExpr:
491			x.Body = editNodes(x.Body)
492
493		case *ir.CaseClause:
494			x.Body = editNodes(x.Body)
495		case *ir.CommClause:
496			x.Body = editNodes(x.Body)
497
498		case *ir.BlockStmt:
499			x.List = editNodes(x.List)
500
501		case *ir.ForStmt:
502			x.Body = editNodes(x.Body)
503		case *ir.RangeStmt:
504			x.Body = editNodes(x.Body)
505		case *ir.IfStmt:
506			x.Body = editNodes(x.Body)
507			x.Else = editNodes(x.Else)
508		case *ir.SelectStmt:
509			x.Compiled = editNodes(x.Compiled)
510		case *ir.SwitchStmt:
511			x.Compiled = editNodes(x.Compiled)
512		}
513		ir.DoChildren(n, forNodes)
514		return false
515	}
516	forNodes(fn)
517}
518
519func LogTransformations(transformed []VarAndLoop) {
520	print := 2 <= base.Debug.LoopVar && base.Debug.LoopVar != 11
521
522	if print || logopt.Enabled() { // 11 is do them all, quietly, 12 includes debugging.
523		fileToPosBase := make(map[string]*src.PosBase) // used to remove inline context for innermost reporting.
524
525		// trueInlinedPos rebases inner w/o inline context so that it prints correctly in WarnfAt; otherwise it prints as outer.
526		trueInlinedPos := func(inner src.Pos) src.XPos {
527			afn := inner.AbsFilename()
528			pb, ok := fileToPosBase[afn]
529			if !ok {
530				pb = src.NewFileBase(inner.Filename(), afn)
531				fileToPosBase[afn] = pb
532			}
533			inner.SetBase(pb)
534			return base.Ctxt.PosTable.XPos(inner)
535		}
536
537		type unit struct{}
538		loopsSeen := make(map[ir.Node]unit)
539		type loopPos struct {
540			loop  ir.Node
541			last  src.XPos
542			curfn *ir.Func
543		}
544		var loops []loopPos
545		for _, lv := range transformed {
546			n := lv.Name
547			if _, ok := loopsSeen[lv.Loop]; !ok {
548				l := lv.Loop
549				loopsSeen[l] = unit{}
550				loops = append(loops, loopPos{l, lv.LastPos, n.Curfn})
551			}
552			pos := n.Pos()
553
554			inner := base.Ctxt.InnermostPos(pos)
555			outer := base.Ctxt.OutermostPos(pos)
556
557			if logopt.Enabled() {
558				// For automated checking of coverage of this transformation, include this in the JSON information.
559				var nString interface{} = n
560				if inner != outer {
561					nString = fmt.Sprintf("%v (from inline)", n)
562				}
563				if n.Esc() == ir.EscHeap {
564					logopt.LogOpt(pos, "iteration-variable-to-heap", "loopvar", ir.FuncName(n.Curfn), nString)
565				} else {
566					logopt.LogOpt(pos, "iteration-variable-to-stack", "loopvar", ir.FuncName(n.Curfn), nString)
567				}
568			}
569			if print {
570				if inner == outer {
571					if n.Esc() == ir.EscHeap {
572						base.WarnfAt(pos, "loop variable %v now per-iteration, heap-allocated", n)
573					} else {
574						base.WarnfAt(pos, "loop variable %v now per-iteration, stack-allocated", n)
575					}
576				} else {
577					innerXPos := trueInlinedPos(inner)
578					if n.Esc() == ir.EscHeap {
579						base.WarnfAt(innerXPos, "loop variable %v now per-iteration, heap-allocated (loop inlined into %s:%d)", n, outer.Filename(), outer.Line())
580					} else {
581						base.WarnfAt(innerXPos, "loop variable %v now per-iteration, stack-allocated (loop inlined into %s:%d)", n, outer.Filename(), outer.Line())
582					}
583				}
584			}
585		}
586		for _, l := range loops {
587			pos := l.loop.Pos()
588			last := l.last
589			loopKind := "range"
590			if _, ok := l.loop.(*ir.ForStmt); ok {
591				loopKind = "for"
592			}
593			if logopt.Enabled() {
594				// Intended to help with performance debugging, we record whole loop ranges
595				logopt.LogOptRange(pos, last, "loop-modified-"+loopKind, "loopvar", ir.FuncName(l.curfn))
596			}
597			if print && 4 <= base.Debug.LoopVar {
598				// TODO decide if we want to keep this, or not.  It was helpful for validating logopt, otherwise, eh.
599				inner := base.Ctxt.InnermostPos(pos)
600				outer := base.Ctxt.OutermostPos(pos)
601
602				if inner == outer {
603					base.WarnfAt(pos, "%s loop ending at %d:%d was modified", loopKind, last.Line(), last.Col())
604				} else {
605					pos = trueInlinedPos(inner)
606					last = trueInlinedPos(base.Ctxt.InnermostPos(last))
607					base.WarnfAt(pos, "%s loop ending at %d:%d was modified (loop inlined into %s:%d)", loopKind, last.Line(), last.Col(), outer.Filename(), outer.Line())
608				}
609			}
610		}
611	}
612}
613