xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/events/lib_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: r2p"]
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.abs
9
10import json
11import logging
12from dataclasses import asdict
13from unittest.mock import patch
14
15from torch.distributed.elastic.events import (
16    _get_or_create_logger,
17    construct_and_record_rdzv_event,
18    Event,
19    EventSource,
20    NodeState,
21    RdzvEvent,
22)
23from torch.testing._internal.common_utils import run_tests, TestCase
24
25
26class EventLibTest(TestCase):
27    def assert_event(self, actual_event, expected_event):
28        self.assertEqual(actual_event.name, expected_event.name)
29        self.assertEqual(actual_event.source, expected_event.source)
30        self.assertEqual(actual_event.timestamp, expected_event.timestamp)
31        self.assertDictEqual(actual_event.metadata, expected_event.metadata)
32
33    @patch("torch.distributed.elastic.events.get_logging_handler")
34    def test_get_or_create_logger(self, logging_handler_mock):
35        logging_handler_mock.return_value = logging.NullHandler()
36        logger = _get_or_create_logger("test_destination")
37        self.assertIsNotNone(logger)
38        self.assertEqual(1, len(logger.handlers))
39        self.assertIsInstance(logger.handlers[0], logging.NullHandler)
40
41    def test_event_created(self):
42        event = Event(
43            name="test_event",
44            source=EventSource.AGENT,
45            metadata={"key1": "value1", "key2": 2},
46        )
47        self.assertEqual("test_event", event.name)
48        self.assertEqual(EventSource.AGENT, event.source)
49        self.assertDictEqual({"key1": "value1", "key2": 2}, event.metadata)
50
51    def test_event_deser(self):
52        event = Event(
53            name="test_event",
54            source=EventSource.AGENT,
55            metadata={"key1": "value1", "key2": 2, "key3": 1.0},
56        )
57        json_event = event.serialize()
58        deser_event = Event.deserialize(json_event)
59        self.assert_event(event, deser_event)
60
61
62class RdzvEventLibTest(TestCase):
63    @patch("torch.distributed.elastic.events.record_rdzv_event")
64    @patch("torch.distributed.elastic.events.get_logging_handler")
65    def test_construct_and_record_rdzv_event(self, get_mock, record_mock):
66        get_mock.return_value = logging.StreamHandler()
67        construct_and_record_rdzv_event(
68            run_id="test_run_id",
69            message="test_message",
70            node_state=NodeState.RUNNING,
71        )
72        record_mock.assert_called_once()
73
74    @patch("torch.distributed.elastic.events.record_rdzv_event")
75    @patch("torch.distributed.elastic.events.get_logging_handler")
76    def test_construct_and_record_rdzv_event_does_not_run_if_invalid_dest(
77        self, get_mock, record_mock
78    ):
79        get_mock.return_value = logging.NullHandler()
80        construct_and_record_rdzv_event(
81            run_id="test_run_id",
82            message="test_message",
83            node_state=NodeState.RUNNING,
84        )
85        record_mock.assert_not_called()
86
87    def assert_rdzv_event(self, actual_event: RdzvEvent, expected_event: RdzvEvent):
88        self.assertEqual(actual_event.name, expected_event.name)
89        self.assertEqual(actual_event.run_id, expected_event.run_id)
90        self.assertEqual(actual_event.message, expected_event.message)
91        self.assertEqual(actual_event.hostname, expected_event.hostname)
92        self.assertEqual(actual_event.pid, expected_event.pid)
93        self.assertEqual(actual_event.node_state, expected_event.node_state)
94        self.assertEqual(actual_event.master_endpoint, expected_event.master_endpoint)
95        self.assertEqual(actual_event.rank, expected_event.rank)
96        self.assertEqual(actual_event.local_id, expected_event.local_id)
97        self.assertEqual(actual_event.error_trace, expected_event.error_trace)
98
99    def get_test_rdzv_event(self) -> RdzvEvent:
100        return RdzvEvent(
101            name="test_name",
102            run_id="test_run_id",
103            message="test_message",
104            hostname="test_hostname",
105            pid=1,
106            node_state=NodeState.RUNNING,
107            master_endpoint="test_master_endpoint",
108            rank=3,
109            local_id=4,
110            error_trace="test_error_trace",
111        )
112
113    def test_rdzv_event_created(self):
114        event = self.get_test_rdzv_event()
115        self.assertEqual(event.name, "test_name")
116        self.assertEqual(event.run_id, "test_run_id")
117        self.assertEqual(event.message, "test_message")
118        self.assertEqual(event.hostname, "test_hostname")
119        self.assertEqual(event.pid, 1)
120        self.assertEqual(event.node_state, NodeState.RUNNING)
121        self.assertEqual(event.master_endpoint, "test_master_endpoint")
122        self.assertEqual(event.rank, 3)
123        self.assertEqual(event.local_id, 4)
124        self.assertEqual(event.error_trace, "test_error_trace")
125
126    def test_rdzv_event_deserialize(self):
127        event = self.get_test_rdzv_event()
128        json_event = event.serialize()
129        deserialized_event = RdzvEvent.deserialize(json_event)
130        self.assert_rdzv_event(event, deserialized_event)
131        self.assert_rdzv_event(event, RdzvEvent.deserialize(event))
132
133    def test_rdzv_event_str(self):
134        event = self.get_test_rdzv_event()
135        self.assertEqual(str(event), json.dumps(asdict(event)))
136
137
138if __name__ == "__main__":
139    run_tests()
140