1# Owner(s): ["module: unknown"] 2 3import torch 4from torch.testing._internal.common_utils import run_tests, TemporaryFileName, TestCase 5from torch.utils import ThroughputBenchmark 6 7 8class TwoLayerNet(torch.jit.ScriptModule): 9 def __init__(self, D_in, H, D_out): 10 super().__init__() 11 self.linear1 = torch.nn.Linear(D_in, H) 12 self.linear2 = torch.nn.Linear(2 * H, D_out) 13 14 @torch.jit.script_method 15 def forward(self, x1, x2): 16 h1_relu = self.linear1(x1).clamp(min=0) 17 h2_relu = self.linear1(x2).clamp(min=0) 18 cat = torch.cat((h1_relu, h2_relu), 1) 19 y_pred = self.linear2(cat) 20 return y_pred 21 22 23class TwoLayerNetModule(torch.nn.Module): 24 def __init__(self, D_in, H, D_out): 25 super().__init__() 26 self.linear1 = torch.nn.Linear(D_in, H) 27 self.linear2 = torch.nn.Linear(2 * H, D_out) 28 29 def forward(self, x1, x2): 30 h1_relu = self.linear1(x1).clamp(min=0) 31 h2_relu = self.linear1(x2).clamp(min=0) 32 cat = torch.cat((h1_relu, h2_relu), 1) 33 y_pred = self.linear2(cat) 34 return y_pred 35 36 37class TestThroughputBenchmark(TestCase): 38 def linear_test(self, Module, profiler_output_path=""): 39 D_in = 10 40 H = 5 41 D_out = 15 42 B = 8 43 NUM_INPUTS = 2 44 45 module = Module(D_in, H, D_out) 46 47 inputs = [] 48 49 for i in range(NUM_INPUTS): 50 inputs.append([torch.randn(B, D_in), torch.randn(B, D_in)]) 51 bench = ThroughputBenchmark(module) 52 53 for input in inputs: 54 # can do both args and kwargs here 55 bench.add_input(input[0], x2=input[1]) 56 57 for i in range(NUM_INPUTS): 58 # or just unpack the list of inputs 59 module_result = module(*inputs[i]) 60 bench_result = bench.run_once(*inputs[i]) 61 torch.testing.assert_close(bench_result, module_result) 62 63 stats = bench.benchmark( 64 num_calling_threads=4, 65 num_warmup_iters=100, 66 num_iters=1000, 67 profiler_output_path=profiler_output_path, 68 ) 69 70 print(stats) 71 72 def test_script_module(self): 73 self.linear_test(TwoLayerNet) 74 75 def test_module(self): 76 self.linear_test(TwoLayerNetModule) 77 78 def test_profiling(self): 79 with TemporaryFileName() as fname: 80 self.linear_test(TwoLayerNetModule, profiler_output_path=fname) 81 82 83if __name__ == "__main__": 84 run_tests() 85