xref: /aosp_15_r20/external/pytorch/benchmarks/gpt_fast/mixtral_moe_quantize.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# flake8: noqa: E266, C417, B950
2from mixtral_moe_model import ConditionalFeedForward
3
4import torch
5import torch.nn as nn
6import torch.nn.functional as F
7
8
9##### Quantization Primitives ######
10
11
12def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
13    # assumes symmetric quantization
14    # assumes axis == 0
15    # assumes dense memory format
16    # TODO(future): relax ^ as needed
17
18    # default setup for affine quantization of activations
19    eps = torch.finfo(torch.float32).eps
20
21    # get min and max
22    min_val, max_val = torch.aminmax(x, dim=1)
23
24    # calculate scales and zero_points based on min and max
25    # reference: https://fburl.com/code/srbiybme
26    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
27    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
28    device = min_val_neg.device
29
30    # reference: https://fburl.com/code/4wll53rk
31    max_val_pos = torch.max(-min_val_neg, max_val_pos)
32    scales = max_val_pos / (float(quant_max - quant_min) / 2)
33    # ensure scales is the same dtype as the original tensor
34    scales = torch.clamp(scales, min=eps).to(x.dtype)
35    zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
36
37    # quantize based on qmin/qmax/scales/zp
38    # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
39    x_div = x / scales.unsqueeze(-1)
40    x_round = torch.round(x_div)
41    x_zp = x_round + zero_points.unsqueeze(-1)
42    quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
43
44    return quant, scales, zero_points
45
46
47##### Weight-only int8 per-channel quantized code ######
48
49
50def replace_linear_weight_only_int8_per_channel(module):
51    for name, child in module.named_children():
52        if isinstance(child, nn.Linear) and name != "gate":
53            setattr(
54                module,
55                name,
56                WeightOnlyInt8Linear(
57                    child.in_features, child.out_features, target_dtype=torch.int8
58                ),
59            )
60        elif isinstance(child, ConditionalFeedForward):
61            num_experts, intermediate_size, dim = child.w1.shape
62            setattr(
63                module,
64                name,
65                ConditionalFeedForwardInt8(
66                    num_experts, intermediate_size, dim, target_dtype=torch.int8
67                ),
68            )
69        else:
70            replace_linear_weight_only_int8_per_channel(child)
71
72
73class WeightOnlyInt8QuantHandler:
74    def __init__(self, mod):
75        self.mod = mod
76
77    @torch.no_grad()
78    def create_quantized_state_dict(self):
79        cur_state_dict = self.mod.state_dict()
80        for fqn, mod in self.mod.named_modules():
81            if isinstance(mod, torch.nn.Linear) and not fqn.endswith(".gate"):
82                int8_weight, scales, _ = dynamically_quantize_per_channel(
83                    mod.weight.float(), -128, 127, torch.int8
84                )
85                cur_state_dict[f"{fqn}.weight"] = int8_weight
86                cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
87            elif isinstance(mod, ConditionalFeedForward):
88                for weight_idx in range(0, 3):
89                    weight_name = f"w{weight_idx + 1}"
90                    scales_name = f"scales{weight_idx + 1}"
91                    weight = getattr(mod, weight_name)
92                    num_experts, intermediate_size, dim = weight.shape
93
94                    bit8_weight_list = []
95                    scales_list = []
96                    for expert_idx in range(num_experts):
97                        bit8_weight, scales, _ = dynamically_quantize_per_channel(
98                            weight[expert_idx].float(), -128, 127, torch.int8
99                        )
100                        bit8_weight_list.append(
101                            bit8_weight.reshape(1, intermediate_size, dim)
102                        )
103                        scales_list.append(scales.reshape(1, intermediate_size))
104
105                    cur_state_dict[f"{fqn}.{weight_name}"] = torch.cat(
106                        bit8_weight_list, dim=0
107                    )
108                    cur_state_dict[f"{fqn}.{scales_name}"] = torch.cat(
109                        scales_list, dim=0
110                    )
111
112        return cur_state_dict
113
114    def convert_for_runtime(self):
115        replace_linear_weight_only_int8_per_channel(self.mod)
116        return self.mod
117
118
119class WeightOnlyInt8Linear(torch.nn.Module):
120    __constants__ = ["in_features", "out_features"]
121    in_features: int
122    out_features: int
123    weight: torch.Tensor
124
125    def __init__(
126        self,
127        in_features: int,
128        out_features: int,
129        bias: bool = True,
130        device=None,
131        dtype=None,
132        target_dtype=None,
133    ) -> None:
134        assert target_dtype is not None
135        factory_kwargs = {"device": device, "dtype": dtype}
136        super().__init__()
137        self.in_features = in_features
138        self.out_features = out_features
139        self.register_buffer(
140            "weight", torch.empty((out_features, in_features), dtype=target_dtype)
141        )
142        self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
143
144    def forward(self, input: torch.Tensor) -> torch.Tensor:
145        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
146
147
148class ConditionalFeedForwardInt8(nn.Module):
149    def __init__(self, num_experts, intermediate_size, dim, target_dtype):
150        super().__init__()
151
152        self.target_dtype = target_dtype
153
154        self.register_buffer(
155            "w1", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype)
156        )
157        self.register_buffer(
158            "w2", torch.empty(num_experts, dim, intermediate_size, dtype=target_dtype)
159        )
160        self.register_buffer(
161            "w3", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype)
162        )
163
164        self.register_buffer(
165            "scales1", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16)
166        )
167        self.register_buffer(
168            "scales2", torch.empty(num_experts, dim, dtype=torch.bfloat16)
169        )
170        self.register_buffer(
171            "scales3", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16)
172        )
173
174    def forward(self, x, expert_indices):
175        w1_weights = self.w1.to(x.dtype)[expert_indices]  # [T, A, D, D]
176        w3_weights = self.w3.to(x.dtype)[expert_indices]  # [T, A, D, D]
177        w2_weights = self.w2.to(x.dtype)[expert_indices]
178        x1 = F.silu(
179            torch.einsum("ti,taoi -> tao", x, w1_weights)
180            * self.scales1[expert_indices].to(x.dtype)
181        )
182        x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) * self.scales3[
183            expert_indices
184        ].to(x.dtype)
185        expert_outs = torch.einsum(
186            "tao, taio -> tai", (x1 * x3), w2_weights
187        ) * self.scales2[expert_indices].to(
188            x.dtype
189        )  # [T, A, D, D]
190        return expert_outs
191