xref: /aosp_15_r20/external/pytorch/torch/_dynamo/side_effects.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import inspect
4import warnings
5from collections.abc import MutableMapping
6from typing import Any, Dict, List, Optional, Type, Union
7
8import torch.nn
9
10from . import utils, variables
11from .bytecode_transformation import (
12    bytecode_from_template,
13    create_call_function,
14    create_call_method,
15    create_instruction,
16)
17from .codegen import PyCodegen
18from .exc import unimplemented
19from .source import GlobalSource, LocalSource, Source
20from .utils import is_frozen_dataclass, nn_module_new, object_new
21from .variables.base import (
22    is_side_effect_safe,
23    MutableLocalBase,
24    MutableLocalSource,
25    VariableTracker,
26)
27from .variables.user_defined import FrozenDataClassVariable
28
29
30class MutableSideEffects(MutableLocalBase):
31    """
32    VariableTracker.mutable_local marker to indicate a list passed as
33    an input that if we mutate we need to re-apply those mutations after
34    the graph runs.
35    """
36
37    def __init__(self, source: Source, is_modified: bool = False):
38        super().__init__(MutableLocalSource.Existing)
39        self.source = source
40        self.is_modified = is_modified
41
42
43class AttributeMutation(MutableLocalBase):
44    """
45    VariableTracker.mutable_local marker to track changes to attributes
46    """
47
48    def __init__(self, typ: MutableLocalSource, source: Optional[Source]):
49        super().__init__(typ)
50        self.source = source
51
52
53class AttributeMutationExisting(AttributeMutation):
54    def __init__(self, source: Source):
55        super().__init__(MutableLocalSource.Existing, source)
56        self.source = source
57
58
59class AttributeMutationNew(AttributeMutation):
60    def __init__(self, source: Optional[Source], cls_source: Optional[Source]):
61        super().__init__(MutableLocalSource.Local, source)
62        self.cls_source = cls_source
63
64
65def _manual_update_dict(dict_from, dict_to):
66    for k, v in dict_from.items():
67        dict_to[k] = v
68
69
70class SideEffects:
71    """
72    Track side effects (list mutation, setattr, etc) that need to be
73    applied after an FX graph is run.
74    """
75
76    id_to_variable: Dict[int, VariableTracker]
77    store_attr_mutations: Dict[MutableLocalBase, Dict[str, VariableTracker]]
78    keepalive: List[Any]
79
80    def __init__(
81        self,
82        id_to_variable=None,
83        store_attr_mutations=None,
84        keepalive=None,
85        save_for_backward=None,
86        tensor_hooks=None,
87    ):
88        super().__init__()
89        self.id_to_variable = id_to_variable or {}
90        self.store_attr_mutations = store_attr_mutations or {}
91        self.keepalive = keepalive or []
92        self.save_for_backward = save_for_backward or []
93        self.tensor_hooks = tensor_hooks or {}
94        # Track Compiled Autograd final callbacks that must be called at the end of Compiled Autograd backward graph.
95        # Only applicable if this graph is created from Dynamo tracing in Compiled Autograd.
96        self.ca_final_callbacks_var = None
97
98    def __eq__(self, other: object) -> bool:
99        assert isinstance(other, SideEffects)
100        # NB: do NOT test keepalive
101        return (
102            self.id_to_variable == other.id_to_variable
103            and self.store_attr_mutations == other.store_attr_mutations
104            and self.save_for_backward == other.save_for_backward
105            and self.tensor_hooks == other.tensor_hooks
106        )
107
108    def diff(self, other: "SideEffects") -> Optional[str]:
109        if self.id_to_variable != other.id_to_variable:
110            sk_itv = self.id_to_variable.keys()
111            ok_itv = other.id_to_variable.keys()
112            if sk_itv != ok_itv:
113                return f"id_to_variable keys: {sk_itv} != {ok_itv}"
114            # Feel free to augment this with more fancy diffing logic
115            # if needed for debugging
116            return "id_to_variable: unknown diff"
117        elif self.store_attr_mutations != other.store_attr_mutations:
118            sk_sam = self.store_attr_mutations.keys()
119            ok_sam = other.store_attr_mutations.keys()
120            if sk_sam != ok_sam:
121                return f"store_attr_mutations keys: {sk_sam} != {ok_sam}"
122            return "store_attr_mutations: unknown diff"
123        elif self.save_for_backward != other.save_for_backward:
124            return "save_for_backward"
125        elif self.tensor_hooks != other.tensor_hooks:
126            return "tensor_hooks"
127        else:
128            return None
129
130    def clone(self):
131        """Create a shallow copy"""
132        return self.__class__(
133            id_to_variable=dict(self.id_to_variable),
134            store_attr_mutations={
135                k: dict(v) for k, v in self.store_attr_mutations.items()
136            },
137            keepalive=list(self.keepalive),
138            save_for_backward=self.save_for_backward,
139            tensor_hooks=self.tensor_hooks,
140        )
141
142    def __contains__(self, item):
143        return id(item) in self.id_to_variable
144
145    def __getitem__(self, item):
146        return self.id_to_variable[id(item)]
147
148    def check_allowed_side_effect(self, item):
149        from torch._dynamo.variables.misc import AutogradFunctionContextVariable
150
151        # People do things like self.dim = dim inside autograd.Function.
152        # These are benign.
153        if isinstance(item, AutogradFunctionContextVariable):
154            return True
155        if not is_side_effect_safe(item.mutable_local):
156            unimplemented(
157                "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)"
158            )
159
160    def store_attr(self, item: VariableTracker, name: str, value: VariableTracker):
161        assert self.is_attribute_mutation(item)
162        self.check_allowed_side_effect(item)
163        if item.mutable_local not in self.store_attr_mutations:
164            self.store_attr_mutations[item.mutable_local] = {}
165        self.store_attr_mutations[item.mutable_local][name] = value
166
167    def load_attr(self, item, name, deleted_ok=False):
168        assert self.is_attribute_mutation(item)
169        result = self.store_attr_mutations[item.mutable_local][name]
170        if not deleted_ok and isinstance(result, variables.DeletedVariable):
171            unimplemented("read deleted attribute")
172        return result
173
174    def store_cell(self, cellvar, value):
175        assert isinstance(cellvar, variables.NewCellVariable)
176        assert isinstance(value, variables.VariableTracker)
177        self.store_attr(cellvar, "cell_contents", value)
178
179    def load_cell(self, cellvar):
180        assert isinstance(cellvar, variables.NewCellVariable)
181        return self.load_attr(cellvar, "cell_contents")
182
183    def load_global(self, gvar: VariableTracker, name: str):
184        assert isinstance(gvar, variables.VariableTracker)
185        return self.load_attr(gvar, name)
186
187    def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker):
188        assert isinstance(gvar, variables.VariableTracker)
189        assert isinstance(value, variables.VariableTracker)
190        self.store_attr(gvar, name, value)
191
192    @staticmethod
193    def cls_supports_mutation_side_effects(cls):
194        return (
195            inspect.getattr_static(cls, "__getattribute__", None)
196            is object.__getattribute__
197        )
198
199    def is_attribute_mutation(self, item):
200        return isinstance(item.mutable_local, AttributeMutation)
201
202    def has_pending_mutation(self, item):
203        return self.is_attribute_mutation(item) and bool(
204            self.store_attr_mutations.get(item.mutable_local)
205        )
206
207    def has_pending_mutation_of_attr(self, item, name):
208        return self.is_attribute_mutation(
209            item
210        ) and name in self.store_attr_mutations.get(item.mutable_local, ())
211
212    def is_modified(self, item):
213        if isinstance(item.mutable_local, AttributeMutationNew):
214            return True
215        if self.is_attribute_mutation(item):
216            return item.mutable_local in self.store_attr_mutations
217        return item.mutable_local.is_modified
218
219    def _track_obj(
220        self,
221        item: Any,
222        variable: VariableTracker,
223        mutable_cls=MutableSideEffects,
224    ):
225        """Start tracking a new variable for mutation"""
226        assert variable.source is not None
227
228        if id(item) in self.id_to_variable:
229            raise AssertionError(
230                f"{variable} is already tracked for mutation. This could be "
231                "because you are not using VariableBuilder to construct "
232                "the variable tracker. "
233                f"Source of new object: {variable.source}. "
234                f"Source of previously tracked object: {self.id_to_variable[id(item)].source}."
235            )
236
237        variable.mutable_local = mutable_cls(variable.source)
238        self.id_to_variable[id(item)] = variable
239        self.keepalive.append(item)
240        return variable
241
242    track_mutable = _track_obj
243
244    def track_object_existing(
245        self,
246        item: Any,
247        variable: VariableTracker,
248    ):
249        return self._track_obj(item, variable, mutable_cls=AttributeMutationExisting)
250
251    def track_object_new(
252        self,
253        cls_source: Source,
254        user_cls: Any,
255        variable_cls: Any,
256        options,
257    ):
258        if user_cls is torch.autograd.function.FunctionCtx:
259            with warnings.catch_warnings(record=True):
260                obj = torch.autograd.Function()
261        elif issubclass(user_cls, torch.nn.Module):
262            obj = nn_module_new(user_cls)
263        else:
264            obj = object_new(user_cls)
265        variable = variable_cls(
266            obj,
267            mutable_local=AttributeMutationNew(None, cls_source),
268            **options,
269        )
270        self.id_to_variable[id(obj)] = variable
271        self.keepalive.append(obj)
272        return variable
273
274    def track_object_new_from_user_defined_class(
275        self,
276        cls_variable: "variables.UserDefinedClassVariable",
277    ):
278        cls_source = cls_variable.source
279        user_cls = cls_variable.value
280
281        # Find the variable class
282        variable_cls: Type[
283            variables.UserDefinedObjectVariable
284        ] = variables.UserDefinedObjectVariable
285        if issubclass(user_cls, torch.nn.Module):
286            variable_cls = variables.UnspecializedNNModuleVariable
287        elif issubclass(user_cls, MutableMapping):
288            variable_cls = variables.MutableMappingVariable
289        elif is_frozen_dataclass(user_cls):
290            variable_cls = FrozenDataClassVariable
291        else:
292            variable_cls = variables.UserDefinedObjectVariable
293
294        assert issubclass(variable_cls, variables.UserDefinedObjectVariable)
295
296        variable_cls = functools.partial(variable_cls, cls_source=cls_source)
297
298        return self.track_object_new(cls_source, user_cls, variable_cls, {})
299
300    def track_cell_new(
301        self,
302    ):
303        obj = object()
304        variable = variables.NewCellVariable(
305            mutable_local=AttributeMutationNew(None, None),
306        )
307        self.id_to_variable[id(obj)] = variable
308        self.keepalive.append(obj)
309        return variable
310
311    def track_cell_existing(self, source: Source, item: Any):
312        variable = variables.NewCellVariable(
313            mutable_local=AttributeMutationExisting(source),
314        )
315        self.id_to_variable[id(item)] = variable
316        self.keepalive.append(item)
317        return variable
318
319    def track_global_existing(self, source: Source, item: Any):
320        variable = variables.NewGlobalVariable(
321            mutable_local=AttributeMutationExisting(source),
322        )
323        self.id_to_variable[id(item)] = variable
324        self.keepalive.append(item)
325        return variable
326
327    def track_save_for_backward(self, ctx, args):
328        assert isinstance(ctx, variables.AutogradFunctionContextVariable)
329        self.save_for_backward.append((ctx, args))
330
331    def track_tensor_variables_from_runahead_side_effects(self, other):
332        # In higher order ops we want to keep track of tensors seen in the
333        # speculate_subgraph so that we don't lift them again as a new input in
334        # other speculate_subgraph or in the root tracer.
335        for other_item in other.keepalive:
336            other_id = id(other_item)
337            other_variable = other.id_to_variable[other_id]
338            if other_id not in self.id_to_variable and isinstance(
339                other_variable, variables.TensorVariable
340            ):
341                self.track_object_existing(other_item, other_variable)
342
343    def prune_dead_object_new(self, tx):
344        live_new_objects = set()
345
346        # use this to avoid cycles in mutable_local (though I'm not sure if that
347        # can actually happen).
348        visited: Any = set({})
349
350        def visit(var: VariableTracker):
351            mutable_local = var.mutable_local
352            if mutable_local is None:
353                return
354            if mutable_local in visited:
355                return
356            visited.add(mutable_local)
357            # Object may have been mutated, store this mutation.
358            if isinstance(mutable_local, AttributeMutationNew):
359                live_new_objects.add(mutable_local)
360            # It's possible that we have mutated the value of this variable
361            # to be another one. The new value is in store_attr_mutations.
362            # Also recurse through the new value to detect alive AttributeMutationNew.
363            if var.mutable_local in self.store_attr_mutations:
364                VariableTracker.visit(
365                    visit, self.store_attr_mutations[var.mutable_local]
366                )
367
368        def is_live(var: Union[MutableLocalBase, VariableTracker]):
369            if isinstance(var, AttributeMutationNew):
370                return var in live_new_objects
371            if isinstance(var, VariableTracker):
372                return is_live(var.mutable_local)
373            return True
374
375        pre_existing_vars = [
376            var
377            for var in self.id_to_variable.values()
378            if not isinstance(var.mutable_local, AttributeMutationNew)
379        ]
380
381        # The only live side effects come from returns (tx.stack), any intermediates
382        # during a graph break (tx.symbolic_locals), and mutation on pre-existing variables.
383        # Recursively visit Variables and see if any of them have been mutated.
384        VariableTracker.visit(visit, (tx.stack, tx.symbolic_locals, pre_existing_vars))
385
386        # NB: cell variable handling.is tricky.
387        # cell variables must stay alive if any NestedUserFunctionVariable
388        # are live. "visit"-ing the NestedUserFunctionVariable visits
389        # the .closures field, from which we will see if we need to keep
390        # any mutations to cell variables alive.
391
392        self.id_to_variable = {
393            k: v for k, v in self.id_to_variable.items() if is_live(v)
394        }
395        self.store_attr_mutations = {
396            k: v for k, v in self.store_attr_mutations.items() if is_live(k)
397        }
398
399    def mutation(self, var):
400        self.check_allowed_side_effect(var)
401        if isinstance(var.mutable_local, MutableSideEffects):
402            var.mutable_local = MutableSideEffects(var.mutable_local.source, True)
403
404    def _get_modified_vars(self):
405        return [var for var in self.id_to_variable.values() if self.is_modified(var)]
406
407    def codegen_save_tempvars(self, cg: PyCodegen):
408        for var in self._get_modified_vars():
409            if isinstance(
410                var.mutable_local, (AttributeMutationExisting, AttributeMutationNew)
411            ) and isinstance(var, variables.NewCellVariable):
412                cg.add_push_null(
413                    lambda: cg.load_import_from(utils.__name__, "make_cell")
414                )
415                cg.extend_output(create_call_function(0, False))
416                cg.add_cache(var)
417                if isinstance(var.mutable_local, AttributeMutationNew):
418                    var.mutable_local.source = LocalSource(cg.tempvars[var])  # type: ignore[attr-defined]
419            elif isinstance(var.mutable_local, AttributeMutationNew):
420                if isinstance(var, variables.AutogradFunctionContextVariable):
421                    unimplemented("AutogradFunctionContextVariable escaped")
422                cg.add_push_null(
423                    lambda: cg.load_import_from(utils.__name__, "object_new")
424                )
425                cg(var.mutable_local.cls_source)
426                cg.extend_output(create_call_function(1, False))
427                cg.add_cache(var)
428                var.mutable_local.source = LocalSource(cg.tempvars[var])
429            elif var in cg.tempvars:
430                assert cg.tempvars.get(var) is None
431                # subsequent usage should point to the original variable
432                cg(var.mutable_local.source)
433                cg.add_cache(var)
434
435        for ctx, args in self.save_for_backward:
436            cg(ctx.source)
437            cg.load_method("save_for_backward")
438            for arg in args:
439                cg(arg)
440            cg.extend_output(
441                [
442                    *create_call_method(len(args)),
443                    create_instruction("POP_TOP"),
444                ]
445            )
446
447    def register_hook(self, tensor, hook, handle, name):
448        assert isinstance(tensor, variables.TensorVariable)
449        assert isinstance(hook, variables.VariableTracker)
450        assert (
451            isinstance(handle, variables.RemovableHandleVariable)
452            and handle.mutable_local
453        )
454        assert hasattr(torch.Tensor, name)
455        idx = len(self.tensor_hooks.keys())
456        # duplicate index possible because of self.remove_hook()
457        while idx in self.tensor_hooks:
458            idx += 1
459        self.tensor_hooks[idx] = (tensor, hook, handle, name)
460        assert not handle.idx
461        handle.idx = idx
462
463    def remove_hook(self, idx):
464        del self.tensor_hooks[idx]
465
466    def codegen_hooks(self, cg):
467        for (
468            tensor,
469            hook,
470            handle,
471            name,
472        ) in self.tensor_hooks.values():
473            # Note: [On tensor.register_hook]
474            #
475            # register_hook on a tensor, AKA backward hooks, have slightly nuanced differences in how they are implemented
476            # when it comes to hooks on objects with sources (inputs, params) vs objects without sources (intermediaries).
477            #
478            # For tensors with a source, we bypass direct inclusion of register_hook calls in the graph.
479            # Instead, these are tracked and stashed as a global variable, enabling their association with tensors in
480            # the residuals. During dynamo's frame creation, these hooks are invoked seamlessly on known reconstructible/fetch-able
481            # tensors. Because a source indicates knowledge of this object outside the torch compile region, and
482            # because we are running residuals firmly before .backward() can be run, it is sound to invoke
483            # `register_hook` on a known tensor.
484            #
485            # For tensors without a source, we support a limited subset of hooks. Global functions only, and
486            # compiled_autograd must be enabled or we will graph break.
487            #
488            # Handling the Handle: When a user retains the register_hook result in a handle, we intercept the
489            # STORE_FAST operation to record the user-designated local variable name. This ensures the reconstructed
490            # bytecode retains this name. If no handle is defined, we simply pop the generated value to keep the
491            # stack intact.
492            #
493            # Dynamo Tensor Hooks Workflow:
494            # - Functions passed to register_hook are lifted globally.
495            # - For tensors with sources:
496            #   - In the "side_effects" phase of codegen, we iterate over tensors with hooks to:
497            #     - Generate the tensor.
498            #     - Issue a register_hook call on the tensor, linking to the globally stored function.
499            #     - Incorporate a handle if one was established in the eager phase.
500            #  - For tensors without sources:
501            #    - We don't generate any instructions for registering a hook.
502            #    - Handles from intermediary hooks are NYI.
503            #    - We produce a call function that utilizes the trace_wrapped higher order op, closing over it.
504            #    - We then manually insert the call function above into the graph.
505            # - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST.
506            assert tensor.source, "Hooks on non input tensors NYI - should not get here"
507
508            def gen_fn():
509                cg(tensor)
510                cg.extend_output([cg.create_load_attr(name)])
511
512            cg.add_push_null(gen_fn)
513            cg(hook)
514            cg.extend_output(create_call_function(1, False))
515
516            # Adding the handle to the cache means RemovableHandleVariable().reconstruct() will
517            # be associated with the return value of register_hook().  This consumes the top of stack.
518            cg.add_cache(handle)
519
520    def get_ca_final_callbacks_var(self):
521        from .variables.base import MutableLocal
522
523        if self.ca_final_callbacks_var is None:
524            self.ca_final_callbacks_var = variables.ListVariable(
525                [], mutable_local=MutableLocal()
526            )
527        return self.ca_final_callbacks_var
528
529    def codegen_update_mutated(self, cg: PyCodegen):
530        suffixes = []
531        for var in self._get_modified_vars():
532            if isinstance(var, variables.ListVariable):
533                # old[:] = new
534                cg(var, allow_cache=False)
535                cg(var.mutable_local.source)  # type: ignore[attr-defined]
536                cg.extend_output(
537                    [
538                        cg.create_load_const(None),
539                        cg.create_load_const(None),
540                        create_instruction("BUILD_SLICE", arg=2),
541                    ]
542                )
543                suffixes.append([create_instruction("STORE_SUBSCR")])
544            elif isinstance(var, variables.CustomizedDictVariable):
545                # need to update the dict manually since update method may be invalid
546                varname_map = {}
547                for name in _manual_update_dict.__code__.co_varnames:
548                    varname_map[name] = cg.tx.output.new_var()
549
550                cg(var.mutable_local.source)  # type: ignore[attr-defined]
551                cg.extend_output(
552                    [create_instruction("STORE_FAST", argval=varname_map["dict_to"])]
553                )
554
555                cg(var, allow_cache=False)
556                cg.extend_output(
557                    [create_instruction("STORE_FAST", argval=varname_map["dict_from"])]
558                )
559
560                cg(var.mutable_local.source)  # type: ignore[attr-defined]
561                cg.load_method("clear")
562
563                # unfortunately can't just use DICT_MERGE due to possible custom behaviors
564                dict_update_insts = bytecode_from_template(
565                    _manual_update_dict, varname_map=varname_map
566                )
567
568                suffixes.append(
569                    [
570                        *create_call_method(0),  # clear
571                        create_instruction("POP_TOP"),
572                        *dict_update_insts,
573                        create_instruction("POP_TOP"),
574                    ]
575                )
576
577            elif isinstance(var, variables.ConstDictVariable):
578                cg(var.mutable_local.source)  # type: ignore[attr-defined]
579                cg.load_method("update")
580                cg(var, allow_cache=False)
581
582                cg(var.mutable_local.source)  # type: ignore[attr-defined]
583                cg.load_method("clear")
584
585                suffixes.append(
586                    [
587                        *create_call_method(0),  # clear
588                        create_instruction("POP_TOP"),
589                        *create_call_method(1),  # update
590                        create_instruction("POP_TOP"),
591                    ]
592                )
593            elif isinstance(
594                var, variables.torch_function.TorchFunctionModeStackVariable
595            ):
596                cg.add_push_null(
597                    lambda: cg.load_import_from(
598                        utils.__name__, "set_torch_function_mode_stack"
599                    )
600                )
601                cg.foreach(var.symbolic_stack)
602                cg.append_output(
603                    create_instruction("BUILD_LIST", arg=len(var.symbolic_stack))
604                )
605                cg.call_function(1, False)
606                cg.append_output(create_instruction("POP_TOP"))
607            elif self.is_attribute_mutation(var):
608                # Applying mutations involves two steps: 1) Push all
609                # reconstructed objects onto the stack.  2) Call STORE_ATTR to
610                # apply the mutations.
611                #
612                # Dynamo must ensure that mutations are applied in the same
613                # order as in the original program. Therefore, two reverse
614                # operations occur below.
615                #
616                # The first reverse operation concerns `suffixes`. We apply
617                # suffixes in reverse order due to the way Python handles the
618                # stack. In Step 1, we push all reconstructed objects onto the
619                # stack, but the item at the top of the stack refers to the last
620                # attribute in the mutation order. If not fixed, this will apply
621                # the mutations of attributes in the reverse order.  To account
622                # for this reversal, we iterate through the mutable attributes
623                # in reverse order.
624                for name, value in reversed(
625                    self.store_attr_mutations.get(var.mutable_local, {}).items()
626                ):
627                    if isinstance(var, variables.NewGlobalVariable):
628                        cg.tx.output.update_co_names(name)
629                        cg(value)
630                        assert isinstance(var.mutable_local.source, GlobalSource)  # type: ignore[attr-defined]
631                        suffixes.append(
632                            [create_instruction("STORE_GLOBAL", argval=name)]
633                        )
634                    elif isinstance(value, variables.DeletedVariable):
635                        if isinstance(
636                            var.mutable_local, AttributeMutationExisting
637                        ) and hasattr(getattr(var, "value", None), name):
638                            cg.tx.output.update_co_names(name)
639                            cg(var.mutable_local.source)
640                            suffixes.append(
641                                [create_instruction("DELETE_ATTR", argval=name)]
642                            )
643                    elif (
644                        isinstance(var, variables.UserDefinedObjectVariable)
645                        and var.needs_slow_setattr()
646                    ):
647                        # __setattr__ is defined on this object, so call object.__setattr__ directly
648                        cg.load_import_from("builtins", "object")
649                        cg.load_method("__setattr__")
650                        cg(var.mutable_local.source)  # type: ignore[attr-defined]
651                        cg(variables.ConstantVariable(name))
652                        cg(value)
653                        suffixes.append(
654                            [*create_call_method(3), create_instruction("POP_TOP")]
655                        )
656                    else:
657                        cg.tx.output.update_co_names(name)
658                        cg(value)
659                        cg(var.mutable_local.source)
660                        suffixes.append([create_instruction("STORE_ATTR", argval=name)])
661            elif isinstance(var, variables.TupleIteratorVariable):
662                for _ in range(var.index):
663                    cg.add_push_null(
664                        lambda: cg.load_import_from(utils.__name__, "iter_next")
665                    )
666                    cg(var.mutable_local.source)  # type: ignore[attr-defined]
667                    cg.call_function(1, False)
668                    cg.pop_top()
669            elif isinstance(var, variables.RandomVariable):
670                # set correct random seed state
671                def gen_fn():
672                    cg(var.mutable_local.source)  # type: ignore[attr-defined]
673                    cg.load_attr("setstate")
674
675                cg.add_push_null(gen_fn)
676                cg(var.wrap_state(var.random.getstate()))
677
678                suffixes.append(
679                    [
680                        *create_call_function(1, False),  # setstate
681                        create_instruction("POP_TOP"),
682                    ]
683                )
684            else:
685                raise AssertionError(type(var))
686
687        # do all the actual mutations at the very end to handle dependencies
688        for suffix in reversed(suffixes):
689            cg.extend_output(suffix)
690
691    def is_empty(self):
692        return not (
693            any(map(self.is_modified, self.id_to_variable.values()))
694            or self.tensor_hooks
695            or self.save_for_backward
696            or self.tensor_hooks
697        )
698
699    def clear(self):
700        self.keepalive.clear()
701        self.id_to_variable.clear()
702