1#!/usr/bin/env python3 2# Owner(s): ["oncall: r2p"] 3 4# Copyright (c) Facebook, Inc. and its affiliates. 5# All rights reserved. 6# 7# This source code is licensed under the BSD-style license found in the 8# LICENSE file in the root directory of this source tree.abs 9 10import json 11import logging 12from dataclasses import asdict 13from unittest.mock import patch 14 15from torch.distributed.elastic.events import ( 16 _get_or_create_logger, 17 construct_and_record_rdzv_event, 18 Event, 19 EventSource, 20 NodeState, 21 RdzvEvent, 22) 23from torch.testing._internal.common_utils import run_tests, TestCase 24 25 26class EventLibTest(TestCase): 27 def assert_event(self, actual_event, expected_event): 28 self.assertEqual(actual_event.name, expected_event.name) 29 self.assertEqual(actual_event.source, expected_event.source) 30 self.assertEqual(actual_event.timestamp, expected_event.timestamp) 31 self.assertDictEqual(actual_event.metadata, expected_event.metadata) 32 33 @patch("torch.distributed.elastic.events.get_logging_handler") 34 def test_get_or_create_logger(self, logging_handler_mock): 35 logging_handler_mock.return_value = logging.NullHandler() 36 logger = _get_or_create_logger("test_destination") 37 self.assertIsNotNone(logger) 38 self.assertEqual(1, len(logger.handlers)) 39 self.assertIsInstance(logger.handlers[0], logging.NullHandler) 40 41 def test_event_created(self): 42 event = Event( 43 name="test_event", 44 source=EventSource.AGENT, 45 metadata={"key1": "value1", "key2": 2}, 46 ) 47 self.assertEqual("test_event", event.name) 48 self.assertEqual(EventSource.AGENT, event.source) 49 self.assertDictEqual({"key1": "value1", "key2": 2}, event.metadata) 50 51 def test_event_deser(self): 52 event = Event( 53 name="test_event", 54 source=EventSource.AGENT, 55 metadata={"key1": "value1", "key2": 2, "key3": 1.0}, 56 ) 57 json_event = event.serialize() 58 deser_event = Event.deserialize(json_event) 59 self.assert_event(event, deser_event) 60 61 62class RdzvEventLibTest(TestCase): 63 @patch("torch.distributed.elastic.events.record_rdzv_event") 64 @patch("torch.distributed.elastic.events.get_logging_handler") 65 def test_construct_and_record_rdzv_event(self, get_mock, record_mock): 66 get_mock.return_value = logging.StreamHandler() 67 construct_and_record_rdzv_event( 68 run_id="test_run_id", 69 message="test_message", 70 node_state=NodeState.RUNNING, 71 ) 72 record_mock.assert_called_once() 73 74 @patch("torch.distributed.elastic.events.record_rdzv_event") 75 @patch("torch.distributed.elastic.events.get_logging_handler") 76 def test_construct_and_record_rdzv_event_does_not_run_if_invalid_dest( 77 self, get_mock, record_mock 78 ): 79 get_mock.return_value = logging.NullHandler() 80 construct_and_record_rdzv_event( 81 run_id="test_run_id", 82 message="test_message", 83 node_state=NodeState.RUNNING, 84 ) 85 record_mock.assert_not_called() 86 87 def assert_rdzv_event(self, actual_event: RdzvEvent, expected_event: RdzvEvent): 88 self.assertEqual(actual_event.name, expected_event.name) 89 self.assertEqual(actual_event.run_id, expected_event.run_id) 90 self.assertEqual(actual_event.message, expected_event.message) 91 self.assertEqual(actual_event.hostname, expected_event.hostname) 92 self.assertEqual(actual_event.pid, expected_event.pid) 93 self.assertEqual(actual_event.node_state, expected_event.node_state) 94 self.assertEqual(actual_event.master_endpoint, expected_event.master_endpoint) 95 self.assertEqual(actual_event.rank, expected_event.rank) 96 self.assertEqual(actual_event.local_id, expected_event.local_id) 97 self.assertEqual(actual_event.error_trace, expected_event.error_trace) 98 99 def get_test_rdzv_event(self) -> RdzvEvent: 100 return RdzvEvent( 101 name="test_name", 102 run_id="test_run_id", 103 message="test_message", 104 hostname="test_hostname", 105 pid=1, 106 node_state=NodeState.RUNNING, 107 master_endpoint="test_master_endpoint", 108 rank=3, 109 local_id=4, 110 error_trace="test_error_trace", 111 ) 112 113 def test_rdzv_event_created(self): 114 event = self.get_test_rdzv_event() 115 self.assertEqual(event.name, "test_name") 116 self.assertEqual(event.run_id, "test_run_id") 117 self.assertEqual(event.message, "test_message") 118 self.assertEqual(event.hostname, "test_hostname") 119 self.assertEqual(event.pid, 1) 120 self.assertEqual(event.node_state, NodeState.RUNNING) 121 self.assertEqual(event.master_endpoint, "test_master_endpoint") 122 self.assertEqual(event.rank, 3) 123 self.assertEqual(event.local_id, 4) 124 self.assertEqual(event.error_trace, "test_error_trace") 125 126 def test_rdzv_event_deserialize(self): 127 event = self.get_test_rdzv_event() 128 json_event = event.serialize() 129 deserialized_event = RdzvEvent.deserialize(json_event) 130 self.assert_rdzv_event(event, deserialized_event) 131 self.assert_rdzv_event(event, RdzvEvent.deserialize(event)) 132 133 def test_rdzv_event_str(self): 134 event = self.get_test_rdzv_event() 135 self.assertEqual(str(event), json.dumps(asdict(event))) 136 137 138if __name__ == "__main__": 139 run_tests() 140