1#!/usr/bin/env python3 2# Owner(s): ["oncall: mobile"] 3 4import os 5import io 6import functools 7import tempfile 8import urllib 9import unittest 10 11import torch 12import torch.backends.xnnpack 13import torch.utils.model_dump 14import torch.utils.mobile_optimizer 15from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfNoXNNPACK 16from torch.testing._internal.common_quantized import supported_qengines 17 18 19class SimpleModel(torch.nn.Module): 20 def __init__(self) -> None: 21 super().__init__() 22 self.layer1 = torch.nn.Linear(16, 64) 23 self.relu1 = torch.nn.ReLU() 24 self.layer2 = torch.nn.Linear(64, 8) 25 self.relu2 = torch.nn.ReLU() 26 27 def forward(self, features): 28 act = features 29 act = self.layer1(act) 30 act = self.relu1(act) 31 act = self.layer2(act) 32 act = self.relu2(act) 33 return act 34 35 36class QuantModel(torch.nn.Module): 37 def __init__(self) -> None: 38 super().__init__() 39 self.quant = torch.ao.quantization.QuantStub() 40 self.dequant = torch.ao.quantization.DeQuantStub() 41 self.core = SimpleModel() 42 43 def forward(self, x): 44 x = self.quant(x) 45 x = self.core(x) 46 x = self.dequant(x) 47 return x 48 49 50class ModelWithLists(torch.nn.Module): 51 def __init__(self) -> None: 52 super().__init__() 53 self.rt = [torch.zeros(1)] 54 self.ot = [torch.zeros(1), None] 55 56 def forward(self, arg): 57 arg = arg + self.rt[0] 58 o = self.ot[0] 59 if o is not None: 60 arg = arg + o 61 return arg 62 63 64def webdriver_test(testfunc): 65 @functools.wraps(testfunc) 66 def wrapper(self, *args, **kwds): 67 self.needs_resources() 68 69 if os.environ.get("RUN_WEBDRIVER") != "1": 70 self.skipTest("Webdriver not requested") 71 from selenium import webdriver 72 73 for driver in [ 74 "Firefox", 75 "Chrome", 76 ]: 77 with self.subTest(driver=driver): 78 wd = getattr(webdriver, driver)() 79 testfunc(self, wd, *args, **kwds) 80 wd.close() 81 82 return wrapper 83 84 85class TestModelDump(TestCase): 86 def needs_resources(self): 87 pass 88 89 def test_inline_skeleton(self): 90 self.needs_resources() 91 skel = torch.utils.model_dump.get_inline_skeleton() 92 assert "unpkg.org" not in skel 93 assert "src=" not in skel 94 95 def do_dump_model(self, model, extra_files=None): 96 # Just check that we're able to run successfully. 97 buf = io.BytesIO() 98 torch.jit.save(model, buf, _extra_files=extra_files) 99 info = torch.utils.model_dump.get_model_info(buf) 100 assert info is not None 101 102 def open_html_model(self, wd, model, extra_files=None): 103 buf = io.BytesIO() 104 torch.jit.save(model, buf, _extra_files=extra_files) 105 page = torch.utils.model_dump.get_info_and_burn_skeleton(buf) 106 wd.get("data:text/html;charset=utf-8," + urllib.parse.quote(page)) 107 108 def open_section_and_get_body(self, wd, name): 109 container = wd.find_element_by_xpath(f"//div[@data-hider-title='{name}']") 110 caret = container.find_element_by_class_name("caret") 111 if container.get_attribute("data-shown") != "true": 112 caret.click() 113 content = container.find_element_by_tag_name("div") 114 return content 115 116 def test_scripted_model(self): 117 model = torch.jit.script(SimpleModel()) 118 self.do_dump_model(model) 119 120 def test_traced_model(self): 121 model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16)) 122 self.do_dump_model(model) 123 124 def test_main(self): 125 self.needs_resources() 126 if IS_WINDOWS: 127 # I was getting tempfile errors in CI. Just skip it. 128 self.skipTest("Disabled on Windows.") 129 130 with tempfile.NamedTemporaryFile() as tf: 131 torch.jit.save(torch.jit.script(SimpleModel()), tf) 132 # Actually write contents to disk so we can read it below 133 tf.flush() 134 135 stdout = io.StringIO() 136 torch.utils.model_dump.main( 137 [ 138 None, 139 "--style=json", 140 tf.name, 141 ], 142 stdout=stdout) 143 self.assertRegex(stdout.getvalue(), r'\A{.*SimpleModel') 144 145 stdout = io.StringIO() 146 torch.utils.model_dump.main( 147 [ 148 None, 149 "--style=html", 150 tf.name, 151 ], 152 stdout=stdout) 153 self.assertRegex( 154 stdout.getvalue().replace("\n", " "), 155 r'\A<!DOCTYPE.*SimpleModel.*componentDidMount') 156 157 def get_quant_model(self): 158 fmodel = QuantModel().eval() 159 fmodel = torch.ao.quantization.fuse_modules(fmodel, [ 160 ["core.layer1", "core.relu1"], 161 ["core.layer2", "core.relu2"], 162 ]) 163 fmodel.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") 164 prepped = torch.ao.quantization.prepare(fmodel) 165 prepped(torch.randn(2, 16)) 166 qmodel = torch.ao.quantization.convert(prepped) 167 return qmodel 168 169 @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available") 170 def test_quantized_model(self): 171 qmodel = self.get_quant_model() 172 self.do_dump_model(torch.jit.script(qmodel)) 173 174 @skipIfNoXNNPACK 175 @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available") 176 def test_optimized_quantized_model(self): 177 qmodel = self.get_quant_model() 178 smodel = torch.jit.trace(qmodel, torch.zeros(2, 16)) 179 omodel = torch.utils.mobile_optimizer.optimize_for_mobile(smodel) 180 self.do_dump_model(omodel) 181 182 def test_model_with_lists(self): 183 model = torch.jit.script(ModelWithLists()) 184 self.do_dump_model(model) 185 186 def test_invalid_json(self): 187 model = torch.jit.script(SimpleModel()) 188 self.do_dump_model(model, extra_files={"foo.json": "{"}) 189 190 @webdriver_test 191 def test_memory_computation(self, wd): 192 def check_memory(model, expected): 193 self.open_html_model(wd, model) 194 memory_table = self.open_section_and_get_body(wd, "Tensor Memory") 195 device = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[1]").text 196 self.assertEqual("cpu", device) 197 memory_usage_str = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[2]").text 198 self.assertEqual(expected, int(memory_usage_str)) 199 200 simple_model_memory = ( 201 # First layer, including bias. 202 64 * (16 + 1) + 203 # Second layer, including bias. 204 8 * (64 + 1) 205 # 32-bit float 206 ) * 4 207 208 check_memory(torch.jit.script(SimpleModel()), simple_model_memory) 209 210 # The same SimpleModel instance appears twice in this model. 211 # The tensors will be shared, so ensure no double-counting. 212 a_simple_model = SimpleModel() 213 check_memory( 214 torch.jit.script( 215 torch.nn.Sequential(a_simple_model, a_simple_model)), 216 simple_model_memory) 217 218 # The freezing process will move the weight and bias 219 # from data to constants. Ensure they are still counted. 220 check_memory( 221 torch.jit.freeze(torch.jit.script(SimpleModel()).eval()), 222 simple_model_memory) 223 224 # Make sure we can handle a model with both constants and data tensors. 225 class ComposedModule(torch.nn.Module): 226 def __init__(self) -> None: 227 super().__init__() 228 self.w1 = torch.zeros(1, 2) 229 self.w2 = torch.ones(2, 2) 230 231 def forward(self, arg): 232 return arg * self.w2 + self.w1 233 234 check_memory( 235 torch.jit.freeze( 236 torch.jit.script(ComposedModule()).eval(), 237 preserved_attrs=["w1"]), 238 4 * (2 + 4)) 239 240 241if __name__ == '__main__': 242 run_tests() 243