xref: /aosp_15_r20/external/pytorch/test/test_throughput_benchmark.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3import torch
4from torch.testing._internal.common_utils import run_tests, TemporaryFileName, TestCase
5from torch.utils import ThroughputBenchmark
6
7
8class TwoLayerNet(torch.jit.ScriptModule):
9    def __init__(self, D_in, H, D_out):
10        super().__init__()
11        self.linear1 = torch.nn.Linear(D_in, H)
12        self.linear2 = torch.nn.Linear(2 * H, D_out)
13
14    @torch.jit.script_method
15    def forward(self, x1, x2):
16        h1_relu = self.linear1(x1).clamp(min=0)
17        h2_relu = self.linear1(x2).clamp(min=0)
18        cat = torch.cat((h1_relu, h2_relu), 1)
19        y_pred = self.linear2(cat)
20        return y_pred
21
22
23class TwoLayerNetModule(torch.nn.Module):
24    def __init__(self, D_in, H, D_out):
25        super().__init__()
26        self.linear1 = torch.nn.Linear(D_in, H)
27        self.linear2 = torch.nn.Linear(2 * H, D_out)
28
29    def forward(self, x1, x2):
30        h1_relu = self.linear1(x1).clamp(min=0)
31        h2_relu = self.linear1(x2).clamp(min=0)
32        cat = torch.cat((h1_relu, h2_relu), 1)
33        y_pred = self.linear2(cat)
34        return y_pred
35
36
37class TestThroughputBenchmark(TestCase):
38    def linear_test(self, Module, profiler_output_path=""):
39        D_in = 10
40        H = 5
41        D_out = 15
42        B = 8
43        NUM_INPUTS = 2
44
45        module = Module(D_in, H, D_out)
46
47        inputs = []
48
49        for i in range(NUM_INPUTS):
50            inputs.append([torch.randn(B, D_in), torch.randn(B, D_in)])
51        bench = ThroughputBenchmark(module)
52
53        for input in inputs:
54            # can do both args and kwargs here
55            bench.add_input(input[0], x2=input[1])
56
57        for i in range(NUM_INPUTS):
58            # or just unpack the list of inputs
59            module_result = module(*inputs[i])
60            bench_result = bench.run_once(*inputs[i])
61            torch.testing.assert_close(bench_result, module_result)
62
63        stats = bench.benchmark(
64            num_calling_threads=4,
65            num_warmup_iters=100,
66            num_iters=1000,
67            profiler_output_path=profiler_output_path,
68        )
69
70        print(stats)
71
72    def test_script_module(self):
73        self.linear_test(TwoLayerNet)
74
75    def test_module(self):
76        self.linear_test(TwoLayerNetModule)
77
78    def test_profiling(self):
79        with TemporaryFileName() as fname:
80            self.linear_test(TwoLayerNetModule, profiler_output_path=fname)
81
82
83if __name__ == "__main__":
84    run_tests()
85