1# mypy: allow-untyped-defs 2import dataclasses 3import os 4from typing import Any, List 5 6import torch 7 8from .utils import print_once 9 10 11@dataclasses.dataclass 12class ProfileMetrics: 13 microseconds: float = 0.0 14 operators: int = 0 15 fusions: int = 0 16 graphs: int = 0 17 18 def __iadd__(self, other: "ProfileMetrics"): 19 self.microseconds += other.microseconds 20 self.operators += other.operators 21 self.fusions += other.fusions 22 return self 23 24 def __add__(self, other: "ProfileMetrics"): 25 assert isinstance(other, ProfileMetrics) 26 return ProfileMetrics( 27 self.microseconds + other.microseconds, 28 self.operators + other.operators, 29 self.fusions + other.fusions, 30 ) 31 32 def __truediv__(self, other): 33 if isinstance(other, int): 34 other = ProfileMetrics(other, other, other) 35 return ProfileMetrics( 36 self.microseconds / max(1, other.microseconds), 37 self.operators / max(1, other.operators), 38 self.fusions / max(1, other.fusions), 39 ) 40 41 def __str__(self) -> str: 42 return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time" 43 44 def tocsv(self): 45 return [self.operators, self.microseconds] 46 47 48class ProfileResult: 49 def __init__(self, captured, total, unique_graphs) -> None: 50 self.captured: ProfileMetrics = captured or ProfileMetrics() 51 self.total: ProfileMetrics = total or ProfileMetrics() 52 self.unique_graphs: int = unique_graphs 53 54 def __iadd__(self, other: "ProfileResult"): 55 self.captured += other.captured 56 self.total += other.total 57 self.unique_graphs += other.unique_graphs 58 return self 59 60 def percent(self): 61 return self.captured / self.total 62 63 def __str__(self) -> str: 64 return ( 65 f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls " 66 f"{self.captured.operators:4}/{self.total.operators:4} = " 67 + str(self.percent()) 68 ) 69 70 def tocsv(self): 71 return [ 72 self.unique_graphs, 73 self.captured.graphs, 74 self.captured.operators, 75 self.total.operators, 76 ] + self.percent().tocsv() 77 78 79def should_print_missing(): 80 return os.environ.get("TORCHDYNAMO_PRINT_MISSING") == "1" 81 82 83def print_missing(stack): 84 if any("/torch/autograd/profiler.py" in x for x in stack): 85 return 86 stack = [ 87 x for x in stack if ("<built-in" not in x and "site-packages/torch/" not in x) 88 ] 89 print_once("MISSING", " >> ".join(stack[-3:])) 90 91 92class Profiler: 93 unique_graphs = 0 94 95 def __init__(self) -> None: 96 self.prof = torch.profiler.profile( 97 activities=[torch.profiler.ProfilerActivity.CPU], 98 with_stack=should_print_missing(), 99 ) 100 101 def results(self): 102 captured_regions = 0 103 captured_ops = 0 104 captured_microseconds = 0 105 total_ops = 0 106 total_microseconds = 0 107 108 last_op_end_time = -1 109 captured_region_end_time = -1 110 events = sorted(self.prof.events(), key=lambda x: x.time_range.start) 111 for e in events: 112 if e.name == "TORCHDYNAMO": 113 captured_region_end_time = e.time_range.end 114 captured_regions += 1 115 # ignore `handle = torch.zeros(1)` in record_function.__init__() 116 total_ops -= 1 117 elif e.time_range.start >= last_op_end_time: 118 last_op_end_time = e.time_range.end 119 if e.time_range.end <= captured_region_end_time: 120 captured_ops += 1 121 captured_microseconds += e.time_range.elapsed_us() 122 elif should_print_missing(): 123 print_missing(e.stack) 124 total_ops += 1 125 total_microseconds += e.time_range.elapsed_us() 126 else: 127 pass # ops recursively called from other ops (ignored) 128 129 unique_graphs = Profiler.unique_graphs 130 Profiler.unique_graphs = 0 131 # we counted one extra op that is part of the profiler setup code 132 total_ops -= 1 133 134 return ProfileResult( 135 captured=ProfileMetrics( 136 microseconds=captured_microseconds, 137 operators=captured_ops, 138 fusions=captured_ops - captured_regions, 139 graphs=captured_regions, 140 ), 141 total=ProfileMetrics( 142 microseconds=total_microseconds, 143 operators=total_ops, 144 fusions=total_ops - 1, 145 ), 146 unique_graphs=unique_graphs, 147 ) 148 149 150def fx_insert_profiling(gm: torch.fx.GraphModule, example_inputs: List[Any]): 151 def _wrapped(*args): 152 with torch.profiler.record_function("TORCHDYNAMO"): 153 return gm.forward(*args) 154 155 Profiler.unique_graphs += 1 156 return _wrapped 157