import itertools import operator import numpy as np import torch from . import benchmark class BroadcastMulBench(benchmark.Benchmark): def __init__(self, mode, device, dtype, case, M, N, K): super().__init__(mode, device, dtype) self.case = case self.M = M self.N = N self.K = K if case == "row": self.d1 = self.rand( [M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad ) self.d2 = self.rand( [M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad ) elif case == "mid": self.d1 = self.rand( [M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad ) self.d2 = self.rand( [1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad ) elif case == "col": self.d1 = self.rand( [M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad ) self.d2 = self.rand( [1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad ) else: raise ValueError(f"invalid case: {case}") self.inputs = [self.d1, self.d2] def forward(self, d1, d2): y = d1 + d2 return y def reference(self): return self.numpy(self.d1) + self.numpy(self.d2) def config(self): return [self.M, self.N, self.K] @staticmethod def default_configs(): return [[128, 256, 128]] def memory_workload(self): if self.mode == "fwd": sol_count = 1 algorithmic_count = 1 else: sol_count = (1) + (1) algorithmic_count = 1 + (1 + 1) buffer_size = self.M * self.N * self.K return { "sol": buffer_size * sol_count, "algorithmic": buffer_size * algorithmic_count, } class BroadcastRowBench(BroadcastMulBench): def __init__(self, mode, device, dtype, M, N, K): super().__init__(mode, device, dtype, "row", M, N, K) @staticmethod def module(): return "broadcast_row" class BroadcastMidBench(BroadcastMulBench): def __init__(self, mode, device, dtype, M, N, K): super().__init__(mode, device, dtype, "mid", M, N, K) @staticmethod def module(): return "broadcast_mid" class BroadcastColBench(BroadcastMulBench): def __init__(self, mode, device, dtype, M, N, K): super().__init__(mode, device, dtype, "col", M, N, K) @staticmethod def module(): return "broadcast_col" class BroadcastThreeArgs(benchmark.Benchmark): def __init__(self, mode, device, dtype, M, N, K, L): super().__init__(mode, device, dtype) self.M = M self.N = N self.K = K self.L = L self.d1 = self.rand( [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad ) self.d2 = self.rand( [K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad ) self.d3 = self.rand( [L, K, 1, 1], device=device, dtype=dtype, requires_grad=self.requires_grad ) self.inputs = [self.d1, self.d2, self.d3] def forward(self, d1, d2, d3): y = d1 + d2 + d3 return y def reference(self): return self.numpy(self.d1) + self.numpy(self.d2) + self.numpy(self.d3) def config(self): return [self.M, self.N, self.K, self.L] @staticmethod def default_configs(): return [[32, 16, 64, 128]] def memory_workload(self): if self.mode == "fwd": sol_count = 1 algorithmic_count = 1 else: sol_count = (1) + (1) algorithmic_count = 1 + (1 + 1 + 1) buffer_size = self.M * self.N * self.K * self.L * 4 return { "sol": buffer_size * sol_count, "algorithmic": buffer_size * algorithmic_count, } @staticmethod def module(): return "broadcast_3args" # benchmark.register_benchmark_class(BroadcastRowBench) # benchmark.register_benchmark_class(BroadcastMidBench) # benchmark.register_benchmark_class(BroadcastColBench) # benchmark.register_benchmark_class(BroadcastThreeArgs) # TODO: merge this with elementwise bench # A template class for elementwise operations. # A derived class will override the class instance to customize its behavior. class BroadcastBench(benchmark.Benchmark): # List of customization class variables. op_str = None binary_op_pt_func = None binary_op_np_func = None unary_op_pt_func = None unary_op_np_func = None split_input = True def __init__(self, mode, device, dtype, M, N, K): super().__init__(mode, device, dtype) self.M = M self.N = N self.K = K self.d1 = self.rand( [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad ) self.d2 = self.rand( [K, 1, N], device=device, dtype=dtype, requires_grad=self.requires_grad ) self.d3 = self.rand( [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad ) self.d4 = self.rand( [K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad ) self.inputs = [self.d1, self.d2, self.d3, self.d4] def _eval(self, d1, d2, d3, d4, binary_op, unary_op): if not binary_op: def binary_op(x, y): return x + y if not unary_op: def unary_op(x): return x if self.split_input: d1 = unary_op(d1) d2 = unary_op(d2) d3 = unary_op(d3) d4 = unary_op(d4) else: d1, d2, d3, d4 = ( unary_op(d1), unary_op(d2), unary_op(d1 + 0.001), unary_op(d4), ) a = binary_op(d1, d2) b = binary_op(d3, d4) c = a + b return c def forward(self, d1, d2, d3, d4): binary_op = self.__class__.binary_op_pt_func unary_op = self.__class__.unary_op_pt_func return self._eval(d1, d2, d3, d4, binary_op, unary_op) def reference(self): binary_op = self.__class__.binary_op_np_func unary_op = self.__class__.unary_op_np_func [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]] return self._eval(d1, d2, d3, d4, binary_op, unary_op) def config(self): return [self.M, self.N, self.K] @classmethod def module(cls): return "broadcast_" + cls.op_str def memory_workload(self): input_count = len(self.inputs) if self.mode == "fwd": if self.split_input: sol_count = 1 algorithmic_count = 1 else: sol_count = 1 algorithmic_count = 1 else: if self.split_input: sol_count = 1 algorithmic_count = input_count else: sol_count = 1 algorithmic_count = input_count buffer_size = self.M * self.N * self.K * 4 return { "sol": buffer_size * sol_count, "algorithmic": buffer_size * algorithmic_count, } @staticmethod def default_configs(): return [[1 << 8, 1 << 7, 1 << 9]] def register_broadcast_ops(): binary_op_list = [ ["mul", operator.mul], ["add", operator.add], ["sub", operator.sub], ["div", lambda a, b: a / (b + 1e-4)], [ "pow", torch.pow, np.power, ], # no fuson triggered ["max", torch.max, np.maximum], ["min", torch.min, np.minimum], ] unary_op_list = [ ["erf", torch.erf, np.erf], ["exp", torch.exp, np.exp], ["sin", torch.sin, np.sin], ["cos", torch.cos, np.cos], ] for split_input, binary_op in itertools.product([True, False], binary_op_list): # Make a copy of BroadcastBench if len(binary_op) == 2: [op_str, op_pt_func] = binary_op op_np_func = op_pt_func elif len(binary_op) == 3: [op_str, op_pt_func, op_np_func] = binary_op split_str = "split" if split_input else "shared" op_str = split_str + "_" + op_str bm_cls = type("BroadcastBench_" + op_str, (BroadcastBench,), {}) bm_cls.op_str = op_str bm_cls.binary_op_pt_func = op_pt_func bm_cls.binary_op_np_func = op_np_func bm_cls.split_input = split_input benchmark.register_benchmark_class(bm_cls) for split_input, unary_op in itertools.product([True, False], unary_op_list): # Make a copy of BroadcastBench if len(unary_op) == 2: [op_str, op_pt_func] = unary_op op_np_func = op_pt_func elif len(unary_op) == 3: [op_str, op_pt_func, op_np_func] = unary_op split_str = "split" if split_input else "shared" op_str = split_str + "_" + op_str bm_cls = type("BroadcastBench_" + op_str, (BroadcastBench,), {}) bm_cls.op_str = op_str bm_cls.unary_op_pt_func = op_pt_func bm_cls.unary_op_np_func = op_np_func bm_cls.split_input = split_input benchmark.register_benchmark_class(bm_cls) register_broadcast_ops()