xref: /aosp_15_r20/external/pytorch/test/test_module_tracker.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: unknown"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom copy import copy
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch import nn
7*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo
8*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.checkpoint import checkpoint
9*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.module_tracker import ModuleTracker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerclass TestModuleTracker(TestCase):
13*da0073e9SAndroid Build Coastguard Worker    # "https://github.com/pytorch/pytorch/issues/127112
14*da0073e9SAndroid Build Coastguard Worker    @xfailIfTorchDynamo
15*da0073e9SAndroid Build Coastguard Worker    def test_module_hierarchy(self):
16*da0073e9SAndroid Build Coastguard Worker        seen_fw = []
17*da0073e9SAndroid Build Coastguard Worker        seen_bw = []
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker        class Foo(nn.Module):
20*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
21*da0073e9SAndroid Build Coastguard Worker                x = x["a"].relu_()
22*da0073e9SAndroid Build Coastguard Worker                seen_fw.append((copy(tracker.parents), tracker.is_bw))
23*da0073e9SAndroid Build Coastguard Worker                x.register_hook(
24*da0073e9SAndroid Build Coastguard Worker                    lambda grad: seen_bw.append((copy(tracker.parents), tracker.is_bw))
25*da0073e9SAndroid Build Coastguard Worker                )
26*da0073e9SAndroid Build Coastguard Worker                return {"a": torch.mm(x, x)}
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker        class Mod(nn.Module):
29*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
30*da0073e9SAndroid Build Coastguard Worker                super().__init__()
31*da0073e9SAndroid Build Coastguard Worker                self.a = Foo()
32*da0073e9SAndroid Build Coastguard Worker                self.b = nn.ModuleDict({"nest": Foo()})
33*da0073e9SAndroid Build Coastguard Worker                self.c = nn.ModuleList([Foo()])
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
36*da0073e9SAndroid Build Coastguard Worker                x = self.c[0](x)
37*da0073e9SAndroid Build Coastguard Worker                return self.b["nest"](self.a(x))
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker        mod = Mod()
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker        with ModuleTracker() as tracker:
42*da0073e9SAndroid Build Coastguard Worker            mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
43*da0073e9SAndroid Build Coastguard Worker                "a"
44*da0073e9SAndroid Build Coastguard Worker            ].sum().backward()
45*da0073e9SAndroid Build Coastguard Worker            mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
46*da0073e9SAndroid Build Coastguard Worker                "a"
47*da0073e9SAndroid Build Coastguard Worker            ].sum().backward()
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
50*da0073e9SAndroid Build Coastguard Worker            seen_fw,
51*da0073e9SAndroid Build Coastguard Worker            [
52*da0073e9SAndroid Build Coastguard Worker                ({"Global", "Mod", "Mod.c.0"}, False),
53*da0073e9SAndroid Build Coastguard Worker                ({"Global", "Mod", "Mod.a"}, False),
54*da0073e9SAndroid Build Coastguard Worker                ({"Global", "Mod", "Mod.b.nest"}, False),
55*da0073e9SAndroid Build Coastguard Worker                ({"Global", "Mod", "Mod.c.0"}, False),
56*da0073e9SAndroid Build Coastguard Worker                ({"Global", "Mod", "Mod.a"}, False),
57*da0073e9SAndroid Build Coastguard Worker                ({"Global", "Mod", "Mod.b.nest"}, False),
58*da0073e9SAndroid Build Coastguard Worker            ],
59*da0073e9SAndroid Build Coastguard Worker        )
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
62*da0073e9SAndroid Build Coastguard Worker            seen_bw,
63*da0073e9SAndroid Build Coastguard Worker            [
64*da0073e9SAndroid Build Coastguard Worker                ({"Global", "Mod", "Mod.b.nest"}, True),
65*da0073e9SAndroid Build Coastguard Worker                ({"Global", "Mod", "Mod.a"}, True),
66*da0073e9SAndroid Build Coastguard Worker                ({"Global", "Mod", "Mod.c.0"}, True),
67*da0073e9SAndroid Build Coastguard Worker                ({"Global", "Mod", "Mod.b.nest"}, True),
68*da0073e9SAndroid Build Coastguard Worker                ({"Global", "Mod", "Mod.a"}, True),
69*da0073e9SAndroid Build Coastguard Worker                ({"Global", "Mod", "Mod.c.0"}, True),
70*da0073e9SAndroid Build Coastguard Worker            ],
71*da0073e9SAndroid Build Coastguard Worker        )
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    def test_confused_hierarchy(self):
74*da0073e9SAndroid Build Coastguard Worker        class MyMod(nn.Module):
75*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
76*da0073e9SAndroid Build Coastguard Worker                super().__init__()
77*da0073e9SAndroid Build Coastguard Worker                self.inner = nn.Linear(2, 2)
78*da0073e9SAndroid Build Coastguard Worker                self.ran = False
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
81*da0073e9SAndroid Build Coastguard Worker                if not self.ran:
82*da0073e9SAndroid Build Coastguard Worker                    self.ran = True
83*da0073e9SAndroid Build Coastguard Worker                    return self(inp)
84*da0073e9SAndroid Build Coastguard Worker                else:
85*da0073e9SAndroid Build Coastguard Worker                    self.ran = False
86*da0073e9SAndroid Build Coastguard Worker                    return self.inner(inp)
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker        mod = MyMod()
89*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(1, 2, requires_grad=True)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker        # Should not fail
92*da0073e9SAndroid Build Coastguard Worker        with ModuleTracker() as tracker:
93*da0073e9SAndroid Build Coastguard Worker            res = mod(inp)
94*da0073e9SAndroid Build Coastguard Worker            res.sum().backward()
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker        # Should not fail
97*da0073e9SAndroid Build Coastguard Worker        with ModuleTracker() as tracker:
98*da0073e9SAndroid Build Coastguard Worker            res = checkpoint(lambda inp: mod(inp), inp)
99*da0073e9SAndroid Build Coastguard Worker            res.sum().backward()
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    def test_bw_detection(self):
102*da0073e9SAndroid Build Coastguard Worker        mod = nn.Linear(2, 2)
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker        with ModuleTracker() as tracker:
105*da0073e9SAndroid Build Coastguard Worker            mod(torch.rand(2, requires_grad=True)).sum().backward()
106*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(tracker.is_bw)
107*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tracker.parents, {"Global"})
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
111*da0073e9SAndroid Build Coastguard Worker    run_tests()
112