1# mypy: allow-untyped-defs 2from .graph_module import GraphModule 3from ._lazy_graph_module import _make_graph_module 4from .graph import Graph 5from .node import Argument, Node, Target, map_arg, map_aggregate 6from .proxy import Proxy 7from ._symbolic_trace import Tracer 8from ._compatibility import compatibility 9from . import config 10import torch.fx.traceback as fx_traceback 11import torch 12from typing import Any, Dict, Iterator, List, Optional, Tuple, Union 13import inspect 14from contextlib import contextmanager 15from torch.hub import tqdm 16 17__all__ = ['Interpreter', 'Transformer'] 18 19@compatibility(is_backward_compatible=True) 20class Interpreter: 21 """ 22 An Interpreter executes an FX graph Node-by-Node. This pattern 23 can be useful for many things, including writing code 24 transformations as well as analysis passes. 25 26 Methods in the Interpreter class can be overridden to customize 27 the behavior of execution. The map of overrideable methods 28 in terms of call hierarchy:: 29 30 run() 31 +-- run_node 32 +-- placeholder() 33 +-- get_attr() 34 +-- call_function() 35 +-- call_method() 36 +-- call_module() 37 +-- output() 38 39 Example: 40 41 Suppose we want to swap all instances of ``torch.neg`` with 42 ``torch.sigmoid`` and vice versa (including their ``Tensor`` 43 method equivalents). We could subclass Interpreter like so:: 44 45 class NegSigmSwapInterpreter(Interpreter): 46 def call_function(self, target : Target, 47 args : Tuple, kwargs : Dict) -> Any: 48 if target == torch.sigmoid: 49 return torch.neg(*args, **kwargs) 50 return super().call_function(n) 51 52 def call_method(self, target : Target, 53 args : Tuple, kwargs : Dict) -> Any: 54 if target == 'neg': 55 call_self, *args_tail = args 56 return call_self.sigmoid(*args_tail, **kwargs) 57 return super().call_method(n) 58 59 def fn(x): 60 return torch.sigmoid(x).neg() 61 62 gm = torch.fx.symbolic_trace(fn) 63 input = torch.randn(3, 4) 64 result = NegSigmSwapInterpreter(gm).run(input) 65 torch.testing.assert_close(result, torch.neg(input).sigmoid()) 66 67 Args: 68 module (torch.nn.Module): The module to be executed 69 garbage_collect_values (bool): Whether to delete values after their last 70 use within the Module's execution. This ensures optimal memory usage during 71 execution. This can be disabled to, for example, examine all of the intermediate 72 values in the execution by looking at the ``Interpreter.env`` attribute. 73 graph (Optional[Graph]): If passed, the interpreter will execute this 74 graph instead of `module.graph`, using the provided `module` 75 argument to satisfy any requests for state. 76 """ 77 @compatibility(is_backward_compatible=True) 78 def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None): 79 self.module = module 80 self.submodules = dict(self.module.named_modules()) 81 if graph is not None: 82 self.graph = graph 83 else: 84 self.graph = self.module.graph 85 self.env : Dict[Node, Any] = {} 86 self.name = "Interpreter" 87 self.garbage_collect_values = garbage_collect_values 88 self.extra_traceback = True 89 90 if self.garbage_collect_values: 91 # Run through reverse nodes and record the first instance of a use 92 # of a given node. This represents the *last* use of the node in the 93 # execution order of the program, which we will use to free unused 94 # values 95 node_to_last_use : Dict[Node, Node] = {} 96 self.user_to_last_uses : Dict[Node, List[Node]] = {} 97 98 def register_last_uses(n : Node, user : Node): 99 if n not in node_to_last_use: 100 node_to_last_use[n] = user 101 self.user_to_last_uses.setdefault(user, []).append(n) 102 103 for node in reversed(self.graph.nodes): 104 map_arg(node.args, lambda n: register_last_uses(n, node)) 105 map_arg(node.kwargs, lambda n: register_last_uses(n, node)) 106 107 @compatibility(is_backward_compatible=True) 108 def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any: 109 """ 110 Run `module` via interpretation and return the result. 111 112 Args: 113 *args: The arguments to the Module to run, in positional order 114 initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. 115 This is a dict mapping `Node` to any value. This can be used, for example, to 116 pre-populate results for certain `Nodes` so as to do only partial evaluation within 117 the interpreter. 118 enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and 119 process_outputs function first before using them. 120 121 Returns: 122 Any: The value returned from executing the Module 123 """ 124 self.env = initial_env if initial_env is not None else {} 125 126 # Positional function args are consumed left-to-right by 127 # `placeholder` nodes. Use an iterator to keep track of 128 # position and extract those values. 129 if enable_io_processing: 130 args = self.graph.process_inputs(*args) 131 self.args_iter : Iterator[Any] = iter(args) 132 pbar = tqdm(total=len(self.graph.nodes), 133 desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}", 134 initial=0, position=0, leave=True, disable=config.disable_progress, delay=0) 135 136 for node in self.graph.nodes: 137 pbar.update(1) 138 if node in self.env: 139 # Short circuit if we have this value. This could 140 # be used, for example, for partial evaluation 141 # where the caller has pre-populated `env` with 142 # values for a subset of the program. 143 continue 144 145 try: 146 self.env[node] = self.run_node(node) 147 except Exception as e: 148 if self.extra_traceback: 149 msg = f"While executing {node.format_node()}" 150 msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg) 151 msg += f"\nOriginal traceback:\n{node.stack_trace}" 152 e.args = (msg,) + e.args[1:] 153 if isinstance(e, KeyError): 154 raise RuntimeError(*e.args) from e 155 raise 156 157 if self.garbage_collect_values: 158 for to_delete in self.user_to_last_uses.get(node, []): 159 del self.env[to_delete] 160 161 if node.op == 'output': 162 output_val = self.env[node] 163 return self.graph.process_outputs(output_val) if enable_io_processing else output_val 164 165 @compatibility(is_backward_compatible=True) 166 def boxed_run(self, args_list): 167 """ 168 Run `module` via interpretation and return the result. This uses the "boxed" 169 calling convention, where you pass a list of arguments, which will be cleared 170 by the interpreter. This ensures that input tensors are promptly deallocated. 171 """ 172 args_iter = iter(args_list) 173 env = {} 174 for n in self.graph.nodes: 175 if n.op == "placeholder": 176 env[n] = next(args_iter) 177 args_list.clear() 178 return self.run(initial_env=env) 179 180 @contextmanager 181 def _set_current_node(self, node): 182 with fx_traceback.set_current_meta(node): 183 yield 184 185 @compatibility(is_backward_compatible=True) 186 def run_node(self, n : Node) -> Any: 187 """ 188 Run a specific node ``n`` and return the result. 189 Calls into placeholder, get_attr, call_function, 190 call_method, call_module, or output depending 191 on ``node.op`` 192 193 Args: 194 n (Node): The Node to execute 195 196 Returns: 197 Any: The result of executing ``n`` 198 """ 199 with self._set_current_node(n): 200 args, kwargs = self.fetch_args_kwargs_from_env(n) 201 assert isinstance(args, tuple) 202 assert isinstance(kwargs, dict) 203 return getattr(self, n.op)(n.target, args, kwargs) 204 205 # Main Node running APIs 206 @compatibility(is_backward_compatible=True) 207 def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 208 """ 209 Execute a ``placeholder`` node. Note that this is stateful: 210 ``Interpreter`` maintains an internal iterator over 211 arguments passed to ``run`` and this method returns 212 next() on that iterator. 213 214 Args: 215 target (Target): The call target for this node. See 216 `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 217 details on semantics 218 args (Tuple): Tuple of positional args for this invocation 219 kwargs (Dict): Dict of keyword arguments for this invocation 220 221 Returns: 222 Any: The argument value that was retrieved. 223 """ 224 assert isinstance(target, str) 225 if target.startswith('*'): 226 # For a starred parameter e.g. `*args`, retrieve all 227 # remaining values from the args list. 228 return list(self.args_iter) 229 else: 230 try: 231 return next(self.args_iter) 232 except StopIteration as si: 233 if len(args) > 0: 234 return args[0] 235 else: 236 raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si 237 238 @compatibility(is_backward_compatible=True) 239 def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 240 """ 241 Execute a ``get_attr`` node. Will retrieve an attribute 242 value from the ``Module`` hierarchy of ``self.module``. 243 244 Args: 245 target (Target): The call target for this node. See 246 `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 247 details on semantics 248 args (Tuple): Tuple of positional args for this invocation 249 kwargs (Dict): Dict of keyword arguments for this invocation 250 251 Return: 252 Any: The value of the attribute that was retrieved 253 """ 254 assert isinstance(target, str) 255 return self.fetch_attr(target) 256 257 @compatibility(is_backward_compatible=True) 258 def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 259 """ 260 Execute a ``call_function`` node and return the result. 261 262 Args: 263 target (Target): The call target for this node. See 264 `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 265 details on semantics 266 args (Tuple): Tuple of positional args for this invocation 267 kwargs (Dict): Dict of keyword arguments for this invocation 268 269 Return 270 Any: The value returned by the function invocation 271 """ 272 assert not isinstance(target, str) 273 274 # Execute the function and return the result 275 return target(*args, **kwargs) 276 277 @compatibility(is_backward_compatible=True) 278 def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 279 """ 280 Execute a ``call_method`` node and return the result. 281 282 Args: 283 target (Target): The call target for this node. See 284 `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 285 details on semantics 286 args (Tuple): Tuple of positional args for this invocation 287 kwargs (Dict): Dict of keyword arguments for this invocation 288 289 Return 290 Any: The value returned by the method invocation 291 """ 292 # args[0] is the `self` object for this method call 293 self_obj, *args_tail = args 294 295 # Execute the method and return the result 296 assert isinstance(target, str) 297 return getattr(self_obj, target)(*args_tail, **kwargs) 298 299 @compatibility(is_backward_compatible=True) 300 def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 301 """ 302 Execute a ``call_module`` node and return the result. 303 304 Args: 305 target (Target): The call target for this node. See 306 `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 307 details on semantics 308 args (Tuple): Tuple of positional args for this invocation 309 kwargs (Dict): Dict of keyword arguments for this invocation 310 311 Return 312 Any: The value returned by the module invocation 313 """ 314 # Retrieve executed args and kwargs values from the environment 315 316 # Execute the method and return the result 317 assert isinstance(target, str) 318 submod = self.fetch_attr(target) 319 320 return submod(*args, **kwargs) 321 322 @compatibility(is_backward_compatible=True) 323 def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 324 """ 325 Execute an ``output`` node. This really just retrieves 326 the value referenced by the ``output`` node and returns it. 327 328 Args: 329 target (Target): The call target for this node. See 330 `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 331 details on semantics 332 args (Tuple): Tuple of positional args for this invocation 333 kwargs (Dict): Dict of keyword arguments for this invocation 334 335 Return: 336 Any: The return value referenced by the output node 337 """ 338 return args[0] 339 340 # Helper methods 341 @compatibility(is_backward_compatible=True) 342 def fetch_attr(self, target : str): 343 """ 344 Fetch an attribute from the ``Module`` hierarchy of ``self.module``. 345 346 Args: 347 target (str): The fully-qualified name of the attribute to fetch 348 349 Return: 350 Any: The value of the attribute. 351 """ 352 target_atoms = target.split('.') 353 attr_itr = self.module 354 for i, atom in enumerate(target_atoms): 355 if not hasattr(attr_itr, atom): 356 raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i+1])}") 357 attr_itr = getattr(attr_itr, atom) 358 return attr_itr 359 360 @compatibility(is_backward_compatible=True) 361 def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: 362 """ 363 Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` 364 from the current execution environment. 365 366 Args: 367 n (Node): The node for which ``args`` and ``kwargs`` should be fetched. 368 369 Return: 370 Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``. 371 """ 372 args = self.map_nodes_to_values(n.args, n) 373 assert isinstance(args, tuple) 374 kwargs = self.map_nodes_to_values(n.kwargs, n) 375 assert isinstance(kwargs, dict) 376 return args, kwargs 377 378 @compatibility(is_backward_compatible=True) 379 def map_nodes_to_values(self, args : Argument, n : Node) -> Argument: 380 """ 381 Recursively descend through ``args`` and look up the concrete value 382 for each ``Node`` in the current execution environment. 383 384 Args: 385 args (Argument): Data structure within which to look up concrete values 386 387 n (Node): Node to which ``args`` belongs. This is only used for error reporting. 388 """ 389 def load_arg(n_arg : Node) -> Any: 390 if n_arg not in self.env: 391 raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() ' 392 f'to diagnose such issues') 393 return self.env[n_arg] 394 return map_arg(args, load_arg) 395 396@compatibility(is_backward_compatible=True) 397class Transformer(Interpreter): 398 """ 399 ``Transformer`` is a special type of interpreter that produces a 400 new ``Module``. It exposes a ``transform()`` method that returns 401 the transformed ``Module``. ``Transformer`` does not require 402 arguments to run, as ``Interpreter`` does. ``Transformer`` works 403 entirely symbolically. 404 405 Example: 406 407 Suppose we want to swap all instances of ``torch.neg`` with 408 ``torch.sigmoid`` and vice versa (including their ``Tensor`` 409 method equivalents). We could subclass ``Transformer`` like so:: 410 411 class NegSigmSwapXformer(Transformer): 412 def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 413 if target == torch.sigmoid: 414 return torch.neg(*args, **kwargs) 415 return super().call_function(n) 416 417 def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 418 if target == 'neg': 419 call_self, *args_tail = args 420 return call_self.sigmoid(*args_tail, **kwargs) 421 return super().call_method(n) 422 423 def fn(x): 424 return torch.sigmoid(x).neg() 425 426 gm = torch.fx.symbolic_trace(fn) 427 428 transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() 429 input = torch.randn(3, 4) 430 torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid()) 431 432 Args: 433 module (GraphModule): The ``Module`` to be transformed. 434 """ 435 436 @compatibility(is_backward_compatible=True) 437 def __init__(self, module): 438 super().__init__(module) 439 self.new_graph = Graph() 440 self.new_graph.set_codegen(module.graph._codegen) 441 442 class TransformerTracer(Tracer): 443 def __init__(self, graph: Graph): 444 super().__init__() 445 self.graph = graph 446 self.tensor_attrs: Dict[torch.Tensor, str] = {} # type: ignore[assignment] 447 448 def is_leaf_module(self, _, __) -> bool: 449 return True 450 451 self.tracer = TransformerTracer(self.new_graph) 452 self.tracer.root = module 453 454 @compatibility(is_backward_compatible=True) 455 def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: 456 """ 457 Execute a ``placeholder`` node. In ``Transformer``, this is 458 overridden to insert a new ``placeholder`` into the output 459 graph. 460 461 Args: 462 target (Target): The call target for this node. See 463 `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 464 details on semantics 465 args (Tuple): Tuple of positional args for this invocation 466 kwargs (Dict): Dict of keyword arguments for this invocation 467 """ 468 assert isinstance(target, str) 469 default_value = next(iter(args)) if args else inspect.Signature.empty 470 return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer) 471 472 @compatibility(is_backward_compatible=True) 473 def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: 474 """ 475 Execute a ``get_attr`` node. In ``Transformer``, this is 476 overridden to insert a new ``get_attr`` node into the output 477 graph. 478 479 Args: 480 target (Target): The call target for this node. See 481 `Node <https://pytorch.org/docs/main/fx.html#torch.fx.Node>`__ for 482 details on semantics 483 args (Tuple): Tuple of positional args for this invocation 484 kwargs (Dict): Dict of keyword arguments for this invocation 485 """ 486 assert isinstance(target, str) 487 return self.tracer.create_proxy("get_attr", target, args, kwargs) 488 489 @compatibility(is_backward_compatible=True) 490 def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 491 # Override so that the leaf module policy from `self.tracer` is respected. 492 assert isinstance(target, str) 493 submod = self.fetch_attr(target) 494 return self.tracer.call_module(submod, submod.forward, args, kwargs) 495 496 @compatibility(is_backward_compatible=True) 497 def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: 498 # Override so that functions that were wrapped are still wrapped. 499 return self.tracer.create_proxy('call_function', target, args, kwargs) 500 501 @compatibility(is_backward_compatible=True) 502 def transform(self) -> GraphModule: 503 """ 504 Transform ``self.module`` and return the transformed 505 ``GraphModule``. 506 """ 507 with fx_traceback.preserve_node_meta(): 508 result = super().run(enable_io_processing=False) 509 if result is not None: 510 def strip_proxy(a : Union[Argument, Proxy]) -> Any: 511 return a.node if isinstance(a, Proxy) else a 512 new_output_node = self.new_graph.output(map_aggregate(result, strip_proxy)) 513 # also preserve the metadata from the old output node, if it exists 514 old_output_node = list(self.graph.nodes)[-1] 515 assert old_output_node.op == "output" 516 for k, v in old_output_node.meta.items(): 517 new_output_node.meta[k] = v 518 519 520 return _make_graph_module(self.module, self.new_graph) 521