1#!/usr/bin/env python3 2# Owner(s): ["oncall: r2p"] 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9import ctypes 10import os 11import shutil 12import sys 13import tempfile 14import unittest 15 16from torch.distributed.elastic.multiprocessing.redirects import ( 17 redirect, 18 redirect_stderr, 19 redirect_stdout, 20) 21 22 23libc = ctypes.CDLL("libc.so.6") 24c_stderr = ctypes.c_void_p.in_dll(libc, "stderr") 25 26 27class RedirectsTest(unittest.TestCase): 28 def setUp(self): 29 self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_") 30 31 def tearDown(self): 32 shutil.rmtree(self.test_dir) 33 34 def test_redirect_invalid_std(self): 35 with self.assertRaises(ValueError): 36 with redirect("stdfoo", os.path.join(self.test_dir, "stdfoo.log")): 37 pass 38 39 def test_redirect_stdout(self): 40 stdout_log = os.path.join(self.test_dir, "stdout.log") 41 42 # printing to stdout before redirect should go to console not stdout.log 43 print("foo first from python") 44 libc.printf(b"foo first from c\n") 45 os.system("echo foo first from cmd") 46 47 with redirect_stdout(stdout_log): 48 print("foo from python") 49 libc.printf(b"foo from c\n") 50 os.system("echo foo from cmd") 51 52 # make sure stdout is restored 53 print("foo again from python") 54 libc.printf(b"foo again from c\n") 55 os.system("echo foo again from cmd") 56 57 with open(stdout_log) as f: 58 # since we print from python, c, cmd -> the stream is not ordered 59 # do a set comparison 60 lines = set(f.readlines()) 61 self.assertEqual( 62 {"foo from python\n", "foo from c\n", "foo from cmd\n"}, lines 63 ) 64 65 def test_redirect_stderr(self): 66 stderr_log = os.path.join(self.test_dir, "stderr.log") 67 68 print("bar first from python") 69 libc.fprintf(c_stderr, b"bar first from c\n") 70 os.system("echo bar first from cmd 1>&2") 71 72 with redirect_stderr(stderr_log): 73 print("bar from python", file=sys.stderr) 74 libc.fprintf(c_stderr, b"bar from c\n") 75 os.system("echo bar from cmd 1>&2") 76 77 print("bar again from python") 78 libc.fprintf(c_stderr, b"bar again from c\n") 79 os.system("echo bar again from cmd 1>&2") 80 81 with open(stderr_log) as f: 82 lines = set(f.readlines()) 83 self.assertEqual( 84 {"bar from python\n", "bar from c\n", "bar from cmd\n"}, lines 85 ) 86 87 def test_redirect_both(self): 88 stdout_log = os.path.join(self.test_dir, "stdout.log") 89 stderr_log = os.path.join(self.test_dir, "stderr.log") 90 91 print("first stdout from python") 92 libc.printf(b"first stdout from c\n") 93 94 print("first stderr from python", file=sys.stderr) 95 libc.fprintf(c_stderr, b"first stderr from c\n") 96 97 with redirect_stdout(stdout_log), redirect_stderr(stderr_log): 98 print("redir stdout from python") 99 print("redir stderr from python", file=sys.stderr) 100 libc.printf(b"redir stdout from c\n") 101 libc.fprintf(c_stderr, b"redir stderr from c\n") 102 103 print("again stdout from python") 104 libc.fprintf(c_stderr, b"again stderr from c\n") 105 106 with open(stdout_log) as f: 107 lines = set(f.readlines()) 108 self.assertEqual( 109 {"redir stdout from python\n", "redir stdout from c\n"}, lines 110 ) 111 112 with open(stderr_log) as f: 113 lines = set(f.readlines()) 114 self.assertEqual( 115 {"redir stderr from python\n", "redir stderr from c\n"}, lines 116 ) 117 118 def _redirect_large_buffer(self, print_fn, num_lines=500_000): 119 stdout_log = os.path.join(self.test_dir, "stdout.log") 120 121 with redirect_stdout(stdout_log): 122 for i in range(num_lines): 123 print_fn(i) 124 125 with open(stdout_log) as fp: 126 actual = {int(line.split(":")[1]) for line in fp} 127 expected = set(range(num_lines)) 128 self.assertSetEqual(expected, actual) 129 130 def test_redirect_large_buffer_py(self): 131 def py_print(i): 132 print(f"py:{i}") 133 134 self._redirect_large_buffer(py_print) 135 136 def test_redirect_large_buffer_c(self): 137 def c_print(i): 138 libc.printf(bytes(f"c:{i}\n", "utf-8")) 139 140 self._redirect_large_buffer(c_print) 141 142 143if __name__ == "__main__": 144 unittest.main() 145