1"""Repository rule for ROCm autoconfiguration. 2 3`rocm_configure` depends on the following environment variables: 4 5 * `TF_NEED_ROCM`: Whether to enable building with ROCm. 6 * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path 7 * `ROCM_PATH`: The path to the ROCm toolkit. Default is `/opt/rocm`. 8 * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. 9""" 10 11load( 12 ":cuda_configure.bzl", 13 "make_copy_dir_rule", 14 "make_copy_files_rule", 15 "to_list_of_strings", 16) 17load( 18 "//third_party/remote_config:common.bzl", 19 "config_repo_label", 20 "err_out", 21 "execute", 22 "files_exist", 23 "get_bash_bin", 24 "get_cpu_value", 25 "get_host_environ", 26 "get_python_bin", 27 "raw_exec", 28 "realpath", 29 "which", 30) 31 32_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" 33_GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX" 34_ROCM_TOOLKIT_PATH = "ROCM_PATH" 35_TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS" 36_TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO" 37 38_DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm" 39 40def verify_build_defines(params): 41 """Verify all variables that crosstool/BUILD.rocm.tpl expects are substituted. 42 43 Args: 44 params: dict of variables that will be passed to the BUILD.tpl template. 45 """ 46 missing = [] 47 for param in [ 48 "cxx_builtin_include_directories", 49 "extra_no_canonical_prefixes_flags", 50 "host_compiler_path", 51 "host_compiler_prefix", 52 "linker_bin_path", 53 "unfiltered_compile_flags", 54 ]: 55 if ("%{" + param + "}") not in params: 56 missing.append(param) 57 58 if missing: 59 auto_configure_fail( 60 "BUILD.rocm.tpl template is missing these variables: " + 61 str(missing) + 62 ".\nWe only got: " + 63 str(params) + 64 ".", 65 ) 66 67def find_cc(repository_ctx): 68 """Find the C++ compiler.""" 69 70 # Return a dummy value for GCC detection here to avoid error 71 target_cc_name = "gcc" 72 cc_path_envvar = _GCC_HOST_COMPILER_PATH 73 cc_name = target_cc_name 74 75 cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar) 76 if cc_name_from_env: 77 cc_name = cc_name_from_env 78 if cc_name.startswith("/"): 79 # Absolute path, maybe we should make this supported by our which function. 80 return cc_name 81 cc = which(repository_ctx, cc_name) 82 if cc == None: 83 fail(("Cannot find {}, either correct your path or set the {}" + 84 " environment variable").format(target_cc_name, cc_path_envvar)) 85 return cc 86 87_INC_DIR_MARKER_BEGIN = "#include <...>" 88 89def _cxx_inc_convert(path): 90 """Convert path returned by cc -E xc++ in a complete path.""" 91 path = path.strip() 92 return path 93 94def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp): 95 """Compute the list of default C or C++ include directories.""" 96 if lang_is_cpp: 97 lang = "c++" 98 else: 99 lang = "c" 100 101 # TODO: We pass -no-canonical-prefixes here to match the compiler flags, 102 # but in rocm_clang CROSSTOOL file that is a `feature` and we should 103 # handle the case when it's disabled and no flag is passed 104 result = raw_exec(repository_ctx, [ 105 cc, 106 "-no-canonical-prefixes", 107 "-E", 108 "-x" + lang, 109 "-", 110 "-v", 111 ]) 112 stderr = err_out(result) 113 index1 = stderr.find(_INC_DIR_MARKER_BEGIN) 114 if index1 == -1: 115 return [] 116 index1 = stderr.find("\n", index1) 117 if index1 == -1: 118 return [] 119 index2 = stderr.rfind("\n ") 120 if index2 == -1 or index2 < index1: 121 return [] 122 index2 = stderr.find("\n", index2 + 1) 123 if index2 == -1: 124 inc_dirs = stderr[index1 + 1:] 125 else: 126 inc_dirs = stderr[index1 + 1:index2].strip() 127 128 return [ 129 str(repository_ctx.path(_cxx_inc_convert(p))) 130 for p in inc_dirs.split("\n") 131 ] 132 133def get_cxx_inc_directories(repository_ctx, cc): 134 """Compute the list of default C and C++ include directories.""" 135 136 # For some reason `clang -xc` sometimes returns include paths that are 137 # different from the ones from `clang -xc++`. (Symlink and a dir) 138 # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists 139 includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True) 140 includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False) 141 142 includes_cpp_set = depset(includes_cpp) 143 return includes_cpp + [ 144 inc 145 for inc in includes_c 146 if inc not in includes_cpp_set.to_list() 147 ] 148 149def auto_configure_fail(msg): 150 """Output failure message when rocm configuration fails.""" 151 red = "\033[0;31m" 152 no_color = "\033[0m" 153 fail("\n%sROCm Configuration Error:%s %s\n" % (red, no_color, msg)) 154 155def auto_configure_warning(msg): 156 """Output warning message during auto configuration.""" 157 yellow = "\033[1;33m" 158 no_color = "\033[0m" 159 print("\n%sAuto-Configuration Warning:%s %s\n" % (yellow, no_color, msg)) 160 161# END cc_configure common functions (see TODO above). 162 163def _rocm_include_path(repository_ctx, rocm_config, bash_bin): 164 """Generates the cxx_builtin_include_directory entries for rocm inc dirs. 165 166 Args: 167 repository_ctx: The repository context. 168 rocm_config: The path to the gcc host compiler. 169 170 Returns: 171 A string containing the Starlark string for each of the gcc 172 host compiler include directories, which can be added to the CROSSTOOL 173 file. 174 """ 175 inc_dirs = [] 176 177 # Add HSA headers (needs to match $HSA_PATH) 178 inc_dirs.append(rocm_config.rocm_toolkit_path + "/hsa/include") 179 180 # Add HIP headers (needs to match $HIP_PATH) 181 inc_dirs.append(rocm_config.rocm_toolkit_path + "/hip/include") 182 183 # Add HIP-Clang headers (realpath relative to compiler binary) 184 rocm_toolkit_path = realpath(repository_ctx, rocm_config.rocm_toolkit_path, bash_bin) 185 inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/8.0/include") 186 inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include") 187 inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include") 188 inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/11.0.0/include") 189 inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/12.0.0/include") 190 inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/13.0.0/include") 191 inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/14.0.0/include") 192 193 # Support hcc based off clang 10.0.0 (for ROCm 3.3) 194 inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/") 195 inc_dirs.append(rocm_toolkit_path + "/hcc/lib/clang/10.0.0/include") 196 197 # Add hcc headers 198 inc_dirs.append(rocm_toolkit_path + "/hcc/include") 199 200 return inc_dirs 201 202def _enable_rocm(repository_ctx): 203 enable_rocm = get_host_environ(repository_ctx, "TF_NEED_ROCM") 204 if enable_rocm == "1": 205 if get_cpu_value(repository_ctx) != "Linux": 206 auto_configure_warning("ROCm configure is only supported on Linux") 207 return False 208 return True 209 return False 210 211def _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin): 212 """Returns a list of strings representing AMDGPU targets.""" 213 amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS) 214 if not amdgpu_targets_str: 215 cmd = "%s/bin/rocm_agent_enumerator" % rocm_toolkit_path 216 result = execute(repository_ctx, [bash_bin, "-c", cmd]) 217 targets = [target for target in result.stdout.strip().split("\n") if target != "gfx000"] 218 targets = {x: None for x in targets} 219 targets = list(targets.keys()) 220 amdgpu_targets_str = ",".join(targets) 221 amdgpu_targets = amdgpu_targets_str.split(",") 222 for amdgpu_target in amdgpu_targets: 223 if amdgpu_target[:3] != "gfx": 224 auto_configure_fail("Invalid AMDGPU target: %s" % amdgpu_target) 225 return amdgpu_targets 226 227def _hipcc_env(repository_ctx): 228 """Returns the environment variable string for hipcc. 229 230 Args: 231 repository_ctx: The repository context. 232 233 Returns: 234 A string containing environment variables for hipcc. 235 """ 236 hipcc_env = "" 237 for name in [ 238 "HIP_CLANG_PATH", 239 "DEVICE_LIB_PATH", 240 "HIP_VDI_HOME", 241 "HIPCC_VERBOSE", 242 "HIPCC_COMPILE_FLAGS_APPEND", 243 "HIPPCC_LINK_FLAGS_APPEND", 244 "HCC_AMDGPU_TARGET", 245 "HIP_PLATFORM", 246 ]: 247 env_value = get_host_environ(repository_ctx, name) 248 if env_value: 249 hipcc_env = (hipcc_env + " " + name + "=\"" + env_value + "\";") 250 return hipcc_env.strip() 251 252def _crosstool_verbose(repository_ctx): 253 """Returns the environment variable value CROSSTOOL_VERBOSE. 254 255 Args: 256 repository_ctx: The repository context. 257 258 Returns: 259 A string containing value of environment variable CROSSTOOL_VERBOSE. 260 """ 261 return get_host_environ(repository_ctx, "CROSSTOOL_VERBOSE", "0") 262 263def _lib_name(lib, version = "", static = False): 264 """Constructs the name of a library on Linux. 265 266 Args: 267 lib: The name of the library, such as "hip" 268 version: The version of the library. 269 static: True the library is static or False if it is a shared object. 270 271 Returns: 272 The platform-specific name of the library. 273 """ 274 if static: 275 return "lib%s.a" % lib 276 else: 277 if version: 278 version = ".%s" % version 279 return "lib%s.so%s" % (lib, version) 280 281def _rocm_lib_paths(repository_ctx, lib, basedir): 282 file_name = _lib_name(lib, version = "", static = False) 283 return [ 284 repository_ctx.path("%s/lib64/%s" % (basedir, file_name)), 285 repository_ctx.path("%s/lib64/stubs/%s" % (basedir, file_name)), 286 repository_ctx.path("%s/lib/x86_64-linux-gnu/%s" % (basedir, file_name)), 287 repository_ctx.path("%s/lib/%s" % (basedir, file_name)), 288 repository_ctx.path("%s/%s" % (basedir, file_name)), 289 ] 290 291def _batch_files_exist(repository_ctx, libs_paths, bash_bin): 292 all_paths = [] 293 for _, lib_paths in libs_paths: 294 for lib_path in lib_paths: 295 all_paths.append(lib_path) 296 return files_exist(repository_ctx, all_paths, bash_bin) 297 298def _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin): 299 test_results = _batch_files_exist(repository_ctx, libs_paths, bash_bin) 300 301 libs = {} 302 i = 0 303 for name, lib_paths in libs_paths: 304 selected_path = None 305 for path in lib_paths: 306 if test_results[i] and selected_path == None: 307 # For each lib select the first path that exists. 308 selected_path = path 309 i = i + 1 310 if selected_path == None: 311 auto_configure_fail("Cannot find rocm library %s" % name) 312 313 libs[name] = struct(file_name = selected_path.basename, path = realpath(repository_ctx, selected_path, bash_bin)) 314 315 return libs 316 317def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, bash_bin): 318 """Returns the ROCm libraries on the system. 319 320 Args: 321 repository_ctx: The repository context. 322 rocm_config: The ROCm config as returned by _get_rocm_config 323 bash_bin: the path to the bash interpreter 324 325 Returns: 326 Map of library names to structs of filename and path 327 """ 328 libs_paths = [ 329 (name, _rocm_lib_paths(repository_ctx, name, path)) 330 for name, path in [ 331 ("amdhip64", rocm_config.rocm_toolkit_path + "/hip"), 332 ("rocblas", rocm_config.rocm_toolkit_path + "/rocblas"), 333 (hipfft_or_rocfft, rocm_config.rocm_toolkit_path + "/" + hipfft_or_rocfft), 334 ("hiprand", rocm_config.rocm_toolkit_path), 335 ("MIOpen", rocm_config.rocm_toolkit_path + "/miopen"), 336 ("rccl", rocm_config.rocm_toolkit_path + "/rccl"), 337 ("hipsparse", rocm_config.rocm_toolkit_path + "/hipsparse"), 338 ("roctracer64", rocm_config.rocm_toolkit_path + "/roctracer"), 339 ("rocsolver", rocm_config.rocm_toolkit_path + "/rocsolver"), 340 ] 341 ] 342 if int(rocm_config.rocm_version_number) >= 40500: 343 libs_paths.append(("hipsolver", _rocm_lib_paths(repository_ctx, "hipsolver", rocm_config.rocm_toolkit_path + "/hipsolver"))) 344 libs_paths.append(("hipblas", _rocm_lib_paths(repository_ctx, "hipblas", rocm_config.rocm_toolkit_path + "/hipblas"))) 345 return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin) 346 347def _exec_find_rocm_config(repository_ctx, script_path): 348 python_bin = get_python_bin(repository_ctx) 349 350 # If used with remote execution then repository_ctx.execute() can't 351 # access files from the source tree. A trick is to read the contents 352 # of the file in Starlark and embed them as part of the command. In 353 # this case the trick is not sufficient as the find_cuda_config.py 354 # script has more than 8192 characters. 8192 is the command length 355 # limit of cmd.exe on Windows. Thus we additionally need to compress 356 # the contents locally and decompress them as part of the execute(). 357 compressed_contents = repository_ctx.read(script_path) 358 decompress_and_execute_cmd = ( 359 "from zlib import decompress;" + 360 "from base64 import b64decode;" + 361 "from os import system;" + 362 "script = decompress(b64decode('%s'));" % compressed_contents + 363 "f = open('script.py', 'wb');" + 364 "f.write(script);" + 365 "f.close();" + 366 "system('\"%s\" script.py');" % (python_bin) 367 ) 368 369 return execute(repository_ctx, [python_bin, "-c", decompress_and_execute_cmd]) 370 371def find_rocm_config(repository_ctx, script_path): 372 """Returns ROCm config dictionary from running find_rocm_config.py""" 373 exec_result = _exec_find_rocm_config(repository_ctx, script_path) 374 if exec_result.return_code: 375 auto_configure_fail("Failed to run find_rocm_config.py: %s" % err_out(exec_result)) 376 377 # Parse the dict from stdout. 378 return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()]) 379 380def _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script): 381 """Detects and returns information about the ROCm installation on the system. 382 383 Args: 384 repository_ctx: The repository context. 385 bash_bin: the path to the path interpreter 386 387 Returns: 388 A struct containing the following fields: 389 rocm_toolkit_path: The ROCm toolkit installation directory. 390 amdgpu_targets: A list of the system's AMDGPU targets. 391 rocm_version_number: The version of ROCm on the system. 392 miopen_version_number: The version of MIOpen on the system. 393 hipruntime_version_number: The version of HIP Runtime on the system. 394 """ 395 config = find_rocm_config(repository_ctx, find_rocm_config_script) 396 rocm_toolkit_path = config["rocm_toolkit_path"] 397 rocm_version_number = config["rocm_version_number"] 398 miopen_version_number = config["miopen_version_number"] 399 hipruntime_version_number = config["hipruntime_version_number"] 400 return struct( 401 amdgpu_targets = _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin), 402 rocm_toolkit_path = rocm_toolkit_path, 403 rocm_version_number = rocm_version_number, 404 miopen_version_number = miopen_version_number, 405 hipruntime_version_number = hipruntime_version_number, 406 ) 407 408def _tpl_path(repository_ctx, labelname): 409 return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % labelname)) 410 411def _tpl(repository_ctx, tpl, substitutions = {}, out = None): 412 if not out: 413 out = tpl.replace(":", "/") 414 repository_ctx.template( 415 out, 416 _tpl_path(repository_ctx, tpl), 417 substitutions, 418 ) 419 420_DUMMY_CROSSTOOL_BZL_FILE = """ 421def error_gpu_disabled(): 422 fail("ERROR: Building with --config=rocm but TensorFlow is not configured " + 423 "to build with GPU support. Please re-run ./configure and enter 'Y' " + 424 "at the prompt to build with GPU support.") 425 426 native.genrule( 427 name = "error_gen_crosstool", 428 outs = ["CROSSTOOL"], 429 cmd = "echo 'Should not be run.' && exit 1", 430 ) 431 432 native.filegroup( 433 name = "crosstool", 434 srcs = [":CROSSTOOL"], 435 output_licenses = ["unencumbered"], 436 ) 437""" 438 439_DUMMY_CROSSTOOL_BUILD_FILE = """ 440load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled") 441 442error_gpu_disabled() 443""" 444 445def _create_dummy_repository(repository_ctx): 446 # Set up BUILD file for rocm/. 447 _tpl( 448 repository_ctx, 449 "rocm:build_defs.bzl", 450 { 451 "%{rocm_is_configured}": "False", 452 "%{rocm_extra_copts}": "[]", 453 "%{rocm_gpu_architectures}": "[]", 454 "%{rocm_version_number}": "0", 455 }, 456 ) 457 _tpl( 458 repository_ctx, 459 "rocm:BUILD", 460 { 461 "%{hip_lib}": _lib_name("hip"), 462 "%{rocblas_lib}": _lib_name("rocblas"), 463 "%{hipblas_lib}": _lib_name("hipblas"), 464 "%{miopen_lib}": _lib_name("miopen"), 465 "%{rccl_lib}": _lib_name("rccl"), 466 "%{hipfft_or_rocfft}": "hipfft", 467 "%{hipfft_or_rocfft_lib}": _lib_name("hipfft"), 468 "%{hiprand_lib}": _lib_name("hiprand"), 469 "%{hipsparse_lib}": _lib_name("hipsparse"), 470 "%{roctracer_lib}": _lib_name("roctracer64"), 471 "%{rocsolver_lib}": _lib_name("rocsolver"), 472 "%{hipsolver_lib}": _lib_name("hipsolver"), 473 "%{copy_rules}": "", 474 "%{rocm_headers}": "", 475 }, 476 ) 477 478 # Create dummy files for the ROCm toolkit since they are still required by 479 # tensorflow/tsl/platform/default/build_config:rocm. 480 repository_ctx.file("rocm/hip/include/hip/hip_runtime.h", "") 481 482 # Set up rocm_config.h, which is used by 483 # tensorflow/stream_executor/dso_loader.cc. 484 _tpl( 485 repository_ctx, 486 "rocm:rocm_config.h", 487 { 488 "%{rocm_toolkit_path}": _DEFAULT_ROCM_TOOLKIT_PATH, 489 }, 490 "rocm/rocm/rocm_config.h", 491 ) 492 493 # If rocm_configure is not configured to build with GPU support, and the user 494 # attempts to build with --config=rocm, add a dummy build rule to intercept 495 # this and fail with an actionable error message. 496 repository_ctx.file( 497 "crosstool/error_gpu_disabled.bzl", 498 _DUMMY_CROSSTOOL_BZL_FILE, 499 ) 500 repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE) 501 502def _norm_path(path): 503 """Returns a path with '/' and remove the trailing slash.""" 504 path = path.replace("\\", "/") 505 if path[-1] == "/": 506 path = path[:-1] 507 return path 508 509def _genrule(src_dir, genrule_name, command, outs): 510 """Returns a string with a genrule. 511 512 Genrule executes the given command and produces the given outputs. 513 """ 514 return ( 515 "genrule(\n" + 516 ' name = "' + 517 genrule_name + '",\n' + 518 " outs = [\n" + 519 outs + 520 "\n ],\n" + 521 ' cmd = """\n' + 522 command + 523 '\n """,\n' + 524 ")\n" 525 ) 526 527def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets): 528 amdgpu_target_flags = ["--amdgpu-target=" + 529 amdgpu_target for amdgpu_target in amdgpu_targets] 530 return str(amdgpu_target_flags) 531 532def _create_local_rocm_repository(repository_ctx): 533 """Creates the repository containing files set up to build with ROCm.""" 534 535 tpl_paths = {labelname: _tpl_path(repository_ctx, labelname) for labelname in [ 536 "rocm:build_defs.bzl", 537 "rocm:BUILD", 538 "crosstool:BUILD.rocm", 539 "crosstool:hipcc_cc_toolchain_config.bzl", 540 "crosstool:clang/bin/crosstool_wrapper_driver_rocm", 541 "rocm:rocm_config.h", 542 ]} 543 544 find_rocm_config_script = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_rocm_config.py.gz.base64")) 545 546 bash_bin = get_bash_bin(repository_ctx) 547 rocm_config = _get_rocm_config(repository_ctx, bash_bin, find_rocm_config_script) 548 549 # For ROCm 4.1 and above use hipfft, older ROCm versions use rocfft 550 rocm_version_number = int(rocm_config.rocm_version_number) 551 hipfft_or_rocfft = "rocfft" if rocm_version_number < 40100 else "hipfft" 552 553 # Copy header and library files to execroot. 554 # rocm_toolkit_path 555 rocm_toolkit_path = rocm_config.rocm_toolkit_path 556 copy_rules = [ 557 make_copy_dir_rule( 558 repository_ctx, 559 name = "rocm-include", 560 src_dir = rocm_toolkit_path + "/include", 561 out_dir = "rocm/include", 562 exceptions = ["gtest", "gmock"], 563 ), 564 make_copy_dir_rule( 565 repository_ctx, 566 name = hipfft_or_rocfft + "-include", 567 src_dir = rocm_toolkit_path + "/" + hipfft_or_rocfft + "/include", 568 out_dir = "rocm/include/" + hipfft_or_rocfft, 569 ), 570 make_copy_dir_rule( 571 repository_ctx, 572 name = "rocblas-include", 573 src_dir = rocm_toolkit_path + "/rocblas/include", 574 out_dir = "rocm/include/rocblas", 575 ), 576 make_copy_dir_rule( 577 repository_ctx, 578 name = "rocblas-hsaco", 579 src_dir = rocm_toolkit_path + "/rocblas/lib/library", 580 out_dir = "rocm/lib/rocblas/lib/library", 581 ), 582 make_copy_dir_rule( 583 repository_ctx, 584 name = "miopen-include", 585 src_dir = rocm_toolkit_path + "/miopen/include", 586 out_dir = "rocm/include/miopen", 587 ), 588 make_copy_dir_rule( 589 repository_ctx, 590 name = "rccl-include", 591 src_dir = rocm_toolkit_path + "/rccl/include", 592 out_dir = "rocm/include/rccl", 593 ), 594 make_copy_dir_rule( 595 repository_ctx, 596 name = "hipsparse-include", 597 src_dir = rocm_toolkit_path + "/hipsparse/include", 598 out_dir = "rocm/include/hipsparse", 599 ), 600 make_copy_dir_rule( 601 repository_ctx, 602 name = "rocsolver-include", 603 src_dir = rocm_toolkit_path + "/rocsolver/include", 604 out_dir = "rocm/include/rocsolver", 605 ), 606 ] 607 608 # Add Hipsolver on ROCm4.5+ 609 if rocm_version_number >= 40500: 610 copy_rules.append( 611 make_copy_dir_rule( 612 repository_ctx, 613 name = "hipsolver-include", 614 src_dir = rocm_toolkit_path + "/hipsolver/include", 615 out_dir = "rocm/include/hipsolver", 616 ), 617 ) 618 copy_rules.append( 619 make_copy_dir_rule( 620 repository_ctx, 621 name = "hipblas-include", 622 src_dir = rocm_toolkit_path + "/hipblas/include", 623 out_dir = "rocm/include/hipblas", 624 ), 625 ) 626 627 # explicitly copy (into the local_config_rocm repo) the $ROCM_PATH/hiprand/include and 628 # $ROCM_PATH/rocrand/include dirs, only once the softlink to them in $ROCM_PATH/include 629 # dir has been removed. This removal will happen in a near-future ROCm release. 630 hiprand_include = "" 631 hiprand_include_softlink = rocm_config.rocm_toolkit_path + "/include/hiprand" 632 softlink_exists = files_exist(repository_ctx, [hiprand_include_softlink], bash_bin) 633 if not softlink_exists[0]: 634 hiprand_include = '":hiprand-include",\n' 635 copy_rules.append( 636 make_copy_dir_rule( 637 repository_ctx, 638 name = "hiprand-include", 639 src_dir = rocm_toolkit_path + "/hiprand/include", 640 out_dir = "rocm/include/hiprand", 641 ), 642 ) 643 644 rocrand_include = "" 645 rocrand_include_softlink = rocm_config.rocm_toolkit_path + "/include/rocrand" 646 softlink_exists = files_exist(repository_ctx, [rocrand_include_softlink], bash_bin) 647 if not softlink_exists[0]: 648 rocrand_include = '":rocrand-include",\n' 649 copy_rules.append( 650 make_copy_dir_rule( 651 repository_ctx, 652 name = "rocrand-include", 653 src_dir = rocm_toolkit_path + "/rocrand/include", 654 out_dir = "rocm/include/rocrand", 655 ), 656 ) 657 658 rocm_libs = _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, bash_bin) 659 rocm_lib_srcs = [] 660 rocm_lib_outs = [] 661 for lib in rocm_libs.values(): 662 rocm_lib_srcs.append(lib.path) 663 rocm_lib_outs.append("rocm/lib/" + lib.file_name) 664 copy_rules.append(make_copy_files_rule( 665 repository_ctx, 666 name = "rocm-lib", 667 srcs = rocm_lib_srcs, 668 outs = rocm_lib_outs, 669 )) 670 671 clang_offload_bundler_path = rocm_toolkit_path + "/llvm/bin/clang-offload-bundler" 672 673 # copy files mentioned in third_party/gpus/rocm/BUILD 674 copy_rules.append(make_copy_files_rule( 675 repository_ctx, 676 name = "rocm-bin", 677 srcs = [ 678 clang_offload_bundler_path, 679 ], 680 outs = [ 681 "rocm/bin/" + "clang-offload-bundler", 682 ], 683 )) 684 685 # Set up BUILD file for rocm/ 686 repository_ctx.template( 687 "rocm/build_defs.bzl", 688 tpl_paths["rocm:build_defs.bzl"], 689 { 690 "%{rocm_is_configured}": "True", 691 "%{rocm_extra_copts}": _compute_rocm_extra_copts( 692 repository_ctx, 693 rocm_config.amdgpu_targets, 694 ), 695 "%{rocm_gpu_architectures}": str(rocm_config.amdgpu_targets), 696 "%{rocm_version_number}": str(rocm_version_number), 697 }, 698 ) 699 700 repository_dict = { 701 "%{hip_lib}": rocm_libs["amdhip64"].file_name, 702 "%{rocblas_lib}": rocm_libs["rocblas"].file_name, 703 "%{hipfft_or_rocfft}": hipfft_or_rocfft, 704 "%{hipfft_or_rocfft_lib}": rocm_libs[hipfft_or_rocfft].file_name, 705 "%{hiprand_lib}": rocm_libs["hiprand"].file_name, 706 "%{miopen_lib}": rocm_libs["MIOpen"].file_name, 707 "%{rccl_lib}": rocm_libs["rccl"].file_name, 708 "%{hipsparse_lib}": rocm_libs["hipsparse"].file_name, 709 "%{roctracer_lib}": rocm_libs["roctracer64"].file_name, 710 "%{rocsolver_lib}": rocm_libs["rocsolver"].file_name, 711 "%{copy_rules}": "\n".join(copy_rules), 712 "%{rocm_headers}": ('":rocm-include",\n' + 713 '":' + hipfft_or_rocfft + '-include",\n' + 714 '":rocblas-include",\n' + 715 '":miopen-include",\n' + 716 '":rccl-include",\n' + 717 hiprand_include + 718 rocrand_include + 719 '":hipsparse-include",\n' + 720 '":rocsolver-include"'), 721 } 722 if rocm_version_number >= 40500: 723 repository_dict["%{hipsolver_lib}"] = rocm_libs["hipsolver"].file_name 724 repository_dict["%{rocm_headers}"] += ',\n":hipsolver-include"' 725 repository_dict["%{hipblas_lib}"] = rocm_libs["hipblas"].file_name 726 repository_dict["%{rocm_headers}"] += ',\n":hipblas-include"' 727 728 repository_ctx.template( 729 "rocm/BUILD", 730 tpl_paths["rocm:BUILD"], 731 repository_dict, 732 ) 733 734 # Set up crosstool/ 735 736 cc = find_cc(repository_ctx) 737 738 host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc) 739 740 host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX, "/usr/bin") 741 742 rocm_defines = {} 743 744 rocm_defines["%{host_compiler_prefix}"] = host_compiler_prefix 745 746 rocm_defines["%{linker_bin_path}"] = rocm_config.rocm_toolkit_path + "/hcc/compiler/bin" 747 748 # For gcc, do not canonicalize system header paths; some versions of gcc 749 # pick the shortest possible path for system includes when creating the 750 # .d file - given that includes that are prefixed with "../" multiple 751 # time quickly grow longer than the root of the tree, this can lead to 752 # bazel's header check failing. 753 rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\"" 754 755 rocm_defines["%{unfiltered_compile_flags}"] = to_list_of_strings([ 756 "-DTENSORFLOW_USE_ROCM=1", 757 "-D__HIP_PLATFORM_HCC__", 758 "-DEIGEN_USE_HIP", 759 ]) 760 761 rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" 762 763 rocm_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings( 764 host_compiler_includes + _rocm_include_path(repository_ctx, rocm_config, bash_bin), 765 ) 766 767 verify_build_defines(rocm_defines) 768 769 # Only expand template variables in the BUILD file 770 repository_ctx.template( 771 "crosstool/BUILD", 772 tpl_paths["crosstool:BUILD.rocm"], 773 rocm_defines, 774 ) 775 776 # No templating of cc_toolchain_config - use attributes and templatize the 777 # BUILD file. 778 repository_ctx.template( 779 "crosstool/cc_toolchain_config.bzl", 780 tpl_paths["crosstool:hipcc_cc_toolchain_config.bzl"], 781 ) 782 783 repository_ctx.template( 784 "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", 785 tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_rocm"], 786 { 787 "%{cpu_compiler}": str(cc), 788 "%{hipcc_path}": rocm_config.rocm_toolkit_path + "/hip/bin/hipcc", 789 "%{hipcc_env}": _hipcc_env(repository_ctx), 790 "%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", 791 "%{rocr_runtime_library}": "hsa-runtime64", 792 "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/hip/lib", 793 "%{hip_runtime_library}": "amdhip64", 794 "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), 795 "%{gcc_host_compiler_path}": str(cc), 796 }, 797 ) 798 799 # Set up rocm_config.h, which is used by 800 # tensorflow/stream_executor/dso_loader.cc. 801 repository_ctx.template( 802 "rocm/rocm/rocm_config.h", 803 tpl_paths["rocm:rocm_config.h"], 804 { 805 "%{rocm_amdgpu_targets}": ",".join( 806 ["\"%s\"" % c for c in rocm_config.amdgpu_targets], 807 ), 808 "%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path, 809 "%{rocm_version_number}": rocm_config.rocm_version_number, 810 "%{miopen_version_number}": rocm_config.miopen_version_number, 811 "%{hipruntime_version_number}": rocm_config.hipruntime_version_number, 812 }, 813 ) 814 815def _create_remote_rocm_repository(repository_ctx, remote_config_repo): 816 """Creates pointers to a remotely configured repo set up to build with ROCm.""" 817 _tpl( 818 repository_ctx, 819 "rocm:build_defs.bzl", 820 { 821 "%{rocm_is_configured}": "True", 822 "%{rocm_extra_copts}": _compute_rocm_extra_copts( 823 repository_ctx, 824 [], #_compute_capabilities(repository_ctx) 825 ), 826 }, 827 ) 828 repository_ctx.template( 829 "rocm/BUILD", 830 config_repo_label(remote_config_repo, "rocm:BUILD"), 831 {}, 832 ) 833 repository_ctx.template( 834 "rocm/build_defs.bzl", 835 config_repo_label(remote_config_repo, "rocm:build_defs.bzl"), 836 {}, 837 ) 838 repository_ctx.template( 839 "rocm/rocm/rocm_config.h", 840 config_repo_label(remote_config_repo, "rocm:rocm/rocm_config.h"), 841 {}, 842 ) 843 repository_ctx.template( 844 "crosstool/BUILD", 845 config_repo_label(remote_config_repo, "crosstool:BUILD"), 846 {}, 847 ) 848 repository_ctx.template( 849 "crosstool/cc_toolchain_config.bzl", 850 config_repo_label(remote_config_repo, "crosstool:cc_toolchain_config.bzl"), 851 {}, 852 ) 853 repository_ctx.template( 854 "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", 855 config_repo_label(remote_config_repo, "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc"), 856 {}, 857 ) 858 859def _rocm_autoconf_impl(repository_ctx): 860 """Implementation of the rocm_autoconf repository rule.""" 861 if not _enable_rocm(repository_ctx): 862 _create_dummy_repository(repository_ctx) 863 elif get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO) != None: 864 _create_remote_rocm_repository( 865 repository_ctx, 866 get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO), 867 ) 868 else: 869 _create_local_rocm_repository(repository_ctx) 870 871_ENVIRONS = [ 872 _GCC_HOST_COMPILER_PATH, 873 _GCC_HOST_COMPILER_PREFIX, 874 "TF_NEED_ROCM", 875 _ROCM_TOOLKIT_PATH, 876 _TF_ROCM_AMDGPU_TARGETS, 877] 878 879remote_rocm_configure = repository_rule( 880 implementation = _create_local_rocm_repository, 881 environ = _ENVIRONS, 882 remotable = True, 883 attrs = { 884 "environ": attr.string_dict(), 885 }, 886) 887 888rocm_configure = repository_rule( 889 implementation = _rocm_autoconf_impl, 890 environ = _ENVIRONS + [_TF_ROCM_CONFIG_REPO], 891) 892"""Detects and configures the local ROCm toolchain. 893 894Add the following to your WORKSPACE FILE: 895 896```python 897rocm_configure(name = "local_config_rocm") 898``` 899 900Args: 901 name: A unique name for this workspace rule. 902""" 903