1import torch 2 3 4class TorchTensorEngine: 5 def rand(self, shape, device=None, dtype=None, requires_grad=False): 6 return torch.rand( 7 shape, device=device, dtype=dtype, requires_grad=requires_grad 8 ) 9 10 def randn(self, shape, device=None, dtype=None, requires_grad=False): 11 return torch.randn( 12 shape, device=device, dtype=dtype, requires_grad=requires_grad 13 ) 14 15 def nchw_rand(self, shape, device=None, requires_grad=False): 16 return self.rand(shape, device=device, requires_grad=requires_grad) 17 18 def reset(self, _): 19 pass 20 21 def rand_like(self, v): 22 return torch.rand_like(v) 23 24 def numpy(self, t): 25 return t.cpu().numpy() 26 27 def mul(self, t1, t2): 28 return t1 * t2 29 30 def add(self, t1, t2): 31 return t1 + t2 32 33 def batch_norm(self, data, mean, var, training): 34 return torch.nn.functional.batch_norm(data, mean, var, training=training) 35 36 def instance_norm(self, data): 37 return torch.nn.functional.instance_norm(data) 38 39 def layer_norm(self, data, shape): 40 return torch.nn.functional.layer_norm(data, shape) 41 42 def sync_cuda(self): 43 torch.cuda.synchronize() 44 45 def backward(self, tensors, grad_tensors, _): 46 torch.autograd.backward(tensors, grad_tensors=grad_tensors) 47 48 def sum(self, data, dims): 49 return torch.sum(data, dims) 50 51 def softmax(self, data, dim=None, dtype=None): 52 return torch.nn.functional.softmax(data, dim, dtype) 53 54 def cat(self, inputs, dim=0): 55 return torch.cat(inputs, dim=dim) 56 57 def clamp(self, data, min, max): 58 return torch.clamp(data, min=min, max=max) 59 60 def relu(self, data): 61 return torch.nn.functional.relu(data) 62 63 def tanh(self, data): 64 return torch.tanh(data) 65 66 def max_pool2d(self, data, kernel_size, stride=1): 67 return torch.nn.functional.max_pool2d(data, kernel_size, stride=stride) 68 69 def avg_pool2d(self, data, kernel_size, stride=1): 70 return torch.nn.functional.avg_pool2d(data, kernel_size, stride=stride) 71 72 def conv2d_layer(self, ic, oc, kernel_size, groups=1): 73 return torch.nn.Conv2d(ic, oc, kernel_size, groups=groups) 74 75 def matmul(self, t1, t2): 76 return torch.matmul(t1, t2) 77 78 def to_device(self, module, device): 79 return module.to(device) 80