1// Copyright 2024 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package concurrent
6
7import (
8	"internal/abi"
9	"internal/goarch"
10	"math/rand/v2"
11	"sync"
12	"sync/atomic"
13	"unsafe"
14)
15
16// HashTrieMap is an implementation of a concurrent hash-trie. The implementation
17// is designed around frequent loads, but offers decent performance for stores
18// and deletes as well, especially if the map is larger. It's primary use-case is
19// the unique package, but can be used elsewhere as well.
20type HashTrieMap[K, V comparable] struct {
21	root     *indirect[K, V]
22	keyHash  hashFunc
23	keyEqual equalFunc
24	valEqual equalFunc
25	seed     uintptr
26}
27
28// NewHashTrieMap creates a new HashTrieMap for the provided key and value.
29func NewHashTrieMap[K, V comparable]() *HashTrieMap[K, V] {
30	var m map[K]V
31	mapType := abi.TypeOf(m).MapType()
32	ht := &HashTrieMap[K, V]{
33		root:     newIndirectNode[K, V](nil),
34		keyHash:  mapType.Hasher,
35		keyEqual: mapType.Key.Equal,
36		valEqual: mapType.Elem.Equal,
37		seed:     uintptr(rand.Uint64()),
38	}
39	return ht
40}
41
42type hashFunc func(unsafe.Pointer, uintptr) uintptr
43type equalFunc func(unsafe.Pointer, unsafe.Pointer) bool
44
45// Load returns the value stored in the map for a key, or nil if no
46// value is present.
47// The ok result indicates whether value was found in the map.
48func (ht *HashTrieMap[K, V]) Load(key K) (value V, ok bool) {
49	hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
50
51	i := ht.root
52	hashShift := 8 * goarch.PtrSize
53	for hashShift != 0 {
54		hashShift -= nChildrenLog2
55
56		n := i.children[(hash>>hashShift)&nChildrenMask].Load()
57		if n == nil {
58			return *new(V), false
59		}
60		if n.isEntry {
61			return n.entry().lookup(key, ht.keyEqual)
62		}
63		i = n.indirect()
64	}
65	panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
66}
67
68// LoadOrStore returns the existing value for the key if present.
69// Otherwise, it stores and returns the given value.
70// The loaded result is true if the value was loaded, false if stored.
71func (ht *HashTrieMap[K, V]) LoadOrStore(key K, value V) (result V, loaded bool) {
72	hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
73	var i *indirect[K, V]
74	var hashShift uint
75	var slot *atomic.Pointer[node[K, V]]
76	var n *node[K, V]
77	for {
78		// Find the key or a candidate location for insertion.
79		i = ht.root
80		hashShift = 8 * goarch.PtrSize
81		haveInsertPoint := false
82		for hashShift != 0 {
83			hashShift -= nChildrenLog2
84
85			slot = &i.children[(hash>>hashShift)&nChildrenMask]
86			n = slot.Load()
87			if n == nil {
88				// We found a nil slot which is a candidate for insertion.
89				haveInsertPoint = true
90				break
91			}
92			if n.isEntry {
93				// We found an existing entry, which is as far as we can go.
94				// If it stays this way, we'll have to replace it with an
95				// indirect node.
96				if v, ok := n.entry().lookup(key, ht.keyEqual); ok {
97					return v, true
98				}
99				haveInsertPoint = true
100				break
101			}
102			i = n.indirect()
103		}
104		if !haveInsertPoint {
105			panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
106		}
107
108		// Grab the lock and double-check what we saw.
109		i.mu.Lock()
110		n = slot.Load()
111		if (n == nil || n.isEntry) && !i.dead.Load() {
112			// What we saw is still true, so we can continue with the insert.
113			break
114		}
115		// We have to start over.
116		i.mu.Unlock()
117	}
118	// N.B. This lock is held from when we broke out of the outer loop above.
119	// We specifically break this out so that we can use defer here safely.
120	// One option is to break this out into a new function instead, but
121	// there's so much local iteration state used below that this turns out
122	// to be cleaner.
123	defer i.mu.Unlock()
124
125	var oldEntry *entry[K, V]
126	if n != nil {
127		oldEntry = n.entry()
128		if v, ok := oldEntry.lookup(key, ht.keyEqual); ok {
129			// Easy case: by loading again, it turns out exactly what we wanted is here!
130			return v, true
131		}
132	}
133	newEntry := newEntryNode(key, value)
134	if oldEntry == nil {
135		// Easy case: create a new entry and store it.
136		slot.Store(&newEntry.node)
137	} else {
138		// We possibly need to expand the entry already there into one or more new nodes.
139		//
140		// Publish the node last, which will make both oldEntry and newEntry visible. We
141		// don't want readers to be able to observe that oldEntry isn't in the tree.
142		slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i))
143	}
144	return value, false
145}
146
147// expand takes oldEntry and newEntry whose hashes conflict from bit 64 down to hashShift and
148// produces a subtree of indirect nodes to hold the two new entries.
149func (ht *HashTrieMap[K, V]) expand(oldEntry, newEntry *entry[K, V], newHash uintptr, hashShift uint, parent *indirect[K, V]) *node[K, V] {
150	// Check for a hash collision.
151	oldHash := ht.keyHash(unsafe.Pointer(&oldEntry.key), ht.seed)
152	if oldHash == newHash {
153		// Store the old entry in the new entry's overflow list, then store
154		// the new entry.
155		newEntry.overflow.Store(oldEntry)
156		return &newEntry.node
157	}
158	// We have to add an indirect node. Worse still, we may need to add more than one.
159	newIndirect := newIndirectNode(parent)
160	top := newIndirect
161	for {
162		if hashShift == 0 {
163			panic("internal/concurrent.HashMapTrie: ran out of hash bits while inserting")
164		}
165		hashShift -= nChildrenLog2 // hashShift is for the level parent is at. We need to go deeper.
166		oi := (oldHash >> hashShift) & nChildrenMask
167		ni := (newHash >> hashShift) & nChildrenMask
168		if oi != ni {
169			newIndirect.children[oi].Store(&oldEntry.node)
170			newIndirect.children[ni].Store(&newEntry.node)
171			break
172		}
173		nextIndirect := newIndirectNode(newIndirect)
174		newIndirect.children[oi].Store(&nextIndirect.node)
175		newIndirect = nextIndirect
176	}
177	return &top.node
178}
179
180// CompareAndDelete deletes the entry for key if its value is equal to old.
181//
182// If there is no current value for key in the map, CompareAndDelete returns false
183// (even if the old value is the nil interface value).
184func (ht *HashTrieMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
185	hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
186	var i *indirect[K, V]
187	var hashShift uint
188	var slot *atomic.Pointer[node[K, V]]
189	var n *node[K, V]
190	for {
191		// Find the key or return when there's nothing to delete.
192		i = ht.root
193		hashShift = 8 * goarch.PtrSize
194		found := false
195		for hashShift != 0 {
196			hashShift -= nChildrenLog2
197
198			slot = &i.children[(hash>>hashShift)&nChildrenMask]
199			n = slot.Load()
200			if n == nil {
201				// Nothing to delete. Give up.
202				return
203			}
204			if n.isEntry {
205				// We found an entry. Check if it matches.
206				if _, ok := n.entry().lookup(key, ht.keyEqual); !ok {
207					// No match, nothing to delete.
208					return
209				}
210				// We've got something to delete.
211				found = true
212				break
213			}
214			i = n.indirect()
215		}
216		if !found {
217			panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
218		}
219
220		// Grab the lock and double-check what we saw.
221		i.mu.Lock()
222		n = slot.Load()
223		if !i.dead.Load() {
224			if n == nil {
225				// Valid node that doesn't contain what we need. Nothing to delete.
226				i.mu.Unlock()
227				return
228			}
229			if n.isEntry {
230				// What we saw is still true, so we can continue with the delete.
231				break
232			}
233		}
234		// We have to start over.
235		i.mu.Unlock()
236	}
237	// Try to delete the entry.
238	e, deleted := n.entry().compareAndDelete(key, old, ht.keyEqual, ht.valEqual)
239	if !deleted {
240		// Nothing was actually deleted, which means the node is no longer there.
241		i.mu.Unlock()
242		return false
243	}
244	if e != nil {
245		// We didn't actually delete the whole entry, just one entry in the chain.
246		// Nothing else to do, since the parent is definitely not empty.
247		slot.Store(&e.node)
248		i.mu.Unlock()
249		return true
250	}
251	// Delete the entry.
252	slot.Store(nil)
253
254	// Check if the node is now empty (and isn't the root), and delete it if able.
255	for i.parent != nil && i.empty() {
256		if hashShift == 8*goarch.PtrSize {
257			panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
258		}
259		hashShift += nChildrenLog2
260
261		// Delete the current node in the parent.
262		parent := i.parent
263		parent.mu.Lock()
264		i.dead.Store(true)
265		parent.children[(hash>>hashShift)&nChildrenMask].Store(nil)
266		i.mu.Unlock()
267		i = parent
268	}
269	i.mu.Unlock()
270	return true
271}
272
273// All returns an iter.Seq2 that produces all key-value pairs in the map.
274// The enumeration does not represent any consistent snapshot of the map,
275// but is guaranteed to visit each unique key-value pair only once. It is
276// safe to operate on the tree during iteration. No particular enumeration
277// order is guaranteed.
278func (ht *HashTrieMap[K, V]) All() func(yield func(K, V) bool) {
279	return func(yield func(key K, value V) bool) {
280		ht.iter(ht.root, yield)
281	}
282}
283
284func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V) bool) bool {
285	for j := range i.children {
286		n := i.children[j].Load()
287		if n == nil {
288			continue
289		}
290		if !n.isEntry {
291			if !ht.iter(n.indirect(), yield) {
292				return false
293			}
294			continue
295		}
296		e := n.entry()
297		for e != nil {
298			if !yield(e.key, e.value) {
299				return false
300			}
301			e = e.overflow.Load()
302		}
303	}
304	return true
305}
306
307const (
308	// 16 children. This seems to be the sweet spot for
309	// load performance: any smaller and we lose out on
310	// 50% or more in CPU performance. Any larger and the
311	// returns are minuscule (~1% improvement for 32 children).
312	nChildrenLog2 = 4
313	nChildren     = 1 << nChildrenLog2
314	nChildrenMask = nChildren - 1
315)
316
317// indirect is an internal node in the hash-trie.
318type indirect[K, V comparable] struct {
319	node[K, V]
320	dead     atomic.Bool
321	mu       sync.Mutex // Protects mutation to children and any children that are entry nodes.
322	parent   *indirect[K, V]
323	children [nChildren]atomic.Pointer[node[K, V]]
324}
325
326func newIndirectNode[K, V comparable](parent *indirect[K, V]) *indirect[K, V] {
327	return &indirect[K, V]{node: node[K, V]{isEntry: false}, parent: parent}
328}
329
330func (i *indirect[K, V]) empty() bool {
331	nc := 0
332	for j := range i.children {
333		if i.children[j].Load() != nil {
334			nc++
335		}
336	}
337	return nc == 0
338}
339
340// entry is a leaf node in the hash-trie.
341type entry[K, V comparable] struct {
342	node[K, V]
343	overflow atomic.Pointer[entry[K, V]] // Overflow for hash collisions.
344	key      K
345	value    V
346}
347
348func newEntryNode[K, V comparable](key K, value V) *entry[K, V] {
349	return &entry[K, V]{
350		node:  node[K, V]{isEntry: true},
351		key:   key,
352		value: value,
353	}
354}
355
356func (e *entry[K, V]) lookup(key K, equal equalFunc) (V, bool) {
357	for e != nil {
358		if equal(unsafe.Pointer(&e.key), abi.NoEscape(unsafe.Pointer(&key))) {
359			return e.value, true
360		}
361		e = e.overflow.Load()
362	}
363	return *new(V), false
364}
365
366// compareAndDelete deletes an entry in the overflow chain if both the key and value compare
367// equal. Returns the new entry chain and whether or not anything was deleted.
368//
369// compareAndDelete must be called under the mutex of the indirect node which e is a child of.
370func (head *entry[K, V]) compareAndDelete(key K, value V, keyEqual, valEqual equalFunc) (*entry[K, V], bool) {
371	if keyEqual(unsafe.Pointer(&head.key), abi.NoEscape(unsafe.Pointer(&key))) &&
372		valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&value))) {
373		// Drop the head of the list.
374		return head.overflow.Load(), true
375	}
376	i := &head.overflow
377	e := i.Load()
378	for e != nil {
379		if keyEqual(unsafe.Pointer(&e.key), abi.NoEscape(unsafe.Pointer(&key))) &&
380			valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value))) {
381			i.Store(e.overflow.Load())
382			return head, true
383		}
384		i = &e.overflow
385		e = e.overflow.Load()
386	}
387	return head, false
388}
389
390// node is the header for a node. It's polymorphic and
391// is actually either an entry or an indirect.
392type node[K, V comparable] struct {
393	isEntry bool
394}
395
396func (n *node[K, V]) entry() *entry[K, V] {
397	if !n.isEntry {
398		panic("called entry on non-entry node")
399	}
400	return (*entry[K, V])(unsafe.Pointer(n))
401}
402
403func (n *node[K, V]) indirect() *indirect[K, V] {
404	if n.isEntry {
405		panic("called indirect on entry node")
406	}
407	return (*indirect[K, V])(unsafe.Pointer(n))
408}
409