1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"] 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport io 6*da0073e9SAndroid Build Coastguard Workerimport functools 7*da0073e9SAndroid Build Coastguard Workerimport tempfile 8*da0073e9SAndroid Build Coastguard Workerimport urllib 9*da0073e9SAndroid Build Coastguard Workerimport unittest 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerimport torch 12*da0073e9SAndroid Build Coastguard Workerimport torch.backends.xnnpack 13*da0073e9SAndroid Build Coastguard Workerimport torch.utils.model_dump 14*da0073e9SAndroid Build Coastguard Workerimport torch.utils.mobile_optimizer 15*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfNoXNNPACK 16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_quantized import supported_qengines 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerclass SimpleModel(torch.nn.Module): 20*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 21*da0073e9SAndroid Build Coastguard Worker super().__init__() 22*da0073e9SAndroid Build Coastguard Worker self.layer1 = torch.nn.Linear(16, 64) 23*da0073e9SAndroid Build Coastguard Worker self.relu1 = torch.nn.ReLU() 24*da0073e9SAndroid Build Coastguard Worker self.layer2 = torch.nn.Linear(64, 8) 25*da0073e9SAndroid Build Coastguard Worker self.relu2 = torch.nn.ReLU() 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker def forward(self, features): 28*da0073e9SAndroid Build Coastguard Worker act = features 29*da0073e9SAndroid Build Coastguard Worker act = self.layer1(act) 30*da0073e9SAndroid Build Coastguard Worker act = self.relu1(act) 31*da0073e9SAndroid Build Coastguard Worker act = self.layer2(act) 32*da0073e9SAndroid Build Coastguard Worker act = self.relu2(act) 33*da0073e9SAndroid Build Coastguard Worker return act 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Workerclass QuantModel(torch.nn.Module): 37*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 38*da0073e9SAndroid Build Coastguard Worker super().__init__() 39*da0073e9SAndroid Build Coastguard Worker self.quant = torch.ao.quantization.QuantStub() 40*da0073e9SAndroid Build Coastguard Worker self.dequant = torch.ao.quantization.DeQuantStub() 41*da0073e9SAndroid Build Coastguard Worker self.core = SimpleModel() 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 44*da0073e9SAndroid Build Coastguard Worker x = self.quant(x) 45*da0073e9SAndroid Build Coastguard Worker x = self.core(x) 46*da0073e9SAndroid Build Coastguard Worker x = self.dequant(x) 47*da0073e9SAndroid Build Coastguard Worker return x 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Workerclass ModelWithLists(torch.nn.Module): 51*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 52*da0073e9SAndroid Build Coastguard Worker super().__init__() 53*da0073e9SAndroid Build Coastguard Worker self.rt = [torch.zeros(1)] 54*da0073e9SAndroid Build Coastguard Worker self.ot = [torch.zeros(1), None] 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 57*da0073e9SAndroid Build Coastguard Worker arg = arg + self.rt[0] 58*da0073e9SAndroid Build Coastguard Worker o = self.ot[0] 59*da0073e9SAndroid Build Coastguard Worker if o is not None: 60*da0073e9SAndroid Build Coastguard Worker arg = arg + o 61*da0073e9SAndroid Build Coastguard Worker return arg 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Workerdef webdriver_test(testfunc): 65*da0073e9SAndroid Build Coastguard Worker @functools.wraps(testfunc) 66*da0073e9SAndroid Build Coastguard Worker def wrapper(self, *args, **kwds): 67*da0073e9SAndroid Build Coastguard Worker self.needs_resources() 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker if os.environ.get("RUN_WEBDRIVER") != "1": 70*da0073e9SAndroid Build Coastguard Worker self.skipTest("Webdriver not requested") 71*da0073e9SAndroid Build Coastguard Worker from selenium import webdriver 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker for driver in [ 74*da0073e9SAndroid Build Coastguard Worker "Firefox", 75*da0073e9SAndroid Build Coastguard Worker "Chrome", 76*da0073e9SAndroid Build Coastguard Worker ]: 77*da0073e9SAndroid Build Coastguard Worker with self.subTest(driver=driver): 78*da0073e9SAndroid Build Coastguard Worker wd = getattr(webdriver, driver)() 79*da0073e9SAndroid Build Coastguard Worker testfunc(self, wd, *args, **kwds) 80*da0073e9SAndroid Build Coastguard Worker wd.close() 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker return wrapper 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Workerclass TestModelDump(TestCase): 86*da0073e9SAndroid Build Coastguard Worker def needs_resources(self): 87*da0073e9SAndroid Build Coastguard Worker pass 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker def test_inline_skeleton(self): 90*da0073e9SAndroid Build Coastguard Worker self.needs_resources() 91*da0073e9SAndroid Build Coastguard Worker skel = torch.utils.model_dump.get_inline_skeleton() 92*da0073e9SAndroid Build Coastguard Worker assert "unpkg.org" not in skel 93*da0073e9SAndroid Build Coastguard Worker assert "src=" not in skel 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker def do_dump_model(self, model, extra_files=None): 96*da0073e9SAndroid Build Coastguard Worker # Just check that we're able to run successfully. 97*da0073e9SAndroid Build Coastguard Worker buf = io.BytesIO() 98*da0073e9SAndroid Build Coastguard Worker torch.jit.save(model, buf, _extra_files=extra_files) 99*da0073e9SAndroid Build Coastguard Worker info = torch.utils.model_dump.get_model_info(buf) 100*da0073e9SAndroid Build Coastguard Worker assert info is not None 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker def open_html_model(self, wd, model, extra_files=None): 103*da0073e9SAndroid Build Coastguard Worker buf = io.BytesIO() 104*da0073e9SAndroid Build Coastguard Worker torch.jit.save(model, buf, _extra_files=extra_files) 105*da0073e9SAndroid Build Coastguard Worker page = torch.utils.model_dump.get_info_and_burn_skeleton(buf) 106*da0073e9SAndroid Build Coastguard Worker wd.get("data:text/html;charset=utf-8," + urllib.parse.quote(page)) 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker def open_section_and_get_body(self, wd, name): 109*da0073e9SAndroid Build Coastguard Worker container = wd.find_element_by_xpath(f"//div[@data-hider-title='{name}']") 110*da0073e9SAndroid Build Coastguard Worker caret = container.find_element_by_class_name("caret") 111*da0073e9SAndroid Build Coastguard Worker if container.get_attribute("data-shown") != "true": 112*da0073e9SAndroid Build Coastguard Worker caret.click() 113*da0073e9SAndroid Build Coastguard Worker content = container.find_element_by_tag_name("div") 114*da0073e9SAndroid Build Coastguard Worker return content 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker def test_scripted_model(self): 117*da0073e9SAndroid Build Coastguard Worker model = torch.jit.script(SimpleModel()) 118*da0073e9SAndroid Build Coastguard Worker self.do_dump_model(model) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker def test_traced_model(self): 121*da0073e9SAndroid Build Coastguard Worker model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16)) 122*da0073e9SAndroid Build Coastguard Worker self.do_dump_model(model) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker def test_main(self): 125*da0073e9SAndroid Build Coastguard Worker self.needs_resources() 126*da0073e9SAndroid Build Coastguard Worker if IS_WINDOWS: 127*da0073e9SAndroid Build Coastguard Worker # I was getting tempfile errors in CI. Just skip it. 128*da0073e9SAndroid Build Coastguard Worker self.skipTest("Disabled on Windows.") 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as tf: 131*da0073e9SAndroid Build Coastguard Worker torch.jit.save(torch.jit.script(SimpleModel()), tf) 132*da0073e9SAndroid Build Coastguard Worker # Actually write contents to disk so we can read it below 133*da0073e9SAndroid Build Coastguard Worker tf.flush() 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker stdout = io.StringIO() 136*da0073e9SAndroid Build Coastguard Worker torch.utils.model_dump.main( 137*da0073e9SAndroid Build Coastguard Worker [ 138*da0073e9SAndroid Build Coastguard Worker None, 139*da0073e9SAndroid Build Coastguard Worker "--style=json", 140*da0073e9SAndroid Build Coastguard Worker tf.name, 141*da0073e9SAndroid Build Coastguard Worker ], 142*da0073e9SAndroid Build Coastguard Worker stdout=stdout) 143*da0073e9SAndroid Build Coastguard Worker self.assertRegex(stdout.getvalue(), r'\A{.*SimpleModel') 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker stdout = io.StringIO() 146*da0073e9SAndroid Build Coastguard Worker torch.utils.model_dump.main( 147*da0073e9SAndroid Build Coastguard Worker [ 148*da0073e9SAndroid Build Coastguard Worker None, 149*da0073e9SAndroid Build Coastguard Worker "--style=html", 150*da0073e9SAndroid Build Coastguard Worker tf.name, 151*da0073e9SAndroid Build Coastguard Worker ], 152*da0073e9SAndroid Build Coastguard Worker stdout=stdout) 153*da0073e9SAndroid Build Coastguard Worker self.assertRegex( 154*da0073e9SAndroid Build Coastguard Worker stdout.getvalue().replace("\n", " "), 155*da0073e9SAndroid Build Coastguard Worker r'\A<!DOCTYPE.*SimpleModel.*componentDidMount') 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker def get_quant_model(self): 158*da0073e9SAndroid Build Coastguard Worker fmodel = QuantModel().eval() 159*da0073e9SAndroid Build Coastguard Worker fmodel = torch.ao.quantization.fuse_modules(fmodel, [ 160*da0073e9SAndroid Build Coastguard Worker ["core.layer1", "core.relu1"], 161*da0073e9SAndroid Build Coastguard Worker ["core.layer2", "core.relu2"], 162*da0073e9SAndroid Build Coastguard Worker ]) 163*da0073e9SAndroid Build Coastguard Worker fmodel.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") 164*da0073e9SAndroid Build Coastguard Worker prepped = torch.ao.quantization.prepare(fmodel) 165*da0073e9SAndroid Build Coastguard Worker prepped(torch.randn(2, 16)) 166*da0073e9SAndroid Build Coastguard Worker qmodel = torch.ao.quantization.convert(prepped) 167*da0073e9SAndroid Build Coastguard Worker return qmodel 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available") 170*da0073e9SAndroid Build Coastguard Worker def test_quantized_model(self): 171*da0073e9SAndroid Build Coastguard Worker qmodel = self.get_quant_model() 172*da0073e9SAndroid Build Coastguard Worker self.do_dump_model(torch.jit.script(qmodel)) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker @skipIfNoXNNPACK 175*da0073e9SAndroid Build Coastguard Worker @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available") 176*da0073e9SAndroid Build Coastguard Worker def test_optimized_quantized_model(self): 177*da0073e9SAndroid Build Coastguard Worker qmodel = self.get_quant_model() 178*da0073e9SAndroid Build Coastguard Worker smodel = torch.jit.trace(qmodel, torch.zeros(2, 16)) 179*da0073e9SAndroid Build Coastguard Worker omodel = torch.utils.mobile_optimizer.optimize_for_mobile(smodel) 180*da0073e9SAndroid Build Coastguard Worker self.do_dump_model(omodel) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker def test_model_with_lists(self): 183*da0073e9SAndroid Build Coastguard Worker model = torch.jit.script(ModelWithLists()) 184*da0073e9SAndroid Build Coastguard Worker self.do_dump_model(model) 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker def test_invalid_json(self): 187*da0073e9SAndroid Build Coastguard Worker model = torch.jit.script(SimpleModel()) 188*da0073e9SAndroid Build Coastguard Worker self.do_dump_model(model, extra_files={"foo.json": "{"}) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker @webdriver_test 191*da0073e9SAndroid Build Coastguard Worker def test_memory_computation(self, wd): 192*da0073e9SAndroid Build Coastguard Worker def check_memory(model, expected): 193*da0073e9SAndroid Build Coastguard Worker self.open_html_model(wd, model) 194*da0073e9SAndroid Build Coastguard Worker memory_table = self.open_section_and_get_body(wd, "Tensor Memory") 195*da0073e9SAndroid Build Coastguard Worker device = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[1]").text 196*da0073e9SAndroid Build Coastguard Worker self.assertEqual("cpu", device) 197*da0073e9SAndroid Build Coastguard Worker memory_usage_str = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[2]").text 198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, int(memory_usage_str)) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker simple_model_memory = ( 201*da0073e9SAndroid Build Coastguard Worker # First layer, including bias. 202*da0073e9SAndroid Build Coastguard Worker 64 * (16 + 1) + 203*da0073e9SAndroid Build Coastguard Worker # Second layer, including bias. 204*da0073e9SAndroid Build Coastguard Worker 8 * (64 + 1) 205*da0073e9SAndroid Build Coastguard Worker # 32-bit float 206*da0073e9SAndroid Build Coastguard Worker ) * 4 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker check_memory(torch.jit.script(SimpleModel()), simple_model_memory) 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker # The same SimpleModel instance appears twice in this model. 211*da0073e9SAndroid Build Coastguard Worker # The tensors will be shared, so ensure no double-counting. 212*da0073e9SAndroid Build Coastguard Worker a_simple_model = SimpleModel() 213*da0073e9SAndroid Build Coastguard Worker check_memory( 214*da0073e9SAndroid Build Coastguard Worker torch.jit.script( 215*da0073e9SAndroid Build Coastguard Worker torch.nn.Sequential(a_simple_model, a_simple_model)), 216*da0073e9SAndroid Build Coastguard Worker simple_model_memory) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker # The freezing process will move the weight and bias 219*da0073e9SAndroid Build Coastguard Worker # from data to constants. Ensure they are still counted. 220*da0073e9SAndroid Build Coastguard Worker check_memory( 221*da0073e9SAndroid Build Coastguard Worker torch.jit.freeze(torch.jit.script(SimpleModel()).eval()), 222*da0073e9SAndroid Build Coastguard Worker simple_model_memory) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker # Make sure we can handle a model with both constants and data tensors. 225*da0073e9SAndroid Build Coastguard Worker class ComposedModule(torch.nn.Module): 226*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 227*da0073e9SAndroid Build Coastguard Worker super().__init__() 228*da0073e9SAndroid Build Coastguard Worker self.w1 = torch.zeros(1, 2) 229*da0073e9SAndroid Build Coastguard Worker self.w2 = torch.ones(2, 2) 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 232*da0073e9SAndroid Build Coastguard Worker return arg * self.w2 + self.w1 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker check_memory( 235*da0073e9SAndroid Build Coastguard Worker torch.jit.freeze( 236*da0073e9SAndroid Build Coastguard Worker torch.jit.script(ComposedModule()).eval(), 237*da0073e9SAndroid Build Coastguard Worker preserved_attrs=["w1"]), 238*da0073e9SAndroid Build Coastguard Worker 4 * (2 + 4)) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 242*da0073e9SAndroid Build Coastguard Worker run_tests() 243