xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/utils/logging.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4# Copyright (c) Facebook, Inc. and its affiliates.
5# All rights reserved.
6#
7# This source code is licensed under the BSD-style license found in the
8# LICENSE file in the root directory of this source tree.
9
10import inspect
11import logging
12import os
13import warnings
14from typing import Optional
15
16from torch.distributed.elastic.utils.log_level import get_log_level
17
18
19def get_logger(name: Optional[str] = None):
20    """
21    Util function to set up a simple logger that writes
22    into stderr. The loglevel is fetched from the LOGLEVEL
23    env. variable or WARNING as default. The function will use the
24    module name of the caller if no name is provided.
25
26    Args:
27        name: Name of the logger. If no name provided, the name will
28              be derived from the call stack.
29    """
30
31    # Derive the name of the caller, if none provided
32    # Use depth=2 since this function takes up one level in the call stack
33    return _setup_logger(name or _derive_module_name(depth=2))
34
35
36def _setup_logger(name: Optional[str] = None):
37    logger = logging.getLogger(name)
38    logger.setLevel(os.environ.get("LOGLEVEL", get_log_level()))
39    return logger
40
41
42def _derive_module_name(depth: int = 1) -> Optional[str]:
43    """
44    Derives the name of the caller module from the stack frames.
45
46    Args:
47        depth: The position of the frame in the stack.
48    """
49    try:
50        stack = inspect.stack()
51        assert depth < len(stack)
52        # FrameInfo is just a named tuple: (frame, filename, lineno, function, code_context, index)
53        frame_info = stack[depth]
54
55        module = inspect.getmodule(frame_info[0])
56        if module:
57            module_name = module.__name__
58        else:
59            # inspect.getmodule(frame_info[0]) does NOT work (returns None) in
60            # binaries built with @mode/opt
61            # return the filename (minus the .py extension) as modulename
62            filename = frame_info[1]
63            module_name = os.path.splitext(os.path.basename(filename))[0]
64        return module_name
65    except Exception as e:
66        warnings.warn(
67            f"Error deriving logger module name, using <None>. Exception: {e}",
68            RuntimeWarning,
69        )
70        return None
71