xref: /aosp_15_r20/external/pytorch/test/test_model_dump.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"]
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport io
6*da0073e9SAndroid Build Coastguard Workerimport functools
7*da0073e9SAndroid Build Coastguard Workerimport tempfile
8*da0073e9SAndroid Build Coastguard Workerimport urllib
9*da0073e9SAndroid Build Coastguard Workerimport unittest
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerimport torch
12*da0073e9SAndroid Build Coastguard Workerimport torch.backends.xnnpack
13*da0073e9SAndroid Build Coastguard Workerimport torch.utils.model_dump
14*da0073e9SAndroid Build Coastguard Workerimport torch.utils.mobile_optimizer
15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfNoXNNPACK
16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_quantized import supported_qengines
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerclass SimpleModel(torch.nn.Module):
20*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
21*da0073e9SAndroid Build Coastguard Worker        super().__init__()
22*da0073e9SAndroid Build Coastguard Worker        self.layer1 = torch.nn.Linear(16, 64)
23*da0073e9SAndroid Build Coastguard Worker        self.relu1 = torch.nn.ReLU()
24*da0073e9SAndroid Build Coastguard Worker        self.layer2 = torch.nn.Linear(64, 8)
25*da0073e9SAndroid Build Coastguard Worker        self.relu2 = torch.nn.ReLU()
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    def forward(self, features):
28*da0073e9SAndroid Build Coastguard Worker        act = features
29*da0073e9SAndroid Build Coastguard Worker        act = self.layer1(act)
30*da0073e9SAndroid Build Coastguard Worker        act = self.relu1(act)
31*da0073e9SAndroid Build Coastguard Worker        act = self.layer2(act)
32*da0073e9SAndroid Build Coastguard Worker        act = self.relu2(act)
33*da0073e9SAndroid Build Coastguard Worker        return act
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Workerclass QuantModel(torch.nn.Module):
37*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
38*da0073e9SAndroid Build Coastguard Worker        super().__init__()
39*da0073e9SAndroid Build Coastguard Worker        self.quant = torch.ao.quantization.QuantStub()
40*da0073e9SAndroid Build Coastguard Worker        self.dequant = torch.ao.quantization.DeQuantStub()
41*da0073e9SAndroid Build Coastguard Worker        self.core = SimpleModel()
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
44*da0073e9SAndroid Build Coastguard Worker        x = self.quant(x)
45*da0073e9SAndroid Build Coastguard Worker        x = self.core(x)
46*da0073e9SAndroid Build Coastguard Worker        x = self.dequant(x)
47*da0073e9SAndroid Build Coastguard Worker        return x
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Workerclass ModelWithLists(torch.nn.Module):
51*da0073e9SAndroid Build Coastguard Worker    def __init__(self) -> None:
52*da0073e9SAndroid Build Coastguard Worker        super().__init__()
53*da0073e9SAndroid Build Coastguard Worker        self.rt = [torch.zeros(1)]
54*da0073e9SAndroid Build Coastguard Worker        self.ot = [torch.zeros(1), None]
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    def forward(self, arg):
57*da0073e9SAndroid Build Coastguard Worker        arg = arg + self.rt[0]
58*da0073e9SAndroid Build Coastguard Worker        o = self.ot[0]
59*da0073e9SAndroid Build Coastguard Worker        if o is not None:
60*da0073e9SAndroid Build Coastguard Worker            arg = arg + o
61*da0073e9SAndroid Build Coastguard Worker        return arg
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Workerdef webdriver_test(testfunc):
65*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(testfunc)
66*da0073e9SAndroid Build Coastguard Worker    def wrapper(self, *args, **kwds):
67*da0073e9SAndroid Build Coastguard Worker        self.needs_resources()
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker        if os.environ.get("RUN_WEBDRIVER") != "1":
70*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Webdriver not requested")
71*da0073e9SAndroid Build Coastguard Worker        from selenium import webdriver
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker        for driver in [
74*da0073e9SAndroid Build Coastguard Worker                "Firefox",
75*da0073e9SAndroid Build Coastguard Worker                "Chrome",
76*da0073e9SAndroid Build Coastguard Worker        ]:
77*da0073e9SAndroid Build Coastguard Worker            with self.subTest(driver=driver):
78*da0073e9SAndroid Build Coastguard Worker                wd = getattr(webdriver, driver)()
79*da0073e9SAndroid Build Coastguard Worker                testfunc(self, wd, *args, **kwds)
80*da0073e9SAndroid Build Coastguard Worker                wd.close()
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    return wrapper
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Workerclass TestModelDump(TestCase):
86*da0073e9SAndroid Build Coastguard Worker    def needs_resources(self):
87*da0073e9SAndroid Build Coastguard Worker        pass
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    def test_inline_skeleton(self):
90*da0073e9SAndroid Build Coastguard Worker        self.needs_resources()
91*da0073e9SAndroid Build Coastguard Worker        skel = torch.utils.model_dump.get_inline_skeleton()
92*da0073e9SAndroid Build Coastguard Worker        assert "unpkg.org" not in skel
93*da0073e9SAndroid Build Coastguard Worker        assert "src=" not in skel
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    def do_dump_model(self, model, extra_files=None):
96*da0073e9SAndroid Build Coastguard Worker        # Just check that we're able to run successfully.
97*da0073e9SAndroid Build Coastguard Worker        buf = io.BytesIO()
98*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(model, buf, _extra_files=extra_files)
99*da0073e9SAndroid Build Coastguard Worker        info = torch.utils.model_dump.get_model_info(buf)
100*da0073e9SAndroid Build Coastguard Worker        assert info is not None
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    def open_html_model(self, wd, model, extra_files=None):
103*da0073e9SAndroid Build Coastguard Worker        buf = io.BytesIO()
104*da0073e9SAndroid Build Coastguard Worker        torch.jit.save(model, buf, _extra_files=extra_files)
105*da0073e9SAndroid Build Coastguard Worker        page = torch.utils.model_dump.get_info_and_burn_skeleton(buf)
106*da0073e9SAndroid Build Coastguard Worker        wd.get("data:text/html;charset=utf-8," + urllib.parse.quote(page))
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker    def open_section_and_get_body(self, wd, name):
109*da0073e9SAndroid Build Coastguard Worker        container = wd.find_element_by_xpath(f"//div[@data-hider-title='{name}']")
110*da0073e9SAndroid Build Coastguard Worker        caret = container.find_element_by_class_name("caret")
111*da0073e9SAndroid Build Coastguard Worker        if container.get_attribute("data-shown") != "true":
112*da0073e9SAndroid Build Coastguard Worker            caret.click()
113*da0073e9SAndroid Build Coastguard Worker        content = container.find_element_by_tag_name("div")
114*da0073e9SAndroid Build Coastguard Worker        return content
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    def test_scripted_model(self):
117*da0073e9SAndroid Build Coastguard Worker        model = torch.jit.script(SimpleModel())
118*da0073e9SAndroid Build Coastguard Worker        self.do_dump_model(model)
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    def test_traced_model(self):
121*da0073e9SAndroid Build Coastguard Worker        model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16))
122*da0073e9SAndroid Build Coastguard Worker        self.do_dump_model(model)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    def test_main(self):
125*da0073e9SAndroid Build Coastguard Worker        self.needs_resources()
126*da0073e9SAndroid Build Coastguard Worker        if IS_WINDOWS:
127*da0073e9SAndroid Build Coastguard Worker            # I was getting tempfile errors in CI.  Just skip it.
128*da0073e9SAndroid Build Coastguard Worker            self.skipTest("Disabled on Windows.")
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker        with tempfile.NamedTemporaryFile() as tf:
131*da0073e9SAndroid Build Coastguard Worker            torch.jit.save(torch.jit.script(SimpleModel()), tf)
132*da0073e9SAndroid Build Coastguard Worker            # Actually write contents to disk so we can read it below
133*da0073e9SAndroid Build Coastguard Worker            tf.flush()
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker            stdout = io.StringIO()
136*da0073e9SAndroid Build Coastguard Worker            torch.utils.model_dump.main(
137*da0073e9SAndroid Build Coastguard Worker                [
138*da0073e9SAndroid Build Coastguard Worker                    None,
139*da0073e9SAndroid Build Coastguard Worker                    "--style=json",
140*da0073e9SAndroid Build Coastguard Worker                    tf.name,
141*da0073e9SAndroid Build Coastguard Worker                ],
142*da0073e9SAndroid Build Coastguard Worker                stdout=stdout)
143*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(stdout.getvalue(), r'\A{.*SimpleModel')
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker            stdout = io.StringIO()
146*da0073e9SAndroid Build Coastguard Worker            torch.utils.model_dump.main(
147*da0073e9SAndroid Build Coastguard Worker                [
148*da0073e9SAndroid Build Coastguard Worker                    None,
149*da0073e9SAndroid Build Coastguard Worker                    "--style=html",
150*da0073e9SAndroid Build Coastguard Worker                    tf.name,
151*da0073e9SAndroid Build Coastguard Worker                ],
152*da0073e9SAndroid Build Coastguard Worker                stdout=stdout)
153*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(
154*da0073e9SAndroid Build Coastguard Worker                stdout.getvalue().replace("\n", " "),
155*da0073e9SAndroid Build Coastguard Worker                r'\A<!DOCTYPE.*SimpleModel.*componentDidMount')
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker    def get_quant_model(self):
158*da0073e9SAndroid Build Coastguard Worker        fmodel = QuantModel().eval()
159*da0073e9SAndroid Build Coastguard Worker        fmodel = torch.ao.quantization.fuse_modules(fmodel, [
160*da0073e9SAndroid Build Coastguard Worker            ["core.layer1", "core.relu1"],
161*da0073e9SAndroid Build Coastguard Worker            ["core.layer2", "core.relu2"],
162*da0073e9SAndroid Build Coastguard Worker        ])
163*da0073e9SAndroid Build Coastguard Worker        fmodel.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
164*da0073e9SAndroid Build Coastguard Worker        prepped = torch.ao.quantization.prepare(fmodel)
165*da0073e9SAndroid Build Coastguard Worker        prepped(torch.randn(2, 16))
166*da0073e9SAndroid Build Coastguard Worker        qmodel = torch.ao.quantization.convert(prepped)
167*da0073e9SAndroid Build Coastguard Worker        return qmodel
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker    @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
170*da0073e9SAndroid Build Coastguard Worker    def test_quantized_model(self):
171*da0073e9SAndroid Build Coastguard Worker        qmodel = self.get_quant_model()
172*da0073e9SAndroid Build Coastguard Worker        self.do_dump_model(torch.jit.script(qmodel))
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    @skipIfNoXNNPACK
175*da0073e9SAndroid Build Coastguard Worker    @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
176*da0073e9SAndroid Build Coastguard Worker    def test_optimized_quantized_model(self):
177*da0073e9SAndroid Build Coastguard Worker        qmodel = self.get_quant_model()
178*da0073e9SAndroid Build Coastguard Worker        smodel = torch.jit.trace(qmodel, torch.zeros(2, 16))
179*da0073e9SAndroid Build Coastguard Worker        omodel = torch.utils.mobile_optimizer.optimize_for_mobile(smodel)
180*da0073e9SAndroid Build Coastguard Worker        self.do_dump_model(omodel)
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker    def test_model_with_lists(self):
183*da0073e9SAndroid Build Coastguard Worker        model = torch.jit.script(ModelWithLists())
184*da0073e9SAndroid Build Coastguard Worker        self.do_dump_model(model)
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker    def test_invalid_json(self):
187*da0073e9SAndroid Build Coastguard Worker        model = torch.jit.script(SimpleModel())
188*da0073e9SAndroid Build Coastguard Worker        self.do_dump_model(model, extra_files={"foo.json": "{"})
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker    @webdriver_test
191*da0073e9SAndroid Build Coastguard Worker    def test_memory_computation(self, wd):
192*da0073e9SAndroid Build Coastguard Worker        def check_memory(model, expected):
193*da0073e9SAndroid Build Coastguard Worker            self.open_html_model(wd, model)
194*da0073e9SAndroid Build Coastguard Worker            memory_table = self.open_section_and_get_body(wd, "Tensor Memory")
195*da0073e9SAndroid Build Coastguard Worker            device = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[1]").text
196*da0073e9SAndroid Build Coastguard Worker            self.assertEqual("cpu", device)
197*da0073e9SAndroid Build Coastguard Worker            memory_usage_str = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[2]").text
198*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(expected, int(memory_usage_str))
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker        simple_model_memory = (
201*da0073e9SAndroid Build Coastguard Worker            # First layer, including bias.
202*da0073e9SAndroid Build Coastguard Worker            64 * (16 + 1) +
203*da0073e9SAndroid Build Coastguard Worker            # Second layer, including bias.
204*da0073e9SAndroid Build Coastguard Worker            8 * (64 + 1)
205*da0073e9SAndroid Build Coastguard Worker            # 32-bit float
206*da0073e9SAndroid Build Coastguard Worker        ) * 4
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker        check_memory(torch.jit.script(SimpleModel()), simple_model_memory)
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker        # The same SimpleModel instance appears twice in this model.
211*da0073e9SAndroid Build Coastguard Worker        # The tensors will be shared, so ensure no double-counting.
212*da0073e9SAndroid Build Coastguard Worker        a_simple_model = SimpleModel()
213*da0073e9SAndroid Build Coastguard Worker        check_memory(
214*da0073e9SAndroid Build Coastguard Worker            torch.jit.script(
215*da0073e9SAndroid Build Coastguard Worker                torch.nn.Sequential(a_simple_model, a_simple_model)),
216*da0073e9SAndroid Build Coastguard Worker            simple_model_memory)
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker        # The freezing process will move the weight and bias
219*da0073e9SAndroid Build Coastguard Worker        # from data to constants.  Ensure they are still counted.
220*da0073e9SAndroid Build Coastguard Worker        check_memory(
221*da0073e9SAndroid Build Coastguard Worker            torch.jit.freeze(torch.jit.script(SimpleModel()).eval()),
222*da0073e9SAndroid Build Coastguard Worker            simple_model_memory)
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker        # Make sure we can handle a model with both constants and data tensors.
225*da0073e9SAndroid Build Coastguard Worker        class ComposedModule(torch.nn.Module):
226*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
227*da0073e9SAndroid Build Coastguard Worker                super().__init__()
228*da0073e9SAndroid Build Coastguard Worker                self.w1 = torch.zeros(1, 2)
229*da0073e9SAndroid Build Coastguard Worker                self.w2 = torch.ones(2, 2)
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker            def forward(self, arg):
232*da0073e9SAndroid Build Coastguard Worker                return arg * self.w2 + self.w1
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        check_memory(
235*da0073e9SAndroid Build Coastguard Worker            torch.jit.freeze(
236*da0073e9SAndroid Build Coastguard Worker                torch.jit.script(ComposedModule()).eval(),
237*da0073e9SAndroid Build Coastguard Worker                preserved_attrs=["w1"]),
238*da0073e9SAndroid Build Coastguard Worker            4 * (2 + 4))
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
242*da0073e9SAndroid Build Coastguard Worker    run_tests()
243