xref: /aosp_15_r20/external/pytorch/c10/cuda/test/build.bzl (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1dsa_tests = [
2    "impl/CUDAAssertionsTest_1_var_test.cu",
3    "impl/CUDAAssertionsTest_catches_stream.cu",
4    "impl/CUDAAssertionsTest_catches_thread_and_block_and_device.cu",
5    "impl/CUDAAssertionsTest_from_2_processes.cu",
6    "impl/CUDAAssertionsTest_multiple_writes_from_blocks_and_threads.cu",
7    "impl/CUDAAssertionsTest_multiple_writes_from_multiple_blocks.cu",
8    "impl/CUDAAssertionsTest_multiple_writes_from_same_block.cu",
9]
10
11def define_targets(rules, gtest_deps):
12    rules.cc_test(
13        name = "test",
14        srcs = [
15            "impl/CUDATest.cpp",
16        ],
17        target_compatible_with = rules.requires_cuda_enabled(),
18        deps = [
19            "//c10/cuda",
20        ] + gtest_deps,
21    )
22
23    for src in dsa_tests:
24        name = src.replace("impl/", "").replace(".cu", "")
25        rules.cuda_library(
26            name = "test_" + name + "_lib",
27            srcs = [
28                src,
29            ],
30            target_compatible_with = rules.requires_cuda_enabled(),
31            deps = [
32                "//c10/cuda",
33            ] + gtest_deps,
34        )
35        rules.cc_test(
36            name = "test_" + name,
37            deps = [
38                ":test_" + name + "_lib",
39            ],
40        )
41