1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package wasi_test
6
7import (
8	"bytes"
9	"fmt"
10	"math/rand"
11	"net"
12	"os"
13	"os/exec"
14	"testing"
15	"time"
16)
17
18func TestTCPEcho(t *testing.T) {
19	if target != "wasip1/wasm" {
20		t.Skip()
21	}
22
23	// We're unable to use port 0 here (let the OS choose a spare port).
24	// Although the WASM runtime accepts port 0, and the WASM module listens
25	// successfully, there's no way for this test to query the selected port
26	// so that it can connect to the WASM module. The WASM module itself
27	// cannot access any information about the socket due to limitations
28	// with WASI preview 1 networking, and the WASM runtimes do not log the
29	// port when you pre-open a socket. So, we probe for a free port here.
30	// Given there's an unavoidable race condition, the test is disabled by
31	// default.
32	if os.Getenv("GOWASIENABLERACYTEST") != "1" {
33		t.Skip("skipping WASI test with unavoidable race condition")
34	}
35	var host string
36	port := rand.Intn(10000) + 40000
37	for attempts := 0; attempts < 10; attempts++ {
38		host = fmt.Sprintf("127.0.0.1:%d", port)
39		l, err := net.Listen("tcp", host)
40		if err == nil {
41			l.Close()
42			break
43		}
44		port++
45	}
46
47	subProcess := exec.Command("go", "run", "./testdata/tcpecho.go")
48
49	subProcess.Env = append(os.Environ(), "GOOS=wasip1", "GOARCH=wasm")
50
51	switch os.Getenv("GOWASIRUNTIME") {
52	case "wazero":
53		subProcess.Env = append(subProcess.Env, "GOWASIRUNTIMEARGS=--listen="+host)
54	case "wasmtime", "":
55		subProcess.Env = append(subProcess.Env, "GOWASIRUNTIMEARGS=--tcplisten="+host)
56	default:
57		t.Skip("WASI runtime does not support sockets")
58	}
59
60	var b bytes.Buffer
61	subProcess.Stdout = &b
62	subProcess.Stderr = &b
63
64	if err := subProcess.Start(); err != nil {
65		t.Log(b.String())
66		t.Fatal(err)
67	}
68	defer subProcess.Process.Kill()
69
70	var conn net.Conn
71	for {
72		var err error
73		conn, err = net.Dial("tcp", host)
74		if err == nil {
75			break
76		}
77		time.Sleep(500 * time.Millisecond)
78	}
79	defer conn.Close()
80
81	payload := []byte("foobar")
82	if _, err := conn.Write(payload); err != nil {
83		t.Fatal(err)
84	}
85	var buf [256]byte
86	n, err := conn.Read(buf[:])
87	if err != nil {
88		t.Fatal(err)
89	}
90	if string(buf[:n]) != string(payload) {
91		t.Error("unexpected payload")
92		t.Logf("expect: %d bytes (%v)", len(payload), payload)
93		t.Logf("actual: %d bytes (%v)", n, buf[:n])
94	}
95}
96