# Owner(s): ["oncall: mobile"] import unittest import torch import torch.nn as nn import torch.utils.bundled_inputs from torch.testing._internal.common_utils import TestCase, run_tests, skipIfNoXNNPACK from torch.testing._internal.jit_utils import get_forward, get_forward_graph from torch.utils.mobile_optimizer import (LintCode, generate_mobile_module_lints, optimize_for_mobile, MobileOptimizerType) from torch.nn import functional as F from torch.testing._internal.common_quantized import override_quantized_engine try: import torchvision HAS_TORCHVISION = True except ImportError: HAS_TORCHVISION = False FileCheck = torch._C.FileCheck class TestOptimizer(TestCase): @skipIfNoXNNPACK def test_optimize_for_mobile(self): batch_size = 2 input_channels_per_group = 6 height = 16 width = 16 output_channels_per_group = 6 groups = 4 kernel_h = kernel_w = 3 stride_h = stride_w = 1 pad_h = pad_w = 1 dilation = 1 input_channels = input_channels_per_group * groups output_channels = output_channels_per_group * groups kernels = (kernel_h, kernel_w) strides = (stride_h, stride_w) paddings = (pad_h, pad_w) dilations = (dilation, dilation) conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) conv_bias_shape = (output_channels) input_data = torch.rand((batch_size, input_channels, height, width)) conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w)) conv_bias = torch.rand(output_channels) result = F.conv2d(input_data, conv_weight, conv_bias, strides, paddings, dilations, groups) weight_output_dim = 24 linear_input_shape = result.shape[1] linear_weight_shape = (weight_output_dim, linear_input_shape) class MyTestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv_weight = torch.nn.Parameter(torch.rand(conv_weight_shape)) self.conv_bias = torch.nn.Parameter(torch.rand(conv_bias_shape)) self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape)) self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim)) self.strides = strides self.paddings = paddings self.dilations = dilations self.groups = groups def forward(self, x): o = F.conv2d(x, self.conv_weight, self.conv_bias, self.strides, self.paddings, self.dilations, self.groups) o = F.relu(o) x = o.permute([0, 2, 3, 1]) o = F.linear(x, self.linear_weight, self.linear_bias) o = o + x return F.relu(o) @torch.jit.export def foo(self, x): o = F.conv2d(x, self.conv_weight, self.conv_bias, self.strides, self.paddings, self.dilations, self.groups) o = F.relu(o) x = o.permute([0, 2, 3, 1]) o = F.linear(x, self.linear_weight, self.linear_bias) o = o + x return F.relu(o) class BNTestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 20, 5, 1) self.bn = torch.nn.BatchNorm2d(num_features=20) self.bn.eps = 0.0023 def forward(self, x): x = self.conv(x) x = self.bn(x) return x data_shape = (batch_size, input_channels, height, width) input_data = torch.normal(1, 20, size=data_shape) scripted_model = torch.jit.script(MyTestModule()) scripted_model.eval() initial_result = scripted_model(input_data) initial_foo_result = scripted_model.foo(input_data) optimized_scripted_model = optimize_for_mobile(scripted_model, preserved_methods=['foo']) optimized_result = optimized_scripted_model(input_data) optimized_foo_result = optimized_scripted_model.foo(input_data) FileCheck().check_not("Tensor = aten::conv2d") \ .check_not("Tensor = prim::CallFunction") \ .check_not("prepacked::conv2d_clamp_prepack") \ .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \ .check_not("prepacked::linear_clamp_prepack") \ .check_count("prepacked::linear_clamp_run", 1, exactly=True) \ .check_not("aten::add(") \ .check_not("aten::relu(") \ .check_count("aten::_add_relu(", 1, exactly=True) \ .run(optimized_scripted_model.graph) torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3) FileCheck().check_not("Tensor = aten::conv2d") \ .check_not("Tensor = prim::CallFunction") \ .check_not("prepacked::conv2d_clamp_prepack") \ .check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \ .check_not("prepacked::linear_clamp_prepack") \ .check_count("prepacked::linear_clamp_run", 1, exactly=True) \ .check_not("aten::add(") \ .check_not("aten::relu(") \ .check_count("aten::_add_relu(", 1, exactly=True) \ .run(optimized_scripted_model.foo.graph) torch.testing.assert_close(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3) optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack) optimized_result_no_prepack = optimized_scripted_model_no_prepack(input_data) FileCheck().check_count("Tensor = aten::conv2d", 1, exactly=True) \ .check_not("prepacked::linear_clamp_run") \ .check_not("prepacked::conv2d_clamp_run") \ .run(optimized_scripted_model_no_prepack.graph) torch.testing.assert_close(initial_result, optimized_result_no_prepack, rtol=1e-2, atol=1e-3) bn_test_module = BNTestModule() bn_scripted_module = torch.jit.script(bn_test_module) bn_scripted_module.eval() self.assertEqual(len(torch.jit.export_opnames(bn_scripted_module)), 11) FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \ .run(str(get_forward(bn_scripted_module._c).graph)) optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS} bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_prepack) self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1) bn_input = torch.rand(1, 1, 6, 6) torch.testing.assert_close(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3) optimization_blocklist_no_fold_bn = {MobileOptimizerType.CONV_BN_FUSION} no_bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blocklist_no_fold_bn) FileCheck().check_count("aten::batch_norm", 1, exactly=True) \ .run(str(get_forward_graph(no_bn_fold_scripted_module._c))) bn_input = torch.rand(1, 1, 6, 6) torch.testing.assert_close(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3) class MyMobileOptimizedTagTest(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape)) self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim)) def forward(self, x): o = F.linear(x, self.linear_weight, self.linear_bias) return F.relu(o) mobile_optimized_tag_module = MyMobileOptimizedTagTest() m = torch.jit.script(mobile_optimized_tag_module) m.eval() opt_m = optimize_for_mobile(m) tag = getattr(opt_m, "mobile_optimized", None) self.assertTrue(tag) class MyPreserveMethodsTest(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear_weight = torch.nn.Parameter(torch.rand(linear_weight_shape)) self.linear_bias = torch.nn.Parameter(torch.rand(weight_output_dim)) def forward(self, x): o = F.linear(x, self.linear_weight, self.linear_bias) return F.relu(o) @torch.jit.export def preserveThis(self): pass preserve_method_module = MyPreserveMethodsTest() m = torch.jit.script(preserve_method_module) m.eval() opt_m = optimize_for_mobile(m) no_preserveThis = getattr(opt_m, "preserveThis", None) self.assertEqual(no_preserveThis, None) opt_m = optimize_for_mobile(m, preserved_methods=["preserveThis"]) preserveThis = getattr(opt_m, "preserveThis", None) self.assertNotEqual(preserveThis, None) class OptimizeNoForwardTest(torch.nn.Module): def __init__(self) -> None: super().__init__() self.l = nn.Linear(10, 100) self.l2 = nn.Linear(100, 1) self.d = nn.Dropout(p=0.2) @torch.jit.export def foo(self, x): x = self.d(F.relu(self.l(x))) x = self.l2(x) x = x + torch.ones(1, 100) return F.relu(x) input_data = torch.ones(1, 10) m = torch.jit.script(OptimizeNoForwardTest()) m.eval() initial_result = m.foo(input_data) optimized_scripted_model = optimize_for_mobile(m, preserved_methods=['foo']) optimized_result = optimized_scripted_model.foo(input_data) FileCheck().check_not("dropout.__") \ .check_count("aten::_add_relu(", 1, exactly=True) \ .run(optimized_scripted_model.foo.graph) torch.testing.assert_close(initial_result, optimized_result, rtol=1e-2, atol=1e-3) class BNTestNoForwardModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 20, 5, 1) self.bn = torch.nn.BatchNorm2d(num_features=20) self.bn.eps = 0.0023 @torch.jit.export def foo(self, x): x = self.conv(x) x = self.bn(x) return x bn_test_no_forward_module = BNTestNoForwardModule() bn_no_forward_scripted_module = torch.jit.script(bn_test_no_forward_module) bn_no_forward_scripted_module.eval() self.assertEqual(len(torch.jit.export_opnames(bn_no_forward_scripted_module)), 11) FileCheck().check_count('prim::CallMethod[name="forward"]', 2, exactly=True) \ .run(bn_no_forward_scripted_module.foo.graph) bn_fold_no_forward_scripted_module = optimize_for_mobile(bn_no_forward_scripted_module, preserved_methods=['foo']) self.assertEqual(len(torch.jit.export_opnames(bn_fold_no_forward_scripted_module)), 1) bn_input = torch.rand(1, 1, 6, 6) torch.testing.assert_close( bn_no_forward_scripted_module.foo(bn_input), bn_fold_no_forward_scripted_module.foo(bn_input), rtol=1e-2, atol=1e-3) @skipIfNoXNNPACK def test_quantized_conv_no_asan_failures(self): # There were ASAN failures when fold_conv_bn was run on # already quantized conv modules. Verifying that this does # not happen again. if 'qnnpack' not in torch.backends.quantized.supported_engines: return class Child(nn.Module): def __init__(self) -> None: super().__init__() self.conv2 = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv2(x) return x class Parent(nn.Module): def __init__(self) -> None: super().__init__() self.quant = torch.ao.quantization.QuantStub() self.conv1 = nn.Conv2d(1, 1, 1) self.child = Child() self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv1(x) x = self.child(x) x = self.dequant(x) return x with override_quantized_engine('qnnpack'): model = Parent() model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack') torch.ao.quantization.prepare(model, inplace=True) model(torch.randn(4, 1, 4, 4)) torch.ao.quantization.convert(model, inplace=True) model = torch.jit.script(model) # this line should not have ASAN failures model_optim = optimize_for_mobile(model) def test_generate_mobile_module_lints(self): class MyTestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc = torch.nn.Linear(4, 4) self.dropout = torch.nn.Dropout(p=0.5) def forward(self, inputs): out = self.fc(inputs) out = self.dropout(out) return out class MyBNModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.bn = torch.nn.BatchNorm2d(4, affine=True) def forward(self, inputs): bn = self.bn(inputs) return bn class MyBundledInputModule(torch.nn.Module): def forward(self, inputs): return inputs def get_lint_count_by_type(lint_type, module_lint_List): return len([lint_dict for lint_dict in module_lint_List if lint_dict['name'] == lint_type.name]) test_module = torch.jit.script(MyTestModule()) test_module_lint_list = generate_mobile_module_lints(test_module) self.assertEqual(len(test_module_lint_list), 4) self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, test_module_lint_list), 1) self.assertEqual(get_lint_count_by_type(LintCode.DROPOUT, test_module_lint_list), 1) self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, test_module_lint_list), 2) bn_module = torch.jit.script(MyBNModule()) bn_module_lint_list = generate_mobile_module_lints(bn_module) self.assertEqual(len(bn_module_lint_list), 4) self.assertEqual(get_lint_count_by_type(LintCode.BUNDLED_INPUT, bn_module_lint_list), 1) self.assertEqual(get_lint_count_by_type(LintCode.BATCHNORM, bn_module_lint_list), 1) self.assertEqual(get_lint_count_by_type(LintCode.REQUIRES_GRAD, bn_module_lint_list), 2) bi_module = torch.jit.script(MyBundledInputModule()) torch.utils.bundled_inputs.augment_model_with_bundled_inputs( bi_module, [(torch.tensor([1]),)], []) bi_module_lint_list = generate_mobile_module_lints(bi_module) self.assertEqual(len(bi_module_lint_list), 0) @skipIfNoXNNPACK def test_preserve_bundled_inputs_methods(self): class MyBundledInputModule(torch.nn.Module): def forward(self, inputs): return inputs class MyIncompleteBundledInputModule(torch.nn.Module): def forward(self, inputs): return inputs @torch.jit.export def get_all_bundled_inputs(self): pass bi_module = torch.jit.script(MyBundledInputModule()) module_optim_bi_not_preserved = optimize_for_mobile(bi_module) # Expected to be False since no bundled inputs methods were added self.assertFalse( hasattr(module_optim_bi_not_preserved, 'get_all_bundled_inputs') or hasattr(module_optim_bi_not_preserved, 'get_num_bundled_inputs') ) # Add bundled inputs methods to the module torch.utils.bundled_inputs.augment_model_with_bundled_inputs( bi_module, [(torch.tensor([1]),)], []) # Now they should be preserved module_optim_bi_preserved = optimize_for_mobile(bi_module) # All of the bundled inputs methods were preserved self.assertTrue( hasattr(module_optim_bi_preserved, 'get_all_bundled_inputs') and hasattr(module_optim_bi_preserved, 'get_num_bundled_inputs') ) bundled_input = module_optim_bi_preserved.get_all_bundled_inputs()[0] module_optim_bi_preserved(*bundled_input) # If not all 3 bundled inputs methods are present in the module, # we will not try to preserve them unless specified by the user. incomplete_bi_module = torch.jit.script(MyIncompleteBundledInputModule()) incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module) self.assertFalse(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs')) # Specifically preserve get_all_bundled_inputs even if it's the only one # bundled inputs method available. incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs']) self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs')) @skipIfNoXNNPACK def test_hoist_conv_packed_params(self): if 'qnnpack' not in torch.backends.quantized.supported_engines: return class Standalone(nn.Module): def __init__(self) -> None: super().__init__() self.quant = torch.ao.quantization.QuantStub() self.conv1 = nn.Conv2d(1, 1, 1) self.conv2 = nn.Conv2d(1, 1, 1) self.relu = nn.ReLU() self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv1(x) x = self.conv2(x) x = self.relu(x) x = self.dequant(x) return x def fuse_model(self): torch.ao.quantization.fuse_modules(self, [['conv2', 'relu']], inplace=True) class Child(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 1, 1) def forward(self, x): x = self.conv1(x) return x class Parent(nn.Module): def __init__(self) -> None: super().__init__() self.quant = torch.ao.quantization.QuantStub() self.conv1 = nn.Conv2d(1, 1, 1) self.child = Child() # TODO: test nn.Sequential after #42039 is fixed self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, x): x = self.quant(x) x = self.conv1(x) x = self.child(x) x = self.dequant(x) return x def fuse_model(self): pass with override_quantized_engine('qnnpack'): def _quant_script_and_optimize(model): model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack') model.fuse_model() torch.ao.quantization.prepare(model, inplace=True) model(torch.randn(4, 1, 4, 4)) torch.ao.quantization.convert(model, inplace=True) model = torch.jit.script(model) model_optim = optimize_for_mobile(model) return model, model_optim # basic case m, m_optim = _quant_script_and_optimize(Standalone()) FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \ .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \ .run(m_optim.graph) self.assertFalse(hasattr(m_optim, "conv1")) self.assertFalse(hasattr(m_optim, "conv2")) data = torch.randn(4, 1, 4, 4) m_res = m(data) m_optim_res = m_optim(data) torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3) # generic case m, m_optim = _quant_script_and_optimize(Parent()) FileCheck().check_not('Conv2d = prim::GetAttr[name="conv1"]') \ .check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \ .run(m_optim.graph) self.assertFalse(hasattr(m_optim, "conv1")) self.assertFalse(hasattr(m_optim, "child")) data = torch.randn(4, 1, 4, 4) m_res = m(data) m_optim_res = m_optim(data) torch.testing.assert_close(m_res, m_optim_res, rtol=1e-2, atol=1e-3) @skipIfNoXNNPACK @unittest.skipUnless(HAS_TORCHVISION, "Needs torchvision") def test_mobilenet_optimize_for_mobile(self): m = torchvision.models.mobilenet_v3_small() m = torch.jit.script(m) m = optimize_for_mobile(m) # run forward 3 times until segfault, see https://github.com/pytorch/pytorch/issues/52463 x = torch.zeros(1, 3, 56, 56) self.assertEqual(m(x).numel(), 1000) self.assertEqual(m(x).numel(), 1000) self.assertEqual(m(x).numel(), 1000) def test_clone_module_with_class(self): class MyInnerTestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.pqr = torch.Tensor([10., 20., 30.]) def forward(self, inputs): return inputs @torch.jit.export def dummy_method_not_cloned(self): return 20 class MyTestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.abc = 23 self.pqr = torch.Tensor([1., 2., 3.]) self.inner = MyInnerTestModule() def forward(self, inputs): x = self.dummy_method_cloned() # The call to self.inner.dummy_method_not_cloned should not raise an error y = self.inner.dummy_method_not_cloned() # The call to self.inner.pqr should not raise an error z = self.inner.pqr return (inputs, x, y, z) @torch.jit.export def dummy_method_not_cloned2(self): # The call to self.inner.dummy_method_not_cloned should not raise an error y = self.inner.dummy_method_not_cloned() # The call to self.inner.pqr should not raise an error z = self.inner.pqr return self.pqr, self.dummy_method_not_cloned(), y, z @torch.jit.export def dummy_method_not_cloned(self): return None @torch.jit.export def dummy_method_cloned(self): return None @torch.jit.export def dummy_method_ref_attr_pqr(self): return self.pqr, self.inner.pqr m = torch.jit.script(MyTestModule()) # Check that the methods exist on the original model. self.assertEqual(hasattr(m, "dummy_method_not_cloned"), True) self.assertEqual(hasattr(m, "dummy_method_cloned"), True) self.assertEqual(hasattr(m, "dummy_method_not_cloned2"), True) self.assertEqual(hasattr(m, "pqr"), True) # Case-1: Successfully clone, ignoring 2 methods, keeping all attributes. cloned = torch._C._hack_do_not_use_clone_module_with_class( m._c, ["dummy_method_not_cloned", "dummy_method_not_cloned2"], # ignored_methods [], # ignored_attributes ) # Check that the ignored methods don't exist on the cloned model. self.assertEqual(hasattr(cloned, "dummy_method_not_cloned"), False) self.assertEqual(hasattr(cloned, "dummy_method_cloned"), True) self.assertEqual(hasattr(cloned, "dummy_method_not_cloned2"), False) self.assertEqual(hasattr(cloned, "pqr"), True) # Check that the cloned class has a classname that starts with __torch__. self.assertTrue( cloned.qualified_name.startswith('__torch__.'), ("Expected the cloned module's name to start with the string " f"'__torch__.', but got: {cloned.qualified_name}"), ) # Case-2: Successfully clone the module, ignoring the attribute pqr, and the method that references it. cloned = torch._C._hack_do_not_use_clone_module_with_class( m._c, ["dummy_method_not_cloned", "dummy_method_not_cloned2", "dummy_method_ref_attr_pqr"], ["pqr"], ) # Check that the ignored methods don't exist on the cloned model. self.assertEqual(hasattr(cloned, "dummy_method_not_cloned"), False) self.assertEqual(hasattr(cloned, "dummy_method_cloned"), True) self.assertEqual(hasattr(cloned, "dummy_method_not_cloned2"), False) self.assertEqual(hasattr(cloned, "dummy_method_ref_attr_pqr"), False) self.assertEqual(hasattr(cloned, "pqr"), False) # Case-3: The statement below will throw since dummy_method_cloned2 is preserved, # and references dummy_method_not_cloned, which is not cloned. with self.assertRaises(RuntimeError): cloned = torch._C._hack_do_not_use_clone_module_with_class(m._c, ["dummy_method_not_cloned"], []) # Case-4: The statement below will throw since dummy_method_ref_attr_pqr # is preserved, and references "pqr", which is not cloned. with self.assertRaises(RuntimeError): cloned = torch._C._hack_do_not_use_clone_module_with_class( m._c, ["dummy_method_not_cloned", "dummy_method_not_cloned2"], ["pqr"], ) if __name__ == '__main__': run_tests()