1// Copyright 2024 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package concurrent 6 7import ( 8 "internal/abi" 9 "internal/goarch" 10 "math/rand/v2" 11 "sync" 12 "sync/atomic" 13 "unsafe" 14) 15 16// HashTrieMap is an implementation of a concurrent hash-trie. The implementation 17// is designed around frequent loads, but offers decent performance for stores 18// and deletes as well, especially if the map is larger. It's primary use-case is 19// the unique package, but can be used elsewhere as well. 20type HashTrieMap[K, V comparable] struct { 21 root *indirect[K, V] 22 keyHash hashFunc 23 keyEqual equalFunc 24 valEqual equalFunc 25 seed uintptr 26} 27 28// NewHashTrieMap creates a new HashTrieMap for the provided key and value. 29func NewHashTrieMap[K, V comparable]() *HashTrieMap[K, V] { 30 var m map[K]V 31 mapType := abi.TypeOf(m).MapType() 32 ht := &HashTrieMap[K, V]{ 33 root: newIndirectNode[K, V](nil), 34 keyHash: mapType.Hasher, 35 keyEqual: mapType.Key.Equal, 36 valEqual: mapType.Elem.Equal, 37 seed: uintptr(rand.Uint64()), 38 } 39 return ht 40} 41 42type hashFunc func(unsafe.Pointer, uintptr) uintptr 43type equalFunc func(unsafe.Pointer, unsafe.Pointer) bool 44 45// Load returns the value stored in the map for a key, or nil if no 46// value is present. 47// The ok result indicates whether value was found in the map. 48func (ht *HashTrieMap[K, V]) Load(key K) (value V, ok bool) { 49 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed) 50 51 i := ht.root 52 hashShift := 8 * goarch.PtrSize 53 for hashShift != 0 { 54 hashShift -= nChildrenLog2 55 56 n := i.children[(hash>>hashShift)&nChildrenMask].Load() 57 if n == nil { 58 return *new(V), false 59 } 60 if n.isEntry { 61 return n.entry().lookup(key, ht.keyEqual) 62 } 63 i = n.indirect() 64 } 65 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating") 66} 67 68// LoadOrStore returns the existing value for the key if present. 69// Otherwise, it stores and returns the given value. 70// The loaded result is true if the value was loaded, false if stored. 71func (ht *HashTrieMap[K, V]) LoadOrStore(key K, value V) (result V, loaded bool) { 72 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed) 73 var i *indirect[K, V] 74 var hashShift uint 75 var slot *atomic.Pointer[node[K, V]] 76 var n *node[K, V] 77 for { 78 // Find the key or a candidate location for insertion. 79 i = ht.root 80 hashShift = 8 * goarch.PtrSize 81 haveInsertPoint := false 82 for hashShift != 0 { 83 hashShift -= nChildrenLog2 84 85 slot = &i.children[(hash>>hashShift)&nChildrenMask] 86 n = slot.Load() 87 if n == nil { 88 // We found a nil slot which is a candidate for insertion. 89 haveInsertPoint = true 90 break 91 } 92 if n.isEntry { 93 // We found an existing entry, which is as far as we can go. 94 // If it stays this way, we'll have to replace it with an 95 // indirect node. 96 if v, ok := n.entry().lookup(key, ht.keyEqual); ok { 97 return v, true 98 } 99 haveInsertPoint = true 100 break 101 } 102 i = n.indirect() 103 } 104 if !haveInsertPoint { 105 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating") 106 } 107 108 // Grab the lock and double-check what we saw. 109 i.mu.Lock() 110 n = slot.Load() 111 if (n == nil || n.isEntry) && !i.dead.Load() { 112 // What we saw is still true, so we can continue with the insert. 113 break 114 } 115 // We have to start over. 116 i.mu.Unlock() 117 } 118 // N.B. This lock is held from when we broke out of the outer loop above. 119 // We specifically break this out so that we can use defer here safely. 120 // One option is to break this out into a new function instead, but 121 // there's so much local iteration state used below that this turns out 122 // to be cleaner. 123 defer i.mu.Unlock() 124 125 var oldEntry *entry[K, V] 126 if n != nil { 127 oldEntry = n.entry() 128 if v, ok := oldEntry.lookup(key, ht.keyEqual); ok { 129 // Easy case: by loading again, it turns out exactly what we wanted is here! 130 return v, true 131 } 132 } 133 newEntry := newEntryNode(key, value) 134 if oldEntry == nil { 135 // Easy case: create a new entry and store it. 136 slot.Store(&newEntry.node) 137 } else { 138 // We possibly need to expand the entry already there into one or more new nodes. 139 // 140 // Publish the node last, which will make both oldEntry and newEntry visible. We 141 // don't want readers to be able to observe that oldEntry isn't in the tree. 142 slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i)) 143 } 144 return value, false 145} 146 147// expand takes oldEntry and newEntry whose hashes conflict from bit 64 down to hashShift and 148// produces a subtree of indirect nodes to hold the two new entries. 149func (ht *HashTrieMap[K, V]) expand(oldEntry, newEntry *entry[K, V], newHash uintptr, hashShift uint, parent *indirect[K, V]) *node[K, V] { 150 // Check for a hash collision. 151 oldHash := ht.keyHash(unsafe.Pointer(&oldEntry.key), ht.seed) 152 if oldHash == newHash { 153 // Store the old entry in the new entry's overflow list, then store 154 // the new entry. 155 newEntry.overflow.Store(oldEntry) 156 return &newEntry.node 157 } 158 // We have to add an indirect node. Worse still, we may need to add more than one. 159 newIndirect := newIndirectNode(parent) 160 top := newIndirect 161 for { 162 if hashShift == 0 { 163 panic("internal/concurrent.HashMapTrie: ran out of hash bits while inserting") 164 } 165 hashShift -= nChildrenLog2 // hashShift is for the level parent is at. We need to go deeper. 166 oi := (oldHash >> hashShift) & nChildrenMask 167 ni := (newHash >> hashShift) & nChildrenMask 168 if oi != ni { 169 newIndirect.children[oi].Store(&oldEntry.node) 170 newIndirect.children[ni].Store(&newEntry.node) 171 break 172 } 173 nextIndirect := newIndirectNode(newIndirect) 174 newIndirect.children[oi].Store(&nextIndirect.node) 175 newIndirect = nextIndirect 176 } 177 return &top.node 178} 179 180// CompareAndDelete deletes the entry for key if its value is equal to old. 181// 182// If there is no current value for key in the map, CompareAndDelete returns false 183// (even if the old value is the nil interface value). 184func (ht *HashTrieMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) { 185 hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed) 186 var i *indirect[K, V] 187 var hashShift uint 188 var slot *atomic.Pointer[node[K, V]] 189 var n *node[K, V] 190 for { 191 // Find the key or return when there's nothing to delete. 192 i = ht.root 193 hashShift = 8 * goarch.PtrSize 194 found := false 195 for hashShift != 0 { 196 hashShift -= nChildrenLog2 197 198 slot = &i.children[(hash>>hashShift)&nChildrenMask] 199 n = slot.Load() 200 if n == nil { 201 // Nothing to delete. Give up. 202 return 203 } 204 if n.isEntry { 205 // We found an entry. Check if it matches. 206 if _, ok := n.entry().lookup(key, ht.keyEqual); !ok { 207 // No match, nothing to delete. 208 return 209 } 210 // We've got something to delete. 211 found = true 212 break 213 } 214 i = n.indirect() 215 } 216 if !found { 217 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating") 218 } 219 220 // Grab the lock and double-check what we saw. 221 i.mu.Lock() 222 n = slot.Load() 223 if !i.dead.Load() { 224 if n == nil { 225 // Valid node that doesn't contain what we need. Nothing to delete. 226 i.mu.Unlock() 227 return 228 } 229 if n.isEntry { 230 // What we saw is still true, so we can continue with the delete. 231 break 232 } 233 } 234 // We have to start over. 235 i.mu.Unlock() 236 } 237 // Try to delete the entry. 238 e, deleted := n.entry().compareAndDelete(key, old, ht.keyEqual, ht.valEqual) 239 if !deleted { 240 // Nothing was actually deleted, which means the node is no longer there. 241 i.mu.Unlock() 242 return false 243 } 244 if e != nil { 245 // We didn't actually delete the whole entry, just one entry in the chain. 246 // Nothing else to do, since the parent is definitely not empty. 247 slot.Store(&e.node) 248 i.mu.Unlock() 249 return true 250 } 251 // Delete the entry. 252 slot.Store(nil) 253 254 // Check if the node is now empty (and isn't the root), and delete it if able. 255 for i.parent != nil && i.empty() { 256 if hashShift == 8*goarch.PtrSize { 257 panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating") 258 } 259 hashShift += nChildrenLog2 260 261 // Delete the current node in the parent. 262 parent := i.parent 263 parent.mu.Lock() 264 i.dead.Store(true) 265 parent.children[(hash>>hashShift)&nChildrenMask].Store(nil) 266 i.mu.Unlock() 267 i = parent 268 } 269 i.mu.Unlock() 270 return true 271} 272 273// All returns an iter.Seq2 that produces all key-value pairs in the map. 274// The enumeration does not represent any consistent snapshot of the map, 275// but is guaranteed to visit each unique key-value pair only once. It is 276// safe to operate on the tree during iteration. No particular enumeration 277// order is guaranteed. 278func (ht *HashTrieMap[K, V]) All() func(yield func(K, V) bool) { 279 return func(yield func(key K, value V) bool) { 280 ht.iter(ht.root, yield) 281 } 282} 283 284func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V) bool) bool { 285 for j := range i.children { 286 n := i.children[j].Load() 287 if n == nil { 288 continue 289 } 290 if !n.isEntry { 291 if !ht.iter(n.indirect(), yield) { 292 return false 293 } 294 continue 295 } 296 e := n.entry() 297 for e != nil { 298 if !yield(e.key, e.value) { 299 return false 300 } 301 e = e.overflow.Load() 302 } 303 } 304 return true 305} 306 307const ( 308 // 16 children. This seems to be the sweet spot for 309 // load performance: any smaller and we lose out on 310 // 50% or more in CPU performance. Any larger and the 311 // returns are minuscule (~1% improvement for 32 children). 312 nChildrenLog2 = 4 313 nChildren = 1 << nChildrenLog2 314 nChildrenMask = nChildren - 1 315) 316 317// indirect is an internal node in the hash-trie. 318type indirect[K, V comparable] struct { 319 node[K, V] 320 dead atomic.Bool 321 mu sync.Mutex // Protects mutation to children and any children that are entry nodes. 322 parent *indirect[K, V] 323 children [nChildren]atomic.Pointer[node[K, V]] 324} 325 326func newIndirectNode[K, V comparable](parent *indirect[K, V]) *indirect[K, V] { 327 return &indirect[K, V]{node: node[K, V]{isEntry: false}, parent: parent} 328} 329 330func (i *indirect[K, V]) empty() bool { 331 nc := 0 332 for j := range i.children { 333 if i.children[j].Load() != nil { 334 nc++ 335 } 336 } 337 return nc == 0 338} 339 340// entry is a leaf node in the hash-trie. 341type entry[K, V comparable] struct { 342 node[K, V] 343 overflow atomic.Pointer[entry[K, V]] // Overflow for hash collisions. 344 key K 345 value V 346} 347 348func newEntryNode[K, V comparable](key K, value V) *entry[K, V] { 349 return &entry[K, V]{ 350 node: node[K, V]{isEntry: true}, 351 key: key, 352 value: value, 353 } 354} 355 356func (e *entry[K, V]) lookup(key K, equal equalFunc) (V, bool) { 357 for e != nil { 358 if equal(unsafe.Pointer(&e.key), abi.NoEscape(unsafe.Pointer(&key))) { 359 return e.value, true 360 } 361 e = e.overflow.Load() 362 } 363 return *new(V), false 364} 365 366// compareAndDelete deletes an entry in the overflow chain if both the key and value compare 367// equal. Returns the new entry chain and whether or not anything was deleted. 368// 369// compareAndDelete must be called under the mutex of the indirect node which e is a child of. 370func (head *entry[K, V]) compareAndDelete(key K, value V, keyEqual, valEqual equalFunc) (*entry[K, V], bool) { 371 if keyEqual(unsafe.Pointer(&head.key), abi.NoEscape(unsafe.Pointer(&key))) && 372 valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&value))) { 373 // Drop the head of the list. 374 return head.overflow.Load(), true 375 } 376 i := &head.overflow 377 e := i.Load() 378 for e != nil { 379 if keyEqual(unsafe.Pointer(&e.key), abi.NoEscape(unsafe.Pointer(&key))) && 380 valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value))) { 381 i.Store(e.overflow.Load()) 382 return head, true 383 } 384 i = &e.overflow 385 e = e.overflow.Load() 386 } 387 return head, false 388} 389 390// node is the header for a node. It's polymorphic and 391// is actually either an entry or an indirect. 392type node[K, V comparable] struct { 393 isEntry bool 394} 395 396func (n *node[K, V]) entry() *entry[K, V] { 397 if !n.isEntry { 398 panic("called entry on non-entry node") 399 } 400 return (*entry[K, V])(unsafe.Pointer(n)) 401} 402 403func (n *node[K, V]) indirect() *indirect[K, V] { 404 if n.isEntry { 405 panic("called indirect on entry node") 406 } 407 return (*indirect[K, V])(unsafe.Pointer(n)) 408} 409