1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package runtime
6
7import (
8	"internal/goarch"
9	"internal/runtime/atomic"
10	"unsafe"
11)
12
13const _DWORD_MAX = 0xffffffff
14
15const _INVALID_HANDLE_VALUE = ^uintptr(0)
16
17// Sources are used to identify the event that created an overlapped entry.
18// The source values are arbitrary. There is no risk of collision with user
19// defined values because the only way to set the key of an overlapped entry
20// is using the iocphandle, which is not accessible to user code.
21const (
22	netpollSourceReady = iota + 1
23	netpollSourceBreak
24	netpollSourceTimer
25)
26
27const (
28	// sourceBits is the number of bits needed to represent a source.
29	// 4 bits can hold 16 different sources, which is more than enough.
30	// It is set to a low value so the overlapped entry key can
31	// contain as much bits as possible for the pollDesc pointer.
32	sourceBits  = 4 // 4 bits can hold 16 different sources, which is more than enough.
33	sourceMasks = 1<<sourceBits - 1
34)
35
36// packNetpollKey creates a key from a source and a tag.
37// Bits that don't fit in the result are discarded.
38func packNetpollKey(source uint8, pd *pollDesc) uintptr {
39	// TODO: Consider combining the source with pd.fdseq to detect stale pollDescs.
40	if source > (1<<sourceBits)-1 {
41		// Also fail on 64-bit systems, even though it can hold more bits.
42		throw("runtime: source value is too large")
43	}
44	if goarch.PtrSize == 4 {
45		return uintptr(unsafe.Pointer(pd))<<sourceBits | uintptr(source)
46	}
47	return uintptr(taggedPointerPack(unsafe.Pointer(pd), uintptr(source)))
48}
49
50// unpackNetpollSource returns the source packed key.
51func unpackNetpollSource(key uintptr) uint8 {
52	if goarch.PtrSize == 4 {
53		return uint8(key & sourceMasks)
54	}
55	return uint8(taggedPointer(key).tag())
56}
57
58// pollOperation must be the same as beginning of internal/poll.operation.
59// Keep these in sync.
60type pollOperation struct {
61	// used by windows
62	_ overlapped
63	// used by netpoll
64	pd   *pollDesc
65	mode int32
66}
67
68// pollOperationFromOverlappedEntry returns the pollOperation contained in
69// e. It can return nil if the entry is not from internal/poll.
70// See go.dev/issue/58870
71func pollOperationFromOverlappedEntry(e *overlappedEntry) *pollOperation {
72	if e.ov == nil {
73		return nil
74	}
75	op := (*pollOperation)(unsafe.Pointer(e.ov))
76	// Check that the key matches the pollDesc pointer.
77	var keyMatch bool
78	if goarch.PtrSize == 4 {
79		keyMatch = e.key&^sourceMasks == uintptr(unsafe.Pointer(op.pd))<<sourceBits
80	} else {
81		keyMatch = (*pollDesc)(taggedPointer(e.key).pointer()) == op.pd
82	}
83	if !keyMatch {
84		return nil
85	}
86	return op
87}
88
89// overlappedEntry contains the information returned by a call to GetQueuedCompletionStatusEx.
90// https://learn.microsoft.com/en-us/windows/win32/api/minwinbase/ns-minwinbase-overlapped_entry
91type overlappedEntry struct {
92	key      uintptr
93	ov       *overlapped
94	internal uintptr
95	qty      uint32
96}
97
98var (
99	iocphandle uintptr = _INVALID_HANDLE_VALUE // completion port io handle
100
101	netpollWakeSig atomic.Uint32 // used to avoid duplicate calls of netpollBreak
102)
103
104func netpollinit() {
105	iocphandle = stdcall4(_CreateIoCompletionPort, _INVALID_HANDLE_VALUE, 0, 0, _DWORD_MAX)
106	if iocphandle == 0 {
107		println("runtime: CreateIoCompletionPort failed (errno=", getlasterror(), ")")
108		throw("runtime: netpollinit failed")
109	}
110}
111
112func netpollIsPollDescriptor(fd uintptr) bool {
113	return fd == iocphandle
114}
115
116func netpollopen(fd uintptr, pd *pollDesc) int32 {
117	key := packNetpollKey(netpollSourceReady, pd)
118	if stdcall4(_CreateIoCompletionPort, fd, iocphandle, key, 0) == 0 {
119		return int32(getlasterror())
120	}
121	return 0
122}
123
124func netpollclose(fd uintptr) int32 {
125	// nothing to do
126	return 0
127}
128
129func netpollarm(pd *pollDesc, mode int) {
130	throw("runtime: unused")
131}
132
133func netpollBreak() {
134	// Failing to cas indicates there is an in-flight wakeup, so we're done here.
135	if !netpollWakeSig.CompareAndSwap(0, 1) {
136		return
137	}
138
139	key := packNetpollKey(netpollSourceBreak, nil)
140	if stdcall4(_PostQueuedCompletionStatus, iocphandle, 0, key, 0) == 0 {
141		println("runtime: netpoll: PostQueuedCompletionStatus failed (errno=", getlasterror(), ")")
142		throw("runtime: netpoll: PostQueuedCompletionStatus failed")
143	}
144}
145
146// netpoll checks for ready network connections.
147// Returns list of goroutines that become runnable.
148// delay < 0: blocks indefinitely
149// delay == 0: does not block, just polls
150// delay > 0: block for up to that many nanoseconds
151func netpoll(delay int64) (gList, int32) {
152	if iocphandle == _INVALID_HANDLE_VALUE {
153		return gList{}, 0
154	}
155
156	var entries [64]overlappedEntry
157	var wait uint32
158	var toRun gList
159	mp := getg().m
160
161	if delay >= 1e15 {
162		// An arbitrary cap on how long to wait for a timer.
163		// 1e15 ns == ~11.5 days.
164		delay = 1e15
165	}
166
167	if delay > 0 && mp.waitIocpHandle != 0 {
168		// GetQueuedCompletionStatusEx doesn't use a high resolution timer internally,
169		// so we use a separate higher resolution timer associated with a wait completion
170		// packet to wake up the poller. Note that the completion packet can be delivered
171		// to another thread, and the Go scheduler expects netpoll to only block up to delay,
172		// so we still need to use a timeout with GetQueuedCompletionStatusEx.
173		// TODO: Improve the Go scheduler to support non-blocking timers.
174		signaled := netpollQueueTimer(delay)
175		if signaled {
176			// There is a small window between the SetWaitableTimer and the NtAssociateWaitCompletionPacket
177			// where the timer can expire. We can return immediately in this case.
178			return gList{}, 0
179		}
180	}
181	if delay < 0 {
182		wait = _INFINITE
183	} else if delay == 0 {
184		wait = 0
185	} else if delay < 1e6 {
186		wait = 1
187	} else {
188		wait = uint32(delay / 1e6)
189	}
190	n := len(entries) / int(gomaxprocs)
191	if n < 8 {
192		n = 8
193	}
194	if delay != 0 {
195		mp.blocked = true
196	}
197	if stdcall6(_GetQueuedCompletionStatusEx, iocphandle, uintptr(unsafe.Pointer(&entries[0])), uintptr(n), uintptr(unsafe.Pointer(&n)), uintptr(wait), 0) == 0 {
198		mp.blocked = false
199		errno := getlasterror()
200		if errno == _WAIT_TIMEOUT {
201			return gList{}, 0
202		}
203		println("runtime: GetQueuedCompletionStatusEx failed (errno=", errno, ")")
204		throw("runtime: netpoll failed")
205	}
206	mp.blocked = false
207	delta := int32(0)
208	for i := 0; i < n; i++ {
209		e := &entries[i]
210		switch unpackNetpollSource(e.key) {
211		case netpollSourceReady:
212			op := pollOperationFromOverlappedEntry(e)
213			if op == nil {
214				// Entry from outside the Go runtime and internal/poll, ignore.
215				continue
216			}
217			// Entry from internal/poll.
218			mode := op.mode
219			if mode != 'r' && mode != 'w' {
220				println("runtime: GetQueuedCompletionStatusEx returned net_op with invalid mode=", mode)
221				throw("runtime: netpoll failed")
222			}
223			delta += netpollready(&toRun, op.pd, mode)
224		case netpollSourceBreak:
225			netpollWakeSig.Store(0)
226			if delay == 0 {
227				// Forward the notification to the blocked poller.
228				netpollBreak()
229			}
230		case netpollSourceTimer:
231			// TODO: We could avoid calling NtCancelWaitCompletionPacket for expired wait completion packets.
232		default:
233			println("runtime: GetQueuedCompletionStatusEx returned net_op with invalid key=", e.key)
234			throw("runtime: netpoll failed")
235		}
236	}
237	return toRun, delta
238}
239
240// netpollQueueTimer queues a timer to wake up the poller after the given delay.
241// It returns true if the timer expired during this call.
242func netpollQueueTimer(delay int64) (signaled bool) {
243	const (
244		STATUS_SUCCESS   = 0x00000000
245		STATUS_PENDING   = 0x00000103
246		STATUS_CANCELLED = 0xC0000120
247	)
248	mp := getg().m
249	// A wait completion packet can only be associated with one timer at a time,
250	// so we need to cancel the previous one if it exists. This wouldn't be necessary
251	// if the poller would only be woken up by the timer, in which case the association
252	// would be automatically canceled, but it can also be woken up by other events,
253	// such as a netpollBreak, so we can get to this point with a timer that hasn't
254	// expired yet. In this case, the completion packet can still be picked up by
255	// another thread, so defer the cancellation until it is really necessary.
256	errno := stdcall2(_NtCancelWaitCompletionPacket, mp.waitIocpHandle, 1)
257	switch errno {
258	case STATUS_CANCELLED:
259		// STATUS_CANCELLED is returned when the associated timer has already expired,
260		// in which automatically cancels the wait completion packet.
261		fallthrough
262	case STATUS_SUCCESS:
263		dt := -delay / 100 // relative sleep (negative), 100ns units
264		if stdcall6(_SetWaitableTimer, mp.waitIocpTimer, uintptr(unsafe.Pointer(&dt)), 0, 0, 0, 0) == 0 {
265			println("runtime: SetWaitableTimer failed; errno=", getlasterror())
266			throw("runtime: netpoll failed")
267		}
268		key := packNetpollKey(netpollSourceTimer, nil)
269		if errno := stdcall8(_NtAssociateWaitCompletionPacket, mp.waitIocpHandle, iocphandle, mp.waitIocpTimer, key, 0, 0, 0, uintptr(unsafe.Pointer(&signaled))); errno != 0 {
270			println("runtime: NtAssociateWaitCompletionPacket failed; errno=", errno)
271			throw("runtime: netpoll failed")
272		}
273	case STATUS_PENDING:
274		// STATUS_PENDING is returned if the wait operation can't be canceled yet.
275		// This can happen if this thread was woken up by another event, such as a netpollBreak,
276		// and the timer expired just while calling NtCancelWaitCompletionPacket, in which case
277		// this call fails to cancel the association to avoid a race condition.
278		// This is a rare case, so we can just avoid using the high resolution timer this time.
279	default:
280		println("runtime: NtCancelWaitCompletionPacket failed; errno=", errno)
281		throw("runtime: netpoll failed")
282	}
283	return signaled
284}
285