xref: /aosp_15_r20/external/tensorflow/third_party/remote_config/common.bzl (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1"""Functions common across configure rules."""
2
3BAZEL_SH = "BAZEL_SH"
4PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
5PYTHON_LIB_PATH = "PYTHON_LIB_PATH"
6TF_PYTHON_CONFIG_REPO = "TF_PYTHON_CONFIG_REPO"
7
8def auto_config_fail(msg):
9    """Output failure message when auto configuration fails."""
10    red = "\033[0;31m"
11    no_color = "\033[0m"
12    fail("%sConfiguration Error:%s %s\n" % (red, no_color, msg))
13
14def which(repository_ctx, program_name, allow_failure = False):
15    """Returns the full path to a program on the execution platform.
16
17    Args:
18      repository_ctx: the repository_ctx
19      program_name: name of the program on the PATH
20
21    Returns:
22      The full path to a program on the execution platform.
23    """
24    if is_windows(repository_ctx):
25        if not program_name.endswith(".exe"):
26            program_name = program_name + ".exe"
27        out = execute(
28            repository_ctx,
29            ["C:\\Windows\\System32\\where.exe", program_name],
30            allow_failure = allow_failure,
31        ).stdout
32        if out != None:
33            out = out.replace("\\", "\\\\").rstrip()
34        return out
35
36    out = execute(
37        repository_ctx,
38        ["which", program_name],
39        allow_failure = allow_failure,
40    ).stdout
41    if out != None:
42        out = out.replace("\\", "\\\\").rstrip()
43    return out
44
45def get_python_bin(repository_ctx):
46    """Gets the python bin path.
47
48    Args:
49      repository_ctx: the repository_ctx
50
51    Returns:
52      The python bin path.
53    """
54    python_bin = get_host_environ(repository_ctx, PYTHON_BIN_PATH)
55    if python_bin:
56        return python_bin
57
58    # First check for an explicit "python3"
59    python_bin = which(repository_ctx, "python3", True)
60    if python_bin:
61        return python_bin
62
63    # Some systems just call pythone3 "python"
64    python_bin = which(repository_ctx, "python", True)
65    if python_bin:
66        return python_bin
67
68    auto_config_fail("Cannot find python in PATH, please make sure " +
69                     "python is installed and add its directory in PATH, or --define " +
70                     "%s='/something/else'.\nPATH=%s" % (
71                         PYTHON_BIN_PATH,
72                         get_environ(repository_ctx, "PATH"),
73                     ))
74    return python_bin  # unreachable
75
76def get_bash_bin(repository_ctx):
77    """Gets the bash bin path.
78
79    Args:
80      repository_ctx: the repository_ctx
81
82    Returns:
83      The bash bin path.
84    """
85    bash_bin = get_host_environ(repository_ctx, BAZEL_SH)
86    if bash_bin != None:
87        return bash_bin
88    bash_bin_path = which(repository_ctx, "bash")
89    if bash_bin_path == None:
90        auto_config_fail("Cannot find bash in PATH, please make sure " +
91                         "bash is installed and add its directory in PATH, or --define " +
92                         "%s='/path/to/bash'.\nPATH=%s" % (
93                             BAZEL_SH,
94                             get_environ(repository_ctx, "PATH"),
95                         ))
96    return bash_bin_path
97
98def read_dir(repository_ctx, src_dir):
99    """Returns a sorted list with all files in a directory.
100
101    Finds all files inside a directory, traversing subfolders and following
102    symlinks.
103
104    Args:
105      repository_ctx: the repository_ctx
106      src_dir: the directory to traverse
107
108    Returns:
109      A sorted list with all files in a directory.
110    """
111    if is_windows(repository_ctx):
112        src_dir = src_dir.replace("/", "\\")
113        find_result = execute(
114            repository_ctx,
115            ["C:\\Windows\\System32\\cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
116            allow_failure = True,
117        )
118
119        # src_files will be used in genrule.outs where the paths must
120        # use forward slashes.
121        result = find_result.stdout.replace("\\", "/")
122    else:
123        find_result = execute(
124            repository_ctx,
125            ["find", src_dir, "-follow", "-type", "f"],
126            allow_failure = True,
127        )
128        result = find_result.stdout
129    return sorted(result.splitlines())
130
131def get_environ(repository_ctx, name, default_value = None):
132    """Returns the value of an environment variable on the execution platform.
133
134    Args:
135      repository_ctx: the repository_ctx
136      name: the name of environment variable
137      default_value: the value to return if not set
138
139    Returns:
140      The value of the environment variable 'name' on the execution platform
141      or 'default_value' if it's not set.
142    """
143    if is_windows(repository_ctx):
144        result = execute(
145            repository_ctx,
146            ["C:\\Windows\\System32\\cmd.exe", "/c", "echo", "%" + name + "%"],
147            allow_failure = True,
148        )
149    else:
150        cmd = "echo -n \"$%s\"" % name
151        result = execute(
152            repository_ctx,
153            [get_bash_bin(repository_ctx), "-c", cmd],
154            allow_failure = True,
155        )
156    if len(result.stdout) == 0:
157        return default_value
158    return result.stdout
159
160def get_host_environ(repository_ctx, name, default_value = None):
161    """Returns the value of an environment variable on the host platform.
162
163    The host platform is the machine that Bazel runs on.
164
165    Args:
166      repository_ctx: the repository_ctx
167      name: the name of environment variable
168
169    Returns:
170      The value of the environment variable 'name' on the host platform.
171    """
172    if name in repository_ctx.os.environ:
173        return repository_ctx.os.environ.get(name).strip()
174
175    if hasattr(repository_ctx.attr, "environ") and name in repository_ctx.attr.environ:
176        return repository_ctx.attr.environ.get(name).strip()
177
178    return default_value
179
180def is_windows(repository_ctx):
181    """Returns true if the execution platform is Windows.
182
183    Args:
184      repository_ctx: the repository_ctx
185
186    Returns:
187      If the execution platform is Windows.
188    """
189    os_name = ""
190    if hasattr(repository_ctx.attr, "exec_properties") and "OSFamily" in repository_ctx.attr.exec_properties:
191        os_name = repository_ctx.attr.exec_properties["OSFamily"]
192    else:
193        os_name = repository_ctx.os.name
194
195    return os_name.lower().find("windows") != -1
196
197def get_cpu_value(repository_ctx):
198    """Returns the name of the host operating system.
199
200    Args:
201      repository_ctx: The repository context.
202    Returns:
203      A string containing the name of the host operating system.
204    """
205    if is_windows(repository_ctx):
206        return "Windows"
207    result = raw_exec(repository_ctx, ["uname", "-s"])
208    return result.stdout.strip()
209
210def execute(
211        repository_ctx,
212        cmdline,
213        error_msg = None,
214        error_details = None,
215        allow_failure = False):
216    """Executes an arbitrary shell command.
217
218    Args:
219      repository_ctx: the repository_ctx object
220      cmdline: list of strings, the command to execute
221      error_msg: string, a summary of the error if the command fails
222      error_details: string, details about the error or steps to fix it
223      allow_failure: bool, if True, an empty stdout result or output to stderr
224        is fine, otherwise either of these is an error
225    Returns:
226      The result of repository_ctx.execute(cmdline)
227    """
228    result = raw_exec(repository_ctx, cmdline)
229    if (result.stderr or not result.stdout) and not allow_failure:
230        fail(
231            "\n".join([
232                error_msg.strip() if error_msg else "Repository command failed",
233                result.stderr.strip(),
234                error_details if error_details else "",
235            ]),
236        )
237    return result
238
239def raw_exec(repository_ctx, cmdline):
240    """Executes a command via repository_ctx.execute() and returns the result.
241
242    This method is useful for debugging purposes. For example, to print all
243    commands executed as well as their return code.
244
245    Args:
246      repository_ctx: the repository_ctx
247      cmdline: the list of args
248
249    Returns:
250      The 'exec_result' of repository_ctx.execute().
251    """
252    return repository_ctx.execute(cmdline)
253
254def files_exist(repository_ctx, paths, bash_bin = None):
255    """Checks which files in paths exists.
256
257    Args:
258      repository_ctx: the repository_ctx
259      paths: a list of paths
260      bash_bin: path to the bash interpreter
261
262    Returns:
263      Returns a list of Bool. True means that the path at the
264      same position in the paths list exists.
265    """
266    if bash_bin == None:
267        bash_bin = get_bash_bin(repository_ctx)
268
269    cmd_tpl = "[ -e \"%s\" ] && echo True || echo False"
270    cmds = [cmd_tpl % path for path in paths]
271    cmd = " ; ".join(cmds)
272
273    stdout = execute(repository_ctx, [bash_bin, "-c", cmd]).stdout.strip()
274    return [val == "True" for val in stdout.splitlines()]
275
276def realpath(repository_ctx, path, bash_bin = None):
277    """Returns the result of "realpath path".
278
279    Args:
280      repository_ctx: the repository_ctx
281      path: a path on the file system
282      bash_bin: path to the bash interpreter
283
284    Returns:
285      Returns the result of "realpath path"
286    """
287    if bash_bin == None:
288        bash_bin = get_bash_bin(repository_ctx)
289
290    return execute(repository_ctx, [bash_bin, "-c", "realpath \"%s\"" % path]).stdout.strip()
291
292def err_out(result):
293    """Returns stderr if set, else stdout.
294
295    This function is a workaround for a bug in RBE where stderr is returned as stdout. Instead
296    of using result.stderr use err_out(result) instead.
297
298    Args:
299      result: the exec_result.
300
301    Returns:
302      The stderr if set, else stdout
303    """
304    if len(result.stderr) == 0:
305        return result.stdout
306    return result.stderr
307
308def config_repo_label(config_repo, target):
309    """Construct a label from config_repo and target.
310
311    This function exists to ease the migration from preconfig to remote config. In preconfig
312    the TF_*_CONFIG_REPO environ variables are set to packages in the main repo while in
313    remote config they will point to remote repositories.
314
315    Args:
316      config_repo: a remote repository or package.
317      target: a target
318    Returns:
319      A label constructed from config_repo and target.
320    """
321    if config_repo.startswith("@") and not config_repo.find("//") > 0:
322        # remote config is being used.
323        return Label(config_repo + "//" + target)
324    elif target.startswith(":"):
325        return Label(config_repo + target)
326    else:
327        return Label(config_repo + "/" + target)
328