1# mypy: allow-untyped-defs 2import itertools 3import logging 4 5from torch.hub import _Faketqdm, tqdm 6 7 8# Disable progress bar by default, not in dynamo config because otherwise get a circular import 9disable_progress = True 10 11 12# Return all loggers that torchdynamo/torchinductor is responsible for 13def get_loggers(): 14 return [ 15 logging.getLogger("torch.fx.experimental.symbolic_shapes"), 16 logging.getLogger("torch._dynamo"), 17 logging.getLogger("torch._inductor"), 18 ] 19 20 21# Creates a logging function that logs a message with a step # prepended. 22# get_step_logger should be lazily called (i.e. at runtime, not at module-load time) 23# so that step numbers are initialized properly. e.g.: 24 25# @functools.lru_cache(None) 26# def _step_logger(): 27# return get_step_logger(logging.getLogger(...)) 28 29# def fn(): 30# _step_logger()(logging.INFO, "msg") 31 32_step_counter = itertools.count(1) 33 34# Update num_steps if more phases are added: Dynamo, AOT, Backend 35# This is very inductor centric 36# _inductor.utils.has_triton() gives a circular import error here 37 38if not disable_progress: 39 try: 40 import triton # noqa: F401 41 42 num_steps = 3 43 except ImportError: 44 num_steps = 2 45 pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0) 46 47 48def get_step_logger(logger): 49 if not disable_progress: 50 pbar.update(1) 51 if not isinstance(pbar, _Faketqdm): 52 pbar.set_postfix_str(f"{logger.name}") 53 54 step = next(_step_counter) 55 56 def log(level, msg, **kwargs): 57 logger.log(level, "Step %s: %s", step, msg, **kwargs) 58 59 return log 60