xref: /aosp_15_r20/external/pytorch/benchmarks/overrides_benchmark/common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2
3
4NUM_REPEATS = 1000
5NUM_REPEAT_OF_REPEATS = 1000
6
7
8class SubTensor(torch.Tensor):
9    pass
10
11
12class WithTorchFunction:
13    def __init__(self, data, requires_grad=False):
14        if isinstance(data, torch.Tensor):
15            self._tensor = data
16            return
17
18        self._tensor = torch.tensor(data, requires_grad=requires_grad)
19
20    @classmethod
21    def __torch_function__(cls, func, types, args=(), kwargs=None):
22        if kwargs is None:
23            kwargs = {}
24
25        return WithTorchFunction(args[0]._tensor + args[1]._tensor)
26
27
28class SubWithTorchFunction(torch.Tensor):
29    @classmethod
30    def __torch_function__(cls, func, types, args=(), kwargs=None):
31        if kwargs is None:
32            kwargs = {}
33
34        return super().__torch_function__(func, types, args, kwargs)
35