xref: /aosp_15_r20/external/pytorch/ufunc_defs.bzl (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1load("@bazel_skylib//lib:paths.bzl", "paths")
2load(":build_variables.bzl", "aten_ufunc_headers")
3
4aten_ufunc_names = [
5    paths.split_extension(paths.basename(h))[0]
6    for h in aten_ufunc_headers
7]
8
9def aten_ufunc_generated_cpu_sources(gencode_pattern = "{}"):
10    return [gencode_pattern.format(name) for name in [
11        "UfuncCPU_{}.cpp".format(n)
12        for n in aten_ufunc_names
13    ]]
14
15def aten_ufunc_generated_cpu_kernel_sources(gencode_pattern = "{}"):
16    return [gencode_pattern.format(name) for name in [
17        "UfuncCPUKernel_{}.cpp".format(n)
18        for n in aten_ufunc_names
19    ]]
20
21def aten_ufunc_generated_cuda_sources(gencode_pattern = "{}"):
22    return [gencode_pattern.format(name) for name in [
23        "UfuncCUDA_{}.cu".format(n)
24        for n in aten_ufunc_names
25    ]]
26