xref: /aosp_15_r20/external/pytorch/torch/contrib/_tensorboard_vis.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import time
3from collections import defaultdict
4from functools import partial
5from typing import DefaultDict
6
7import torch
8
9
10# Unfortunately it doesn't seem as if there was any way to get TensorBoard to do
11# anything without having TF installed, and so this file has a hard dependency on it
12# as well. It really is a debugging tool, so it doesn't matter.
13try:
14    from tensorflow.core.util import event_pb2
15    from tensorflow.core.framework import graph_pb2
16    from tensorflow.python.summary.writer.writer import FileWriter
17except ImportError:
18    raise ImportError("TensorBoard visualization of GraphExecutors requires having "
19                      "TensorFlow installed") from None
20
21
22def dump_tensorboard_summary(graph_executor, logdir):
23    with FileWriter(logdir) as w:
24        pb_graph = visualize(graph_executor)
25        evt = event_pb2.Event(wall_time=time.time(), graph_def=pb_graph.SerializeToString())
26        w.add_event(evt)
27
28
29def visualize(graph, name_prefix='', pb_graph=None, executors_it=None):
30    """Visualizes an independent graph, or a graph executor."""
31    value_map = {}
32    pb_graph = pb_graph or graph_pb2.GraphDef()
33
34    if isinstance(graph, torch._C.GraphExecutorState):
35        visualize_graph_executor(graph, name_prefix, pb_graph,
36                                 partial(visualize, pb_graph=pb_graph))
37        return pb_graph
38
39    # Set up an input node
40    input_node = pb_graph.node.add(op='input', name=name_prefix + 'input')
41    for i, value in enumerate(graph.param_node().outputs()):
42        value_map[value.unique()] = name_prefix + 'input:' + str(i)
43
44    visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it)
45
46    # Gather all outputs
47    return_node = pb_graph.node.add(op='output', name=name_prefix + 'output')
48    for value in graph.return_node().inputs():
49        return_node.input.append(value_map[value.unique()])
50
51    return pb_graph
52
53
54def visualize_graph_executor(state, name_prefix, pb_graph, inline_graph):
55    """Append the state of a given GraphExecutor to the graph protobuf.
56
57    Args:
58        state (GraphExecutor or GraphExecutorState): GraphExecutor to display.
59        name_prefix (str): Name prefix of the containing subgraph.
60        pb_graph (GraphDef): graph to append to.
61        inline_graph (Callable): a function that handles setting up a value_map,
62            so that some graphs in here can be inlined. This is necessary, because
63            this will simply be `visualize` for the top-level GraphExecutor,
64            or `inline_graph` for all nested ones.
65
66            The signature should look like (Graph, name_prefix) -> ().
67            It will be called exactly once.
68
69    The strategy is to embed all different configurations as independent subgraphs,
70    while inlining the original graph as the one that actually produces the values.
71    """
72    if state.autograd_fallback_graph is not None:
73        visualize(graph=state.autograd_fallback_graph,
74                  name_prefix=name_prefix + 'autograd_fallback/',
75                  pb_graph=pb_graph,
76                  executors_it=iter(state.autograd_fallback.executors()))
77
78    for i, (arg_spec, plan) in enumerate(state.execution_plans.items()):
79        subgraph_name = name_prefix + f'plan{i}/'
80
81        # Create a disconnected node that will keep information regarding the input
82        # types of this trace. This is unfortunately a bit too verbose to be included
83        # in the subgraph name.
84        input_kinds = pb_graph.node.add(op='INPUT_KIND', name=subgraph_name)
85        input_kinds.attr['inputs'].s = repr(arg_spec).encode('ascii')
86
87        visualize(plan.graph, subgraph_name, pb_graph, iter(plan.code.executors()))
88
89        # Show gradient as an independent subgraph of this plan
90        if plan.grad_executor is not None:
91            grad_subgraph_name = subgraph_name + 'grad/'
92            visualize(plan.grad_executor, grad_subgraph_name, pb_graph)
93
94    return inline_graph(state.graph, name_prefix + 'original/')
95
96
97def visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it=None):
98    """Recursive part of visualize (basically skips setting up the input and output nodes)."""
99    def inline_graph(subgraph, name, node):
100        rec_value_map = {inp.unique(): value_map[val.unique()]
101                         for inp, val in zip(subgraph.inputs(), node.inputs())}
102        visualize_rec(graph=subgraph,
103                      value_map=rec_value_map,
104                      name_prefix=name,
105                      pb_graph=pb_graph)
106        for out, val in zip(subgraph.outputs(), node.outputs()):
107            value_map[val.unique()] = rec_value_map[out.unique()]
108
109    op_id_counter: DefaultDict[str, int] = defaultdict(int)
110
111    def name_for(node):
112        kind = node.kind()[node.kind().index('::') + 2:]
113        op_id_counter[kind] += 1
114        return kind, name_prefix + kind + '_' + str(op_id_counter[kind])
115
116    def add_fusion_group(node):
117        op, name = name_for(node)
118        inline_graph(node.g('Subgraph'), name + '/', node)
119
120    def add_graph_executor(node):
121        op, name = name_for(node)
122        if executors_it is None:
123            add_node(node)
124        else:
125            ge = next(executors_it)
126            visualize_graph_executor(ge, name + '/', pb_graph,
127                                     partial(inline_graph, node=node))
128
129    def add_node(node):
130        if node.kind() == 'prim::FusionGroup':
131            return add_fusion_group(node)
132        elif node.kind() == 'prim::GraphExecutor':
133            return add_graph_executor(node)
134        op, name = name_for(node)
135        pb_node = pb_graph.node.add(op=op, name=name)
136        for value in node.inputs():
137            pb_node.input.append(value_map[value.unique()])
138        # TODO: handle attrs
139        for i, value in enumerate(node.outputs()):
140            value_map[value.unique()] = name + ':' + str(i)
141
142    for node in graph.nodes():
143        add_node(node)
144