xref: /aosp_15_r20/external/pytorch/test/test_metal.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: mobile"]
2
3import torch
4from torch.nn import functional as F
5
6from torch.testing._internal.common_utils import TestCase, run_tests
7from torch.testing import FileCheck
8import io
9
10class TestMetalRewritePass(TestCase):
11    @staticmethod
12    def validate_transformed_module(
13            # To please flake
14            self,
15            pattern_count_map,
16            data_shape,
17            prepack_removal=False,
18            fuse_clamping_ops=False):
19        module_instance = self
20        scripted_model = torch.jit.script(module_instance)
21        scripted_model.eval()
22        input_data = torch.normal(1, 20, size=data_shape)
23        ref_result = scripted_model(input_data)
24        torch._C._jit_pass_metal_insert_prepacked_ops(scripted_model._c)
25        if fuse_clamping_ops or prepack_removal:
26            scripted_model._c = torch._C._freeze_module(scripted_model._c)
27        if fuse_clamping_ops:
28            torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv(scripted_model._c)
29        if prepack_removal:
30            torch._C._jit_pass_metal_fold_prepacking_ops(scripted_model._c)
31
32        buffer = io.BytesIO()
33        torch.jit.save(scripted_model, buffer)
34        buffer.seek(0)
35        deserialized_scripted_model = torch.jit.load(buffer)
36        for pattern, v in pattern_count_map.items():
37            if (v == 0):
38                FileCheck().check(pattern).run(deserialized_scripted_model.graph)
39            elif (v == -1):
40                FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
41            else:
42                FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
43
44    def test_conv(self):
45        # Conv params
46        batch_size = 2
47        input_channels_per_group = 6
48        height = 16
49        width = 16
50        output_channels_per_group = 6
51        groups = 4
52        kernel_h = kernel_w = 3
53        stride_h = stride_w = 1
54        pad_h = pad_w = 1
55        dilation = 1
56        input_channels = input_channels_per_group * groups
57        output_channels = output_channels_per_group * groups
58        kernels = (kernel_h, kernel_w)
59        strides = (stride_h, stride_w)
60        paddings = (pad_h, pad_w)
61        dilations = (dilation, dilation)
62        conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
63        conv_bias_shape = (output_channels)
64
65        class Conv2D(torch.nn.Module):
66            def __init__(self) -> None:
67                super().__init__()
68                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
69                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
70                self.strides = strides
71                self.paddings = paddings
72                self.dilations = dilations
73                self.groups = groups
74
75            def forward(self, x):
76                return F.conv2d(x, self.weight, self.bias,
77                                self.strides, self.paddings, self.dilations, self.groups)
78
79        data_shape = (batch_size, input_channels, height, width)
80        pattern_count_map = {"Tensor = aten::conv2d": -1,
81                             "metal_prepack::conv2d_prepack": 1,
82                             "metal_prepack::conv2d_run": 1}
83        TestMetalRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
84
85        class Conv2DRelu(torch.nn.Module):
86            def __init__(self) -> None:
87                super().__init__()
88                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
89                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
90                self.strides = strides
91                self.paddings = paddings
92                self.dilations = dilations
93                self.groups = groups
94
95            def forward(self, x):
96                o = F.conv2d(x, self.weight, self.bias,
97                             self.strides, self.paddings, self.dilations, self.groups)
98                o = F.relu(o)
99                return o
100
101        data_shape = (batch_size, input_channels, height, width)
102        pattern_count_map = {"Tensor = aten::conv2d": -1,
103                             "metal_prepack::conv2d_prepack": 1,
104                             "metal_prepack::conv2d_run": 1}
105        TestMetalRewritePass.validate_transformed_module(
106            Conv2DRelu(), pattern_count_map, data_shape)
107
108        pattern_count_map["aten::relu"] = 1
109        pattern_count_map["metal_prepack::conv2d_prepack"] = -1
110        TestMetalRewritePass.validate_transformed_module(
111            Conv2DRelu(),
112            pattern_count_map,
113            data_shape,
114            prepack_removal=True)
115        pattern_count_map["aten::relu"] = -1
116        TestMetalRewritePass.validate_transformed_module(
117            Conv2DRelu(),
118            pattern_count_map,
119            data_shape,
120            prepack_removal=True,
121            fuse_clamping_ops=True)
122
123
124        class Conv2DHardtanh(torch.nn.Module):
125            def __init__(self) -> None:
126                super().__init__()
127                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
128                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
129                self.strides = strides
130                self.paddings = paddings
131                self.dilations = dilations
132                self.groups = groups
133
134            def forward(self, x):
135                o = F.conv2d(x, self.weight, self.bias,
136                             self.strides, self.paddings, self.dilations, self.groups)
137                o = F.hardtanh(o)
138                return o
139
140        data_shape = (batch_size, input_channels, height, width)
141        pattern_count_map = {"Tensor = aten::conv2d": -1,
142                             "metal_prepack::conv2d_prepack": 1,
143                             "metal_prepack::conv2d_run": 1}
144        TestMetalRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
145        pattern_count_map["aten::hardtanh"] = 1
146        pattern_count_map["metal_prepack::conv2d_prepack"] = -1
147        TestMetalRewritePass.validate_transformed_module(
148            Conv2DHardtanh(),
149            pattern_count_map,
150            data_shape,
151            prepack_removal=True)
152        pattern_count_map["aten::hardtanh"] = -1
153        TestMetalRewritePass.validate_transformed_module(
154            Conv2DRelu(),
155            pattern_count_map,
156            data_shape,
157            prepack_removal=True,
158            fuse_clamping_ops=True)
159
160if __name__ == "__main__":
161    run_tests()
162