1// Copyright 2021 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package types
6
7import (
8	"go/ast"
9	"go/token"
10	. "internal/types/errors"
11)
12
13// ----------------------------------------------------------------------------
14// API
15
16// A Union represents a union of terms embedded in an interface.
17type Union struct {
18	terms []*Term // list of syntactical terms (not a canonicalized termlist)
19}
20
21// NewUnion returns a new [Union] type with the given terms.
22// It is an error to create an empty union; they are syntactically not possible.
23func NewUnion(terms []*Term) *Union {
24	if len(terms) == 0 {
25		panic("empty union")
26	}
27	return &Union{terms}
28}
29
30func (u *Union) Len() int         { return len(u.terms) }
31func (u *Union) Term(i int) *Term { return u.terms[i] }
32
33func (u *Union) Underlying() Type { return u }
34func (u *Union) String() string   { return TypeString(u, nil) }
35
36// A Term represents a term in a [Union].
37type Term term
38
39// NewTerm returns a new union term.
40func NewTerm(tilde bool, typ Type) *Term { return &Term{tilde, typ} }
41
42func (t *Term) Tilde() bool    { return t.tilde }
43func (t *Term) Type() Type     { return t.typ }
44func (t *Term) String() string { return (*term)(t).String() }
45
46// ----------------------------------------------------------------------------
47// Implementation
48
49// Avoid excessive type-checking times due to quadratic termlist operations.
50const maxTermCount = 100
51
52// parseUnion parses uexpr as a union of expressions.
53// The result is a Union type, or Typ[Invalid] for some errors.
54func parseUnion(check *Checker, uexpr ast.Expr) Type {
55	blist, tlist := flattenUnion(nil, uexpr)
56	assert(len(blist) == len(tlist)-1)
57
58	var terms []*Term
59
60	var u Type
61	for i, x := range tlist {
62		term := parseTilde(check, x)
63		if len(tlist) == 1 && !term.tilde {
64			// Single type. Ok to return early because all relevant
65			// checks have been performed in parseTilde (no need to
66			// run through term validity check below).
67			return term.typ // typ already recorded through check.typ in parseTilde
68		}
69		if len(terms) >= maxTermCount {
70			if isValid(u) {
71				check.errorf(x, InvalidUnion, "cannot handle more than %d union terms (implementation limitation)", maxTermCount)
72				u = Typ[Invalid]
73			}
74		} else {
75			terms = append(terms, term)
76			u = &Union{terms}
77		}
78
79		if i > 0 {
80			check.recordTypeAndValue(blist[i-1], typexpr, u, nil)
81		}
82	}
83
84	if !isValid(u) {
85		return u
86	}
87
88	// Check validity of terms.
89	// Do this check later because it requires types to be set up.
90	// Note: This is a quadratic algorithm, but unions tend to be short.
91	check.later(func() {
92		for i, t := range terms {
93			if !isValid(t.typ) {
94				continue
95			}
96
97			u := under(t.typ)
98			f, _ := u.(*Interface)
99			if t.tilde {
100				if f != nil {
101					check.errorf(tlist[i], InvalidUnion, "invalid use of ~ (%s is an interface)", t.typ)
102					continue // don't report another error for t
103				}
104
105				if !Identical(u, t.typ) {
106					check.errorf(tlist[i], InvalidUnion, "invalid use of ~ (underlying type of %s is %s)", t.typ, u)
107					continue
108				}
109			}
110
111			// Stand-alone embedded interfaces are ok and are handled by the single-type case
112			// in the beginning. Embedded interfaces with tilde are excluded above. If we reach
113			// here, we must have at least two terms in the syntactic term list (but not necessarily
114			// in the term list of the union's type set).
115			if f != nil {
116				tset := f.typeSet()
117				switch {
118				case tset.NumMethods() != 0:
119					check.errorf(tlist[i], InvalidUnion, "cannot use %s in union (%s contains methods)", t, t)
120				case t.typ == universeComparable.Type():
121					check.error(tlist[i], InvalidUnion, "cannot use comparable in union")
122				case tset.comparable:
123					check.errorf(tlist[i], InvalidUnion, "cannot use %s in union (%s embeds comparable)", t, t)
124				}
125				continue // terms with interface types are not subject to the no-overlap rule
126			}
127
128			// Report overlapping (non-disjoint) terms such as
129			// a|a, a|~a, ~a|~a, and ~a|A (where under(A) == a).
130			if j := overlappingTerm(terms[:i], t); j >= 0 {
131				check.softErrorf(tlist[i], InvalidUnion, "overlapping terms %s and %s", t, terms[j])
132			}
133		}
134	}).describef(uexpr, "check term validity %s", uexpr)
135
136	return u
137}
138
139func parseTilde(check *Checker, tx ast.Expr) *Term {
140	x := tx
141	var tilde bool
142	if op, _ := x.(*ast.UnaryExpr); op != nil && op.Op == token.TILDE {
143		x = op.X
144		tilde = true
145	}
146	typ := check.typ(x)
147	// Embedding stand-alone type parameters is not permitted (go.dev/issue/47127).
148	// We don't need this restriction anymore if we make the underlying type of a type
149	// parameter its constraint interface: if we embed a lone type parameter, we will
150	// simply use its underlying type (like we do for other named, embedded interfaces),
151	// and since the underlying type is an interface the embedding is well defined.
152	if isTypeParam(typ) {
153		if tilde {
154			check.errorf(x, MisplacedTypeParam, "type in term %s cannot be a type parameter", tx)
155		} else {
156			check.error(x, MisplacedTypeParam, "term cannot be a type parameter")
157		}
158		typ = Typ[Invalid]
159	}
160	term := NewTerm(tilde, typ)
161	if tilde {
162		check.recordTypeAndValue(tx, typexpr, &Union{[]*Term{term}}, nil)
163	}
164	return term
165}
166
167// overlappingTerm reports the index of the term x in terms which is
168// overlapping (not disjoint) from y. The result is < 0 if there is no
169// such term. The type of term y must not be an interface, and terms
170// with an interface type are ignored in the terms list.
171func overlappingTerm(terms []*Term, y *Term) int {
172	assert(!IsInterface(y.typ))
173	for i, x := range terms {
174		if IsInterface(x.typ) {
175			continue
176		}
177		// disjoint requires non-nil, non-top arguments,
178		// and non-interface types as term types.
179		if debug {
180			if x == nil || x.typ == nil || y == nil || y.typ == nil {
181				panic("empty or top union term")
182			}
183		}
184		if !(*term)(x).disjoint((*term)(y)) {
185			return i
186		}
187	}
188	return -1
189}
190
191// flattenUnion walks a union type expression of the form A | B | C | ...,
192// extracting both the binary exprs (blist) and leaf types (tlist).
193func flattenUnion(list []ast.Expr, x ast.Expr) (blist, tlist []ast.Expr) {
194	if o, _ := x.(*ast.BinaryExpr); o != nil && o.Op == token.OR {
195		blist, tlist = flattenUnion(list, o.X)
196		blist = append(blist, o)
197		x = o.Y
198	}
199	return blist, append(tlist, x)
200}
201