1# mypy: allow-untyped-defs 2""" 3APIs related to torch.compile which lazily import torch._dynamo to avoid 4circular dependencies. 5""" 6 7import functools 8 9 10def _disable_dynamo(fn=None, recursive=True): 11 """ 12 This API should be only used inside torch, external users should still use 13 torch._dynamo.disable. The main goal of this API is to avoid circular 14 imports issues that is common while using _dynamo.disable inside torch 15 itself. 16 17 This API avoids it by lazily importing torch._dynamo from the import time to 18 the invocation of the decorated function. 19 """ 20 if fn is not None: 21 22 @functools.wraps(fn) 23 def inner(*args, **kwargs): 24 # cache this on the first invocation to avoid adding too much overhead. 25 disable_fn = getattr(fn, "__dynamo_disable", None) 26 if disable_fn is None: 27 import torch._dynamo 28 29 disable_fn = torch._dynamo.disable(fn, recursive) 30 fn.__dynamo_disable = disable_fn 31 32 return disable_fn(*args, **kwargs) 33 34 return inner 35 else: 36 # decorator usage like @_disable_dynamo(recursive=False). The resulting 37 # object expects the original decorated function as the arg. 38 return functools.partial(_disable_dynamo, recursive=recursive) 39