1# Owner(s): ["oncall: quantization"] 2 3from .common import AOMigrationTestCase 4 5 6class TestAOMigrationNNQuantized(AOMigrationTestCase): 7 def test_functional_import(self): 8 r"""Tests the migration of the torch.nn.quantized.functional""" 9 function_list = [ 10 "avg_pool2d", 11 "avg_pool3d", 12 "adaptive_avg_pool2d", 13 "adaptive_avg_pool3d", 14 "conv1d", 15 "conv2d", 16 "conv3d", 17 "interpolate", 18 "linear", 19 "max_pool1d", 20 "max_pool2d", 21 "celu", 22 "leaky_relu", 23 "hardtanh", 24 "hardswish", 25 "threshold", 26 "elu", 27 "hardsigmoid", 28 "clamp", 29 "upsample", 30 "upsample_bilinear", 31 "upsample_nearest", 32 ] 33 self._test_function_import("functional", function_list, base="nn.quantized") 34 35 def test_modules_import(self): 36 module_list = [ 37 # Modules 38 "BatchNorm2d", 39 "BatchNorm3d", 40 "Conv1d", 41 "Conv2d", 42 "Conv3d", 43 "ConvTranspose1d", 44 "ConvTranspose2d", 45 "ConvTranspose3d", 46 "DeQuantize", 47 "ELU", 48 "Embedding", 49 "EmbeddingBag", 50 "GroupNorm", 51 "Hardswish", 52 "InstanceNorm1d", 53 "InstanceNorm2d", 54 "InstanceNorm3d", 55 "LayerNorm", 56 "LeakyReLU", 57 "Linear", 58 "MaxPool2d", 59 "Quantize", 60 "ReLU6", 61 "Sigmoid", 62 "Softmax", 63 "Dropout", 64 # Wrapper modules 65 "FloatFunctional", 66 "FXFloatFunctional", 67 "QFunctional", 68 ] 69 self._test_function_import("modules", module_list, base="nn.quantized") 70 71 def test_modules_activation(self): 72 function_list = [ 73 "ReLU6", 74 "Hardswish", 75 "ELU", 76 "LeakyReLU", 77 "Sigmoid", 78 "Softmax", 79 ] 80 self._test_function_import( 81 "activation", function_list, base="nn.quantized.modules" 82 ) 83 84 def test_modules_batchnorm(self): 85 function_list = [ 86 "BatchNorm2d", 87 "BatchNorm3d", 88 ] 89 self._test_function_import( 90 "batchnorm", function_list, base="nn.quantized.modules" 91 ) 92 93 def test_modules_conv(self): 94 function_list = [ 95 "_reverse_repeat_padding", 96 "Conv1d", 97 "Conv2d", 98 "Conv3d", 99 "ConvTranspose1d", 100 "ConvTranspose2d", 101 "ConvTranspose3d", 102 ] 103 104 self._test_function_import("conv", function_list, base="nn.quantized.modules") 105 106 def test_modules_dropout(self): 107 function_list = [ 108 "Dropout", 109 ] 110 self._test_function_import( 111 "dropout", function_list, base="nn.quantized.modules" 112 ) 113 114 def test_modules_embedding_ops(self): 115 function_list = [ 116 "EmbeddingPackedParams", 117 "Embedding", 118 "EmbeddingBag", 119 ] 120 self._test_function_import( 121 "embedding_ops", function_list, base="nn.quantized.modules" 122 ) 123 124 def test_modules_functional_modules(self): 125 function_list = [ 126 "FloatFunctional", 127 "FXFloatFunctional", 128 "QFunctional", 129 ] 130 self._test_function_import( 131 "functional_modules", function_list, base="nn.quantized.modules" 132 ) 133 134 def test_modules_linear(self): 135 function_list = [ 136 "Linear", 137 "LinearPackedParams", 138 ] 139 self._test_function_import("linear", function_list, base="nn.quantized.modules") 140 141 def test_modules_normalization(self): 142 function_list = [ 143 "LayerNorm", 144 "GroupNorm", 145 "InstanceNorm1d", 146 "InstanceNorm2d", 147 "InstanceNorm3d", 148 ] 149 self._test_function_import( 150 "normalization", function_list, base="nn.quantized.modules" 151 ) 152 153 def test_modules_utils(self): 154 function_list = [ 155 "_ntuple_from_first", 156 "_pair_from_first", 157 "_quantize_weight", 158 "_hide_packed_params_repr", 159 "WeightedQuantizedModule", 160 ] 161 self._test_function_import("utils", function_list, base="nn.quantized.modules") 162 163 def test_import_nn_quantized_dynamic_import(self): 164 module_list = [ 165 # Modules 166 "Linear", 167 "LSTM", 168 "GRU", 169 "LSTMCell", 170 "RNNCell", 171 "GRUCell", 172 "Conv1d", 173 "Conv2d", 174 "Conv3d", 175 "ConvTranspose1d", 176 "ConvTranspose2d", 177 "ConvTranspose3d", 178 ] 179 self._test_function_import("dynamic", module_list, base="nn.quantized") 180 181 def test_import_nn_quantizable_activation(self): 182 module_list = [ 183 # Modules 184 "MultiheadAttention", 185 ] 186 self._test_function_import( 187 "activation", module_list, base="nn.quantizable.modules" 188 ) 189 190 def test_import_nn_quantizable_rnn(self): 191 module_list = [ 192 # Modules 193 "LSTM", 194 "LSTMCell", 195 ] 196 self._test_function_import("rnn", module_list, base="nn.quantizable.modules") 197 198 def test_import_nn_qat_conv(self): 199 module_list = [ 200 "Conv1d", 201 "Conv2d", 202 "Conv3d", 203 ] 204 self._test_function_import("conv", module_list, base="nn.qat.modules") 205 206 def test_import_nn_qat_embedding_ops(self): 207 module_list = [ 208 "Embedding", 209 "EmbeddingBag", 210 ] 211 self._test_function_import("embedding_ops", module_list, base="nn.qat.modules") 212 213 def test_import_nn_qat_linear(self): 214 module_list = [ 215 "Linear", 216 ] 217 self._test_function_import("linear", module_list, base="nn.qat.modules") 218 219 def test_import_nn_qat_dynamic_linear(self): 220 module_list = [ 221 "Linear", 222 ] 223 self._test_function_import("linear", module_list, base="nn.qat.dynamic.modules") 224 225 226class TestAOMigrationNNIntrinsic(AOMigrationTestCase): 227 def test_modules_import_nn_intrinsic(self): 228 module_list = [ 229 # Modules 230 "_FusedModule", 231 "ConvBn1d", 232 "ConvBn2d", 233 "ConvBn3d", 234 "ConvBnReLU1d", 235 "ConvBnReLU2d", 236 "ConvBnReLU3d", 237 "ConvReLU1d", 238 "ConvReLU2d", 239 "ConvReLU3d", 240 "LinearReLU", 241 "BNReLU2d", 242 "BNReLU3d", 243 "LinearBn1d", 244 ] 245 self._test_function_import("intrinsic", module_list, base="nn") 246 247 def test_modules_nn_intrinsic_fused(self): 248 function_list = [ 249 "_FusedModule", 250 "ConvBn1d", 251 "ConvBn2d", 252 "ConvBn3d", 253 "ConvBnReLU1d", 254 "ConvBnReLU2d", 255 "ConvBnReLU3d", 256 "ConvReLU1d", 257 "ConvReLU2d", 258 "ConvReLU3d", 259 "LinearReLU", 260 "BNReLU2d", 261 "BNReLU3d", 262 "LinearBn1d", 263 ] 264 self._test_function_import("fused", function_list, base="nn.intrinsic.modules") 265 266 def test_modules_import_nn_intrinsic_qat(self): 267 module_list = [ 268 "LinearReLU", 269 "LinearBn1d", 270 "ConvReLU1d", 271 "ConvReLU2d", 272 "ConvReLU3d", 273 "ConvBn1d", 274 "ConvBn2d", 275 "ConvBn3d", 276 "ConvBnReLU1d", 277 "ConvBnReLU2d", 278 "ConvBnReLU3d", 279 "update_bn_stats", 280 "freeze_bn_stats", 281 ] 282 self._test_function_import("qat", module_list, base="nn.intrinsic") 283 284 def test_modules_intrinsic_qat_conv_fused(self): 285 function_list = [ 286 "ConvBn1d", 287 "ConvBnReLU1d", 288 "ConvReLU1d", 289 "ConvBn2d", 290 "ConvBnReLU2d", 291 "ConvReLU2d", 292 "ConvBn3d", 293 "ConvBnReLU3d", 294 "ConvReLU3d", 295 "update_bn_stats", 296 "freeze_bn_stats", 297 ] 298 self._test_function_import( 299 "conv_fused", function_list, base="nn.intrinsic.qat.modules" 300 ) 301 302 def test_modules_intrinsic_qat_linear_fused(self): 303 function_list = [ 304 "LinearBn1d", 305 ] 306 self._test_function_import( 307 "linear_fused", function_list, base="nn.intrinsic.qat.modules" 308 ) 309 310 def test_modules_intrinsic_qat_linear_relu(self): 311 function_list = [ 312 "LinearReLU", 313 ] 314 self._test_function_import( 315 "linear_relu", function_list, base="nn.intrinsic.qat.modules" 316 ) 317 318 def test_modules_import_nn_intrinsic_quantized(self): 319 module_list = [ 320 "BNReLU2d", 321 "BNReLU3d", 322 "ConvReLU1d", 323 "ConvReLU2d", 324 "ConvReLU3d", 325 "LinearReLU", 326 ] 327 self._test_function_import("quantized", module_list, base="nn.intrinsic") 328 329 def test_modules_intrinsic_quantized_bn_relu(self): 330 function_list = [ 331 "BNReLU2d", 332 "BNReLU3d", 333 ] 334 self._test_function_import( 335 "bn_relu", function_list, base="nn.intrinsic.quantized.modules" 336 ) 337 338 def test_modules_intrinsic_quantized_conv_relu(self): 339 function_list = [ 340 "ConvReLU1d", 341 "ConvReLU2d", 342 "ConvReLU3d", 343 ] 344 self._test_function_import( 345 "conv_relu", function_list, base="nn.intrinsic.quantized.modules" 346 ) 347 348 def test_modules_intrinsic_quantized_linear_relu(self): 349 function_list = [ 350 "LinearReLU", 351 ] 352 self._test_function_import( 353 "linear_relu", function_list, base="nn.intrinsic.quantized.modules" 354 ) 355 356 def test_modules_no_import_nn_intrinsic_quantized_dynamic(self): 357 # TODO(future PR): generalize this 358 import torch 359 360 _ = torch.ao.nn.intrinsic.quantized.dynamic 361 _ = torch.nn.intrinsic.quantized.dynamic 362