1import numpy as np 2 3from . import benchmark 4 5 6class MatMulBench(benchmark.Benchmark): 7 def __init__(self, mode, device, dtype, B, M, N, K): 8 super().__init__(mode, device, dtype) 9 self.B = B 10 self.M = M 11 self.N = N 12 self.K = K 13 self.d1 = self.rand( 14 [B, M, N], device=device, dtype=dtype, requires_grad=self.requires_grad 15 ) 16 self.d2 = self.rand( 17 [B, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad 18 ) 19 self.inputs = [self.d1, self.d2] 20 21 def forward(self, d1, d2): 22 y = self.matmul(d1, d2) 23 return y 24 25 def reference(self): 26 return np.matmul(self.numpy(self.d1), self.numpy(self.d2)) 27 28 def config(self): 29 return [self.B, self.M, self.N, self.K] 30 31 @staticmethod 32 def module(): 33 return "batch_matmul" 34 35 def memory_workload(self): 36 if self.mode == "fwd": 37 sol_count = 1 38 algorithmic_count = 1 39 else: 40 sol_count = 1 + 1 41 algorithmic_count = 1 + (1 + 1) 42 43 buffer_size = ( 44 self.B * self.M * self.N 45 + self.B * self.M * self.N 46 + self.B * self.N * self.K 47 ) 48 return { 49 "sol": buffer_size * sol_count, 50 "algorithmic": buffer_size * algorithmic_count, 51 } 52 53 def compute_workload(self): 54 if self.mode == "fwd": 55 count = 1 56 else: 57 count = 1 + (1 + 1) 58 59 op_count = 2 * self.B * self.M * self.N * self.K 60 61 return op_count * count 62 63 @staticmethod 64 def default_configs(): 65 return [[128, 64, 128, 256]] 66 67 68benchmark.register_benchmark_class(MatMulBench) 69