xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/scratch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2
3
4@torch.jit.script
5def fn(x, scale, shift):
6    return scale * x / shift
7
8
9@torch.jit.script
10def recurrent(x, scale, shift):
11    y = x
12    for i in range(100):
13        y = fn(y, scale, shift)
14    return y
15
16
17x = torch.randn(2, 2, device="cuda")
18scale = torch.randn(2, 2, device="cuda", requires_grad=True)
19shift = torch.randn(2, 2, device="cuda", requires_grad=True)
20inputs = [x, scale, shift]
21
22
23out = recurrent(x, scale, shift)
24recurrent.graph_for(x, scale, shift)
25
26
27import torch
28
29
30@torch.jit.script
31def recurrent_scaleshift(x, scale, shift):
32    y = x
33    for i in range(64):
34        y = scale * y + shift
35    return y
36
37
38x = torch.randn(2, 2, device="cuda")
39scale = torch.randn(2, 2, device="cuda", requires_grad=True)
40shift = torch.randn(2, 2, device="cuda", requires_grad=True)
41inputs = [x, scale, shift]
42out = recurrent_scaleshift(x, scale, shift)
43recurrent_scaleshift.graph_for(x, scale, shift)
44
45
46import torch
47
48
49x = torch.tensor([])
50x.requires_grad = True
51x.mean().backward()  # no error triggered
52x = x.cuda()
53x.mean().backward()
54