xref: /aosp_15_r20/external/bazelbuild-rules_go/go/tools/fetch_repo/fetch_repo_test.go (revision 9bb1b549b6a84214c53be0924760be030e66b93a)
1package main
2
3import (
4	"os"
5	"reflect"
6	"testing"
7
8	"golang.org/x/tools/go/vcs"
9)
10
11var (
12	root = &vcs.RepoRoot{
13		VCS:  vcs.ByCmd("git"),
14		Repo: "https://github.com/bazeltest/rules_go",
15		Root: "github.com/bazeltest/rules_go",
16	}
17)
18
19func TestMain(m *testing.M) {
20	// Replace vcs.RepoRootForImportPath to disable any network calls.
21	repoRootForImportPath = func(_ string, _ bool) (*vcs.RepoRoot, error) {
22		return root, nil
23	}
24	os.Exit(m.Run())
25}
26
27func TestGetRepoRoot(t *testing.T) {
28	for _, tc := range []struct {
29		label      string
30		remote     string
31		cmd        string
32		importpath string
33		r          *vcs.RepoRoot
34	}{
35		{
36			label:      "all",
37			remote:     "https://github.com/bazeltest/rules_go",
38			cmd:        "git",
39			importpath: "github.com/bazeltest/rules_go",
40			r:          root,
41		},
42		{
43			label:      "different remote",
44			remote:     "https://example.com/rules_go",
45			cmd:        "git",
46			importpath: "github.com/bazeltest/rules_go",
47			r: &vcs.RepoRoot{
48				VCS:  vcs.ByCmd("git"),
49				Repo: "https://example.com/rules_go",
50				Root: "github.com/bazeltest/rules_go",
51			},
52		},
53		{
54			label:      "only importpath",
55			importpath: "github.com/bazeltest/rules_go",
56			r:          root,
57		},
58	} {
59		r, err := getRepoRoot(tc.remote, tc.cmd, tc.importpath)
60		if err != nil {
61			t.Errorf("[%s] %v", tc.label, err)
62		}
63		if !reflect.DeepEqual(r, tc.r) {
64			t.Errorf("[%s] Expected %+v, got %+v", tc.label, tc.r, r)
65		}
66	}
67}
68
69func TestGetRepoRoot_error(t *testing.T) {
70	for _, tc := range []struct {
71		label      string
72		remote     string
73		cmd        string
74		importpath string
75	}{
76		{
77			label:  "importpath as remote",
78			remote: "github.com/bazeltest/rules_go",
79		},
80		{
81			label:      "missing vcs",
82			remote:     "https://github.com/bazeltest/rules_go",
83			importpath: "github.com/bazeltest/rules_go",
84		},
85		{
86			label:      "missing remote",
87			cmd:        "git",
88			importpath: "github.com/bazeltest/rules_go",
89		},
90	} {
91		r, err := getRepoRoot(tc.remote, tc.cmd, tc.importpath)
92		if err == nil {
93			t.Errorf("[%s] expected error. Got %+v", tc.label, r)
94		}
95	}
96}
97