1import torch 2 3from ._common_operator_config_utils import ( 4 _get_binary_op_configs, 5 _get_bn_configs, 6 _get_cat_config, 7 _get_conv_configs, 8 _get_default_op_configs, 9 _get_embedding_op_configs, 10 _get_fixed_qparams_op_configs, 11 _get_linear_configs, 12 _get_rnn_op_configs, 13 _get_share_qparams_op_configs, 14 _get_tensor_info_op_configs, 15) 16from .backend_config import BackendConfig, DTypeConfig 17 18 19__all__ = [ 20 "get_fbgemm_backend_config", 21] 22 23# =================== 24# | DTYPE CONFIGS | 25# =================== 26 27# TODO: For now, these DTypeConfigs are identical to the ones defined in native.py 28# In the future, once we support specifying quant_min/quant_max and scale_min/scale_max, 29# these will diverge. In particular, for FBGEMM, we will restrict the activation quantized 30# values to within [0, 127]. 31 32fbgemm_weighted_op_quint8_dtype_config = DTypeConfig( 33 input_dtype=torch.quint8, 34 output_dtype=torch.quint8, 35 weight_dtype=torch.qint8, 36 bias_dtype=torch.float, 37) 38 39fbgemm_default_op_quint8_dtype_config = DTypeConfig( 40 input_dtype=torch.quint8, 41 output_dtype=torch.quint8, 42) 43 44fbgemm_default_op_fp16_dtype_config = DTypeConfig( 45 input_dtype=torch.float16, 46 output_dtype=torch.float16, 47 weight_dtype=torch.float16, 48 bias_dtype=torch.float16, 49) 50 51fbgemm_default_dynamic_int8_dtype_config = DTypeConfig( 52 input_dtype=torch.quint8, 53 output_dtype=torch.float, 54 weight_dtype=torch.qint8, 55 bias_dtype=torch.float, 56 is_dynamic=True, 57) 58 59fbgemm_default_dynamic_float16_dtype_config = DTypeConfig( 60 input_dtype=torch.float16, 61 output_dtype=torch.float, 62 weight_dtype=torch.float16, 63 bias_dtype=torch.float, 64 is_dynamic=True, 65) 66 67fbgemm_weight_only_quint8_dtype_config = DTypeConfig( 68 input_dtype=torch.float, 69 output_dtype=torch.float, 70 weight_dtype=torch.quint8, 71) 72 73fbgemm_weight_only_quint4x2_dtype_config = DTypeConfig( 74 input_dtype=torch.float, 75 output_dtype=torch.float, 76 weight_dtype=torch.quint4x2, 77) 78 79 80# ===================== 81# | BACKEND CONFIGS | 82# ===================== 83 84 85def get_fbgemm_backend_config() -> BackendConfig: 86 """ 87 Return the `BackendConfig` for PyTorch's native FBGEMM backend. 88 """ 89 conv_dtype_configs = [fbgemm_weighted_op_quint8_dtype_config] 90 linear_dtype_configs = [ 91 fbgemm_weighted_op_quint8_dtype_config, 92 fbgemm_default_dynamic_int8_dtype_config, 93 fbgemm_default_dynamic_float16_dtype_config, 94 ] 95 binary_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] 96 default_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] 97 fixed_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] 98 share_qparams_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] 99 tensor_info_op_dtype_configs = [fbgemm_default_op_quint8_dtype_config] 100 rnn_op_dtype_configs = [ 101 fbgemm_default_dynamic_int8_dtype_config, 102 fbgemm_default_dynamic_float16_dtype_config, 103 ] 104 embedding_op_dtype_configs = [ 105 fbgemm_weight_only_quint8_dtype_config, 106 fbgemm_weight_only_quint4x2_dtype_config, 107 ] 108 return ( 109 BackendConfig("fbgemm") 110 .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) 111 .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) 112 .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) 113 .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) 114 .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) 115 .set_backend_pattern_configs( 116 _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs) 117 ) 118 .set_backend_pattern_configs( 119 _get_share_qparams_op_configs(share_qparams_op_dtype_configs) 120 ) 121 .set_backend_pattern_configs( 122 _get_tensor_info_op_configs(tensor_info_op_dtype_configs) 123 ) 124 .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) 125 .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) 126 .set_backend_pattern_configs( 127 _get_embedding_op_configs(embedding_op_dtype_configs) 128 ) 129 ) 130