1# mypy: allow-untyped-defs 2import contextlib 3import functools 4from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union 5 6import torch 7from torch._dynamo.external_utils import ( 8 call_backward, 9 call_hook, 10 FakeCompiledAutogradEngine, 11) 12from torch._dynamo.source import GetItemSource, LocalSource 13from torch._dynamo.utils import counters, lazy_format_graph_code, set_locals_to_steal 14from torch._logging import getArtifactLogger, trace_structured 15from torch._prims_common import clone_preserve_strides 16from torch._subclasses import FakeTensorMode 17from torch.fx import GraphModule 18from torch.fx.experimental._backward_state import BackwardState 19from torch.fx.experimental.proxy_tensor import ( 20 decompose, 21 disable_autocast_cache, 22 disable_proxy_modes_tracing, 23 fetch_object_proxy, 24 ProxyTorchDispatchMode, 25 PythonKeyTracer, 26 track_tensor_tree, 27) 28from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv 29from torch.fx.traceback import preserve_node_meta, set_stack_trace 30from torch.utils._traceback import CapturedTraceback 31 32 33if TYPE_CHECKING: 34 from torch.fx.proxy import Proxy 35 36 37compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd") 38verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose") 39 40 41def snapshot_verbose_logging_enabled(): 42 return torch._logging._internal.log_state.is_artifact_enabled( 43 "compiled_autograd_verbose" 44 ) 45 46 47def cpp_verbose_log_fn(msg: str) -> None: 48 verbose_log.debug(msg) 49 50 51def snapshot_cudagraph_enabled(): 52 return torch._inductor.config.triton.cudagraphs 53 54 55def maybe_clone(x): 56 if x is not None: 57 return clone_preserve_strides(x) 58 return x 59 60 61class AutogradCompilerInstance: 62 def __init__(self, compiler_fn) -> None: 63 self.compiler_fn = compiler_fn 64 self.stack = contextlib.ExitStack() 65 self.close = self.stack.close 66 self.shape_env = ShapeEnv() 67 self.fake_tensor_mode = FakeTensorMode( 68 allow_fallback_kernels=True, 69 allow_non_fake_inputs=True, 70 shape_env=self.shape_env, 71 ) 72 self.fx_tracer = PythonKeyTracer() 73 self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic") 74 self.hooks_proxy: Optional[Proxy] = None 75 self.graph_placeholders = ["inputs", "sizes", "scalars", "hooks"] 76 77 def wrap_fake(self, x, source): 78 assert isinstance(x, torch.Tensor) 79 return self.fake_tensor_mode.from_tensor(x, source=source) 80 81 @staticmethod 82 def source(name, idx) -> GetItemSource: 83 return GetItemSource(LocalSource(name), idx) 84 85 def begin_capture( 86 self, 87 inputs: List[torch.Tensor], 88 sizes: List[int], 89 scalars: List[Union[int, float]], 90 ): 91 counters["compiled_autograd"]["captures"] += 1 92 self.aot_graph_cls_name: Optional[str] = None 93 self.aot_graph_infos: Dict[int, Dict[str, Any]] = {} 94 self.fx_tracer.root = torch.nn.Module() 95 self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer) 96 self.fx_tracer.tensor_attrs = {} 97 args_proxy, sizes_proxy, scalars_proxy, self.hooks_proxy = ( 98 self.fx_tracer.create_proxy("placeholder", name, (), {}) 99 for name in self.graph_placeholders 100 ) 101 102 # tensor inputs to fake tensors 103 inputs = [ 104 self.wrap_fake(x, self.source("inputs", idx)) 105 for idx, x in enumerate(inputs) 106 ] 107 self.bind_tensors_to_proxies(inputs, args_proxy) 108 109 # size inputs to symints 110 sizes = [ 111 self.shape_env.create_unspecified_symint_and_symbol( 112 val, 113 self.source("sizes", idx), 114 DimDynamic.DYNAMIC, 115 ) 116 for idx, val in enumerate(sizes) 117 ] 118 self.bind_tensors_to_proxies(sizes, sizes_proxy) 119 120 for idx, val in enumerate(scalars): 121 source = self.source("scalars", idx) 122 if isinstance(val, int): 123 scalars[idx] = self.shape_env.create_unspecified_symint_and_symbol( 124 val, 125 source, 126 DimDynamic.DYNAMIC, 127 ) 128 elif isinstance(val, float): 129 scalars[idx] = self.shape_env.create_symfloatnode( 130 self.shape_env.create_unspecified_symbol( 131 val, 132 source=source, 133 dynamic_dim=DimDynamic.DYNAMIC, 134 ), 135 hint=val, 136 source=source, 137 ) 138 else: 139 raise AssertionError("Unexpected scalar type: ", type(val)) 140 self.bind_tensors_to_proxies(scalars, scalars_proxy) 141 142 # TODO(jansel): are all these modes needed? 143 self.stack.enter_context(decompose({})) 144 self.stack.enter_context(self.fake_tensor_mode) 145 self.stack.enter_context(self.proxy_mode) 146 self.stack.enter_context(disable_autocast_cache()) 147 self.stack.enter_context(preserve_node_meta()) 148 return inputs, sizes, scalars 149 150 def proxy_call_backward( 151 self, 152 inputs, 153 output_metadatas, 154 saved_tensors, 155 backward_idx: int, 156 ): 157 assert self.hooks_proxy is not None 158 backward_c_function = self.hooks_proxy[backward_idx] # type: ignore[index] 159 proxies = self.fx_tracer.create_proxy( 160 kind="call_function", 161 target=call_backward, 162 args=( 163 backward_c_function, 164 self.to_proxy(saved_tensors), 165 *self.to_proxy(inputs), 166 ), 167 kwargs={}, 168 ) 169 170 with disable_proxy_modes_tracing(): 171 # create fake Tensors 172 grad_ins: List[Optional[torch.Tensor]] = [] 173 for output_metadata in output_metadatas: 174 if output_metadata is None: 175 grad_ins.append(None) 176 continue 177 178 layout, device, dtype, size = output_metadata 179 grad_ins.append( 180 torch.empty(size=size, dtype=dtype, layout=layout, device=device) 181 ) 182 self.bind_tensors_to_proxies(grad_ins, proxies) 183 return tuple(grad_ins) 184 185 def proxy_call_hook(self, hook, *args, **kwargs): 186 return self.fx_tracer.create_proxy( 187 "call_function", 188 call_hook, 189 ( 190 hook, 191 *[self.to_proxy(x) for x in args], 192 ), 193 kwargs, 194 ) 195 196 def tensor_pre_hook(self, inputs, hook_id, i: int): 197 assert self.hooks_proxy is not None 198 hook = self.hooks_proxy[hook_id] # type: ignore[index] 199 proxy = self.proxy_call_hook( 200 hook, 201 inputs[i], 202 hook_type="tensor_pre_hook", 203 ) 204 with disable_proxy_modes_tracing(): 205 inputs[i] = maybe_clone(inputs[i]) 206 self.bind_tensors_to_proxies([inputs[i]], [proxy]) 207 return inputs 208 209 def pre_hook(self, inputs, hook_id): 210 assert self.hooks_proxy is not None 211 hook = self.hooks_proxy[hook_id] # type: ignore[index] 212 proxies = self.proxy_call_hook( 213 hook, 214 inputs, 215 hook_type="pre_hook", 216 ) 217 with disable_proxy_modes_tracing(): 218 inputs = [maybe_clone(x) for x in inputs] 219 self.bind_tensors_to_proxies(inputs, proxies) 220 return inputs 221 222 def post_hook(self, outputs, inputs, hook_id): 223 assert self.hooks_proxy is not None 224 hook = self.hooks_proxy[hook_id] # type: ignore[index] 225 proxies = self.proxy_call_hook( 226 hook, 227 outputs, 228 inputs, 229 hook_type="post_hook", 230 ) 231 with disable_proxy_modes_tracing(): 232 outputs = [maybe_clone(x) for x in outputs] 233 self.bind_tensors_to_proxies(outputs, proxies) 234 return outputs 235 236 def post_acc_grad_hook(self, input, hook_id): 237 assert isinstance(input, torch.Tensor) 238 assert self.hooks_proxy is not None 239 hook = self.hooks_proxy[hook_id] # type: ignore[index] 240 proxy = self.proxy_call_hook( 241 hook, 242 input, 243 hook_type="post_acc_grad_hook", 244 ) 245 with disable_proxy_modes_tracing(): 246 input = [maybe_clone(input)] 247 self.bind_tensors_to_proxies(input, [proxy]) 248 return input 249 250 # Note: [Compiled autograd and cudagraphs] 251 # Eager autograd backward implements scalars as 0-dim tensors, see DivBackward0::other_. 252 # When compiled autograd traces those nodes, it lifts the scalar tensors, resulting in a graph 253 # with some cpu 0-dim tensor inputs. To prevent the entire graph from skipping cudagraph, we move the 254 # scalars tensors to cuda. This works because ATen/prims ops will accept cuda 0-dim tensors too. 255 def move_graph_nodes_to_cuda(self, graph) -> List[int]: 256 to_move: Dict[int, torch.fx.Node] = {} 257 has_cuda_inputs = False 258 nodes = list(graph.nodes) 259 assert nodes[0].target == "inputs" 260 inputs = nodes[0] 261 inputs_users = list(inputs.users.keys()) 262 # input access nodes should immediately follow placeholder nodes 263 first_getitem_idx = len(self.graph_placeholders) 264 assert nodes[first_getitem_idx] == inputs_users[0] 265 last_getitem_idx = first_getitem_idx + len(inputs_users) - 1 266 assert nodes[last_getitem_idx] == inputs_users[-1] 267 for i, node in enumerate(inputs_users): 268 if not has_cuda_inputs and node.meta["val"].device.type == "cuda": 269 has_cuda_inputs = True 270 continue 271 272 is_cpu = node.meta["val"].device.type == "cpu" 273 is_scalar = len(node.meta["val"].size()) == 0 274 if is_cpu and is_scalar: 275 node_users = list(node.users.keys()) 276 if all( 277 isinstance(user.target, torch._ops.OpOverload) 278 and user.target.namespace in ("prims", "aten") 279 for user in node_users 280 ): 281 # all users are prims/aten, can move safely 282 to_move[i] = node 283 284 # only move cpu scalars to cuda if there were cuda activations in this graph, 285 # this is to handle the case where cudagraphs is enabled on a cpu-only graph 286 if has_cuda_inputs: 287 for node in to_move.values(): 288 node.meta["val"] = node.meta["val"].cuda() 289 290 # return runtime indices we need to move to cuda 291 return list(to_move.keys()) 292 293 return [] 294 295 def end_capture(self, outputs): 296 self.fx_tracer.create_proxy( 297 "call_function", 298 FakeCompiledAutogradEngine._exec_final_callbacks_stub, 299 (), 300 {}, 301 ) 302 self.stack.close() 303 self.fx_tracer.create_node( 304 "output", 305 "output", 306 (self.fx_tracer.create_arg(self.to_proxy(outputs)),), 307 {}, 308 ) 309 self.rename_aot_dispatcher_nodes() 310 self.reorder_accumulate_grad_nodes() 311 runtime_inputs_to_move: List[int] = [] 312 if snapshot_cudagraph_enabled(): 313 runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) 314 315 graph = GraphModule( 316 self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd" 317 ) 318 set_locals_to_steal(graph, ["inputs"]) 319 lazy_graph_code = lazy_format_graph_code( 320 "Compiled autograd graph", 321 graph, 322 include_device=True, 323 include_stride=True, 324 colored=True, 325 ) 326 compiled_autograd_log.info("%s", lazy_graph_code) 327 verbose_log.debug("%s", lazy_graph_code) 328 trace_structured( 329 "compiled_autograd_graph", 330 payload_fn=lambda: graph.print_readable(print_output=False), 331 ) 332 333 def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks): 334 global in_compiled_autograd_region 335 try: 336 in_compiled_autograd_region = True 337 for i in runtime_inputs_to_move: 338 inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True) 339 340 return compiled_fn(inputs, sizes, scalars, hooks) 341 finally: 342 in_compiled_autograd_region = False 343 344 return runtime_wrapper, self.compiler_fn(graph) 345 346 def rename_aot_dispatcher_nodes(self): 347 """ 348 Renames nodes as they appear in the AOTDispatcher backward graphs, prefixed by AOT id 349 e.g. AOTDispatcher backward graph X's `sin_Y` -> `aotX_sin_Y` 350 """ 351 if self.aot_graph_cls_name is None: 352 return 353 354 def is_similar(a: torch.fx.node.Node, b: torch.fx.node.Node): 355 target_match = a.target == b.target 356 if not target_match: 357 target_match = ( 358 hasattr(a.target, "__name__") 359 and hasattr(b.target, "__name__") 360 and a.target.__name__ == b.target.__name__ 361 ) 362 return ( 363 target_match 364 and a.op == b.op 365 and a.type == b.type 366 and len(a.all_input_nodes) == len(b.all_input_nodes) 367 ) 368 369 for nodecall_index, info in self.aot_graph_infos.items(): 370 ca_node_start_idx = info["ca_node_start_idx"] 371 aot_id = info["aot_id"] 372 aot_graph = info["aot_gm"].graph 373 374 # 1. Find the first op from user code in the AOT graph 375 aot_it = iter(aot_graph.nodes) 376 aot_node = next(aot_it) 377 assert aot_node is not None 378 try: 379 while aot_node.op != "call_function": 380 aot_node = next(aot_it) 381 except StopIteration: 382 continue 383 384 try: 385 # 2. Find the first op in the compiled autograd graph segment 386 ca_it = iter(self.fx_tracer.graph.nodes) 387 for _ in range(ca_node_start_idx): 388 next(ca_it) 389 ca_node = next(ca_it) 390 391 # Graphs should all end with output node 392 while ca_node.op != "output" and not is_similar(ca_node, aot_node): 393 # The compiled autograd graph may contain lazily inserted ops 394 # We skip those when aligning nodes 395 ca_node = next(ca_it) 396 397 # 3. Keep alligned and rename nodes 398 while aot_node.op != "output" and ca_node.op != "output": 399 if not ca_node.users: 400 # TODO: DCE for compiled autograd graph 401 ca_node = next(ca_it) 402 continue 403 404 if not is_similar(aot_node, ca_node): 405 # There should be no lazily inserted ops in the middle of a match 406 # So any deviation is an error 407 raise StopIteration 408 409 ca_node.name = f"aot{aot_id}_{aot_node.name}" 410 for i, inp in enumerate(aot_node.all_input_nodes): 411 ca_node.all_input_nodes[i].name = f"aot{aot_id}_{inp.name}" 412 413 aot_node = next(aot_it) 414 ca_node = next(ca_it) 415 except StopIteration: 416 verbose_log.debug( 417 "Failed to match %s%s (NodeCall %s) nodes with AOT backward graph %s nodes", 418 self.aot_graph_cls_name, 419 aot_id, 420 nodecall_index, 421 aot_id, 422 ) 423 424 def reorder_accumulate_grad_nodes(self): 425 """ 426 Usage of AOTAutograd causes all the accumulate_grad_ nodes to get pushed to the end of 427 the graph. This differs from eager mode, which schedules them as soon as possible. This 428 pass attempts to reorder the graph to mimic eager behavior. 429 """ 430 for node in self.fx_tracer.graph.find_nodes( 431 op="call_function", target=torch.ops.inductor.accumulate_grad_.default 432 ): 433 arg = max(node.args) # last arg 434 if arg is not node.prev and arg.op != "placeholder": 435 arg.append(node) 436 437 def to_proxy(self, t): 438 if t is None: 439 return None 440 if isinstance(t, list): 441 return [self.to_proxy(x) for x in t] 442 if isinstance(t, tuple): 443 return tuple(self.to_proxy(x) for x in t) 444 # can it be torch.SymInt as the code used to imply? 445 assert isinstance(t, torch.Tensor) 446 proxy_tensor = fetch_object_proxy(self.fx_tracer, t) 447 assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor) 448 return proxy_tensor.proxy 449 450 def bind_tensors_to_proxies(self, tensors, proxies): 451 if isinstance(proxies, torch.fx.Proxy): 452 proxies = [proxies[i] for i in range(len(tensors))] # type: ignore[index] 453 assert len(tensors) == len(proxies) 454 track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer) 455 456 def bind_backward_state(self, index: int): 457 assert self.hooks_proxy is not None 458 proxy = self.hooks_proxy[index] # type: ignore[index] 459 bw_state = BackwardState() 460 track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer) 461 return bw_state 462 463 def set_node_origin( 464 self, 465 node_name: str, 466 nodecall_index: int, 467 pyobj: Optional[torch.autograd.Function], 468 ): 469 maybe_aot_id = "" 470 if pyobj is not None: 471 forward_cls = pyobj._forward_cls # type: ignore[attr-defined] 472 if hasattr(forward_cls, "_aot_id"): 473 # backward was created by AOT Dispatcher 474 self.aot_graph_cls_name = node_name 475 maybe_aot_id = forward_cls._aot_id 476 self.aot_graph_infos[nodecall_index] = { 477 "ca_node_start_idx": len(self.fx_tracer.graph.nodes), 478 "aot_id": maybe_aot_id, 479 "aot_gm": forward_cls._lazy_backward_info.bw_module, 480 } 481 482 new_code = f"{node_name}{maybe_aot_id} (NodeCall {nodecall_index})" 483 raw_stack_trace = CapturedTraceback.extract().format()[-1] 484 new_stack_trace = raw_stack_trace.replace( 485 "raw_stack_trace = CapturedTraceback.extract().format()[-1]", new_code 486 ) 487 set_stack_trace(new_stack_trace) 488 489 490# state of the autograd engine dispatch, kept in sync by enable/disable context managers 491compiled_autograd_enabled = False 492 493# global flag to check if we are processing graphs produced from a compiled autograd graph 494in_compiled_autograd_region = False 495 496 497@contextlib.contextmanager 498def enable(compiler_fn): 499 prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler( 500 functools.partial(AutogradCompilerInstance, compiler_fn) 501 ) 502 if snapshot_verbose_logging_enabled(): 503 torch._C._dynamo.compiled_autograd.set_verbose_logger(cpp_verbose_log_fn) 504 global compiled_autograd_enabled 505 compiled_autograd_enabled = True 506 try: 507 with torch.autograd.set_multithreading_enabled(False): 508 yield 509 finally: 510 if not prior: 511 compiled_autograd_enabled = False 512 torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) 513 514 515@contextlib.contextmanager 516def disable(): 517 prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) 518 global compiled_autograd_enabled 519 compiled_autograd_enabled = False 520 try: 521 yield 522 finally: 523 if prior: 524 compiled_autograd_enabled = True 525 torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) 526 527 528# return to starting state of a new process 529def reset() -> None: 530 compiled_autograd_enable = False 531 assert not in_compiled_autograd_region 532 torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) 533 torch._C._dynamo.compiled_autograd.set_verbose_logger(None) 534