xref: /aosp_15_r20/external/pytorch/test/dynamo/test_view.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport torch
3*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo
4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker@torch._dynamo.config.patch("capture_scalar_outputs", True)
8*da0073e9SAndroid Build Coastguard Workerclass ViewTests(torch._dynamo.test_case.TestCase):
9*da0073e9SAndroid Build Coastguard Worker    def test_view_to_2d(self):
10*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True, backend="eager")
11*da0073e9SAndroid Build Coastguard Worker        def f(t, _u0):
12*da0073e9SAndroid Build Coastguard Worker            u0 = t[0].item()
13*da0073e9SAndroid Build Coastguard Worker            u1 = t[1].item()
14*da0073e9SAndroid Build Coastguard Worker            torch._check_is_size(u0)
15*da0073e9SAndroid Build Coastguard Worker            torch._check_is_size(u1)
16*da0073e9SAndroid Build Coastguard Worker            n = u0 * u1
17*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(n)
18*da0073e9SAndroid Build Coastguard Worker            return a.view(-1, _u0)
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([2, 4], dtype=torch.int32)
21*da0073e9SAndroid Build Coastguard Worker        f(t, 2)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    def test_view_to_1d(self):
24*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True, backend="eager")
25*da0073e9SAndroid Build Coastguard Worker        def f(t, _n):
26*da0073e9SAndroid Build Coastguard Worker            u0 = t[0].item()
27*da0073e9SAndroid Build Coastguard Worker            u1 = t[1].item()
28*da0073e9SAndroid Build Coastguard Worker            torch._check_is_size(u0)
29*da0073e9SAndroid Build Coastguard Worker            torch._check_is_size(u1)
30*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(u0, u1)
31*da0073e9SAndroid Build Coastguard Worker            return a.view(_n)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([2, 4], dtype=torch.int32)
34*da0073e9SAndroid Build Coastguard Worker        f(t, 8)
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
38*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    run_tests()
41