xref: /aosp_15_r20/external/pytorch/torch/jit/_check.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import ast
3import inspect
4import textwrap
5import warnings
6
7import torch
8
9
10class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
11    """Check the ``__init__`` method of a given ``nn.Module``.
12
13    It ensures that all instance-level attributes can be properly initialized.
14
15    Specifically, we do type inference based on attribute values...even
16    if the attribute in question has already been typed using
17    Python3-style annotations or ``torch.jit.annotate``. This means that
18    setting an instance-level attribute to ``[]`` (for ``List``),
19    ``{}`` for ``Dict``), or ``None`` (for ``Optional``) isn't enough
20    information for us to properly initialize that attribute.
21
22    An object of this class can walk a given ``nn.Module``'s AST and
23    determine if it meets our requirements or not.
24
25    Known limitations
26    1. We can only check the AST nodes for certain constructs; we can't
27    ``eval`` arbitrary expressions. This means that function calls,
28    class instantiations, and complex expressions that resolve to one of
29    the "empty" values specified above will NOT be flagged as
30    problematic.
31    2. We match on string literals, so if the user decides to use a
32    non-standard import (e.g. `from typing import List as foo`), we
33    won't catch it.
34
35    Example:
36        .. code-block:: python
37
38            class M(torch.nn.Module):
39                def fn(self):
40                    return []
41
42                def __init__(self) -> None:
43                    super().__init__()
44                    self.x: List[int] = []
45
46                def forward(self, x: List[int]):
47                    self.x = x
48                    return 1
49
50        The above code will pass the ``AttributeTypeIsSupportedChecker``
51        check since we have a function call in ``__init__``. However,
52        it will still fail later with the ``RuntimeError`` "Tried to set
53        nonexistent attribute: x. Did you forget to initialize it in
54        __init__()?".
55
56    Args:
57        nn_module - The instance of ``torch.nn.Module`` whose
58            ``__init__`` method we wish to check
59    """
60
61    def check(self, nn_module: torch.nn.Module) -> None:
62        source_lines = inspect.getsource(nn_module.__class__.__init__)
63
64        # Ignore comments no matter the indentation
65        def is_useless_comment(line):
66            line = line.strip()
67            return line.startswith("#") and not line.startswith("# type:")
68
69        source_lines = "\n".join(
70            [l for l in source_lines.split("\n") if not is_useless_comment(l)]
71        )
72
73        # This AST only contains the `__init__` method of the nn.Module
74        init_ast = ast.parse(textwrap.dedent(source_lines))
75
76        # Get items annotated in the class body
77        self.class_level_annotations = list(nn_module.__annotations__.keys())
78
79        # Flag for later
80        self.visiting_class_level_ann = False
81
82        self.visit(init_ast)
83
84    def _is_empty_container(self, node: ast.AST, ann_type: str) -> bool:
85        if ann_type == "List":
86            # Assigning `[]` to a `List` type gives you a Node where
87            # value=List(elts=[], ctx=Load())
88            if not isinstance(node, ast.List):
89                return False
90            if node.elts:
91                return False
92        elif ann_type == "Dict":
93            # Assigning `{}` to a `Dict` type gives you a Node where
94            # value=Dict(keys=[], values=[])
95            if not isinstance(node, ast.Dict):
96                return False
97            if node.keys:
98                return False
99        elif ann_type == "Optional":
100            # Assigning `None` to an `Optional` type gives you a
101            # Node where value=Constant(value=None, kind=None)
102            if not isinstance(node, ast.Constant):
103                return False
104            if node.value:  # type: ignore[attr-defined]
105                return False
106
107        return True
108
109    def visit_Assign(self, node):
110        """Store assignment state when assigning to a Call Node.
111
112        If we're visiting a Call Node (the right-hand side of an
113        assignment statement), we won't be able to check the variable
114        that we're assigning to (the left-hand side of an assignment).
115        Because of this, we need to store this state in visitAssign.
116        (Luckily, we only have to do this if we're assigning to a Call
117        Node, i.e. ``torch.jit.annotate``. If we're using normal Python
118        annotations, we'll be visiting an AnnAssign Node, which has its
119        target built in.)
120        """
121        try:
122            if (
123                isinstance(node.value, ast.Call)
124                and node.targets[0].attr in self.class_level_annotations
125            ):
126                self.visiting_class_level_ann = True
127        except AttributeError:
128            return
129        self.generic_visit(node)
130        self.visiting_class_level_ann = False
131
132    def visit_AnnAssign(self, node):
133        """Visit an AnnAssign node in an ``nn.Module``'s ``__init__`` method.
134
135        It checks if it conforms to our attribute annotation rules."""
136        # If we have a local variable
137        try:
138            if node.target.value.id != "self":
139                return
140        except AttributeError:
141            return
142
143        # If we have an attribute that's already been annotated at the
144        # class level
145        if node.target.attr in self.class_level_annotations:
146            return
147
148        # TODO @ansley: add `Union` once landed
149
150        # NB: Even though `Tuple` is a "container", we don't want to
151        # check for it here. `Tuple` functions as an type with an
152        # "infinite" number of subtypes, in the sense that you can have
153        # `Tuple[())]`, `Tuple[T1]`, `Tuple[T2]`, `Tuple[T1, T2]`,
154        # `Tuple[T2, T1]` and so on, and none of these subtypes can be
155        # used in place of the other. Therefore, assigning an empty
156        # tuple in `__init__` CORRECTLY means that that variable
157        # cannot be reassigned later to a non-empty tuple. Same
158        # deal with `NamedTuple`
159
160        containers = {"List", "list", "Dict", "dict", "Optional"}
161
162        # If we're not evaluating one of the specified problem types
163        try:
164            if node.annotation.value.id not in containers:
165                return
166        except AttributeError:
167            # To evaluate a base type (`str`, `int`, etc.), we would
168            # have needed to get the name through `node.annotation.id`
169            # instead of `node.annotation.value.id`. Seems that we're
170            # not evaluating one of our "containers"
171            return
172
173        # Check if the assigned variable is empty
174        ann_type = node.annotation.value.id
175        if not self._is_empty_container(node.value, ann_type):
176            return
177
178        warnings.warn(
179            "The TorchScript type system doesn't support "
180            "instance-level annotations on empty non-base "
181            "types in `__init__`. Instead, either 1) use a "
182            "type annotation in the class body, or 2) wrap "
183            "the type in `torch.jit.Attribute`."
184        )
185
186    def visit_Call(self, node):
187        """Determine if a Call node is 'torch.jit.annotate' in __init__.
188
189        Visit a Call node in an ``nn.Module``'s ``__init__``
190        method and determine if it's ``torch.jit.annotate``. If so,
191        see if it conforms to our attribute annotation rules.
192        """
193        # If we have an attribute that's already been annotated at the
194        # class level
195        if self.visiting_class_level_ann:
196            return
197
198        # If this isn't a call to `torch.jit.annotate`
199        try:
200            if (
201                node.func.value.value.id != "torch"
202                or node.func.value.attr != "jit"
203                or node.func.attr != "annotate"
204            ):
205                self.generic_visit(node)
206            elif (
207                node.func.value.value.id != "jit" or node.func.value.attr != "annotate"
208            ):
209                self.generic_visit(node)
210        except AttributeError:
211            # Looks like we didn't even have the right node structure
212            # to check for `torch.jit.annotate` in the first place
213            self.generic_visit(node)
214
215        # Invariant: we have a `torch.jit.annotate` or a
216        # `torch.annotate` call
217
218        # A Call Node for `torch.jit.annotate` should have an `args`
219        # list of length 2 where args[0] represents the annotation and
220        # args[1] represents the actual value
221        if len(node.args) != 2:
222            return
223
224        if not isinstance(node.args[0], ast.Subscript):
225            return
226
227        # See notes in `visit_AnnAssign` r.e. containers
228
229        containers = {"List", "Dict", "Optional"}
230
231        try:
232            ann_type = node.args[0].value.id  # type: ignore[attr-defined]
233        except AttributeError:
234            return
235
236        if ann_type not in containers:
237            return
238
239        # Check if the assigned variable is empty
240        if not self._is_empty_container(node.args[1], ann_type):
241            return
242
243        warnings.warn(
244            "The TorchScript type system doesn't support "
245            "instance-level annotations on empty non-base "
246            "types in `__init__`. Instead, either 1) use a "
247            "type annotation in the class body, or 2) wrap "
248            "the type in `torch.jit.Attribute`."
249        )
250