1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build !windows && !plan9 && !js && !wasip1
6
7package syslog
8
9import (
10	"bufio"
11	"fmt"
12	"io"
13	"net"
14	"os"
15	"path/filepath"
16	"runtime"
17	"sync"
18	"testing"
19	"time"
20)
21
22func runPktSyslog(c net.PacketConn, done chan<- string) {
23	var buf [4096]byte
24	var rcvd string
25	ct := 0
26	for {
27		var n int
28		var err error
29
30		c.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
31		n, _, err = c.ReadFrom(buf[:])
32		rcvd += string(buf[:n])
33		if err != nil {
34			if oe, ok := err.(*net.OpError); ok {
35				if ct < 3 && oe.Temporary() {
36					ct++
37					continue
38				}
39			}
40			break
41		}
42	}
43	c.Close()
44	done <- rcvd
45}
46
47var crashy = false
48
49func testableNetwork(network string) bool {
50	switch network {
51	case "unix", "unixgram":
52		switch runtime.GOOS {
53		case "ios", "android":
54			return false
55		}
56	}
57	return true
58}
59
60func runStreamSyslog(l net.Listener, done chan<- string, wg *sync.WaitGroup) {
61	for {
62		var c net.Conn
63		var err error
64		if c, err = l.Accept(); err != nil {
65			return
66		}
67		wg.Add(1)
68		go func(c net.Conn) {
69			defer wg.Done()
70			c.SetReadDeadline(time.Now().Add(5 * time.Second))
71			b := bufio.NewReader(c)
72			for ct := 1; !crashy || ct&7 != 0; ct++ {
73				s, err := b.ReadString('\n')
74				if err != nil {
75					break
76				}
77				done <- s
78			}
79			c.Close()
80		}(c)
81	}
82}
83
84func startServer(t *testing.T, n, la string, done chan<- string) (addr string, sock io.Closer, wg *sync.WaitGroup) {
85	if n == "udp" || n == "tcp" {
86		la = "127.0.0.1:0"
87	} else {
88		// unix and unixgram: choose an address if none given.
89		if la == "" {
90			// The address must be short to fit in the sun_path field of the
91			// sockaddr_un passed to the underlying system calls, so we use
92			// os.MkdirTemp instead of t.TempDir: t.TempDir generally includes all or
93			// part of the test name in the directory, which can be much more verbose
94			// and risks running up against the limit.
95			dir, err := os.MkdirTemp("", "")
96			if err != nil {
97				t.Fatal(err)
98			}
99			t.Cleanup(func() {
100				if err := os.RemoveAll(dir); err != nil {
101					t.Errorf("failed to remove socket temp directory: %v", err)
102				}
103			})
104			la = filepath.Join(dir, "sock")
105		}
106	}
107
108	wg = new(sync.WaitGroup)
109	if n == "udp" || n == "unixgram" {
110		l, e := net.ListenPacket(n, la)
111		if e != nil {
112			t.Helper()
113			t.Fatalf("startServer failed: %v", e)
114		}
115		addr = l.LocalAddr().String()
116		sock = l
117		wg.Add(1)
118		go func() {
119			defer wg.Done()
120			runPktSyslog(l, done)
121		}()
122	} else {
123		l, e := net.Listen(n, la)
124		if e != nil {
125			t.Helper()
126			t.Fatalf("startServer failed: %v", e)
127		}
128		addr = l.Addr().String()
129		sock = l
130		wg.Add(1)
131		go func() {
132			defer wg.Done()
133			runStreamSyslog(l, done, wg)
134		}()
135	}
136	return
137}
138
139func TestWithSimulated(t *testing.T) {
140	t.Parallel()
141
142	msg := "Test 123"
143	for _, tr := range []string{"unix", "unixgram", "udp", "tcp"} {
144		if !testableNetwork(tr) {
145			continue
146		}
147
148		tr := tr
149		t.Run(tr, func(t *testing.T) {
150			t.Parallel()
151
152			done := make(chan string)
153			addr, sock, srvWG := startServer(t, tr, "", done)
154			defer srvWG.Wait()
155			defer sock.Close()
156			if tr == "unix" || tr == "unixgram" {
157				defer os.Remove(addr)
158			}
159			s, err := Dial(tr, addr, LOG_INFO|LOG_USER, "syslog_test")
160			if err != nil {
161				t.Fatalf("Dial() failed: %v", err)
162			}
163			err = s.Info(msg)
164			if err != nil {
165				t.Fatalf("log failed: %v", err)
166			}
167			check(t, msg, <-done, tr)
168			s.Close()
169		})
170	}
171}
172
173func TestFlap(t *testing.T) {
174	net := "unix"
175	if !testableNetwork(net) {
176		t.Skipf("skipping on %s/%s; 'unix' is not supported", runtime.GOOS, runtime.GOARCH)
177	}
178
179	done := make(chan string)
180	addr, sock, srvWG := startServer(t, net, "", done)
181	defer srvWG.Wait()
182	defer os.Remove(addr)
183	defer sock.Close()
184
185	s, err := Dial(net, addr, LOG_INFO|LOG_USER, "syslog_test")
186	if err != nil {
187		t.Fatalf("Dial() failed: %v", err)
188	}
189	msg := "Moo 2"
190	err = s.Info(msg)
191	if err != nil {
192		t.Fatalf("log failed: %v", err)
193	}
194	check(t, msg, <-done, net)
195
196	// restart the server
197	if err := os.Remove(addr); err != nil {
198		t.Fatal(err)
199	}
200	_, sock2, srvWG2 := startServer(t, net, addr, done)
201	defer srvWG2.Wait()
202	defer sock2.Close()
203
204	// and try retransmitting
205	msg = "Moo 3"
206	err = s.Info(msg)
207	if err != nil {
208		t.Fatalf("log failed: %v", err)
209	}
210	check(t, msg, <-done, net)
211
212	s.Close()
213}
214
215func TestNew(t *testing.T) {
216	if LOG_LOCAL7 != 23<<3 {
217		t.Fatalf("LOG_LOCAL7 has wrong value")
218	}
219	if testing.Short() {
220		// Depends on syslog daemon running, and sometimes it's not.
221		t.Skip("skipping syslog test during -short")
222	}
223
224	s, err := New(LOG_INFO|LOG_USER, "the_tag")
225	if err != nil {
226		if err.Error() == "Unix syslog delivery error" {
227			t.Skip("skipping: syslogd not running")
228		}
229		t.Fatalf("New() failed: %s", err)
230	}
231	// Don't send any messages.
232	s.Close()
233}
234
235func TestNewLogger(t *testing.T) {
236	if testing.Short() {
237		t.Skip("skipping syslog test during -short")
238	}
239	f, err := NewLogger(LOG_USER|LOG_INFO, 0)
240	if f == nil {
241		if err.Error() == "Unix syslog delivery error" {
242			t.Skip("skipping: syslogd not running")
243		}
244		t.Error(err)
245	}
246}
247
248func TestDial(t *testing.T) {
249	if testing.Short() {
250		t.Skip("skipping syslog test during -short")
251	}
252	f, err := Dial("", "", (LOG_LOCAL7|LOG_DEBUG)+1, "syslog_test")
253	if f != nil {
254		t.Fatalf("Should have trapped bad priority")
255	}
256	f, err = Dial("", "", -1, "syslog_test")
257	if f != nil {
258		t.Fatalf("Should have trapped bad priority")
259	}
260	l, err := Dial("", "", LOG_USER|LOG_ERR, "syslog_test")
261	if err != nil {
262		if err.Error() == "Unix syslog delivery error" {
263			t.Skip("skipping: syslogd not running")
264		}
265		t.Fatalf("Dial() failed: %s", err)
266	}
267	l.Close()
268}
269
270func check(t *testing.T, in, out, transport string) {
271	hostname, err := os.Hostname()
272	if err != nil {
273		t.Errorf("Error retrieving hostname: %v", err)
274		return
275	}
276
277	if transport == "unixgram" || transport == "unix" {
278		var month, date, ts string
279		var pid int
280		tmpl := fmt.Sprintf("<%d>%%s %%s %%s syslog_test[%%d]: %s\n", LOG_USER+LOG_INFO, in)
281		n, err := fmt.Sscanf(out, tmpl, &month, &date, &ts, &pid)
282		if n != 4 || err != nil {
283			t.Errorf("Got %q, does not match template %q (%d %s)", out, tmpl, n, err)
284		}
285		return
286	}
287
288	// Non-UNIX domain transports.
289	var parsedHostname, timestamp string
290	var pid int
291	tmpl := fmt.Sprintf("<%d>%%s %%s syslog_test[%%d]: %s\n", LOG_USER+LOG_INFO, in)
292	n, err := fmt.Sscanf(out, tmpl, &timestamp, &parsedHostname, &pid)
293	if n != 3 || err != nil {
294		t.Errorf("Got %q, does not match template %q (%d %s)", out, tmpl, n, err)
295	}
296	if hostname != parsedHostname {
297		t.Errorf("Hostname got %q want %q in %q", parsedHostname, hostname, out)
298	}
299}
300
301func TestWrite(t *testing.T) {
302	t.Parallel()
303
304	tests := []struct {
305		pri Priority
306		pre string
307		msg string
308		exp string
309	}{
310		{LOG_USER | LOG_ERR, "syslog_test", "", "%s %s syslog_test[%d]: \n"},
311		{LOG_USER | LOG_ERR, "syslog_test", "write test", "%s %s syslog_test[%d]: write test\n"},
312		// Write should not add \n if there already is one
313		{LOG_USER | LOG_ERR, "syslog_test", "write test 2\n", "%s %s syslog_test[%d]: write test 2\n"},
314	}
315
316	if hostname, err := os.Hostname(); err != nil {
317		t.Fatalf("Error retrieving hostname")
318	} else {
319		for _, test := range tests {
320			done := make(chan string)
321			addr, sock, srvWG := startServer(t, "udp", "", done)
322			defer srvWG.Wait()
323			defer sock.Close()
324			l, err := Dial("udp", addr, test.pri, test.pre)
325			if err != nil {
326				t.Fatalf("syslog.Dial() failed: %v", err)
327			}
328			defer l.Close()
329			_, err = io.WriteString(l, test.msg)
330			if err != nil {
331				t.Fatalf("WriteString() failed: %v", err)
332			}
333			rcvd := <-done
334			test.exp = fmt.Sprintf("<%d>", test.pri) + test.exp
335			var parsedHostname, timestamp string
336			var pid int
337			if n, err := fmt.Sscanf(rcvd, test.exp, &timestamp, &parsedHostname, &pid); n != 3 || err != nil || hostname != parsedHostname {
338				t.Errorf("s.Info() = '%q', didn't match '%q' (%d %s)", rcvd, test.exp, n, err)
339			}
340		}
341	}
342}
343
344func TestConcurrentWrite(t *testing.T) {
345	addr, sock, srvWG := startServer(t, "udp", "", make(chan string, 1))
346	defer srvWG.Wait()
347	defer sock.Close()
348	w, err := Dial("udp", addr, LOG_USER|LOG_ERR, "how's it going?")
349	if err != nil {
350		t.Fatalf("syslog.Dial() failed: %v", err)
351	}
352	var wg sync.WaitGroup
353	for i := 0; i < 10; i++ {
354		wg.Add(1)
355		go func() {
356			defer wg.Done()
357			err := w.Info("test")
358			if err != nil {
359				t.Errorf("Info() failed: %v", err)
360				return
361			}
362		}()
363	}
364	wg.Wait()
365}
366
367func TestConcurrentReconnect(t *testing.T) {
368	crashy = true
369	defer func() { crashy = false }()
370
371	const N = 10
372	const M = 100
373	net := "unix"
374	if !testableNetwork(net) {
375		net = "tcp"
376		if !testableNetwork(net) {
377			t.Skipf("skipping on %s/%s; neither 'unix' or 'tcp' is supported", runtime.GOOS, runtime.GOARCH)
378		}
379	}
380	done := make(chan string, N*M)
381	addr, sock, srvWG := startServer(t, net, "", done)
382	if net == "unix" {
383		defer os.Remove(addr)
384	}
385
386	// count all the messages arriving
387	count := make(chan int, 1)
388	go func() {
389		ct := 0
390		for range done {
391			ct++
392			// we are looking for 500 out of 1000 events
393			// here because lots of log messages are lost
394			// in buffers (kernel and/or bufio)
395			if ct > N*M/2 {
396				break
397			}
398		}
399		count <- ct
400	}()
401
402	var wg sync.WaitGroup
403	wg.Add(N)
404	for i := 0; i < N; i++ {
405		go func() {
406			defer wg.Done()
407			w, err := Dial(net, addr, LOG_USER|LOG_ERR, "tag")
408			if err != nil {
409				t.Errorf("syslog.Dial() failed: %v", err)
410				return
411			}
412			defer w.Close()
413			for i := 0; i < M; i++ {
414				err := w.Info("test")
415				if err != nil {
416					t.Errorf("Info() failed: %v", err)
417					return
418				}
419			}
420		}()
421	}
422	wg.Wait()
423	sock.Close()
424	srvWG.Wait()
425	close(done)
426
427	select {
428	case <-count:
429	case <-time.After(100 * time.Millisecond):
430		t.Error("timeout in concurrent reconnect")
431	}
432}
433