1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo 5*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 6*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 7*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.source import ( 8*da0073e9SAndroid Build Coastguard Worker AttrSource, 9*da0073e9SAndroid Build Coastguard Worker GlobalSource, 10*da0073e9SAndroid Build Coastguard Worker is_from_local_source, 11*da0073e9SAndroid Build Coastguard Worker LocalSource, 12*da0073e9SAndroid Build Coastguard Worker) 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerclass CausalLMOutputWithPast: 16*da0073e9SAndroid Build Coastguard Worker value = 5 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerclass SourceTests(torch._dynamo.test_case.TestCase): 20*da0073e9SAndroid Build Coastguard Worker def test_is_local(self): 21*da0073e9SAndroid Build Coastguard Worker x_src = LocalSource("x") 22*da0073e9SAndroid Build Coastguard Worker y_src = GlobalSource("y") 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker attr_x_a = AttrSource(x_src, "a") 25*da0073e9SAndroid Build Coastguard Worker attr_y_b = AttrSource(y_src, "b") 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_from_local_source(attr_x_a)) 28*da0073e9SAndroid Build Coastguard Worker self.assertEqual(is_from_local_source(attr_y_b), False) 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker def test_property_closure(self): 31*da0073e9SAndroid Build Coastguard Worker def external_property(): 32*da0073e9SAndroid Build Coastguard Worker closed_value = 7 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker def internal_function(self): 35*da0073e9SAndroid Build Coastguard Worker return closed_value 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker return internal_function 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker class Elements: 40*da0073e9SAndroid Build Coastguard Worker myprop = property(external_property()) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker def func(elements): 43*da0073e9SAndroid Build Coastguard Worker if not elements.myprop: 44*da0073e9SAndroid Build Coastguard Worker return torch.tensor([1, 2, 3]) 45*da0073e9SAndroid Build Coastguard Worker else: 46*da0073e9SAndroid Build Coastguard Worker return torch.tensor([4, 5, 6]) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker e = Elements() 49*da0073e9SAndroid Build Coastguard Worker a = func(e) 50*da0073e9SAndroid Build Coastguard Worker b = torch.compile(func, backend="eager", fullgraph=True)(e) 51*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker def test_supported_nodes(self): 54*da0073e9SAndroid Build Coastguard Worker class Model(nn.Module): 55*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 56*da0073e9SAndroid Build Coastguard Worker super().__init__() 57*da0073e9SAndroid Build Coastguard Worker self.x = torch.randn(10, 10) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker def forward(self): 60*da0073e9SAndroid Build Coastguard Worker if ( 61*da0073e9SAndroid Build Coastguard Worker torch.utils._pytree.SUPPORTED_NODES[CausalLMOutputWithPast].type 62*da0073e9SAndroid Build Coastguard Worker == int 63*da0073e9SAndroid Build Coastguard Worker ): 64*da0073e9SAndroid Build Coastguard Worker x = torch.sin(self.x) 65*da0073e9SAndroid Build Coastguard Worker else: 66*da0073e9SAndroid Build Coastguard Worker x = torch.cos(self.x) 67*da0073e9SAndroid Build Coastguard Worker return x 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker torch.utils._pytree.register_pytree_node( 70*da0073e9SAndroid Build Coastguard Worker CausalLMOutputWithPast, 71*da0073e9SAndroid Build Coastguard Worker lambda x: ((), None), 72*da0073e9SAndroid Build Coastguard Worker lambda x, _: CausalLMOutputWithPast(), 73*da0073e9SAndroid Build Coastguard Worker ) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker torch.export.export(Model(), ()) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 79*da0073e9SAndroid Build Coastguard Worker torch._dynamo.test_case.run_tests() 80