1# mypy: allow-untyped-defs 2import collections 3import dataclasses 4import re 5import sys 6import types 7from typing import Counter, Dict, List, Optional 8 9import torch.nn 10 11from . import utils 12from .bytecode_transformation import ( 13 add_push_null, 14 add_push_null_call_function_ex, 15 create_call_function, 16 create_call_method, 17 create_dup_top, 18 create_instruction, 19 create_load_method, 20 create_rot_n, 21 Instruction, 22) 23from .exc import unimplemented 24from .source import AttrSource, Source 25from .utils import is_safe_constant, rot_n_helper 26from .variables.base import VariableTracker 27from .variables.nn_module import NNModuleVariable 28from .variables.tensor import ( 29 NumpyNdarrayVariable, 30 SymNodeVariable, 31 TensorVariable, 32 UnspecializedPythonVariable, 33) 34from .variables.torch_function import TensorWithTFOverrideVariable 35 36 37@dataclasses.dataclass 38class GraphOutputEntry: 39 index: int 40 variable: VariableTracker 41 42 43class PyCodegen: 44 """ 45 Helper class uses for constructing Python bytecode 46 """ 47 48 def __init__( 49 self, 50 tx=None, 51 root: Optional[torch.nn.Module] = None, 52 graph_output_var: Optional[str] = None, 53 tempvars=None, 54 ) -> None: 55 self.root = root 56 self.top_of_stack: Optional[VariableTracker] = None 57 self.uses: Counter[VariableTracker] = collections.Counter() 58 self.graph_outputs: Dict[int, GraphOutputEntry] = {} 59 self._output: List[Instruction] = [] 60 self.tempvars = tempvars or {} 61 self.tx = tx 62 self.graph_output_var = graph_output_var 63 self.code_options = self.tx.output.code_options 64 self.cell_and_freevars = self.tx.cell_and_freevars 65 self.new_var = self.tx.output.new_var 66 self.mutable_side_effects_from_source = False 67 self.value_from_source: bool = True 68 69 def restore_stack(self, stack_values, *, value_from_source=True): 70 prior = self.mutable_side_effects_from_source 71 self.mutable_side_effects_from_source = True 72 prev = self.value_from_source 73 self.value_from_source &= value_from_source 74 try: 75 self.foreach(stack_values) 76 finally: 77 self.mutable_side_effects_from_source = prior 78 self.value_from_source = prev 79 80 def graph_output_vars(self): 81 return [x.variable for x in self.graph_outputs.values()] 82 83 def call_reconstruct(self, value): 84 res = value.reconstruct(self) 85 assert res is None, f"reconstruct!=None {value}" 86 87 def add_push_null(self, gen_fn, call_function_ex=False): 88 """ 89 `gen_fn` generates instructions via PyCodegen methods 90 that push a single callable to the stack. 91 92 `add_push_null` pushes a NULL to the stack before or after the 93 instructions generated by `gen_fn`, depending on Python version. 94 95 Will attempt to use the NULL push bit for instructions 96 with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR). 97 """ 98 old_len = len(self._output) 99 if sys.version_info < (3, 13): 100 # gen_fn may DUP_TOP instead if TOS is not cleared. 101 # Will cause problems since NULL will be pushed right 102 # before the generated instructions in <= 3.12 103 self.clear_tos() 104 gen_fn() 105 # inplace modify self._output 106 added_insts = self._output[old_len:] 107 del self._output[old_len:] 108 if call_function_ex: 109 self._output.extend(add_push_null_call_function_ex(added_insts)) 110 else: 111 self._output.extend(add_push_null(added_insts)) 112 if sys.version_info >= (3, 13): 113 # NULL will be at top of stack 114 self.clear_tos() 115 116 def __call__(self, value, allow_cache=True): 117 """Generate code such that top-of-stack (TOS) is set to value""" 118 if isinstance(value, Source): 119 self.call_reconstruct(value) 120 self.clear_tos() 121 return 122 123 assert isinstance(value, VariableTracker) 124 output = self._output 125 graph_outputs = self.graph_outputs 126 127 if self.top_of_stack is value and allow_cache: 128 output.append(create_dup_top()) 129 return 130 131 if self.mutable_side_effects_from_source: 132 # this is needed to get aliasing relationships right 133 # value.mutable_local.source will get mutated to hold `value` 134 # mutable_side_effects_from_source=False is used to codegen the mutation 135 # mutable_side_effects_from_source=True is used to codegen a reference 136 from .side_effects import MutableSideEffects 137 138 if isinstance(value.mutable_local, MutableSideEffects): 139 self(value.mutable_local.source) 140 return 141 142 if allow_cache: 143 if value.mutable_local and value.mutable_local in self.tempvars: 144 output.append(self.create_load(self.tempvars[value.mutable_local])) 145 self.top_of_stack = value 146 return 147 if self.tempvars.get(value) is not None: 148 output.append(self.create_load(self.tempvars[value])) 149 self.top_of_stack = value 150 return 151 152 if value.source is not None and allow_cache and self.value_from_source: 153 self.call_reconstruct(value.source) 154 elif value.is_python_constant() and is_safe_constant( 155 value.as_python_constant() 156 ): 157 output.append(self.create_load_const(value.as_python_constant())) 158 elif isinstance(value, TensorWithTFOverrideVariable): 159 graph_outputs_key = self.add_graph_output(value) 160 161 self.add_push_null( 162 lambda: self.load_import_from(utils.__name__, "to_subclass") 163 ) 164 self.load_graph_output(graph_outputs[graph_outputs_key].index) 165 output.append( 166 self.create_load_global( 167 value.global_mangled_class_name(self.tx), add=True 168 ) 169 ) 170 output.extend(create_call_function(2, False)) 171 elif ( 172 isinstance(value, SymNodeVariable) 173 and value.python_type() == float 174 and not self.tx.export 175 ): 176 # This is a little unusual; force the output convention to be a 177 # Tensor here. Don't do this for export because this is 178 # apparently load bearing for export tests (but I am a bit 179 # doubtful it actually works in the real world) 180 # NB: It works to add_graph_output on a computed expression 181 # as_tensor here, because we memoize as_tensor calls on 182 # SymNodeVariable! 183 graph_outputs_key = self.add_graph_output(value.as_tensor(self.tx)) 184 185 def gen_fn(): 186 self.load_graph_output(graph_outputs[graph_outputs_key].index) 187 output.append(self.create_load_attr("item")) 188 189 self.add_push_null(gen_fn) 190 output.extend(create_call_function(0, False)) 191 elif isinstance( 192 value, 193 ( 194 TensorVariable, 195 SymNodeVariable, 196 UnspecializedPythonVariable, 197 NumpyNdarrayVariable, 198 ), 199 ): 200 graph_outputs_key = self.add_graph_output(value) 201 202 if isinstance(value, NumpyNdarrayVariable): 203 self.add_push_null( 204 lambda: self.load_import_from(utils.__name__, "to_numpy_helper") 205 ) 206 self.load_graph_output(graph_outputs[graph_outputs_key].index) 207 output.extend(create_call_function(1, False)) 208 elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap: 209 210 def gen_fn(): 211 self.load_graph_output(graph_outputs[graph_outputs_key].index) 212 output.append(self.create_load_attr("item")) 213 214 self.add_push_null(gen_fn) 215 output.extend(create_call_function(0, False)) 216 else: 217 self.load_graph_output(graph_outputs[graph_outputs_key].index) 218 elif isinstance(value, NNModuleVariable): 219 parts = value.module_key.split(".") 220 if parts[0] in self.code_options["co_varnames"]: 221 output.append(self.create_load(parts[0])) 222 parts = parts[1:] 223 else: 224 assert self.root is not None 225 output.append(self.create_load_output(self.root)) 226 for part in parts: 227 output.append(self.create_load_attr(part)) 228 else: 229 self.uses[value] += 1 230 try: 231 self.call_reconstruct(value) 232 except NotImplementedError: 233 unimplemented(f"reconstruct: {value}") 234 if allow_cache and value in self.tempvars: 235 self._output.append(create_dup_top()) 236 self.add_cache(value) 237 238 self.top_of_stack = value 239 240 def add_graph_output(self, value): 241 graph_outputs_key = id(value.as_proxy()) 242 if graph_outputs_key not in self.graph_outputs: 243 self.graph_outputs[graph_outputs_key] = GraphOutputEntry( 244 len(self.graph_outputs), value 245 ) 246 return graph_outputs_key 247 248 def load_graph_output(self, index): 249 output = self._output 250 output.append(self.create_load(self.graph_output_var)) 251 output.append(self._create_load_const(index)) 252 output.append(create_instruction("BINARY_SUBSCR")) 253 254 def add_cache(self, value): 255 var = self.new_var() 256 self.tempvars[value] = var 257 if value.mutable_local: 258 self.tempvars[value.mutable_local] = var 259 self._output.append(self.create_store(var)) 260 261 def foreach(self, items): 262 for i in items: 263 self(i) 264 265 def setup_globally_cached(self, name, value): 266 """Store value in a new global""" 267 name = re.sub(r"[^a-zA-Z0-9_]+", "_", name) 268 f_globals = self.tx.f_globals 269 if name in f_globals: 270 assert id(f_globals[name]) == id(value) 271 else: 272 f_globals[name] = value 273 return [self.create_load_global(name, add=True)] 274 275 def clear_tos(self): 276 self.top_of_stack = None 277 278 def append_output(self, inst): 279 assert isinstance(inst, Instruction) 280 self._output.append(inst) 281 self.clear_tos() 282 283 def extend_output(self, insts): 284 assert all(isinstance(x, Instruction) for x in insts) 285 self._output.extend(insts) 286 self.clear_tos() 287 288 def get_instructions(self) -> List[Instruction]: 289 return self._output 290 291 def create_load(self, name) -> Instruction: 292 if name in self.cell_and_freevars(): 293 return create_instruction("LOAD_DEREF", argval=name) 294 assert name in self.code_options["co_varnames"], f"{name} missing" 295 return create_instruction("LOAD_FAST", argval=name) 296 297 def create_load_closure(self, name) -> Instruction: 298 assert name in self.cell_and_freevars() 299 inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE" 300 return create_instruction(inst_name, argval=name) 301 302 def create_store(self, name) -> Instruction: 303 if name in self.cell_and_freevars(): 304 return create_instruction("STORE_DEREF", argval=name) 305 assert name in self.code_options["co_varnames"] 306 return create_instruction("STORE_FAST", argval=name) 307 308 def create_load_global(self, name, add=False) -> Instruction: 309 if add: 310 self.tx.output.update_co_names(name) 311 assert name in self.code_options["co_names"], f"{name} not in co_names" 312 return create_instruction("LOAD_GLOBAL", argval=name) 313 314 def create_load_const(self, value) -> Instruction: 315 assert is_safe_constant(value), f"unsafe constant {value}" 316 return self._create_load_const(value) 317 318 def _create_load_const(self, value) -> Instruction: 319 return create_instruction("LOAD_CONST", argval=value) 320 321 create_load_output = _create_load_const 322 323 def load_method(self, name): 324 self.tx.output.update_co_names(name) 325 self.append_output(create_load_method(name)) 326 327 def call_method(self, nargs): 328 self.extend_output(create_call_method(nargs)) 329 330 def create_load_attr(self, name) -> Instruction: 331 if name not in self.code_options["co_names"]: 332 self.code_options["co_names"] += (name,) 333 return create_instruction("LOAD_ATTR", argval=name) 334 335 def load_attr(self, name): 336 self.append_output(self.create_load_attr(name)) 337 338 def create_load_attrs(self, names): 339 return [self.create_load_attr(name) for name in names.split(".")] 340 341 def create_store_attr(self, name) -> Instruction: 342 if name not in self.code_options["co_names"]: 343 self.code_options["co_names"] += (name,) 344 return create_instruction("STORE_ATTR", argval=name) 345 346 def store_attr(self, name): 347 self.append_output(self.create_store_attr(name)) 348 349 def load_function_name(self, fn_name, push_null, num_on_stack=0): 350 """Load the global fn_name on the stack num_on_stack down""" 351 output = [] 352 if push_null and sys.version_info >= (3, 11): 353 output.extend(add_push_null(self.create_load_global(fn_name, add=True))) 354 if num_on_stack > 0: 355 output.extend( 356 [ 357 *self.rot_n(num_on_stack + 2), 358 *self.rot_n(num_on_stack + 2), 359 ] 360 ) 361 else: 362 output.extend( 363 [ 364 self.create_load_global(fn_name, add=True), 365 *self.rot_n(num_on_stack + 1), 366 ] 367 ) 368 return output 369 370 def rot_n(self, n): 371 try: 372 return create_rot_n(n) 373 except AttributeError: 374 # desired rotate bytecode doesn't exist, generate equivalent bytecode 375 return [ 376 create_instruction("BUILD_TUPLE", arg=n), 377 self._create_load_const(rot_n_helper(n)), 378 *create_rot_n(2), 379 create_instruction("CALL_FUNCTION_EX", arg=0), 380 create_instruction("UNPACK_SEQUENCE", arg=n), 381 ] 382 383 def pop_null(self): 384 # POP_TOP doesn't work for null, so we pop nulls by pushing in a 385 # nop function, calling it (which consumes the null), and popping the result. 386 assert sys.version_info >= (3, 11) 387 return [ 388 self._create_load_const(lambda: None), 389 # 3.13 swapped NULL and callable 390 *( 391 (create_instruction("SWAP", arg=2),) 392 if sys.version_info >= (3, 13) 393 else () 394 ), 395 *create_call_function(0, False), 396 create_instruction("POP_TOP"), 397 ] 398 399 def pop_top(self): 400 self.append_output(create_instruction("POP_TOP")) 401 402 def call_function(self, nargs: int, push_null: bool): 403 self.extend_output(create_call_function(nargs, push_null=push_null)) 404 405 def dup_top(self): 406 self.append_output(create_dup_top()) 407 408 def store(self, varname): 409 self.append_output(self.create_store(varname)) 410 411 def make_function_with_closure( 412 self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0 413 ): 414 freevars = code.co_freevars 415 assert freevars 416 output = self._output 417 418 def gen_fn(): 419 for var in freevars: 420 assert var in self.cell_and_freevars() 421 inst_name = ( 422 "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE" 423 ) 424 output.append(create_instruction(inst_name, argval=var)) 425 output.append(create_instruction("BUILD_TUPLE", arg=len(freevars))) 426 output.append(self.create_load_const(code)) 427 if sys.version_info < (3, 11): 428 output.append(self.create_load_const(fn_name)) 429 if sys.version_info >= (3, 13): 430 output.extend( 431 [ 432 create_instruction("MAKE_FUNCTION"), 433 create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08), 434 ] 435 ) 436 else: 437 output.append(create_instruction("MAKE_FUNCTION", arg=0x08)) 438 439 if push_null and sys.version_info >= (3, 11): 440 self.add_push_null(gen_fn) 441 output.extend(self.rot_n(num_on_stack + 2)) 442 output.extend(self.rot_n(num_on_stack + 2)) 443 else: 444 gen_fn() 445 output.extend(self.rot_n(num_on_stack + 1)) 446 self.clear_tos() 447 448 def create_load_python_module(self, mod) -> Instruction: 449 """ 450 Generate a LOAD_GLOBAL instruction to fetch a given python module. 451 """ 452 output = self.tx.output 453 global_scope = output.global_scope 454 name = re.sub(r"^.*[.]", "", mod.__name__) 455 if global_scope.get(name, None) is mod: 456 return self.create_load_global(name, add=True) 457 prefix = f"___module_{name}" 458 global_name = self.tx.output.install_global_by_id(prefix, mod) 459 return self.create_load_global(global_name, add=True) 460 461 def make_call_generated_code(self, fn_name: str) -> None: 462 """Call the generated code function stored in fn_name""" 463 self.extend_output(self.load_function_name(fn_name, True)) 464 465 graphargs = self.tx.output.graphargs 466 for arg in graphargs: 467 if arg.pass_arg_as_tensor: 468 self.add_push_null( 469 lambda: self.extend_output( 470 [ 471 self.create_load_python_module(torch), 472 self.create_load_attr("as_tensor"), 473 ] 474 ) 475 ) 476 self.call_reconstruct(arg) 477 self.extend_output(create_call_function(1, False)) 478 else: 479 self.call_reconstruct(arg) 480 481 self.extend_output(create_call_function(len(graphargs), False)) 482 483 def load_import_from(self, module_name, object_name) -> None: 484 self(AttrSource(self.tx.import_source(module_name), object_name)) 485 486 def create_call_function_kw(self, nargs, kw_names, push_null) -> List[Instruction]: 487 if sys.version_info >= (3, 13): 488 output = create_call_function(nargs, push_null) 489 assert output[-1].opname == "CALL" 490 output.insert(-1, self.create_load_const(kw_names)) 491 output[-1] = create_instruction("CALL_KW", arg=nargs) 492 return output 493 elif sys.version_info >= (3, 11): 494 output = create_call_function(nargs, push_null) 495 if sys.version_info >= (3, 12): 496 idx = -1 497 expected_inst = "CALL" 498 else: 499 idx = -2 500 expected_inst = "PRECALL" 501 assert output[idx].opname == expected_inst 502 kw_names_inst = create_instruction("KW_NAMES", argval=kw_names) 503 output.insert(idx, kw_names_inst) 504 return output 505 return [ 506 self.create_load_const(kw_names), 507 create_instruction("CALL_FUNCTION_KW", arg=nargs), 508 ] 509 510 def create_delete(self, value) -> Instruction: 511 return create_instruction("DELETE_FAST", argval=value) 512