xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/benchmark.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport contextlib
2*da0073e9SAndroid Build Coastguard Workerimport json
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport time
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport numpy as np
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerfrom . import tensor_engine
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerclass Benchmark:
14*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype):
15*da0073e9SAndroid Build Coastguard Worker        self.mode = mode
16*da0073e9SAndroid Build Coastguard Worker        self.deterministic = False
17*da0073e9SAndroid Build Coastguard Worker        self.device = device
18*da0073e9SAndroid Build Coastguard Worker        self.dtype = dtype
19*da0073e9SAndroid Build Coastguard Worker        self.output_type = "stdout"
20*da0073e9SAndroid Build Coastguard Worker        self.print_ir = False
21*da0073e9SAndroid Build Coastguard Worker        self.print_kernel = False
22*da0073e9SAndroid Build Coastguard Worker        if mode == "both":
23*da0073e9SAndroid Build Coastguard Worker            self.requires_grad = True
24*da0073e9SAndroid Build Coastguard Worker        elif mode == "fwd":
25*da0073e9SAndroid Build Coastguard Worker            self.requires_grad = False
26*da0073e9SAndroid Build Coastguard Worker        else:
27*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"invalid mode: {mode}")
28*da0073e9SAndroid Build Coastguard Worker        self.result_grad = None
29*da0073e9SAndroid Build Coastguard Worker        self.grad_variables = []
30*da0073e9SAndroid Build Coastguard Worker        self.engine = tensor_engine.get_engine()
31*da0073e9SAndroid Build Coastguard Worker        self.engine.reset(device)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker        # forward all member functions in self.engine to self
34*da0073e9SAndroid Build Coastguard Worker        for method in dir(self.engine):
35*da0073e9SAndroid Build Coastguard Worker            if not callable(getattr(self.engine, method)):
36*da0073e9SAndroid Build Coastguard Worker                continue
37*da0073e9SAndroid Build Coastguard Worker            # don't forward if this function is overriden here
38*da0073e9SAndroid Build Coastguard Worker            if hasattr(self, method):
39*da0073e9SAndroid Build Coastguard Worker                continue
40*da0073e9SAndroid Build Coastguard Worker            # don't forward if it is a internal function
41*da0073e9SAndroid Build Coastguard Worker            if method.startswith("_"):
42*da0073e9SAndroid Build Coastguard Worker                continue
43*da0073e9SAndroid Build Coastguard Worker            method_engine = getattr(self.engine, method)
44*da0073e9SAndroid Build Coastguard Worker            setattr(self, method, method_engine)
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    def forward(self):
47*da0073e9SAndroid Build Coastguard Worker        """do one step worth of computation"""
48*da0073e9SAndroid Build Coastguard Worker        raise ValueError("this method should be reimplemented by subclass")
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    def check(self):
51*da0073e9SAndroid Build Coastguard Worker        if not self.deterministic:
52*da0073e9SAndroid Build Coastguard Worker            return
53*da0073e9SAndroid Build Coastguard Worker        np.testing.assert_allclose(
54*da0073e9SAndroid Build Coastguard Worker            self.reference(), self.numpy(self.compute()), atol=1e-2
55*da0073e9SAndroid Build Coastguard Worker        )
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    def config(self):
58*da0073e9SAndroid Build Coastguard Worker        """returns an array for the current benchmark configs"""
59*da0073e9SAndroid Build Coastguard Worker        raise ValueError("this method should be reimplemented by subclass")
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    def desc(self):
62*da0073e9SAndroid Build Coastguard Worker        """return the description of the current benchmark"""
63*da0073e9SAndroid Build Coastguard Worker        config = self.config()
64*da0073e9SAndroid Build Coastguard Worker        config_str = "_".join([str(x) for x in config])
65*da0073e9SAndroid Build Coastguard Worker        device = self.device
66*da0073e9SAndroid Build Coastguard Worker        if "NNC_NUM_THREADS" in os.environ:
67*da0073e9SAndroid Build Coastguard Worker            num_threads_str = os.environ["NNC_NUM_THREADS"]
68*da0073e9SAndroid Build Coastguard Worker            device += num_threads_str
69*da0073e9SAndroid Build Coastguard Worker        return f"{self.engine.mode}: {self.module()}_{self.mode}_{device}_{config_str}"
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    @staticmethod
72*da0073e9SAndroid Build Coastguard Worker    def module():
73*da0073e9SAndroid Build Coastguard Worker        raise ValueError("this method should be reimplemented by subclass")
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
76*da0073e9SAndroid Build Coastguard Worker        raise ValueError("this method should be reimplemented by subclass")
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    def compute_workload(self):
79*da0073e9SAndroid Build Coastguard Worker        """return the number of scalar operations it takes to finish the tensor op"""
80*da0073e9SAndroid Build Coastguard Worker        return None
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    @staticmethod
83*da0073e9SAndroid Build Coastguard Worker    def input_iterable():
84*da0073e9SAndroid Build Coastguard Worker        """A benchmark child class should return true if it utilizes the input iter arg"""
85*da0073e9SAndroid Build Coastguard Worker        return False
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    def dtype_to_bytes(self):
88*da0073e9SAndroid Build Coastguard Worker        return torch.tensor(0, dtype=self.dtype).element_size()
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    @staticmethod
91*da0073e9SAndroid Build Coastguard Worker    def default_configs():
92*da0073e9SAndroid Build Coastguard Worker        """return a list of defualt configs for this benchmark"""
93*da0073e9SAndroid Build Coastguard Worker        raise ValueError("this method should be reimplemented by subclass")
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    def is_supported(self):
96*da0073e9SAndroid Build Coastguard Worker        return True
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    def rand(self, shape, device=None, dtype=None, requires_grad=False):
99*da0073e9SAndroid Build Coastguard Worker        v = self.engine.rand(
100*da0073e9SAndroid Build Coastguard Worker            shape, device=device, dtype=dtype, requires_grad=requires_grad
101*da0073e9SAndroid Build Coastguard Worker        )
102*da0073e9SAndroid Build Coastguard Worker        if requires_grad:
103*da0073e9SAndroid Build Coastguard Worker            self.grad_variables.append(v)
104*da0073e9SAndroid Build Coastguard Worker        return v
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    def nchw_rand(self, shape, device=None, requires_grad=False):
107*da0073e9SAndroid Build Coastguard Worker        v = self.engine.nchw_rand(shape, device=device, requires_grad=requires_grad)
108*da0073e9SAndroid Build Coastguard Worker        if requires_grad:
109*da0073e9SAndroid Build Coastguard Worker            self.grad_variables.append(v)
110*da0073e9SAndroid Build Coastguard Worker        return v
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker    def compute(self):
113*da0073e9SAndroid Build Coastguard Worker        if self.bm_jit:
114*da0073e9SAndroid Build Coastguard Worker            return self.bm_jit(*self.inputs)
115*da0073e9SAndroid Build Coastguard Worker        else:
116*da0073e9SAndroid Build Coastguard Worker            return self.forward(*self.inputs)
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    def run(self, args):
119*da0073e9SAndroid Build Coastguard Worker        self.print_ir = args.print_ir
120*da0073e9SAndroid Build Coastguard Worker        if args.cuda_fuser == "old":
121*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_override_can_fuse_on_gpu(True)
122*da0073e9SAndroid Build Coastguard Worker            if args.print_kernel:
123*da0073e9SAndroid Build Coastguard Worker                os.environ["PYTORCH_FUSION_DEBUG"] = "1"
124*da0073e9SAndroid Build Coastguard Worker            return self.run_impl(True)
125*da0073e9SAndroid Build Coastguard Worker        elif args.cuda_fuser == "te":
126*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_set_texpr_fuser_enabled(True)
127*da0073e9SAndroid Build Coastguard Worker            with cuda_pointwise_context(
128*da0073e9SAndroid Build Coastguard Worker                args.cuda_pointwise_loop_levels,
129*da0073e9SAndroid Build Coastguard Worker                args.cuda_pointwise_block_count,
130*da0073e9SAndroid Build Coastguard Worker                args.cuda_pointwise_block_size,
131*da0073e9SAndroid Build Coastguard Worker            ):
132*da0073e9SAndroid Build Coastguard Worker                return self.run_impl(True)
133*da0073e9SAndroid Build Coastguard Worker        elif args.cuda_fuser == "nvf":
134*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_set_nvfuser_enabled(True)
135*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_set_profiling_executor(True)
136*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_set_profiling_mode(True)
137*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_override_can_fuse_on_cpu(False)
138*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_override_can_fuse_on_gpu(False)
139*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_set_bailout_depth(20)
140*da0073e9SAndroid Build Coastguard Worker            if args.print_kernel:
141*da0073e9SAndroid Build Coastguard Worker                os.environ["PYTORCH_CUDA_FUSER_DEBUG"] = "1"
142*da0073e9SAndroid Build Coastguard Worker            return self.run_impl(True)
143*da0073e9SAndroid Build Coastguard Worker        else:
144*da0073e9SAndroid Build Coastguard Worker            return self.run_impl(False)
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker    def run_impl(self, use_fuser):
147*da0073e9SAndroid Build Coastguard Worker        warmups = 10
148*da0073e9SAndroid Build Coastguard Worker        if self.device == "cuda":
149*da0073e9SAndroid Build Coastguard Worker            iters = 1000
150*da0073e9SAndroid Build Coastguard Worker        else:
151*da0073e9SAndroid Build Coastguard Worker            iters = 10
152*da0073e9SAndroid Build Coastguard Worker        engine = tensor_engine.get_engine()
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker        self.bm_jit = None
155*da0073e9SAndroid Build Coastguard Worker        for i in range(warmups + iters):
156*da0073e9SAndroid Build Coastguard Worker            if i == warmups:
157*da0073e9SAndroid Build Coastguard Worker                if self.device == "cuda":
158*da0073e9SAndroid Build Coastguard Worker                    engine.sync_cuda()
159*da0073e9SAndroid Build Coastguard Worker                time_start = time.time()
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker            if i == 0:
162*da0073e9SAndroid Build Coastguard Worker                if self.jit_mode == "trace" and use_fuser:
163*da0073e9SAndroid Build Coastguard Worker                    self.bm_jit = torch.jit.trace(
164*da0073e9SAndroid Build Coastguard Worker                        self.forward, example_inputs=self.inputs, check_trace=False
165*da0073e9SAndroid Build Coastguard Worker                    )
166*da0073e9SAndroid Build Coastguard Worker                if callable(getattr(self, "reference", None)):
167*da0073e9SAndroid Build Coastguard Worker                    self.check()
168*da0073e9SAndroid Build Coastguard Worker                else:
169*da0073e9SAndroid Build Coastguard Worker                    print("Warning: no reference result for ", self.module())
170*da0073e9SAndroid Build Coastguard Worker            elif i == 1:
171*da0073e9SAndroid Build Coastguard Worker                # The fusion graph is visible after the first iter is executed
172*da0073e9SAndroid Build Coastguard Worker                if self.jit_mode == "trace" and use_fuser and self.print_ir:
173*da0073e9SAndroid Build Coastguard Worker                    print(self.bm_jit.graph_for(*self.inputs))
174*da0073e9SAndroid Build Coastguard Worker            z = self.compute()
175*da0073e9SAndroid Build Coastguard Worker            if self.mode == "both":
176*da0073e9SAndroid Build Coastguard Worker                if self.result_grad is None:
177*da0073e9SAndroid Build Coastguard Worker                    self.result_grad = engine.rand_like(z)
178*da0073e9SAndroid Build Coastguard Worker                engine.backward([z], [self.result_grad], self.grad_variables)
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker        if self.device == "cuda":
181*da0073e9SAndroid Build Coastguard Worker            engine.sync_cuda()
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        duration = time.time() - time_start
184*da0073e9SAndroid Build Coastguard Worker        iter_time = duration / iters
185*da0073e9SAndroid Build Coastguard Worker        memory_workload = self.memory_workload()
186*da0073e9SAndroid Build Coastguard Worker        compute_workload = self.compute_workload()
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker        result_dict = {
189*da0073e9SAndroid Build Coastguard Worker            "desc": self.desc(),
190*da0073e9SAndroid Build Coastguard Worker            "us": iter_time * 1e6,
191*da0073e9SAndroid Build Coastguard Worker            "sol": memory_workload["sol"] * self.dtype_to_bytes() / iter_time / 1e9,
192*da0073e9SAndroid Build Coastguard Worker            "algorithmic": memory_workload["algorithmic"]
193*da0073e9SAndroid Build Coastguard Worker            * self.dtype_to_bytes()
194*da0073e9SAndroid Build Coastguard Worker            / iter_time
195*da0073e9SAndroid Build Coastguard Worker            / 1e9,
196*da0073e9SAndroid Build Coastguard Worker        }
197*da0073e9SAndroid Build Coastguard Worker        if compute_workload:
198*da0073e9SAndroid Build Coastguard Worker            result_dict["compute_workload"] = compute_workload / iter_time / 1e9
199*da0073e9SAndroid Build Coastguard Worker        self.dump_result(result_dict)
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker    def dump_result(self, result_dict):
202*da0073e9SAndroid Build Coastguard Worker        if self.output_type == "json":
203*da0073e9SAndroid Build Coastguard Worker            print(json.dumps(result_dict))
204*da0073e9SAndroid Build Coastguard Worker        elif self.output_type == "stdout":
205*da0073e9SAndroid Build Coastguard Worker            msg = "{}: {:.2f} us, SOL {:.2f} GB/s, algorithmic {:.2f} GB/s".format(
206*da0073e9SAndroid Build Coastguard Worker                result_dict["desc"],
207*da0073e9SAndroid Build Coastguard Worker                result_dict["us"],
208*da0073e9SAndroid Build Coastguard Worker                result_dict["sol"],
209*da0073e9SAndroid Build Coastguard Worker                result_dict["algorithmic"],
210*da0073e9SAndroid Build Coastguard Worker            )
211*da0073e9SAndroid Build Coastguard Worker            if "compute_workload" in result_dict:
212*da0073e9SAndroid Build Coastguard Worker                msg += f", compute {result_dict['compute_workload']:.2f} Gops/s"
213*da0073e9SAndroid Build Coastguard Worker            print(msg)
214*da0073e9SAndroid Build Coastguard Worker        else:
215*da0073e9SAndroid Build Coastguard Worker            raise Exception("Unknown output_type " + self.output_type)  # noqa: TRY002
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
219*da0073e9SAndroid Build Coastguard Workerdef cuda_pointwise_context(loop_levels, block_count, block_size):
220*da0073e9SAndroid Build Coastguard Worker    if loop_levels:
221*da0073e9SAndroid Build Coastguard Worker        old_loop_levels = torch._C._jit_get_te_cuda_pointwise_loop_levels()
222*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_te_cuda_pointwise_loop_levels(loop_levels)
223*da0073e9SAndroid Build Coastguard Worker    if block_count:
224*da0073e9SAndroid Build Coastguard Worker        old_block_count = torch._C._jit_get_te_cuda_pointwise_block_count()
225*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_te_cuda_pointwise_block_count(block_count)
226*da0073e9SAndroid Build Coastguard Worker    if block_size:
227*da0073e9SAndroid Build Coastguard Worker        old_block_size = torch._C._jit_get_te_cuda_pointwise_block_size()
228*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_te_cuda_pointwise_block_size(block_size)
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker    try:
231*da0073e9SAndroid Build Coastguard Worker        yield
232*da0073e9SAndroid Build Coastguard Worker    finally:
233*da0073e9SAndroid Build Coastguard Worker        if loop_levels:
234*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_set_te_cuda_pointwise_loop_levels(old_loop_levels)
235*da0073e9SAndroid Build Coastguard Worker        if block_count:
236*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_set_te_cuda_pointwise_block_count(old_block_count)
237*da0073e9SAndroid Build Coastguard Worker        if block_size:
238*da0073e9SAndroid Build Coastguard Worker            torch._C._jit_set_te_cuda_pointwise_block_size(old_block_size)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker# Auxiliary class to facilitate dynamic input shape
242*da0073e9SAndroid Build Coastguard Workerclass DynamicShape:
243*da0073e9SAndroid Build Coastguard Worker    r"""
244*da0073e9SAndroid Build Coastguard Worker    An Auxiliary class for dynamic shape benchmarks
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker    Pre-computes input with random shapes and also
247*da0073e9SAndroid Build Coastguard Worker    modifies the compute method so in each call the
248*da0073e9SAndroid Build Coastguard Worker    fuser sees a different input tensor shape
249*da0073e9SAndroid Build Coastguard Worker    """
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker    # Number of random inputs in an instance
252*da0073e9SAndroid Build Coastguard Worker    SAMPLE_SIZE = 100
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker    def __init__(self, dynamic_range=1.2):
255*da0073e9SAndroid Build Coastguard Worker        self._input_samples = []
256*da0073e9SAndroid Build Coastguard Worker        self._input_sample_index = 0
257*da0073e9SAndroid Build Coastguard Worker        self._dynamic_range = (
258*da0073e9SAndroid Build Coastguard Worker            1.0 / dynamic_range if dynamic_range > 1.0 else dynamic_range
259*da0073e9SAndroid Build Coastguard Worker        )
260*da0073e9SAndroid Build Coastguard Worker        self._enable_dynamic_shapes = True
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker    # Returns the input test case that current index points to
263*da0073e9SAndroid Build Coastguard Worker    @property
264*da0073e9SAndroid Build Coastguard Worker    def inputs(self):
265*da0073e9SAndroid Build Coastguard Worker        return self._input_samples[self._input_sample_index]
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    # An inputs assignment actually adds a test case in the class buffer
268*da0073e9SAndroid Build Coastguard Worker    @inputs.setter
269*da0073e9SAndroid Build Coastguard Worker    def inputs(self, val):
270*da0073e9SAndroid Build Coastguard Worker        self._input_samples.append(val)
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker    # Runs normal compute while increment test case index
273*da0073e9SAndroid Build Coastguard Worker    def compute(self):
274*da0073e9SAndroid Build Coastguard Worker        super().compute()
275*da0073e9SAndroid Build Coastguard Worker        self._input_sample_index = (self._input_sample_index + 1) % self.SAMPLE_SIZE
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker    # Defined by benchmark, the benchmark needs to specify the input
278*da0073e9SAndroid Build Coastguard Worker    # tensor construction in this method, essentially the same way
279*da0073e9SAndroid Build Coastguard Worker    # a benchmark creates the inputs list in the initializer
280*da0073e9SAndroid Build Coastguard Worker    def instantiate_input(self):
281*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker    # Instantiate random shaped inputs and start the benchmark run
284*da0073e9SAndroid Build Coastguard Worker    def run(self, args):
285*da0073e9SAndroid Build Coastguard Worker        # force disable dynamic shape from command line
286*da0073e9SAndroid Build Coastguard Worker        if args.no_dynamic_shape:
287*da0073e9SAndroid Build Coastguard Worker            self._enable_dynamic_shapes = False
288*da0073e9SAndroid Build Coastguard Worker        self.load_inputs()
289*da0073e9SAndroid Build Coastguard Worker        super().run(args)
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker    # pre-compute inputs so the creations of random tensors
292*da0073e9SAndroid Build Coastguard Worker    # do not add to the compute time
293*da0073e9SAndroid Build Coastguard Worker    def load_inputs(self):
294*da0073e9SAndroid Build Coastguard Worker        for i in range(self.SAMPLE_SIZE - 1):
295*da0073e9SAndroid Build Coastguard Worker            self.instantiate_input()
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    # returns a randomized shape
298*da0073e9SAndroid Build Coastguard Worker    def rand_shape(self, shape):
299*da0073e9SAndroid Build Coastguard Worker        if not self._enable_dynamic_shapes:
300*da0073e9SAndroid Build Coastguard Worker            return shape
301*da0073e9SAndroid Build Coastguard Worker        ratios = np.random.uniform(self._dynamic_range, 1.0, len(shape))
302*da0073e9SAndroid Build Coastguard Worker        dyn_shape = list(np.multiply(shape, ratios).astype(int))
303*da0073e9SAndroid Build Coastguard Worker        return dyn_shape
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Workerbenchmark_classes = []
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Workerdef register_benchmark_class(benchmark_cls):
310*da0073e9SAndroid Build Coastguard Worker    benchmark_classes.append(benchmark_cls)
311