xref: /aosp_15_r20/external/pytorch/torch/onnx/symbolic_caffe2.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# mypy: disable-error-code=arg-type
3import importlib
4import inspect
5
6from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
7from torch.onnx._internal import jit_utils, registration
8
9
10def register_quantized_ops(domain: str, version: int):
11    # Register all quantized ops
12    module = importlib.import_module("torch.onnx.symbolic_caffe2")
13    quant_version_ops = inspect.getmembers(module)
14    aten_q_ops = {
15        "relu",
16        "_empty_affine_quantized",
17        "dequantize",
18        "quantize_per_tensor",
19        "upsample_nearest2d",
20        "avg_pool2d",
21        "reshape",
22        "slice",
23        "cat",
24        "max_pool2d",
25        "sigmoid",
26    }
27    for op, func in quant_version_ops:
28        name = f"{domain}::{op}"
29        if inspect.isfunction(func) and not registration.registry.is_registered_op(
30            name, version
31        ):
32            if op in aten_q_ops:
33                # Override the builtin aten ops
34                registration.registry.register(
35                    f"aten::{op}", version, func, custom=True
36                )
37            registration.registry.register(name, version, func)
38
39
40def _permute_helper(g: jit_utils.GraphContext, input, axes):
41    quant_args = {
42        "axes_i": axes,
43        "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
44        "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
45    }
46    output = g.op("_caffe2::Int8Transpose", input, **quant_args)
47    symbolic_helper._quantized_ops.add(output)
48    return output
49
50
51def nchw2nhwc(g: jit_utils.GraphContext, input):
52    axes = [0, 2, 3, 1]
53    return _permute_helper(g, input, axes)
54
55
56def nhwc2nchw(g: jit_utils.GraphContext, input):
57    axes = [0, 3, 1, 2]
58    return _permute_helper(g, input, axes)
59
60
61def linear_prepack(g: jit_utils.GraphContext, weight, bias):
62    # Mapping to a dummy caffe2 prepack node.
63    # During the onnx -> c2 conversion we can look up original weight and bias
64    # from this node
65    output = g.op("_caffe2::WeightPrepack", weight, bias)
66    symbolic_helper._quantized_ops.add(output)
67    return output
68
69
70@symbolic_helper.parse_args("v", "v", "v", "f", "i")
71def linear(g: jit_utils.GraphContext, input, weight, bias, scale, zero_point):
72    kwargs = {
73        "Y_scale_f": scale,
74        "Y_zero_point_i": zero_point,
75    }
76    output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs)
77    symbolic_helper._quantized_ops.add(output)
78    return output
79
80
81def conv_prepack(
82    g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups
83):
84    # Mapping to a dummy caffe2 prepack node.
85    # During the onnx -> c2 conversion we can look up original weight and bias
86    # from this node
87    output = g.op("_caffe2::WeightPrepack", input, weight, bias)
88    symbolic_helper._quantized_ops.add(output)
89    return output
90
91
92@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i")
93def conv2d(
94    g: jit_utils.GraphContext,
95    input,
96    weight,
97    bias,
98    stride,
99    padding,
100    dilation,
101    groups,
102    scale,
103    zero_point,
104):
105    kernel_size = weight.node()["shape"][1:3]
106    kwargs = {
107        "strides_i": stride,
108        "pads_i": padding + padding,
109        "dilations_i": dilation,
110        "group_i": groups,
111        "kernels_i": kernel_size,
112        "order_s": "NHWC",
113        "Y_scale_f": scale,
114        "Y_zero_point_i": zero_point,
115    }
116    output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs)
117    symbolic_helper._quantized_ops.add(output)
118    return output
119
120
121@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i")
122def conv2d_relu(
123    g: jit_utils.GraphContext,
124    input,
125    weight,
126    bias,
127    stride,
128    padding,
129    dilation,
130    groups,
131    scale,
132    zero_point,
133):
134    kernel_size = weight.node()["shape"][1:3]
135    kwargs = {
136        "strides_i": stride,
137        "pads_i": padding + padding,
138        "dilations_i": dilation,
139        "group_i": groups,
140        "kernels_i": kernel_size,
141        "order_s": "NHWC",
142        "Y_scale_f": scale,
143        "Y_zero_point_i": zero_point,
144    }
145    output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs)
146    symbolic_helper._quantized_ops.add(output)
147    return output
148
149
150@symbolic_helper.parse_args("v", "v", "f", "i")
151def add(g: jit_utils.GraphContext, input_a, input_b, scale, zero_point):
152    kwargs = {
153        "Y_scale_f": scale,
154        "Y_zero_point_i": zero_point,
155    }
156    output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs)
157    symbolic_helper._quantized_ops.add(output)
158    return output
159
160
161@symbolic_helper.parse_args("v")
162def relu(g: jit_utils.GraphContext, input):
163    if input not in symbolic_helper._quantized_ops:
164        return opset9.relu(g, input)
165    kwargs = {
166        "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
167        "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
168    }
169    output = g.op("_caffe2::Int8Relu", input, **kwargs)
170    symbolic_helper._quantized_ops.add(output)
171    return output
172
173
174@symbolic_helper.parse_args("v", "f", "i", "t")
175def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype):
176    kwargs = {
177        "Y_scale_f": scale,
178        "Y_zero_point_i": zero_point,
179    }
180    output = g.op("_caffe2::Int8Quantize", input, **kwargs)
181    symbolic_helper._quantized_ops.add(output)
182    return output
183
184
185@symbolic_helper.parse_args("v")
186def dequantize(g: jit_utils.GraphContext, input):
187    return g.op("_caffe2::Int8Dequantize", input)
188
189
190@symbolic_helper.parse_args("v", "t", "t", "t", "t", "t", "t", "t")
191def _empty_affine_quantized(
192    g: jit_utils.GraphContext,
193    input,
194    shape,
195    scale,
196    zero_point,
197    dtype,
198    pin_memory,
199    memory_format,
200    layout,
201):
202    return input
203
204
205def upsample_nearest2d(
206    g: jit_utils.GraphContext,
207    input,
208    output_size,
209    align_corners=None,
210    scales_h=None,
211    scales_w=None,
212):
213    if input not in symbolic_helper._quantized_ops:
214        return opset9.upsample_nearest2d(g, input, output_size, align_corners)  # type: ignore[attr-defined]
215
216    output_size = symbolic_helper._parse_arg(output_size, "is")
217    kwargs = {
218        "output_size_i": output_size,
219        "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
220        "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
221    }
222    input = nchw2nhwc(g, input)
223    output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs)
224    output = nhwc2nchw(g, output)
225    symbolic_helper._quantized_ops.add(output)
226    return output
227
228
229@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
230def max_pool2d(
231    g: jit_utils.GraphContext,
232    input,
233    kernel_size,
234    stride,
235    padding,
236    dilation,
237    ceil_mode,
238):
239    if input not in symbolic_helper._quantized_ops:
240        return opset9.max_pool2d(  # type: ignore[attr-defined]
241            g, input, kernel_size, stride, padding, dilation, ceil_mode
242        )
243    kwargs = {
244        "strides_i": stride,
245        "pads_i": padding + padding,
246        "kernel_i": kernel_size[0],
247        "order_s": "NHWC",
248        "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
249        "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
250    }
251    input = nchw2nhwc(g, input)
252    output = g.op("_caffe2::Int8MaxPool", input, **kwargs)
253    output = nhwc2nchw(g, output)
254    symbolic_helper._quantized_ops.add(output)
255    return output
256
257
258@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
259def avg_pool2d(
260    g: jit_utils.GraphContext,
261    input,
262    kernel_size,
263    stride,
264    padding,
265    ceil_mode,
266    count_include_pad,
267    divisor_override=None,
268):
269    if input not in symbolic_helper._quantized_ops:
270        return opset9.avg_pool2d(  # type: ignore[attr-defined]
271            g,
272            input,
273            kernel_size,
274            stride,
275            padding,
276            ceil_mode,
277            count_include_pad,
278            divisor_override,
279        )
280    kwargs = {
281        "strides_i": stride,
282        "pads_i": padding + padding,
283        "kernel_i": kernel_size[0],
284        "order_s": "NHWC",
285        "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
286        "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
287    }
288    input = nchw2nhwc(g, input)
289    output = g.op("_caffe2::Int8AveragePool", input, **kwargs)
290    output = nhwc2nchw(g, output)
291    symbolic_helper._quantized_ops.add(output)
292    return output
293
294
295def reshape(g: jit_utils.GraphContext, input, shape):
296    if input not in symbolic_helper._quantized_ops:
297        return opset9.reshape(g, input, shape)
298
299    kwargs = {
300        "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
301        "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
302    }
303    output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs)
304    symbolic_helper._quantized_ops.add(output)
305    return output
306
307
308@symbolic_helper.parse_args("v", "v", "v", "v", "i")
309def slice(g: jit_utils.GraphContext, input, dim, start, end, step):
310    if input not in symbolic_helper._quantized_ops:
311        return opset9.slice(g, input, dim, start, end, step)
312
313    if step != 1:
314        raise RuntimeError("ONNX quantized slice export only works for step 1.")
315    start = symbolic_helper._parse_arg(start, "i")
316    end = symbolic_helper._parse_arg(end, "i")
317    dim = symbolic_helper._parse_arg(dim, "i")
318
319    kwargs = {
320        "start_idx_i": start,
321        "end_idx_i": end,
322        "dim_i": dim,
323        "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
324        "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
325    }
326    output = g.op("_caffe2::Int8Slice", input, **kwargs)
327    symbolic_helper._quantized_ops.add(output)
328    return output
329
330
331def cat(g: jit_utils.GraphContext, tensor_list, dim, scale=None, zero_point=None):
332    tensors = symbolic_helper._unpack_list(tensor_list)
333    input = tensors[0]
334    if input not in symbolic_helper._quantized_ops:
335        return opset9.cat(g, tensor_list, dim)
336
337    dim = symbolic_helper._parse_arg(dim, "i")
338    kwargs = {
339        "Y_scale_f": tensors[0].node()["Y_scale"],
340        "Y_zero_point_i": tensors[0].node()["Y_zero_point"],
341    }
342    output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs)
343    symbolic_helper._quantized_ops.add(output)
344    return output
345
346
347@symbolic_helper.parse_args("v")
348def sigmoid(g: jit_utils.GraphContext, input):
349    if input not in symbolic_helper._quantized_ops:
350        return opset9.sigmoid(g, input)
351    # Caffe2 expects the output scale to be 1/2^8
352    # and output zero_point to be 0 (quint8 type)
353    out_scale = 1.0 / 256
354    zero_point = 0
355    kwargs = {
356        "Y_scale_f": out_scale,
357        "Y_zero_point_i": zero_point,
358    }
359    output = g.op("_caffe2::Int8Sigmoid", input, **kwargs)
360    symbolic_helper._quantized_ops.add(output)
361    return output
362