1#!/usr/bin/env fbpython 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8# Generates a template `functions.yaml` from a model binary. Ignoring all custom ops 9import argparse 10import os 11import sys 12 13from typing import Any, List 14 15import torch 16import yaml 17from executorch.codegen.tools.yaml_util import BlankLineDumper 18from executorch.exir._serialize import _deserialize_pte_binary 19from executorch.exir.schema import Operator 20 21 22def get_operators(model_file: str) -> List[Operator]: 23 print("Processing model file: ", model_file) 24 with open(model_file, "rb") as f: 25 flatbuffer = f.read() 26 program = _deserialize_pte_binary(flatbuffer) 27 print(f"Program loaded from model file: {model_file}") 28 operators = program.execution_plan[0].operators 29 return operators 30 31 32def dump_yaml(model_file: str, output_file: str) -> None: 33 ops = get_operators(model_file) 34 m = [] 35 for op in ops: 36 if op.name.startswith("aten::"): 37 schemas = torch._C._jit_get_schemas_for_operator(op.name) 38 m.extend(filter(lambda s: s.overload_name == op.overload, schemas)) 39 else: 40 print(f"Warning: not generating template for {op.name}") 41 output = [] 42 for s in m: 43 print(str(s)) 44 name = s.name.replace("aten::", "torch::executor::") 45 output.append( 46 { 47 "func": str(s), 48 "variants": "function", 49 "dispatch": { 50 "CPU": f"{name}_{s.overload_name}", 51 }, 52 } 53 ) 54 with open(output_file, "w") as f: 55 yaml.dump( 56 output, 57 f, 58 Dumper=BlankLineDumper, 59 default_flow_style=False, 60 sort_keys=False, 61 width=1000, 62 ) 63 64 65def main(args: List[Any]) -> None: 66 """This binary generates a template functions.yaml which will be consumed by ExecuTorch codegen. 67 It reads the model file, deserialize it and dumps all the operators into a new functions.yaml. 68 The generated file contains placeholder kernels, it needs to be updated with proper kernel names. 69 """ 70 parser = argparse.ArgumentParser( 71 description="Generate operator list from a model file" 72 ) 73 parser.add_argument( 74 "--output_path", 75 help=("The path to the output yaml file (functions.yaml)"), 76 required=True, 77 ) 78 parser.add_argument( 79 "--model_file_path", 80 help=("Path to an executorch program"), 81 required=False, 82 ) 83 options = parser.parse_args(args) 84 assert options.model_file_path, "Need to provide a model_file_path." 85 86 assert os.path.isfile( 87 options.model_file_path 88 ), "The value for --model_file_path needs to be a valid file." 89 dump_yaml( 90 model_file=options.model_file_path, 91 output_file=options.output_path if options.output_path else "./functions.yaml", 92 ) 93 94 95if __name__ == "__main__": 96 main(sys.argv[1:]) 97