1load("@fbsource//tools/build_defs:platform_defs.bzl", "ANDROID") 2load("@fbsource//xplat/caffe2:pt_defs.bzl", "get_pt_ops_deps") 3load("@fbsource//xplat/caffe2:pt_ops.bzl", "pt_operator_library") 4load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") 5 6def define_common_targets(is_fbcode = False): 7 if is_fbcode: 8 return 9 10 runtime.python_library( 11 name = "generate_op_correctness_tests_lib", 12 srcs = native.glob(["utils/*.py"]) + [ 13 "generate_op_correctness_tests.py", 14 "cases.py", 15 ], 16 base_module = "executorch.backends.vulkan.test.op_tests", 17 deps = [ 18 "fbsource//third-party/pypi/expecttest:expecttest", 19 ], 20 external_deps = ["torchgen"], 21 ) 22 23 runtime.python_library( 24 name = "generate_op_benchmarks_lib", 25 srcs = native.glob(["utils/*.py"]) + [ 26 "generate_op_benchmarks.py", 27 "cases.py", 28 ], 29 base_module = "executorch.backends.vulkan.test.op_tests", 30 deps = [ 31 "fbsource//third-party/pypi/expecttest:expecttest", 32 ], 33 external_deps = ["torchgen"], 34 ) 35 36 runtime.python_binary( 37 name = "generate_op_correctness_tests", 38 main_module = "executorch.backends.vulkan.test.op_tests.generate_op_correctness_tests", 39 deps = [ 40 ":generate_op_correctness_tests_lib", 41 ], 42 ) 43 44 runtime.python_binary( 45 name = "generate_op_benchmarks", 46 main_module = "executorch.backends.vulkan.test.op_tests.generate_op_benchmarks", 47 deps = [ 48 ":generate_op_benchmarks_lib", 49 ], 50 ) 51 52 aten_src_path = runtime.external_dep_location("aten-src-path") 53 genrule_cmd = [ 54 "$(exe :generate_op_correctness_tests)", 55 "--tags-path $(location {})/aten/src/ATen/native/tags.yaml".format(aten_src_path), 56 "--aten-yaml-path $(location {})/aten/src/ATen/native/native_functions.yaml".format(aten_src_path), 57 "-o $OUT", 58 ] 59 60 runtime.genrule( 61 name = "generated_op_correctness_tests_cpp", 62 outs = { 63 "op_tests.cpp": ["op_tests.cpp"], 64 }, 65 cmd = " ".join(genrule_cmd), 66 default_outs = ["."], 67 ) 68 69 benchmarks_genrule_cmd = [ 70 "$(exe :generate_op_benchmarks)", 71 "--tags-path $(location {})/aten/src/ATen/native/tags.yaml".format(aten_src_path), 72 "--aten-yaml-path $(location {})/aten/src/ATen/native/native_functions.yaml".format(aten_src_path), 73 "-o $OUT", 74 ] 75 76 runtime.genrule( 77 name = "generated_op_benchmarks_cpp", 78 outs = { 79 "op_benchmarks.cpp": ["op_benchmarks.cpp"], 80 }, 81 cmd = " ".join(benchmarks_genrule_cmd), 82 default_outs = ["."], 83 ) 84 85 pt_operator_library( 86 name = "all_aten_ops", 87 check_decl = False, 88 include_all_operators = True, 89 ) 90 91 runtime.cxx_library( 92 name = "all_aten_ops_lib", 93 srcs = [], 94 define_static_target = False, 95 exported_deps = get_pt_ops_deps( 96 name = "pt_ops_full", 97 deps = [ 98 ":all_aten_ops", 99 ], 100 ), 101 ) 102 103 runtime.cxx_binary( 104 name = "compute_graph_op_tests_bin", 105 srcs = [ 106 ":generated_op_correctness_tests_cpp[op_tests.cpp]", 107 ], 108 define_static_target = False, 109 deps = [ 110 "//third-party/googletest:gtest_main", 111 "//executorch/backends/vulkan:vulkan_graph_runtime", 112 ":all_aten_ops_lib", 113 ], 114 ) 115 116 runtime.cxx_binary( 117 name = "compute_graph_op_benchmarks_bin", 118 srcs = [ 119 ":generated_op_benchmarks_cpp[op_benchmarks.cpp]", 120 ], 121 compiler_flags = [ 122 "-Wno-unused-variable", 123 ], 124 define_static_target = False, 125 deps = [ 126 "//third-party/benchmark:benchmark", 127 "//executorch/backends/vulkan:vulkan_graph_runtime", 128 ":all_aten_ops_lib", 129 ], 130 ) 131 132 runtime.cxx_test( 133 name = "compute_graph_op_tests", 134 srcs = [ 135 ":generated_op_correctness_tests_cpp[op_tests.cpp]", 136 ], 137 contacts = ["[email protected]"], 138 fbandroid_additional_loaded_sonames = [ 139 "torch-code-gen", 140 "vulkan_graph_runtime", 141 "vulkan_graph_runtime_shaderlib", 142 ], 143 platforms = [ANDROID], 144 use_instrumentation_test = True, 145 deps = [ 146 "//third-party/googletest:gtest_main", 147 "//executorch/backends/vulkan:vulkan_graph_runtime", 148 runtime.external_dep_location("libtorch"), 149 ], 150 ) 151 152 runtime.cxx_binary( 153 name = "sdpa_test_bin", 154 srcs = [ 155 "sdpa_test.cpp", 156 ], 157 compiler_flags = [ 158 "-Wno-unused-variable", 159 ], 160 define_static_target = False, 161 deps = [ 162 "//third-party/googletest:gtest_main", 163 "//executorch/backends/vulkan:vulkan_graph_runtime", 164 "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", 165 ], 166 ) 167 168 runtime.cxx_test( 169 name = "sdpa_test", 170 srcs = [ 171 "sdpa_test.cpp", 172 ], 173 contacts = ["[email protected]"], 174 fbandroid_additional_loaded_sonames = [ 175 "torch-code-gen", 176 "vulkan_graph_runtime", 177 "vulkan_graph_runtime_shaderlib", 178 ], 179 platforms = [ANDROID], 180 use_instrumentation_test = True, 181 deps = [ 182 "//third-party/googletest:gtest_main", 183 "//executorch/backends/vulkan:vulkan_graph_runtime", 184 "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", 185 "//executorch/extension/tensor:tensor", 186 runtime.external_dep_location("libtorch"), 187 ], 188 ) 189 190 runtime.cxx_binary( 191 name = "linear_weight_int4_test_bin", 192 srcs = [ 193 "linear_weight_int4_test.cpp", 194 ], 195 compiler_flags = [ 196 "-Wno-unused-variable", 197 ], 198 define_static_target = False, 199 deps = [ 200 "//third-party/googletest:gtest_main", 201 "//executorch/backends/vulkan:vulkan_graph_runtime", 202 runtime.external_dep_location("libtorch"), 203 ], 204 ) 205 206 runtime.cxx_test( 207 name = "linear_weight_int4_test", 208 srcs = [ 209 "linear_weight_int4_test.cpp", 210 ], 211 contacts = ["[email protected]"], 212 fbandroid_additional_loaded_sonames = [ 213 "torch-code-gen", 214 "vulkan_graph_runtime", 215 "vulkan_graph_runtime_shaderlib", 216 ], 217 platforms = [ANDROID], 218 use_instrumentation_test = True, 219 deps = [ 220 "//third-party/googletest:gtest_main", 221 "//executorch/backends/vulkan:vulkan_graph_runtime", 222 "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", 223 "//executorch/extension/tensor:tensor", 224 runtime.external_dep_location("libtorch"), 225 ], 226 ) 227 228 runtime.cxx_binary( 229 name = "rotary_embedding_test_bin", 230 srcs = [ 231 "rotary_embedding_test.cpp", 232 ], 233 compiler_flags = [ 234 "-Wno-unused-variable", 235 ], 236 define_static_target = False, 237 deps = [ 238 "//third-party/googletest:gtest_main", 239 "//executorch/backends/vulkan:vulkan_graph_runtime", 240 runtime.external_dep_location("libtorch"), 241 ], 242 ) 243 244 runtime.cxx_test( 245 name = "rotary_embedding_test", 246 srcs = [ 247 "rotary_embedding_test.cpp", 248 ], 249 contacts = ["[email protected]"], 250 fbandroid_additional_loaded_sonames = [ 251 "torch-code-gen", 252 "vulkan_graph_runtime", 253 "vulkan_graph_runtime_shaderlib", 254 ], 255 platforms = [ANDROID], 256 use_instrumentation_test = True, 257 deps = [ 258 "//third-party/googletest:gtest_main", 259 "//executorch/backends/vulkan:vulkan_graph_runtime", 260 "//executorch/extension/tensor:tensor", 261 runtime.external_dep_location("libtorch"), 262 ], 263 ) 264