1# mypy: allow-untyped-defs 2# mypy: disable-error-code="method-assign" 3 4import functools 5import weakref 6 7import torch.nn 8from torch.nn import Module 9 10from . import config 11from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks 12 13 14unpatched_nn_module_init = torch.nn.Module.__init__ 15 16 17class MutationTracker: 18 db = ExactWeakKeyDictionary() 19 20 def __init__(self): 21 self.mutation_count = 0 22 self.watchers = [] 23 24 def on_mutation(self, name): 25 self.mutation_count += 1 26 tmp = self.watchers 27 self.watchers = [] 28 for ref in tmp: 29 guarded = ref() 30 if guarded is not None: 31 guarded.invalidate(ref) 32 33 def track(self, guarded_code): 34 self.watchers.append(weakref.ref(guarded_code)) 35 36 37def watch(obj, guarded_code): 38 """invalidate guarded_code when obj is mutated""" 39 ensure_patched(type(obj)) 40 41 if obj not in MutationTracker.db: 42 MutationTracker.db[obj] = MutationTracker() 43 tracker = MutationTracker.db[obj] 44 tracker.track(guarded_code) 45 46 47def ensure_patched(cls): 48 if getattr(cls, "___needs_mutation_patch", True): 49 cls.___needs_mutation_patch = False 50 original_setattr = cls.__setattr__ 51 52 @functools.wraps(original_setattr) 53 def custom_setattr(self, key, value): 54 try: 55 MutationTracker.db[self].on_mutation(key) 56 except KeyError: 57 pass 58 return original_setattr(self, key, value) 59 60 cls.__setattr__ = custom_setattr 61 62 63class GenerationTracker: 64 generation = 0 65 dynamic_classes = ExactWeakKeyDictionary() 66 generation_values = ExactWeakKeyDictionary() 67 68 @classmethod 69 def tag(cls, obj): 70 cls.generation_values[obj] = cls.generation 71 72 @staticmethod 73 def mark_class_dynamic(cls): 74 assert issubclass(cls, torch.nn.Module) 75 GenerationTracker.dynamic_classes[cls] = True 76 77 @classmethod 78 def get_generation_value(cls, obj): 79 if obj not in cls.generation_values: 80 return -1 81 return cls.generation_values[obj] 82 83 @classmethod 84 def check(cls, obj): 85 return ( 86 obj in cls.generation_values 87 and cls.generation_values[obj] == cls.generation 88 ) 89 90 @classmethod 91 def clear(cls): 92 cls.generation = 0 93 cls.dynamic_classes = ExactWeakKeyDictionary() 94 cls.generation_values = ExactWeakKeyDictionary() 95 96 97def is_dynamic_nn_module(obj, is_export): 98 """Check for nn.Modules() created dynamically or mutated""" 99 if isinstance(obj, torch.nn.Module) and "forward" in obj.__dict__: 100 # A monkey patched `.forward` indicates something wacky is going on 101 return True 102 if hasattr(obj, "torchdynamo_force_dynamic"): 103 return obj.torchdynamo_force_dynamic 104 if is_lazy_module(obj): 105 return False 106 # For export, we will have to fix 107 # 1) Input signature problem because params are lifted as inputs 108 # 2) nn module stack info changes 109 # 3) adjust failing tests 110 if ( 111 isinstance(obj, torch.nn.Module) 112 and config.inline_inbuilt_nn_modules 113 and not is_export 114 ): 115 return True 116 117 if isinstance(obj, torch.nn.Module) and nn_module_has_global_hooks(): 118 return True 119 dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check( 120 obj 121 ) 122 return dyn 123 124 125def install_generation_tagging_init(): 126 """ 127 Monkey patch torch.nn.Module.__init__ and torch.nn.Module.__setstate__ 128 so we can detect nn.Module instances created dynamically inside forward methods. 129 """ 130 131 if getattr(Module, "___needs_generation_tag_patch", True): 132 init = Module.__init__ 133 134 def patched_init(self, *args, **kwargs): 135 init(self, *args, **kwargs) 136 GenerationTracker.tag(self) 137 138 Module.__init__ = patched_init 139 140 setstate = Module.__setstate__ 141 142 def patched_setstate(self, state): 143 setstate(self, state) 144 GenerationTracker.tag(self) 145 146 Module.__setstate__ = patched_setstate 147 148 Module.___needs_generation_tag_patch = False # type: ignore[attr-defined] 149 150 GenerationTracker.generation += 1 151