1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssa
6
7import (
8	"fmt"
9	"os"
10)
11
12// If true, check poset integrity after every mutation
13var debugPoset = false
14
15const uintSize = 32 << (^uint(0) >> 63) // 32 or 64
16
17// bitset is a bit array for dense indexes.
18type bitset []uint
19
20func newBitset(n int) bitset {
21	return make(bitset, (n+uintSize-1)/uintSize)
22}
23
24func (bs bitset) Reset() {
25	for i := range bs {
26		bs[i] = 0
27	}
28}
29
30func (bs bitset) Set(idx uint32) {
31	bs[idx/uintSize] |= 1 << (idx % uintSize)
32}
33
34func (bs bitset) Clear(idx uint32) {
35	bs[idx/uintSize] &^= 1 << (idx % uintSize)
36}
37
38func (bs bitset) Test(idx uint32) bool {
39	return bs[idx/uintSize]&(1<<(idx%uintSize)) != 0
40}
41
42type undoType uint8
43
44const (
45	undoInvalid     undoType = iota
46	undoCheckpoint           // a checkpoint to group undo passes
47	undoSetChl               // change back left child of undo.idx to undo.edge
48	undoSetChr               // change back right child of undo.idx to undo.edge
49	undoNonEqual             // forget that SSA value undo.ID is non-equal to undo.idx (another ID)
50	undoNewNode              // remove new node created for SSA value undo.ID
51	undoNewConstant          // remove the constant node idx from the constants map
52	undoAliasNode            // unalias SSA value undo.ID so that it points back to node index undo.idx
53	undoNewRoot              // remove node undo.idx from root list
54	undoChangeRoot           // remove node undo.idx from root list, and put back undo.edge.Target instead
55	undoMergeRoot            // remove node undo.idx from root list, and put back its children instead
56)
57
58// posetUndo represents an undo pass to be performed.
59// It's a union of fields that can be used to store information,
60// and typ is the discriminant, that specifies which kind
61// of operation must be performed. Not all fields are always used.
62type posetUndo struct {
63	typ  undoType
64	idx  uint32
65	ID   ID
66	edge posetEdge
67}
68
69const (
70	// Make poset handle constants as unsigned numbers.
71	posetFlagUnsigned = 1 << iota
72)
73
74// A poset edge. The zero value is the null/empty edge.
75// Packs target node index (31 bits) and strict flag (1 bit).
76type posetEdge uint32
77
78func newedge(t uint32, strict bool) posetEdge {
79	s := uint32(0)
80	if strict {
81		s = 1
82	}
83	return posetEdge(t<<1 | s)
84}
85func (e posetEdge) Target() uint32 { return uint32(e) >> 1 }
86func (e posetEdge) Strict() bool   { return uint32(e)&1 != 0 }
87func (e posetEdge) String() string {
88	s := fmt.Sprint(e.Target())
89	if e.Strict() {
90		s += "*"
91	}
92	return s
93}
94
95// posetNode is a node of a DAG within the poset.
96type posetNode struct {
97	l, r posetEdge
98}
99
100// poset is a union-find data structure that can represent a partially ordered set
101// of SSA values. Given a binary relation that creates a partial order (eg: '<'),
102// clients can record relations between SSA values using SetOrder, and later
103// check relations (in the transitive closure) with Ordered. For instance,
104// if SetOrder is called to record that A<B and B<C, Ordered will later confirm
105// that A<C.
106//
107// It is possible to record equality relations between SSA values with SetEqual and check
108// equality with Equal. Equality propagates into the transitive closure for the partial
109// order so that if we know that A<B<C and later learn that A==D, Ordered will return
110// true for D<C.
111//
112// It is also possible to record inequality relations between nodes with SetNonEqual;
113// non-equality relations are not transitive, but they can still be useful: for instance
114// if we know that A<=B and later we learn that A!=B, we can deduce that A<B.
115// NonEqual can be used to check whether it is known that the nodes are different, either
116// because SetNonEqual was called before, or because we know that they are strictly ordered.
117//
118// poset will refuse to record new relations that contradict existing relations:
119// for instance if A<B<C, calling SetOrder for C<A will fail returning false; also
120// calling SetEqual for C==A will fail.
121//
122// poset is implemented as a forest of DAGs; in each DAG, if there is a path (directed)
123// from node A to B, it means that A<B (or A<=B). Equality is represented by mapping
124// two SSA values to the same DAG node; when a new equality relation is recorded
125// between two existing nodes, the nodes are merged, adjusting incoming and outgoing edges.
126//
127// Constants are specially treated. When a constant is added to the poset, it is
128// immediately linked to other constants already present; so for instance if the
129// poset knows that x<=3, and then x is tested against 5, 5 is first added and linked
130// 3 (using 3<5), so that the poset knows that x<=3<5; at that point, it is able
131// to answer x<5 correctly. This means that all constants are always within the same
132// DAG; as an implementation detail, we enfoce that the DAG containtining the constants
133// is always the first in the forest.
134//
135// poset is designed to be memory efficient and do little allocations during normal usage.
136// Most internal data structures are pre-allocated and flat, so for instance adding a
137// new relation does not cause any allocation. For performance reasons,
138// each node has only up to two outgoing edges (like a binary tree), so intermediate
139// "extra" nodes are required to represent more than two relations. For instance,
140// to record that A<I, A<J, A<K (with no known relation between I,J,K), we create the
141// following DAG:
142//
143//	  A
144//	 / \
145//	I  extra
146//	    /  \
147//	   J    K
148type poset struct {
149	lastidx   uint32            // last generated dense index
150	flags     uint8             // internal flags
151	values    map[ID]uint32     // map SSA values to dense indexes
152	constants map[int64]uint32  // record SSA constants together with their value
153	nodes     []posetNode       // nodes (in all DAGs)
154	roots     []uint32          // list of root nodes (forest)
155	noneq     map[uint32]bitset // non-equal relations
156	undo      []posetUndo       // undo chain
157}
158
159func newPoset() *poset {
160	return &poset{
161		values:    make(map[ID]uint32),
162		constants: make(map[int64]uint32, 8),
163		nodes:     make([]posetNode, 1, 16),
164		roots:     make([]uint32, 0, 4),
165		noneq:     make(map[uint32]bitset),
166		undo:      make([]posetUndo, 0, 4),
167	}
168}
169
170func (po *poset) SetUnsigned(uns bool) {
171	if uns {
172		po.flags |= posetFlagUnsigned
173	} else {
174		po.flags &^= posetFlagUnsigned
175	}
176}
177
178// Handle children
179func (po *poset) setchl(i uint32, l posetEdge) { po.nodes[i].l = l }
180func (po *poset) setchr(i uint32, r posetEdge) { po.nodes[i].r = r }
181func (po *poset) chl(i uint32) uint32          { return po.nodes[i].l.Target() }
182func (po *poset) chr(i uint32) uint32          { return po.nodes[i].r.Target() }
183func (po *poset) children(i uint32) (posetEdge, posetEdge) {
184	return po.nodes[i].l, po.nodes[i].r
185}
186
187// upush records a new undo step. It can be used for simple
188// undo passes that record up to one index and one edge.
189func (po *poset) upush(typ undoType, p uint32, e posetEdge) {
190	po.undo = append(po.undo, posetUndo{typ: typ, idx: p, edge: e})
191}
192
193// upushnew pushes an undo pass for a new node
194func (po *poset) upushnew(id ID, idx uint32) {
195	po.undo = append(po.undo, posetUndo{typ: undoNewNode, ID: id, idx: idx})
196}
197
198// upushneq pushes a new undo pass for a nonequal relation
199func (po *poset) upushneq(idx1 uint32, idx2 uint32) {
200	po.undo = append(po.undo, posetUndo{typ: undoNonEqual, ID: ID(idx1), idx: idx2})
201}
202
203// upushalias pushes a new undo pass for aliasing two nodes
204func (po *poset) upushalias(id ID, i2 uint32) {
205	po.undo = append(po.undo, posetUndo{typ: undoAliasNode, ID: id, idx: i2})
206}
207
208// upushconst pushes a new undo pass for a new constant
209func (po *poset) upushconst(idx uint32, old uint32) {
210	po.undo = append(po.undo, posetUndo{typ: undoNewConstant, idx: idx, ID: ID(old)})
211}
212
213// addchild adds i2 as direct child of i1.
214func (po *poset) addchild(i1, i2 uint32, strict bool) {
215	i1l, i1r := po.children(i1)
216	e2 := newedge(i2, strict)
217
218	if i1l == 0 {
219		po.setchl(i1, e2)
220		po.upush(undoSetChl, i1, 0)
221	} else if i1r == 0 {
222		po.setchr(i1, e2)
223		po.upush(undoSetChr, i1, 0)
224	} else {
225		// If n1 already has two children, add an intermediate extra
226		// node to record the relation correctly (without relating
227		// n2 to other existing nodes). Use a non-deterministic value
228		// to decide whether to append on the left or the right, to avoid
229		// creating degenerated chains.
230		//
231		//      n1
232		//     /  \
233		//   i1l  extra
234		//        /   \
235		//      i1r   n2
236		//
237		extra := po.newnode(nil)
238		if (i1^i2)&1 != 0 { // non-deterministic
239			po.setchl(extra, i1r)
240			po.setchr(extra, e2)
241			po.setchr(i1, newedge(extra, false))
242			po.upush(undoSetChr, i1, i1r)
243		} else {
244			po.setchl(extra, i1l)
245			po.setchr(extra, e2)
246			po.setchl(i1, newedge(extra, false))
247			po.upush(undoSetChl, i1, i1l)
248		}
249	}
250}
251
252// newnode allocates a new node bound to SSA value n.
253// If n is nil, this is an extra node (= only used internally).
254func (po *poset) newnode(n *Value) uint32 {
255	i := po.lastidx + 1
256	po.lastidx++
257	po.nodes = append(po.nodes, posetNode{})
258	if n != nil {
259		if po.values[n.ID] != 0 {
260			panic("newnode for Value already inserted")
261		}
262		po.values[n.ID] = i
263		po.upushnew(n.ID, i)
264	} else {
265		po.upushnew(0, i)
266	}
267	return i
268}
269
270// lookup searches for a SSA value into the forest of DAGS, and return its node.
271// Constants are materialized on the fly during lookup.
272func (po *poset) lookup(n *Value) (uint32, bool) {
273	i, f := po.values[n.ID]
274	if !f && n.isGenericIntConst() {
275		po.newconst(n)
276		i, f = po.values[n.ID]
277	}
278	return i, f
279}
280
281// newconst creates a node for a constant. It links it to other constants, so
282// that n<=5 is detected true when n<=3 is known to be true.
283// TODO: this is O(N), fix it.
284func (po *poset) newconst(n *Value) {
285	if !n.isGenericIntConst() {
286		panic("newconst on non-constant")
287	}
288
289	// If the same constant is already present in the poset through a different
290	// Value, just alias to it without allocating a new node.
291	val := n.AuxInt
292	if po.flags&posetFlagUnsigned != 0 {
293		val = int64(n.AuxUnsigned())
294	}
295	if c, found := po.constants[val]; found {
296		po.values[n.ID] = c
297		po.upushalias(n.ID, 0)
298		return
299	}
300
301	// Create the new node for this constant
302	i := po.newnode(n)
303
304	// If this is the first constant, put it as a new root, as
305	// we can't record an existing connection so we don't have
306	// a specific DAG to add it to. Notice that we want all
307	// constants to be in root #0, so make sure the new root
308	// goes there.
309	if len(po.constants) == 0 {
310		idx := len(po.roots)
311		po.roots = append(po.roots, i)
312		po.roots[0], po.roots[idx] = po.roots[idx], po.roots[0]
313		po.upush(undoNewRoot, i, 0)
314		po.constants[val] = i
315		po.upushconst(i, 0)
316		return
317	}
318
319	// Find the lower and upper bound among existing constants. That is,
320	// find the higher constant that is lower than the one that we're adding,
321	// and the lower constant that is higher.
322	// The loop is duplicated to handle signed and unsigned comparison,
323	// depending on how the poset was configured.
324	var lowerptr, higherptr uint32
325
326	if po.flags&posetFlagUnsigned != 0 {
327		var lower, higher uint64
328		val1 := n.AuxUnsigned()
329		for val2, ptr := range po.constants {
330			val2 := uint64(val2)
331			if val1 == val2 {
332				panic("unreachable")
333			}
334			if val2 < val1 && (lowerptr == 0 || val2 > lower) {
335				lower = val2
336				lowerptr = ptr
337			} else if val2 > val1 && (higherptr == 0 || val2 < higher) {
338				higher = val2
339				higherptr = ptr
340			}
341		}
342	} else {
343		var lower, higher int64
344		val1 := n.AuxInt
345		for val2, ptr := range po.constants {
346			if val1 == val2 {
347				panic("unreachable")
348			}
349			if val2 < val1 && (lowerptr == 0 || val2 > lower) {
350				lower = val2
351				lowerptr = ptr
352			} else if val2 > val1 && (higherptr == 0 || val2 < higher) {
353				higher = val2
354				higherptr = ptr
355			}
356		}
357	}
358
359	if lowerptr == 0 && higherptr == 0 {
360		// This should not happen, as at least one
361		// other constant must exist if we get here.
362		panic("no constant found")
363	}
364
365	// Create the new node and connect it to the bounds, so that
366	// lower < n < higher. We could have found both bounds or only one
367	// of them, depending on what other constants are present in the poset.
368	// Notice that we always link constants together, so they
369	// are always part of the same DAG.
370	switch {
371	case lowerptr != 0 && higherptr != 0:
372		// Both bounds are present, record lower < n < higher.
373		po.addchild(lowerptr, i, true)
374		po.addchild(i, higherptr, true)
375
376	case lowerptr != 0:
377		// Lower bound only, record lower < n.
378		po.addchild(lowerptr, i, true)
379
380	case higherptr != 0:
381		// Higher bound only. To record n < higher, we need
382		// an extra root:
383		//
384		//        extra
385		//        /   \
386		//      root   \
387		//       /      n
388		//     ....    /
389		//       \    /
390		//       higher
391		//
392		i2 := higherptr
393		r2 := po.findroot(i2)
394		if r2 != po.roots[0] { // all constants should be in root #0
395			panic("constant not in root #0")
396		}
397		extra := po.newnode(nil)
398		po.changeroot(r2, extra)
399		po.upush(undoChangeRoot, extra, newedge(r2, false))
400		po.addchild(extra, r2, false)
401		po.addchild(extra, i, false)
402		po.addchild(i, i2, true)
403	}
404
405	po.constants[val] = i
406	po.upushconst(i, 0)
407}
408
409// aliasnewnode records that a single node n2 (not in the poset yet) is an alias
410// of the master node n1.
411func (po *poset) aliasnewnode(n1, n2 *Value) {
412	i1, i2 := po.values[n1.ID], po.values[n2.ID]
413	if i1 == 0 || i2 != 0 {
414		panic("aliasnewnode invalid arguments")
415	}
416
417	po.values[n2.ID] = i1
418	po.upushalias(n2.ID, 0)
419}
420
421// aliasnodes records that all the nodes i2s are aliases of a single master node n1.
422// aliasnodes takes care of rearranging the DAG, changing references of parent/children
423// of nodes in i2s, so that they point to n1 instead.
424// Complexity is O(n) (with n being the total number of nodes in the poset, not just
425// the number of nodes being aliased).
426func (po *poset) aliasnodes(n1 *Value, i2s bitset) {
427	i1 := po.values[n1.ID]
428	if i1 == 0 {
429		panic("aliasnode for non-existing node")
430	}
431	if i2s.Test(i1) {
432		panic("aliasnode i2s contains n1 node")
433	}
434
435	// Go through all the nodes to adjust parent/chidlren of nodes in i2s
436	for idx, n := range po.nodes {
437		// Do not touch i1 itself, otherwise we can create useless self-loops
438		if uint32(idx) == i1 {
439			continue
440		}
441		l, r := n.l, n.r
442
443		// Rename all references to i2s into i1
444		if i2s.Test(l.Target()) {
445			po.setchl(uint32(idx), newedge(i1, l.Strict()))
446			po.upush(undoSetChl, uint32(idx), l)
447		}
448		if i2s.Test(r.Target()) {
449			po.setchr(uint32(idx), newedge(i1, r.Strict()))
450			po.upush(undoSetChr, uint32(idx), r)
451		}
452
453		// Connect all children of i2s to i1 (unless those children
454		// are in i2s as well, in which case it would be useless)
455		if i2s.Test(uint32(idx)) {
456			if l != 0 && !i2s.Test(l.Target()) {
457				po.addchild(i1, l.Target(), l.Strict())
458			}
459			if r != 0 && !i2s.Test(r.Target()) {
460				po.addchild(i1, r.Target(), r.Strict())
461			}
462			po.setchl(uint32(idx), 0)
463			po.setchr(uint32(idx), 0)
464			po.upush(undoSetChl, uint32(idx), l)
465			po.upush(undoSetChr, uint32(idx), r)
466		}
467	}
468
469	// Reassign all existing IDs that point to i2 to i1.
470	// This includes n2.ID.
471	for k, v := range po.values {
472		if i2s.Test(v) {
473			po.values[k] = i1
474			po.upushalias(k, v)
475		}
476	}
477
478	// If one of the aliased nodes is a constant, then make sure
479	// po.constants is updated to point to the master node.
480	for val, idx := range po.constants {
481		if i2s.Test(idx) {
482			po.constants[val] = i1
483			po.upushconst(i1, idx)
484		}
485	}
486}
487
488func (po *poset) isroot(r uint32) bool {
489	for i := range po.roots {
490		if po.roots[i] == r {
491			return true
492		}
493	}
494	return false
495}
496
497func (po *poset) changeroot(oldr, newr uint32) {
498	for i := range po.roots {
499		if po.roots[i] == oldr {
500			po.roots[i] = newr
501			return
502		}
503	}
504	panic("changeroot on non-root")
505}
506
507func (po *poset) removeroot(r uint32) {
508	for i := range po.roots {
509		if po.roots[i] == r {
510			po.roots = append(po.roots[:i], po.roots[i+1:]...)
511			return
512		}
513	}
514	panic("removeroot on non-root")
515}
516
517// dfs performs a depth-first search within the DAG whose root is r.
518// f is the visit function called for each node; if it returns true,
519// the search is aborted and true is returned. The root node is
520// visited too.
521// If strict, ignore edges across a path until at least one
522// strict edge is found. For instance, for a chain A<=B<=C<D<=E<F,
523// a strict walk visits D,E,F.
524// If the visit ends, false is returned.
525func (po *poset) dfs(r uint32, strict bool, f func(i uint32) bool) bool {
526	closed := newBitset(int(po.lastidx + 1))
527	open := make([]uint32, 1, 64)
528	open[0] = r
529
530	if strict {
531		// Do a first DFS; walk all paths and stop when we find a strict
532		// edge, building a "next" list of nodes reachable through strict
533		// edges. This will be the bootstrap open list for the real DFS.
534		next := make([]uint32, 0, 64)
535
536		for len(open) > 0 {
537			i := open[len(open)-1]
538			open = open[:len(open)-1]
539
540			// Don't visit the same node twice. Notice that all nodes
541			// across non-strict paths are still visited at least once, so
542			// a non-strict path can never obscure a strict path to the
543			// same node.
544			if !closed.Test(i) {
545				closed.Set(i)
546
547				l, r := po.children(i)
548				if l != 0 {
549					if l.Strict() {
550						next = append(next, l.Target())
551					} else {
552						open = append(open, l.Target())
553					}
554				}
555				if r != 0 {
556					if r.Strict() {
557						next = append(next, r.Target())
558					} else {
559						open = append(open, r.Target())
560					}
561				}
562			}
563		}
564		open = next
565		closed.Reset()
566	}
567
568	for len(open) > 0 {
569		i := open[len(open)-1]
570		open = open[:len(open)-1]
571
572		if !closed.Test(i) {
573			if f(i) {
574				return true
575			}
576			closed.Set(i)
577			l, r := po.children(i)
578			if l != 0 {
579				open = append(open, l.Target())
580			}
581			if r != 0 {
582				open = append(open, r.Target())
583			}
584		}
585	}
586	return false
587}
588
589// Returns true if there is a path from i1 to i2.
590// If strict ==  true: if the function returns true, then i1 <  i2.
591// If strict == false: if the function returns true, then i1 <= i2.
592// If the function returns false, no relation is known.
593func (po *poset) reaches(i1, i2 uint32, strict bool) bool {
594	return po.dfs(i1, strict, func(n uint32) bool {
595		return n == i2
596	})
597}
598
599// findroot finds i's root, that is which DAG contains i.
600// Returns the root; if i is itself a root, it is returned.
601// Panic if i is not in any DAG.
602func (po *poset) findroot(i uint32) uint32 {
603	// TODO(rasky): if needed, a way to speed up this search is
604	// storing a bitset for each root using it as a mini bloom filter
605	// of nodes present under that root.
606	for _, r := range po.roots {
607		if po.reaches(r, i, false) {
608			return r
609		}
610	}
611	panic("findroot didn't find any root")
612}
613
614// mergeroot merges two DAGs into one DAG by creating a new extra root
615func (po *poset) mergeroot(r1, r2 uint32) uint32 {
616	// Root #0 is special as it contains all constants. Since mergeroot
617	// discards r2 as root and keeps r1, make sure that r2 is not root #0,
618	// otherwise constants would move to a different root.
619	if r2 == po.roots[0] {
620		r1, r2 = r2, r1
621	}
622	r := po.newnode(nil)
623	po.setchl(r, newedge(r1, false))
624	po.setchr(r, newedge(r2, false))
625	po.changeroot(r1, r)
626	po.removeroot(r2)
627	po.upush(undoMergeRoot, r, 0)
628	return r
629}
630
631// collapsepath marks n1 and n2 as equal and collapses as equal all
632// nodes across all paths between n1 and n2. If a strict edge is
633// found, the function does not modify the DAG and returns false.
634// Complexity is O(n).
635func (po *poset) collapsepath(n1, n2 *Value) bool {
636	i1, i2 := po.values[n1.ID], po.values[n2.ID]
637	if po.reaches(i1, i2, true) {
638		return false
639	}
640
641	// Find all the paths from i1 to i2
642	paths := po.findpaths(i1, i2)
643	// Mark all nodes in all the paths as aliases of n1
644	// (excluding n1 itself)
645	paths.Clear(i1)
646	po.aliasnodes(n1, paths)
647	return true
648}
649
650// findpaths is a recursive function that calculates all paths from cur to dst
651// and return them as a bitset (the index of a node is set in the bitset if
652// that node is on at least one path from cur to dst).
653// We do a DFS from cur (stopping going deep any time we reach dst, if ever),
654// and mark as part of the paths any node that has a children which is already
655// part of the path (or is dst itself).
656func (po *poset) findpaths(cur, dst uint32) bitset {
657	seen := newBitset(int(po.lastidx + 1))
658	path := newBitset(int(po.lastidx + 1))
659	path.Set(dst)
660	po.findpaths1(cur, dst, seen, path)
661	return path
662}
663
664func (po *poset) findpaths1(cur, dst uint32, seen bitset, path bitset) {
665	if cur == dst {
666		return
667	}
668	seen.Set(cur)
669	l, r := po.chl(cur), po.chr(cur)
670	if !seen.Test(l) {
671		po.findpaths1(l, dst, seen, path)
672	}
673	if !seen.Test(r) {
674		po.findpaths1(r, dst, seen, path)
675	}
676	if path.Test(l) || path.Test(r) {
677		path.Set(cur)
678	}
679}
680
681// Check whether it is recorded that i1!=i2
682func (po *poset) isnoneq(i1, i2 uint32) bool {
683	if i1 == i2 {
684		return false
685	}
686	if i1 < i2 {
687		i1, i2 = i2, i1
688	}
689
690	// Check if we recorded a non-equal relation before
691	if bs, ok := po.noneq[i1]; ok && bs.Test(i2) {
692		return true
693	}
694	return false
695}
696
697// Record that i1!=i2
698func (po *poset) setnoneq(n1, n2 *Value) {
699	i1, f1 := po.lookup(n1)
700	i2, f2 := po.lookup(n2)
701
702	// If any of the nodes do not exist in the poset, allocate them. Since
703	// we don't know any relation (in the partial order) about them, they must
704	// become independent roots.
705	if !f1 {
706		i1 = po.newnode(n1)
707		po.roots = append(po.roots, i1)
708		po.upush(undoNewRoot, i1, 0)
709	}
710	if !f2 {
711		i2 = po.newnode(n2)
712		po.roots = append(po.roots, i2)
713		po.upush(undoNewRoot, i2, 0)
714	}
715
716	if i1 == i2 {
717		panic("setnoneq on same node")
718	}
719	if i1 < i2 {
720		i1, i2 = i2, i1
721	}
722	bs := po.noneq[i1]
723	if bs == nil {
724		// Given that we record non-equality relations using the
725		// higher index as a key, the bitsize will never change size.
726		// TODO(rasky): if memory is a problem, consider allocating
727		// a small bitset and lazily grow it when higher indices arrive.
728		bs = newBitset(int(i1))
729		po.noneq[i1] = bs
730	} else if bs.Test(i2) {
731		// Already recorded
732		return
733	}
734	bs.Set(i2)
735	po.upushneq(i1, i2)
736}
737
738// CheckIntegrity verifies internal integrity of a poset. It is intended
739// for debugging purposes.
740func (po *poset) CheckIntegrity() {
741	// Record which index is a constant
742	constants := newBitset(int(po.lastidx + 1))
743	for _, c := range po.constants {
744		constants.Set(c)
745	}
746
747	// Verify that each node appears in a single DAG, and that
748	// all constants are within the first DAG
749	seen := newBitset(int(po.lastidx + 1))
750	for ridx, r := range po.roots {
751		if r == 0 {
752			panic("empty root")
753		}
754
755		po.dfs(r, false, func(i uint32) bool {
756			if seen.Test(i) {
757				panic("duplicate node")
758			}
759			seen.Set(i)
760			if constants.Test(i) {
761				if ridx != 0 {
762					panic("constants not in the first DAG")
763				}
764			}
765			return false
766		})
767	}
768
769	// Verify that values contain the minimum set
770	for id, idx := range po.values {
771		if !seen.Test(idx) {
772			panic(fmt.Errorf("spurious value [%d]=%d", id, idx))
773		}
774	}
775
776	// Verify that only existing nodes have non-zero children
777	for i, n := range po.nodes {
778		if n.l|n.r != 0 {
779			if !seen.Test(uint32(i)) {
780				panic(fmt.Errorf("children of unknown node %d->%v", i, n))
781			}
782			if n.l.Target() == uint32(i) || n.r.Target() == uint32(i) {
783				panic(fmt.Errorf("self-loop on node %d", i))
784			}
785		}
786	}
787}
788
789// CheckEmpty checks that a poset is completely empty.
790// It can be used for debugging purposes, as a poset is supposed to
791// be empty after it's fully rolled back through Undo.
792func (po *poset) CheckEmpty() error {
793	if len(po.nodes) != 1 {
794		return fmt.Errorf("non-empty nodes list: %v", po.nodes)
795	}
796	if len(po.values) != 0 {
797		return fmt.Errorf("non-empty value map: %v", po.values)
798	}
799	if len(po.roots) != 0 {
800		return fmt.Errorf("non-empty root list: %v", po.roots)
801	}
802	if len(po.constants) != 0 {
803		return fmt.Errorf("non-empty constants: %v", po.constants)
804	}
805	if len(po.undo) != 0 {
806		return fmt.Errorf("non-empty undo list: %v", po.undo)
807	}
808	if po.lastidx != 0 {
809		return fmt.Errorf("lastidx index is not zero: %v", po.lastidx)
810	}
811	for _, bs := range po.noneq {
812		for _, x := range bs {
813			if x != 0 {
814				return fmt.Errorf("non-empty noneq map")
815			}
816		}
817	}
818	return nil
819}
820
821// DotDump dumps the poset in graphviz format to file fn, with the specified title.
822func (po *poset) DotDump(fn string, title string) error {
823	f, err := os.Create(fn)
824	if err != nil {
825		return err
826	}
827	defer f.Close()
828
829	// Create reverse index mapping (taking aliases into account)
830	names := make(map[uint32]string)
831	for id, i := range po.values {
832		s := names[i]
833		if s == "" {
834			s = fmt.Sprintf("v%d", id)
835		} else {
836			s += fmt.Sprintf(", v%d", id)
837		}
838		names[i] = s
839	}
840
841	// Create reverse constant mapping
842	consts := make(map[uint32]int64)
843	for val, idx := range po.constants {
844		consts[idx] = val
845	}
846
847	fmt.Fprintf(f, "digraph poset {\n")
848	fmt.Fprintf(f, "\tedge [ fontsize=10 ]\n")
849	for ridx, r := range po.roots {
850		fmt.Fprintf(f, "\tsubgraph root%d {\n", ridx)
851		po.dfs(r, false, func(i uint32) bool {
852			if val, ok := consts[i]; ok {
853				// Constant
854				var vals string
855				if po.flags&posetFlagUnsigned != 0 {
856					vals = fmt.Sprint(uint64(val))
857				} else {
858					vals = fmt.Sprint(int64(val))
859				}
860				fmt.Fprintf(f, "\t\tnode%d [shape=box style=filled fillcolor=cadetblue1 label=<%s <font point-size=\"6\">%s [%d]</font>>]\n",
861					i, vals, names[i], i)
862			} else {
863				// Normal SSA value
864				fmt.Fprintf(f, "\t\tnode%d [label=<%s <font point-size=\"6\">[%d]</font>>]\n", i, names[i], i)
865			}
866			chl, chr := po.children(i)
867			for _, ch := range []posetEdge{chl, chr} {
868				if ch != 0 {
869					if ch.Strict() {
870						fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <\" color=\"red\"]\n", i, ch.Target())
871					} else {
872						fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <=\" color=\"green\"]\n", i, ch.Target())
873					}
874				}
875			}
876			return false
877		})
878		fmt.Fprintf(f, "\t}\n")
879	}
880	fmt.Fprintf(f, "\tlabelloc=\"t\"\n")
881	fmt.Fprintf(f, "\tlabeldistance=\"3.0\"\n")
882	fmt.Fprintf(f, "\tlabel=%q\n", title)
883	fmt.Fprintf(f, "}\n")
884	return nil
885}
886
887// Ordered reports whether n1<n2. It returns false either when it is
888// certain that n1<n2 is false, or if there is not enough information
889// to tell.
890// Complexity is O(n).
891func (po *poset) Ordered(n1, n2 *Value) bool {
892	if debugPoset {
893		defer po.CheckIntegrity()
894	}
895	if n1.ID == n2.ID {
896		panic("should not call Ordered with n1==n2")
897	}
898
899	i1, f1 := po.lookup(n1)
900	i2, f2 := po.lookup(n2)
901	if !f1 || !f2 {
902		return false
903	}
904
905	return i1 != i2 && po.reaches(i1, i2, true)
906}
907
908// OrderedOrEqual reports whether n1<=n2. It returns false either when it is
909// certain that n1<=n2 is false, or if there is not enough information
910// to tell.
911// Complexity is O(n).
912func (po *poset) OrderedOrEqual(n1, n2 *Value) bool {
913	if debugPoset {
914		defer po.CheckIntegrity()
915	}
916	if n1.ID == n2.ID {
917		panic("should not call Ordered with n1==n2")
918	}
919
920	i1, f1 := po.lookup(n1)
921	i2, f2 := po.lookup(n2)
922	if !f1 || !f2 {
923		return false
924	}
925
926	return i1 == i2 || po.reaches(i1, i2, false)
927}
928
929// Equal reports whether n1==n2. It returns false either when it is
930// certain that n1==n2 is false, or if there is not enough information
931// to tell.
932// Complexity is O(1).
933func (po *poset) Equal(n1, n2 *Value) bool {
934	if debugPoset {
935		defer po.CheckIntegrity()
936	}
937	if n1.ID == n2.ID {
938		panic("should not call Equal with n1==n2")
939	}
940
941	i1, f1 := po.lookup(n1)
942	i2, f2 := po.lookup(n2)
943	return f1 && f2 && i1 == i2
944}
945
946// NonEqual reports whether n1!=n2. It returns false either when it is
947// certain that n1!=n2 is false, or if there is not enough information
948// to tell.
949// Complexity is O(n) (because it internally calls Ordered to see if we
950// can infer n1!=n2 from n1<n2 or n2<n1).
951func (po *poset) NonEqual(n1, n2 *Value) bool {
952	if debugPoset {
953		defer po.CheckIntegrity()
954	}
955	if n1.ID == n2.ID {
956		panic("should not call NonEqual with n1==n2")
957	}
958
959	// If we never saw the nodes before, we don't
960	// have a recorded non-equality.
961	i1, f1 := po.lookup(n1)
962	i2, f2 := po.lookup(n2)
963	if !f1 || !f2 {
964		return false
965	}
966
967	// Check if we recorded inequality
968	if po.isnoneq(i1, i2) {
969		return true
970	}
971
972	// Check if n1<n2 or n2<n1, in which case we can infer that n1!=n2
973	if po.Ordered(n1, n2) || po.Ordered(n2, n1) {
974		return true
975	}
976
977	return false
978}
979
980// setOrder records that n1<n2 or n1<=n2 (depending on strict). Returns false
981// if this is a contradiction.
982// Implements SetOrder() and SetOrderOrEqual()
983func (po *poset) setOrder(n1, n2 *Value, strict bool) bool {
984	i1, f1 := po.lookup(n1)
985	i2, f2 := po.lookup(n2)
986
987	switch {
988	case !f1 && !f2:
989		// Neither n1 nor n2 are in the poset, so they are not related
990		// in any way to existing nodes.
991		// Create a new DAG to record the relation.
992		i1, i2 = po.newnode(n1), po.newnode(n2)
993		po.roots = append(po.roots, i1)
994		po.upush(undoNewRoot, i1, 0)
995		po.addchild(i1, i2, strict)
996
997	case f1 && !f2:
998		// n1 is in one of the DAGs, while n2 is not. Add n2 as children
999		// of n1.
1000		i2 = po.newnode(n2)
1001		po.addchild(i1, i2, strict)
1002
1003	case !f1 && f2:
1004		// n1 is not in any DAG but n2 is. If n2 is a root, we can put
1005		// n1 in its place as a root; otherwise, we need to create a new
1006		// extra root to record the relation.
1007		i1 = po.newnode(n1)
1008
1009		if po.isroot(i2) {
1010			po.changeroot(i2, i1)
1011			po.upush(undoChangeRoot, i1, newedge(i2, strict))
1012			po.addchild(i1, i2, strict)
1013			return true
1014		}
1015
1016		// Search for i2's root; this requires a O(n) search on all
1017		// DAGs
1018		r := po.findroot(i2)
1019
1020		// Re-parent as follows:
1021		//
1022		//                  extra
1023		//     r            /   \
1024		//      \   ===>   r    i1
1025		//      i2          \   /
1026		//                    i2
1027		//
1028		extra := po.newnode(nil)
1029		po.changeroot(r, extra)
1030		po.upush(undoChangeRoot, extra, newedge(r, false))
1031		po.addchild(extra, r, false)
1032		po.addchild(extra, i1, false)
1033		po.addchild(i1, i2, strict)
1034
1035	case f1 && f2:
1036		// If the nodes are aliased, fail only if we're setting a strict order
1037		// (that is, we cannot set n1<n2 if n1==n2).
1038		if i1 == i2 {
1039			return !strict
1040		}
1041
1042		// If we are trying to record n1<=n2 but we learned that n1!=n2,
1043		// record n1<n2, as it provides more information.
1044		if !strict && po.isnoneq(i1, i2) {
1045			strict = true
1046		}
1047
1048		// Both n1 and n2 are in the poset. This is the complex part of the algorithm
1049		// as we need to find many different cases and DAG shapes.
1050
1051		// Check if n1 somehow reaches n2
1052		if po.reaches(i1, i2, false) {
1053			// This is the table of all cases we need to handle:
1054			//
1055			//      DAG          New      Action
1056			//      ---------------------------------------------------
1057			// #1:  N1<=X<=N2 |  N1<=N2 | do nothing
1058			// #2:  N1<=X<=N2 |  N1<N2  | add strict edge (N1<N2)
1059			// #3:  N1<X<N2   |  N1<=N2 | do nothing (we already know more)
1060			// #4:  N1<X<N2   |  N1<N2  | do nothing
1061
1062			// Check if we're in case #2
1063			if strict && !po.reaches(i1, i2, true) {
1064				po.addchild(i1, i2, true)
1065				return true
1066			}
1067
1068			// Case #1, #3, or #4: nothing to do
1069			return true
1070		}
1071
1072		// Check if n2 somehow reaches n1
1073		if po.reaches(i2, i1, false) {
1074			// This is the table of all cases we need to handle:
1075			//
1076			//      DAG           New      Action
1077			//      ---------------------------------------------------
1078			// #5:  N2<=X<=N1  |  N1<=N2 | collapse path (learn that N1=X=N2)
1079			// #6:  N2<=X<=N1  |  N1<N2  | contradiction
1080			// #7:  N2<X<N1    |  N1<=N2 | contradiction in the path
1081			// #8:  N2<X<N1    |  N1<N2  | contradiction
1082
1083			if strict {
1084				// Cases #6 and #8: contradiction
1085				return false
1086			}
1087
1088			// We're in case #5 or #7. Try to collapse path, and that will
1089			// fail if it realizes that we are in case #7.
1090			return po.collapsepath(n2, n1)
1091		}
1092
1093		// We don't know of any existing relation between n1 and n2. They could
1094		// be part of the same DAG or not.
1095		// Find their roots to check whether they are in the same DAG.
1096		r1, r2 := po.findroot(i1), po.findroot(i2)
1097		if r1 != r2 {
1098			// We need to merge the two DAGs to record a relation between the nodes
1099			po.mergeroot(r1, r2)
1100		}
1101
1102		// Connect n1 and n2
1103		po.addchild(i1, i2, strict)
1104	}
1105
1106	return true
1107}
1108
1109// SetOrder records that n1<n2. Returns false if this is a contradiction
1110// Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
1111func (po *poset) SetOrder(n1, n2 *Value) bool {
1112	if debugPoset {
1113		defer po.CheckIntegrity()
1114	}
1115	if n1.ID == n2.ID {
1116		panic("should not call SetOrder with n1==n2")
1117	}
1118	return po.setOrder(n1, n2, true)
1119}
1120
1121// SetOrderOrEqual records that n1<=n2. Returns false if this is a contradiction
1122// Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
1123func (po *poset) SetOrderOrEqual(n1, n2 *Value) bool {
1124	if debugPoset {
1125		defer po.CheckIntegrity()
1126	}
1127	if n1.ID == n2.ID {
1128		panic("should not call SetOrder with n1==n2")
1129	}
1130	return po.setOrder(n1, n2, false)
1131}
1132
1133// SetEqual records that n1==n2. Returns false if this is a contradiction
1134// (that is, if it is already recorded that n1<n2 or n2<n1).
1135// Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
1136func (po *poset) SetEqual(n1, n2 *Value) bool {
1137	if debugPoset {
1138		defer po.CheckIntegrity()
1139	}
1140	if n1.ID == n2.ID {
1141		panic("should not call Add with n1==n2")
1142	}
1143
1144	i1, f1 := po.lookup(n1)
1145	i2, f2 := po.lookup(n2)
1146
1147	switch {
1148	case !f1 && !f2:
1149		i1 = po.newnode(n1)
1150		po.roots = append(po.roots, i1)
1151		po.upush(undoNewRoot, i1, 0)
1152		po.aliasnewnode(n1, n2)
1153	case f1 && !f2:
1154		po.aliasnewnode(n1, n2)
1155	case !f1 && f2:
1156		po.aliasnewnode(n2, n1)
1157	case f1 && f2:
1158		if i1 == i2 {
1159			// Already aliased, ignore
1160			return true
1161		}
1162
1163		// If we recorded that n1!=n2, this is a contradiction.
1164		if po.isnoneq(i1, i2) {
1165			return false
1166		}
1167
1168		// If we already knew that n1<=n2, we can collapse the path to
1169		// record n1==n2 (and vice versa).
1170		if po.reaches(i1, i2, false) {
1171			return po.collapsepath(n1, n2)
1172		}
1173		if po.reaches(i2, i1, false) {
1174			return po.collapsepath(n2, n1)
1175		}
1176
1177		r1 := po.findroot(i1)
1178		r2 := po.findroot(i2)
1179		if r1 != r2 {
1180			// Merge the two DAGs so we can record relations between the nodes
1181			po.mergeroot(r1, r2)
1182		}
1183
1184		// Set n2 as alias of n1. This will also update all the references
1185		// to n2 to become references to n1
1186		i2s := newBitset(int(po.lastidx) + 1)
1187		i2s.Set(i2)
1188		po.aliasnodes(n1, i2s)
1189	}
1190	return true
1191}
1192
1193// SetNonEqual records that n1!=n2. Returns false if this is a contradiction
1194// (that is, if it is already recorded that n1==n2).
1195// Complexity is O(n).
1196func (po *poset) SetNonEqual(n1, n2 *Value) bool {
1197	if debugPoset {
1198		defer po.CheckIntegrity()
1199	}
1200	if n1.ID == n2.ID {
1201		panic("should not call SetNonEqual with n1==n2")
1202	}
1203
1204	// Check whether the nodes are already in the poset
1205	i1, f1 := po.lookup(n1)
1206	i2, f2 := po.lookup(n2)
1207
1208	// If either node wasn't present, we just record the new relation
1209	// and exit.
1210	if !f1 || !f2 {
1211		po.setnoneq(n1, n2)
1212		return true
1213	}
1214
1215	// See if we already know this, in which case there's nothing to do.
1216	if po.isnoneq(i1, i2) {
1217		return true
1218	}
1219
1220	// Check if we're contradicting an existing equality relation
1221	if po.Equal(n1, n2) {
1222		return false
1223	}
1224
1225	// Record non-equality
1226	po.setnoneq(n1, n2)
1227
1228	// If we know that i1<=i2 but not i1<i2, learn that as we
1229	// now know that they are not equal. Do the same for i2<=i1.
1230	// Do this check only if both nodes were already in the DAG,
1231	// otherwise there cannot be an existing relation.
1232	if po.reaches(i1, i2, false) && !po.reaches(i1, i2, true) {
1233		po.addchild(i1, i2, true)
1234	}
1235	if po.reaches(i2, i1, false) && !po.reaches(i2, i1, true) {
1236		po.addchild(i2, i1, true)
1237	}
1238
1239	return true
1240}
1241
1242// Checkpoint saves the current state of the DAG so that it's possible
1243// to later undo this state.
1244// Complexity is O(1).
1245func (po *poset) Checkpoint() {
1246	po.undo = append(po.undo, posetUndo{typ: undoCheckpoint})
1247}
1248
1249// Undo restores the state of the poset to the previous checkpoint.
1250// Complexity depends on the type of operations that were performed
1251// since the last checkpoint; each Set* operation creates an undo
1252// pass which Undo has to revert with a worst-case complexity of O(n).
1253func (po *poset) Undo() {
1254	if len(po.undo) == 0 {
1255		panic("empty undo stack")
1256	}
1257	if debugPoset {
1258		defer po.CheckIntegrity()
1259	}
1260
1261	for len(po.undo) > 0 {
1262		pass := po.undo[len(po.undo)-1]
1263		po.undo = po.undo[:len(po.undo)-1]
1264
1265		switch pass.typ {
1266		case undoCheckpoint:
1267			return
1268
1269		case undoSetChl:
1270			po.setchl(pass.idx, pass.edge)
1271
1272		case undoSetChr:
1273			po.setchr(pass.idx, pass.edge)
1274
1275		case undoNonEqual:
1276			po.noneq[uint32(pass.ID)].Clear(pass.idx)
1277
1278		case undoNewNode:
1279			if pass.idx != po.lastidx {
1280				panic("invalid newnode index")
1281			}
1282			if pass.ID != 0 {
1283				if po.values[pass.ID] != pass.idx {
1284					panic("invalid newnode undo pass")
1285				}
1286				delete(po.values, pass.ID)
1287			}
1288			po.setchl(pass.idx, 0)
1289			po.setchr(pass.idx, 0)
1290			po.nodes = po.nodes[:pass.idx]
1291			po.lastidx--
1292
1293		case undoNewConstant:
1294			// FIXME: remove this O(n) loop
1295			var val int64
1296			var i uint32
1297			for val, i = range po.constants {
1298				if i == pass.idx {
1299					break
1300				}
1301			}
1302			if i != pass.idx {
1303				panic("constant not found in undo pass")
1304			}
1305			if pass.ID == 0 {
1306				delete(po.constants, val)
1307			} else {
1308				// Restore previous index as constant node
1309				// (also restoring the invariant on correct bounds)
1310				oldidx := uint32(pass.ID)
1311				po.constants[val] = oldidx
1312			}
1313
1314		case undoAliasNode:
1315			ID, prev := pass.ID, pass.idx
1316			cur := po.values[ID]
1317			if prev == 0 {
1318				// Born as an alias, die as an alias
1319				delete(po.values, ID)
1320			} else {
1321				if cur == prev {
1322					panic("invalid aliasnode undo pass")
1323				}
1324				// Give it back previous value
1325				po.values[ID] = prev
1326			}
1327
1328		case undoNewRoot:
1329			i := pass.idx
1330			l, r := po.children(i)
1331			if l|r != 0 {
1332				panic("non-empty root in undo newroot")
1333			}
1334			po.removeroot(i)
1335
1336		case undoChangeRoot:
1337			i := pass.idx
1338			l, r := po.children(i)
1339			if l|r != 0 {
1340				panic("non-empty root in undo changeroot")
1341			}
1342			po.changeroot(i, pass.edge.Target())
1343
1344		case undoMergeRoot:
1345			i := pass.idx
1346			l, r := po.children(i)
1347			po.changeroot(i, l.Target())
1348			po.roots = append(po.roots, r.Target())
1349
1350		default:
1351			panic(pass.typ)
1352		}
1353	}
1354
1355	if debugPoset && po.CheckEmpty() != nil {
1356		panic("poset not empty at the end of undo")
1357	}
1358}
1359