1from . import benchmark 2 3 4class ConvImplBench(benchmark.Benchmark): 5 def __init__(self, case, mode, device, dtype, kernel_size, N, iC, H, W, oC): 6 super().__init__(mode, device, dtype) 7 self.case = case 8 self.kernel_size = kernel_size 9 self.N = N 10 self.iC = iC 11 self.H = H 12 self.W = W 13 self.oC = oC 14 self.data = self.rand( 15 [N, iC, H, W], device=device, requires_grad=self.requires_grad 16 ) 17 if case == "conv": 18 self.groups = 1 19 elif case == "depthwise_conv": 20 self.groups = iC 21 else: 22 raise ValueError(f"invalid case: {case}") 23 24 self.conv = self.conv2d_layer(iC, oC, kernel_size, groups=self.groups) 25 if device != "cpu": 26 self.to_device(self.conv, device) 27 28 def forward(self): 29 y = self.conv(self.data) 30 return y 31 32 def config(self): 33 return [self.kernel_size, self.N, self.iC, self.H, self.W, self.oC] 34 35 def memory_workload(self): 36 if self.mode == "fwd": 37 sol_count = {"i": 1, "o": 1, "k": 1} 38 algorithmic_count = {"i": 1, "o": 1, "k": 1} 39 else: 40 sol_count = {"i": 1 + 1, "o": 1 + 1, "k": 1 + 1} 41 algorithmic_count = {"i": 1 + (1 + 1), "o": 1 + (1 + 1), "k": 1 + (1 + 1)} 42 43 buffer_size = { 44 "i": self.N * self.iC * self.H * self.W, 45 "o": self.N * self.oC * self.H * self.W, 46 "k": self.oC 47 * (self.iC / self.groups) 48 * self.kernel_size 49 * self.kernel_size, 50 } 51 sol_size = 0 52 algorithmic_size = 0 53 for key in sol_count: 54 sol_size += buffer_size[key] * sol_count[key] 55 algorithmic_size += buffer_size[key] * algorithmic_count[key] 56 return {"sol": sol_size, "algorithmic": algorithmic_size} 57 58 def compute_workload(self): 59 if self.mode == "fwd": 60 count = 1 61 elif self.mode == "both": 62 count = 1 + (1 + 1) 63 else: 64 raise ValueError(f"invalid mode: {self.mode}") 65 66 op_count = ( 67 self.N 68 * self.iC 69 / self.groups 70 * self.oC 71 * self.kernel_size 72 * self.kernel_size 73 * self.H 74 * self.W 75 ) 76 op_count *= 2 77 78 return op_count * count 79 80 @staticmethod 81 def default_configs(): 82 return [ 83 [3, 64, 32, 128, 128, 64], 84 ] 85 86 87class ConvBench(ConvImplBench): 88 def __init__(self, *args): 89 super().__init__("conv", *args) 90 91 @staticmethod 92 def module(): 93 return "conv" 94 95 96class DepthwiseConvBench(ConvImplBench): 97 def __init__(self, *args): 98 super().__init__("depthwise_conv", *args) 99 100 @staticmethod 101 def module(): 102 return "depthwise_conv" 103 104 105benchmark.register_benchmark_class(ConvBench) 106benchmark.register_benchmark_class(DepthwiseConvBench) 107