1# flake8: noqa: E266, C417, B950 2import torch 3import torch.nn as nn 4import torch.nn.functional as F 5 6 7##### Quantization Primitives ###### 8 9 10def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): 11 # assumes symmetric quantization 12 # assumes axis == 0 13 # assumes dense memory format 14 # TODO(future): relax ^ as needed 15 16 # default setup for affine quantization of activations 17 eps = torch.finfo(torch.float32).eps 18 19 # get min and max 20 min_val, max_val = torch.aminmax(x, dim=1) 21 22 # calculate scales and zero_points based on min and max 23 # reference: https://fburl.com/code/srbiybme 24 min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 25 max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 26 device = min_val_neg.device 27 28 # reference: https://fburl.com/code/4wll53rk 29 max_val_pos = torch.max(-min_val_neg, max_val_pos) 30 scales = max_val_pos / (float(quant_max - quant_min) / 2) 31 # ensure scales is the same dtype as the original tensor 32 scales = torch.clamp(scales, min=eps).to(x.dtype) 33 zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) 34 35 # quantize based on qmin/qmax/scales/zp 36 # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 37 x_div = x / scales.unsqueeze(-1) 38 x_round = torch.round(x_div) 39 x_zp = x_round + zero_points.unsqueeze(-1) 40 quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) 41 42 return quant, scales, zero_points 43 44 45##### Weight-only int8 per-channel quantized code ###### 46 47 48def replace_linear_weight_only_int8_per_channel(module): 49 for name, child in module.named_children(): 50 if isinstance(child, nn.Linear): 51 setattr( 52 module, 53 name, 54 WeightOnlyInt8Linear(child.in_features, child.out_features), 55 ) 56 else: 57 replace_linear_weight_only_int8_per_channel(child) 58 59 60class WeightOnlyInt8QuantHandler: 61 def __init__(self, mod): 62 self.mod = mod 63 64 @torch.no_grad() 65 def create_quantized_state_dict(self): 66 cur_state_dict = self.mod.state_dict() 67 for fqn, mod in self.mod.named_modules(): 68 if isinstance(mod, torch.nn.Linear): 69 int8_weight, scales, _ = dynamically_quantize_per_channel( 70 mod.weight.float(), -128, 127, torch.int8 71 ) 72 cur_state_dict[f"{fqn}.weight"] = int8_weight.to("cpu") 73 cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype).to("cpu") 74 75 return cur_state_dict 76 77 def convert_for_runtime(self): 78 replace_linear_weight_only_int8_per_channel(self.mod) 79 return self.mod 80 81 82class WeightOnlyInt8Linear(torch.nn.Module): 83 __constants__ = ["in_features", "out_features"] 84 in_features: int 85 out_features: int 86 weight: torch.Tensor 87 88 def __init__( 89 self, 90 in_features: int, 91 out_features: int, 92 bias: bool = True, 93 device=None, 94 dtype=None, 95 ) -> None: 96 factory_kwargs = {"device": device, "dtype": dtype} 97 super().__init__() 98 self.in_features = in_features 99 self.out_features = out_features 100 self.register_buffer( 101 "weight", torch.empty((out_features, in_features), dtype=torch.int8) 102 ) 103 self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) 104 105 def forward(self, input: torch.Tensor) -> torch.Tensor: 106 return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales 107