1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build !race
6
7package maphash
8
9import (
10	"fmt"
11	"internal/testenv"
12	"math"
13	"math/rand"
14	"runtime"
15	"slices"
16	"strings"
17	"testing"
18	"unsafe"
19)
20
21// Smhasher is a torture test for hash functions.
22// https://code.google.com/p/smhasher/
23// This code is a port of some of the Smhasher tests to Go.
24
25// Note: due to the long running time of these tests, they are
26// currently disabled in -race mode.
27
28var fixedSeed = MakeSeed()
29
30// Sanity checks.
31// hash should not depend on values outside key.
32// hash should not depend on alignment.
33func TestSmhasherSanity(t *testing.T) {
34	t.Parallel()
35	r := rand.New(rand.NewSource(1234))
36	const REP = 10
37	const KEYMAX = 128
38	const PAD = 16
39	const OFFMAX = 16
40	for k := 0; k < REP; k++ {
41		for n := 0; n < KEYMAX; n++ {
42			for i := 0; i < OFFMAX; i++ {
43				var b [KEYMAX + OFFMAX + 2*PAD]byte
44				var c [KEYMAX + OFFMAX + 2*PAD]byte
45				randBytes(r, b[:])
46				randBytes(r, c[:])
47				copy(c[PAD+i:PAD+i+n], b[PAD:PAD+n])
48				if bytesHash(b[PAD:PAD+n]) != bytesHash(c[PAD+i:PAD+i+n]) {
49					t.Errorf("hash depends on bytes outside key")
50				}
51			}
52		}
53	}
54}
55
56func bytesHash(b []byte) uint64 {
57	var h Hash
58	h.SetSeed(fixedSeed)
59	h.Write(b)
60	return h.Sum64()
61}
62func stringHash(s string) uint64 {
63	var h Hash
64	h.SetSeed(fixedSeed)
65	h.WriteString(s)
66	return h.Sum64()
67}
68
69const hashSize = 64
70
71func randBytes(r *rand.Rand, b []byte) {
72	r.Read(b) // can't fail
73}
74
75// A hashSet measures the frequency of hash collisions.
76type hashSet struct {
77	list []uint64 // list of hashes added
78}
79
80func newHashSet() *hashSet {
81	return &hashSet{list: make([]uint64, 0, 1024)}
82}
83func (s *hashSet) add(h uint64) {
84	s.list = append(s.list, h)
85}
86func (s *hashSet) addS(x string) {
87	s.add(stringHash(x))
88}
89func (s *hashSet) addB(x []byte) {
90	s.add(bytesHash(x))
91}
92func (s *hashSet) addS_seed(x string, seed Seed) {
93	var h Hash
94	h.SetSeed(seed)
95	h.WriteString(x)
96	s.add(h.Sum64())
97}
98func (s *hashSet) check(t *testing.T) {
99	list := s.list
100	slices.Sort(list)
101
102	collisions := 0
103	for i := 1; i < len(list); i++ {
104		if list[i] == list[i-1] {
105			collisions++
106		}
107	}
108	n := len(list)
109
110	const SLOP = 10.0
111	pairs := int64(n) * int64(n-1) / 2
112	expected := float64(pairs) / math.Pow(2.0, float64(hashSize))
113	stddev := math.Sqrt(expected)
114	if float64(collisions) > expected+SLOP*(3*stddev+1) {
115		t.Errorf("unexpected number of collisions: got=%d mean=%f stddev=%f", collisions, expected, stddev)
116	}
117	// Reset for reuse
118	s.list = s.list[:0]
119}
120
121// a string plus adding zeros must make distinct hashes
122func TestSmhasherAppendedZeros(t *testing.T) {
123	t.Parallel()
124	s := "hello" + strings.Repeat("\x00", 256)
125	h := newHashSet()
126	for i := 0; i <= len(s); i++ {
127		h.addS(s[:i])
128	}
129	h.check(t)
130}
131
132// All 0-3 byte strings have distinct hashes.
133func TestSmhasherSmallKeys(t *testing.T) {
134	testenv.ParallelOn64Bit(t)
135	h := newHashSet()
136	var b [3]byte
137	for i := 0; i < 256; i++ {
138		b[0] = byte(i)
139		h.addB(b[:1])
140		for j := 0; j < 256; j++ {
141			b[1] = byte(j)
142			h.addB(b[:2])
143			if !testing.Short() {
144				for k := 0; k < 256; k++ {
145					b[2] = byte(k)
146					h.addB(b[:3])
147				}
148			}
149		}
150	}
151	h.check(t)
152}
153
154// Different length strings of all zeros have distinct hashes.
155func TestSmhasherZeros(t *testing.T) {
156	t.Parallel()
157	N := 256 * 1024
158	if testing.Short() {
159		N = 1024
160	}
161	h := newHashSet()
162	b := make([]byte, N)
163	for i := 0; i <= N; i++ {
164		h.addB(b[:i])
165	}
166	h.check(t)
167}
168
169// Strings with up to two nonzero bytes all have distinct hashes.
170func TestSmhasherTwoNonzero(t *testing.T) {
171	if runtime.GOARCH == "wasm" {
172		t.Skip("Too slow on wasm")
173	}
174	if testing.Short() {
175		t.Skip("Skipping in short mode")
176	}
177	testenv.ParallelOn64Bit(t)
178	h := newHashSet()
179	for n := 2; n <= 16; n++ {
180		twoNonZero(h, n)
181	}
182	h.check(t)
183}
184func twoNonZero(h *hashSet, n int) {
185	b := make([]byte, n)
186
187	// all zero
188	h.addB(b)
189
190	// one non-zero byte
191	for i := 0; i < n; i++ {
192		for x := 1; x < 256; x++ {
193			b[i] = byte(x)
194			h.addB(b)
195			b[i] = 0
196		}
197	}
198
199	// two non-zero bytes
200	for i := 0; i < n; i++ {
201		for x := 1; x < 256; x++ {
202			b[i] = byte(x)
203			for j := i + 1; j < n; j++ {
204				for y := 1; y < 256; y++ {
205					b[j] = byte(y)
206					h.addB(b)
207					b[j] = 0
208				}
209			}
210			b[i] = 0
211		}
212	}
213}
214
215// Test strings with repeats, like "abcdabcdabcdabcd..."
216func TestSmhasherCyclic(t *testing.T) {
217	if testing.Short() {
218		t.Skip("Skipping in short mode")
219	}
220	t.Parallel()
221	r := rand.New(rand.NewSource(1234))
222	const REPEAT = 8
223	const N = 1000000
224	h := newHashSet()
225	for n := 4; n <= 12; n++ {
226		b := make([]byte, REPEAT*n)
227		for i := 0; i < N; i++ {
228			b[0] = byte(i * 79 % 97)
229			b[1] = byte(i * 43 % 137)
230			b[2] = byte(i * 151 % 197)
231			b[3] = byte(i * 199 % 251)
232			randBytes(r, b[4:n])
233			for j := n; j < n*REPEAT; j++ {
234				b[j] = b[j-n]
235			}
236			h.addB(b)
237		}
238		h.check(t)
239	}
240}
241
242// Test strings with only a few bits set
243func TestSmhasherSparse(t *testing.T) {
244	if runtime.GOARCH == "wasm" {
245		t.Skip("Too slow on wasm")
246	}
247	if testing.Short() {
248		t.Skip("Skipping in short mode")
249	}
250	t.Parallel()
251	h := newHashSet()
252	sparse(t, h, 32, 6)
253	sparse(t, h, 40, 6)
254	sparse(t, h, 48, 5)
255	sparse(t, h, 56, 5)
256	sparse(t, h, 64, 5)
257	sparse(t, h, 96, 4)
258	sparse(t, h, 256, 3)
259	sparse(t, h, 2048, 2)
260}
261func sparse(t *testing.T, h *hashSet, n int, k int) {
262	b := make([]byte, n/8)
263	setbits(h, b, 0, k)
264	h.check(t)
265}
266
267// set up to k bits at index i and greater
268func setbits(h *hashSet, b []byte, i int, k int) {
269	h.addB(b)
270	if k == 0 {
271		return
272	}
273	for j := i; j < len(b)*8; j++ {
274		b[j/8] |= byte(1 << uint(j&7))
275		setbits(h, b, j+1, k-1)
276		b[j/8] &= byte(^(1 << uint(j&7)))
277	}
278}
279
280// Test all possible combinations of n blocks from the set s.
281// "permutation" is a bad name here, but it is what Smhasher uses.
282func TestSmhasherPermutation(t *testing.T) {
283	if runtime.GOARCH == "wasm" {
284		t.Skip("Too slow on wasm")
285	}
286	if testing.Short() {
287		t.Skip("Skipping in short mode")
288	}
289	testenv.ParallelOn64Bit(t)
290	h := newHashSet()
291	permutation(t, h, []uint32{0, 1, 2, 3, 4, 5, 6, 7}, 8)
292	permutation(t, h, []uint32{0, 1 << 29, 2 << 29, 3 << 29, 4 << 29, 5 << 29, 6 << 29, 7 << 29}, 8)
293	permutation(t, h, []uint32{0, 1}, 20)
294	permutation(t, h, []uint32{0, 1 << 31}, 20)
295	permutation(t, h, []uint32{0, 1, 2, 3, 4, 5, 6, 7, 1 << 29, 2 << 29, 3 << 29, 4 << 29, 5 << 29, 6 << 29, 7 << 29}, 6)
296}
297func permutation(t *testing.T, h *hashSet, s []uint32, n int) {
298	b := make([]byte, n*4)
299	genPerm(h, b, s, 0)
300	h.check(t)
301}
302func genPerm(h *hashSet, b []byte, s []uint32, n int) {
303	h.addB(b[:n])
304	if n == len(b) {
305		return
306	}
307	for _, v := range s {
308		b[n] = byte(v)
309		b[n+1] = byte(v >> 8)
310		b[n+2] = byte(v >> 16)
311		b[n+3] = byte(v >> 24)
312		genPerm(h, b, s, n+4)
313	}
314}
315
316type key interface {
317	clear()              // set bits all to 0
318	random(r *rand.Rand) // set key to something random
319	bits() int           // how many bits key has
320	flipBit(i int)       // flip bit i of the key
321	hash() uint64        // hash the key
322	name() string        // for error reporting
323}
324
325type bytesKey struct {
326	b []byte
327}
328
329func (k *bytesKey) clear() {
330	clear(k.b)
331}
332func (k *bytesKey) random(r *rand.Rand) {
333	randBytes(r, k.b)
334}
335func (k *bytesKey) bits() int {
336	return len(k.b) * 8
337}
338func (k *bytesKey) flipBit(i int) {
339	k.b[i>>3] ^= byte(1 << uint(i&7))
340}
341func (k *bytesKey) hash() uint64 {
342	return bytesHash(k.b)
343}
344func (k *bytesKey) name() string {
345	return fmt.Sprintf("bytes%d", len(k.b))
346}
347
348// Flipping a single bit of a key should flip each output bit with 50% probability.
349func TestSmhasherAvalanche(t *testing.T) {
350	if runtime.GOARCH == "wasm" {
351		t.Skip("Too slow on wasm")
352	}
353	if testing.Short() {
354		t.Skip("Skipping in short mode")
355	}
356	t.Parallel()
357	avalancheTest1(t, &bytesKey{make([]byte, 2)})
358	avalancheTest1(t, &bytesKey{make([]byte, 4)})
359	avalancheTest1(t, &bytesKey{make([]byte, 8)})
360	avalancheTest1(t, &bytesKey{make([]byte, 16)})
361	avalancheTest1(t, &bytesKey{make([]byte, 32)})
362	avalancheTest1(t, &bytesKey{make([]byte, 200)})
363}
364func avalancheTest1(t *testing.T, k key) {
365	const REP = 100000
366	r := rand.New(rand.NewSource(1234))
367	n := k.bits()
368
369	// grid[i][j] is a count of whether flipping
370	// input bit i affects output bit j.
371	grid := make([][hashSize]int, n)
372
373	for z := 0; z < REP; z++ {
374		// pick a random key, hash it
375		k.random(r)
376		h := k.hash()
377
378		// flip each bit, hash & compare the results
379		for i := 0; i < n; i++ {
380			k.flipBit(i)
381			d := h ^ k.hash()
382			k.flipBit(i)
383
384			// record the effects of that bit flip
385			g := &grid[i]
386			for j := 0; j < hashSize; j++ {
387				g[j] += int(d & 1)
388				d >>= 1
389			}
390		}
391	}
392
393	// Each entry in the grid should be about REP/2.
394	// More precisely, we did N = k.bits() * hashSize experiments where
395	// each is the sum of REP coin flips. We want to find bounds on the
396	// sum of coin flips such that a truly random experiment would have
397	// all sums inside those bounds with 99% probability.
398	N := n * hashSize
399	var c float64
400	// find c such that Prob(mean-c*stddev < x < mean+c*stddev)^N > .9999
401	for c = 0.0; math.Pow(math.Erf(c/math.Sqrt(2)), float64(N)) < .9999; c += .1 {
402	}
403	c *= 11.0 // allowed slack: 40% to 60% - we don't need to be perfectly random
404	mean := .5 * REP
405	stddev := .5 * math.Sqrt(REP)
406	low := int(mean - c*stddev)
407	high := int(mean + c*stddev)
408	for i := 0; i < n; i++ {
409		for j := 0; j < hashSize; j++ {
410			x := grid[i][j]
411			if x < low || x > high {
412				t.Errorf("bad bias for %s bit %d -> bit %d: %d/%d\n", k.name(), i, j, x, REP)
413			}
414		}
415	}
416}
417
418// All bit rotations of a set of distinct keys
419func TestSmhasherWindowed(t *testing.T) {
420	t.Parallel()
421	windowed(t, &bytesKey{make([]byte, 128)})
422}
423func windowed(t *testing.T, k key) {
424	if runtime.GOARCH == "wasm" {
425		t.Skip("Too slow on wasm")
426	}
427	if testing.Short() {
428		t.Skip("Skipping in short mode")
429	}
430	const BITS = 16
431
432	h := newHashSet()
433	for r := 0; r < k.bits(); r++ {
434		for i := 0; i < 1<<BITS; i++ {
435			k.clear()
436			for j := 0; j < BITS; j++ {
437				if i>>uint(j)&1 != 0 {
438					k.flipBit((j + r) % k.bits())
439				}
440			}
441			h.add(k.hash())
442		}
443		h.check(t)
444	}
445}
446
447// All keys of the form prefix + [A-Za-z0-9]*N + suffix.
448func TestSmhasherText(t *testing.T) {
449	if testing.Short() {
450		t.Skip("Skipping in short mode")
451	}
452	t.Parallel()
453	h := newHashSet()
454	text(t, h, "Foo", "Bar")
455	text(t, h, "FooBar", "")
456	text(t, h, "", "FooBar")
457}
458func text(t *testing.T, h *hashSet, prefix, suffix string) {
459	const N = 4
460	const S = "ABCDEFGHIJKLMNOPQRSTabcdefghijklmnopqrst0123456789"
461	const L = len(S)
462	b := make([]byte, len(prefix)+N+len(suffix))
463	copy(b, prefix)
464	copy(b[len(prefix)+N:], suffix)
465	c := b[len(prefix):]
466	for i := 0; i < L; i++ {
467		c[0] = S[i]
468		for j := 0; j < L; j++ {
469			c[1] = S[j]
470			for k := 0; k < L; k++ {
471				c[2] = S[k]
472				for x := 0; x < L; x++ {
473					c[3] = S[x]
474					h.addB(b)
475				}
476			}
477		}
478	}
479	h.check(t)
480}
481
482// Make sure different seed values generate different hashes.
483func TestSmhasherSeed(t *testing.T) {
484	if unsafe.Sizeof(uintptr(0)) == 4 {
485		t.Skip("32-bit platforms don't have ideal seed-input distributions (see issue 33988)")
486	}
487	t.Parallel()
488	h := newHashSet()
489	const N = 100000
490	s := "hello"
491	for i := 0; i < N; i++ {
492		h.addS_seed(s, Seed{s: uint64(i + 1)})
493		h.addS_seed(s, Seed{s: uint64(i+1) << 32}) // make sure high bits are used
494	}
495	h.check(t)
496}
497