1# mypy: allow-untyped-defs 2import functools 3from enum import Enum 4 5import torch 6import torch.distributed as dist 7 8 9TORCH_HALF_MIN = torch.finfo(torch.float16).min 10TORCH_HALF_MAX = torch.finfo(torch.float16).max 11 12 13class DQuantType(Enum): 14 """ 15 Different quantization methods for auto_quantize API are identified here. 16 17 auto_quantize API currently supports fp16 and bfp16 methods. 18 """ 19 20 FP16 = ("fp16",) 21 BFP16 = "bfp16" 22 23 def __str__(self) -> str: 24 return self.value 25 26 27def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor: 28 return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half() 29 30 31def _quantize_tensor(tensor, qtype): 32 if not isinstance(tensor, torch.Tensor): 33 raise RuntimeError( 34 f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}" 35 ) 36 if qtype == DQuantType.FP16: 37 return _fp32_to_fp16_with_clamp(tensor) 38 elif qtype == DQuantType.BFP16: 39 return torch.ops.quantization._FloatToBfloat16Quantized(tensor) 40 else: 41 raise RuntimeError(f"Quantization type {qtype} is not supported") 42 43 44def _quantize_tensor_list(tensor_list, qtype): 45 if not isinstance(tensor_list, list) or not all( 46 isinstance(p, torch.Tensor) for p in tensor_list 47 ): 48 raise RuntimeError( 49 f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}" 50 ) 51 quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list] 52 return quantized_tensor_list 53 54 55def _dequantize_tensor(tensor, qtype, quant_loss=None): 56 if not isinstance(tensor, torch.Tensor): 57 raise RuntimeError( 58 f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}" 59 ) 60 if qtype == DQuantType.FP16: 61 if tensor.dtype != torch.float16: 62 raise RuntimeError( 63 f"tensor dtype is {tensor.dtype} while expected to be FP16." 64 ) 65 elif tensor.dtype == torch.float16 and quant_loss is None: 66 return tensor.float() 67 else: 68 return tensor.float() / quant_loss 69 elif qtype == DQuantType.BFP16: 70 if tensor.dtype != torch.float16: 71 raise RuntimeError( 72 f"tensor dtype is {tensor.dtype} while expected to be FP16." 73 ) 74 else: 75 return torch.ops.quantization._Bfloat16QuantizedToFloat(tensor) 76 else: 77 raise RuntimeError(f"Quantization type {qtype} is not supported") 78 79 80def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None): 81 if not isinstance(tensor_list, list) or not all( 82 isinstance(p, torch.Tensor) for p in tensor_list 83 ): 84 raise RuntimeError( 85 f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}" 86 ) 87 dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list] 88 return dequantized_tensor_list 89 90 91def auto_quantize(func, qtype, quant_loss=None): 92 """ 93 Quantize the input tensors, choose the precision types, and pass other necessary arguments and then dequantizes the output. 94 95 Currently it only supports: 96 . FP16 and BFP16 quantization method supported for gloo and nccl backends 97 . all_gather, all_to_all collective ops 98 Note: BFP16 only supports 2D tensors. 99 Args: 100 func (Callable): A function representing collective operations. 101 qtype (QuantType): Quantization method 102 quant_loss (float, optional): This can be used to improve accuracy in the dequantization. 103 Returns: 104 (Callable): the same collective as func but enables automatic quantization/dequantization. 105 """ 106 107 @functools.wraps(func) 108 def wrapper(*args, **kwargs): 109 group = kwargs.get("group", None) 110 async_op = kwargs.get("async_op", False) 111 if async_op is True: 112 raise RuntimeError("The async_op=True mode is not supported yet.") 113 if func == dist.all_gather: 114 tensors = args[0] 115 input_tensors = _quantize_tensor(args[1], qtype) 116 out_tensors = _quantize_tensor_list(tensors, qtype) 117 dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op) 118 for i, t in enumerate( 119 _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss) 120 ): 121 tensors[i] = t 122 123 elif func == dist.all_to_all: 124 tensors = args[0] 125 input_tensors = _quantize_tensor_list(args[1], qtype) 126 out_tensors = _quantize_tensor_list(tensors, qtype) 127 dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op) 128 for i, t in enumerate( 129 _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss) 130 ): 131 tensors[i] = t 132 133 elif func == dist.all_to_all_single: 134 tensors = args[0] 135 out_splits = kwargs.get("out_splits", None) 136 in_splits = kwargs.get("in_splits", None) 137 # Quantizing the input/output tensor 138 input_tensors = _quantize_tensor(args[1], qtype) 139 out_tensors = _quantize_tensor(tensors, qtype) 140 dist.all_to_all_single( 141 out_tensors, input_tensors, out_splits, in_splits, group=group 142 ) 143 for i, t in enumerate( 144 _dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss) 145 ): 146 tensors[i] = t 147 else: 148 raise RuntimeError(f"The collective op {func} is not supported yet") 149 150 return wrapper 151