1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package sync
6
7import (
8	"internal/race"
9	"sync/atomic"
10	"unsafe"
11)
12
13// A WaitGroup waits for a collection of goroutines to finish.
14// The main goroutine calls [WaitGroup.Add] to set the number of
15// goroutines to wait for. Then each of the goroutines
16// runs and calls [WaitGroup.Done] when finished. At the same time,
17// [WaitGroup.Wait] can be used to block until all goroutines have finished.
18//
19// A WaitGroup must not be copied after first use.
20//
21// In the terminology of [the Go memory model], a call to [WaitGroup.Done]
22// “synchronizes before” the return of any Wait call that it unblocks.
23//
24// [the Go memory model]: https://go.dev/ref/mem
25type WaitGroup struct {
26	noCopy noCopy
27
28	state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
29	sema  uint32
30}
31
32// Add adds delta, which may be negative, to the [WaitGroup] counter.
33// If the counter becomes zero, all goroutines blocked on [WaitGroup.Wait] are released.
34// If the counter goes negative, Add panics.
35//
36// Note that calls with a positive delta that occur when the counter is zero
37// must happen before a Wait. Calls with a negative delta, or calls with a
38// positive delta that start when the counter is greater than zero, may happen
39// at any time.
40// Typically this means the calls to Add should execute before the statement
41// creating the goroutine or other event to be waited for.
42// If a WaitGroup is reused to wait for several independent sets of events,
43// new Add calls must happen after all previous Wait calls have returned.
44// See the WaitGroup example.
45func (wg *WaitGroup) Add(delta int) {
46	if race.Enabled {
47		if delta < 0 {
48			// Synchronize decrements with Wait.
49			race.ReleaseMerge(unsafe.Pointer(wg))
50		}
51		race.Disable()
52		defer race.Enable()
53	}
54	state := wg.state.Add(uint64(delta) << 32)
55	v := int32(state >> 32)
56	w := uint32(state)
57	if race.Enabled && delta > 0 && v == int32(delta) {
58		// The first increment must be synchronized with Wait.
59		// Need to model this as a read, because there can be
60		// several concurrent wg.counter transitions from 0.
61		race.Read(unsafe.Pointer(&wg.sema))
62	}
63	if v < 0 {
64		panic("sync: negative WaitGroup counter")
65	}
66	if w != 0 && delta > 0 && v == int32(delta) {
67		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
68	}
69	if v > 0 || w == 0 {
70		return
71	}
72	// This goroutine has set counter to 0 when waiters > 0.
73	// Now there can't be concurrent mutations of state:
74	// - Adds must not happen concurrently with Wait,
75	// - Wait does not increment waiters if it sees counter == 0.
76	// Still do a cheap sanity check to detect WaitGroup misuse.
77	if wg.state.Load() != state {
78		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
79	}
80	// Reset waiters count to 0.
81	wg.state.Store(0)
82	for ; w != 0; w-- {
83		runtime_Semrelease(&wg.sema, false, 0)
84	}
85}
86
87// Done decrements the [WaitGroup] counter by one.
88func (wg *WaitGroup) Done() {
89	wg.Add(-1)
90}
91
92// Wait blocks until the [WaitGroup] counter is zero.
93func (wg *WaitGroup) Wait() {
94	if race.Enabled {
95		race.Disable()
96	}
97	for {
98		state := wg.state.Load()
99		v := int32(state >> 32)
100		w := uint32(state)
101		if v == 0 {
102			// Counter is 0, no need to wait.
103			if race.Enabled {
104				race.Enable()
105				race.Acquire(unsafe.Pointer(wg))
106			}
107			return
108		}
109		// Increment waiters count.
110		if wg.state.CompareAndSwap(state, state+1) {
111			if race.Enabled && w == 0 {
112				// Wait must be synchronized with the first Add.
113				// Need to model this is as a write to race with the read in Add.
114				// As a consequence, can do the write only for the first waiter,
115				// otherwise concurrent Waits will race with each other.
116				race.Write(unsafe.Pointer(&wg.sema))
117			}
118			runtime_Semacquire(&wg.sema)
119			if wg.state.Load() != 0 {
120				panic("sync: WaitGroup is reused before previous Wait has returned")
121			}
122			if race.Enabled {
123				race.Enable()
124				race.Acquire(unsafe.Pointer(wg))
125			}
126			return
127		}
128	}
129}
130