1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import json 8import os 9from multiprocessing.connection import Client 10 11import numpy as np 12 13import torch 14from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype 15from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset 16from executorch.backends.qualcomm.utils.utils import ( 17 generate_htp_compiler_spec, 18 generate_qnn_executorch_compiler_spec, 19 skip_annotation, 20) 21from executorch.examples.qualcomm.utils import ( 22 build_executorch_binary, 23 make_output_dir, 24 make_quantizer, 25 parse_skip_delegation_node, 26 QnnPartitioner, 27 setup_common_args_and_variables, 28 SimpleADB, 29) 30from executorch.exir import to_edge 31from transformers import BertTokenizer, MobileBertForSequenceClassification 32 33 34def evaluate(model, data_val): 35 predictions, true_vals = [], [] 36 for data in data_val: 37 inputs = { 38 "input_ids": data[0].to(torch.long), 39 "attention_mask": data[1].to(torch.long), 40 "labels": data[2].to(torch.long), 41 } 42 logits = model(**inputs)[1].detach().numpy() 43 label_ids = inputs["labels"].numpy() 44 predictions.append(logits) 45 true_vals.append(label_ids) 46 47 return ( 48 np.concatenate(predictions, axis=0), 49 np.concatenate(true_vals, axis=0), 50 ) 51 52 53def accuracy_per_class(preds, goldens, labels): 54 labels_inverse = {v: k for k, v in labels.items()} 55 preds_flat = np.argmax(preds, axis=1).flatten() 56 goldens_flat = goldens.flatten() 57 58 result = {} 59 for golden in np.unique(goldens_flat): 60 pred = preds_flat[goldens_flat == golden] 61 true = goldens_flat[goldens_flat == golden] 62 result.update({labels_inverse[golden]: [len(pred[pred == golden]), len(true)]}) 63 64 return result 65 66 67def get_dataset(data_val): 68 # prepare input data 69 inputs, input_list = [], "" 70 # max_position_embeddings defaults to 512 71 position_ids = torch.arange(512).expand((1, -1)).to(torch.int32) 72 for index, data in enumerate(data_val): 73 data = [d.to(torch.int32) for d in data] 74 # input_ids, attention_mask, token_type_ids, position_ids 75 inputs.append( 76 ( 77 *data[:2], 78 torch.zeros(data[0].size(), dtype=torch.int32), 79 position_ids[:, : data[0].shape[1]], 80 ) 81 ) 82 input_text = " ".join( 83 [f"input_{index}_{i}.raw" for i in range(len(inputs[-1]))] 84 ) 85 input_list += f"{input_text}\n" 86 87 return inputs, input_list 88 89 90def get_fine_tuned_mobilebert(artifacts_dir, pretrained_weight, batch_size): 91 from io import BytesIO 92 93 import pandas as pd 94 import requests 95 from sklearn.model_selection import train_test_split 96 from torch.utils.data import ( 97 DataLoader, 98 RandomSampler, 99 SequentialSampler, 100 TensorDataset, 101 ) 102 from tqdm import tqdm 103 from transformers import get_linear_schedule_with_warmup 104 105 # grab dataset 106 url = ( 107 "https://raw.githubusercontent.com/susanli2016/NLP-with-Python" 108 "/master/data/title_conference.csv" 109 ) 110 content = requests.get(url, allow_redirects=True).content 111 data = pd.read_csv(BytesIO(content)) 112 113 # get training / validation data 114 labels = {key: index for index, key in enumerate(data.Conference.unique())} 115 data["label"] = data.Conference.replace(labels) 116 117 train, val, _, _ = train_test_split( 118 data.index.values, 119 data.label.values, 120 test_size=0.15, 121 random_state=42, 122 stratify=data.label.values, 123 ) 124 125 data["data_type"] = ["not_set"] * data.shape[0] 126 data.loc[train, "data_type"] = "train" 127 data.loc[val, "data_type"] = "val" 128 data.groupby(["Conference", "label", "data_type"]).count() 129 130 # get pre-trained mobilebert 131 tokenizer = BertTokenizer.from_pretrained( 132 "bert-base-uncased", 133 do_lower_case=True, 134 ) 135 model = MobileBertForSequenceClassification.from_pretrained( 136 "google/mobilebert-uncased", 137 num_labels=len(labels), 138 return_dict=False, 139 ) 140 141 # tokenize dataset 142 encoded_data_train = tokenizer.batch_encode_plus( 143 data[data.data_type == "train"].Title.values, 144 add_special_tokens=True, 145 return_attention_mask=True, 146 max_length=256, 147 padding="max_length", 148 truncation=True, 149 return_tensors="pt", 150 ) 151 encoded_data_val = tokenizer.batch_encode_plus( 152 data[data.data_type == "val"].Title.values, 153 add_special_tokens=True, 154 return_attention_mask=True, 155 max_length=256, 156 padding="max_length", 157 truncation=True, 158 return_tensors="pt", 159 ) 160 161 input_ids_train = encoded_data_train["input_ids"] 162 attention_masks_train = encoded_data_train["attention_mask"] 163 labels_train = torch.tensor(data[data.data_type == "train"].label.values) 164 165 input_ids_val = encoded_data_val["input_ids"] 166 attention_masks_val = encoded_data_val["attention_mask"] 167 labels_val = torch.tensor(data[data.data_type == "val"].label.values) 168 169 dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train) 170 dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val) 171 172 epochs = 5 173 dataloader_train = DataLoader( 174 dataset_train, 175 sampler=RandomSampler(dataset_train), 176 batch_size=batch_size, 177 ) 178 dataloader_val = DataLoader( 179 dataset_val, 180 sampler=SequentialSampler(dataset_val), 181 batch_size=batch_size, 182 drop_last=True, 183 ) 184 optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) 185 scheduler = get_linear_schedule_with_warmup( 186 optimizer, num_warmup_steps=0, num_training_steps=len(dataloader_train) * epochs 187 ) 188 189 # start training 190 if not pretrained_weight: 191 for epoch in range(1, epochs + 1): 192 loss_train_total = 0 193 print(f"epoch {epoch}") 194 195 for batch in tqdm(dataloader_train): 196 model.zero_grad() 197 inputs = { 198 "input_ids": batch[0], 199 "attention_mask": batch[1], 200 "labels": batch[2], 201 } 202 loss = model(**inputs)[0] 203 loss_train_total += loss.item() 204 loss.backward() 205 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 206 optimizer.step() 207 scheduler.step() 208 209 torch.save( 210 model.state_dict(), 211 f"{artifacts_dir}/finetuned_mobilebert_epoch_{epoch}.model", 212 ) 213 214 model.load_state_dict( 215 torch.load( 216 ( 217 f"{artifacts_dir}/finetuned_mobilebert_epoch_{epochs}.model" 218 if pretrained_weight is None 219 else pretrained_weight 220 ), 221 map_location=torch.device("cpu"), 222 weights_only=True, 223 ), 224 ) 225 226 return model.eval(), dataloader_val, labels 227 228 229def main(args): 230 skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) 231 232 # ensure the working directory exist. 233 os.makedirs(args.artifact, exist_ok=True) 234 235 if not args.compile_only and args.device is None: 236 raise RuntimeError( 237 "device serial is required if not compile only. " 238 "Please specify a device serial by -s/--device argument." 239 ) 240 241 batch_size, pte_filename = 1, "ptq_mb_qnn" 242 model, data_val, labels = get_fine_tuned_mobilebert( 243 args.artifact, args.pretrained_weight, batch_size 244 ) 245 inputs, input_list = get_dataset(data_val) 246 247 try: 248 quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") 249 except: 250 raise AssertionError( 251 f"No support for quant type {args.ptq}. Support 8a8w, 16a16w and 16a4w." 252 ) 253 254 if args.use_fp16: 255 quant_dtype = None 256 pte_filename = "mb_qnn" 257 build_executorch_binary( 258 model, 259 inputs[0], 260 args.model, 261 f"{args.artifact}/{pte_filename}", 262 inputs, 263 skip_node_id_set=skip_node_id_set, 264 skip_node_op_set=skip_node_op_set, 265 quant_dtype=quant_dtype, 266 shared_buffer=args.shared_buffer, 267 ) 268 else: 269 270 def calibrator(gm): 271 for input in inputs: 272 gm(*input) 273 274 quantizer = make_quantizer(quant_dtype=quant_dtype) 275 backend_options = generate_htp_compiler_spec(quant_dtype is not None) 276 partitioner = QnnPartitioner( 277 generate_qnn_executorch_compiler_spec( 278 soc_model=getattr(QcomChipset, args.model), 279 backend_options=backend_options, 280 ), 281 skip_node_id_set=skip_node_id_set, 282 skip_node_op_set=skip_node_op_set, 283 ) 284 # skip embedding layer cause it's quantization sensitive 285 graph_module, _ = skip_annotation( 286 nn_module=model, 287 quantizer=quantizer, 288 partitioner=partitioner, 289 sample_input=inputs[0], 290 calibration_cb=calibrator, 291 fp_node_op_set={torch.ops.aten.embedding.default}, 292 ) 293 # lower all graph again, the skipped operators will be left in CPU 294 exec_prog = to_edge( 295 torch.export.export(graph_module, inputs[0]), 296 ).to_executorch() 297 298 with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: 299 file.write(exec_prog.buffer) 300 301 if args.compile_only: 302 return 303 304 adb = SimpleADB( 305 qnn_sdk=os.getenv("QNN_SDK_ROOT"), 306 build_path=f"{args.build_folder}", 307 pte_path=f"{args.artifact}/{pte_filename}.pte", 308 workspace=f"/data/local/tmp/executorch/{pte_filename}", 309 device_id=args.device, 310 host_id=args.host, 311 soc_model=args.model, 312 shared_buffer=args.shared_buffer, 313 ) 314 adb.push(inputs=inputs, input_list=input_list) 315 adb.execute() 316 317 # collect output data 318 output_data_folder = f"{args.artifact}/outputs" 319 make_output_dir(output_data_folder) 320 321 adb.pull(output_path=args.artifact) 322 323 # get torch cpu result 324 cpu_preds, true_vals = evaluate(model, data_val) 325 cpu_result = accuracy_per_class(cpu_preds, true_vals, labels) 326 327 # get QNN HTP result 328 htp_preds = [] 329 for i in range(len(data_val)): 330 result = np.fromfile( 331 os.path.join(output_data_folder, f"output_{i}_0.raw"), 332 dtype=np.float32, 333 ) 334 htp_preds.append(result.reshape(batch_size, -1)) 335 336 htp_result = accuracy_per_class( 337 np.concatenate(htp_preds, axis=0), true_vals, labels 338 ) 339 340 if args.ip and args.port != -1: 341 with Client((args.ip, args.port)) as conn: 342 conn.send(json.dumps({"CPU": cpu_result, "HTP": htp_result})) 343 else: 344 for target in zip(["CPU", "HTP"], [cpu_result, htp_result]): 345 print(f"\n[{target[0]}]") 346 for k, v in target[1].items(): 347 print(f"{k}: {v[0]}/{v[1]}") 348 349 350if __name__ == "__main__": 351 parser = setup_common_args_and_variables() 352 353 parser.add_argument( 354 "-a", 355 "--artifact", 356 help="path for storing generated artifacts by this example. Default ./mobilebert_fine_tune", 357 default="./mobilebert_fine_tune", 358 type=str, 359 ) 360 361 parser.add_argument( 362 "-p", 363 "--pretrained_weight", 364 help="Location of pretrained weight", 365 default=None, 366 type=str, 367 ) 368 369 parser.add_argument( 370 "-F", 371 "--use_fp16", 372 help="If specified, will run in fp16 precision and discard ptq setting", 373 action="store_true", 374 default=False, 375 ) 376 377 parser.add_argument( 378 "-P", 379 "--ptq", 380 help="If specified, will do PTQ quantization. default is 8bits activation and 8bits weight. Support 8a8w, 16a16w and 16a4w.", 381 default="8a8w", 382 ) 383 384 args = parser.parse_args() 385 try: 386 main(args) 387 except Exception as e: 388 if args.ip and args.port != -1: 389 with Client((args.ip, args.port)) as conn: 390 conn.send(json.dumps({"Error": str(e)})) 391 else: 392 raise Exception(e) 393