xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/backend_config/x86.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2
3from ._common_operator_config_utils import (
4    _get_binary_op_configs,
5    _get_bn_configs,
6    _get_cat_config,
7    _get_conv_configs,
8    _get_default_op_configs,
9    _get_embedding_op_configs,
10    _get_fixed_qparams_op_configs,
11    _get_linear_configs,
12    _get_rnn_op_configs,
13    _get_share_qparams_op_configs,
14    _get_tensor_info_op_configs,
15)
16from .backend_config import BackendConfig, DTypeConfig
17
18
19__all__ = [
20    "get_x86_backend_config",
21]
22
23# ===================
24# |  DTYPE CONFIGS  |
25# ===================
26
27# X86 aligns with FBGEMM for now
28
29x86_weighted_op_int8_dtype_config = DTypeConfig(
30    input_dtype=torch.quint8,
31    output_dtype=torch.quint8,
32    weight_dtype=torch.qint8,
33    bias_dtype=torch.float,
34)
35
36x86_default_op_quint8_dtype_config = DTypeConfig(
37    input_dtype=torch.quint8,
38    output_dtype=torch.quint8,
39)
40
41x86_default_op_fp16_dtype_config = DTypeConfig(
42    input_dtype=torch.float16,
43    output_dtype=torch.float16,
44    weight_dtype=torch.float16,
45    bias_dtype=torch.float16,
46)
47
48x86_default_dynamic_int8_dtype_config = DTypeConfig(
49    input_dtype=torch.quint8,
50    output_dtype=torch.float,
51    weight_dtype=torch.qint8,
52    bias_dtype=torch.float,
53    is_dynamic=True,
54)
55
56x86_default_dynamic_float16_dtype_config = DTypeConfig(
57    input_dtype=torch.float16,
58    output_dtype=torch.float,
59    weight_dtype=torch.float16,
60    bias_dtype=torch.float,
61    is_dynamic=True,
62)
63
64x86_weight_only_quint8_dtype_config = DTypeConfig(
65    input_dtype=torch.float,
66    output_dtype=torch.float,
67    weight_dtype=torch.quint8,
68)
69
70x86_weight_only_quint4x2_dtype_config = DTypeConfig(
71    input_dtype=torch.float,
72    output_dtype=torch.float,
73    weight_dtype=torch.quint4x2,
74)
75
76
77# =====================
78# |  BACKEND CONFIGS  |
79# =====================
80
81
82def get_x86_backend_config() -> BackendConfig:
83    """
84    Return the `BackendConfig` for PyTorch's native x86 backend.
85    """
86    conv_dtype_configs = [x86_weighted_op_int8_dtype_config]
87    linear_dtype_configs = [
88        x86_weighted_op_int8_dtype_config,
89        x86_default_dynamic_int8_dtype_config,
90        x86_default_dynamic_float16_dtype_config,
91    ]
92    binary_op_dtype_configs = [x86_weighted_op_int8_dtype_config]
93    default_op_dtype_configs = [x86_default_op_quint8_dtype_config]
94    fixed_qparams_op_dtype_configs = [x86_weighted_op_int8_dtype_config]
95    share_qparams_op_dtype_configs = [x86_default_op_quint8_dtype_config]
96    tensor_info_op_dtype_configs = [x86_default_op_quint8_dtype_config]
97    rnn_op_dtype_configs = [
98        x86_default_dynamic_int8_dtype_config,
99        x86_default_dynamic_float16_dtype_config,
100    ]
101    embedding_op_dtype_configs = [
102        x86_weight_only_quint8_dtype_config,
103        x86_weight_only_quint4x2_dtype_config,
104    ]
105    return (
106        BackendConfig("x86")
107        .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs))
108        .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs))
109        .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs))
110        .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs))
111        .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs))
112        .set_backend_pattern_configs(
113            _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)
114        )
115        .set_backend_pattern_configs(
116            _get_share_qparams_op_configs(share_qparams_op_dtype_configs)
117        )
118        .set_backend_pattern_configs(
119            _get_tensor_info_op_configs(tensor_info_op_dtype_configs)
120        )
121        .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs))
122        .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs))
123        .set_backend_pattern_configs(
124            _get_embedding_op_configs(embedding_op_dtype_configs)
125        )
126    )
127