xref: /aosp_15_r20/external/pytorch/torch/distributed/argparse_util.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9import os
10from argparse import Action
11
12
13class env(Action):
14    """
15    Get argument values from ``PET_{dest}`` before defaulting to the given ``default`` value.
16
17    For flags (e.g. ``--standalone``)
18    use ``check_env`` instead.
19
20    .. note:: when multiple option strings are specified, ``dest`` is
21              the longest option string (e.g. for ``"-f", "--foo"``
22              the env var to set is ``PET_FOO`` not ``PET_F``)
23
24    Example:
25    ::
26
27     parser.add_argument("-f", "--foo", action=env, default="bar")
28
29     ./program                                      -> args.foo="bar"
30     ./program -f baz                               -> args.foo="baz"
31     ./program --foo baz                            -> args.foo="baz"
32     PET_FOO="env_bar" ./program -f baz    -> args.foo="baz"
33     PET_FOO="env_bar" ./program --foo baz -> args.foo="baz"
34     PET_FOO="env_bar" ./program           -> args.foo="env_bar"
35
36     parser.add_argument("-f", "--foo", action=env, required=True)
37
38     ./program                                      -> fails
39     ./program -f baz                               -> args.foo="baz"
40     PET_FOO="env_bar" ./program           -> args.foo="env_bar"
41     PET_FOO="env_bar" ./program -f baz    -> args.foo="baz"
42    """
43
44    def __init__(self, dest, default=None, required=False, **kwargs) -> None:
45        env_name = f"PET_{dest.upper()}"
46        default = os.environ.get(env_name, default)
47
48        # ``required`` means that it NEEDS to be present  in the command-line args
49        # rather than "this option requires a value (either set explicitly or default"
50        # so if we found default then we don't "require" it to be in the command-line
51        # so set it to False
52        if default:
53            required = False
54
55        super().__init__(dest=dest, default=default, required=required, **kwargs)
56
57    def __call__(self, parser, namespace, values, option_string=None):
58        setattr(namespace, self.dest, values)
59
60
61class check_env(Action):
62    """
63    Check whether the env var ``PET_{dest}`` exists before defaulting to the given ``default`` value.
64
65    Equivalent to
66    ``store_true`` argparse built-in action except that the argument can
67    be omitted from the commandline if the env var is present and has a
68    non-zero value.
69
70    .. note:: it is redundant to pass ``default=True`` for arguments
71              that use this action because a flag should be ``True``
72              when present and ``False`` otherwise.
73
74    Example:
75    ::
76
77     parser.add_argument("--verbose", action=check_env)
78
79     ./program                                  -> args.verbose=False
80     ./program --verbose                        -> args.verbose=True
81     PET_VERBOSE=1 ./program           -> args.verbose=True
82     PET_VERBOSE=0 ./program           -> args.verbose=False
83     PET_VERBOSE=0 ./program --verbose -> args.verbose=True
84
85    Anti-pattern (don't do this):
86
87    ::
88
89     parser.add_argument("--verbose", action=check_env, default=True)
90
91     ./program                                  -> args.verbose=True
92     ./program --verbose                        -> args.verbose=True
93     PET_VERBOSE=1 ./program           -> args.verbose=True
94     PET_VERBOSE=0 ./program           -> args.verbose=False
95
96    """
97
98    def __init__(self, dest, default=False, **kwargs) -> None:
99        env_name = f"PET_{dest.upper()}"
100        default = bool(int(os.environ.get(env_name, "1" if default else "0")))
101        super().__init__(dest=dest, const=True, default=default, nargs=0, **kwargs)
102
103    def __call__(self, parser, namespace, values, option_string=None):
104        setattr(namespace, self.dest, self.const)
105