xref: /aosp_15_r20/external/pytorch/BUILD.bazel (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1load("@bazel_skylib//lib:paths.bzl", "paths")
2load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
3load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test")
4load("@rules_python//python:defs.bzl", "py_library", "py_test")
5load("@pytorch//third_party:substitution.bzl", "header_template_rule", "template_rule")
6load("@pytorch//:tools/bazel.bzl", "rules")
7load("@pytorch//tools/rules:cu.bzl", "cu_library")
8load("@pytorch//tools/config:defs.bzl", "if_cuda")
9load("@pytorch//:aten.bzl", "generate_aten", "intern_build_aten_ops")
10load(":build.bzl", "GENERATED_AUTOGRAD_CPP", "GENERATED_AUTOGRAD_PYTHON", "define_targets")
11load(":build_variables.bzl", "jit_core_sources", "lazy_tensor_ts_sources", "libtorch_core_sources", "libtorch_cuda_sources", "libtorch_distributed_sources", "libtorch_extra_sources", "libtorch_python_core_sources", "torch_cpp_srcs", "libtorch_python_cuda_sources", "libtorch_python_distributed_sources")
12load(":ufunc_defs.bzl", "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cuda_sources")
13load("//:tools/bazel.bzl", "rules")
14
15define_targets(rules = rules)
16
17COMMON_COPTS = [
18    "-DHAVE_MALLOC_USABLE_SIZE=1",
19    "-DHAVE_MMAP=1",
20    "-DHAVE_SHM_OPEN=1",
21    "-DHAVE_SHM_UNLINK=1",
22    "-D_FILE_OFFSET_BITS=64",
23    "-DUSE_FBGEMM",
24    "-DUSE_DISTRIBUTED",
25    "-DAT_PER_OPERATOR_HEADERS",
26    "-DATEN_THREADING=NATIVE",
27    "-DNO_CUDNN_DESTROY_HANDLE",
28] + if_cuda([
29    "-DUSE_CUDA",
30    "-DUSE_CUDNN",
31    # TODO: This should be passed only when building for CUDA-11.5 or newer
32    # use cub in a safe manner, see:
33    # https://github.com/pytorch/pytorch/pull/55292
34    "-DCUB_WRAPPED_NAMESPACE=at_cuda_detail",
35])
36
37aten_generation_srcs = ["aten/src/ATen/native/native_functions.yaml"] + ["aten/src/ATen/native/tags.yaml"] + glob(["aten/src/ATen/templates/**"])
38
39generated_cpu_cpp = [
40    "aten/src/ATen/RegisterBackendSelect.cpp",
41    "aten/src/ATen/RegisterCPU.cpp",
42    "aten/src/ATen/RegisterFunctionalization_0.cpp",
43    "aten/src/ATen/RegisterFunctionalization_1.cpp",
44    "aten/src/ATen/RegisterFunctionalization_2.cpp",
45    "aten/src/ATen/RegisterFunctionalization_3.cpp",
46    # "aten/src/ATen/RegisterFunctionalizationEverything.cpp",
47    "aten/src/ATen/RegisterMkldnnCPU.cpp",
48    "aten/src/ATen/RegisterNestedTensorCPU.cpp",
49    "aten/src/ATen/RegisterQuantizedCPU.cpp",
50    "aten/src/ATen/RegisterSparseCPU.cpp",
51    "aten/src/ATen/RegisterSparseCsrCPU.cpp",
52    "aten/src/ATen/RegisterZeroTensor.cpp",
53    "aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
54    "aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp",
55    "aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
56    "aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp",
57    "aten/src/ATen/RegisterMeta.cpp",
58    "aten/src/ATen/RegisterSparseMeta.cpp",
59    "aten/src/ATen/RegisterQuantizedMeta.cpp",
60    "aten/src/ATen/RegisterNestedTensorMeta.cpp",
61    "aten/src/ATen/RegisterSchema.cpp",
62    "aten/src/ATen/CPUFunctions.h",
63    "aten/src/ATen/CPUFunctions_inl.h",
64    "aten/src/ATen/CompositeExplicitAutogradFunctions.h",
65    "aten/src/ATen/CompositeExplicitAutogradFunctions_inl.h",
66    "aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
67    "aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
68    "aten/src/ATen/CompositeImplicitAutogradFunctions.h",
69    "aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h",
70    "aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions.h",
71    "aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h",
72    "aten/src/ATen/CompositeViewCopyKernels.cpp",
73    "aten/src/ATen/FunctionalInverses.h",
74    "aten/src/ATen/Functions.h",
75    "aten/src/ATen/Functions.cpp",
76    "aten/src/ATen/RedispatchFunctions.h",
77    "aten/src/ATen/Operators.h",
78    "aten/src/ATen/Operators_0.cpp",
79    "aten/src/ATen/Operators_1.cpp",
80    "aten/src/ATen/Operators_2.cpp",
81    "aten/src/ATen/Operators_3.cpp",
82    "aten/src/ATen/Operators_4.cpp",
83    "aten/src/ATen/NativeFunctions.h",
84    "aten/src/ATen/MetaFunctions.h",
85    "aten/src/ATen/MetaFunctions_inl.h",
86    "aten/src/ATen/MethodOperators.h",
87    "aten/src/ATen/NativeMetaFunctions.h",
88    "aten/src/ATen/RegistrationDeclarations.h",
89    "aten/src/ATen/VmapGeneratedPlumbing.h",
90    "aten/src/ATen/core/aten_interned_strings.h",
91    "aten/src/ATen/core/enum_tag.h",
92    "aten/src/ATen/core/TensorBody.h",
93    "aten/src/ATen/core/TensorMethods.cpp",
94    "aten/src/ATen/core/ATenOpList.cpp",
95]
96
97generated_cuda_cpp = [
98    "aten/src/ATen/CUDAFunctions.h",
99    "aten/src/ATen/CUDAFunctions_inl.h",
100    "aten/src/ATen/RegisterCUDA.cpp",
101    "aten/src/ATen/RegisterNestedTensorCUDA.cpp",
102    "aten/src/ATen/RegisterQuantizedCUDA.cpp",
103    "aten/src/ATen/RegisterSparseCUDA.cpp",
104    "aten/src/ATen/RegisterSparseCsrCUDA.cpp",
105]
106
107generate_aten(
108    name = "generated_aten_cpp",
109    srcs = aten_generation_srcs,
110    outs = (
111        generated_cpu_cpp +
112        generated_cuda_cpp +
113        aten_ufunc_generated_cpu_sources("aten/src/ATen/{}") +
114        aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}") +
115        aten_ufunc_generated_cuda_sources("aten/src/ATen/{}") + [
116            "aten/src/ATen/Declarations.yaml",
117        ]
118    ),
119    generator = "//torchgen:gen",
120)
121
122filegroup(
123    name = "cpp_generated_code",
124    srcs = GENERATED_AUTOGRAD_CPP,
125    data = [":generate-code"],
126)
127
128# ATen
129filegroup(
130    name = "aten_base_cpp",
131    srcs = glob([
132        "aten/src/ATen/*.cpp",
133        "aten/src/ATen/functorch/*.cpp",
134        "aten/src/ATen/detail/*.cpp",
135        "aten/src/ATen/cpu/*.cpp",
136    ]),
137)
138
139filegroup(
140    name = "ATen_CORE_SRCS",
141    srcs = glob(
142        [
143            "aten/src/ATen/core/**/*.cpp",
144        ],
145        exclude = [
146            "aten/src/ATen/core/**/*_test.cpp",
147        ],
148    ),
149)
150
151filegroup(
152    name = "aten_native_cpp",
153    srcs = glob(["aten/src/ATen/native/*.cpp"]),
154)
155
156filegroup(
157    name = "aten_native_sparse_cpp",
158    srcs = glob(["aten/src/ATen/native/sparse/*.cpp"]),
159)
160
161filegroup(
162    name = "aten_native_nested_cpp",
163    srcs = glob(["aten/src/ATen/native/nested/*.cpp"]),
164)
165
166filegroup(
167    name = "aten_native_quantized_cpp",
168    srcs = glob(
169        [
170            "aten/src/ATen/native/quantized/*.cpp",
171            "aten/src/ATen/native/quantized/cpu/*.cpp",
172        ],
173    ),
174)
175
176filegroup(
177    name = "aten_native_transformers_cpp",
178    srcs = glob(["aten/src/ATen/native/transformers/*.cpp"]),
179)
180
181filegroup(
182    name = "aten_native_mkl_cpp",
183    srcs = glob([
184        "aten/src/ATen/native/mkl/*.cpp",
185        "aten/src/ATen/mkl/*.cpp",
186    ]),
187)
188
189filegroup(
190    name = "aten_native_mkldnn_cpp",
191    srcs = glob(["aten/src/ATen/native/mkldnn/*.cpp"]),
192)
193
194filegroup(
195    name = "aten_native_xnnpack",
196    srcs = glob(["aten/src/ATen/native/xnnpack/*.cpp"]),
197)
198
199filegroup(
200    name = "aten_base_vulkan",
201    srcs = glob(["aten/src/ATen/vulkan/*.cpp"]),
202)
203
204filegroup(
205    name = "aten_base_metal",
206    srcs = glob(["aten/src/ATen/metal/*.cpp"]),
207)
208
209filegroup(
210    name = "ATen_QUANTIZED_SRCS",
211    srcs = glob(
212        [
213            "aten/src/ATen/quantized/**/*.cpp",
214        ],
215        exclude = [
216            "aten/src/ATen/quantized/**/*_test.cpp",
217        ],
218    ),
219)
220
221filegroup(
222    name = "aten_cuda_cpp_srcs",
223    srcs = glob(
224        [
225            "aten/src/ATen/cuda/*.cpp",
226            "aten/src/ATen/cuda/detail/*.cpp",
227            "aten/src/ATen/cuda/tunable/*.cpp",
228            "aten/src/ATen/cudnn/*.cpp",
229            "aten/src/ATen/native/cuda/*.cpp",
230            "aten/src/ATen/native/cuda/linalg/*.cpp",
231            "aten/src/ATen/native/cudnn/*.cpp",
232            "aten/src/ATen/native/miopen/*.cpp",
233            "aten/src/ATen/native/nested/cuda/*.cpp",
234            "aten/src/ATen/native/quantized/cuda/*.cpp",
235            "aten/src/ATen/native/quantized/cudnn/*.cpp",
236            "aten/src/ATen/native/sparse/cuda/*.cpp",
237            "aten/src/ATen/native/transformers/cuda/*.cpp",
238        ],
239    ),
240)
241
242filegroup(
243    name = "aten_cu_srcs",
244    srcs = glob([
245        "aten/src/ATen/cuda/*.cu",
246        "aten/src/ATen/cuda/detail/*.cu",
247        "aten/src/ATen/native/cuda/*.cu",
248        "aten/src/ATen/native/nested/cuda/*.cu",
249        "aten/src/ATen/native/quantized/cuda/*.cu",
250        "aten/src/ATen/native/sparse/cuda/*.cu",
251        "aten/src/ATen/native/transformers/cuda/*.cu",
252    ]) + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}"),
253    # It's a bit puzzling to me why it's not necessary to declare the
254    # target that generates these sources...
255)
256
257header_template_rule(
258    name = "aten_src_ATen_config",
259    src = "aten/src/ATen/Config.h.in",
260    out = "aten/src/ATen/Config.h",
261    include = "aten/src",
262    substitutions = {
263        "@AT_MKLDNN_ENABLED@": "1",
264        "@AT_MKLDNN_ACL_ENABLED@": "0",
265        "@AT_MKL_ENABLED@": "1",
266        "@AT_MKL_SEQUENTIAL@": "0",
267        "@AT_POCKETFFT_ENABLED@": "0",
268        "@AT_NNPACK_ENABLED@": "0",
269        "@CAFFE2_STATIC_LINK_CUDA_INT@": "0",
270        "@AT_BUILD_WITH_BLAS@": "1",
271        "@AT_BUILD_WITH_LAPACK@": "1",
272        "@AT_PARALLEL_OPENMP@": "0",
273        "@AT_PARALLEL_NATIVE@": "1",
274        "@AT_BLAS_F2C@": "0",
275        "@AT_BLAS_USE_CBLAS_DOT@": "1",
276    },
277)
278
279header_template_rule(
280    name = "aten_src_ATen_cuda_config",
281    src = "aten/src/ATen/cuda/CUDAConfig.h.in",
282    out = "aten/src/ATen/cuda/CUDAConfig.h",
283    include = "aten/src",
284    substitutions = {
285        "@AT_CUDNN_ENABLED@": "1",
286        "@AT_CUSPARSELT_ENABLED@": "0",
287        "@AT_ROCM_ENABLED@": "0",
288        "@AT_MAGMA_ENABLED@": "0",
289        "@NVCC_FLAGS_EXTRA@": "",
290    },
291)
292
293cc_library(
294    name = "aten_headers",
295    hdrs = [
296        "torch/csrc/Export.h",
297        "torch/csrc/jit/frontend/function_schema_parser.h",
298    ] + glob(
299        [
300            "aten/src/**/*.h",
301            "aten/src/**/*.hpp",
302            "aten/src/ATen/cuda/**/*.cuh",
303            "aten/src/ATen/native/**/*.cuh",
304            "aten/src/THC/*.cuh",
305        ],
306    ) + [
307        ":aten_src_ATen_config",
308        ":generated_aten_cpp",
309    ],
310    includes = [
311        "aten/src",
312    ],
313    deps = [
314        "//c10",
315    ],
316)
317
318ATEN_COPTS = COMMON_COPTS + [
319    "-DCAFFE2_BUILD_MAIN_LIBS",
320    "-DHAVE_AVX_CPU_DEFINITION",
321    "-DHAVE_AVX2_CPU_DEFINITION",
322    "-fvisibility-inlines-hidden",
323    "-fno-math-errno",
324    "-fno-trapping-math",
325]
326
327intern_build_aten_ops(
328    copts = ATEN_COPTS,
329    extra_impls = aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}"),
330    deps = [
331        ":aten_headers",
332        "@fbgemm",
333        "@mkl",
334        "@sleef",
335        "@mkl_dnn//:mkl-dnn",
336    ],
337)
338
339cc_library(
340    name = "aten",
341    srcs = [
342        ":ATen_CORE_SRCS",
343        ":ATen_QUANTIZED_SRCS",
344        ":aten_base_cpp",
345        ":aten_base_metal",
346        ":aten_base_vulkan",
347        ":aten_native_cpp",
348        ":aten_native_mkl_cpp",
349        ":aten_native_mkldnn_cpp",
350        ":aten_native_nested_cpp",
351        ":aten_native_quantized_cpp",
352        ":aten_native_sparse_cpp",
353        ":aten_native_transformers_cpp",
354        ":aten_native_xnnpack",
355        ":aten_src_ATen_config",
356    ] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"),
357    copts = ATEN_COPTS,
358    linkopts = [
359      "-ldl",
360    ],
361    data = if_cuda(
362        [":libcaffe2_nvrtc.so"],
363        [],
364    ),
365    visibility = ["//visibility:public"],
366    deps = [
367        ":ATen_CPU",
368        ":aten_headers",
369        ":caffe2_for_aten_headers",
370        ":torch_headers",
371        "@fbgemm",
372        "@ideep",
373    ],
374    alwayslink = True,
375)
376
377cc_library(
378    name = "aten_nvrtc",
379    srcs = glob([
380        "aten/src/ATen/cuda/nvrtc_stub/*.cpp",
381    ]),
382    copts = ATEN_COPTS,
383    linkstatic = True,
384    visibility = ["//visibility:public"],
385    deps = [
386        ":aten_headers",
387        "//c10",
388        "@cuda",
389        "@cuda//:cuda_driver",
390        "@cuda//:nvrtc",
391    ],
392    alwayslink = True,
393)
394
395cc_binary(
396    name = "libcaffe2_nvrtc.so",
397    linkshared = True,
398    visibility = ["//visibility:public"],
399    deps = [
400        ":aten_nvrtc",
401    ],
402)
403
404cc_library(
405    name = "aten_cuda_cpp",
406    srcs = [":aten_cuda_cpp_srcs"] + generated_cuda_cpp,
407    hdrs = [":aten_src_ATen_cuda_config"],
408    copts = ATEN_COPTS,
409    visibility = ["//visibility:public"],
410    deps = [
411        ":aten",
412        "@cuda",
413        "@cuda//:cusolver",
414        "@cuda//:nvrtc",
415        "@cudnn",
416        "@cudnn_frontend",
417    ],
418    alwayslink = True,
419)
420
421torch_cuda_half_options = [
422    "-DCUDA_HAS_FP16=1",
423    "-D__CUDA_NO_HALF_OPERATORS__",
424    "-D__CUDA_NO_HALF_CONVERSIONS__",
425    "-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
426    "-D__CUDA_NO_HALF2_OPERATORS__",
427]
428
429cu_library(
430    name = "aten_cuda",
431    srcs = [":aten_cu_srcs"],
432    copts = ATEN_COPTS + torch_cuda_half_options,
433    visibility = ["//visibility:public"],
434    deps = [
435        ":aten_cuda_cpp",
436        "//c10/util:bit_cast",
437        "@cuda//:cublas",
438        "@cuda//:cufft",
439        "@cuda//:cusparse",
440        "@cutlass",
441    ],
442    alwayslink = True,
443)
444
445# caffe2
446CAFFE2_COPTS = COMMON_COPTS + [
447    "-Dcaffe2_EXPORTS",
448    "-DCAFFE2_USE_CUDNN",
449    "-DCAFFE2_BUILD_MAIN_LIB",
450    "-fvisibility-inlines-hidden",
451    "-fno-math-errno",
452    "-fno-trapping-math",
453]
454
455filegroup(
456    name = "caffe2_core_srcs",
457    srcs = [
458        "caffe2/core/common.cc",
459    ],
460)
461
462filegroup(
463    name = "caffe2_perfkernels_srcs",
464    srcs = [
465        "caffe2/perfkernels/embedding_lookup_idx.cc",
466    ],
467)
468
469
470filegroup(
471    name = "caffe2_serialize_srcs",
472    srcs = [
473        "caffe2/serialize/file_adapter.cc",
474        "caffe2/serialize/inline_container.cc",
475        "caffe2/serialize/istream_adapter.cc",
476        "caffe2/serialize/read_adapter_interface.cc",
477    ],
478)
479
480filegroup(
481    name = "caffe2_utils_srcs",
482    srcs = [
483        "caffe2/utils/proto_wrap.cc",
484        "caffe2/utils/string_utils.cc",
485        "caffe2/utils/threadpool/ThreadPool.cc",
486        "caffe2/utils/threadpool/pthreadpool.cc",
487        "caffe2/utils/threadpool/pthreadpool_impl.cc",
488        "caffe2/utils/threadpool/thread_pool_guard.cpp",
489    ],
490)
491
492# To achieve finer granularity and make debug easier, caffe2 is split into three libraries:
493# ATen, caffe2 and caffe2_for_aten_headers. ATen lib group up source codes under
494# aten/ directory and caffe2 contains most files under `caffe2/` directory. Since the
495# ATen lib and the caffe2 lib would depend on each other, `caffe2_for_aten_headers` is splitted
496# out from `caffe2` to avoid dependency cycle.
497cc_library(
498    name = "caffe2_for_aten_headers",
499    hdrs = [
500        "caffe2/core/common.h",
501        "caffe2/perfkernels/common.h",
502        "caffe2/perfkernels/embedding_lookup_idx.h",
503        "caffe2/utils/fixed_divisor.h",
504    ] + glob([
505        "caffe2/utils/threadpool/*.h",
506    ]),
507    copts = CAFFE2_COPTS,
508    visibility = ["//visibility:public"],
509    deps = [
510        ":caffe2_core_macros",
511        "//c10",
512    ],
513)
514
515cc_library(
516    name = "caffe2_headers",
517    hdrs = glob(
518        [
519            "caffe2/perfkernels/*.h",
520            "caffe2/serialize/*.h",
521            "caffe2/utils/*.h",
522            "caffe2/utils/threadpool/*.h",
523            "modules/**/*.h",
524        ],
525        exclude = [
526            "caffe2/core/macros.h",
527        ],
528    ) + if_cuda(glob([
529        "caffe2/**/*.cuh",
530    ])),
531    copts = CAFFE2_COPTS,
532    visibility = ["//visibility:public"],
533    deps = [
534        ":caffe2_core_macros",
535        ":caffe2_for_aten_headers",
536    ],
537)
538
539cc_library(
540    name = "caffe2",
541    srcs = [
542        ":caffe2_core_srcs",
543        ":caffe2_perfkernels_srcs",
544        ":caffe2_serialize_srcs",
545        ":caffe2_utils_srcs",
546    ],
547    copts = CAFFE2_COPTS + ["-mf16c"],
548    linkstatic = 1,
549    visibility = ["//visibility:public"],
550    deps = [
551        ":caffe2_core_macros",
552        ":caffe2_headers",
553        ":caffe2_perfkernels_avx",
554        ":caffe2_perfkernels_avx2",
555        "//third_party/miniz-2.1.0:miniz",
556        "@com_google_protobuf//:protobuf",
557        "@eigen",
558        "@fbgemm//:fbgemm_src_headers",
559        "@fmt",
560        "@onnx",
561    ] + if_cuda(
562        [
563            ":aten_cuda",
564            "@tensorpipe//:tensorpipe_cuda",
565        ],
566        [
567            ":aten",
568            "@tensorpipe//:tensorpipe_cpu",
569        ],
570    ),
571    alwayslink = True,
572)
573
574cu_library(
575    name = "torch_cuda",
576    srcs = [
577        "torch/csrc/distributed/c10d/intra_node_comm.cu",
578        "torch/csrc/distributed/c10d/NanCheck.cu",
579        "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
580    ],
581    copts = torch_cuda_half_options,
582    visibility = ["//visibility:public"],
583    deps = [
584        ":aten",
585        "@cuda//:cublas",
586        "@cuda//:curand",
587        "@cudnn",
588        "@eigen",
589        "@tensorpipe//:tensorpipe_cuda",
590    ],
591    alwayslink = True,
592)
593
594PERF_COPTS = [
595    "-DHAVE_AVX_CPU_DEFINITION",
596    "-DHAVE_AVX2_CPU_DEFINITION",
597    "-DENABLE_ALIAS=1",
598    "-DHAVE_MALLOC_USABLE_SIZE=1",
599    "-DHAVE_MMAP=1",
600    "-DHAVE_SHM_OPEN=1",
601    "-DHAVE_SHM_UNLINK=1",
602    "-DSLEEF_STATIC_LIBS=1",
603    "-DTH_BALS_MKL",
604    "-D_FILE_OFFSET_BITS=64",
605    "-DUSE_FBGEMM",
606    "-fvisibility-inlines-hidden",
607    "-Wunused-parameter",
608    "-fno-math-errno",
609    "-fno-trapping-math",
610    "-mf16c",
611]
612
613PERF_HEADERS = glob([
614    "caffe2/perfkernels/*.h",
615    "caffe2/core/*.h",
616])
617
618cc_library(
619    name = "caffe2_perfkernels_avx",
620    srcs = glob([
621        "caffe2/perfkernels/*_avx.cc",
622    ]),
623    hdrs = PERF_HEADERS,
624    copts = PERF_COPTS + [
625        "-mavx",
626    ],
627    visibility = ["//visibility:public"],
628    deps = [
629        ":caffe2_headers",
630        "//c10",
631    ],
632    alwayslink = True,
633)
634
635cc_library(
636    name = "caffe2_perfkernels_avx2",
637    srcs = glob([
638        "caffe2/perfkernels/*_avx2.cc",
639    ]),
640    hdrs = PERF_HEADERS,
641    copts = PERF_COPTS + [
642        "-mavx2",
643        "-mfma",
644        "-mavx",
645    ],
646    visibility = ["//visibility:public"],
647    deps = [
648        ":caffe2_headers",
649        "//c10",
650    ],
651    alwayslink = True,
652)
653
654# torch
655torch_cuda_headers = glob(["torch/csrc/cuda/*.h"])
656
657cc_library(
658    name = "torch_headers",
659    hdrs = if_cuda(
660        torch_cuda_headers,
661    ) + glob(
662        [
663            "torch/*.h",
664            "torch/csrc/**/*.h",
665            "torch/csrc/distributed/c10d/**/*.hpp",
666            "torch/lib/libshm/*.h",
667        ],
668        exclude = [
669            "torch/csrc/*/generated/*.h",
670        ] + torch_cuda_headers,
671    ) + GENERATED_AUTOGRAD_CPP + [":version_h"],
672    includes = [
673        "third_party/kineto/libkineto/include",
674        "torch/csrc",
675        "torch/csrc/api/include",
676        "torch/csrc/distributed",
677        "torch/lib",
678        "torch/lib/libshm",
679    ],
680    visibility = ["//visibility:public"],
681    deps = [
682        ":aten_headers",
683        ":caffe2_headers",
684        "//c10",
685        "@com_github_google_flatbuffers//:flatbuffers",
686        "@local_config_python//:python_headers",
687        "@onnx",
688    ],
689    alwayslink = True,
690)
691
692TORCH_COPTS = COMMON_COPTS + [
693    "-Dtorch_EXPORTS",
694    "-DHAVE_AVX_CPU_DEFINITION",
695    "-DHAVE_AVX2_CPU_DEFINITION",
696    "-DCAFFE2_USE_GLOO",
697    "-fvisibility-inlines-hidden",
698    "-fno-math-errno ",
699    "-fno-trapping-math",
700    "-Wno-error=unused-function",
701]
702
703torch_sources = {
704    k: ""
705    for k in (
706        libtorch_core_sources +
707        libtorch_distributed_sources +
708        torch_cpp_srcs +
709        libtorch_extra_sources +
710        jit_core_sources +
711        lazy_tensor_ts_sources +
712        GENERATED_AUTOGRAD_CPP
713    )
714}.keys()
715
716cc_library(
717    name = "torch",
718    srcs = if_cuda(glob(
719        libtorch_cuda_sources,
720        exclude = [
721            "torch/csrc/cuda/python_nccl.cpp",
722            "torch/csrc/cuda/nccl.cpp",
723            "torch/csrc/distributed/c10d/intra_node_comm.cu",
724            "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu",
725            "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu",
726            "torch/csrc/distributed/c10d/NanCheck.cu",
727            "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
728        ],
729    )) + torch_sources,
730    copts = TORCH_COPTS,
731    linkopts = [
732      "-lrt",
733    ],
734    defines = [
735        "CAFFE2_NIGHTLY_VERSION=20200115",
736    ],
737    visibility = ["//visibility:public"],
738    deps = [
739        ":caffe2",
740        ":torch_headers",
741        "@kineto",
742        "@cpp-httplib",
743        "@nlohmann",
744    ] + if_cuda([
745        "@cuda//:nvToolsExt",
746        "@cutlass",
747        ":torch_cuda",
748    ]),
749    alwayslink = True,
750)
751
752cc_library(
753    name = "shm",
754    srcs = glob(["torch/lib/libshm/*.cpp"]),
755    linkopts = [
756      "-lrt",
757    ],
758    deps = [
759        ":torch",
760    ],
761)
762
763cc_library(
764    name = "libtorch_headers",
765    hdrs = glob([
766        "**/*.h",
767        "**/*.cuh",
768    ]) + [
769        # We need the filegroup here because the raw list causes Bazel
770        # to see duplicate files. It knows how to deduplicate with the
771        # filegroup.
772        ":cpp_generated_code",
773    ],
774    includes = [
775        "torch/csrc/api/include",
776        "torch/csrc/distributed",
777        "torch/lib",
778        "torch/lib/libshm",
779    ],
780    visibility = ["//visibility:public"],
781    deps = [
782        ":torch_headers",
783    ],
784)
785
786cc_library(
787    name = "torch_python",
788    srcs = libtorch_python_core_sources
789        + if_cuda(libtorch_python_cuda_sources)
790        + if_cuda(libtorch_python_distributed_sources)
791        + GENERATED_AUTOGRAD_PYTHON,
792    hdrs = glob([
793        "torch/csrc/generic/*.cpp",
794    ]),
795    copts = COMMON_COPTS + if_cuda(["-DUSE_CUDA=1"]),
796    deps = [
797        ":torch",
798        ":shm",
799        "@pybind11",
800    ],
801)
802
803pybind_extension(
804    name = "torch/_C",
805    srcs = ["torch/csrc/stub.c"],
806    deps = [
807        ":torch_python",
808        ":aten_nvrtc",
809    ],
810)
811
812cc_library(
813    name = "functorch",
814    hdrs = glob([
815        "functorch/csrc/dim/*.h",
816    ]),
817    srcs = glob([
818        "functorch/csrc/dim/*.cpp",
819    ]),
820    deps = [
821        ":aten_nvrtc",
822        ":torch_python",
823        "@pybind11",
824    ],
825)
826
827pybind_extension(
828    name = "functorch/_C",
829    copts=[
830        "-DTORCH_EXTENSION_NAME=_C"
831    ],
832    srcs = [
833        "functorch/csrc/init_dim_only.cpp",
834    ],
835    deps = [
836        ":functorch",
837        ":torch_python",
838        ":aten_nvrtc",
839    ],
840)
841
842cc_binary(
843    name = "torch/bin/torch_shm_manager",
844    srcs = [
845        "torch/lib/libshm/manager.cpp",
846    ],
847    deps = [
848        ":shm",
849    ],
850    linkstatic = False,
851)
852
853template_rule(
854    name = "gen_version_py",
855    src = ":torch/version.py.tpl",
856    out = "torch/version.py",
857    substitutions = if_cuda({
858        # Set default to 11.2. Otherwise Torchvision complains about incompatibility.
859        "{{CUDA_VERSION}}": "11.2",
860        "{{VERSION}}": "2.0.0",
861    }, {
862        "{{CUDA_VERSION}}": "None",
863        "{{VERSION}}": "2.0.0",
864    }),
865)
866
867py_library(
868    name = "pytorch_py",
869    visibility = ["//visibility:public"],
870    srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"] + glob(["functorch/**/*.py"]),
871    deps = [
872        rules.requirement("numpy"),
873        rules.requirement("pyyaml"),
874        rules.requirement("requests"),
875        rules.requirement("setuptools"),
876        rules.requirement("sympy"),
877        rules.requirement("typing_extensions"),
878        "//torchgen",
879    ],
880    data = [
881        ":torch/_C.so",
882        ":functorch/_C.so",
883        ":torch/bin/torch_shm_manager",
884    ],
885)
886
887# cpp api tests
888cc_library(
889    name = "test_support",
890    testonly = True,
891    srcs = [
892        "test/cpp/api/support.cpp",
893    ],
894    hdrs = [
895        "test/cpp/api/init_baseline.h",
896        "test/cpp/api/optim_baseline.h",
897        "test/cpp/api/support.h",
898        "test/cpp/common/support.h",
899    ],
900    deps = [
901        ":torch",
902        "@com_google_googletest//:gtest_main",
903    ],
904)
905
906# Torch integration tests rely on a labeled data set from the MNIST database.
907# http://yann.lecun.com/exdb/mnist/
908
909cpp_api_tests = glob(
910    ["test/cpp/api/*.cpp"],
911    exclude = [
912        "test/cpp/api/imethod.cpp",
913        "test/cpp/api/integration.cpp",
914    ],
915)
916
917cc_test(
918    name = "integration_test",
919    size = "medium",
920    srcs = ["test/cpp/api/integration.cpp"],
921    data = [
922        ":download_mnist",
923    ],
924    tags = [
925        "gpu-required",
926    ],
927    deps = [
928        ":test_support",
929        "@com_google_googletest//:gtest_main",
930    ],
931)
932
933[
934    cc_test(
935        name = paths.split_extension(paths.basename(filename))[0].replace("-", "_") + "_test",
936        size = "medium",
937        srcs = [filename],
938        deps = [
939            ":test_support",
940            "@com_google_googletest//:gtest_main",
941        ],
942    )
943    for filename in cpp_api_tests
944]
945
946test_suite(
947    name = "api_tests",
948    tests = [
949        "any_test",
950        "autograd_test",
951        "dataloader_test",
952        "enum_test",
953        "expanding_array_test",
954        "functional_test",
955        "init_test",
956        "integration_test",
957        "jit_test",
958        "memory_test",
959        "misc_test",
960        "module_test",
961        "modulelist_test",
962        "modules_test",
963        "nn_utils_test",
964        "optim_test",
965        "ordered_dict_test",
966        "rnn_test",
967        "sequential_test",
968        "serialize_test",
969        "static_test",
970        "tensor_options_test",
971        "tensor_test",
972        "torch_include_test",
973    ],
974)
975
976# dist autograd tests
977cc_test(
978    name = "torch_dist_autograd_test",
979    size = "small",
980    srcs = ["test/cpp/dist_autograd/test_dist_autograd.cpp"],
981    tags = [
982        "exclusive",
983        "gpu-required",
984    ],
985    deps = [
986        ":torch",
987        "@com_google_googletest//:gtest_main",
988    ],
989)
990
991# jit tests
992# Because these individual unit tests require custom registering,
993# it is easier to mimic the cmake build by globing together a single test.
994cc_test(
995    name = "jit_tests",
996    size = "small",
997    srcs = glob(
998        [
999            "test/cpp/jit/*.cpp",
1000            "test/cpp/jit/*.h",
1001            "test/cpp/tensorexpr/*.cpp",
1002            "test/cpp/tensorexpr/*.h",
1003        ],
1004        exclude = [
1005            # skip this since <pybind11/embed.h> is not found in OSS build
1006            "test/cpp/jit/test_exception.cpp",
1007        ],
1008    ),
1009    linkstatic = True,
1010    tags = [
1011        "exclusive",
1012        "gpu-required",
1013    ],
1014    deps = [
1015        ":torch",
1016        "@com_google_googletest//:gtest_main",
1017    ],
1018)
1019
1020cc_test(
1021    name = "lazy_tests",
1022    size = "small",
1023    srcs = glob(
1024        [
1025            "test/cpp/lazy/*.cpp",
1026            "test/cpp/lazy/*.h",
1027        ],
1028        exclude = [
1029            # skip these since they depend on generated LazyIr.h which isn't available in bazel yet
1030            "test/cpp/lazy/test_ir.cpp",
1031            "test/cpp/lazy/test_lazy_ops.cpp",
1032            "test/cpp/lazy/test_lazy_ops_util.cpp",
1033        ],
1034    ),
1035    linkstatic = True,
1036    tags = [
1037        "exclusive",
1038    ],
1039    deps = [
1040        ":torch",
1041        "@com_google_googletest//:gtest_main",
1042    ],
1043)
1044
1045# python api tests
1046
1047py_test(
1048    name = "test_bazel",
1049    srcs = ["test/_test_bazel.py"],
1050    main = "test/_test_bazel.py",
1051    deps = [":pytorch_py"],
1052)
1053
1054# all tests
1055test_suite(
1056    name = "all_tests",
1057    tests = [
1058        "api_tests",
1059        "jit_tests",
1060        "torch_dist_autograd_test",
1061        "//c10/test:tests",
1062    ],
1063)
1064
1065# An internal genrule that we are converging with refers to these file
1066# as if they are from this package, so we alias them for
1067# compatibility.
1068
1069[
1070    alias(
1071        name = paths.basename(path),
1072        actual = path,
1073    )
1074    for path in [
1075        "aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp",
1076        "aten/src/ATen/templates/DispatchKeyNativeFunctions.h",
1077        "aten/src/ATen/templates/LazyIr.h",
1078        "aten/src/ATen/templates/LazyNonNativeIr.h",
1079        "aten/src/ATen/templates/RegisterDispatchKey.cpp",
1080        "aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
1081        "aten/src/ATen/native/native_functions.yaml",
1082        "aten/src/ATen/native/tags.yaml",
1083        "aten/src/ATen/native/ts_native_functions.yaml",
1084        "torch/csrc/lazy/core/shape_inference.h",
1085        "torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
1086    ]
1087]
1088
1089genrule(
1090    name = "download_mnist",
1091    srcs = ["//:tools/download_mnist.py"],
1092    outs = [
1093        "mnist/train-images-idx3-ubyte",
1094        "mnist/train-labels-idx1-ubyte",
1095        "mnist/t10k-images-idx3-ubyte",
1096        "mnist/t10k-labels-idx1-ubyte",
1097    ],
1098    cmd = "python3 tools/download_mnist.py -d $(RULEDIR)/mnist",
1099)
1100