1import numpy as np 2 3import torch 4 5from . import benchmark 6 7 8class Concat2D2InputBench(benchmark.Benchmark): 9 def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim): 10 super().__init__(mode, device, dtype) 11 self.I1_D1 = I1_D1 12 self.I1_D2 = I1_D2 13 self.I2_D1 = I2_D1 14 self.I2_D2 = I2_D2 15 self.concat_dim = concat_dim 16 self.input1 = self.randn( 17 [I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad 18 ) 19 self.input2 = self.randn( 20 [I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad 21 ) 22 self.inputs = [self.input1, self.input2] 23 24 def forward(self, input1, input2): 25 x1 = self.add(input1, 0.00001) 26 x2 = self.add(input2, 0.00001) 27 y = self.cat((x1, x2), dim=self.concat_dim) 28 return y 29 30 def reference(self): 31 return np.concatenate( 32 (self.numpy(self.input1), self.numpy(self.input2)), 33 axis=self.concat_dim, 34 ) 35 36 def config(self): 37 return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim] 38 39 @staticmethod 40 def module(): 41 return "concat2d2input" 42 43 def memory_workload(self): 44 if self.mode == "fwd": 45 sol_count = 1 + 1 46 algorithmic_count = 3 + 1 47 else: 48 sol_count = (1 + 1) + (1 + 1) 49 algorithmic_count = (3 + 1) + (3 + 1) 50 51 buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2 52 return { 53 "sol": buffer_size * sol_count, 54 "algorithmic": buffer_size * algorithmic_count, 55 } 56 57 @staticmethod 58 def default_configs(): 59 return [ 60 [1, 160, 1, 14, 1], 61 [1, 580, 1, 174, 1], 62 [20, 160, 20, 14, 1], 63 [20, 580, 20, 174, 1], 64 [8, 512, 8, 512, 1], 65 [1 << 13, 1060, 1 << 13, 1040, 1], 66 [1 << 13, 2000, 1 << 13, 1074, 1], 67 [1 << 15, 1060, 1 << 15, 2670, 1], 68 [1 << 15, 5120, 1 << 15, 2512, 1], 69 ] 70 71 72benchmark.register_benchmark_class(Concat2D2InputBench) 73 74 75class ConcatGraphOptBench(benchmark.Benchmark): 76 def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim): 77 super().__init__(mode, device, dtype) 78 self.I1_D1 = I1_D1 79 self.I1_D2 = I1_D2 80 self.I2_D1 = I2_D1 81 self.I2_D2 = I2_D2 82 self.concat_dim = concat_dim 83 self.input1 = self.randn( 84 [I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad 85 ) 86 self.input2 = self.randn( 87 [I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad 88 ) 89 self.inputs = [self.input1, self.input2] 90 torch._C._jit_override_can_fuse_on_cpu(True) 91 torch._C._jit_cat_wo_conditionals(True) 92 93 def forward(self, input1, input2): 94 x1 = self.add(input1, 0.00001) 95 x2 = self.add(input2, 0.00001) 96 y = self.cat((x1, x2), dim=self.concat_dim) 97 z = self.relu(y) 98 return z 99 100 def reference(self): 101 return np.concatenate( 102 (self.numpy(self.input1), self.numpy(self.input2)), 103 axis=self.concat_dim, 104 ) 105 106 def config(self): 107 return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim] 108 109 @staticmethod 110 def module(): 111 return "concatGraphOpt" 112 113 def memory_workload(self): 114 if self.mode == "fwd": 115 sol_count = 1 + 1 116 algorithmic_count = 3 + 1 117 else: 118 sol_count = (1 + 1) + (1 + 1) 119 algorithmic_count = (3 + 1) + (3 + 1) 120 121 buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2 122 return { 123 "sol": buffer_size * sol_count, 124 "algorithmic": buffer_size * algorithmic_count, 125 } 126 127 @staticmethod 128 def default_configs(): 129 return [ 130 [1 << 13, 1060, 1 << 13, 1040, 1], 131 [1 << 13, 2000, 1 << 13, 1074, 1], 132 [1 << 15, 1060, 1 << 15, 2670, 1], 133 [1 << 15, 5120, 1 << 15, 2512, 1], 134 ] 135 136 137benchmark.register_benchmark_class(ConcatGraphOptBench) 138