xref: /aosp_15_r20/external/pytorch/torch/_compile.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3APIs related to torch.compile which lazily import torch._dynamo to avoid
4circular dependencies.
5"""
6
7import functools
8
9
10def _disable_dynamo(fn=None, recursive=True):
11    """
12    This API should be only used inside torch, external users should still use
13    torch._dynamo.disable. The main goal of this API is to avoid circular
14    imports issues that is common while using _dynamo.disable inside torch
15    itself.
16
17    This API avoids it by lazily importing torch._dynamo from the import time to
18    the invocation of the decorated function.
19    """
20    if fn is not None:
21
22        @functools.wraps(fn)
23        def inner(*args, **kwargs):
24            # cache this on the first invocation to avoid adding too much overhead.
25            disable_fn = getattr(fn, "__dynamo_disable", None)
26            if disable_fn is None:
27                import torch._dynamo
28
29                disable_fn = torch._dynamo.disable(fn, recursive)
30                fn.__dynamo_disable = disable_fn
31
32            return disable_fn(*args, **kwargs)
33
34        return inner
35    else:
36        # decorator usage like @_disable_dynamo(recursive=False). The resulting
37        # object expects the original decorated function as the arg.
38        return functools.partial(_disable_dynamo, recursive=recursive)
39