xref: /aosp_15_r20/external/pytorch/torch/_dynamo/logging.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import itertools
3import logging
4
5from torch.hub import _Faketqdm, tqdm
6
7
8# Disable progress bar by default, not in dynamo config because otherwise get a circular import
9disable_progress = True
10
11
12# Return all loggers that torchdynamo/torchinductor is responsible for
13def get_loggers():
14    return [
15        logging.getLogger("torch.fx.experimental.symbolic_shapes"),
16        logging.getLogger("torch._dynamo"),
17        logging.getLogger("torch._inductor"),
18    ]
19
20
21# Creates a logging function that logs a message with a step # prepended.
22# get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
23# so that step numbers are initialized properly. e.g.:
24
25# @functools.lru_cache(None)
26# def _step_logger():
27#     return get_step_logger(logging.getLogger(...))
28
29# def fn():
30#     _step_logger()(logging.INFO, "msg")
31
32_step_counter = itertools.count(1)
33
34# Update num_steps if more phases are added: Dynamo, AOT, Backend
35# This is very inductor centric
36# _inductor.utils.has_triton() gives a circular import error here
37
38if not disable_progress:
39    try:
40        import triton  # noqa: F401
41
42        num_steps = 3
43    except ImportError:
44        num_steps = 2
45    pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0)
46
47
48def get_step_logger(logger):
49    if not disable_progress:
50        pbar.update(1)
51        if not isinstance(pbar, _Faketqdm):
52            pbar.set_postfix_str(f"{logger.name}")
53
54    step = next(_step_counter)
55
56    def log(level, msg, **kwargs):
57        logger.log(level, "Step %s: %s", step, msg, **kwargs)
58
59    return log
60