1# mypy: allow-untyped-defs 2import ast 3import inspect 4import textwrap 5import warnings 6 7import torch 8 9 10class AttributeTypeIsSupportedChecker(ast.NodeVisitor): 11 """Check the ``__init__`` method of a given ``nn.Module``. 12 13 It ensures that all instance-level attributes can be properly initialized. 14 15 Specifically, we do type inference based on attribute values...even 16 if the attribute in question has already been typed using 17 Python3-style annotations or ``torch.jit.annotate``. This means that 18 setting an instance-level attribute to ``[]`` (for ``List``), 19 ``{}`` for ``Dict``), or ``None`` (for ``Optional``) isn't enough 20 information for us to properly initialize that attribute. 21 22 An object of this class can walk a given ``nn.Module``'s AST and 23 determine if it meets our requirements or not. 24 25 Known limitations 26 1. We can only check the AST nodes for certain constructs; we can't 27 ``eval`` arbitrary expressions. This means that function calls, 28 class instantiations, and complex expressions that resolve to one of 29 the "empty" values specified above will NOT be flagged as 30 problematic. 31 2. We match on string literals, so if the user decides to use a 32 non-standard import (e.g. `from typing import List as foo`), we 33 won't catch it. 34 35 Example: 36 .. code-block:: python 37 38 class M(torch.nn.Module): 39 def fn(self): 40 return [] 41 42 def __init__(self) -> None: 43 super().__init__() 44 self.x: List[int] = [] 45 46 def forward(self, x: List[int]): 47 self.x = x 48 return 1 49 50 The above code will pass the ``AttributeTypeIsSupportedChecker`` 51 check since we have a function call in ``__init__``. However, 52 it will still fail later with the ``RuntimeError`` "Tried to set 53 nonexistent attribute: x. Did you forget to initialize it in 54 __init__()?". 55 56 Args: 57 nn_module - The instance of ``torch.nn.Module`` whose 58 ``__init__`` method we wish to check 59 """ 60 61 def check(self, nn_module: torch.nn.Module) -> None: 62 source_lines = inspect.getsource(nn_module.__class__.__init__) 63 64 # Ignore comments no matter the indentation 65 def is_useless_comment(line): 66 line = line.strip() 67 return line.startswith("#") and not line.startswith("# type:") 68 69 source_lines = "\n".join( 70 [l for l in source_lines.split("\n") if not is_useless_comment(l)] 71 ) 72 73 # This AST only contains the `__init__` method of the nn.Module 74 init_ast = ast.parse(textwrap.dedent(source_lines)) 75 76 # Get items annotated in the class body 77 self.class_level_annotations = list(nn_module.__annotations__.keys()) 78 79 # Flag for later 80 self.visiting_class_level_ann = False 81 82 self.visit(init_ast) 83 84 def _is_empty_container(self, node: ast.AST, ann_type: str) -> bool: 85 if ann_type == "List": 86 # Assigning `[]` to a `List` type gives you a Node where 87 # value=List(elts=[], ctx=Load()) 88 if not isinstance(node, ast.List): 89 return False 90 if node.elts: 91 return False 92 elif ann_type == "Dict": 93 # Assigning `{}` to a `Dict` type gives you a Node where 94 # value=Dict(keys=[], values=[]) 95 if not isinstance(node, ast.Dict): 96 return False 97 if node.keys: 98 return False 99 elif ann_type == "Optional": 100 # Assigning `None` to an `Optional` type gives you a 101 # Node where value=Constant(value=None, kind=None) 102 if not isinstance(node, ast.Constant): 103 return False 104 if node.value: # type: ignore[attr-defined] 105 return False 106 107 return True 108 109 def visit_Assign(self, node): 110 """Store assignment state when assigning to a Call Node. 111 112 If we're visiting a Call Node (the right-hand side of an 113 assignment statement), we won't be able to check the variable 114 that we're assigning to (the left-hand side of an assignment). 115 Because of this, we need to store this state in visitAssign. 116 (Luckily, we only have to do this if we're assigning to a Call 117 Node, i.e. ``torch.jit.annotate``. If we're using normal Python 118 annotations, we'll be visiting an AnnAssign Node, which has its 119 target built in.) 120 """ 121 try: 122 if ( 123 isinstance(node.value, ast.Call) 124 and node.targets[0].attr in self.class_level_annotations 125 ): 126 self.visiting_class_level_ann = True 127 except AttributeError: 128 return 129 self.generic_visit(node) 130 self.visiting_class_level_ann = False 131 132 def visit_AnnAssign(self, node): 133 """Visit an AnnAssign node in an ``nn.Module``'s ``__init__`` method. 134 135 It checks if it conforms to our attribute annotation rules.""" 136 # If we have a local variable 137 try: 138 if node.target.value.id != "self": 139 return 140 except AttributeError: 141 return 142 143 # If we have an attribute that's already been annotated at the 144 # class level 145 if node.target.attr in self.class_level_annotations: 146 return 147 148 # TODO @ansley: add `Union` once landed 149 150 # NB: Even though `Tuple` is a "container", we don't want to 151 # check for it here. `Tuple` functions as an type with an 152 # "infinite" number of subtypes, in the sense that you can have 153 # `Tuple[())]`, `Tuple[T1]`, `Tuple[T2]`, `Tuple[T1, T2]`, 154 # `Tuple[T2, T1]` and so on, and none of these subtypes can be 155 # used in place of the other. Therefore, assigning an empty 156 # tuple in `__init__` CORRECTLY means that that variable 157 # cannot be reassigned later to a non-empty tuple. Same 158 # deal with `NamedTuple` 159 160 containers = {"List", "list", "Dict", "dict", "Optional"} 161 162 # If we're not evaluating one of the specified problem types 163 try: 164 if node.annotation.value.id not in containers: 165 return 166 except AttributeError: 167 # To evaluate a base type (`str`, `int`, etc.), we would 168 # have needed to get the name through `node.annotation.id` 169 # instead of `node.annotation.value.id`. Seems that we're 170 # not evaluating one of our "containers" 171 return 172 173 # Check if the assigned variable is empty 174 ann_type = node.annotation.value.id 175 if not self._is_empty_container(node.value, ann_type): 176 return 177 178 warnings.warn( 179 "The TorchScript type system doesn't support " 180 "instance-level annotations on empty non-base " 181 "types in `__init__`. Instead, either 1) use a " 182 "type annotation in the class body, or 2) wrap " 183 "the type in `torch.jit.Attribute`." 184 ) 185 186 def visit_Call(self, node): 187 """Determine if a Call node is 'torch.jit.annotate' in __init__. 188 189 Visit a Call node in an ``nn.Module``'s ``__init__`` 190 method and determine if it's ``torch.jit.annotate``. If so, 191 see if it conforms to our attribute annotation rules. 192 """ 193 # If we have an attribute that's already been annotated at the 194 # class level 195 if self.visiting_class_level_ann: 196 return 197 198 # If this isn't a call to `torch.jit.annotate` 199 try: 200 if ( 201 node.func.value.value.id != "torch" 202 or node.func.value.attr != "jit" 203 or node.func.attr != "annotate" 204 ): 205 self.generic_visit(node) 206 elif ( 207 node.func.value.value.id != "jit" or node.func.value.attr != "annotate" 208 ): 209 self.generic_visit(node) 210 except AttributeError: 211 # Looks like we didn't even have the right node structure 212 # to check for `torch.jit.annotate` in the first place 213 self.generic_visit(node) 214 215 # Invariant: we have a `torch.jit.annotate` or a 216 # `torch.annotate` call 217 218 # A Call Node for `torch.jit.annotate` should have an `args` 219 # list of length 2 where args[0] represents the annotation and 220 # args[1] represents the actual value 221 if len(node.args) != 2: 222 return 223 224 if not isinstance(node.args[0], ast.Subscript): 225 return 226 227 # See notes in `visit_AnnAssign` r.e. containers 228 229 containers = {"List", "Dict", "Optional"} 230 231 try: 232 ann_type = node.args[0].value.id # type: ignore[attr-defined] 233 except AttributeError: 234 return 235 236 if ann_type not in containers: 237 return 238 239 # Check if the assigned variable is empty 240 if not self._is_empty_container(node.args[1], ann_type): 241 return 242 243 warnings.warn( 244 "The TorchScript type system doesn't support " 245 "instance-level annotations on empty non-base " 246 "types in `__init__`. Instead, either 1) use a " 247 "type annotation in the class body, or 2) wrap " 248 "the type in `torch.jit.Attribute`." 249 ) 250