xref: /aosp_15_r20/external/pytorch/test/mobile/model_test/quantization_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2import torch.nn as nn
3
4
5class GeneralQuantModule(torch.nn.Module):
6    def __init__(self) -> None:
7        super().__init__()
8        self.embedding = torch.ao.nn.quantized.Embedding(
9            num_embeddings=10, embedding_dim=12
10        )
11        self.embedding_input = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8])
12        self.func = torch.ao.nn.quantized.QFunctional()
13        self.conv1 = torch.ao.nn.quantized.ConvTranspose1d(16, 33, 3, stride=2)
14        self.conv2 = torch.ao.nn.quantized.ConvTranspose2d(16, 33, 3, stride=2)
15        self.conv3 = torch.ao.nn.quantized.ConvTranspose3d(16, 33, 3, stride=2)
16
17    def forward(self):
18        a = torch.quantize_per_tensor(torch.tensor([3.0]), 1.0, 0, torch.qint32)
19        b = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint32)
20        c = torch.quantize_per_tensor(
21            torch.tensor([3.0]), torch.tensor(1.0), torch.tensor(0), torch.qint32
22        )
23        input1 = torch.randn(1, 16, 4)
24        input2 = torch.randn(1, 16, 4, 4)
25        input3 = torch.randn(1, 16, 4, 4, 4)
26        return len(
27            self.func.add(a, b),
28            self.func.cat((a, a), 0),
29            self.func.mul(a, b),
30            self.func.add_relu(a, b),
31            self.func.add_scalar(a, b),
32            self.func.mul_scalar(a, b),
33            self.embedding(self.embedding_input),
34            self.conv1(
35                torch.quantize_per_tensor(
36                    input1, scale=1.0, zero_point=0, dtype=torch.quint8
37                )
38            ),
39            self.conv2(
40                torch.quantize_per_tensor(
41                    input2, scale=1.0, zero_point=0, dtype=torch.quint8
42                )
43            ),
44            c,
45            # self.conv3(torch.quantize_per_tensor(input3, scale=1.0, zero_point=0, dtype=torch.quint8)), # failed on iOS
46        )
47
48
49class DynamicQuantModule:
50    def __init__(self) -> None:
51        super().__init__()
52        self.module = self.M()
53
54    def getModule(self):
55        return torch.ao.quantization.quantize_dynamic(self.module, dtype=torch.qint8)
56
57    class M(torch.nn.Module):
58        def __init__(self) -> None:
59            super(DynamicQuantModule.M, self).__init__()
60            self.rnn = nn.RNN(4, 8, 2)
61            self.rnncell = nn.RNNCell(4, 8)
62            self.gru = nn.GRU(4, 8, 2)
63            self.grucell = nn.GRUCell(4, 8)
64            self.lstm = nn.LSTM(4, 8, 2)
65            self.lstmcell = nn.LSTMCell(4, 8)
66            self.linears = nn.ModuleList(
67                [
68                    nn.Identity(54),
69                    nn.Linear(20, 20),
70                    nn.Bilinear(20, 20, 40),
71                ]
72            )
73            self.transformers = nn.ModuleList(
74                [
75                    nn.Transformer(
76                        d_model=2, nhead=2, num_encoder_layers=1, num_decoder_layers=1
77                    ),
78                    nn.TransformerEncoder(
79                        nn.TransformerEncoderLayer(d_model=2, nhead=2), num_layers=1
80                    ),
81                    nn.TransformerDecoder(
82                        nn.TransformerDecoderLayer(d_model=2, nhead=2), num_layers=1
83                    ),
84                ]
85            )
86            # self.a = torch.nn.utils.rnn.pad_sequence([torch.tensor([1,2,3]), torch.tensor([3,4])], batch_first=True)
87
88        def forward(self):
89            input = torch.randn(5, 3, 4)
90            h = torch.randn(2, 3, 8)
91            c = torch.randn(2, 3, 8)
92            linear_input = torch.randn(32, 20)
93            trans_input = torch.randn(1, 16, 2)
94            tgt = torch.rand(1, 16, 2)
95
96            return len(
97                (
98                    self.rnn(input, h),
99                    self.rnncell(input[0], h[0]),
100                    self.gru(input, h),
101                    self.grucell(input[0], h[0]),
102                    self.lstm(input, (h, c)),
103                    # self.lstm(torch.nn.utils.rnn.pack_padded_sequence(self.a, lengths=torch.tensor([3,2,1])), (h, c)),
104                    self.lstmcell(input[0], (h[0], c[0])),
105                    self.transformers[0](trans_input, tgt),
106                    self.transformers[1](trans_input),
107                    self.transformers[2](trans_input, tgt),
108                    self.linears[0](linear_input),
109                    self.linears[1](linear_input),
110                    self.linears[2](linear_input, linear_input),
111                )
112            )
113
114
115class StaticQuantModule:
116    def getModule(self):
117        model_fp32 = self.M()
118        model_fp32.eval()
119        model_fp32.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
120        model_fp32_prepared = torch.ao.quantization.prepare(model_fp32)
121        model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
122        return model_int8
123
124    class M(torch.nn.Module):
125        def __init__(self) -> None:
126            super(StaticQuantModule.M, self).__init__()
127            self.quant = torch.ao.quantization.QuantStub()
128            self.input1d = torch.randn(4, 2, 2)
129            self.input2d = torch.randn((4, 2, 4, 4))
130            self.input3d = torch.randn(4, 2, 2, 4, 4)
131            self.linear_input = torch.randn(32, 20)
132
133            self.layer1 = nn.Sequential(
134                nn.Conv1d(2, 2, 1), nn.InstanceNorm1d(1), nn.Hardswish()
135            )
136            self.layer2 = nn.Sequential(
137                nn.Conv2d(2, 2, 1),
138                nn.BatchNorm2d(2),
139                nn.InstanceNorm2d(1),
140                nn.LeakyReLU(),
141            )
142            self.layer3 = nn.Sequential(
143                nn.Conv3d(2, 2, 1), nn.BatchNorm3d(2), nn.InstanceNorm3d(1), nn.ReLU()
144            )
145            self.layer4 = nn.Sequential(nn.Linear(4, 3))
146            self.dequant = torch.ao.quantization.DeQuantStub()
147
148        def forward(self):
149            x = self.quant(self.input1d)
150            x = self.layer1(x)
151            x = self.dequant(x)
152
153            y = self.input2d
154            y = self.quant(y)
155            y = self.layer2(y)
156            y = self.layer4(y)
157            y = self.dequant(y)
158
159            z = self.quant(self.input3d)
160            z = self.layer3(z)
161            z = self.dequant(z)
162
163            return (x, y, z)
164
165
166class FusedQuantModule:
167    def getModule(self):
168        model_fp32 = self.M()
169        model_fp32.eval()
170        model_fp32.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
171        model_fp32_fused = torch.ao.quantization.fuse_modules(
172            model_fp32,
173            [
174                ["conv1d", "relu1"],
175                ["conv2d", "relu2"],
176                ["conv3d", "relu3"],
177                ["linear", "relu4"],
178            ],
179        )
180        model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)
181        model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
182        return model_int8
183
184    class M(torch.nn.Module):
185        def __init__(self) -> None:
186            super(FusedQuantModule.M, self).__init__()
187            self.quant = torch.ao.quantization.QuantStub()
188            self.input1d = torch.randn(4, 2, 2)
189            self.input2d = torch.randn((4, 2, 4, 4))
190            self.input3d = torch.randn(4, 2, 2, 4, 4)
191            self.conv1d = nn.Conv1d(2, 2, 1)
192            self.conv2d = nn.Conv2d(2, 2, 1)
193            self.conv3d = nn.Conv3d(2, 2, 1)
194            self.linear = nn.Linear(4, 2)
195            self.relu1 = nn.ReLU()
196            self.relu2 = nn.ReLU()
197            self.relu3 = nn.ReLU()
198            self.relu4 = nn.ReLU()
199            self.dequant = torch.ao.quantization.DeQuantStub()
200
201        def forward(self):
202            x = self.input1d
203            y = self.input2d
204            z = self.input3d
205
206            x = self.quant(x)
207            x = self.conv1d(x)
208            x = self.relu1(x)
209            x = self.dequant(x)
210
211            y = self.quant(y)
212            y = self.conv2d(y)
213            y = self.relu2(y)
214            y = self.dequant(y)
215
216            z = self.quant(z)
217            z = self.conv3d(z)
218            z = self.relu3(z)
219            z = self.linear(z)
220            z = self.relu4(z)
221            z = self.dequant(z)
222
223            return (x, y, z)
224