1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build unix || (js && wasm) || wasip1
6
7package socktest
8
9import "syscall"
10
11// Socket wraps [syscall.Socket].
12func (sw *Switch) Socket(family, sotype, proto int) (s int, err error) {
13	sw.once.Do(sw.init)
14
15	so := &Status{Cookie: cookie(family, sotype, proto)}
16	sw.fmu.RLock()
17	f := sw.fltab[FilterSocket]
18	sw.fmu.RUnlock()
19
20	af, err := f.apply(so)
21	if err != nil {
22		return -1, err
23	}
24	s, so.Err = syscall.Socket(family, sotype, proto)
25	if err = af.apply(so); err != nil {
26		if so.Err == nil {
27			syscall.Close(s)
28		}
29		return -1, err
30	}
31
32	sw.smu.Lock()
33	defer sw.smu.Unlock()
34	if so.Err != nil {
35		sw.stats.getLocked(so.Cookie).OpenFailed++
36		return -1, so.Err
37	}
38	nso := sw.addLocked(s, family, sotype, proto)
39	sw.stats.getLocked(nso.Cookie).Opened++
40	return s, nil
41}
42
43// Close wraps syscall.Close.
44func (sw *Switch) Close(s int) (err error) {
45	so := sw.sockso(s)
46	if so == nil {
47		return syscall.Close(s)
48	}
49	sw.fmu.RLock()
50	f := sw.fltab[FilterClose]
51	sw.fmu.RUnlock()
52
53	af, err := f.apply(so)
54	if err != nil {
55		return err
56	}
57	so.Err = syscall.Close(s)
58	if err = af.apply(so); err != nil {
59		return err
60	}
61
62	sw.smu.Lock()
63	defer sw.smu.Unlock()
64	if so.Err != nil {
65		sw.stats.getLocked(so.Cookie).CloseFailed++
66		return so.Err
67	}
68	delete(sw.sotab, s)
69	sw.stats.getLocked(so.Cookie).Closed++
70	return nil
71}
72
73// Connect wraps syscall.Connect.
74func (sw *Switch) Connect(s int, sa syscall.Sockaddr) (err error) {
75	so := sw.sockso(s)
76	if so == nil {
77		return syscall.Connect(s, sa)
78	}
79	sw.fmu.RLock()
80	f := sw.fltab[FilterConnect]
81	sw.fmu.RUnlock()
82
83	af, err := f.apply(so)
84	if err != nil {
85		return err
86	}
87	so.Err = syscall.Connect(s, sa)
88	if err = af.apply(so); err != nil {
89		return err
90	}
91
92	sw.smu.Lock()
93	defer sw.smu.Unlock()
94	if so.Err != nil {
95		sw.stats.getLocked(so.Cookie).ConnectFailed++
96		return so.Err
97	}
98	sw.stats.getLocked(so.Cookie).Connected++
99	return nil
100}
101
102// Listen wraps syscall.Listen.
103func (sw *Switch) Listen(s, backlog int) (err error) {
104	so := sw.sockso(s)
105	if so == nil {
106		return syscall.Listen(s, backlog)
107	}
108	sw.fmu.RLock()
109	f := sw.fltab[FilterListen]
110	sw.fmu.RUnlock()
111
112	af, err := f.apply(so)
113	if err != nil {
114		return err
115	}
116	so.Err = syscall.Listen(s, backlog)
117	if err = af.apply(so); err != nil {
118		return err
119	}
120
121	sw.smu.Lock()
122	defer sw.smu.Unlock()
123	if so.Err != nil {
124		sw.stats.getLocked(so.Cookie).ListenFailed++
125		return so.Err
126	}
127	sw.stats.getLocked(so.Cookie).Listened++
128	return nil
129}
130
131// Accept wraps syscall.Accept.
132func (sw *Switch) Accept(s int) (ns int, sa syscall.Sockaddr, err error) {
133	so := sw.sockso(s)
134	if so == nil {
135		return syscall.Accept(s)
136	}
137	sw.fmu.RLock()
138	f := sw.fltab[FilterAccept]
139	sw.fmu.RUnlock()
140
141	af, err := f.apply(so)
142	if err != nil {
143		return -1, nil, err
144	}
145	ns, sa, so.Err = syscall.Accept(s)
146	if err = af.apply(so); err != nil {
147		if so.Err == nil {
148			syscall.Close(ns)
149		}
150		return -1, nil, err
151	}
152
153	sw.smu.Lock()
154	defer sw.smu.Unlock()
155	if so.Err != nil {
156		sw.stats.getLocked(so.Cookie).AcceptFailed++
157		return -1, nil, so.Err
158	}
159	nso := sw.addLocked(ns, so.Cookie.Family(), so.Cookie.Type(), so.Cookie.Protocol())
160	sw.stats.getLocked(nso.Cookie).Accepted++
161	return ns, sa, nil
162}
163
164// GetsockoptInt wraps syscall.GetsockoptInt.
165func (sw *Switch) GetsockoptInt(s, level, opt int) (soerr int, err error) {
166	so := sw.sockso(s)
167	if so == nil {
168		return syscall.GetsockoptInt(s, level, opt)
169	}
170	sw.fmu.RLock()
171	f := sw.fltab[FilterGetsockoptInt]
172	sw.fmu.RUnlock()
173
174	af, err := f.apply(so)
175	if err != nil {
176		return -1, err
177	}
178	soerr, so.Err = syscall.GetsockoptInt(s, level, opt)
179	so.SocketErr = syscall.Errno(soerr)
180	if err = af.apply(so); err != nil {
181		return -1, err
182	}
183
184	if so.Err != nil {
185		return -1, so.Err
186	}
187	if opt == syscall.SO_ERROR && (so.SocketErr == syscall.Errno(0) || so.SocketErr == syscall.EISCONN) {
188		sw.smu.Lock()
189		sw.stats.getLocked(so.Cookie).Connected++
190		sw.smu.Unlock()
191	}
192	return soerr, nil
193}
194