xref: /aosp_15_r20/external/libopus/dnn/torch/osce/utils/softquant.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1import torch
2
3@torch.no_grad()
4def compute_optimal_scale(weight):
5    with torch.no_grad():
6        n_out, n_in = weight.shape
7        assert n_in % 4 == 0
8        if n_out % 8:
9            # add padding
10            pad = n_out - n_out % 8
11            weight = torch.cat((weight, torch.zeros((pad, n_in), dtype=weight.dtype, device=weight.device)), dim=0)
12
13        weight_max_abs, _ = torch.max(torch.abs(weight), dim=1)
14        weight_max_sum, _ = torch.max(torch.abs(weight[:, : n_in : 2] + weight[:, 1 : n_in : 2]), dim=1)
15        scale_max = weight_max_abs / 127
16        scale_sum = weight_max_sum / 129
17
18        scale = torch.maximum(scale_max, scale_sum)
19
20    return scale[:n_out]
21
22@torch.no_grad()
23def q_scaled_noise(module, weight):
24    if isinstance(module, torch.nn.Conv1d):
25        w = weight.permute(0, 2, 1).flatten(1)
26        noise = torch.rand_like(w) - 0.5
27        scale = compute_optimal_scale(w)
28        noise = noise * scale.unsqueeze(-1)
29        noise = noise.reshape(weight.size(0), weight.size(2), weight.size(1)).permute(0, 2, 1)
30    elif isinstance(module, torch.nn.ConvTranspose1d):
31        i, o, k = weight.shape
32        w = weight.permute(2, 1, 0).reshape(k * o, i)
33        noise = torch.rand_like(w) - 0.5
34        scale = compute_optimal_scale(w)
35        noise = noise * scale.unsqueeze(-1)
36        noise = noise.reshape(k, o, i).permute(2, 1, 0)
37    elif len(weight.shape) == 2:
38        noise = torch.rand_like(weight) - 0.5
39        scale = compute_optimal_scale(weight)
40        noise = noise * scale.unsqueeze(-1)
41    else:
42        raise ValueError('unknown quantization setting')
43
44    return noise
45
46class SoftQuant:
47    name: str
48
49    def __init__(self, names: str, scale: float) -> None:
50        self.names = names
51        self.quantization_noise = None
52        self.scale = scale
53
54    def __call__(self, module, inputs, *args, before=True):
55        if not module.training: return
56
57        if before:
58            self.quantization_noise = dict()
59            for name in self.names:
60                weight = getattr(module, name)
61                if self.scale is None:
62                    self.quantization_noise[name] = q_scaled_noise(module, weight)
63                else:
64                    self.quantization_noise[name] = \
65                        self.scale * weight.abs().max() * (torch.rand_like(weight) - 0.5)
66                with torch.no_grad():
67                    weight.data[:] = weight + self.quantization_noise[name]
68        else:
69            for name in self.names:
70                weight = getattr(module, name)
71                with torch.no_grad():
72                    weight.data[:] = weight - self.quantization_noise[name]
73            self.quantization_noise = None
74
75    def apply(module, names=['weight'], scale=None):
76        fn = SoftQuant(names, scale)
77
78        for name in names:
79            if not hasattr(module, name):
80                raise ValueError("")
81
82        fn_before = lambda *x : fn(*x, before=True)
83        fn_after = lambda *x : fn(*x, before=False)
84        setattr(fn_before, 'sqm', fn)
85        setattr(fn_after, 'sqm', fn)
86
87
88        module.register_forward_pre_hook(fn_before)
89        module.register_forward_hook(fn_after)
90
91        module
92
93        return fn
94
95
96def soft_quant(module, names=['weight'], scale=None):
97    fn = SoftQuant.apply(module, names, scale)
98    return module
99
100def remove_soft_quant(module, names=['weight']):
101    for k, hook in module._forward_pre_hooks.items():
102        if hasattr(hook, 'sqm'):
103            if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
104                del module._forward_pre_hooks[k]
105    for k, hook in module._forward_hooks.items():
106        if hasattr(hook, 'sqm'):
107            if isinstance(hook.sqm, SoftQuant) and hook.sqm.names == names:
108                del module._forward_hooks[k]
109
110    return module