1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package runtime
6
7import (
8	"internal/abi"
9	"internal/runtime/atomic"
10	"unsafe"
11)
12
13// A Pinner is a set of Go objects each pinned to a fixed location in memory. The
14// [Pinner.Pin] method pins one object, while [Pinner.Unpin] unpins all pinned
15// objects. See their comments for more information.
16type Pinner struct {
17	*pinner
18}
19
20// Pin pins a Go object, preventing it from being moved or freed by the garbage
21// collector until the [Pinner.Unpin] method has been called.
22//
23// A pointer to a pinned object can be directly stored in C memory or can be
24// contained in Go memory passed to C functions. If the pinned object itself
25// contains pointers to Go objects, these objects must be pinned separately if they
26// are going to be accessed from C code.
27//
28// The argument must be a pointer of any type or an [unsafe.Pointer].
29// It's safe to call Pin on non-Go pointers, in which case Pin will do nothing.
30func (p *Pinner) Pin(pointer any) {
31	if p.pinner == nil {
32		// Check the pinner cache first.
33		mp := acquirem()
34		if pp := mp.p.ptr(); pp != nil {
35			p.pinner = pp.pinnerCache
36			pp.pinnerCache = nil
37		}
38		releasem(mp)
39
40		if p.pinner == nil {
41			// Didn't get anything from the pinner cache.
42			p.pinner = new(pinner)
43			p.refs = p.refStore[:0]
44
45			// We set this finalizer once and never clear it. Thus, if the
46			// pinner gets cached, we'll reuse it, along with its finalizer.
47			// This lets us avoid the relatively expensive SetFinalizer call
48			// when reusing from the cache. The finalizer however has to be
49			// resilient to an empty pinner being finalized, which is done
50			// by checking p.refs' length.
51			SetFinalizer(p.pinner, func(i *pinner) {
52				if len(i.refs) != 0 {
53					i.unpin() // only required to make the test idempotent
54					pinnerLeakPanic()
55				}
56			})
57		}
58	}
59	ptr := pinnerGetPtr(&pointer)
60	if setPinned(ptr, true) {
61		p.refs = append(p.refs, ptr)
62	}
63}
64
65// Unpin unpins all pinned objects of the [Pinner].
66func (p *Pinner) Unpin() {
67	p.pinner.unpin()
68
69	mp := acquirem()
70	if pp := mp.p.ptr(); pp != nil && pp.pinnerCache == nil {
71		// Put the pinner back in the cache, but only if the
72		// cache is empty. If application code is reusing Pinners
73		// on its own, we want to leave the backing store in place
74		// so reuse is more efficient.
75		pp.pinnerCache = p.pinner
76		p.pinner = nil
77	}
78	releasem(mp)
79}
80
81const (
82	pinnerSize         = 64
83	pinnerRefStoreSize = (pinnerSize - unsafe.Sizeof([]unsafe.Pointer{})) / unsafe.Sizeof(unsafe.Pointer(nil))
84)
85
86type pinner struct {
87	refs     []unsafe.Pointer
88	refStore [pinnerRefStoreSize]unsafe.Pointer
89}
90
91func (p *pinner) unpin() {
92	if p == nil || p.refs == nil {
93		return
94	}
95	for i := range p.refs {
96		setPinned(p.refs[i], false)
97	}
98	// The following two lines make all pointers to references
99	// in p.refs unreachable, either by deleting them or dropping
100	// p.refs' backing store (if it was not backed by refStore).
101	p.refStore = [pinnerRefStoreSize]unsafe.Pointer{}
102	p.refs = p.refStore[:0]
103}
104
105func pinnerGetPtr(i *any) unsafe.Pointer {
106	e := efaceOf(i)
107	etyp := e._type
108	if etyp == nil {
109		panic(errorString("runtime.Pinner: argument is nil"))
110	}
111	if kind := etyp.Kind_ & abi.KindMask; kind != abi.Pointer && kind != abi.UnsafePointer {
112		panic(errorString("runtime.Pinner: argument is not a pointer: " + toRType(etyp).string()))
113	}
114	if inUserArenaChunk(uintptr(e.data)) {
115		// Arena-allocated objects are not eligible for pinning.
116		panic(errorString("runtime.Pinner: object was allocated into an arena"))
117	}
118	return e.data
119}
120
121// isPinned checks if a Go pointer is pinned.
122// nosplit, because it's called from nosplit code in cgocheck.
123//
124//go:nosplit
125func isPinned(ptr unsafe.Pointer) bool {
126	span := spanOfHeap(uintptr(ptr))
127	if span == nil {
128		// this code is only called for Go pointer, so this must be a
129		// linker-allocated global object.
130		return true
131	}
132	pinnerBits := span.getPinnerBits()
133	// these pinnerBits might get unlinked by a concurrently running sweep, but
134	// that's OK because gcBits don't get cleared until the following GC cycle
135	// (nextMarkBitArenaEpoch)
136	if pinnerBits == nil {
137		return false
138	}
139	objIndex := span.objIndex(uintptr(ptr))
140	pinState := pinnerBits.ofObject(objIndex)
141	KeepAlive(ptr) // make sure ptr is alive until we are done so the span can't be freed
142	return pinState.isPinned()
143}
144
145// setPinned marks or unmarks a Go pointer as pinned, when the ptr is a Go pointer.
146// It will be ignored while try to pin a non-Go pointer,
147// and it will be panic while try to unpin a non-Go pointer,
148// which should not happen in normal usage.
149func setPinned(ptr unsafe.Pointer, pin bool) bool {
150	span := spanOfHeap(uintptr(ptr))
151	if span == nil {
152		if !pin {
153			panic(errorString("tried to unpin non-Go pointer"))
154		}
155		// This is a linker-allocated, zero size object or other object,
156		// nothing to do, silently ignore it.
157		return false
158	}
159
160	// ensure that the span is swept, b/c sweeping accesses the specials list
161	// w/o locks.
162	mp := acquirem()
163	span.ensureSwept()
164	KeepAlive(ptr) // make sure ptr is still alive after span is swept
165
166	objIndex := span.objIndex(uintptr(ptr))
167
168	lock(&span.speciallock) // guard against concurrent calls of setPinned on same span
169
170	pinnerBits := span.getPinnerBits()
171	if pinnerBits == nil {
172		pinnerBits = span.newPinnerBits()
173		span.setPinnerBits(pinnerBits)
174	}
175	pinState := pinnerBits.ofObject(objIndex)
176	if pin {
177		if pinState.isPinned() {
178			// multiple pins on same object, set multipin bit
179			pinState.setMultiPinned(true)
180			// and increase the pin counter
181			// TODO(mknyszek): investigate if systemstack is necessary here
182			systemstack(func() {
183				offset := objIndex * span.elemsize
184				span.incPinCounter(offset)
185			})
186		} else {
187			// set pin bit
188			pinState.setPinned(true)
189		}
190	} else {
191		// unpin
192		if pinState.isPinned() {
193			if pinState.isMultiPinned() {
194				var exists bool
195				// TODO(mknyszek): investigate if systemstack is necessary here
196				systemstack(func() {
197					offset := objIndex * span.elemsize
198					exists = span.decPinCounter(offset)
199				})
200				if !exists {
201					// counter is 0, clear multipin bit
202					pinState.setMultiPinned(false)
203				}
204			} else {
205				// no multipins recorded. unpin object.
206				pinState.setPinned(false)
207			}
208		} else {
209			// unpinning unpinned object, bail out
210			throw("runtime.Pinner: object already unpinned")
211		}
212	}
213	unlock(&span.speciallock)
214	releasem(mp)
215	return true
216}
217
218type pinState struct {
219	bytep   *uint8
220	byteVal uint8
221	mask    uint8
222}
223
224// nosplit, because it's called by isPinned, which is nosplit
225//
226//go:nosplit
227func (v *pinState) isPinned() bool {
228	return (v.byteVal & v.mask) != 0
229}
230
231func (v *pinState) isMultiPinned() bool {
232	return (v.byteVal & (v.mask << 1)) != 0
233}
234
235func (v *pinState) setPinned(val bool) {
236	v.set(val, false)
237}
238
239func (v *pinState) setMultiPinned(val bool) {
240	v.set(val, true)
241}
242
243// set sets the pin bit of the pinState to val. If multipin is true, it
244// sets/unsets the multipin bit instead.
245func (v *pinState) set(val bool, multipin bool) {
246	mask := v.mask
247	if multipin {
248		mask <<= 1
249	}
250	if val {
251		atomic.Or8(v.bytep, mask)
252	} else {
253		atomic.And8(v.bytep, ^mask)
254	}
255}
256
257// pinnerBits is the same type as gcBits but has different methods.
258type pinnerBits gcBits
259
260// ofObject returns the pinState of the n'th object.
261// nosplit, because it's called by isPinned, which is nosplit
262//
263//go:nosplit
264func (p *pinnerBits) ofObject(n uintptr) pinState {
265	bytep, mask := (*gcBits)(p).bitp(n * 2)
266	byteVal := atomic.Load8(bytep)
267	return pinState{bytep, byteVal, mask}
268}
269
270func (s *mspan) pinnerBitSize() uintptr {
271	return divRoundUp(uintptr(s.nelems)*2, 8)
272}
273
274// newPinnerBits returns a pointer to 8 byte aligned bytes to be used for this
275// span's pinner bits. newPinnerBits is used to mark objects that are pinned.
276// They are copied when the span is swept.
277func (s *mspan) newPinnerBits() *pinnerBits {
278	return (*pinnerBits)(newMarkBits(uintptr(s.nelems) * 2))
279}
280
281// nosplit, because it's called by isPinned, which is nosplit
282//
283//go:nosplit
284func (s *mspan) getPinnerBits() *pinnerBits {
285	return (*pinnerBits)(atomic.Loadp(unsafe.Pointer(&s.pinnerBits)))
286}
287
288func (s *mspan) setPinnerBits(p *pinnerBits) {
289	atomicstorep(unsafe.Pointer(&s.pinnerBits), unsafe.Pointer(p))
290}
291
292// refreshPinnerBits replaces pinnerBits with a fresh copy in the arenas for the
293// next GC cycle. If it does not contain any pinned objects, pinnerBits of the
294// span is set to nil.
295func (s *mspan) refreshPinnerBits() {
296	p := s.getPinnerBits()
297	if p == nil {
298		return
299	}
300
301	hasPins := false
302	bytes := alignUp(s.pinnerBitSize(), 8)
303
304	// Iterate over each 8-byte chunk and check for pins. Note that
305	// newPinnerBits guarantees that pinnerBits will be 8-byte aligned, so we
306	// don't have to worry about edge cases, irrelevant bits will simply be
307	// zero.
308	for _, x := range unsafe.Slice((*uint64)(unsafe.Pointer(&p.x)), bytes/8) {
309		if x != 0 {
310			hasPins = true
311			break
312		}
313	}
314
315	if hasPins {
316		newPinnerBits := s.newPinnerBits()
317		memmove(unsafe.Pointer(&newPinnerBits.x), unsafe.Pointer(&p.x), bytes)
318		s.setPinnerBits(newPinnerBits)
319	} else {
320		s.setPinnerBits(nil)
321	}
322}
323
324// incPinCounter is only called for multiple pins of the same object and records
325// the _additional_ pins.
326func (span *mspan) incPinCounter(offset uintptr) {
327	var rec *specialPinCounter
328	ref, exists := span.specialFindSplicePoint(offset, _KindSpecialPinCounter)
329	if !exists {
330		lock(&mheap_.speciallock)
331		rec = (*specialPinCounter)(mheap_.specialPinCounterAlloc.alloc())
332		unlock(&mheap_.speciallock)
333		// splice in record, fill in offset.
334		rec.special.offset = uint16(offset)
335		rec.special.kind = _KindSpecialPinCounter
336		rec.special.next = *ref
337		*ref = (*special)(unsafe.Pointer(rec))
338		spanHasSpecials(span)
339	} else {
340		rec = (*specialPinCounter)(unsafe.Pointer(*ref))
341	}
342	rec.counter++
343}
344
345// decPinCounter decreases the counter. If the counter reaches 0, the counter
346// special is deleted and false is returned. Otherwise true is returned.
347func (span *mspan) decPinCounter(offset uintptr) bool {
348	ref, exists := span.specialFindSplicePoint(offset, _KindSpecialPinCounter)
349	if !exists {
350		throw("runtime.Pinner: decreased non-existing pin counter")
351	}
352	counter := (*specialPinCounter)(unsafe.Pointer(*ref))
353	counter.counter--
354	if counter.counter == 0 {
355		*ref = counter.special.next
356		if span.specials == nil {
357			spanHasNoSpecials(span)
358		}
359		lock(&mheap_.speciallock)
360		mheap_.specialPinCounterAlloc.free(unsafe.Pointer(counter))
361		unlock(&mheap_.speciallock)
362		return false
363	}
364	return true
365}
366
367// only for tests
368func pinnerGetPinCounter(addr unsafe.Pointer) *uintptr {
369	_, span, objIndex := findObject(uintptr(addr), 0, 0)
370	offset := objIndex * span.elemsize
371	t, exists := span.specialFindSplicePoint(offset, _KindSpecialPinCounter)
372	if !exists {
373		return nil
374	}
375	counter := (*specialPinCounter)(unsafe.Pointer(*t))
376	return &counter.counter
377}
378
379// to be able to test that the GC panics when a pinned pointer is leaking, this
380// panic function is a variable, that can be overwritten by a test.
381var pinnerLeakPanic = func() {
382	panic(errorString("runtime.Pinner: found leaking pinned pointer; forgot to call Unpin()?"))
383}
384