xref: /aosp_15_r20/external/pytorch/torch/utils/tensorboard/_pytorch_graph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from collections import OrderedDict
3import contextlib
4from typing import Dict, Any
5
6from tensorboard.compat.proto.config_pb2 import RunMetadata
7from tensorboard.compat.proto.graph_pb2 import GraphDef
8from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats
9from tensorboard.compat.proto.versions_pb2 import VersionDef
10
11import torch
12from ._proto_graph import node_proto
13
14methods_OP = [
15    "attributeNames",
16    "hasMultipleOutputs",
17    "hasUses",
18    "inputs",
19    "kind",
20    "outputs",
21    "outputsSize",
22    "scopeName",
23]
24# Some additional methods to explure for methods_IO are
25#
26#   'unique' (type int)
27#   'type' (type <Tensor<class 'torch._C.Type'>>)
28#
29# But the below are sufficient for now.
30methods_IO = ["node", "offset", "debugName"]
31
32GETATTR_KIND = "prim::GetAttr"
33CLASSTYPE_KIND = "ClassType"
34
35
36class NodeBase:
37    def __init__(
38        self,
39        debugName=None,
40        inputs=None,
41        scope=None,
42        tensor_size=None,
43        op_type="UnSpecified",
44        attributes="",
45    ):
46        # TODO; Specify a __slots__ for this class or potentially
47        # used namedtuple instead
48        self.debugName = debugName
49        self.inputs = inputs
50        self.tensor_size = tensor_size
51        self.kind = op_type
52        self.attributes = attributes
53        self.scope = scope
54
55    def __repr__(self):
56        repr = []
57        repr.append(str(type(self)))
58        for m in dir(self):
59            if "__" not in m:
60                repr.append(
61                    m + ": " + str(getattr(self, m)) + str(type(getattr(self, m)))
62                )
63        return "\n".join(repr) + "\n\n"
64
65
66class NodePy(NodeBase):
67    def __init__(self, node_cpp, valid_methods):
68        super().__init__(node_cpp)
69        valid_methods = valid_methods[:]
70        self.inputs = []
71
72        for m in valid_methods:
73            if m == "inputs" or m == "outputs":
74                list_of_node = list(getattr(node_cpp, m)())
75                io_unique_names = []
76                io_tensor_sizes = []
77                for n in list_of_node:
78                    io_unique_names.append(n.debugName())
79                    if n.isCompleteTensor():
80                        io_tensor_sizes.append(n.type().sizes())
81                    else:
82                        io_tensor_sizes.append(None)
83
84                setattr(self, m, io_unique_names)
85                setattr(self, m + "tensor_size", io_tensor_sizes)
86
87            else:
88                setattr(self, m, getattr(node_cpp, m)())
89
90
91class NodePyIO(NodePy):
92    def __init__(self, node_cpp, input_or_output=None):
93        super().__init__(node_cpp, methods_IO)
94        try:
95            tensor_size = node_cpp.type().sizes()
96        except RuntimeError:
97            tensor_size = [
98                1,
99            ]  # fail when constant model is used.
100        self.tensor_size = tensor_size
101        # Kind attribute string is purely descriptive and will be shown
102        # in detailed information for the node in TensorBoard's graph plugin.
103        #
104        # NodePyOP nodes get this from their kind() method.
105        self.kind = "Parameter"
106        if input_or_output:
107            self.input_or_output = input_or_output
108            self.kind = "IO Node"
109
110
111class NodePyOP(NodePy):
112    def __init__(self, node_cpp):
113        super().__init__(node_cpp, methods_OP)
114        # Replace single quote which causes strange behavior in TensorBoard
115        # TODO: See if we can remove this in the future
116        self.attributes = str(
117            {k: _node_get(node_cpp, k) for k in node_cpp.attributeNames()}
118        ).replace("'", " ")
119        self.kind = node_cpp.kind()
120
121
122class GraphPy:
123    """Helper class to convert torch.nn.Module to GraphDef proto and visualization with TensorBoard.
124
125    GraphDef generation operates in two passes:
126
127    In the first pass, all nodes are read and saved to two lists.
128    One list is for input/output nodes (nodes_io), which only have inbound
129    or outbound connections, but not both. Another list is for internal
130    operator nodes (nodes_op). The first pass also saves all scope name
131    appeared in the nodes in scope_name_appeared list for later processing.
132
133    In the second pass, scope names are fully applied to all nodes.
134    debugNameToScopedName is a mapping from a node's ID to its fully qualified
135    scope name. e.g. Net1/Linear[0]/1. Unfortunately torch.jit doesn't have
136    totally correct scope output, so this is nontrivial. The function
137    populate_namespace_from_OP_to_IO and find_common_root are used to
138    assign scope name to a node based on the connection between nodes
139    in a heuristic kind of way. Bookkeeping is done with shallowest_scope_name
140    and scope_name_appeared.
141    """
142
143    def __init__(self):
144        self.nodes_op = []
145        self.nodes_io = OrderedDict()
146        self.unique_name_to_scoped_name = {}
147        self.shallowest_scope_name = "default"
148        self.scope_name_appeared = []
149
150    def append(self, x):
151        if isinstance(x, NodePyIO):
152            self.nodes_io[x.debugName] = x
153        if isinstance(x, NodePyOP):
154            self.nodes_op.append(x)
155
156    def printall(self):
157        print("all nodes")
158        for node in self.nodes_op:
159            print(node)
160        for key in self.nodes_io:
161            print(self.nodes_io[key])
162
163    def find_common_root(self):
164        for fullscope in self.scope_name_appeared:
165            if fullscope:
166                self.shallowest_scope_name = fullscope.split("/")[0]
167
168    def populate_namespace_from_OP_to_IO(self):
169        for node in self.nodes_op:
170            for node_output, outputSize in zip(node.outputs, node.outputstensor_size):
171                self.scope_name_appeared.append(node.scopeName)
172                self.nodes_io[node_output] = NodeBase(
173                    node_output,
174                    node.inputs,
175                    node.scopeName,
176                    outputSize,
177                    op_type=node.kind,
178                    attributes=node.attributes,
179                )
180
181        self.find_common_root()
182
183        for node in self.nodes_op:
184            for input_node_id in node.inputs:
185                self.unique_name_to_scoped_name[input_node_id] = (
186                    node.scopeName + "/" + input_node_id
187                )
188
189        for key, node in self.nodes_io.items():
190            if type(node) == NodeBase:
191                self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
192            if hasattr(node, "input_or_output"):
193                self.unique_name_to_scoped_name[key] = (
194                    node.input_or_output + "/" + node.debugName
195                )
196
197            if hasattr(node, "scope") and node.scope is not None:
198                self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
199                if node.scope == "" and self.shallowest_scope_name:
200                    self.unique_name_to_scoped_name[node.debugName] = (
201                        self.shallowest_scope_name + "/" + node.debugName
202                    )
203
204        # replace name
205        for key, node in self.nodes_io.items():
206            self.nodes_io[key].inputs = [
207                self.unique_name_to_scoped_name[node_input_id]
208                for node_input_id in node.inputs
209            ]
210            if node.debugName in self.unique_name_to_scoped_name:
211                self.nodes_io[key].debugName = self.unique_name_to_scoped_name[
212                    node.debugName
213                ]
214
215    def to_proto(self):
216        """Convert graph representation of GraphPy object to TensorBoard required format."""
217        # TODO: compute correct memory usage and CPU time once
218        # PyTorch supports it
219        nodes = []
220        for v in self.nodes_io.values():
221            nodes.append(
222                node_proto(
223                    v.debugName,
224                    input=v.inputs,
225                    outputsize=v.tensor_size,
226                    op=v.kind,
227                    attributes=v.attributes,
228                )
229            )
230        return nodes
231
232
233def parse(graph, trace, args=None, omit_useless_nodes=True):
234    """Parse an optimized PyTorch model graph and produces a list of nodes and node stats.
235
236    Useful for eventual conversion to TensorBoard protobuf format.
237
238    Args:
239      graph (PyTorch module): The model graph to be parsed.
240      trace (PyTorch JIT TracedModule): The model trace to be parsed.
241      args (tuple): input tensor[s] for the model.
242      omit_useless_nodes (boolean): Whether to remove nodes from the graph.
243    """
244    n_inputs = len(args)
245
246    scope = {}
247    nodes_py = GraphPy()
248    for node in graph.inputs():
249        if omit_useless_nodes:
250            if (
251                len(node.uses()) == 0
252            ):  # number of user of the node (= number of outputs/ fanout)
253                continue
254
255        if node.type().kind() != CLASSTYPE_KIND:
256            nodes_py.append(NodePyIO(node, "input"))
257
258    attr_to_scope: Dict[Any, str] = {}
259    for node in graph.nodes():
260        if node.kind() == GETATTR_KIND:
261            attr_name = node.s("name")
262            attr_key = node.output().debugName()
263            parent = node.input().node()
264            if (
265                parent.kind() == GETATTR_KIND
266            ):  # If the parent node is not the top-level "self" node
267                parent_attr_name = parent.s("name")
268                parent_attr_key = parent.output().debugName()
269                parent_scope = attr_to_scope[parent_attr_key]
270                attr_scope = parent_scope.split("/")[-1]
271                attr_to_scope[attr_key] = f"{parent_scope}/{attr_scope}.{attr_name}"
272            else:
273                attr_to_scope[attr_key] = f"__module.{attr_name}"
274            # We don't need classtype nodes; scope will provide this information
275            if node.output().type().kind() != CLASSTYPE_KIND:
276                node_py = NodePyOP(node)
277                node_py.scopeName = attr_to_scope[attr_key]  # type: ignore[attr-defined]
278                nodes_py.append(node_py)
279        else:
280            nodes_py.append(NodePyOP(node))
281
282    for i, node in enumerate(graph.outputs()):  # Create sink nodes for output ops
283        node_pyio = NodePyIO(node, "output")
284        node_pyio.debugName = f"output.{i + 1}"
285        node_pyio.inputs = [node.debugName()]
286        nodes_py.append(node_pyio)
287
288    def parse_traced_name(module):
289        if isinstance(module, torch.jit.TracedModule):
290            module_name = module._name
291        else:
292            module_name = getattr(module, "original_name", "Module")
293        return module_name
294
295    alias_to_name = {}
296    base_name = parse_traced_name(trace)
297    for name, module in trace.named_modules(prefix="__module"):
298        mod_name = parse_traced_name(module)
299        attr_name = name.split(".")[-1]
300        alias_to_name[name] = f"{mod_name}[{attr_name}]"
301
302    for node in nodes_py.nodes_op:
303        module_aliases = node.scopeName.split("/")
304        replacements = [
305            alias_to_name[alias] if alias in alias_to_name else alias.split(".")[-1]
306            for alias in module_aliases
307        ]
308        node.scopeName = base_name
309        if any(replacements):
310            node.scopeName += "/" + "/".join(replacements)
311
312    nodes_py.populate_namespace_from_OP_to_IO()
313    return nodes_py.to_proto()
314
315
316def graph(model, args, verbose=False, use_strict_trace=True):
317    """
318    Process a PyTorch model and produces a `GraphDef` proto that can be logged to TensorBoard.
319
320    Args:
321      model (PyTorch module): The model to be parsed.
322      args (tuple): input tensor[s] for the model.
323      verbose (bool): Whether to print out verbose information while
324        processing.
325      use_strict_trace (bool): Whether to pass keyword argument `strict` to
326        `torch.jit.trace`. Pass False when you want the tracer to
327        record your mutable container types (list, dict)
328    """
329    with _set_model_to_eval(model):
330        try:
331            trace = torch.jit.trace(model, args, strict=use_strict_trace)
332            graph = trace.graph
333            torch._C._jit_pass_inline(graph)
334        except RuntimeError as e:
335            print(e)
336            print("Error occurs, No graph saved")
337            raise e
338
339    if verbose:
340        print(graph)
341    list_of_nodes = parse(graph, trace, args)
342    # We are hardcoding that this was run on CPU even though it might have actually
343    # run on GPU. Note this is what is shown in TensorBoard and has no bearing
344    # on actual execution.
345    # TODO: See if we can extract GPU vs CPU information from the PyTorch model
346    # and pass it correctly to TensorBoard.
347    #
348    # Definition of StepStats and DeviceStepStats can be found at
349    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
350    # and
351    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
352    stepstats = RunMetadata(
353        step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")])
354    )
355    return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
356    # The producer version has been reverse engineered from standard
357    # TensorBoard logged data.
358
359
360@contextlib.contextmanager
361def _set_model_to_eval(model):
362    """Context manager to temporarily set the training mode of ``model`` to eval."""
363    if not isinstance(model, torch.jit.ScriptFunction):
364        originally_training = model.training
365        model.train(False)
366        try:
367            yield
368        finally:
369            model.train(originally_training)
370    else:
371        # Do nothing for ScriptFunction
372        try:
373            yield
374        finally:
375            pass
376
377
378def _node_get(node: torch._C.Node, key: str):
379    """Get attributes of a node which is polymorphic over return type."""
380    sel = node.kindOf(key)
381    return getattr(node, sel)(key)
382