xref: /aosp_15_r20/external/executorch/examples/arm/aot_arm_compiler.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3# Copyright 2023-2024 Arm Limited and/or its affiliates.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8# Example script for exporting simple models to flatbuffer
9
10import argparse
11import json
12import logging
13import os
14
15from pathlib import Path
16from typing import Any, Dict, Optional, Tuple
17
18import torch
19from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
20from executorch.backends.arm.arm_partitioner import ArmPartitioner
21from executorch.backends.arm.quantizer.arm_quantizer import (
22    ArmQuantizer,
23    get_symmetric_quantization_config,
24)
25
26from executorch.backends.arm.util.arm_model_evaluator import (
27    GenericModelEvaluator,
28    MobileNetV2Evaluator,
29)
30from executorch.devtools.backend_debug import get_delegation_info
31from executorch.exir import (
32    EdgeCompileConfig,
33    ExecutorchBackendConfig,
34    to_edge_transform_and_lower,
35)
36from executorch.extension.export_util.utils import save_pte_program
37from tabulate import tabulate
38
39# Quantize model if required using the standard export quantizaion flow.
40from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
41from torch.utils.data import DataLoader
42
43from ..models import MODEL_NAME_TO_MODEL
44from ..models.model_factory import EagerModelFactory
45
46FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
47logging.basicConfig(level=logging.WARNING, format=FORMAT)
48
49
50def get_model_and_inputs_from_name(model_name: str) -> Tuple[torch.nn.Module, Any]:
51    """Given the name of an example pytorch model, return it and example inputs.
52
53    Raises RuntimeError if there is no example model corresponding to the given name.
54    """
55    # Case 1: Model is defined in this file
56    if model_name in models.keys():
57        model = models[model_name]()
58        example_inputs = models[model_name].example_input
59    # Case 2: Model is defined in examples/models/
60    elif model_name in MODEL_NAME_TO_MODEL.keys():
61        logging.warning(
62            "Using a model from examples/models not all of these are currently supported"
63        )
64        model, example_inputs, _, _ = EagerModelFactory.create_model(
65            *MODEL_NAME_TO_MODEL[model_name]
66        )
67    # Case 3: Model is in an external python file loaded as a module.
68    #         ModelUnderTest should be a torch.nn.module instance
69    #         ModelInputs should be a tuple of inputs to the forward function
70    elif model_name.endswith(".py"):
71        import importlib.util
72
73        # load model's module and add it
74        spec = importlib.util.spec_from_file_location("tmp_model", model_name)
75        module = importlib.util.module_from_spec(spec)
76        spec.loader.exec_module(module)
77        model = module.ModelUnderTest
78        example_inputs = module.ModelInputs
79
80    else:
81        raise RuntimeError(
82            f"Model '{model_name}' is not a valid name. Use --help for a list of available models."
83        )
84
85    return model, example_inputs
86
87
88def quantize(
89    model: torch.nn.Module,
90    model_name: str,
91    example_inputs: Tuple[torch.Tensor],
92    evaluator_name: str | None,
93    evaluator_config: Dict[str, Any] | None,
94) -> torch.nn.Module:
95    """This is the official recommended flow for quantization in pytorch 2.0 export"""
96    logging.info("Quantizing Model...")
97    logging.debug(f"Original model: {model}")
98    quantizer = ArmQuantizer()
99
100    # if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
101    operator_config = get_symmetric_quantization_config(is_per_channel=False)
102    quantizer.set_global(operator_config)
103    m = prepare_pt2e(model, quantizer)
104
105    dataset = get_calibration_data(
106        model_name, example_inputs, evaluator_name, evaluator_config
107    )
108
109    # The dataset could be a tuple of tensors or a DataLoader
110    # These two cases need to be accounted for
111    if isinstance(dataset, DataLoader):
112        for sample, _ in dataset:
113            m(sample)
114    else:
115        m(*dataset)
116
117    m = convert_pt2e(m)
118    logging.debug(f"Quantized model: {m}")
119    return m
120
121
122# Simple example models
123class AddModule(torch.nn.Module):
124    def __init__(self):
125        super().__init__()
126
127    def forward(self, x):
128        return x + x
129
130    example_input = (torch.ones(5, dtype=torch.int32),)
131    can_delegate = True
132
133
134class AddModule2(torch.nn.Module):
135    def __init__(self):
136        super().__init__()
137
138    def forward(self, x, y):
139        return x + y
140
141    example_input = (
142        torch.ones(5, dtype=torch.int32),
143        torch.ones(5, dtype=torch.int32),
144    )
145    can_delegate = True
146
147
148class AddModule3(torch.nn.Module):
149    def __init__(self):
150        super().__init__()
151
152    def forward(self, x, y):
153        return (x + y, x + x)
154
155    example_input = (
156        torch.ones(5, dtype=torch.int32),
157        torch.ones(5, dtype=torch.int32),
158    )
159    can_delegate = True
160
161
162class SoftmaxModule(torch.nn.Module):
163    def __init__(self):
164        super().__init__()
165        self.softmax = torch.nn.Softmax(dim=0)
166
167    def forward(self, x):
168        z = self.softmax(x)
169        return z
170
171    example_input = (torch.ones(2, 2),)
172    can_delegate = False
173
174
175models = {
176    "add": AddModule,
177    "add2": AddModule2,
178    "add3": AddModule3,
179    "softmax": SoftmaxModule,
180}
181
182calibration_data = {
183    "add": (torch.randn(1, 5),),
184    "add2": (
185        torch.randn(1, 5),
186        torch.randn(1, 5),
187    ),
188    "add3": (
189        torch.randn(32, 5),
190        torch.randn(32, 5),
191    ),
192    "softmax": (torch.randn(32, 2, 2),),
193}
194
195evaluators = {
196    "generic": GenericModelEvaluator,
197    "mv2": MobileNetV2Evaluator,
198}
199
200targets = [
201    "ethos-u55-32",
202    "ethos-u55-64",
203    "ethos-u55-128",
204    "ethos-u55-256",
205    "ethos-u85-128",
206    "ethos-u85-256",
207    "ethos-u85-512",
208    "ethos-u85-1024",
209    "ethos-u85-2048",
210    "TOSA",
211]
212
213
214def get_calibration_data(
215    model_name: str,
216    example_inputs: Tuple[torch.Tensor],
217    evaluator_name: str | None,
218    evaluator_config: str | None,
219):
220    # Firstly, if the model is being evaluated, take the evaluators calibration function if it has one
221    if evaluator_name is not None:
222        evaluator = evaluators[evaluator_name]
223
224        if hasattr(evaluator, "get_calibrator"):
225            assert evaluator_config is not None
226
227            config_path = Path(evaluator_config)
228            with config_path.open() as f:
229                config = json.load(f)
230
231            if evaluator_name == "mv2":
232                return evaluator.get_calibrator(
233                    training_dataset_path=config["training_dataset_path"]
234                )
235            else:
236                raise RuntimeError(f"Unknown evaluator: {evaluator_name}")
237
238    # If the model is in the calibration_data dictionary, get the data from there
239    # This is used for the simple model examples provided
240    if model_name in calibration_data:
241        return calibration_data[model_name]
242
243    # As a last resort, fallback to the scripts previous behavior and return the example inputs
244    return example_inputs
245
246
247def get_compile_spec(
248    target: str, intermediates: Optional[str] = None
249) -> ArmCompileSpecBuilder:
250    spec_builder = None
251    if target == "TOSA":
252        spec_builder = (
253            ArmCompileSpecBuilder()
254            .tosa_compile_spec("TOSA-0.80.0+BI")
255            .set_permute_memory_format(True)
256        )
257    elif "ethos-u55" in target:
258        spec_builder = (
259            ArmCompileSpecBuilder()
260            .ethosu_compile_spec(
261                target,
262                system_config="Ethos_U55_High_End_Embedded",
263                memory_mode="Shared_Sram",
264                extra_flags="--debug-force-regor --output-format=raw",
265            )
266            .set_permute_memory_format(True)
267            .set_quantize_io(True)
268        )
269    elif "ethos-u85" in target:
270        spec_builder = (
271            ArmCompileSpecBuilder()
272            .ethosu_compile_spec(
273                target,
274                system_config="Ethos_U85_SYS_DRAM_Mid",
275                memory_mode="Shared_Sram",
276                extra_flags="--output-format=raw",
277            )
278            .set_permute_memory_format(True)
279            .set_quantize_io(True)
280        )
281
282    if intermediates is not None:
283        spec_builder.dump_intermediate_artifacts_to(intermediates)
284
285    return spec_builder.build()
286
287
288def evaluate_model(
289    model_name: str,
290    intermediates: str,
291    model_fp32: torch.nn.Module,
292    model_int8: torch.nn.Module,
293    example_inputs: Tuple[torch.Tensor],
294    evaluator_name: str,
295    evaluator_config: str | None,
296) -> None:
297    evaluator = evaluators[evaluator_name]
298
299    # Get the path of the TOSA flatbuffer that is dumped
300    intermediates_path = Path(intermediates)
301    tosa_paths = list(intermediates_path.glob("*.tosa"))
302
303    if evaluator.REQUIRES_CONFIG:
304        assert evaluator_config is not None
305
306        config_path = Path(evaluator_config)
307        with config_path.open() as f:
308            config = json.load(f)
309
310        if evaluator_name == "mv2":
311            init_evaluator = evaluator(
312                model_name,
313                model_fp32,
314                model_int8,
315                example_inputs,
316                str(tosa_paths[0]),
317                config["batch_size"],
318                config["validation_dataset_path"],
319            )
320        else:
321            raise RuntimeError(f"Unknown evaluator {evaluator_name}")
322    else:
323        init_evaluator = evaluator(
324            model_name, model_fp32, model_int8, example_inputs, str(tosa_paths[0])
325        )
326
327    quant_metrics = init_evaluator.evaluate()
328    output_json_path = intermediates_path / "quant_metrics.json"
329
330    with output_json_path.open("w") as json_file:
331        json.dump(quant_metrics, json_file)
332
333
334def dump_delegation_info(edge, intermediate_files_folder: Optional[str] = None):
335    graph_module = edge.exported_program().graph_module
336    delegation_info = get_delegation_info(graph_module)
337    df = delegation_info.get_operator_delegation_dataframe()
338    table = tabulate(df, headers="keys", tablefmt="fancy_grid")
339    delegation_info_string = f"Delegation info:\n{delegation_info.get_summary()}\nDelegation table:\n{table}\n"
340    logging.info(delegation_info_string)
341    if intermediate_files_folder is not None:
342        delegation_file_path = os.path.join(
343            intermediate_files_folder, "delegation_info.txt"
344        )
345        with open(delegation_file_path, "w") as file:
346            file.write(delegation_info_string)
347
348
349def get_args():
350    parser = argparse.ArgumentParser()
351    parser.add_argument(
352        "-m",
353        "--model_name",
354        required=True,
355        help=f"Provide model name. Valid ones: {set(list(models.keys())+list(MODEL_NAME_TO_MODEL.keys()))}",
356    )
357    parser.add_argument(
358        "-d",
359        "--delegate",
360        action="store_true",
361        required=False,
362        default=False,
363        help="Flag for producing ArmBackend delegated model",
364    )
365    parser.add_argument(
366        "-t",
367        "--target",
368        action="store",
369        required=False,
370        default="ethos-u55-128",
371        choices=targets,
372        help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {targets}",
373    )
374    parser.add_argument(
375        "-e",
376        "--evaluate",
377        required=False,
378        nargs="?",
379        const="generic",
380        choices=["generic", "mv2"],
381        help="Flag for running evaluation of the model.",
382    )
383    parser.add_argument(
384        "-c",
385        "--evaluate_config",
386        required=False,
387        default=None,
388        help="Provide path to evaluator config, if it is required.",
389    )
390    parser.add_argument(
391        "-q",
392        "--quantize",
393        action="store_true",
394        required=False,
395        default=False,
396        help="Produce a quantized model",
397    )
398    parser.add_argument(
399        "-s",
400        "--so_library",
401        required=False,
402        default=None,
403        help="Provide path to so library. E.g., cmake-out/examples/portable/custom_ops/libcustom_ops_aot_lib.so",
404    )
405    parser.add_argument(
406        "--debug", action="store_true", help="Set the logging level to debug."
407    )
408    parser.add_argument(
409        "-i",
410        "--intermediates",
411        action="store",
412        required=False,
413        help="Store intermediate output (like TOSA artefacts) somewhere.",
414    )
415    parser.add_argument(
416        "-o",
417        "--output",
418        action="store",
419        required=False,
420        help="Location for outputs, if not the default of cwd.",
421    )
422    args = parser.parse_args()
423
424    if args.evaluate and (
425        args.quantize is None or args.intermediates is None or (not args.delegate)
426    ):
427        raise RuntimeError(
428            "--evaluate requires --quantize, --intermediates and --delegate to be enabled."
429        )
430
431    if args.debug:
432        logging.basicConfig(level=logging.DEBUG, format=FORMAT, force=True)
433
434    if args.quantize and not args.so_library:
435        logging.warning(
436            "Quantization enabled without supplying path to libcustom_ops_aot_lib using -s flag."
437            + "This is required for running quantized models with unquantized input."
438        )
439
440    # if we have custom ops, register them before processing the model
441    if args.so_library is not None:
442        logging.info(f"Loading custom ops from {args.so_library}")
443        torch.ops.load_library(args.so_library)
444
445    if (
446        args.model_name in models.keys()
447        and args.delegate is True
448        and models[args.model_name].can_delegate is False
449    ):
450        raise RuntimeError(f"Model {args.model_name} cannot be delegated.")
451
452    return args
453
454
455if __name__ == "__main__":
456    args = get_args()
457
458    # Pick model from one of the supported lists
459    model, example_inputs = get_model_and_inputs_from_name(args.model_name)
460    model = model.eval()
461
462    # export_for_training under the assumption we quantize, the exported form also works
463    # in to_edge if we don't quantize
464    exported_program = torch.export.export_for_training(model, example_inputs)
465    model = exported_program.module()
466    model_fp32 = model
467
468    # Quantize if required
469    model_int8 = None
470    if args.quantize:
471        model = quantize(
472            model, args.model_name, example_inputs, args.evaluate, args.evaluate_config
473        )
474        model_int8 = model
475        # Wrap quantized model back into an exported_program
476        exported_program = torch.export.export_for_training(model, example_inputs)
477
478    if args.intermediates:
479        os.makedirs(args.intermediates, exist_ok=True)
480
481    if args.delegate:
482        # As we can target multiple output encodings from ArmBackend, one must
483        # be specified.
484        compile_spec = get_compile_spec(args.target, args.intermediates)
485        edge = to_edge_transform_and_lower(
486            exported_program,
487            partitioner=[ArmPartitioner(compile_spec)],
488            compile_config=EdgeCompileConfig(
489                _check_ir_validity=False,
490                _skip_dim_order=True,
491            ),
492        )
493    else:
494        edge = to_edge_transform_and_lower(
495            exported_program,
496            compile_config=EdgeCompileConfig(
497                _check_ir_validity=False,
498                _skip_dim_order=True,
499            ),
500        )
501
502    dump_delegation_info(edge, args.intermediates)
503
504    try:
505        exec_prog = edge.to_executorch(
506            config=ExecutorchBackendConfig(extract_delegate_segments=False)
507        )
508    except RuntimeError as e:
509        if "Missing out variants" in str(e.args[0]):
510            raise RuntimeError(
511                e.args[0]
512                + ".\nThis likely due to an external so library not being loaded. Supply a path to it with the -s flag."
513            ).with_traceback(e.__traceback__) from None
514        else:
515            raise e
516
517    model_name = os.path.basename(os.path.splitext(args.model_name)[0])
518    output_name = f"{model_name}" + (
519        f"_arm_delegate_{args.target}"
520        if args.delegate is True
521        else f"_arm_{args.target}"
522    )
523
524    if args.output is not None:
525        output_name = os.path.join(args.output, output_name)
526
527    save_pte_program(exec_prog, output_name)
528
529    if args.evaluate:
530        evaluate_model(
531            args.model_name,
532            args.intermediates,
533            model_fp32,
534            model_int8,
535            example_inputs,
536            args.evaluate,
537            args.evaluate_config,
538        )
539