1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerimport torch 3*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo 4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker@torch._dynamo.config.patch("capture_scalar_outputs", True) 8*da0073e9SAndroid Build Coastguard Workerclass ViewTests(torch._dynamo.test_case.TestCase): 9*da0073e9SAndroid Build Coastguard Worker def test_view_to_2d(self): 10*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True, backend="eager") 11*da0073e9SAndroid Build Coastguard Worker def f(t, _u0): 12*da0073e9SAndroid Build Coastguard Worker u0 = t[0].item() 13*da0073e9SAndroid Build Coastguard Worker u1 = t[1].item() 14*da0073e9SAndroid Build Coastguard Worker torch._check_is_size(u0) 15*da0073e9SAndroid Build Coastguard Worker torch._check_is_size(u1) 16*da0073e9SAndroid Build Coastguard Worker n = u0 * u1 17*da0073e9SAndroid Build Coastguard Worker a = torch.randn(n) 18*da0073e9SAndroid Build Coastguard Worker return a.view(-1, _u0) 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([2, 4], dtype=torch.int32) 21*da0073e9SAndroid Build Coastguard Worker f(t, 2) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker def test_view_to_1d(self): 24*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True, backend="eager") 25*da0073e9SAndroid Build Coastguard Worker def f(t, _n): 26*da0073e9SAndroid Build Coastguard Worker u0 = t[0].item() 27*da0073e9SAndroid Build Coastguard Worker u1 = t[1].item() 28*da0073e9SAndroid Build Coastguard Worker torch._check_is_size(u0) 29*da0073e9SAndroid Build Coastguard Worker torch._check_is_size(u1) 30*da0073e9SAndroid Build Coastguard Worker a = torch.randn(u0, u1) 31*da0073e9SAndroid Build Coastguard Worker return a.view(_n) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([2, 4], dtype=torch.int32) 34*da0073e9SAndroid Build Coastguard Worker f(t, 8) 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 38*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker run_tests() 41