1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// This file implements syntax tree walking.
6
7package syntax
8
9import "fmt"
10
11// Inspect traverses an AST in pre-order: it starts by calling f(root);
12// root must not be nil. If f returns true, Inspect invokes f recursively
13// for each of the non-nil children of root, followed by a call of f(nil).
14//
15// See Walk for caveats about shared nodes.
16func Inspect(root Node, f func(Node) bool) {
17	Walk(root, inspector(f))
18}
19
20type inspector func(Node) bool
21
22func (v inspector) Visit(node Node) Visitor {
23	if v(node) {
24		return v
25	}
26	return nil
27}
28
29// Walk traverses an AST in pre-order: It starts by calling
30// v.Visit(node); node must not be nil. If the visitor w returned by
31// v.Visit(node) is not nil, Walk is invoked recursively with visitor
32// w for each of the non-nil children of node, followed by a call of
33// w.Visit(nil).
34//
35// Some nodes may be shared among multiple parent nodes (e.g., types in
36// field lists such as type T in "a, b, c T"). Such shared nodes are
37// walked multiple times.
38// TODO(gri) Revisit this design. It may make sense to walk those nodes
39// only once. A place where this matters is types2.TestResolveIdents.
40func Walk(root Node, v Visitor) {
41	walker{v}.node(root)
42}
43
44// A Visitor's Visit method is invoked for each node encountered by Walk.
45// If the result visitor w is not nil, Walk visits each of the children
46// of node with the visitor w, followed by a call of w.Visit(nil).
47type Visitor interface {
48	Visit(node Node) (w Visitor)
49}
50
51type walker struct {
52	v Visitor
53}
54
55func (w walker) node(n Node) {
56	if n == nil {
57		panic("nil node")
58	}
59
60	w.v = w.v.Visit(n)
61	if w.v == nil {
62		return
63	}
64
65	switch n := n.(type) {
66	// packages
67	case *File:
68		w.node(n.PkgName)
69		w.declList(n.DeclList)
70
71	// declarations
72	case *ImportDecl:
73		if n.LocalPkgName != nil {
74			w.node(n.LocalPkgName)
75		}
76		w.node(n.Path)
77
78	case *ConstDecl:
79		w.nameList(n.NameList)
80		if n.Type != nil {
81			w.node(n.Type)
82		}
83		if n.Values != nil {
84			w.node(n.Values)
85		}
86
87	case *TypeDecl:
88		w.node(n.Name)
89		w.fieldList(n.TParamList)
90		w.node(n.Type)
91
92	case *VarDecl:
93		w.nameList(n.NameList)
94		if n.Type != nil {
95			w.node(n.Type)
96		}
97		if n.Values != nil {
98			w.node(n.Values)
99		}
100
101	case *FuncDecl:
102		if n.Recv != nil {
103			w.node(n.Recv)
104		}
105		w.node(n.Name)
106		w.fieldList(n.TParamList)
107		w.node(n.Type)
108		if n.Body != nil {
109			w.node(n.Body)
110		}
111
112	// expressions
113	case *BadExpr: // nothing to do
114	case *Name: // nothing to do
115	case *BasicLit: // nothing to do
116
117	case *CompositeLit:
118		if n.Type != nil {
119			w.node(n.Type)
120		}
121		w.exprList(n.ElemList)
122
123	case *KeyValueExpr:
124		w.node(n.Key)
125		w.node(n.Value)
126
127	case *FuncLit:
128		w.node(n.Type)
129		w.node(n.Body)
130
131	case *ParenExpr:
132		w.node(n.X)
133
134	case *SelectorExpr:
135		w.node(n.X)
136		w.node(n.Sel)
137
138	case *IndexExpr:
139		w.node(n.X)
140		w.node(n.Index)
141
142	case *SliceExpr:
143		w.node(n.X)
144		for _, x := range n.Index {
145			if x != nil {
146				w.node(x)
147			}
148		}
149
150	case *AssertExpr:
151		w.node(n.X)
152		w.node(n.Type)
153
154	case *TypeSwitchGuard:
155		if n.Lhs != nil {
156			w.node(n.Lhs)
157		}
158		w.node(n.X)
159
160	case *Operation:
161		w.node(n.X)
162		if n.Y != nil {
163			w.node(n.Y)
164		}
165
166	case *CallExpr:
167		w.node(n.Fun)
168		w.exprList(n.ArgList)
169
170	case *ListExpr:
171		w.exprList(n.ElemList)
172
173	// types
174	case *ArrayType:
175		if n.Len != nil {
176			w.node(n.Len)
177		}
178		w.node(n.Elem)
179
180	case *SliceType:
181		w.node(n.Elem)
182
183	case *DotsType:
184		w.node(n.Elem)
185
186	case *StructType:
187		w.fieldList(n.FieldList)
188		for _, t := range n.TagList {
189			if t != nil {
190				w.node(t)
191			}
192		}
193
194	case *Field:
195		if n.Name != nil {
196			w.node(n.Name)
197		}
198		w.node(n.Type)
199
200	case *InterfaceType:
201		w.fieldList(n.MethodList)
202
203	case *FuncType:
204		w.fieldList(n.ParamList)
205		w.fieldList(n.ResultList)
206
207	case *MapType:
208		w.node(n.Key)
209		w.node(n.Value)
210
211	case *ChanType:
212		w.node(n.Elem)
213
214	// statements
215	case *EmptyStmt: // nothing to do
216
217	case *LabeledStmt:
218		w.node(n.Label)
219		w.node(n.Stmt)
220
221	case *BlockStmt:
222		w.stmtList(n.List)
223
224	case *ExprStmt:
225		w.node(n.X)
226
227	case *SendStmt:
228		w.node(n.Chan)
229		w.node(n.Value)
230
231	case *DeclStmt:
232		w.declList(n.DeclList)
233
234	case *AssignStmt:
235		w.node(n.Lhs)
236		if n.Rhs != nil {
237			w.node(n.Rhs)
238		}
239
240	case *BranchStmt:
241		if n.Label != nil {
242			w.node(n.Label)
243		}
244		// Target points to nodes elsewhere in the syntax tree
245
246	case *CallStmt:
247		w.node(n.Call)
248
249	case *ReturnStmt:
250		if n.Results != nil {
251			w.node(n.Results)
252		}
253
254	case *IfStmt:
255		if n.Init != nil {
256			w.node(n.Init)
257		}
258		w.node(n.Cond)
259		w.node(n.Then)
260		if n.Else != nil {
261			w.node(n.Else)
262		}
263
264	case *ForStmt:
265		if n.Init != nil {
266			w.node(n.Init)
267		}
268		if n.Cond != nil {
269			w.node(n.Cond)
270		}
271		if n.Post != nil {
272			w.node(n.Post)
273		}
274		w.node(n.Body)
275
276	case *SwitchStmt:
277		if n.Init != nil {
278			w.node(n.Init)
279		}
280		if n.Tag != nil {
281			w.node(n.Tag)
282		}
283		for _, s := range n.Body {
284			w.node(s)
285		}
286
287	case *SelectStmt:
288		for _, s := range n.Body {
289			w.node(s)
290		}
291
292	// helper nodes
293	case *RangeClause:
294		if n.Lhs != nil {
295			w.node(n.Lhs)
296		}
297		w.node(n.X)
298
299	case *CaseClause:
300		if n.Cases != nil {
301			w.node(n.Cases)
302		}
303		w.stmtList(n.Body)
304
305	case *CommClause:
306		if n.Comm != nil {
307			w.node(n.Comm)
308		}
309		w.stmtList(n.Body)
310
311	default:
312		panic(fmt.Sprintf("internal error: unknown node type %T", n))
313	}
314
315	w.v.Visit(nil)
316}
317
318func (w walker) declList(list []Decl) {
319	for _, n := range list {
320		w.node(n)
321	}
322}
323
324func (w walker) exprList(list []Expr) {
325	for _, n := range list {
326		w.node(n)
327	}
328}
329
330func (w walker) stmtList(list []Stmt) {
331	for _, n := range list {
332		w.node(n)
333	}
334}
335
336func (w walker) nameList(list []*Name) {
337	for _, n := range list {
338		w.node(n)
339	}
340}
341
342func (w walker) fieldList(list []*Field) {
343	for _, n := range list {
344		w.node(n)
345	}
346}
347