xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/multiprocessing/errors/error_handler_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4import filecmp
5import json
6import os
7import shutil
8import tempfile
9import unittest
10from unittest.mock import patch
11
12from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler
13from torch.distributed.elastic.multiprocessing.errors.handlers import get_error_handler
14
15
16def raise_exception_fn():
17    raise RuntimeError("foobar")
18
19
20class GetErrorHandlerTest(unittest.TestCase):
21    def test_get_error_handler(self):
22        self.assertTrue(isinstance(get_error_handler(), ErrorHandler))
23
24
25class ErrorHandlerTest(unittest.TestCase):
26    def setUp(self):
27        self.test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__)
28        self.test_error_file = os.path.join(self.test_dir, "error.json")
29
30    def tearDown(self):
31        shutil.rmtree(self.test_dir)
32
33    @patch("faulthandler.enable")
34    def test_initialize(self, fh_enable_mock):
35        ErrorHandler().initialize()
36        fh_enable_mock.assert_called_once()
37
38    @patch("faulthandler.enable", side_effect=RuntimeError)
39    def test_initialize_error(self, fh_enable_mock):
40        # makes sure that initialize handles errors gracefully
41        ErrorHandler().initialize()
42        fh_enable_mock.assert_called_once()
43
44    def test_record_exception(self):
45        with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": self.test_error_file}):
46            eh = ErrorHandler()
47            eh.initialize()
48
49            try:
50                raise_exception_fn()
51            except Exception as e:
52                eh.record_exception(e)
53
54            with open(self.test_error_file) as fp:
55                err = json.load(fp)
56                # error file content example:
57                # {
58                #   "message": {
59                #     "message": "RuntimeError: foobar",
60                #     "extraInfo": {
61                #       "py_callstack": "Traceback (most recent call last):\n  <... OMITTED ...>",
62                #       "timestamp": "1605774851"
63                #     }
64                #   }
65            self.assertIsNotNone(err["message"]["message"])
66            self.assertIsNotNone(err["message"]["extraInfo"]["py_callstack"])
67            self.assertIsNotNone(err["message"]["extraInfo"]["timestamp"])
68
69    def test_record_exception_no_error_file(self):
70        # make sure record does not fail when no error file is specified in env vars
71        with patch.dict(os.environ, {}):
72            eh = ErrorHandler()
73            eh.initialize()
74            try:
75                raise_exception_fn()
76            except Exception as e:
77                eh.record_exception(e)
78
79    def test_dump_error_file(self):
80        src_error_file = os.path.join(self.test_dir, "src_error.json")
81        eh = ErrorHandler()
82        with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": src_error_file}):
83            eh.record_exception(RuntimeError("foobar"))
84
85        with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": self.test_error_file}):
86            eh.dump_error_file(src_error_file)
87            self.assertTrue(filecmp.cmp(src_error_file, self.test_error_file))
88
89        with patch.dict(os.environ, {}):
90            eh.dump_error_file(src_error_file)
91            # just validate that dump_error_file works when
92            # my error file is not set
93            # should just log an error with src_error_file pretty printed
94
95    def test_dump_error_file_overwrite_existing(self):
96        dst_error_file = os.path.join(self.test_dir, "dst_error.json")
97        src_error_file = os.path.join(self.test_dir, "src_error.json")
98        eh = ErrorHandler()
99        with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": dst_error_file}):
100            eh.record_exception(RuntimeError("foo"))
101
102        with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": src_error_file}):
103            eh.record_exception(RuntimeError("bar"))
104
105        with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": dst_error_file}):
106            eh.dump_error_file(src_error_file)
107            self.assertTrue(filecmp.cmp(src_error_file, dst_error_file))
108