xref: /aosp_15_r20/external/executorch/examples/xnnpack/aot_compiler.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker# Example script for exporting simple models to flatbuffer
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerimport argparse
12*523fa7a6SAndroid Build Coastguard Workerimport copy
13*523fa7a6SAndroid Build Coastguard Workerimport logging
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Workerimport torch
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.devtools import generate_etrecord
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import (
19*523fa7a6SAndroid Build Coastguard Worker    EdgeCompileConfig,
20*523fa7a6SAndroid Build Coastguard Worker    ExecutorchBackendConfig,
21*523fa7a6SAndroid Build Coastguard Worker    to_edge_transform_and_lower,
22*523fa7a6SAndroid Build Coastguard Worker)
23*523fa7a6SAndroid Build Coastguard Workerfrom executorch.extension.export_util.utils import save_pte_program
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Workerfrom ..models import MODEL_NAME_TO_MODEL
26*523fa7a6SAndroid Build Coastguard Workerfrom ..models.model_factory import EagerModelFactory
27*523fa7a6SAndroid Build Coastguard Workerfrom . import MODEL_NAME_TO_OPTIONS
28*523fa7a6SAndroid Build Coastguard Workerfrom .quantization.utils import quantize
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard WorkerFORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
32*523fa7a6SAndroid Build Coastguard Workerlogging.basicConfig(level=logging.INFO, format=FORMAT)
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Workerif __name__ == "__main__":
36*523fa7a6SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser()
37*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
38*523fa7a6SAndroid Build Coastguard Worker        "-m",
39*523fa7a6SAndroid Build Coastguard Worker        "--model_name",
40*523fa7a6SAndroid Build Coastguard Worker        required=True,
41*523fa7a6SAndroid Build Coastguard Worker        help=f"Model name. Valid ones: {list(MODEL_NAME_TO_OPTIONS.keys())}",
42*523fa7a6SAndroid Build Coastguard Worker    )
43*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
44*523fa7a6SAndroid Build Coastguard Worker        "-q",
45*523fa7a6SAndroid Build Coastguard Worker        "--quantize",
46*523fa7a6SAndroid Build Coastguard Worker        action="store_true",
47*523fa7a6SAndroid Build Coastguard Worker        required=False,
48*523fa7a6SAndroid Build Coastguard Worker        default=False,
49*523fa7a6SAndroid Build Coastguard Worker        help="Produce an 8-bit quantized model",
50*523fa7a6SAndroid Build Coastguard Worker    )
51*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
52*523fa7a6SAndroid Build Coastguard Worker        "-d",
53*523fa7a6SAndroid Build Coastguard Worker        "--delegate",
54*523fa7a6SAndroid Build Coastguard Worker        action="store_true",
55*523fa7a6SAndroid Build Coastguard Worker        required=False,
56*523fa7a6SAndroid Build Coastguard Worker        default=True,
57*523fa7a6SAndroid Build Coastguard Worker        help="Produce an XNNPACK delegated model",
58*523fa7a6SAndroid Build Coastguard Worker    )
59*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
60*523fa7a6SAndroid Build Coastguard Worker        "-r",
61*523fa7a6SAndroid Build Coastguard Worker        "--etrecord",
62*523fa7a6SAndroid Build Coastguard Worker        required=False,
63*523fa7a6SAndroid Build Coastguard Worker        help="Generate and save an ETRecord to the given file location",
64*523fa7a6SAndroid Build Coastguard Worker    )
65*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument("-o", "--output_dir", default=".", help="output directory")
66*523fa7a6SAndroid Build Coastguard Worker
67*523fa7a6SAndroid Build Coastguard Worker    args = parser.parse_args()
68*523fa7a6SAndroid Build Coastguard Worker
69*523fa7a6SAndroid Build Coastguard Worker    if not args.delegate:
70*523fa7a6SAndroid Build Coastguard Worker        raise NotImplementedError(
71*523fa7a6SAndroid Build Coastguard Worker            "T161880157: Quantization-only without delegation is not supported yet"
72*523fa7a6SAndroid Build Coastguard Worker        )
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard Worker    if args.model_name not in MODEL_NAME_TO_OPTIONS and args.quantize:
75*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(
76*523fa7a6SAndroid Build Coastguard Worker            f"Model {args.model_name} is not a valid name. or not quantizable right now, "
77*523fa7a6SAndroid Build Coastguard Worker            "please contact executorch team if you want to learn why or how to support "
78*523fa7a6SAndroid Build Coastguard Worker            "quantization for the requested model"
79*523fa7a6SAndroid Build Coastguard Worker            f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}."
80*523fa7a6SAndroid Build Coastguard Worker        )
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker    model, example_inputs, _, _ = EagerModelFactory.create_model(
83*523fa7a6SAndroid Build Coastguard Worker        *MODEL_NAME_TO_MODEL[args.model_name]
84*523fa7a6SAndroid Build Coastguard Worker    )
85*523fa7a6SAndroid Build Coastguard Worker
86*523fa7a6SAndroid Build Coastguard Worker    model = model.eval()
87*523fa7a6SAndroid Build Coastguard Worker    # pre-autograd export. eventually this will become torch.export
88*523fa7a6SAndroid Build Coastguard Worker    ep = torch.export.export_for_training(model, example_inputs)
89*523fa7a6SAndroid Build Coastguard Worker    model = ep.module()
90*523fa7a6SAndroid Build Coastguard Worker
91*523fa7a6SAndroid Build Coastguard Worker    if args.quantize:
92*523fa7a6SAndroid Build Coastguard Worker        logging.info("Quantizing Model...")
93*523fa7a6SAndroid Build Coastguard Worker        # TODO(T165162973): This pass shall eventually be folded into quantizer
94*523fa7a6SAndroid Build Coastguard Worker        model = quantize(model, example_inputs)
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Worker    edge = to_edge_transform_and_lower(
97*523fa7a6SAndroid Build Coastguard Worker        ep,
98*523fa7a6SAndroid Build Coastguard Worker        partitioner=[XnnpackPartitioner()],
99*523fa7a6SAndroid Build Coastguard Worker        compile_config=EdgeCompileConfig(
100*523fa7a6SAndroid Build Coastguard Worker            _check_ir_validity=False if args.quantize else True,
101*523fa7a6SAndroid Build Coastguard Worker            _skip_dim_order=True,  # TODO(T182187531): enable dim order in xnnpack
102*523fa7a6SAndroid Build Coastguard Worker        ),
103*523fa7a6SAndroid Build Coastguard Worker    )
104*523fa7a6SAndroid Build Coastguard Worker    logging.info(f"Exported and lowered graph:\n{edge.exported_program().graph}")
105*523fa7a6SAndroid Build Coastguard Worker
106*523fa7a6SAndroid Build Coastguard Worker    # this is needed for the ETRecord as lowering modifies the graph in-place
107*523fa7a6SAndroid Build Coastguard Worker    edge_copy = copy.deepcopy(edge)
108*523fa7a6SAndroid Build Coastguard Worker
109*523fa7a6SAndroid Build Coastguard Worker    exec_prog = edge.to_executorch(
110*523fa7a6SAndroid Build Coastguard Worker        config=ExecutorchBackendConfig(extract_delegate_segments=False)
111*523fa7a6SAndroid Build Coastguard Worker    )
112*523fa7a6SAndroid Build Coastguard Worker
113*523fa7a6SAndroid Build Coastguard Worker    if args.etrecord is not None:
114*523fa7a6SAndroid Build Coastguard Worker        generate_etrecord(args.etrecord, edge_copy, exec_prog)
115*523fa7a6SAndroid Build Coastguard Worker        logging.info(f"Saved ETRecord to {args.etrecord}")
116*523fa7a6SAndroid Build Coastguard Worker
117*523fa7a6SAndroid Build Coastguard Worker    quant_tag = "q8" if args.quantize else "fp32"
118*523fa7a6SAndroid Build Coastguard Worker    model_name = f"{args.model_name}_xnnpack_{quant_tag}"
119*523fa7a6SAndroid Build Coastguard Worker    save_pte_program(exec_prog, model_name, args.output_dir)
120