xref: /aosp_15_r20/external/pytorch/torch/_sources.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import ast
3import functools
4import inspect
5from textwrap import dedent
6from typing import Any, List, NamedTuple, Optional, Tuple
7
8from torch._C import ErrorReport
9from torch._C._jit_tree_views import SourceRangeFactory
10
11
12def get_source_lines_and_file(
13    obj: Any,
14    error_msg: Optional[str] = None,
15) -> Tuple[List[str], int, Optional[str]]:
16    """
17    Wrapper around inspect.getsourcelines and inspect.getsourcefile.
18
19    Returns: (sourcelines, file_lino, filename)
20    """
21    filename = None  # in case getsourcefile throws
22    try:
23        filename = inspect.getsourcefile(obj)
24        sourcelines, file_lineno = inspect.getsourcelines(obj)
25    except OSError as e:
26        msg = (
27            f"Can't get source for {obj}. TorchScript requires source access in "
28            "order to carry out compilation, make sure original .py files are "
29            "available."
30        )
31        if error_msg:
32            msg += "\n" + error_msg
33        raise OSError(msg) from e
34
35    return sourcelines, file_lineno, filename
36
37
38def normalize_source_lines(sourcelines: List[str]) -> List[str]:
39    """
40    This helper function accepts a list of source lines. It finds the
41    indentation level of the function definition (`def`), then it indents
42    all lines in the function body to a point at or greater than that
43    level. This allows for comments and continued string literals that
44    are at a lower indentation than the rest of the code.
45    Args:
46        sourcelines: function source code, separated into lines by
47                        the '\n' character
48    Returns:
49        A list of source lines that have been correctly aligned
50    """
51
52    def remove_prefix(text, prefix):
53        return text[text.startswith(prefix) and len(prefix) :]
54
55    # Find the line and line number containing the function definition
56    idx = None
57    for i, l in enumerate(sourcelines):
58        if l.lstrip().startswith("def"):
59            idx = i
60            break
61
62    # This will happen when the function is a lambda- we won't find "def" anywhere in the source
63    # lines in that case. Currently trying to JIT compile a lambda will throw an error up in
64    # `parse_def()`, but we might want to handle this case in the future.
65    if idx is None:
66        return sourcelines
67
68    # Get a string representing the amount of leading whitespace
69    fn_def = sourcelines[idx]
70    whitespace = fn_def.split("def")[0]
71
72    # Add this leading whitespace to all lines before and after the `def`
73    aligned_prefix = [
74        whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]
75    ]
76    aligned_suffix = [
77        whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :]
78    ]
79
80    # Put it together again
81    aligned_prefix.append(fn_def)
82    return aligned_prefix + aligned_suffix
83
84
85# Thin wrapper around SourceRangeFactory to store extra metadata
86# about the function-to-be-compiled.
87class SourceContext(SourceRangeFactory):
88    def __init__(
89        self,
90        source,
91        filename,
92        file_lineno,
93        leading_whitespace_len,
94        uses_true_division=True,
95        funcname=None,
96    ):
97        super().__init__(source, filename, file_lineno, leading_whitespace_len)
98        self.uses_true_division = uses_true_division
99        self.filename = filename
100        self.funcname = funcname
101
102
103@functools.lru_cache(maxsize=None)
104def make_source_context(*args):
105    return SourceContext(*args)
106
107
108def fake_range():
109    return SourceContext("", None, 0, 0).make_raw_range(0, 1)
110
111
112class ParsedDef(NamedTuple):
113    ast: ast.Module
114    ctx: SourceContext
115    source: str
116    filename: Optional[str]
117    file_lineno: int
118
119
120def parse_def(fn):
121    sourcelines, file_lineno, filename = get_source_lines_and_file(
122        fn, ErrorReport.call_stack()
123    )
124    sourcelines = normalize_source_lines(sourcelines)
125    source = "".join(sourcelines)
126    dedent_src = dedent(source)
127    py_ast = ast.parse(dedent_src)
128    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
129        raise RuntimeError(
130            f"Expected a single top-level function: {filename}:{file_lineno}"
131        )
132    leading_whitespace_len = len(source.split("\n", 1)[0]) - len(
133        dedent_src.split("\n", 1)[0]
134    )
135    ctx = make_source_context(
136        source, filename, file_lineno, leading_whitespace_len, True, fn.__name__
137    )
138    return ParsedDef(py_ast, ctx, source, filename, file_lineno)
139