xref: /aosp_15_r20/external/tensorflow/third_party/tensorrt/tensorrt_configure.bzl (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1"""Repository rule for TensorRT configuration.
2
3`tensorrt_configure` depends on the following environment variables:
4
5  * `TF_TENSORRT_VERSION`: The TensorRT libnvinfer version.
6  * `TENSORRT_INSTALL_PATH`: The installation path of the TensorRT library.
7"""
8
9load(
10    "//third_party/gpus:cuda_configure.bzl",
11    "find_cuda_config",
12    "lib_name",
13    "make_copy_files_rule",
14)
15load(
16    "//third_party/remote_config:common.bzl",
17    "config_repo_label",
18    "get_cpu_value",
19    "get_host_environ",
20)
21
22_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
23_TF_TENSORRT_STATIC_PATH = "TF_TENSORRT_STATIC_PATH"
24_TF_TENSORRT_CONFIG_REPO = "TF_TENSORRT_CONFIG_REPO"
25_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
26_TF_NEED_TENSORRT = "TF_NEED_TENSORRT"
27
28_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin"]
29_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h", "NvInferPlugin.h"]
30_TF_TENSORRT_HEADERS_V6 = [
31    "NvInfer.h",
32    "NvUtils.h",
33    "NvInferPlugin.h",
34    "NvInferVersion.h",
35    "NvInferRuntime.h",
36    "NvInferRuntimeCommon.h",
37    "NvInferPluginUtils.h",
38]
39_TF_TENSORRT_HEADERS_V8 = [
40    "NvInfer.h",
41    "NvInferLegacyDims.h",
42    "NvInferImpl.h",
43    "NvUtils.h",
44    "NvInferPlugin.h",
45    "NvInferVersion.h",
46    "NvInferRuntime.h",
47    "NvInferRuntimeCommon.h",
48    "NvInferPluginUtils.h",
49]
50
51_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
52_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
53_DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH"
54
55_TENSORRT_OSS_DUMMY_BUILD_CONTENT = """
56cc_library(
57  name = "nvinfer_plugin_nms",
58  visibility = ["//visibility:public"],
59)
60"""
61
62_TENSORRT_OSS_ARCHIVE_BUILD_CONTENT = """
63alias(
64  name = "nvinfer_plugin_nms",
65  actual = "@tensorrt_oss_archive//:nvinfer_plugin_nms",
66  visibility = ["//visibility:public"],
67)
68"""
69
70def _at_least_version(actual_version, required_version):
71    actual = [int(v) for v in actual_version.split(".")]
72    required = [int(v) for v in required_version.split(".")]
73    return actual >= required
74
75def _get_tensorrt_headers(tensorrt_version):
76    if _at_least_version(tensorrt_version, "8"):
77        return _TF_TENSORRT_HEADERS_V8
78    if _at_least_version(tensorrt_version, "6"):
79        return _TF_TENSORRT_HEADERS_V6
80    return _TF_TENSORRT_HEADERS
81
82def _tpl_path(repository_ctx, filename):
83    return repository_ctx.path(Label("//third_party/tensorrt:%s.tpl" % filename))
84
85def _tpl(repository_ctx, tpl, substitutions):
86    repository_ctx.template(
87        tpl,
88        _tpl_path(repository_ctx, tpl),
89        substitutions,
90    )
91
92def _create_dummy_repository(repository_ctx):
93    """Create a dummy TensorRT repository."""
94    _tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_false"})
95    _tpl(repository_ctx, "BUILD", {
96        "%{copy_rules}": "",
97        "\":tensorrt_include\"": "",
98        "\":tensorrt_lib\"": "",
99        "%{oss_rules}": _TENSORRT_OSS_DUMMY_BUILD_CONTENT,
100    })
101    _tpl(repository_ctx, "tensorrt/include/tensorrt_config.h", {
102        "%{tensorrt_version}": "",
103    })
104
105    # Copy license file in non-remote build.
106    repository_ctx.template(
107        "LICENSE",
108        Label("//third_party/tensorrt:LICENSE"),
109        {},
110    )
111
112    # Set up tensorrt_config.py, which is used by gen_build_info to provide
113    # build environment info to the API
114    _tpl(
115        repository_ctx,
116        "tensorrt/tensorrt_config.py",
117        _py_tmpl_dict({}),
118    )
119
120def enable_tensorrt(repository_ctx):
121    """Returns whether to build with TensorRT support."""
122    return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
123
124def _get_tensorrt_static_path(repository_ctx):
125    """Returns the path for TensorRT static libraries."""
126    return get_host_environ(repository_ctx, _TF_TENSORRT_STATIC_PATH, None)
127
128def _create_local_tensorrt_repository(repository_ctx):
129    # Resolve all labels before doing any real work. Resolving causes the
130    # function to be restarted with all previous state being lost. This
131    # can easily lead to a O(n^2) runtime in the number of labels.
132    # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
133    find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py.gz.base64"))
134    tpl_paths = {
135        "build_defs.bzl": _tpl_path(repository_ctx, "build_defs.bzl"),
136        "BUILD": _tpl_path(repository_ctx, "BUILD"),
137        "tensorrt/include/tensorrt_config.h": _tpl_path(repository_ctx, "tensorrt/include/tensorrt_config.h"),
138        "tensorrt/tensorrt_config.py": _tpl_path(repository_ctx, "tensorrt/tensorrt_config.py"),
139        "plugin.BUILD": _tpl_path(repository_ctx, "plugin.BUILD"),
140    }
141
142    config = find_cuda_config(repository_ctx, find_cuda_config_path, ["tensorrt"])
143    trt_version = config["tensorrt_version"]
144    cpu_value = get_cpu_value(repository_ctx)
145
146    # Copy the library and header files.
147    libraries = [lib_name(lib, cpu_value, trt_version) for lib in _TF_TENSORRT_LIBS]
148
149    library_dir = config["tensorrt_library_dir"] + "/"
150    headers = _get_tensorrt_headers(trt_version)
151    include_dir = config["tensorrt_include_dir"] + "/"
152    copy_rules = [
153        make_copy_files_rule(
154            repository_ctx,
155            name = "tensorrt_lib",
156            srcs = [library_dir + library for library in libraries],
157            outs = ["tensorrt/lib/" + library for library in libraries],
158        ),
159        make_copy_files_rule(
160            repository_ctx,
161            name = "tensorrt_include",
162            srcs = [include_dir + header for header in headers],
163            outs = ["tensorrt/include/" + header for header in headers],
164        ),
165    ]
166
167    tensorrt_static_path = _get_tensorrt_static_path(repository_ctx)
168    if tensorrt_static_path:
169        tensorrt_static_path = tensorrt_static_path + "/"
170        if _at_least_version(trt_version, "8"):
171            raw_static_library_names = _TF_TENSORRT_LIBS
172        else:
173            raw_static_library_names = _TF_TENSORRT_LIBS + ["nvrtc", "myelin_compiler", "myelin_executor", "myelin_pattern_library", "myelin_pattern_runtime"]
174        static_library_names = ["%s_static" % name for name in raw_static_library_names]
175        static_libraries = [lib_name(lib, cpu_value, trt_version, static = True) for lib in static_library_names]
176        if tensorrt_static_path != None:
177            copy_rules = copy_rules + [
178                make_copy_files_rule(
179                    repository_ctx,
180                    name = "tensorrt_static_lib",
181                    srcs = [tensorrt_static_path + library for library in static_libraries],
182                    outs = ["tensorrt/lib/" + library for library in static_libraries],
183                ),
184            ]
185
186    # Set up config file.
187    repository_ctx.template(
188        "build_defs.bzl",
189        tpl_paths["build_defs.bzl"],
190        {"%{if_tensorrt}": "if_true"},
191    )
192
193    # Set up BUILD file.
194    repository_ctx.template(
195        "BUILD",
196        tpl_paths["BUILD"],
197        {
198            "%{copy_rules}": "\n".join(copy_rules),
199        },
200    )
201
202    # Set up the plugins folder BUILD file.
203    repository_ctx.template(
204        "plugin/BUILD",
205        tpl_paths["plugin.BUILD"],
206        {
207            "%{oss_rules}": _TENSORRT_OSS_ARCHIVE_BUILD_CONTENT,
208        },
209    )
210
211    # Copy license file in non-remote build.
212    repository_ctx.template(
213        "LICENSE",
214        Label("//third_party/tensorrt:LICENSE"),
215        {},
216    )
217
218    # Set up tensorrt_config.h, which is used by
219    # tensorflow/stream_executor/dso_loader.cc.
220    repository_ctx.template(
221        "tensorrt/include/tensorrt_config.h",
222        tpl_paths["tensorrt/include/tensorrt_config.h"],
223        {"%{tensorrt_version}": trt_version},
224    )
225
226    # Set up tensorrt_config.py, which is used by gen_build_info to provide
227    # build environment info to the API
228    repository_ctx.template(
229        "tensorrt/tensorrt_config.py",
230        tpl_paths["tensorrt/tensorrt_config.py"],
231        _py_tmpl_dict({
232            "tensorrt_version": trt_version,
233        }),
234    )
235
236def _py_tmpl_dict(d):
237    return {"%{tensorrt_config}": str(d)}
238
239def _tensorrt_configure_impl(repository_ctx):
240    """Implementation of the tensorrt_configure repository rule."""
241
242    if get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO) != None:
243        # Forward to the pre-configured remote repository.
244        remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO]
245        repository_ctx.template("BUILD", config_repo_label(remote_config_repo, ":BUILD"), {})
246        repository_ctx.template(
247            "build_defs.bzl",
248            config_repo_label(remote_config_repo, ":build_defs.bzl"),
249            {},
250        )
251        repository_ctx.template(
252            "tensorrt/include/tensorrt_config.h",
253            config_repo_label(remote_config_repo, ":tensorrt/include/tensorrt_config.h"),
254            {},
255        )
256        repository_ctx.template(
257            "tensorrt/tensorrt_config.py",
258            config_repo_label(remote_config_repo, ":tensorrt/tensorrt_config.py"),
259            {},
260        )
261        repository_ctx.template(
262            "LICENSE",
263            config_repo_label(remote_config_repo, ":LICENSE"),
264            {},
265        )
266        return
267
268    if not enable_tensorrt(repository_ctx):
269        _create_dummy_repository(repository_ctx)
270        return
271
272    _create_local_tensorrt_repository(repository_ctx)
273
274_ENVIRONS = [
275    _TENSORRT_INSTALL_PATH,
276    _TF_TENSORRT_VERSION,
277    _TF_NEED_TENSORRT,
278    _TF_TENSORRT_STATIC_PATH,
279    "TF_CUDA_PATHS",
280]
281
282remote_tensorrt_configure = repository_rule(
283    implementation = _create_local_tensorrt_repository,
284    environ = _ENVIRONS,
285    remotable = True,
286    attrs = {
287        "environ": attr.string_dict(),
288    },
289)
290
291tensorrt_configure = repository_rule(
292    implementation = _tensorrt_configure_impl,
293    environ = _ENVIRONS + [_TF_TENSORRT_CONFIG_REPO],
294)
295"""Detects and configures the local CUDA toolchain.
296
297Add the following to your WORKSPACE FILE:
298
299```python
300tensorrt_configure(name = "local_config_tensorrt")
301```
302
303Args:
304  name: A unique name for this workspace rule.
305"""
306