1# mypy: allow-untyped-defs 2import time 3from collections import defaultdict 4from functools import partial 5from typing import DefaultDict 6 7import torch 8 9 10# Unfortunately it doesn't seem as if there was any way to get TensorBoard to do 11# anything without having TF installed, and so this file has a hard dependency on it 12# as well. It really is a debugging tool, so it doesn't matter. 13try: 14 from tensorflow.core.util import event_pb2 15 from tensorflow.core.framework import graph_pb2 16 from tensorflow.python.summary.writer.writer import FileWriter 17except ImportError: 18 raise ImportError("TensorBoard visualization of GraphExecutors requires having " 19 "TensorFlow installed") from None 20 21 22def dump_tensorboard_summary(graph_executor, logdir): 23 with FileWriter(logdir) as w: 24 pb_graph = visualize(graph_executor) 25 evt = event_pb2.Event(wall_time=time.time(), graph_def=pb_graph.SerializeToString()) 26 w.add_event(evt) 27 28 29def visualize(graph, name_prefix='', pb_graph=None, executors_it=None): 30 """Visualizes an independent graph, or a graph executor.""" 31 value_map = {} 32 pb_graph = pb_graph or graph_pb2.GraphDef() 33 34 if isinstance(graph, torch._C.GraphExecutorState): 35 visualize_graph_executor(graph, name_prefix, pb_graph, 36 partial(visualize, pb_graph=pb_graph)) 37 return pb_graph 38 39 # Set up an input node 40 input_node = pb_graph.node.add(op='input', name=name_prefix + 'input') 41 for i, value in enumerate(graph.param_node().outputs()): 42 value_map[value.unique()] = name_prefix + 'input:' + str(i) 43 44 visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it) 45 46 # Gather all outputs 47 return_node = pb_graph.node.add(op='output', name=name_prefix + 'output') 48 for value in graph.return_node().inputs(): 49 return_node.input.append(value_map[value.unique()]) 50 51 return pb_graph 52 53 54def visualize_graph_executor(state, name_prefix, pb_graph, inline_graph): 55 """Append the state of a given GraphExecutor to the graph protobuf. 56 57 Args: 58 state (GraphExecutor or GraphExecutorState): GraphExecutor to display. 59 name_prefix (str): Name prefix of the containing subgraph. 60 pb_graph (GraphDef): graph to append to. 61 inline_graph (Callable): a function that handles setting up a value_map, 62 so that some graphs in here can be inlined. This is necessary, because 63 this will simply be `visualize` for the top-level GraphExecutor, 64 or `inline_graph` for all nested ones. 65 66 The signature should look like (Graph, name_prefix) -> (). 67 It will be called exactly once. 68 69 The strategy is to embed all different configurations as independent subgraphs, 70 while inlining the original graph as the one that actually produces the values. 71 """ 72 if state.autograd_fallback_graph is not None: 73 visualize(graph=state.autograd_fallback_graph, 74 name_prefix=name_prefix + 'autograd_fallback/', 75 pb_graph=pb_graph, 76 executors_it=iter(state.autograd_fallback.executors())) 77 78 for i, (arg_spec, plan) in enumerate(state.execution_plans.items()): 79 subgraph_name = name_prefix + f'plan{i}/' 80 81 # Create a disconnected node that will keep information regarding the input 82 # types of this trace. This is unfortunately a bit too verbose to be included 83 # in the subgraph name. 84 input_kinds = pb_graph.node.add(op='INPUT_KIND', name=subgraph_name) 85 input_kinds.attr['inputs'].s = repr(arg_spec).encode('ascii') 86 87 visualize(plan.graph, subgraph_name, pb_graph, iter(plan.code.executors())) 88 89 # Show gradient as an independent subgraph of this plan 90 if plan.grad_executor is not None: 91 grad_subgraph_name = subgraph_name + 'grad/' 92 visualize(plan.grad_executor, grad_subgraph_name, pb_graph) 93 94 return inline_graph(state.graph, name_prefix + 'original/') 95 96 97def visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it=None): 98 """Recursive part of visualize (basically skips setting up the input and output nodes).""" 99 def inline_graph(subgraph, name, node): 100 rec_value_map = {inp.unique(): value_map[val.unique()] 101 for inp, val in zip(subgraph.inputs(), node.inputs())} 102 visualize_rec(graph=subgraph, 103 value_map=rec_value_map, 104 name_prefix=name, 105 pb_graph=pb_graph) 106 for out, val in zip(subgraph.outputs(), node.outputs()): 107 value_map[val.unique()] = rec_value_map[out.unique()] 108 109 op_id_counter: DefaultDict[str, int] = defaultdict(int) 110 111 def name_for(node): 112 kind = node.kind()[node.kind().index('::') + 2:] 113 op_id_counter[kind] += 1 114 return kind, name_prefix + kind + '_' + str(op_id_counter[kind]) 115 116 def add_fusion_group(node): 117 op, name = name_for(node) 118 inline_graph(node.g('Subgraph'), name + '/', node) 119 120 def add_graph_executor(node): 121 op, name = name_for(node) 122 if executors_it is None: 123 add_node(node) 124 else: 125 ge = next(executors_it) 126 visualize_graph_executor(ge, name + '/', pb_graph, 127 partial(inline_graph, node=node)) 128 129 def add_node(node): 130 if node.kind() == 'prim::FusionGroup': 131 return add_fusion_group(node) 132 elif node.kind() == 'prim::GraphExecutor': 133 return add_graph_executor(node) 134 op, name = name_for(node) 135 pb_node = pb_graph.node.add(op=op, name=name) 136 for value in node.inputs(): 137 pb_node.input.append(value_map[value.unique()]) 138 # TODO: handle attrs 139 for i, value in enumerate(node.outputs()): 140 value_map[value.unique()] = name + ':' + str(i) 141 142 for node in graph.nodes(): 143 add_node(node) 144