1*da0073e9SAndroid Build Coastguard Workerimport contextlib 2*da0073e9SAndroid Build Coastguard Workerimport json 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport time 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport numpy as np 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerfrom . import tensor_engine 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerclass Benchmark: 14*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype): 15*da0073e9SAndroid Build Coastguard Worker self.mode = mode 16*da0073e9SAndroid Build Coastguard Worker self.deterministic = False 17*da0073e9SAndroid Build Coastguard Worker self.device = device 18*da0073e9SAndroid Build Coastguard Worker self.dtype = dtype 19*da0073e9SAndroid Build Coastguard Worker self.output_type = "stdout" 20*da0073e9SAndroid Build Coastguard Worker self.print_ir = False 21*da0073e9SAndroid Build Coastguard Worker self.print_kernel = False 22*da0073e9SAndroid Build Coastguard Worker if mode == "both": 23*da0073e9SAndroid Build Coastguard Worker self.requires_grad = True 24*da0073e9SAndroid Build Coastguard Worker elif mode == "fwd": 25*da0073e9SAndroid Build Coastguard Worker self.requires_grad = False 26*da0073e9SAndroid Build Coastguard Worker else: 27*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"invalid mode: {mode}") 28*da0073e9SAndroid Build Coastguard Worker self.result_grad = None 29*da0073e9SAndroid Build Coastguard Worker self.grad_variables = [] 30*da0073e9SAndroid Build Coastguard Worker self.engine = tensor_engine.get_engine() 31*da0073e9SAndroid Build Coastguard Worker self.engine.reset(device) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker # forward all member functions in self.engine to self 34*da0073e9SAndroid Build Coastguard Worker for method in dir(self.engine): 35*da0073e9SAndroid Build Coastguard Worker if not callable(getattr(self.engine, method)): 36*da0073e9SAndroid Build Coastguard Worker continue 37*da0073e9SAndroid Build Coastguard Worker # don't forward if this function is overriden here 38*da0073e9SAndroid Build Coastguard Worker if hasattr(self, method): 39*da0073e9SAndroid Build Coastguard Worker continue 40*da0073e9SAndroid Build Coastguard Worker # don't forward if it is a internal function 41*da0073e9SAndroid Build Coastguard Worker if method.startswith("_"): 42*da0073e9SAndroid Build Coastguard Worker continue 43*da0073e9SAndroid Build Coastguard Worker method_engine = getattr(self.engine, method) 44*da0073e9SAndroid Build Coastguard Worker setattr(self, method, method_engine) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker def forward(self): 47*da0073e9SAndroid Build Coastguard Worker """do one step worth of computation""" 48*da0073e9SAndroid Build Coastguard Worker raise ValueError("this method should be reimplemented by subclass") 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker def check(self): 51*da0073e9SAndroid Build Coastguard Worker if not self.deterministic: 52*da0073e9SAndroid Build Coastguard Worker return 53*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose( 54*da0073e9SAndroid Build Coastguard Worker self.reference(), self.numpy(self.compute()), atol=1e-2 55*da0073e9SAndroid Build Coastguard Worker ) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker def config(self): 58*da0073e9SAndroid Build Coastguard Worker """returns an array for the current benchmark configs""" 59*da0073e9SAndroid Build Coastguard Worker raise ValueError("this method should be reimplemented by subclass") 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker def desc(self): 62*da0073e9SAndroid Build Coastguard Worker """return the description of the current benchmark""" 63*da0073e9SAndroid Build Coastguard Worker config = self.config() 64*da0073e9SAndroid Build Coastguard Worker config_str = "_".join([str(x) for x in config]) 65*da0073e9SAndroid Build Coastguard Worker device = self.device 66*da0073e9SAndroid Build Coastguard Worker if "NNC_NUM_THREADS" in os.environ: 67*da0073e9SAndroid Build Coastguard Worker num_threads_str = os.environ["NNC_NUM_THREADS"] 68*da0073e9SAndroid Build Coastguard Worker device += num_threads_str 69*da0073e9SAndroid Build Coastguard Worker return f"{self.engine.mode}: {self.module()}_{self.mode}_{device}_{config_str}" 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker @staticmethod 72*da0073e9SAndroid Build Coastguard Worker def module(): 73*da0073e9SAndroid Build Coastguard Worker raise ValueError("this method should be reimplemented by subclass") 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker def memory_workload(self): 76*da0073e9SAndroid Build Coastguard Worker raise ValueError("this method should be reimplemented by subclass") 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker def compute_workload(self): 79*da0073e9SAndroid Build Coastguard Worker """return the number of scalar operations it takes to finish the tensor op""" 80*da0073e9SAndroid Build Coastguard Worker return None 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker @staticmethod 83*da0073e9SAndroid Build Coastguard Worker def input_iterable(): 84*da0073e9SAndroid Build Coastguard Worker """A benchmark child class should return true if it utilizes the input iter arg""" 85*da0073e9SAndroid Build Coastguard Worker return False 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker def dtype_to_bytes(self): 88*da0073e9SAndroid Build Coastguard Worker return torch.tensor(0, dtype=self.dtype).element_size() 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker @staticmethod 91*da0073e9SAndroid Build Coastguard Worker def default_configs(): 92*da0073e9SAndroid Build Coastguard Worker """return a list of defualt configs for this benchmark""" 93*da0073e9SAndroid Build Coastguard Worker raise ValueError("this method should be reimplemented by subclass") 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker def is_supported(self): 96*da0073e9SAndroid Build Coastguard Worker return True 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker def rand(self, shape, device=None, dtype=None, requires_grad=False): 99*da0073e9SAndroid Build Coastguard Worker v = self.engine.rand( 100*da0073e9SAndroid Build Coastguard Worker shape, device=device, dtype=dtype, requires_grad=requires_grad 101*da0073e9SAndroid Build Coastguard Worker ) 102*da0073e9SAndroid Build Coastguard Worker if requires_grad: 103*da0073e9SAndroid Build Coastguard Worker self.grad_variables.append(v) 104*da0073e9SAndroid Build Coastguard Worker return v 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker def nchw_rand(self, shape, device=None, requires_grad=False): 107*da0073e9SAndroid Build Coastguard Worker v = self.engine.nchw_rand(shape, device=device, requires_grad=requires_grad) 108*da0073e9SAndroid Build Coastguard Worker if requires_grad: 109*da0073e9SAndroid Build Coastguard Worker self.grad_variables.append(v) 110*da0073e9SAndroid Build Coastguard Worker return v 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker def compute(self): 113*da0073e9SAndroid Build Coastguard Worker if self.bm_jit: 114*da0073e9SAndroid Build Coastguard Worker return self.bm_jit(*self.inputs) 115*da0073e9SAndroid Build Coastguard Worker else: 116*da0073e9SAndroid Build Coastguard Worker return self.forward(*self.inputs) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker def run(self, args): 119*da0073e9SAndroid Build Coastguard Worker self.print_ir = args.print_ir 120*da0073e9SAndroid Build Coastguard Worker if args.cuda_fuser == "old": 121*da0073e9SAndroid Build Coastguard Worker torch._C._jit_override_can_fuse_on_gpu(True) 122*da0073e9SAndroid Build Coastguard Worker if args.print_kernel: 123*da0073e9SAndroid Build Coastguard Worker os.environ["PYTORCH_FUSION_DEBUG"] = "1" 124*da0073e9SAndroid Build Coastguard Worker return self.run_impl(True) 125*da0073e9SAndroid Build Coastguard Worker elif args.cuda_fuser == "te": 126*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_texpr_fuser_enabled(True) 127*da0073e9SAndroid Build Coastguard Worker with cuda_pointwise_context( 128*da0073e9SAndroid Build Coastguard Worker args.cuda_pointwise_loop_levels, 129*da0073e9SAndroid Build Coastguard Worker args.cuda_pointwise_block_count, 130*da0073e9SAndroid Build Coastguard Worker args.cuda_pointwise_block_size, 131*da0073e9SAndroid Build Coastguard Worker ): 132*da0073e9SAndroid Build Coastguard Worker return self.run_impl(True) 133*da0073e9SAndroid Build Coastguard Worker elif args.cuda_fuser == "nvf": 134*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_nvfuser_enabled(True) 135*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_profiling_executor(True) 136*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_profiling_mode(True) 137*da0073e9SAndroid Build Coastguard Worker torch._C._jit_override_can_fuse_on_cpu(False) 138*da0073e9SAndroid Build Coastguard Worker torch._C._jit_override_can_fuse_on_gpu(False) 139*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_bailout_depth(20) 140*da0073e9SAndroid Build Coastguard Worker if args.print_kernel: 141*da0073e9SAndroid Build Coastguard Worker os.environ["PYTORCH_CUDA_FUSER_DEBUG"] = "1" 142*da0073e9SAndroid Build Coastguard Worker return self.run_impl(True) 143*da0073e9SAndroid Build Coastguard Worker else: 144*da0073e9SAndroid Build Coastguard Worker return self.run_impl(False) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker def run_impl(self, use_fuser): 147*da0073e9SAndroid Build Coastguard Worker warmups = 10 148*da0073e9SAndroid Build Coastguard Worker if self.device == "cuda": 149*da0073e9SAndroid Build Coastguard Worker iters = 1000 150*da0073e9SAndroid Build Coastguard Worker else: 151*da0073e9SAndroid Build Coastguard Worker iters = 10 152*da0073e9SAndroid Build Coastguard Worker engine = tensor_engine.get_engine() 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker self.bm_jit = None 155*da0073e9SAndroid Build Coastguard Worker for i in range(warmups + iters): 156*da0073e9SAndroid Build Coastguard Worker if i == warmups: 157*da0073e9SAndroid Build Coastguard Worker if self.device == "cuda": 158*da0073e9SAndroid Build Coastguard Worker engine.sync_cuda() 159*da0073e9SAndroid Build Coastguard Worker time_start = time.time() 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker if i == 0: 162*da0073e9SAndroid Build Coastguard Worker if self.jit_mode == "trace" and use_fuser: 163*da0073e9SAndroid Build Coastguard Worker self.bm_jit = torch.jit.trace( 164*da0073e9SAndroid Build Coastguard Worker self.forward, example_inputs=self.inputs, check_trace=False 165*da0073e9SAndroid Build Coastguard Worker ) 166*da0073e9SAndroid Build Coastguard Worker if callable(getattr(self, "reference", None)): 167*da0073e9SAndroid Build Coastguard Worker self.check() 168*da0073e9SAndroid Build Coastguard Worker else: 169*da0073e9SAndroid Build Coastguard Worker print("Warning: no reference result for ", self.module()) 170*da0073e9SAndroid Build Coastguard Worker elif i == 1: 171*da0073e9SAndroid Build Coastguard Worker # The fusion graph is visible after the first iter is executed 172*da0073e9SAndroid Build Coastguard Worker if self.jit_mode == "trace" and use_fuser and self.print_ir: 173*da0073e9SAndroid Build Coastguard Worker print(self.bm_jit.graph_for(*self.inputs)) 174*da0073e9SAndroid Build Coastguard Worker z = self.compute() 175*da0073e9SAndroid Build Coastguard Worker if self.mode == "both": 176*da0073e9SAndroid Build Coastguard Worker if self.result_grad is None: 177*da0073e9SAndroid Build Coastguard Worker self.result_grad = engine.rand_like(z) 178*da0073e9SAndroid Build Coastguard Worker engine.backward([z], [self.result_grad], self.grad_variables) 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker if self.device == "cuda": 181*da0073e9SAndroid Build Coastguard Worker engine.sync_cuda() 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker duration = time.time() - time_start 184*da0073e9SAndroid Build Coastguard Worker iter_time = duration / iters 185*da0073e9SAndroid Build Coastguard Worker memory_workload = self.memory_workload() 186*da0073e9SAndroid Build Coastguard Worker compute_workload = self.compute_workload() 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker result_dict = { 189*da0073e9SAndroid Build Coastguard Worker "desc": self.desc(), 190*da0073e9SAndroid Build Coastguard Worker "us": iter_time * 1e6, 191*da0073e9SAndroid Build Coastguard Worker "sol": memory_workload["sol"] * self.dtype_to_bytes() / iter_time / 1e9, 192*da0073e9SAndroid Build Coastguard Worker "algorithmic": memory_workload["algorithmic"] 193*da0073e9SAndroid Build Coastguard Worker * self.dtype_to_bytes() 194*da0073e9SAndroid Build Coastguard Worker / iter_time 195*da0073e9SAndroid Build Coastguard Worker / 1e9, 196*da0073e9SAndroid Build Coastguard Worker } 197*da0073e9SAndroid Build Coastguard Worker if compute_workload: 198*da0073e9SAndroid Build Coastguard Worker result_dict["compute_workload"] = compute_workload / iter_time / 1e9 199*da0073e9SAndroid Build Coastguard Worker self.dump_result(result_dict) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker def dump_result(self, result_dict): 202*da0073e9SAndroid Build Coastguard Worker if self.output_type == "json": 203*da0073e9SAndroid Build Coastguard Worker print(json.dumps(result_dict)) 204*da0073e9SAndroid Build Coastguard Worker elif self.output_type == "stdout": 205*da0073e9SAndroid Build Coastguard Worker msg = "{}: {:.2f} us, SOL {:.2f} GB/s, algorithmic {:.2f} GB/s".format( 206*da0073e9SAndroid Build Coastguard Worker result_dict["desc"], 207*da0073e9SAndroid Build Coastguard Worker result_dict["us"], 208*da0073e9SAndroid Build Coastguard Worker result_dict["sol"], 209*da0073e9SAndroid Build Coastguard Worker result_dict["algorithmic"], 210*da0073e9SAndroid Build Coastguard Worker ) 211*da0073e9SAndroid Build Coastguard Worker if "compute_workload" in result_dict: 212*da0073e9SAndroid Build Coastguard Worker msg += f", compute {result_dict['compute_workload']:.2f} Gops/s" 213*da0073e9SAndroid Build Coastguard Worker print(msg) 214*da0073e9SAndroid Build Coastguard Worker else: 215*da0073e9SAndroid Build Coastguard Worker raise Exception("Unknown output_type " + self.output_type) # noqa: TRY002 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 219*da0073e9SAndroid Build Coastguard Workerdef cuda_pointwise_context(loop_levels, block_count, block_size): 220*da0073e9SAndroid Build Coastguard Worker if loop_levels: 221*da0073e9SAndroid Build Coastguard Worker old_loop_levels = torch._C._jit_get_te_cuda_pointwise_loop_levels() 222*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_te_cuda_pointwise_loop_levels(loop_levels) 223*da0073e9SAndroid Build Coastguard Worker if block_count: 224*da0073e9SAndroid Build Coastguard Worker old_block_count = torch._C._jit_get_te_cuda_pointwise_block_count() 225*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_te_cuda_pointwise_block_count(block_count) 226*da0073e9SAndroid Build Coastguard Worker if block_size: 227*da0073e9SAndroid Build Coastguard Worker old_block_size = torch._C._jit_get_te_cuda_pointwise_block_size() 228*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_te_cuda_pointwise_block_size(block_size) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker try: 231*da0073e9SAndroid Build Coastguard Worker yield 232*da0073e9SAndroid Build Coastguard Worker finally: 233*da0073e9SAndroid Build Coastguard Worker if loop_levels: 234*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_te_cuda_pointwise_loop_levels(old_loop_levels) 235*da0073e9SAndroid Build Coastguard Worker if block_count: 236*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_te_cuda_pointwise_block_count(old_block_count) 237*da0073e9SAndroid Build Coastguard Worker if block_size: 238*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_te_cuda_pointwise_block_size(old_block_size) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker# Auxiliary class to facilitate dynamic input shape 242*da0073e9SAndroid Build Coastguard Workerclass DynamicShape: 243*da0073e9SAndroid Build Coastguard Worker r""" 244*da0073e9SAndroid Build Coastguard Worker An Auxiliary class for dynamic shape benchmarks 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker Pre-computes input with random shapes and also 247*da0073e9SAndroid Build Coastguard Worker modifies the compute method so in each call the 248*da0073e9SAndroid Build Coastguard Worker fuser sees a different input tensor shape 249*da0073e9SAndroid Build Coastguard Worker """ 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker # Number of random inputs in an instance 252*da0073e9SAndroid Build Coastguard Worker SAMPLE_SIZE = 100 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker def __init__(self, dynamic_range=1.2): 255*da0073e9SAndroid Build Coastguard Worker self._input_samples = [] 256*da0073e9SAndroid Build Coastguard Worker self._input_sample_index = 0 257*da0073e9SAndroid Build Coastguard Worker self._dynamic_range = ( 258*da0073e9SAndroid Build Coastguard Worker 1.0 / dynamic_range if dynamic_range > 1.0 else dynamic_range 259*da0073e9SAndroid Build Coastguard Worker ) 260*da0073e9SAndroid Build Coastguard Worker self._enable_dynamic_shapes = True 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker # Returns the input test case that current index points to 263*da0073e9SAndroid Build Coastguard Worker @property 264*da0073e9SAndroid Build Coastguard Worker def inputs(self): 265*da0073e9SAndroid Build Coastguard Worker return self._input_samples[self._input_sample_index] 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker # An inputs assignment actually adds a test case in the class buffer 268*da0073e9SAndroid Build Coastguard Worker @inputs.setter 269*da0073e9SAndroid Build Coastguard Worker def inputs(self, val): 270*da0073e9SAndroid Build Coastguard Worker self._input_samples.append(val) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker # Runs normal compute while increment test case index 273*da0073e9SAndroid Build Coastguard Worker def compute(self): 274*da0073e9SAndroid Build Coastguard Worker super().compute() 275*da0073e9SAndroid Build Coastguard Worker self._input_sample_index = (self._input_sample_index + 1) % self.SAMPLE_SIZE 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker # Defined by benchmark, the benchmark needs to specify the input 278*da0073e9SAndroid Build Coastguard Worker # tensor construction in this method, essentially the same way 279*da0073e9SAndroid Build Coastguard Worker # a benchmark creates the inputs list in the initializer 280*da0073e9SAndroid Build Coastguard Worker def instantiate_input(self): 281*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker # Instantiate random shaped inputs and start the benchmark run 284*da0073e9SAndroid Build Coastguard Worker def run(self, args): 285*da0073e9SAndroid Build Coastguard Worker # force disable dynamic shape from command line 286*da0073e9SAndroid Build Coastguard Worker if args.no_dynamic_shape: 287*da0073e9SAndroid Build Coastguard Worker self._enable_dynamic_shapes = False 288*da0073e9SAndroid Build Coastguard Worker self.load_inputs() 289*da0073e9SAndroid Build Coastguard Worker super().run(args) 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker # pre-compute inputs so the creations of random tensors 292*da0073e9SAndroid Build Coastguard Worker # do not add to the compute time 293*da0073e9SAndroid Build Coastguard Worker def load_inputs(self): 294*da0073e9SAndroid Build Coastguard Worker for i in range(self.SAMPLE_SIZE - 1): 295*da0073e9SAndroid Build Coastguard Worker self.instantiate_input() 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker # returns a randomized shape 298*da0073e9SAndroid Build Coastguard Worker def rand_shape(self, shape): 299*da0073e9SAndroid Build Coastguard Worker if not self._enable_dynamic_shapes: 300*da0073e9SAndroid Build Coastguard Worker return shape 301*da0073e9SAndroid Build Coastguard Worker ratios = np.random.uniform(self._dynamic_range, 1.0, len(shape)) 302*da0073e9SAndroid Build Coastguard Worker dyn_shape = list(np.multiply(shape, ratios).astype(int)) 303*da0073e9SAndroid Build Coastguard Worker return dyn_shape 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Workerbenchmark_classes = [] 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Workerdef register_benchmark_class(benchmark_cls): 310*da0073e9SAndroid Build Coastguard Worker benchmark_classes.append(benchmark_cls) 311