1# Owner(s): ["module: dynamo"] 2 3import torch 4import torch._dynamo 5import torch._dynamo.test_case 6import torch.nn as nn 7from torch._dynamo.source import ( 8 AttrSource, 9 GlobalSource, 10 is_from_local_source, 11 LocalSource, 12) 13 14 15class CausalLMOutputWithPast: 16 value = 5 17 18 19class SourceTests(torch._dynamo.test_case.TestCase): 20 def test_is_local(self): 21 x_src = LocalSource("x") 22 y_src = GlobalSource("y") 23 24 attr_x_a = AttrSource(x_src, "a") 25 attr_y_b = AttrSource(y_src, "b") 26 27 self.assertTrue(is_from_local_source(attr_x_a)) 28 self.assertEqual(is_from_local_source(attr_y_b), False) 29 30 def test_property_closure(self): 31 def external_property(): 32 closed_value = 7 33 34 def internal_function(self): 35 return closed_value 36 37 return internal_function 38 39 class Elements: 40 myprop = property(external_property()) 41 42 def func(elements): 43 if not elements.myprop: 44 return torch.tensor([1, 2, 3]) 45 else: 46 return torch.tensor([4, 5, 6]) 47 48 e = Elements() 49 a = func(e) 50 b = torch.compile(func, backend="eager", fullgraph=True)(e) 51 self.assertEqual(a, b) 52 53 def test_supported_nodes(self): 54 class Model(nn.Module): 55 def __init__(self) -> None: 56 super().__init__() 57 self.x = torch.randn(10, 10) 58 59 def forward(self): 60 if ( 61 torch.utils._pytree.SUPPORTED_NODES[CausalLMOutputWithPast].type 62 == int 63 ): 64 x = torch.sin(self.x) 65 else: 66 x = torch.cos(self.x) 67 return x 68 69 torch.utils._pytree.register_pytree_node( 70 CausalLMOutputWithPast, 71 lambda x: ((), None), 72 lambda x, _: CausalLMOutputWithPast(), 73 ) 74 75 torch.export.export(Model(), ()) 76 77 78if __name__ == "__main__": 79 torch._dynamo.test_case.run_tests() 80