1import collections 2import contextlib 3import dataclasses 4import functools 5import itertools 6import logging 7import os 8import os.path 9import pickle 10import pstats 11import shutil 12import subprocess 13from typing import Any, Callable, Dict, IO, Iterator, List, Optional, Type, Union 14from unittest.mock import patch 15 16import torch 17from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled 18from torch import fx as fx 19from torch._dynamo.repro.after_aot import save_graph_repro 20from torch._dynamo.utils import get_debug_dir 21from torch.fx.graph_module import GraphModule 22from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata 23from torch.fx.passes.tools_common import legalize_graph 24from torch.utils._pytree import tree_map 25 26from . import config, ir # noqa: F811, this is needed 27from .scheduler import ( 28 BaseSchedulerNode, 29 FusedSchedulerNode, 30 NopKernelSchedulerNode, 31 OutputNode, 32 SchedulerNode, 33) 34from .virtualized import V 35 36 37log = logging.getLogger(__name__) 38 39SchedulerNodeList = List[Any] 40BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"]) 41GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"] 42 43 44@functools.lru_cache(None) 45def has_dot() -> bool: 46 try: 47 subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE) 48 return True 49 except subprocess.SubprocessError: 50 return False 51 52 53def draw_buffers( 54 nodes: List[BaseSchedulerNode], 55 print_graph: bool = False, 56 fname: Optional[str] = None, 57) -> None: 58 """ 59 Draw a graph in fname.svg. 60 """ 61 if not has_dot(): 62 log.warning("draw_buffers() requires `graphviz` package") 63 return 64 65 if fname is None: 66 fname = get_graph_being_compiled() 67 68 graph = create_fx_from_snodes(nodes) 69 70 for node in graph.nodes: 71 if "fusion_meta" not in node.meta: 72 continue 73 group = node.meta["fusion_meta"].group 74 if isinstance(group, tuple): 75 if isinstance(group[1], int): 76 group = (group[1],) 77 else: 78 group = group[1] 79 80 # gather meta data 81 dtype = None 82 if isinstance(node, ir.ComputedBuffer): 83 dtype = node.data.dtype 84 85 metadata = TensorMetadata(group, dtype, None, None, None, None, None) # type: ignore[arg-type] 86 node.meta["tensor_meta"] = metadata 87 88 if print_graph: 89 print(graph) 90 91 gm = GraphModule({}, graph) 92 legalize_graph(gm) 93 gm.graph.lint() 94 draw_graph( 95 gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape 96 ) 97 98 99def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph: 100 """ 101 Creates a FX Graph from a list of SchedulerNode objects. 102 """ 103 104 def get_fake_func(name: str) -> Callable[..., int]: 105 def func1(*args: Any) -> int: 106 return 0 107 108 func1.__name__ = name 109 return func1 110 111 FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"]) 112 113 buf_to_fx_node = {} 114 node_to_fx_node = {} 115 graph = torch.fx.Graph() 116 first_node = None 117 118 outputs = [] 119 group: Any = None 120 # create call_function node for each Buffer and Kernel 121 for snode in snodes: 122 if snode.is_extern(): 123 node_type = "extern" 124 group = node_type 125 elif snode.is_template(): 126 node_type = "template" 127 group = node_type 128 elif isinstance(snode, NopKernelSchedulerNode): 129 node_type = "nop" 130 group = node_type 131 elif isinstance(snode, SchedulerNode): 132 node_type = "compute" 133 group = snode.group 134 elif isinstance(snode, FusedSchedulerNode): 135 node_type = "fused" 136 group = snode.group 137 else: 138 raise RuntimeError("Unknown node type") 139 140 fused_name = torch._inductor.utils.get_fused_kernel_name( 141 snode.get_nodes(), "original_aten" 142 ) 143 func_name = f"{node_type}: {fused_name}" 144 node_func = get_fake_func(func_name) 145 kwargs = {} 146 if hasattr(snode, "get_device"): 147 kwargs = {"device": snode.get_device()} 148 fx_node = graph.call_function(node_func, args=(), kwargs=kwargs) # type: ignore[arg-type] 149 150 def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool: 151 if isinstance(snode, FusedSchedulerNode): 152 return any(in_output(x) for x in snode.snodes) 153 return any( 154 isinstance(user.node, OutputNode) 155 for buf in snode.get_outputs() 156 for user in buf.users 157 ) 158 159 if in_output(snode): 160 outputs.append(fx_node) 161 name = snode.get_name() 162 fx_node.name = name 163 164 fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type) 165 166 node_to_fx_node[name] = fx_node 167 for buf in snode.get_outputs(): 168 buf_to_fx_node[buf.get_name()] = fx_node 169 170 if first_node is None: 171 first_node = fx_node 172 173 # create edges between nodes 174 for snode in snodes: 175 name = snode.get_name() 176 deps = snode.read_writes.reads 177 178 fx_node = node_to_fx_node[name] 179 new_args = [] 180 for dep in deps: 181 if dep.name in buf_to_fx_node: 182 dep_node = buf_to_fx_node[dep.name] 183 else: 184 with graph.inserting_before(first_node): 185 dep_node = graph.placeholder(dep.name) 186 buf_to_fx_node[dep.name] = dep_node 187 if dep_node == fx_node: # to avoid cycles 188 continue 189 new_args.append(dep_node) 190 191 fx_node.args = tuple(new_args) 192 193 graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) 194 return graph 195 196 197def update_orig_fx_node_name_to_buf_name( 198 nodes: Optional[SchedulerNodeList], 199 node_name_to_buf_name: Dict[str, str], 200 parent_buf_name: Optional[str] = None, 201 n_origins: int = 0, 202) -> None: 203 if nodes is None: 204 return 205 for node in nodes: 206 # for FusedSchedulerNode, traverse recursively into get_nodes() 207 buf_name = node.get_name() 208 children_nodes = node.get_nodes() 209 if children_nodes is not None and len(children_nodes) > 1: 210 update_orig_fx_node_name_to_buf_name( 211 children_nodes, 212 node_name_to_buf_name, 213 buf_name if parent_buf_name is None else parent_buf_name, 214 ) 215 continue 216 else: 217 assert len(children_nodes) == 1 and children_nodes[0] == node 218 219 ir_node = node.node 220 if ir_node is None or ir_node.origins is None: 221 continue 222 for origin in ir_node.origins: 223 node_name = origin.name 224 # when buf1 and buf2 both have origin=node1 225 # we draw node1 according to buf1 226 if node_name not in node_name_to_buf_name: 227 node_name_to_buf_name[node_name] = ( 228 buf_name if parent_buf_name is None else parent_buf_name 229 ) 230 231 232def get_node_name_to_buf_meta( 233 node_name_to_buf_name: Dict[str, str] 234) -> Dict[str, BufMeta]: 235 buf_name_to_n_node = {} 236 for node_name, buf_name in node_name_to_buf_name.items(): 237 if buf_name not in buf_name_to_n_node: 238 buf_name_to_n_node[buf_name] = {node_name} 239 else: 240 buf_name_to_n_node[buf_name].add(node_name) 241 242 node_name_to_buf_meta = {} 243 for node_name, buf_name in node_name_to_buf_name.items(): 244 n_node = len(buf_name_to_n_node[buf_name]) 245 node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node) 246 return node_name_to_buf_meta 247 248 249def annotate_orig_fx_with_snodes( 250 gm: torch.fx.GraphModule, 251 snodes: SchedulerNodeList, 252) -> None: 253 """ 254 Creates a FX Graph from a list of SchedulerNode objects. 255 """ 256 node_name_to_buf_name: Dict[str, str] = {} 257 update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name) 258 if node_name_to_buf_name is None: 259 return 260 node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name) 261 for node in gm.graph.nodes: 262 if node.name in node_name_to_buf_meta: 263 node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name) 264 265 266@contextlib.contextmanager 267def enable_aot_logging() -> Iterator[None]: 268 compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1" 269 270 import torch._functorch.aot_autograd 271 272 log = logging.getLogger(torch._functorch.aot_autograd.__name__) 273 274 stack = contextlib.ExitStack() 275 if not compile_debug: 276 try: 277 yield 278 finally: 279 stack.close() 280 return 281 282 # Enable all graphs to be logged to a file by setting the flags to True 283 # and the log level of the file logger to DEBUG 284 stack.enter_context(patch("functorch.compile.config.debug_partitioner", True)) 285 286 path = os.path.join(get_debug_dir(), "torchinductor") 287 os.makedirs(path, exist_ok=True) 288 289 fh = logging.FileHandler( 290 os.path.join( 291 path, 292 f"aot_{get_aot_graph_name()}_debug.log", 293 ) 294 ) 295 fh.setLevel(logging.DEBUG) 296 fh.setFormatter( 297 logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") 298 ) 299 log.addHandler(fh) 300 try: 301 yield 302 finally: 303 log.removeHandler(fh) 304 stack.close() 305 306 307class DebugContext: 308 _counter = itertools.count() 309 310 @staticmethod 311 def create_debug_dir(folder_name: str) -> Optional[str]: 312 debug_dir = config.trace.debug_dir or get_debug_dir() 313 for n in DebugContext._counter: 314 dirname = os.path.join( 315 debug_dir, 316 "torchinductor", 317 f"{folder_name}.{n}", 318 ) 319 if not os.path.exists(dirname): 320 os.makedirs(dirname) 321 return dirname 322 return None 323 324 def __init__(self) -> None: 325 self._prof = None 326 self._path = None 327 self._stack = contextlib.ExitStack() 328 329 def copy(self, new_path: str) -> None: 330 if not self._path: 331 return 332 assert new_path.endswith(".debug"), new_path 333 from filelock import FileLock 334 335 try: 336 with FileLock(f"{new_path}.lock"): 337 if os.path.exists(new_path): 338 shutil.rmtree(new_path) 339 shutil.copytree(self._path, new_path) 340 except OSError: 341 log.warning( 342 "Failed to copy debug files from %s to %s", self._path, new_path 343 ) 344 345 def fopen( 346 self, 347 filename: str, 348 write_mode: str = "w", 349 *args: Any, 350 **kwargs: Any, 351 ) -> IO[Any]: 352 assert self._path 353 return open(os.path.join(self._path, filename), write_mode, *args, **kwargs) 354 355 @contextlib.contextmanager 356 def fopen_context( 357 self, 358 filename: str, 359 write_mode: str = "w", 360 *args: Any, 361 **kwargs: Any, 362 ) -> Iterator[IO[Any]]: 363 assert self._path 364 with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f: 365 yield f 366 367 def filename(self, suffix: str) -> str: 368 assert self._path 369 return os.path.join(self._path, suffix) 370 371 def upload_tar(self) -> None: 372 if config.trace.upload_tar is not None: 373 import tarfile 374 375 assert self._path 376 tar_file = os.path.join( 377 self._path, f"{os.path.basename(self._path)}.tar.gz" 378 ) 379 with tarfile.open(tar_file, "w:gz") as tar: 380 tar.add(self._path, arcname=os.path.basename(self._path)) 381 config.trace.upload_tar(tar_file) 382 383 def __enter__(self) -> None: 384 if config.debug: 385 log = logging.getLogger("torch._dynamo") 386 prev_level = log.level 387 log.setLevel(logging.DEBUG) 388 389 def reset_log_level(level: Any) -> None: 390 log.setLevel(level) 391 392 self._stack.callback(reset_log_level, prev_level) 393 394 self._stack.enter_context(V.set_debug_handler(self)) 395 396 if not config.trace.enabled: 397 return 398 399 self._path = self.create_debug_dir(get_aot_graph_name()) # type: ignore[assignment] 400 401 if config.trace.debug_log: 402 self._setup_log_capture("debug.log", logging.DEBUG) 403 if config.trace.info_log: 404 self._setup_log_capture("info.log", logging.INFO) 405 406 def _setup_log_capture( 407 self, 408 filename: str, 409 level: int, 410 ) -> None: 411 log = logging.getLogger("torch._inductor") 412 fd = self._stack.enter_context(self.fopen(filename)) 413 ch = logging.StreamHandler(fd) 414 ch.setLevel(level) 415 ch.setFormatter( 416 logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") 417 ) 418 log.addHandler(ch) 419 log.setLevel(min(log.level, level)) 420 self._stack.callback(log.removeHandler, ch) 421 422 def __exit__( 423 self, 424 exc_type: Optional[Type[BaseException]], 425 exc_val: Optional[BaseException], 426 exc_tb: Optional[Any], 427 ) -> None: 428 if self._prof: 429 self._prof.disable() 430 self._save_profile_data() 431 432 if self._path: 433 self.upload_tar() 434 log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path) 435 self._stack.close() 436 437 def _save_profile_data(self) -> None: 438 assert self._prof 439 self._prof.dump_stats(self.filename("compile.prof")) 440 with self.fopen("compile.stats") as fd: 441 stats = pstats.Stats(self._prof, stream=fd) 442 stats.strip_dirs() 443 stats.sort_stats("cumtime") 444 stats.print_stats(100) 445 stats.sort_stats("tottime") 446 stats.print_stats(100) 447 448 def __getattr__(self, name: str) -> Optional[Callable[..., None]]: 449 if config.trace.enabled and getattr(config.trace, name): 450 try: 451 return getattr(DebugFormatter(self), name) 452 except Exception: 453 log.warning("Ignoring exception in debug code", exc_info=True) 454 return None 455 else: 456 457 def ignored(*args: Any, **kwargs: Any) -> None: 458 pass 459 460 return ignored 461 462 463class DebugFormatter: 464 def __init__(self, handler: DebugContext) -> None: 465 self.fopen = handler.fopen 466 self.fopen_context = handler.fopen_context 467 self.filename = handler.filename 468 self.handler = handler 469 470 def fx_graph( 471 self, 472 gm: torch.fx.GraphModule, 473 inputs: List[torch.Tensor], 474 ) -> None: 475 with self.fopen("fx_graph_runnable.py") as fd: 476 save_graph_repro(fd, gm, inputs, "inductor") 477 478 with self.fopen("fx_graph_readable.py") as fd: 479 fd.write(gm.print_readable(print_output=False)) 480 481 def fx_graph_transformed( 482 self, 483 gm: torch.fx.GraphModule, 484 inputs: List[torch.Tensor], 485 ) -> None: 486 with self.fopen("fx_graph_transformed.py") as fd: 487 fd.write(gm.print_readable(print_output=False)) 488 489 def ir_pre_fusion(self, nodes: SchedulerNodeList) -> None: 490 self._write_ir("ir_pre_fusion.txt", nodes) 491 492 def ir_post_fusion(self, nodes: SchedulerNodeList) -> None: 493 self._write_ir("ir_post_fusion.txt", nodes) 494 495 def _write_ir( 496 self, 497 filename: str, 498 nodes: SchedulerNodeList, 499 ) -> None: 500 with self.fopen(filename) as fd: 501 log.info("Writing debug ir to %s", fd.name) 502 for node in nodes: 503 fd.write(node.debug_str()) 504 fd.write("\n\n\n") 505 506 def graph_diagram(self, nodes: SchedulerNodeList) -> None: 507 draw_buffers(nodes, fname=self.filename("graph_diagram.svg")) 508 509 def draw_orig_fx_graph( 510 self, 511 gm: torch.fx.GraphModule, 512 nodes: SchedulerNodeList, 513 ) -> None: 514 annotate_orig_fx_with_snodes(gm, nodes) 515 draw_graph( 516 gm, 517 fname=self.filename("orig_fx_graph_diagram.svg"), 518 clear_meta=False, 519 prog=GRAPHVIZ_COMMAND_SCALABLE, 520 parse_stack_trace=True, 521 dot_graph_shape=config.trace.dot_graph_shape, 522 ) 523 524 def output_code(self, filename: str) -> None: 525 shutil.copy(filename, self.filename("output_code.py")) 526 527 def log_autotuning_results( 528 self, 529 name: str, 530 input_nodes: List[ir.IRNode], 531 timings: Dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821 532 elapse: float, 533 precompile_elapse: float, 534 ) -> None: 535 import json 536 537 from .ir import FixedLayout 538 539 def build_node_info(node: ir.IRNode) -> Dict[str, str]: 540 if hasattr(node, "name"): 541 node_name = node.name 542 else: 543 node_name = "" 544 node_info = { 545 "name": node_name, 546 "type": type(node).__name__, 547 } 548 try: 549 layout = node.get_layout() 550 if isinstance(layout, FixedLayout): 551 offset = 0 552 try: 553 offset = int(layout.offset) 554 except Exception: 555 try: 556 offset = V.graph.sizevars.size_hint( 557 layout.offset, fallback=0 558 ) 559 except Exception: 560 pass 561 static_layout = FixedLayout( 562 layout.device, 563 dtype=layout.dtype, 564 size=list(V.graph.sizevars.size_hints(layout.size)), 565 stride=list(V.graph.sizevars.size_hints(layout.stride)), 566 offset=offset, 567 ) 568 node_info["layout"] = str(static_layout) 569 else: 570 node_info["layout"] = str(node.get_layout()) 571 except Exception as e: 572 pass 573 try: 574 node_info["dtype"] = str(node.get_dtype()) 575 except Exception as e: 576 pass 577 try: 578 node_info["device"] = str(node.get_device()) 579 except Exception as e: 580 pass 581 try: 582 node_info["stride"] = str( 583 V.graph.sizevars.size_hints(node.get_stride()) 584 ) 585 except Exception as e: 586 pass 587 try: 588 node_info["size"] = str(V.graph.sizevars.size_hints(node.get_size())) 589 except Exception as e: 590 pass 591 try: 592 node_info["numel"] = str(V.graph.sizevars.size_hint(node.get_numel())) 593 except Exception as e: 594 pass 595 if hasattr(node, "data") and isinstance(node.data, ir.IRNode): 596 node_info["data"] = build_node_info(node.data) 597 return node_info 598 599 general_properties = { 600 "op_name": name, 601 "cuda_device_name": torch.cuda.get_device_name(), 602 "cuda_device_count": torch.cuda.device_count(), 603 "input_nodes": [build_node_info(node) for node in input_nodes], 604 "autotuning_time": elapse, 605 "precompile_time": precompile_elapse, 606 } 607 with self.fopen_context( 608 "autotuning_result_json_list.txt", "at", encoding="utf-8" 609 ) as fd: 610 for caller, time in timings.items(): 611 info_dict = dict(caller.info_dict()) 612 info_dict.update(general_properties) 613 info_dict["benchmark_result"] = time 614 json.dump(info_dict, fd) 615 fd.write("\n") 616 617 618@dataclasses.dataclass 619class TensorMetadataHolder: 620 tensor_metadata: TensorMetadata 621 device: torch.device 622 623 624save_args_cnt = itertools.count() 625 626 627def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None: 628 """ 629 This function is used to save arguments for a compile_fx_inner function call 630 to the file system. Later on one can replay the compile_fx_inner call 631 with the saved arguments using load_args_and_run_compile_fx_inner. 632 """ 633 634 folder = "/tmp/inductor_saved_args" 635 if not os.path.exists(folder): 636 os.mkdir(folder) 637 638 def handle_tensor(x: Any) -> Any: 639 """ 640 Pickle FakeTensor will result in error: 641 AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove' 642 643 Convert all Tensor to metadata. This may also makes pickle faster. 644 """ 645 if isinstance(x, torch.Tensor): 646 return TensorMetadataHolder(_extract_tensor_metadata(x), x.device) 647 else: 648 return x 649 650 args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs)) 651 652 fn_name = "compile_fx_inner" 653 path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl" 654 with open(path, "wb") as f: 655 pickle.dump((args_to_save, kwargs_to_save), f) 656 657 if log.isEnabledFor(logging.DEBUG): 658 message = f""" 659Arguments for a compile_fx_inner call is saved to {path}. To replay the call, 660run the following: 661 662from torch._inductor.debug import load_args_and_run_compile_fx_inner 663load_args_and_run_compile_fx_inner({path!r}) 664 """ 665 # call print rather than log.debug. log.debug will print message 666 # prefix for each line which makes the code snippet harder to be 667 # copied. 668 # Not a big deal since the code is already been guarded by checking 669 # the log level. 670 print(message) 671 672 673def load_args_and_run_compile_fx_inner(path: str) -> Any: 674 from torch._inductor.compile_fx import compile_fx_inner 675 676 with open(path, "rb") as f: 677 args, kwargs = pickle.load(f) 678 679 def handle_tensor(x: Any) -> Any: 680 if isinstance(x, TensorMetadataHolder): 681 return torch._dynamo.testing.rand_strided( 682 x.tensor_metadata.shape, 683 x.tensor_metadata.stride, 684 x.tensor_metadata.dtype, 685 x.device, 686 ) 687 else: 688 return x 689 690 fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) 691 with fake_mode, config.patch("save_args", False): 692 args, kwargs = tree_map(handle_tensor, (args, kwargs)) 693 return compile_fx_inner(*args, **kwargs) 694