xref: /aosp_15_r20/external/pytorch/torch/_dynamo/mutation_guard.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# mypy: disable-error-code="method-assign"
3
4import functools
5import weakref
6
7import torch.nn
8from torch.nn import Module
9
10from . import config
11from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks
12
13
14unpatched_nn_module_init = torch.nn.Module.__init__
15
16
17class MutationTracker:
18    db = ExactWeakKeyDictionary()
19
20    def __init__(self):
21        self.mutation_count = 0
22        self.watchers = []
23
24    def on_mutation(self, name):
25        self.mutation_count += 1
26        tmp = self.watchers
27        self.watchers = []
28        for ref in tmp:
29            guarded = ref()
30            if guarded is not None:
31                guarded.invalidate(ref)
32
33    def track(self, guarded_code):
34        self.watchers.append(weakref.ref(guarded_code))
35
36
37def watch(obj, guarded_code):
38    """invalidate guarded_code when obj is mutated"""
39    ensure_patched(type(obj))
40
41    if obj not in MutationTracker.db:
42        MutationTracker.db[obj] = MutationTracker()
43    tracker = MutationTracker.db[obj]
44    tracker.track(guarded_code)
45
46
47def ensure_patched(cls):
48    if getattr(cls, "___needs_mutation_patch", True):
49        cls.___needs_mutation_patch = False
50        original_setattr = cls.__setattr__
51
52        @functools.wraps(original_setattr)
53        def custom_setattr(self, key, value):
54            try:
55                MutationTracker.db[self].on_mutation(key)
56            except KeyError:
57                pass
58            return original_setattr(self, key, value)
59
60        cls.__setattr__ = custom_setattr
61
62
63class GenerationTracker:
64    generation = 0
65    dynamic_classes = ExactWeakKeyDictionary()
66    generation_values = ExactWeakKeyDictionary()
67
68    @classmethod
69    def tag(cls, obj):
70        cls.generation_values[obj] = cls.generation
71
72    @staticmethod
73    def mark_class_dynamic(cls):
74        assert issubclass(cls, torch.nn.Module)
75        GenerationTracker.dynamic_classes[cls] = True
76
77    @classmethod
78    def get_generation_value(cls, obj):
79        if obj not in cls.generation_values:
80            return -1
81        return cls.generation_values[obj]
82
83    @classmethod
84    def check(cls, obj):
85        return (
86            obj in cls.generation_values
87            and cls.generation_values[obj] == cls.generation
88        )
89
90    @classmethod
91    def clear(cls):
92        cls.generation = 0
93        cls.dynamic_classes = ExactWeakKeyDictionary()
94        cls.generation_values = ExactWeakKeyDictionary()
95
96
97def is_dynamic_nn_module(obj, is_export):
98    """Check for nn.Modules() created dynamically or mutated"""
99    if isinstance(obj, torch.nn.Module) and "forward" in obj.__dict__:
100        # A monkey patched `.forward` indicates something wacky is going on
101        return True
102    if hasattr(obj, "torchdynamo_force_dynamic"):
103        return obj.torchdynamo_force_dynamic
104    if is_lazy_module(obj):
105        return False
106    # For export, we will have to fix
107    # 1) Input signature problem because params are lifted as inputs
108    # 2) nn module stack info changes
109    # 3) adjust failing tests
110    if (
111        isinstance(obj, torch.nn.Module)
112        and config.inline_inbuilt_nn_modules
113        and not is_export
114    ):
115        return True
116
117    if isinstance(obj, torch.nn.Module) and nn_module_has_global_hooks():
118        return True
119    dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check(
120        obj
121    )
122    return dyn
123
124
125def install_generation_tagging_init():
126    """
127    Monkey patch torch.nn.Module.__init__ and torch.nn.Module.__setstate__
128    so we can detect nn.Module instances created dynamically inside forward methods.
129    """
130
131    if getattr(Module, "___needs_generation_tag_patch", True):
132        init = Module.__init__
133
134        def patched_init(self, *args, **kwargs):
135            init(self, *args, **kwargs)
136            GenerationTracker.tag(self)
137
138        Module.__init__ = patched_init
139
140        setstate = Module.__setstate__
141
142        def patched_setstate(self, state):
143            setstate(self, state)
144            GenerationTracker.tag(self)
145
146        Module.__setstate__ = patched_setstate
147
148        Module.___needs_generation_tag_patch = False  # type: ignore[attr-defined]
149
150    GenerationTracker.generation += 1
151