xref: /aosp_15_r20/external/boringssl/src/ssl/test/runner/shim_dispatcher.go (revision 8fb009dc861624b67b6cdb62ea21f0f22d0c584b)
1// Copyright (c) 2023, Google Inc.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15package runner
16
17import (
18	"context"
19	"encoding/binary"
20	"fmt"
21	"io"
22	"net"
23	"os"
24	"sync"
25	"time"
26)
27
28type shimDispatcher struct {
29	lock       sync.Mutex
30	nextShimID uint64
31	listener   *net.TCPListener
32	shims      map[uint64]*shimListener
33	err        error
34}
35
36func newShimDispatcher() (*shimDispatcher, error) {
37	listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv6loopback})
38	if err != nil {
39		listener, err = net.ListenTCP("tcp4", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}})
40	}
41
42	if err != nil {
43		return nil, err
44	}
45	d := &shimDispatcher{listener: listener, shims: make(map[uint64]*shimListener)}
46	go d.acceptLoop()
47	return d, nil
48}
49
50func (d *shimDispatcher) NewShim() (*shimListener, error) {
51	d.lock.Lock()
52	defer d.lock.Unlock()
53	if d.err != nil {
54		return nil, d.err
55	}
56
57	l := &shimListener{dispatcher: d, shimID: d.nextShimID, connChan: make(chan net.Conn, 1)}
58	d.shims[l.shimID] = l
59	d.nextShimID++
60	return l, nil
61}
62
63func (d *shimDispatcher) unregisterShim(l *shimListener) {
64	d.lock.Lock()
65	delete(d.shims, l.shimID)
66	d.lock.Unlock()
67}
68
69func (d *shimDispatcher) acceptLoop() {
70	for {
71		conn, err := d.listener.Accept()
72		if err != nil {
73			// Something went wrong. Shut down the listener.
74			d.closeWithError(err)
75			return
76		}
77
78		go func() {
79			if err := d.dispatch(conn); err != nil {
80				// To be robust against port scanners, etc., we log a warning
81				// but otherwise treat undispatchable connections as non-fatal.
82				fmt.Fprintf(os.Stderr, "Error dispatching connection: %s\n", err)
83				conn.Close()
84			}
85		}()
86	}
87}
88
89func (d *shimDispatcher) dispatch(conn net.Conn) error {
90	conn.SetReadDeadline(time.Now().Add(*idleTimeout))
91	var buf [8]byte
92	if _, err := io.ReadFull(conn, buf[:]); err != nil {
93		return err
94	}
95	conn.SetReadDeadline(time.Time{})
96
97	shimID := binary.LittleEndian.Uint64(buf[:])
98	d.lock.Lock()
99	shim, ok := d.shims[shimID]
100	d.lock.Unlock()
101	if !ok {
102		return fmt.Errorf("shim ID %d not found", shimID)
103	}
104
105	shim.connChan <- conn
106	return nil
107}
108
109func (d *shimDispatcher) Close() error {
110	return d.closeWithError(net.ErrClosed)
111}
112
113func (d *shimDispatcher) closeWithError(err error) error {
114	closeErr := d.listener.Close()
115
116	d.lock.Lock()
117	shims := d.shims
118	d.shims = make(map[uint64]*shimListener)
119	d.err = err
120	d.lock.Unlock()
121
122	for _, shim := range shims {
123		shim.closeWithError(err)
124	}
125	return closeErr
126}
127
128type shimListener struct {
129	dispatcher *shimDispatcher
130	shimID     uint64
131	// connChan contains connections from the dispatcher. On fatal error, it is
132	// closed, with the error available in err.
133	connChan chan net.Conn
134	err      error
135	lock     sync.Mutex
136}
137
138func (l *shimListener) Port() int {
139	return l.dispatcher.listener.Addr().(*net.TCPAddr).Port
140}
141
142func (l *shimListener) IsIPv6() bool {
143	return len(l.dispatcher.listener.Addr().(*net.TCPAddr).IP) == net.IPv6len
144}
145
146func (l *shimListener) ShimID() uint64 {
147	return l.shimID
148}
149
150func (l *shimListener) Close() error {
151	l.dispatcher.unregisterShim(l)
152	l.closeWithError(net.ErrClosed)
153	return nil
154}
155
156func (l *shimListener) closeWithError(err error) {
157	// Multiple threads may close the listener at once, so protect closing with
158	// a lock.
159	l.lock.Lock()
160	if l.err == nil {
161		l.err = err
162		close(l.connChan)
163	}
164	l.lock.Unlock()
165}
166
167func (l *shimListener) Accept(deadline time.Time) (net.Conn, error) {
168	var timerChan <-chan time.Time
169	if !deadline.IsZero() {
170		remaining := time.Until(deadline)
171		if remaining < 0 {
172			return nil, context.DeadlineExceeded
173		}
174		timer := time.NewTimer(remaining)
175		defer timer.Stop()
176		timerChan = timer.C
177	}
178
179	select {
180	case <-timerChan:
181		return nil, context.DeadlineExceeded
182	case conn, ok := <-l.connChan:
183		if !ok {
184			return nil, l.err
185		}
186		return conn, nil
187	}
188}
189