1# mypy: allow-untyped-defs 2from tensorboard.compat.proto.graph_pb2 import GraphDef 3from tensorboard.compat.proto.node_def_pb2 import NodeDef 4from tensorboard.compat.proto.versions_pb2 import VersionDef 5from tensorboard.compat.proto.attr_value_pb2 import AttrValue 6from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto 7 8 9def load_onnx_graph(fname): 10 import onnx 11 12 m = onnx.load(fname) # type: ignore[attr-defined] 13 g = m.graph 14 return parse(g) 15 16 17def parse(graph): 18 nodes = [] 19 import itertools 20 21 nodes_proto = list(itertools.chain(graph.input, graph.output)) 22 23 for node in nodes_proto: 24 print(node.name) 25 shapeproto = TensorShapeProto( 26 dim=[ 27 TensorShapeProto.Dim(size=d.dim_value) 28 for d in node.type.tensor_type.shape.dim 29 ] 30 ) 31 nodes.append( 32 NodeDef( 33 name=node.name.encode(encoding="utf_8"), 34 op="Variable", 35 input=[], 36 attr={ 37 "dtype": AttrValue(type=node.type.tensor_type.elem_type), 38 "shape": AttrValue(shape=shapeproto), 39 }, 40 ) 41 ) 42 43 for node in graph.node: 44 _attr = [] 45 for s in node.attribute: 46 _attr.append(" = ".join([str(f[1]) for f in s.ListFields()])) 47 attr = ", ".join(_attr).encode(encoding="utf_8") 48 print(node.output[0]) 49 nodes.append( 50 NodeDef( 51 name=node.output[0].encode(encoding="utf_8"), 52 op=node.op_type, 53 input=node.input, 54 attr={"parameters": AttrValue(s=attr)}, 55 ) 56 ) 57 58 # two pass token replacement, appends opname to object id 59 mapping = {} 60 for node in nodes: 61 mapping[node.name] = node.op + "_" + node.name 62 63 return GraphDef(node=nodes, versions=VersionDef(producer=22)) 64