xref: /aosp_15_r20/external/executorch/examples/apple/mps/scripts/mps_example.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#
2#  Copyright (c) 2023 Apple Inc. All rights reserved.
3#  Provided subject to the LICENSE file in the top level directory.
4#
5
6# Example script for exporting simple models to flatbuffer
7
8import argparse
9import copy
10import logging
11
12import torch
13from examples.apple.mps.scripts.bench_utils import bench_torch, compare_outputs
14from executorch import exir
15from executorch.backends.apple.mps import MPSBackend
16from executorch.backends.apple.mps.partition import MPSPartitioner
17from executorch.devtools import BundledProgram, generate_etrecord
18from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
19from executorch.devtools.bundled_program.serialize import (
20    serialize_from_bundled_program_to_flatbuffer,
21)
22
23from executorch.exir import (
24    EdgeCompileConfig,
25    EdgeProgramManager,
26    ExecutorchProgramManager,
27)
28from executorch.exir.backend.backend_api import to_backend
29from executorch.exir.backend.backend_details import CompileSpec
30from executorch.exir.capture._config import ExecutorchBackendConfig
31from executorch.extension.export_util.utils import export_to_edge, save_pte_program
32
33from ....models import MODEL_NAME_TO_MODEL
34from ....models.model_factory import EagerModelFactory
35
36FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
37logging.basicConfig(level=logging.INFO, format=FORMAT)
38
39
40def get_bundled_program(executorch_program, example_inputs, expected_output):
41    method_test_suites = [
42        MethodTestSuite(
43            method_name="forward",
44            test_cases=[
45                MethodTestCase(
46                    inputs=example_inputs, expected_outputs=[expected_output]
47                )
48            ],
49        )
50    ]
51    logging.info(f"Expected output: {expected_output}")
52
53    bundled_program = BundledProgram(executorch_program, method_test_suites)
54    bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer(
55        bundled_program
56    )
57    return bundled_program_buffer
58
59
60def parse_args():
61    parser = argparse.ArgumentParser()
62    parser.add_argument(
63        "-m",
64        "--model_name",
65        required=True,
66        help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}",
67    )
68
69    parser.add_argument(
70        "--use_fp16",
71        default=True,
72        action=argparse.BooleanOptionalAction,
73        help="Whether to automatically convert float32 operations to float16 operations.",
74    )
75
76    parser.add_argument(
77        "--use_partitioner",
78        default=True,
79        action=argparse.BooleanOptionalAction,
80        help="Use MPS partitioner to run the model instead of using whole graph lowering.",
81    )
82
83    parser.add_argument(
84        "--bench_pytorch",
85        default=False,
86        action=argparse.BooleanOptionalAction,
87        help="Bench ExecuTorch MPS foward pass with PyTorch MPS forward pass.",
88    )
89
90    parser.add_argument(
91        "-b",
92        "--bundled",
93        action="store_true",
94        required=False,
95        default=False,
96        help="Flag for bundling inputs and outputs in the final flatbuffer program",
97    )
98
99    parser.add_argument(
100        "-c",
101        "--check_correctness",
102        action="store_true",
103        required=False,
104        default=False,
105        help="Whether to compare the ExecuTorch MPS results with the PyTorch forward pass",
106    )
107
108    parser.add_argument(
109        "--generate_etrecord",
110        action="store_true",
111        required=False,
112        default=False,
113        help="Generate ETRecord metadata to link with runtime results (used for profiling)",
114    )
115
116    parser.add_argument(
117        "--checkpoint",
118        required=False,
119        default=None,
120        help="checkpoing for llama model",
121    )
122
123    parser.add_argument(
124        "--params",
125        required=False,
126        default=None,
127        help="params for llama model",
128    )
129
130    args = parser.parse_args()
131    return args
132
133
134def get_model_config(args):
135    model_config = {}
136    model_config["module_name"] = MODEL_NAME_TO_MODEL[args.model_name][0]
137    model_config["model_class_name"] = MODEL_NAME_TO_MODEL[args.model_name][1]
138
139    if args.model_name == "llama2":
140        if args.checkpoint:
141            model_config["checkpoint"] = args.checkpoint
142        if args.params:
143            model_config["params"] = args.params
144        model_config["use_kv_cache"] = True
145    return model_config
146
147
148if __name__ == "__main__":
149    args = parse_args()
150
151    if args.model_name not in MODEL_NAME_TO_MODEL:
152        raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.")
153
154    model_config = get_model_config(args)
155    model, example_inputs, _, _ = EagerModelFactory.create_model(**model_config)
156
157    model = model.eval()
158
159    # Deep copy the model inputs to check against PyTorch forward pass
160    if args.check_correctness or args.bench_pytorch:
161        model_copy = copy.deepcopy(model)
162        inputs_copy = []
163        for t in example_inputs:
164            inputs_copy.append(t.detach().clone())
165        inputs_copy = tuple(inputs_copy)
166
167    # pre-autograd export. eventually this will become torch.export
168    with torch.no_grad():
169        model = torch.export.export_for_training(model, example_inputs).module()
170        edge: EdgeProgramManager = export_to_edge(
171            model,
172            example_inputs,
173            edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
174        )
175
176    edge_program_manager_copy = copy.deepcopy(edge)
177
178    compile_specs = [CompileSpec("use_fp16", bytes([args.use_fp16]))]
179
180    logging.info(f"Edge IR graph:\n{edge.exported_program().graph}")
181    if args.use_partitioner:
182        edge = edge.to_backend(MPSPartitioner(compile_specs=compile_specs))
183        logging.info(f"Lowered graph:\n{edge.exported_program().graph}")
184
185        executorch_program = edge.to_executorch(
186            config=ExecutorchBackendConfig(extract_delegate_segments=False)
187        )
188    else:
189        lowered_module = to_backend(
190            MPSBackend.__name__, edge.exported_program(), compile_specs
191        )
192        executorch_program: ExecutorchProgramManager = export_to_edge(
193            lowered_module,
194            example_inputs,
195            edge_compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
196        ).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))
197
198    dtype = "float16" if args.use_fp16 else "float32"
199    model_name = f"{args.model_name}_mps_{dtype}"
200
201    if args.bundled:
202        expected_output = model(*example_inputs)
203        bundled_program_buffer = get_bundled_program(
204            executorch_program, example_inputs, expected_output
205        )
206        model_name = f"{model_name}_bundled.pte"
207
208    if args.generate_etrecord:
209        etrecord_path = "etrecord.bin"
210        logging.info("generating etrecord.bin")
211        generate_etrecord(etrecord_path, edge_program_manager_copy, executorch_program)
212
213    if args.bundled:
214        with open(model_name, "wb") as file:
215            file.write(bundled_program_buffer)
216        logging.info(f"Saved bundled program to {model_name}")
217    else:
218        save_pte_program(executorch_program, model_name)
219
220    if args.bench_pytorch:
221        bench_torch(executorch_program, model_copy, example_inputs, model_name)
222
223    if args.check_correctness:
224        compare_outputs(
225            executorch_program, model_copy, inputs_copy, model_name, args.use_fp16
226        )
227