1# Owner(s): ["oncall: mobile"] 2 3import torch 4from torch.nn import functional as F 5 6from torch.testing._internal.common_utils import TestCase, run_tests 7from torch.testing import FileCheck 8import io 9 10class TestMetalRewritePass(TestCase): 11 @staticmethod 12 def validate_transformed_module( 13 # To please flake 14 self, 15 pattern_count_map, 16 data_shape, 17 prepack_removal=False, 18 fuse_clamping_ops=False): 19 module_instance = self 20 scripted_model = torch.jit.script(module_instance) 21 scripted_model.eval() 22 input_data = torch.normal(1, 20, size=data_shape) 23 ref_result = scripted_model(input_data) 24 torch._C._jit_pass_metal_insert_prepacked_ops(scripted_model._c) 25 if fuse_clamping_ops or prepack_removal: 26 scripted_model._c = torch._C._freeze_module(scripted_model._c) 27 if fuse_clamping_ops: 28 torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv(scripted_model._c) 29 if prepack_removal: 30 torch._C._jit_pass_metal_fold_prepacking_ops(scripted_model._c) 31 32 buffer = io.BytesIO() 33 torch.jit.save(scripted_model, buffer) 34 buffer.seek(0) 35 deserialized_scripted_model = torch.jit.load(buffer) 36 for pattern, v in pattern_count_map.items(): 37 if (v == 0): 38 FileCheck().check(pattern).run(deserialized_scripted_model.graph) 39 elif (v == -1): 40 FileCheck().check_not(pattern).run(deserialized_scripted_model.graph) 41 else: 42 FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph) 43 44 def test_conv(self): 45 # Conv params 46 batch_size = 2 47 input_channels_per_group = 6 48 height = 16 49 width = 16 50 output_channels_per_group = 6 51 groups = 4 52 kernel_h = kernel_w = 3 53 stride_h = stride_w = 1 54 pad_h = pad_w = 1 55 dilation = 1 56 input_channels = input_channels_per_group * groups 57 output_channels = output_channels_per_group * groups 58 kernels = (kernel_h, kernel_w) 59 strides = (stride_h, stride_w) 60 paddings = (pad_h, pad_w) 61 dilations = (dilation, dilation) 62 conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) 63 conv_bias_shape = (output_channels) 64 65 class Conv2D(torch.nn.Module): 66 def __init__(self) -> None: 67 super().__init__() 68 self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) 69 self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) 70 self.strides = strides 71 self.paddings = paddings 72 self.dilations = dilations 73 self.groups = groups 74 75 def forward(self, x): 76 return F.conv2d(x, self.weight, self.bias, 77 self.strides, self.paddings, self.dilations, self.groups) 78 79 data_shape = (batch_size, input_channels, height, width) 80 pattern_count_map = {"Tensor = aten::conv2d": -1, 81 "metal_prepack::conv2d_prepack": 1, 82 "metal_prepack::conv2d_run": 1} 83 TestMetalRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape) 84 85 class Conv2DRelu(torch.nn.Module): 86 def __init__(self) -> None: 87 super().__init__() 88 self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) 89 self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) 90 self.strides = strides 91 self.paddings = paddings 92 self.dilations = dilations 93 self.groups = groups 94 95 def forward(self, x): 96 o = F.conv2d(x, self.weight, self.bias, 97 self.strides, self.paddings, self.dilations, self.groups) 98 o = F.relu(o) 99 return o 100 101 data_shape = (batch_size, input_channels, height, width) 102 pattern_count_map = {"Tensor = aten::conv2d": -1, 103 "metal_prepack::conv2d_prepack": 1, 104 "metal_prepack::conv2d_run": 1} 105 TestMetalRewritePass.validate_transformed_module( 106 Conv2DRelu(), pattern_count_map, data_shape) 107 108 pattern_count_map["aten::relu"] = 1 109 pattern_count_map["metal_prepack::conv2d_prepack"] = -1 110 TestMetalRewritePass.validate_transformed_module( 111 Conv2DRelu(), 112 pattern_count_map, 113 data_shape, 114 prepack_removal=True) 115 pattern_count_map["aten::relu"] = -1 116 TestMetalRewritePass.validate_transformed_module( 117 Conv2DRelu(), 118 pattern_count_map, 119 data_shape, 120 prepack_removal=True, 121 fuse_clamping_ops=True) 122 123 124 class Conv2DHardtanh(torch.nn.Module): 125 def __init__(self) -> None: 126 super().__init__() 127 self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) 128 self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) 129 self.strides = strides 130 self.paddings = paddings 131 self.dilations = dilations 132 self.groups = groups 133 134 def forward(self, x): 135 o = F.conv2d(x, self.weight, self.bias, 136 self.strides, self.paddings, self.dilations, self.groups) 137 o = F.hardtanh(o) 138 return o 139 140 data_shape = (batch_size, input_channels, height, width) 141 pattern_count_map = {"Tensor = aten::conv2d": -1, 142 "metal_prepack::conv2d_prepack": 1, 143 "metal_prepack::conv2d_run": 1} 144 TestMetalRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape) 145 pattern_count_map["aten::hardtanh"] = 1 146 pattern_count_map["metal_prepack::conv2d_prepack"] = -1 147 TestMetalRewritePass.validate_transformed_module( 148 Conv2DHardtanh(), 149 pattern_count_map, 150 data_shape, 151 prepack_removal=True) 152 pattern_count_map["aten::hardtanh"] = -1 153 TestMetalRewritePass.validate_transformed_module( 154 Conv2DRelu(), 155 pattern_count_map, 156 data_shape, 157 prepack_removal=True, 158 fuse_clamping_ops=True) 159 160if __name__ == "__main__": 161 run_tests() 162