xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/utils/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import os
10import socket
11from string import Template
12from typing import Any, List
13
14
15def get_env_variable_or_raise(env_name: str) -> str:
16    r"""
17    Tries to retrieve environment variable. Raises ``ValueError``
18    if no environment variable found.
19
20    Args:
21        env_name (str): Name of the env variable
22    """
23    value = os.environ.get(env_name, None)
24    if value is None:
25        msg = f"Environment variable {env_name} expected, but not set"
26        raise ValueError(msg)
27    return value
28
29
30def get_socket_with_port() -> socket.socket:
31    addrs = socket.getaddrinfo(
32        host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
33    )
34    for addr in addrs:
35        family, type, proto, _, _ = addr
36        s = socket.socket(family, type, proto)
37        try:
38            s.bind(("localhost", 0))
39            s.listen(0)
40            return s
41        except OSError as e:
42            s.close()
43    raise RuntimeError("Failed to create a socket")
44
45
46class macros:
47    """
48    Defines simple macros for caffe2.distributed.launch cmd args substitution
49    """
50
51    local_rank = "${local_rank}"
52
53    @staticmethod
54    def substitute(args: List[Any], local_rank: str) -> List[str]:
55        args_sub = []
56        for arg in args:
57            if isinstance(arg, str):
58                sub = Template(arg).safe_substitute(local_rank=local_rank)
59                args_sub.append(sub)
60            else:
61                args_sub.append(arg)
62        return args_sub
63