xref: /aosp_15_r20/external/executorch/backends/vulkan/test/op_tests/generate_op_correctness_tests.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import os
9
10from typing import Dict
11
12from executorch.backends.vulkan.test.op_tests.cases import test_suites
13from executorch.backends.vulkan.test.op_tests.utils.gen_computegraph import (
14    ComputeGraphGen,
15)
16
17from executorch.backends.vulkan.test.op_tests.utils.gen_correctness_vk import (
18    VkCorrectnessTestFileGen,
19)
20from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite
21from torchgen import local
22
23from torchgen.gen import parse_native_yaml, ParsedYaml
24from torchgen.model import DispatchKey, NativeFunction
25
26
27def registry_name(f: NativeFunction) -> str:
28    name = str(f.namespace) + "." + str(f.func.name)
29    if len(f.func.name.overload_name) == 0:
30        name += ".default"
31    return name
32
33
34def construct_f_map(parsed_yaml: ParsedYaml) -> Dict[str, NativeFunction]:
35    f_map: Dict[str, NativeFunction] = {}
36    for f in parsed_yaml.native_functions:
37        f_map[registry_name(f)] = f
38    return f_map
39
40
41def process_test_suites(
42    cpp_generator: VkCorrectnessTestFileGen,
43    f_map: Dict[str, NativeFunction],
44    test_suites: Dict[str, TestSuite],
45) -> None:
46    for registry_name, op_test_suites in test_suites.items():
47        f = f_map[registry_name]
48        if isinstance(op_test_suites, list):
49            for suite in op_test_suites:
50                cpp_generator.add_suite(registry_name, f, suite)
51        else:
52            cpp_generator.add_suite(registry_name, f, op_test_suites)
53
54
55@local.parametrize(
56    use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
57)
58def generate_cpp(
59    native_functions_yaml_path: str, tags_path: str, output_dir: str
60) -> None:
61    output_file = os.path.join(output_dir, "op_tests.cpp")
62    cpp_generator = VkCorrectnessTestFileGen(output_file)
63
64    parsed_yaml = parse_native_yaml(native_functions_yaml_path, tags_path)
65    f_map = construct_f_map(parsed_yaml)
66
67    ComputeGraphGen.backend_key = parsed_yaml.backend_indices[DispatchKey.CPU]
68
69    process_test_suites(cpp_generator, f_map, test_suites)
70
71    with open(output_file, "w") as file:
72        file.write(cpp_generator.generate_cpp())
73
74
75if __name__ == "__main__":
76    parser = argparse.ArgumentParser()
77    parser.add_argument(
78        "--aten-yaml-path",
79        help="path to native_functions.yaml file.",
80    )
81    parser.add_argument(
82        "--tags-path",
83        help="Path to tags.yaml. Required by yaml parsing in gen_correctness_vk system.",
84    )
85    parser.add_argument("-o", "--output", help="Output directory", required=True)
86    args = parser.parse_args()
87    generate_cpp(args.aten_yaml_path, args.tags_path, args.output)
88