1import torch 2import torch.nn as nn 3 4 5class GeneralQuantModule(torch.nn.Module): 6 def __init__(self) -> None: 7 super().__init__() 8 self.embedding = torch.ao.nn.quantized.Embedding( 9 num_embeddings=10, embedding_dim=12 10 ) 11 self.embedding_input = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8]) 12 self.func = torch.ao.nn.quantized.QFunctional() 13 self.conv1 = torch.ao.nn.quantized.ConvTranspose1d(16, 33, 3, stride=2) 14 self.conv2 = torch.ao.nn.quantized.ConvTranspose2d(16, 33, 3, stride=2) 15 self.conv3 = torch.ao.nn.quantized.ConvTranspose3d(16, 33, 3, stride=2) 16 17 def forward(self): 18 a = torch.quantize_per_tensor(torch.tensor([3.0]), 1.0, 0, torch.qint32) 19 b = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint32) 20 c = torch.quantize_per_tensor( 21 torch.tensor([3.0]), torch.tensor(1.0), torch.tensor(0), torch.qint32 22 ) 23 input1 = torch.randn(1, 16, 4) 24 input2 = torch.randn(1, 16, 4, 4) 25 input3 = torch.randn(1, 16, 4, 4, 4) 26 return len( 27 self.func.add(a, b), 28 self.func.cat((a, a), 0), 29 self.func.mul(a, b), 30 self.func.add_relu(a, b), 31 self.func.add_scalar(a, b), 32 self.func.mul_scalar(a, b), 33 self.embedding(self.embedding_input), 34 self.conv1( 35 torch.quantize_per_tensor( 36 input1, scale=1.0, zero_point=0, dtype=torch.quint8 37 ) 38 ), 39 self.conv2( 40 torch.quantize_per_tensor( 41 input2, scale=1.0, zero_point=0, dtype=torch.quint8 42 ) 43 ), 44 c, 45 # self.conv3(torch.quantize_per_tensor(input3, scale=1.0, zero_point=0, dtype=torch.quint8)), # failed on iOS 46 ) 47 48 49class DynamicQuantModule: 50 def __init__(self) -> None: 51 super().__init__() 52 self.module = self.M() 53 54 def getModule(self): 55 return torch.ao.quantization.quantize_dynamic(self.module, dtype=torch.qint8) 56 57 class M(torch.nn.Module): 58 def __init__(self) -> None: 59 super(DynamicQuantModule.M, self).__init__() 60 self.rnn = nn.RNN(4, 8, 2) 61 self.rnncell = nn.RNNCell(4, 8) 62 self.gru = nn.GRU(4, 8, 2) 63 self.grucell = nn.GRUCell(4, 8) 64 self.lstm = nn.LSTM(4, 8, 2) 65 self.lstmcell = nn.LSTMCell(4, 8) 66 self.linears = nn.ModuleList( 67 [ 68 nn.Identity(54), 69 nn.Linear(20, 20), 70 nn.Bilinear(20, 20, 40), 71 ] 72 ) 73 self.transformers = nn.ModuleList( 74 [ 75 nn.Transformer( 76 d_model=2, nhead=2, num_encoder_layers=1, num_decoder_layers=1 77 ), 78 nn.TransformerEncoder( 79 nn.TransformerEncoderLayer(d_model=2, nhead=2), num_layers=1 80 ), 81 nn.TransformerDecoder( 82 nn.TransformerDecoderLayer(d_model=2, nhead=2), num_layers=1 83 ), 84 ] 85 ) 86 # self.a = torch.nn.utils.rnn.pad_sequence([torch.tensor([1,2,3]), torch.tensor([3,4])], batch_first=True) 87 88 def forward(self): 89 input = torch.randn(5, 3, 4) 90 h = torch.randn(2, 3, 8) 91 c = torch.randn(2, 3, 8) 92 linear_input = torch.randn(32, 20) 93 trans_input = torch.randn(1, 16, 2) 94 tgt = torch.rand(1, 16, 2) 95 96 return len( 97 ( 98 self.rnn(input, h), 99 self.rnncell(input[0], h[0]), 100 self.gru(input, h), 101 self.grucell(input[0], h[0]), 102 self.lstm(input, (h, c)), 103 # self.lstm(torch.nn.utils.rnn.pack_padded_sequence(self.a, lengths=torch.tensor([3,2,1])), (h, c)), 104 self.lstmcell(input[0], (h[0], c[0])), 105 self.transformers[0](trans_input, tgt), 106 self.transformers[1](trans_input), 107 self.transformers[2](trans_input, tgt), 108 self.linears[0](linear_input), 109 self.linears[1](linear_input), 110 self.linears[2](linear_input, linear_input), 111 ) 112 ) 113 114 115class StaticQuantModule: 116 def getModule(self): 117 model_fp32 = self.M() 118 model_fp32.eval() 119 model_fp32.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") 120 model_fp32_prepared = torch.ao.quantization.prepare(model_fp32) 121 model_int8 = torch.ao.quantization.convert(model_fp32_prepared) 122 return model_int8 123 124 class M(torch.nn.Module): 125 def __init__(self) -> None: 126 super(StaticQuantModule.M, self).__init__() 127 self.quant = torch.ao.quantization.QuantStub() 128 self.input1d = torch.randn(4, 2, 2) 129 self.input2d = torch.randn((4, 2, 4, 4)) 130 self.input3d = torch.randn(4, 2, 2, 4, 4) 131 self.linear_input = torch.randn(32, 20) 132 133 self.layer1 = nn.Sequential( 134 nn.Conv1d(2, 2, 1), nn.InstanceNorm1d(1), nn.Hardswish() 135 ) 136 self.layer2 = nn.Sequential( 137 nn.Conv2d(2, 2, 1), 138 nn.BatchNorm2d(2), 139 nn.InstanceNorm2d(1), 140 nn.LeakyReLU(), 141 ) 142 self.layer3 = nn.Sequential( 143 nn.Conv3d(2, 2, 1), nn.BatchNorm3d(2), nn.InstanceNorm3d(1), nn.ReLU() 144 ) 145 self.layer4 = nn.Sequential(nn.Linear(4, 3)) 146 self.dequant = torch.ao.quantization.DeQuantStub() 147 148 def forward(self): 149 x = self.quant(self.input1d) 150 x = self.layer1(x) 151 x = self.dequant(x) 152 153 y = self.input2d 154 y = self.quant(y) 155 y = self.layer2(y) 156 y = self.layer4(y) 157 y = self.dequant(y) 158 159 z = self.quant(self.input3d) 160 z = self.layer3(z) 161 z = self.dequant(z) 162 163 return (x, y, z) 164 165 166class FusedQuantModule: 167 def getModule(self): 168 model_fp32 = self.M() 169 model_fp32.eval() 170 model_fp32.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") 171 model_fp32_fused = torch.ao.quantization.fuse_modules( 172 model_fp32, 173 [ 174 ["conv1d", "relu1"], 175 ["conv2d", "relu2"], 176 ["conv3d", "relu3"], 177 ["linear", "relu4"], 178 ], 179 ) 180 model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused) 181 model_int8 = torch.ao.quantization.convert(model_fp32_prepared) 182 return model_int8 183 184 class M(torch.nn.Module): 185 def __init__(self) -> None: 186 super(FusedQuantModule.M, self).__init__() 187 self.quant = torch.ao.quantization.QuantStub() 188 self.input1d = torch.randn(4, 2, 2) 189 self.input2d = torch.randn((4, 2, 4, 4)) 190 self.input3d = torch.randn(4, 2, 2, 4, 4) 191 self.conv1d = nn.Conv1d(2, 2, 1) 192 self.conv2d = nn.Conv2d(2, 2, 1) 193 self.conv3d = nn.Conv3d(2, 2, 1) 194 self.linear = nn.Linear(4, 2) 195 self.relu1 = nn.ReLU() 196 self.relu2 = nn.ReLU() 197 self.relu3 = nn.ReLU() 198 self.relu4 = nn.ReLU() 199 self.dequant = torch.ao.quantization.DeQuantStub() 200 201 def forward(self): 202 x = self.input1d 203 y = self.input2d 204 z = self.input3d 205 206 x = self.quant(x) 207 x = self.conv1d(x) 208 x = self.relu1(x) 209 x = self.dequant(x) 210 211 y = self.quant(y) 212 y = self.conv2d(y) 213 y = self.relu2(y) 214 y = self.dequant(y) 215 216 z = self.quant(z) 217 z = self.conv3d(z) 218 z = self.relu3(z) 219 z = self.linear(z) 220 z = self.relu4(z) 221 z = self.dequant(z) 222 223 return (x, y, z) 224