xref: /aosp_15_r20/external/tflite-support/third_party/tensorflow/tf_configure.bzl (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1"""Setup TensorFlow as external dependency"""
2
3_TF_HEADER_DIR = "TF_HEADER_DIR"
4_TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR"
5_TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME"
6
7def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
8    if not out:
9        out = tpl
10    repository_ctx.template(
11        out,
12        Label("//third_party/tensorflow:%s.tpl" % tpl),
13        substitutions,
14    )
15
16def _fail(msg):
17    """Output failure message when auto configuration fails."""
18    red = "\033[0;31m"
19    no_color = "\033[0m"
20    fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg))
21
22def _is_windows(repository_ctx):
23    """Returns true if the host operating system is windows."""
24    os_name = repository_ctx.os.name.lower()
25    if os_name.find("windows") != -1:
26        return True
27    return False
28
29def _execute(
30        repository_ctx,
31        cmdline,
32        error_msg = None,
33        error_details = None,
34        empty_stdout_fine = False):
35    """Executes an arbitrary shell command.
36
37    Helper for executes an arbitrary shell command.
38
39    Args:
40      repository_ctx: the repository_ctx object.
41      cmdline: list of strings, the command to execute.
42      error_msg: string, a summary of the error if the command fails.
43      error_details: string, details about the error or steps to fix it.
44      empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
45        it's an error.
46
47    Returns:
48      The result of repository_ctx.execute(cmdline).
49    """
50    result = repository_ctx.execute(cmdline)
51    if result.stderr or not (empty_stdout_fine or result.stdout):
52        _fail("\n".join([
53            error_msg.strip() if error_msg else "Repository command failed",
54            result.stderr.strip(),
55            error_details if error_details else "",
56        ]))
57    return result
58
59def _read_dir(repository_ctx, src_dir):
60    """Returns a string with all files in a directory.
61
62    Finds all files inside a directory, traversing subfolders and following
63    symlinks. The returned string contains the full path of all files
64    separated by line breaks.
65
66    Args:
67        repository_ctx: the repository_ctx object.
68        src_dir: directory to find files from.
69
70    Returns:
71        A string of all files inside the given dir.
72    """
73    if _is_windows(repository_ctx):
74        src_dir = src_dir.replace("/", "\\")
75        find_result = _execute(
76            repository_ctx,
77            ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
78            empty_stdout_fine = True,
79        )
80
81        # src_files will be used in genrule.outs where the paths must
82        # use forward slashes.
83        result = find_result.stdout.replace("\\", "/")
84    else:
85        find_result = _execute(
86            repository_ctx,
87            ["find", src_dir, "-follow", "-type", "f"],
88            empty_stdout_fine = True,
89        )
90        result = find_result.stdout
91    return result
92
93def _genrule(genrule_name, command, outs):
94    """Returns a string with a genrule.
95
96    Genrule executes the given command and produces the given outputs.
97
98    Args:
99        genrule_name: A unique name for genrule target.
100        command: The command to run.
101        outs: A list of files generated by this rule.
102
103    Returns:
104        A genrule target.
105    """
106    return (
107        "genrule(\n" +
108        '    name = "' +
109        genrule_name + '",\n' +
110        "    outs = [\n" +
111        outs +
112        "\n    ],\n" +
113        '    cmd = """\n' +
114        command +
115        '\n   """,\n' +
116        ")\n"
117    )
118
119def _norm_path(path):
120    """Returns a path with '/' and remove the trailing slash."""
121    path = path.replace("\\", "/")
122    if path[-1] == "/":
123        path = path[:-1]
124    return path
125
126def _symlink_genrule_for_dir(
127        repository_ctx,
128        src_dir,
129        dest_dir,
130        genrule_name,
131        src_files = [],
132        dest_files = [],
133        tf_pip_dir_rename_pair = []):
134    """Returns a genrule to symlink(or copy if on Windows) a set of files.
135
136    If src_dir is passed, files will be read from the given directory; otherwise
137    we assume files are in src_files and dest_files.
138
139    Args:
140        repository_ctx: the repository_ctx object.
141        src_dir: source directory.
142        dest_dir: directory to create symlink in.
143        genrule_name: genrule name.
144        src_files: list of source files instead of src_dir.
145        dest_files: list of corresonding destination files.
146        tf_pip_dir_rename_pair: list of the pair of tf pip parent directory to
147          replace. For example, in TF pip package, the source code is under
148          "tensorflow_core", and we might want to replace it with
149          "tensorflow" to match the header includes.
150    Returns:
151        genrule target that creates the symlinks.
152    """
153
154    # Check that tf_pip_dir_rename_pair has the right length
155    tf_pip_dir_rename_pair_len = len(tf_pip_dir_rename_pair)
156    if tf_pip_dir_rename_pair_len != 0 and tf_pip_dir_rename_pair_len != 2:
157        _fail("The size of argument tf_pip_dir_rename_pair should be either 0 or 2, but %d is given." % tf_pip_dir_rename_pair_len)
158
159    if src_dir != None:
160        src_dir = _norm_path(src_dir)
161        dest_dir = _norm_path(dest_dir)
162        files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
163
164        # Create a list with the src_dir stripped to use for outputs.
165        if tf_pip_dir_rename_pair_len:
166            dest_files = files.replace(src_dir, "").replace(tf_pip_dir_rename_pair[0], tf_pip_dir_rename_pair[1]).splitlines()
167        else:
168            dest_files = files.replace(src_dir, "").splitlines()
169        src_files = files.splitlines()
170    command = []
171    outs = []
172    for i in range(len(dest_files)):
173        if dest_files[i] != "":
174            # If we have only one file to link we do not want to use the dest_dir, as
175            # $(@D) will include the full path to the file.
176            dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
177
178            # Copy the headers to create a sandboxable setup.
179            cmd = "cp -f"
180            command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
181            outs.append('        "' + dest_dir + dest_files[i] + '",')
182    dest_dir = "abc"
183    genrule = _genrule(
184        genrule_name,
185        " && ".join(command),
186        "\n".join(outs),
187    )
188    return genrule
189
190def _tf_pip_impl(repository_ctx):
191    tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR]
192    tf_header_rule = _symlink_genrule_for_dir(
193        repository_ctx,
194        tf_header_dir,
195        "include",
196        "tf_header_include",
197        tf_pip_dir_rename_pair = ["tensorflow_core", "tensorflow"],
198    )
199
200    tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR]
201    tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME]
202    tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name)
203    tf_shared_library_rule = _symlink_genrule_for_dir(
204        repository_ctx,
205        None,
206        "",
207        "libtensorflow_framework.so",
208        [tf_shared_library_path],
209        ["_pywrap_tensorflow_internal.lib" if _is_windows(repository_ctx) else "libtensorflow_framework.so"],
210    )
211
212    _tpl(repository_ctx, "BUILD", {
213        "%{TF_HEADER_GENRULE}": tf_header_rule,
214        "%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule,
215    })
216
217tf_configure = repository_rule(
218    implementation = _tf_pip_impl,
219    environ = [
220        _TF_HEADER_DIR,
221        _TF_SHARED_LIBRARY_DIR,
222        _TF_SHARED_LIBRARY_NAME,
223    ],
224)
225