xref: /aosp_15_r20/external/executorch/examples/qualcomm/oss_scripts/fastvit.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9from multiprocessing.connection import Client
10
11import numpy as np
12import torch
13from executorch.backends.qualcomm.quantizer.annotators import (
14    QuantizationConfig,
15    QuantizationSpec,
16)
17from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import (
18    PerChannelParamObserver,
19)
20from executorch.backends.qualcomm.quantizer.qconfig import (
21    _derived_bias_quant_spec,
22    MovingAverageMinMaxObserver,
23)
24
25from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
26from executorch.backends.qualcomm.utils.constants import (
27    QCOM_PASS_EXPAND_BROADCAST_SHAPE,
28)
29from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d
30from executorch.examples.qualcomm.utils import (
31    build_executorch_binary,
32    get_imagenet_dataset,
33    make_output_dir,
34    make_quantizer,
35    parse_skip_delegation_node,
36    setup_common_args_and_variables,
37    SimpleADB,
38    topk_accuracy,
39)
40
41
42def get_instance(repo_path: str, checkpoint_path: str):
43    import sys
44
45    sys.path.insert(0, repo_path)
46
47    from models.modules.mobileone import reparameterize_model
48    from timm.models import create_model
49
50    checkpoint = torch.load(checkpoint_path, weights_only=True)
51    model = create_model("fastvit_s12")
52    model = reparameterize_model(model).eval()
53    model.load_state_dict(checkpoint["state_dict"])
54    return model
55
56
57def main(args):
58    skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
59
60    # ensure the working directory exist.
61    os.makedirs(args.artifact, exist_ok=True)
62
63    if not args.compile_only and args.device is None:
64        raise RuntimeError(
65            "device serial is required if not compile only. "
66            "Please specify a device serial by -s/--device argument."
67        )
68
69    data_num = 100
70    inputs, targets, input_list = get_imagenet_dataset(
71        dataset_path=f"{args.dataset}",
72        data_size=data_num,
73        image_shape=(256, 256),
74    )
75
76    pte_filename = "fastvit_qnn"
77    quantizer = make_quantizer(quant_dtype=QuantDtype.use_8a8w)
78
79    # there are lots of outliers appearing in fastvit parameters
80    # we need to apply special configuration to saturate their impact
81    act_qspec = QuantizationSpec(
82        dtype=torch.uint8,
83        qscheme=torch.per_tensor_affine,
84        observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(
85            **{"averaging_constant": 0.02}
86        ),
87    )
88    weight_qspec = QuantizationSpec(
89        dtype=torch.int8,
90        quant_min=torch.iinfo(torch.int8).min + 1,
91        quant_max=torch.iinfo(torch.int8).max,
92        qscheme=torch.per_channel_symmetric,
93        ch_axis=0,
94        observer_or_fake_quant_ctr=PerChannelParamObserver.with_args(
95            **{"steps": 200, "use_mse": True}
96        ),
97    )
98    # rewrite default per-channel ptq config
99    quantizer.per_channel_quant_config = QuantizationConfig(
100        input_activation=act_qspec,
101        output_activation=act_qspec,
102        weight=weight_qspec,
103        bias=_derived_bias_quant_spec,
104    )
105    # rewrite default ptq config
106    q_config = quantizer.bit8_quant_config
107    quantizer.bit8_quant_config = QuantizationConfig(
108        input_activation=act_qspec,
109        output_activation=act_qspec,
110        weight=q_config.weight,
111        bias=q_config.bias,
112    )
113    # lower to QNN
114    build_executorch_binary(
115        convert_linear_to_conv2d(get_instance(args.oss_repo, args.pretrained_weight)),
116        inputs[0],
117        args.model,
118        f"{args.artifact}/{pte_filename}",
119        dataset=inputs,
120        skip_node_id_set=skip_node_id_set,
121        skip_node_op_set=skip_node_op_set,
122        quant_dtype=QuantDtype.use_8a8w,
123        custom_quantizer=quantizer,
124        custom_pass_config={QCOM_PASS_EXPAND_BROADCAST_SHAPE},
125        shared_buffer=args.shared_buffer,
126    )
127
128    if args.compile_only:
129        return
130
131    adb = SimpleADB(
132        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
133        build_path=f"{args.build_folder}",
134        pte_path=f"{args.artifact}/{pte_filename}.pte",
135        workspace=f"/data/local/tmp/executorch/{pte_filename}",
136        device_id=args.device,
137        host_id=args.host,
138        soc_model=args.model,
139    )
140    adb.push(inputs=inputs, input_list=input_list)
141    adb.execute()
142
143    # collect output data
144    output_data_folder = f"{args.artifact}/outputs"
145    make_output_dir(output_data_folder)
146
147    adb.pull(output_path=args.artifact)
148
149    # top-k analysis
150    predictions = []
151    for i in range(data_num):
152        predictions.append(
153            np.fromfile(
154                os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
155            )
156        )
157
158    k_val = [1, 5]
159    topk = [topk_accuracy(predictions, targets, k).item() for k in k_val]
160    if args.ip and args.port != -1:
161        with Client((args.ip, args.port)) as conn:
162            conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)}))
163    else:
164        for i, k in enumerate(k_val):
165            print(f"top_{k}->{topk[i]}%")
166
167
168if __name__ == "__main__":
169    parser = setup_common_args_and_variables()
170
171    parser.add_argument(
172        "-a",
173        "--artifact",
174        help="path for storing generated artifacts by this example. Default ./fastvit",
175        default="./fastvit",
176        type=str,
177    )
178
179    parser.add_argument(
180        "-d",
181        "--dataset",
182        help=(
183            "path to the validation folder of ImageNet dataset. "
184            "e.g. --dataset imagenet-mini/val "
185            "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)"
186        ),
187        type=str,
188        required=True,
189    )
190
191    parser.add_argument(
192        "--oss_repo",
193        help="Path to cloned https://github.com/apple/ml-fastvit",
194        type=str,
195        required=True,
196    )
197
198    parser.add_argument(
199        "-p",
200        "--pretrained_weight",
201        help=(
202            "Location of model pretrained weight."
203            "e.g., -p ./fastvit_s12_reparam.pth.tar"
204            "Pretrained model can be found in "
205            "https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_distilled_models/fastvit_s12_reparam.pth.tar"
206        ),
207        type=str,
208        required=True,
209    )
210
211    args = parser.parse_args()
212    try:
213        main(args)
214    except Exception as e:
215        if args.ip and args.port != -1:
216            with Client((args.ip, args.port)) as conn:
217                conn.send(json.dumps({"Error": str(e)}))
218        else:
219            raise Exception(e)
220