1# Owner(s): ["module: dynamo"] 2import torch 3import torch._dynamo 4import torch._dynamo.test_case 5 6 7@torch._dynamo.config.patch("capture_scalar_outputs", True) 8class ViewTests(torch._dynamo.test_case.TestCase): 9 def test_view_to_2d(self): 10 @torch.compile(fullgraph=True, backend="eager") 11 def f(t, _u0): 12 u0 = t[0].item() 13 u1 = t[1].item() 14 torch._check_is_size(u0) 15 torch._check_is_size(u1) 16 n = u0 * u1 17 a = torch.randn(n) 18 return a.view(-1, _u0) 19 20 t = torch.tensor([2, 4], dtype=torch.int32) 21 f(t, 2) 22 23 def test_view_to_1d(self): 24 @torch.compile(fullgraph=True, backend="eager") 25 def f(t, _n): 26 u0 = t[0].item() 27 u1 = t[1].item() 28 torch._check_is_size(u0) 29 torch._check_is_size(u1) 30 a = torch.randn(u0, u1) 31 return a.view(_n) 32 33 t = torch.tensor([2, 4], dtype=torch.int32) 34 f(t, 8) 35 36 37if __name__ == "__main__": 38 from torch._dynamo.test_case import run_tests 39 40 run_tests() 41