1# mypy: allow-untyped-defs 2import functools 3import inspect 4import warnings 5from collections.abc import MutableMapping 6from typing import Any, Dict, List, Optional, Type, Union 7 8import torch.nn 9 10from . import utils, variables 11from .bytecode_transformation import ( 12 bytecode_from_template, 13 create_call_function, 14 create_call_method, 15 create_instruction, 16) 17from .codegen import PyCodegen 18from .exc import unimplemented 19from .source import GlobalSource, LocalSource, Source 20from .utils import is_frozen_dataclass, nn_module_new, object_new 21from .variables.base import ( 22 is_side_effect_safe, 23 MutableLocalBase, 24 MutableLocalSource, 25 VariableTracker, 26) 27from .variables.user_defined import FrozenDataClassVariable 28 29 30class MutableSideEffects(MutableLocalBase): 31 """ 32 VariableTracker.mutable_local marker to indicate a list passed as 33 an input that if we mutate we need to re-apply those mutations after 34 the graph runs. 35 """ 36 37 def __init__(self, source: Source, is_modified: bool = False): 38 super().__init__(MutableLocalSource.Existing) 39 self.source = source 40 self.is_modified = is_modified 41 42 43class AttributeMutation(MutableLocalBase): 44 """ 45 VariableTracker.mutable_local marker to track changes to attributes 46 """ 47 48 def __init__(self, typ: MutableLocalSource, source: Optional[Source]): 49 super().__init__(typ) 50 self.source = source 51 52 53class AttributeMutationExisting(AttributeMutation): 54 def __init__(self, source: Source): 55 super().__init__(MutableLocalSource.Existing, source) 56 self.source = source 57 58 59class AttributeMutationNew(AttributeMutation): 60 def __init__(self, source: Optional[Source], cls_source: Optional[Source]): 61 super().__init__(MutableLocalSource.Local, source) 62 self.cls_source = cls_source 63 64 65def _manual_update_dict(dict_from, dict_to): 66 for k, v in dict_from.items(): 67 dict_to[k] = v 68 69 70class SideEffects: 71 """ 72 Track side effects (list mutation, setattr, etc) that need to be 73 applied after an FX graph is run. 74 """ 75 76 id_to_variable: Dict[int, VariableTracker] 77 store_attr_mutations: Dict[MutableLocalBase, Dict[str, VariableTracker]] 78 keepalive: List[Any] 79 80 def __init__( 81 self, 82 id_to_variable=None, 83 store_attr_mutations=None, 84 keepalive=None, 85 save_for_backward=None, 86 tensor_hooks=None, 87 ): 88 super().__init__() 89 self.id_to_variable = id_to_variable or {} 90 self.store_attr_mutations = store_attr_mutations or {} 91 self.keepalive = keepalive or [] 92 self.save_for_backward = save_for_backward or [] 93 self.tensor_hooks = tensor_hooks or {} 94 # Track Compiled Autograd final callbacks that must be called at the end of Compiled Autograd backward graph. 95 # Only applicable if this graph is created from Dynamo tracing in Compiled Autograd. 96 self.ca_final_callbacks_var = None 97 98 def __eq__(self, other: object) -> bool: 99 assert isinstance(other, SideEffects) 100 # NB: do NOT test keepalive 101 return ( 102 self.id_to_variable == other.id_to_variable 103 and self.store_attr_mutations == other.store_attr_mutations 104 and self.save_for_backward == other.save_for_backward 105 and self.tensor_hooks == other.tensor_hooks 106 ) 107 108 def diff(self, other: "SideEffects") -> Optional[str]: 109 if self.id_to_variable != other.id_to_variable: 110 sk_itv = self.id_to_variable.keys() 111 ok_itv = other.id_to_variable.keys() 112 if sk_itv != ok_itv: 113 return f"id_to_variable keys: {sk_itv} != {ok_itv}" 114 # Feel free to augment this with more fancy diffing logic 115 # if needed for debugging 116 return "id_to_variable: unknown diff" 117 elif self.store_attr_mutations != other.store_attr_mutations: 118 sk_sam = self.store_attr_mutations.keys() 119 ok_sam = other.store_attr_mutations.keys() 120 if sk_sam != ok_sam: 121 return f"store_attr_mutations keys: {sk_sam} != {ok_sam}" 122 return "store_attr_mutations: unknown diff" 123 elif self.save_for_backward != other.save_for_backward: 124 return "save_for_backward" 125 elif self.tensor_hooks != other.tensor_hooks: 126 return "tensor_hooks" 127 else: 128 return None 129 130 def clone(self): 131 """Create a shallow copy""" 132 return self.__class__( 133 id_to_variable=dict(self.id_to_variable), 134 store_attr_mutations={ 135 k: dict(v) for k, v in self.store_attr_mutations.items() 136 }, 137 keepalive=list(self.keepalive), 138 save_for_backward=self.save_for_backward, 139 tensor_hooks=self.tensor_hooks, 140 ) 141 142 def __contains__(self, item): 143 return id(item) in self.id_to_variable 144 145 def __getitem__(self, item): 146 return self.id_to_variable[id(item)] 147 148 def check_allowed_side_effect(self, item): 149 from torch._dynamo.variables.misc import AutogradFunctionContextVariable 150 151 # People do things like self.dim = dim inside autograd.Function. 152 # These are benign. 153 if isinstance(item, AutogradFunctionContextVariable): 154 return True 155 if not is_side_effect_safe(item.mutable_local): 156 unimplemented( 157 "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)" 158 ) 159 160 def store_attr(self, item: VariableTracker, name: str, value: VariableTracker): 161 assert self.is_attribute_mutation(item) 162 self.check_allowed_side_effect(item) 163 if item.mutable_local not in self.store_attr_mutations: 164 self.store_attr_mutations[item.mutable_local] = {} 165 self.store_attr_mutations[item.mutable_local][name] = value 166 167 def load_attr(self, item, name, deleted_ok=False): 168 assert self.is_attribute_mutation(item) 169 result = self.store_attr_mutations[item.mutable_local][name] 170 if not deleted_ok and isinstance(result, variables.DeletedVariable): 171 unimplemented("read deleted attribute") 172 return result 173 174 def store_cell(self, cellvar, value): 175 assert isinstance(cellvar, variables.NewCellVariable) 176 assert isinstance(value, variables.VariableTracker) 177 self.store_attr(cellvar, "cell_contents", value) 178 179 def load_cell(self, cellvar): 180 assert isinstance(cellvar, variables.NewCellVariable) 181 return self.load_attr(cellvar, "cell_contents") 182 183 def load_global(self, gvar: VariableTracker, name: str): 184 assert isinstance(gvar, variables.VariableTracker) 185 return self.load_attr(gvar, name) 186 187 def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker): 188 assert isinstance(gvar, variables.VariableTracker) 189 assert isinstance(value, variables.VariableTracker) 190 self.store_attr(gvar, name, value) 191 192 @staticmethod 193 def cls_supports_mutation_side_effects(cls): 194 return ( 195 inspect.getattr_static(cls, "__getattribute__", None) 196 is object.__getattribute__ 197 ) 198 199 def is_attribute_mutation(self, item): 200 return isinstance(item.mutable_local, AttributeMutation) 201 202 def has_pending_mutation(self, item): 203 return self.is_attribute_mutation(item) and bool( 204 self.store_attr_mutations.get(item.mutable_local) 205 ) 206 207 def has_pending_mutation_of_attr(self, item, name): 208 return self.is_attribute_mutation( 209 item 210 ) and name in self.store_attr_mutations.get(item.mutable_local, ()) 211 212 def is_modified(self, item): 213 if isinstance(item.mutable_local, AttributeMutationNew): 214 return True 215 if self.is_attribute_mutation(item): 216 return item.mutable_local in self.store_attr_mutations 217 return item.mutable_local.is_modified 218 219 def _track_obj( 220 self, 221 item: Any, 222 variable: VariableTracker, 223 mutable_cls=MutableSideEffects, 224 ): 225 """Start tracking a new variable for mutation""" 226 assert variable.source is not None 227 228 if id(item) in self.id_to_variable: 229 raise AssertionError( 230 f"{variable} is already tracked for mutation. This could be " 231 "because you are not using VariableBuilder to construct " 232 "the variable tracker. " 233 f"Source of new object: {variable.source}. " 234 f"Source of previously tracked object: {self.id_to_variable[id(item)].source}." 235 ) 236 237 variable.mutable_local = mutable_cls(variable.source) 238 self.id_to_variable[id(item)] = variable 239 self.keepalive.append(item) 240 return variable 241 242 track_mutable = _track_obj 243 244 def track_object_existing( 245 self, 246 item: Any, 247 variable: VariableTracker, 248 ): 249 return self._track_obj(item, variable, mutable_cls=AttributeMutationExisting) 250 251 def track_object_new( 252 self, 253 cls_source: Source, 254 user_cls: Any, 255 variable_cls: Any, 256 options, 257 ): 258 if user_cls is torch.autograd.function.FunctionCtx: 259 with warnings.catch_warnings(record=True): 260 obj = torch.autograd.Function() 261 elif issubclass(user_cls, torch.nn.Module): 262 obj = nn_module_new(user_cls) 263 else: 264 obj = object_new(user_cls) 265 variable = variable_cls( 266 obj, 267 mutable_local=AttributeMutationNew(None, cls_source), 268 **options, 269 ) 270 self.id_to_variable[id(obj)] = variable 271 self.keepalive.append(obj) 272 return variable 273 274 def track_object_new_from_user_defined_class( 275 self, 276 cls_variable: "variables.UserDefinedClassVariable", 277 ): 278 cls_source = cls_variable.source 279 user_cls = cls_variable.value 280 281 # Find the variable class 282 variable_cls: Type[ 283 variables.UserDefinedObjectVariable 284 ] = variables.UserDefinedObjectVariable 285 if issubclass(user_cls, torch.nn.Module): 286 variable_cls = variables.UnspecializedNNModuleVariable 287 elif issubclass(user_cls, MutableMapping): 288 variable_cls = variables.MutableMappingVariable 289 elif is_frozen_dataclass(user_cls): 290 variable_cls = FrozenDataClassVariable 291 else: 292 variable_cls = variables.UserDefinedObjectVariable 293 294 assert issubclass(variable_cls, variables.UserDefinedObjectVariable) 295 296 variable_cls = functools.partial(variable_cls, cls_source=cls_source) 297 298 return self.track_object_new(cls_source, user_cls, variable_cls, {}) 299 300 def track_cell_new( 301 self, 302 ): 303 obj = object() 304 variable = variables.NewCellVariable( 305 mutable_local=AttributeMutationNew(None, None), 306 ) 307 self.id_to_variable[id(obj)] = variable 308 self.keepalive.append(obj) 309 return variable 310 311 def track_cell_existing(self, source: Source, item: Any): 312 variable = variables.NewCellVariable( 313 mutable_local=AttributeMutationExisting(source), 314 ) 315 self.id_to_variable[id(item)] = variable 316 self.keepalive.append(item) 317 return variable 318 319 def track_global_existing(self, source: Source, item: Any): 320 variable = variables.NewGlobalVariable( 321 mutable_local=AttributeMutationExisting(source), 322 ) 323 self.id_to_variable[id(item)] = variable 324 self.keepalive.append(item) 325 return variable 326 327 def track_save_for_backward(self, ctx, args): 328 assert isinstance(ctx, variables.AutogradFunctionContextVariable) 329 self.save_for_backward.append((ctx, args)) 330 331 def track_tensor_variables_from_runahead_side_effects(self, other): 332 # In higher order ops we want to keep track of tensors seen in the 333 # speculate_subgraph so that we don't lift them again as a new input in 334 # other speculate_subgraph or in the root tracer. 335 for other_item in other.keepalive: 336 other_id = id(other_item) 337 other_variable = other.id_to_variable[other_id] 338 if other_id not in self.id_to_variable and isinstance( 339 other_variable, variables.TensorVariable 340 ): 341 self.track_object_existing(other_item, other_variable) 342 343 def prune_dead_object_new(self, tx): 344 live_new_objects = set() 345 346 # use this to avoid cycles in mutable_local (though I'm not sure if that 347 # can actually happen). 348 visited: Any = set({}) 349 350 def visit(var: VariableTracker): 351 mutable_local = var.mutable_local 352 if mutable_local is None: 353 return 354 if mutable_local in visited: 355 return 356 visited.add(mutable_local) 357 # Object may have been mutated, store this mutation. 358 if isinstance(mutable_local, AttributeMutationNew): 359 live_new_objects.add(mutable_local) 360 # It's possible that we have mutated the value of this variable 361 # to be another one. The new value is in store_attr_mutations. 362 # Also recurse through the new value to detect alive AttributeMutationNew. 363 if var.mutable_local in self.store_attr_mutations: 364 VariableTracker.visit( 365 visit, self.store_attr_mutations[var.mutable_local] 366 ) 367 368 def is_live(var: Union[MutableLocalBase, VariableTracker]): 369 if isinstance(var, AttributeMutationNew): 370 return var in live_new_objects 371 if isinstance(var, VariableTracker): 372 return is_live(var.mutable_local) 373 return True 374 375 pre_existing_vars = [ 376 var 377 for var in self.id_to_variable.values() 378 if not isinstance(var.mutable_local, AttributeMutationNew) 379 ] 380 381 # The only live side effects come from returns (tx.stack), any intermediates 382 # during a graph break (tx.symbolic_locals), and mutation on pre-existing variables. 383 # Recursively visit Variables and see if any of them have been mutated. 384 VariableTracker.visit(visit, (tx.stack, tx.symbolic_locals, pre_existing_vars)) 385 386 # NB: cell variable handling.is tricky. 387 # cell variables must stay alive if any NestedUserFunctionVariable 388 # are live. "visit"-ing the NestedUserFunctionVariable visits 389 # the .closures field, from which we will see if we need to keep 390 # any mutations to cell variables alive. 391 392 self.id_to_variable = { 393 k: v for k, v in self.id_to_variable.items() if is_live(v) 394 } 395 self.store_attr_mutations = { 396 k: v for k, v in self.store_attr_mutations.items() if is_live(k) 397 } 398 399 def mutation(self, var): 400 self.check_allowed_side_effect(var) 401 if isinstance(var.mutable_local, MutableSideEffects): 402 var.mutable_local = MutableSideEffects(var.mutable_local.source, True) 403 404 def _get_modified_vars(self): 405 return [var for var in self.id_to_variable.values() if self.is_modified(var)] 406 407 def codegen_save_tempvars(self, cg: PyCodegen): 408 for var in self._get_modified_vars(): 409 if isinstance( 410 var.mutable_local, (AttributeMutationExisting, AttributeMutationNew) 411 ) and isinstance(var, variables.NewCellVariable): 412 cg.add_push_null( 413 lambda: cg.load_import_from(utils.__name__, "make_cell") 414 ) 415 cg.extend_output(create_call_function(0, False)) 416 cg.add_cache(var) 417 if isinstance(var.mutable_local, AttributeMutationNew): 418 var.mutable_local.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] 419 elif isinstance(var.mutable_local, AttributeMutationNew): 420 if isinstance(var, variables.AutogradFunctionContextVariable): 421 unimplemented("AutogradFunctionContextVariable escaped") 422 cg.add_push_null( 423 lambda: cg.load_import_from(utils.__name__, "object_new") 424 ) 425 cg(var.mutable_local.cls_source) 426 cg.extend_output(create_call_function(1, False)) 427 cg.add_cache(var) 428 var.mutable_local.source = LocalSource(cg.tempvars[var]) 429 elif var in cg.tempvars: 430 assert cg.tempvars.get(var) is None 431 # subsequent usage should point to the original variable 432 cg(var.mutable_local.source) 433 cg.add_cache(var) 434 435 for ctx, args in self.save_for_backward: 436 cg(ctx.source) 437 cg.load_method("save_for_backward") 438 for arg in args: 439 cg(arg) 440 cg.extend_output( 441 [ 442 *create_call_method(len(args)), 443 create_instruction("POP_TOP"), 444 ] 445 ) 446 447 def register_hook(self, tensor, hook, handle, name): 448 assert isinstance(tensor, variables.TensorVariable) 449 assert isinstance(hook, variables.VariableTracker) 450 assert ( 451 isinstance(handle, variables.RemovableHandleVariable) 452 and handle.mutable_local 453 ) 454 assert hasattr(torch.Tensor, name) 455 idx = len(self.tensor_hooks.keys()) 456 # duplicate index possible because of self.remove_hook() 457 while idx in self.tensor_hooks: 458 idx += 1 459 self.tensor_hooks[idx] = (tensor, hook, handle, name) 460 assert not handle.idx 461 handle.idx = idx 462 463 def remove_hook(self, idx): 464 del self.tensor_hooks[idx] 465 466 def codegen_hooks(self, cg): 467 for ( 468 tensor, 469 hook, 470 handle, 471 name, 472 ) in self.tensor_hooks.values(): 473 # Note: [On tensor.register_hook] 474 # 475 # register_hook on a tensor, AKA backward hooks, have slightly nuanced differences in how they are implemented 476 # when it comes to hooks on objects with sources (inputs, params) vs objects without sources (intermediaries). 477 # 478 # For tensors with a source, we bypass direct inclusion of register_hook calls in the graph. 479 # Instead, these are tracked and stashed as a global variable, enabling their association with tensors in 480 # the residuals. During dynamo's frame creation, these hooks are invoked seamlessly on known reconstructible/fetch-able 481 # tensors. Because a source indicates knowledge of this object outside the torch compile region, and 482 # because we are running residuals firmly before .backward() can be run, it is sound to invoke 483 # `register_hook` on a known tensor. 484 # 485 # For tensors without a source, we support a limited subset of hooks. Global functions only, and 486 # compiled_autograd must be enabled or we will graph break. 487 # 488 # Handling the Handle: When a user retains the register_hook result in a handle, we intercept the 489 # STORE_FAST operation to record the user-designated local variable name. This ensures the reconstructed 490 # bytecode retains this name. If no handle is defined, we simply pop the generated value to keep the 491 # stack intact. 492 # 493 # Dynamo Tensor Hooks Workflow: 494 # - Functions passed to register_hook are lifted globally. 495 # - For tensors with sources: 496 # - In the "side_effects" phase of codegen, we iterate over tensors with hooks to: 497 # - Generate the tensor. 498 # - Issue a register_hook call on the tensor, linking to the globally stored function. 499 # - Incorporate a handle if one was established in the eager phase. 500 # - For tensors without sources: 501 # - We don't generate any instructions for registering a hook. 502 # - Handles from intermediary hooks are NYI. 503 # - We produce a call function that utilizes the trace_wrapped higher order op, closing over it. 504 # - We then manually insert the call function above into the graph. 505 # - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST. 506 assert tensor.source, "Hooks on non input tensors NYI - should not get here" 507 508 def gen_fn(): 509 cg(tensor) 510 cg.extend_output([cg.create_load_attr(name)]) 511 512 cg.add_push_null(gen_fn) 513 cg(hook) 514 cg.extend_output(create_call_function(1, False)) 515 516 # Adding the handle to the cache means RemovableHandleVariable().reconstruct() will 517 # be associated with the return value of register_hook(). This consumes the top of stack. 518 cg.add_cache(handle) 519 520 def get_ca_final_callbacks_var(self): 521 from .variables.base import MutableLocal 522 523 if self.ca_final_callbacks_var is None: 524 self.ca_final_callbacks_var = variables.ListVariable( 525 [], mutable_local=MutableLocal() 526 ) 527 return self.ca_final_callbacks_var 528 529 def codegen_update_mutated(self, cg: PyCodegen): 530 suffixes = [] 531 for var in self._get_modified_vars(): 532 if isinstance(var, variables.ListVariable): 533 # old[:] = new 534 cg(var, allow_cache=False) 535 cg(var.mutable_local.source) # type: ignore[attr-defined] 536 cg.extend_output( 537 [ 538 cg.create_load_const(None), 539 cg.create_load_const(None), 540 create_instruction("BUILD_SLICE", arg=2), 541 ] 542 ) 543 suffixes.append([create_instruction("STORE_SUBSCR")]) 544 elif isinstance(var, variables.CustomizedDictVariable): 545 # need to update the dict manually since update method may be invalid 546 varname_map = {} 547 for name in _manual_update_dict.__code__.co_varnames: 548 varname_map[name] = cg.tx.output.new_var() 549 550 cg(var.mutable_local.source) # type: ignore[attr-defined] 551 cg.extend_output( 552 [create_instruction("STORE_FAST", argval=varname_map["dict_to"])] 553 ) 554 555 cg(var, allow_cache=False) 556 cg.extend_output( 557 [create_instruction("STORE_FAST", argval=varname_map["dict_from"])] 558 ) 559 560 cg(var.mutable_local.source) # type: ignore[attr-defined] 561 cg.load_method("clear") 562 563 # unfortunately can't just use DICT_MERGE due to possible custom behaviors 564 dict_update_insts = bytecode_from_template( 565 _manual_update_dict, varname_map=varname_map 566 ) 567 568 suffixes.append( 569 [ 570 *create_call_method(0), # clear 571 create_instruction("POP_TOP"), 572 *dict_update_insts, 573 create_instruction("POP_TOP"), 574 ] 575 ) 576 577 elif isinstance(var, variables.ConstDictVariable): 578 cg(var.mutable_local.source) # type: ignore[attr-defined] 579 cg.load_method("update") 580 cg(var, allow_cache=False) 581 582 cg(var.mutable_local.source) # type: ignore[attr-defined] 583 cg.load_method("clear") 584 585 suffixes.append( 586 [ 587 *create_call_method(0), # clear 588 create_instruction("POP_TOP"), 589 *create_call_method(1), # update 590 create_instruction("POP_TOP"), 591 ] 592 ) 593 elif isinstance( 594 var, variables.torch_function.TorchFunctionModeStackVariable 595 ): 596 cg.add_push_null( 597 lambda: cg.load_import_from( 598 utils.__name__, "set_torch_function_mode_stack" 599 ) 600 ) 601 cg.foreach(var.symbolic_stack) 602 cg.append_output( 603 create_instruction("BUILD_LIST", arg=len(var.symbolic_stack)) 604 ) 605 cg.call_function(1, False) 606 cg.append_output(create_instruction("POP_TOP")) 607 elif self.is_attribute_mutation(var): 608 # Applying mutations involves two steps: 1) Push all 609 # reconstructed objects onto the stack. 2) Call STORE_ATTR to 610 # apply the mutations. 611 # 612 # Dynamo must ensure that mutations are applied in the same 613 # order as in the original program. Therefore, two reverse 614 # operations occur below. 615 # 616 # The first reverse operation concerns `suffixes`. We apply 617 # suffixes in reverse order due to the way Python handles the 618 # stack. In Step 1, we push all reconstructed objects onto the 619 # stack, but the item at the top of the stack refers to the last 620 # attribute in the mutation order. If not fixed, this will apply 621 # the mutations of attributes in the reverse order. To account 622 # for this reversal, we iterate through the mutable attributes 623 # in reverse order. 624 for name, value in reversed( 625 self.store_attr_mutations.get(var.mutable_local, {}).items() 626 ): 627 if isinstance(var, variables.NewGlobalVariable): 628 cg.tx.output.update_co_names(name) 629 cg(value) 630 assert isinstance(var.mutable_local.source, GlobalSource) # type: ignore[attr-defined] 631 suffixes.append( 632 [create_instruction("STORE_GLOBAL", argval=name)] 633 ) 634 elif isinstance(value, variables.DeletedVariable): 635 if isinstance( 636 var.mutable_local, AttributeMutationExisting 637 ) and hasattr(getattr(var, "value", None), name): 638 cg.tx.output.update_co_names(name) 639 cg(var.mutable_local.source) 640 suffixes.append( 641 [create_instruction("DELETE_ATTR", argval=name)] 642 ) 643 elif ( 644 isinstance(var, variables.UserDefinedObjectVariable) 645 and var.needs_slow_setattr() 646 ): 647 # __setattr__ is defined on this object, so call object.__setattr__ directly 648 cg.load_import_from("builtins", "object") 649 cg.load_method("__setattr__") 650 cg(var.mutable_local.source) # type: ignore[attr-defined] 651 cg(variables.ConstantVariable(name)) 652 cg(value) 653 suffixes.append( 654 [*create_call_method(3), create_instruction("POP_TOP")] 655 ) 656 else: 657 cg.tx.output.update_co_names(name) 658 cg(value) 659 cg(var.mutable_local.source) 660 suffixes.append([create_instruction("STORE_ATTR", argval=name)]) 661 elif isinstance(var, variables.TupleIteratorVariable): 662 for _ in range(var.index): 663 cg.add_push_null( 664 lambda: cg.load_import_from(utils.__name__, "iter_next") 665 ) 666 cg(var.mutable_local.source) # type: ignore[attr-defined] 667 cg.call_function(1, False) 668 cg.pop_top() 669 elif isinstance(var, variables.RandomVariable): 670 # set correct random seed state 671 def gen_fn(): 672 cg(var.mutable_local.source) # type: ignore[attr-defined] 673 cg.load_attr("setstate") 674 675 cg.add_push_null(gen_fn) 676 cg(var.wrap_state(var.random.getstate())) 677 678 suffixes.append( 679 [ 680 *create_call_function(1, False), # setstate 681 create_instruction("POP_TOP"), 682 ] 683 ) 684 else: 685 raise AssertionError(type(var)) 686 687 # do all the actual mutations at the very end to handle dependencies 688 for suffix in reversed(suffixes): 689 cg.extend_output(suffix) 690 691 def is_empty(self): 692 return not ( 693 any(map(self.is_modified, self.id_to_variable.values())) 694 or self.tensor_hooks 695 or self.save_for_backward 696 or self.tensor_hooks 697 ) 698 699 def clear(self): 700 self.keepalive.clear() 701 self.id_to_variable.clear() 702