1import sys 2 3from benchmark_base import BenchmarkBase 4 5import torch 6import torch.nn as nn 7import torch.nn.functional as F 8from torch._inductor.utils import fresh_inductor_cache 9 10 11class ListOfLinears(nn.Module): 12 def __init__(self): 13 super().__init__() 14 self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)]) 15 16 def forward(self, x): 17 # ModuleList can act as an iterable, or be indexed using ints 18 for i, l in enumerate(self.linears): 19 x = self.linears[i // 2](x) + l(x) 20 return x 21 22 23class BasicModule(torch.nn.Module): 24 def __init__(self) -> None: 25 super().__init__() 26 self.linear1 = torch.nn.Linear(10, 10) 27 self.scale = torch.randn(1, 10) 28 29 def forward(self, x): 30 return F.relu(self.linear1(x)) * self.scale 31 32 33class ModuleForwardHasGraphBreak(torch.nn.Module): 34 def __init__(self) -> None: 35 super().__init__() 36 self.layer1 = BasicModule() 37 self.layer2 = BasicModule() 38 self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule()) 39 self.layer4 = torch.nn.ModuleList( 40 [ 41 torch.nn.Linear(10, 10), 42 torch.nn.ReLU(), 43 torch.nn.Linear(10, 10), 44 torch.nn.ReLU(), 45 ] 46 ) 47 self.layer5 = torch.nn.ModuleDict( 48 { 49 "0": torch.nn.Linear(10, 10), 50 } 51 ) 52 self.scale = torch.randn(1, 10) 53 54 def forward(self, x): 55 """ 56 This is used to test if the results of functions like `named_parameters` 57 can be reconstructed correctly after graph break. 58 59 https://github.com/pytorch/torchdynamo/issues/1931 60 """ 61 x = self.layer1(x) 62 params1 = dict(self.named_parameters()) 63 params2 = list(self.parameters()) 64 buffers1 = dict(self.named_buffers()) 65 buffers2 = list(self.buffers()) 66 modules1 = dict(self.named_modules()) 67 modules2 = list(self.modules()) 68 torch._dynamo.graph_break() 69 y = modules2 70 y = modules1 71 y = buffers2 72 y = buffers1 73 y = params2 74 y = params1 75 x = ( 76 self.layer2(x) 77 + y["layer3.1.linear1.weight"] 78 + y["layer4.2.weight"] 79 + y["layer5.0.weight"] 80 ) 81 return x * self.scale 82 83 84class SequentialWithDuplicatedModule(torch.nn.Module): 85 # Sequential module(self.layer) contains three duplicated ReLU module. 86 def __init__(self) -> None: 87 super().__init__() 88 self.relu = torch.nn.ReLU() 89 self.layer = torch.nn.Sequential( 90 torch.nn.Linear(10, 20), 91 self.relu, 92 torch.nn.Linear(20, 20), 93 self.relu, 94 torch.nn.Linear(20, 10), 95 self.relu, 96 ) 97 98 def forward(self, x): 99 return self.layer(x) 100 101 102class ModuleComparison(torch.nn.Module): 103 def __init__(self) -> None: 104 super().__init__() 105 self.layer0 = torch.nn.Linear(10, 10) 106 self.layer1 = torch.nn.Linear(10, 10) 107 self.layer2 = torch.nn.Linear(10, 10) 108 109 @property 110 def encoder_layers(self): 111 return [self.layer0, self.layer1, self.layer2] 112 113 def forward(self, x): 114 for layer in self.encoder_layers: 115 output = layer(x) 116 if layer is None or layer == self.layer0: 117 output = F.relu6(output) 118 else: 119 output = F.relu(output) 120 return output 121 122 123class Benchmark(BenchmarkBase): 124 def __init__(self, ModuleClass, backend): 125 self.ModuleClass = ModuleClass 126 self.backend = backend 127 self._name = ModuleClass.__name__ 128 129 def name(self): 130 return f"basic_modules_{self._name}_{self.backend}" 131 132 def _prepare_once(self): 133 self.m = self.ModuleClass() 134 self.input = torch.ones(10) 135 136 def _prepare(self): 137 torch._dynamo.reset() 138 139 def _work(self): 140 with fresh_inductor_cache(): 141 opt_m = torch.compile(backend=self.backend)(self.m) 142 opt_m(self.input) 143 144 145def main(): 146 result_path = sys.argv[1] 147 benchmarks = [ 148 Benchmark(ListOfLinears, "inductor"), 149 Benchmark(ListOfLinears, "eager"), 150 Benchmark(ModuleForwardHasGraphBreak, "inductor"), 151 Benchmark(ModuleForwardHasGraphBreak, "eager"), 152 Benchmark(SequentialWithDuplicatedModule, "inductor"), 153 Benchmark(SequentialWithDuplicatedModule, "eager"), 154 Benchmark(ModuleComparison, "inductor"), 155 Benchmark(ModuleComparison, "eager"), 156 ] 157 for b in benchmarks: 158 b.enable_compile_time_instruction_count().collect_all().append_results( 159 result_path 160 ) 161 162 163if __name__ == "__main__": 164 main() 165