1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package net
6
7import (
8	"cmp"
9	"context"
10	"encoding/json"
11	"errors"
12	"fmt"
13	"internal/testenv"
14	"os/exec"
15	"reflect"
16	"regexp"
17	"slices"
18	"strings"
19	"syscall"
20	"testing"
21)
22
23var nslookupTestServers = []string{"mail.golang.com", "gmail.com"}
24var lookupTestIPs = []string{"8.8.8.8", "1.1.1.1"}
25
26func toJson(v any) string {
27	data, _ := json.Marshal(v)
28	return string(data)
29}
30
31func testLookup(t *testing.T, fn func(*testing.T, *Resolver, string)) {
32	for _, def := range []bool{true, false} {
33		def := def
34		for _, server := range nslookupTestServers {
35			server := server
36			var name string
37			if def {
38				name = "default/"
39			} else {
40				name = "go/"
41			}
42			t.Run(name+server, func(t *testing.T) {
43				t.Parallel()
44				r := DefaultResolver
45				if !def {
46					r = &Resolver{PreferGo: true}
47				}
48				fn(t, r, server)
49			})
50		}
51	}
52}
53
54func TestNSLookupMX(t *testing.T) {
55	testenv.MustHaveExternalNetwork(t)
56
57	testLookup(t, func(t *testing.T, r *Resolver, server string) {
58		mx, err := r.LookupMX(context.Background(), server)
59		if err != nil {
60			t.Fatal(err)
61		}
62		if len(mx) == 0 {
63			t.Fatal("no results")
64		}
65		expected, err := nslookupMX(server)
66		if err != nil {
67			t.Skipf("skipping failed nslookup %s test: %s", server, err)
68		}
69		byPrefAndHost := func(a, b *MX) int {
70			if r := cmp.Compare(a.Pref, b.Pref); r != 0 {
71				return r
72			}
73			return strings.Compare(a.Host, b.Host)
74		}
75		slices.SortFunc(expected, byPrefAndHost)
76		slices.SortFunc(mx, byPrefAndHost)
77		if !reflect.DeepEqual(expected, mx) {
78			t.Errorf("different results %s:\texp:%v\tgot:%v", server, toJson(expected), toJson(mx))
79		}
80	})
81}
82
83func TestNSLookupCNAME(t *testing.T) {
84	testenv.MustHaveExternalNetwork(t)
85
86	testLookup(t, func(t *testing.T, r *Resolver, server string) {
87		cname, err := r.LookupCNAME(context.Background(), server)
88		if err != nil {
89			t.Fatalf("failed %s: %s", server, err)
90		}
91		if cname == "" {
92			t.Fatalf("no result %s", server)
93		}
94		expected, err := nslookupCNAME(server)
95		if err != nil {
96			t.Skipf("skipping failed nslookup %s test: %s", server, err)
97		}
98		if expected != cname {
99			t.Errorf("different results %s:\texp:%v\tgot:%v", server, expected, cname)
100		}
101	})
102}
103
104func TestNSLookupNS(t *testing.T) {
105	testenv.MustHaveExternalNetwork(t)
106
107	testLookup(t, func(t *testing.T, r *Resolver, server string) {
108		ns, err := r.LookupNS(context.Background(), server)
109		if err != nil {
110			t.Fatalf("failed %s: %s", server, err)
111		}
112		if len(ns) == 0 {
113			t.Fatal("no results")
114		}
115		expected, err := nslookupNS(server)
116		if err != nil {
117			t.Skipf("skipping failed nslookup %s test: %s", server, err)
118		}
119		byHost := func(a, b *NS) int {
120			return strings.Compare(a.Host, b.Host)
121		}
122		slices.SortFunc(expected, byHost)
123		slices.SortFunc(ns, byHost)
124		if !reflect.DeepEqual(expected, ns) {
125			t.Errorf("different results %s:\texp:%v\tgot:%v", toJson(server), toJson(expected), ns)
126		}
127	})
128}
129
130func TestNSLookupTXT(t *testing.T) {
131	testenv.MustHaveExternalNetwork(t)
132
133	testLookup(t, func(t *testing.T, r *Resolver, server string) {
134		txt, err := r.LookupTXT(context.Background(), server)
135		if err != nil {
136			t.Fatalf("failed %s: %s", server, err)
137		}
138		if len(txt) == 0 {
139			t.Fatalf("no results")
140		}
141		expected, err := nslookupTXT(server)
142		if err != nil {
143			t.Skipf("skipping failed nslookup %s test: %s", server, err)
144		}
145		slices.Sort(expected)
146		slices.Sort(txt)
147		if !reflect.DeepEqual(expected, txt) {
148			t.Errorf("different results %s:\texp:%v\tgot:%v", server, toJson(expected), toJson(txt))
149		}
150	})
151}
152
153func TestLookupLocalPTR(t *testing.T) {
154	testenv.MustHaveExternalNetwork(t)
155
156	addr, err := localIP()
157	if err != nil {
158		t.Errorf("failed to get local ip: %s", err)
159	}
160	names, err := LookupAddr(addr.String())
161	if err != nil {
162		t.Errorf("failed %s: %s", addr, err)
163	}
164	if len(names) == 0 {
165		t.Errorf("no results")
166	}
167	expected, err := lookupPTR(addr.String())
168	if err != nil {
169		t.Skipf("skipping failed lookup %s test: %s", addr.String(), err)
170	}
171	slices.Sort(expected)
172	slices.Sort(names)
173	if !reflect.DeepEqual(expected, names) {
174		t.Errorf("different results %s:\texp:%v\tgot:%v", addr, toJson(expected), toJson(names))
175	}
176}
177
178func TestLookupPTR(t *testing.T) {
179	testenv.MustHaveExternalNetwork(t)
180
181	for _, addr := range lookupTestIPs {
182		names, err := LookupAddr(addr)
183		if err != nil {
184			// The DNSError type stores the error as a string, so it cannot wrap the
185			// original error code and we cannot check for it here. However, we can at
186			// least use its error string to identify the correct localized text for
187			// the error to skip.
188			var DNS_ERROR_RCODE_SERVER_FAILURE syscall.Errno = 9002
189			if strings.HasSuffix(err.Error(), DNS_ERROR_RCODE_SERVER_FAILURE.Error()) {
190				testenv.SkipFlaky(t, 38111)
191			}
192			t.Errorf("failed %s: %s", addr, err)
193		}
194		if len(names) == 0 {
195			t.Errorf("no results")
196		}
197		expected, err := lookupPTR(addr)
198		if err != nil {
199			t.Logf("skipping failed lookup %s test: %s", addr, err)
200			continue
201		}
202		slices.Sort(expected)
203		slices.Sort(names)
204		if !reflect.DeepEqual(expected, names) {
205			t.Errorf("different results %s:\texp:%v\tgot:%v", addr, toJson(expected), toJson(names))
206		}
207	}
208}
209
210func nslookup(qtype, name string) (string, error) {
211	var out strings.Builder
212	var err strings.Builder
213	cmd := exec.Command("nslookup", "-querytype="+qtype, name)
214	cmd.Stdout = &out
215	cmd.Stderr = &err
216	if err := cmd.Run(); err != nil {
217		return "", err
218	}
219	r := strings.ReplaceAll(out.String(), "\r\n", "\n")
220	// nslookup stderr output contains also debug information such as
221	// "Non-authoritative answer" and it doesn't return the correct errcode
222	if strings.Contains(err.String(), "can't find") {
223		return r, errors.New(err.String())
224	}
225	return r, nil
226}
227
228func nslookupMX(name string) (mx []*MX, err error) {
229	var r string
230	if r, err = nslookup("mx", name); err != nil {
231		return
232	}
233	mx = make([]*MX, 0, 10)
234	// linux nslookup syntax
235	// golang.org      mail exchanger = 2 alt1.aspmx.l.google.com.
236	rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+mail exchanger\s*=\s*([0-9]+)\s*([a-z0-9.\-]+)$`)
237	for _, ans := range rx.FindAllStringSubmatch(r, -1) {
238		pref, _, _ := dtoi(ans[2])
239		mx = append(mx, &MX{absDomainName(ans[3]), uint16(pref)})
240	}
241	// windows nslookup syntax
242	// gmail.com       MX preference = 30, mail exchanger = alt3.gmail-smtp-in.l.google.com
243	rx = regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+MX preference\s*=\s*([0-9]+)\s*,\s*mail exchanger\s*=\s*([a-z0-9.\-]+)$`)
244	for _, ans := range rx.FindAllStringSubmatch(r, -1) {
245		pref, _, _ := dtoi(ans[2])
246		mx = append(mx, &MX{absDomainName(ans[3]), uint16(pref)})
247	}
248	return
249}
250
251func nslookupNS(name string) (ns []*NS, err error) {
252	var r string
253	if r, err = nslookup("ns", name); err != nil {
254		return
255	}
256	ns = make([]*NS, 0, 10)
257	// golang.org      nameserver = ns1.google.com.
258	rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+nameserver\s*=\s*([a-z0-9.\-]+)$`)
259	for _, ans := range rx.FindAllStringSubmatch(r, -1) {
260		ns = append(ns, &NS{absDomainName(ans[2])})
261	}
262	return
263}
264
265func nslookupCNAME(name string) (cname string, err error) {
266	var r string
267	if r, err = nslookup("cname", name); err != nil {
268		return
269	}
270	// mail.golang.com canonical name = golang.org.
271	rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+canonical name\s*=\s*([a-z0-9.\-]+)$`)
272	// assumes the last CNAME is the correct one
273	last := name
274	for _, ans := range rx.FindAllStringSubmatch(r, -1) {
275		last = ans[2]
276	}
277	return absDomainName(last), nil
278}
279
280func nslookupTXT(name string) (txt []string, err error) {
281	var r string
282	if r, err = nslookup("txt", name); err != nil {
283		return
284	}
285	txt = make([]string, 0, 10)
286	// linux
287	// golang.org      text = "v=spf1 redirect=_spf.google.com"
288
289	// windows
290	// golang.org      text =
291	//
292	//    "v=spf1 redirect=_spf.google.com"
293	rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+text\s*=\s*"(.*)"$`)
294	for _, ans := range rx.FindAllStringSubmatch(r, -1) {
295		txt = append(txt, ans[2])
296	}
297	return
298}
299
300func ping(name string) (string, error) {
301	cmd := exec.Command("ping", "-n", "1", "-a", name)
302	stdoutStderr, err := cmd.CombinedOutput()
303	if err != nil {
304		return "", fmt.Errorf("%v: %v", err, string(stdoutStderr))
305	}
306	r := strings.ReplaceAll(string(stdoutStderr), "\r\n", "\n")
307	return r, nil
308}
309
310func lookupPTR(name string) (ptr []string, err error) {
311	var r string
312	if r, err = ping(name); err != nil {
313		return
314	}
315	ptr = make([]string, 0, 10)
316	rx := regexp.MustCompile(`(?m)^Pinging\s+([a-zA-Z0-9.\-]+)\s+\[.*$`)
317	for _, ans := range rx.FindAllStringSubmatch(r, -1) {
318		ptr = append(ptr, absDomainName(ans[1]))
319	}
320	return
321}
322
323func localIP() (ip IP, err error) {
324	conn, err := Dial("udp", "golang.org:80")
325	if err != nil {
326		return nil, err
327	}
328	defer conn.Close()
329
330	localAddr := conn.LocalAddr().(*UDPAddr)
331
332	return localAddr.IP, nil
333}
334