xref: /aosp_15_r20/external/tensorflow/third_party/gpus/cuda_configure.bzl (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1"""Repository rule for CUDA autoconfiguration.
2
3`cuda_configure` depends on the following environment variables:
4
5  * `TF_NEED_CUDA`: Whether to enable building with CUDA.
6  * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
7  * `TF_CUDA_CLANG`: Whether to use clang as a cuda compiler.
8  * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for
9    both host and device code compilation if TF_CUDA_CLANG is 1.
10  * `TF_SYSROOT`: The sysroot to use when compiling.
11  * `TF_DOWNLOAD_CLANG`: Whether to download a recent release of clang
12    compiler and use it to build tensorflow. When this option is set
13    CLANG_CUDA_COMPILER_PATH is ignored.
14  * `TF_CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is
15    `/usr/local/cuda,usr/`.
16  * `CUDA_TOOLKIT_PATH` (deprecated): The path to the CUDA toolkit. Default is
17    `/usr/local/cuda`.
18  * `TF_CUDA_VERSION`: The version of the CUDA toolkit. If this is blank, then
19    use the system default.
20  * `TF_CUDNN_VERSION`: The version of the cuDNN library.
21  * `CUDNN_INSTALL_PATH` (deprecated): The path to the cuDNN library. Default is
22    `/usr/local/cuda`.
23  * `TF_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default is
24    `3.5,5.2`.
25  * `PYTHON_BIN_PATH`: The python binary path
26"""
27
28load("//third_party/clang_toolchain:download_clang.bzl", "download_clang")
29load(
30    "@bazel_tools//tools/cpp:lib_cc_configure.bzl",
31    "escape_string",
32    "get_env_var",
33)
34load(
35    "@bazel_tools//tools/cpp:windows_cc_configure.bzl",
36    "find_msvc_tool",
37    "find_vc_path",
38    "setup_vc_env_vars",
39)
40load(
41    "//third_party/remote_config:common.bzl",
42    "config_repo_label",
43    "err_out",
44    "execute",
45    "get_bash_bin",
46    "get_cpu_value",
47    "get_host_environ",
48    "get_python_bin",
49    "is_windows",
50    "raw_exec",
51    "read_dir",
52    "realpath",
53    "which",
54)
55
56_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
57_GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX"
58_CLANG_CUDA_COMPILER_PATH = "CLANG_CUDA_COMPILER_PATH"
59_TF_SYSROOT = "TF_SYSROOT"
60_CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH"
61_TF_CUDA_VERSION = "TF_CUDA_VERSION"
62_TF_CUDNN_VERSION = "TF_CUDNN_VERSION"
63_CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH"
64_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES"
65_TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO"
66_TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG"
67_PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
68
69def to_list_of_strings(elements):
70    """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'.
71
72    This is to be used to put a list of strings into the bzl file templates
73    so it gets interpreted as list of strings in Starlark.
74
75    Args:
76      elements: list of string elements
77
78    Returns:
79      single string of elements wrapped in quotes separated by a comma."""
80    quoted_strings = ["\"" + element + "\"" for element in elements]
81    return ", ".join(quoted_strings)
82
83def verify_build_defines(params):
84    """Verify all variables that crosstool/BUILD.tpl expects are substituted.
85
86    Args:
87      params: dict of variables that will be passed to the BUILD.tpl template.
88    """
89    missing = []
90    for param in [
91        "cxx_builtin_include_directories",
92        "extra_no_canonical_prefixes_flags",
93        "host_compiler_path",
94        "host_compiler_prefix",
95        "host_compiler_warnings",
96        "linker_bin_path",
97        "compiler_deps",
98        "msvc_cl_path",
99        "msvc_env_include",
100        "msvc_env_lib",
101        "msvc_env_path",
102        "msvc_env_tmp",
103        "msvc_lib_path",
104        "msvc_link_path",
105        "msvc_ml_path",
106        "unfiltered_compile_flags",
107        "win_compiler_deps",
108    ]:
109        if ("%{" + param + "}") not in params:
110            missing.append(param)
111
112    if missing:
113        auto_configure_fail(
114            "BUILD.tpl template is missing these variables: " +
115            str(missing) +
116            ".\nWe only got: " +
117            str(params) +
118            ".",
119        )
120
121def _get_nvcc_tmp_dir_for_windows(repository_ctx):
122    """Return the Windows tmp directory for nvcc to generate intermediate source files."""
123    escaped_tmp_dir = escape_string(
124        get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace(
125            "\\",
126            "\\\\",
127        ),
128    )
129    return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir"
130
131def _get_msvc_compiler(repository_ctx):
132    vc_path = find_vc_path(repository_ctx)
133    return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/")
134
135def _get_win_cuda_defines(repository_ctx):
136    """Return CROSSTOOL defines for Windows"""
137
138    # If we are not on Windows, return fake vaules for Windows specific fields.
139    # This ensures the CROSSTOOL file parser is happy.
140    if not is_windows(repository_ctx):
141        return {
142            "%{msvc_env_tmp}": "msvc_not_used",
143            "%{msvc_env_path}": "msvc_not_used",
144            "%{msvc_env_include}": "msvc_not_used",
145            "%{msvc_env_lib}": "msvc_not_used",
146            "%{msvc_cl_path}": "msvc_not_used",
147            "%{msvc_ml_path}": "msvc_not_used",
148            "%{msvc_link_path}": "msvc_not_used",
149            "%{msvc_lib_path}": "msvc_not_used",
150        }
151
152    vc_path = find_vc_path(repository_ctx)
153    if not vc_path:
154        auto_configure_fail(
155            "Visual C++ build tools not found on your machine." +
156            "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using",
157        )
158        return {}
159
160    env = setup_vc_env_vars(repository_ctx, vc_path)
161    escaped_paths = escape_string(env["PATH"])
162    escaped_include_paths = escape_string(env["INCLUDE"])
163    escaped_lib_paths = escape_string(env["LIB"])
164    escaped_tmp_dir = escape_string(
165        get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace(
166            "\\",
167            "\\\\",
168        ),
169    )
170
171    msvc_cl_path = get_python_bin(repository_ctx)
172    msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace(
173        "\\",
174        "/",
175    )
176    msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace(
177        "\\",
178        "/",
179    )
180    msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace(
181        "\\",
182        "/",
183    )
184
185    # nvcc will generate some temporary source files under %{nvcc_tmp_dir}
186    # The generated files are guaranteed to have unique name, so they can share
187    # the same tmp directory
188    escaped_cxx_include_directories = [
189        _get_nvcc_tmp_dir_for_windows(repository_ctx),
190        "C:\\\\botcode\\\\w",
191    ]
192    for path in escaped_include_paths.split(";"):
193        if path:
194            escaped_cxx_include_directories.append(path)
195
196    return {
197        "%{msvc_env_tmp}": escaped_tmp_dir,
198        "%{msvc_env_path}": escaped_paths,
199        "%{msvc_env_include}": escaped_include_paths,
200        "%{msvc_env_lib}": escaped_lib_paths,
201        "%{msvc_cl_path}": msvc_cl_path,
202        "%{msvc_ml_path}": msvc_ml_path,
203        "%{msvc_link_path}": msvc_link_path,
204        "%{msvc_lib_path}": msvc_lib_path,
205        "%{cxx_builtin_include_directories}": to_list_of_strings(
206            escaped_cxx_include_directories,
207        ),
208    }
209
210# TODO(dzc): Once these functions have been factored out of Bazel's
211# cc_configure.bzl, load them from @bazel_tools instead.
212# BEGIN cc_configure common functions.
213def find_cc(repository_ctx):
214    """Find the C++ compiler."""
215    if is_windows(repository_ctx):
216        return _get_msvc_compiler(repository_ctx)
217
218    if _use_cuda_clang(repository_ctx):
219        target_cc_name = "clang"
220        cc_path_envvar = _CLANG_CUDA_COMPILER_PATH
221        if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG):
222            return "extra_tools/bin/clang"
223    else:
224        target_cc_name = "gcc"
225        cc_path_envvar = _GCC_HOST_COMPILER_PATH
226    cc_name = target_cc_name
227
228    cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
229    if cc_name_from_env:
230        cc_name = cc_name_from_env
231    if cc_name.startswith("/"):
232        # Absolute path, maybe we should make this supported by our which function.
233        return cc_name
234    cc = which(repository_ctx, cc_name)
235    if cc == None:
236        fail(("Cannot find {}, either correct your path or set the {}" +
237              " environment variable").format(target_cc_name, cc_path_envvar))
238    return cc
239
240_INC_DIR_MARKER_BEGIN = "#include <...>"
241
242# OSX add " (framework directory)" at the end of line, strip it.
243_OSX_FRAMEWORK_SUFFIX = " (framework directory)"
244_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX)
245
246def _cxx_inc_convert(path):
247    """Convert path returned by cc -E xc++ in a complete path."""
248    path = path.strip()
249    if path.endswith(_OSX_FRAMEWORK_SUFFIX):
250        path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
251    return path
252
253def _normalize_include_path(repository_ctx, path):
254    """Normalizes include paths before writing them to the crosstool.
255
256      If path points inside the 'crosstool' folder of the repository, a relative
257      path is returned.
258      If path points outside the 'crosstool' folder, an absolute path is returned.
259      """
260    path = str(repository_ctx.path(path))
261    crosstool_folder = str(repository_ctx.path(".").get_child("crosstool"))
262
263    if path.startswith(crosstool_folder):
264        # We drop the path to "$REPO/crosstool" and a trailing path separator.
265        return path[len(crosstool_folder) + 1:]
266    return path
267
268def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot):
269    """Compute the list of default C or C++ include directories."""
270    if lang_is_cpp:
271        lang = "c++"
272    else:
273        lang = "c"
274    sysroot = []
275    if tf_sysroot:
276        sysroot += ["--sysroot", tf_sysroot]
277    result = raw_exec(repository_ctx, [cc, "-E", "-x" + lang, "-", "-v"] +
278                                      sysroot)
279    stderr = err_out(result)
280    index1 = stderr.find(_INC_DIR_MARKER_BEGIN)
281    if index1 == -1:
282        return []
283    index1 = stderr.find("\n", index1)
284    if index1 == -1:
285        return []
286    index2 = stderr.rfind("\n ")
287    if index2 == -1 or index2 < index1:
288        return []
289    index2 = stderr.find("\n", index2 + 1)
290    if index2 == -1:
291        inc_dirs = stderr[index1 + 1:]
292    else:
293        inc_dirs = stderr[index1 + 1:index2].strip()
294
295    return [
296        _normalize_include_path(repository_ctx, _cxx_inc_convert(p))
297        for p in inc_dirs.split("\n")
298    ]
299
300def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot):
301    """Compute the list of default C and C++ include directories."""
302
303    # For some reason `clang -xc` sometimes returns include paths that are
304    # different from the ones from `clang -xc++`. (Symlink and a dir)
305    # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
306    includes_cpp = _get_cxx_inc_directories_impl(
307        repository_ctx,
308        cc,
309        True,
310        tf_sysroot,
311    )
312    includes_c = _get_cxx_inc_directories_impl(
313        repository_ctx,
314        cc,
315        False,
316        tf_sysroot,
317    )
318
319    return includes_cpp + [
320        inc
321        for inc in includes_c
322        if inc not in includes_cpp
323    ]
324
325def auto_configure_fail(msg):
326    """Output failure message when cuda configuration fails."""
327    red = "\033[0;31m"
328    no_color = "\033[0m"
329    fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg))
330
331# END cc_configure common functions (see TODO above).
332
333def _cuda_include_path(repository_ctx, cuda_config):
334    """Generates the Starlark string with cuda include directories.
335
336      Args:
337        repository_ctx: The repository context.
338        cc: The path to the gcc host compiler.
339
340      Returns:
341        A list of the gcc host compiler include directories.
342      """
343    nvcc_path = repository_ctx.path("%s/bin/nvcc%s" % (
344        cuda_config.cuda_toolkit_path,
345        ".exe" if cuda_config.cpu_value == "Windows" else "",
346    ))
347
348    # The expected exit code of this command is non-zero. Bazel remote execution
349    # only caches commands with zero exit code. So force a zero exit code.
350    cmd = "%s -v /dev/null -o /dev/null ; [ $? -eq 1 ]" % str(nvcc_path)
351    result = raw_exec(repository_ctx, [get_bash_bin(repository_ctx), "-c", cmd])
352    target_dir = ""
353    for one_line in err_out(result).splitlines():
354        if one_line.startswith("#$ _TARGET_DIR_="):
355            target_dir = (
356                cuda_config.cuda_toolkit_path + "/" + one_line.replace(
357                    "#$ _TARGET_DIR_=",
358                    "",
359                ) + "/include"
360            )
361    inc_entries = []
362    if target_dir != "":
363        inc_entries.append(realpath(repository_ctx, target_dir))
364    inc_entries.append(realpath(repository_ctx, cuda_config.cuda_toolkit_path + "/include"))
365    return inc_entries
366
367def enable_cuda(repository_ctx):
368    """Returns whether to build with CUDA support."""
369    return int(get_host_environ(repository_ctx, "TF_NEED_CUDA", False))
370
371def matches_version(environ_version, detected_version):
372    """Checks whether the user-specified version matches the detected version.
373
374      This function performs a weak matching so that if the user specifies only
375      the
376      major or major and minor versions, the versions are still considered
377      matching
378      if the version parts match. To illustrate:
379
380          environ_version  detected_version  result
381          -----------------------------------------
382          5.1.3            5.1.3             True
383          5.1              5.1.3             True
384          5                5.1               True
385          5.1.3            5.1               False
386          5.2.3            5.1.3             False
387
388      Args:
389        environ_version: The version specified by the user via environment
390          variables.
391        detected_version: The version autodetected from the CUDA installation on
392          the system.
393      Returns: True if user-specified version matches detected version and False
394        otherwise.
395    """
396    environ_version_parts = environ_version.split(".")
397    detected_version_parts = detected_version.split(".")
398    if len(detected_version_parts) < len(environ_version_parts):
399        return False
400    for i, part in enumerate(detected_version_parts):
401        if i >= len(environ_version_parts):
402            break
403        if part != environ_version_parts[i]:
404            return False
405    return True
406
407_NVCC_VERSION_PREFIX = "Cuda compilation tools, release "
408
409_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
410
411def compute_capabilities(repository_ctx):
412    """Returns a list of strings representing cuda compute capabilities.
413
414    Args:
415      repository_ctx: the repo rule's context.
416    Returns: list of cuda architectures to compile for. 'compute_xy' refers to
417      both PTX and SASS, 'sm_xy' refers to SASS only.
418    """
419    capabilities = get_host_environ(
420        repository_ctx,
421        _TF_CUDA_COMPUTE_CAPABILITIES,
422        "compute_35,compute_52",
423    ).split(",")
424
425    # Map old 'x.y' capabilities to 'compute_xy'.
426    if len(capabilities) > 0 and all([len(x.split(".")) == 2 for x in capabilities]):
427        # If all capabilities are in 'x.y' format, only include PTX for the
428        # highest capability.
429        cc_list = sorted([x.replace(".", "") for x in capabilities])
430        capabilities = ["sm_%s" % x for x in cc_list[:-1]] + ["compute_%s" % cc_list[-1]]
431    for i, capability in enumerate(capabilities):
432        parts = capability.split(".")
433        if len(parts) != 2:
434            continue
435        capabilities[i] = "compute_%s%s" % (parts[0], parts[1])
436
437    # Make list unique
438    capabilities = dict(zip(capabilities, capabilities)).keys()
439
440    # Validate capabilities.
441    for capability in capabilities:
442        if not capability.startswith(("compute_", "sm_")):
443            auto_configure_fail("Invalid compute capability: %s" % capability)
444        for prefix in ["compute_", "sm_"]:
445            if not capability.startswith(prefix):
446                continue
447            if len(capability) == len(prefix) + 2 and capability[-2:].isdigit():
448                continue
449            auto_configure_fail("Invalid compute capability: %s" % capability)
450
451    return capabilities
452
453def lib_name(base_name, cpu_value, version = None, static = False):
454    """Constructs the platform-specific name of a library.
455
456      Args:
457        base_name: The name of the library, such as "cudart"
458        cpu_value: The name of the host operating system.
459        version: The version of the library.
460        static: True the library is static or False if it is a shared object.
461
462      Returns:
463        The platform-specific name of the library.
464      """
465    version = "" if not version else "." + version
466    if cpu_value in ("Linux", "FreeBSD"):
467        if static:
468            return "lib%s.a" % base_name
469        return "lib%s.so%s" % (base_name, version)
470    elif cpu_value == "Windows":
471        return "%s.lib" % base_name
472    elif cpu_value == "Darwin":
473        if static:
474            return "lib%s.a" % base_name
475        return "lib%s%s.dylib" % (base_name, version)
476    else:
477        auto_configure_fail("Invalid cpu_value: %s" % cpu_value)
478
479def _lib_path(lib, cpu_value, basedir, version, static):
480    file_name = lib_name(lib, cpu_value, version, static)
481    return "%s/%s" % (basedir, file_name)
482
483def _should_check_soname(version, static):
484    return version and not static
485
486def _check_cuda_lib_params(lib, cpu_value, basedir, version, static = False):
487    return (
488        _lib_path(lib, cpu_value, basedir, version, static),
489        _should_check_soname(version, static),
490    )
491
492def _check_cuda_libs(repository_ctx, script_path, libs):
493    python_bin = get_python_bin(repository_ctx)
494    contents = repository_ctx.read(script_path).splitlines()
495
496    cmd = "from os import linesep;"
497    cmd += "f = open('script.py', 'w');"
498    for line in contents:
499        cmd += "f.write('%s' + linesep);" % line
500    cmd += "f.close();"
501    cmd += "from os import system;"
502    args = " ".join(["\"" + path + "\" " + str(check) for path, check in libs])
503    cmd += "system('%s script.py %s');" % (python_bin, args)
504
505    all_paths = [path for path, _ in libs]
506    checked_paths = execute(repository_ctx, [python_bin, "-c", cmd]).stdout.splitlines()
507
508    # Filter out empty lines from splitting on '\r\n' on Windows
509    checked_paths = [path for path in checked_paths if len(path) > 0]
510    if all_paths != checked_paths:
511        auto_configure_fail("Error with installed CUDA libs. Expected '%s'. Actual '%s'." % (all_paths, checked_paths))
512
513def _find_libs(repository_ctx, check_cuda_libs_script, cuda_config):
514    """Returns the CUDA and cuDNN libraries on the system.
515
516      Also, verifies that the script actually exist.
517
518      Args:
519        repository_ctx: The repository context.
520        check_cuda_libs_script: The path to a script verifying that the cuda
521          libraries exist on the system.
522        cuda_config: The CUDA config as returned by _get_cuda_config
523
524      Returns:
525        Map of library names to structs of filename and path.
526      """
527    cpu_value = cuda_config.cpu_value
528    stub_dir = "" if is_windows(repository_ctx) else "/stubs"
529
530    check_cuda_libs_params = {
531        "cuda": _check_cuda_lib_params(
532            "cuda",
533            cpu_value,
534            cuda_config.config["cuda_library_dir"] + stub_dir,
535            version = None,
536            static = False,
537        ),
538        "cudart": _check_cuda_lib_params(
539            "cudart",
540            cpu_value,
541            cuda_config.config["cuda_library_dir"],
542            cuda_config.cudart_version,
543            static = False,
544        ),
545        "cudart_static": _check_cuda_lib_params(
546            "cudart_static",
547            cpu_value,
548            cuda_config.config["cuda_library_dir"],
549            cuda_config.cudart_version,
550            static = True,
551        ),
552        "cublas": _check_cuda_lib_params(
553            "cublas",
554            cpu_value,
555            cuda_config.config["cublas_library_dir"],
556            cuda_config.cublas_version,
557            static = False,
558        ),
559        "cublasLt": _check_cuda_lib_params(
560            "cublasLt",
561            cpu_value,
562            cuda_config.config["cublas_library_dir"],
563            cuda_config.cublas_version,
564            static = False,
565        ),
566        "cusolver": _check_cuda_lib_params(
567            "cusolver",
568            cpu_value,
569            cuda_config.config["cusolver_library_dir"],
570            cuda_config.cusolver_version,
571            static = False,
572        ),
573        "curand": _check_cuda_lib_params(
574            "curand",
575            cpu_value,
576            cuda_config.config["curand_library_dir"],
577            cuda_config.curand_version,
578            static = False,
579        ),
580        "cufft": _check_cuda_lib_params(
581            "cufft",
582            cpu_value,
583            cuda_config.config["cufft_library_dir"],
584            cuda_config.cufft_version,
585            static = False,
586        ),
587        "cudnn": _check_cuda_lib_params(
588            "cudnn",
589            cpu_value,
590            cuda_config.config["cudnn_library_dir"],
591            cuda_config.cudnn_version,
592            static = False,
593        ),
594        "cupti": _check_cuda_lib_params(
595            "cupti",
596            cpu_value,
597            cuda_config.config["cupti_library_dir"],
598            cuda_config.cuda_version,
599            static = False,
600        ),
601        "cusparse": _check_cuda_lib_params(
602            "cusparse",
603            cpu_value,
604            cuda_config.config["cusparse_library_dir"],
605            cuda_config.cusparse_version,
606            static = False,
607        ),
608    }
609
610    # Verify that the libs actually exist at their locations.
611    _check_cuda_libs(repository_ctx, check_cuda_libs_script, check_cuda_libs_params.values())
612
613    paths = {filename: v[0] for (filename, v) in check_cuda_libs_params.items()}
614    return paths
615
616def _cudart_static_linkopt(cpu_value):
617    """Returns additional platform-specific linkopts for cudart."""
618    return "" if cpu_value == "Darwin" else "\"-lrt\","
619
620def _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries):
621    python_bin = get_python_bin(repository_ctx)
622
623    # If used with remote execution then repository_ctx.execute() can't
624    # access files from the source tree. A trick is to read the contents
625    # of the file in Starlark and embed them as part of the command. In
626    # this case the trick is not sufficient as the find_cuda_config.py
627    # script has more than 8192 characters. 8192 is the command length
628    # limit of cmd.exe on Windows. Thus we additionally need to compress
629    # the contents locally and decompress them as part of the execute().
630    compressed_contents = repository_ctx.read(script_path)
631    decompress_and_execute_cmd = (
632        "from zlib import decompress;" +
633        "from base64 import b64decode;" +
634        "from os import system;" +
635        "script = decompress(b64decode('%s'));" % compressed_contents +
636        "f = open('script.py', 'wb');" +
637        "f.write(script);" +
638        "f.close();" +
639        "system('\"%s\" script.py %s');" % (python_bin, " ".join(cuda_libraries))
640    )
641
642    return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd])
643
644# TODO(csigg): Only call once instead of from here, tensorrt_configure.bzl,
645# and nccl_configure.bzl.
646def find_cuda_config(repository_ctx, script_path, cuda_libraries):
647    """Returns CUDA config dictionary from running find_cuda_config.py"""
648    exec_result = _exec_find_cuda_config(repository_ctx, script_path, cuda_libraries)
649    if exec_result.return_code:
650        auto_configure_fail("Failed to run find_cuda_config.py: %s" % err_out(exec_result))
651
652    # Parse the dict from stdout.
653    return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
654
655def _get_cuda_config(repository_ctx, find_cuda_config_script):
656    """Detects and returns information about the CUDA installation on the system.
657
658      Args:
659        repository_ctx: The repository context.
660
661      Returns:
662        A struct containing the following fields:
663          cuda_toolkit_path: The CUDA toolkit installation directory.
664          cudnn_install_basedir: The cuDNN installation directory.
665          cuda_version: The version of CUDA on the system.
666          cudart_version: The CUDA runtime version on the system.
667          cudnn_version: The version of cuDNN on the system.
668          compute_capabilities: A list of the system's CUDA compute capabilities.
669          cpu_value: The name of the host operating system.
670      """
671    config = find_cuda_config(repository_ctx, find_cuda_config_script, ["cuda", "cudnn"])
672    cpu_value = get_cpu_value(repository_ctx)
673    toolkit_path = config["cuda_toolkit_path"]
674
675    is_windows = cpu_value == "Windows"
676    cuda_version = config["cuda_version"].split(".")
677    cuda_major = cuda_version[0]
678    cuda_minor = cuda_version[1]
679
680    cuda_version = ("64_%s%s" if is_windows else "%s.%s") % (cuda_major, cuda_minor)
681    cudnn_version = ("64_%s" if is_windows else "%s") % config["cudnn_version"]
682
683    if int(cuda_major) >= 11:
684        # The libcudart soname in CUDA 11.x is versioned as 11.0 for backward compatability.
685        if int(cuda_major) == 11:
686            cudart_version = "64_110" if is_windows else "11.0"
687        else:
688            cudart_version = ("64_%s" if is_windows else "%s") % cuda_major
689        cublas_version = ("64_%s" if is_windows else "%s") % config["cublas_version"].split(".")[0]
690        cusolver_version = ("64_%s" if is_windows else "%s") % config["cusolver_version"].split(".")[0]
691        curand_version = ("64_%s" if is_windows else "%s") % config["curand_version"].split(".")[0]
692        cufft_version = ("64_%s" if is_windows else "%s") % config["cufft_version"].split(".")[0]
693        cusparse_version = ("64_%s" if is_windows else "%s") % config["cusparse_version"].split(".")[0]
694    elif (int(cuda_major), int(cuda_minor)) >= (10, 1):
695        # cuda_lib_version is for libraries like cuBLAS, cuFFT, cuSOLVER, etc.
696        # It changed from 'x.y' to just 'x' in CUDA 10.1.
697        cuda_lib_version = ("64_%s" if is_windows else "%s") % cuda_major
698        cudart_version = cuda_version
699        cublas_version = cuda_lib_version
700        cusolver_version = cuda_lib_version
701        curand_version = cuda_lib_version
702        cufft_version = cuda_lib_version
703        cusparse_version = cuda_lib_version
704    else:
705        cudart_version = cuda_version
706        cublas_version = cuda_version
707        cusolver_version = cuda_version
708        curand_version = cuda_version
709        cufft_version = cuda_version
710        cusparse_version = cuda_version
711
712    return struct(
713        cuda_toolkit_path = toolkit_path,
714        cuda_version = cuda_version,
715        cuda_version_major = cuda_major,
716        cudart_version = cudart_version,
717        cublas_version = cublas_version,
718        cusolver_version = cusolver_version,
719        curand_version = curand_version,
720        cufft_version = cufft_version,
721        cusparse_version = cusparse_version,
722        cudnn_version = cudnn_version,
723        compute_capabilities = compute_capabilities(repository_ctx),
724        cpu_value = cpu_value,
725        config = config,
726    )
727
728def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
729    if not out:
730        out = tpl.replace(":", "/")
731    repository_ctx.template(
732        out,
733        Label("//third_party/gpus/%s.tpl" % tpl),
734        substitutions,
735    )
736
737def _file(repository_ctx, label):
738    repository_ctx.template(
739        label.replace(":", "/"),
740        Label("//third_party/gpus/%s.tpl" % label),
741        {},
742    )
743
744_DUMMY_CROSSTOOL_BZL_FILE = """
745def error_gpu_disabled():
746  fail("ERROR: Building with --config=cuda but TensorFlow is not configured " +
747       "to build with GPU support. Please re-run ./configure and enter 'Y' " +
748       "at the prompt to build with GPU support.")
749
750  native.genrule(
751      name = "error_gen_crosstool",
752      outs = ["CROSSTOOL"],
753      cmd = "echo 'Should not be run.' && exit 1",
754  )
755
756  native.filegroup(
757      name = "crosstool",
758      srcs = [":CROSSTOOL"],
759      output_licenses = ["unencumbered"],
760  )
761"""
762
763_DUMMY_CROSSTOOL_BUILD_FILE = """
764load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled")
765
766error_gpu_disabled()
767"""
768
769def _create_dummy_repository(repository_ctx):
770    cpu_value = get_cpu_value(repository_ctx)
771
772    # Set up BUILD file for cuda/.
773    _tpl(
774        repository_ctx,
775        "cuda:build_defs.bzl",
776        {
777            "%{cuda_is_configured}": "False",
778            "%{cuda_extra_copts}": "[]",
779            "%{cuda_gpu_architectures}": "[]",
780        },
781    )
782    _tpl(
783        repository_ctx,
784        "cuda:BUILD",
785        {
786            "%{cuda_driver_lib}": lib_name("cuda", cpu_value),
787            "%{cudart_static_lib}": lib_name(
788                "cudart_static",
789                cpu_value,
790                static = True,
791            ),
792            "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
793            "%{cudart_lib}": lib_name("cudart", cpu_value),
794            "%{cublas_lib}": lib_name("cublas", cpu_value),
795            "%{cublasLt_lib}": lib_name("cublasLt", cpu_value),
796            "%{cusolver_lib}": lib_name("cusolver", cpu_value),
797            "%{cudnn_lib}": lib_name("cudnn", cpu_value),
798            "%{cufft_lib}": lib_name("cufft", cpu_value),
799            "%{curand_lib}": lib_name("curand", cpu_value),
800            "%{cupti_lib}": lib_name("cupti", cpu_value),
801            "%{cusparse_lib}": lib_name("cusparse", cpu_value),
802            "%{cub_actual}": ":cuda_headers",
803            "%{copy_rules}": """
804filegroup(name="cuda-include")
805filegroup(name="cublas-include")
806filegroup(name="cusolver-include")
807filegroup(name="cufft-include")
808filegroup(name="cusparse-include")
809filegroup(name="curand-include")
810filegroup(name="cudnn-include")
811""",
812        },
813    )
814
815    # Create dummy files for the CUDA toolkit since they are still required by
816    # tensorflow/tsl/platform/default/build_config:cuda.
817    repository_ctx.file("cuda/cuda/include/cuda.h")
818    repository_ctx.file("cuda/cuda/include/cublas.h")
819    repository_ctx.file("cuda/cuda/include/cudnn.h")
820    repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h")
821    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cuda", cpu_value))
822    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudart", cpu_value))
823    repository_ctx.file(
824        "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value),
825    )
826    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value))
827    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublasLt", cpu_value))
828    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value))
829    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value))
830    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("curand", cpu_value))
831    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cufft", cpu_value))
832    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cupti", cpu_value))
833    repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusparse", cpu_value))
834
835    # Set up cuda_config.h, which is used by
836    # tensorflow/stream_executor/dso_loader.cc.
837    _tpl(
838        repository_ctx,
839        "cuda:cuda_config.h",
840        {
841            "%{cuda_version}": "",
842            "%{cudart_version}": "",
843            "%{cublas_version}": "",
844            "%{cusolver_version}": "",
845            "%{curand_version}": "",
846            "%{cufft_version}": "",
847            "%{cusparse_version}": "",
848            "%{cudnn_version}": "",
849            "%{cuda_toolkit_path}": "",
850            "%{cuda_compute_capabilities}": "",
851        },
852        "cuda/cuda/cuda_config.h",
853    )
854
855    # Set up cuda_config.py, which is used by gen_build_info to provide
856    # static build environment info to the API
857    _tpl(
858        repository_ctx,
859        "cuda:cuda_config.py",
860        _py_tmpl_dict({}),
861        "cuda/cuda/cuda_config.py",
862    )
863
864    # If cuda_configure is not configured to build with GPU support, and the user
865    # attempts to build with --config=cuda, add a dummy build rule to intercept
866    # this and fail with an actionable error message.
867    repository_ctx.file(
868        "crosstool/error_gpu_disabled.bzl",
869        _DUMMY_CROSSTOOL_BZL_FILE,
870    )
871    repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
872
873def _norm_path(path):
874    """Returns a path with '/' and remove the trailing slash."""
875    path = path.replace("\\", "/")
876    if path[-1] == "/":
877        path = path[:-1]
878    return path
879
880def make_copy_files_rule(repository_ctx, name, srcs, outs):
881    """Returns a rule to copy a set of files."""
882    cmds = []
883
884    # Copy files.
885    for src, out in zip(srcs, outs):
886        cmds.append('cp -f "%s" "$(location %s)"' % (src, out))
887    outs = [('        "%s",' % out) for out in outs]
888    return """genrule(
889    name = "%s",
890    outs = [
891%s
892    ],
893    cmd = \"""%s \""",
894)""" % (name, "\n".join(outs), " && \\\n".join(cmds))
895
896def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir, exceptions = None):
897    """Returns a rule to recursively copy a directory.
898    If exceptions is not None, it must be a list of files or directories in
899    'src_dir'; these will be excluded from copying.
900    """
901    src_dir = _norm_path(src_dir)
902    out_dir = _norm_path(out_dir)
903    outs = read_dir(repository_ctx, src_dir)
904    post_cmd = ""
905    if exceptions != None:
906        outs = [x for x in outs if not any([
907            x.startswith(src_dir + "/" + y)
908            for y in exceptions
909        ])]
910    outs = [('        "%s",' % out.replace(src_dir, out_dir)) for out in outs]
911
912    # '@D' already contains the relative path for a single file, see
913    # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables
914    out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)"
915    if exceptions != None:
916        for x in exceptions:
917            post_cmd += " ; rm -fR " + out_dir + "/" + x
918    return """genrule(
919    name = "%s",
920    outs = [
921%s
922    ],
923    cmd = \"""cp -rLf "%s/." "%s/" %s\""",
924)""" % (name, "\n".join(outs), src_dir, out_dir, post_cmd)
925
926def _flag_enabled(repository_ctx, flag_name):
927    return get_host_environ(repository_ctx, flag_name) == "1"
928
929def _use_cuda_clang(repository_ctx):
930    return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
931
932def _tf_sysroot(repository_ctx):
933    return get_host_environ(repository_ctx, _TF_SYSROOT, "")
934
935def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
936    copts = []
937    for capability in compute_capabilities:
938        if capability.startswith("compute_"):
939            capability = capability.replace("compute_", "sm_")
940            copts.append("--cuda-include-ptx=%s" % capability)
941        copts.append("--cuda-gpu-arch=%s" % capability)
942
943    return str(copts)
944
945def _tpl_path(repository_ctx, filename):
946    return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % filename))
947
948def _basename(repository_ctx, path_str):
949    """Returns the basename of a path of type string.
950
951    This method is different from path.basename in that it also works if
952    the host platform is different from the execution platform
953    i.e. linux -> windows.
954    """
955
956    num_chars = len(path_str)
957    is_win = is_windows(repository_ctx)
958    for i in range(num_chars):
959        r_i = num_chars - 1 - i
960        if (is_win and path_str[r_i] == "\\") or path_str[r_i] == "/":
961            return path_str[r_i + 1:]
962    return path_str
963
964def _create_local_cuda_repository(repository_ctx):
965    """Creates the repository containing files set up to build with CUDA."""
966
967    # Resolve all labels before doing any real work. Resolving causes the
968    # function to be restarted with all previous state being lost. This
969    # can easily lead to a O(n^2) runtime in the number of labels.
970    # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
971    tpl_paths = {filename: _tpl_path(repository_ctx, filename) for filename in [
972        "cuda:build_defs.bzl",
973        "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
974        "crosstool:windows/msvc_wrapper_for_nvcc.py",
975        "crosstool:BUILD",
976        "crosstool:cc_toolchain_config.bzl",
977        "cuda:cuda_config.h",
978        "cuda:cuda_config.py",
979    ]}
980    tpl_paths["cuda:BUILD"] = _tpl_path(repository_ctx, "cuda:BUILD.windows" if is_windows(repository_ctx) else "cuda:BUILD")
981    find_cuda_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
982
983    cuda_config = _get_cuda_config(repository_ctx, find_cuda_config_script)
984
985    cuda_include_path = cuda_config.config["cuda_include_dir"]
986    cublas_include_path = cuda_config.config["cublas_include_dir"]
987    cudnn_header_dir = cuda_config.config["cudnn_include_dir"]
988    cupti_header_dir = cuda_config.config["cupti_include_dir"]
989    nvvm_libdevice_dir = cuda_config.config["nvvm_library_dir"]
990
991    # Create genrule to copy files from the installed CUDA toolkit into execroot.
992    copy_rules = [
993        make_copy_dir_rule(
994            repository_ctx,
995            name = "cuda-include",
996            src_dir = cuda_include_path,
997            out_dir = "cuda/include",
998        ),
999        make_copy_dir_rule(
1000            repository_ctx,
1001            name = "cuda-nvvm",
1002            src_dir = nvvm_libdevice_dir,
1003            out_dir = "cuda/nvvm/libdevice",
1004        ),
1005        make_copy_dir_rule(
1006            repository_ctx,
1007            name = "cuda-extras",
1008            src_dir = cupti_header_dir,
1009            out_dir = "cuda/extras/CUPTI/include",
1010        ),
1011    ]
1012
1013    copy_rules.append(make_copy_files_rule(
1014        repository_ctx,
1015        name = "cublas-include",
1016        srcs = [
1017            cublas_include_path + "/cublas.h",
1018            cublas_include_path + "/cublas_v2.h",
1019            cublas_include_path + "/cublas_api.h",
1020            cublas_include_path + "/cublasLt.h",
1021        ],
1022        outs = [
1023            "cublas/include/cublas.h",
1024            "cublas/include/cublas_v2.h",
1025            "cublas/include/cublas_api.h",
1026            "cublas/include/cublasLt.h",
1027        ],
1028    ))
1029
1030    cusolver_include_path = cuda_config.config["cusolver_include_dir"]
1031    copy_rules.append(make_copy_files_rule(
1032        repository_ctx,
1033        name = "cusolver-include",
1034        srcs = [
1035            cusolver_include_path + "/cusolver_common.h",
1036            cusolver_include_path + "/cusolverDn.h",
1037        ],
1038        outs = [
1039            "cusolver/include/cusolver_common.h",
1040            "cusolver/include/cusolverDn.h",
1041        ],
1042    ))
1043
1044    cufft_include_path = cuda_config.config["cufft_include_dir"]
1045    copy_rules.append(make_copy_files_rule(
1046        repository_ctx,
1047        name = "cufft-include",
1048        srcs = [
1049            cufft_include_path + "/cufft.h",
1050        ],
1051        outs = [
1052            "cufft/include/cufft.h",
1053        ],
1054    ))
1055
1056    cusparse_include_path = cuda_config.config["cusparse_include_dir"]
1057    copy_rules.append(make_copy_files_rule(
1058        repository_ctx,
1059        name = "cusparse-include",
1060        srcs = [
1061            cusparse_include_path + "/cusparse.h",
1062        ],
1063        outs = [
1064            "cusparse/include/cusparse.h",
1065        ],
1066    ))
1067
1068    curand_include_path = cuda_config.config["curand_include_dir"]
1069    copy_rules.append(make_copy_files_rule(
1070        repository_ctx,
1071        name = "curand-include",
1072        srcs = [
1073            curand_include_path + "/curand.h",
1074        ],
1075        outs = [
1076            "curand/include/curand.h",
1077        ],
1078    ))
1079
1080    check_cuda_libs_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:check_cuda_libs.py"))
1081    cuda_libs = _find_libs(repository_ctx, check_cuda_libs_script, cuda_config)
1082    cuda_lib_srcs = []
1083    cuda_lib_outs = []
1084    for path in cuda_libs.values():
1085        cuda_lib_srcs.append(path)
1086        cuda_lib_outs.append("cuda/lib/" + _basename(repository_ctx, path))
1087    copy_rules.append(make_copy_files_rule(
1088        repository_ctx,
1089        name = "cuda-lib",
1090        srcs = cuda_lib_srcs,
1091        outs = cuda_lib_outs,
1092    ))
1093
1094    # copy files mentioned in third_party/nccl/build_defs.bzl.tpl
1095    file_ext = ".exe" if is_windows(repository_ctx) else ""
1096    bin_files = (
1097        ["crt/link.stub"] +
1098        [f + file_ext for f in ["bin2c", "fatbinary", "nvlink", "nvprune"]]
1099    )
1100    copy_rules.append(make_copy_files_rule(
1101        repository_ctx,
1102        name = "cuda-bin",
1103        srcs = [cuda_config.cuda_toolkit_path + "/bin/" + f for f in bin_files],
1104        outs = ["cuda/bin/" + f for f in bin_files],
1105    ))
1106
1107    # Select the headers based on the cuDNN version (strip '64_' for Windows).
1108    cudnn_headers = ["cudnn.h"]
1109    if cuda_config.cudnn_version.rsplit("_", 1)[-1] >= "8":
1110        cudnn_headers += [
1111            "cudnn_backend.h",
1112            "cudnn_adv_infer.h",
1113            "cudnn_adv_train.h",
1114            "cudnn_cnn_infer.h",
1115            "cudnn_cnn_train.h",
1116            "cudnn_ops_infer.h",
1117            "cudnn_ops_train.h",
1118            "cudnn_version.h",
1119        ]
1120
1121    cudnn_srcs = []
1122    cudnn_outs = []
1123    for header in cudnn_headers:
1124        cudnn_srcs.append(cudnn_header_dir + "/" + header)
1125        cudnn_outs.append("cudnn/include/" + header)
1126
1127    copy_rules.append(make_copy_files_rule(
1128        repository_ctx,
1129        name = "cudnn-include",
1130        srcs = cudnn_srcs,
1131        outs = cudnn_outs,
1132    ))
1133
1134    # Set up BUILD file for cuda/
1135    repository_ctx.template(
1136        "cuda/build_defs.bzl",
1137        tpl_paths["cuda:build_defs.bzl"],
1138        {
1139            "%{cuda_is_configured}": "True",
1140            "%{cuda_extra_copts}": _compute_cuda_extra_copts(
1141                repository_ctx,
1142                cuda_config.compute_capabilities,
1143            ),
1144            "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities),
1145        },
1146    )
1147
1148    cub_actual = "@cub_archive//:cub"
1149    if int(cuda_config.cuda_version_major) >= 11:
1150        cub_actual = ":cuda_headers"
1151
1152    repository_ctx.template(
1153        "cuda/BUILD",
1154        tpl_paths["cuda:BUILD"],
1155        {
1156            "%{cuda_driver_lib}": _basename(repository_ctx, cuda_libs["cuda"]),
1157            "%{cudart_static_lib}": _basename(repository_ctx, cuda_libs["cudart_static"]),
1158            "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value),
1159            "%{cudart_lib}": _basename(repository_ctx, cuda_libs["cudart"]),
1160            "%{cublas_lib}": _basename(repository_ctx, cuda_libs["cublas"]),
1161            "%{cublasLt_lib}": _basename(repository_ctx, cuda_libs["cublasLt"]),
1162            "%{cusolver_lib}": _basename(repository_ctx, cuda_libs["cusolver"]),
1163            "%{cudnn_lib}": _basename(repository_ctx, cuda_libs["cudnn"]),
1164            "%{cufft_lib}": _basename(repository_ctx, cuda_libs["cufft"]),
1165            "%{curand_lib}": _basename(repository_ctx, cuda_libs["curand"]),
1166            "%{cupti_lib}": _basename(repository_ctx, cuda_libs["cupti"]),
1167            "%{cusparse_lib}": _basename(repository_ctx, cuda_libs["cusparse"]),
1168            "%{cub_actual}": cub_actual,
1169            "%{copy_rules}": "\n".join(copy_rules),
1170        },
1171    )
1172
1173    is_cuda_clang = _use_cuda_clang(repository_ctx)
1174    tf_sysroot = _tf_sysroot(repository_ctx)
1175
1176    should_download_clang = is_cuda_clang and _flag_enabled(
1177        repository_ctx,
1178        _TF_DOWNLOAD_CLANG,
1179    )
1180    if should_download_clang:
1181        download_clang(repository_ctx, "crosstool/extra_tools")
1182
1183    # Set up crosstool/
1184    cc = find_cc(repository_ctx)
1185    cc_fullpath = cc if not should_download_clang else "crosstool/" + cc
1186
1187    host_compiler_includes = get_cxx_inc_directories(
1188        repository_ctx,
1189        cc_fullpath,
1190        tf_sysroot,
1191    )
1192    cuda_defines = {}
1193    cuda_defines["%{builtin_sysroot}"] = tf_sysroot
1194    cuda_defines["%{cuda_toolkit_path}"] = ""
1195    cuda_defines["%{compiler}"] = "unknown"
1196    if is_cuda_clang:
1197        cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"]
1198        cuda_defines["%{compiler}"] = "clang"
1199
1200    host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX)
1201    if not host_compiler_prefix:
1202        host_compiler_prefix = "/usr/bin"
1203
1204    cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix
1205
1206    # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
1207    # https://github.com/bazelbuild/bazel/issues/760).
1208    # However, this stops our custom clang toolchain from picking the provided
1209    # LLD linker, so we're only adding '-B/usr/bin' when using non-downloaded
1210    # toolchain.
1211    # TODO: when bazel stops adding '-B/usr/bin' by default, remove this
1212    #       flag from the CROSSTOOL completely (see
1213    #       https://github.com/bazelbuild/bazel/issues/5634)
1214    if should_download_clang:
1215        cuda_defines["%{linker_bin_path}"] = ""
1216    else:
1217        cuda_defines["%{linker_bin_path}"] = host_compiler_prefix
1218
1219    cuda_defines["%{extra_no_canonical_prefixes_flags}"] = ""
1220    cuda_defines["%{unfiltered_compile_flags}"] = ""
1221    if is_cuda_clang:
1222        cuda_defines["%{host_compiler_path}"] = str(cc)
1223        cuda_defines["%{host_compiler_warnings}"] = """
1224        # Some parts of the codebase set -Werror and hit this warning, so
1225        # switch it off for now.
1226        "-Wno-invalid-partial-specialization"
1227    """
1228        cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(host_compiler_includes)
1229        cuda_defines["%{compiler_deps}"] = ":empty"
1230        cuda_defines["%{win_compiler_deps}"] = ":empty"
1231        repository_ctx.file(
1232            "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
1233            "",
1234        )
1235        repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "")
1236    else:
1237        cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
1238        cuda_defines["%{host_compiler_warnings}"] = ""
1239
1240        # nvcc has the system include paths built in and will automatically
1241        # search them; we cannot work around that, so we add the relevant cuda
1242        # system paths to the allowed compiler specific include paths.
1243        cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(
1244            host_compiler_includes + _cuda_include_path(
1245                repository_ctx,
1246                cuda_config,
1247            ) + [cupti_header_dir, cudnn_header_dir],
1248        )
1249
1250        # For gcc, do not canonicalize system header paths; some versions of gcc
1251        # pick the shortest possible path for system includes when creating the
1252        # .d file - given that includes that are prefixed with "../" multiple
1253        # time quickly grow longer than the root of the tree, this can lead to
1254        # bazel's header check failing.
1255        cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\""
1256
1257        file_ext = ".exe" if is_windows(repository_ctx) else ""
1258        nvcc_path = "%s/nvcc%s" % (cuda_config.config["cuda_binary_dir"], file_ext)
1259        cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc"
1260        cuda_defines["%{win_compiler_deps}"] = ":windows_msvc_wrapper_files"
1261
1262        wrapper_defines = {
1263            "%{cpu_compiler}": str(cc),
1264            "%{cuda_version}": cuda_config.cuda_version,
1265            "%{nvcc_path}": nvcc_path,
1266            "%{gcc_host_compiler_path}": str(cc),
1267            "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx),
1268        }
1269        repository_ctx.template(
1270            "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
1271            tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"],
1272            wrapper_defines,
1273        )
1274        repository_ctx.template(
1275            "crosstool/windows/msvc_wrapper_for_nvcc.py",
1276            tpl_paths["crosstool:windows/msvc_wrapper_for_nvcc.py"],
1277            wrapper_defines,
1278        )
1279
1280    cuda_defines.update(_get_win_cuda_defines(repository_ctx))
1281
1282    verify_build_defines(cuda_defines)
1283
1284    # Only expand template variables in the BUILD file
1285    repository_ctx.template(
1286        "crosstool/BUILD",
1287        tpl_paths["crosstool:BUILD"],
1288        cuda_defines,
1289    )
1290
1291    # No templating of cc_toolchain_config - use attributes and templatize the
1292    # BUILD file.
1293    repository_ctx.template(
1294        "crosstool/cc_toolchain_config.bzl",
1295        tpl_paths["crosstool:cc_toolchain_config.bzl"],
1296        {},
1297    )
1298
1299    # Set up cuda_config.h, which is used by
1300    # tensorflow/stream_executor/dso_loader.cc.
1301    repository_ctx.template(
1302        "cuda/cuda/cuda_config.h",
1303        tpl_paths["cuda:cuda_config.h"],
1304        {
1305            "%{cuda_version}": cuda_config.cuda_version,
1306            "%{cudart_version}": cuda_config.cudart_version,
1307            "%{cublas_version}": cuda_config.cublas_version,
1308            "%{cusolver_version}": cuda_config.cusolver_version,
1309            "%{curand_version}": cuda_config.curand_version,
1310            "%{cufft_version}": cuda_config.cufft_version,
1311            "%{cusparse_version}": cuda_config.cusparse_version,
1312            "%{cudnn_version}": cuda_config.cudnn_version,
1313            "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
1314            "%{cuda_compute_capabilities}": ", ".join([
1315                cc.split("_")[1]
1316                for cc in cuda_config.compute_capabilities
1317            ]),
1318        },
1319    )
1320
1321    # Set up cuda_config.py, which is used by gen_build_info to provide
1322    # static build environment info to the API
1323    repository_ctx.template(
1324        "cuda/cuda/cuda_config.py",
1325        tpl_paths["cuda:cuda_config.py"],
1326        _py_tmpl_dict({
1327            "cuda_version": cuda_config.cuda_version,
1328            "cudnn_version": cuda_config.cudnn_version,
1329            "cuda_compute_capabilities": cuda_config.compute_capabilities,
1330            "cpu_compiler": str(cc),
1331        }),
1332    )
1333
1334def _py_tmpl_dict(d):
1335    return {"%{cuda_config}": str(d)}
1336
1337def _create_remote_cuda_repository(repository_ctx, remote_config_repo):
1338    """Creates pointers to a remotely configured repo set up to build with CUDA."""
1339    _tpl(
1340        repository_ctx,
1341        "cuda:build_defs.bzl",
1342        {
1343            "%{cuda_is_configured}": "True",
1344            "%{cuda_extra_copts}": _compute_cuda_extra_copts(
1345                repository_ctx,
1346                compute_capabilities(repository_ctx),
1347            ),
1348        },
1349    )
1350    repository_ctx.template(
1351        "cuda/BUILD",
1352        config_repo_label(remote_config_repo, "cuda:BUILD"),
1353        {},
1354    )
1355    repository_ctx.template(
1356        "cuda/build_defs.bzl",
1357        config_repo_label(remote_config_repo, "cuda:build_defs.bzl"),
1358        {},
1359    )
1360    repository_ctx.template(
1361        "cuda/cuda/cuda_config.h",
1362        config_repo_label(remote_config_repo, "cuda:cuda/cuda_config.h"),
1363        {},
1364    )
1365    repository_ctx.template(
1366        "cuda/cuda/cuda_config.py",
1367        config_repo_label(remote_config_repo, "cuda:cuda/cuda_config.py"),
1368        _py_tmpl_dict({}),
1369    )
1370
1371    repository_ctx.template(
1372        "crosstool/BUILD",
1373        config_repo_label(remote_config_repo, "crosstool:BUILD"),
1374        {},
1375    )
1376
1377    repository_ctx.template(
1378        "crosstool/cc_toolchain_config.bzl",
1379        config_repo_label(remote_config_repo, "crosstool:cc_toolchain_config.bzl"),
1380        {},
1381    )
1382
1383    repository_ctx.template(
1384        "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
1385        config_repo_label(remote_config_repo, "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"),
1386        {},
1387    )
1388
1389def _cuda_autoconf_impl(repository_ctx):
1390    """Implementation of the cuda_autoconf repository rule."""
1391    build_file = Label("//third_party/gpus:local_config_cuda.BUILD")
1392
1393    if not enable_cuda(repository_ctx):
1394        _create_dummy_repository(repository_ctx)
1395    elif get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO) != None:
1396        has_cuda_version = get_host_environ(repository_ctx, _TF_CUDA_VERSION) != None
1397        has_cudnn_version = get_host_environ(repository_ctx, _TF_CUDNN_VERSION) != None
1398        if not has_cuda_version or not has_cudnn_version:
1399            auto_configure_fail("%s and %s must also be set if %s is specified" %
1400                                (_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO))
1401        _create_remote_cuda_repository(
1402            repository_ctx,
1403            get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO),
1404        )
1405    else:
1406        _create_local_cuda_repository(repository_ctx)
1407
1408    repository_ctx.symlink(build_file, "BUILD")
1409
1410# For @bazel_tools//tools/cpp:windows_cc_configure.bzl
1411_MSVC_ENVVARS = [
1412    "BAZEL_VC",
1413    "BAZEL_VC_FULL_VERSION",
1414    "BAZEL_VS",
1415    "BAZEL_WINSDK_FULL_VERSION",
1416    "VS90COMNTOOLS",
1417    "VS100COMNTOOLS",
1418    "VS110COMNTOOLS",
1419    "VS120COMNTOOLS",
1420    "VS140COMNTOOLS",
1421    "VS150COMNTOOLS",
1422    "VS160COMNTOOLS",
1423]
1424
1425_ENVIRONS = [
1426    _GCC_HOST_COMPILER_PATH,
1427    _GCC_HOST_COMPILER_PREFIX,
1428    _CLANG_CUDA_COMPILER_PATH,
1429    "TF_NEED_CUDA",
1430    "TF_CUDA_CLANG",
1431    _TF_DOWNLOAD_CLANG,
1432    _CUDA_TOOLKIT_PATH,
1433    _CUDNN_INSTALL_PATH,
1434    _TF_CUDA_VERSION,
1435    _TF_CUDNN_VERSION,
1436    _TF_CUDA_COMPUTE_CAPABILITIES,
1437    "NVVMIR_LIBRARY_DIR",
1438    _PYTHON_BIN_PATH,
1439    "TMP",
1440    "TMPDIR",
1441    "TF_CUDA_PATHS",
1442] + _MSVC_ENVVARS
1443
1444remote_cuda_configure = repository_rule(
1445    implementation = _create_local_cuda_repository,
1446    environ = _ENVIRONS,
1447    remotable = True,
1448    attrs = {
1449        "environ": attr.string_dict(),
1450    },
1451)
1452
1453cuda_configure = repository_rule(
1454    implementation = _cuda_autoconf_impl,
1455    environ = _ENVIRONS + [_TF_CUDA_CONFIG_REPO],
1456)
1457"""Detects and configures the local CUDA toolchain.
1458
1459Add the following to your WORKSPACE FILE:
1460
1461```python
1462cuda_configure(name = "local_config_cuda")
1463```
1464
1465Args:
1466  name: A unique name for this workspace rule.
1467"""
1468