xref: /aosp_15_r20/external/pytorch/torch/_inductor/debug.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import collections
2import contextlib
3import dataclasses
4import functools
5import itertools
6import logging
7import os
8import os.path
9import pickle
10import pstats
11import shutil
12import subprocess
13from typing import Any, Callable, Dict, IO, Iterator, List, Optional, Type, Union
14from unittest.mock import patch
15
16import torch
17from functorch.compile import draw_graph, get_aot_graph_name, get_graph_being_compiled
18from torch import fx as fx
19from torch._dynamo.repro.after_aot import save_graph_repro
20from torch._dynamo.utils import get_debug_dir
21from torch.fx.graph_module import GraphModule
22from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
23from torch.fx.passes.tools_common import legalize_graph
24from torch.utils._pytree import tree_map
25
26from . import config, ir  # noqa: F811, this is needed
27from .scheduler import (
28    BaseSchedulerNode,
29    FusedSchedulerNode,
30    NopKernelSchedulerNode,
31    OutputNode,
32    SchedulerNode,
33)
34from .virtualized import V
35
36
37log = logging.getLogger(__name__)
38
39SchedulerNodeList = List[Any]
40BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
41GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
42
43
44@functools.lru_cache(None)
45def has_dot() -> bool:
46    try:
47        subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
48        return True
49    except subprocess.SubprocessError:
50        return False
51
52
53def draw_buffers(
54    nodes: List[BaseSchedulerNode],
55    print_graph: bool = False,
56    fname: Optional[str] = None,
57) -> None:
58    """
59    Draw a graph in fname.svg.
60    """
61    if not has_dot():
62        log.warning("draw_buffers() requires `graphviz` package")
63        return
64
65    if fname is None:
66        fname = get_graph_being_compiled()
67
68    graph = create_fx_from_snodes(nodes)
69
70    for node in graph.nodes:
71        if "fusion_meta" not in node.meta:
72            continue
73        group = node.meta["fusion_meta"].group
74        if isinstance(group, tuple):
75            if isinstance(group[1], int):
76                group = (group[1],)
77            else:
78                group = group[1]
79
80        # gather meta data
81        dtype = None
82        if isinstance(node, ir.ComputedBuffer):
83            dtype = node.data.dtype
84
85        metadata = TensorMetadata(group, dtype, None, None, None, None, None)  # type: ignore[arg-type]
86        node.meta["tensor_meta"] = metadata
87
88    if print_graph:
89        print(graph)
90
91    gm = GraphModule({}, graph)
92    legalize_graph(gm)
93    gm.graph.lint()
94    draw_graph(
95        gm, fname, clear_meta=False, dot_graph_shape=config.trace.dot_graph_shape
96    )
97
98
99def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
100    """
101    Creates a FX Graph from a list of SchedulerNode objects.
102    """
103
104    def get_fake_func(name: str) -> Callable[..., int]:
105        def func1(*args: Any) -> int:
106            return 0
107
108        func1.__name__ = name
109        return func1
110
111    FusionMeta = collections.namedtuple("FusionMeta", ["group", "snode", "type"])
112
113    buf_to_fx_node = {}
114    node_to_fx_node = {}
115    graph = torch.fx.Graph()
116    first_node = None
117
118    outputs = []
119    group: Any = None
120    # create call_function node for each Buffer and Kernel
121    for snode in snodes:
122        if snode.is_extern():
123            node_type = "extern"
124            group = node_type
125        elif snode.is_template():
126            node_type = "template"
127            group = node_type
128        elif isinstance(snode, NopKernelSchedulerNode):
129            node_type = "nop"
130            group = node_type
131        elif isinstance(snode, SchedulerNode):
132            node_type = "compute"
133            group = snode.group
134        elif isinstance(snode, FusedSchedulerNode):
135            node_type = "fused"
136            group = snode.group
137        else:
138            raise RuntimeError("Unknown node type")
139
140        fused_name = torch._inductor.utils.get_fused_kernel_name(
141            snode.get_nodes(), "original_aten"
142        )
143        func_name = f"{node_type}: {fused_name}"
144        node_func = get_fake_func(func_name)
145        kwargs = {}
146        if hasattr(snode, "get_device"):
147            kwargs = {"device": snode.get_device()}
148        fx_node = graph.call_function(node_func, args=(), kwargs=kwargs)  # type: ignore[arg-type]
149
150        def in_output(snode: Union[BaseSchedulerNode, FusedSchedulerNode]) -> bool:
151            if isinstance(snode, FusedSchedulerNode):
152                return any(in_output(x) for x in snode.snodes)
153            return any(
154                isinstance(user.node, OutputNode)
155                for buf in snode.get_outputs()
156                for user in buf.users
157            )
158
159        if in_output(snode):
160            outputs.append(fx_node)
161        name = snode.get_name()
162        fx_node.name = name
163
164        fx_node.meta["fusion_meta"] = FusionMeta(group, snode, node_type)
165
166        node_to_fx_node[name] = fx_node
167        for buf in snode.get_outputs():
168            buf_to_fx_node[buf.get_name()] = fx_node
169
170        if first_node is None:
171            first_node = fx_node
172
173    # create edges between nodes
174    for snode in snodes:
175        name = snode.get_name()
176        deps = snode.read_writes.reads
177
178        fx_node = node_to_fx_node[name]
179        new_args = []
180        for dep in deps:
181            if dep.name in buf_to_fx_node:
182                dep_node = buf_to_fx_node[dep.name]
183            else:
184                with graph.inserting_before(first_node):
185                    dep_node = graph.placeholder(dep.name)
186                    buf_to_fx_node[dep.name] = dep_node
187            if dep_node == fx_node:  # to avoid cycles
188                continue
189            new_args.append(dep_node)
190
191        fx_node.args = tuple(new_args)
192
193    graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
194    return graph
195
196
197def update_orig_fx_node_name_to_buf_name(
198    nodes: Optional[SchedulerNodeList],
199    node_name_to_buf_name: Dict[str, str],
200    parent_buf_name: Optional[str] = None,
201    n_origins: int = 0,
202) -> None:
203    if nodes is None:
204        return
205    for node in nodes:
206        # for FusedSchedulerNode, traverse recursively into get_nodes()
207        buf_name = node.get_name()
208        children_nodes = node.get_nodes()
209        if children_nodes is not None and len(children_nodes) > 1:
210            update_orig_fx_node_name_to_buf_name(
211                children_nodes,
212                node_name_to_buf_name,
213                buf_name if parent_buf_name is None else parent_buf_name,
214            )
215            continue
216        else:
217            assert len(children_nodes) == 1 and children_nodes[0] == node
218
219        ir_node = node.node
220        if ir_node is None or ir_node.origins is None:
221            continue
222        for origin in ir_node.origins:
223            node_name = origin.name
224            # when buf1 and buf2 both have origin=node1
225            # we draw node1 according to buf1
226            if node_name not in node_name_to_buf_name:
227                node_name_to_buf_name[node_name] = (
228                    buf_name if parent_buf_name is None else parent_buf_name
229                )
230
231
232def get_node_name_to_buf_meta(
233    node_name_to_buf_name: Dict[str, str]
234) -> Dict[str, BufMeta]:
235    buf_name_to_n_node = {}
236    for node_name, buf_name in node_name_to_buf_name.items():
237        if buf_name not in buf_name_to_n_node:
238            buf_name_to_n_node[buf_name] = {node_name}
239        else:
240            buf_name_to_n_node[buf_name].add(node_name)
241
242    node_name_to_buf_meta = {}
243    for node_name, buf_name in node_name_to_buf_name.items():
244        n_node = len(buf_name_to_n_node[buf_name])
245        node_name_to_buf_meta[node_name] = BufMeta(buf_name, n_node)
246    return node_name_to_buf_meta
247
248
249def annotate_orig_fx_with_snodes(
250    gm: torch.fx.GraphModule,
251    snodes: SchedulerNodeList,
252) -> None:
253    """
254    Creates a FX Graph from a list of SchedulerNode objects.
255    """
256    node_name_to_buf_name: Dict[str, str] = {}
257    update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name)
258    if node_name_to_buf_name is None:
259        return
260    node_name_to_buf_meta = get_node_name_to_buf_meta(node_name_to_buf_name)
261    for node in gm.graph.nodes:
262        if node.name in node_name_to_buf_meta:
263            node.meta["buf_meta"] = node_name_to_buf_meta.get(node.name)
264
265
266@contextlib.contextmanager
267def enable_aot_logging() -> Iterator[None]:
268    compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
269
270    import torch._functorch.aot_autograd
271
272    log = logging.getLogger(torch._functorch.aot_autograd.__name__)
273
274    stack = contextlib.ExitStack()
275    if not compile_debug:
276        try:
277            yield
278        finally:
279            stack.close()
280        return
281
282    # Enable all graphs to be logged to a file by setting the flags to True
283    # and the log level of the file logger to DEBUG
284    stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
285
286    path = os.path.join(get_debug_dir(), "torchinductor")
287    os.makedirs(path, exist_ok=True)
288
289    fh = logging.FileHandler(
290        os.path.join(
291            path,
292            f"aot_{get_aot_graph_name()}_debug.log",
293        )
294    )
295    fh.setLevel(logging.DEBUG)
296    fh.setFormatter(
297        logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
298    )
299    log.addHandler(fh)
300    try:
301        yield
302    finally:
303        log.removeHandler(fh)
304        stack.close()
305
306
307class DebugContext:
308    _counter = itertools.count()
309
310    @staticmethod
311    def create_debug_dir(folder_name: str) -> Optional[str]:
312        debug_dir = config.trace.debug_dir or get_debug_dir()
313        for n in DebugContext._counter:
314            dirname = os.path.join(
315                debug_dir,
316                "torchinductor",
317                f"{folder_name}.{n}",
318            )
319            if not os.path.exists(dirname):
320                os.makedirs(dirname)
321                return dirname
322        return None
323
324    def __init__(self) -> None:
325        self._prof = None
326        self._path = None
327        self._stack = contextlib.ExitStack()
328
329    def copy(self, new_path: str) -> None:
330        if not self._path:
331            return
332        assert new_path.endswith(".debug"), new_path
333        from filelock import FileLock
334
335        try:
336            with FileLock(f"{new_path}.lock"):
337                if os.path.exists(new_path):
338                    shutil.rmtree(new_path)
339                shutil.copytree(self._path, new_path)
340        except OSError:
341            log.warning(
342                "Failed to copy debug files from %s to %s", self._path, new_path
343            )
344
345    def fopen(
346        self,
347        filename: str,
348        write_mode: str = "w",
349        *args: Any,
350        **kwargs: Any,
351    ) -> IO[Any]:
352        assert self._path
353        return open(os.path.join(self._path, filename), write_mode, *args, **kwargs)
354
355    @contextlib.contextmanager
356    def fopen_context(
357        self,
358        filename: str,
359        write_mode: str = "w",
360        *args: Any,
361        **kwargs: Any,
362    ) -> Iterator[IO[Any]]:
363        assert self._path
364        with open(os.path.join(self._path, filename), write_mode, *args, **kwargs) as f:
365            yield f
366
367    def filename(self, suffix: str) -> str:
368        assert self._path
369        return os.path.join(self._path, suffix)
370
371    def upload_tar(self) -> None:
372        if config.trace.upload_tar is not None:
373            import tarfile
374
375            assert self._path
376            tar_file = os.path.join(
377                self._path, f"{os.path.basename(self._path)}.tar.gz"
378            )
379            with tarfile.open(tar_file, "w:gz") as tar:
380                tar.add(self._path, arcname=os.path.basename(self._path))
381            config.trace.upload_tar(tar_file)
382
383    def __enter__(self) -> None:
384        if config.debug:
385            log = logging.getLogger("torch._dynamo")
386            prev_level = log.level
387            log.setLevel(logging.DEBUG)
388
389            def reset_log_level(level: Any) -> None:
390                log.setLevel(level)
391
392            self._stack.callback(reset_log_level, prev_level)
393
394        self._stack.enter_context(V.set_debug_handler(self))
395
396        if not config.trace.enabled:
397            return
398
399        self._path = self.create_debug_dir(get_aot_graph_name())  # type: ignore[assignment]
400
401        if config.trace.debug_log:
402            self._setup_log_capture("debug.log", logging.DEBUG)
403        if config.trace.info_log:
404            self._setup_log_capture("info.log", logging.INFO)
405
406    def _setup_log_capture(
407        self,
408        filename: str,
409        level: int,
410    ) -> None:
411        log = logging.getLogger("torch._inductor")
412        fd = self._stack.enter_context(self.fopen(filename))
413        ch = logging.StreamHandler(fd)
414        ch.setLevel(level)
415        ch.setFormatter(
416            logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
417        )
418        log.addHandler(ch)
419        log.setLevel(min(log.level, level))
420        self._stack.callback(log.removeHandler, ch)
421
422    def __exit__(
423        self,
424        exc_type: Optional[Type[BaseException]],
425        exc_val: Optional[BaseException],
426        exc_tb: Optional[Any],
427    ) -> None:
428        if self._prof:
429            self._prof.disable()
430            self._save_profile_data()
431
432        if self._path:
433            self.upload_tar()
434            log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
435        self._stack.close()
436
437    def _save_profile_data(self) -> None:
438        assert self._prof
439        self._prof.dump_stats(self.filename("compile.prof"))
440        with self.fopen("compile.stats") as fd:
441            stats = pstats.Stats(self._prof, stream=fd)
442            stats.strip_dirs()
443            stats.sort_stats("cumtime")
444            stats.print_stats(100)
445            stats.sort_stats("tottime")
446            stats.print_stats(100)
447
448    def __getattr__(self, name: str) -> Optional[Callable[..., None]]:
449        if config.trace.enabled and getattr(config.trace, name):
450            try:
451                return getattr(DebugFormatter(self), name)
452            except Exception:
453                log.warning("Ignoring exception in debug code", exc_info=True)
454                return None
455        else:
456
457            def ignored(*args: Any, **kwargs: Any) -> None:
458                pass
459
460            return ignored
461
462
463class DebugFormatter:
464    def __init__(self, handler: DebugContext) -> None:
465        self.fopen = handler.fopen
466        self.fopen_context = handler.fopen_context
467        self.filename = handler.filename
468        self.handler = handler
469
470    def fx_graph(
471        self,
472        gm: torch.fx.GraphModule,
473        inputs: List[torch.Tensor],
474    ) -> None:
475        with self.fopen("fx_graph_runnable.py") as fd:
476            save_graph_repro(fd, gm, inputs, "inductor")
477
478        with self.fopen("fx_graph_readable.py") as fd:
479            fd.write(gm.print_readable(print_output=False))
480
481    def fx_graph_transformed(
482        self,
483        gm: torch.fx.GraphModule,
484        inputs: List[torch.Tensor],
485    ) -> None:
486        with self.fopen("fx_graph_transformed.py") as fd:
487            fd.write(gm.print_readable(print_output=False))
488
489    def ir_pre_fusion(self, nodes: SchedulerNodeList) -> None:
490        self._write_ir("ir_pre_fusion.txt", nodes)
491
492    def ir_post_fusion(self, nodes: SchedulerNodeList) -> None:
493        self._write_ir("ir_post_fusion.txt", nodes)
494
495    def _write_ir(
496        self,
497        filename: str,
498        nodes: SchedulerNodeList,
499    ) -> None:
500        with self.fopen(filename) as fd:
501            log.info("Writing debug ir to  %s", fd.name)
502            for node in nodes:
503                fd.write(node.debug_str())
504                fd.write("\n\n\n")
505
506    def graph_diagram(self, nodes: SchedulerNodeList) -> None:
507        draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
508
509    def draw_orig_fx_graph(
510        self,
511        gm: torch.fx.GraphModule,
512        nodes: SchedulerNodeList,
513    ) -> None:
514        annotate_orig_fx_with_snodes(gm, nodes)
515        draw_graph(
516            gm,
517            fname=self.filename("orig_fx_graph_diagram.svg"),
518            clear_meta=False,
519            prog=GRAPHVIZ_COMMAND_SCALABLE,
520            parse_stack_trace=True,
521            dot_graph_shape=config.trace.dot_graph_shape,
522        )
523
524    def output_code(self, filename: str) -> None:
525        shutil.copy(filename, self.filename("output_code.py"))
526
527    def log_autotuning_results(
528        self,
529        name: str,
530        input_nodes: List[ir.IRNode],
531        timings: Dict["ChoiceCaller", float],  # type: ignore[name-defined] # noqa: F821
532        elapse: float,
533        precompile_elapse: float,
534    ) -> None:
535        import json
536
537        from .ir import FixedLayout
538
539        def build_node_info(node: ir.IRNode) -> Dict[str, str]:
540            if hasattr(node, "name"):
541                node_name = node.name
542            else:
543                node_name = ""
544            node_info = {
545                "name": node_name,
546                "type": type(node).__name__,
547            }
548            try:
549                layout = node.get_layout()
550                if isinstance(layout, FixedLayout):
551                    offset = 0
552                    try:
553                        offset = int(layout.offset)
554                    except Exception:
555                        try:
556                            offset = V.graph.sizevars.size_hint(
557                                layout.offset, fallback=0
558                            )
559                        except Exception:
560                            pass
561                    static_layout = FixedLayout(
562                        layout.device,
563                        dtype=layout.dtype,
564                        size=list(V.graph.sizevars.size_hints(layout.size)),
565                        stride=list(V.graph.sizevars.size_hints(layout.stride)),
566                        offset=offset,
567                    )
568                    node_info["layout"] = str(static_layout)
569                else:
570                    node_info["layout"] = str(node.get_layout())
571            except Exception as e:
572                pass
573            try:
574                node_info["dtype"] = str(node.get_dtype())
575            except Exception as e:
576                pass
577            try:
578                node_info["device"] = str(node.get_device())
579            except Exception as e:
580                pass
581            try:
582                node_info["stride"] = str(
583                    V.graph.sizevars.size_hints(node.get_stride())
584                )
585            except Exception as e:
586                pass
587            try:
588                node_info["size"] = str(V.graph.sizevars.size_hints(node.get_size()))
589            except Exception as e:
590                pass
591            try:
592                node_info["numel"] = str(V.graph.sizevars.size_hint(node.get_numel()))
593            except Exception as e:
594                pass
595            if hasattr(node, "data") and isinstance(node.data, ir.IRNode):
596                node_info["data"] = build_node_info(node.data)
597            return node_info
598
599        general_properties = {
600            "op_name": name,
601            "cuda_device_name": torch.cuda.get_device_name(),
602            "cuda_device_count": torch.cuda.device_count(),
603            "input_nodes": [build_node_info(node) for node in input_nodes],
604            "autotuning_time": elapse,
605            "precompile_time": precompile_elapse,
606        }
607        with self.fopen_context(
608            "autotuning_result_json_list.txt", "at", encoding="utf-8"
609        ) as fd:
610            for caller, time in timings.items():
611                info_dict = dict(caller.info_dict())
612                info_dict.update(general_properties)
613                info_dict["benchmark_result"] = time
614                json.dump(info_dict, fd)
615                fd.write("\n")
616
617
618@dataclasses.dataclass
619class TensorMetadataHolder:
620    tensor_metadata: TensorMetadata
621    device: torch.device
622
623
624save_args_cnt = itertools.count()
625
626
627def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None:
628    """
629    This function is used to save arguments for a compile_fx_inner function call
630    to the file system.  Later on one can replay the compile_fx_inner call
631    with the saved arguments using load_args_and_run_compile_fx_inner.
632    """
633
634    folder = "/tmp/inductor_saved_args"
635    if not os.path.exists(folder):
636        os.mkdir(folder)
637
638    def handle_tensor(x: Any) -> Any:
639        """
640        Pickle FakeTensor will result in error:
641        AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'
642
643        Convert all Tensor to metadata. This may also makes pickle faster.
644        """
645        if isinstance(x, torch.Tensor):
646            return TensorMetadataHolder(_extract_tensor_metadata(x), x.device)
647        else:
648            return x
649
650    args_to_save, kwargs_to_save = tree_map(handle_tensor, (args, kwargs))
651
652    fn_name = "compile_fx_inner"
653    path = f"{folder}/{fn_name}_{next(save_args_cnt)}.pkl"
654    with open(path, "wb") as f:
655        pickle.dump((args_to_save, kwargs_to_save), f)
656
657    if log.isEnabledFor(logging.DEBUG):
658        message = f"""
659Arguments for a compile_fx_inner call is saved to {path}. To replay the call,
660run the following:
661
662from torch._inductor.debug import load_args_and_run_compile_fx_inner
663load_args_and_run_compile_fx_inner({path!r})
664        """
665        # call print rather than log.debug. log.debug will print message
666        # prefix for each line which makes the code snippet harder to be
667        # copied.
668        # Not a big deal since the code is already been guarded by checking
669        # the log level.
670        print(message)
671
672
673def load_args_and_run_compile_fx_inner(path: str) -> Any:
674    from torch._inductor.compile_fx import compile_fx_inner
675
676    with open(path, "rb") as f:
677        args, kwargs = pickle.load(f)
678
679    def handle_tensor(x: Any) -> Any:
680        if isinstance(x, TensorMetadataHolder):
681            return torch._dynamo.testing.rand_strided(
682                x.tensor_metadata.shape,
683                x.tensor_metadata.stride,
684                x.tensor_metadata.dtype,
685                x.device,
686            )
687        else:
688            return x
689
690    fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
691    with fake_mode, config.patch("save_args", False):
692        args, kwargs = tree_map(handle_tensor, (args, kwargs))
693        return compile_fx_inner(*args, **kwargs)
694