xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/multiprocessing/tail_log_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9import io
10import os
11import shutil
12import sys
13import tempfile
14import time
15import unittest
16from concurrent.futures import wait
17from concurrent.futures._base import ALL_COMPLETED
18from concurrent.futures.thread import ThreadPoolExecutor
19from typing import Dict, Set
20from unittest import mock
21
22from torch.distributed.elastic.multiprocessing.tail_log import TailLog
23
24
25def write(max: int, sleep: float, file: str):
26    with open(file, "w") as fp:
27        for i in range(max):
28            print(i, file=fp, flush=True)
29            time.sleep(sleep)
30
31
32class TailLogTest(unittest.TestCase):
33    def setUp(self):
34        self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
35        self.threadpool = ThreadPoolExecutor()
36
37    def tearDown(self):
38        shutil.rmtree(self.test_dir)
39
40    def test_tail(self):
41        """
42        writer() writes 0 - max (on number on each line) to a log file.
43        Run nprocs such writers and tail the log files into an IOString
44        and validate that all lines are accounted for.
45        """
46        nprocs = 32
47        max = 1000
48        interval_sec = 0.0001
49
50        log_files = {
51            local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log")
52            for local_rank in range(nprocs)
53        }
54
55        dst = io.StringIO()
56        tail = TailLog(
57            name="writer", log_files=log_files, dst=dst, interval_sec=interval_sec
58        ).start()
59        # sleep here is intentional to ensure that the log tail
60        # can gracefully handle and wait for non-existent log files
61        time.sleep(interval_sec * 10)
62
63        futs = []
64        for local_rank, file in log_files.items():
65            f = self.threadpool.submit(
66                write, max=max, sleep=interval_sec * local_rank, file=file
67            )
68            futs.append(f)
69
70        wait(futs, return_when=ALL_COMPLETED)
71        self.assertFalse(tail.stopped())
72        tail.stop()
73
74        dst.seek(0)
75        actual: Dict[int, Set[int]] = {}
76
77        for line in dst.readlines():
78            header, num = line.split(":")
79            nums = actual.setdefault(header, set())
80            nums.add(int(num))
81
82        self.assertEqual(nprocs, len(actual))
83        self.assertEqual(
84            {f"[writer{i}]": set(range(max)) for i in range(nprocs)}, actual
85        )
86        self.assertTrue(tail.stopped())
87
88    def test_tail_with_custom_prefix(self):
89        """
90        writer() writes 0 - max (on number on each line) to a log file.
91        Run nprocs such writers and tail the log files into an IOString
92        and validate that all lines are accounted for.
93        """
94        nprocs = 3
95        max = 10
96        interval_sec = 0.0001
97
98        log_files = {
99            local_rank: os.path.join(self.test_dir, f"{local_rank}_stdout.log")
100            for local_rank in range(nprocs)
101        }
102
103        dst = io.StringIO()
104        log_line_prefixes = {n: f"[worker{n}][{n}]:" for n in range(nprocs)}
105        tail = TailLog(
106            "writer",
107            log_files,
108            dst,
109            interval_sec=interval_sec,
110            log_line_prefixes=log_line_prefixes,
111        ).start()
112        # sleep here is intentional to ensure that the log tail
113        # can gracefully handle and wait for non-existent log files
114        time.sleep(interval_sec * 10)
115        futs = []
116        for local_rank, file in log_files.items():
117            f = self.threadpool.submit(
118                write, max=max, sleep=interval_sec * local_rank, file=file
119            )
120            futs.append(f)
121        wait(futs, return_when=ALL_COMPLETED)
122        self.assertFalse(tail.stopped())
123        tail.stop()
124        dst.seek(0)
125
126        headers: Set[str] = set()
127        for line in dst.readlines():
128            header, _ = line.split(":")
129            headers.add(header)
130        self.assertEqual(nprocs, len(headers))
131        for i in range(nprocs):
132            self.assertIn(f"[worker{i}][{i}]", headers)
133        self.assertTrue(tail.stopped())
134
135    def test_tail_no_files(self):
136        """
137        Ensures that the log tail can gracefully handle no log files
138        in which case it does nothing.
139        """
140        tail = TailLog("writer", log_files={}, dst=sys.stdout).start()
141        self.assertFalse(tail.stopped())
142        tail.stop()
143        self.assertTrue(tail.stopped())
144
145    def test_tail_logfile_never_generates(self):
146        """
147        Ensures that we properly shutdown the threadpool
148        even when the logfile never generates.
149        """
150
151        tail = TailLog("writer", log_files={0: "foobar.log"}, dst=sys.stdout).start()
152        tail.stop()
153        self.assertTrue(tail.stopped())
154        self.assertTrue(tail._threadpool._shutdown)
155
156    @mock.patch("torch.distributed.elastic.multiprocessing.tail_log.logger")
157    def test_tail_logfile_error_in_tail_fn(self, mock_logger):
158        """
159        Ensures that when there is an error in the tail_fn (the one that runs in the
160        threadpool), it is dealt with and raised properly.
161        """
162
163        # try giving tail log a directory (should fail with an IsADirectoryError
164        tail = TailLog("writer", log_files={0: self.test_dir}, dst=sys.stdout).start()
165        tail.stop()
166
167        mock_logger.error.assert_called_once()
168