xref: /aosp_15_r20/external/pytorch/test/dynamo/test_view.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import torch
3import torch._dynamo
4import torch._dynamo.test_case
5
6
7@torch._dynamo.config.patch("capture_scalar_outputs", True)
8class ViewTests(torch._dynamo.test_case.TestCase):
9    def test_view_to_2d(self):
10        @torch.compile(fullgraph=True, backend="eager")
11        def f(t, _u0):
12            u0 = t[0].item()
13            u1 = t[1].item()
14            torch._check_is_size(u0)
15            torch._check_is_size(u1)
16            n = u0 * u1
17            a = torch.randn(n)
18            return a.view(-1, _u0)
19
20        t = torch.tensor([2, 4], dtype=torch.int32)
21        f(t, 2)
22
23    def test_view_to_1d(self):
24        @torch.compile(fullgraph=True, backend="eager")
25        def f(t, _n):
26            u0 = t[0].item()
27            u1 = t[1].item()
28            torch._check_is_size(u0)
29            torch._check_is_size(u1)
30            a = torch.randn(u0, u1)
31            return a.view(_n)
32
33        t = torch.tensor([2, 4], dtype=torch.int32)
34        f(t, 8)
35
36
37if __name__ == "__main__":
38    from torch._dynamo.test_case import run_tests
39
40    run_tests()
41