xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/pt2e/representation/rewrite.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from dataclasses import dataclass
3from functools import partial
4from typing import Any, Callable, Optional, Tuple
5
6import torch
7from torch._higher_order_ops.out_dtype import out_dtype
8from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib  # noqa: F401
9from torch.ao.quantization.pt2e.export_utils import _WrapperModule
10from torch.ao.quantization.pt2e.utils import (
11    _get_aten_graph_module_for_pattern,
12    _replace_literals_with_existing_placeholders,
13    _replace_literals_with_new_placeholders,
14    remove_tensor_overload_for_qdq_ops,
15)
16from torch.fx import GraphModule
17from torch.fx.subgraph_rewriter import replace_pattern
18
19
20__all__ = [
21    "reference_representation_rewrite",
22]
23
24
25_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
26    torch.randint(-128, 127, (2, 5), dtype=torch.int8),
27    torch.randn(1, dtype=torch.float),
28    torch.zeros(1, dtype=torch.int),
29    torch.tensor([-128], dtype=torch.int),
30    torch.tensor([127], dtype=torch.int),
31    torch.randint(-128, 127, (5, 5), dtype=torch.int8),
32    torch.randn(1, dtype=torch.float),
33    torch.zeros(1, dtype=torch.int),
34    torch.tensor([-127], dtype=torch.int),
35    torch.tensor([127], dtype=torch.int),
36    torch.randn(1, dtype=torch.float),
37    torch.randn(1, dtype=torch.float),
38    torch.zeros(1, dtype=torch.int),
39    torch.tensor([-128], dtype=torch.int),
40    torch.tensor([127], dtype=torch.int),
41)
42
43
44def _qdq_quantized_linear(
45    x_i8,
46    x_scale,
47    x_zero_point,
48    x_quant_min,
49    x_quant_max,
50    weight_i8,
51    weight_scale,
52    weight_zero_point,
53    weight_quant_min,
54    weight_quant_max,
55    bias_fp32,
56    out_scale,
57    out_zero_point,
58    out_quant_min,
59    out_quant_max,
60):
61    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
62        x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
63    )
64    weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
65        weight_i8,
66        weight_scale,
67        weight_zero_point,
68        weight_quant_min,
69        weight_quant_max,
70        torch.int8,
71    )
72    out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
73    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
74        out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8
75    )
76    return out_i8
77
78
79def _reference_quantized_linear(
80    x_i8,
81    x_scale,
82    x_zero_point,
83    x_quant_min,
84    x_quant_max,
85    weight_i8,
86    weight_scale,
87    weight_zero_point,
88    weight_quant_min,
89    weight_quant_max,
90    bias_fp32,
91    out_scale,
92    out_zero_point,
93    out_quant_min,
94    out_quant_max,
95):
96    # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
97    # This results in failure to match the pattern.
98    # Therefore, we call a torch.ops.aten.clamp here
99    x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
100    weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
101
102    x_i16 = x_i8.to(torch.int16)
103    weight_i16 = weight_i8.to(torch.int16)
104    # always set bias to None so that the same representation can work for the case
105    # no matter if bias_scale == x_scale * weight_scale or not
106    acc_i32 = out_dtype(
107        torch.ops.aten.linear.default,
108        torch.int32,
109        x_i16 - x_zero_point,
110        weight_i16 - weight_zero_point,
111        None,
112    )
113    # TODO: change to mul.Scalar
114    # Note: we are quantizing bias with these scales without signal from user, but it might be OK
115    bias_scale = x_scale * weight_scale
116    bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
117    acc_i32 = acc_i32 + bias_i32
118    # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
119    acc_i32 = (
120        out_dtype(
121            torch.ops.aten.mul.Tensor,
122            torch.int32,
123            acc_i32,
124            x_scale * weight_scale / out_scale,
125        )
126        + out_zero_point
127    )
128    out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
129    return out_i8
130
131
132_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS = (
133    torch.randn((2, 5), dtype=torch.float),
134    -128,
135    127,
136    torch.finfo(torch.float32).eps,
137    torch.randint(-128, 127, (5, 5), dtype=torch.int8),
138    torch.randn(1, dtype=torch.float),
139    torch.zeros(1, dtype=torch.int),
140    torch.tensor([-127], dtype=torch.int),
141    torch.tensor([127], dtype=torch.int),
142    torch.randn(1, dtype=torch.float),
143)
144
145
146def _qdq_dynamic_quantized_linear(
147    x_fp32,
148    x_quant_min,
149    x_quant_max,
150    x_eps,
151    weight_i8,
152    weight_scale,
153    weight_zero_point,
154    weight_quant_min,
155    weight_quant_max,
156    bias_fp32,
157):
158    x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(
159        x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8
160    )
161    x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
162        x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
163    )
164    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
165        x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
166    )
167    weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
168        weight_i8,
169        weight_scale,
170        weight_zero_point,
171        weight_quant_min,
172        weight_quant_max,
173        torch.int8,
174    )
175    out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
176    return out_fp32
177
178
179def _reference_dynamic_quantized_linear(
180    x_fp32,
181    x_quant_min,
182    x_quant_max,
183    x_eps,
184    weight_i8,
185    weight_scale,
186    weight_zero_point,
187    weight_quant_min,
188    weight_quant_max,
189    bias_fp32,
190):
191    x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(
192        x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8
193    )
194    # decomposed representation for quantize_per_tensor
195    # TODO: use out_dtype(mul, ...) here when the op is ready
196    x_fp32 = x_fp32 / x_scale  # fp32
197    # round modes might be different here
198    # pytorch is rounding to even, which is also common for most of the backends
199    x_fp32 = torch.round(x_fp32)  # fp32
200    x_i32 = x_fp32.to(dtype=torch.int32)  # int32
201    x_i32 = x_i32 + x_zero_point  # int32
202    # clamp works for fp32, int32 and int8 dtypes
203    x_i32 = torch.clamp(x_i32, x_quant_min, x_quant_max)  # int32
204    x_i8 = x_i32.to(dtype=torch.int8)
205
206    weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
207
208    x_i16 = x_i8.to(torch.int16)
209    weight_i16 = weight_i8.to(torch.int16)
210    # always set bias to None so that the same representation can work for the case
211    # no matter if bias_scale == x_scale * weight_scale or not
212    acc_i32 = out_dtype(
213        torch.ops.aten.linear.default,
214        torch.int32,
215        x_i16 - x_zero_point,
216        weight_i16 - weight_zero_point,
217        None,
218    )
219    bias_scale = x_scale * weight_scale
220    bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
221    acc_i32 = acc_i32 + bias_i32
222    out_fp32 = acc_i32 * (x_scale * weight_scale)
223    return out_fp32
224
225
226_QUANTIZED_CONV2d_EXAMPLE_INPUTS = (
227    torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
228    torch.randn(1, dtype=torch.float),
229    torch.zeros(1, dtype=torch.int),
230    torch.tensor([-128], dtype=torch.int),
231    torch.tensor([127], dtype=torch.int),
232    torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
233    torch.randn(1, dtype=torch.float),
234    torch.zeros(1, dtype=torch.int),
235    torch.tensor([-127], dtype=torch.int),
236    torch.tensor([127], dtype=torch.int),
237    torch.randn(1, dtype=torch.float),
238    torch.randn(1, dtype=torch.float),
239    torch.zeros(1, dtype=torch.int),
240    torch.tensor([-128], dtype=torch.int),
241    torch.tensor([127], dtype=torch.int),
242)
243
244
245def _qdq_quantized_conv2d(
246    x_i8,
247    x_scale,
248    x_zero_point,
249    x_quant_min,
250    x_quant_max,
251    weight_i8,
252    weight_scale,
253    weight_zero_point,
254    weight_quant_min,
255    weight_quant_max,
256    bias_fp32,
257    out_scale,
258    out_zero_point,
259    out_quant_min,
260    out_quant_max,
261):
262    stride = [1, 1]
263    padding = [0, 0]
264    dilation = [1, 1]
265    transposed = False
266    output_padding = [0, 0]
267    groups = 1
268    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
269        x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
270    )
271    weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
272        weight_i8,
273        weight_scale,
274        weight_zero_point,
275        weight_quant_min,
276        weight_quant_max,
277        torch.int8,
278    )
279    out_fp32 = torch.ops.aten.convolution.default(
280        x_fp32,
281        weight_fp32,
282        bias_fp32,
283        stride,
284        padding,
285        dilation,
286        transposed,
287        output_padding,
288        groups,
289    )
290    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
291        out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8
292    )
293    return out_i8
294
295
296def _reference_quantized_conv2d(
297    x_i8,
298    x_scale,
299    x_zero_point,
300    x_quant_min,
301    x_quant_max,
302    weight_i8,
303    weight_scale,
304    weight_zero_point,
305    weight_quant_min,
306    weight_quant_max,
307    bias_fp32,
308    out_scale,
309    out_zero_point,
310    out_quant_min,
311    out_quant_max,
312):
313    stride = [1, 1]
314    padding = [0, 0]
315    dilation = [1, 1]
316    transposed = False
317    output_padding = [0, 0]
318    groups = 1
319    # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
320    # This results in failure to match the pattern.
321    # Therefore, we call a torch.ops.aten.clamp here
322    x_i8 = torch.ops.aten.clamp(x_i8, x_quant_min, x_quant_max)
323    weight_i8 = torch.ops.aten.clamp(weight_i8, weight_quant_min, weight_quant_max)
324
325    x_i16 = x_i8.to(torch.int16)
326    weight_i16 = weight_i8.to(torch.int16)
327    # always set bias to None so that the same representation can work for the case
328    # no matter if bias_scale == x_scale * weight_scale or not
329    acc_i32 = out_dtype(
330        torch.ops.aten.convolution.default,
331        torch.int32,
332        x_i16 - x_zero_point,
333        weight_i16 - weight_zero_point,
334        None,
335        stride,
336        padding,
337        dilation,
338        transposed,
339        output_padding,
340        groups,
341    )
342    # Note: we are quantizing bias with these scales without signal from user, but it might be OK
343    bias_scale = x_scale * weight_scale
344    # bias quantization to int32 uses bias_scale = x_scale * weight_scale due to:
345    # Take linear calculation for example
346    # Out_(i, j)_fp32 = Sum_(over k)[X_(i, k)_fp32 * W_(i, k)_fp32] + bias_(i)_fp32
347    # Represent X, W fp32 as their dequant transforms
348    # A_fp32 = (A_q - A_zero_point)/A_scale
349    # Out_(i, j)_fp32 = Sum_(over k)[(X_(i, k)_fp32 - X_zp) * X_scale * (W_(i, k)_fp32 - W_zp) * W_scale] + bias_(i)_fp32
350    # Factor out X_scale and W_scale
351    # Out_(i, j)_fp32 = ((X_scale * W_scale) * Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)]) + bias_(i)_fp32
352    # In order to addition of bias_(i)_fp32 inside, we must do
353    # Out_(i, j)_fp32 = (X_scale * W_scale) * (Sum_(over k)[(X_(i, k)_fp32 - X_zp) * (W_(i, k)_fp32 - W_zp)] + (1 / (X_scale * W_scale)) * bias_(i)_fp32)W_scale  # noqa: B950
354    # Note we had to multiply bias_fp32 qith X_scale * W_scale = bias_scale
355    # Thus bias quantization to int32 must be with X_scale * W_scale
356
357    bias_i32 = out_dtype(torch.ops.aten.div.Tensor, torch.int32, bias_fp32, bias_scale)
358    # Unsqueeze to match broadcast dims
359    # Unfortnuately I cannot do bias_i32.unsqueeze(0) due to literal matching nightmare
360    # in graph pattern replacement
361    bias_i32 = bias_i32.unsqueeze(-1)
362    bias_i32 = bias_i32.unsqueeze(-1)
363    acc_i32 = acc_i32 + bias_i32
364    # TODO: change to mul.Scalar when we make x_scale/weight_scale etc. Scalar values
365    acc_i32 = (
366        out_dtype(
367            torch.ops.aten.mul.Tensor,
368            torch.int32,
369            acc_i32,
370            x_scale * weight_scale / out_scale,
371        )
372        + out_zero_point
373    )
374    out_i8 = torch.ops.aten.clamp(acc_i32, out_quant_min, out_quant_max).to(torch.int8)
375    return out_i8
376
377
378_QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS = (
379    torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
380    torch.randn(1, dtype=torch.float),
381    torch.zeros(1, dtype=torch.int),
382    torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
383    torch.randn(1, dtype=torch.float),
384    torch.zeros(1, dtype=torch.int),
385    torch.randn(1, dtype=torch.float),
386    torch.zeros(1, dtype=torch.int),
387    torch.tensor([-128], dtype=torch.int),
388    torch.tensor([127], dtype=torch.int),
389)
390
391
392def _qdq_quantized_add_relu(
393    x_i8,
394    x_scale,
395    x_zero_point,
396    y_i8,
397    y_scale,
398    y_zero_point,
399    out_scale,
400    out_zero_point,
401    quant_min,
402    quant_max,
403):
404    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
405        x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8
406    )
407    y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
408        y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8
409    )
410    out_fp32 = x_fp32 + y_fp32
411    out_fp32 = torch.ops.aten.relu(out_fp32)
412    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
413        out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
414    )
415    return out_i8
416
417
418def _reference_quantized_add_relu(
419    x_i8,
420    x_scale,
421    x_zero_point,
422    y_i8,
423    y_scale,
424    y_zero_point,
425    out_scale,
426    out_zero_point,
427    quant_min,
428    quant_max,
429):
430    """
431    See comments for `_reference_quantized_add` for more information on
432    how to derive the formula for out_i8 based on x_i8 and y_i8
433    """
434    x_i32 = x_i8.to(torch.int32)
435    y_i32 = y_i8.to(torch.int32)
436    # TODO: change this to mul.Scalar?
437    x_i32 = out_dtype(
438        torch.ops.aten.mul.Tensor,
439        torch.int32,
440        (x_i32 - x_zero_point),
441        (x_scale / out_scale),
442    )
443    y_i32 = out_dtype(
444        torch.ops.aten.mul.Tensor,
445        torch.int32,
446        (y_i32 - y_zero_point),
447        (y_scale / out_scale),
448    )
449    out_i32 = x_i32 + y_i32 + out_zero_point
450    # out_i32 = torch.ops.aten.clamp(out_i32, out_zero_point)
451    out_i8 = torch.ops.aten.clamp(out_i32, out_zero_point, quant_max).to(torch.int8)
452    return out_i8
453
454
455def _qdq_quantized_add(
456    x_i8,
457    x_scale,
458    x_zero_point,
459    y_i8,
460    y_scale,
461    y_zero_point,
462    out_scale,
463    out_zero_point,
464    quant_min,
465    quant_max,
466):
467    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
468        x_i8, x_scale, x_zero_point, quant_min, quant_max, torch.int8
469    )
470    y_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
471        y_i8, y_scale, y_zero_point, quant_min, quant_max, torch.int8
472    )
473    out_fp32 = x_fp32 + y_fp32
474    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
475        out_fp32, out_scale, out_zero_point, quant_min, quant_max, torch.int8
476    )
477    return out_i8
478
479
480def _reference_quantized_add(
481    x_i8,
482    x_scale,
483    x_zero_point,
484    y_i8,
485    y_scale,
486    y_zero_point,
487    out_scale,
488    out_zero_point,
489    quant_min,
490    quant_max,
491):
492    """
493        # How to Derive the formula for out_i8 based on x_i8 and y_i8
494        # (since quantized add takes x_i8, y_i8 and their quantization parameters, and produce an out_i8)
495
496        # out_i8 is quantized output, we can write down the formula for it first:
497    out_i8 = out_f32 / out_scale + out_zero_point           (1)
498
499        # then out_fp32 is computed from x_f32 + y_f32, and the x_fp32 and y_fp32 are the dequantized x_i8 and y_i8
500        out_f32 = x_f32 + y_f32           (2)
501        x_fp32 = (x_i8 - x_zero_point) * x_scale         (3)
502        y_fp32 = (y_i8 - y_zero_point) * y_scale         (4)
503
504        # applying the above fomula to the out_i8 equation we can get the following:
505        out_i8 = out_fp32 / out_scale + out_zero_point             # (1)
506           = (x_f32 + y_f32) / out_scale + out_zero_point      # applying (2) to substitute out_fp32 with x_fp32 + y_fp32
507           = ((x_i8 - x_zero_point) * x_scale + (y_i8 - y_zero_point) * y_scale) / out_scale + out_zero_point  # apply (3) and (4)
508    """
509    x_i32 = x_i8.to(torch.int32)
510    y_i32 = y_i8.to(torch.int32)
511    # TODO: use out_dtype op
512    x_i32 = torch.round((x_scale / out_scale) * (x_i32 - x_zero_point)).to(torch.int32)
513    y_i32 = torch.round((y_scale / out_scale) * (y_i32 - y_zero_point)).to(torch.int32)
514    out_i32 = x_i32 + y_i32 + out_zero_point
515    quant_min = -128
516    quant_max = 127
517    out_i8 = torch.ops.aten.clamp(out_i32, quant_min, quant_max).to(torch.int8)
518    return out_i8
519
520
521_QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS = (
522    torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
523    torch.randn(1, dtype=torch.float),
524    torch.zeros(1, dtype=torch.int),
525    torch.tensor([-128], dtype=torch.int),
526    torch.tensor([127], dtype=torch.int),
527    torch.randn(1, dtype=torch.float),
528    torch.zeros(1, dtype=torch.int),
529    torch.tensor([-128], dtype=torch.int),
530    torch.tensor([127], dtype=torch.int),
531)
532
533
534def _qdq_quantized_max_pool2d(
535    x_i8,
536    x_scale,
537    x_zero_point,
538    x_quant_min,
539    x_quant_max,
540    out_scale,
541    out_zero_point,
542    out_quant_min,
543    out_quant_max,
544):
545    kernel_size = 1
546    stride = 1
547    padding = 0
548    dilation = 1
549    ceil_mode = False
550    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
551        x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8
552    )
553    out_fp32, _ = torch.ops.aten.max_pool2d_with_indices.default(
554        x_fp32, kernel_size, stride, padding, dilation, ceil_mode
555    )
556    out_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
557        out_fp32, out_scale, out_zero_point, out_quant_min, out_quant_max, torch.int8
558    )
559    return out_i8
560
561
562def _reference_quantized_max_pool2d(
563    x_i8,
564    x_scale,
565    x_zero_point,
566    x_quant_min,
567    x_quant_max,
568    out_scale,
569    out_zero_point,
570    out_quant_min,
571    out_quant_max,
572):
573    kernel_size = 1
574    stride = 1
575    padding = 0
576    dilation = 1
577    ceil_mode = False
578    # to preserve x_quant_min, x_quant_max in the graph for pattern matching
579    x_i8 = torch.clamp(x_i8, x_quant_min, x_quant_max)
580    x_i32 = x_i8.to(torch.int32)
581    out_i32, _ = torch.ops.aten.max_pool2d_with_indices.default(
582        x_i32 - x_zero_point, kernel_size, stride, padding, dilation, ceil_mode
583    )
584    out_fp32 = out_i32 * (x_scale / out_scale) + out_zero_point
585    out_fp32 = torch.clamp(out_fp32, out_quant_min, out_quant_max)
586    out_i8 = out_fp32.to(torch.int8)
587    return out_i8
588
589
590_QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
591    torch.randn(1, 3, 3, 3, dtype=torch.float),
592    torch.randn(1, dtype=torch.float),
593    torch.zeros(1, dtype=torch.int),
594    torch.tensor([-128], dtype=torch.int),
595    torch.tensor([127], dtype=torch.int),
596)
597
598
599def _quantize_per_tensor_int8(x_fp32, scale, zero_point, quant_min, quant_max):
600    x = torch.ops.quantized_decomposed.quantize_per_tensor(
601        x_fp32, scale, zero_point, quant_min, quant_max, torch.int8
602    )
603    return x
604
605
606def _reference_quantize_per_tensor_int8(
607    x_fp32, scale, zero_point, quant_min, quant_max
608):
609    # TODO: use out_dtype(mul, ...) here when the op is ready
610    x = x_fp32 / scale  # fp32
611    # round modes might be different here
612    # pytorch is rounding to even, which is also common for most of the backends
613    x = torch.round(x)  # fp32
614    x = x.to(dtype=torch.int32)  # int32
615    x = x + zero_point  # int32
616    # clamp works for fp32, int32 and int8 dtypes
617    x = torch.clamp(x, quant_min, quant_max)  # int32
618    x = x.to(dtype=torch.int8)
619    return x
620
621
622_DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS = (
623    torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
624    torch.randn(1, dtype=torch.float),
625    torch.zeros(1, dtype=torch.int),
626    torch.tensor([-128], dtype=torch.int),
627    torch.tensor([127], dtype=torch.int),
628)
629
630
631def _dequantize_per_tensor_int8(x_i8, scale, zero_point, quant_min, quant_max):
632    x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
633        x_i8, scale, zero_point, quant_min, quant_max, torch.int8
634    )
635    return x_fp32
636
637
638def _reference_dequantize_per_tensor_int8(
639    x_i8, scale, zero_point, quant_min, quant_max
640):
641    # without using quant_min/max in clamp, the traced graph will not have quant_mi/max args.
642    # This results in failure to match the pattern.
643    # Therefore, we call a torch.ops.aten.clamp here
644    x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max)
645    # TODO: use out_dtype op
646    # note: x_i8.to(torch.int32) does not work here
647    # TODO: debug the implementation later when torchdynamo time out issue is resolved
648    return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
649
650
651_QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
652    torch.randn(1, 3, 3, 3, dtype=torch.float),
653    torch.randn(3, dtype=torch.float),
654    torch.zeros(3, dtype=torch.int),
655    1,
656    -128,
657    127,
658)
659
660
661def _quantize_per_channel_int8(
662    x_fp32, scales, zero_points, ch_axis, quant_min, quant_max
663):
664    out_i8 = torch.ops.quantized_decomposed.quantize_per_channel(
665        x_fp32, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8
666    )
667    return out_i8
668
669
670def _reference_quantize_per_channel_int8(
671    x_fp32, scales, zero_points, ch_axis, quant_min, quant_max
672):
673    x_fp32 = torch.transpose(x_fp32, ch_axis, -1)
674    out_i32 = torch.ops.aten.clamp(
675        torch.round(x_fp32 / scales).to(torch.int32) + zero_points, quant_min, quant_max
676    )
677    out_i32 = torch.transpose(out_i32, ch_axis, -1)
678    return out_i32.to(torch.int8)
679
680
681_DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS = (
682    torch.randint(-128, 127, (1, 3, 3, 3), dtype=torch.int8),
683    torch.randn(3, dtype=torch.float),
684    torch.zeros(3, dtype=torch.int),
685    1,
686    -128,
687    127,
688)
689
690
691def _dequantize_per_channel_int8(
692    x_i8, scales, zero_points, ch_axis, quant_min, quant_max
693):
694    # the following will be replaced as placeholders
695    out_fp32 = torch.ops.quantized_decomposed.dequantize_per_channel(
696        x_i8, scales, zero_points, ch_axis, quant_min, quant_max, torch.int8
697    )
698    return out_fp32
699
700
701def _reference_dequantize_per_channel_int8(
702    x_i8, scales, zero_points, ch_axis, quant_min, quant_max
703):
704    # the following will be replaced as placeholders
705    # in order to preserve the quant_min/quant_max args for pattern matching (e.g. matching for int4 quantized ops)
706    # we call a torch.ops.aten.clamp here
707    x_i8 = torch.ops.aten.clamp(x_i8, quant_min, quant_max)
708    x_i8 = torch.transpose(x_i8, ch_axis, -1)
709    x_i32 = x_i8.to(torch.int32)
710    out_fp32 = (x_i32 - zero_points).to(torch.float) * scales
711    out_fp32 = torch.transpose(out_fp32, ch_axis, -1)
712    return out_fp32
713
714
715def _replace_ph_qdq_per_channel_replacement(gm: torch.fx.GraphModule):
716    return _replace_literals_with_existing_placeholders(
717        gm, exclude_literals=[-1], literal_to_ph_idx={1: 3, -128: 4, 127: 5}
718    )
719
720
721@dataclass
722class _RewriteInfo:
723    """Data needed for rewrite, this includes example inputs, pattern and replacement functions
724    and post transformation functions for the exported pattern and replacement GraphModule
725    """
726
727    # example inputs used for exporting the pattern into GraphModule
728    example_inputs: Tuple[Any, ...]
729    pattern: Callable
730    replacement: Callable
731    # post transformation on the exported pattern and replacement GraphModule
732    pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
733    replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
734
735
736_REWRITE_INFO_LIST = [
737    _RewriteInfo(
738        _DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
739        _WrapperModule(_qdq_dynamic_quantized_linear),
740        _WrapperModule(_reference_dynamic_quantized_linear),
741        partial(
742            _replace_literals_with_existing_placeholders,
743            literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
744        ),
745        partial(
746            _replace_literals_with_existing_placeholders,
747            literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
748        ),
749    ),
750    _RewriteInfo(
751        _QUANTIZED_LINEAR_EXAMPLE_INPUTS,
752        _WrapperModule(_qdq_quantized_linear),
753        _WrapperModule(_reference_quantized_linear),
754        _replace_literals_with_new_placeholders,
755        _replace_literals_with_new_placeholders,
756    ),
757    _RewriteInfo(
758        _QUANTIZED_CONV2d_EXAMPLE_INPUTS,
759        _WrapperModule(_qdq_quantized_conv2d),
760        _WrapperModule(_reference_quantized_conv2d),
761        partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
762        partial(_replace_literals_with_new_placeholders, exclude_literals=[-1]),
763    ),
764    _RewriteInfo(
765        _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
766        _WrapperModule(_qdq_quantized_add_relu),
767        _WrapperModule(_reference_quantized_add_relu),
768    ),
769    _RewriteInfo(
770        _QUANTIZED_ADD_OR_ADD_RELU_EXAMPLE_INPUTS,
771        _WrapperModule(_qdq_quantized_add),
772        _WrapperModule(_reference_quantized_add),
773    ),
774    _RewriteInfo(
775        _QUANTIZED_MAX_POOL2D_EXAMPLE_INPUTS,
776        _WrapperModule(_qdq_quantized_max_pool2d),
777        _WrapperModule(_reference_quantized_max_pool2d),
778        _replace_literals_with_new_placeholders,
779        _replace_literals_with_new_placeholders,
780    ),
781    _RewriteInfo(
782        _QUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
783        _WrapperModule(_quantize_per_tensor_int8),
784        _WrapperModule(_reference_quantize_per_tensor_int8),
785    ),
786    _RewriteInfo(
787        _DEQUANTIZE_PER_TENSOR_INT8_EXAMPLE_INPUTS,
788        _WrapperModule(_dequantize_per_tensor_int8),
789        _WrapperModule(_reference_dequantize_per_tensor_int8),
790    ),
791    _RewriteInfo(
792        _QUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
793        _WrapperModule(_quantize_per_channel_int8),
794        _WrapperModule(_reference_quantize_per_channel_int8),
795        _replace_ph_qdq_per_channel_replacement,
796        _replace_ph_qdq_per_channel_replacement,
797    ),
798    _RewriteInfo(
799        _DEQUANTIZE_PER_CHANNEL_INT8_EXAMPLE_INPUTS,
800        _WrapperModule(_dequantize_per_channel_int8),
801        _WrapperModule(_reference_dequantize_per_channel_int8),
802        _replace_ph_qdq_per_channel_replacement,
803        _replace_ph_qdq_per_channel_replacement,
804    ),
805]
806
807
808def reference_representation_rewrite(model: GraphModule) -> GraphModule:
809    remove_tensor_overload_for_qdq_ops(model)
810    for rewrite_info in _REWRITE_INFO_LIST:
811        example_inputs = rewrite_info.example_inputs
812        pattern = rewrite_info.pattern
813        replacement = rewrite_info.replacement
814        pattern_post_trans = rewrite_info.pattern_post_trans
815        replacement_post_trans = rewrite_info.replacement_post_trans
816        pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs)  # type: ignore[arg-type, assignment]
817        remove_tensor_overload_for_qdq_ops(pattern)  # type: ignore[arg-type]
818        replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs)  # type: ignore[arg-type, assignment]
819        remove_tensor_overload_for_qdq_ops(replacement)  # type: ignore[arg-type]
820        if pattern_post_trans:
821            pattern = pattern_post_trans(pattern)
822        if replacement_post_trans:
823            replacement = replacement_post_trans(replacement)
824        pattern.recompile()  # type: ignore[attr-defined]
825        replacement.recompile()  # type: ignore[attr-defined]
826        matches = replace_pattern(model, pattern, replacement)
827    return model
828