xref: /aosp_15_r20/external/pytorch/test/fx/quantization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1r"""
2**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not
3rely on it for anything!**
4"""
5import operator
6import sys
7from typing import Optional
8
9import torch
10from torch.fx import Graph, GraphModule, Node
11from torch.fx.graph import map_arg
12from torch.fx.proxy import Proxy
13from torch.nn.utils import fuse_conv_bn_weights
14
15
16# can be a
17#  module type, a builtin function, or a string to match target
18
19
20def _minmax_scale_zeropoint(
21    min_val, max_val, qmin=-127, qmax=128, eps=torch.finfo(torch.float32).eps
22):
23    min_val = min(0.0, min_val)
24    max_val = max(0.0, max_val)
25    if max_val == min_val:
26        return 1.0, 0
27    else:
28        scale = (max_val - min_val) / float(qmax - qmin)
29        scale = max(scale, eps)
30        zero_point = qmin - round(min_val / scale)
31        zero_point = max(qmin, zero_point)
32        zero_point = min(qmax, zero_point)
33        zero_point = int(zero_point)
34        return scale, zero_point
35
36
37class MinMaxObserver:
38    def __init__(self, quantizer, node):
39        self.min, self.max = float("inf"), float("-inf")
40        self.all_tensors = True
41
42    def observe(self, node, env):
43        v = env[node.name]
44        if not isinstance(v, torch.Tensor):
45            self.all_tensors = False
46            return
47        self.max = max(self.max, float(v.max()))
48        self.min = min(self.min, float(v.min()))
49
50    def scale_zeropoint(self):
51        return _minmax_scale_zeropoint(self.min, self.max, qmin=0, qmax=255)
52
53
54class NoObserver:
55    def __init__(self, quantizer, node):
56        pass
57
58    def observe(self, node, env):
59        pass
60
61
62_DEFAULT_QUANTIZATION_PATTERNS = {}
63
64
65def register_pattern(pattern):
66    def insert(fn):
67        _DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
68        return fn
69
70    return insert
71
72
73@register_pattern(operator.add)
74class Add(MinMaxObserver):
75    def quantize(self, quantizer, node, load_arg):
76        if not self.all_tensors:
77            return NotImplemented
78        scale, zeropoint = self.scale_zeropoint()
79        return quantizer.quantized_graph.create_node(
80            "call_function",
81            torch.ops.quantized.add,
82            load_arg(node.args),
83            {"scale": scale, "zero_point": zeropoint},
84        )
85
86
87class Relu(NoObserver):
88    def quantize(self, quantizer, node, load_arg):
89        return torch.relu(
90            load_arg(node.args[0])
91        )  # torch.relu works directly on quantized tensors?
92
93
94# these ops have quantized equivalents that do not need any extra information
95@register_pattern(torch.nn.ReLU)
96@register_pattern(torch.nn.AvgPool2d)
97@register_pattern(torch.nn.MaxPool2d)
98@register_pattern(torch.nn.AdaptiveAvgPool2d)
99class CopyNode(NoObserver):
100    def quantize(self, quantizer, node, load_arg):
101        return quantizer.quantized_graph.node_copy(node, load_arg)
102
103
104class IdentityModule(torch.nn.Module):
105    def forward(self, x):
106        return x
107
108
109# handle conv, maybe followed by bn, maybe followed by relu
110@register_pattern(torch.nn.modules.conv.Conv2d)
111@register_pattern((torch.nn.ReLU, torch.nn.modules.conv.Conv2d))
112@register_pattern(
113    (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d)
114)
115@register_pattern(
116    (
117        torch.nn.ReLU,
118        (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d),
119    )
120)
121class ConvNormRelu(MinMaxObserver):
122    def __init__(self, quantizer, node):
123        super().__init__(quantizer, node)
124        self.relu_node, self.bn_node = None, None
125        if isinstance(quantizer.modules[node.target], torch.nn.ReLU):
126            self.relu_node = node
127            node = node.args[0]
128        if isinstance(quantizer.modules[node.target], torch.nn.BatchNorm2d):
129            self.bn_node = node
130            self.bn = quantizer.modules[self.bn_node.target]
131            node = node.args[0]
132        assert isinstance(quantizer.modules[node.target], torch.nn.modules.Conv2d)
133        self.conv_node = node
134        self.conv = quantizer.modules[self.conv_node.target]
135
136    def quantize(self, quantizer, node, load_arg):
137        mod = self.conv
138        weight, bias = mod.weight, mod.bias
139
140        if self.bn_node is not None:
141            weight, bias = fuse_conv_bn_weights(
142                weight,
143                bias,
144                self.bn.running_mean,
145                self.bn.running_var,
146                self.bn.eps,
147                self.bn.weight,
148                self.bn.bias,
149            )
150
151        min_val, max_val = float(weight.min()), float(weight.max())
152
153        act_scale, act_zp = self.scale_zeropoint()
154
155        weight_scale, weight_zp = _minmax_scale_zeropoint(min_val, max_val)
156        qweight = torch.quantize_per_tensor(
157            weight, weight_scale, weight_zp, torch.qint8
158        )
159
160        ctor = (
161            torch.ao.nn.intrinsic.quantized.ConvReLU2d
162            if self.relu_node is not None
163            else torch.ao.nn.quantized.Conv2d
164        )
165
166        qconv = ctor(
167            mod.in_channels,
168            mod.out_channels,
169            mod.kernel_size,
170            mod.stride,
171            mod.padding,
172            mod.dilation,
173            mod.groups,
174            mod.bias is not None,
175            mod.padding_mode,
176        )
177
178        qconv.set_weight_bias(qweight, bias)
179        qconv.scale = float(act_scale)
180        qconv.zero_point = int(act_zp)
181        parent_name, name = _parent_name(self.conv_node.target)
182        setattr(quantizer.modules[parent_name], name, qconv)
183        if self.bn_node is not None:
184            parent_bn, bn_name = _parent_name(self.bn_node.target)
185            # we can't just delete this because submodules's forwards (which are not longer use)
186            # try to call it, so replace with something that does nothing.
187            setattr(quantizer.modules[parent_name], bn_name, IdentityModule())
188
189        return quantizer.quantized_graph.create_node(
190            "call_module",
191            self.conv_node.target,
192            (load_arg(self.conv_node.args[0]),),
193            {},
194        )
195
196
197# turn foo.bar -> ['foo', 'bar']
198def _parent_name(target):
199    r = target.rsplit(".", 1)
200    if len(r) == 1:
201        return "", r[0]
202    else:
203        return r[0], r[1]
204
205
206class DefaultQuant(MinMaxObserver):
207    def quantize(self, input):
208        assert self.all_tensors
209        scale, zeropoint = self.scale_zeropoint()
210        return torch.quantize_per_tensor(
211            Proxy(input), scale, zeropoint, torch.quint8
212        ).node
213
214
215def matches(modules, node, pattern, max_uses=sys.maxsize):
216    if isinstance(pattern, tuple):
217        self_match, *arg_matches = pattern
218    else:
219        self_match = pattern
220        arg_matches = None
221
222    if len(node.users) > max_uses:
223        return False
224
225    if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
226        if node.op != "call_module":
227            return False
228        if not isinstance(modules[node.target], self_match):
229            return False
230    elif callable(self_match):
231        if node.op != "call_function" or node.target is not self_match:
232            return False
233    elif node.target != self_match:
234        return False
235
236    if not arg_matches:
237        return True
238
239    if len(arg_matches) != len(node.args):
240        return False
241
242    return all(
243        matches(modules, node, arg_match, max_uses=1)
244        for node, arg_match in zip(node.args, arg_matches)
245    )
246
247
248class Quantizer:
249    def __init__(
250        self, mod, patterns=_DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant
251    ):
252        self.root = mod
253        self.graph = mod.graph
254        self.quant_ctor = quant_ctor
255
256        # cached information for observe
257        self.state_dict = self.root.state_dict()
258        self.modules = dict(self.root.named_modules())
259
260        # match the patterns that will get quantized
261        self.matches = self._find_matches(patterns)
262        # find _inputs_ to matched nodes that are not quantized, these
263        # have to be quantized, which requires measuring stats,
264        # initialize an quant_ctor object for each
265        self.quants = self._find_quants(quant_ctor)
266
267    def observe(self, args):
268        # most of this function is just an interpreter for the graph
269        # it would be possible to put this in some abstraction, but
270        # it is pretty nice to just be able to see exactly what is happening here
271        # and hack on it.
272        # maybe we should just provide an example interpreter that people copy/paste
273        # then edit.
274        args_iter = iter(args)
275        env = {}
276
277        def load_arg(a):
278            return map_arg(a, lambda node: env[node.name])
279
280        output_node: Optional[Node] = None
281        for node in self.graph.nodes:
282            if node.op == "placeholder":
283                result = next(args_iter)
284            elif node.op == "get_attr":
285                result = self.state_dict[node.target]
286            elif node.op == "call_function":
287                result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
288            elif node.op == "call_method":
289                self_obj, *args = load_arg(node.args)
290                kwargs = load_arg(node.kwargs)
291                result = getattr(self_obj, node.target)(*args, **kwargs)
292            elif node.op == "call_module":
293                result = self.modules[node.target](
294                    *load_arg(node.args), **load_arg(node.kwargs)
295                )
296            elif node.op == "output":
297                return load_arg(node.args[0])
298
299            env[node.name] = result
300            root_node, obj = self.matches.get(node.name, (None, None))
301            if root_node is node:
302                obj.observe(node, env)
303            if node.name in self.quants:
304                self.quants[node.name].observe(node, env)
305
306        raise RuntimeError("Graph had no output node!")
307
308    def quantize(self):
309        self.quantized_graph = Graph()
310
311        env = {}
312        quant_env = {}
313
314        def load_arg(n, quantized):
315            if not quantized:
316                if n.name not in env and n.name in quant_env:
317                    env[n.name] = Proxy(quant_env[n.name]).dequantize().node
318                return env[n.name]
319            else:
320                if n.name not in quant_env and n.name in env:
321                    quant_env[n.name] = self.quants[n.name].quantize(env[n.name])
322                return quant_env[n.name]
323
324        def copy_recursive(node):
325            def load_or_emit(n):
326                if n.name in env or e.name in quant_env:  # noqa: F821
327                    return load_arg(n, quantized=False)
328                else:
329                    return copy_recursive(n)
330
331            r = env[node.name] = self.quantized_graph.node_copy(
332                node, lambda n: load_arg(n, quantized=False)
333            )
334            return r
335
336        for node in self.graph.nodes:
337            root_node, obj = self.matches.get(node.name, (None, None))
338            if root_node is None:
339                # not quantized just copy it
340                env[node.name] = self.quantized_graph.node_copy(
341                    node, lambda n: load_arg(n, quantized=False)
342                )
343
344            elif root_node is node:
345                r = obj.quantize(
346                    self,
347                    node,
348                    lambda a: map_arg(a, lambda n: load_arg(n, quantized=True)),
349                )
350                if r is NotImplemented:
351                    # quantizer choose to to quantize the node take the entire match, and just copy it over
352                    env[node.name] = copy_recursive(node)
353                else:
354                    quant_env[node.name] = r
355
356        return GraphModule(self.root, self.quantized_graph)
357
358    def _find_matches(self, patterns):
359        modules = dict(self.root.named_modules())
360        match_map = {}  # node name -> (root_node, match_value?)
361
362        def apply_match(pattern, node, match):
363            if isinstance(pattern, tuple):
364                s, *args = pattern
365                apply_match(s, node, match)
366                for subpattern, arg in zip(args, node.args):
367                    apply_match(subpattern, arg, match)
368            else:
369                match_map[node.name] = match
370
371        for node in reversed(self.graph.nodes):
372            if node.name not in match_map:
373                for pattern, value in patterns.items():
374                    if matches(modules, node, pattern):
375                        apply_match(pattern, node, (node, value(self, node)))
376
377        return match_map
378
379    def _find_quants(self, quant_ctor):
380        quants = {}
381
382        def visit_arg(n):
383            # note: we have to measure quantization information
384            # even for nodes where we might not use it because it is already
385            # quantized. This is because each match has the option to
386            # say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
387            if n.name not in quants:
388                quants[n.name] = quant_ctor(self, n)
389
390        for node in self.graph.nodes:
391            if node.name in self.matches:
392                map_arg(node.args, visit_arg)
393                map_arg(node.kwargs, visit_arg)
394        return quants
395