1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || windows
6
7package net
8
9import (
10	"syscall"
11	"testing"
12	"time"
13)
14
15func getCurrentKeepAliveSettings(fd fdType) (cfg KeepAliveConfig, err error) {
16	tcpKeepAlive, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE)
17	if err != nil {
18		return
19	}
20	tcpKeepAliveIdle, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall_TCP_KEEPIDLE)
21	if err != nil {
22		return
23	}
24	tcpKeepAliveInterval, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall_TCP_KEEPINTVL)
25	if err != nil {
26		return
27	}
28	tcpKeepAliveCount, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall_TCP_KEEPCNT)
29	if err != nil {
30		return
31	}
32	cfg = KeepAliveConfig{
33		Enable:   tcpKeepAlive != 0,
34		Idle:     time.Duration(tcpKeepAliveIdle) * time.Second,
35		Interval: time.Duration(tcpKeepAliveInterval) * time.Second,
36		Count:    tcpKeepAliveCount,
37	}
38	return
39}
40
41func verifyKeepAliveSettings(t *testing.T, fd fdType, oldCfg, cfg KeepAliveConfig) {
42	if cfg.Idle == 0 {
43		cfg.Idle = defaultTCPKeepAliveIdle
44	}
45	if cfg.Interval == 0 {
46		cfg.Interval = defaultTCPKeepAliveInterval
47	}
48	if cfg.Count == 0 {
49		cfg.Count = defaultTCPKeepAliveCount
50	}
51	if cfg.Idle == -1 {
52		cfg.Idle = oldCfg.Idle
53	}
54	if cfg.Interval == -1 {
55		cfg.Interval = oldCfg.Interval
56	}
57	if cfg.Count == -1 {
58		cfg.Count = oldCfg.Count
59	}
60
61	tcpKeepAlive, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE)
62	if err != nil {
63		t.Fatal(err)
64	}
65	if (tcpKeepAlive != 0) != cfg.Enable {
66		t.Fatalf("SO_KEEPALIVE: got %t; want %t", tcpKeepAlive != 0, cfg.Enable)
67	}
68
69	tcpKeepAliveIdle, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall_TCP_KEEPIDLE)
70	if err != nil {
71		t.Fatal(err)
72	}
73	if time.Duration(tcpKeepAliveIdle)*time.Second != cfg.Idle {
74		t.Fatalf("TCP_KEEPIDLE: got %ds; want %v", tcpKeepAliveIdle, cfg.Idle)
75	}
76
77	tcpKeepAliveInterval, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall_TCP_KEEPINTVL)
78	if err != nil {
79		t.Fatal(err)
80	}
81	if time.Duration(tcpKeepAliveInterval)*time.Second != cfg.Interval {
82		t.Fatalf("TCP_KEEPINTVL: got %ds; want %v", tcpKeepAliveInterval, cfg.Interval)
83	}
84
85	tcpKeepAliveCount, err := syscall.GetsockoptInt(fd, syscall.IPPROTO_TCP, syscall_TCP_KEEPCNT)
86	if err != nil {
87		t.Fatal(err)
88	}
89	if tcpKeepAliveCount != cfg.Count {
90		t.Fatalf("TCP_KEEPCNT: got %d; want %d", tcpKeepAliveCount, cfg.Count)
91	}
92}
93