1import torch 2 3 4@torch.jit.script 5def fn(x, scale, shift): 6 return scale * x / shift 7 8 9@torch.jit.script 10def recurrent(x, scale, shift): 11 y = x 12 for i in range(100): 13 y = fn(y, scale, shift) 14 return y 15 16 17x = torch.randn(2, 2, device="cuda") 18scale = torch.randn(2, 2, device="cuda", requires_grad=True) 19shift = torch.randn(2, 2, device="cuda", requires_grad=True) 20inputs = [x, scale, shift] 21 22 23out = recurrent(x, scale, shift) 24recurrent.graph_for(x, scale, shift) 25 26 27import torch 28 29 30@torch.jit.script 31def recurrent_scaleshift(x, scale, shift): 32 y = x 33 for i in range(64): 34 y = scale * y + shift 35 return y 36 37 38x = torch.randn(2, 2, device="cuda") 39scale = torch.randn(2, 2, device="cuda", requires_grad=True) 40shift = torch.randn(2, 2, device="cuda", requires_grad=True) 41inputs = [x, scale, shift] 42out = recurrent_scaleshift(x, scale, shift) 43recurrent_scaleshift.graph_for(x, scale, shift) 44 45 46import torch 47 48 49x = torch.tensor([]) 50x.requires_grad = True 51x.mean().backward() # no error triggered 52x = x.cuda() 53x.mean().backward() 54