xref: /aosp_15_r20/external/pytorch/torch/utils/tensorboard/_onnx_graph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from tensorboard.compat.proto.graph_pb2 import GraphDef
3from tensorboard.compat.proto.node_def_pb2 import NodeDef
4from tensorboard.compat.proto.versions_pb2 import VersionDef
5from tensorboard.compat.proto.attr_value_pb2 import AttrValue
6from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
7
8
9def load_onnx_graph(fname):
10    import onnx
11
12    m = onnx.load(fname)  # type: ignore[attr-defined]
13    g = m.graph
14    return parse(g)
15
16
17def parse(graph):
18    nodes = []
19    import itertools
20
21    nodes_proto = list(itertools.chain(graph.input, graph.output))
22
23    for node in nodes_proto:
24        print(node.name)
25        shapeproto = TensorShapeProto(
26            dim=[
27                TensorShapeProto.Dim(size=d.dim_value)
28                for d in node.type.tensor_type.shape.dim
29            ]
30        )
31        nodes.append(
32            NodeDef(
33                name=node.name.encode(encoding="utf_8"),
34                op="Variable",
35                input=[],
36                attr={
37                    "dtype": AttrValue(type=node.type.tensor_type.elem_type),
38                    "shape": AttrValue(shape=shapeproto),
39                },
40            )
41        )
42
43    for node in graph.node:
44        _attr = []
45        for s in node.attribute:
46            _attr.append(" = ".join([str(f[1]) for f in s.ListFields()]))
47        attr = ", ".join(_attr).encode(encoding="utf_8")
48        print(node.output[0])
49        nodes.append(
50            NodeDef(
51                name=node.output[0].encode(encoding="utf_8"),
52                op=node.op_type,
53                input=node.input,
54                attr={"parameters": AttrValue(s=attr)},
55            )
56        )
57
58    # two pass token replacement, appends opname to object id
59    mapping = {}
60    for node in nodes:
61        mapping[node.name] = node.op + "_" + node.name
62
63    return GraphDef(node=nodes, versions=VersionDef(producer=22))
64