xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/parser.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/parser.h>
2 
3 #include <torch/csrc/jit/frontend/lexer.h>
4 #include <torch/csrc/jit/frontend/parse_string_literal.h>
5 #include <torch/csrc/jit/frontend/tree.h>
6 #include <torch/csrc/jit/frontend/tree_views.h>
7 #include <optional>
8 
9 namespace torch::jit {
10 
mergeTypesFromTypeComment(const Decl & decl,const Decl & type_annotation_decl,bool is_method)11 Decl mergeTypesFromTypeComment(
12     const Decl& decl,
13     const Decl& type_annotation_decl,
14     bool is_method) {
15   auto expected_num_annotations = decl.params().size();
16   if (is_method) {
17     // `self` argument
18     expected_num_annotations -= 1;
19   }
20   if (expected_num_annotations != type_annotation_decl.params().size()) {
21     throw ErrorReport(decl.range())
22         << "Number of type annotations ("
23         << type_annotation_decl.params().size()
24         << ") did not match the number of "
25         << (is_method ? "method" : "function") << " parameters ("
26         << expected_num_annotations << ")";
27   }
28   auto old = decl.params();
29   auto _new = type_annotation_decl.params();
30   // Merge signature idents and ranges with annotation types
31 
32   std::vector<Param> new_params;
33   size_t i = is_method ? 1 : 0;
34   size_t j = 0;
35   if (is_method) {
36     new_params.push_back(old[0]);
37   }
38   for (; i < decl.params().size(); ++i, ++j) {
39     new_params.emplace_back(old[i].withType(_new[j].type()));
40   }
41   return Decl::create(
42       decl.range(),
43       List<Param>::create(decl.range(), new_params),
44       type_annotation_decl.return_type());
45 }
46 
47 struct ParserImpl {
ParserImpltorch::jit::ParserImpl48   explicit ParserImpl(const std::shared_ptr<Source>& source)
49       : L(source), shared(sharedParserData()) {}
50 
parseIdenttorch::jit::ParserImpl51   Ident parseIdent() {
52     auto t = L.expect(TK_IDENT);
53     // whenever we parse something that has a TreeView type we always
54     // use its create method so that the accessors and the constructor
55     // of the Compound tree are in the same place.
56     return Ident::create(t.range, t.text());
57   }
createApplytorch::jit::ParserImpl58   TreeRef createApply(const Expr& expr) {
59     TreeList attributes;
60     auto range = L.cur().range;
61     TreeList inputs;
62     parseArguments(inputs, attributes);
63     return Apply::create(
64         range,
65         expr,
66         List<Expr>(makeList(range, std::move(inputs))),
67         List<Attribute>(makeList(range, std::move(attributes))));
68   }
69 
followsTupletorch::jit::ParserImpl70   static bool followsTuple(int kind) {
71     switch (kind) {
72       case TK_PLUS_EQ:
73       case TK_MINUS_EQ:
74       case TK_TIMES_EQ:
75       case TK_DIV_EQ:
76       case TK_MOD_EQ:
77       case TK_BIT_OR_EQ:
78       case TK_BIT_AND_EQ:
79       case TK_BIT_XOR_EQ:
80       case TK_LSHIFT_EQ:
81       case TK_RSHIFT_EQ:
82       case TK_POW_EQ:
83       case TK_NEWLINE:
84       case '=':
85       case ')':
86         return true;
87       default:
88         return false;
89     }
90   }
91 
92   // exp | expr, | expr, expr, ...
parseExpOrExpTupletorch::jit::ParserImpl93   Expr parseExpOrExpTuple() {
94     auto prefix = parseExp();
95     if (L.cur().kind == ',') {
96       std::vector<Expr> exprs = {prefix};
97       while (L.nextIf(',')) {
98         if (followsTuple(L.cur().kind))
99           break;
100         exprs.push_back(parseExp());
101       }
102       auto list = List<Expr>::create(prefix.range(), exprs);
103       prefix = TupleLiteral::create(list.range(), list);
104     }
105     return prefix;
106   }
107   // things like a 1.0 or a(4) that are not unary/binary expressions
108   // and have higher precedence than all of them
parseBaseExptorch::jit::ParserImpl109   TreeRef parseBaseExp() {
110     TreeRef prefix;
111     switch (L.cur().kind) {
112       case TK_NUMBER: {
113         prefix = parseConst();
114       } break;
115       case TK_TRUE:
116       case TK_FALSE:
117       case TK_NONE:
118       case TK_NONE_TYPE: {
119         auto k = L.cur().kind;
120         auto r = L.cur().range;
121         prefix = create_compound(k, r, {});
122         L.next();
123       } break;
124       case '(': {
125         L.next();
126         if (L.nextIf(')')) {
127           /// here we have the empty tuple case
128           std::vector<Expr> vecExpr;
129           List<Expr> listExpr = List<Expr>::create(L.cur().range, vecExpr);
130           prefix = TupleLiteral::create(L.cur().range, listExpr);
131           break;
132         }
133         prefix = parseExpOrExpTuple();
134         L.expect(')');
135       } break;
136       case '[': {
137         auto list = parseList('[', ',', ']', &ParserImpl::parseExp);
138 
139         if (list.size() == 1 && (*list.begin()).kind() == TK_LIST_COMP) {
140           prefix = *list.begin();
141         } else {
142           for (auto se : list) {
143             if (se.kind() == TK_LIST_COMP) {
144               throw ErrorReport(list.range())
145                   << " expected a single list comprehension within '[' , ']'";
146             }
147           }
148           prefix = ListLiteral::create(list.range(), List<Expr>(list));
149         }
150 
151       } break;
152       case '{': {
153         L.next();
154         // If we have a dict literal, `keys` and `values` will store the keys
155         // and values used in the object's construction. EDGE CASE: We have a
156         // dict comprehension, so we'll get the first element of the dict
157         // comprehension in `keys` and a list comprehension in `values`.
158         // For example, `{i : chr(i + 65) for i in range(4)}` would give us
159         // `i` in `keys` and `chr(i + 65) for i in range(4)` in `values`.
160         // The optimal way of handling this case is to simply splice the new
161         // dict comprehension together from the existing list comprehension.
162         // Splicing prevents breaking changes to our API and does not require
163         // the use of global variables.
164         std::vector<Expr> keys;
165         std::vector<Expr> values;
166         auto range = L.cur().range;
167         if (L.cur().kind != '}') {
168           do {
169             keys.push_back(parseExp());
170             L.expect(':');
171             values.push_back(parseExp());
172           } while (L.nextIf(','));
173         }
174         L.expect('}');
175         if (keys.size() == 1 && (*values.begin()).kind() == TK_LIST_COMP) {
176           ListComp lc(*values.begin());
177           prefix = DictComp::create(
178               range, *keys.begin(), lc.elt(), lc.target(), lc.iter());
179         } else {
180           prefix = DictLiteral::create(
181               range,
182               List<Expr>::create(range, keys),
183               List<Expr>::create(range, values));
184         }
185       } break;
186       case TK_STRINGLITERAL: {
187         prefix = parseConcatenatedStringLiterals();
188       } break;
189       case TK_ELLIPSIS:
190       case TK_DOTS: {
191         prefix = Dots::create(L.cur().range);
192         L.next();
193       } break;
194       default: {
195         Ident name = parseIdent();
196         prefix = Var::create(name.range(), name);
197       } break;
198     }
199     while (true) {
200       if (L.nextIf('.')) {
201         const auto name = parseIdent();
202         prefix = Select::create(name.range(), Expr(prefix), Ident(name));
203       } else if (L.cur().kind == '(') {
204         prefix = createApply(Expr(prefix));
205       } else if (L.cur().kind == '[') {
206         prefix = parseSubscript(prefix);
207       } else {
208         break;
209       }
210     }
211     return prefix;
212   }
maybeParseAssignmentOptorch::jit::ParserImpl213   std::optional<TreeRef> maybeParseAssignmentOp() {
214     auto r = L.cur().range;
215     switch (L.cur().kind) {
216       case TK_PLUS_EQ:
217       case TK_MINUS_EQ:
218       case TK_TIMES_EQ:
219       case TK_DIV_EQ:
220       case TK_BIT_OR_EQ:
221       case TK_BIT_AND_EQ:
222       case TK_BIT_XOR_EQ:
223       case TK_MOD_EQ: {
224         int modifier = L.next().text()[0];
225         return create_compound(modifier, r, {});
226       } break;
227       case TK_LSHIFT_EQ: {
228         L.next();
229         return create_compound(TK_LSHIFT, r, {});
230       } break;
231       case TK_RSHIFT_EQ: {
232         L.next();
233         return create_compound(TK_RSHIFT, r, {});
234       } break;
235       case TK_POW_EQ: {
236         L.next();
237         return create_compound(TK_POW, r, {});
238       } break;
239       case '=': {
240         L.next();
241         return create_compound('=', r, {}); // no reduction
242       } break;
243       default:
244         return std::nullopt;
245     }
246   }
parseTrinarytorch::jit::ParserImpl247   TreeRef parseTrinary(
248       TreeRef true_branch,
249       const SourceRange& range,
250       int binary_prec) {
251     auto cond = parseExp();
252     L.expect(TK_ELSE);
253     auto false_branch = parseExp(binary_prec);
254     return create_compound(
255         TK_IF_EXPR, range, {cond, std::move(true_branch), false_branch});
256   }
257   // parse the longest expression whose binary operators have
258   // precedence strictly greater than 'precedence'
259   // precedence == 0 will parse _all_ expressions
260   // this is the core loop of 'top-down precedence parsing'
parseExptorch::jit::ParserImpl261   Expr parseExp() {
262     return parseExp(0);
263   }
parseExptorch::jit::ParserImpl264   Expr parseExp(int precedence) {
265     TreeRef prefix;
266     int unary_prec = 0;
267     if (shared.isUnary(L.cur().kind, &unary_prec)) {
268       auto kind = L.cur().kind;
269       auto pos = L.cur().range;
270       L.next();
271       auto unary_kind = kind == '*' ? TK_STARRED
272           : kind == '-'             ? TK_UNARY_MINUS
273                                     : kind;
274       auto subexp = parseExp(unary_prec);
275       // fold '-' into constant numbers, so that attributes can accept
276       // things like -1
277       if (unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
278         prefix = Const::create(subexp.range(), "-" + Const(subexp).text());
279       } else {
280         prefix = create_compound(unary_kind, pos, {subexp});
281       }
282     } else {
283       prefix = parseBaseExp();
284     }
285     int binary_prec = 0;
286     while (shared.isBinary(L.cur().kind, &binary_prec)) {
287       if (binary_prec <= precedence) // not allowed to parse something which is
288         // not greater than 'precedence'
289         break;
290 
291       int kind = L.cur().kind;
292       auto pos = L.cur().range;
293       L.next();
294       if (shared.isRightAssociative(kind))
295         binary_prec--;
296 
297       if (kind == TK_NOTIN) {
298         // NB: `not in` is just `not( in )`, so we don't introduce new tree view
299         // but just make it a nested call in our tree view structure
300         prefix = create_compound(TK_IN, pos, {prefix, parseExp(binary_prec)});
301         prefix = create_compound(TK_NOT, pos, {prefix});
302         continue;
303       }
304 
305       // special case for trinary operator
306       if (kind == TK_IF) {
307         prefix = parseTrinary(prefix, pos, binary_prec);
308         continue;
309       }
310 
311       if (kind == TK_FOR) {
312         // TK_FOR targets should only parse exprs prec greater than 4, which
313         // only includes subset of Exprs that suppose to be on the LHS according
314         // to the python grammar
315         // https://docs.python.org/3/reference/grammar.html
316         auto target = parseLHSExp();
317         L.expect(TK_IN);
318         auto iter = parseExp();
319         prefix = ListComp::create(pos, Expr(prefix), target, iter);
320         continue;
321       }
322 
323       prefix = create_compound(kind, pos, {prefix, parseExp(binary_prec)});
324     }
325     return Expr(prefix);
326   }
327 
parseSequencetorch::jit::ParserImpl328   void parseSequence(
329       int begin,
330       int sep,
331       int end,
332       const std::function<void()>& parse) {
333     if (begin != TK_NOTHING) {
334       L.expect(begin);
335     }
336     while (end != L.cur().kind) {
337       parse();
338       if (!L.nextIf(sep)) {
339         if (end != TK_NOTHING) {
340           L.expect(end);
341         }
342         return;
343       }
344     }
345     L.expect(end);
346   }
347   template <typename T>
parseListtorch::jit::ParserImpl348   List<T> parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) {
349     auto r = L.cur().range;
350     std::vector<T> elements;
351     parseSequence(
352         begin, sep, end, [&] { elements.emplace_back((this->*parse)()); });
353     return List<T>::create(r, elements);
354   }
355 
parseConsttorch::jit::ParserImpl356   Const parseConst() {
357     auto range = L.cur().range;
358     auto t = L.expect(TK_NUMBER);
359     return Const::create(t.range, t.text());
360   }
361 
parseConcatenatedStringLiteralstorch::jit::ParserImpl362   StringLiteral parseConcatenatedStringLiterals() {
363     auto range = L.cur().range;
364     std::string ss;
365     while (L.cur().kind == TK_STRINGLITERAL) {
366       auto literal_range = L.cur().range;
367       ss.append(parseStringLiteral(literal_range, L.next().text()));
368     }
369     return StringLiteral::create(range, ss);
370   }
371 
parseAttributeValuetorch::jit::ParserImpl372   Expr parseAttributeValue() {
373     return parseExp();
374   }
375 
parseArgumentstorch::jit::ParserImpl376   void parseArguments(TreeList& inputs, TreeList& attributes) {
377     parseSequence('(', ',', ')', [&] {
378       if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') {
379         auto ident = parseIdent();
380         L.expect('=');
381         auto v = parseAttributeValue();
382         attributes.push_back(Attribute::create(ident.range(), Ident(ident), v));
383       } else {
384         inputs.push_back(parseExp());
385       }
386     });
387   }
388 
389   // parse LHS acceptable exprs, which only includes subset of Exprs that prec
390   // is greater than 4 according to the python grammar
parseLHSExptorch::jit::ParserImpl391   Expr parseLHSExp() {
392     return parseExp(4);
393   }
394 
395   // Parse expr's of the form [a:], [:b], [a:b], [:] and all variations with
396   // "::"
parseSubscriptExptorch::jit::ParserImpl397   Expr parseSubscriptExp() {
398     TreeRef first, second, third;
399     auto range = L.cur().range;
400     if (L.cur().kind != ':') {
401       first = parseExp();
402     }
403     if (L.nextIf(':')) {
404       if (L.cur().kind != ',' && L.cur().kind != ']' && L.cur().kind != ':') {
405         second = parseExp();
406       }
407       if (L.nextIf(':')) {
408         if (L.cur().kind != ',' && L.cur().kind != ']') {
409           third = parseExp();
410         }
411       }
412       auto maybe_first = first ? Maybe<Expr>::create(range, Expr(first))
413                                : Maybe<Expr>::create(range);
414       auto maybe_second = second ? Maybe<Expr>::create(range, Expr(second))
415                                  : Maybe<Expr>::create(range);
416       auto maybe_third = third ? Maybe<Expr>::create(range, Expr(third))
417                                : Maybe<Expr>::create(range);
418       return SliceExpr::create(range, maybe_first, maybe_second, maybe_third);
419     } else {
420       return Expr(first);
421     }
422   }
423 
parseSubscripttorch::jit::ParserImpl424   TreeRef parseSubscript(const TreeRef& value) {
425     const auto range = L.cur().range;
426 
427     auto subscript_exprs =
428         parseList('[', ',', ']', &ParserImpl::parseSubscriptExp);
429 
430     const auto whole_range =
431         SourceRange(range.source(), range.start(), L.cur().range.start());
432     return Subscript::create(whole_range, Expr(value), subscript_exprs);
433   }
434 
maybeParseTypeAnnotationtorch::jit::ParserImpl435   Maybe<Expr> maybeParseTypeAnnotation() {
436     if (L.nextIf(':')) {
437       // NB: parseExp must not be called inline, since argument evaluation order
438       // changes when L.cur().range is mutated with respect to the parseExp()
439       // call.
440       auto expr = parseExp();
441       return Maybe<Expr>::create(expr.range(), expr);
442     } else {
443       return Maybe<Expr>::create(L.cur().range);
444     }
445   }
446 
parseFormalParamtorch::jit::ParserImpl447   TreeRef parseFormalParam(bool kwarg_only) {
448     auto ident = parseIdent();
449     TreeRef type = maybeParseTypeAnnotation();
450     TreeRef def;
451     if (L.nextIf('=')) {
452       // NB: parseExp must not be called inline, since argument evaluation order
453       // changes when L.cur().range is mutated with respect to the parseExp()
454       // call.
455       auto expr = parseExp();
456       def = Maybe<Expr>::create(expr.range(), expr);
457     } else {
458       def = Maybe<Expr>::create(L.cur().range);
459     }
460     return Param::create(
461         type->range(),
462         Ident(ident),
463         Maybe<Expr>(type),
464         Maybe<Expr>(def),
465         kwarg_only);
466   }
467 
parseBareTypeAnnotationtorch::jit::ParserImpl468   Param parseBareTypeAnnotation() {
469     auto type = parseExp();
470     return Param::create(
471         type.range(),
472         Ident::create(type.range(), ""),
473         Maybe<Expr>::create(type.range(), type),
474         Maybe<Expr>::create(type.range()),
475         /*kwarg_only=*/false);
476   }
477 
parseTypeCommenttorch::jit::ParserImpl478   Decl parseTypeComment() {
479     auto range = L.cur().range;
480     L.expect(TK_TYPE_COMMENT);
481     auto param_types =
482         parseList('(', ',', ')', &ParserImpl::parseBareTypeAnnotation);
483     TreeRef return_type;
484     if (L.nextIf(TK_ARROW)) {
485       auto return_type_range = L.cur().range;
486       return_type = Maybe<Expr>::create(return_type_range, parseExp());
487     } else {
488       return_type = Maybe<Expr>::create(L.cur().range);
489     }
490     return Decl::create(range, param_types, Maybe<Expr>(return_type));
491   }
492 
493   // 'first' has already been parsed since expressions can exist
494   // alone on a line:
495   // first[,other,lhs] = rhs
parseAssigntorch::jit::ParserImpl496   TreeRef parseAssign(const Expr& lhs) {
497     auto type = maybeParseTypeAnnotation();
498     auto maybeOp = maybeParseAssignmentOp();
499     if (maybeOp) {
500       // There is an assignment operator, parse the RHS and generate the
501       // assignment.
502       auto rhs = parseExpOrExpTuple();
503       if (maybeOp.value()->kind() == '=') {
504         std::vector<Expr> lhs_list = {lhs};
505         while (L.nextIf('=')) {
506           lhs_list.push_back(rhs);
507           rhs = parseExpOrExpTuple();
508         }
509         if (type.present() && lhs_list.size() > 1) {
510           throw ErrorReport(type.range())
511               << "Annotated multiple assignment is not supported in python";
512         }
513         L.expect(TK_NEWLINE);
514         return Assign::create(
515             lhs.range(),
516             List<Expr>::create(lhs_list[0].range(), lhs_list),
517             Maybe<Expr>::create(rhs.range(), rhs),
518             type);
519       } else {
520         L.expect(TK_NEWLINE);
521         // this is an augmented assignment
522         if (lhs.kind() == TK_TUPLE_LITERAL) {
523           throw ErrorReport(lhs.range())
524               << " augmented assignment can only have one LHS expression";
525         }
526         return AugAssign::create(
527             lhs.range(), lhs, AugAssignKind(*maybeOp), Expr(rhs));
528       }
529     } else {
530       // There is no assignment operator, so this is of the form `lhs : <type>`
531       TORCH_INTERNAL_ASSERT(type.present());
532       L.expect(TK_NEWLINE);
533       return Assign::create(
534           lhs.range(),
535           List<Expr>::create(lhs.range(), {lhs}),
536           Maybe<Expr>::create(lhs.range()),
537           type);
538     }
539   }
540 
parseStmttorch::jit::ParserImpl541   TreeRef parseStmt(bool in_class = false) {
542     switch (L.cur().kind) {
543       case TK_IF:
544         return parseIf();
545       case TK_WHILE:
546         return parseWhile();
547       case TK_FOR:
548         return parseFor();
549       case TK_GLOBAL: {
550         auto range = L.next().range;
551         auto idents =
552             parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseIdent);
553         L.expect(TK_NEWLINE);
554         return Global::create(range, idents);
555       }
556       case TK_RETURN: {
557         auto range = L.next().range;
558         Expr value = L.cur().kind != TK_NEWLINE
559             ? parseExpOrExpTuple()
560             : Expr(create_compound(TK_NONE, range, {}));
561         L.expect(TK_NEWLINE);
562         return Return::create(range, value);
563       }
564       case TK_RAISE: {
565         auto range = L.next().range;
566         auto expr = parseExp();
567         L.expect(TK_NEWLINE);
568         return Raise::create(range, expr);
569       }
570       case TK_ASSERT: {
571         auto range = L.next().range;
572         auto cond = parseExp();
573         Maybe<Expr> maybe_first = Maybe<Expr>::create(range);
574         if (L.nextIf(',')) {
575           auto msg = parseExp();
576           maybe_first = Maybe<Expr>::create(range, Expr(msg));
577         }
578         L.expect(TK_NEWLINE);
579         return Assert::create(range, cond, maybe_first);
580       }
581       case TK_BREAK: {
582         auto range = L.next().range;
583         L.expect(TK_NEWLINE);
584         return Break::create(range);
585       }
586       case TK_CONTINUE: {
587         auto range = L.next().range;
588         L.expect(TK_NEWLINE);
589         return Continue::create(range);
590       }
591       case TK_PASS: {
592         auto range = L.next().range;
593         L.expect(TK_NEWLINE);
594         return Pass::create(range);
595       }
596       case TK_DEF: {
597         return parseFunction(/*is_method=*/in_class);
598       }
599       case TK_DELETE: {
600         auto range = L.next().range;
601         auto targets =
602             parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp);
603         L.expect(TK_NEWLINE);
604         return Delete::create(range, targets);
605       }
606       case TK_WITH: {
607         return parseWith();
608       }
609       default: {
610         auto lhs = parseExpOrExpTuple();
611         if (L.cur().kind != TK_NEWLINE) {
612           return parseAssign(lhs);
613         } else {
614           L.expect(TK_NEWLINE);
615           return ExprStmt::create(lhs.range(), lhs);
616         }
617       }
618     }
619   }
620 
parseWithItemtorch::jit::ParserImpl621   WithItem parseWithItem() {
622     auto target = parseExp();
623 
624     if (L.cur().kind == TK_AS) {
625       // If the current token is TK_AS, this with item is of the form
626       // "expression as target".
627       auto token = L.expect(TK_AS);
628       Ident ident = parseIdent();
629       auto var = Var::create(ident.range(), ident);
630       return WithItem::create(
631           token.range, target, Maybe<Var>::create(ident.range(), var));
632     } else {
633       // If not, this with item is of the form "expression".
634       return WithItem::create(
635           target.range(), target, Maybe<Var>::create(target.range()));
636     }
637   }
638 
parseIftorch::jit::ParserImpl639   TreeRef parseIf(bool expect_if = true) {
640     auto r = L.cur().range;
641     if (expect_if)
642       L.expect(TK_IF);
643     auto cond = parseExp();
644     L.expect(':');
645     auto true_branch = parseStatements(/*expect_indent=*/true);
646     auto false_branch = makeList(L.cur().range, {});
647     if (L.nextIf(TK_ELSE)) {
648       L.expect(':');
649       false_branch = parseStatements(/*expect_indent=*/true);
650     } else if (L.nextIf(TK_ELIF)) {
651       // NB: this needs to be a separate statement, since the call to parseIf
652       // mutates the lexer state, and thus causes a heap-use-after-free in
653       // compilers which evaluate argument expressions LTR
654       auto range = L.cur().range;
655       false_branch = makeList(range, {parseIf(false)});
656     }
657     return If::create(
658         r, Expr(cond), List<Stmt>(true_branch), List<Stmt>(false_branch));
659   }
parseWhiletorch::jit::ParserImpl660   TreeRef parseWhile() {
661     auto r = L.cur().range;
662     L.expect(TK_WHILE);
663     auto cond = parseExp();
664     L.expect(':');
665     auto body = parseStatements(/*expect_indent=*/true);
666     return While::create(r, Expr(cond), List<Stmt>(body));
667   }
668 
parseFortorch::jit::ParserImpl669   TreeRef parseFor() {
670     auto r = L.cur().range;
671     L.expect(TK_FOR);
672     auto targets = parseList(TK_NOTHING, ',', TK_IN, &ParserImpl::parseLHSExp);
673     auto itrs = parseList(TK_NOTHING, ',', ':', &ParserImpl::parseExp);
674     auto body = parseStatements(/*expect_indent=*/true);
675     return For::create(r, targets, itrs, body);
676   }
677 
parseWithtorch::jit::ParserImpl678   TreeRef parseWith() {
679     auto r = L.cur().range;
680     // Parse "with expression [as target][, expression [as target]]*:".
681     L.expect(TK_WITH);
682     auto targets = parseList(TK_NOTHING, ',', ':', &ParserImpl::parseWithItem);
683     // Parse the body.
684     auto body = parseStatements(/*expect_indent=*/true);
685     return With::create(r, targets, body);
686   }
687 
parseStatementstorch::jit::ParserImpl688   TreeRef parseStatements(bool expect_indent, bool in_class = false) {
689     auto r = L.cur().range;
690     if (expect_indent) {
691       L.expect(TK_INDENT);
692     }
693     TreeList stmts;
694     do {
695       stmts.push_back(parseStmt(in_class));
696     } while (!L.nextIf(TK_DEDENT));
697     return create_compound(TK_LIST, r, std::move(stmts));
698   }
699 
parseReturnAnnotationtorch::jit::ParserImpl700   Maybe<Expr> parseReturnAnnotation() {
701     if (L.nextIf(TK_ARROW)) {
702       // Exactly one expression for return type annotation
703       auto return_type_range = L.cur().range;
704       return Maybe<Expr>::create(return_type_range, parseExp());
705     } else {
706       return Maybe<Expr>::create(L.cur().range);
707     }
708   }
709 
parseFormalParamstorch::jit::ParserImpl710   List<Param> parseFormalParams() {
711     auto r = L.cur().range;
712     std::vector<Param> params;
713     bool kwarg_only = false;
714     parseSequence('(', ',', ')', [&] {
715       if (!kwarg_only && L.nextIf('*')) {
716         kwarg_only = true;
717       } else {
718         params.emplace_back(parseFormalParam(kwarg_only));
719       }
720     });
721     return List<Param>::create(r, params);
722   }
parseDecltorch::jit::ParserImpl723   Decl parseDecl() {
724     // Parse return type annotation
725     List<Param> paramlist = parseFormalParams();
726     TreeRef return_type;
727     Maybe<Expr> return_annotation = parseReturnAnnotation();
728     L.expect(':');
729     return Decl::create(
730         paramlist.range(), List<Param>(paramlist), return_annotation);
731   }
732 
parseClasstorch::jit::ParserImpl733   TreeRef parseClass() {
734     L.expect(TK_CLASS_DEF);
735     const auto name = parseIdent();
736     Maybe<Expr> superclass = Maybe<Expr>::create(name.range());
737     if (L.nextIf('(')) {
738       // Only support inheriting from NamedTuple right now.
739       auto id = parseExp();
740       superclass = Maybe<Expr>::create(id.range(), id);
741       L.expect(')');
742     }
743     L.expect(':');
744     const auto statements =
745         parseStatements(/*expect_indent=*/true, /*in_class=*/true);
746     return ClassDef::create(
747         name.range(), name, superclass, List<Stmt>(statements));
748   }
749 
parseFunctiontorch::jit::ParserImpl750   TreeRef parseFunction(bool is_method) {
751     L.expect(TK_DEF);
752     auto name = parseIdent();
753     auto decl = parseDecl();
754 
755     TreeRef stmts_list;
756     if (L.nextIf(TK_INDENT)) {
757       // Handle type annotations specified in a type comment as the first line
758       // of the function.
759       if (L.cur().kind == TK_TYPE_COMMENT) {
760         auto type_annotation_decl = Decl(parseTypeComment());
761         L.expect(TK_NEWLINE);
762         decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
763       }
764 
765       stmts_list = parseStatements(false);
766     } else {
767       // Special case: the Python grammar allows one-line functions with a
768       // single statement.
769       if (L.cur().kind == TK_TYPE_COMMENT) {
770         auto type_annotation_decl = Decl(parseTypeComment());
771         decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
772       }
773 
774       TreeList stmts;
775       stmts.push_back(parseStmt(is_method));
776       stmts_list = create_compound(TK_LIST, L.cur().range, std::move(stmts));
777     }
778 
779     return Def::create(
780         name.range(), Ident(name), Decl(decl), List<Stmt>(stmts_list));
781   }
lexertorch::jit::ParserImpl782   Lexer& lexer() {
783     return L;
784   }
785 
786  private:
787   // short helpers to create nodes
create_compoundtorch::jit::ParserImpl788   TreeRef create_compound(
789       int kind,
790       const SourceRange& range,
791       TreeList&& trees) {
792     return Compound::create(kind, range, std::move(trees));
793   }
makeListtorch::jit::ParserImpl794   TreeRef makeList(const SourceRange& range, TreeList&& trees) {
795     return create_compound(TK_LIST, range, std::move(trees));
796   }
797   Lexer L;
798   SharedParserData& shared;
799 };
800 
Parser(const std::shared_ptr<Source> & src)801 Parser::Parser(const std::shared_ptr<Source>& src)
802     : pImpl(new ParserImpl(src)) {}
803 
804 Parser::~Parser() = default;
805 
parseFunction(bool is_method)806 TreeRef Parser::parseFunction(bool is_method) {
807   return pImpl->parseFunction(is_method);
808 }
parseClass()809 TreeRef Parser::parseClass() {
810   return pImpl->parseClass();
811 }
lexer()812 Lexer& Parser::lexer() {
813   return pImpl->lexer();
814 }
parseTypeComment()815 Decl Parser::parseTypeComment() {
816   return pImpl->parseTypeComment();
817 }
parseExp()818 Expr Parser::parseExp() {
819   return pImpl->parseExp();
820 }
821 
822 } // namespace torch::jit
823