1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass 8*523fa7a6SAndroid Build Coastguard Workerfrom typing import List 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Workerimport torch 11*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.example.example_operators.ops import module_to_annotator 12*523fa7a6SAndroid Build Coastguard Workerfrom torch import fx 13*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.observer import HistogramObserver, MinMaxObserver 14*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions 15*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer import QuantizationSpec, Quantizer 16*523fa7a6SAndroid Build Coastguard Workerfrom torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OperatorConfig 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Workerdef get_uint8_tensor_spec(observer_or_fake_quant_ctr): 20*523fa7a6SAndroid Build Coastguard Worker return QuantizationSpec( 21*523fa7a6SAndroid Build Coastguard Worker dtype=torch.uint8, 22*523fa7a6SAndroid Build Coastguard Worker quant_min=0, 23*523fa7a6SAndroid Build Coastguard Worker quant_max=255, 24*523fa7a6SAndroid Build Coastguard Worker qscheme=torch.per_tensor_affine, 25*523fa7a6SAndroid Build Coastguard Worker is_dynamic=False, 26*523fa7a6SAndroid Build Coastguard Worker observer_or_fake_quant_ctr=observer_or_fake_quant_ctr, 27*523fa7a6SAndroid Build Coastguard Worker ) 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker@dataclass 31*523fa7a6SAndroid Build Coastguard Workerclass ExampleQuantConfig: 32*523fa7a6SAndroid Build Coastguard Worker input_quant_spec: QuantizationSpec 33*523fa7a6SAndroid Build Coastguard Worker output_quant_spec: QuantizationSpec 34*523fa7a6SAndroid Build Coastguard Worker weight_quant_spec: QuantizationSpec 35*523fa7a6SAndroid Build Coastguard Worker bias_quant_spec: QuantizationSpec 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker 38*523fa7a6SAndroid Build Coastguard Workerdefault_static_config = ExampleQuantConfig( 39*523fa7a6SAndroid Build Coastguard Worker get_uint8_tensor_spec(HistogramObserver), 40*523fa7a6SAndroid Build Coastguard Worker get_uint8_tensor_spec(HistogramObserver), 41*523fa7a6SAndroid Build Coastguard Worker get_uint8_tensor_spec(MinMaxObserver), 42*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[6]: Incompatible parameter type [6]: In call `ExampleQuantConfig.__init__`, for 4th positional argument, expected `QuantizationSpec` but got `None`. 43*523fa7a6SAndroid Build Coastguard Worker None, # #bias quantization can be configured here or done in a pass later on. 44*523fa7a6SAndroid Build Coastguard Worker) 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Worker 47*523fa7a6SAndroid Build Coastguard Workerdef check_for_outside_users(partitions) -> bool: 48*523fa7a6SAndroid Build Coastguard Worker """ 49*523fa7a6SAndroid Build Coastguard Worker Make sure that all the users of this partiton are within the delegatable subgraph, 50*523fa7a6SAndroid Build Coastguard Worker except the last partition. If we quantize partitions that have users outside this 51*523fa7a6SAndroid Build Coastguard Worker subgraph then delegation of this partition to the backend will not be possible. 52*523fa7a6SAndroid Build Coastguard Worker """ 53*523fa7a6SAndroid Build Coastguard Worker for source_partition in partitions[:-1]: 54*523fa7a6SAndroid Build Coastguard Worker if len(source_partition.output_nodes) != 1: 55*523fa7a6SAndroid Build Coastguard Worker return True 56*523fa7a6SAndroid Build Coastguard Worker if len(source_partition.output_nodes[0].users) != 1: 57*523fa7a6SAndroid Build Coastguard Worker return True 58*523fa7a6SAndroid Build Coastguard Worker return False 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker 61*523fa7a6SAndroid Build Coastguard Workerclass ExampleQuantizer(Quantizer): 62*523fa7a6SAndroid Build Coastguard Worker def __init__(self, quantizer_supported_modules=None, quant_config=None): 63*523fa7a6SAndroid Build Coastguard Worker super().__init__() 64*523fa7a6SAndroid Build Coastguard Worker if quantizer_supported_modules is not None: 65*523fa7a6SAndroid Build Coastguard Worker self.quantizer_supported_modules = quantizer_supported_modules 66*523fa7a6SAndroid Build Coastguard Worker for module in self.quantizer_supported_modules: 67*523fa7a6SAndroid Build Coastguard Worker if module not in module_to_annotator.keys(): 68*523fa7a6SAndroid Build Coastguard Worker assert 0, f"{module} is not supported by this quantizer" 69*523fa7a6SAndroid Build Coastguard Worker else: 70*523fa7a6SAndroid Build Coastguard Worker self.quantizer_supported_modules = module_to_annotator.keys() 71*523fa7a6SAndroid Build Coastguard Worker if quant_config is not None: 72*523fa7a6SAndroid Build Coastguard Worker self.quant_config = quant_config 73*523fa7a6SAndroid Build Coastguard Worker else: 74*523fa7a6SAndroid Build Coastguard Worker self.quant_config = default_static_config 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker def annotate(self, model): 77*523fa7a6SAndroid Build Coastguard Worker for supported_modules in self.quantizer_supported_modules: 78*523fa7a6SAndroid Build Coastguard Worker # print("supported modules: ", supported_modules) 79*523fa7a6SAndroid Build Coastguard Worker fused_partitions = find_sequential_partitions( 80*523fa7a6SAndroid Build Coastguard Worker model, 81*523fa7a6SAndroid Build Coastguard Worker list(supported_modules), 82*523fa7a6SAndroid Build Coastguard Worker ) 83*523fa7a6SAndroid Build Coastguard Worker 84*523fa7a6SAndroid Build Coastguard Worker for partitions in fused_partitions: 85*523fa7a6SAndroid Build Coastguard Worker if check_for_outside_users(partitions): 86*523fa7a6SAndroid Build Coastguard Worker continue 87*523fa7a6SAndroid Build Coastguard Worker 88*523fa7a6SAndroid Build Coastguard Worker source_module_list = () 89*523fa7a6SAndroid Build Coastguard Worker for partition in partitions: 90*523fa7a6SAndroid Build Coastguard Worker source_module_list += (partition,) 91*523fa7a6SAndroid Build Coastguard Worker 92*523fa7a6SAndroid Build Coastguard Worker annotator = module_to_annotator[supported_modules].annotate_handle 93*523fa7a6SAndroid Build Coastguard Worker annotator(partitions, self.quant_config) 94*523fa7a6SAndroid Build Coastguard Worker 95*523fa7a6SAndroid Build Coastguard Worker return model 96*523fa7a6SAndroid Build Coastguard Worker 97*523fa7a6SAndroid Build Coastguard Worker def validate(self, model: fx.GraphModule) -> None: 98*523fa7a6SAndroid Build Coastguard Worker pass 99*523fa7a6SAndroid Build Coastguard Worker 100*523fa7a6SAndroid Build Coastguard Worker @classmethod 101*523fa7a6SAndroid Build Coastguard Worker def get_supported_operators(cls) -> List[OperatorConfig]: 102*523fa7a6SAndroid Build Coastguard Worker return [] 103