xref: /aosp_15_r20/external/pytorch/torch/_higher_order_ops/wrap.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import inspect
3import itertools
4import logging
5from typing import Optional
6
7from torch._logging import warning_once
8from torch._ops import HigherOrderOperator
9from torch.types import _dtype
10from torch.utils.checkpoint import checkpoint, CheckpointPolicy
11
12
13log = logging.getLogger(__name__)
14
15uid = itertools.count(1)
16
17
18# Used for testing the HigherOrderOperator mechanism
19class Wrap(HigherOrderOperator):
20    def __init__(self) -> None:
21        super().__init__("wrap")
22
23    def __call__(self, func, *args, **kwargs):
24        # Dynamo already traces the body of HigherOrderOp beforehand when it
25        # so no need to trace into it.
26        import torch._dynamo  # noqa: F401
27        from torch._dynamo import disable
28
29        @disable
30        def wrapper():
31            result = func(*args, **kwargs)
32            return result
33
34        return wrapper()
35
36
37wrap = Wrap()
38
39
40class WrapWithSetGradEnabled(HigherOrderOperator):
41    def __init__(self) -> None:
42        super().__init__("wrap_with_set_grad_enabled")
43
44    def __call__(self, enable_grad, wrapped_func, *args, **kwargs):
45        # Dynamo already traces the body of HigherOrderOp beforehand when it
46        # so no need to trace into it.
47        import torch._dynamo  # noqa: F401
48        from torch._dynamo import disable
49
50        @disable
51        def wrapper():
52            with torch.set_grad_enabled(enable_grad):
53                return wrapped_func(*args, **kwargs)
54
55        return wrapper()
56
57
58wrap_with_set_grad_enabled = WrapWithSetGradEnabled()
59
60
61class WrapWithAutocast(HigherOrderOperator):
62    def __init__(self):
63        super().__init__("wrap_with_autocast")
64
65    def __call__(
66        self,
67        device_type: str,
68        dtype: Optional[_dtype],
69        enabled: bool,
70        cache_enabled: Optional[bool],
71        wrapped_func,
72        *args,
73        **kwargs,
74    ):
75        # Dynamo already traces the body of HigherOrderOp beforehand when it
76        # so no need to trace into it.
77        import torch._dynamo  # noqa: F401
78        from torch._dynamo import disable
79
80        @disable
81        def wrapper():
82            with torch.autocast(device_type, dtype, enabled, cache_enabled):
83                return wrapped_func(*args, **kwargs)
84
85        return wrapper()
86
87
88wrap_with_autocast = WrapWithAutocast()
89
90
91class WrapActivationCheckpoint(HigherOrderOperator):
92    """
93    This operator is used to wrap torch.utils.checkpoint. This avoids
94    TorchDynamo to look into saved tensor hooks and directly passes the control
95    to AOT Autograd, which is ok with tracing saved tensor hooks. As a result of
96    AOT tracing torch.utils.checkpoint code, we have a backward graph with
97    recomputed forward nodes.
98
99    However, we might deprecate this operator soon. The difficulty arises in the
100    functionalization of rng ops. Today, there are two different
101    functionalization of rng ops - one at AOT autograd and other at Inductor.
102    And they are difficult to map to each other. The rng states also complicate
103    pattern matching in Inductor. Due to the ease of implementation, we are
104    currently inclined towards functionalization at Inductor level, which means
105    that duplication/recomputation is done as a compiler pass in the
106    partitioners. See TagActivationCheckpoint for more information.
107    """
108
109    def __init__(self) -> None:
110        super().__init__("wrap_activation_checkpoint")
111
112    def __call__(self, function, *args, **kwargs):
113        # use_reentrant is set to False because this op is going to be traced.
114        # And we ensure that AOT Autograd traces through the non reentrant
115        # version of checkpointing.
116        import torch.fx.traceback as fx_traceback
117        from torch.fx import Interpreter
118
119        kwargs["use_reentrant"] = False
120        kwargs["preserve_rng_state"] = False
121        # Using interpreter allows preservation of metadata through torch.compile stack.
122        with fx_traceback.preserve_node_meta():
123            return checkpoint(Interpreter(function).run, *args, **kwargs)
124
125
126wrap_activation_checkpoint = WrapActivationCheckpoint()
127
128
129class TagActivationCheckpoint(HigherOrderOperator):
130    """
131    This operator is supposed to be used only with torch.compile stack. This
132    accepts a Fx graph module which needs to be checkpointed. This operator adds
133    "recomputable" tag to the nodes of the Fx graph that should be recomputed.
134
135    The goal is to:
136    1. Avoid using Dynamo to trace through saved tensor hooks.
137    2. For selective checkpointing case, let AOTAutograd trace through
138       saved tensor hooks but has special logic with TorchDispatchMode to override
139       the usual saved_tensor_hooks fn logic in order to tag the nodes.
140    3. Rely on the partitioners to actually duplicate the nodes.
141    This sits well in the torch.compile stack, because by the time graph
142    reaches partitioner, inductor has already run its functionalization of rng
143    ops (by setting fixed seed for each random op, see `replace_random_passes`).
144    Therefore, the duplication of nodes, by design, respects the rng states in
145    the forward and recomputed forward in backward.
146    """
147
148    def __init__(self) -> None:
149        super().__init__("tag_activation_checkpoint")
150
151    @staticmethod
152    def divide_kwargs(kwargs):
153        """
154        checkpoint fn can have mixed kwargs between checkpointed fn and
155        checkpoint fn itself. For example
156        >> def gn(x, y, z=None):
157        >>     a = torch.matmul(x, y)
158        >>     if z is not None:
159        >>         return torch.matmul(a, z)
160        >>     return a
161        >> def fn(x, y, z):
162        >>     return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))
163        In the above case, z belongs to checkpointed function gn, but
164        use_reentrant belongs to the checkpoint function. This function splits
165        the kwargs into checkpoint_kwargs and gmod_kwargs (or
166        checkpointed_fn_kwargs).
167        We do sorting to ensure same graph from run to run for better
168        debuggability. It is not required for correctness.
169        """
170        ckpt_signature = inspect.signature(checkpoint)
171        checkpoint_keys = set()
172        for name in ckpt_signature.parameters:
173            if name in ("function", "args", "kwargs"):
174                continue
175            checkpoint_keys.add(name)
176
177        # `preserve_rng_state` is not a regular kwarg
178        checkpoint_keys.add("preserve_rng_state")
179
180        checkpoint_kwargs = {
181            name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys
182        }
183        gmod_kwargs = {
184            name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys
185        }
186        return checkpoint_kwargs, gmod_kwargs
187
188    def tag_nodes(self, gmod, is_sac):
189        unique_graph_id = next(uid)
190        for node in gmod.graph.nodes:
191            if node.op in ("call_function", "call_method", "call_module"):
192                node.meta["ac_graph_id"] = unique_graph_id
193                if is_sac:
194                    # For selective checkpointing, we will populate this tag later in _CachingTorchDispatchMode.
195                    node.meta["recompute"] = None
196                else:
197                    # Under vanilla activation checkpointing, all nodes should be recomputed.
198                    node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE
199        return gmod
200
201    def __call__(self, gmod, *args, **kwargs):
202        import torch.fx.traceback as fx_traceback
203        from torch.fx import Interpreter
204
205        if "_checkpoint_context_fn" in gmod.meta:
206            warning_once(
207                log,
208                """
209Detected that context_fn is passed to torch.utils.checkpoint under torch.compile.
210Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_).
211""",
212            )
213            # use_reentrant is set to False because this op is going to be traced.
214            # And we ensure that AOT Autograd traces through the non reentrant
215            # version of checkpointing.
216            kwargs["use_reentrant"] = False
217            # preserve_rng_state is set to False because we want to prevent AOTAutograd from tracing through
218            # `torch.random.fork_rng` op (which is not supported yet under CUDA).
219            # This doesn't mean that we don't preserve RNG state. Instead, we will always preserve RNG state
220            # regardless of this flag (by doing RNG functionalization via `replace_random_passes` in Inductor
221            # instead of in AOTAutograd).
222            kwargs["preserve_rng_state"] = False
223            kwargs["context_fn"] = gmod.meta["_checkpoint_context_fn"]
224            # We first tag all nodes as "recompute" in this graph, and then we undo the "recompute" tag
225            # for specific nodes in _CachingTorchDispatchMode in torch/utils/checkpoint.py.
226            gmod = self.tag_nodes(gmod, is_sac=True)
227            # Using interpreter allows preservation of metadata through torch.compile stack.
228            with fx_traceback.preserve_node_meta():
229                return checkpoint(Interpreter(gmod).run, *args, **kwargs)
230        else:
231            gmod = self.tag_nodes(gmod, is_sac=False)
232            # Using interpreter allows preservation of metadata through torch.compile stack.
233            # TODO: We want to use the same `checkpoint(Interpreter(gmod).run, *args, **kwargs)` here
234            # as the `context_fn != None` case, but that depends on in-place op support in TorchDispatchMode + torch.compile.
235            # (for details on in-place op issue, run `test_compile_selective_checkpoint_inplace_op` unit test)
236            with fx_traceback.preserve_node_meta():
237                return Interpreter(gmod).run(*args)
238
239
240tag_activation_checkpoint = TagActivationCheckpoint()
241