1 #include <torch/csrc/jit/frontend/parser.h>
2
3 #include <torch/csrc/jit/frontend/lexer.h>
4 #include <torch/csrc/jit/frontend/parse_string_literal.h>
5 #include <torch/csrc/jit/frontend/tree.h>
6 #include <torch/csrc/jit/frontend/tree_views.h>
7 #include <optional>
8
9 namespace torch::jit {
10
mergeTypesFromTypeComment(const Decl & decl,const Decl & type_annotation_decl,bool is_method)11 Decl mergeTypesFromTypeComment(
12 const Decl& decl,
13 const Decl& type_annotation_decl,
14 bool is_method) {
15 auto expected_num_annotations = decl.params().size();
16 if (is_method) {
17 // `self` argument
18 expected_num_annotations -= 1;
19 }
20 if (expected_num_annotations != type_annotation_decl.params().size()) {
21 throw ErrorReport(decl.range())
22 << "Number of type annotations ("
23 << type_annotation_decl.params().size()
24 << ") did not match the number of "
25 << (is_method ? "method" : "function") << " parameters ("
26 << expected_num_annotations << ")";
27 }
28 auto old = decl.params();
29 auto _new = type_annotation_decl.params();
30 // Merge signature idents and ranges with annotation types
31
32 std::vector<Param> new_params;
33 size_t i = is_method ? 1 : 0;
34 size_t j = 0;
35 if (is_method) {
36 new_params.push_back(old[0]);
37 }
38 for (; i < decl.params().size(); ++i, ++j) {
39 new_params.emplace_back(old[i].withType(_new[j].type()));
40 }
41 return Decl::create(
42 decl.range(),
43 List<Param>::create(decl.range(), new_params),
44 type_annotation_decl.return_type());
45 }
46
47 struct ParserImpl {
ParserImpltorch::jit::ParserImpl48 explicit ParserImpl(const std::shared_ptr<Source>& source)
49 : L(source), shared(sharedParserData()) {}
50
parseIdenttorch::jit::ParserImpl51 Ident parseIdent() {
52 auto t = L.expect(TK_IDENT);
53 // whenever we parse something that has a TreeView type we always
54 // use its create method so that the accessors and the constructor
55 // of the Compound tree are in the same place.
56 return Ident::create(t.range, t.text());
57 }
createApplytorch::jit::ParserImpl58 TreeRef createApply(const Expr& expr) {
59 TreeList attributes;
60 auto range = L.cur().range;
61 TreeList inputs;
62 parseArguments(inputs, attributes);
63 return Apply::create(
64 range,
65 expr,
66 List<Expr>(makeList(range, std::move(inputs))),
67 List<Attribute>(makeList(range, std::move(attributes))));
68 }
69
followsTupletorch::jit::ParserImpl70 static bool followsTuple(int kind) {
71 switch (kind) {
72 case TK_PLUS_EQ:
73 case TK_MINUS_EQ:
74 case TK_TIMES_EQ:
75 case TK_DIV_EQ:
76 case TK_MOD_EQ:
77 case TK_BIT_OR_EQ:
78 case TK_BIT_AND_EQ:
79 case TK_BIT_XOR_EQ:
80 case TK_LSHIFT_EQ:
81 case TK_RSHIFT_EQ:
82 case TK_POW_EQ:
83 case TK_NEWLINE:
84 case '=':
85 case ')':
86 return true;
87 default:
88 return false;
89 }
90 }
91
92 // exp | expr, | expr, expr, ...
parseExpOrExpTupletorch::jit::ParserImpl93 Expr parseExpOrExpTuple() {
94 auto prefix = parseExp();
95 if (L.cur().kind == ',') {
96 std::vector<Expr> exprs = {prefix};
97 while (L.nextIf(',')) {
98 if (followsTuple(L.cur().kind))
99 break;
100 exprs.push_back(parseExp());
101 }
102 auto list = List<Expr>::create(prefix.range(), exprs);
103 prefix = TupleLiteral::create(list.range(), list);
104 }
105 return prefix;
106 }
107 // things like a 1.0 or a(4) that are not unary/binary expressions
108 // and have higher precedence than all of them
parseBaseExptorch::jit::ParserImpl109 TreeRef parseBaseExp() {
110 TreeRef prefix;
111 switch (L.cur().kind) {
112 case TK_NUMBER: {
113 prefix = parseConst();
114 } break;
115 case TK_TRUE:
116 case TK_FALSE:
117 case TK_NONE:
118 case TK_NONE_TYPE: {
119 auto k = L.cur().kind;
120 auto r = L.cur().range;
121 prefix = create_compound(k, r, {});
122 L.next();
123 } break;
124 case '(': {
125 L.next();
126 if (L.nextIf(')')) {
127 /// here we have the empty tuple case
128 std::vector<Expr> vecExpr;
129 List<Expr> listExpr = List<Expr>::create(L.cur().range, vecExpr);
130 prefix = TupleLiteral::create(L.cur().range, listExpr);
131 break;
132 }
133 prefix = parseExpOrExpTuple();
134 L.expect(')');
135 } break;
136 case '[': {
137 auto list = parseList('[', ',', ']', &ParserImpl::parseExp);
138
139 if (list.size() == 1 && (*list.begin()).kind() == TK_LIST_COMP) {
140 prefix = *list.begin();
141 } else {
142 for (auto se : list) {
143 if (se.kind() == TK_LIST_COMP) {
144 throw ErrorReport(list.range())
145 << " expected a single list comprehension within '[' , ']'";
146 }
147 }
148 prefix = ListLiteral::create(list.range(), List<Expr>(list));
149 }
150
151 } break;
152 case '{': {
153 L.next();
154 // If we have a dict literal, `keys` and `values` will store the keys
155 // and values used in the object's construction. EDGE CASE: We have a
156 // dict comprehension, so we'll get the first element of the dict
157 // comprehension in `keys` and a list comprehension in `values`.
158 // For example, `{i : chr(i + 65) for i in range(4)}` would give us
159 // `i` in `keys` and `chr(i + 65) for i in range(4)` in `values`.
160 // The optimal way of handling this case is to simply splice the new
161 // dict comprehension together from the existing list comprehension.
162 // Splicing prevents breaking changes to our API and does not require
163 // the use of global variables.
164 std::vector<Expr> keys;
165 std::vector<Expr> values;
166 auto range = L.cur().range;
167 if (L.cur().kind != '}') {
168 do {
169 keys.push_back(parseExp());
170 L.expect(':');
171 values.push_back(parseExp());
172 } while (L.nextIf(','));
173 }
174 L.expect('}');
175 if (keys.size() == 1 && (*values.begin()).kind() == TK_LIST_COMP) {
176 ListComp lc(*values.begin());
177 prefix = DictComp::create(
178 range, *keys.begin(), lc.elt(), lc.target(), lc.iter());
179 } else {
180 prefix = DictLiteral::create(
181 range,
182 List<Expr>::create(range, keys),
183 List<Expr>::create(range, values));
184 }
185 } break;
186 case TK_STRINGLITERAL: {
187 prefix = parseConcatenatedStringLiterals();
188 } break;
189 case TK_ELLIPSIS:
190 case TK_DOTS: {
191 prefix = Dots::create(L.cur().range);
192 L.next();
193 } break;
194 default: {
195 Ident name = parseIdent();
196 prefix = Var::create(name.range(), name);
197 } break;
198 }
199 while (true) {
200 if (L.nextIf('.')) {
201 const auto name = parseIdent();
202 prefix = Select::create(name.range(), Expr(prefix), Ident(name));
203 } else if (L.cur().kind == '(') {
204 prefix = createApply(Expr(prefix));
205 } else if (L.cur().kind == '[') {
206 prefix = parseSubscript(prefix);
207 } else {
208 break;
209 }
210 }
211 return prefix;
212 }
maybeParseAssignmentOptorch::jit::ParserImpl213 std::optional<TreeRef> maybeParseAssignmentOp() {
214 auto r = L.cur().range;
215 switch (L.cur().kind) {
216 case TK_PLUS_EQ:
217 case TK_MINUS_EQ:
218 case TK_TIMES_EQ:
219 case TK_DIV_EQ:
220 case TK_BIT_OR_EQ:
221 case TK_BIT_AND_EQ:
222 case TK_BIT_XOR_EQ:
223 case TK_MOD_EQ: {
224 int modifier = L.next().text()[0];
225 return create_compound(modifier, r, {});
226 } break;
227 case TK_LSHIFT_EQ: {
228 L.next();
229 return create_compound(TK_LSHIFT, r, {});
230 } break;
231 case TK_RSHIFT_EQ: {
232 L.next();
233 return create_compound(TK_RSHIFT, r, {});
234 } break;
235 case TK_POW_EQ: {
236 L.next();
237 return create_compound(TK_POW, r, {});
238 } break;
239 case '=': {
240 L.next();
241 return create_compound('=', r, {}); // no reduction
242 } break;
243 default:
244 return std::nullopt;
245 }
246 }
parseTrinarytorch::jit::ParserImpl247 TreeRef parseTrinary(
248 TreeRef true_branch,
249 const SourceRange& range,
250 int binary_prec) {
251 auto cond = parseExp();
252 L.expect(TK_ELSE);
253 auto false_branch = parseExp(binary_prec);
254 return create_compound(
255 TK_IF_EXPR, range, {cond, std::move(true_branch), false_branch});
256 }
257 // parse the longest expression whose binary operators have
258 // precedence strictly greater than 'precedence'
259 // precedence == 0 will parse _all_ expressions
260 // this is the core loop of 'top-down precedence parsing'
parseExptorch::jit::ParserImpl261 Expr parseExp() {
262 return parseExp(0);
263 }
parseExptorch::jit::ParserImpl264 Expr parseExp(int precedence) {
265 TreeRef prefix;
266 int unary_prec = 0;
267 if (shared.isUnary(L.cur().kind, &unary_prec)) {
268 auto kind = L.cur().kind;
269 auto pos = L.cur().range;
270 L.next();
271 auto unary_kind = kind == '*' ? TK_STARRED
272 : kind == '-' ? TK_UNARY_MINUS
273 : kind;
274 auto subexp = parseExp(unary_prec);
275 // fold '-' into constant numbers, so that attributes can accept
276 // things like -1
277 if (unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
278 prefix = Const::create(subexp.range(), "-" + Const(subexp).text());
279 } else {
280 prefix = create_compound(unary_kind, pos, {subexp});
281 }
282 } else {
283 prefix = parseBaseExp();
284 }
285 int binary_prec = 0;
286 while (shared.isBinary(L.cur().kind, &binary_prec)) {
287 if (binary_prec <= precedence) // not allowed to parse something which is
288 // not greater than 'precedence'
289 break;
290
291 int kind = L.cur().kind;
292 auto pos = L.cur().range;
293 L.next();
294 if (shared.isRightAssociative(kind))
295 binary_prec--;
296
297 if (kind == TK_NOTIN) {
298 // NB: `not in` is just `not( in )`, so we don't introduce new tree view
299 // but just make it a nested call in our tree view structure
300 prefix = create_compound(TK_IN, pos, {prefix, parseExp(binary_prec)});
301 prefix = create_compound(TK_NOT, pos, {prefix});
302 continue;
303 }
304
305 // special case for trinary operator
306 if (kind == TK_IF) {
307 prefix = parseTrinary(prefix, pos, binary_prec);
308 continue;
309 }
310
311 if (kind == TK_FOR) {
312 // TK_FOR targets should only parse exprs prec greater than 4, which
313 // only includes subset of Exprs that suppose to be on the LHS according
314 // to the python grammar
315 // https://docs.python.org/3/reference/grammar.html
316 auto target = parseLHSExp();
317 L.expect(TK_IN);
318 auto iter = parseExp();
319 prefix = ListComp::create(pos, Expr(prefix), target, iter);
320 continue;
321 }
322
323 prefix = create_compound(kind, pos, {prefix, parseExp(binary_prec)});
324 }
325 return Expr(prefix);
326 }
327
parseSequencetorch::jit::ParserImpl328 void parseSequence(
329 int begin,
330 int sep,
331 int end,
332 const std::function<void()>& parse) {
333 if (begin != TK_NOTHING) {
334 L.expect(begin);
335 }
336 while (end != L.cur().kind) {
337 parse();
338 if (!L.nextIf(sep)) {
339 if (end != TK_NOTHING) {
340 L.expect(end);
341 }
342 return;
343 }
344 }
345 L.expect(end);
346 }
347 template <typename T>
parseListtorch::jit::ParserImpl348 List<T> parseList(int begin, int sep, int end, T (ParserImpl::*parse)()) {
349 auto r = L.cur().range;
350 std::vector<T> elements;
351 parseSequence(
352 begin, sep, end, [&] { elements.emplace_back((this->*parse)()); });
353 return List<T>::create(r, elements);
354 }
355
parseConsttorch::jit::ParserImpl356 Const parseConst() {
357 auto range = L.cur().range;
358 auto t = L.expect(TK_NUMBER);
359 return Const::create(t.range, t.text());
360 }
361
parseConcatenatedStringLiteralstorch::jit::ParserImpl362 StringLiteral parseConcatenatedStringLiterals() {
363 auto range = L.cur().range;
364 std::string ss;
365 while (L.cur().kind == TK_STRINGLITERAL) {
366 auto literal_range = L.cur().range;
367 ss.append(parseStringLiteral(literal_range, L.next().text()));
368 }
369 return StringLiteral::create(range, ss);
370 }
371
parseAttributeValuetorch::jit::ParserImpl372 Expr parseAttributeValue() {
373 return parseExp();
374 }
375
parseArgumentstorch::jit::ParserImpl376 void parseArguments(TreeList& inputs, TreeList& attributes) {
377 parseSequence('(', ',', ')', [&] {
378 if (L.cur().kind == TK_IDENT && L.lookahead().kind == '=') {
379 auto ident = parseIdent();
380 L.expect('=');
381 auto v = parseAttributeValue();
382 attributes.push_back(Attribute::create(ident.range(), Ident(ident), v));
383 } else {
384 inputs.push_back(parseExp());
385 }
386 });
387 }
388
389 // parse LHS acceptable exprs, which only includes subset of Exprs that prec
390 // is greater than 4 according to the python grammar
parseLHSExptorch::jit::ParserImpl391 Expr parseLHSExp() {
392 return parseExp(4);
393 }
394
395 // Parse expr's of the form [a:], [:b], [a:b], [:] and all variations with
396 // "::"
parseSubscriptExptorch::jit::ParserImpl397 Expr parseSubscriptExp() {
398 TreeRef first, second, third;
399 auto range = L.cur().range;
400 if (L.cur().kind != ':') {
401 first = parseExp();
402 }
403 if (L.nextIf(':')) {
404 if (L.cur().kind != ',' && L.cur().kind != ']' && L.cur().kind != ':') {
405 second = parseExp();
406 }
407 if (L.nextIf(':')) {
408 if (L.cur().kind != ',' && L.cur().kind != ']') {
409 third = parseExp();
410 }
411 }
412 auto maybe_first = first ? Maybe<Expr>::create(range, Expr(first))
413 : Maybe<Expr>::create(range);
414 auto maybe_second = second ? Maybe<Expr>::create(range, Expr(second))
415 : Maybe<Expr>::create(range);
416 auto maybe_third = third ? Maybe<Expr>::create(range, Expr(third))
417 : Maybe<Expr>::create(range);
418 return SliceExpr::create(range, maybe_first, maybe_second, maybe_third);
419 } else {
420 return Expr(first);
421 }
422 }
423
parseSubscripttorch::jit::ParserImpl424 TreeRef parseSubscript(const TreeRef& value) {
425 const auto range = L.cur().range;
426
427 auto subscript_exprs =
428 parseList('[', ',', ']', &ParserImpl::parseSubscriptExp);
429
430 const auto whole_range =
431 SourceRange(range.source(), range.start(), L.cur().range.start());
432 return Subscript::create(whole_range, Expr(value), subscript_exprs);
433 }
434
maybeParseTypeAnnotationtorch::jit::ParserImpl435 Maybe<Expr> maybeParseTypeAnnotation() {
436 if (L.nextIf(':')) {
437 // NB: parseExp must not be called inline, since argument evaluation order
438 // changes when L.cur().range is mutated with respect to the parseExp()
439 // call.
440 auto expr = parseExp();
441 return Maybe<Expr>::create(expr.range(), expr);
442 } else {
443 return Maybe<Expr>::create(L.cur().range);
444 }
445 }
446
parseFormalParamtorch::jit::ParserImpl447 TreeRef parseFormalParam(bool kwarg_only) {
448 auto ident = parseIdent();
449 TreeRef type = maybeParseTypeAnnotation();
450 TreeRef def;
451 if (L.nextIf('=')) {
452 // NB: parseExp must not be called inline, since argument evaluation order
453 // changes when L.cur().range is mutated with respect to the parseExp()
454 // call.
455 auto expr = parseExp();
456 def = Maybe<Expr>::create(expr.range(), expr);
457 } else {
458 def = Maybe<Expr>::create(L.cur().range);
459 }
460 return Param::create(
461 type->range(),
462 Ident(ident),
463 Maybe<Expr>(type),
464 Maybe<Expr>(def),
465 kwarg_only);
466 }
467
parseBareTypeAnnotationtorch::jit::ParserImpl468 Param parseBareTypeAnnotation() {
469 auto type = parseExp();
470 return Param::create(
471 type.range(),
472 Ident::create(type.range(), ""),
473 Maybe<Expr>::create(type.range(), type),
474 Maybe<Expr>::create(type.range()),
475 /*kwarg_only=*/false);
476 }
477
parseTypeCommenttorch::jit::ParserImpl478 Decl parseTypeComment() {
479 auto range = L.cur().range;
480 L.expect(TK_TYPE_COMMENT);
481 auto param_types =
482 parseList('(', ',', ')', &ParserImpl::parseBareTypeAnnotation);
483 TreeRef return_type;
484 if (L.nextIf(TK_ARROW)) {
485 auto return_type_range = L.cur().range;
486 return_type = Maybe<Expr>::create(return_type_range, parseExp());
487 } else {
488 return_type = Maybe<Expr>::create(L.cur().range);
489 }
490 return Decl::create(range, param_types, Maybe<Expr>(return_type));
491 }
492
493 // 'first' has already been parsed since expressions can exist
494 // alone on a line:
495 // first[,other,lhs] = rhs
parseAssigntorch::jit::ParserImpl496 TreeRef parseAssign(const Expr& lhs) {
497 auto type = maybeParseTypeAnnotation();
498 auto maybeOp = maybeParseAssignmentOp();
499 if (maybeOp) {
500 // There is an assignment operator, parse the RHS and generate the
501 // assignment.
502 auto rhs = parseExpOrExpTuple();
503 if (maybeOp.value()->kind() == '=') {
504 std::vector<Expr> lhs_list = {lhs};
505 while (L.nextIf('=')) {
506 lhs_list.push_back(rhs);
507 rhs = parseExpOrExpTuple();
508 }
509 if (type.present() && lhs_list.size() > 1) {
510 throw ErrorReport(type.range())
511 << "Annotated multiple assignment is not supported in python";
512 }
513 L.expect(TK_NEWLINE);
514 return Assign::create(
515 lhs.range(),
516 List<Expr>::create(lhs_list[0].range(), lhs_list),
517 Maybe<Expr>::create(rhs.range(), rhs),
518 type);
519 } else {
520 L.expect(TK_NEWLINE);
521 // this is an augmented assignment
522 if (lhs.kind() == TK_TUPLE_LITERAL) {
523 throw ErrorReport(lhs.range())
524 << " augmented assignment can only have one LHS expression";
525 }
526 return AugAssign::create(
527 lhs.range(), lhs, AugAssignKind(*maybeOp), Expr(rhs));
528 }
529 } else {
530 // There is no assignment operator, so this is of the form `lhs : <type>`
531 TORCH_INTERNAL_ASSERT(type.present());
532 L.expect(TK_NEWLINE);
533 return Assign::create(
534 lhs.range(),
535 List<Expr>::create(lhs.range(), {lhs}),
536 Maybe<Expr>::create(lhs.range()),
537 type);
538 }
539 }
540
parseStmttorch::jit::ParserImpl541 TreeRef parseStmt(bool in_class = false) {
542 switch (L.cur().kind) {
543 case TK_IF:
544 return parseIf();
545 case TK_WHILE:
546 return parseWhile();
547 case TK_FOR:
548 return parseFor();
549 case TK_GLOBAL: {
550 auto range = L.next().range;
551 auto idents =
552 parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseIdent);
553 L.expect(TK_NEWLINE);
554 return Global::create(range, idents);
555 }
556 case TK_RETURN: {
557 auto range = L.next().range;
558 Expr value = L.cur().kind != TK_NEWLINE
559 ? parseExpOrExpTuple()
560 : Expr(create_compound(TK_NONE, range, {}));
561 L.expect(TK_NEWLINE);
562 return Return::create(range, value);
563 }
564 case TK_RAISE: {
565 auto range = L.next().range;
566 auto expr = parseExp();
567 L.expect(TK_NEWLINE);
568 return Raise::create(range, expr);
569 }
570 case TK_ASSERT: {
571 auto range = L.next().range;
572 auto cond = parseExp();
573 Maybe<Expr> maybe_first = Maybe<Expr>::create(range);
574 if (L.nextIf(',')) {
575 auto msg = parseExp();
576 maybe_first = Maybe<Expr>::create(range, Expr(msg));
577 }
578 L.expect(TK_NEWLINE);
579 return Assert::create(range, cond, maybe_first);
580 }
581 case TK_BREAK: {
582 auto range = L.next().range;
583 L.expect(TK_NEWLINE);
584 return Break::create(range);
585 }
586 case TK_CONTINUE: {
587 auto range = L.next().range;
588 L.expect(TK_NEWLINE);
589 return Continue::create(range);
590 }
591 case TK_PASS: {
592 auto range = L.next().range;
593 L.expect(TK_NEWLINE);
594 return Pass::create(range);
595 }
596 case TK_DEF: {
597 return parseFunction(/*is_method=*/in_class);
598 }
599 case TK_DELETE: {
600 auto range = L.next().range;
601 auto targets =
602 parseList(TK_NOTHING, ',', TK_NOTHING, &ParserImpl::parseExp);
603 L.expect(TK_NEWLINE);
604 return Delete::create(range, targets);
605 }
606 case TK_WITH: {
607 return parseWith();
608 }
609 default: {
610 auto lhs = parseExpOrExpTuple();
611 if (L.cur().kind != TK_NEWLINE) {
612 return parseAssign(lhs);
613 } else {
614 L.expect(TK_NEWLINE);
615 return ExprStmt::create(lhs.range(), lhs);
616 }
617 }
618 }
619 }
620
parseWithItemtorch::jit::ParserImpl621 WithItem parseWithItem() {
622 auto target = parseExp();
623
624 if (L.cur().kind == TK_AS) {
625 // If the current token is TK_AS, this with item is of the form
626 // "expression as target".
627 auto token = L.expect(TK_AS);
628 Ident ident = parseIdent();
629 auto var = Var::create(ident.range(), ident);
630 return WithItem::create(
631 token.range, target, Maybe<Var>::create(ident.range(), var));
632 } else {
633 // If not, this with item is of the form "expression".
634 return WithItem::create(
635 target.range(), target, Maybe<Var>::create(target.range()));
636 }
637 }
638
parseIftorch::jit::ParserImpl639 TreeRef parseIf(bool expect_if = true) {
640 auto r = L.cur().range;
641 if (expect_if)
642 L.expect(TK_IF);
643 auto cond = parseExp();
644 L.expect(':');
645 auto true_branch = parseStatements(/*expect_indent=*/true);
646 auto false_branch = makeList(L.cur().range, {});
647 if (L.nextIf(TK_ELSE)) {
648 L.expect(':');
649 false_branch = parseStatements(/*expect_indent=*/true);
650 } else if (L.nextIf(TK_ELIF)) {
651 // NB: this needs to be a separate statement, since the call to parseIf
652 // mutates the lexer state, and thus causes a heap-use-after-free in
653 // compilers which evaluate argument expressions LTR
654 auto range = L.cur().range;
655 false_branch = makeList(range, {parseIf(false)});
656 }
657 return If::create(
658 r, Expr(cond), List<Stmt>(true_branch), List<Stmt>(false_branch));
659 }
parseWhiletorch::jit::ParserImpl660 TreeRef parseWhile() {
661 auto r = L.cur().range;
662 L.expect(TK_WHILE);
663 auto cond = parseExp();
664 L.expect(':');
665 auto body = parseStatements(/*expect_indent=*/true);
666 return While::create(r, Expr(cond), List<Stmt>(body));
667 }
668
parseFortorch::jit::ParserImpl669 TreeRef parseFor() {
670 auto r = L.cur().range;
671 L.expect(TK_FOR);
672 auto targets = parseList(TK_NOTHING, ',', TK_IN, &ParserImpl::parseLHSExp);
673 auto itrs = parseList(TK_NOTHING, ',', ':', &ParserImpl::parseExp);
674 auto body = parseStatements(/*expect_indent=*/true);
675 return For::create(r, targets, itrs, body);
676 }
677
parseWithtorch::jit::ParserImpl678 TreeRef parseWith() {
679 auto r = L.cur().range;
680 // Parse "with expression [as target][, expression [as target]]*:".
681 L.expect(TK_WITH);
682 auto targets = parseList(TK_NOTHING, ',', ':', &ParserImpl::parseWithItem);
683 // Parse the body.
684 auto body = parseStatements(/*expect_indent=*/true);
685 return With::create(r, targets, body);
686 }
687
parseStatementstorch::jit::ParserImpl688 TreeRef parseStatements(bool expect_indent, bool in_class = false) {
689 auto r = L.cur().range;
690 if (expect_indent) {
691 L.expect(TK_INDENT);
692 }
693 TreeList stmts;
694 do {
695 stmts.push_back(parseStmt(in_class));
696 } while (!L.nextIf(TK_DEDENT));
697 return create_compound(TK_LIST, r, std::move(stmts));
698 }
699
parseReturnAnnotationtorch::jit::ParserImpl700 Maybe<Expr> parseReturnAnnotation() {
701 if (L.nextIf(TK_ARROW)) {
702 // Exactly one expression for return type annotation
703 auto return_type_range = L.cur().range;
704 return Maybe<Expr>::create(return_type_range, parseExp());
705 } else {
706 return Maybe<Expr>::create(L.cur().range);
707 }
708 }
709
parseFormalParamstorch::jit::ParserImpl710 List<Param> parseFormalParams() {
711 auto r = L.cur().range;
712 std::vector<Param> params;
713 bool kwarg_only = false;
714 parseSequence('(', ',', ')', [&] {
715 if (!kwarg_only && L.nextIf('*')) {
716 kwarg_only = true;
717 } else {
718 params.emplace_back(parseFormalParam(kwarg_only));
719 }
720 });
721 return List<Param>::create(r, params);
722 }
parseDecltorch::jit::ParserImpl723 Decl parseDecl() {
724 // Parse return type annotation
725 List<Param> paramlist = parseFormalParams();
726 TreeRef return_type;
727 Maybe<Expr> return_annotation = parseReturnAnnotation();
728 L.expect(':');
729 return Decl::create(
730 paramlist.range(), List<Param>(paramlist), return_annotation);
731 }
732
parseClasstorch::jit::ParserImpl733 TreeRef parseClass() {
734 L.expect(TK_CLASS_DEF);
735 const auto name = parseIdent();
736 Maybe<Expr> superclass = Maybe<Expr>::create(name.range());
737 if (L.nextIf('(')) {
738 // Only support inheriting from NamedTuple right now.
739 auto id = parseExp();
740 superclass = Maybe<Expr>::create(id.range(), id);
741 L.expect(')');
742 }
743 L.expect(':');
744 const auto statements =
745 parseStatements(/*expect_indent=*/true, /*in_class=*/true);
746 return ClassDef::create(
747 name.range(), name, superclass, List<Stmt>(statements));
748 }
749
parseFunctiontorch::jit::ParserImpl750 TreeRef parseFunction(bool is_method) {
751 L.expect(TK_DEF);
752 auto name = parseIdent();
753 auto decl = parseDecl();
754
755 TreeRef stmts_list;
756 if (L.nextIf(TK_INDENT)) {
757 // Handle type annotations specified in a type comment as the first line
758 // of the function.
759 if (L.cur().kind == TK_TYPE_COMMENT) {
760 auto type_annotation_decl = Decl(parseTypeComment());
761 L.expect(TK_NEWLINE);
762 decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
763 }
764
765 stmts_list = parseStatements(false);
766 } else {
767 // Special case: the Python grammar allows one-line functions with a
768 // single statement.
769 if (L.cur().kind == TK_TYPE_COMMENT) {
770 auto type_annotation_decl = Decl(parseTypeComment());
771 decl = mergeTypesFromTypeComment(decl, type_annotation_decl, is_method);
772 }
773
774 TreeList stmts;
775 stmts.push_back(parseStmt(is_method));
776 stmts_list = create_compound(TK_LIST, L.cur().range, std::move(stmts));
777 }
778
779 return Def::create(
780 name.range(), Ident(name), Decl(decl), List<Stmt>(stmts_list));
781 }
lexertorch::jit::ParserImpl782 Lexer& lexer() {
783 return L;
784 }
785
786 private:
787 // short helpers to create nodes
create_compoundtorch::jit::ParserImpl788 TreeRef create_compound(
789 int kind,
790 const SourceRange& range,
791 TreeList&& trees) {
792 return Compound::create(kind, range, std::move(trees));
793 }
makeListtorch::jit::ParserImpl794 TreeRef makeList(const SourceRange& range, TreeList&& trees) {
795 return create_compound(TK_LIST, range, std::move(trees));
796 }
797 Lexer L;
798 SharedParserData& shared;
799 };
800
Parser(const std::shared_ptr<Source> & src)801 Parser::Parser(const std::shared_ptr<Source>& src)
802 : pImpl(new ParserImpl(src)) {}
803
804 Parser::~Parser() = default;
805
parseFunction(bool is_method)806 TreeRef Parser::parseFunction(bool is_method) {
807 return pImpl->parseFunction(is_method);
808 }
parseClass()809 TreeRef Parser::parseClass() {
810 return pImpl->parseClass();
811 }
lexer()812 Lexer& Parser::lexer() {
813 return pImpl->lexer();
814 }
parseTypeComment()815 Decl Parser::parseTypeComment() {
816 return pImpl->parseTypeComment();
817 }
parseExp()818 Expr Parser::parseExp() {
819 return pImpl->parseExp();
820 }
821
822 } // namespace torch::jit
823