xref: /aosp_15_r20/external/tensorflow/third_party/gpus/rocm_configure.bzl (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1"""Repository rule for ROCm autoconfiguration.
2
3`rocm_configure` depends on the following environment variables:
4
5  * `TF_NEED_ROCM`: Whether to enable building with ROCm.
6  * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
7  * `ROCM_PATH`: The path to the ROCm toolkit. Default is `/opt/rocm`.
8  * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets.
9"""
10
11load(
12    ":cuda_configure.bzl",
13    "make_copy_dir_rule",
14    "make_copy_files_rule",
15    "to_list_of_strings",
16)
17load(
18    "//third_party/remote_config:common.bzl",
19    "config_repo_label",
20    "err_out",
21    "execute",
22    "files_exist",
23    "get_bash_bin",
24    "get_cpu_value",
25    "get_host_environ",
26    "get_python_bin",
27    "raw_exec",
28    "realpath",
29    "which",
30)
31
32_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
33_GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX"
34_ROCM_TOOLKIT_PATH = "ROCM_PATH"
35_TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS"
36_TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO"
37
38_DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm"
39
40def verify_build_defines(params):
41    """Verify all variables that crosstool/BUILD.rocm.tpl expects are substituted.
42
43    Args:
44      params: dict of variables that will be passed to the BUILD.tpl template.
45    """
46    missing = []
47    for param in [
48        "cxx_builtin_include_directories",
49        "extra_no_canonical_prefixes_flags",
50        "host_compiler_path",
51        "host_compiler_prefix",
52        "linker_bin_path",
53        "unfiltered_compile_flags",
54    ]:
55        if ("%{" + param + "}") not in params:
56            missing.append(param)
57
58    if missing:
59        auto_configure_fail(
60            "BUILD.rocm.tpl template is missing these variables: " +
61            str(missing) +
62            ".\nWe only got: " +
63            str(params) +
64            ".",
65        )
66
67def find_cc(repository_ctx):
68    """Find the C++ compiler."""
69
70    # Return a dummy value for GCC detection here to avoid error
71    target_cc_name = "gcc"
72    cc_path_envvar = _GCC_HOST_COMPILER_PATH
73    cc_name = target_cc_name
74
75    cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
76    if cc_name_from_env:
77        cc_name = cc_name_from_env
78    if cc_name.startswith("/"):
79        # Absolute path, maybe we should make this supported by our which function.
80        return cc_name
81    cc = which(repository_ctx, cc_name)
82    if cc == None:
83        fail(("Cannot find {}, either correct your path or set the {}" +
84              " environment variable").format(target_cc_name, cc_path_envvar))
85    return cc
86
87_INC_DIR_MARKER_BEGIN = "#include <...>"
88
89def _cxx_inc_convert(path):
90    """Convert path returned by cc -E xc++ in a complete path."""
91    path = path.strip()
92    return path
93
94def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp):
95    """Compute the list of default C or C++ include directories."""
96    if lang_is_cpp:
97        lang = "c++"
98    else:
99        lang = "c"
100
101    # TODO: We pass -no-canonical-prefixes here to match the compiler flags,
102    #       but in rocm_clang CROSSTOOL file that is a `feature` and we should
103    #       handle the case when it's disabled and no flag is passed
104    result = raw_exec(repository_ctx, [
105        cc,
106        "-no-canonical-prefixes",
107        "-E",
108        "-x" + lang,
109        "-",
110        "-v",
111    ])
112    stderr = err_out(result)
113    index1 = stderr.find(_INC_DIR_MARKER_BEGIN)
114    if index1 == -1:
115        return []
116    index1 = stderr.find("\n", index1)
117    if index1 == -1:
118        return []
119    index2 = stderr.rfind("\n ")
120    if index2 == -1 or index2 < index1:
121        return []
122    index2 = stderr.find("\n", index2 + 1)
123    if index2 == -1:
124        inc_dirs = stderr[index1 + 1:]
125    else:
126        inc_dirs = stderr[index1 + 1:index2].strip()
127
128    return [
129        str(repository_ctx.path(_cxx_inc_convert(p)))
130        for p in inc_dirs.split("\n")
131    ]
132
133def get_cxx_inc_directories(repository_ctx, cc):
134    """Compute the list of default C and C++ include directories."""
135
136    # For some reason `clang -xc` sometimes returns include paths that are
137    # different from the ones from `clang -xc++`. (Symlink and a dir)
138    # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
139    includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True)
140    includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False)
141
142    includes_cpp_set = depset(includes_cpp)
143    return includes_cpp + [
144        inc
145        for inc in includes_c
146        if inc not in includes_cpp_set.to_list()
147    ]
148
149def auto_configure_fail(msg):
150    """Output failure message when rocm configuration fails."""
151    red = "\033[0;31m"
152    no_color = "\033[0m"
153    fail("\n%sROCm Configuration Error:%s %s\n" % (red, no_color, msg))
154
155def auto_configure_warning(msg):
156    """Output warning message during auto configuration."""
157    yellow = "\033[1;33m"
158    no_color = "\033[0m"
159    print("\n%sAuto-Configuration Warning:%s %s\n" % (yellow, no_color, msg))
160
161# END cc_configure common functions (see TODO above).
162
163def _rocm_include_path(repository_ctx, rocm_config, bash_bin):
164    """Generates the cxx_builtin_include_directory entries for rocm inc dirs.
165
166    Args:
167      repository_ctx: The repository context.
168      rocm_config: The path to the gcc host compiler.
169
170    Returns:
171      A string containing the Starlark string for each of the gcc
172      host compiler include directories, which can be added to the CROSSTOOL
173      file.
174    """
175    inc_dirs = []
176
177    # Add HSA headers (needs to match $HSA_PATH)
178    inc_dirs.append(rocm_config.rocm_toolkit_path + "/hsa/include")
179
180    # Add HIP headers (needs to match $HIP_PATH)
181    inc_dirs.append(rocm_config.rocm_toolkit_path + "/hip/include")
182
183    # Add HIP-Clang headers (realpath relative to compiler binary)
184    rocm_toolkit_path = realpath(repository_ctx, rocm_config.rocm_toolkit_path, bash_bin)
185    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/8.0/include")
186    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include")
187    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include")
188    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/11.0.0/include")
189    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/12.0.0/include")
190    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/13.0.0/include")
191    inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/14.0.0/include")
192
193    # Support hcc based off clang 10.0.0 (for ROCm 3.3)
194    inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")
195    inc_dirs.append(rocm_toolkit_path + "/hcc/lib/clang/10.0.0/include")
196
197    # Add hcc headers
198    inc_dirs.append(rocm_toolkit_path + "/hcc/include")
199
200    return inc_dirs
201
202def _enable_rocm(repository_ctx):
203    enable_rocm = get_host_environ(repository_ctx, "TF_NEED_ROCM")
204    if enable_rocm == "1":
205        if get_cpu_value(repository_ctx) != "Linux":
206            auto_configure_warning("ROCm configure is only supported on Linux")
207            return False
208        return True
209    return False
210
211def _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin):
212    """Returns a list of strings representing AMDGPU targets."""
213    amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS)
214    if not amdgpu_targets_str:
215        cmd = "%s/bin/rocm_agent_enumerator" % rocm_toolkit_path
216        result = execute(repository_ctx, [bash_bin, "-c", cmd])
217        targets = [target for target in result.stdout.strip().split("\n") if target != "gfx000"]
218        targets = {x: None for x in targets}
219        targets = list(targets.keys())
220        amdgpu_targets_str = ",".join(targets)
221    amdgpu_targets = amdgpu_targets_str.split(",")
222    for amdgpu_target in amdgpu_targets:
223        if amdgpu_target[:3] != "gfx":
224            auto_configure_fail("Invalid AMDGPU target: %s" % amdgpu_target)
225    return amdgpu_targets
226
227def _hipcc_env(repository_ctx):
228    """Returns the environment variable string for hipcc.
229
230    Args:
231        repository_ctx: The repository context.
232
233    Returns:
234        A string containing environment variables for hipcc.
235    """
236    hipcc_env = ""
237    for name in [
238        "HIP_CLANG_PATH",
239        "DEVICE_LIB_PATH",
240        "HIP_VDI_HOME",
241        "HIPCC_VERBOSE",
242        "HIPCC_COMPILE_FLAGS_APPEND",
243        "HIPPCC_LINK_FLAGS_APPEND",
244        "HCC_AMDGPU_TARGET",
245        "HIP_PLATFORM",
246    ]:
247        env_value = get_host_environ(repository_ctx, name)
248        if env_value:
249            hipcc_env = (hipcc_env + " " + name + "=\"" + env_value + "\";")
250    return hipcc_env.strip()
251
252def _crosstool_verbose(repository_ctx):
253    """Returns the environment variable value CROSSTOOL_VERBOSE.
254
255    Args:
256        repository_ctx: The repository context.
257
258    Returns:
259        A string containing value of environment variable CROSSTOOL_VERBOSE.
260    """
261    return get_host_environ(repository_ctx, "CROSSTOOL_VERBOSE", "0")
262
263def _lib_name(lib, version = "", static = False):
264    """Constructs the name of a library on Linux.
265
266    Args:
267      lib: The name of the library, such as "hip"
268      version: The version of the library.
269      static: True the library is static or False if it is a shared object.
270
271    Returns:
272      The platform-specific name of the library.
273    """
274    if static:
275        return "lib%s.a" % lib
276    else:
277        if version:
278            version = ".%s" % version
279        return "lib%s.so%s" % (lib, version)
280
281def _rocm_lib_paths(repository_ctx, lib, basedir):
282    file_name = _lib_name(lib, version = "", static = False)
283    return [
284        repository_ctx.path("%s/lib64/%s" % (basedir, file_name)),
285        repository_ctx.path("%s/lib64/stubs/%s" % (basedir, file_name)),
286        repository_ctx.path("%s/lib/x86_64-linux-gnu/%s" % (basedir, file_name)),
287        repository_ctx.path("%s/lib/%s" % (basedir, file_name)),
288        repository_ctx.path("%s/%s" % (basedir, file_name)),
289    ]
290
291def _batch_files_exist(repository_ctx, libs_paths, bash_bin):
292    all_paths = []
293    for _, lib_paths in libs_paths:
294        for lib_path in lib_paths:
295            all_paths.append(lib_path)
296    return files_exist(repository_ctx, all_paths, bash_bin)
297
298def _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin):
299    test_results = _batch_files_exist(repository_ctx, libs_paths, bash_bin)
300
301    libs = {}
302    i = 0
303    for name, lib_paths in libs_paths:
304        selected_path = None
305        for path in lib_paths:
306            if test_results[i] and selected_path == None:
307                # For each lib select the first path that exists.
308                selected_path = path
309            i = i + 1
310        if selected_path == None:
311            auto_configure_fail("Cannot find rocm library %s" % name)
312
313        libs[name] = struct(file_name = selected_path.basename, path = realpath(repository_ctx, selected_path, bash_bin))
314
315    return libs
316
317def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, bash_bin):
318    """Returns the ROCm libraries on the system.
319
320    Args:
321      repository_ctx: The repository context.
322      rocm_config: The ROCm config as returned by _get_rocm_config
323      bash_bin: the path to the bash interpreter
324
325    Returns:
326      Map of library names to structs of filename and path
327    """
328    libs_paths = [
329        (name, _rocm_lib_paths(repository_ctx, name, path))
330        for name, path in [
331            ("amdhip64", rocm_config.rocm_toolkit_path + "/hip"),
332            ("rocblas", rocm_config.rocm_toolkit_path + "/rocblas"),
333            (hipfft_or_rocfft, rocm_config.rocm_toolkit_path + "/" + hipfft_or_rocfft),
334            ("hiprand", rocm_config.rocm_toolkit_path),
335            ("MIOpen", rocm_config.rocm_toolkit_path + "/miopen"),
336            ("rccl", rocm_config.rocm_toolkit_path + "/rccl"),
337            ("hipsparse", rocm_config.rocm_toolkit_path + "/hipsparse"),
338            ("roctracer64", rocm_config.rocm_toolkit_path + "/roctracer"),
339            ("rocsolver", rocm_config.rocm_toolkit_path + "/rocsolver"),
340        ]
341    ]
342    if int(rocm_config.rocm_version_number) >= 40500:
343        libs_paths.append(("hipsolver", _rocm_lib_paths(repository_ctx, "hipsolver", rocm_config.rocm_toolkit_path + "/hipsolver")))
344        libs_paths.append(("hipblas", _rocm_lib_paths(repository_ctx, "hipblas", rocm_config.rocm_toolkit_path + "/hipblas")))
345    return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin)
346
347def _exec_find_rocm_config(repository_ctx, script_path):
348    python_bin = get_python_bin(repository_ctx)
349
350    # If used with remote execution then repository_ctx.execute() can't
351    # access files from the source tree. A trick is to read the contents
352    # of the file in Starlark and embed them as part of the command. In
353    # this case the trick is not sufficient as the find_cuda_config.py
354    # script has more than 8192 characters. 8192 is the command length
355    # limit of cmd.exe on Windows. Thus we additionally need to compress
356    # the contents locally and decompress them as part of the execute().
357    compressed_contents = repository_ctx.read(script_path)
358    decompress_and_execute_cmd = (
359        "from zlib import decompress;" +
360        "from base64 import b64decode;" +
361        "from os import system;" +
362        "script = decompress(b64decode('%s'));" % compressed_contents +
363        "f = open('script.py', 'wb');" +
364        "f.write(script);" +
365        "f.close();" +
366        "system('\"%s\" script.py');" % (python_bin)
367    )
368
369    return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd])
370
371def find_rocm_config(repository_ctx, script_path):
372    """Returns ROCm config dictionary from running find_rocm_config.py"""
373    exec_result = _exec_find_rocm_config(repository_ctx, script_path)
374    if exec_result.return_code:
375        auto_configure_fail("Failed to run find_rocm_config.py: %s" % err_out(exec_result))
376
377    # Parse the dict from stdout.
378    return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()])
379
380def _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script):
381    """Detects and returns information about the ROCm installation on the system.
382
383    Args:
384      repository_ctx: The repository context.
385      bash_bin: the path to the path interpreter
386
387    Returns:
388      A struct containing the following fields:
389        rocm_toolkit_path: The ROCm toolkit installation directory.
390        amdgpu_targets: A list of the system's AMDGPU targets.
391        rocm_version_number: The version of ROCm on the system.
392        miopen_version_number: The version of MIOpen on the system.
393        hipruntime_version_number: The version of HIP Runtime on the system.
394    """
395    config = find_rocm_config(repository_ctx, find_rocm_config_script)
396    rocm_toolkit_path = config["rocm_toolkit_path"]
397    rocm_version_number = config["rocm_version_number"]
398    miopen_version_number = config["miopen_version_number"]
399    hipruntime_version_number = config["hipruntime_version_number"]
400    return struct(
401        amdgpu_targets = _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin),
402        rocm_toolkit_path = rocm_toolkit_path,
403        rocm_version_number = rocm_version_number,
404        miopen_version_number = miopen_version_number,
405        hipruntime_version_number = hipruntime_version_number,
406    )
407
408def _tpl_path(repository_ctx, labelname):
409    return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % labelname))
410
411def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
412    if not out:
413        out = tpl.replace(":", "/")
414    repository_ctx.template(
415        out,
416        _tpl_path(repository_ctx, tpl),
417        substitutions,
418    )
419
420_DUMMY_CROSSTOOL_BZL_FILE = """
421def error_gpu_disabled():
422  fail("ERROR: Building with --config=rocm but TensorFlow is not configured " +
423       "to build with GPU support. Please re-run ./configure and enter 'Y' " +
424       "at the prompt to build with GPU support.")
425
426  native.genrule(
427      name = "error_gen_crosstool",
428      outs = ["CROSSTOOL"],
429      cmd = "echo 'Should not be run.' && exit 1",
430  )
431
432  native.filegroup(
433      name = "crosstool",
434      srcs = [":CROSSTOOL"],
435      output_licenses = ["unencumbered"],
436  )
437"""
438
439_DUMMY_CROSSTOOL_BUILD_FILE = """
440load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled")
441
442error_gpu_disabled()
443"""
444
445def _create_dummy_repository(repository_ctx):
446    # Set up BUILD file for rocm/.
447    _tpl(
448        repository_ctx,
449        "rocm:build_defs.bzl",
450        {
451            "%{rocm_is_configured}": "False",
452            "%{rocm_extra_copts}": "[]",
453            "%{rocm_gpu_architectures}": "[]",
454            "%{rocm_version_number}": "0",
455        },
456    )
457    _tpl(
458        repository_ctx,
459        "rocm:BUILD",
460        {
461            "%{hip_lib}": _lib_name("hip"),
462            "%{rocblas_lib}": _lib_name("rocblas"),
463            "%{hipblas_lib}": _lib_name("hipblas"),
464            "%{miopen_lib}": _lib_name("miopen"),
465            "%{rccl_lib}": _lib_name("rccl"),
466            "%{hipfft_or_rocfft}": "hipfft",
467            "%{hipfft_or_rocfft_lib}": _lib_name("hipfft"),
468            "%{hiprand_lib}": _lib_name("hiprand"),
469            "%{hipsparse_lib}": _lib_name("hipsparse"),
470            "%{roctracer_lib}": _lib_name("roctracer64"),
471            "%{rocsolver_lib}": _lib_name("rocsolver"),
472            "%{hipsolver_lib}": _lib_name("hipsolver"),
473            "%{copy_rules}": "",
474            "%{rocm_headers}": "",
475        },
476    )
477
478    # Create dummy files for the ROCm toolkit since they are still required by
479    # tensorflow/tsl/platform/default/build_config:rocm.
480    repository_ctx.file("rocm/hip/include/hip/hip_runtime.h", "")
481
482    # Set up rocm_config.h, which is used by
483    # tensorflow/stream_executor/dso_loader.cc.
484    _tpl(
485        repository_ctx,
486        "rocm:rocm_config.h",
487        {
488            "%{rocm_toolkit_path}": _DEFAULT_ROCM_TOOLKIT_PATH,
489        },
490        "rocm/rocm/rocm_config.h",
491    )
492
493    # If rocm_configure is not configured to build with GPU support, and the user
494    # attempts to build with --config=rocm, add a dummy build rule to intercept
495    # this and fail with an actionable error message.
496    repository_ctx.file(
497        "crosstool/error_gpu_disabled.bzl",
498        _DUMMY_CROSSTOOL_BZL_FILE,
499    )
500    repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
501
502def _norm_path(path):
503    """Returns a path with '/' and remove the trailing slash."""
504    path = path.replace("\\", "/")
505    if path[-1] == "/":
506        path = path[:-1]
507    return path
508
509def _genrule(src_dir, genrule_name, command, outs):
510    """Returns a string with a genrule.
511
512    Genrule executes the given command and produces the given outputs.
513    """
514    return (
515        "genrule(\n" +
516        '    name = "' +
517        genrule_name + '",\n' +
518        "    outs = [\n" +
519        outs +
520        "\n    ],\n" +
521        '    cmd = """\n' +
522        command +
523        '\n   """,\n' +
524        ")\n"
525    )
526
527def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets):
528    amdgpu_target_flags = ["--amdgpu-target=" +
529                           amdgpu_target for amdgpu_target in amdgpu_targets]
530    return str(amdgpu_target_flags)
531
532def _create_local_rocm_repository(repository_ctx):
533    """Creates the repository containing files set up to build with ROCm."""
534
535    tpl_paths = {labelname: _tpl_path(repository_ctx, labelname) for labelname in [
536        "rocm:build_defs.bzl",
537        "rocm:BUILD",
538        "crosstool:BUILD.rocm",
539        "crosstool:hipcc_cc_toolchain_config.bzl",
540        "crosstool:clang/bin/crosstool_wrapper_driver_rocm",
541        "rocm:rocm_config.h",
542    ]}
543
544    find_rocm_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_rocm_config.py.gz.base64"))
545
546    bash_bin = get_bash_bin(repository_ctx)
547    rocm_config = _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script)
548
549    # For ROCm 4.1 and above use hipfft, older ROCm versions use rocfft
550    rocm_version_number = int(rocm_config.rocm_version_number)
551    hipfft_or_rocfft = "rocfft" if rocm_version_number < 40100 else "hipfft"
552
553    # Copy header and library files to execroot.
554    # rocm_toolkit_path
555    rocm_toolkit_path = rocm_config.rocm_toolkit_path
556    copy_rules = [
557        make_copy_dir_rule(
558            repository_ctx,
559            name = "rocm-include",
560            src_dir = rocm_toolkit_path + "/include",
561            out_dir = "rocm/include",
562            exceptions = ["gtest", "gmock"],
563        ),
564        make_copy_dir_rule(
565            repository_ctx,
566            name = hipfft_or_rocfft + "-include",
567            src_dir = rocm_toolkit_path + "/" + hipfft_or_rocfft + "/include",
568            out_dir = "rocm/include/" + hipfft_or_rocfft,
569        ),
570        make_copy_dir_rule(
571            repository_ctx,
572            name = "rocblas-include",
573            src_dir = rocm_toolkit_path + "/rocblas/include",
574            out_dir = "rocm/include/rocblas",
575        ),
576        make_copy_dir_rule(
577            repository_ctx,
578            name = "rocblas-hsaco",
579            src_dir = rocm_toolkit_path + "/rocblas/lib/library",
580            out_dir = "rocm/lib/rocblas/lib/library",
581        ),
582        make_copy_dir_rule(
583            repository_ctx,
584            name = "miopen-include",
585            src_dir = rocm_toolkit_path + "/miopen/include",
586            out_dir = "rocm/include/miopen",
587        ),
588        make_copy_dir_rule(
589            repository_ctx,
590            name = "rccl-include",
591            src_dir = rocm_toolkit_path + "/rccl/include",
592            out_dir = "rocm/include/rccl",
593        ),
594        make_copy_dir_rule(
595            repository_ctx,
596            name = "hipsparse-include",
597            src_dir = rocm_toolkit_path + "/hipsparse/include",
598            out_dir = "rocm/include/hipsparse",
599        ),
600        make_copy_dir_rule(
601            repository_ctx,
602            name = "rocsolver-include",
603            src_dir = rocm_toolkit_path + "/rocsolver/include",
604            out_dir = "rocm/include/rocsolver",
605        ),
606    ]
607
608    # Add Hipsolver on ROCm4.5+
609    if rocm_version_number >= 40500:
610        copy_rules.append(
611            make_copy_dir_rule(
612                repository_ctx,
613                name = "hipsolver-include",
614                src_dir = rocm_toolkit_path + "/hipsolver/include",
615                out_dir = "rocm/include/hipsolver",
616            ),
617        )
618        copy_rules.append(
619            make_copy_dir_rule(
620                repository_ctx,
621                name = "hipblas-include",
622                src_dir = rocm_toolkit_path + "/hipblas/include",
623                out_dir = "rocm/include/hipblas",
624            ),
625        )
626
627    # explicitly copy (into the local_config_rocm repo) the $ROCM_PATH/hiprand/include and
628    # $ROCM_PATH/rocrand/include dirs, only once the softlink to them in $ROCM_PATH/include
629    # dir has been removed. This removal will happen in a near-future ROCm release.
630    hiprand_include = ""
631    hiprand_include_softlink = rocm_config.rocm_toolkit_path + "/include/hiprand"
632    softlink_exists = files_exist(repository_ctx, [hiprand_include_softlink], bash_bin)
633    if not softlink_exists[0]:
634        hiprand_include = '":hiprand-include",\n'
635        copy_rules.append(
636            make_copy_dir_rule(
637                repository_ctx,
638                name = "hiprand-include",
639                src_dir = rocm_toolkit_path + "/hiprand/include",
640                out_dir = "rocm/include/hiprand",
641            ),
642        )
643
644    rocrand_include = ""
645    rocrand_include_softlink = rocm_config.rocm_toolkit_path + "/include/rocrand"
646    softlink_exists = files_exist(repository_ctx, [rocrand_include_softlink], bash_bin)
647    if not softlink_exists[0]:
648        rocrand_include = '":rocrand-include",\n'
649        copy_rules.append(
650            make_copy_dir_rule(
651                repository_ctx,
652                name = "rocrand-include",
653                src_dir = rocm_toolkit_path + "/rocrand/include",
654                out_dir = "rocm/include/rocrand",
655            ),
656        )
657
658    rocm_libs = _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, bash_bin)
659    rocm_lib_srcs = []
660    rocm_lib_outs = []
661    for lib in rocm_libs.values():
662        rocm_lib_srcs.append(lib.path)
663        rocm_lib_outs.append("rocm/lib/" + lib.file_name)
664    copy_rules.append(make_copy_files_rule(
665        repository_ctx,
666        name = "rocm-lib",
667        srcs = rocm_lib_srcs,
668        outs = rocm_lib_outs,
669    ))
670
671    clang_offload_bundler_path = rocm_toolkit_path + "/llvm/bin/clang-offload-bundler"
672
673    # copy files mentioned in third_party/gpus/rocm/BUILD
674    copy_rules.append(make_copy_files_rule(
675        repository_ctx,
676        name = "rocm-bin",
677        srcs = [
678            clang_offload_bundler_path,
679        ],
680        outs = [
681            "rocm/bin/" + "clang-offload-bundler",
682        ],
683    ))
684
685    # Set up BUILD file for rocm/
686    repository_ctx.template(
687        "rocm/build_defs.bzl",
688        tpl_paths["rocm:build_defs.bzl"],
689        {
690            "%{rocm_is_configured}": "True",
691            "%{rocm_extra_copts}": _compute_rocm_extra_copts(
692                repository_ctx,
693                rocm_config.amdgpu_targets,
694            ),
695            "%{rocm_gpu_architectures}": str(rocm_config.amdgpu_targets),
696            "%{rocm_version_number}": str(rocm_version_number),
697        },
698    )
699
700    repository_dict = {
701        "%{hip_lib}": rocm_libs["amdhip64"].file_name,
702        "%{rocblas_lib}": rocm_libs["rocblas"].file_name,
703        "%{hipfft_or_rocfft}": hipfft_or_rocfft,
704        "%{hipfft_or_rocfft_lib}": rocm_libs[hipfft_or_rocfft].file_name,
705        "%{hiprand_lib}": rocm_libs["hiprand"].file_name,
706        "%{miopen_lib}": rocm_libs["MIOpen"].file_name,
707        "%{rccl_lib}": rocm_libs["rccl"].file_name,
708        "%{hipsparse_lib}": rocm_libs["hipsparse"].file_name,
709        "%{roctracer_lib}": rocm_libs["roctracer64"].file_name,
710        "%{rocsolver_lib}": rocm_libs["rocsolver"].file_name,
711        "%{copy_rules}": "\n".join(copy_rules),
712        "%{rocm_headers}": ('":rocm-include",\n' +
713                            '":' + hipfft_or_rocfft + '-include",\n' +
714                            '":rocblas-include",\n' +
715                            '":miopen-include",\n' +
716                            '":rccl-include",\n' +
717                            hiprand_include +
718                            rocrand_include +
719                            '":hipsparse-include",\n' +
720                            '":rocsolver-include"'),
721    }
722    if rocm_version_number >= 40500:
723        repository_dict["%{hipsolver_lib}"] = rocm_libs["hipsolver"].file_name
724        repository_dict["%{rocm_headers}"] += ',\n":hipsolver-include"'
725        repository_dict["%{hipblas_lib}"] = rocm_libs["hipblas"].file_name
726        repository_dict["%{rocm_headers}"] += ',\n":hipblas-include"'
727
728    repository_ctx.template(
729        "rocm/BUILD",
730        tpl_paths["rocm:BUILD"],
731        repository_dict,
732    )
733
734    # Set up crosstool/
735
736    cc = find_cc(repository_ctx)
737
738    host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc)
739
740    host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX, "/usr/bin")
741
742    rocm_defines = {}
743
744    rocm_defines["%{host_compiler_prefix}"] = host_compiler_prefix
745
746    rocm_defines["%{linker_bin_path}"] = rocm_config.rocm_toolkit_path + "/hcc/compiler/bin"
747
748    # For gcc, do not canonicalize system header paths; some versions of gcc
749    # pick the shortest possible path for system includes when creating the
750    # .d file - given that includes that are prefixed with "../" multiple
751    # time quickly grow longer than the root of the tree, this can lead to
752    # bazel's header check failing.
753    rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\""
754
755    rocm_defines["%{unfiltered_compile_flags}"] = to_list_of_strings([
756        "-DTENSORFLOW_USE_ROCM=1",
757        "-D__HIP_PLATFORM_HCC__",
758        "-DEIGEN_USE_HIP",
759    ])
760
761    rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
762
763    rocm_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings(
764        host_compiler_includes + _rocm_include_path(repository_ctx, rocm_config, bash_bin),
765    )
766
767    verify_build_defines(rocm_defines)
768
769    # Only expand template variables in the BUILD file
770    repository_ctx.template(
771        "crosstool/BUILD",
772        tpl_paths["crosstool:BUILD.rocm"],
773        rocm_defines,
774    )
775
776    # No templating of cc_toolchain_config - use attributes and templatize the
777    # BUILD file.
778    repository_ctx.template(
779        "crosstool/cc_toolchain_config.bzl",
780        tpl_paths["crosstool:hipcc_cc_toolchain_config.bzl"],
781    )
782
783    repository_ctx.template(
784        "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
785        tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_rocm"],
786        {
787            "%{cpu_compiler}": str(cc),
788            "%{hipcc_path}": rocm_config.rocm_toolkit_path + "/hip/bin/hipcc",
789            "%{hipcc_env}": _hipcc_env(repository_ctx),
790            "%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib",
791            "%{rocr_runtime_library}": "hsa-runtime64",
792            "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/hip/lib",
793            "%{hip_runtime_library}": "amdhip64",
794            "%{crosstool_verbose}": _crosstool_verbose(repository_ctx),
795            "%{gcc_host_compiler_path}": str(cc),
796        },
797    )
798
799    # Set up rocm_config.h, which is used by
800    # tensorflow/stream_executor/dso_loader.cc.
801    repository_ctx.template(
802        "rocm/rocm/rocm_config.h",
803        tpl_paths["rocm:rocm_config.h"],
804        {
805            "%{rocm_amdgpu_targets}": ",".join(
806                ["\"%s\"" % c for c in rocm_config.amdgpu_targets],
807            ),
808            "%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path,
809            "%{rocm_version_number}": rocm_config.rocm_version_number,
810            "%{miopen_version_number}": rocm_config.miopen_version_number,
811            "%{hipruntime_version_number}": rocm_config.hipruntime_version_number,
812        },
813    )
814
815def _create_remote_rocm_repository(repository_ctx, remote_config_repo):
816    """Creates pointers to a remotely configured repo set up to build with ROCm."""
817    _tpl(
818        repository_ctx,
819        "rocm:build_defs.bzl",
820        {
821            "%{rocm_is_configured}": "True",
822            "%{rocm_extra_copts}": _compute_rocm_extra_copts(
823                repository_ctx,
824                [],  #_compute_capabilities(repository_ctx)
825            ),
826        },
827    )
828    repository_ctx.template(
829        "rocm/BUILD",
830        config_repo_label(remote_config_repo, "rocm:BUILD"),
831        {},
832    )
833    repository_ctx.template(
834        "rocm/build_defs.bzl",
835        config_repo_label(remote_config_repo, "rocm:build_defs.bzl"),
836        {},
837    )
838    repository_ctx.template(
839        "rocm/rocm/rocm_config.h",
840        config_repo_label(remote_config_repo, "rocm:rocm/rocm_config.h"),
841        {},
842    )
843    repository_ctx.template(
844        "crosstool/BUILD",
845        config_repo_label(remote_config_repo, "crosstool:BUILD"),
846        {},
847    )
848    repository_ctx.template(
849        "crosstool/cc_toolchain_config.bzl",
850        config_repo_label(remote_config_repo, "crosstool:cc_toolchain_config.bzl"),
851        {},
852    )
853    repository_ctx.template(
854        "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
855        config_repo_label(remote_config_repo, "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"),
856        {},
857    )
858
859def _rocm_autoconf_impl(repository_ctx):
860    """Implementation of the rocm_autoconf repository rule."""
861    if not _enable_rocm(repository_ctx):
862        _create_dummy_repository(repository_ctx)
863    elif get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO) != None:
864        _create_remote_rocm_repository(
865            repository_ctx,
866            get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO),
867        )
868    else:
869        _create_local_rocm_repository(repository_ctx)
870
871_ENVIRONS = [
872    _GCC_HOST_COMPILER_PATH,
873    _GCC_HOST_COMPILER_PREFIX,
874    "TF_NEED_ROCM",
875    _ROCM_TOOLKIT_PATH,
876    _TF_ROCM_AMDGPU_TARGETS,
877]
878
879remote_rocm_configure = repository_rule(
880    implementation = _create_local_rocm_repository,
881    environ = _ENVIRONS,
882    remotable = True,
883    attrs = {
884        "environ": attr.string_dict(),
885    },
886)
887
888rocm_configure = repository_rule(
889    implementation = _rocm_autoconf_impl,
890    environ = _ENVIRONS + [_TF_ROCM_CONFIG_REPO],
891)
892"""Detects and configures the local ROCm toolchain.
893
894Add the following to your WORKSPACE FILE:
895
896```python
897rocm_configure(name = "local_config_rocm")
898```
899
900Args:
901  name: A unique name for this workspace rule.
902"""
903