xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/multiprocessing/redirects_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9import ctypes
10import os
11import shutil
12import sys
13import tempfile
14import unittest
15
16from torch.distributed.elastic.multiprocessing.redirects import (
17    redirect,
18    redirect_stderr,
19    redirect_stdout,
20)
21
22
23libc = ctypes.CDLL("libc.so.6")
24c_stderr = ctypes.c_void_p.in_dll(libc, "stderr")
25
26
27class RedirectsTest(unittest.TestCase):
28    def setUp(self):
29        self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
30
31    def tearDown(self):
32        shutil.rmtree(self.test_dir)
33
34    def test_redirect_invalid_std(self):
35        with self.assertRaises(ValueError):
36            with redirect("stdfoo", os.path.join(self.test_dir, "stdfoo.log")):
37                pass
38
39    def test_redirect_stdout(self):
40        stdout_log = os.path.join(self.test_dir, "stdout.log")
41
42        # printing to stdout before redirect should go to console not stdout.log
43        print("foo first from python")
44        libc.printf(b"foo first from c\n")
45        os.system("echo foo first from cmd")
46
47        with redirect_stdout(stdout_log):
48            print("foo from python")
49            libc.printf(b"foo from c\n")
50            os.system("echo foo from cmd")
51
52        # make sure stdout is restored
53        print("foo again from python")
54        libc.printf(b"foo again from c\n")
55        os.system("echo foo again from cmd")
56
57        with open(stdout_log) as f:
58            # since we print from python, c, cmd -> the stream is not ordered
59            # do a set comparison
60            lines = set(f.readlines())
61            self.assertEqual(
62                {"foo from python\n", "foo from c\n", "foo from cmd\n"}, lines
63            )
64
65    def test_redirect_stderr(self):
66        stderr_log = os.path.join(self.test_dir, "stderr.log")
67
68        print("bar first from python")
69        libc.fprintf(c_stderr, b"bar first from c\n")
70        os.system("echo bar first from cmd 1>&2")
71
72        with redirect_stderr(stderr_log):
73            print("bar from python", file=sys.stderr)
74            libc.fprintf(c_stderr, b"bar from c\n")
75            os.system("echo bar from cmd 1>&2")
76
77        print("bar again from python")
78        libc.fprintf(c_stderr, b"bar again from c\n")
79        os.system("echo bar again from cmd 1>&2")
80
81        with open(stderr_log) as f:
82            lines = set(f.readlines())
83            self.assertEqual(
84                {"bar from python\n", "bar from c\n", "bar from cmd\n"}, lines
85            )
86
87    def test_redirect_both(self):
88        stdout_log = os.path.join(self.test_dir, "stdout.log")
89        stderr_log = os.path.join(self.test_dir, "stderr.log")
90
91        print("first stdout from python")
92        libc.printf(b"first stdout from c\n")
93
94        print("first stderr from python", file=sys.stderr)
95        libc.fprintf(c_stderr, b"first stderr from c\n")
96
97        with redirect_stdout(stdout_log), redirect_stderr(stderr_log):
98            print("redir stdout from python")
99            print("redir stderr from python", file=sys.stderr)
100            libc.printf(b"redir stdout from c\n")
101            libc.fprintf(c_stderr, b"redir stderr from c\n")
102
103        print("again stdout from python")
104        libc.fprintf(c_stderr, b"again stderr from c\n")
105
106        with open(stdout_log) as f:
107            lines = set(f.readlines())
108            self.assertEqual(
109                {"redir stdout from python\n", "redir stdout from c\n"}, lines
110            )
111
112        with open(stderr_log) as f:
113            lines = set(f.readlines())
114            self.assertEqual(
115                {"redir stderr from python\n", "redir stderr from c\n"}, lines
116            )
117
118    def _redirect_large_buffer(self, print_fn, num_lines=500_000):
119        stdout_log = os.path.join(self.test_dir, "stdout.log")
120
121        with redirect_stdout(stdout_log):
122            for i in range(num_lines):
123                print_fn(i)
124
125        with open(stdout_log) as fp:
126            actual = {int(line.split(":")[1]) for line in fp}
127            expected = set(range(num_lines))
128            self.assertSetEqual(expected, actual)
129
130    def test_redirect_large_buffer_py(self):
131        def py_print(i):
132            print(f"py:{i}")
133
134        self._redirect_large_buffer(py_print)
135
136    def test_redirect_large_buffer_c(self):
137        def c_print(i):
138            libc.printf(bytes(f"c:{i}\n", "utf-8"))
139
140        self._redirect_large_buffer(c_print)
141
142
143if __name__ == "__main__":
144    unittest.main()
145