xref: /aosp_15_r20/external/executorch/backends/example/example_quantizer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass
8*523fa7a6SAndroid Build Coastguard Workerfrom typing import List
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Workerimport torch
11*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.example.example_operators.ops import module_to_annotator
12*523fa7a6SAndroid Build Coastguard Workerfrom torch import fx
13*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.observer import HistogramObserver, MinMaxObserver
14*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
15*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
16*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OperatorConfig
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Workerdef get_uint8_tensor_spec(observer_or_fake_quant_ctr):
20*523fa7a6SAndroid Build Coastguard Worker    return QuantizationSpec(
21*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.uint8,
22*523fa7a6SAndroid Build Coastguard Worker        quant_min=0,
23*523fa7a6SAndroid Build Coastguard Worker        quant_max=255,
24*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_affine,
25*523fa7a6SAndroid Build Coastguard Worker        is_dynamic=False,
26*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=observer_or_fake_quant_ctr,
27*523fa7a6SAndroid Build Coastguard Worker    )
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker@dataclass
31*523fa7a6SAndroid Build Coastguard Workerclass ExampleQuantConfig:
32*523fa7a6SAndroid Build Coastguard Worker    input_quant_spec: QuantizationSpec
33*523fa7a6SAndroid Build Coastguard Worker    output_quant_spec: QuantizationSpec
34*523fa7a6SAndroid Build Coastguard Worker    weight_quant_spec: QuantizationSpec
35*523fa7a6SAndroid Build Coastguard Worker    bias_quant_spec: QuantizationSpec
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker
38*523fa7a6SAndroid Build Coastguard Workerdefault_static_config = ExampleQuantConfig(
39*523fa7a6SAndroid Build Coastguard Worker    get_uint8_tensor_spec(HistogramObserver),
40*523fa7a6SAndroid Build Coastguard Worker    get_uint8_tensor_spec(HistogramObserver),
41*523fa7a6SAndroid Build Coastguard Worker    get_uint8_tensor_spec(MinMaxObserver),
42*523fa7a6SAndroid Build Coastguard Worker    # pyre-fixme[6]: Incompatible parameter type [6]: In call `ExampleQuantConfig.__init__`, for 4th positional argument, expected `QuantizationSpec` but got `None`.
43*523fa7a6SAndroid Build Coastguard Worker    None,  # #bias quantization can be configured here or done in a pass later on.
44*523fa7a6SAndroid Build Coastguard Worker)
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Workerdef check_for_outside_users(partitions) -> bool:
48*523fa7a6SAndroid Build Coastguard Worker    """
49*523fa7a6SAndroid Build Coastguard Worker    Make sure that all the users of this partiton are within the delegatable subgraph,
50*523fa7a6SAndroid Build Coastguard Worker    except the last partition. If we quantize partitions that have users outside this
51*523fa7a6SAndroid Build Coastguard Worker    subgraph then delegation of this partition to the backend will not be possible.
52*523fa7a6SAndroid Build Coastguard Worker    """
53*523fa7a6SAndroid Build Coastguard Worker    for source_partition in partitions[:-1]:
54*523fa7a6SAndroid Build Coastguard Worker        if len(source_partition.output_nodes) != 1:
55*523fa7a6SAndroid Build Coastguard Worker            return True
56*523fa7a6SAndroid Build Coastguard Worker        if len(source_partition.output_nodes[0].users) != 1:
57*523fa7a6SAndroid Build Coastguard Worker            return True
58*523fa7a6SAndroid Build Coastguard Worker    return False
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard Workerclass ExampleQuantizer(Quantizer):
62*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, quantizer_supported_modules=None, quant_config=None):
63*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
64*523fa7a6SAndroid Build Coastguard Worker        if quantizer_supported_modules is not None:
65*523fa7a6SAndroid Build Coastguard Worker            self.quantizer_supported_modules = quantizer_supported_modules
66*523fa7a6SAndroid Build Coastguard Worker            for module in self.quantizer_supported_modules:
67*523fa7a6SAndroid Build Coastguard Worker                if module not in module_to_annotator.keys():
68*523fa7a6SAndroid Build Coastguard Worker                    assert 0, f"{module} is not supported by this quantizer"
69*523fa7a6SAndroid Build Coastguard Worker        else:
70*523fa7a6SAndroid Build Coastguard Worker            self.quantizer_supported_modules = module_to_annotator.keys()
71*523fa7a6SAndroid Build Coastguard Worker        if quant_config is not None:
72*523fa7a6SAndroid Build Coastguard Worker            self.quant_config = quant_config
73*523fa7a6SAndroid Build Coastguard Worker        else:
74*523fa7a6SAndroid Build Coastguard Worker            self.quant_config = default_static_config
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker    def annotate(self, model):
77*523fa7a6SAndroid Build Coastguard Worker        for supported_modules in self.quantizer_supported_modules:
78*523fa7a6SAndroid Build Coastguard Worker            # print("supported modules: ", supported_modules)
79*523fa7a6SAndroid Build Coastguard Worker            fused_partitions = find_sequential_partitions(
80*523fa7a6SAndroid Build Coastguard Worker                model,
81*523fa7a6SAndroid Build Coastguard Worker                list(supported_modules),
82*523fa7a6SAndroid Build Coastguard Worker            )
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker            for partitions in fused_partitions:
85*523fa7a6SAndroid Build Coastguard Worker                if check_for_outside_users(partitions):
86*523fa7a6SAndroid Build Coastguard Worker                    continue
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker                source_module_list = ()
89*523fa7a6SAndroid Build Coastguard Worker                for partition in partitions:
90*523fa7a6SAndroid Build Coastguard Worker                    source_module_list += (partition,)
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Worker                annotator = module_to_annotator[supported_modules].annotate_handle
93*523fa7a6SAndroid Build Coastguard Worker                annotator(partitions, self.quant_config)
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Worker        return model
96*523fa7a6SAndroid Build Coastguard Worker
97*523fa7a6SAndroid Build Coastguard Worker    def validate(self, model: fx.GraphModule) -> None:
98*523fa7a6SAndroid Build Coastguard Worker        pass
99*523fa7a6SAndroid Build Coastguard Worker
100*523fa7a6SAndroid Build Coastguard Worker    @classmethod
101*523fa7a6SAndroid Build Coastguard Worker    def get_supported_operators(cls) -> List[OperatorConfig]:
102*523fa7a6SAndroid Build Coastguard Worker        return []
103