1import torch 2 3from . import benchmark 4 5 6class SwishBench(benchmark.Benchmark): 7 def __init__(self, mode, device, dtype, M, N): 8 super().__init__(mode, device, dtype) 9 self.M = M 10 self.N = N 11 self.data = self.rand( 12 [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad 13 ) 14 self.inputs = [self.data] 15 self.zeros = torch.zeros(M, N, device=device) 16 self.six = self.zeros + 6.0 17 self.three = self.zeros + 3.0 18 self.sixth = self.zeros + 1.0 / 6.0 19 20 def forward(self, inp): 21 y = inp * (torch.min(torch.relu(inp), self.six) + self.three) * self.sixth 22 return y 23 24 def reference(self): 25 return self.numpy(self.forward(self.data)) 26 27 def config(self): 28 return [self.M, self.N] 29 30 @staticmethod 31 def module(): 32 return "swish" 33 34 def memory_workload(self): 35 if self.mode == "fwd": 36 sol_count = 1 + 1 37 algorithmic_count = 3 + 1 38 else: 39 sol_count = (1 + 1) + (1 + 1) 40 algorithmic_count = (3 + 1) + (3 + 1) 41 42 buffer_size = self.M * self.N 43 return { 44 "sol": buffer_size * sol_count, 45 "algorithmic": buffer_size * algorithmic_count, 46 } 47 48 @staticmethod 49 def default_configs(): 50 return [[128, 1 << 16]] 51 52 53benchmark.register_benchmark_class(SwishBench) 54