1# mypy: ignore-errors 2 3import weakref 4from typing import Dict, List, TYPE_CHECKING 5 6import torch 7from torch.utils._pytree import tree_map_only 8 9from ..guards import GuardBuilder, install_guard 10from ..source import ( 11 AttrSource, 12 ConstDictKeySource, 13 GetItemSource, 14 GlobalWeakRefSource, 15 GradSource, 16) 17from ..utils import GLOBAL_KEY_PREFIX 18from .constant import ConstantVariable 19from .dicts import ConstDictVariable 20from .lists import ListVariable 21from .misc import GetAttrVariable 22from .user_defined import UserDefinedObjectVariable 23 24 25if TYPE_CHECKING: 26 from torch._dynamo.symbolic_convert import InstructionTranslator 27 28 from .base import VariableTracker 29 30 31class ArgMappingException(Exception): 32 pass 33 34 35class GuardInstallException(Exception): 36 pass 37 38 39class OptimizerVariable(UserDefinedObjectVariable): 40 _nonvar_fields = { 41 "grad_to_source", 42 "tensor_to_source", 43 "static_tensor_names", 44 *UserDefinedObjectVariable._nonvar_fields, 45 } 46 47 def __init__( 48 self, 49 value, 50 grad_to_source=None, 51 static_tensor_names=None, 52 tensor_to_source=None, 53 **kwargs, 54 ) -> None: 55 super().__init__(value, **kwargs) 56 self.grad_to_source = grad_to_source or {} 57 self.tensor_to_source = tensor_to_source or {} 58 self.static_tensor_names = static_tensor_names or set() 59 60 def call_method( 61 self, 62 tx, 63 name, 64 args: "List[VariableTracker]", 65 kwargs: "Dict[str, VariableTracker]", 66 ) -> "VariableTracker": 67 """This is an optimization to avoid tracing the very slow initialization of the optimizer""" 68 if name == "_init_group": 69 try: 70 self.graph_break_if_pending_mutation(tx) 71 self.move_step_if_cpu() 72 py_args, py_kwargs = self.get_python_args(*args, **kwargs) 73 ret_val = self.value._init_group(*py_args, **py_kwargs) 74 self.map_sources_and_install_guards(tx) 75 self.update_list_args(tx, args, kwargs, py_args, py_kwargs) 76 # stash a weak_ptr to optimizer to invalidate code 77 # if the optimizer object dies 78 mangled_name = f"__optimizer_{id(self.value)}" 79 tx.store_global_weakref_by_id(mangled_name, self.value) 80 self.create_finalizer(tx) 81 82 # This is currently safe only because the only actual `ret_val`s returned 83 # by the `_init_group` of existing optimizers are properties that are invariant 84 # to the input tensors (e.g. dtype, layout). Changing these would trigger a 85 # recompilation and hence never result in the wrong specialization of `ret_val`. 86 return ConstantVariable.create(ret_val) 87 except (ArgMappingException, GuardInstallException) as _: 88 # trace normally if we can't map args or install guards correctly 89 pass 90 91 return super().call_method(tx, name, args, kwargs) 92 93 def var_getattr(self, tx: "InstructionTranslator", name): 94 # Note: this allows us to intercept the call in call_method 95 # in the typical case, we return a UserMethodVariable 96 # which will directly inline 97 if name in ("_init_group", "step"): 98 return GetAttrVariable(self, name, source=AttrSource(self.source, name)) 99 100 if name == "param_groups": 101 from ..decorators import mark_static_address 102 103 for group in self.value.param_groups: 104 for p in group["params"]: 105 mark_static_address(p) 106 107 self._set_capturable(tx) 108 109 return super().var_getattr(tx, name) 110 111 def graph_break_if_pending_mutation(self, tx): 112 # If there are pending mutations on a parameter (due to using closure) 113 # then we need to graph break to allow the python version of the parameter 114 # to update, so that running _init_group will initialize the states with 115 # the correct values 116 for g in self.value.param_groups: 117 for p in g["params"]: 118 side_effects = tx.output.side_effects 119 variable = side_effects.id_to_variable.get(id(p), None) 120 if variable and side_effects.has_pending_mutation(variable): 121 from ..exc import Unsupported 122 123 raise Unsupported("Pending mutation on parameter") 124 125 def _set_capturable(self, tx): 126 from . import LazyVariableTracker 127 from .builder import VariableBuilder 128 129 # We only set capturable if params are on cuda 130 # and the state is not initialized 131 def safe_to_set_capturable(group): 132 all_uninitialized = True 133 all_gpu = True 134 135 for p in group.get("params", []): 136 all_gpu &= p.is_cuda or p.is_xpu 137 all_uninitialized &= p not in self.value.state 138 139 return "capturable" in group and all_uninitialized and all_gpu 140 141 # track indices to not set so we don't need to 142 # in the variable tracker realize the whole state 143 # we handle guarding the state specially 144 for ind, group in enumerate(self.value.param_groups): 145 if safe_to_set_capturable(group): 146 group["capturable"] = True 147 148 param_groups_vt = LazyVariableTracker.realize_all( 149 VariableBuilder(tx, AttrSource(self.source, "param_groups"))( 150 self.value.param_groups 151 ) 152 ) 153 for ind, param_group_vt in enumerate(param_groups_vt.items): 154 key = ConstDictVariable._HashableTracker( 155 ConstantVariable.create("capturable") 156 ) 157 param_group_vt.items[key] = ConstantVariable.create(True) 158 159 def get_python_args(self, *args, **kwargs): 160 """Get python values equivalent to the variable tracker args""" 161 162 def map_arg(arg): 163 if isinstance(arg, ConstantVariable): 164 return arg.as_python_constant() 165 elif isinstance(arg, ListVariable) and not arg.items: 166 return [] 167 elif ( 168 isinstance(arg, ConstDictVariable) 169 and isinstance(arg.source, GetItemSource) 170 and isinstance(arg.source.base, AttrSource) 171 and arg.source.base.member == "param_groups" 172 ): 173 return self.value.param_groups[arg.source.index] 174 175 raise ArgMappingException 176 177 new_args = [map_arg(arg) for arg in args] 178 new_kwargs = {k: map_arg(v) for k, v in kwargs.items()} 179 180 return new_args, new_kwargs 181 182 # If users load an old state dictionary, 183 # it's possible that step could be on the cpu 184 # if this is the case, move it to the GPU 185 # corresponding to the parameter 186 # in most cases this is a no-op because the state is empty 187 def move_step_if_cpu(self): 188 for p, state in self.value.state.items(): 189 if "step" in state and state["step"].is_cpu: 190 state["step"] = state["step"].to(p.device) 191 192 def map_sources_and_install_guards(self, tx): 193 from ..decorators import mark_static_address 194 from .builder import VariableBuilder 195 from .lazy import LazyVariableTracker 196 197 self.grad_to_source = {} 198 self.tensor_to_source = {} 199 200 # Tracing the _init_group is expensive. But we still have to insert the 201 # necessary guards for _init_group. So, we manually handle insertion of 202 # guards. We also want to mark all the tensors inside the state dict to 203 # be static address. 204 205 # Mark all the tensors in the state dict to be static address. This has 206 # to be done first because the variable builder relies on the static 207 # address annotation. 208 def mark_static(x): 209 mark_static_address(x) 210 211 tree_map_only(torch.Tensor, mark_static, self.value.state) 212 213 # Recursively realize the variable trackers for optim.state and 214 # optim.param_groups, which recursively install the necessary guards. 215 param_groups_vt = LazyVariableTracker.realize_all( 216 VariableBuilder(tx, AttrSource(self.source, "param_groups"))( 217 self.value.param_groups 218 ) 219 ) 220 221 state_vt = VariableBuilder(tx, AttrSource(self.source, "state"))( 222 self.value.state 223 ) 224 225 # We need to realize the top level state dict to populate 226 # the guard locals 227 state_vt.realize() 228 229 # Populate self.grad_to_source and self.tensor_to_source so that we can 230 # manually update_list_args 231 for g_ind, (group, group_vt) in enumerate( 232 zip(self.value.param_groups, param_groups_vt.items) 233 ): 234 # we assume here that all params within a param group 235 # are initialized similarly 236 if len(group["params"]) > 0: 237 for param in group["params"]: 238 if param.grad is not None: 239 key_index = None 240 for i, k in enumerate(self.value.state.keys()): 241 if k is param: 242 key_index = i 243 break 244 if key_index: 245 state_source = AttrSource(self.source, "state") 246 LazyVariableTracker.realize_all( 247 VariableBuilder( 248 tx, 249 GetItemSource( 250 state_source, 251 ConstDictKeySource(state_source, key_index), 252 ), 253 )(self.value.state[param]) 254 ) 255 break 256 257 group_source = group_vt.source 258 params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params")) 259 for p_ind, (p, p_vt) in enumerate( 260 zip(group["params"], params_vt.unpack_var_sequence(tx)) 261 ): 262 param_source = p_vt.source 263 self.tensor_to_source[p] = param_source 264 grad_source = GradSource( 265 param_source, 266 "grad", 267 ) 268 269 if p.grad is not None: 270 self.grad_to_source[p.grad] = grad_source 271 else: 272 install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH)) 273 274 # We have to again iterate over the state dict to collect the 275 # tensor_to_source dict. This is used for the finalizer. 276 state_source = AttrSource(self.source, "state") 277 for idx, (p, value) in enumerate(self.value.state.items()): 278 p_state_source = GetItemSource( 279 state_source, ConstDictKeySource(state_source, idx) 280 ) 281 for k, v in value.items(): 282 if ( 283 isinstance(v, torch.Tensor) 284 and v not in self.grad_to_source 285 and v not in self.tensor_to_source 286 ): 287 self.tensor_to_source[v] = GetItemSource(p_state_source, k) 288 289 def wrap_tensor(self, tx: "InstructionTranslator", tensor_value): 290 """Wrap state tensor in a TensorVariable""" 291 from ..decorators import mark_static_address 292 from .builder import VariableBuilder 293 294 # If we have a source for a tensor already use it, 295 # if we have not seen a tensor before, stash and use a 296 # global weak ref source, since it must be an optimizer tensor 297 # that we have missed 298 299 if tensor_value in self.tensor_to_source: 300 # mark these tensors as static for cudagraphs 301 mark_static_address(tensor_value) 302 builder = VariableBuilder(tx, self.tensor_to_source[tensor_value]) 303 self.static_tensor_names.add(tx.output.module_key_name(builder.name)) 304 elif tensor_value in self.grad_to_source: 305 builder = VariableBuilder(tx, self.grad_to_source[tensor_value]) 306 else: 307 # mark these tensors as static for cudagraphs 308 mark_static_address(tensor_value) 309 310 global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value) 311 builder = VariableBuilder(tx, GlobalWeakRefSource(global_name)) 312 self.static_tensor_names.add(tx.output.module_key_name(builder.name)) 313 314 result = builder(tensor_value) 315 return result 316 317 def update_list_args( 318 self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs 319 ): 320 """Update the args and kwargs to the traced optimizer call""" 321 for arg, py_arg in zip(args, py_args): 322 if isinstance(arg, ListVariable): 323 assert isinstance( 324 py_arg, list 325 ), "py_arg should be a list in optimizer variable" 326 for i, val in enumerate(py_arg): 327 tx.output.side_effects.mutation(arg) 328 if isinstance(val, torch.Tensor): 329 arg.items.append(self.wrap_tensor(tx, val)) 330 else: 331 from .builder import SourcelessBuilder, VariableBuilder 332 333 if arg.source: 334 arg.items.append( 335 VariableBuilder(tx, GetItemSource(arg.source, i))(val) 336 ) 337 else: 338 arg.items.append(SourcelessBuilder.create(tx, val)) 339 340 def create_finalizer(self, tx): 341 names_to_delete = self.static_tensor_names 342 value = self.value 343 tc = tx.output.tracing_context 344 345 def init_finalizer(gm): 346 def clear_static_tensor_refs(): 347 for name in names_to_delete: 348 gm._buffers.pop(name, None) 349 gm._parameters.pop(name, None) 350 if tc.params_flat: 351 tc.params_flat.clear() 352 353 weakref.finalize(value, clear_static_tensor_refs) 354 355 tx.output.add_graph_finalizer(init_finalizer) 356