xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/events/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env/python3
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9"""
10Module contains events processing mechanisms that are integrated with the standard python logging.
11
12Example of usage:
13
14::
15
16  from torch.distributed.elastic import events
17  event = events.Event(name="test_event", source=events.EventSource.WORKER, metadata={...})
18  events.get_logging_handler(destination="console").info(event)
19
20"""
21
22import inspect
23import logging
24import os
25import socket
26import traceback
27from typing import Dict, Optional
28
29from torch.distributed.elastic.events.handlers import get_logging_handler
30
31from .api import (  # noqa: F401
32    Event,
33    EventMetadataValue,
34    EventSource,
35    NodeState,
36    RdzvEvent,
37)
38
39
40_events_loggers: Dict[str, logging.Logger] = {}
41
42
43def _get_or_create_logger(destination: str = "null") -> logging.Logger:
44    """
45    Construct python logger based on the destination type or extends if provided.
46
47    Available destination could be found in ``handlers.py`` file.
48    The constructed logger does not propagate messages to the upper level loggers,
49    e.g. root logger. This makes sure that a single event can be processed once.
50
51    Args:
52        destination: The string representation of the event handler.
53            Available handlers found in ``handlers`` module
54    """
55    global _events_loggers
56
57    if destination not in _events_loggers:
58        _events_logger = logging.getLogger(f"torchelastic-events-{destination}")
59        _events_logger.setLevel(os.environ.get("LOGLEVEL", "INFO"))
60        # Do not propagate message to the root logger
61        _events_logger.propagate = False
62
63        logging_handler = get_logging_handler(destination)
64        _events_logger.addHandler(logging_handler)
65
66        # Add the logger to the global dictionary
67        _events_loggers[destination] = _events_logger
68
69    return _events_loggers[destination]
70
71
72def record(event: Event, destination: str = "null") -> None:
73    _get_or_create_logger(destination).info(event.serialize())
74
75
76def record_rdzv_event(event: RdzvEvent) -> None:
77    _get_or_create_logger("dynamic_rendezvous").info(event.serialize())
78
79
80def construct_and_record_rdzv_event(
81    run_id: str,
82    message: str,
83    node_state: NodeState,
84    name: str = "",
85    hostname: str = "",
86    pid: Optional[int] = None,
87    master_endpoint: str = "",
88    local_id: Optional[int] = None,
89    rank: Optional[int] = None,
90) -> None:
91    """
92    Initialize rendezvous event object and record its operations.
93
94    Args:
95        run_id (str): The run id of the rendezvous.
96        message (str): The message describing the event.
97        node_state (NodeState): The state of the node (INIT, RUNNING, SUCCEEDED, FAILED).
98        name (str): Event name. (E.g. Current action being performed).
99        hostname (str): Hostname of the node.
100        pid (Optional[int]): The process id of the node.
101        master_endpoint (str): The master endpoint for the rendezvous store, if known.
102        local_id (Optional[int]):  The local_id of the node, if defined in dynamic_rendezvous.py
103        rank (Optional[int]): The rank of the node, if known.
104    Returns:
105        None
106    Example:
107        >>> # See DynamicRendezvousHandler class
108        >>> def _record(
109        ...     self,
110        ...     message: str,
111        ...     node_state: NodeState = NodeState.RUNNING,
112        ...     rank: Optional[int] = None,
113        ... ) -> None:
114        ...     construct_and_record_rdzv_event(
115        ...         name=f"{self.__class__.__name__}.{get_method_name()}",
116        ...         run_id=self._settings.run_id,
117        ...         message=message,
118        ...         node_state=node_state,
119        ...         hostname=self._this_node.addr,
120        ...         pid=self._this_node.pid,
121        ...         local_id=self._this_node.local_id,
122        ...         rank=rank,
123        ...     )
124    """
125    # We don't want to perform an extra computation if not needed.
126    if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler):
127        return
128
129    # Set up parameters.
130    if not hostname:
131        hostname = socket.getfqdn()
132    if not pid:
133        pid = os.getpid()
134
135    # Determines which file called this function.
136    callstack = inspect.stack()
137    filename = "no_file"
138    if len(callstack) > 1:
139        stack_depth_1 = callstack[1]
140        filename = os.path.basename(stack_depth_1.filename)
141        if not name:
142            name = stack_depth_1.function
143
144    # Delete the callstack variable. If kept, this can mess with python's
145    # garbage collector as we are holding on to stack frame information in
146    # the inspect module.
147    del callstack
148
149    # Set up error trace if this is an exception
150    if node_state == NodeState.FAILED:
151        error_trace = traceback.format_exc()
152    else:
153        error_trace = ""
154
155    # Initialize event object
156    event = RdzvEvent(
157        name=f"{filename}:{name}",
158        run_id=run_id,
159        message=message,
160        hostname=hostname,
161        pid=pid,
162        node_state=node_state,
163        master_endpoint=master_endpoint,
164        rank=rank,
165        local_id=local_id,
166        error_trace=error_trace,
167    )
168
169    # Finally, record the event.
170    record_rdzv_event(event)
171