1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package singleflight
6
7import (
8	"errors"
9	"fmt"
10	"sync"
11	"sync/atomic"
12	"testing"
13	"time"
14)
15
16func TestDo(t *testing.T) {
17	var g Group
18	v, err, _ := g.Do("key", func() (any, error) {
19		return "bar", nil
20	})
21	if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
22		t.Errorf("Do = %v; want %v", got, want)
23	}
24	if err != nil {
25		t.Errorf("Do error = %v", err)
26	}
27}
28
29func TestDoErr(t *testing.T) {
30	var g Group
31	someErr := errors.New("some error")
32	v, err, _ := g.Do("key", func() (any, error) {
33		return nil, someErr
34	})
35	if err != someErr {
36		t.Errorf("Do error = %v; want someErr %v", err, someErr)
37	}
38	if v != nil {
39		t.Errorf("unexpected non-nil value %#v", v)
40	}
41}
42
43func TestDoDupSuppress(t *testing.T) {
44	var g Group
45	var wg1, wg2 sync.WaitGroup
46	c := make(chan string, 1)
47	var calls atomic.Int32
48	fn := func() (any, error) {
49		if calls.Add(1) == 1 {
50			// First invocation.
51			wg1.Done()
52		}
53		v := <-c
54		c <- v // pump; make available for any future calls
55
56		time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
57
58		return v, nil
59	}
60
61	const n = 10
62	wg1.Add(1)
63	for i := 0; i < n; i++ {
64		wg1.Add(1)
65		wg2.Add(1)
66		go func() {
67			defer wg2.Done()
68			wg1.Done()
69			v, err, _ := g.Do("key", fn)
70			if err != nil {
71				t.Errorf("Do error: %v", err)
72				return
73			}
74			if s, _ := v.(string); s != "bar" {
75				t.Errorf("Do = %T %v; want %q", v, v, "bar")
76			}
77		}()
78	}
79	wg1.Wait()
80	// At least one goroutine is in fn now and all of them have at
81	// least reached the line before the Do.
82	c <- "bar"
83	wg2.Wait()
84	if got := calls.Load(); got <= 0 || got >= n {
85		t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
86	}
87}
88
89func TestForgetUnshared(t *testing.T) {
90	var g Group
91
92	var firstStarted, firstFinished sync.WaitGroup
93
94	firstStarted.Add(1)
95	firstFinished.Add(1)
96
97	key := "key"
98	firstCh := make(chan struct{})
99	go func() {
100		g.Do(key, func() (i interface{}, e error) {
101			firstStarted.Done()
102			<-firstCh
103			return
104		})
105		firstFinished.Done()
106	}()
107
108	firstStarted.Wait()
109	g.ForgetUnshared(key) // from this point no two function using same key should be executed concurrently
110
111	secondCh := make(chan struct{})
112	go func() {
113		g.Do(key, func() (i interface{}, e error) {
114			// Notify that we started
115			secondCh <- struct{}{}
116			<-secondCh
117			return 2, nil
118		})
119	}()
120
121	<-secondCh
122
123	resultCh := g.DoChan(key, func() (i interface{}, e error) {
124		panic("third must not be started")
125	})
126
127	if g.ForgetUnshared(key) {
128		t.Errorf("Before first goroutine finished, key %q is shared, should return false", key)
129	}
130
131	close(firstCh)
132	firstFinished.Wait()
133
134	if g.ForgetUnshared(key) {
135		t.Errorf("After first goroutine finished, key %q is still shared, should return false", key)
136	}
137
138	secondCh <- struct{}{}
139
140	if result := <-resultCh; result.Val != 2 {
141		t.Errorf("We should receive result produced by second call, expected: 2, got %d", result.Val)
142	}
143}
144
145func TestDoAndForgetUnsharedRace(t *testing.T) {
146	t.Parallel()
147
148	var g Group
149	key := "key"
150	d := time.Millisecond
151	for {
152		var calls, shared atomic.Int64
153		const n = 1000
154		var wg sync.WaitGroup
155		wg.Add(n)
156		for i := 0; i < n; i++ {
157			go func() {
158				g.Do(key, func() (interface{}, error) {
159					time.Sleep(d)
160					return calls.Add(1), nil
161				})
162				if !g.ForgetUnshared(key) {
163					shared.Add(1)
164				}
165				wg.Done()
166			}()
167		}
168		wg.Wait()
169
170		if calls.Load() != 1 {
171			// The goroutines didn't park in g.Do in time,
172			// so the key was re-added and may have been shared after the call.
173			// Try again with more time to park.
174			d *= 2
175			continue
176		}
177
178		// All of the Do calls ended up sharing the first
179		// invocation, so the key should have been unused
180		// (and therefore unshared) when they returned.
181		if shared.Load() > 0 {
182			t.Errorf("after a single shared Do, ForgetUnshared returned false %d times", shared.Load())
183		}
184		break
185	}
186}
187