1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssa
6
7import (
8	"cmd/compile/internal/base"
9	"cmd/compile/internal/types"
10	"cmd/internal/src"
11	"sort"
12)
13
14// memcombine combines smaller loads and stores into larger ones.
15// We ensure this generates good code for encoding/binary operations.
16// It may help other cases also.
17func memcombine(f *Func) {
18	// This optimization requires that the architecture has
19	// unaligned loads and unaligned stores.
20	if !f.Config.unalignedOK {
21		return
22	}
23
24	memcombineLoads(f)
25	memcombineStores(f)
26}
27
28func memcombineLoads(f *Func) {
29	// Find "OR trees" to start with.
30	mark := f.newSparseSet(f.NumValues())
31	defer f.retSparseSet(mark)
32	var order []*Value
33
34	// Mark all values that are the argument of an OR.
35	for _, b := range f.Blocks {
36		for _, v := range b.Values {
37			if v.Op == OpOr16 || v.Op == OpOr32 || v.Op == OpOr64 {
38				mark.add(v.Args[0].ID)
39				mark.add(v.Args[1].ID)
40			}
41		}
42	}
43	for _, b := range f.Blocks {
44		order = order[:0]
45		for _, v := range b.Values {
46			if v.Op != OpOr16 && v.Op != OpOr32 && v.Op != OpOr64 {
47				continue
48			}
49			if mark.contains(v.ID) {
50				// marked - means it is not the root of an OR tree
51				continue
52			}
53			// Add the OR tree rooted at v to the order.
54			// We use BFS here, but any walk that puts roots before leaves would work.
55			i := len(order)
56			order = append(order, v)
57			for ; i < len(order); i++ {
58				x := order[i]
59				for j := 0; j < 2; j++ {
60					a := x.Args[j]
61					if a.Op == OpOr16 || a.Op == OpOr32 || a.Op == OpOr64 {
62						order = append(order, a)
63					}
64				}
65			}
66		}
67		for _, v := range order {
68			max := f.Config.RegSize
69			switch v.Op {
70			case OpOr64:
71			case OpOr32:
72				max = 4
73			case OpOr16:
74				max = 2
75			default:
76				continue
77			}
78			for n := max; n > 1; n /= 2 {
79				if combineLoads(v, n) {
80					break
81				}
82			}
83		}
84	}
85}
86
87// A BaseAddress represents the address ptr+idx, where
88// ptr is a pointer type and idx is an integer type.
89// idx may be nil, in which case it is treated as 0.
90type BaseAddress struct {
91	ptr *Value
92	idx *Value
93}
94
95// splitPtr returns the base address of ptr and any
96// constant offset from that base.
97// BaseAddress{ptr,nil},0 is always a valid result, but splitPtr
98// tries to peel away as many constants into off as possible.
99func splitPtr(ptr *Value) (BaseAddress, int64) {
100	var idx *Value
101	var off int64
102	for {
103		if ptr.Op == OpOffPtr {
104			off += ptr.AuxInt
105			ptr = ptr.Args[0]
106		} else if ptr.Op == OpAddPtr {
107			if idx != nil {
108				// We have two or more indexing values.
109				// Pick the first one we found.
110				return BaseAddress{ptr: ptr, idx: idx}, off
111			}
112			idx = ptr.Args[1]
113			if idx.Op == OpAdd32 || idx.Op == OpAdd64 {
114				if idx.Args[0].Op == OpConst32 || idx.Args[0].Op == OpConst64 {
115					off += idx.Args[0].AuxInt
116					idx = idx.Args[1]
117				} else if idx.Args[1].Op == OpConst32 || idx.Args[1].Op == OpConst64 {
118					off += idx.Args[1].AuxInt
119					idx = idx.Args[0]
120				}
121			}
122			ptr = ptr.Args[0]
123		} else {
124			return BaseAddress{ptr: ptr, idx: idx}, off
125		}
126	}
127}
128
129func combineLoads(root *Value, n int64) bool {
130	orOp := root.Op
131	var shiftOp Op
132	switch orOp {
133	case OpOr64:
134		shiftOp = OpLsh64x64
135	case OpOr32:
136		shiftOp = OpLsh32x64
137	case OpOr16:
138		shiftOp = OpLsh16x64
139	default:
140		return false
141	}
142
143	// Find n values that are ORed together with the above op.
144	a := make([]*Value, 0, 8)
145	a = append(a, root)
146	for i := 0; i < len(a) && int64(len(a)) < n; i++ {
147		v := a[i]
148		if v.Uses != 1 && v != root {
149			// Something in this subtree is used somewhere else.
150			return false
151		}
152		if v.Op == orOp {
153			a[i] = v.Args[0]
154			a = append(a, v.Args[1])
155			i--
156		}
157	}
158	if int64(len(a)) != n {
159		return false
160	}
161
162	// Check that the first entry to see what ops we're looking for.
163	// All the entries should be of the form shift(extend(load)), maybe with no shift.
164	v := a[0]
165	if v.Op == shiftOp {
166		v = v.Args[0]
167	}
168	var extOp Op
169	if orOp == OpOr64 && (v.Op == OpZeroExt8to64 || v.Op == OpZeroExt16to64 || v.Op == OpZeroExt32to64) ||
170		orOp == OpOr32 && (v.Op == OpZeroExt8to32 || v.Op == OpZeroExt16to32) ||
171		orOp == OpOr16 && v.Op == OpZeroExt8to16 {
172		extOp = v.Op
173		v = v.Args[0]
174	} else {
175		return false
176	}
177	if v.Op != OpLoad {
178		return false
179	}
180	base, _ := splitPtr(v.Args[0])
181	mem := v.Args[1]
182	size := v.Type.Size()
183
184	if root.Block.Func.Config.arch == "S390X" {
185		// s390x can't handle unaligned accesses to global variables.
186		if base.ptr.Op == OpAddr {
187			return false
188		}
189	}
190
191	// Check all the entries, extract useful info.
192	type LoadRecord struct {
193		load   *Value
194		offset int64 // offset of load address from base
195		shift  int64
196	}
197	r := make([]LoadRecord, n, 8)
198	for i := int64(0); i < n; i++ {
199		v := a[i]
200		if v.Uses != 1 {
201			return false
202		}
203		shift := int64(0)
204		if v.Op == shiftOp {
205			if v.Args[1].Op != OpConst64 {
206				return false
207			}
208			shift = v.Args[1].AuxInt
209			v = v.Args[0]
210			if v.Uses != 1 {
211				return false
212			}
213		}
214		if v.Op != extOp {
215			return false
216		}
217		load := v.Args[0]
218		if load.Op != OpLoad {
219			return false
220		}
221		if load.Uses != 1 {
222			return false
223		}
224		if load.Args[1] != mem {
225			return false
226		}
227		p, off := splitPtr(load.Args[0])
228		if p != base {
229			return false
230		}
231		r[i] = LoadRecord{load: load, offset: off, shift: shift}
232	}
233
234	// Sort in memory address order.
235	sort.Slice(r, func(i, j int) bool {
236		return r[i].offset < r[j].offset
237	})
238
239	// Check that we have contiguous offsets.
240	for i := int64(0); i < n; i++ {
241		if r[i].offset != r[0].offset+i*size {
242			return false
243		}
244	}
245
246	// Check for reads in little-endian or big-endian order.
247	shift0 := r[0].shift
248	isLittleEndian := true
249	for i := int64(0); i < n; i++ {
250		if r[i].shift != shift0+i*size*8 {
251			isLittleEndian = false
252			break
253		}
254	}
255	isBigEndian := true
256	for i := int64(0); i < n; i++ {
257		if r[i].shift != shift0-i*size*8 {
258			isBigEndian = false
259			break
260		}
261	}
262	if !isLittleEndian && !isBigEndian {
263		return false
264	}
265
266	// Find a place to put the new load.
267	// This is tricky, because it has to be at a point where
268	// its memory argument is live. We can't just put it in root.Block.
269	// We use the block of the latest load.
270	loads := make([]*Value, n, 8)
271	for i := int64(0); i < n; i++ {
272		loads[i] = r[i].load
273	}
274	loadBlock := mergePoint(root.Block, loads...)
275	if loadBlock == nil {
276		return false
277	}
278	// Find a source position to use.
279	pos := src.NoXPos
280	for _, load := range loads {
281		if load.Block == loadBlock {
282			pos = load.Pos
283			break
284		}
285	}
286	if pos == src.NoXPos {
287		return false
288	}
289
290	// Check to see if we need byte swap before storing.
291	needSwap := isLittleEndian && root.Block.Func.Config.BigEndian ||
292		isBigEndian && !root.Block.Func.Config.BigEndian
293	if needSwap && (size != 1 || !root.Block.Func.Config.haveByteSwap(n)) {
294		return false
295	}
296
297	// This is the commit point.
298
299	// First, issue load at lowest address.
300	v = loadBlock.NewValue2(pos, OpLoad, sizeType(n*size), r[0].load.Args[0], mem)
301
302	// Byte swap if needed,
303	if needSwap {
304		v = byteSwap(loadBlock, pos, v)
305	}
306
307	// Extend if needed.
308	if n*size < root.Type.Size() {
309		v = zeroExtend(loadBlock, pos, v, n*size, root.Type.Size())
310	}
311
312	// Shift if needed.
313	if isLittleEndian && shift0 != 0 {
314		v = leftShift(loadBlock, pos, v, shift0)
315	}
316	if isBigEndian && shift0-(n-1)*size*8 != 0 {
317		v = leftShift(loadBlock, pos, v, shift0-(n-1)*size*8)
318	}
319
320	// Install with (Copy v).
321	root.reset(OpCopy)
322	root.AddArg(v)
323
324	// Clobber the loads, just to prevent additional work being done on
325	// subtrees (which are now unreachable).
326	for i := int64(0); i < n; i++ {
327		clobber(r[i].load)
328	}
329	return true
330}
331
332func memcombineStores(f *Func) {
333	mark := f.newSparseSet(f.NumValues())
334	defer f.retSparseSet(mark)
335	var order []*Value
336
337	for _, b := range f.Blocks {
338		// Mark all stores which are not last in a store sequence.
339		mark.clear()
340		for _, v := range b.Values {
341			if v.Op == OpStore {
342				mark.add(v.MemoryArg().ID)
343			}
344		}
345
346		// pick an order for visiting stores such that
347		// later stores come earlier in the ordering.
348		order = order[:0]
349		for _, v := range b.Values {
350			if v.Op != OpStore {
351				continue
352			}
353			if mark.contains(v.ID) {
354				continue // not last in a chain of stores
355			}
356			for {
357				order = append(order, v)
358				v = v.Args[2]
359				if v.Block != b || v.Op != OpStore {
360					break
361				}
362			}
363		}
364
365		// Look for combining opportunities at each store in queue order.
366		for _, v := range order {
367			if v.Op != OpStore { // already rewritten
368				continue
369			}
370
371			size := v.Aux.(*types.Type).Size()
372			if size >= f.Config.RegSize || size == 0 {
373				continue
374			}
375
376			for n := f.Config.RegSize / size; n > 1; n /= 2 {
377				if combineStores(v, n) {
378					continue
379				}
380			}
381		}
382	}
383}
384
385// Try to combine the n stores ending in root.
386// Returns true if successful.
387func combineStores(root *Value, n int64) bool {
388	// Helper functions.
389	type StoreRecord struct {
390		store  *Value
391		offset int64
392	}
393	getShiftBase := func(a []StoreRecord) *Value {
394		x := a[0].store.Args[1]
395		y := a[1].store.Args[1]
396		switch x.Op {
397		case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
398			x = x.Args[0]
399		default:
400			return nil
401		}
402		switch y.Op {
403		case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
404			y = y.Args[0]
405		default:
406			return nil
407		}
408		var x2 *Value
409		switch x.Op {
410		case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
411			x2 = x.Args[0]
412		default:
413		}
414		var y2 *Value
415		switch y.Op {
416		case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
417			y2 = y.Args[0]
418		default:
419		}
420		if y2 == x {
421			// a shift of x and x itself.
422			return x
423		}
424		if x2 == y {
425			// a shift of y and y itself.
426			return y
427		}
428		if x2 == y2 {
429			// 2 shifts both of the same argument.
430			return x2
431		}
432		return nil
433	}
434	isShiftBase := func(v, base *Value) bool {
435		val := v.Args[1]
436		switch val.Op {
437		case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
438			val = val.Args[0]
439		default:
440			return false
441		}
442		if val == base {
443			return true
444		}
445		switch val.Op {
446		case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
447			val = val.Args[0]
448		default:
449			return false
450		}
451		return val == base
452	}
453	shift := func(v, base *Value) int64 {
454		val := v.Args[1]
455		switch val.Op {
456		case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
457			val = val.Args[0]
458		default:
459			return -1
460		}
461		if val == base {
462			return 0
463		}
464		switch val.Op {
465		case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
466			val = val.Args[1]
467		default:
468			return -1
469		}
470		if val.Op != OpConst64 {
471			return -1
472		}
473		return val.AuxInt
474	}
475
476	// Element size of the individual stores.
477	size := root.Aux.(*types.Type).Size()
478	if size*n > root.Block.Func.Config.RegSize {
479		return false
480	}
481
482	// Gather n stores to look at. Check easy conditions we require.
483	a := make([]StoreRecord, 0, 8)
484	rbase, roff := splitPtr(root.Args[0])
485	if root.Block.Func.Config.arch == "S390X" {
486		// s390x can't handle unaligned accesses to global variables.
487		if rbase.ptr.Op == OpAddr {
488			return false
489		}
490	}
491	a = append(a, StoreRecord{root, roff})
492	for i, x := int64(1), root.Args[2]; i < n; i, x = i+1, x.Args[2] {
493		if x.Op != OpStore {
494			return false
495		}
496		if x.Block != root.Block {
497			return false
498		}
499		if x.Uses != 1 { // Note: root can have more than one use.
500			return false
501		}
502		if x.Aux.(*types.Type).Size() != size {
503			// TODO: the constant source and consecutive load source cases
504			// do not need all the stores to be the same size.
505			return false
506		}
507		base, off := splitPtr(x.Args[0])
508		if base != rbase {
509			return false
510		}
511		a = append(a, StoreRecord{x, off})
512	}
513	// Before we sort, grab the memory arg the result should have.
514	mem := a[n-1].store.Args[2]
515	// Also grab position of first store (last in array = first in memory order).
516	pos := a[n-1].store.Pos
517
518	// Sort stores in increasing address order.
519	sort.Slice(a, func(i, j int) bool {
520		return a[i].offset < a[j].offset
521	})
522
523	// Check that everything is written to sequential locations.
524	for i := int64(0); i < n; i++ {
525		if a[i].offset != a[0].offset+i*size {
526			return false
527		}
528	}
529
530	// Memory location we're going to write at (the lowest one).
531	ptr := a[0].store.Args[0]
532
533	// Check for constant stores
534	isConst := true
535	for i := int64(0); i < n; i++ {
536		switch a[i].store.Args[1].Op {
537		case OpConst32, OpConst16, OpConst8, OpConstBool:
538		default:
539			isConst = false
540			break
541		}
542	}
543	if isConst {
544		// Modify root to do all the stores.
545		var c int64
546		mask := int64(1)<<(8*size) - 1
547		for i := int64(0); i < n; i++ {
548			s := 8 * size * int64(i)
549			if root.Block.Func.Config.BigEndian {
550				s = 8*size*(n-1) - s
551			}
552			c |= (a[i].store.Args[1].AuxInt & mask) << s
553		}
554		var cv *Value
555		switch size * n {
556		case 2:
557			cv = root.Block.Func.ConstInt16(types.Types[types.TUINT16], int16(c))
558		case 4:
559			cv = root.Block.Func.ConstInt32(types.Types[types.TUINT32], int32(c))
560		case 8:
561			cv = root.Block.Func.ConstInt64(types.Types[types.TUINT64], c)
562		}
563
564		// Move all the stores to the root.
565		for i := int64(0); i < n; i++ {
566			v := a[i].store
567			if v == root {
568				v.Aux = cv.Type // widen store type
569				v.Pos = pos
570				v.SetArg(0, ptr)
571				v.SetArg(1, cv)
572				v.SetArg(2, mem)
573			} else {
574				clobber(v)
575				v.Type = types.Types[types.TBOOL] // erase memory type
576			}
577		}
578		return true
579	}
580
581	// Check for consecutive loads as the source of the stores.
582	var loadMem *Value
583	var loadBase BaseAddress
584	var loadIdx int64
585	for i := int64(0); i < n; i++ {
586		load := a[i].store.Args[1]
587		if load.Op != OpLoad {
588			loadMem = nil
589			break
590		}
591		if load.Uses != 1 {
592			loadMem = nil
593			break
594		}
595		if load.Type.IsPtr() {
596			// Don't combine stores containing a pointer, as we need
597			// a write barrier for those. This can't currently happen,
598			// but might in the future if we ever have another
599			// 8-byte-reg/4-byte-ptr architecture like amd64p32.
600			loadMem = nil
601			break
602		}
603		mem := load.Args[1]
604		base, idx := splitPtr(load.Args[0])
605		if loadMem == nil {
606			// First one we found
607			loadMem = mem
608			loadBase = base
609			loadIdx = idx
610			continue
611		}
612		if base != loadBase || mem != loadMem {
613			loadMem = nil
614			break
615		}
616		if idx != loadIdx+(a[i].offset-a[0].offset) {
617			loadMem = nil
618			break
619		}
620	}
621	if loadMem != nil {
622		// Modify the first load to do a larger load instead.
623		load := a[0].store.Args[1]
624		switch size * n {
625		case 2:
626			load.Type = types.Types[types.TUINT16]
627		case 4:
628			load.Type = types.Types[types.TUINT32]
629		case 8:
630			load.Type = types.Types[types.TUINT64]
631		}
632
633		// Modify root to do the store.
634		for i := int64(0); i < n; i++ {
635			v := a[i].store
636			if v == root {
637				v.Aux = load.Type // widen store type
638				v.Pos = pos
639				v.SetArg(0, ptr)
640				v.SetArg(1, load)
641				v.SetArg(2, mem)
642			} else {
643				clobber(v)
644				v.Type = types.Types[types.TBOOL] // erase memory type
645			}
646		}
647		return true
648	}
649
650	// Check that all the shift/trunc are of the same base value.
651	shiftBase := getShiftBase(a)
652	if shiftBase == nil {
653		return false
654	}
655	for i := int64(0); i < n; i++ {
656		if !isShiftBase(a[i].store, shiftBase) {
657			return false
658		}
659	}
660
661	// Check for writes in little-endian or big-endian order.
662	isLittleEndian := true
663	shift0 := shift(a[0].store, shiftBase)
664	for i := int64(1); i < n; i++ {
665		if shift(a[i].store, shiftBase) != shift0+i*size*8 {
666			isLittleEndian = false
667			break
668		}
669	}
670	isBigEndian := true
671	for i := int64(1); i < n; i++ {
672		if shift(a[i].store, shiftBase) != shift0-i*size*8 {
673			isBigEndian = false
674			break
675		}
676	}
677	if !isLittleEndian && !isBigEndian {
678		return false
679	}
680
681	// Check to see if we need byte swap before storing.
682	needSwap := isLittleEndian && root.Block.Func.Config.BigEndian ||
683		isBigEndian && !root.Block.Func.Config.BigEndian
684	if needSwap && (size != 1 || !root.Block.Func.Config.haveByteSwap(n)) {
685		return false
686	}
687
688	// This is the commit point.
689
690	// Modify root to do all the stores.
691	sv := shiftBase
692	if isLittleEndian && shift0 != 0 {
693		sv = rightShift(root.Block, root.Pos, sv, shift0)
694	}
695	if isBigEndian && shift0-(n-1)*size*8 != 0 {
696		sv = rightShift(root.Block, root.Pos, sv, shift0-(n-1)*size*8)
697	}
698	if sv.Type.Size() > size*n {
699		sv = truncate(root.Block, root.Pos, sv, sv.Type.Size(), size*n)
700	}
701	if needSwap {
702		sv = byteSwap(root.Block, root.Pos, sv)
703	}
704
705	// Move all the stores to the root.
706	for i := int64(0); i < n; i++ {
707		v := a[i].store
708		if v == root {
709			v.Aux = sv.Type // widen store type
710			v.Pos = pos
711			v.SetArg(0, ptr)
712			v.SetArg(1, sv)
713			v.SetArg(2, mem)
714		} else {
715			clobber(v)
716			v.Type = types.Types[types.TBOOL] // erase memory type
717		}
718	}
719	return true
720}
721
722func sizeType(size int64) *types.Type {
723	switch size {
724	case 8:
725		return types.Types[types.TUINT64]
726	case 4:
727		return types.Types[types.TUINT32]
728	case 2:
729		return types.Types[types.TUINT16]
730	default:
731		base.Fatalf("bad size %d\n", size)
732		return nil
733	}
734}
735
736func truncate(b *Block, pos src.XPos, v *Value, from, to int64) *Value {
737	switch from*10 + to {
738	case 82:
739		return b.NewValue1(pos, OpTrunc64to16, types.Types[types.TUINT16], v)
740	case 84:
741		return b.NewValue1(pos, OpTrunc64to32, types.Types[types.TUINT32], v)
742	case 42:
743		return b.NewValue1(pos, OpTrunc32to16, types.Types[types.TUINT16], v)
744	default:
745		base.Fatalf("bad sizes %d %d\n", from, to)
746		return nil
747	}
748}
749func zeroExtend(b *Block, pos src.XPos, v *Value, from, to int64) *Value {
750	switch from*10 + to {
751	case 24:
752		return b.NewValue1(pos, OpZeroExt16to32, types.Types[types.TUINT32], v)
753	case 28:
754		return b.NewValue1(pos, OpZeroExt16to64, types.Types[types.TUINT64], v)
755	case 48:
756		return b.NewValue1(pos, OpZeroExt32to64, types.Types[types.TUINT64], v)
757	default:
758		base.Fatalf("bad sizes %d %d\n", from, to)
759		return nil
760	}
761}
762
763func leftShift(b *Block, pos src.XPos, v *Value, shift int64) *Value {
764	s := b.Func.ConstInt64(types.Types[types.TUINT64], shift)
765	size := v.Type.Size()
766	switch size {
767	case 8:
768		return b.NewValue2(pos, OpLsh64x64, v.Type, v, s)
769	case 4:
770		return b.NewValue2(pos, OpLsh32x64, v.Type, v, s)
771	case 2:
772		return b.NewValue2(pos, OpLsh16x64, v.Type, v, s)
773	default:
774		base.Fatalf("bad size %d\n", size)
775		return nil
776	}
777}
778func rightShift(b *Block, pos src.XPos, v *Value, shift int64) *Value {
779	s := b.Func.ConstInt64(types.Types[types.TUINT64], shift)
780	size := v.Type.Size()
781	switch size {
782	case 8:
783		return b.NewValue2(pos, OpRsh64Ux64, v.Type, v, s)
784	case 4:
785		return b.NewValue2(pos, OpRsh32Ux64, v.Type, v, s)
786	case 2:
787		return b.NewValue2(pos, OpRsh16Ux64, v.Type, v, s)
788	default:
789		base.Fatalf("bad size %d\n", size)
790		return nil
791	}
792}
793func byteSwap(b *Block, pos src.XPos, v *Value) *Value {
794	switch v.Type.Size() {
795	case 8:
796		return b.NewValue1(pos, OpBswap64, v.Type, v)
797	case 4:
798		return b.NewValue1(pos, OpBswap32, v.Type, v)
799	case 2:
800		return b.NewValue1(pos, OpBswap16, v.Type, v)
801
802	default:
803		v.Fatalf("bad size %d\n", v.Type.Size())
804		return nil
805	}
806}
807