xref: /aosp_15_r20/external/executorch/examples/qualcomm/scripts/mobilebert_fine_tune.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9from multiprocessing.connection import Client
10
11import numpy as np
12
13import torch
14from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
15from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
16from executorch.backends.qualcomm.utils.utils import (
17    generate_htp_compiler_spec,
18    generate_qnn_executorch_compiler_spec,
19    skip_annotation,
20)
21from executorch.examples.qualcomm.utils import (
22    build_executorch_binary,
23    make_output_dir,
24    make_quantizer,
25    parse_skip_delegation_node,
26    QnnPartitioner,
27    setup_common_args_and_variables,
28    SimpleADB,
29)
30from executorch.exir import to_edge
31from transformers import BertTokenizer, MobileBertForSequenceClassification
32
33
34def evaluate(model, data_val):
35    predictions, true_vals = [], []
36    for data in data_val:
37        inputs = {
38            "input_ids": data[0].to(torch.long),
39            "attention_mask": data[1].to(torch.long),
40            "labels": data[2].to(torch.long),
41        }
42        logits = model(**inputs)[1].detach().numpy()
43        label_ids = inputs["labels"].numpy()
44        predictions.append(logits)
45        true_vals.append(label_ids)
46
47    return (
48        np.concatenate(predictions, axis=0),
49        np.concatenate(true_vals, axis=0),
50    )
51
52
53def accuracy_per_class(preds, goldens, labels):
54    labels_inverse = {v: k for k, v in labels.items()}
55    preds_flat = np.argmax(preds, axis=1).flatten()
56    goldens_flat = goldens.flatten()
57
58    result = {}
59    for golden in np.unique(goldens_flat):
60        pred = preds_flat[goldens_flat == golden]
61        true = goldens_flat[goldens_flat == golden]
62        result.update({labels_inverse[golden]: [len(pred[pred == golden]), len(true)]})
63
64    return result
65
66
67def get_dataset(data_val):
68    # prepare input data
69    inputs, input_list = [], ""
70    # max_position_embeddings defaults to 512
71    position_ids = torch.arange(512).expand((1, -1)).to(torch.int32)
72    for index, data in enumerate(data_val):
73        data = [d.to(torch.int32) for d in data]
74        # input_ids, attention_mask, token_type_ids, position_ids
75        inputs.append(
76            (
77                *data[:2],
78                torch.zeros(data[0].size(), dtype=torch.int32),
79                position_ids[:, : data[0].shape[1]],
80            )
81        )
82        input_text = " ".join(
83            [f"input_{index}_{i}.raw" for i in range(len(inputs[-1]))]
84        )
85        input_list += f"{input_text}\n"
86
87    return inputs, input_list
88
89
90def get_fine_tuned_mobilebert(artifacts_dir, pretrained_weight, batch_size):
91    from io import BytesIO
92
93    import pandas as pd
94    import requests
95    from sklearn.model_selection import train_test_split
96    from torch.utils.data import (
97        DataLoader,
98        RandomSampler,
99        SequentialSampler,
100        TensorDataset,
101    )
102    from tqdm import tqdm
103    from transformers import get_linear_schedule_with_warmup
104
105    # grab dataset
106    url = (
107        "https://raw.githubusercontent.com/susanli2016/NLP-with-Python"
108        "/master/data/title_conference.csv"
109    )
110    content = requests.get(url, allow_redirects=True).content
111    data = pd.read_csv(BytesIO(content))
112
113    # get training / validation data
114    labels = {key: index for index, key in enumerate(data.Conference.unique())}
115    data["label"] = data.Conference.replace(labels)
116
117    train, val, _, _ = train_test_split(
118        data.index.values,
119        data.label.values,
120        test_size=0.15,
121        random_state=42,
122        stratify=data.label.values,
123    )
124
125    data["data_type"] = ["not_set"] * data.shape[0]
126    data.loc[train, "data_type"] = "train"
127    data.loc[val, "data_type"] = "val"
128    data.groupby(["Conference", "label", "data_type"]).count()
129
130    # get pre-trained mobilebert
131    tokenizer = BertTokenizer.from_pretrained(
132        "bert-base-uncased",
133        do_lower_case=True,
134    )
135    model = MobileBertForSequenceClassification.from_pretrained(
136        "google/mobilebert-uncased",
137        num_labels=len(labels),
138        return_dict=False,
139    )
140
141    # tokenize dataset
142    encoded_data_train = tokenizer.batch_encode_plus(
143        data[data.data_type == "train"].Title.values,
144        add_special_tokens=True,
145        return_attention_mask=True,
146        max_length=256,
147        padding="max_length",
148        truncation=True,
149        return_tensors="pt",
150    )
151    encoded_data_val = tokenizer.batch_encode_plus(
152        data[data.data_type == "val"].Title.values,
153        add_special_tokens=True,
154        return_attention_mask=True,
155        max_length=256,
156        padding="max_length",
157        truncation=True,
158        return_tensors="pt",
159    )
160
161    input_ids_train = encoded_data_train["input_ids"]
162    attention_masks_train = encoded_data_train["attention_mask"]
163    labels_train = torch.tensor(data[data.data_type == "train"].label.values)
164
165    input_ids_val = encoded_data_val["input_ids"]
166    attention_masks_val = encoded_data_val["attention_mask"]
167    labels_val = torch.tensor(data[data.data_type == "val"].label.values)
168
169    dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)
170    dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)
171
172    epochs = 5
173    dataloader_train = DataLoader(
174        dataset_train,
175        sampler=RandomSampler(dataset_train),
176        batch_size=batch_size,
177    )
178    dataloader_val = DataLoader(
179        dataset_val,
180        sampler=SequentialSampler(dataset_val),
181        batch_size=batch_size,
182        drop_last=True,
183    )
184    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
185    scheduler = get_linear_schedule_with_warmup(
186        optimizer, num_warmup_steps=0, num_training_steps=len(dataloader_train) * epochs
187    )
188
189    # start training
190    if not pretrained_weight:
191        for epoch in range(1, epochs + 1):
192            loss_train_total = 0
193            print(f"epoch {epoch}")
194
195            for batch in tqdm(dataloader_train):
196                model.zero_grad()
197                inputs = {
198                    "input_ids": batch[0],
199                    "attention_mask": batch[1],
200                    "labels": batch[2],
201                }
202                loss = model(**inputs)[0]
203                loss_train_total += loss.item()
204                loss.backward()
205                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
206                optimizer.step()
207                scheduler.step()
208
209            torch.save(
210                model.state_dict(),
211                f"{artifacts_dir}/finetuned_mobilebert_epoch_{epoch}.model",
212            )
213
214    model.load_state_dict(
215        torch.load(
216            (
217                f"{artifacts_dir}/finetuned_mobilebert_epoch_{epochs}.model"
218                if pretrained_weight is None
219                else pretrained_weight
220            ),
221            map_location=torch.device("cpu"),
222            weights_only=True,
223        ),
224    )
225
226    return model.eval(), dataloader_val, labels
227
228
229def main(args):
230    skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
231
232    # ensure the working directory exist.
233    os.makedirs(args.artifact, exist_ok=True)
234
235    if not args.compile_only and args.device is None:
236        raise RuntimeError(
237            "device serial is required if not compile only. "
238            "Please specify a device serial by -s/--device argument."
239        )
240
241    batch_size, pte_filename = 1, "ptq_mb_qnn"
242    model, data_val, labels = get_fine_tuned_mobilebert(
243        args.artifact, args.pretrained_weight, batch_size
244    )
245    inputs, input_list = get_dataset(data_val)
246
247    try:
248        quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")
249    except:
250        raise AssertionError(
251            f"No support for quant type {args.ptq}. Support 8a8w, 16a16w and 16a4w."
252        )
253
254    if args.use_fp16:
255        quant_dtype = None
256        pte_filename = "mb_qnn"
257        build_executorch_binary(
258            model,
259            inputs[0],
260            args.model,
261            f"{args.artifact}/{pte_filename}",
262            inputs,
263            skip_node_id_set=skip_node_id_set,
264            skip_node_op_set=skip_node_op_set,
265            quant_dtype=quant_dtype,
266            shared_buffer=args.shared_buffer,
267        )
268    else:
269
270        def calibrator(gm):
271            for input in inputs:
272                gm(*input)
273
274        quantizer = make_quantizer(quant_dtype=quant_dtype)
275        backend_options = generate_htp_compiler_spec(quant_dtype is not None)
276        partitioner = QnnPartitioner(
277            generate_qnn_executorch_compiler_spec(
278                soc_model=getattr(QcomChipset, args.model),
279                backend_options=backend_options,
280            ),
281            skip_node_id_set=skip_node_id_set,
282            skip_node_op_set=skip_node_op_set,
283        )
284        # skip embedding layer cause it's quantization sensitive
285        graph_module, _ = skip_annotation(
286            nn_module=model,
287            quantizer=quantizer,
288            partitioner=partitioner,
289            sample_input=inputs[0],
290            calibration_cb=calibrator,
291            fp_node_op_set={torch.ops.aten.embedding.default},
292        )
293        # lower all graph again, the skipped operators will be left in CPU
294        exec_prog = to_edge(
295            torch.export.export(graph_module, inputs[0]),
296        ).to_executorch()
297
298        with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file:
299            file.write(exec_prog.buffer)
300
301    if args.compile_only:
302        return
303
304    adb = SimpleADB(
305        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
306        build_path=f"{args.build_folder}",
307        pte_path=f"{args.artifact}/{pte_filename}.pte",
308        workspace=f"/data/local/tmp/executorch/{pte_filename}",
309        device_id=args.device,
310        host_id=args.host,
311        soc_model=args.model,
312        shared_buffer=args.shared_buffer,
313    )
314    adb.push(inputs=inputs, input_list=input_list)
315    adb.execute()
316
317    # collect output data
318    output_data_folder = f"{args.artifact}/outputs"
319    make_output_dir(output_data_folder)
320
321    adb.pull(output_path=args.artifact)
322
323    # get torch cpu result
324    cpu_preds, true_vals = evaluate(model, data_val)
325    cpu_result = accuracy_per_class(cpu_preds, true_vals, labels)
326
327    # get QNN HTP result
328    htp_preds = []
329    for i in range(len(data_val)):
330        result = np.fromfile(
331            os.path.join(output_data_folder, f"output_{i}_0.raw"),
332            dtype=np.float32,
333        )
334        htp_preds.append(result.reshape(batch_size, -1))
335
336    htp_result = accuracy_per_class(
337        np.concatenate(htp_preds, axis=0), true_vals, labels
338    )
339
340    if args.ip and args.port != -1:
341        with Client((args.ip, args.port)) as conn:
342            conn.send(json.dumps({"CPU": cpu_result, "HTP": htp_result}))
343    else:
344        for target in zip(["CPU", "HTP"], [cpu_result, htp_result]):
345            print(f"\n[{target[0]}]")
346            for k, v in target[1].items():
347                print(f"{k}: {v[0]}/{v[1]}")
348
349
350if __name__ == "__main__":
351    parser = setup_common_args_and_variables()
352
353    parser.add_argument(
354        "-a",
355        "--artifact",
356        help="path for storing generated artifacts by this example. Default ./mobilebert_fine_tune",
357        default="./mobilebert_fine_tune",
358        type=str,
359    )
360
361    parser.add_argument(
362        "-p",
363        "--pretrained_weight",
364        help="Location of pretrained weight",
365        default=None,
366        type=str,
367    )
368
369    parser.add_argument(
370        "-F",
371        "--use_fp16",
372        help="If specified, will run in fp16 precision and discard ptq setting",
373        action="store_true",
374        default=False,
375    )
376
377    parser.add_argument(
378        "-P",
379        "--ptq",
380        help="If specified, will do PTQ quantization. default is 8bits activation and 8bits weight. Support 8a8w, 16a16w and 16a4w.",
381        default="8a8w",
382    )
383
384    args = parser.parse_args()
385    try:
386        main(args)
387    except Exception as e:
388        if args.ip and args.port != -1:
389            with Client((args.ip, args.port)) as conn:
390                conn.send(json.dumps({"Error": str(e)}))
391        else:
392            raise Exception(e)
393