1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10import torchvision 11 12from executorch.backends.xnnpack.test.tester import Quantize, Tester 13 14 15class TestResNet18(unittest.TestCase): 16 inputs = (torch.randn(1, 3, 224, 224),) 17 dynamic_shapes = ( 18 { 19 2: torch.export.Dim("height", min=224, max=455), 20 3: torch.export.Dim("width", min=224, max=455), 21 }, 22 ) 23 24 class DynamicResNet(torch.nn.Module): 25 def __init__(self): 26 super().__init__() 27 self.model = torchvision.models.resnet18() 28 29 def forward(self, x): 30 x = torch.nn.functional.interpolate( 31 x, 32 size=(224, 224), 33 mode="bilinear", 34 align_corners=True, 35 antialias=False, 36 ) 37 return self.model(x) 38 39 def _test_exported_resnet(self, tester): 40 ( 41 tester.export() 42 .to_edge_transform_and_lower() 43 .check_not( 44 [ 45 "executorch_exir_dialects_edge__ops_aten_convolution_default", 46 "executorch_exir_dialects_edge__ops_aten_mean_dim", 47 ] 48 ) 49 .check(["torch.ops.higher_order.executorch_call_delegate"]) 50 .to_executorch() 51 .serialize() 52 .run_method_and_compare_outputs() 53 ) 54 55 def test_fp32_resnet18(self): 56 self._test_exported_resnet(Tester(torchvision.models.resnet18(), self.inputs)) 57 58 @unittest.skip("T187799178: Debugging Numerical Issues with Calibration") 59 def _test_qs8_resnet18(self): 60 quantized_tester = Tester(torchvision.models.resnet18(), self.inputs).quantize() 61 self._test_exported_resnet(quantized_tester) 62 63 # TODO: Delete and only used calibrated test after T187799178 64 def test_qs8_resnet18_no_calibration(self): 65 quantized_tester = Tester(torchvision.models.resnet18(), self.inputs).quantize( 66 Quantize(calibrate=False) 67 ) 68 self._test_exported_resnet(quantized_tester) 69 70 def test_fp32_resnet18_dynamic(self): 71 self._test_exported_resnet( 72 Tester(self.DynamicResNet(), self.inputs, self.dynamic_shapes) 73 ) 74 75 @unittest.skip("T187799178: Debugging Numerical Issues with Calibration") 76 def _test_qs8_resnet18_dynamic(self): 77 self._test_exported_resnet( 78 Tester(self.DynamicResNet(), self.inputs, self.dynamic_shapes).quantize() 79 ) 80 81 # TODO: Delete and only used calibrated test after T187799178 82 def test_qs8_resnet18_dynamic_no_calibration(self): 83 self._test_exported_resnet( 84 Tester(self.DynamicResNet(), self.inputs, self.dynamic_shapes).quantize( 85 Quantize(calibrate=False) 86 ) 87 ) 88