1# mypy: allow-untyped-defs 2from collections import defaultdict 3from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name 4import torch.utils._pytree as pytree 5from . import _pytree as fx_pytree 6from ._compatibility import compatibility 7from torch._C import _NodeIter 8 9import os 10import contextlib 11from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type, Iterable 12from dataclasses import dataclass 13from contextlib import contextmanager 14import copy 15import enum 16import torch 17import keyword 18import re 19import builtins 20import math 21import warnings 22import inspect 23 24__all__ = ["PythonCode", "CodeGen", "Graph"] 25 26if TYPE_CHECKING: 27 from .graph_module import GraphModule # noqa: F401 28 from ._symbolic_trace import Tracer # noqa: F401 29 30 31# Mapping of builtins to their `typing` equivalent. 32_origin_type_map = { 33 list: List, 34 dict: Dict, 35 set: Set, 36 frozenset: FrozenSet, 37 tuple: Tuple, 38} 39 40 41# Signature for functions thattransforms the body (`list[str]`) of the 42# generated code 43TransformCodeFunc = Callable[[List[str]], List[str]] 44 45 46class _CustomBuiltin(NamedTuple): 47 """Additional objs that we add to every graph's globals. 48 49 The repr() for some standard library objects is not valid Python code without 50 an import. For common objects of this sort, we bundle them in the globals of 51 every FX graph. 52 """ 53 # How to import this object from the standard library. 54 import_str: str 55 # The actual object, produced from that import string. 56 obj: Any 57 58_custom_builtins: Dict[str, _CustomBuiltin] = {} 59 60 61def _register_custom_builtin(name: str, import_str: str, obj: Any): 62 _custom_builtins[name] = _CustomBuiltin(import_str, obj) 63 64 65_register_custom_builtin('inf', 'from math import inf', math.inf) 66_register_custom_builtin('nan', 'from math import nan', math.nan) 67_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None)) 68_register_custom_builtin('torch', 'import torch', torch) 69_register_custom_builtin('device', 'from torch import device', torch.device) 70_register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree) 71_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree) 72 73 74def _is_magic(x: str) -> bool: 75 return x.startswith('__') and x.endswith('__') 76 77 78def _snake_case(s: str) -> str: 79 """ 80 Transforms the given string ``s`` to a Python-style variable name 81 82 Examples: 83 ``mod.snake_case`` -> ``mod.snake_case`` 84 ``mod.pascalCase``-> ``mod.pascal_case`` 85 ``mod.ALL_CAPS`` -> ``mod.all_caps`` 86 """ 87 chars = [] 88 prev_lower = False 89 for c in s: 90 if prev_lower and c.isupper(): 91 chars.append('_') 92 chars.append(c.lower()) 93 prev_lower = c.islower() 94 return ''.join(chars) 95 96 97def _is_from_torch(obj: Any) -> bool: 98 module_name = getattr(obj, '__module__', None) 99 if module_name is not None: 100 base_module = module_name.partition('.')[0] 101 return ( 102 base_module == 'torch' and 103 not module_name.startswith("torch._dynamo.") and 104 not module_name.startswith("torch._inductor.") 105 ) 106 107 name = getattr(obj, '__name__', None) 108 # exclude torch because torch.torch.torch.torch works. idk mang 109 if name is not None and name != 'torch': 110 for guess in [torch, torch.nn.functional]: 111 if getattr(guess, name, None) is obj: 112 return True 113 114 return False 115 116 117class _Namespace: 118 """A context for associating names uniquely with objects. 119 120 The following invariants are enforced: 121 - Each object gets a single name. 122 - Each name is unique within a given namespace. 123 - Names generated do not shadow builtins, unless the object is indeed that builtin. 124 """ 125 def __init__(self): 126 self._obj_to_name: Dict[Any, str] = {} 127 self._unassociated_names = set() 128 self._used_names: Set[str] = set() 129 self._base_count: Dict[str, int] = defaultdict(int) 130 131 self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+') 132 self._name_suffix_regex = re.compile(r"(.*)_(\d+)$") 133 134 def create_name(self, candidate: str, obj: Optional[Any]) -> str: 135 """Create a unique name. 136 137 Arguments: 138 candidate: used as the basis for the unique name, relevant to the user. 139 obj: If not None, an object that will be associated with the unique name. 140 """ 141 if obj is not None and obj in self._obj_to_name: 142 return self._obj_to_name[obj] 143 144 # delete all characters that are illegal in a Python identifier 145 candidate = self._illegal_char_regex.sub('_', candidate) 146 147 if not candidate: 148 candidate = '_unnamed' 149 150 if candidate[0].isdigit(): 151 candidate = f'_{candidate}' 152 153 match = self._name_suffix_regex.match(candidate) 154 if match is None: 155 base = candidate 156 num = None 157 else: 158 base, num_str = match.group(1, 2) 159 num = int(num_str) 160 161 candidate = base if num is None else f'{base}_{num}' 162 if not num: 163 num = self._base_count[base] 164 165 while candidate in self._used_names or self._is_illegal_name(candidate, obj): 166 num += 1 167 candidate = f'{base}_{num}' 168 169 self._used_names.add(candidate) 170 self._base_count[base] = num 171 if obj is None: 172 self._unassociated_names.add(candidate) 173 else: 174 self._obj_to_name[obj] = candidate 175 return candidate 176 177 def associate_name_with_obj(self, name: str, obj: Any): 178 """Associate a unique name with an object. 179 180 Neither `name` nor `obj` should be associated already. 181 """ 182 assert obj not in self._obj_to_name 183 assert name in self._unassociated_names 184 self._obj_to_name[obj] = name 185 self._unassociated_names.remove(name) 186 187 def _is_illegal_name(self, name: str, obj: Any) -> bool: 188 # 1. keywords are never allowed as names. 189 if name in keyword.kwlist: 190 return True 191 192 # 2. Can't shadow a builtin name, unless you *are* that builtin. 193 if name in builtins.__dict__: 194 return obj is not builtins.__dict__[name] 195 196 # 3. Can't shadow our custom builtins either 197 if name in _custom_builtins: 198 return obj is not _custom_builtins[name].obj 199 200 return False 201 202 def _rename_object(self, obj: Any, name: str): 203 assert obj in self._obj_to_name 204 self._obj_to_name[obj] = name 205 self._used_names.add(name) 206 207dtype_abbrs = { 208 torch.bfloat16: 'bf16', 209 torch.float64: 'f64', 210 torch.float32: 'f32', 211 torch.float16: 'f16', 212 torch.float8_e4m3fn: 'f8e4m3fn', 213 torch.float8_e5m2: 'f8e5m2', 214 torch.float8_e4m3fnuz: 'f8e4m3fnuz', 215 torch.float8_e5m2fnuz: 'f8e5m2fnuz', 216 torch.complex32: 'c32', 217 torch.complex64: 'c64', 218 torch.complex128: 'c128', 219 torch.int8: 'i8', 220 torch.int16: 'i16', 221 torch.int32: 'i32', 222 torch.int64: 'i64', 223 torch.bool: 'b8', 224 torch.uint8: 'u8', 225 torch.uint16: 'u16', 226 torch.uint32: 'u32', 227 torch.uint64: 'u64', 228 torch.bits16: 'b16', 229} 230 231@compatibility(is_backward_compatible=True) 232@dataclass 233class PythonCode: 234 """ 235 Represents all the information necessary to exec or save a graph as Python code. 236 """ 237 # Python source code for the forward function definition. 238 src: str 239 # Values in global scope during execution of `src_def`. 240 globals: Dict[str, Any] 241 # Optional mapping from the forward function's line number to 242 # node index. 243 _lineno_map: Optional[Dict[int, Optional[int]]] 244 245 246def _format_target(base: str, target: str) -> str: 247 elems = target.split('.') 248 r = base 249 for e in elems: 250 if not e.isidentifier(): 251 r = f'getattr({r}, "{e}")' 252 else: 253 r = f'{r}.{e}' 254 return r 255 256class _InsertPoint: 257 def __init__(self, graph, new_insert): 258 self.graph = graph 259 self.orig_insert, graph._insert = graph._insert, new_insert 260 261 def __enter__(self): 262 pass 263 264 def __exit__(self, type, value, tb): 265 self.graph._insert = self.orig_insert 266 267class _node_list: 268 def __init__(self, graph: 'Graph', direction: str = '_next'): 269 assert direction in ['_next', '_prev'] 270 self.graph = graph 271 self.direction = direction 272 273 def __len__(self): 274 return self.graph._len 275 276 def __iter__(self): 277 assert self.direction == "_prev" or self.direction == "_next" 278 yield from _NodeIter(self.graph._root, self.direction == "_prev") 279 280 def __reversed__(self): 281 return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev') 282 283class _PyTreeInfo(NamedTuple): 284 """ 285 Contains extra info stored when we're using Pytrees 286 """ 287 orig_args: List[str] 288 in_spec: pytree.TreeSpec 289 out_spec: Optional[pytree.TreeSpec] 290 291@dataclass(frozen=True) 292class _ParsedStackTrace: 293 """ 294 Represents the top-most frame of a parsed stack trace 295 """ 296 file: str 297 lineno: str 298 name: str 299 code: str 300 301 def get_summary_str(self): 302 return f'File: {self.file}:{self.lineno} in {self.name}, code: {self.code}' 303 304# get File:lineno code from stack_trace 305def _parse_stack_trace(stack_trace: str): 306 if stack_trace is None: 307 return None 308 pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$") 309 lines = stack_trace.strip().split('\n') 310 # stacktrace should have innermost frame last, so we 311 # iterate backwards to find the first line that starts 312 # with 'File ' 313 summary_str = "" 314 for idx in range(len(lines) - 2, -1, -1): 315 line = lines[idx].strip() 316 matches = pattern.match(line) 317 if matches: 318 file = matches.group(1) 319 lineno = matches.group(2) 320 name = matches.group(3) 321 # next line should be the code 322 code = lines[idx + 1].strip() 323 return _ParsedStackTrace(file, lineno, name, code) 324 return None 325 326@compatibility(is_backward_compatible=False) 327class CodeGen: 328 def __init__(self): 329 self._body_transformer: Optional[TransformCodeFunc] = None 330 self._func_name: str = "forward" 331 332 def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str: 333 """ 334 Given the free variables and a return annotation, generates the beginning of the FX function. 335 By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'` 336 """ 337 # If the original function didn't have self as its first argument, we 338 # would have added it. 339 if len(free_vars) == 0 or free_vars[0] != 'self': 340 free_vars.insert(0, 'self') 341 return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:" 342 343 def generate_output(self, output_args: Argument) -> str: 344 """ 345 Given the output arguments, generates the return statement of the FX function. 346 Note: The returned statement should not be indented. 347 """ 348 return f'return {repr(output_args)}' 349 350 def process_inputs(self, *args: Any) -> Any: 351 """ 352 Transforms the inputs so that the graph can take them as arguments, as 353 non-default codegen may result in the inputs to the function being 354 different from the inputs to the graph. 355 356 If the graph was directly runnable, this invariant should hold true 357 `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)` 358 """ 359 return args 360 361 def process_outputs(self, outputs: Any) -> Any: 362 """ 363 Transforms the outputs of the graph to be identical to the codegen. 364 365 See ``process_inputs`` for more details. 366 """ 367 return outputs 368 369 def additional_globals(self) -> List[Tuple[str, Any]]: 370 """ 371 If your codegen uses extra global values, add tuples of (identifier,reference to the value) here. 372 For example, return ['List', typing.List] if you need ``List`` in the global context. 373 """ 374 return [] 375 376 def _gen_python_code( 377 self, nodes, root_module: str, namespace: _Namespace, *, 378 verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False 379 ) -> PythonCode: 380 free_vars: List[str] = [] 381 body: List[str] = [] 382 globals_: Dict[str, Any] = {} 383 wrapped_fns: Dict[str, None] = {} 384 385 # Wrap string in list to pass by reference 386 maybe_return_annotation : List[str] = [''] 387 include_stride = include_stride or (os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1") 388 include_device = include_device or (os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1") 389 390 def add_global(name_hint: str, obj: Any): 391 """Add an obj to be tracked as a global. 392 393 We call this for names that reference objects external to the 394 Graph, like functions or types. 395 396 Returns: the global name that should be used to reference 'obj' in generated source. 397 """ 398 if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device 399 # HACK: workaround for how torch custom ops are registered. We 400 # can't import them like normal modules so they must retain their 401 # fully qualified name. 402 return _get_qualified_name(obj) 403 404 # normalize the name hint to get a proper identifier 405 global_name = namespace.create_name(name_hint, obj) 406 407 if global_name in globals_: 408 assert globals_[global_name] is obj 409 return global_name 410 globals_[global_name] = obj 411 return global_name 412 413 # Pre-fill the globals table with registered builtins. 414 for name, (_, obj) in _custom_builtins.items(): 415 add_global(name, obj) 416 417 def type_repr(o : Any): 418 if o == (): 419 # Empty tuple is used for empty tuple type annotation Tuple[()] 420 return '()' 421 422 typename = _type_repr(o) 423 424 if hasattr(o, '__origin__'): 425 # This is a generic type, e.g. typing.List[torch.Tensor] 426 origin_type = _origin_type_map.get(o.__origin__, o.__origin__) 427 origin_typename = add_global(_type_repr(origin_type), origin_type) 428 429 if hasattr(o, '__args__'): 430 # Assign global names for each of the inner type variables. 431 args = [type_repr(arg) for arg in o.__args__] 432 433 if len(args) == 0: 434 # Bare type, such as `typing.Tuple` with no subscript 435 # This code-path used in Python < 3.9 436 return origin_typename 437 438 return f'{origin_typename}[{",".join(args)}]' 439 else: 440 # Bare type, such as `typing.Tuple` with no subscript 441 # This code-path used in Python 3.9+ 442 return origin_typename 443 444 # Common case: this is a regular module name like 'foo.bar.baz' 445 return add_global(typename, o) 446 447 codes = { 448 "yellow": "\033[33m", 449 "cyan": "\033[36m", 450 "green": "\033[32m", 451 "blue": "\033[34m", 452 "red": "\033[31m", 453 "dim": "\033[2m", 454 "dim_blue": "\033[2m\033[34m", 455 "dim_green": "\033[2m\033[32m", 456 "reset": "\033[0m", 457 } 458 459 def make_wrapper_func(name): 460 def f(s): 461 if colored: 462 return f"{codes[name]}{s}{codes['reset']}" 463 return s 464 return f 465 466 yellow = make_wrapper_func("yellow") 467 cyan = make_wrapper_func("cyan") 468 red = make_wrapper_func("red") 469 green = make_wrapper_func("green") 470 dim_green = make_wrapper_func("dim_green") 471 dim = make_wrapper_func("dim") 472 dim_blue = make_wrapper_func("dim_blue") 473 blue = make_wrapper_func("blue") 474 475 def _get_repr(arg: Any) -> str: 476 # Handle NamedTuples (if it has `_fields`) via add_global. 477 if isinstance(arg, tuple) and hasattr(arg, '_fields'): 478 qualified_name = _get_qualified_name(type(arg)) 479 global_name = add_global(qualified_name, type(arg)) 480 return f"{global_name}{repr(tuple(arg))}" 481 elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): 482 qualified_name = _get_qualified_name(arg) 483 global_name = add_global(qualified_name, arg) 484 return f"{global_name}" 485 elif isinstance(arg, enum.Enum): 486 cls = arg.__class__ 487 clsname = add_global(cls.__name__, cls) 488 return f"{clsname}.{arg.name}" 489 elif isinstance(arg, Node): 490 return repr(arg) 491 elif isinstance(arg, torch.Tensor): 492 size = list(arg.size()) 493 dtype = str(arg.dtype).split(".")[-1] 494 return f"torch.Tensor(size={size}, dtype={dtype})" 495 else: 496 return blue(repr(arg)) 497 498 499 def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: 500 args_s = ', '.join(_get_repr(a) for a in args) 501 kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) 502 if args_s and kwargs_s: 503 return f'{args_s}, {kwargs_s}' 504 return args_s or kwargs_s 505 506 # Run through reverse nodes and record the first instance of a use 507 # of a given node. This represents the *last* use of the node in the 508 # execution order of the program, which we will use to free unused 509 # values 510 node_to_last_use : Dict[Node, Node] = {} 511 user_to_last_uses : Dict[Node, List[Node]] = {} 512 513 def register_last_uses(n : Node, user : Node): 514 if n not in node_to_last_use: 515 node_to_last_use[n] = user 516 user_to_last_uses.setdefault(user, []).append(n) 517 518 for node in reversed(nodes): 519 map_arg(node.args, lambda n: register_last_uses(n, node)) 520 map_arg(node.kwargs, lambda n: register_last_uses(n, node)) 521 522 def delete_unused_values(user : Node): 523 """ 524 Delete values after their last use. This ensures that values that are 525 not used in the remainder of the code are freed and the memory usage 526 of the code is optimal. 527 """ 528 if user.op == 'placeholder': 529 return 530 if user.op == 'output': 531 body.append('\n') 532 return 533 nodes_to_delete = user_to_last_uses.get(user, []) 534 535 if len(user.users.keys()) == 0: 536 # This node is not used by any others. however it's also not 537 # removed by DCE since side-effect. We want to free it's outputs 538 # right after its execution done to save memory. 539 nodes_to_delete.append(user) 540 541 if len(nodes_to_delete): 542 to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) 543 body.append(f'; {dim(to_delete_str)}\n') 544 else: 545 body.append('\n') 546 547 prev_stacktrace = None 548 549 def append_stacktrace_summary(node : Node): 550 """ 551 Append a summary of the stacktrace to the generated code. This is 552 useful for debugging. 553 """ 554 nonlocal prev_stacktrace 555 556 if node.op not in {'placeholder', 'output'}: 557 if node.stack_trace: 558 if node.stack_trace != prev_stacktrace: 559 prev_stacktrace = node.stack_trace 560 summary_str = "" 561 562 if parsed_stack_trace := _parse_stack_trace(node.stack_trace): 563 summary_str = parsed_stack_trace.get_summary_str() 564 565 body.append(f'\n {dim("# " + summary_str)}\n') 566 elif prev_stacktrace != "": 567 prev_stacktrace = "" 568 no_stacktrace_msg = "# No stacktrace found for following nodes" 569 body.append(f'\n{dim(no_stacktrace_msg)}\n') 570 571 def stringify_shape(shape : Iterable) -> str: 572 return f"[{', '.join(str(x) for x in shape)}]" 573 574 def emit_node(node : Node): 575 maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' 576 577 if verbose: 578 # override annotation with more detailed information 579 from torch.fx.experimental.proxy_tensor import py_sym_types 580 from torch.fx.passes.shape_prop import TensorMetadata 581 582 meta_val = node.meta.get('val', node.meta.get('tensor_meta', node.meta.get('example_value', None))) 583 # use string as annotation, to make it valid python code 584 585 if isinstance(meta_val, torch.Tensor): 586 stride_annotation = f"{stringify_shape(meta_val.stride())}" if include_stride else "" 587 device_annotation = f"{meta_val.device}" if include_device else "" 588 maybe_type_annotation = \ 589 f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}' \ 590 f'{dim_blue(stride_annotation)}{dim_green(device_annotation)}"' 591 elif isinstance(meta_val, py_sym_types): 592 maybe_type_annotation = f': "Sym({meta_val})"' 593 elif isinstance(meta_val, TensorMetadata): 594 maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"' 595 596 if node.op == 'placeholder': 597 assert isinstance(node.target, str) 598 maybe_default_arg = '' if not node.args else f' = {_get_repr(node.args[0])}' 599 free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') 600 raw_name = node.target.replace('*', '') 601 if raw_name != repr(node): 602 body.append(f'{repr(node)} = {raw_name}\n') 603 return 604 elif node.op == 'call_method': 605 assert isinstance(node.target, str) 606 body.append( 607 f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}' 608 f'({_format_args(node.args[1:], node.kwargs)})') 609 return 610 elif node.op == 'call_function': 611 assert callable(node.target) 612 # pretty print operators 613 if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in magic_methods: 614 assert isinstance(node.args, tuple) 615 body.append(f'{repr(node)}{maybe_type_annotation} = ' 616 f'{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}') 617 return 618 619 # pretty print inplace operators; required for jit.script to work properly 620 # not currently supported in normal FX graphs, but generated by torchdynamo 621 if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in inplace_methods: 622 body.append(f'{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; ' 623 f'{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}') 624 return 625 626 qualified_name = _get_qualified_name(node.target) 627 global_name = add_global(qualified_name, node.target) 628 # special case for getattr: node.args could be 2-argument or 3-argument 629 # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value 630 if global_name == 'getattr' and \ 631 isinstance(node.args, tuple) and \ 632 isinstance(node.args[1], str) and \ 633 node.args[1].isidentifier() and \ 634 len(node.args) == 2: 635 body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}') 636 return 637 body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') 638 if node.meta.get('is_wrapped', False): 639 wrapped_fns.setdefault(global_name) 640 return 641 elif node.op == 'call_module': 642 assert isinstance(node.target, str) 643 body.append(f'{repr(node)}{maybe_type_annotation} = ' 644 f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') 645 return 646 elif node.op == 'get_attr': 647 assert isinstance(node.target, str) 648 body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') 649 return 650 elif node.op == 'output': 651 if node.type is not None: 652 maybe_return_annotation[0] = f" -> {type_repr(node.type)}" 653 body.append(self.generate_output(node.args[0])) 654 return 655 raise NotImplementedError(f'node: {node.op} {node.target}') 656 657 for i, node in enumerate(nodes): 658 # NOTE: emit_node does not emit a string with newline. It depends 659 # on delete_unused_values to append one 660 if verbose: 661 append_stacktrace_summary(node) 662 # emit a counter comment to keep track of 663 # node index, which will be deleted later 664 # after going through _body_transformer 665 body.append(f"# COUNTER: {i}\n") 666 emit_node(node) 667 delete_unused_values(node) 668 669 if len(body) == 0: 670 # If the Graph has no non-placeholder nodes, no lines for the body 671 # have been emitted. To continue to have valid Python code, emit a 672 # single pass statement 673 body.append('pass\n') 674 675 676 677 if len(wrapped_fns) > 0: 678 wrap_name = add_global('wrap', torch.fx.wrap) 679 wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) 680 else: 681 wrap_stmts = '' 682 683 if self._body_transformer: 684 body = self._body_transformer(body) 685 686 for name, value in self.additional_globals(): 687 add_global(name, value) 688 689 prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) 690 691 # remove counter and generate lineno to node index mapping 692 lineno_map: Dict[int, Optional[int]] = {} 693 prologue_len = prologue.count('\n') + 1 694 new_lines: List[str] = [] 695 cur_idx = None 696 for line in ''.join(body).split('\n'): 697 counter = re.search(r"# COUNTER: (\d+)", line) 698 if counter and counter.group(1) is not None: 699 cur_idx = int(counter.group(1)) 700 else: 701 lineno_map[len(new_lines) + prologue_len] = cur_idx 702 new_lines.append(line) 703 704 code = "\n".join(new_lines).lstrip('\n') 705 code = '\n'.join(' ' + line for line in code.split('\n')) 706 707 fn_code = f""" 708{wrap_stmts} 709 710{prologue} 711{code}""" 712 return PythonCode(fn_code, globals_, _lineno_map=lineno_map) 713 714 715# Ideally, we'd like to refactor all of the pytree logic into this codegen 716# class. Unfortunately, there are 3 areas we currently need extra logic in FX. 717# 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`. 718# 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec. 719# Since we can't access .graph within the FX forward, we need to copy the attribute to the module. 720# 3. We currently can't register the pytree imports with `add_global` - not sure why. 721class _PyTreeCodeGen(CodeGen): 722 def __init__(self, pytree_info: _PyTreeInfo): 723 super().__init__() 724 self.pytree_info: _PyTreeInfo = pytree_info 725 726 def process_inputs(self, *inputs: Any) -> Any: 727 flat_args = pytree.arg_tree_leaves(*inputs) 728 return flat_args 729 730 def process_outputs(self, out: Any) -> Any: 731 if self.pytree_info is None or self.pytree_info.out_spec is None: 732 return out 733 if not isinstance(out, (list, tuple)): 734 out = [out] 735 assert self.pytree_info.out_spec is not None 736 return pytree.tree_unflatten(out, self.pytree_info.out_spec) 737 738 def gen_fn_def(self, free_vars, maybe_return_annotation): 739 # Given a user function/model: 740 # myargs = [myargs0, myargs1] 741 # mykwargs = {'mykwargs0': ..., 'mykwargs1': ...} 742 # def forward(self, mypos, *myargs, mykey=None, **mykwargs): 743 # 744 # The generated code flattens all keywords into positional arguments for `forward()` 745 # e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1): 746 # 747 # Within `forward`, `tree_flatten_spec``still parses args and kwargs separately 748 # e.g. tree_flatten_spec(([mypos, myargs0, myargs1], 749 # {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}), 750 # self._in_spec) 751 # 752 # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec 753 # e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec) 754 if self.pytree_info is None: 755 return super().gen_fn_def(free_vars, maybe_return_annotation) 756 757 fn_args = self.pytree_info.orig_args 758 has_orig_self = (fn_args[0] == 'self') if len(fn_args) > 0 else False 759 if has_orig_self: 760 free_vars.insert(0, 'self') 761 fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation) 762 763 if len(free_vars) > 0: # pytree has placeholders in it 764 # when kwargs is present, in_spec is tuple(args, kwargs) 765 has_args_kwargs_tuple = self.pytree_info.in_spec.type == tuple and \ 766 self.pytree_info.in_spec.num_children == 2 and \ 767 self.pytree_info.in_spec.children_specs[0].type == tuple and \ 768 self.pytree_info.in_spec.children_specs[1].type == dict 769 fn_kwargs = '{}' 770 fn_signature = f"[{', '.join(fn_args)}], self._in_spec" 771 if has_args_kwargs_tuple: 772 count_args = self.pytree_info.in_spec.children_specs[0].num_children 773 fn_args = self.pytree_info.orig_args[:count_args] 774 fn_kwargs = '{' + ', '.join(f"'{k}':{v}" for k, v in zip( 775 self.pytree_info.in_spec.children_specs[1].context, 776 self.pytree_info.orig_args[count_args:])) + '}' 777 fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec" 778 779 # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid. 780 # we need to split it to two lines: 781 # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon) 782 # one for code: `var1, var2, = function_call()` 783 without_annotation = [x.split(":")[0] for x in free_vars] 784 has_annotation = [x + "; " for x in free_vars if ":" in x] 785 if len(has_annotation) > 0: 786 fn_definition += "\n " + "".join(has_annotation) + "\n" 787 fn_definition += f""" 788 {', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" 789 return fn_definition 790 791 def generate_output(self, output_args): 792 if self.pytree_info and self.pytree_info.out_spec: 793 return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)' 794 else: 795 return super().generate_output(output_args) 796 797class _FindNodesLookupTable: 798 """ 799 Side table for the graph for the purpose of doing fast queries 800 """ 801 def __init__(self): 802 self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict(dict) 803 804 def _key(self, node) -> Tuple[str, Optional[Target]]: 805 return (node.op, node.target if node.op == "call_function" else None) 806 807 def __contains__(self, node) -> bool: 808 return node in self.table[self._key(node)] 809 810 def insert(self, node: Node) -> None: 811 self.table[self._key(node)][node] = None 812 813 def remove(self, node: Node) -> None: 814 self.table[self._key(node)].pop(node) 815 816 def find_nodes(self, *, op: str, target: Optional['Target'] = None): 817 if op == "call_function": 818 assert target is not None 819 return dict(self.table[(op, target)]).keys() 820 821 if target is None: 822 return dict(self.table[(op, None)]).keys() 823 824 # op is call_method, get_attr, call_module 825 return [node for node in self.table[(op, None)].keys() if node.target == target] 826 827@compatibility(is_backward_compatible=True) 828class Graph: 829 """ 830 ``Graph`` is the main data structure used in the FX Intermediate Representation. 831 It consists of a series of ``Node`` s, each representing callsites (or other 832 syntactic constructs). The list of ``Node`` s, taken together, constitute a 833 valid Python function. 834 835 For example, the following code 836 837 .. code-block:: python 838 839 import torch 840 import torch.fx 841 842 class MyModule(torch.nn.Module): 843 def __init__(self): 844 super().__init__() 845 self.param = torch.nn.Parameter(torch.rand(3, 4)) 846 self.linear = torch.nn.Linear(4, 5) 847 848 def forward(self, x): 849 return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) 850 851 m = MyModule() 852 gm = torch.fx.symbolic_trace(m) 853 854 Will produce the following Graph:: 855 856 print(gm.graph) 857 858 .. code-block:: text 859 860 graph(x): 861 %linear_weight : [num_users=1] = self.linear.weight 862 %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) 863 %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) 864 %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) 865 %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) 866 %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) 867 return topk_1 868 869 For the semantics of operations represented in the ``Graph``, please see :class:`Node`. 870 """ 871 872 @compatibility(is_backward_compatible=True) 873 def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None, 874 tracer_extras: Optional[Dict[str, Any]] = None): 875 """ 876 Construct an empty Graph. 877 """ 878 self._root : Node = Node(self, '', 'root', '', (), {}) 879 self._used_names : Dict[str, int] = {} # base name -> number 880 self._insert = self._root.prepend 881 self._len = 0 882 self._graph_namespace = _Namespace() 883 self._owning_module = owning_module 884 self._tracer_cls = tracer_cls 885 self._tracer_extras = tracer_extras 886 self._codegen = CodeGen() 887 self._co_fields : Dict[str, Any] = {} 888 self._find_nodes_lookup_table = _FindNodesLookupTable() 889 890 @property 891 def owning_module(self): 892 return self._owning_module 893 894 @owning_module.setter 895 def owning_module(self, mod: Optional["GraphModule"]): 896 self._owning_module = mod 897 898 @property 899 def nodes(self) -> _node_list: 900 """ 901 Get the list of Nodes that constitute this Graph. 902 903 Note that this ``Node`` list representation is a doubly-linked list. Mutations 904 during iteration (e.g. delete a Node, add a Node) are safe. 905 906 Returns: 907 908 A doubly-linked list of Nodes. Note that ``reversed`` can be called on 909 this list to switch iteration order. 910 """ 911 return _node_list(self) 912 913 @compatibility(is_backward_compatible=False) 914 def find_nodes(self, *, op: str, target: Optional['Target'] = None, sort: bool = True): 915 """ 916 Allows for fast query of nodes 917 918 Args: 919 920 op (str): the name of the operation 921 922 target (Optional[Target]): the target of the node. For call_function, 923 the target is required. For other ops, the target is optional. 924 925 sort (bool): whether to return nodes in the order they appear on 926 on the graph. 927 928 Returns: 929 930 Iteratable of nodes with the requested op and target. 931 """ 932 node_list = self._find_nodes_lookup_table.find_nodes(op=op, target=target) 933 if sort: 934 return sorted(node_list) 935 return node_list 936 937 @compatibility(is_backward_compatible=True) 938 def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]': 939 """ 940 Copy all nodes from a given graph into ``self``. 941 942 Args: 943 944 g (Graph): The source graph from which to copy Nodes. 945 946 val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping 947 from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed 948 in with values in it already to override copying of certain values. 949 950 Returns: 951 952 The value in ``self`` that is now equivalent to the output value in ``g``, 953 if ``g`` had an ``output`` node. ``None`` otherwise. 954 """ 955 for node in g.nodes: 956 if node in val_map: 957 continue 958 if node.op == 'output': 959 rv = map_arg(node.args[0], lambda n: val_map[n]) 960 return rv if not return_output_node else (rv, node) 961 val_map[node] = self.node_copy(node, lambda n : val_map[n]) 962 return None 963 964 def __deepcopy__(self, memo=None) -> 'Graph': 965 """ 966 Explicitly implement __deepcopy__ to prevent excessive recursion depth 967 from the default implementation. This uses graph_copy to copy the nodes 968 in an iterative way, rather than recursive. It also populates the 969 memoization table to prevent unnecessary copies (e.g. references to 970 nodes or other parts of the Graph from a custom GraphModule implementation. 971 """ 972 memo = memo if memo else {} 973 g = Graph(tracer_cls=self._tracer_cls) 974 output_vals = g.graph_copy(self, val_map=memo, return_output_node=True) 975 g._codegen = copy.deepcopy(self._codegen) 976 assert isinstance(output_vals, tuple) 977 output_val, old_output_node = output_vals 978 new_output_node = g.output(output_val, type_expr=getattr(old_output_node, 'type', None)) 979 new_output_node.meta = copy.copy(old_output_node.meta) 980 return g 981 982 @compatibility(is_backward_compatible=True) 983 def create_node(self, op: str, target: 'Target', 984 args: Optional[Tuple['Argument', ...]] = None, 985 kwargs: Optional[Dict[str, 'Argument']] = None, 986 name: Optional[str] = None, 987 type_expr: Optional[Any] = None) -> Node: 988 """ 989 Create a ``Node`` and add it to the ``Graph`` at the current insert-point. 990 Note that the current insert-point can be set via :meth:`Graph.inserting_before` 991 and :meth:`Graph.inserting_after`. 992 993 Args: 994 op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr', 995 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are 996 described in the ``Graph`` docstring. 997 998 args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node. 999 1000 kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node 1001 1002 name (Optional[str]): an optional string name for the ``Node``. 1003 This will influence the name of the value assigned to in the 1004 Python generated code. 1005 1006 type_expr (Optional[Any]): an optional type annotation representing the 1007 Python type the output of this node will have. 1008 1009 Returns: 1010 1011 The newly-created and inserted node. 1012 """ 1013 assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output') 1014 args = () if args is None else args 1015 kwargs = {} if kwargs is None else kwargs 1016 assert isinstance(args, tuple), "args must be a tuple" 1017 assert isinstance(kwargs, dict), "kwargs must be a dict" 1018 1019 candidate = name if name is not None else self._target_to_str(target) 1020 name = self._graph_namespace.create_name(candidate, None) 1021 n = Node(self, name, op, target, args, kwargs, type_expr) 1022 1023 if self.owning_module is not None and getattr(self.owning_module, "_create_node_hooks", None) is not None: 1024 for f in self.owning_module._create_node_hooks: 1025 f(n) 1026 1027 self._graph_namespace.associate_name_with_obj(name, n) 1028 1029 self._insert(n) 1030 self._find_nodes_lookup_table.insert(n) 1031 self._len += 1 1032 return n 1033 1034 @compatibility(is_backward_compatible=False) 1035 def process_inputs(self, *args): 1036 """ 1037 Processes args so that they can be passed to the FX graph. 1038 """ 1039 return self._codegen.process_inputs(*args) 1040 1041 @compatibility(is_backward_compatible=False) 1042 def process_outputs(self, out): 1043 return self._codegen.process_outputs(out) 1044 1045 1046 @compatibility(is_backward_compatible=True) 1047 def erase_node(self, to_erase : Node) -> None: 1048 """ 1049 Erases a ``Node`` from the ``Graph``. Throws an exception if 1050 there are still users of that node in the ``Graph``. 1051 1052 Args: 1053 1054 to_erase (Node): The ``Node`` to erase from the ``Graph``. 1055 """ 1056 if len(to_erase.users) > 0: 1057 raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} ' 1058 f'users in the graph: {to_erase.users}!') 1059 if to_erase.graph != self: 1060 raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!") 1061 if to_erase._erased: 1062 warnings.warn(f"erase_node({to_erase}) on an already erased node") 1063 return 1064 1065 if self.owning_module is not None and getattr(self.owning_module, "_erase_node_hooks", None) is not None: 1066 for f in self.owning_module._erase_node_hooks: 1067 f(to_erase) 1068 1069 self._find_nodes_lookup_table.remove(to_erase) 1070 to_erase._remove_from_list() 1071 to_erase._erased = True # iterators may retain handles to erased nodes 1072 self._len -= 1 1073 1074 # Null out this Node's argument nodes so that the Nodes referred to 1075 # can update their ``users`` accordingly 1076 new_args = map_arg(to_erase.args, lambda n: None) 1077 assert isinstance(new_args, tuple) 1078 to_erase.args = new_args 1079 new_kwargs = map_arg(to_erase.kwargs, lambda n: None) 1080 assert isinstance(new_kwargs, dict) 1081 to_erase.kwargs = new_kwargs 1082 1083 @compatibility(is_backward_compatible=True) 1084 def inserting_before(self, n: Optional[Node] = None): 1085 """Set the point at which create_node and companion methods will insert into the graph. 1086 When used within a 'with' statement, this will temporary set the insert point and 1087 then restore it when the with statement exits:: 1088 1089 with g.inserting_before(n): 1090 ... # inserting before node n 1091 ... # insert point restored to what it was previously 1092 g.inserting_before(n) # set the insert point permanently 1093 1094 Args: 1095 1096 n (Optional[Node]): The node before which to insert. If None this will insert before 1097 the beginning of the entire graph. 1098 1099 Returns: 1100 A resource manager that will restore the insert point on ``__exit__``. 1101 """ 1102 if n is None: 1103 return self.inserting_after(self._root) 1104 assert n.graph == self, "Node to insert before is not in graph." 1105 return _InsertPoint(self, n.prepend) 1106 1107 @compatibility(is_backward_compatible=True) 1108 def inserting_after(self, n: Optional[Node] = None): 1109 """Set the point at which create_node and companion methods will insert into the graph. 1110 When used within a 'with' statement, this will temporary set the insert point and 1111 then restore it when the with statement exits:: 1112 1113 with g.inserting_after(n): 1114 ... # inserting after node n 1115 ... # insert point restored to what it was previously 1116 g.inserting_after(n) # set the insert point permanently 1117 1118 Args: 1119 1120 n (Optional[Node]): The node before which to insert. If None this will insert after 1121 the beginning of the entire graph. 1122 1123 Returns: 1124 A resource manager that will restore the insert point on ``__exit__``. 1125 """ 1126 if n is None: 1127 return self.inserting_before(self._root) 1128 assert n.graph == self, "Node to insert after is not in graph." 1129 return _InsertPoint(self, n.append) 1130 1131 @compatibility(is_backward_compatible=True) 1132 def placeholder(self, name: str, type_expr: Optional[Any] = None, 1133 default_value : Any = inspect.Signature.empty) -> Node: 1134 """ 1135 Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents 1136 a function input. 1137 1138 Args: 1139 1140 name (str): A name for the input value. This corresponds to the name 1141 of the positional argument to the function this ``Graph`` represents. 1142 1143 type_expr (Optional[Any]): an optional type annotation representing the 1144 Python type the output of this node will have. This is needed in some 1145 cases for proper code generation (e.g. when the function is used 1146 subsequently in TorchScript compilation). 1147 1148 default_value (Any): The default value this function argument should take 1149 on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty` 1150 should be passed as this argument to specify that the parameter does _not_ 1151 have a default value. 1152 1153 .. note:: 1154 The same insertion point and type expression rules apply for this method 1155 as ``Graph.create_node``. 1156 """ 1157 args = () if default_value is inspect.Signature.empty else (default_value,) 1158 return self.create_node('placeholder', name, args=args, type_expr=type_expr) 1159 1160 @compatibility(is_backward_compatible=True) 1161 def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: 1162 """ 1163 Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the 1164 fetch of an attribute from the ``Module`` hierarchy. 1165 1166 Args: 1167 1168 qualified_name (str): the fully-qualified name of the attribute to be retrieved. 1169 For example, if the traced Module has a submodule named ``foo``, which has a 1170 submodule named ``bar``, which has an attribute named ``baz``, the qualified 1171 name ``foo.bar.baz`` should be passed as ``qualified_name``. 1172 1173 type_expr (Optional[Any]): an optional type annotation representing the 1174 Python type the output of this node will have. 1175 1176 1177 Returns: 1178 1179 The newly-created and inserted ``get_attr`` node. 1180 1181 .. note:: 1182 The same insertion point and type expression rules apply for this method 1183 as ``Graph.create_node``. 1184 """ 1185 def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool: 1186 module_path, _, name = qualified_name.rpartition(".") 1187 1188 try: 1189 submod: torch.nn.Module = mod.get_submodule(module_path) 1190 except AttributeError: 1191 warnings.warn(f"Failed to fetch module {module_path}!") 1192 return False 1193 1194 if not hasattr(submod, name): 1195 return False 1196 1197 res = getattr(submod, name) 1198 1199 if (not isinstance(res, torch.nn.Module) 1200 and not isinstance(res, torch.nn.Parameter) 1201 and name not in submod._buffers): 1202 return False 1203 1204 return True 1205 1206 if (self.owning_module and 1207 not _get_attr_reference_exists(self.owning_module, qualified_name)): 1208 warnings.warn("Attempted to insert a get_attr Node with no " 1209 "underlying reference in the owning " 1210 "GraphModule! Call " 1211 "GraphModule.add_submodule to add the " 1212 "necessary submodule, " 1213 "GraphModule.add_parameter to add the " 1214 "necessary Parameter, or " 1215 "nn.Module.register_buffer to add the " 1216 "necessary buffer", stacklevel=2) 1217 return self.create_node('get_attr', qualified_name, type_expr=type_expr) 1218 1219 @compatibility(is_backward_compatible=True) 1220 def call_module(self, 1221 module_name: str, 1222 args: Optional[Tuple['Argument', ...]] = None, 1223 kwargs: Optional[Dict[str, 'Argument']] = None, 1224 type_expr: Optional[Any] = None) -> Node: 1225 """ 1226 Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node 1227 represents a call to the forward() function of a ``Module`` in the ``Module`` 1228 hierarchy. 1229 1230 Args: 1231 1232 module_name (str): The qualified name of the ``Module`` in the ``Module`` 1233 hierarchy to be called. For example, if the traced ``Module`` has a 1234 submodule named ``foo``, which has a submodule named ``bar``, the 1235 qualified name ``foo.bar`` should be passed as ``module_name`` to 1236 call that module. 1237 1238 args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed 1239 to the called method. Note that this should *not* include a ``self`` argument. 1240 1241 kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed 1242 to the called method 1243 1244 type_expr (Optional[Any]): an optional type annotation representing the 1245 Python type the output of this node will have. 1246 1247 Returns: 1248 1249 The newly-created and inserted ``call_module`` node. 1250 1251 .. note:: 1252 The same insertion point and type expression rules apply for this method 1253 as :meth:`Graph.create_node`. 1254 """ 1255 if (self.owning_module and 1256 self.owning_module.get_submodule(module_name) is None): 1257 warnings.warn("Attempted to insert a call_module Node with " 1258 "no underlying reference in the owning " 1259 "GraphModule! Call " 1260 "GraphModule.add_submodule to add the " 1261 "necessary submodule") 1262 return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr) 1263 1264 @compatibility(is_backward_compatible=True) 1265 def call_method(self, 1266 method_name: str, 1267 args: Optional[Tuple['Argument', ...]] = None, 1268 kwargs: Optional[Dict[str, 'Argument']] = None, 1269 type_expr: Optional[Any] = None) -> Node: 1270 """ 1271 Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node 1272 represents a call to a given method on the 0th element of ``args``. 1273 1274 Args: 1275 1276 method_name (str): The name of the method to apply to the self argument. 1277 For example, if args[0] is a ``Node`` representing a ``Tensor``, 1278 then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``. 1279 1280 args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed 1281 to the called method. Note that this *should* include a ``self`` argument. 1282 1283 kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed 1284 to the called method 1285 1286 type_expr (Optional[Any]): an optional type annotation representing the 1287 Python type the output of this node will have. 1288 1289 Returns: 1290 1291 The newly created and inserted ``call_method`` node. 1292 1293 .. note:: 1294 The same insertion point and type expression rules apply for this method 1295 as :meth:`Graph.create_node`. 1296 """ 1297 return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr) 1298 1299 @compatibility(is_backward_compatible=True) 1300 def call_function(self, 1301 the_function: Callable[..., Any], 1302 args: Optional[Tuple['Argument', ...]] = None, 1303 kwargs: Optional[Dict[str, 'Argument']] = None, 1304 type_expr: Optional[Any] = None) -> Node: 1305 """ 1306 Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node 1307 represents a call to a Python callable, specified by ``the_function``. 1308 1309 Args: 1310 1311 the_function (Callable[..., Any]): The function to be called. Can be any PyTorch 1312 operator, Python function, or member of the ``builtins`` or ``operator`` 1313 namespaces. 1314 1315 args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed 1316 to the called function. 1317 1318 kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed 1319 to the called function 1320 1321 type_expr (Optional[Any]): an optional type annotation representing the 1322 Python type the output of this node will have. 1323 1324 Returns: 1325 1326 The newly created and inserted ``call_function`` node. 1327 1328 .. note:: 1329 The same insertion point and type expression rules apply for this method 1330 as :meth:`Graph.create_node`. 1331 """ 1332 return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr) 1333 1334 @compatibility(is_backward_compatible=True) 1335 def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node: 1336 """ 1337 Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from 1338 the graph of node to the graph of self. Example:: 1339 1340 # Copying all the nodes in `g` into `new_graph` 1341 g : torch.fx.Graph = ... 1342 new_graph = torch.fx.graph() 1343 value_remap = {} 1344 for node in g.nodes: 1345 value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n]) 1346 1347 Args: 1348 1349 node (Node): The node to copy into ``self``. 1350 1351 arg_transform (Callable[[Node], Argument]): A function that transforms 1352 ``Node`` arguments in node's ``args`` and ``kwargs`` into the 1353 equivalent argument in ``self``. In the simplest case, this should 1354 retrieve a value out of a table mapping Nodes in the original 1355 graph to ``self``. 1356 """ 1357 args = map_arg(node.args, arg_transform) 1358 kwargs = map_arg(node.kwargs, arg_transform) 1359 assert isinstance(args, tuple) 1360 assert isinstance(kwargs, dict) 1361 result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type) 1362 result_node.meta = copy.copy(node.meta) 1363 return result_node 1364 1365 @compatibility(is_backward_compatible=True) 1366 def output(self, result: 'Argument', type_expr: Optional[Any] = None): 1367 """ 1368 Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents 1369 a ``return`` statement in Python code. ``result`` is the value that should 1370 be returned. 1371 1372 Args: 1373 1374 result (Argument): The value to be returned. 1375 1376 type_expr (Optional[Any]): an optional type annotation representing the 1377 Python type the output of this node will have. 1378 1379 .. note:: 1380 1381 The same insertion point and type expression rules apply for this method 1382 as ``Graph.create_node``. 1383 """ 1384 return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr) 1385 1386 def _target_to_str(self, target : Target) -> str: 1387 if callable(target): 1388 op = target.__name__ 1389 else: 1390 assert isinstance(target, str) 1391 op = target 1392 if _is_magic(op): 1393 op = op[2:-2] 1394 op = _snake_case(op) 1395 return op 1396 1397 @compatibility(is_backward_compatible=True) 1398 def python_code( 1399 self, root_module: str, *, 1400 verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False 1401 ) -> PythonCode: 1402 """ 1403 Turn this ``Graph`` into valid Python code. 1404 1405 Args: 1406 1407 root_module (str): The name of the root module on which to look-up 1408 qualified name targets. This is usually 'self'. 1409 1410 Returns: 1411 1412 A PythonCode object, consisting of two fields: 1413 src: the Python source code representing the object 1414 globals: a dictionary of global names in `src` -> the objects that they reference. 1415 """ 1416 # NOTE: [Graph Namespaces] 1417 # 1418 # There are two types of symbols in generated Python source code: 1419 # locals and globals. 1420 # Locals are locally defined by the output of a node in the Graph. 1421 # Globals are references to external objects, like functions or types. 1422 # 1423 # When generating Python code, we need to make sure to name things 1424 # appropriately. In particular: 1425 # - All names should be unique, to avoid weird shadowing bugs. 1426 # - These names need to be consistent, e.g. a object should always be 1427 # referenced by the same name. 1428 # 1429 # To do this, we create a new namespace just for this source. All names 1430 # that get printed must come from this namespace. 1431 # 1432 # Why can't we re-use node.name? Because it was generated within the 1433 # namespace `self._graph_namespace`. In order to provide uniqueness 1434 # over both locals (node.name) *and* globals, we create a completely 1435 # new namespace to put all identifiers in. 1436 namespace = _Namespace() 1437 1438 # Override Node's repr to generate a valid name within our namespace. 1439 # Since repr() is designed to produce a valid Python expression, it 1440 # makes sense to re-use it. This way, it's easy to print something like 1441 # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is 1442 # implemented cooperatively to allow this. 1443 def node_repr(n: Node): 1444 return namespace.create_name(n.name, n) 1445 1446 @contextmanager 1447 def override_node_repr(graph: Graph): 1448 orig_repr_fns = {} 1449 for node in graph.nodes: 1450 orig_repr_fns[node] = node._repr_fn 1451 node._repr_fn = node_repr 1452 try: 1453 yield None 1454 finally: 1455 # restore the original repr functions 1456 for node in graph.nodes: 1457 node._repr_fn = orig_repr_fns[node] 1458 1459 with override_node_repr(self): 1460 return self._python_code( 1461 root_module, namespace, 1462 verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored 1463 ) 1464 1465 def _python_code( 1466 self, root_module: str, namespace: _Namespace, *, 1467 verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, 1468 ) -> PythonCode: 1469 return self._codegen._gen_python_code( 1470 self.nodes, root_module, namespace, 1471 verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored 1472 ) 1473 1474 1475 def __str__(self) -> str: 1476 """ 1477 Return a human-readable (not machine-readable) string representation 1478 of this Graph 1479 """ 1480 placeholder_names : List[str] = [] 1481 # This is a one-element array just so ``format_node`` can modify the closed 1482 # over value 1483 maybe_return_typename : List[str] = [''] 1484 1485 node_strs = [node.format_node(placeholder_names) for node in self.nodes] 1486 param_str = ', '.join(placeholder_names) 1487 s = f'graph({param_str}){maybe_return_typename[0]}:' 1488 for node_str in node_strs: 1489 if node_str: 1490 s += '\n ' + node_str 1491 return s 1492 1493 @compatibility(is_backward_compatible=True) 1494 def print_tabular(self): 1495 """ 1496 Prints the intermediate representation of the graph in tabular 1497 format. Note that this API requires the ``tabulate`` module to be 1498 installed. 1499 """ 1500 try: 1501 from tabulate import tabulate 1502 except ImportError: 1503 print("`print_tabular` relies on the library `tabulate`, " 1504 "which could not be found on this machine. Run `pip " 1505 "install tabulate` to install the library.") 1506 raise 1507 1508 node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] 1509 for n in self.nodes] 1510 print(tabulate(node_specs, 1511 headers=['opcode', 'name', 'target', 'args', 'kwargs'])) 1512 1513 @compatibility(is_backward_compatible=True) 1514 def lint(self): 1515 """ 1516 Runs various checks on this Graph to make sure it is well-formed. In 1517 particular: 1518 - Checks Nodes have correct ownership (owned by this graph) 1519 - Checks Nodes appear in topological order 1520 - If this Graph has an owning GraphModule, checks that targets 1521 exist in that GraphModule 1522 """ 1523 1524 # Check topo order 1525 def check_arg(arg : Node, n : Optional[Node] = None) -> None: 1526 context_str = f' of Node \'{n}\' ' if n else ' ' 1527 if arg.graph is not self: 1528 raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, ' 1529 f'but was used as an argument! If you are copying nodes from another graph, make ' 1530 f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}') 1531 if arg not in seen_values: 1532 raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been ' 1533 f'defined! Please check that Nodes in the graph are topologically ordered\n{self}') 1534 1535 seen_names : Set[str] = set() 1536 seen_values : Set[Node] = set() 1537 for node in self.nodes: 1538 if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']: 1539 raise RuntimeError(f'Node {node} had unknown opcode {node.op}!') 1540 if node.graph is not self: 1541 raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!') 1542 if node not in self._find_nodes_lookup_table: 1543 raise RuntimeError(f"Node '{node}' is not added to the side table") 1544 map_arg(node.args, lambda arg: check_arg(arg, node)) 1545 map_arg(node.kwargs, lambda arg: check_arg(arg, node)) 1546 seen_values.add(node) 1547 1548 if node.name in seen_names: 1549 raise RuntimeError(f'Node redefined name {node.name}!') 1550 seen_names.add(node.name) 1551 1552 # Check targets are legit 1553 if self.owning_module: 1554 num_warnings = 0 1555 MAX_WARNINGS = 5 1556 for node in self.nodes: 1557 if node.op == 'call_function': 1558 if not callable(node.target): 1559 raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' 1560 'a Callable is expected') 1561 else: 1562 if not isinstance(node.target, str): 1563 raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' 1564 'a str is expected') 1565 if node.op in ['get_attr', 'call_module']: 1566 target_atoms = node.target.split('.') 1567 m_itr = self.owning_module 1568 for i, atom in enumerate(target_atoms): 1569 new_m_itr = getattr(m_itr, atom, None) 1570 seen_qualname = '.'.join(target_atoms[:i]) 1571 if new_m_itr is None: 1572 raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute ' 1573 f'{atom} of {seen_qualname}') 1574 if (node.op == "call_module" 1575 and not isinstance(new_m_itr, torch.nn.Module)): 1576 raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' 1577 'not reference an nn.Module') 1578 elif (node.op == "get_attr" 1579 and not isinstance(new_m_itr, torch.nn.Module) 1580 and not isinstance(new_m_itr, torch.nn.Parameter) 1581 and atom not in m_itr._buffers): 1582 if num_warnings < MAX_WARNINGS: 1583 # Don't emit this warning too frequently, 1584 # for very large graphs this can become very expensive 1585 # from a performance perspective. 1586 warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' 1587 'not reference an nn.Module, nn.Parameter, or buffer, which is ' 1588 'what \'get_attr\' Nodes typically target') 1589 num_warnings += 1 1590 else: 1591 m_itr = new_m_itr 1592 if num_warnings > MAX_WARNINGS: 1593 warnings.warn( 1594 f'Additional {num_warnings - MAX_WARNINGS} warnings ' 1595 'suppressed about get_attr references' 1596 ) 1597 1598 @compatibility(is_backward_compatible=True) 1599 def eliminate_dead_code(self, is_impure_node: Optional[Callable[[Node], bool]] = None): 1600 """ 1601 Remove all dead code from the graph, based on each node's number of 1602 users, and whether the nodes have any side effects. The graph must be 1603 topologically sorted before calling. 1604 1605 Args: 1606 is_impure_node (Optional[Callable[[Node], bool]]): A function that returns 1607 whether a node is impure. If this is None, then the default behavior is to 1608 use Node.is_impure. 1609 1610 Returns: 1611 bool: Whether the graph was changed as a result of the pass. 1612 1613 Example: 1614 1615 Before dead code is eliminated, `a` from `a = x + 1` below has no users 1616 and thus can be eliminated from the graph without having an effect. 1617 1618 .. code-block:: python 1619 1620 def forward(self, x): 1621 a = x + 1 1622 return x + self.attr_1 1623 1624 After dead code is eliminated, `a = x + 1` has been removed, and the rest 1625 of `forward` remains. 1626 1627 .. code-block:: python 1628 1629 def forward(self, x): 1630 return x + self.attr_1 1631 1632 .. warning:: 1633 1634 Dead code elimination has some heuristics to avoid removing 1635 side-effectful nodes (see Node.is_impure) but in general coverage 1636 is very bad, so you should assume that this method is not sound 1637 to call unless you know that your FX graph consists entirely 1638 of functional operations or you supply your own custom 1639 function for detecting side-effectful nodes. 1640 """ 1641 # Lint the graph first to make sure its topologically sorted, otherwise 1642 # DCE below will not behave as expected. 1643 self.lint() 1644 1645 def has_side_effect(node): 1646 if is_impure_node is not None: 1647 return is_impure_node(node) 1648 return node.is_impure() 1649 1650 # Reverse iterate so that when we remove a node, any nodes used as an 1651 # input to that node have an updated user count that no longer reflects 1652 # the removed node. 1653 changed = False 1654 for node in reversed(self.nodes): 1655 if not has_side_effect(node) and len(node.users) == 0: 1656 self.erase_node(node) 1657 changed = True 1658 1659 return changed 1660 1661 @compatibility(is_backward_compatible=False) 1662 def set_codegen(self, codegen: CodeGen): 1663 self._codegen = codegen 1664 1665 @compatibility(is_backward_compatible=False) 1666 def on_generate_code( 1667 self, 1668 make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc] 1669 ): 1670 """Register a transformer function when python code is generated 1671 1672 Args: 1673 make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]): 1674 a function that returns a code transformer to be registered. 1675 This function is called by `on_generate_code` to obtain the 1676 code transformer. 1677 1678 This function is also given as its input the currently 1679 registered code transformer (or None if nothing is registered), 1680 in case it is not desirable to overwrite it. This is useful to 1681 chain code transformers together. 1682 1683 Returns: 1684 a context manager that when used in a `with` statement, to automatically 1685 restore the previously registered code transformer. 1686 1687 Example: 1688 1689 .. code-block:: python 1690 1691 1692 gm: fx.GraphModule = ... 1693 1694 # This is a code transformer we want to register. This code 1695 # transformer prepends a pdb import and trace statement at the very 1696 # beginning of the generated torch.fx code to allow for manual 1697 # debugging with the PDB library. 1698 def insert_pdb(body): 1699 return ["import pdb; pdb.set_trace()\\n", *body] 1700 1701 # Registers `insert_pdb`, and overwrites the current registered 1702 # code transformer (given by `_` to the lambda): 1703 gm.graph.on_generate_code( 1704 lambda _: insert_pdb 1705 ) 1706 1707 # Or alternatively, registers a code transformer which first 1708 # runs `body` through existing registered transformer, then 1709 # through `insert_pdb`: 1710 gm.graph.on_generate_code( 1711 lambda current_trans: ( 1712 lambda body: insert_pdb( 1713 current_trans(body) if current_trans 1714 else body 1715 ) 1716 ) 1717 ) 1718 1719 gm.recompile() 1720 gm(*inputs) # drops into pdb 1721 1722 1723 This function can also be used as a context manager, with the benefit to 1724 automatically restores the previously registered code transformer: 1725 1726 .. code-block:: python 1727 1728 # ... continue from previous example 1729 1730 with gm.graph.on_generate_code(lambda _: insert_pdb): 1731 # do more stuff with `gm`... 1732 gm.recompile() 1733 gm(*inputs) # drops into pdb 1734 1735 # now previous code transformer is restored (but `gm`'s code with pdb 1736 # remains - that means you can run `gm` with pdb here too, until you 1737 # run next `recompile()`). 1738 """ 1739 on_gen_code_old = self._codegen._body_transformer 1740 self._codegen._body_transformer = make_transformer(on_gen_code_old) 1741 1742 @contextlib.contextmanager 1743 def on_generate_code_context_manager(): 1744 try: 1745 yield 1746 finally: 1747 self._codegen._body_transformer = on_gen_code_old 1748 1749 return on_generate_code_context_manager() 1750 1751 1752reflectable_magic_methods = { 1753 'add': '{} + {}', 1754 'sub': '{} - {}', 1755 'mul': '{} * {}', 1756 'floordiv': '{} // {}', 1757 'truediv': '{} / {}', 1758 'div': '{} / {}', 1759 'mod': '{} % {}', 1760 'pow': '{} ** {}', 1761 'lshift': '{} << {}', 1762 'rshift': '{} >> {}', 1763 'and_': '{} & {}', 1764 'or_': '{} | {}', 1765 'xor': '{} ^ {}', 1766 'getitem': '{}[{}]', 1767 'matmul': '{} @ {}', 1768} 1769 1770magic_methods = dict({ 1771 'eq': '{} == {}', 1772 'ne': '{} != {}', 1773 'lt': '{} < {}', 1774 'gt': '{} > {}', 1775 'le': '{} <= {}', 1776 'ge': '{} >= {}', 1777 'pos': '+{}', 1778 'neg': '-{}', 1779 'invert': '~{}'}, **reflectable_magic_methods) 1780 1781inplace_methods = { 1782 'iadd': '{} += {}', 1783 'iand': '{} &= {}', 1784 'ifloordiv': '{} //= {}', 1785 'ilshift': '{} <<= {}', 1786 'imod': '{} %= {}', 1787 'imul': '{} *= {}', 1788 'imatmul': '{} @= {}', 1789 'ior': '{} |= {}', 1790 'ipow': '{} **= {}', 1791 'irshift': '{} >>= {}', 1792 'isub': '{} -= {}', 1793 'itruediv': '{} /= {}', 1794 'ixor': '{} ^= {}', 1795 'setitem': '{}[{}] = {}', 1796} 1797