1import torch 2 3 4NUM_REPEATS = 1000 5NUM_REPEAT_OF_REPEATS = 1000 6 7 8class SubTensor(torch.Tensor): 9 pass 10 11 12class WithTorchFunction: 13 def __init__(self, data, requires_grad=False): 14 if isinstance(data, torch.Tensor): 15 self._tensor = data 16 return 17 18 self._tensor = torch.tensor(data, requires_grad=requires_grad) 19 20 @classmethod 21 def __torch_function__(cls, func, types, args=(), kwargs=None): 22 if kwargs is None: 23 kwargs = {} 24 25 return WithTorchFunction(args[0]._tensor + args[1]._tensor) 26 27 28class SubWithTorchFunction(torch.Tensor): 29 @classmethod 30 def __torch_function__(cls, func, types, args=(), kwargs=None): 31 if kwargs is None: 32 kwargs = {} 33 34 return super().__torch_function__(func, types, args, kwargs) 35