xref: /aosp_15_r20/external/executorch/codegen/tools/gen_ops_def.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#!/usr/bin/env fbpython
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8# Generates a template `functions.yaml` from a model binary. Ignoring all custom ops
9import argparse
10import os
11import sys
12
13from typing import Any, List
14
15import torch
16import yaml
17from executorch.codegen.tools.yaml_util import BlankLineDumper
18from executorch.exir._serialize import _deserialize_pte_binary
19from executorch.exir.schema import Operator
20
21
22def get_operators(model_file: str) -> List[Operator]:
23    print("Processing model file: ", model_file)
24    with open(model_file, "rb") as f:
25        flatbuffer = f.read()
26    program = _deserialize_pte_binary(flatbuffer)
27    print(f"Program loaded from model file: {model_file}")
28    operators = program.execution_plan[0].operators
29    return operators
30
31
32def dump_yaml(model_file: str, output_file: str) -> None:
33    ops = get_operators(model_file)
34    m = []
35    for op in ops:
36        if op.name.startswith("aten::"):
37            schemas = torch._C._jit_get_schemas_for_operator(op.name)
38            m.extend(filter(lambda s: s.overload_name == op.overload, schemas))
39        else:
40            print(f"Warning: not generating template for {op.name}")
41    output = []
42    for s in m:
43        print(str(s))
44        name = s.name.replace("aten::", "torch::executor::")
45        output.append(
46            {
47                "func": str(s),
48                "variants": "function",
49                "dispatch": {
50                    "CPU": f"{name}_{s.overload_name}",
51                },
52            }
53        )
54    with open(output_file, "w") as f:
55        yaml.dump(
56            output,
57            f,
58            Dumper=BlankLineDumper,
59            default_flow_style=False,
60            sort_keys=False,
61            width=1000,
62        )
63
64
65def main(args: List[Any]) -> None:
66    """This binary generates a template functions.yaml which will be consumed by ExecuTorch codegen.
67    It reads the model file, deserialize it and dumps all the operators into a new functions.yaml.
68    The generated file contains placeholder kernels, it needs to be updated with proper kernel names.
69    """
70    parser = argparse.ArgumentParser(
71        description="Generate operator list from a model file"
72    )
73    parser.add_argument(
74        "--output_path",
75        help=("The path to the output yaml file (functions.yaml)"),
76        required=True,
77    )
78    parser.add_argument(
79        "--model_file_path",
80        help=("Path to an executorch program"),
81        required=False,
82    )
83    options = parser.parse_args(args)
84    assert options.model_file_path, "Need to provide a model_file_path."
85
86    assert os.path.isfile(
87        options.model_file_path
88    ), "The value for --model_file_path needs to be a valid file."
89    dump_yaml(
90        model_file=options.model_file_path,
91        output_file=options.output_path if options.output_path else "./functions.yaml",
92    )
93
94
95if __name__ == "__main__":
96    main(sys.argv[1:])
97