1# mypy: allow-untyped-defs 2from collections import OrderedDict 3import contextlib 4from typing import Dict, Any 5 6from tensorboard.compat.proto.config_pb2 import RunMetadata 7from tensorboard.compat.proto.graph_pb2 import GraphDef 8from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats 9from tensorboard.compat.proto.versions_pb2 import VersionDef 10 11import torch 12from ._proto_graph import node_proto 13 14methods_OP = [ 15 "attributeNames", 16 "hasMultipleOutputs", 17 "hasUses", 18 "inputs", 19 "kind", 20 "outputs", 21 "outputsSize", 22 "scopeName", 23] 24# Some additional methods to explure for methods_IO are 25# 26# 'unique' (type int) 27# 'type' (type <Tensor<class 'torch._C.Type'>>) 28# 29# But the below are sufficient for now. 30methods_IO = ["node", "offset", "debugName"] 31 32GETATTR_KIND = "prim::GetAttr" 33CLASSTYPE_KIND = "ClassType" 34 35 36class NodeBase: 37 def __init__( 38 self, 39 debugName=None, 40 inputs=None, 41 scope=None, 42 tensor_size=None, 43 op_type="UnSpecified", 44 attributes="", 45 ): 46 # TODO; Specify a __slots__ for this class or potentially 47 # used namedtuple instead 48 self.debugName = debugName 49 self.inputs = inputs 50 self.tensor_size = tensor_size 51 self.kind = op_type 52 self.attributes = attributes 53 self.scope = scope 54 55 def __repr__(self): 56 repr = [] 57 repr.append(str(type(self))) 58 for m in dir(self): 59 if "__" not in m: 60 repr.append( 61 m + ": " + str(getattr(self, m)) + str(type(getattr(self, m))) 62 ) 63 return "\n".join(repr) + "\n\n" 64 65 66class NodePy(NodeBase): 67 def __init__(self, node_cpp, valid_methods): 68 super().__init__(node_cpp) 69 valid_methods = valid_methods[:] 70 self.inputs = [] 71 72 for m in valid_methods: 73 if m == "inputs" or m == "outputs": 74 list_of_node = list(getattr(node_cpp, m)()) 75 io_unique_names = [] 76 io_tensor_sizes = [] 77 for n in list_of_node: 78 io_unique_names.append(n.debugName()) 79 if n.isCompleteTensor(): 80 io_tensor_sizes.append(n.type().sizes()) 81 else: 82 io_tensor_sizes.append(None) 83 84 setattr(self, m, io_unique_names) 85 setattr(self, m + "tensor_size", io_tensor_sizes) 86 87 else: 88 setattr(self, m, getattr(node_cpp, m)()) 89 90 91class NodePyIO(NodePy): 92 def __init__(self, node_cpp, input_or_output=None): 93 super().__init__(node_cpp, methods_IO) 94 try: 95 tensor_size = node_cpp.type().sizes() 96 except RuntimeError: 97 tensor_size = [ 98 1, 99 ] # fail when constant model is used. 100 self.tensor_size = tensor_size 101 # Kind attribute string is purely descriptive and will be shown 102 # in detailed information for the node in TensorBoard's graph plugin. 103 # 104 # NodePyOP nodes get this from their kind() method. 105 self.kind = "Parameter" 106 if input_or_output: 107 self.input_or_output = input_or_output 108 self.kind = "IO Node" 109 110 111class NodePyOP(NodePy): 112 def __init__(self, node_cpp): 113 super().__init__(node_cpp, methods_OP) 114 # Replace single quote which causes strange behavior in TensorBoard 115 # TODO: See if we can remove this in the future 116 self.attributes = str( 117 {k: _node_get(node_cpp, k) for k in node_cpp.attributeNames()} 118 ).replace("'", " ") 119 self.kind = node_cpp.kind() 120 121 122class GraphPy: 123 """Helper class to convert torch.nn.Module to GraphDef proto and visualization with TensorBoard. 124 125 GraphDef generation operates in two passes: 126 127 In the first pass, all nodes are read and saved to two lists. 128 One list is for input/output nodes (nodes_io), which only have inbound 129 or outbound connections, but not both. Another list is for internal 130 operator nodes (nodes_op). The first pass also saves all scope name 131 appeared in the nodes in scope_name_appeared list for later processing. 132 133 In the second pass, scope names are fully applied to all nodes. 134 debugNameToScopedName is a mapping from a node's ID to its fully qualified 135 scope name. e.g. Net1/Linear[0]/1. Unfortunately torch.jit doesn't have 136 totally correct scope output, so this is nontrivial. The function 137 populate_namespace_from_OP_to_IO and find_common_root are used to 138 assign scope name to a node based on the connection between nodes 139 in a heuristic kind of way. Bookkeeping is done with shallowest_scope_name 140 and scope_name_appeared. 141 """ 142 143 def __init__(self): 144 self.nodes_op = [] 145 self.nodes_io = OrderedDict() 146 self.unique_name_to_scoped_name = {} 147 self.shallowest_scope_name = "default" 148 self.scope_name_appeared = [] 149 150 def append(self, x): 151 if isinstance(x, NodePyIO): 152 self.nodes_io[x.debugName] = x 153 if isinstance(x, NodePyOP): 154 self.nodes_op.append(x) 155 156 def printall(self): 157 print("all nodes") 158 for node in self.nodes_op: 159 print(node) 160 for key in self.nodes_io: 161 print(self.nodes_io[key]) 162 163 def find_common_root(self): 164 for fullscope in self.scope_name_appeared: 165 if fullscope: 166 self.shallowest_scope_name = fullscope.split("/")[0] 167 168 def populate_namespace_from_OP_to_IO(self): 169 for node in self.nodes_op: 170 for node_output, outputSize in zip(node.outputs, node.outputstensor_size): 171 self.scope_name_appeared.append(node.scopeName) 172 self.nodes_io[node_output] = NodeBase( 173 node_output, 174 node.inputs, 175 node.scopeName, 176 outputSize, 177 op_type=node.kind, 178 attributes=node.attributes, 179 ) 180 181 self.find_common_root() 182 183 for node in self.nodes_op: 184 for input_node_id in node.inputs: 185 self.unique_name_to_scoped_name[input_node_id] = ( 186 node.scopeName + "/" + input_node_id 187 ) 188 189 for key, node in self.nodes_io.items(): 190 if type(node) == NodeBase: 191 self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName 192 if hasattr(node, "input_or_output"): 193 self.unique_name_to_scoped_name[key] = ( 194 node.input_or_output + "/" + node.debugName 195 ) 196 197 if hasattr(node, "scope") and node.scope is not None: 198 self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName 199 if node.scope == "" and self.shallowest_scope_name: 200 self.unique_name_to_scoped_name[node.debugName] = ( 201 self.shallowest_scope_name + "/" + node.debugName 202 ) 203 204 # replace name 205 for key, node in self.nodes_io.items(): 206 self.nodes_io[key].inputs = [ 207 self.unique_name_to_scoped_name[node_input_id] 208 for node_input_id in node.inputs 209 ] 210 if node.debugName in self.unique_name_to_scoped_name: 211 self.nodes_io[key].debugName = self.unique_name_to_scoped_name[ 212 node.debugName 213 ] 214 215 def to_proto(self): 216 """Convert graph representation of GraphPy object to TensorBoard required format.""" 217 # TODO: compute correct memory usage and CPU time once 218 # PyTorch supports it 219 nodes = [] 220 for v in self.nodes_io.values(): 221 nodes.append( 222 node_proto( 223 v.debugName, 224 input=v.inputs, 225 outputsize=v.tensor_size, 226 op=v.kind, 227 attributes=v.attributes, 228 ) 229 ) 230 return nodes 231 232 233def parse(graph, trace, args=None, omit_useless_nodes=True): 234 """Parse an optimized PyTorch model graph and produces a list of nodes and node stats. 235 236 Useful for eventual conversion to TensorBoard protobuf format. 237 238 Args: 239 graph (PyTorch module): The model graph to be parsed. 240 trace (PyTorch JIT TracedModule): The model trace to be parsed. 241 args (tuple): input tensor[s] for the model. 242 omit_useless_nodes (boolean): Whether to remove nodes from the graph. 243 """ 244 n_inputs = len(args) 245 246 scope = {} 247 nodes_py = GraphPy() 248 for node in graph.inputs(): 249 if omit_useless_nodes: 250 if ( 251 len(node.uses()) == 0 252 ): # number of user of the node (= number of outputs/ fanout) 253 continue 254 255 if node.type().kind() != CLASSTYPE_KIND: 256 nodes_py.append(NodePyIO(node, "input")) 257 258 attr_to_scope: Dict[Any, str] = {} 259 for node in graph.nodes(): 260 if node.kind() == GETATTR_KIND: 261 attr_name = node.s("name") 262 attr_key = node.output().debugName() 263 parent = node.input().node() 264 if ( 265 parent.kind() == GETATTR_KIND 266 ): # If the parent node is not the top-level "self" node 267 parent_attr_name = parent.s("name") 268 parent_attr_key = parent.output().debugName() 269 parent_scope = attr_to_scope[parent_attr_key] 270 attr_scope = parent_scope.split("/")[-1] 271 attr_to_scope[attr_key] = f"{parent_scope}/{attr_scope}.{attr_name}" 272 else: 273 attr_to_scope[attr_key] = f"__module.{attr_name}" 274 # We don't need classtype nodes; scope will provide this information 275 if node.output().type().kind() != CLASSTYPE_KIND: 276 node_py = NodePyOP(node) 277 node_py.scopeName = attr_to_scope[attr_key] # type: ignore[attr-defined] 278 nodes_py.append(node_py) 279 else: 280 nodes_py.append(NodePyOP(node)) 281 282 for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops 283 node_pyio = NodePyIO(node, "output") 284 node_pyio.debugName = f"output.{i + 1}" 285 node_pyio.inputs = [node.debugName()] 286 nodes_py.append(node_pyio) 287 288 def parse_traced_name(module): 289 if isinstance(module, torch.jit.TracedModule): 290 module_name = module._name 291 else: 292 module_name = getattr(module, "original_name", "Module") 293 return module_name 294 295 alias_to_name = {} 296 base_name = parse_traced_name(trace) 297 for name, module in trace.named_modules(prefix="__module"): 298 mod_name = parse_traced_name(module) 299 attr_name = name.split(".")[-1] 300 alias_to_name[name] = f"{mod_name}[{attr_name}]" 301 302 for node in nodes_py.nodes_op: 303 module_aliases = node.scopeName.split("/") 304 replacements = [ 305 alias_to_name[alias] if alias in alias_to_name else alias.split(".")[-1] 306 for alias in module_aliases 307 ] 308 node.scopeName = base_name 309 if any(replacements): 310 node.scopeName += "/" + "/".join(replacements) 311 312 nodes_py.populate_namespace_from_OP_to_IO() 313 return nodes_py.to_proto() 314 315 316def graph(model, args, verbose=False, use_strict_trace=True): 317 """ 318 Process a PyTorch model and produces a `GraphDef` proto that can be logged to TensorBoard. 319 320 Args: 321 model (PyTorch module): The model to be parsed. 322 args (tuple): input tensor[s] for the model. 323 verbose (bool): Whether to print out verbose information while 324 processing. 325 use_strict_trace (bool): Whether to pass keyword argument `strict` to 326 `torch.jit.trace`. Pass False when you want the tracer to 327 record your mutable container types (list, dict) 328 """ 329 with _set_model_to_eval(model): 330 try: 331 trace = torch.jit.trace(model, args, strict=use_strict_trace) 332 graph = trace.graph 333 torch._C._jit_pass_inline(graph) 334 except RuntimeError as e: 335 print(e) 336 print("Error occurs, No graph saved") 337 raise e 338 339 if verbose: 340 print(graph) 341 list_of_nodes = parse(graph, trace, args) 342 # We are hardcoding that this was run on CPU even though it might have actually 343 # run on GPU. Note this is what is shown in TensorBoard and has no bearing 344 # on actual execution. 345 # TODO: See if we can extract GPU vs CPU information from the PyTorch model 346 # and pass it correctly to TensorBoard. 347 # 348 # Definition of StepStats and DeviceStepStats can be found at 349 # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts 350 # and 351 # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto 352 stepstats = RunMetadata( 353 step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]) 354 ) 355 return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats 356 # The producer version has been reverse engineered from standard 357 # TensorBoard logged data. 358 359 360@contextlib.contextmanager 361def _set_model_to_eval(model): 362 """Context manager to temporarily set the training mode of ``model`` to eval.""" 363 if not isinstance(model, torch.jit.ScriptFunction): 364 originally_training = model.training 365 model.train(False) 366 try: 367 yield 368 finally: 369 model.train(originally_training) 370 else: 371 # Do nothing for ScriptFunction 372 try: 373 yield 374 finally: 375 pass 376 377 378def _node_get(node: torch._C.Node, key: str): 379 """Get attributes of a node which is polymorphic over return type.""" 380 sel = node.kindOf(key) 381 return getattr(node, sel)(key) 382