1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main_test
6
7import (
8	"archive/zip"
9	"bytes"
10	"encoding/json"
11	"errors"
12	"flag"
13	"fmt"
14	"internal/txtar"
15	"io"
16	"io/fs"
17	"log"
18	"net"
19	"net/http"
20	"os"
21	"path/filepath"
22	"strconv"
23	"strings"
24	"sync"
25	"testing"
26
27	"cmd/go/internal/modfetch/codehost"
28	"cmd/go/internal/par"
29
30	"golang.org/x/mod/module"
31	"golang.org/x/mod/semver"
32	"golang.org/x/mod/sumdb"
33	"golang.org/x/mod/sumdb/dirhash"
34)
35
36var (
37	proxyAddr = flag.String("proxy", "", "run proxy on this network address instead of running any tests")
38	proxyURL  string
39)
40
41var proxyOnce sync.Once
42
43// StartProxy starts the Go module proxy running on *proxyAddr (like "localhost:1234")
44// and sets proxyURL to the GOPROXY setting to use to access the proxy.
45// Subsequent calls are no-ops.
46//
47// The proxy serves from testdata/mod. See testdata/mod/README.
48func StartProxy() {
49	proxyOnce.Do(func() {
50		readModList()
51		addr := *proxyAddr
52		if addr == "" {
53			addr = "localhost:0"
54		}
55		l, err := net.Listen("tcp", addr)
56		if err != nil {
57			log.Fatal(err)
58		}
59		*proxyAddr = l.Addr().String()
60		proxyURL = "http://" + *proxyAddr + "/mod"
61		fmt.Fprintf(os.Stderr, "go test proxy running at GOPROXY=%s\n", proxyURL)
62		go func() {
63			log.Fatalf("go proxy: http.Serve: %v", http.Serve(l, http.HandlerFunc(proxyHandler)))
64		}()
65
66		// Prepopulate main sumdb.
67		for _, mod := range modList {
68			sumdbOps.Lookup(nil, mod)
69		}
70	})
71}
72
73var modList []module.Version
74
75func readModList() {
76	files, err := os.ReadDir("testdata/mod")
77	if err != nil {
78		log.Fatal(err)
79	}
80	for _, f := range files {
81		name := f.Name()
82		if !strings.HasSuffix(name, ".txt") {
83			continue
84		}
85		name = strings.TrimSuffix(name, ".txt")
86		i := strings.LastIndex(name, "_v")
87		if i < 0 {
88			continue
89		}
90		encPath := strings.ReplaceAll(name[:i], "_", "/")
91		path, err := module.UnescapePath(encPath)
92		if err != nil {
93			if testing.Verbose() && encPath != "example.com/invalidpath/v1" {
94				fmt.Fprintf(os.Stderr, "go proxy_test: %v\n", err)
95			}
96			continue
97		}
98		encVers := name[i+1:]
99		vers, err := module.UnescapeVersion(encVers)
100		if err != nil {
101			fmt.Fprintf(os.Stderr, "go proxy_test: %v\n", err)
102			continue
103		}
104		modList = append(modList, module.Version{Path: path, Version: vers})
105	}
106}
107
108var zipCache par.ErrCache[*txtar.Archive, []byte]
109
110const (
111	testSumDBName        = "localhost.localdev/sumdb"
112	testSumDBVerifierKey = "localhost.localdev/sumdb+00000c67+AcTrnkbUA+TU4heY3hkjiSES/DSQniBqIeQ/YppAUtK6"
113	testSumDBSignerKey   = "PRIVATE+KEY+localhost.localdev/sumdb+00000c67+AXu6+oaVaOYuQOFrf1V59JK1owcFlJcHwwXHDfDGxSPk"
114)
115
116var (
117	sumdbOps    = sumdb.NewTestServer(testSumDBSignerKey, proxyGoSum)
118	sumdbServer = sumdb.NewServer(sumdbOps)
119
120	sumdbWrongOps    = sumdb.NewTestServer(testSumDBSignerKey, proxyGoSumWrong)
121	sumdbWrongServer = sumdb.NewServer(sumdbWrongOps)
122)
123
124// proxyHandler serves the Go module proxy protocol.
125// See the proxy section of https://research.swtch.com/vgo-module.
126func proxyHandler(w http.ResponseWriter, r *http.Request) {
127	if !strings.HasPrefix(r.URL.Path, "/mod/") {
128		http.NotFound(w, r)
129		return
130	}
131	path := r.URL.Path[len("/mod/"):]
132
133	// /mod/invalid returns faulty responses.
134	if strings.HasPrefix(path, "invalid/") {
135		w.Write([]byte("invalid"))
136		return
137	}
138
139	// Next element may opt into special behavior.
140	if j := strings.Index(path, "/"); j >= 0 {
141		n, err := strconv.Atoi(path[:j])
142		if err == nil && n >= 200 {
143			w.WriteHeader(n)
144			return
145		}
146		if strings.HasPrefix(path, "sumdb-") {
147			n, err := strconv.Atoi(path[len("sumdb-"):j])
148			if err == nil && n >= 200 {
149				if strings.HasPrefix(path[j:], "/sumdb/") {
150					w.WriteHeader(n)
151					return
152				}
153				path = path[j+1:]
154			}
155		}
156	}
157
158	// Request for $GOPROXY/sumdb-direct is direct sumdb access.
159	// (Client thinks it is talking directly to a sumdb.)
160	if strings.HasPrefix(path, "sumdb-direct/") {
161		r.URL.Path = path[len("sumdb-direct"):]
162		sumdbServer.ServeHTTP(w, r)
163		return
164	}
165
166	// Request for $GOPROXY/sumdb-wrong is direct sumdb access
167	// but all the hashes are wrong.
168	// (Client thinks it is talking directly to a sumdb.)
169	if strings.HasPrefix(path, "sumdb-wrong/") {
170		r.URL.Path = path[len("sumdb-wrong"):]
171		sumdbWrongServer.ServeHTTP(w, r)
172		return
173	}
174
175	// Request for $GOPROXY/redirect/<count>/... goes to redirects.
176	if strings.HasPrefix(path, "redirect/") {
177		path = path[len("redirect/"):]
178		if j := strings.Index(path, "/"); j >= 0 {
179			count, err := strconv.Atoi(path[:j])
180			if err != nil {
181				return
182			}
183
184			// The last redirect.
185			if count <= 1 {
186				http.Redirect(w, r, fmt.Sprintf("/mod/%s", path[j+1:]), 302)
187				return
188			}
189			http.Redirect(w, r, fmt.Sprintf("/mod/redirect/%d/%s", count-1, path[j+1:]), 302)
190			return
191		}
192	}
193
194	// Request for $GOPROXY/sumdb/<name>/supported
195	// is checking whether it's OK to access sumdb via the proxy.
196	if path == "sumdb/"+testSumDBName+"/supported" {
197		w.WriteHeader(200)
198		return
199	}
200
201	// Request for $GOPROXY/sumdb/<name>/... goes to sumdb.
202	if sumdbPrefix := "sumdb/" + testSumDBName + "/"; strings.HasPrefix(path, sumdbPrefix) {
203		r.URL.Path = path[len(sumdbPrefix)-1:]
204		sumdbServer.ServeHTTP(w, r)
205		return
206	}
207
208	// Module proxy request: /mod/path/@latest
209	// Rewrite to /mod/path/@v/<latest>.info where <latest> is the semantically
210	// latest version, including pseudo-versions.
211	if i := strings.LastIndex(path, "/@latest"); i >= 0 {
212		enc := path[:i]
213		modPath, err := module.UnescapePath(enc)
214		if err != nil {
215			if testing.Verbose() {
216				fmt.Fprintf(os.Stderr, "go proxy_test: %v\n", err)
217			}
218			http.NotFound(w, r)
219			return
220		}
221
222		// Imitate what "latest" does in direct mode and what proxy.golang.org does.
223		// Use the latest released version.
224		// If there is no released version, use the latest prereleased version.
225		// Otherwise, use the latest pseudoversion.
226		var latestRelease, latestPrerelease, latestPseudo string
227		for _, m := range modList {
228			if m.Path != modPath {
229				continue
230			}
231			if module.IsPseudoVersion(m.Version) && (latestPseudo == "" || semver.Compare(latestPseudo, m.Version) > 0) {
232				latestPseudo = m.Version
233			} else if semver.Prerelease(m.Version) != "" && (latestPrerelease == "" || semver.Compare(latestPrerelease, m.Version) > 0) {
234				latestPrerelease = m.Version
235			} else if latestRelease == "" || semver.Compare(latestRelease, m.Version) > 0 {
236				latestRelease = m.Version
237			}
238		}
239		var latest string
240		if latestRelease != "" {
241			latest = latestRelease
242		} else if latestPrerelease != "" {
243			latest = latestPrerelease
244		} else if latestPseudo != "" {
245			latest = latestPseudo
246		} else {
247			http.NotFound(w, r)
248			return
249		}
250
251		encVers, err := module.EscapeVersion(latest)
252		if err != nil {
253			http.Error(w, err.Error(), http.StatusInternalServerError)
254			return
255		}
256		path = fmt.Sprintf("%s/@v/%s.info", enc, encVers)
257	}
258
259	// Module proxy request: /mod/path/@v/version[.suffix]
260	i := strings.Index(path, "/@v/")
261	if i < 0 {
262		http.NotFound(w, r)
263		return
264	}
265	enc, file := path[:i], path[i+len("/@v/"):]
266	path, err := module.UnescapePath(enc)
267	if err != nil {
268		if testing.Verbose() {
269			fmt.Fprintf(os.Stderr, "go proxy_test: %v\n", err)
270		}
271		http.NotFound(w, r)
272		return
273	}
274	if file == "list" {
275		// list returns a list of versions, not including pseudo-versions.
276		// If the module has no tagged versions, we should serve an empty 200.
277		// If the module doesn't exist, we should serve 404 or 410.
278		found := false
279		for _, m := range modList {
280			if m.Path != path {
281				continue
282			}
283			found = true
284			if !module.IsPseudoVersion(m.Version) {
285				if err := module.Check(m.Path, m.Version); err == nil {
286					fmt.Fprintf(w, "%s\n", m.Version)
287				}
288			}
289		}
290		if !found {
291			http.NotFound(w, r)
292		}
293		return
294	}
295
296	i = strings.LastIndex(file, ".")
297	if i < 0 {
298		http.NotFound(w, r)
299		return
300	}
301	encVers, ext := file[:i], file[i+1:]
302	vers, err := module.UnescapeVersion(encVers)
303	if err != nil {
304		fmt.Fprintf(os.Stderr, "go proxy_test: %v\n", err)
305		http.NotFound(w, r)
306		return
307	}
308
309	if codehost.AllHex(vers) {
310		var best string
311		// Convert commit hash (only) to known version.
312		// Use latest version in semver priority, to match similar logic
313		// in the repo-based module server (see modfetch.(*codeRepo).convert).
314		for _, m := range modList {
315			if m.Path == path && semver.Compare(best, m.Version) < 0 {
316				var hash string
317				if module.IsPseudoVersion(m.Version) {
318					hash = m.Version[strings.LastIndex(m.Version, "-")+1:]
319				} else {
320					hash = findHash(m)
321				}
322				if strings.HasPrefix(hash, vers) || strings.HasPrefix(vers, hash) {
323					best = m.Version
324				}
325			}
326		}
327		if best != "" {
328			vers = best
329		}
330	}
331
332	a, err := readArchive(path, vers)
333	if err != nil {
334		if testing.Verbose() {
335			fmt.Fprintf(os.Stderr, "go proxy: no archive %s %s: %v\n", path, vers, err)
336		}
337		if errors.Is(err, fs.ErrNotExist) {
338			http.NotFound(w, r)
339		} else {
340			http.Error(w, "cannot load archive", 500)
341		}
342		return
343	}
344
345	switch ext {
346	case "info", "mod":
347		want := "." + ext
348		for _, f := range a.Files {
349			if f.Name == want {
350				w.Write(f.Data)
351				return
352			}
353		}
354
355	case "zip":
356		zipBytes, err := zipCache.Do(a, func() ([]byte, error) {
357			var buf bytes.Buffer
358			z := zip.NewWriter(&buf)
359			for _, f := range a.Files {
360				if f.Name == ".info" || f.Name == ".mod" || f.Name == ".zip" {
361					continue
362				}
363				var zipName string
364				if strings.HasPrefix(f.Name, "/") {
365					zipName = f.Name[1:]
366				} else {
367					zipName = path + "@" + vers + "/" + f.Name
368				}
369				zf, err := z.Create(zipName)
370				if err != nil {
371					return nil, err
372				}
373				if _, err := zf.Write(f.Data); err != nil {
374					return nil, err
375				}
376			}
377			if err := z.Close(); err != nil {
378				return nil, err
379			}
380			return buf.Bytes(), nil
381		})
382
383		if err != nil {
384			if testing.Verbose() {
385				fmt.Fprintf(os.Stderr, "go proxy: %v\n", err)
386			}
387			http.Error(w, err.Error(), 500)
388			return
389		}
390		w.Write(zipBytes)
391		return
392
393	}
394	http.NotFound(w, r)
395}
396
397func findHash(m module.Version) string {
398	a, err := readArchive(m.Path, m.Version)
399	if err != nil {
400		return ""
401	}
402	var data []byte
403	for _, f := range a.Files {
404		if f.Name == ".info" {
405			data = f.Data
406			break
407		}
408	}
409	var info struct{ Short string }
410	json.Unmarshal(data, &info)
411	return info.Short
412}
413
414var archiveCache par.Cache[string, *txtar.Archive]
415
416var cmdGoDir, _ = os.Getwd()
417
418func readArchive(path, vers string) (*txtar.Archive, error) {
419	enc, err := module.EscapePath(path)
420	if err != nil {
421		return nil, err
422	}
423	encVers, err := module.EscapeVersion(vers)
424	if err != nil {
425		return nil, err
426	}
427
428	prefix := strings.ReplaceAll(enc, "/", "_")
429	name := filepath.Join(cmdGoDir, "testdata/mod", prefix+"_"+encVers+".txt")
430	a := archiveCache.Do(name, func() *txtar.Archive {
431		a, err := txtar.ParseFile(name)
432		if err != nil {
433			if testing.Verbose() || !os.IsNotExist(err) {
434				fmt.Fprintf(os.Stderr, "go proxy: %v\n", err)
435			}
436			a = nil
437		}
438		return a
439	})
440	if a == nil {
441		return nil, fs.ErrNotExist
442	}
443	return a, nil
444}
445
446// proxyGoSum returns the two go.sum lines for path@vers.
447func proxyGoSum(path, vers string) ([]byte, error) {
448	a, err := readArchive(path, vers)
449	if err != nil {
450		return nil, err
451	}
452	var names []string
453	files := make(map[string][]byte)
454	var gomod []byte
455	for _, f := range a.Files {
456		if strings.HasPrefix(f.Name, ".") {
457			if f.Name == ".mod" {
458				gomod = f.Data
459			}
460			continue
461		}
462		name := path + "@" + vers + "/" + f.Name
463		names = append(names, name)
464		files[name] = f.Data
465	}
466	h1, err := dirhash.Hash1(names, func(name string) (io.ReadCloser, error) {
467		data := files[name]
468		return io.NopCloser(bytes.NewReader(data)), nil
469	})
470	if err != nil {
471		return nil, err
472	}
473	h1mod, err := dirhash.Hash1([]string{"go.mod"}, func(string) (io.ReadCloser, error) {
474		return io.NopCloser(bytes.NewReader(gomod)), nil
475	})
476	if err != nil {
477		return nil, err
478	}
479	data := []byte(fmt.Sprintf("%s %s %s\n%s %s/go.mod %s\n", path, vers, h1, path, vers, h1mod))
480	return data, nil
481}
482
483// proxyGoSumWrong returns the wrong lines.
484func proxyGoSumWrong(path, vers string) ([]byte, error) {
485	data := []byte(fmt.Sprintf("%s %s %s\n%s %s/go.mod %s\n", path, vers, "h1:wrong", path, vers, "h1:wrong"))
486	return data, nil
487}
488