xref: /aosp_15_r20/external/pytorch/test/distributed/launcher/bin/test_script_local_rank.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10import argparse
11import os
12
13
14def parse_args():
15    parser = argparse.ArgumentParser(description="test script")
16
17    parser.add_argument(
18        "--local-rank",
19        "--local_rank",
20        type=int,
21        required=True,
22        help="The rank of the node for multi-node distributed training",
23    )
24
25    return parser.parse_args()
26
27
28def main():
29    print("Start execution")
30    args = parse_args()
31    expected_rank = int(os.environ["LOCAL_RANK"])
32    actual_rank = args.local_rank
33    if expected_rank != actual_rank:
34        raise RuntimeError(
35            "Parameters passed: --local-rank that has different value "
36            f"from env var: expected: {expected_rank}, got: {actual_rank}"
37        )
38    print("End execution")
39
40
41if __name__ == "__main__":
42    main()
43