1// Copyright 2021 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build unix
6
7package net
8
9import (
10	"internal/syscall/unix"
11	"os"
12	"syscall"
13	"testing"
14	"time"
15)
16
17func TestUnixConnReadMsgUnixSCMRightsCloseOnExec(t *testing.T) {
18	if !testableNetwork("unix") {
19		t.Skip("not unix system")
20	}
21
22	scmFile, err := os.Open(os.DevNull)
23	if err != nil {
24		t.Fatalf("file open: %v", err)
25	}
26	defer scmFile.Close()
27
28	rights := syscall.UnixRights(int(scmFile.Fd()))
29	fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0)
30	if err != nil {
31		t.Fatalf("Socketpair: %v", err)
32	}
33
34	writeFile := os.NewFile(uintptr(fds[0]), "write-socket")
35	defer writeFile.Close()
36	readFile := os.NewFile(uintptr(fds[1]), "read-socket")
37	defer readFile.Close()
38
39	cw, err := FileConn(writeFile)
40	if err != nil {
41		t.Fatalf("FileConn: %v", err)
42	}
43	defer cw.Close()
44	cr, err := FileConn(readFile)
45	if err != nil {
46		t.Fatalf("FileConn: %v", err)
47	}
48	defer cr.Close()
49
50	ucw, ok := cw.(*UnixConn)
51	if !ok {
52		t.Fatalf("got %T; want UnixConn", cw)
53	}
54	ucr, ok := cr.(*UnixConn)
55	if !ok {
56		t.Fatalf("got %T; want UnixConn", cr)
57	}
58
59	oob := make([]byte, syscall.CmsgSpace(4))
60	err = ucw.SetWriteDeadline(time.Now().Add(5 * time.Second))
61	if err != nil {
62		t.Fatalf("Can't set unix connection timeout: %v", err)
63	}
64	_, _, err = ucw.WriteMsgUnix(nil, rights, nil)
65	if err != nil {
66		t.Fatalf("UnixConn readMsg: %v", err)
67	}
68	err = ucr.SetReadDeadline(time.Now().Add(5 * time.Second))
69	if err != nil {
70		t.Fatalf("Can't set unix connection timeout: %v", err)
71	}
72	_, oobn, _, _, err := ucr.ReadMsgUnix(nil, oob)
73	if err != nil {
74		t.Fatalf("UnixConn readMsg: %v", err)
75	}
76
77	scms, err := syscall.ParseSocketControlMessage(oob[:oobn])
78	if err != nil {
79		t.Fatalf("ParseSocketControlMessage: %v", err)
80	}
81	if len(scms) != 1 {
82		t.Fatalf("got scms = %#v; expected 1 SocketControlMessage", scms)
83	}
84	scm := scms[0]
85	gotFDs, err := syscall.ParseUnixRights(&scm)
86	if err != nil {
87		t.Fatalf("syscall.ParseUnixRights: %v", err)
88	}
89	if len(gotFDs) != 1 {
90		t.Fatalf("got FDs %#v: wanted only 1 fd", gotFDs)
91	}
92	defer func() {
93		if err := syscall.Close(gotFDs[0]); err != nil {
94			t.Fatalf("fail to close gotFDs: %v", err)
95		}
96	}()
97
98	flags, err := unix.Fcntl(gotFDs[0], syscall.F_GETFD, 0)
99	if err != nil {
100		t.Fatalf("Can't get flags of fd:%#v, with err:%v", gotFDs[0], err)
101	}
102	if flags&syscall.FD_CLOEXEC == 0 {
103		t.Fatalf("got flags %#x, want %#x (FD_CLOEXEC) set", flags, syscall.FD_CLOEXEC)
104	}
105}
106