1import torch 2 3from ._common_operator_config_utils import ( 4 _get_binary_op_configs, 5 _get_bn_configs, 6 _get_cat_config, 7 _get_conv_configs, 8 _get_default_op_configs, 9 _get_embedding_op_configs, 10 _get_fixed_qparams_op_configs, 11 _get_linear_configs, 12 _get_rnn_op_configs, 13 _get_share_qparams_op_configs, 14 _get_tensor_info_op_configs, 15) 16from .backend_config import BackendConfig, DTypeConfig 17 18 19__all__ = [ 20 "get_x86_backend_config", 21] 22 23# =================== 24# | DTYPE CONFIGS | 25# =================== 26 27# X86 aligns with FBGEMM for now 28 29x86_weighted_op_int8_dtype_config = DTypeConfig( 30 input_dtype=torch.quint8, 31 output_dtype=torch.quint8, 32 weight_dtype=torch.qint8, 33 bias_dtype=torch.float, 34) 35 36x86_default_op_quint8_dtype_config = DTypeConfig( 37 input_dtype=torch.quint8, 38 output_dtype=torch.quint8, 39) 40 41x86_default_op_fp16_dtype_config = DTypeConfig( 42 input_dtype=torch.float16, 43 output_dtype=torch.float16, 44 weight_dtype=torch.float16, 45 bias_dtype=torch.float16, 46) 47 48x86_default_dynamic_int8_dtype_config = DTypeConfig( 49 input_dtype=torch.quint8, 50 output_dtype=torch.float, 51 weight_dtype=torch.qint8, 52 bias_dtype=torch.float, 53 is_dynamic=True, 54) 55 56x86_default_dynamic_float16_dtype_config = DTypeConfig( 57 input_dtype=torch.float16, 58 output_dtype=torch.float, 59 weight_dtype=torch.float16, 60 bias_dtype=torch.float, 61 is_dynamic=True, 62) 63 64x86_weight_only_quint8_dtype_config = DTypeConfig( 65 input_dtype=torch.float, 66 output_dtype=torch.float, 67 weight_dtype=torch.quint8, 68) 69 70x86_weight_only_quint4x2_dtype_config = DTypeConfig( 71 input_dtype=torch.float, 72 output_dtype=torch.float, 73 weight_dtype=torch.quint4x2, 74) 75 76 77# ===================== 78# | BACKEND CONFIGS | 79# ===================== 80 81 82def get_x86_backend_config() -> BackendConfig: 83 """ 84 Return the `BackendConfig` for PyTorch's native x86 backend. 85 """ 86 conv_dtype_configs = [x86_weighted_op_int8_dtype_config] 87 linear_dtype_configs = [ 88 x86_weighted_op_int8_dtype_config, 89 x86_default_dynamic_int8_dtype_config, 90 x86_default_dynamic_float16_dtype_config, 91 ] 92 binary_op_dtype_configs = [x86_weighted_op_int8_dtype_config] 93 default_op_dtype_configs = [x86_default_op_quint8_dtype_config] 94 fixed_qparams_op_dtype_configs = [x86_weighted_op_int8_dtype_config] 95 share_qparams_op_dtype_configs = [x86_default_op_quint8_dtype_config] 96 tensor_info_op_dtype_configs = [x86_default_op_quint8_dtype_config] 97 rnn_op_dtype_configs = [ 98 x86_default_dynamic_int8_dtype_config, 99 x86_default_dynamic_float16_dtype_config, 100 ] 101 embedding_op_dtype_configs = [ 102 x86_weight_only_quint8_dtype_config, 103 x86_weight_only_quint4x2_dtype_config, 104 ] 105 return ( 106 BackendConfig("x86") 107 .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) 108 .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) 109 .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) 110 .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) 111 .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) 112 .set_backend_pattern_configs( 113 _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) 114 ) 115 .set_backend_pattern_configs( 116 _get_share_qparams_op_configs(share_qparams_op_dtype_configs) 117 ) 118 .set_backend_pattern_configs( 119 _get_tensor_info_op_configs(tensor_info_op_dtype_configs) 120 ) 121 .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) 122 .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) 123 .set_backend_pattern_configs( 124 _get_embedding_op_configs(embedding_op_dtype_configs) 125 ) 126 ) 127