1#!/usr/bin/env python3 2# Owner(s): ["oncall: r2p"] 3 4import filecmp 5import json 6import os 7import shutil 8import tempfile 9import unittest 10from unittest.mock import patch 11 12from torch.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler 13from torch.distributed.elastic.multiprocessing.errors.handlers import get_error_handler 14 15 16def raise_exception_fn(): 17 raise RuntimeError("foobar") 18 19 20class GetErrorHandlerTest(unittest.TestCase): 21 def test_get_error_handler(self): 22 self.assertTrue(isinstance(get_error_handler(), ErrorHandler)) 23 24 25class ErrorHandlerTest(unittest.TestCase): 26 def setUp(self): 27 self.test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__) 28 self.test_error_file = os.path.join(self.test_dir, "error.json") 29 30 def tearDown(self): 31 shutil.rmtree(self.test_dir) 32 33 @patch("faulthandler.enable") 34 def test_initialize(self, fh_enable_mock): 35 ErrorHandler().initialize() 36 fh_enable_mock.assert_called_once() 37 38 @patch("faulthandler.enable", side_effect=RuntimeError) 39 def test_initialize_error(self, fh_enable_mock): 40 # makes sure that initialize handles errors gracefully 41 ErrorHandler().initialize() 42 fh_enable_mock.assert_called_once() 43 44 def test_record_exception(self): 45 with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": self.test_error_file}): 46 eh = ErrorHandler() 47 eh.initialize() 48 49 try: 50 raise_exception_fn() 51 except Exception as e: 52 eh.record_exception(e) 53 54 with open(self.test_error_file) as fp: 55 err = json.load(fp) 56 # error file content example: 57 # { 58 # "message": { 59 # "message": "RuntimeError: foobar", 60 # "extraInfo": { 61 # "py_callstack": "Traceback (most recent call last):\n <... OMITTED ...>", 62 # "timestamp": "1605774851" 63 # } 64 # } 65 self.assertIsNotNone(err["message"]["message"]) 66 self.assertIsNotNone(err["message"]["extraInfo"]["py_callstack"]) 67 self.assertIsNotNone(err["message"]["extraInfo"]["timestamp"]) 68 69 def test_record_exception_no_error_file(self): 70 # make sure record does not fail when no error file is specified in env vars 71 with patch.dict(os.environ, {}): 72 eh = ErrorHandler() 73 eh.initialize() 74 try: 75 raise_exception_fn() 76 except Exception as e: 77 eh.record_exception(e) 78 79 def test_dump_error_file(self): 80 src_error_file = os.path.join(self.test_dir, "src_error.json") 81 eh = ErrorHandler() 82 with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": src_error_file}): 83 eh.record_exception(RuntimeError("foobar")) 84 85 with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": self.test_error_file}): 86 eh.dump_error_file(src_error_file) 87 self.assertTrue(filecmp.cmp(src_error_file, self.test_error_file)) 88 89 with patch.dict(os.environ, {}): 90 eh.dump_error_file(src_error_file) 91 # just validate that dump_error_file works when 92 # my error file is not set 93 # should just log an error with src_error_file pretty printed 94 95 def test_dump_error_file_overwrite_existing(self): 96 dst_error_file = os.path.join(self.test_dir, "dst_error.json") 97 src_error_file = os.path.join(self.test_dir, "src_error.json") 98 eh = ErrorHandler() 99 with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": dst_error_file}): 100 eh.record_exception(RuntimeError("foo")) 101 102 with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": src_error_file}): 103 eh.record_exception(RuntimeError("bar")) 104 105 with patch.dict(os.environ, {"TORCHELASTIC_ERROR_FILE": dst_error_file}): 106 eh.dump_error_file(src_error_file) 107 self.assertTrue(filecmp.cmp(src_error_file, dst_error_file)) 108