xref: /aosp_15_r20/external/executorch/backends/vulkan/test/op_tests/targets.bzl (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID")
2load("@fbsource//xplat/caffe2:pt_defs.bzl", "get_pt_ops_deps")
3load("@fbsource//xplat/caffe2:pt_ops.bzl", "pt_operator_library")
4load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
5
6def define_common_targets(is_fbcode = False):
7    if is_fbcode:
8        return
9
10    runtime.python_library(
11        name = "generate_op_correctness_tests_lib",
12        srcs = native.glob(["utils/*.py"]) + [
13            "generate_op_correctness_tests.py",
14            "cases.py",
15        ],
16        base_module = "executorch.backends.vulkan.test.op_tests",
17        deps = [
18            "fbsource//third-party/pypi/expecttest:expecttest",
19        ],
20        external_deps = ["torchgen"],
21    )
22
23    runtime.python_library(
24        name = "generate_op_benchmarks_lib",
25        srcs = native.glob(["utils/*.py"]) + [
26            "generate_op_benchmarks.py",
27            "cases.py",
28        ],
29        base_module = "executorch.backends.vulkan.test.op_tests",
30        deps = [
31            "fbsource//third-party/pypi/expecttest:expecttest",
32        ],
33        external_deps = ["torchgen"],
34    )
35
36    runtime.python_binary(
37        name = "generate_op_correctness_tests",
38        main_module = "executorch.backends.vulkan.test.op_tests.generate_op_correctness_tests",
39        deps = [
40            ":generate_op_correctness_tests_lib",
41        ],
42    )
43
44    runtime.python_binary(
45        name = "generate_op_benchmarks",
46        main_module = "executorch.backends.vulkan.test.op_tests.generate_op_benchmarks",
47        deps = [
48            ":generate_op_benchmarks_lib",
49        ],
50    )
51
52    aten_src_path = runtime.external_dep_location("aten-src-path")
53    genrule_cmd = [
54        "$(exe :generate_op_correctness_tests)",
55        "--tags-path $(location {})/aten/src/ATen/native/tags.yaml".format(aten_src_path),
56        "--aten-yaml-path $(location {})/aten/src/ATen/native/native_functions.yaml".format(aten_src_path),
57        "-o $OUT",
58    ]
59
60    runtime.genrule(
61        name = "generated_op_correctness_tests_cpp",
62        outs = {
63            "op_tests.cpp": ["op_tests.cpp"],
64        },
65        cmd = " ".join(genrule_cmd),
66        default_outs = ["."],
67    )
68
69    benchmarks_genrule_cmd = [
70        "$(exe :generate_op_benchmarks)",
71        "--tags-path $(location {})/aten/src/ATen/native/tags.yaml".format(aten_src_path),
72        "--aten-yaml-path $(location {})/aten/src/ATen/native/native_functions.yaml".format(aten_src_path),
73        "-o $OUT",
74    ]
75
76    runtime.genrule(
77        name = "generated_op_benchmarks_cpp",
78        outs = {
79            "op_benchmarks.cpp": ["op_benchmarks.cpp"],
80        },
81        cmd = " ".join(benchmarks_genrule_cmd),
82        default_outs = ["."],
83    )
84
85    pt_operator_library(
86        name = "all_aten_ops",
87        check_decl = False,
88        include_all_operators = True,
89    )
90
91    runtime.cxx_library(
92        name = "all_aten_ops_lib",
93        srcs = [],
94        define_static_target = False,
95        exported_deps = get_pt_ops_deps(
96            name = "pt_ops_full",
97            deps = [
98                ":all_aten_ops",
99            ],
100        ),
101    )
102
103    runtime.cxx_binary(
104        name = "compute_graph_op_tests_bin",
105        srcs = [
106            ":generated_op_correctness_tests_cpp[op_tests.cpp]",
107        ],
108        define_static_target = False,
109        deps = [
110            "//third-party/googletest:gtest_main",
111            "//executorch/backends/vulkan:vulkan_graph_runtime",
112            ":all_aten_ops_lib",
113        ],
114    )
115
116    runtime.cxx_binary(
117        name = "compute_graph_op_benchmarks_bin",
118        srcs = [
119            ":generated_op_benchmarks_cpp[op_benchmarks.cpp]",
120        ],
121        compiler_flags = [
122            "-Wno-unused-variable",
123        ],
124        define_static_target = False,
125        deps = [
126            "//third-party/benchmark:benchmark",
127            "//executorch/backends/vulkan:vulkan_graph_runtime",
128            ":all_aten_ops_lib",
129        ],
130    )
131
132    runtime.cxx_test(
133        name = "compute_graph_op_tests",
134        srcs = [
135            ":generated_op_correctness_tests_cpp[op_tests.cpp]",
136        ],
137        contacts = ["[email protected]"],
138        fbandroid_additional_loaded_sonames = [
139            "torch-code-gen",
140            "vulkan_graph_runtime",
141            "vulkan_graph_runtime_shaderlib",
142        ],
143        platforms = [ANDROID],
144        use_instrumentation_test = True,
145        deps = [
146            "//third-party/googletest:gtest_main",
147            "//executorch/backends/vulkan:vulkan_graph_runtime",
148            runtime.external_dep_location("libtorch"),
149        ],
150    )
151
152    runtime.cxx_binary(
153        name = "sdpa_test_bin",
154        srcs = [
155            "sdpa_test.cpp",
156        ],
157        compiler_flags = [
158            "-Wno-unused-variable",
159        ],
160        define_static_target = False,
161        deps = [
162            "//third-party/googletest:gtest_main",
163            "//executorch/backends/vulkan:vulkan_graph_runtime",
164            "//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
165        ],
166    )
167
168    runtime.cxx_test(
169        name = "sdpa_test",
170        srcs = [
171            "sdpa_test.cpp",
172        ],
173        contacts = ["[email protected]"],
174        fbandroid_additional_loaded_sonames = [
175            "torch-code-gen",
176            "vulkan_graph_runtime",
177            "vulkan_graph_runtime_shaderlib",
178        ],
179        platforms = [ANDROID],
180        use_instrumentation_test = True,
181        deps = [
182            "//third-party/googletest:gtest_main",
183            "//executorch/backends/vulkan:vulkan_graph_runtime",
184            "//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
185            "//executorch/extension/tensor:tensor",
186            runtime.external_dep_location("libtorch"),
187        ],
188    )
189
190    runtime.cxx_binary(
191        name = "linear_weight_int4_test_bin",
192        srcs = [
193            "linear_weight_int4_test.cpp",
194        ],
195        compiler_flags = [
196            "-Wno-unused-variable",
197        ],
198        define_static_target = False,
199        deps = [
200            "//third-party/googletest:gtest_main",
201            "//executorch/backends/vulkan:vulkan_graph_runtime",
202            runtime.external_dep_location("libtorch"),
203        ],
204    )
205
206    runtime.cxx_test(
207        name = "linear_weight_int4_test",
208        srcs = [
209            "linear_weight_int4_test.cpp",
210        ],
211        contacts = ["[email protected]"],
212        fbandroid_additional_loaded_sonames = [
213            "torch-code-gen",
214            "vulkan_graph_runtime",
215            "vulkan_graph_runtime_shaderlib",
216        ],
217        platforms = [ANDROID],
218        use_instrumentation_test = True,
219        deps = [
220            "//third-party/googletest:gtest_main",
221            "//executorch/backends/vulkan:vulkan_graph_runtime",
222            "//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
223            "//executorch/extension/tensor:tensor",
224            runtime.external_dep_location("libtorch"),
225        ],
226    )
227
228    runtime.cxx_binary(
229        name = "rotary_embedding_test_bin",
230        srcs = [
231            "rotary_embedding_test.cpp",
232        ],
233        compiler_flags = [
234            "-Wno-unused-variable",
235        ],
236        define_static_target = False,
237        deps = [
238            "//third-party/googletest:gtest_main",
239            "//executorch/backends/vulkan:vulkan_graph_runtime",
240            runtime.external_dep_location("libtorch"),
241        ],
242    )
243
244    runtime.cxx_test(
245        name = "rotary_embedding_test",
246        srcs = [
247            "rotary_embedding_test.cpp",
248        ],
249        contacts = ["[email protected]"],
250        fbandroid_additional_loaded_sonames = [
251            "torch-code-gen",
252            "vulkan_graph_runtime",
253            "vulkan_graph_runtime_shaderlib",
254        ],
255        platforms = [ANDROID],
256        use_instrumentation_test = True,
257        deps = [
258            "//third-party/googletest:gtest_main",
259            "//executorch/backends/vulkan:vulkan_graph_runtime",
260            "//executorch/extension/tensor:tensor",
261            runtime.external_dep_location("libtorch"),
262        ],
263    )
264