xref: /aosp_15_r20/external/pytorch/torch/fx/_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import sys
3from typing import Dict, Optional
4
5import torch
6from torch._logging import LazyString
7
8
9def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
10    """
11    Returns a LazyString that formats the graph code.
12    """
13
14    def format_name():
15        if maybe_id is not None:
16            return f"{name} {maybe_id}"
17        else:
18            return name
19
20    if "print_output" not in kwargs:
21        kwargs["print_output"] = False
22
23    if "colored" in kwargs and not sys.stdout.isatty():
24        kwargs["colored"] = False
25
26    return LazyString(
27        lambda: _format_graph_code(
28            f"===== {format_name()} =====\n",
29            gm.forward.__code__.co_filename,
30            gm.print_readable(**kwargs),
31        )
32    )
33
34
35def _format_graph_code(name, filename, graph_str):
36    """
37    Returns a string that formats the graph code.
38    """
39    return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
40
41
42def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]:
43    """
44    Returns the nn_module_stack of the first call_function node.
45    """
46    for node in graph.nodes:
47        if node.op == "call_function" and "nn_module_stack" in node.meta:
48            return node.meta["nn_module_stack"]
49    return None
50
51
52def get_node_context(node, num_nodes=2) -> str:
53    """
54    Returns a string of the last num_nodes nodes in the graph.
55    """
56    node_contexts = []
57    cur = node
58    for i in range(num_nodes):
59        node_contexts.append(cur.format_node())
60        if cur.op == "root":
61            break
62        cur = cur.prev
63    return "\n".join(node_contexts[::-1])
64