1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package atomic_test
6
7import (
8	"math/rand"
9	"runtime"
10	"strconv"
11	"sync"
12	"sync/atomic"
13	. "sync/atomic"
14	"testing"
15)
16
17func TestValue(t *testing.T) {
18	var v Value
19	if v.Load() != nil {
20		t.Fatal("initial Value is not nil")
21	}
22	v.Store(42)
23	x := v.Load()
24	if xx, ok := x.(int); !ok || xx != 42 {
25		t.Fatalf("wrong value: got %+v, want 42", x)
26	}
27	v.Store(84)
28	x = v.Load()
29	if xx, ok := x.(int); !ok || xx != 84 {
30		t.Fatalf("wrong value: got %+v, want 84", x)
31	}
32}
33
34func TestValueLarge(t *testing.T) {
35	var v Value
36	v.Store("foo")
37	x := v.Load()
38	if xx, ok := x.(string); !ok || xx != "foo" {
39		t.Fatalf("wrong value: got %+v, want foo", x)
40	}
41	v.Store("barbaz")
42	x = v.Load()
43	if xx, ok := x.(string); !ok || xx != "barbaz" {
44		t.Fatalf("wrong value: got %+v, want barbaz", x)
45	}
46}
47
48func TestValuePanic(t *testing.T) {
49	const nilErr = "sync/atomic: store of nil value into Value"
50	const badErr = "sync/atomic: store of inconsistently typed value into Value"
51	var v Value
52	func() {
53		defer func() {
54			err := recover()
55			if err != nilErr {
56				t.Fatalf("inconsistent store panic: got '%v', want '%v'", err, nilErr)
57			}
58		}()
59		v.Store(nil)
60	}()
61	v.Store(42)
62	func() {
63		defer func() {
64			err := recover()
65			if err != badErr {
66				t.Fatalf("inconsistent store panic: got '%v', want '%v'", err, badErr)
67			}
68		}()
69		v.Store("foo")
70	}()
71	func() {
72		defer func() {
73			err := recover()
74			if err != nilErr {
75				t.Fatalf("inconsistent store panic: got '%v', want '%v'", err, nilErr)
76			}
77		}()
78		v.Store(nil)
79	}()
80}
81
82func TestValueConcurrent(t *testing.T) {
83	tests := [][]any{
84		{uint16(0), ^uint16(0), uint16(1 + 2<<8), uint16(3 + 4<<8)},
85		{uint32(0), ^uint32(0), uint32(1 + 2<<16), uint32(3 + 4<<16)},
86		{uint64(0), ^uint64(0), uint64(1 + 2<<32), uint64(3 + 4<<32)},
87		{complex(0, 0), complex(1, 2), complex(3, 4), complex(5, 6)},
88	}
89	p := 4 * runtime.GOMAXPROCS(0)
90	N := int(1e5)
91	if testing.Short() {
92		p /= 2
93		N = 1e3
94	}
95	for _, test := range tests {
96		var v Value
97		done := make(chan bool, p)
98		for i := 0; i < p; i++ {
99			go func() {
100				r := rand.New(rand.NewSource(rand.Int63()))
101				expected := true
102			loop:
103				for j := 0; j < N; j++ {
104					x := test[r.Intn(len(test))]
105					v.Store(x)
106					x = v.Load()
107					for _, x1 := range test {
108						if x == x1 {
109							continue loop
110						}
111					}
112					t.Logf("loaded unexpected value %+v, want %+v", x, test)
113					expected = false
114					break
115				}
116				done <- expected
117			}()
118		}
119		for i := 0; i < p; i++ {
120			if !<-done {
121				t.FailNow()
122			}
123		}
124	}
125}
126
127func BenchmarkValueRead(b *testing.B) {
128	var v Value
129	v.Store(new(int))
130	b.RunParallel(func(pb *testing.PB) {
131		for pb.Next() {
132			x := v.Load().(*int)
133			if *x != 0 {
134				b.Fatalf("wrong value: got %v, want 0", *x)
135			}
136		}
137	})
138}
139
140var Value_SwapTests = []struct {
141	init any
142	new  any
143	want any
144	err  any
145}{
146	{init: nil, new: nil, err: "sync/atomic: swap of nil value into Value"},
147	{init: nil, new: true, want: nil, err: nil},
148	{init: true, new: "", err: "sync/atomic: swap of inconsistently typed value into Value"},
149	{init: true, new: false, want: true, err: nil},
150}
151
152func TestValue_Swap(t *testing.T) {
153	for i, tt := range Value_SwapTests {
154		t.Run(strconv.Itoa(i), func(t *testing.T) {
155			var v Value
156			if tt.init != nil {
157				v.Store(tt.init)
158			}
159			defer func() {
160				err := recover()
161				switch {
162				case tt.err == nil && err != nil:
163					t.Errorf("should not panic, got %v", err)
164				case tt.err != nil && err == nil:
165					t.Errorf("should panic %v, got <nil>", tt.err)
166				}
167			}()
168			if got := v.Swap(tt.new); got != tt.want {
169				t.Errorf("got %v, want %v", got, tt.want)
170			}
171			if got := v.Load(); got != tt.new {
172				t.Errorf("got %v, want %v", got, tt.new)
173			}
174		})
175	}
176}
177
178func TestValueSwapConcurrent(t *testing.T) {
179	var v Value
180	var count uint64
181	var g sync.WaitGroup
182	var m, n uint64 = 10000, 10000
183	if testing.Short() {
184		m = 1000
185		n = 1000
186	}
187	for i := uint64(0); i < m*n; i += n {
188		i := i
189		g.Add(1)
190		go func() {
191			var c uint64
192			for new := i; new < i+n; new++ {
193				if old := v.Swap(new); old != nil {
194					c += old.(uint64)
195				}
196			}
197			atomic.AddUint64(&count, c)
198			g.Done()
199		}()
200	}
201	g.Wait()
202	if want, got := (m*n-1)*(m*n)/2, count+v.Load().(uint64); got != want {
203		t.Errorf("sum from 0 to %d was %d, want %v", m*n-1, got, want)
204	}
205}
206
207var heapA, heapB = struct{ uint }{0}, struct{ uint }{0}
208
209var Value_CompareAndSwapTests = []struct {
210	init any
211	new  any
212	old  any
213	want bool
214	err  any
215}{
216	{init: nil, new: nil, old: nil, err: "sync/atomic: compare and swap of nil value into Value"},
217	{init: nil, new: true, old: "", err: "sync/atomic: compare and swap of inconsistently typed values into Value"},
218	{init: nil, new: true, old: true, want: false, err: nil},
219	{init: nil, new: true, old: nil, want: true, err: nil},
220	{init: true, new: "", err: "sync/atomic: compare and swap of inconsistently typed value into Value"},
221	{init: true, new: true, old: false, want: false, err: nil},
222	{init: true, new: true, old: true, want: true, err: nil},
223	{init: heapA, new: struct{ uint }{1}, old: heapB, want: true, err: nil},
224}
225
226func TestValue_CompareAndSwap(t *testing.T) {
227	for i, tt := range Value_CompareAndSwapTests {
228		t.Run(strconv.Itoa(i), func(t *testing.T) {
229			var v Value
230			if tt.init != nil {
231				v.Store(tt.init)
232			}
233			defer func() {
234				err := recover()
235				switch {
236				case tt.err == nil && err != nil:
237					t.Errorf("got %v, wanted no panic", err)
238				case tt.err != nil && err == nil:
239					t.Errorf("did not panic, want %v", tt.err)
240				}
241			}()
242			if got := v.CompareAndSwap(tt.old, tt.new); got != tt.want {
243				t.Errorf("got %v, want %v", got, tt.want)
244			}
245		})
246	}
247}
248
249func TestValueCompareAndSwapConcurrent(t *testing.T) {
250	var v Value
251	var w sync.WaitGroup
252	v.Store(0)
253	m, n := 1000, 100
254	if testing.Short() {
255		m = 100
256		n = 100
257	}
258	for i := 0; i < m; i++ {
259		i := i
260		w.Add(1)
261		go func() {
262			for j := i; j < m*n; runtime.Gosched() {
263				if v.CompareAndSwap(j, j+1) {
264					j += m
265				}
266			}
267			w.Done()
268		}()
269	}
270	w.Wait()
271	if stop := v.Load().(int); stop != m*n {
272		t.Errorf("did not get to %v, stopped at %v", m*n, stop)
273	}
274}
275