1# Owner(s): ["module: dynamo"] 2import unittest 3 4import torch 5import torch._dynamo.test_case 6import torch._dynamo.testing 7from torch.testing._internal.common_utils import IS_FBCODE 8 9 10class MutationExportTests(torch._dynamo.test_case.TestCase): 11 def check_failure_on_export(self, mod, *args): 12 with self.assertRaises(AssertionError): 13 torch._dynamo.export(mod)(*args) 14 15 def check_same_with_export(self, mod, arg): 16 real_result = mod(arg) 17 graph, _ = torch._dynamo.export(mod)(arg) 18 result = graph(arg) 19 self.assertEqual(result, real_result) 20 21 def test_module_attribute_mutation_violation_positive_1(self): 22 # Mutating attribute with a Tensor type 23 class Foo(torch.nn.Module): 24 def __init__(self) -> None: 25 super().__init__() 26 self.a = torch.randn(3, 2) 27 28 def forward(self, x): 29 self.a = self.a.to(torch.float64) 30 return x.sum() + self.a.sum() 31 32 self.check_failure_on_export(Foo(), torch.randn(3, 2)) 33 34 def test_module_attribute_mutation_violation_negative_1(self): 35 # Mutating attribute with a Tensor type inside __init__ but 36 # not in forward() 37 class Foo(torch.nn.Module): 38 def __init__(self) -> None: 39 super().__init__() 40 self.a = torch.randn(3, 2) 41 42 def forward(self, x): 43 return x.sum() + self.a.to(torch.float64).sum() 44 45 self.check_same_with_export(Foo(), torch.randn(3, 2)) 46 47 def test_module_attribute_mutation_violation_negative_2(self): 48 # Mutating attribute with a Tensor type inside __init__ twice 49 class Foo(torch.nn.Module): 50 def __init__(self) -> None: 51 super().__init__() 52 self.a = torch.randn(3, 2) 53 self.a = self.a.to(torch.float64) 54 55 def forward(self, x): 56 return x.sum() + self.a.sum() 57 58 self.check_same_with_export(Foo(), torch.randn(3, 2)) 59 60 def test_module_attribute_mutation_violation_negative_3(self): 61 # Mutating local variable inside forward() 62 class Foo(torch.nn.Module): 63 def __init__(self) -> None: 64 super().__init__() 65 self.a = torch.randn(3, 2) 66 67 def forward(self, x): 68 b = 1 69 b = b * 5 70 return x.sum() + self.a.sum() + b 71 72 self.check_same_with_export(Foo(), torch.randn(3, 2)) 73 74 @unittest.skipIf(IS_FBCODE, "Broken in fbcode") 75 def test_module_attribute_mutation_violation_negative_4(self): 76 # Mutating attribute with a Tensor type 77 # But not exporting but using eager mode as well as dynamo optimize mode 78 class Foo(torch.nn.Module): 79 def __init__(self) -> None: 80 super().__init__() 81 self.a = torch.randn(3, 2) 82 83 def forward(self, x): 84 self.a = self.a.to(torch.float64) 85 return x.sum() + self.a.sum() 86 87 mod = Foo() 88 arg = torch.randn(3, 2) 89 real_result = mod(arg) 90 opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod) 91 self.assertEqual(opt_mod(arg), real_result) 92 93 94if __name__ == "__main__": 95 from torch._dynamo.test_case import run_tests 96 97 run_tests() 98