xref: /aosp_15_r20/external/pytorch/ufunc_defs.bzl (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerload("@bazel_skylib//lib:paths.bzl", "paths")
2*da0073e9SAndroid Build Coastguard Workerload(":build_variables.bzl", "aten_ufunc_headers")
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workeraten_ufunc_names = [
5*da0073e9SAndroid Build Coastguard Worker    paths.split_extension(paths.basename(h))[0]
6*da0073e9SAndroid Build Coastguard Worker    for h in aten_ufunc_headers
7*da0073e9SAndroid Build Coastguard Worker]
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerdef aten_ufunc_generated_cpu_sources(gencode_pattern = "{}"):
10*da0073e9SAndroid Build Coastguard Worker    return [gencode_pattern.format(name) for name in [
11*da0073e9SAndroid Build Coastguard Worker        "UfuncCPU_{}.cpp".format(n)
12*da0073e9SAndroid Build Coastguard Worker        for n in aten_ufunc_names
13*da0073e9SAndroid Build Coastguard Worker    ]]
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerdef aten_ufunc_generated_cpu_kernel_sources(gencode_pattern = "{}"):
16*da0073e9SAndroid Build Coastguard Worker    return [gencode_pattern.format(name) for name in [
17*da0073e9SAndroid Build Coastguard Worker        "UfuncCPUKernel_{}.cpp".format(n)
18*da0073e9SAndroid Build Coastguard Worker        for n in aten_ufunc_names
19*da0073e9SAndroid Build Coastguard Worker    ]]
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerdef aten_ufunc_generated_cuda_sources(gencode_pattern = "{}"):
22*da0073e9SAndroid Build Coastguard Worker    return [gencode_pattern.format(name) for name in [
23*da0073e9SAndroid Build Coastguard Worker        "UfuncCUDA_{}.cu".format(n)
24*da0073e9SAndroid Build Coastguard Worker        for n in aten_ufunc_names
25*da0073e9SAndroid Build Coastguard Worker    ]]
26