1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package walk
6
7import (
8	"encoding/binary"
9	"fmt"
10	"go/constant"
11	"hash/fnv"
12	"io"
13
14	"cmd/compile/internal/base"
15	"cmd/compile/internal/compare"
16	"cmd/compile/internal/ir"
17	"cmd/compile/internal/reflectdata"
18	"cmd/compile/internal/ssagen"
19	"cmd/compile/internal/typecheck"
20	"cmd/compile/internal/types"
21)
22
23func fakePC(n ir.Node) ir.Node {
24	// In order to get deterministic IDs, we include the package path, absolute filename, line number, column number
25	// in the calculation of the fakePC for the IR node.
26	hash := fnv.New32()
27	// We ignore the errors here because the `io.Writer` in the `hash.Hash` interface never returns an error.
28	io.WriteString(hash, base.Ctxt.Pkgpath)
29	io.WriteString(hash, base.Ctxt.PosTable.Pos(n.Pos()).AbsFilename())
30	binary.Write(hash, binary.LittleEndian, int64(n.Pos().Line()))
31	binary.Write(hash, binary.LittleEndian, int64(n.Pos().Col()))
32	// We also include the string representation of the node to distinguish autogenerated expression since
33	// those get the same `src.XPos`
34	io.WriteString(hash, fmt.Sprintf("%v", n))
35
36	return ir.NewInt(base.Pos, int64(hash.Sum32()))
37}
38
39// The result of walkCompare MUST be assigned back to n, e.g.
40//
41//	n.Left = walkCompare(n.Left, init)
42func walkCompare(n *ir.BinaryExpr, init *ir.Nodes) ir.Node {
43	if n.X.Type().IsInterface() && n.Y.Type().IsInterface() && n.X.Op() != ir.ONIL && n.Y.Op() != ir.ONIL {
44		return walkCompareInterface(n, init)
45	}
46
47	if n.X.Type().IsString() && n.Y.Type().IsString() {
48		return walkCompareString(n, init)
49	}
50
51	n.X = walkExpr(n.X, init)
52	n.Y = walkExpr(n.Y, init)
53
54	// Given mixed interface/concrete comparison,
55	// rewrite into types-equal && data-equal.
56	// This is efficient, avoids allocations, and avoids runtime calls.
57	//
58	// TODO(mdempsky): It would be more general and probably overall
59	// simpler to just extend walkCompareInterface to optimize when one
60	// operand is an OCONVIFACE.
61	if n.X.Type().IsInterface() != n.Y.Type().IsInterface() {
62		// Preserve side-effects in case of short-circuiting; see #32187.
63		l := cheapExpr(n.X, init)
64		r := cheapExpr(n.Y, init)
65		// Swap so that l is the interface value and r is the concrete value.
66		if n.Y.Type().IsInterface() {
67			l, r = r, l
68		}
69
70		// Handle both == and !=.
71		eq := n.Op()
72		andor := ir.OOROR
73		if eq == ir.OEQ {
74			andor = ir.OANDAND
75		}
76		// Check for types equal.
77		// For empty interface, this is:
78		//   l.tab == type(r)
79		// For non-empty interface, this is:
80		//   l.tab != nil && l.tab._type == type(r)
81		//
82		// TODO(mdempsky): For non-empty interface comparisons, just
83		// compare against the itab address directly?
84		var eqtype ir.Node
85		tab := ir.NewUnaryExpr(base.Pos, ir.OITAB, l)
86		rtyp := reflectdata.CompareRType(base.Pos, n)
87		if l.Type().IsEmptyInterface() {
88			tab.SetType(types.NewPtr(types.Types[types.TUINT8]))
89			tab.SetTypecheck(1)
90			eqtype = ir.NewBinaryExpr(base.Pos, eq, tab, rtyp)
91		} else {
92			nonnil := ir.NewBinaryExpr(base.Pos, brcom(eq), typecheck.NodNil(), tab)
93			match := ir.NewBinaryExpr(base.Pos, eq, itabType(tab), rtyp)
94			eqtype = ir.NewLogicalExpr(base.Pos, andor, nonnil, match)
95		}
96		// Check for data equal.
97		eqdata := ir.NewBinaryExpr(base.Pos, eq, ifaceData(n.Pos(), l, r.Type()), r)
98		// Put it all together.
99		expr := ir.NewLogicalExpr(base.Pos, andor, eqtype, eqdata)
100		return finishCompare(n, expr, init)
101	}
102
103	// Must be comparison of array or struct.
104	// Otherwise back end handles it.
105	// While we're here, decide whether to
106	// inline or call an eq alg.
107	t := n.X.Type()
108	var inline bool
109
110	maxcmpsize := int64(4)
111	unalignedLoad := ssagen.Arch.LinkArch.CanMergeLoads
112	if unalignedLoad {
113		// Keep this low enough to generate less code than a function call.
114		maxcmpsize = 2 * int64(ssagen.Arch.LinkArch.RegSize)
115	}
116
117	switch t.Kind() {
118	default:
119		if base.Debug.Libfuzzer != 0 && t.IsInteger() && (n.X.Name() == nil || !n.X.Name().Libfuzzer8BitCounter()) {
120			n.X = cheapExpr(n.X, init)
121			n.Y = cheapExpr(n.Y, init)
122
123			// If exactly one comparison operand is
124			// constant, invoke the constcmp functions
125			// instead, and arrange for the constant
126			// operand to be the first argument.
127			l, r := n.X, n.Y
128			if r.Op() == ir.OLITERAL {
129				l, r = r, l
130			}
131			constcmp := l.Op() == ir.OLITERAL && r.Op() != ir.OLITERAL
132
133			var fn string
134			var paramType *types.Type
135			switch t.Size() {
136			case 1:
137				fn = "libfuzzerTraceCmp1"
138				if constcmp {
139					fn = "libfuzzerTraceConstCmp1"
140				}
141				paramType = types.Types[types.TUINT8]
142			case 2:
143				fn = "libfuzzerTraceCmp2"
144				if constcmp {
145					fn = "libfuzzerTraceConstCmp2"
146				}
147				paramType = types.Types[types.TUINT16]
148			case 4:
149				fn = "libfuzzerTraceCmp4"
150				if constcmp {
151					fn = "libfuzzerTraceConstCmp4"
152				}
153				paramType = types.Types[types.TUINT32]
154			case 8:
155				fn = "libfuzzerTraceCmp8"
156				if constcmp {
157					fn = "libfuzzerTraceConstCmp8"
158				}
159				paramType = types.Types[types.TUINT64]
160			default:
161				base.Fatalf("unexpected integer size %d for %v", t.Size(), t)
162			}
163			init.Append(mkcall(fn, nil, init, tracecmpArg(l, paramType, init), tracecmpArg(r, paramType, init), fakePC(n)))
164		}
165		return n
166	case types.TARRAY:
167		// We can compare several elements at once with 2/4/8 byte integer compares
168		inline = t.NumElem() <= 1 || (types.IsSimple[t.Elem().Kind()] && (t.NumElem() <= 4 || t.Elem().Size()*t.NumElem() <= maxcmpsize))
169	case types.TSTRUCT:
170		inline = compare.EqStructCost(t) <= 4
171	}
172
173	cmpl := n.X
174	for cmpl != nil && cmpl.Op() == ir.OCONVNOP {
175		cmpl = cmpl.(*ir.ConvExpr).X
176	}
177	cmpr := n.Y
178	for cmpr != nil && cmpr.Op() == ir.OCONVNOP {
179		cmpr = cmpr.(*ir.ConvExpr).X
180	}
181
182	// Chose not to inline. Call equality function directly.
183	if !inline {
184		// eq algs take pointers; cmpl and cmpr must be addressable
185		if !ir.IsAddressable(cmpl) || !ir.IsAddressable(cmpr) {
186			base.Fatalf("arguments of comparison must be lvalues - %v %v", cmpl, cmpr)
187		}
188
189		// Should only arrive here with large memory or
190		// a struct/array containing a non-memory field/element.
191		// Small memory is handled inline, and single non-memory
192		// is handled by walkCompare.
193		fn, needsLength := reflectdata.EqFor(t)
194		call := ir.NewCallExpr(base.Pos, ir.OCALL, fn, nil)
195		call.Args.Append(typecheck.NodAddr(cmpl))
196		call.Args.Append(typecheck.NodAddr(cmpr))
197		if needsLength {
198			call.Args.Append(ir.NewInt(base.Pos, t.Size()))
199		}
200		res := ir.Node(call)
201		if n.Op() != ir.OEQ {
202			res = ir.NewUnaryExpr(base.Pos, ir.ONOT, res)
203		}
204		return finishCompare(n, res, init)
205	}
206
207	// inline: build boolean expression comparing element by element
208	andor := ir.OANDAND
209	if n.Op() == ir.ONE {
210		andor = ir.OOROR
211	}
212	var expr ir.Node
213	comp := func(el, er ir.Node) {
214		a := ir.NewBinaryExpr(base.Pos, n.Op(), el, er)
215		if expr == nil {
216			expr = a
217		} else {
218			expr = ir.NewLogicalExpr(base.Pos, andor, expr, a)
219		}
220	}
221	and := func(cond ir.Node) {
222		if expr == nil {
223			expr = cond
224		} else {
225			expr = ir.NewLogicalExpr(base.Pos, andor, expr, cond)
226		}
227	}
228	cmpl = safeExpr(cmpl, init)
229	cmpr = safeExpr(cmpr, init)
230	if t.IsStruct() {
231		conds, _ := compare.EqStruct(t, cmpl, cmpr)
232		if n.Op() == ir.OEQ {
233			for _, cond := range conds {
234				and(cond)
235			}
236		} else {
237			for _, cond := range conds {
238				notCond := ir.NewUnaryExpr(base.Pos, ir.ONOT, cond)
239				and(notCond)
240			}
241		}
242	} else {
243		step := int64(1)
244		remains := t.NumElem() * t.Elem().Size()
245		combine64bit := unalignedLoad && types.RegSize == 8 && t.Elem().Size() <= 4 && t.Elem().IsInteger()
246		combine32bit := unalignedLoad && t.Elem().Size() <= 2 && t.Elem().IsInteger()
247		combine16bit := unalignedLoad && t.Elem().Size() == 1 && t.Elem().IsInteger()
248		for i := int64(0); remains > 0; {
249			var convType *types.Type
250			switch {
251			case remains >= 8 && combine64bit:
252				convType = types.Types[types.TINT64]
253				step = 8 / t.Elem().Size()
254			case remains >= 4 && combine32bit:
255				convType = types.Types[types.TUINT32]
256				step = 4 / t.Elem().Size()
257			case remains >= 2 && combine16bit:
258				convType = types.Types[types.TUINT16]
259				step = 2 / t.Elem().Size()
260			default:
261				step = 1
262			}
263			if step == 1 {
264				comp(
265					ir.NewIndexExpr(base.Pos, cmpl, ir.NewInt(base.Pos, i)),
266					ir.NewIndexExpr(base.Pos, cmpr, ir.NewInt(base.Pos, i)),
267				)
268				i++
269				remains -= t.Elem().Size()
270			} else {
271				elemType := t.Elem().ToUnsigned()
272				cmplw := ir.Node(ir.NewIndexExpr(base.Pos, cmpl, ir.NewInt(base.Pos, i)))
273				cmplw = typecheck.Conv(cmplw, elemType) // convert to unsigned
274				cmplw = typecheck.Conv(cmplw, convType) // widen
275				cmprw := ir.Node(ir.NewIndexExpr(base.Pos, cmpr, ir.NewInt(base.Pos, i)))
276				cmprw = typecheck.Conv(cmprw, elemType)
277				cmprw = typecheck.Conv(cmprw, convType)
278				// For code like this:  uint32(s[0]) | uint32(s[1])<<8 | uint32(s[2])<<16 ...
279				// ssa will generate a single large load.
280				for offset := int64(1); offset < step; offset++ {
281					lb := ir.Node(ir.NewIndexExpr(base.Pos, cmpl, ir.NewInt(base.Pos, i+offset)))
282					lb = typecheck.Conv(lb, elemType)
283					lb = typecheck.Conv(lb, convType)
284					lb = ir.NewBinaryExpr(base.Pos, ir.OLSH, lb, ir.NewInt(base.Pos, 8*t.Elem().Size()*offset))
285					cmplw = ir.NewBinaryExpr(base.Pos, ir.OOR, cmplw, lb)
286					rb := ir.Node(ir.NewIndexExpr(base.Pos, cmpr, ir.NewInt(base.Pos, i+offset)))
287					rb = typecheck.Conv(rb, elemType)
288					rb = typecheck.Conv(rb, convType)
289					rb = ir.NewBinaryExpr(base.Pos, ir.OLSH, rb, ir.NewInt(base.Pos, 8*t.Elem().Size()*offset))
290					cmprw = ir.NewBinaryExpr(base.Pos, ir.OOR, cmprw, rb)
291				}
292				comp(cmplw, cmprw)
293				i += step
294				remains -= step * t.Elem().Size()
295			}
296		}
297	}
298	if expr == nil {
299		expr = ir.NewBool(base.Pos, n.Op() == ir.OEQ)
300		// We still need to use cmpl and cmpr, in case they contain
301		// an expression which might panic. See issue 23837.
302		a1 := typecheck.Stmt(ir.NewAssignStmt(base.Pos, ir.BlankNode, cmpl))
303		a2 := typecheck.Stmt(ir.NewAssignStmt(base.Pos, ir.BlankNode, cmpr))
304		init.Append(a1, a2)
305	}
306	return finishCompare(n, expr, init)
307}
308
309func walkCompareInterface(n *ir.BinaryExpr, init *ir.Nodes) ir.Node {
310	n.Y = cheapExpr(n.Y, init)
311	n.X = cheapExpr(n.X, init)
312	eqtab, eqdata := compare.EqInterface(n.X, n.Y)
313	var cmp ir.Node
314	if n.Op() == ir.OEQ {
315		cmp = ir.NewLogicalExpr(base.Pos, ir.OANDAND, eqtab, eqdata)
316	} else {
317		eqtab.SetOp(ir.ONE)
318		cmp = ir.NewLogicalExpr(base.Pos, ir.OOROR, eqtab, ir.NewUnaryExpr(base.Pos, ir.ONOT, eqdata))
319	}
320	return finishCompare(n, cmp, init)
321}
322
323func walkCompareString(n *ir.BinaryExpr, init *ir.Nodes) ir.Node {
324	if base.Debug.Libfuzzer != 0 {
325		if !ir.IsConst(n.X, constant.String) || !ir.IsConst(n.Y, constant.String) {
326			fn := "libfuzzerHookStrCmp"
327			n.X = cheapExpr(n.X, init)
328			n.Y = cheapExpr(n.Y, init)
329			paramType := types.Types[types.TSTRING]
330			init.Append(mkcall(fn, nil, init, tracecmpArg(n.X, paramType, init), tracecmpArg(n.Y, paramType, init), fakePC(n)))
331		}
332	}
333	// Rewrite comparisons to short constant strings as length+byte-wise comparisons.
334	var cs, ncs ir.Node // const string, non-const string
335	switch {
336	case ir.IsConst(n.X, constant.String) && ir.IsConst(n.Y, constant.String):
337		// ignore; will be constant evaluated
338	case ir.IsConst(n.X, constant.String):
339		cs = n.X
340		ncs = n.Y
341	case ir.IsConst(n.Y, constant.String):
342		cs = n.Y
343		ncs = n.X
344	}
345	if cs != nil {
346		cmp := n.Op()
347		// Our comparison below assumes that the non-constant string
348		// is on the left hand side, so rewrite "" cmp x to x cmp "".
349		// See issue 24817.
350		if ir.IsConst(n.X, constant.String) {
351			cmp = brrev(cmp)
352		}
353
354		// maxRewriteLen was chosen empirically.
355		// It is the value that minimizes cmd/go file size
356		// across most architectures.
357		// See the commit description for CL 26758 for details.
358		maxRewriteLen := 6
359		// Some architectures can load unaligned byte sequence as 1 word.
360		// So we can cover longer strings with the same amount of code.
361		canCombineLoads := ssagen.Arch.LinkArch.CanMergeLoads
362		combine64bit := false
363		if canCombineLoads {
364			// Keep this low enough to generate less code than a function call.
365			maxRewriteLen = 2 * ssagen.Arch.LinkArch.RegSize
366			combine64bit = ssagen.Arch.LinkArch.RegSize >= 8
367		}
368
369		var and ir.Op
370		switch cmp {
371		case ir.OEQ:
372			and = ir.OANDAND
373		case ir.ONE:
374			and = ir.OOROR
375		default:
376			// Don't do byte-wise comparisons for <, <=, etc.
377			// They're fairly complicated.
378			// Length-only checks are ok, though.
379			maxRewriteLen = 0
380		}
381		if s := ir.StringVal(cs); len(s) <= maxRewriteLen {
382			if len(s) > 0 {
383				ncs = safeExpr(ncs, init)
384			}
385			r := ir.Node(ir.NewBinaryExpr(base.Pos, cmp, ir.NewUnaryExpr(base.Pos, ir.OLEN, ncs), ir.NewInt(base.Pos, int64(len(s)))))
386			remains := len(s)
387			for i := 0; remains > 0; {
388				if remains == 1 || !canCombineLoads {
389					cb := ir.NewInt(base.Pos, int64(s[i]))
390					ncb := ir.NewIndexExpr(base.Pos, ncs, ir.NewInt(base.Pos, int64(i)))
391					r = ir.NewLogicalExpr(base.Pos, and, r, ir.NewBinaryExpr(base.Pos, cmp, ncb, cb))
392					remains--
393					i++
394					continue
395				}
396				var step int
397				var convType *types.Type
398				switch {
399				case remains >= 8 && combine64bit:
400					convType = types.Types[types.TINT64]
401					step = 8
402				case remains >= 4:
403					convType = types.Types[types.TUINT32]
404					step = 4
405				case remains >= 2:
406					convType = types.Types[types.TUINT16]
407					step = 2
408				}
409				ncsubstr := typecheck.Conv(ir.NewIndexExpr(base.Pos, ncs, ir.NewInt(base.Pos, int64(i))), convType)
410				csubstr := int64(s[i])
411				// Calculate large constant from bytes as sequence of shifts and ors.
412				// Like this:  uint32(s[0]) | uint32(s[1])<<8 | uint32(s[2])<<16 ...
413				// ssa will combine this into a single large load.
414				for offset := 1; offset < step; offset++ {
415					b := typecheck.Conv(ir.NewIndexExpr(base.Pos, ncs, ir.NewInt(base.Pos, int64(i+offset))), convType)
416					b = ir.NewBinaryExpr(base.Pos, ir.OLSH, b, ir.NewInt(base.Pos, int64(8*offset)))
417					ncsubstr = ir.NewBinaryExpr(base.Pos, ir.OOR, ncsubstr, b)
418					csubstr |= int64(s[i+offset]) << uint8(8*offset)
419				}
420				csubstrPart := ir.NewInt(base.Pos, csubstr)
421				// Compare "step" bytes as once
422				r = ir.NewLogicalExpr(base.Pos, and, r, ir.NewBinaryExpr(base.Pos, cmp, csubstrPart, ncsubstr))
423				remains -= step
424				i += step
425			}
426			return finishCompare(n, r, init)
427		}
428	}
429
430	var r ir.Node
431	if n.Op() == ir.OEQ || n.Op() == ir.ONE {
432		// prepare for rewrite below
433		n.X = cheapExpr(n.X, init)
434		n.Y = cheapExpr(n.Y, init)
435		eqlen, eqmem := compare.EqString(n.X, n.Y)
436		// quick check of len before full compare for == or !=.
437		// memequal then tests equality up to length len.
438		if n.Op() == ir.OEQ {
439			// len(left) == len(right) && memequal(left, right, len)
440			r = ir.NewLogicalExpr(base.Pos, ir.OANDAND, eqlen, eqmem)
441		} else {
442			// len(left) != len(right) || !memequal(left, right, len)
443			eqlen.SetOp(ir.ONE)
444			r = ir.NewLogicalExpr(base.Pos, ir.OOROR, eqlen, ir.NewUnaryExpr(base.Pos, ir.ONOT, eqmem))
445		}
446	} else {
447		// sys_cmpstring(s1, s2) :: 0
448		r = mkcall("cmpstring", types.Types[types.TINT], init, typecheck.Conv(n.X, types.Types[types.TSTRING]), typecheck.Conv(n.Y, types.Types[types.TSTRING]))
449		r = ir.NewBinaryExpr(base.Pos, n.Op(), r, ir.NewInt(base.Pos, 0))
450	}
451
452	return finishCompare(n, r, init)
453}
454
455// The result of finishCompare MUST be assigned back to n, e.g.
456//
457//	n.Left = finishCompare(n.Left, x, r, init)
458func finishCompare(n *ir.BinaryExpr, r ir.Node, init *ir.Nodes) ir.Node {
459	r = typecheck.Expr(r)
460	r = typecheck.Conv(r, n.Type())
461	r = walkExpr(r, init)
462	return r
463}
464
465// brcom returns !(op).
466// For example, brcom(==) is !=.
467func brcom(op ir.Op) ir.Op {
468	switch op {
469	case ir.OEQ:
470		return ir.ONE
471	case ir.ONE:
472		return ir.OEQ
473	case ir.OLT:
474		return ir.OGE
475	case ir.OGT:
476		return ir.OLE
477	case ir.OLE:
478		return ir.OGT
479	case ir.OGE:
480		return ir.OLT
481	}
482	base.Fatalf("brcom: no com for %v\n", op)
483	return op
484}
485
486// brrev returns reverse(op).
487// For example, Brrev(<) is >.
488func brrev(op ir.Op) ir.Op {
489	switch op {
490	case ir.OEQ:
491		return ir.OEQ
492	case ir.ONE:
493		return ir.ONE
494	case ir.OLT:
495		return ir.OGT
496	case ir.OGT:
497		return ir.OLT
498	case ir.OLE:
499		return ir.OGE
500	case ir.OGE:
501		return ir.OLE
502	}
503	base.Fatalf("brrev: no rev for %v\n", op)
504	return op
505}
506
507func tracecmpArg(n ir.Node, t *types.Type, init *ir.Nodes) ir.Node {
508	// Ugly hack to avoid "constant -1 overflows uintptr" errors, etc.
509	if n.Op() == ir.OLITERAL && n.Type().IsSigned() && ir.Int64Val(n) < 0 {
510		n = copyExpr(n, n.Type(), init)
511	}
512
513	return typecheck.Conv(n, t)
514}
515