1#!/usr/bin/env python3 2# Owner(s): ["oncall: r2p"] 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree. 9import io 10import os 11import shutil 12import sys 13import tempfile 14import time 15import unittest 16from concurrent.futures import wait 17from concurrent.futures._base import ALL_COMPLETED 18from concurrent.futures.thread import ThreadPoolExecutor 19from typing import Dict, Set 20from unittest import mock 21 22from torch.distributed.elastic.multiprocessing.tail_log import TailLog 23 24 25def write(max: int, sleep: float, file: str): 26 with open(file, "w") as fp: 27 for i in range(max): 28 print(i, file=fp, flush=True) 29 time.sleep(sleep) 30 31 32class TailLogTest(unittest.TestCase): 33 def setUp(self): 34 self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_") 35 self.threadpool = ThreadPoolExecutor() 36 37 def tearDown(self): 38 shutil.rmtree(self.test_dir) 39 40 def test_tail(self): 41 """ 42 writer() writes 0 - max (on number on each line) to a log file. 43 Run nprocs such writers and tail the log files into an IOString 44 and validate that all lines are accounted for. 45 """ 46 nprocs = 32 47 max = 1000 48 interval_sec = 0.0001 49 50 log_files = { 51 local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log") 52 for local_rank in range(nprocs) 53 } 54 55 dst = io.StringIO() 56 tail = TailLog( 57 name="writer", log_files=log_files, dst=dst, interval_sec=interval_sec 58 ).start() 59 # sleep here is intentional to ensure that the log tail 60 # can gracefully handle and wait for non-existent log files 61 time.sleep(interval_sec * 10) 62 63 futs = [] 64 for local_rank, file in log_files.items(): 65 f = self.threadpool.submit( 66 write, max=max, sleep=interval_sec * local_rank, file=file 67 ) 68 futs.append(f) 69 70 wait(futs, return_when=ALL_COMPLETED) 71 self.assertFalse(tail.stopped()) 72 tail.stop() 73 74 dst.seek(0) 75 actual: Dict[int, Set[int]] = {} 76 77 for line in dst.readlines(): 78 header, num = line.split(":") 79 nums = actual.setdefault(header, set()) 80 nums.add(int(num)) 81 82 self.assertEqual(nprocs, len(actual)) 83 self.assertEqual( 84 {f"[writer{i}]": set(range(max)) for i in range(nprocs)}, actual 85 ) 86 self.assertTrue(tail.stopped()) 87 88 def test_tail_with_custom_prefix(self): 89 """ 90 writer() writes 0 - max (on number on each line) to a log file. 91 Run nprocs such writers and tail the log files into an IOString 92 and validate that all lines are accounted for. 93 """ 94 nprocs = 3 95 max = 10 96 interval_sec = 0.0001 97 98 log_files = { 99 local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log") 100 for local_rank in range(nprocs) 101 } 102 103 dst = io.StringIO() 104 log_line_prefixes = {n: f"[worker{n}][{n}]:" for n in range(nprocs)} 105 tail = TailLog( 106 "writer", 107 log_files, 108 dst, 109 interval_sec=interval_sec, 110 log_line_prefixes=log_line_prefixes, 111 ).start() 112 # sleep here is intentional to ensure that the log tail 113 # can gracefully handle and wait for non-existent log files 114 time.sleep(interval_sec * 10) 115 futs = [] 116 for local_rank, file in log_files.items(): 117 f = self.threadpool.submit( 118 write, max=max, sleep=interval_sec * local_rank, file=file 119 ) 120 futs.append(f) 121 wait(futs, return_when=ALL_COMPLETED) 122 self.assertFalse(tail.stopped()) 123 tail.stop() 124 dst.seek(0) 125 126 headers: Set[str] = set() 127 for line in dst.readlines(): 128 header, _ = line.split(":") 129 headers.add(header) 130 self.assertEqual(nprocs, len(headers)) 131 for i in range(nprocs): 132 self.assertIn(f"[worker{i}][{i}]", headers) 133 self.assertTrue(tail.stopped()) 134 135 def test_tail_no_files(self): 136 """ 137 Ensures that the log tail can gracefully handle no log files 138 in which case it does nothing. 139 """ 140 tail = TailLog("writer", log_files={}, dst=sys.stdout).start() 141 self.assertFalse(tail.stopped()) 142 tail.stop() 143 self.assertTrue(tail.stopped()) 144 145 def test_tail_logfile_never_generates(self): 146 """ 147 Ensures that we properly shutdown the threadpool 148 even when the logfile never generates. 149 """ 150 151 tail = TailLog("writer", log_files={0: "foobar.log"}, dst=sys.stdout).start() 152 tail.stop() 153 self.assertTrue(tail.stopped()) 154 self.assertTrue(tail._threadpool._shutdown) 155 156 @mock.patch("torch.distributed.elastic.multiprocessing.tail_log.logger") 157 def test_tail_logfile_error_in_tail_fn(self, mock_logger): 158 """ 159 Ensures that when there is an error in the tail_fn (the one that runs in the 160 threadpool), it is dealt with and raised properly. 161 """ 162 163 # try giving tail log a directory (should fail with an IsADirectoryError 164 tail = TailLog("writer", log_files={0: self.test_dir}, dst=sys.stdout).start() 165 tail.stop() 166 167 mock_logger.error.assert_called_once() 168