1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import json 8import os 9from multiprocessing.connection import Client 10 11import numpy as np 12import torch 13from executorch.backends.qualcomm.quantizer.annotators import ( 14 QuantizationConfig, 15 QuantizationSpec, 16) 17from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( 18 PerChannelParamObserver, 19) 20from executorch.backends.qualcomm.quantizer.qconfig import ( 21 _derived_bias_quant_spec, 22 MovingAverageMinMaxObserver, 23) 24 25from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype 26from executorch.backends.qualcomm.utils.constants import ( 27 QCOM_PASS_EXPAND_BROADCAST_SHAPE, 28) 29from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d 30from executorch.examples.qualcomm.utils import ( 31 build_executorch_binary, 32 get_imagenet_dataset, 33 make_output_dir, 34 make_quantizer, 35 parse_skip_delegation_node, 36 setup_common_args_and_variables, 37 SimpleADB, 38 topk_accuracy, 39) 40 41 42def get_instance(repo_path: str, checkpoint_path: str): 43 import sys 44 45 sys.path.insert(0, repo_path) 46 47 from models.modules.mobileone import reparameterize_model 48 from timm.models import create_model 49 50 checkpoint = torch.load(checkpoint_path, weights_only=True) 51 model = create_model("fastvit_s12") 52 model = reparameterize_model(model).eval() 53 model.load_state_dict(checkpoint["state_dict"]) 54 return model 55 56 57def main(args): 58 skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) 59 60 # ensure the working directory exist. 61 os.makedirs(args.artifact, exist_ok=True) 62 63 if not args.compile_only and args.device is None: 64 raise RuntimeError( 65 "device serial is required if not compile only. " 66 "Please specify a device serial by -s/--device argument." 67 ) 68 69 data_num = 100 70 inputs, targets, input_list = get_imagenet_dataset( 71 dataset_path=f"{args.dataset}", 72 data_size=data_num, 73 image_shape=(256, 256), 74 ) 75 76 pte_filename = "fastvit_qnn" 77 quantizer = make_quantizer(quant_dtype=QuantDtype.use_8a8w) 78 79 # there are lots of outliers appearing in fastvit parameters 80 # we need to apply special configuration to saturate their impact 81 act_qspec = QuantizationSpec( 82 dtype=torch.uint8, 83 qscheme=torch.per_tensor_affine, 84 observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args( 85 **{"averaging_constant": 0.02} 86 ), 87 ) 88 weight_qspec = QuantizationSpec( 89 dtype=torch.int8, 90 quant_min=torch.iinfo(torch.int8).min + 1, 91 quant_max=torch.iinfo(torch.int8).max, 92 qscheme=torch.per_channel_symmetric, 93 ch_axis=0, 94 observer_or_fake_quant_ctr=PerChannelParamObserver.with_args( 95 **{"steps": 200, "use_mse": True} 96 ), 97 ) 98 # rewrite default per-channel ptq config 99 quantizer.per_channel_quant_config = QuantizationConfig( 100 input_activation=act_qspec, 101 output_activation=act_qspec, 102 weight=weight_qspec, 103 bias=_derived_bias_quant_spec, 104 ) 105 # rewrite default ptq config 106 q_config = quantizer.bit8_quant_config 107 quantizer.bit8_quant_config = QuantizationConfig( 108 input_activation=act_qspec, 109 output_activation=act_qspec, 110 weight=q_config.weight, 111 bias=q_config.bias, 112 ) 113 # lower to QNN 114 build_executorch_binary( 115 convert_linear_to_conv2d(get_instance(args.oss_repo, args.pretrained_weight)), 116 inputs[0], 117 args.model, 118 f"{args.artifact}/{pte_filename}", 119 dataset=inputs, 120 skip_node_id_set=skip_node_id_set, 121 skip_node_op_set=skip_node_op_set, 122 quant_dtype=QuantDtype.use_8a8w, 123 custom_quantizer=quantizer, 124 custom_pass_config={QCOM_PASS_EXPAND_BROADCAST_SHAPE}, 125 shared_buffer=args.shared_buffer, 126 ) 127 128 if args.compile_only: 129 return 130 131 adb = SimpleADB( 132 qnn_sdk=os.getenv("QNN_SDK_ROOT"), 133 build_path=f"{args.build_folder}", 134 pte_path=f"{args.artifact}/{pte_filename}.pte", 135 workspace=f"/data/local/tmp/executorch/{pte_filename}", 136 device_id=args.device, 137 host_id=args.host, 138 soc_model=args.model, 139 ) 140 adb.push(inputs=inputs, input_list=input_list) 141 adb.execute() 142 143 # collect output data 144 output_data_folder = f"{args.artifact}/outputs" 145 make_output_dir(output_data_folder) 146 147 adb.pull(output_path=args.artifact) 148 149 # top-k analysis 150 predictions = [] 151 for i in range(data_num): 152 predictions.append( 153 np.fromfile( 154 os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 155 ) 156 ) 157 158 k_val = [1, 5] 159 topk = [topk_accuracy(predictions, targets, k).item() for k in k_val] 160 if args.ip and args.port != -1: 161 with Client((args.ip, args.port)) as conn: 162 conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)})) 163 else: 164 for i, k in enumerate(k_val): 165 print(f"top_{k}->{topk[i]}%") 166 167 168if __name__ == "__main__": 169 parser = setup_common_args_and_variables() 170 171 parser.add_argument( 172 "-a", 173 "--artifact", 174 help="path for storing generated artifacts by this example. Default ./fastvit", 175 default="./fastvit", 176 type=str, 177 ) 178 179 parser.add_argument( 180 "-d", 181 "--dataset", 182 help=( 183 "path to the validation folder of ImageNet dataset. " 184 "e.g. --dataset imagenet-mini/val " 185 "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" 186 ), 187 type=str, 188 required=True, 189 ) 190 191 parser.add_argument( 192 "--oss_repo", 193 help="Path to cloned https://github.com/apple/ml-fastvit", 194 type=str, 195 required=True, 196 ) 197 198 parser.add_argument( 199 "-p", 200 "--pretrained_weight", 201 help=( 202 "Location of model pretrained weight." 203 "e.g., -p ./fastvit_s12_reparam.pth.tar" 204 "Pretrained model can be found in " 205 "https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_s12_reparam.pth.tar" 206 ), 207 type=str, 208 required=True, 209 ) 210 211 args = parser.parse_args() 212 try: 213 main(args) 214 except Exception as e: 215 if args.ip and args.port != -1: 216 with Client((args.ip, args.port)) as conn: 217 conn.send(json.dumps({"Error": str(e)})) 218 else: 219 raise Exception(e) 220