xref: /aosp_15_r20/external/bazelbuild-rules_go/go/tools/releaser/upgradedep.go (revision 9bb1b549b6a84214c53be0924760be030e66b93a)
1// Copyright 2021 The Bazel Authors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package main
16
17import (
18	"bytes"
19	"context"
20	"crypto/sha256"
21	"encoding/hex"
22	"errors"
23	"flag"
24	"fmt"
25	"io"
26	"net/http"
27	"os"
28	"os/exec"
29	"path"
30	"path/filepath"
31	"strings"
32	"time"
33
34	bzl "github.com/bazelbuild/buildtools/build"
35	"github.com/google/go-github/v36/github"
36	"golang.org/x/mod/semver"
37	"golang.org/x/oauth2"
38	"golang.org/x/sync/errgroup"
39)
40
41var upgradeDepCmd = command{
42	name:        "upgrade-dep",
43	description: "upgrades a dependency in WORKSPACE or go_repositories.bzl",
44	help: `releaser upgrade-dep [-githubtoken=token] [-mirror] [-work] deps...
45
46upgrade-dep upgrades one or more rules_go dependencies in WORKSPACE or
47go/private/repositories.bzl. Dependency names (matching the name attributes)
48can be specified with positional arguments. "all" may be specified to upgrade
49all upgradeable dependencies.
50
51For each dependency, upgrade-dep finds the highest version available in the
52upstream repository. If no version is available, upgrade-dep uses the commit
53at the tip of the default branch. If a version is part of a release,
54upgrade-dep will try to use an archive attached to the release; if none is
55available, upgrade-dep uses an archive generated by GitHub.
56
57Once upgrade-dep has found the URL for the latest version, it will:
58
59* Download the archive.
60* Upload the archive to mirror.bazel.build.
61* Re-generate patches, either by running a command or by re-applying the
62  old patches.
63* Update dependency attributes in WORKSPACE and repositories.bzl, then format
64  and rewrite those files.
65
66Upgradeable dependencies need a comment like '# releaser:upgrade-dep org repo'
67where org and repo are the GitHub organization and repository. We could
68potentially fetch archives from proxy.golang.org instead, but it's not available
69in as many countries.
70
71Patches may have a comment like '# releaser:patch-cmd name args...'. If this
72comment is present, upgrade-dep will generate the patch by running the specified
73command in a temporary directory containing the extracted archive with the
74previous patches applied.
75`,
76}
77
78func init() {
79	// break init cycle
80	upgradeDepCmd.run = runUpgradeDep
81}
82
83func runUpgradeDep(ctx context.Context, stderr io.Writer, args []string) error {
84	// Parse arguments.
85	flags := flag.NewFlagSet("releaser upgrade-dep", flag.ContinueOnError)
86	var githubToken githubTokenFlag
87	var uploadToMirror, leaveWorkDir bool
88	flags.Var(&githubToken, "githubtoken", "GitHub personal access token or path to a file containing it")
89	flags.BoolVar(&uploadToMirror, "mirror", true, "whether to upload dependency archives to mirror.bazel.build")
90	flags.BoolVar(&leaveWorkDir, "work", false, "don't delete temporary work directory (for debugging)")
91	if err := flags.Parse(args); err != nil {
92		return err
93	}
94	if flags.NArg() == 0 {
95		return usageErrorf(&upgradeDepCmd, "No dependencies specified")
96	}
97	upgradeAll := false
98	for _, arg := range flags.Args() {
99		if arg == "all" {
100			upgradeAll = true
101			break
102		}
103	}
104	if upgradeAll && flags.NArg() != 1 {
105		return usageErrorf(&upgradeDepCmd, "When 'all' is specified, it must be the only argument")
106	}
107
108	httpClient := http.DefaultClient
109	if githubToken != "" {
110		ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: string(githubToken)})
111		httpClient = oauth2.NewClient(ctx, ts)
112	}
113	gh := &githubClient{Client: github.NewClient(httpClient)}
114
115	workDir, err := os.MkdirTemp("", "releaser-upgrade-dep-*")
116	if leaveWorkDir {
117		fmt.Fprintf(stderr, "work dir: %s\n", workDir)
118	} else {
119		defer func() {
120			if rerr := os.RemoveAll(workDir); err == nil && rerr != nil {
121				err = rerr
122			}
123		}()
124	}
125
126	// Make sure we have everything we need.
127	// upgrade-dep must be run inside rules_go (though we just check for
128	// WORKSPACE), and a few tools must be available.
129	rootDir, err := repoRoot()
130	if err != nil {
131		return err
132	}
133	for _, tool := range []string{"diff", "gazelle", "gsutil", "patch"} {
134		if _, err := exec.LookPath(tool); err != nil {
135			return fmt.Errorf("%s must be installed in PATH", tool)
136		}
137	}
138
139	// Parse and index files we might want to update.
140	type file struct {
141		path     string
142		funcName string
143		parsed   *bzl.File
144		body     []bzl.Expr
145	}
146	files := []file{
147		{path: filepath.Join(rootDir, "WORKSPACE")},
148		{path: filepath.Join(rootDir, "go/private/repositories.bzl"), funcName: "go_rules_dependencies"},
149	}
150	depIndex := make(map[string]*bzl.CallExpr)
151
152	for i := range files {
153		f := &files[i]
154		data, err := os.ReadFile(f.path)
155		if err != nil {
156			return err
157		}
158		f.parsed, err = bzl.Parse(f.path, data)
159		if err != nil {
160			return err
161		}
162
163		if f.funcName == "" {
164			f.body = f.parsed.Stmt
165		} else {
166			for _, expr := range f.parsed.Stmt {
167				def, ok := expr.(*bzl.DefStmt)
168				if !ok {
169					continue
170				}
171				if def.Name == f.funcName {
172					f.body = def.Body
173					break
174				}
175			}
176			if f.body == nil {
177				return fmt.Errorf("in file %s, could not find function %s", f.path, f.funcName)
178			}
179		}
180
181		for _, expr := range f.body {
182			call, ok := expr.(*bzl.CallExpr)
183			if !ok {
184				continue
185			}
186			for _, arg := range call.List {
187				kwarg, ok := arg.(*bzl.AssignExpr)
188				if !ok {
189					continue
190				}
191				key := kwarg.LHS.(*bzl.Ident) // required by parser
192				if key.Name != "name" {
193					continue
194				}
195				value, ok := kwarg.RHS.(*bzl.StringExpr)
196				if !ok {
197					continue
198				}
199				depIndex[value.Value] = call
200			}
201		}
202	}
203
204	// Update dependencies in those files.
205	eg, egctx := errgroup.WithContext(ctx)
206	if upgradeAll {
207		for name := range depIndex {
208			name := name
209			if _, _, err := parseUpgradeDepDirective(depIndex[name]); err != nil {
210				continue
211			}
212			eg.Go(func() error {
213				return upgradeDepDecl(egctx, gh, workDir, name, depIndex[name], uploadToMirror)
214			})
215		}
216	} else {
217		for _, arg := range flags.Args() {
218			if depIndex[arg] == nil {
219				return fmt.Errorf("could not find dependency %s", arg)
220			}
221		}
222		for _, arg := range flags.Args() {
223			arg := arg
224			eg.Go(func() error {
225				return upgradeDepDecl(egctx, gh, workDir, arg, depIndex[arg], uploadToMirror)
226			})
227		}
228	}
229	if err := eg.Wait(); err != nil {
230		return err
231	}
232
233	// Format and write files back to disk.
234	for _, f := range files {
235		if err := os.WriteFile(f.path, bzl.Format(f.parsed), 0666); err != nil {
236			return err
237		}
238	}
239	return nil
240}
241
242// upgradeDepDecl upgrades a specific dependency.
243func upgradeDepDecl(ctx context.Context, gh *githubClient, workDir, name string, call *bzl.CallExpr, uploadToMirror bool) (err error) {
244	defer func() {
245		if err != nil {
246			err = fmt.Errorf("upgrading %s: %w", name, err)
247		}
248	}()
249
250	// Find a '# releaser:upgrade-dep org repo' comment. We could probably
251	// figure this out from URLs but this also serves to mark a dependency as
252	// being automatically upgradeable.
253	orgName, repoName, err := parseUpgradeDepDirective(call)
254	if err != nil {
255		return err
256	}
257
258	// Find attributes we'll need to read or write. We'll modify these directly
259	// in the AST. Nothing else should read or write them while we're working.
260	attrs := map[string]*bzl.Expr{
261		"patches":      nil,
262		"sha256":       nil,
263		"strip_prefix": nil,
264		"urls":         nil,
265	}
266	var urlsKwarg *bzl.AssignExpr
267	for _, arg := range call.List {
268		kwarg, ok := arg.(*bzl.AssignExpr)
269		if !ok {
270			continue
271		}
272		key := kwarg.LHS.(*bzl.Ident) // required by parser
273		if _, ok := attrs[key.Name]; ok {
274			attrs[key.Name] = &kwarg.RHS
275		}
276		if key.Name == "urls" {
277			urlsKwarg = kwarg
278		}
279	}
280	for key := range attrs {
281		if key == "patches" {
282			// Don't add optional attributes.
283			continue
284		}
285		if attrs[key] == nil {
286			kwarg := &bzl.AssignExpr{LHS: &bzl.Ident{Name: key}, Op: "="}
287			call.List = append(call.List, kwarg)
288			attrs[key] = &kwarg.RHS
289		}
290	}
291
292	// Find the highest tag in semver order, ignoring whether the version has a
293	// leading "v" or not. If there are no tags, find the commit at the tip of the
294	// default branch.
295	tags, err := gh.listTags(ctx, orgName, repoName)
296	if err != nil {
297		return err
298	}
299
300	vname := func(name string) string {
301		if !strings.HasPrefix(name, "v") {
302			return "v" + name
303		}
304		return name
305	}
306
307	w := 0
308	for r := range tags {
309		name := vname(*tags[r].Name)
310		if name != semver.Canonical(name) {
311			continue
312		}
313		tags[w] = tags[r]
314		w++
315	}
316	tags = tags[:w]
317
318	var highestTag *github.RepositoryTag
319	var highestVname string
320	for _, tag := range tags {
321		name := vname(*tag.Name)
322		if highestTag == nil || semver.Compare(name, highestVname) > 0 {
323			highestTag = tag
324			highestVname = name
325		}
326	}
327
328	var ghURL, stripPrefix, urlComment string
329	date := time.Now().Format("2006-01-02")
330	if highestTag != nil {
331		// If the tag is part of a release, check whether there is a release
332		// artifact we should use.
333		release, _, err := gh.Repositories.GetReleaseByTag(ctx, orgName, repoName, *highestTag.Name)
334		if err == nil {
335			wantNames := []string{
336				fmt.Sprintf("%s-%s.tar.gz", repoName, *highestTag.Name),
337				fmt.Sprintf("%s-%s.zip", repoName, *highestTag.Name),
338			}
339		AssetName:
340			for _, asset := range release.Assets {
341				for _, wantName := range wantNames {
342					if *asset.Name == wantName {
343						ghURL = asset.GetBrowserDownloadURL()
344						stripPrefix = "" // may not always be correct
345						break AssetName
346					}
347				}
348			}
349		}
350		if ghURL == "" {
351			ghURL = fmt.Sprintf("https://github.com/%s/%s/archive/refs/tags/%s.zip", orgName, repoName, *highestTag.Name)
352			stripPrefix = repoName + "-" + strings.TrimPrefix(*highestTag.Name, "v")
353		}
354		urlComment = fmt.Sprintf("%s, latest as of %s", *highestTag.Name, date)
355	} else {
356		repo, _, err := gh.Repositories.Get(ctx, orgName, repoName)
357		if err != nil {
358			return err
359		}
360		defaultBranchName := "main"
361		if repo.DefaultBranch != nil {
362			defaultBranchName = *repo.DefaultBranch
363		}
364		branch, _, err := gh.Repositories.GetBranch(ctx, orgName, repoName, defaultBranchName)
365		if err != nil {
366			return err
367		}
368		ghURL = fmt.Sprintf("https://github.com/%s/%s/archive/%s.zip", orgName, repoName, *branch.Commit.SHA)
369		stripPrefix = repoName + "-" + *branch.Commit.SHA
370		urlComment = fmt.Sprintf("%s, as of %s", defaultBranchName, date)
371	}
372	ghURLWithoutScheme := ghURL[len("https://"):]
373	mirrorURL := "https://mirror.bazel.build/" + ghURLWithoutScheme
374
375	// Download the archive and find the SHA.
376	archiveFile, err := os.CreateTemp("", "")
377	if err != nil {
378		return err
379	}
380	defer func() {
381		archiveFile.Close()
382		if rerr := os.Remove(archiveFile.Name()); err == nil && rerr != nil {
383			err = rerr
384		}
385	}()
386	resp, err := http.Get(ghURL)
387	if err != nil {
388		return err
389	}
390	hw := sha256.New()
391	mw := io.MultiWriter(hw, archiveFile)
392	if _, err := io.Copy(mw, resp.Body); err != nil {
393		resp.Body.Close()
394		return err
395	}
396	if err := resp.Body.Close(); err != nil {
397		return err
398	}
399	sha256Sum := hex.EncodeToString(hw.Sum(nil))
400	if _, err := archiveFile.Seek(0, io.SeekStart); err != nil {
401		return err
402	}
403
404	// Upload the archive to mirror.bazel.build.
405	if uploadToMirror {
406		if err := copyFileToMirror(ctx, ghURLWithoutScheme, archiveFile.Name()); err != nil {
407			return err
408		}
409	}
410
411	// If there are patches, re-apply or re-generate them.
412	// Patch labels may have "# releaser:patch-cmd name args..." directives
413	// that instruct this program to generate the patch by running a commnad
414	// in the directory. If there is no such directive, we apply the old patch
415	// using "patch". In either case, we'll generate a new patch with "diff".
416	// We'll scrub the timestamps to avoid excessive diffs in the PR that
417	// updates dependencies.
418	rootDir, err := repoRoot()
419	if err != nil {
420		return err
421	}
422	if attrs["patches"] != nil {
423		if err != nil {
424			return err
425		}
426		patchDir := filepath.Join(workDir, name, "a")
427		if err := extractArchive(archiveFile, path.Base(ghURL), patchDir, stripPrefix); err != nil {
428			return err
429		}
430
431		patchesList, ok := (*attrs["patches"]).(*bzl.ListExpr)
432		if !ok {
433			return fmt.Errorf("\"patches\" attribute is not a list")
434		}
435		for patchIndex, patchLabelExpr := range patchesList.List {
436			patchLabelValue, comments, err := parsePatchesItem(patchLabelExpr)
437			if err != nil {
438				return fmt.Errorf("parsing expr %#v : %w", patchLabelExpr, err)
439			}
440
441			if !strings.HasPrefix(patchLabelValue, "//third_party:") {
442				return fmt.Errorf("patch does not start with '//third_party:': %q", patchLabelValue)
443			}
444			patchName := patchLabelValue[len("//third_party:"):]
445			patchPath := filepath.Join(rootDir, "third_party", patchName)
446			prevDir := filepath.Join(workDir, name, string('a'+patchIndex))
447			patchDir := filepath.Join(workDir, name, string('a'+patchIndex+1))
448			var patchCmd []string
449			for _, c := range comments.Before {
450				words := strings.Fields(strings.TrimPrefix(c.Token, "#"))
451				if len(words) > 0 && words[0] == "releaser:patch-cmd" {
452					patchCmd = words[1:]
453					break
454				}
455			}
456
457			if err := copyDir(patchDir, prevDir); err != nil {
458				return err
459			}
460			if patchCmd == nil {
461				if err := runForError(ctx, patchDir, "patch", "-Np1", "-i", patchPath); err != nil {
462					return err
463				}
464			} else {
465				if err := runForError(ctx, patchDir, patchCmd[0], patchCmd[1:]...); err != nil {
466					return err
467				}
468			}
469			patch, _ := runForOutput(ctx, filepath.Join(workDir, name), "diff", "-urN", string('a'+patchIndex), string('a'+patchIndex+1))
470			patch = sanitizePatch(patch)
471			if err := os.WriteFile(patchPath, patch, 0666); err != nil {
472				return err
473			}
474		}
475	}
476
477	// Update the attributes.
478	*attrs["sha256"] = &bzl.StringExpr{Value: sha256Sum}
479	*attrs["strip_prefix"] = &bzl.StringExpr{Value: stripPrefix}
480	*attrs["urls"] = &bzl.ListExpr{
481		List: []bzl.Expr{
482			&bzl.StringExpr{Value: mirrorURL},
483			&bzl.StringExpr{Value: ghURL},
484		},
485		ForceMultiLine: true,
486	}
487	urlsKwarg.Before = []bzl.Comment{{Token: "# " + urlComment}}
488
489	return nil
490}
491
492func parsePatchesItem(patchLabelExpr bzl.Expr) (value string, comments *bzl.Comments, err error) {
493	switch patchLabel := patchLabelExpr.(type) {
494	case *bzl.CallExpr:
495		// Verify the identifier, should be Label
496		if ident, ok := patchLabel.X.(*bzl.Ident); !ok {
497			return "", nil, fmt.Errorf("invalid identifier while parsing patch label")
498		} else if ident.Name != "Label" {
499			return "", nil, fmt.Errorf("invalid patch function: %q", ident.Name)
500		}
501
502		// Expect 1 String argument with the patch
503		if len(patchLabel.List) != 1 {
504			return "", nil, fmt.Errorf("Label expr should have 1 argument, found %d", len(patchLabel.List))
505		}
506
507		// Parse patch as a string
508		patchLabelStr, ok := patchLabel.List[0].(*bzl.StringExpr)
509		if !ok {
510			return "", nil, fmt.Errorf("Label expr does not contain a string literal")
511		}
512		return patchLabelStr.Value, patchLabel.Comment(), nil
513	case *bzl.StringExpr:
514		return strings.TrimPrefix(patchLabel.Value, "@io_bazel_rules_go"), patchLabel.Comment(), nil
515	default:
516		return "", nil, fmt.Errorf("not all patches are string literals or Label()")
517	}
518}
519
520// parseUpgradeDepDirective parses a '# releaser:upgrade-dep org repo' directive
521// and returns the organization and repository name or an error if the directive
522// was not found or malformed.
523func parseUpgradeDepDirective(call *bzl.CallExpr) (orgName, repoName string, err error) {
524	// TODO: support other upgrade strategies. For example, support git_repository
525	// and go_repository (possibly wrapped in _maybe).
526	for _, c := range call.Comment().Before {
527		words := strings.Fields(strings.TrimPrefix(c.Token, "#"))
528		if len(words) == 0 || words[0] != "releaser:upgrade-dep" {
529			continue
530		}
531		if len(words) != 3 {
532			return "", "", errors.New("invalid upgrade-dep directive; expected org, and name fields")
533		}
534		return words[1], words[2], nil
535	}
536	return "", "", errors.New("releaser:upgrade-dep directive not found")
537}
538
539// sanitizePatch sets all of the non-zero patch dates to the same value. This
540// reduces churn in the PR that updates the patches.
541//
542// We avoid changing zero-valued patch dates, which are used in added or
543// deleted files. Since zero-valued dates can vary a bit by time zone, we assume
544// that any year starting with "19" is a zero-valeud date.
545func sanitizePatch(patch []byte) []byte {
546	lines := bytes.Split(patch, []byte{'\n'})
547
548	for i, line := range lines {
549		if !bytes.HasPrefix(line, []byte("+++ ")) && !bytes.HasPrefix(line, []byte("--- ")) {
550			continue
551		}
552
553		tab := bytes.LastIndexByte(line, '\t')
554		if tab < 0 || bytes.HasPrefix(line[tab+1:], []byte("19")) {
555			continue
556		}
557
558		lines[i] = append(line[:tab+1], []byte("2000-01-01 00:00:00.000000000 -0000")...)
559	}
560	return bytes.Join(lines, []byte{'\n'})
561}
562