1import itertools 2import operator 3 4import numpy as np 5 6import torch 7 8from . import benchmark 9 10 11class BroadcastMulBench(benchmark.Benchmark): 12 def __init__(self, mode, device, dtype, case, M, N, K): 13 super().__init__(mode, device, dtype) 14 self.case = case 15 self.M = M 16 self.N = N 17 self.K = K 18 19 if case == "row": 20 self.d1 = self.rand( 21 [M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad 22 ) 23 self.d2 = self.rand( 24 [M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad 25 ) 26 elif case == "mid": 27 self.d1 = self.rand( 28 [M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad 29 ) 30 self.d2 = self.rand( 31 [1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad 32 ) 33 elif case == "col": 34 self.d1 = self.rand( 35 [M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad 36 ) 37 self.d2 = self.rand( 38 [1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad 39 ) 40 else: 41 raise ValueError(f"invalid case: {case}") 42 43 self.inputs = [self.d1, self.d2] 44 45 def forward(self, d1, d2): 46 y = d1 + d2 47 return y 48 49 def reference(self): 50 return self.numpy(self.d1) + self.numpy(self.d2) 51 52 def config(self): 53 return [self.M, self.N, self.K] 54 55 @staticmethod 56 def default_configs(): 57 return [[128, 256, 128]] 58 59 def memory_workload(self): 60 if self.mode == "fwd": 61 sol_count = 1 62 algorithmic_count = 1 63 else: 64 sol_count = (1) + (1) 65 algorithmic_count = 1 + (1 + 1) 66 67 buffer_size = self.M * self.N * self.K 68 return { 69 "sol": buffer_size * sol_count, 70 "algorithmic": buffer_size * algorithmic_count, 71 } 72 73 74class BroadcastRowBench(BroadcastMulBench): 75 def __init__(self, mode, device, dtype, M, N, K): 76 super().__init__(mode, device, dtype, "row", M, N, K) 77 78 @staticmethod 79 def module(): 80 return "broadcast_row" 81 82 83class BroadcastMidBench(BroadcastMulBench): 84 def __init__(self, mode, device, dtype, M, N, K): 85 super().__init__(mode, device, dtype, "mid", M, N, K) 86 87 @staticmethod 88 def module(): 89 return "broadcast_mid" 90 91 92class BroadcastColBench(BroadcastMulBench): 93 def __init__(self, mode, device, dtype, M, N, K): 94 super().__init__(mode, device, dtype, "col", M, N, K) 95 96 @staticmethod 97 def module(): 98 return "broadcast_col" 99 100 101class BroadcastThreeArgs(benchmark.Benchmark): 102 def __init__(self, mode, device, dtype, M, N, K, L): 103 super().__init__(mode, device, dtype) 104 self.M = M 105 self.N = N 106 self.K = K 107 self.L = L 108 109 self.d1 = self.rand( 110 [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad 111 ) 112 self.d2 = self.rand( 113 [K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad 114 ) 115 self.d3 = self.rand( 116 [L, K, 1, 1], device=device, dtype=dtype, requires_grad=self.requires_grad 117 ) 118 119 self.inputs = [self.d1, self.d2, self.d3] 120 121 def forward(self, d1, d2, d3): 122 y = d1 + d2 + d3 123 return y 124 125 def reference(self): 126 return self.numpy(self.d1) + self.numpy(self.d2) + self.numpy(self.d3) 127 128 def config(self): 129 return [self.M, self.N, self.K, self.L] 130 131 @staticmethod 132 def default_configs(): 133 return [[32, 16, 64, 128]] 134 135 def memory_workload(self): 136 if self.mode == "fwd": 137 sol_count = 1 138 algorithmic_count = 1 139 else: 140 sol_count = (1) + (1) 141 algorithmic_count = 1 + (1 + 1 + 1) 142 143 buffer_size = self.M * self.N * self.K * self.L * 4 144 return { 145 "sol": buffer_size * sol_count, 146 "algorithmic": buffer_size * algorithmic_count, 147 } 148 149 @staticmethod 150 def module(): 151 return "broadcast_3args" 152 153 154# benchmark.register_benchmark_class(BroadcastRowBench) 155# benchmark.register_benchmark_class(BroadcastMidBench) 156# benchmark.register_benchmark_class(BroadcastColBench) 157# benchmark.register_benchmark_class(BroadcastThreeArgs) 158 159 160# TODO: merge this with elementwise bench 161# A template class for elementwise operations. 162# A derived class will override the class instance to customize its behavior. 163class BroadcastBench(benchmark.Benchmark): 164 # List of customization class variables. 165 op_str = None 166 binary_op_pt_func = None 167 binary_op_np_func = None 168 unary_op_pt_func = None 169 unary_op_np_func = None 170 split_input = True 171 172 def __init__(self, mode, device, dtype, M, N, K): 173 super().__init__(mode, device, dtype) 174 self.M = M 175 self.N = N 176 self.K = K 177 self.d1 = self.rand( 178 [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad 179 ) 180 self.d2 = self.rand( 181 [K, 1, N], device=device, dtype=dtype, requires_grad=self.requires_grad 182 ) 183 self.d3 = self.rand( 184 [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad 185 ) 186 self.d4 = self.rand( 187 [K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad 188 ) 189 self.inputs = [self.d1, self.d2, self.d3, self.d4] 190 191 def _eval(self, d1, d2, d3, d4, binary_op, unary_op): 192 if not binary_op: 193 194 def binary_op(x, y): 195 return x + y 196 197 if not unary_op: 198 199 def unary_op(x): 200 return x 201 202 if self.split_input: 203 d1 = unary_op(d1) 204 d2 = unary_op(d2) 205 d3 = unary_op(d3) 206 d4 = unary_op(d4) 207 else: 208 d1, d2, d3, d4 = ( 209 unary_op(d1), 210 unary_op(d2), 211 unary_op(d1 + 0.001), 212 unary_op(d4), 213 ) 214 a = binary_op(d1, d2) 215 b = binary_op(d3, d4) 216 c = a + b 217 return c 218 219 def forward(self, d1, d2, d3, d4): 220 binary_op = self.__class__.binary_op_pt_func 221 unary_op = self.__class__.unary_op_pt_func 222 return self._eval(d1, d2, d3, d4, binary_op, unary_op) 223 224 def reference(self): 225 binary_op = self.__class__.binary_op_np_func 226 unary_op = self.__class__.unary_op_np_func 227 [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]] 228 return self._eval(d1, d2, d3, d4, binary_op, unary_op) 229 230 def config(self): 231 return [self.M, self.N, self.K] 232 233 @classmethod 234 def module(cls): 235 return "broadcast_" + cls.op_str 236 237 def memory_workload(self): 238 input_count = len(self.inputs) 239 if self.mode == "fwd": 240 if self.split_input: 241 sol_count = 1 242 algorithmic_count = 1 243 else: 244 sol_count = 1 245 algorithmic_count = 1 246 else: 247 if self.split_input: 248 sol_count = 1 249 algorithmic_count = input_count 250 else: 251 sol_count = 1 252 algorithmic_count = input_count 253 254 buffer_size = self.M * self.N * self.K * 4 255 return { 256 "sol": buffer_size * sol_count, 257 "algorithmic": buffer_size * algorithmic_count, 258 } 259 260 @staticmethod 261 def default_configs(): 262 return [[1 << 8, 1 << 7, 1 << 9]] 263 264 265def register_broadcast_ops(): 266 binary_op_list = [ 267 ["mul", operator.mul], 268 ["add", operator.add], 269 ["sub", operator.sub], 270 ["div", lambda a, b: a / (b + 1e-4)], 271 [ 272 "pow", 273 torch.pow, 274 np.power, 275 ], # no fuson triggered 276 ["max", torch.max, np.maximum], 277 ["min", torch.min, np.minimum], 278 ] 279 280 unary_op_list = [ 281 ["erf", torch.erf, np.erf], 282 ["exp", torch.exp, np.exp], 283 ["sin", torch.sin, np.sin], 284 ["cos", torch.cos, np.cos], 285 ] 286 287 for split_input, binary_op in itertools.product([True, False], binary_op_list): 288 # Make a copy of BroadcastBench 289 if len(binary_op) == 2: 290 [op_str, op_pt_func] = binary_op 291 op_np_func = op_pt_func 292 elif len(binary_op) == 3: 293 [op_str, op_pt_func, op_np_func] = binary_op 294 split_str = "split" if split_input else "shared" 295 op_str = split_str + "_" + op_str 296 bm_cls = type("BroadcastBench_" + op_str, (BroadcastBench,), {}) 297 bm_cls.op_str = op_str 298 bm_cls.binary_op_pt_func = op_pt_func 299 bm_cls.binary_op_np_func = op_np_func 300 bm_cls.split_input = split_input 301 benchmark.register_benchmark_class(bm_cls) 302 303 for split_input, unary_op in itertools.product([True, False], unary_op_list): 304 # Make a copy of BroadcastBench 305 if len(unary_op) == 2: 306 [op_str, op_pt_func] = unary_op 307 op_np_func = op_pt_func 308 elif len(unary_op) == 3: 309 [op_str, op_pt_func, op_np_func] = unary_op 310 split_str = "split" if split_input else "shared" 311 op_str = split_str + "_" + op_str 312 bm_cls = type("BroadcastBench_" + op_str, (BroadcastBench,), {}) 313 bm_cls.op_str = op_str 314 bm_cls.unary_op_pt_func = op_pt_func 315 bm_cls.unary_op_np_func = op_np_func 316 bm_cls.split_input = split_input 317 benchmark.register_benchmark_class(bm_cls) 318 319 320register_broadcast_ops() 321