1"""Repository rule for TensorRT configuration. 2 3`tensorrt_configure` depends on the following environment variables: 4 5 * `TF_TENSORRT_VERSION`: The TensorRT libnvinfer version. 6 * `TENSORRT_INSTALL_PATH`: The installation path of the TensorRT library. 7""" 8 9load( 10 "//third_party/gpus:cuda_configure.bzl", 11 "find_cuda_config", 12 "lib_name", 13 "make_copy_files_rule", 14) 15load( 16 "//third_party/remote_config:common.bzl", 17 "config_repo_label", 18 "get_cpu_value", 19 "get_host_environ", 20) 21 22_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH" 23_TF_TENSORRT_STATIC_PATH = "TF_TENSORRT_STATIC_PATH" 24_TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO" 25_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION" 26_TF_NEED_TENSORRT = "TF_NEED_TENSORRT" 27 28_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin"] 29_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"] 30_TF_TENSORRT_HEADERS_V6 = [ 31 "NvInfer.h", 32 "NvUtils.h", 33 "NvInferPlugin.h", 34 "NvInferVersion.h", 35 "NvInferRuntime.h", 36 "NvInferRuntimeCommon.h", 37 "NvInferPluginUtils.h", 38] 39_TF_TENSORRT_HEADERS_V8 = [ 40 "NvInfer.h", 41 "NvInferLegacyDims.h", 42 "NvInferImpl.h", 43 "NvUtils.h", 44 "NvInferPlugin.h", 45 "NvInferVersion.h", 46 "NvInferRuntime.h", 47 "NvInferRuntimeCommon.h", 48 "NvInferPluginUtils.h", 49] 50 51_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR" 52_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR" 53_DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH" 54 55_TENSORRT_OSS_DUMMY_BUILD_CONTENT = """ 56cc_library( 57 name = "nvinfer_plugin_nms", 58 visibility = ["//visibility:public"], 59) 60""" 61 62_TENSORRT_OSS_ARCHIVE_BUILD_CONTENT = """ 63alias( 64 name = "nvinfer_plugin_nms", 65 actual = "@tensorrt_oss_archive//:nvinfer_plugin_nms", 66 visibility = ["//visibility:public"], 67) 68""" 69 70def _at_least_version(actual_version, required_version): 71 actual = [int(v) for v in actual_version.split(".")] 72 required = [int(v) for v in required_version.split(".")] 73 return actual >= required 74 75def _get_tensorrt_headers(tensorrt_version): 76 if _at_least_version(tensorrt_version, "8"): 77 return _TF_TENSORRT_HEADERS_V8 78 if _at_least_version(tensorrt_version, "6"): 79 return _TF_TENSORRT_HEADERS_V6 80 return _TF_TENSORRT_HEADERS 81 82def _tpl_path(repository_ctx, filename): 83 return repository_ctx.path(Label("//third_party/tensorrt:%s.tpl" % filename)) 84 85def _tpl(repository_ctx, tpl, substitutions): 86 repository_ctx.template( 87 tpl, 88 _tpl_path(repository_ctx, tpl), 89 substitutions, 90 ) 91 92def _create_dummy_repository(repository_ctx): 93 """Create a dummy TensorRT repository.""" 94 _tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_false"}) 95 _tpl(repository_ctx, "BUILD", { 96 "%{copy_rules}": "", 97 "\":tensorrt_include\"": "", 98 "\":tensorrt_lib\"": "", 99 "%{oss_rules}": _TENSORRT_OSS_DUMMY_BUILD_CONTENT, 100 }) 101 _tpl(repository_ctx, "tensorrt/include/tensorrt_config.h", { 102 "%{tensorrt_version}": "", 103 }) 104 105 # Copy license file in non-remote build. 106 repository_ctx.template( 107 "LICENSE", 108 Label("//third_party/tensorrt:LICENSE"), 109 {}, 110 ) 111 112 # Set up tensorrt_config.py, which is used by gen_build_info to provide 113 # build environment info to the API 114 _tpl( 115 repository_ctx, 116 "tensorrt/tensorrt_config.py", 117 _py_tmpl_dict({}), 118 ) 119 120def enable_tensorrt(repository_ctx): 121 """Returns whether to build with TensorRT support.""" 122 return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False)) 123 124def _get_tensorrt_static_path(repository_ctx): 125 """Returns the path for TensorRT static libraries.""" 126 return get_host_environ(repository_ctx, _TF_TENSORRT_STATIC_PATH, None) 127 128def _create_local_tensorrt_repository(repository_ctx): 129 # Resolve all labels before doing any real work. Resolving causes the 130 # function to be restarted with all previous state being lost. This 131 # can easily lead to a O(n^2) runtime in the number of labels. 132 # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778 133 find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64")) 134 tpl_paths = { 135 "build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"), 136 "BUILD": _tpl_path(repository_ctx, "BUILD"), 137 "tensorrt/include/tensorrt_config.h": _tpl_path(repository_ctx, "tensorrt/include/tensorrt_config.h"), 138 "tensorrt/tensorrt_config.py": _tpl_path(repository_ctx, "tensorrt/tensorrt_config.py"), 139 "plugin.BUILD": _tpl_path(repository_ctx, "plugin.BUILD"), 140 } 141 142 config = find_cuda_config(repository_ctx, find_cuda_config_path, ["tensorrt"]) 143 trt_version = config["tensorrt_version"] 144 cpu_value = get_cpu_value(repository_ctx) 145 146 # Copy the library and header files. 147 libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS] 148 149 library_dir = config["tensorrt_library_dir"] + "/" 150 headers = _get_tensorrt_headers(trt_version) 151 include_dir = config["tensorrt_include_dir"] + "/" 152 copy_rules = [ 153 make_copy_files_rule( 154 repository_ctx, 155 name = "tensorrt_lib", 156 srcs = [library_dir + library for library in libraries], 157 outs = ["tensorrt/lib/" + library for library in libraries], 158 ), 159 make_copy_files_rule( 160 repository_ctx, 161 name = "tensorrt_include", 162 srcs = [include_dir + header for header in headers], 163 outs = ["tensorrt/include/" + header for header in headers], 164 ), 165 ] 166 167 tensorrt_static_path = _get_tensorrt_static_path(repository_ctx) 168 if tensorrt_static_path: 169 tensorrt_static_path = tensorrt_static_path + "/" 170 if _at_least_version(trt_version, "8"): 171 raw_static_library_names = _TF_TENSORRT_LIBS 172 else: 173 raw_static_library_names = _TF_TENSORRT_LIBS + ["nvrtc", "myelin_compiler", "myelin_executor", "myelin_pattern_library", "myelin_pattern_runtime"] 174 static_library_names = ["%s_static" % name for name in raw_static_library_names] 175 static_libraries = [lib_name(lib, cpu_value, trt_version, static = True) for lib in static_library_names] 176 if tensorrt_static_path != None: 177 copy_rules = copy_rules + [ 178 make_copy_files_rule( 179 repository_ctx, 180 name = "tensorrt_static_lib", 181 srcs = [tensorrt_static_path + library for library in static_libraries], 182 outs = ["tensorrt/lib/" + library for library in static_libraries], 183 ), 184 ] 185 186 # Set up config file. 187 repository_ctx.template( 188 "build_defs.bzl", 189 tpl_paths["build_defs.bzl"], 190 {"%{if_tensorrt}": "if_true"}, 191 ) 192 193 # Set up BUILD file. 194 repository_ctx.template( 195 "BUILD", 196 tpl_paths["BUILD"], 197 { 198 "%{copy_rules}": "\n".join(copy_rules), 199 }, 200 ) 201 202 # Set up the plugins folder BUILD file. 203 repository_ctx.template( 204 "plugin/BUILD", 205 tpl_paths["plugin.BUILD"], 206 { 207 "%{oss_rules}": _TENSORRT_OSS_ARCHIVE_BUILD_CONTENT, 208 }, 209 ) 210 211 # Copy license file in non-remote build. 212 repository_ctx.template( 213 "LICENSE", 214 Label("//third_party/tensorrt:LICENSE"), 215 {}, 216 ) 217 218 # Set up tensorrt_config.h, which is used by 219 # tensorflow/stream_executor/dso_loader.cc. 220 repository_ctx.template( 221 "tensorrt/include/tensorrt_config.h", 222 tpl_paths["tensorrt/include/tensorrt_config.h"], 223 {"%{tensorrt_version}": trt_version}, 224 ) 225 226 # Set up tensorrt_config.py, which is used by gen_build_info to provide 227 # build environment info to the API 228 repository_ctx.template( 229 "tensorrt/tensorrt_config.py", 230 tpl_paths["tensorrt/tensorrt_config.py"], 231 _py_tmpl_dict({ 232 "tensorrt_version": trt_version, 233 }), 234 ) 235 236def _py_tmpl_dict(d): 237 return {"%{tensorrt_config}": str(d)} 238 239def _tensorrt_configure_impl(repository_ctx): 240 """Implementation of the tensorrt_configure repository rule.""" 241 242 if get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO) != None: 243 # Forward to the pre-configured remote repository. 244 remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO] 245 repository_ctx.template("BUILD", config_repo_label(remote_config_repo, ":BUILD"), {}) 246 repository_ctx.template( 247 "build_defs.bzl", 248 config_repo_label(remote_config_repo, ":build_defs.bzl"), 249 {}, 250 ) 251 repository_ctx.template( 252 "tensorrt/include/tensorrt_config.h", 253 config_repo_label(remote_config_repo, ":tensorrt/include/tensorrt_config.h"), 254 {}, 255 ) 256 repository_ctx.template( 257 "tensorrt/tensorrt_config.py", 258 config_repo_label(remote_config_repo, ":tensorrt/tensorrt_config.py"), 259 {}, 260 ) 261 repository_ctx.template( 262 "LICENSE", 263 config_repo_label(remote_config_repo, ":LICENSE"), 264 {}, 265 ) 266 return 267 268 if not enable_tensorrt(repository_ctx): 269 _create_dummy_repository(repository_ctx) 270 return 271 272 _create_local_tensorrt_repository(repository_ctx) 273 274_ENVIRONS = [ 275 _TENSORRT_INSTALL_PATH, 276 _TF_TENSORRT_VERSION, 277 _TF_NEED_TENSORRT, 278 _TF_TENSORRT_STATIC_PATH, 279 "TF_CUDA_PATHS", 280] 281 282remote_tensorrt_configure = repository_rule( 283 implementation = _create_local_tensorrt_repository, 284 environ = _ENVIRONS, 285 remotable = True, 286 attrs = { 287 "environ": attr.string_dict(), 288 }, 289) 290 291tensorrt_configure = repository_rule( 292 implementation = _tensorrt_configure_impl, 293 environ = _ENVIRONS + [_TF_TENSORRT_CONFIG_REPO], 294) 295"""Detects and configures the local CUDA toolchain. 296 297Add the following to your WORKSPACE FILE: 298 299```python 300tensorrt_configure(name = "local_config_tensorrt") 301``` 302 303Args: 304 name: A unique name for this workspace rule. 305""" 306