1// Copyright (c) 2023, Google Inc. 2// 3// Permission to use, copy, modify, and/or distribute this software for any 4// purpose with or without fee is hereby granted, provided that the above 5// copyright notice and this permission notice appear in all copies. 6// 7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY 10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION 12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN 13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 15package runner 16 17import ( 18 "context" 19 "encoding/binary" 20 "fmt" 21 "io" 22 "net" 23 "os" 24 "sync" 25 "time" 26) 27 28type shimDispatcher struct { 29 lock sync.Mutex 30 nextShimID uint64 31 listener *net.TCPListener 32 shims map[uint64]*shimListener 33 err error 34} 35 36func newShimDispatcher() (*shimDispatcher, error) { 37 listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv6loopback}) 38 if err != nil { 39 listener, err = net.ListenTCP("tcp4", &net.TCPAddr{IP: net.IP{127, 0, 0, 1}}) 40 } 41 42 if err != nil { 43 return nil, err 44 } 45 d := &shimDispatcher{listener: listener, shims: make(map[uint64]*shimListener)} 46 go d.acceptLoop() 47 return d, nil 48} 49 50func (d *shimDispatcher) NewShim() (*shimListener, error) { 51 d.lock.Lock() 52 defer d.lock.Unlock() 53 if d.err != nil { 54 return nil, d.err 55 } 56 57 l := &shimListener{dispatcher: d, shimID: d.nextShimID, connChan: make(chan net.Conn, 1)} 58 d.shims[l.shimID] = l 59 d.nextShimID++ 60 return l, nil 61} 62 63func (d *shimDispatcher) unregisterShim(l *shimListener) { 64 d.lock.Lock() 65 delete(d.shims, l.shimID) 66 d.lock.Unlock() 67} 68 69func (d *shimDispatcher) acceptLoop() { 70 for { 71 conn, err := d.listener.Accept() 72 if err != nil { 73 // Something went wrong. Shut down the listener. 74 d.closeWithError(err) 75 return 76 } 77 78 go func() { 79 if err := d.dispatch(conn); err != nil { 80 // To be robust against port scanners, etc., we log a warning 81 // but otherwise treat undispatchable connections as non-fatal. 82 fmt.Fprintf(os.Stderr, "Error dispatching connection: %s\n", err) 83 conn.Close() 84 } 85 }() 86 } 87} 88 89func (d *shimDispatcher) dispatch(conn net.Conn) error { 90 conn.SetReadDeadline(time.Now().Add(*idleTimeout)) 91 var buf [8]byte 92 if _, err := io.ReadFull(conn, buf[:]); err != nil { 93 return err 94 } 95 conn.SetReadDeadline(time.Time{}) 96 97 shimID := binary.LittleEndian.Uint64(buf[:]) 98 d.lock.Lock() 99 shim, ok := d.shims[shimID] 100 d.lock.Unlock() 101 if !ok { 102 return fmt.Errorf("shim ID %d not found", shimID) 103 } 104 105 shim.connChan <- conn 106 return nil 107} 108 109func (d *shimDispatcher) Close() error { 110 return d.closeWithError(net.ErrClosed) 111} 112 113func (d *shimDispatcher) closeWithError(err error) error { 114 closeErr := d.listener.Close() 115 116 d.lock.Lock() 117 shims := d.shims 118 d.shims = make(map[uint64]*shimListener) 119 d.err = err 120 d.lock.Unlock() 121 122 for _, shim := range shims { 123 shim.closeWithError(err) 124 } 125 return closeErr 126} 127 128type shimListener struct { 129 dispatcher *shimDispatcher 130 shimID uint64 131 // connChan contains connections from the dispatcher. On fatal error, it is 132 // closed, with the error available in err. 133 connChan chan net.Conn 134 err error 135 lock sync.Mutex 136} 137 138func (l *shimListener) Port() int { 139 return l.dispatcher.listener.Addr().(*net.TCPAddr).Port 140} 141 142func (l *shimListener) IsIPv6() bool { 143 return len(l.dispatcher.listener.Addr().(*net.TCPAddr).IP) == net.IPv6len 144} 145 146func (l *shimListener) ShimID() uint64 { 147 return l.shimID 148} 149 150func (l *shimListener) Close() error { 151 l.dispatcher.unregisterShim(l) 152 l.closeWithError(net.ErrClosed) 153 return nil 154} 155 156func (l *shimListener) closeWithError(err error) { 157 // Multiple threads may close the listener at once, so protect closing with 158 // a lock. 159 l.lock.Lock() 160 if l.err == nil { 161 l.err = err 162 close(l.connChan) 163 } 164 l.lock.Unlock() 165} 166 167func (l *shimListener) Accept(deadline time.Time) (net.Conn, error) { 168 var timerChan <-chan time.Time 169 if !deadline.IsZero() { 170 remaining := time.Until(deadline) 171 if remaining < 0 { 172 return nil, context.DeadlineExceeded 173 } 174 timer := time.NewTimer(remaining) 175 defer timer.Stop() 176 timerChan = timer.C 177 } 178 179 select { 180 case <-timerChan: 181 return nil, context.DeadlineExceeded 182 case conn, ok := <-l.connChan: 183 if !ok { 184 return nil, l.err 185 } 186 return conn, nil 187 } 188} 189