xref: /aosp_15_r20/external/pytorch/torch/distributed/algorithms/_quantization/quantization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3from enum import Enum
4
5import torch
6import torch.distributed as dist
7
8
9TORCH_HALF_MIN = torch.finfo(torch.float16).min
10TORCH_HALF_MAX = torch.finfo(torch.float16).max
11
12
13class DQuantType(Enum):
14    """
15    Different quantization methods for auto_quantize API are identified here.
16
17    auto_quantize API currently supports fp16 and bfp16 methods.
18    """
19
20    FP16 = ("fp16",)
21    BFP16 = "bfp16"
22
23    def __str__(self) -> str:
24        return self.value
25
26
27def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
28    return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()
29
30
31def _quantize_tensor(tensor, qtype):
32    if not isinstance(tensor, torch.Tensor):
33        raise RuntimeError(
34            f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
35        )
36    if qtype == DQuantType.FP16:
37        return _fp32_to_fp16_with_clamp(tensor)
38    elif qtype == DQuantType.BFP16:
39        return torch.ops.quantization._FloatToBfloat16Quantized(tensor)
40    else:
41        raise RuntimeError(f"Quantization type {qtype} is not supported")
42
43
44def _quantize_tensor_list(tensor_list, qtype):
45    if not isinstance(tensor_list, list) or not all(
46        isinstance(p, torch.Tensor) for p in tensor_list
47    ):
48        raise RuntimeError(
49            f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
50        )
51    quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list]
52    return quantized_tensor_list
53
54
55def _dequantize_tensor(tensor, qtype, quant_loss=None):
56    if not isinstance(tensor, torch.Tensor):
57        raise RuntimeError(
58            f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
59        )
60    if qtype == DQuantType.FP16:
61        if tensor.dtype != torch.float16:
62            raise RuntimeError(
63                f"tensor dtype is {tensor.dtype} while expected to be FP16."
64            )
65        elif tensor.dtype == torch.float16 and quant_loss is None:
66            return tensor.float()
67        else:
68            return tensor.float() / quant_loss
69    elif qtype == DQuantType.BFP16:
70        if tensor.dtype != torch.float16:
71            raise RuntimeError(
72                f"tensor dtype is {tensor.dtype} while expected to be FP16."
73            )
74        else:
75            return torch.ops.quantization._Bfloat16QuantizedToFloat(tensor)
76    else:
77        raise RuntimeError(f"Quantization type {qtype} is not supported")
78
79
80def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None):
81    if not isinstance(tensor_list, list) or not all(
82        isinstance(p, torch.Tensor) for p in tensor_list
83    ):
84        raise RuntimeError(
85            f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
86        )
87    dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list]
88    return dequantized_tensor_list
89
90
91def auto_quantize(func, qtype, quant_loss=None):
92    """
93    Quantize the input tensors, choose the precision types, and pass other necessary arguments and then dequantizes the output.
94
95    Currently it only supports:
96        . FP16 and BFP16 quantization method supported for gloo and nccl backends
97        . all_gather, all_to_all collective ops
98    Note: BFP16 only supports 2D tensors.
99    Args:
100        func (Callable): A function representing collective operations.
101        qtype (QuantType): Quantization method
102        quant_loss (float, optional): This can be used to improve accuracy in the dequantization.
103    Returns:
104        (Callable): the same collective as func but enables automatic quantization/dequantization.
105    """
106
107    @functools.wraps(func)
108    def wrapper(*args, **kwargs):
109        group = kwargs.get("group", None)
110        async_op = kwargs.get("async_op", False)
111        if async_op is True:
112            raise RuntimeError("The async_op=True mode is not supported yet.")
113        if func == dist.all_gather:
114            tensors = args[0]
115            input_tensors = _quantize_tensor(args[1], qtype)
116            out_tensors = _quantize_tensor_list(tensors, qtype)
117            dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op)
118            for i, t in enumerate(
119                _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)
120            ):
121                tensors[i] = t
122
123        elif func == dist.all_to_all:
124            tensors = args[0]
125            input_tensors = _quantize_tensor_list(args[1], qtype)
126            out_tensors = _quantize_tensor_list(tensors, qtype)
127            dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op)
128            for i, t in enumerate(
129                _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)
130            ):
131                tensors[i] = t
132
133        elif func == dist.all_to_all_single:
134            tensors = args[0]
135            out_splits = kwargs.get("out_splits", None)
136            in_splits = kwargs.get("in_splits", None)
137            # Quantizing the input/output tensor
138            input_tensors = _quantize_tensor(args[1], qtype)
139            out_tensors = _quantize_tensor(tensors, qtype)
140            dist.all_to_all_single(
141                out_tensors, input_tensors, out_splits, in_splits, group=group
142            )
143            for i, t in enumerate(
144                _dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)
145            ):
146                tensors[i] = t
147        else:
148            raise RuntimeError(f"The collective op {func} is not supported yet")
149
150    return wrapper
151