xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/optimizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import weakref
4from typing import Dict, List, TYPE_CHECKING
5
6import torch
7from torch.utils._pytree import tree_map_only
8
9from ..guards import GuardBuilder, install_guard
10from ..source import (
11    AttrSource,
12    ConstDictKeySource,
13    GetItemSource,
14    GlobalWeakRefSource,
15    GradSource,
16)
17from ..utils import GLOBAL_KEY_PREFIX
18from .constant import ConstantVariable
19from .dicts import ConstDictVariable
20from .lists import ListVariable
21from .misc import GetAttrVariable
22from .user_defined import UserDefinedObjectVariable
23
24
25if TYPE_CHECKING:
26    from torch._dynamo.symbolic_convert import InstructionTranslator
27
28    from .base import VariableTracker
29
30
31class ArgMappingException(Exception):
32    pass
33
34
35class GuardInstallException(Exception):
36    pass
37
38
39class OptimizerVariable(UserDefinedObjectVariable):
40    _nonvar_fields = {
41        "grad_to_source",
42        "tensor_to_source",
43        "static_tensor_names",
44        *UserDefinedObjectVariable._nonvar_fields,
45    }
46
47    def __init__(
48        self,
49        value,
50        grad_to_source=None,
51        static_tensor_names=None,
52        tensor_to_source=None,
53        **kwargs,
54    ) -> None:
55        super().__init__(value, **kwargs)
56        self.grad_to_source = grad_to_source or {}
57        self.tensor_to_source = tensor_to_source or {}
58        self.static_tensor_names = static_tensor_names or set()
59
60    def call_method(
61        self,
62        tx,
63        name,
64        args: "List[VariableTracker]",
65        kwargs: "Dict[str, VariableTracker]",
66    ) -> "VariableTracker":
67        """This is an optimization to avoid tracing the very slow initialization of the optimizer"""
68        if name == "_init_group":
69            try:
70                self.graph_break_if_pending_mutation(tx)
71                self.move_step_if_cpu()
72                py_args, py_kwargs = self.get_python_args(*args, **kwargs)
73                ret_val = self.value._init_group(*py_args, **py_kwargs)
74                self.map_sources_and_install_guards(tx)
75                self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
76                # stash a weak_ptr to optimizer to invalidate code
77                # if the optimizer object dies
78                mangled_name = f"__optimizer_{id(self.value)}"
79                tx.store_global_weakref_by_id(mangled_name, self.value)
80                self.create_finalizer(tx)
81
82                # This is currently safe only because the only actual `ret_val`s returned
83                # by the `_init_group` of existing optimizers are properties that are invariant
84                # to the input tensors (e.g. dtype, layout). Changing these would trigger a
85                # recompilation and hence never result in the wrong specialization of `ret_val`.
86                return ConstantVariable.create(ret_val)
87            except (ArgMappingException, GuardInstallException) as _:
88                # trace normally if we can't map args or install guards correctly
89                pass
90
91        return super().call_method(tx, name, args, kwargs)
92
93    def var_getattr(self, tx: "InstructionTranslator", name):
94        # Note: this allows us to intercept the call in call_method
95        # in the typical case, we return a UserMethodVariable
96        # which will directly inline
97        if name in ("_init_group", "step"):
98            return GetAttrVariable(self, name, source=AttrSource(self.source, name))
99
100        if name == "param_groups":
101            from ..decorators import mark_static_address
102
103            for group in self.value.param_groups:
104                for p in group["params"]:
105                    mark_static_address(p)
106
107            self._set_capturable(tx)
108
109        return super().var_getattr(tx, name)
110
111    def graph_break_if_pending_mutation(self, tx):
112        # If there are pending mutations on a parameter (due to using closure)
113        # then we need to graph break to allow the python version of the parameter
114        # to update, so that running _init_group will initialize the states with
115        # the correct values
116        for g in self.value.param_groups:
117            for p in g["params"]:
118                side_effects = tx.output.side_effects
119                variable = side_effects.id_to_variable.get(id(p), None)
120                if variable and side_effects.has_pending_mutation(variable):
121                    from ..exc import Unsupported
122
123                    raise Unsupported("Pending mutation on parameter")
124
125    def _set_capturable(self, tx):
126        from . import LazyVariableTracker
127        from .builder import VariableBuilder
128
129        # We only set capturable if params are on cuda
130        # and the state is not initialized
131        def safe_to_set_capturable(group):
132            all_uninitialized = True
133            all_gpu = True
134
135            for p in group.get("params", []):
136                all_gpu &= p.is_cuda or p.is_xpu
137                all_uninitialized &= p not in self.value.state
138
139            return "capturable" in group and all_uninitialized and all_gpu
140
141        # track indices to not set so we don't need to
142        # in the variable tracker realize the whole state
143        # we handle guarding the state specially
144        for ind, group in enumerate(self.value.param_groups):
145            if safe_to_set_capturable(group):
146                group["capturable"] = True
147
148        param_groups_vt = LazyVariableTracker.realize_all(
149            VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
150                self.value.param_groups
151            )
152        )
153        for ind, param_group_vt in enumerate(param_groups_vt.items):
154            key = ConstDictVariable._HashableTracker(
155                ConstantVariable.create("capturable")
156            )
157            param_group_vt.items[key] = ConstantVariable.create(True)
158
159    def get_python_args(self, *args, **kwargs):
160        """Get python values equivalent to the variable tracker args"""
161
162        def map_arg(arg):
163            if isinstance(arg, ConstantVariable):
164                return arg.as_python_constant()
165            elif isinstance(arg, ListVariable) and not arg.items:
166                return []
167            elif (
168                isinstance(arg, ConstDictVariable)
169                and isinstance(arg.source, GetItemSource)
170                and isinstance(arg.source.base, AttrSource)
171                and arg.source.base.member == "param_groups"
172            ):
173                return self.value.param_groups[arg.source.index]
174
175            raise ArgMappingException
176
177        new_args = [map_arg(arg) for arg in args]
178        new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}
179
180        return new_args, new_kwargs
181
182    # If users load an old state dictionary,
183    # it's possible that step could be on the cpu
184    # if this is the case, move it to the GPU
185    # corresponding to the parameter
186    # in most cases this is a no-op because the state is empty
187    def move_step_if_cpu(self):
188        for p, state in self.value.state.items():
189            if "step" in state and state["step"].is_cpu:
190                state["step"] = state["step"].to(p.device)
191
192    def map_sources_and_install_guards(self, tx):
193        from ..decorators import mark_static_address
194        from .builder import VariableBuilder
195        from .lazy import LazyVariableTracker
196
197        self.grad_to_source = {}
198        self.tensor_to_source = {}
199
200        # Tracing the _init_group is expensive. But we still have to insert the
201        # necessary guards for _init_group. So, we manually handle insertion of
202        # guards. We also want to mark all the tensors inside the state dict to
203        # be static address.
204
205        # Mark all the tensors in the state dict to be static address. This has
206        # to be done first because the variable builder relies on the static
207        # address annotation.
208        def mark_static(x):
209            mark_static_address(x)
210
211        tree_map_only(torch.Tensor, mark_static, self.value.state)
212
213        # Recursively realize the variable trackers for optim.state and
214        # optim.param_groups, which recursively install the necessary guards.
215        param_groups_vt = LazyVariableTracker.realize_all(
216            VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
217                self.value.param_groups
218            )
219        )
220
221        state_vt = VariableBuilder(tx, AttrSource(self.source, "state"))(
222            self.value.state
223        )
224
225        # We need to realize the top level state dict to populate
226        # the guard locals
227        state_vt.realize()
228
229        # Populate self.grad_to_source and self.tensor_to_source so that we can
230        # manually update_list_args
231        for g_ind, (group, group_vt) in enumerate(
232            zip(self.value.param_groups, param_groups_vt.items)
233        ):
234            # we assume here that all params within a param group
235            # are initialized similarly
236            if len(group["params"]) > 0:
237                for param in group["params"]:
238                    if param.grad is not None:
239                        key_index = None
240                        for i, k in enumerate(self.value.state.keys()):
241                            if k is param:
242                                key_index = i
243                                break
244                        if key_index:
245                            state_source = AttrSource(self.source, "state")
246                            LazyVariableTracker.realize_all(
247                                VariableBuilder(
248                                    tx,
249                                    GetItemSource(
250                                        state_source,
251                                        ConstDictKeySource(state_source, key_index),
252                                    ),
253                                )(self.value.state[param])
254                            )
255                            break
256
257            group_source = group_vt.source
258            params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params"))
259            for p_ind, (p, p_vt) in enumerate(
260                zip(group["params"], params_vt.unpack_var_sequence(tx))
261            ):
262                param_source = p_vt.source
263                self.tensor_to_source[p] = param_source
264                grad_source = GradSource(
265                    param_source,
266                    "grad",
267                )
268
269                if p.grad is not None:
270                    self.grad_to_source[p.grad] = grad_source
271                else:
272                    install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH))
273
274        # We have to again iterate over the state dict to collect the
275        # tensor_to_source dict. This is used for the finalizer.
276        state_source = AttrSource(self.source, "state")
277        for idx, (p, value) in enumerate(self.value.state.items()):
278            p_state_source = GetItemSource(
279                state_source, ConstDictKeySource(state_source, idx)
280            )
281            for k, v in value.items():
282                if (
283                    isinstance(v, torch.Tensor)
284                    and v not in self.grad_to_source
285                    and v not in self.tensor_to_source
286                ):
287                    self.tensor_to_source[v] = GetItemSource(p_state_source, k)
288
289    def wrap_tensor(self, tx: "InstructionTranslator", tensor_value):
290        """Wrap state tensor in a TensorVariable"""
291        from ..decorators import mark_static_address
292        from .builder import VariableBuilder
293
294        # If we have a source for a tensor already use it,
295        # if we have not seen a tensor before, stash and use a
296        # global weak ref source, since it must be an optimizer tensor
297        # that we have missed
298
299        if tensor_value in self.tensor_to_source:
300            # mark these tensors as static for cudagraphs
301            mark_static_address(tensor_value)
302            builder = VariableBuilder(tx, self.tensor_to_source[tensor_value])
303            self.static_tensor_names.add(tx.output.module_key_name(builder.name))
304        elif tensor_value in self.grad_to_source:
305            builder = VariableBuilder(tx, self.grad_to_source[tensor_value])
306        else:
307            # mark these tensors as static for cudagraphs
308            mark_static_address(tensor_value)
309
310            global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
311            builder = VariableBuilder(tx, GlobalWeakRefSource(global_name))
312            self.static_tensor_names.add(tx.output.module_key_name(builder.name))
313
314        result = builder(tensor_value)
315        return result
316
317    def update_list_args(
318        self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs
319    ):
320        """Update the args and kwargs to the traced optimizer call"""
321        for arg, py_arg in zip(args, py_args):
322            if isinstance(arg, ListVariable):
323                assert isinstance(
324                    py_arg, list
325                ), "py_arg should be a list in optimizer variable"
326                for i, val in enumerate(py_arg):
327                    tx.output.side_effects.mutation(arg)
328                    if isinstance(val, torch.Tensor):
329                        arg.items.append(self.wrap_tensor(tx, val))
330                    else:
331                        from .builder import SourcelessBuilder, VariableBuilder
332
333                        if arg.source:
334                            arg.items.append(
335                                VariableBuilder(tx, GetItemSource(arg.source, i))(val)
336                            )
337                        else:
338                            arg.items.append(SourcelessBuilder.create(tx, val))
339
340    def create_finalizer(self, tx):
341        names_to_delete = self.static_tensor_names
342        value = self.value
343        tc = tx.output.tracing_context
344
345        def init_finalizer(gm):
346            def clear_static_tensor_refs():
347                for name in names_to_delete:
348                    gm._buffers.pop(name, None)
349                    gm._parameters.pop(name, None)
350                    if tc.params_flat:
351                        tc.params_flat.clear()
352
353            weakref.finalize(value, clear_static_tensor_refs)
354
355        tx.output.add_graph_finalizer(init_finalizer)
356