1import sys
2
3from benchmark_base import BenchmarkBase
4
5import torch
6import torch.nn as nn
7import torch.nn.functional as F
8from torch._inductor.utils import fresh_inductor_cache
9
10
11class ListOfLinears(nn.Module):
12    def __init__(self):
13        super().__init__()
14        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])
15
16    def forward(self, x):
17        # ModuleList can act as an iterable, or be indexed using ints
18        for i, l in enumerate(self.linears):
19            x = self.linears[i // 2](x) + l(x)
20        return x
21
22
23class BasicModule(torch.nn.Module):
24    def __init__(self) -> None:
25        super().__init__()
26        self.linear1 = torch.nn.Linear(10, 10)
27        self.scale = torch.randn(1, 10)
28
29    def forward(self, x):
30        return F.relu(self.linear1(x)) * self.scale
31
32
33class ModuleForwardHasGraphBreak(torch.nn.Module):
34    def __init__(self) -> None:
35        super().__init__()
36        self.layer1 = BasicModule()
37        self.layer2 = BasicModule()
38        self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule())
39        self.layer4 = torch.nn.ModuleList(
40            [
41                torch.nn.Linear(10, 10),
42                torch.nn.ReLU(),
43                torch.nn.Linear(10, 10),
44                torch.nn.ReLU(),
45            ]
46        )
47        self.layer5 = torch.nn.ModuleDict(
48            {
49                "0": torch.nn.Linear(10, 10),
50            }
51        )
52        self.scale = torch.randn(1, 10)
53
54    def forward(self, x):
55        """
56        This is used to test if the results of functions like `named_parameters`
57        can be reconstructed correctly after graph break.
58
59        https://github.com/pytorch/torchdynamo/issues/1931
60        """
61        x = self.layer1(x)
62        params1 = dict(self.named_parameters())
63        params2 = list(self.parameters())
64        buffers1 = dict(self.named_buffers())
65        buffers2 = list(self.buffers())
66        modules1 = dict(self.named_modules())
67        modules2 = list(self.modules())
68        torch._dynamo.graph_break()
69        y = modules2
70        y = modules1
71        y = buffers2
72        y = buffers1
73        y = params2
74        y = params1
75        x = (
76            self.layer2(x)
77            + y["layer3.1.linear1.weight"]
78            + y["layer4.2.weight"]
79            + y["layer5.0.weight"]
80        )
81        return x * self.scale
82
83
84class SequentialWithDuplicatedModule(torch.nn.Module):
85    # Sequential module(self.layer) contains three duplicated ReLU module.
86    def __init__(self) -> None:
87        super().__init__()
88        self.relu = torch.nn.ReLU()
89        self.layer = torch.nn.Sequential(
90            torch.nn.Linear(10, 20),
91            self.relu,
92            torch.nn.Linear(20, 20),
93            self.relu,
94            torch.nn.Linear(20, 10),
95            self.relu,
96        )
97
98    def forward(self, x):
99        return self.layer(x)
100
101
102class ModuleComparison(torch.nn.Module):
103    def __init__(self) -> None:
104        super().__init__()
105        self.layer0 = torch.nn.Linear(10, 10)
106        self.layer1 = torch.nn.Linear(10, 10)
107        self.layer2 = torch.nn.Linear(10, 10)
108
109    @property
110    def encoder_layers(self):
111        return [self.layer0, self.layer1, self.layer2]
112
113    def forward(self, x):
114        for layer in self.encoder_layers:
115            output = layer(x)
116            if layer is None or layer == self.layer0:
117                output = F.relu6(output)
118            else:
119                output = F.relu(output)
120        return output
121
122
123class Benchmark(BenchmarkBase):
124    def __init__(self, ModuleClass, backend):
125        self.ModuleClass = ModuleClass
126        self.backend = backend
127        self._name = ModuleClass.__name__
128
129    def name(self):
130        return f"basic_modules_{self._name}_{self.backend}"
131
132    def _prepare_once(self):
133        self.m = self.ModuleClass()
134        self.input = torch.ones(10)
135
136    def _prepare(self):
137        torch._dynamo.reset()
138
139    def _work(self):
140        with fresh_inductor_cache():
141            opt_m = torch.compile(backend=self.backend)(self.m)
142            opt_m(self.input)
143
144
145def main():
146    result_path = sys.argv[1]
147    benchmarks = [
148        Benchmark(ListOfLinears, "inductor"),
149        Benchmark(ListOfLinears, "eager"),
150        Benchmark(ModuleForwardHasGraphBreak, "inductor"),
151        Benchmark(ModuleForwardHasGraphBreak, "eager"),
152        Benchmark(SequentialWithDuplicatedModule, "inductor"),
153        Benchmark(SequentialWithDuplicatedModule, "eager"),
154        Benchmark(ModuleComparison, "inductor"),
155        Benchmark(ModuleComparison, "eager"),
156    ]
157    for b in benchmarks:
158        b.enable_compile_time_instruction_count().collect_all().append_results(
159            result_path
160        )
161
162
163if __name__ == "__main__":
164    main()
165