1# mypy: allow-untyped-defs 2import ast 3import functools 4import inspect 5from textwrap import dedent 6from typing import Any, List, NamedTuple, Optional, Tuple 7 8from torch._C import ErrorReport 9from torch._C._jit_tree_views import SourceRangeFactory 10 11 12def get_source_lines_and_file( 13 obj: Any, 14 error_msg: Optional[str] = None, 15) -> Tuple[List[str], int, Optional[str]]: 16 """ 17 Wrapper around inspect.getsourcelines and inspect.getsourcefile. 18 19 Returns: (sourcelines, file_lino, filename) 20 """ 21 filename = None # in case getsourcefile throws 22 try: 23 filename = inspect.getsourcefile(obj) 24 sourcelines, file_lineno = inspect.getsourcelines(obj) 25 except OSError as e: 26 msg = ( 27 f"Can't get source for {obj}. TorchScript requires source access in " 28 "order to carry out compilation, make sure original .py files are " 29 "available." 30 ) 31 if error_msg: 32 msg += "\n" + error_msg 33 raise OSError(msg) from e 34 35 return sourcelines, file_lineno, filename 36 37 38def normalize_source_lines(sourcelines: List[str]) -> List[str]: 39 """ 40 This helper function accepts a list of source lines. It finds the 41 indentation level of the function definition (`def`), then it indents 42 all lines in the function body to a point at or greater than that 43 level. This allows for comments and continued string literals that 44 are at a lower indentation than the rest of the code. 45 Args: 46 sourcelines: function source code, separated into lines by 47 the '\n' character 48 Returns: 49 A list of source lines that have been correctly aligned 50 """ 51 52 def remove_prefix(text, prefix): 53 return text[text.startswith(prefix) and len(prefix) :] 54 55 # Find the line and line number containing the function definition 56 idx = None 57 for i, l in enumerate(sourcelines): 58 if l.lstrip().startswith("def"): 59 idx = i 60 break 61 62 # This will happen when the function is a lambda- we won't find "def" anywhere in the source 63 # lines in that case. Currently trying to JIT compile a lambda will throw an error up in 64 # `parse_def()`, but we might want to handle this case in the future. 65 if idx is None: 66 return sourcelines 67 68 # Get a string representing the amount of leading whitespace 69 fn_def = sourcelines[idx] 70 whitespace = fn_def.split("def")[0] 71 72 # Add this leading whitespace to all lines before and after the `def` 73 aligned_prefix = [ 74 whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx] 75 ] 76 aligned_suffix = [ 77 whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :] 78 ] 79 80 # Put it together again 81 aligned_prefix.append(fn_def) 82 return aligned_prefix + aligned_suffix 83 84 85# Thin wrapper around SourceRangeFactory to store extra metadata 86# about the function-to-be-compiled. 87class SourceContext(SourceRangeFactory): 88 def __init__( 89 self, 90 source, 91 filename, 92 file_lineno, 93 leading_whitespace_len, 94 uses_true_division=True, 95 funcname=None, 96 ): 97 super().__init__(source, filename, file_lineno, leading_whitespace_len) 98 self.uses_true_division = uses_true_division 99 self.filename = filename 100 self.funcname = funcname 101 102 103@functools.lru_cache(maxsize=None) 104def make_source_context(*args): 105 return SourceContext(*args) 106 107 108def fake_range(): 109 return SourceContext("", None, 0, 0).make_raw_range(0, 1) 110 111 112class ParsedDef(NamedTuple): 113 ast: ast.Module 114 ctx: SourceContext 115 source: str 116 filename: Optional[str] 117 file_lineno: int 118 119 120def parse_def(fn): 121 sourcelines, file_lineno, filename = get_source_lines_and_file( 122 fn, ErrorReport.call_stack() 123 ) 124 sourcelines = normalize_source_lines(sourcelines) 125 source = "".join(sourcelines) 126 dedent_src = dedent(source) 127 py_ast = ast.parse(dedent_src) 128 if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): 129 raise RuntimeError( 130 f"Expected a single top-level function: {filename}:{file_lineno}" 131 ) 132 leading_whitespace_len = len(source.split("\n", 1)[0]) - len( 133 dedent_src.split("\n", 1)[0] 134 ) 135 ctx = make_source_context( 136 source, filename, file_lineno, leading_whitespace_len, True, fn.__name__ 137 ) 138 return ParsedDef(py_ast, ctx, source, filename, file_lineno) 139