1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom copy import copy 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch import nn 7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo 8*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.checkpoint import checkpoint 9*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.module_tracker import ModuleTracker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerclass TestModuleTracker(TestCase): 13*da0073e9SAndroid Build Coastguard Worker # "https://github.com/pytorch/pytorch/issues/127112 14*da0073e9SAndroid Build Coastguard Worker @xfailIfTorchDynamo 15*da0073e9SAndroid Build Coastguard Worker def test_module_hierarchy(self): 16*da0073e9SAndroid Build Coastguard Worker seen_fw = [] 17*da0073e9SAndroid Build Coastguard Worker seen_bw = [] 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker class Foo(nn.Module): 20*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 21*da0073e9SAndroid Build Coastguard Worker x = x["a"].relu_() 22*da0073e9SAndroid Build Coastguard Worker seen_fw.append((copy(tracker.parents), tracker.is_bw)) 23*da0073e9SAndroid Build Coastguard Worker x.register_hook( 24*da0073e9SAndroid Build Coastguard Worker lambda grad: seen_bw.append((copy(tracker.parents), tracker.is_bw)) 25*da0073e9SAndroid Build Coastguard Worker ) 26*da0073e9SAndroid Build Coastguard Worker return {"a": torch.mm(x, x)} 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 29*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 30*da0073e9SAndroid Build Coastguard Worker super().__init__() 31*da0073e9SAndroid Build Coastguard Worker self.a = Foo() 32*da0073e9SAndroid Build Coastguard Worker self.b = nn.ModuleDict({"nest": Foo()}) 33*da0073e9SAndroid Build Coastguard Worker self.c = nn.ModuleList([Foo()]) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 36*da0073e9SAndroid Build Coastguard Worker x = self.c[0](x) 37*da0073e9SAndroid Build Coastguard Worker return self.b["nest"](self.a(x)) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker mod = Mod() 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker with ModuleTracker() as tracker: 42*da0073e9SAndroid Build Coastguard Worker mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ 43*da0073e9SAndroid Build Coastguard Worker "a" 44*da0073e9SAndroid Build Coastguard Worker ].sum().backward() 45*da0073e9SAndroid Build Coastguard Worker mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[ 46*da0073e9SAndroid Build Coastguard Worker "a" 47*da0073e9SAndroid Build Coastguard Worker ].sum().backward() 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 50*da0073e9SAndroid Build Coastguard Worker seen_fw, 51*da0073e9SAndroid Build Coastguard Worker [ 52*da0073e9SAndroid Build Coastguard Worker ({"Global", "Mod", "Mod.c.0"}, False), 53*da0073e9SAndroid Build Coastguard Worker ({"Global", "Mod", "Mod.a"}, False), 54*da0073e9SAndroid Build Coastguard Worker ({"Global", "Mod", "Mod.b.nest"}, False), 55*da0073e9SAndroid Build Coastguard Worker ({"Global", "Mod", "Mod.c.0"}, False), 56*da0073e9SAndroid Build Coastguard Worker ({"Global", "Mod", "Mod.a"}, False), 57*da0073e9SAndroid Build Coastguard Worker ({"Global", "Mod", "Mod.b.nest"}, False), 58*da0073e9SAndroid Build Coastguard Worker ], 59*da0073e9SAndroid Build Coastguard Worker ) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 62*da0073e9SAndroid Build Coastguard Worker seen_bw, 63*da0073e9SAndroid Build Coastguard Worker [ 64*da0073e9SAndroid Build Coastguard Worker ({"Global", "Mod", "Mod.b.nest"}, True), 65*da0073e9SAndroid Build Coastguard Worker ({"Global", "Mod", "Mod.a"}, True), 66*da0073e9SAndroid Build Coastguard Worker ({"Global", "Mod", "Mod.c.0"}, True), 67*da0073e9SAndroid Build Coastguard Worker ({"Global", "Mod", "Mod.b.nest"}, True), 68*da0073e9SAndroid Build Coastguard Worker ({"Global", "Mod", "Mod.a"}, True), 69*da0073e9SAndroid Build Coastguard Worker ({"Global", "Mod", "Mod.c.0"}, True), 70*da0073e9SAndroid Build Coastguard Worker ], 71*da0073e9SAndroid Build Coastguard Worker ) 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker def test_confused_hierarchy(self): 74*da0073e9SAndroid Build Coastguard Worker class MyMod(nn.Module): 75*da0073e9SAndroid Build Coastguard Worker def __init__(self): 76*da0073e9SAndroid Build Coastguard Worker super().__init__() 77*da0073e9SAndroid Build Coastguard Worker self.inner = nn.Linear(2, 2) 78*da0073e9SAndroid Build Coastguard Worker self.ran = False 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 81*da0073e9SAndroid Build Coastguard Worker if not self.ran: 82*da0073e9SAndroid Build Coastguard Worker self.ran = True 83*da0073e9SAndroid Build Coastguard Worker return self(inp) 84*da0073e9SAndroid Build Coastguard Worker else: 85*da0073e9SAndroid Build Coastguard Worker self.ran = False 86*da0073e9SAndroid Build Coastguard Worker return self.inner(inp) 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker mod = MyMod() 89*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(1, 2, requires_grad=True) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker # Should not fail 92*da0073e9SAndroid Build Coastguard Worker with ModuleTracker() as tracker: 93*da0073e9SAndroid Build Coastguard Worker res = mod(inp) 94*da0073e9SAndroid Build Coastguard Worker res.sum().backward() 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker # Should not fail 97*da0073e9SAndroid Build Coastguard Worker with ModuleTracker() as tracker: 98*da0073e9SAndroid Build Coastguard Worker res = checkpoint(lambda inp: mod(inp), inp) 99*da0073e9SAndroid Build Coastguard Worker res.sum().backward() 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker def test_bw_detection(self): 102*da0073e9SAndroid Build Coastguard Worker mod = nn.Linear(2, 2) 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker with ModuleTracker() as tracker: 105*da0073e9SAndroid Build Coastguard Worker mod(torch.rand(2, requires_grad=True)).sum().backward() 106*da0073e9SAndroid Build Coastguard Worker self.assertFalse(tracker.is_bw) 107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tracker.parents, {"Global"}) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 111*da0073e9SAndroid Build Coastguard Worker run_tests() 112