1# mypy: allow-untyped-defs 2# mypy: disable-error-code=arg-type 3import importlib 4import inspect 5 6from torch.onnx import symbolic_helper, symbolic_opset9 as opset9 7from torch.onnx._internal import jit_utils, registration 8 9 10def register_quantized_ops(domain: str, version: int): 11 # Register all quantized ops 12 module = importlib.import_module("torch.onnx.symbolic_caffe2") 13 quant_version_ops = inspect.getmembers(module) 14 aten_q_ops = { 15 "relu", 16 "_empty_affine_quantized", 17 "dequantize", 18 "quantize_per_tensor", 19 "upsample_nearest2d", 20 "avg_pool2d", 21 "reshape", 22 "slice", 23 "cat", 24 "max_pool2d", 25 "sigmoid", 26 } 27 for op, func in quant_version_ops: 28 name = f"{domain}::{op}" 29 if inspect.isfunction(func) and not registration.registry.is_registered_op( 30 name, version 31 ): 32 if op in aten_q_ops: 33 # Override the builtin aten ops 34 registration.registry.register( 35 f"aten::{op}", version, func, custom=True 36 ) 37 registration.registry.register(name, version, func) 38 39 40def _permute_helper(g: jit_utils.GraphContext, input, axes): 41 quant_args = { 42 "axes_i": axes, 43 "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), 44 "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), 45 } 46 output = g.op("_caffe2::Int8Transpose", input, **quant_args) 47 symbolic_helper._quantized_ops.add(output) 48 return output 49 50 51def nchw2nhwc(g: jit_utils.GraphContext, input): 52 axes = [0, 2, 3, 1] 53 return _permute_helper(g, input, axes) 54 55 56def nhwc2nchw(g: jit_utils.GraphContext, input): 57 axes = [0, 3, 1, 2] 58 return _permute_helper(g, input, axes) 59 60 61def linear_prepack(g: jit_utils.GraphContext, weight, bias): 62 # Mapping to a dummy caffe2 prepack node. 63 # During the onnx -> c2 conversion we can look up original weight and bias 64 # from this node 65 output = g.op("_caffe2::WeightPrepack", weight, bias) 66 symbolic_helper._quantized_ops.add(output) 67 return output 68 69 70@symbolic_helper.parse_args("v", "v", "v", "f", "i") 71def linear(g: jit_utils.GraphContext, input, weight, bias, scale, zero_point): 72 kwargs = { 73 "Y_scale_f": scale, 74 "Y_zero_point_i": zero_point, 75 } 76 output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs) 77 symbolic_helper._quantized_ops.add(output) 78 return output 79 80 81def conv_prepack( 82 g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups 83): 84 # Mapping to a dummy caffe2 prepack node. 85 # During the onnx -> c2 conversion we can look up original weight and bias 86 # from this node 87 output = g.op("_caffe2::WeightPrepack", input, weight, bias) 88 symbolic_helper._quantized_ops.add(output) 89 return output 90 91 92@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") 93def conv2d( 94 g: jit_utils.GraphContext, 95 input, 96 weight, 97 bias, 98 stride, 99 padding, 100 dilation, 101 groups, 102 scale, 103 zero_point, 104): 105 kernel_size = weight.node()["shape"][1:3] 106 kwargs = { 107 "strides_i": stride, 108 "pads_i": padding + padding, 109 "dilations_i": dilation, 110 "group_i": groups, 111 "kernels_i": kernel_size, 112 "order_s": "NHWC", 113 "Y_scale_f": scale, 114 "Y_zero_point_i": zero_point, 115 } 116 output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs) 117 symbolic_helper._quantized_ops.add(output) 118 return output 119 120 121@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i") 122def conv2d_relu( 123 g: jit_utils.GraphContext, 124 input, 125 weight, 126 bias, 127 stride, 128 padding, 129 dilation, 130 groups, 131 scale, 132 zero_point, 133): 134 kernel_size = weight.node()["shape"][1:3] 135 kwargs = { 136 "strides_i": stride, 137 "pads_i": padding + padding, 138 "dilations_i": dilation, 139 "group_i": groups, 140 "kernels_i": kernel_size, 141 "order_s": "NHWC", 142 "Y_scale_f": scale, 143 "Y_zero_point_i": zero_point, 144 } 145 output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs) 146 symbolic_helper._quantized_ops.add(output) 147 return output 148 149 150@symbolic_helper.parse_args("v", "v", "f", "i") 151def add(g: jit_utils.GraphContext, input_a, input_b, scale, zero_point): 152 kwargs = { 153 "Y_scale_f": scale, 154 "Y_zero_point_i": zero_point, 155 } 156 output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs) 157 symbolic_helper._quantized_ops.add(output) 158 return output 159 160 161@symbolic_helper.parse_args("v") 162def relu(g: jit_utils.GraphContext, input): 163 if input not in symbolic_helper._quantized_ops: 164 return opset9.relu(g, input) 165 kwargs = { 166 "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), 167 "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), 168 } 169 output = g.op("_caffe2::Int8Relu", input, **kwargs) 170 symbolic_helper._quantized_ops.add(output) 171 return output 172 173 174@symbolic_helper.parse_args("v", "f", "i", "t") 175def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): 176 kwargs = { 177 "Y_scale_f": scale, 178 "Y_zero_point_i": zero_point, 179 } 180 output = g.op("_caffe2::Int8Quantize", input, **kwargs) 181 symbolic_helper._quantized_ops.add(output) 182 return output 183 184 185@symbolic_helper.parse_args("v") 186def dequantize(g: jit_utils.GraphContext, input): 187 return g.op("_caffe2::Int8Dequantize", input) 188 189 190@symbolic_helper.parse_args("v", "t", "t", "t", "t", "t", "t", "t") 191def _empty_affine_quantized( 192 g: jit_utils.GraphContext, 193 input, 194 shape, 195 scale, 196 zero_point, 197 dtype, 198 pin_memory, 199 memory_format, 200 layout, 201): 202 return input 203 204 205def upsample_nearest2d( 206 g: jit_utils.GraphContext, 207 input, 208 output_size, 209 align_corners=None, 210 scales_h=None, 211 scales_w=None, 212): 213 if input not in symbolic_helper._quantized_ops: 214 return opset9.upsample_nearest2d(g, input, output_size, align_corners) # type: ignore[attr-defined] 215 216 output_size = symbolic_helper._parse_arg(output_size, "is") 217 kwargs = { 218 "output_size_i": output_size, 219 "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), 220 "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), 221 } 222 input = nchw2nhwc(g, input) 223 output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs) 224 output = nhwc2nchw(g, output) 225 symbolic_helper._quantized_ops.add(output) 226 return output 227 228 229@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") 230def max_pool2d( 231 g: jit_utils.GraphContext, 232 input, 233 kernel_size, 234 stride, 235 padding, 236 dilation, 237 ceil_mode, 238): 239 if input not in symbolic_helper._quantized_ops: 240 return opset9.max_pool2d( # type: ignore[attr-defined] 241 g, input, kernel_size, stride, padding, dilation, ceil_mode 242 ) 243 kwargs = { 244 "strides_i": stride, 245 "pads_i": padding + padding, 246 "kernel_i": kernel_size[0], 247 "order_s": "NHWC", 248 "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), 249 "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), 250 } 251 input = nchw2nhwc(g, input) 252 output = g.op("_caffe2::Int8MaxPool", input, **kwargs) 253 output = nhwc2nchw(g, output) 254 symbolic_helper._quantized_ops.add(output) 255 return output 256 257 258@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") 259def avg_pool2d( 260 g: jit_utils.GraphContext, 261 input, 262 kernel_size, 263 stride, 264 padding, 265 ceil_mode, 266 count_include_pad, 267 divisor_override=None, 268): 269 if input not in symbolic_helper._quantized_ops: 270 return opset9.avg_pool2d( # type: ignore[attr-defined] 271 g, 272 input, 273 kernel_size, 274 stride, 275 padding, 276 ceil_mode, 277 count_include_pad, 278 divisor_override, 279 ) 280 kwargs = { 281 "strides_i": stride, 282 "pads_i": padding + padding, 283 "kernel_i": kernel_size[0], 284 "order_s": "NHWC", 285 "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), 286 "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), 287 } 288 input = nchw2nhwc(g, input) 289 output = g.op("_caffe2::Int8AveragePool", input, **kwargs) 290 output = nhwc2nchw(g, output) 291 symbolic_helper._quantized_ops.add(output) 292 return output 293 294 295def reshape(g: jit_utils.GraphContext, input, shape): 296 if input not in symbolic_helper._quantized_ops: 297 return opset9.reshape(g, input, shape) 298 299 kwargs = { 300 "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), 301 "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), 302 } 303 output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs) 304 symbolic_helper._quantized_ops.add(output) 305 return output 306 307 308@symbolic_helper.parse_args("v", "v", "v", "v", "i") 309def slice(g: jit_utils.GraphContext, input, dim, start, end, step): 310 if input not in symbolic_helper._quantized_ops: 311 return opset9.slice(g, input, dim, start, end, step) 312 313 if step != 1: 314 raise RuntimeError("ONNX quantized slice export only works for step 1.") 315 start = symbolic_helper._parse_arg(start, "i") 316 end = symbolic_helper._parse_arg(end, "i") 317 dim = symbolic_helper._parse_arg(dim, "i") 318 319 kwargs = { 320 "start_idx_i": start, 321 "end_idx_i": end, 322 "dim_i": dim, 323 "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"), 324 "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"), 325 } 326 output = g.op("_caffe2::Int8Slice", input, **kwargs) 327 symbolic_helper._quantized_ops.add(output) 328 return output 329 330 331def cat(g: jit_utils.GraphContext, tensor_list, dim, scale=None, zero_point=None): 332 tensors = symbolic_helper._unpack_list(tensor_list) 333 input = tensors[0] 334 if input not in symbolic_helper._quantized_ops: 335 return opset9.cat(g, tensor_list, dim) 336 337 dim = symbolic_helper._parse_arg(dim, "i") 338 kwargs = { 339 "Y_scale_f": tensors[0].node()["Y_scale"], 340 "Y_zero_point_i": tensors[0].node()["Y_zero_point"], 341 } 342 output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs) 343 symbolic_helper._quantized_ops.add(output) 344 return output 345 346 347@symbolic_helper.parse_args("v") 348def sigmoid(g: jit_utils.GraphContext, input): 349 if input not in symbolic_helper._quantized_ops: 350 return opset9.sigmoid(g, input) 351 # Caffe2 expects the output scale to be 1/2^8 352 # and output zero_point to be 0 (quint8 type) 353 out_scale = 1.0 / 256 354 zero_point = 0 355 kwargs = { 356 "Y_scale_f": out_scale, 357 "Y_zero_point_i": zero_point, 358 } 359 output = g.op("_caffe2::Int8Sigmoid", input, **kwargs) 360 symbolic_helper._quantized_ops.add(output) 361 return output 362