xref: /aosp_15_r20/external/pytorch/benchmarks/gpt_fast/quantize.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# flake8: noqa: E266, C417, B950
2import torch
3import torch.nn as nn
4import torch.nn.functional as F
5
6
7##### Quantization Primitives ######
8
9
10def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
11    # assumes symmetric quantization
12    # assumes axis == 0
13    # assumes dense memory format
14    # TODO(future): relax ^ as needed
15
16    # default setup for affine quantization of activations
17    eps = torch.finfo(torch.float32).eps
18
19    # get min and max
20    min_val, max_val = torch.aminmax(x, dim=1)
21
22    # calculate scales and zero_points based on min and max
23    # reference: https://fburl.com/code/srbiybme
24    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
25    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
26    device = min_val_neg.device
27
28    # reference: https://fburl.com/code/4wll53rk
29    max_val_pos = torch.max(-min_val_neg, max_val_pos)
30    scales = max_val_pos / (float(quant_max - quant_min) / 2)
31    # ensure scales is the same dtype as the original tensor
32    scales = torch.clamp(scales, min=eps).to(x.dtype)
33    zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
34
35    # quantize based on qmin/qmax/scales/zp
36    # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
37    x_div = x / scales.unsqueeze(-1)
38    x_round = torch.round(x_div)
39    x_zp = x_round + zero_points.unsqueeze(-1)
40    quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
41
42    return quant, scales, zero_points
43
44
45##### Weight-only int8 per-channel quantized code ######
46
47
48def replace_linear_weight_only_int8_per_channel(module):
49    for name, child in module.named_children():
50        if isinstance(child, nn.Linear):
51            setattr(
52                module,
53                name,
54                WeightOnlyInt8Linear(child.in_features, child.out_features),
55            )
56        else:
57            replace_linear_weight_only_int8_per_channel(child)
58
59
60class WeightOnlyInt8QuantHandler:
61    def __init__(self, mod):
62        self.mod = mod
63
64    @torch.no_grad()
65    def create_quantized_state_dict(self):
66        cur_state_dict = self.mod.state_dict()
67        for fqn, mod in self.mod.named_modules():
68            if isinstance(mod, torch.nn.Linear):
69                int8_weight, scales, _ = dynamically_quantize_per_channel(
70                    mod.weight.float(), -128, 127, torch.int8
71                )
72                cur_state_dict[f"{fqn}.weight"] = int8_weight.to("cpu")
73                cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype).to("cpu")
74
75        return cur_state_dict
76
77    def convert_for_runtime(self):
78        replace_linear_weight_only_int8_per_channel(self.mod)
79        return self.mod
80
81
82class WeightOnlyInt8Linear(torch.nn.Module):
83    __constants__ = ["in_features", "out_features"]
84    in_features: int
85    out_features: int
86    weight: torch.Tensor
87
88    def __init__(
89        self,
90        in_features: int,
91        out_features: int,
92        bias: bool = True,
93        device=None,
94        dtype=None,
95    ) -> None:
96        factory_kwargs = {"device": device, "dtype": dtype}
97        super().__init__()
98        self.in_features = in_features
99        self.out_features = out_features
100        self.register_buffer(
101            "weight", torch.empty((out_features, in_features), dtype=torch.int8)
102        )
103        self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
104
105    def forward(self, input: torch.Tensor) -> torch.Tensor:
106        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
107