xref: /aosp_15_r20/external/pytorch/test/test_model_dump.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: mobile"]
3
4import os
5import io
6import functools
7import tempfile
8import urllib
9import unittest
10
11import torch
12import torch.backends.xnnpack
13import torch.utils.model_dump
14import torch.utils.mobile_optimizer
15from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfNoXNNPACK
16from torch.testing._internal.common_quantized import supported_qengines
17
18
19class SimpleModel(torch.nn.Module):
20    def __init__(self) -> None:
21        super().__init__()
22        self.layer1 = torch.nn.Linear(16, 64)
23        self.relu1 = torch.nn.ReLU()
24        self.layer2 = torch.nn.Linear(64, 8)
25        self.relu2 = torch.nn.ReLU()
26
27    def forward(self, features):
28        act = features
29        act = self.layer1(act)
30        act = self.relu1(act)
31        act = self.layer2(act)
32        act = self.relu2(act)
33        return act
34
35
36class QuantModel(torch.nn.Module):
37    def __init__(self) -> None:
38        super().__init__()
39        self.quant = torch.ao.quantization.QuantStub()
40        self.dequant = torch.ao.quantization.DeQuantStub()
41        self.core = SimpleModel()
42
43    def forward(self, x):
44        x = self.quant(x)
45        x = self.core(x)
46        x = self.dequant(x)
47        return x
48
49
50class ModelWithLists(torch.nn.Module):
51    def __init__(self) -> None:
52        super().__init__()
53        self.rt = [torch.zeros(1)]
54        self.ot = [torch.zeros(1), None]
55
56    def forward(self, arg):
57        arg = arg + self.rt[0]
58        o = self.ot[0]
59        if o is not None:
60            arg = arg + o
61        return arg
62
63
64def webdriver_test(testfunc):
65    @functools.wraps(testfunc)
66    def wrapper(self, *args, **kwds):
67        self.needs_resources()
68
69        if os.environ.get("RUN_WEBDRIVER") != "1":
70            self.skipTest("Webdriver not requested")
71        from selenium import webdriver
72
73        for driver in [
74                "Firefox",
75                "Chrome",
76        ]:
77            with self.subTest(driver=driver):
78                wd = getattr(webdriver, driver)()
79                testfunc(self, wd, *args, **kwds)
80                wd.close()
81
82    return wrapper
83
84
85class TestModelDump(TestCase):
86    def needs_resources(self):
87        pass
88
89    def test_inline_skeleton(self):
90        self.needs_resources()
91        skel = torch.utils.model_dump.get_inline_skeleton()
92        assert "unpkg.org" not in skel
93        assert "src=" not in skel
94
95    def do_dump_model(self, model, extra_files=None):
96        # Just check that we're able to run successfully.
97        buf = io.BytesIO()
98        torch.jit.save(model, buf, _extra_files=extra_files)
99        info = torch.utils.model_dump.get_model_info(buf)
100        assert info is not None
101
102    def open_html_model(self, wd, model, extra_files=None):
103        buf = io.BytesIO()
104        torch.jit.save(model, buf, _extra_files=extra_files)
105        page = torch.utils.model_dump.get_info_and_burn_skeleton(buf)
106        wd.get("data:text/html;charset=utf-8," + urllib.parse.quote(page))
107
108    def open_section_and_get_body(self, wd, name):
109        container = wd.find_element_by_xpath(f"//div[@data-hider-title='{name}']")
110        caret = container.find_element_by_class_name("caret")
111        if container.get_attribute("data-shown") != "true":
112            caret.click()
113        content = container.find_element_by_tag_name("div")
114        return content
115
116    def test_scripted_model(self):
117        model = torch.jit.script(SimpleModel())
118        self.do_dump_model(model)
119
120    def test_traced_model(self):
121        model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16))
122        self.do_dump_model(model)
123
124    def test_main(self):
125        self.needs_resources()
126        if IS_WINDOWS:
127            # I was getting tempfile errors in CI.  Just skip it.
128            self.skipTest("Disabled on Windows.")
129
130        with tempfile.NamedTemporaryFile() as tf:
131            torch.jit.save(torch.jit.script(SimpleModel()), tf)
132            # Actually write contents to disk so we can read it below
133            tf.flush()
134
135            stdout = io.StringIO()
136            torch.utils.model_dump.main(
137                [
138                    None,
139                    "--style=json",
140                    tf.name,
141                ],
142                stdout=stdout)
143            self.assertRegex(stdout.getvalue(), r'\A{.*SimpleModel')
144
145            stdout = io.StringIO()
146            torch.utils.model_dump.main(
147                [
148                    None,
149                    "--style=html",
150                    tf.name,
151                ],
152                stdout=stdout)
153            self.assertRegex(
154                stdout.getvalue().replace("\n", " "),
155                r'\A<!DOCTYPE.*SimpleModel.*componentDidMount')
156
157    def get_quant_model(self):
158        fmodel = QuantModel().eval()
159        fmodel = torch.ao.quantization.fuse_modules(fmodel, [
160            ["core.layer1", "core.relu1"],
161            ["core.layer2", "core.relu2"],
162        ])
163        fmodel.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
164        prepped = torch.ao.quantization.prepare(fmodel)
165        prepped(torch.randn(2, 16))
166        qmodel = torch.ao.quantization.convert(prepped)
167        return qmodel
168
169    @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
170    def test_quantized_model(self):
171        qmodel = self.get_quant_model()
172        self.do_dump_model(torch.jit.script(qmodel))
173
174    @skipIfNoXNNPACK
175    @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
176    def test_optimized_quantized_model(self):
177        qmodel = self.get_quant_model()
178        smodel = torch.jit.trace(qmodel, torch.zeros(2, 16))
179        omodel = torch.utils.mobile_optimizer.optimize_for_mobile(smodel)
180        self.do_dump_model(omodel)
181
182    def test_model_with_lists(self):
183        model = torch.jit.script(ModelWithLists())
184        self.do_dump_model(model)
185
186    def test_invalid_json(self):
187        model = torch.jit.script(SimpleModel())
188        self.do_dump_model(model, extra_files={"foo.json": "{"})
189
190    @webdriver_test
191    def test_memory_computation(self, wd):
192        def check_memory(model, expected):
193            self.open_html_model(wd, model)
194            memory_table = self.open_section_and_get_body(wd, "Tensor Memory")
195            device = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[1]").text
196            self.assertEqual("cpu", device)
197            memory_usage_str = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[2]").text
198            self.assertEqual(expected, int(memory_usage_str))
199
200        simple_model_memory = (
201            # First layer, including bias.
202            64 * (16 + 1) +
203            # Second layer, including bias.
204            8 * (64 + 1)
205            # 32-bit float
206        ) * 4
207
208        check_memory(torch.jit.script(SimpleModel()), simple_model_memory)
209
210        # The same SimpleModel instance appears twice in this model.
211        # The tensors will be shared, so ensure no double-counting.
212        a_simple_model = SimpleModel()
213        check_memory(
214            torch.jit.script(
215                torch.nn.Sequential(a_simple_model, a_simple_model)),
216            simple_model_memory)
217
218        # The freezing process will move the weight and bias
219        # from data to constants.  Ensure they are still counted.
220        check_memory(
221            torch.jit.freeze(torch.jit.script(SimpleModel()).eval()),
222            simple_model_memory)
223
224        # Make sure we can handle a model with both constants and data tensors.
225        class ComposedModule(torch.nn.Module):
226            def __init__(self) -> None:
227                super().__init__()
228                self.w1 = torch.zeros(1, 2)
229                self.w2 = torch.ones(2, 2)
230
231            def forward(self, arg):
232                return arg * self.w2 + self.w1
233
234        check_memory(
235            torch.jit.freeze(
236                torch.jit.script(ComposedModule()).eval(),
237                preserved_attrs=["w1"]),
238            4 * (2 + 4))
239
240
241if __name__ == '__main__':
242    run_tests()
243