1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import os 9 10from typing import Dict 11 12from executorch.backends.vulkan.test.op_tests.cases import test_suites 13from executorch.backends.vulkan.test.op_tests.utils.gen_computegraph import ( 14 ComputeGraphGen, 15) 16 17from executorch.backends.vulkan.test.op_tests.utils.gen_correctness_vk import ( 18 VkCorrectnessTestFileGen, 19) 20from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite 21from torchgen import local 22 23from torchgen.gen import parse_native_yaml, ParsedYaml 24from torchgen.model import DispatchKey, NativeFunction 25 26 27def registry_name(f: NativeFunction) -> str: 28 name = str(f.namespace) + "." + str(f.func.name) 29 if len(f.func.name.overload_name) == 0: 30 name += ".default" 31 return name 32 33 34def construct_f_map(parsed_yaml: ParsedYaml) -> Dict[str, NativeFunction]: 35 f_map: Dict[str, NativeFunction] = {} 36 for f in parsed_yaml.native_functions: 37 f_map[registry_name(f)] = f 38 return f_map 39 40 41def process_test_suites( 42 cpp_generator: VkCorrectnessTestFileGen, 43 f_map: Dict[str, NativeFunction], 44 test_suites: Dict[str, TestSuite], 45) -> None: 46 for registry_name, op_test_suites in test_suites.items(): 47 f = f_map[registry_name] 48 if isinstance(op_test_suites, list): 49 for suite in op_test_suites: 50 cpp_generator.add_suite(registry_name, f, suite) 51 else: 52 cpp_generator.add_suite(registry_name, f, op_test_suites) 53 54 55@local.parametrize( 56 use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False 57) 58def generate_cpp( 59 native_functions_yaml_path: str, tags_path: str, output_dir: str 60) -> None: 61 output_file = os.path.join(output_dir, "op_tests.cpp") 62 cpp_generator = VkCorrectnessTestFileGen(output_file) 63 64 parsed_yaml = parse_native_yaml(native_functions_yaml_path, tags_path) 65 f_map = construct_f_map(parsed_yaml) 66 67 ComputeGraphGen.backend_key = parsed_yaml.backend_indices[DispatchKey.CPU] 68 69 process_test_suites(cpp_generator, f_map, test_suites) 70 71 with open(output_file, "w") as file: 72 file.write(cpp_generator.generate_cpp()) 73 74 75if __name__ == "__main__": 76 parser = argparse.ArgumentParser() 77 parser.add_argument( 78 "--aten-yaml-path", 79 help="path to native_functions.yaml file.", 80 ) 81 parser.add_argument( 82 "--tags-path", 83 help="Path to tags.yaml. Required by yaml parsing in gen_correctness_vk system.", 84 ) 85 parser.add_argument("-o", "--output", help="Output directory", required=True) 86 args = parser.parse_args() 87 generate_cpp(args.aten_yaml_path, args.tags_path, args.output) 88