xref: /aosp_15_r20/external/pytorch/test/dynamo/test_export_mutations.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import unittest
3
4import torch
5import torch._dynamo.test_case
6import torch._dynamo.testing
7from torch.testing._internal.common_utils import IS_FBCODE
8
9
10class MutationExportTests(torch._dynamo.test_case.TestCase):
11    def check_failure_on_export(self, mod, *args):
12        with self.assertRaises(AssertionError):
13            torch._dynamo.export(mod)(*args)
14
15    def check_same_with_export(self, mod, arg):
16        real_result = mod(arg)
17        graph, _ = torch._dynamo.export(mod)(arg)
18        result = graph(arg)
19        self.assertEqual(result, real_result)
20
21    def test_module_attribute_mutation_violation_positive_1(self):
22        # Mutating attribute with a Tensor type
23        class Foo(torch.nn.Module):
24            def __init__(self) -> None:
25                super().__init__()
26                self.a = torch.randn(3, 2)
27
28            def forward(self, x):
29                self.a = self.a.to(torch.float64)
30                return x.sum() + self.a.sum()
31
32        self.check_failure_on_export(Foo(), torch.randn(3, 2))
33
34    def test_module_attribute_mutation_violation_negative_1(self):
35        # Mutating attribute with a Tensor type inside __init__ but
36        # not in forward()
37        class Foo(torch.nn.Module):
38            def __init__(self) -> None:
39                super().__init__()
40                self.a = torch.randn(3, 2)
41
42            def forward(self, x):
43                return x.sum() + self.a.to(torch.float64).sum()
44
45        self.check_same_with_export(Foo(), torch.randn(3, 2))
46
47    def test_module_attribute_mutation_violation_negative_2(self):
48        # Mutating attribute with a Tensor type inside __init__ twice
49        class Foo(torch.nn.Module):
50            def __init__(self) -> None:
51                super().__init__()
52                self.a = torch.randn(3, 2)
53                self.a = self.a.to(torch.float64)
54
55            def forward(self, x):
56                return x.sum() + self.a.sum()
57
58        self.check_same_with_export(Foo(), torch.randn(3, 2))
59
60    def test_module_attribute_mutation_violation_negative_3(self):
61        # Mutating local variable inside forward()
62        class Foo(torch.nn.Module):
63            def __init__(self) -> None:
64                super().__init__()
65                self.a = torch.randn(3, 2)
66
67            def forward(self, x):
68                b = 1
69                b = b * 5
70                return x.sum() + self.a.sum() + b
71
72        self.check_same_with_export(Foo(), torch.randn(3, 2))
73
74    @unittest.skipIf(IS_FBCODE, "Broken in fbcode")
75    def test_module_attribute_mutation_violation_negative_4(self):
76        # Mutating attribute with a Tensor type
77        # But not exporting but using eager mode as well as dynamo optimize mode
78        class Foo(torch.nn.Module):
79            def __init__(self) -> None:
80                super().__init__()
81                self.a = torch.randn(3, 2)
82
83            def forward(self, x):
84                self.a = self.a.to(torch.float64)
85                return x.sum() + self.a.sum()
86
87        mod = Foo()
88        arg = torch.randn(3, 2)
89        real_result = mod(arg)
90        opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod)
91        self.assertEqual(opt_mod(arg), real_result)
92
93
94if __name__ == "__main__":
95    from torch._dynamo.test_case import run_tests
96
97    run_tests()
98