xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/tree_views.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <torch/csrc/jit/frontend/error_report.h>
3 #include <torch/csrc/jit/frontend/strtod.h>
4 #include <torch/csrc/jit/frontend/tree.h>
5 
6 #include <c10/util/complex.h>
7 #include <functional>
8 #include <iostream>
9 #include <string>
10 #include <utility>
11 
12 namespace torch::jit {
13 
14 // clang-format off
15 // TreeView provides a statically-typed way to traverse the tree, which should
16 // be formed according to the grammar below.
17 //
18 // A few notes on types and their aliases:
19 // - List<T> is really a Tree with kind TK_LIST and elements as subtrees
20 // - Maybe<T> is really a Tree with kind TK_OPTION that has 0 or 1 subtree of type T
21 // - Builtin types are: Ident (TK_IDENT), String (TK_STRING)
22 //
23 // Param = Param(Maybe<Expr> type, Ident name)                          TK_PARAM
24 //
25 // Decl  = Decl(List<Param> params, Maybe<Expr> return_type)            TK_DECL
26 // Def   = Def(Ident name, Decl decl, List<Stmt> body)                  TK_DEF
27 // ClassDef = ClassDef(Ident name,                                      TK_CLASS_DEF
28 //                     Maybe<Expr> superclass,
29 //                     List<Stmt> body)
30 //
31 // Stmt  = If(Expr cond, List<Stmt> true_body, List<Stmt> false_body)   TK_IF
32 //       | For(List<Expr> targets, List<Expr> iters, List<Stmt> body)   TK_FOR
33 //       | While(Expr cond, List<Stmt> body)                            TK_WHILE
34 //       | Global(List<Ident> idents)                                   TK_GLOBAL
35 //       -- NB: the only type of Expr's allowed on lhs are Var
36 //          Or a tuple containing Var with an optional terminating Starred
37 //       | Assign(Expr lhs, Maybe<Expr> rhs, Maybe<Expr> type)          TK_ASSIGN
38 //       | AugAssign(Expr lhs, AugAssignKind aug_op, Expr rhs)          TK_AUG_ASSIGN
39 //       | Return(List<Expr> values)                                    TK_RETURN
40 //       | ExprStmt(List<Expr> expr)                                    TK_EXPR_STMT
41 //       | Raise(Expr expr)                                             TK_RAISE
42 //       | Def                                                          TK_DEF
43 //       | With(List<WithItem> targets, List<Stmt> body)                TK_WITH
44 //
45 // Expr  = TernaryIf(Expr cond, Expr true_expr, Expr false_expr)        TK_IF_EXPR
46 //       | BinOp(Expr lhs, Expr rhs)
47 //       |     And                                                      TK_AND
48 //       |     Or                                                       TK_OR
49 //       |     Lt                                                       '<'
50 //       |     Gt                                                       '>'
51 //       |     Eq                                                       TK_EQ
52 //       |     Le                                                       TK_LE
53 //       |     Ge                                                       TK_GE
54 //       |     Ne                                                       TK_NE
55 //       |     Is                                                       TK_IS
56 //       |     IsNot                                                    TK_ISNOT
57 //       |     Add                                                      '+'
58 //       |     Sub                                                      '-'
59 //       |     Mul                                                      '*'
60 //       |     Div                                                      '/'
61 //       |     Mod                                                      '%'
62 //       |     MatMult                                                  '@'
63 //       |     Pow                                                      TK_POW
64 //       | UnaryOp(Expr expr)
65 //       |     Not                                                      TK_NOT
66 //       |     USub                                                     '-'
67 //       | Const(String value)                                          TK_CONST
68 //       -- NB: x.name(y) is desugared into name(x, y)
69 //       | Apply(Ident name, List<Expr> args, List<Attribute> kwargs)   TK_APPLY
70 //       | Select(Expr value, Ident selector)                           '.'
71 //       | Subscript(Expr value, List<Expr> subscript_exprs)            TK_SUBSCRIPT
72 //       | SliceExpr(Maybe<Expr> start, Maybe<Expr> end)                TK_SLICE_EXPR
73 //       | Var(Ident name)                                              TK_VAR
74 //       | ListLiteral(List<Expr> inputs)                               TK_LIST_LITERAL
75 //       | TupleLiteral(List<Expr> inputs)                              TK_TUPLE_LITERAL
76 //       | Starred(Expr expr)                                           TK_STARRED
77 //       | WithItem(Expr target, Maybe<Var> var)                        TK_WITH_ITEM
78 // -- NB: only allowed expressions are Const or List(Const)
79 //        (List as a value, not type constructor)
80 // Attribute = Attribute(Ident name, Expr value)                        TK_ATTRIBUTE
81 //
82 // AugAssignKind =
83 //            | Add()                                                   TK_PLUS_EQ
84 //            | Sub()                                                   TK_MINUS_EQ
85 //            | Mul()                                                   TK_TIMES_EQ
86 //            | Div()                                                   TK_DIV_EQ
87 //            | Mod()                                                   TK_MOD_EQ
88 //
89 
90 // Each subclass of TreeView should provide:
91 // 1. Constructor that takes a TreeRef, and checks that it's of the right type.
92 // 2. Accessors that get underlying information out of the object. If they
93 //    return subtrees, they should wrap them in appropriate views too.
94 // 3. Static method 'create' that creates the underlying TreeRef object
95 //    for every TreeRef kind that has a TreeView, the parser always uses
96 //    (e.g.) Ident::create rather than Compound::Create, this means that
97 //    changes to the structure of Ident are always made right here rather
98 //    than both in the parser and in this code.
99 // XXX: these structs should have no fields to prevent slicing when passing by value
100 // clang-format on
101 struct TreeView {
TreeViewTreeView102   explicit TreeView(TreeRef tree) : tree_(std::move(tree)) {}
treeTreeView103   TreeRef tree() const {
104     return tree_;
105   }
rangeTreeView106   const SourceRange& range() const {
107     return tree_->range();
108   }
TreeRefTreeView109   operator TreeRef() const {
110     return tree_;
111   }
getTreeView112   const TreeRef& get() const {
113     return tree_;
114   }
kindTreeView115   int kind() const {
116     return tree_->kind();
117   }
dumpTreeView118   void dump() const {
119     std::cout << tree_;
120   }
121 
122  protected:
subtreeTreeView123   const TreeRef& subtree(size_t i) const {
124     return tree_->trees().at(i);
125   }
126   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
127   TreeRef tree_;
128 };
129 
130 template <typename T>
131 struct ListIterator {
ListIteratorListIterator132   ListIterator(TreeList::const_iterator it) : it(it) {}
133   bool operator!=(const ListIterator& rhs) const {
134     return it != rhs.it;
135   }
136   bool operator==(const ListIterator& rhs) const {
137     return it == rhs.it;
138   }
139   T operator*() const {
140     return T(*it);
141   }
142   ListIterator& operator+=(std::ptrdiff_t n) {
143     it += n;
144     return *this;
145   }
146   ListIterator& operator++() {
147     ++it;
148     return *this;
149   }
150   ListIterator& operator--() {
151     --it;
152     return *this;
153   }
154 
155  private:
156   TreeList::const_iterator it;
157 };
158 
159 template <typename T>
160 struct List : public TreeView {
161   using iterator = ListIterator<T>;
162   using const_iterator = ListIterator<T>;
163 
ListList164   List(const TreeRef& tree) : TreeView(tree) {
165     tree->match(TK_LIST);
166     // Iterate over list to temporarily instantiate Ts that will check the type
167     for (const T& elem : *this) {
168       (void)elem; // silence unused warning
169     }
170   }
beginList171   iterator begin() const {
172     return iterator(tree_->trees().begin());
173   }
endList174   iterator end() const {
175     return iterator(tree_->trees().end());
176   }
emptyList177   bool empty() const {
178     return tree_->trees().begin() == tree_->trees().end();
179   }
180   T operator[](size_t i) const {
181     return T(subtree(i));
182   }
mapList183   TreeRef map(const std::function<TreeRef(const T&)>& fn) {
184     return tree_->map([&](TreeRef v) { return fn(T(v)); });
185   }
createList186   static List create(const SourceRange& range, const std::vector<T>& subtrees) {
187     TreeList type_erased_sub{subtrees.begin(), subtrees.end()};
188     return List(Compound::create(TK_LIST, range, std::move(type_erased_sub)));
189   }
unsafeCreateList190   static List unsafeCreate(const SourceRange& range, TreeList&& subtrees) {
191     return List(Compound::create(TK_LIST, range, std::move(subtrees)));
192   }
sizeList193   size_t size() const {
194     return tree_->trees().size();
195   }
196 };
197 
198 template <typename T>
199 struct Maybe : public TreeView {
MaybeMaybe200   explicit Maybe(const TreeRef& tree) : TreeView(tree) {
201     tree_->match(TK_OPTION);
202     if (tree_->trees().size() > 1)
203       throw(ErrorReport(tree) << "Maybe trees can have at most one subtree");
204   }
MaybeMaybe205   /* implicit */ Maybe(const T& tree) : TreeView(tree) {}
presentMaybe206   bool present() const {
207     return tree_->trees().size() > 0;
208   }
getMaybe209   T get() const {
210     return T(tree_->trees().at(0));
211   }
mapMaybe212   TreeRef map(const std::function<TreeRef(const T&)>& fn) {
213     return tree_->map([&](TreeRef v) { return fn(T(v)); });
214   }
createMaybe215   static Maybe<T> create(const SourceRange& range) {
216     return Maybe<T>(Compound::create(TK_OPTION, range, {}));
217   }
createMaybe218   static Maybe<T> create(const SourceRange& range, const T& value) {
219     return Maybe<T>(Compound::create(TK_OPTION, range, {value}));
220   }
221 };
222 
223 struct Ident : public TreeView {
IdentIdent224   explicit Ident(const TreeRef& tree) : TreeView(tree) {
225     tree_->match(TK_IDENT);
226   }
nameIdent227   const std::string& name() const {
228     return subtree(0)->stringValue();
229   }
createIdent230   static Ident create(const SourceRange& range, std::string name) {
231     return Ident(
232         Compound::create(TK_IDENT, range, {String::create(std::move(name))}));
233   }
234 };
235 
236 ////////////////////////////////////////////////////////////////////////////////
237 // Base types (production LHS)
238 ////////////////////////////////////////////////////////////////////////////////
239 
240 struct Stmt : public TreeView {
StmtStmt241   explicit Stmt(const TreeRef& tree) : TreeView(tree) {
242     switch (tree->kind()) {
243       case TK_IF:
244       case TK_FOR:
245       case TK_WHILE:
246       case TK_GLOBAL:
247       case TK_ASSIGN:
248       case TK_AUG_ASSIGN:
249       case TK_RETURN:
250       case TK_EXPR_STMT:
251       case TK_RAISE:
252       case TK_ASSERT:
253       case TK_PASS:
254       case TK_BREAK:
255       case TK_DELETE:
256       case TK_CONTINUE:
257       case TK_DEF:
258       case TK_WITH:
259         return;
260       default:
261         throw(
262             ErrorReport(tree)
263             << kindToString(tree->kind()) << " is not a valid Stmt");
264     }
265   }
266 };
267 
268 struct Expr : public TreeView {
ExprExpr269   explicit Expr(const TreeRef& tree) : TreeView(tree) {
270     switch (tree->kind()) {
271       case TK_IF_EXPR:
272       case TK_AND:
273       case TK_OR:
274       case '<':
275       case '>':
276       case TK_IS:
277       case TK_ISNOT:
278       case TK_EQ:
279       case TK_LE:
280       case TK_GE:
281       case TK_NE:
282       case '+':
283       case '-':
284       case TK_UNARY_MINUS:
285       case '~':
286       case '*':
287       case TK_STARRED:
288       case '/':
289       case '%':
290       case TK_NOT:
291       case TK_CONST:
292       case TK_STRINGLITERAL:
293       case TK_TRUE:
294       case TK_FALSE:
295       case TK_NONE:
296       case TK_NONE_TYPE:
297       case TK_CAST:
298       case TK_APPLY:
299       case '.':
300       case TK_SUBSCRIPT:
301       case TK_SLICE_EXPR:
302       case TK_VAR:
303       case TK_LIST_LITERAL:
304       case TK_TUPLE_LITERAL:
305       case TK_DICT_LITERAL:
306       case '@':
307       case TK_POW:
308       case TK_LSHIFT:
309       case TK_RSHIFT:
310       case TK_FLOOR_DIV:
311       case '&':
312       case '^':
313       case '|':
314       case TK_LIST_COMP:
315       case TK_DICT_COMP:
316       case TK_DOTS:
317       case TK_IN:
318       case TK_WITH_ITEM:
319         return;
320       default:
321         throw(
322             ErrorReport(tree)
323             << kindToString(tree->kind()) << " is not a valid Expr");
324     }
325   }
326 };
327 
328 ////////////////////////////////////////////////////////////////////////////////
329 // Helper nodes (mostly for function arguments)
330 ////////////////////////////////////////////////////////////////////////////////
331 
332 struct Attribute : public TreeView {
AttributeAttribute333   explicit Attribute(const TreeRef& tree) : TreeView(tree) {
334     tree_->match(TK_ATTRIBUTE);
335   }
nameAttribute336   Ident name() const {
337     return Ident(subtree(0));
338   }
valueAttribute339   Expr value() const {
340     return Expr(subtree(1));
341   }
createAttribute342   static Attribute create(
343       const SourceRange& range,
344       const Ident& name,
345       const TreeRef& value) {
346     return Attribute(Compound::create(TK_ATTRIBUTE, range, {name, value}));
347   }
348 };
349 
350 struct Param : public TreeView {
ParamParam351   explicit Param(const TreeRef& tree) : TreeView(tree) {
352     tree_->match(TK_PARAM);
353   }
createParam354   static Param create(
355       const SourceRange& range,
356       const Ident& ident,
357       const Maybe<Expr>& type,
358       const Maybe<Expr>& def,
359       bool kwarg_only) {
360     TreeRef kwarg_only_tree =
361         Compound::create(kwarg_only ? TK_TRUE : TK_FALSE, range, {});
362     return Param(Compound::create(
363         TK_PARAM, range, {ident, type, def, std::move(kwarg_only_tree)}));
364   }
identParam365   Ident ident() const {
366     return Ident(subtree(0));
367   }
typeParam368   Maybe<Expr> type() const {
369     return Maybe<Expr>(subtree(1));
370   }
defaultValueParam371   Maybe<Expr> defaultValue() const {
372     return Maybe<Expr>(subtree(2));
373   }
kwarg_onlyParam374   bool kwarg_only() const {
375     return TK_TRUE == subtree(3)->kind();
376   }
withTypeParam377   Param withType(const Maybe<Expr>& typ) const {
378     return Param::create(range(), ident(), typ, defaultValue(), kwarg_only());
379   }
380 };
381 
382 ////////////////////////////////////////////////////////////////////////////////
383 // Top level definitions
384 ////////////////////////////////////////////////////////////////////////////////
385 
386 struct Decl : public TreeView {
DeclDecl387   explicit Decl(const TreeRef& tree) : TreeView(tree) {
388     tree->match(TK_DECL);
389   }
paramsDecl390   List<Param> params() const {
391     return List<Param>(subtree(0));
392   }
return_typeDecl393   Maybe<Expr> return_type() const {
394     return Maybe<Expr>(subtree(1));
395   }
createDecl396   static Decl create(
397       const SourceRange& range,
398       const List<Param>& params,
399       const Maybe<Expr>& return_type) {
400     return Decl(Compound::create(TK_DECL, range, {params, return_type}));
401   }
402 };
403 
404 struct Def : public TreeView {
DefDef405   explicit Def(const TreeRef& tree) : TreeView(tree) {
406     tree->match(TK_DEF);
407   }
withNameDef408   Def withName(std::string new_name) const {
409     auto new_ident = Ident::create(name().range(), std::move(new_name));
410     return create(range(), new_ident, decl(), statements());
411   }
withDeclDef412   Def withDecl(const Decl& decl) const {
413     return create(range(), name(), decl, statements());
414   }
nameDef415   Ident name() const {
416     return Ident(subtree(0));
417   }
declDef418   Decl decl() const {
419     return Decl(subtree(1));
420   }
statementsDef421   List<Stmt> statements() const {
422     return List<Stmt>(subtree(2));
423   }
createDef424   static Def create(
425       const SourceRange& range,
426       const Ident& name,
427       const Decl& decl,
428       const List<Stmt>& stmts) {
429     return Def(Compound::create(TK_DEF, range, {name, decl, stmts}));
430   }
431 };
432 
433 // Property represents a named attribute combined with a getter and setter
434 // method to access and mutate that attribute.
435 struct Property : public TreeView {
PropertyProperty436   explicit Property(const TreeRef& tree) : TreeView(tree) {
437     tree->match(TK_PROP);
438   }
nameProperty439   Ident name() const {
440     return Ident(subtree(0));
441   }
getterProperty442   Def getter() const {
443     return Def(subtree(1));
444   }
setterProperty445   Maybe<Def> setter() const {
446     return Maybe<Def>(subtree(2));
447   }
createProperty448   static Property create(
449       const SourceRange& range,
450       const Ident& name,
451       const Def& getter,
452       const Maybe<Def>& setter) {
453     return Property(Compound::create(TK_PROP, range, {name, getter, setter}));
454   }
455 };
456 
457 struct Assign;
458 
459 struct ClassDef : public TreeView {
ClassDefClassDef460   explicit ClassDef(const TreeRef& tree) : TreeView(tree) {
461     tree->match(TK_CLASS_DEF);
462   }
ClassDefClassDef463   explicit ClassDef(TreeRef&& tree) : TreeView(std::move(tree)) {
464     tree_->match(TK_CLASS_DEF);
465   }
withNameClassDef466   ClassDef withName(std::string new_name) const {
467     auto new_ident = Ident::create(name().range(), std::move(new_name));
468     return create(range(), new_ident, superclass(), body());
469   }
nameClassDef470   Ident name() const {
471     return Ident(subtree(0));
472   }
superclassClassDef473   Maybe<Expr> superclass() const {
474     return Maybe<Expr>(subtree(1));
475   }
bodyClassDef476   List<Stmt> body() const {
477     return List<Stmt>(subtree(2));
478   }
propertiesClassDef479   Maybe<List<Property>> properties() const {
480     return Maybe<List<Property>>(subtree(3));
481   }
assignsClassDef482   Maybe<List<Assign>> assigns() const {
483     return Maybe<List<Assign>>(subtree(4));
484   }
createClassDef485   static ClassDef create(
486       const SourceRange& range,
487       const Ident& name,
488       const Maybe<Expr>& superclass,
489       const List<Stmt>& body) {
490     return ClassDef(Compound::create(
491         TK_CLASS_DEF,
492         range,
493         {name,
494          superclass,
495          body,
496          Maybe<List<Property>>::create(range),
497          Maybe<List<Assign>>::create(range)}));
498   }
499   static ClassDef create(
500       const SourceRange& range,
501       const Ident& name,
502       const Maybe<Expr>& superclass,
503       const List<Stmt>& body,
504       const List<Property>& properties,
505       const List<Assign>& assigns);
506 };
507 
508 TORCH_API std::vector<std::string> getUnresolvedClassAttributes(
509     const ClassDef& def);
510 
511 ////////////////////////////////////////////////////////////////////////////////
512 // Statements
513 ////////////////////////////////////////////////////////////////////////////////
514 
515 struct If : public Stmt {
IfIf516   explicit If(const TreeRef& tree) : Stmt(tree) {
517     tree_->match(TK_IF);
518   }
condIf519   Expr cond() const {
520     return Expr(subtree(0));
521   }
trueBranchIf522   List<Stmt> trueBranch() const {
523     return List<Stmt>(subtree(1));
524   }
falseBranchIf525   List<Stmt> falseBranch() const {
526     return List<Stmt>(subtree(2));
527   }
withNewBranchesIf528   If withNewBranches(
529       const List<Stmt>& true_branch,
530       const List<Stmt>& false_branch) const {
531     return create(range(), cond(), true_branch, false_branch);
532   }
createIf533   static If create(
534       const SourceRange& range,
535       const Expr& cond,
536       const List<Stmt>& true_branch,
537       const List<Stmt>& false_branch) {
538     return If(
539         Compound::create(TK_IF, range, {cond, true_branch, false_branch}));
540   }
541 };
542 
543 struct While : public Stmt {
WhileWhile544   explicit While(const TreeRef& tree) : Stmt(tree) {
545     tree_->match(TK_WHILE);
546   }
condWhile547   Expr cond() const {
548     return Expr(subtree(0));
549   }
bodyWhile550   List<Stmt> body() const {
551     return List<Stmt>(subtree(1));
552   }
createWhile553   static While create(
554       const SourceRange& range,
555       const Expr& cond,
556       const List<Stmt>& body) {
557     return While(Compound::create(TK_WHILE, range, {cond, body}));
558   }
559 };
560 
561 struct For : public Stmt {
ForFor562   explicit For(const TreeRef& tree) : Stmt(tree) {
563     tree->match(TK_FOR);
564   }
targetsFor565   List<Expr> targets() const {
566     return List<Expr>(subtree(0));
567   }
itrsFor568   List<Expr> itrs() const {
569     return List<Expr>(subtree(1));
570   }
bodyFor571   List<Stmt> body() const {
572     return List<Stmt>(subtree(2));
573   }
createFor574   static For create(
575       const SourceRange& range,
576       const List<Expr>& targets,
577       const List<Expr>& itrs,
578       const List<Stmt>& body) {
579     return For(Compound::create(TK_FOR, range, {targets, itrs, body}));
580   }
581 };
582 
583 // TODO: supports only single comprehension for now
584 struct ListComp : public Expr {
ListCompListComp585   explicit ListComp(const TreeRef& tree) : Expr(tree) {
586     tree->match(TK_LIST_COMP);
587   }
eltListComp588   Expr elt() const {
589     return Expr(subtree(0));
590   }
targetListComp591   Expr target() const {
592     return Expr(subtree(1));
593   }
iterListComp594   Expr iter() const {
595     return Expr(subtree(2));
596   }
597   // TODO: no ifs for now
createListComp598   static ListComp create(
599       const SourceRange& range,
600       const Expr& elt,
601       const Expr& target,
602       const Expr& iter) {
603     return ListComp(Compound::create(TK_LIST_COMP, range, {elt, target, iter}));
604   }
605 };
606 
607 // TODO: supports only single comprehension for now
608 struct DictComp : public Expr {
DictCompDictComp609   explicit DictComp(const TreeRef& tree) : Expr(tree) {
610     tree->match(TK_DICT_COMP);
611   }
keyDictComp612   Expr key() const {
613     return Expr(subtree(0));
614   }
valueDictComp615   Expr value() const {
616     return Expr(subtree(1));
617   }
targetDictComp618   Expr target() const {
619     return Expr(subtree(2));
620   }
iterDictComp621   Expr iter() const {
622     return Expr(subtree(3));
623   }
624   // TODO: no ifs for now
createDictComp625   static DictComp create(
626       const SourceRange& range,
627       const Expr& key,
628       const Expr& value,
629       const Expr& target,
630       const Expr& iter) {
631     return DictComp(
632         Compound::create(TK_DICT_COMP, range, {key, value, target, iter}));
633   }
634 };
635 
636 struct Global : public Stmt {
GlobalGlobal637   explicit Global(const TreeRef& tree) : Stmt(tree) {
638     tree_->match(TK_GLOBAL);
639   }
namesGlobal640   List<Ident> names() {
641     return List<Ident>(subtree(0));
642   }
createGlobal643   static Global create(const SourceRange& range, const List<Ident>& names) {
644     return Global(Compound::create(TK_GLOBAL, range, {names}));
645   }
646 };
647 
648 struct AugAssignKind : public TreeView {
AugAssignKindAugAssignKind649   explicit AugAssignKind(const TreeRef& tree) : TreeView(tree) {
650     switch (tree->kind()) {
651       case '+':
652       case '-':
653       case '*':
654       case '/':
655       case '%':
656       case '|':
657       case '&':
658       case '^':
659       case TK_POW:
660       case TK_LSHIFT:
661       case TK_RSHIFT:
662         return;
663       default:
664         throw(ErrorReport(tree) << "is not a valid AugAssignKind");
665     }
666   }
667 };
668 
669 // Augmented assignment, like "foo += bar"
670 struct AugAssign : public Stmt {
AugAssignAugAssign671   explicit AugAssign(const TreeRef& tree) : Stmt(tree) {
672     tree_->match(TK_AUG_ASSIGN);
673   }
createAugAssign674   static AugAssign create(
675       const SourceRange& range,
676       const Expr& lhs,
677       const AugAssignKind& aug_op,
678       const Expr& rhs) {
679     return AugAssign(
680         Compound::create(TK_AUG_ASSIGN, range, {lhs, aug_op, rhs}));
681   }
lhsAugAssign682   Expr lhs() const {
683     return Expr(subtree(0));
684   }
aug_opAugAssign685   int aug_op() const {
686     return subtree(1)->kind();
687   }
rhsAugAssign688   Expr rhs() const {
689     return Expr(subtree(2));
690   }
691 };
692 
693 struct Assign : public Stmt {
AssignAssign694   explicit Assign(const TreeRef& tree) : Stmt(tree) {
695     tree_->match(TK_ASSIGN);
696   }
createAssign697   static Assign create(
698       const SourceRange& range,
699       const List<Expr>& lhs,
700       const Maybe<Expr>& rhs,
701       const Maybe<Expr>& type) {
702     return Assign(Compound::create(TK_ASSIGN, range, {lhs, rhs, type}));
703   }
704 
lhs_listAssign705   List<Expr> lhs_list() const {
706     return List<Expr>(subtree(0));
707   }
708 
lhsAssign709   Expr lhs() const {
710     const auto& li = lhs_list();
711     TORCH_INTERNAL_ASSERT(li.size() == 1);
712     return *li.begin();
713   }
714 
rhsAssign715   Maybe<Expr> rhs() const {
716     return Maybe<Expr>(subtree(1));
717   }
718 
typeAssign719   Maybe<Expr> type() const {
720     return Maybe<Expr>(subtree(2));
721   }
722 };
723 
724 struct Return : public Stmt {
ReturnReturn725   explicit Return(const TreeRef& tree) : Stmt(tree) {
726     tree_->match(TK_RETURN);
727   }
exprReturn728   Expr expr() const {
729     return Expr(subtree(0));
730   }
createReturn731   static Return create(const SourceRange& range, const Expr& value) {
732     return Return(Compound::create(TK_RETURN, range, {value}));
733   }
734 };
735 
736 struct Raise : public Stmt {
RaiseRaise737   explicit Raise(const TreeRef& tree) : Stmt(tree) {
738     tree_->match(TK_RAISE);
739   }
exprRaise740   Expr expr() const {
741     return Expr(subtree(0));
742   }
createRaise743   static Raise create(const SourceRange& range, const Expr& expr) {
744     return Raise(Compound::create(TK_RAISE, range, {expr}));
745   }
746 };
747 
748 struct Assert : public Stmt {
AssertAssert749   explicit Assert(const TreeRef& tree) : Stmt(tree) {
750     tree_->match(TK_ASSERT);
751   }
testAssert752   Expr test() const {
753     return Expr(subtree(0));
754   }
msgAssert755   Maybe<Expr> msg() const {
756     return Maybe<Expr>(subtree(1));
757   }
createAssert758   static Assert create(
759       const SourceRange& range,
760       const Expr& test,
761       const Maybe<Expr>& msg) {
762     return Assert(Compound::create(TK_ASSERT, range, {test, msg}));
763   }
764 };
765 
766 struct Pass : public Stmt {
PassPass767   explicit Pass(const TreeRef& tree) : Stmt(tree) {
768     tree_->match(TK_PASS);
769   }
createPass770   static Pass create(const SourceRange& range) {
771     return Pass(Compound::create(TK_PASS, range, {}));
772   }
773 };
774 
775 struct Dots : public Expr {
DotsDots776   explicit Dots(const TreeRef& tree) : Expr(tree) {
777     tree_->match(TK_DOTS);
778   }
createDots779   static Dots create(const SourceRange& range) {
780     return Dots(Compound::create(TK_DOTS, range, {}));
781   }
782 };
783 
784 struct Break : public Stmt {
BreakBreak785   explicit Break(const TreeRef& tree) : Stmt(tree) {
786     tree_->match(TK_BREAK);
787   }
createBreak788   static Break create(const SourceRange& range) {
789     return Break(Compound::create(TK_BREAK, range, {}));
790   }
791 };
792 
793 struct Continue : public Stmt {
ContinueContinue794   explicit Continue(const TreeRef& tree) : Stmt(tree) {
795     tree_->match(TK_CONTINUE);
796   }
createContinue797   static Continue create(const SourceRange& range) {
798     return Continue(Compound::create(TK_CONTINUE, range, {}));
799   }
800 };
801 
802 struct ExprStmt : public Stmt {
ExprStmtExprStmt803   explicit ExprStmt(const TreeRef& tree) : Stmt(tree) {
804     tree_->match(TK_EXPR_STMT);
805   }
exprExprStmt806   Expr expr() {
807     return Expr(subtree(0));
808   }
createExprStmt809   static ExprStmt create(const SourceRange& range, const Expr& list) {
810     return ExprStmt(Compound::create(TK_EXPR_STMT, range, {list}));
811   }
812 };
813 
814 ////////////////////////////////////////////////////////////////////////////////
815 // Expressions
816 ////////////////////////////////////////////////////////////////////////////////
817 
818 struct BinOp : public Expr {
BinOpBinOp819   explicit BinOp(const TreeRef& tree) : Expr(tree) {
820     switch (tree->kind()) {
821       case TK_AND:
822       case TK_OR:
823       case '<':
824       case '>':
825       case TK_IS:
826       case TK_ISNOT:
827       case TK_EQ:
828       case TK_LE:
829       case TK_GE:
830       case TK_NE:
831       case '+':
832       case '*':
833       case '/':
834       case '-':
835       case '@':
836       case TK_POW:
837       case TK_LSHIFT:
838       case TK_RSHIFT:
839       case '%':
840       case '&':
841       case '^':
842       case '|':
843       case TK_FLOOR_DIV:
844       case TK_IN:
845         if (tree->trees().size() != 2)
846           throw(
847               ErrorReport(tree)
848               << "BinOp expected 2 subtrees, found " << tree->trees().size());
849         return;
850       default:
851         throw(
852             ErrorReport(tree)
853             << kindToString(tree->kind()) << " is not a valid BinOp");
854     }
855   }
lhsBinOp856   Expr lhs() const {
857     return Expr(subtree(0));
858   }
rhsBinOp859   Expr rhs() const {
860     return Expr(subtree(1));
861   }
createBinOp862   static BinOp create(
863       const SourceRange& range,
864       int kind,
865       const Expr& lhs,
866       const Expr& rhs) {
867     return BinOp(Compound::create(kind, range, {lhs, rhs}));
868   }
869 };
870 
871 struct UnaryOp : public Expr {
UnaryOpUnaryOp872   explicit UnaryOp(const TreeRef& tree) : Expr(tree) {
873     switch (tree->kind()) {
874       case TK_UNARY_MINUS:
875       case '~':
876       case TK_NOT:
877         if (tree->trees().size() != 1)
878           throw(
879               ErrorReport(tree)
880               << "UnaryOp expected 1 subtree, found " << tree->trees().size());
881         return;
882       default:
883         throw(
884             ErrorReport(tree)
885             << kindToString(tree->kind()) << " is not a valid UnaryOp");
886     }
887   }
createUnaryOp888   static UnaryOp create(const SourceRange& range, int kind, const Expr& expr) {
889     return UnaryOp(Compound::create(kind, range, {expr}));
890   }
891 };
892 
893 struct Const : public Expr {
ConstConst894   explicit Const(const TreeRef& tree) : Expr(tree) {
895     tree_->matchNumSubtrees(TK_CONST, 1);
896   }
isFloatingPointConst897   bool isFloatingPoint() const {
898     if (isComplex())
899       return false;
900 
901     bool is_inf = subtree(0)->stringValue() == "inf";
902     return is_inf ||
903         subtree(0)->stringValue().find_first_of(".eE") != std::string::npos;
904   }
isIntegralConst905   bool isIntegral() const {
906     return !isFloatingPoint() && !isComplex();
907   }
isComplexConst908   bool isComplex() const {
909     return subtree(0)->stringValue().find_first_of('j') != std::string::npos;
910   }
asIntegralConst911   int64_t asIntegral() const {
912     try {
913       return std::stoll(subtree(0)->stringValue(), nullptr, 0);
914     } catch (const std::out_of_range&) {
915       throw(
916           ErrorReport(range()) << "Integral constant out of range "
917                                   "(must fit in a signed 64 bit integer)");
918     }
919   }
asFloatingPointConst920   double asFloatingPoint() const {
921     // We can't pass in nullptr as the dummy pointer gets dereferenced for
922     // Android version of strtod_c().
923     char* dummy = nullptr;
924     return torch::jit::strtod_c(subtree(0)->stringValue().c_str(), &dummy);
925   }
asComplexConst926   c10::complex<double> asComplex() const {
927     char* dummy = nullptr;
928     auto str = subtree(0)->stringValue();
929     // Complex numbers (a+bj, where a is non-zero) are parsed as an addition
930     // between float/int a and a complex number "bj". When a is 0, a complex
931     // number bj is created as above. So, while parsing the string, we don't
932     // have to worry about the real component of the complex number.
933     auto imag =
934         torch::jit::strtod_c(str.substr(0, str.size() - 1).c_str(), &dummy);
935     return c10::complex<double>(0, imag);
936   }
textConst937   const std::string& text() const {
938     return subtree(0)->stringValue();
939   }
createConst940   static Const create(const SourceRange& range, const std::string& value) {
941     return Const(Compound::create(TK_CONST, range, {String::create(value)}));
942   }
943 };
944 
945 struct StringLiteral : public Expr {
StringLiteralStringLiteral946   explicit StringLiteral(const TreeRef& tree) : Expr(tree) {
947     tree_->matchNumSubtrees(TK_STRINGLITERAL, 1);
948   }
textStringLiteral949   const std::string& text() const {
950     return subtree(0)->stringValue();
951   }
createStringLiteral952   static StringLiteral create(
953       const SourceRange& range,
954       const std::string& value) {
955     return StringLiteral(
956         Compound::create(TK_STRINGLITERAL, range, {String::create(value)}));
957   }
958 };
959 
960 struct Apply : public Expr {
ApplyApply961   explicit Apply(const TreeRef& tree) : Expr(tree) {
962     tree_->match(TK_APPLY);
963   }
calleeApply964   Expr callee() const {
965     return Expr(subtree(0));
966   }
inputsApply967   List<Expr> inputs() const {
968     return List<Expr>(subtree(1));
969   }
attributesApply970   List<Attribute> attributes() const {
971     return List<Attribute>(subtree(2));
972   }
createApply973   static Apply create(
974       const SourceRange& range,
975       const Expr& callee,
976       const List<Expr>& inputs,
977       const List<Attribute>& attributes) {
978     return Apply(
979         Compound::create(TK_APPLY, range, {callee, inputs, attributes}));
980   }
981 };
982 
983 struct Select : public Expr {
SelectSelect984   explicit Select(const TreeRef& tree) : Expr(tree) {
985     tree_->match('.');
986   }
valueSelect987   Expr value() const {
988     return Expr(subtree(0));
989   }
selectorSelect990   Ident selector() const {
991     return Ident(subtree(1));
992   }
createSelect993   static Select create(
994       const SourceRange& range,
995       const Expr& value,
996       const Ident& selector) {
997     return Select(Compound::create('.', range, {value, selector}));
998   }
999 };
1000 
1001 struct SliceExpr : public Expr {
SliceExprSliceExpr1002   explicit SliceExpr(const TreeRef& tree) : Expr(tree) {
1003     tree_->match(TK_SLICE_EXPR);
1004   }
startSliceExpr1005   Maybe<Expr> start() const {
1006     return Maybe<Expr>(subtree(0));
1007   }
endSliceExpr1008   Maybe<Expr> end() const {
1009     return Maybe<Expr>(subtree(1));
1010   }
stepSliceExpr1011   Maybe<Expr> step() const {
1012     return Maybe<Expr>(subtree(2));
1013   }
startOrSliceExpr1014   Expr startOr(int64_t alternative) const {
1015     const auto startOption = start();
1016     return startOption.present() ? startOption.get() : createInt(alternative);
1017   }
endOrSliceExpr1018   Expr endOr(int64_t alternative) const {
1019     const auto endOption = end();
1020     return endOption.present() ? endOption.get() : createInt(alternative);
1021   }
stepOrSliceExpr1022   Expr stepOr(int64_t alternative) const {
1023     const auto stepOption = step();
1024     return stepOption.present() ? stepOption.get() : createInt(alternative);
1025   }
createSliceExpr1026   static SliceExpr create(
1027       const SourceRange& range,
1028       const Maybe<Expr>& start,
1029       const Maybe<Expr>& end,
1030       const Maybe<Expr>& step) {
1031     return SliceExpr(
1032         Compound::create(TK_SLICE_EXPR, range, {start, end, step}));
1033   }
1034 
1035  private:
createIntSliceExpr1036   Expr createInt(int64_t value) const {
1037     return Expr(Const::create(range(), std::to_string(value)));
1038   }
1039 };
1040 
1041 struct Subscript : public Expr {
SubscriptSubscript1042   explicit Subscript(const TreeRef& tree) : Expr(tree) {
1043     tree_->match(TK_SUBSCRIPT);
1044   }
valueSubscript1045   Expr value() const {
1046     return Expr(subtree(0));
1047   }
subscript_exprsSubscript1048   List<Expr> subscript_exprs() const {
1049     return List<Expr>(subtree(1));
1050   }
createSubscript1051   static Subscript create(
1052       const SourceRange& range,
1053       const Expr& value,
1054       const List<Expr>& subscript_exprs) {
1055     auto whole_range = SourceRange(
1056         range.source(), range.start(), subscript_exprs.range().end() + 1);
1057     return Subscript(
1058         Compound::create(TK_SUBSCRIPT, whole_range, {value, subscript_exprs}));
1059   }
1060 };
1061 
1062 struct Var : public Expr {
VarVar1063   explicit Var(const TreeRef& tree) : Expr(tree) {
1064     tree_->match(TK_VAR);
1065   };
nameVar1066   Ident name() const {
1067     return Ident(subtree(0));
1068   }
createVar1069   static Var create(const SourceRange& range, const Ident& name) {
1070     return Var(Compound::create(TK_VAR, range, {name}));
1071   }
1072 };
1073 
1074 // WithItem represents an item using with a WithStmt.
1075 struct WithItem : public Expr {
WithItemWithItem1076   explicit WithItem(const TreeRef& tree) : Expr(tree) {
1077     tree_->match(TK_WITH_ITEM);
1078   }
1079 
targetWithItem1080   Expr target() const {
1081     return Expr(subtree(0));
1082   }
1083 
varWithItem1084   Maybe<Var> var() const {
1085     return Maybe<Var>(subtree(1));
1086   }
1087 
createWithItem1088   static WithItem create(
1089       const SourceRange& range,
1090       const Expr& target,
1091       const Maybe<Var>& var) {
1092     return WithItem(Compound::create(TK_WITH_ITEM, range, {target, var}));
1093   }
1094 };
1095 
1096 // With represents a with statement consisting of a list of with items and a
1097 // body of statements.
1098 struct With : public Stmt {
WithWith1099   explicit With(const TreeRef& tree) : Stmt(tree) {
1100     tree_->match(TK_WITH);
1101   }
1102 
targetsWith1103   List<WithItem> targets() const {
1104     return List<WithItem>(subtree(0));
1105   }
1106 
bodyWith1107   List<Stmt> body() const {
1108     return List<Stmt>(subtree(1));
1109   }
1110 
createWith1111   static With create(
1112       const SourceRange& range,
1113       const List<WithItem>& targets,
1114       const List<Stmt>& body) {
1115     return With(Compound::create(TK_WITH, range, {targets, body}));
1116   }
1117 };
1118 
1119 struct TernaryIf : public Expr {
TernaryIfTernaryIf1120   explicit TernaryIf(const TreeRef& tree) : Expr(tree) {
1121     tree_->matchNumSubtrees(TK_IF_EXPR, 3);
1122   };
condTernaryIf1123   Expr cond() const {
1124     return Expr(subtree(0));
1125   }
true_exprTernaryIf1126   Expr true_expr() const {
1127     return Expr(subtree(1));
1128   }
false_exprTernaryIf1129   Expr false_expr() const {
1130     return Expr(subtree(2));
1131   }
createTernaryIf1132   static TernaryIf create(
1133       const SourceRange& range,
1134       const Expr& cond,
1135       const Expr& true_expr,
1136       const Expr& false_expr) {
1137     return TernaryIf(
1138         Compound::create(TK_IF_EXPR, range, {cond, true_expr, false_expr}));
1139   };
1140 };
1141 
1142 struct ListLiteral : public Expr {
ListLiteralListLiteral1143   explicit ListLiteral(const TreeRef& tree) : Expr(tree) {
1144     tree_->match(TK_LIST_LITERAL);
1145   }
inputsListLiteral1146   List<Expr> inputs() const {
1147     return subtree(0);
1148   }
createListLiteral1149   static ListLiteral create(
1150       const SourceRange& range,
1151       const List<Expr>& inputs) {
1152     return ListLiteral(Compound::create(TK_LIST_LITERAL, range, {inputs}));
1153   }
1154 };
1155 
1156 struct TupleLiteral : public Expr {
TupleLiteralTupleLiteral1157   explicit TupleLiteral(const TreeRef& tree) : Expr(tree) {
1158     tree_->match(TK_TUPLE_LITERAL);
1159   }
inputsTupleLiteral1160   List<Expr> inputs() const {
1161     return subtree(0);
1162   }
createTupleLiteral1163   static TupleLiteral create(
1164       const SourceRange& range,
1165       const List<Expr>& inputs) {
1166     return TupleLiteral(Compound::create(TK_TUPLE_LITERAL, range, {inputs}));
1167   }
1168 };
1169 
1170 struct DictLiteral : public Expr {
DictLiteralDictLiteral1171   explicit DictLiteral(const TreeRef& tree) : Expr(tree) {
1172     tree_->match(TK_DICT_LITERAL);
1173   }
key_inputsDictLiteral1174   List<Expr> key_inputs() const {
1175     return subtree(0);
1176   }
value_inputsDictLiteral1177   List<Expr> value_inputs() const {
1178     return subtree(1);
1179   }
createDictLiteral1180   static DictLiteral create(
1181       const SourceRange& range,
1182       const List<Expr>& keys,
1183       const List<Expr>& values) {
1184     return DictLiteral(
1185         Compound::create(TK_DICT_LITERAL, range, {keys, values}));
1186   }
1187 };
1188 
1189 struct Starred : public Expr {
StarredStarred1190   explicit Starred(const TreeRef& tree) : Expr(tree) {
1191     tree_->match(TK_STARRED);
1192   }
exprStarred1193   Expr expr() const {
1194     return Expr(subtree(0));
1195   }
createStarred1196   static Starred create(const SourceRange& range, const Expr& expr) {
1197     return Starred(Compound::create(TK_STARRED, range, {expr}));
1198   }
1199 };
1200 
1201 struct Delete : public Stmt {
DeleteDelete1202   explicit Delete(const TreeRef& tree) : Stmt(tree) {
1203     tree_->match(TK_DELETE);
1204   }
targetsDelete1205   List<Expr> targets() const {
1206     return subtree(0);
1207   }
createDelete1208   static Delete create(const SourceRange& range, const List<Expr>& targets) {
1209     return Delete(Compound::create(TK_DELETE, range, {targets}));
1210   }
1211 };
1212 
1213 /*
1214  * NOTE: transforming PEP 604 union into equivalent union type
1215  *
1216  * NOTE: Union[int, float] parses into:
1217  * <EXPR> expr:(subscript
1218  *  (variable (ident Union))
1219  *  (list
1220  *    (variable (ident int))
1221  *    (variable (ident float))))
1222  * <KIND> subscript
1223  *
1224  * NOTE: (int | float) parses into:
1225  * <EXPR> expr:(|
1226  *  (variable (ident int))
1227  *  (variable (ident float)))
1228  * <KIND> |
1229  */
1230 
_flatten_pep604_union(const torch::jit::Expr & node,std::vector<torch::jit::Expr> * result)1231 inline void _flatten_pep604_union(
1232     const torch::jit::Expr& node,
1233     std::vector<torch::jit::Expr>* result) {
1234   // flatten possibly nested union expressions like (int | (float | str))
1235   // into a flat list of expressions like [int, float, str]
1236   if (node.kind() == '|') {
1237     auto as_binop = torch::jit::BinOp(node);
1238     _flatten_pep604_union(as_binop.lhs(), result);
1239     _flatten_pep604_union(as_binop.rhs(), result);
1240   } else {
1241     result->push_back(node);
1242   }
1243 }
1244 
get_pep604_union_members(const Expr & node)1245 inline std::vector<Expr> get_pep604_union_members(const Expr& node) {
1246   std::vector<Expr> result;
1247   _flatten_pep604_union(node, &result);
1248   return result;
1249 }
1250 
1251 // Flattens a PEP 604 union into a classical union.
1252 // For example, ((x | y) | z) is transformed into Union[x, y, z].
pep604union_to_union(const Expr & expr)1253 inline Expr pep604union_to_union(const Expr& expr) {
1254   // noop if not a pep604 union
1255   if (expr.kind() != '|')
1256     return expr;
1257 
1258   // In order to support unions with more than 2 operands ((x|y)|z), we need to
1259   // recursively flatten the tree of | expressions.
1260   auto members = get_pep604_union_members(expr);
1261   auto synthesised_union = Subscript::create(
1262       expr.range(),
1263       Var::create(expr.range(), Ident::create(expr.range(), "Union")),
1264       List<Expr>::create(expr.range(), members));
1265   return std::move(synthesised_union);
1266 }
1267 
1268 } // namespace torch::jit
1269 
1270 namespace std {
1271 
1272 template <typename T>
1273 struct iterator_traits<torch::jit::ListIterator<T>>
1274     : std::iterator_traits<torch::jit::TreeList::const_iterator> {};
1275 
1276 } // namespace std
1277