1# Macros for building CUDA code. 2def if_cuda(if_true, if_false = []): 3 """Shorthand for select()'ing on whether we're building with CUDA. 4 5 Returns a select statement which evaluates to if_true if we're building 6 with CUDA enabled. Otherwise, the select statement evaluates to if_false. 7 8 """ 9 return select({ 10 "@local_config_cuda//:is_cuda_enabled": if_true, 11 "//conditions:default": if_false, 12 }) 13 14def if_cuda_clang(if_true, if_false = []): 15 """Shorthand for select()'ing on wheteher we're building with cuda-clang. 16 17 Returns a select statement which evaluates to if_true if we're building 18 with cuda-clang. Otherwise, the select statement evaluates to if_false. 19 20 """ 21 return select({ 22 "@local_config_cuda//cuda:using_clang": if_true, 23 "//conditions:default": if_false 24 }) 25 26def if_cuda_clang_opt(if_true, if_false = []): 27 """Shorthand for select()'ing on wheteher we're building with cuda-clang 28 in opt mode. 29 30 Returns a select statement which evaluates to if_true if we're building 31 with cuda-clang in opt mode. Otherwise, the select statement evaluates to 32 if_false. 33 34 """ 35 return select({ 36 "@local_config_cuda//cuda:using_clang_opt": if_true, 37 "//conditions:default": if_false 38 }) 39 40def cuda_default_copts(): 41 """Default options for all CUDA compilations.""" 42 return if_cuda([ 43 "-x", "cuda", 44 "-DGOOGLE_CUDA=1", 45 "-Xcuda-fatbinary=--compress-all", 46 ] + %{cuda_extra_copts}) + if_cuda_clang_opt( 47 # Some important CUDA optimizations are only enabled at O3. 48 ["-O3"] 49 ) 50 51def cuda_gpu_architectures(): 52 """Returns a list of supported GPU architectures.""" 53 return %{cuda_gpu_architectures} 54 55def if_cuda_is_configured(x): 56 """Tests if the CUDA was enabled during the configure process. 57 58 Unlike if_cuda(), this does not require that we are building with 59 --config=cuda. Used to allow non-CUDA code to depend on CUDA libraries. 60 """ 61 if %{cuda_is_configured}: 62 return select({"//conditions:default": x}) 63 return select({"//conditions:default": []}) 64 65def cuda_header_library( 66 name, 67 hdrs, 68 include_prefix = None, 69 strip_include_prefix = None, 70 deps = [], 71 **kwargs): 72 """Generates a cc_library containing both virtual and system include paths. 73 74 Generates both a header-only target with virtual includes plus the full 75 target without virtual includes. This works around the fact that bazel can't 76 mix 'includes' and 'include_prefix' in the same target.""" 77 78 native.cc_library( 79 name = name + "_virtual", 80 hdrs = hdrs, 81 include_prefix = include_prefix, 82 strip_include_prefix = strip_include_prefix, 83 deps = deps, 84 visibility = ["//visibility:private"], 85 ) 86 87 native.cc_library( 88 name = name, 89 textual_hdrs = hdrs, 90 deps = deps + [":%s_virtual" % name], 91 **kwargs 92 ) 93 94def cuda_library(copts = [], **kwargs): 95 """Wrapper over cc_library which adds default CUDA options.""" 96 native.cc_library(copts = cuda_default_copts() + copts, **kwargs) 97 98EnableCudaInfo = provider() 99 100def _enable_cuda_flag_impl(ctx): 101 value = ctx.build_setting_value 102 if ctx.attr.enable_override: 103 print( 104 "\n\033[1;33mWarning:\033[0m '--define=using_cuda_nvcc' will be " + 105 "unsupported soon. Use '--@local_config_cuda//:enable_cuda' " + 106 "instead." 107 ) 108 value = True 109 return EnableCudaInfo(value = value) 110 111enable_cuda_flag = rule( 112 implementation = _enable_cuda_flag_impl, 113 build_setting = config.bool(flag = True), 114 attrs = {"enable_override": attr.bool()}, 115) 116