xref: /aosp_15_r20/external/pytorch/aten.bzl (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerload("@bazel_skylib//lib:paths.bzl", "paths")
2*da0073e9SAndroid Build Coastguard Workerload("@rules_cc//cc:defs.bzl", "cc_library")
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard WorkerCPU_CAPABILITY_NAMES = ["DEFAULT", "AVX2"]
5*da0073e9SAndroid Build Coastguard WorkerCAPABILITY_COMPILER_FLAGS = {
6*da0073e9SAndroid Build Coastguard Worker    "AVX2": ["-mavx2", "-mfma", "-mf16c"],
7*da0073e9SAndroid Build Coastguard Worker    "DEFAULT": [],
8*da0073e9SAndroid Build Coastguard Worker}
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard WorkerPREFIX = "aten/src/ATen/native/"
11*da0073e9SAndroid Build Coastguard WorkerEXTRA_PREFIX = "aten/src/ATen/"
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerdef intern_build_aten_ops(copts, deps, extra_impls):
14*da0073e9SAndroid Build Coastguard Worker    for cpu_capability in CPU_CAPABILITY_NAMES:
15*da0073e9SAndroid Build Coastguard Worker        srcs = []
16*da0073e9SAndroid Build Coastguard Worker        for impl in native.glob(
17*da0073e9SAndroid Build Coastguard Worker            [
18*da0073e9SAndroid Build Coastguard Worker                PREFIX + "cpu/*.cpp",
19*da0073e9SAndroid Build Coastguard Worker                PREFIX + "quantized/cpu/kernels/*.cpp",
20*da0073e9SAndroid Build Coastguard Worker            ],
21*da0073e9SAndroid Build Coastguard Worker        ):
22*da0073e9SAndroid Build Coastguard Worker            name = impl.replace(PREFIX, "")
23*da0073e9SAndroid Build Coastguard Worker            out = PREFIX + name + "." + cpu_capability + ".cpp"
24*da0073e9SAndroid Build Coastguard Worker            native.genrule(
25*da0073e9SAndroid Build Coastguard Worker                name = name + "_" + cpu_capability + "_cp",
26*da0073e9SAndroid Build Coastguard Worker                srcs = [impl],
27*da0073e9SAndroid Build Coastguard Worker                outs = [out],
28*da0073e9SAndroid Build Coastguard Worker                cmd = "cp $< $@",
29*da0073e9SAndroid Build Coastguard Worker            )
30*da0073e9SAndroid Build Coastguard Worker            srcs.append(out)
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker        for impl in extra_impls:
33*da0073e9SAndroid Build Coastguard Worker            name = impl.replace(EXTRA_PREFIX, "")
34*da0073e9SAndroid Build Coastguard Worker            out = EXTRA_PREFIX + name + "." + cpu_capability + ".cpp"
35*da0073e9SAndroid Build Coastguard Worker            native.genrule(
36*da0073e9SAndroid Build Coastguard Worker                name = name + "_" + cpu_capability + "_cp",
37*da0073e9SAndroid Build Coastguard Worker                srcs = [impl],
38*da0073e9SAndroid Build Coastguard Worker                outs = [out],
39*da0073e9SAndroid Build Coastguard Worker                cmd = "cp $< $@",
40*da0073e9SAndroid Build Coastguard Worker            )
41*da0073e9SAndroid Build Coastguard Worker            srcs.append(out)
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker        cc_library(
44*da0073e9SAndroid Build Coastguard Worker            name = "ATen_CPU_" + cpu_capability,
45*da0073e9SAndroid Build Coastguard Worker            srcs = srcs,
46*da0073e9SAndroid Build Coastguard Worker            copts = copts + [
47*da0073e9SAndroid Build Coastguard Worker                "-DCPU_CAPABILITY=" + cpu_capability,
48*da0073e9SAndroid Build Coastguard Worker                "-DCPU_CAPABILITY_" + cpu_capability,
49*da0073e9SAndroid Build Coastguard Worker            ] + CAPABILITY_COMPILER_FLAGS[cpu_capability],
50*da0073e9SAndroid Build Coastguard Worker            deps = deps,
51*da0073e9SAndroid Build Coastguard Worker            linkstatic = 1,
52*da0073e9SAndroid Build Coastguard Worker        )
53*da0073e9SAndroid Build Coastguard Worker    cc_library(
54*da0073e9SAndroid Build Coastguard Worker        name = "ATen_CPU",
55*da0073e9SAndroid Build Coastguard Worker        deps = [":ATen_CPU_" + cpu_capability for cpu_capability in CPU_CAPABILITY_NAMES],
56*da0073e9SAndroid Build Coastguard Worker        linkstatic = 1,
57*da0073e9SAndroid Build Coastguard Worker    )
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Workerdef generate_aten_impl(ctx):
60*da0073e9SAndroid Build Coastguard Worker    # Declare the entire ATen/ops/ directory as an output
61*da0073e9SAndroid Build Coastguard Worker    ops_dir = ctx.actions.declare_directory("aten/src/ATen/ops")
62*da0073e9SAndroid Build Coastguard Worker    outputs = [ops_dir] + ctx.outputs.outs
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    install_dir = paths.dirname(ops_dir.path)
65*da0073e9SAndroid Build Coastguard Worker    ctx.actions.run(
66*da0073e9SAndroid Build Coastguard Worker        outputs = outputs,
67*da0073e9SAndroid Build Coastguard Worker        inputs = ctx.files.srcs,
68*da0073e9SAndroid Build Coastguard Worker        executable = ctx.executable.generator,
69*da0073e9SAndroid Build Coastguard Worker        arguments = [
70*da0073e9SAndroid Build Coastguard Worker            "--source-path",
71*da0073e9SAndroid Build Coastguard Worker            "aten/src/ATen",
72*da0073e9SAndroid Build Coastguard Worker            "--per-operator-headers",
73*da0073e9SAndroid Build Coastguard Worker            "--install_dir",
74*da0073e9SAndroid Build Coastguard Worker            install_dir,
75*da0073e9SAndroid Build Coastguard Worker        ],
76*da0073e9SAndroid Build Coastguard Worker        use_default_shell_env = True,
77*da0073e9SAndroid Build Coastguard Worker        mnemonic = "GenerateAten",
78*da0073e9SAndroid Build Coastguard Worker    )
79*da0073e9SAndroid Build Coastguard Worker    return [DefaultInfo(files = depset(outputs))]
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Workergenerate_aten = rule(
82*da0073e9SAndroid Build Coastguard Worker    implementation = generate_aten_impl,
83*da0073e9SAndroid Build Coastguard Worker    attrs = {
84*da0073e9SAndroid Build Coastguard Worker        "generator": attr.label(
85*da0073e9SAndroid Build Coastguard Worker            executable = True,
86*da0073e9SAndroid Build Coastguard Worker            allow_files = True,
87*da0073e9SAndroid Build Coastguard Worker            mandatory = True,
88*da0073e9SAndroid Build Coastguard Worker            cfg = "exec",
89*da0073e9SAndroid Build Coastguard Worker        ),
90*da0073e9SAndroid Build Coastguard Worker        "outs": attr.output_list(),
91*da0073e9SAndroid Build Coastguard Worker        "srcs": attr.label_list(allow_files = True),
92*da0073e9SAndroid Build Coastguard Worker    },
93*da0073e9SAndroid Build Coastguard Worker)
94