1# mypy: ignore-errors 2 3import collections 4import copy 5import functools 6import inspect 7import itertools 8import types 9from typing import Dict, List, Optional, TYPE_CHECKING, Union 10 11import torch 12 13from .. import variables 14from ..bytecode_transformation import create_call_function, create_rot_n 15from ..exc import unimplemented, Unsupported 16from ..guards import GuardBuilder, install_guard 17from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource 18from ..utils import check_constant_args, get_first_attr, identity, istype, make_cell 19from .base import MutableLocal, typestr, VariableTracker 20from .constant import ConstantVariable 21 22if TYPE_CHECKING: 23 from torch._guards import Source 24 25 26def wrap_bound_arg(tx, val, source=None): 27 # Source propagation is best effort since not every object we encounter has a source to begin with. 28 if isinstance(val, VariableTracker): 29 return val 30 elif not source: 31 from torch._dynamo.variables.builder import SourcelessBuilder 32 33 return SourcelessBuilder.create(tx, val) 34 else: 35 # Create a lazy variable to avoid guarding on __defaults__ unless really 36 # needed. 37 return variables.LazyVariableTracker.create(val, source) 38 39 40def wrap_args_kwargs(tx, result): 41 for k, v in list(result.items()): 42 if isinstance(v, (tuple, dict)): 43 # args/kwargs 44 result[k] = wrap_bound_arg(tx, v) 45 46 47def init_cellvars(parent, result, code): 48 closure_cells = dict() 49 side_effects = parent.output.side_effects 50 51 # for name in itertools.chain(code.co_cellvars, code.co_freevars): 52 for name in code.co_cellvars: 53 closure_cells[name] = side_effects.track_cell_new() 54 if name in result: 55 side_effects.store_cell(closure_cells[name], result.pop(name)) 56 57 return closure_cells 58 59 60def _create_nested_fn( 61 code, f_globals, name, defaults, closure, kwdefaults, annotations 62): 63 from types import FunctionType 64 65 func = FunctionType(code, f_globals, name, defaults, closure) 66 func.__kwdefaults__ = kwdefaults 67 68 if isinstance(annotations, tuple): 69 from itertools import pairwise 70 71 annotations = dict(pairwise(annotations)) 72 73 # TypeError: __annotations__ must be set to a dict object 74 assert annotations is None or isinstance(annotations, dict) 75 func.__annotations__ = annotations 76 77 return func 78 79 80class BaseUserFunctionVariable(VariableTracker): 81 def get_filename(self): 82 return self.get_code().co_filename 83 84 def get_name(self): 85 return self.get_code().co_name 86 87 def call_function( 88 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 89 ) -> "VariableTracker": 90 return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) 91 92 def call_hasattr(self, tx, name: str) -> VariableTracker: 93 result = False 94 95 try: 96 result = hasattr(self.get_function(), name) 97 except NotImplementedError: 98 if name == "__name__" and isinstance(self, NestedUserFunctionVariable): 99 result = True 100 return variables.ConstantVariable.create(result) 101 102 def inspect_parameter_names(self): 103 return list(inspect.signature(self.get_function()).parameters) 104 105 def closure_vars(self, tx): 106 return {} 107 108 109class UserFunctionVariable(BaseUserFunctionVariable): 110 """Some unsupported user-defined global function""" 111 112 _nonvar_fields = { 113 "fn", 114 "is_constant", 115 *BaseUserFunctionVariable._nonvar_fields, 116 } 117 118 @classmethod 119 def create_with_source(cls, value, source): 120 install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH)) 121 return cls( 122 value, 123 source=source, 124 ) 125 126 def __init__(self, fn, is_constant=False, **kwargs): 127 super().__init__(**kwargs) 128 if getattr(fn, "_dynamo_marked_constant", False): 129 # This method should be treated as a constant for the purposes of compilation 130 self.is_constant = True 131 else: 132 self.is_constant = False 133 134 assert isinstance( 135 fn, (types.FunctionType, torch.jit.ScriptFunction) 136 ), f"expected FunctionType found {typestr(fn)} {fn}" 137 # unpack @torch._dynamo.optimize()(fn) wrapped function 138 fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn) 139 # unpack torch.jit.script_if_tracing 140 if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False): 141 fn = inspect.getattr_static(fn, "__original_fn", fn) 142 self.fn: types.FunctionType = fn 143 144 def as_python_constant(self): 145 if istype(self, UserFunctionVariable): 146 return self.fn 147 # subclasses (such as methods) usually aren't a constant 148 return super().as_python_constant() 149 150 def self_args(self): 151 return [] 152 153 def get_function(self): 154 return self.fn 155 156 def get_code(self): 157 return self.fn.__code__ 158 159 def python_type(self): 160 return types.FunctionType 161 162 def has_self(self): 163 return getattr(self.fn, "__self__", None) is not None 164 165 def get_globals(self): 166 return self.fn.__globals__ 167 168 def bind_args(self, parent, args, kwargs): 169 assert not self.is_constant 170 tx = parent.output.root_tx 171 wrap = functools.partial(wrap_bound_arg, tx=tx) 172 173 fn: types.FunctionType = self.fn 174 defaults = fn.__defaults__ or [] 175 defaults_sources = [ 176 None if self.source is None else DefaultsSource(self.source, idx) 177 for idx, _ in enumerate(defaults) 178 ] 179 fake_func = types.FunctionType( 180 fn.__code__, 181 fn.__globals__, 182 fn.__name__, 183 tuple( 184 [ 185 wrap(val=arg, source=source) 186 for arg, source in zip(defaults, defaults_sources) 187 ] 188 ), 189 fn.__closure__, 190 ) 191 if fn.__kwdefaults__: 192 kwdefaults_sources = { 193 k: None 194 if self.source is None 195 else DefaultsSource(self.source, k, is_kw=True) 196 for k in fn.__kwdefaults__ 197 } 198 fake_func.__kwdefaults__ = { 199 k: wrap(val=v, source=kwdefaults_sources[k]) 200 for k, v in fn.__kwdefaults__.items() 201 } 202 203 bound = inspect.signature(fake_func).bind(*args, **kwargs) 204 bound.apply_defaults() 205 result = dict(bound.arguments.items()) 206 207 wrap_args_kwargs(tx, result) 208 closure_cells = init_cellvars(parent, result, fn.__code__) 209 closure = self.fn.__closure__ or () 210 assert len(closure) == len(self.fn.__code__.co_freevars) 211 for idx, name, cell in zip( 212 itertools.count(), self.fn.__code__.co_freevars, closure 213 ): 214 if name == "__class__": 215 source = AttrSource(self.source, "__class__") if self.source else None 216 result[name] = variables.UserDefinedClassVariable( 217 cell.cell_contents, 218 source=source, 219 ) 220 else: 221 var = tx.match_nested_cell(name, cell) 222 if var is not None: 223 # optimization for cleaner codegen 224 result[name] = var 225 elif self.source: 226 from .builder import VariableBuilder 227 228 side_effects = parent.output.side_effects 229 if cell in side_effects: 230 out = side_effects[cell] 231 else: 232 closure_cell = GetItemSource( 233 AttrSource(self.source, "__closure__"), idx 234 ) 235 closure_cell_contents = AttrSource( 236 closure_cell, "cell_contents" 237 ) 238 try: 239 contents_var = VariableBuilder( 240 parent, closure_cell_contents 241 )(cell.cell_contents) 242 except ValueError: 243 # Cell has not yet been assigned 244 contents_var = variables.DeletedVariable() 245 246 if ( 247 closure_cell_contents.name() 248 not in tx.mutated_closure_cell_contents 249 ): 250 # Optimistically don't allocate the cell, to 251 # reduce the number of side effects. This is 252 # important for cond, as without it, any accesses 253 # to closures create side effects and cond doesn't 254 # support side effects. If we're wrong and this 255 # closure cell gets written to, we will restart 256 # the analysis with this cell's name in the 257 # mutated list here 258 result[name] = contents_var 259 continue 260 261 # cells are written to with "cell_contents", 262 # so the source should just be the closure_cell, not its contents 263 out = side_effects.track_cell_existing(closure_cell, cell) 264 side_effects.store_cell( 265 out, 266 contents_var, 267 ) 268 269 result[name] = out 270 271 else: 272 from .builder import SourcelessBuilder 273 274 result[name] = SourcelessBuilder.create(tx, cell.cell_contents) 275 276 return result, closure_cells 277 278 def export_freevars(self, parent, child): 279 pass 280 281 def call_hasattr(self, tx, name: str) -> VariableTracker: 282 result = hasattr(self.fn, name) 283 return variables.ConstantVariable.create(result) 284 285 def call_function( 286 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 287 ) -> "VariableTracker": 288 if self.is_constant: 289 return invoke_and_store_as_constant( 290 tx, self.fn, self.get_name(), args, kwargs 291 ) 292 293 return super().call_function(tx, args, kwargs) 294 295 296class UserMethodVariable(UserFunctionVariable): 297 """Some unsupported user-defined method""" 298 299 def __init__(self, fn, obj, **kwargs): 300 super().__init__(fn=fn, **kwargs) 301 self.obj = obj 302 303 def __str__(self): 304 return f"{self.__class__.__name__}({self.fn}, {self.obj})" 305 306 def self_args(self): 307 return [self.obj] 308 309 def python_type(self): 310 return types.MethodType 311 312 def call_function( 313 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 314 ) -> "VariableTracker": 315 # For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution 316 # rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method 317 # since we ensure `forward` of allowed modules can be traced by AOT safely. 318 # Note this is not only for allowed modules, as user customized modules can extend from 319 # allowed modules but using parent's `forward` method, which is also covered by this branch. 320 321 # If we are tracing the higher order op, we want Dynamo to step inside 322 # the module call so that Dynamo can see the underlying parameters and 323 # buffers and raise them as inputs to the graph. The is_root_tracer 324 # check bypasses the if condition for non-root tracers and directly 325 # calls the super().call_function at the end, which is basically 326 # equivalent of inlining the method. 327 if tx.output.is_root_tracer() and isinstance( 328 self.obj, variables.NNModuleVariable 329 ): 330 module_attr = getattr(self.fn, "__module__", "") 331 # inline torch.nn.utils.parametrize 332 if ( 333 module_attr is not None 334 and module_attr.startswith("torch.nn.") 335 and module_attr != "torch.nn.utils.parametrize" 336 or self.is_constant 337 ): 338 return self.obj.call_method( 339 tx, self.fn.__name__, args, kwargs, constant=self.is_constant 340 ) 341 if self.is_constant: 342 fn = getattr(self.obj.value, self.fn.__name__) 343 return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs) 344 return super().call_function(tx, args, kwargs) 345 346 def inspect_parameter_names(self): 347 return super().inspect_parameter_names()[1:] 348 349 350class WrappedUserMethodVariable(UserMethodVariable): 351 def __init__(self, wrapped, context, **kwargs): 352 kwargs.pop("fn", None) 353 kwargs.pop("obj", None) 354 super().__init__(wrapped.fn, wrapped.obj, **kwargs) 355 self.wrapped = wrapped 356 self.context = context 357 358 def call_function( 359 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 360 ) -> "VariableTracker": 361 self.context.enter(tx) 362 result = super().call_function(tx, args, kwargs) 363 self.context.exit(tx) 364 return result 365 366 367class WrappedUserFunctionVariable(UserFunctionVariable): 368 def __init__(self, wrapped, context, **kwargs): 369 kwargs.pop("fn", None) 370 kwargs.pop("obj", None) 371 super().__init__(wrapped.fn, **kwargs) 372 self.wrapped = wrapped 373 self.context = context 374 375 def call_function( 376 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 377 ) -> "VariableTracker": 378 self.context.enter(tx) 379 result = super().call_function(tx, args, kwargs) 380 self.context.exit(tx) 381 return result 382 383 384def invoke_and_store_as_constant(tx, fn, name, args, kwargs): 385 def convert(x): 386 if isinstance(x, variables.TensorVariable): 387 return x.get_real_value() 388 return x.as_python_constant() 389 390 args = [convert(x) for x in args] 391 kwargs = {k: convert(v) for k, v in kwargs.items()} 392 res = fn(*args, **kwargs) 393 return tx.output.register_attr_or_module( 394 res, 395 name, 396 source=ConstantSource(name), 397 ) 398 399 400class NestedUserFunctionVariable(BaseUserFunctionVariable): 401 _nonvar_fields = { 402 "closure_scope", 403 "f_globals", 404 *BaseUserFunctionVariable._nonvar_fields, 405 } 406 407 def __init__( 408 self, 409 fn_name, 410 code, 411 f_globals, 412 defaults, 413 kwdefaults, 414 annotations, 415 closure, 416 closure_scope, 417 wrapped_reconstructible=None, 418 **kwargs, 419 ): 420 super().__init__(**kwargs) 421 assert isinstance(fn_name.as_python_constant(), str) 422 assert isinstance(code.as_python_constant(), types.CodeType) 423 assert isinstance(f_globals, dict) 424 self.fn_name = fn_name 425 self.code = code 426 self.f_globals = f_globals 427 self.defaults = defaults 428 self.kwdefaults = kwdefaults 429 self.annotations = annotations 430 self.closure = closure 431 if closure is None: 432 closure_scope = None 433 self.closure_scope = closure_scope 434 # Either a source or a VT with .can_reconstruct() == True 435 self.wrapped_reconstructible: Optional[ 436 Union[Source, VariableTracker] 437 ] = wrapped_reconstructible 438 439 def self_args(self): 440 return [] 441 442 def get_code(self): 443 return self.code.as_python_constant() 444 445 def get_function(self): 446 if self.closure: 447 raise NotImplementedError 448 func = types.FunctionType( 449 self.code.as_python_constant(), 450 self.f_globals, 451 self.fn_name.as_python_constant(), 452 ) 453 if self.defaults: 454 func.__defaults__ = self.defaults.as_python_constant() 455 if self.kwdefaults: 456 func.__kwdefaults__ = self.kwdefaults.as_python_constant() 457 if self.annotations: 458 annotations = self.annotations.as_python_constant() 459 if isinstance(annotations, tuple): 460 from itertools import pairwise 461 462 annotations = dict(pairwise(annotations)) 463 464 # TypeError: __annotations__ must be set to a dict object 465 assert isinstance(annotations, dict) 466 func.__annotations__ = annotations 467 return func 468 469 def has_closure(self): 470 return self.closure is not None 471 472 def has_self(self): 473 return False 474 475 def get_globals(self): 476 return self.f_globals 477 478 def bind_args(self, parent, args, kwargs): 479 from .misc import InlinedClosureVariable 480 481 code = self.get_code() 482 func = types.FunctionType( 483 code, 484 self.f_globals, 485 self.fn_name.as_python_constant(), 486 tuple(self.defaults.items) if self.defaults else None, 487 tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))), 488 ) 489 if self.kwdefaults: 490 func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant() 491 bound = inspect.signature(func).bind(*args, **kwargs) 492 bound.apply_defaults() 493 result = dict(bound.arguments.items()) 494 wrap_args_kwargs(parent.output.root_tx, result) 495 closure_cells = init_cellvars(parent, result, code) 496 497 for idx, name in enumerate(code.co_freevars): 498 cell = self.closure.items[idx] 499 assert getattr(cell, name, name) == name 500 assert name not in result 501 if isinstance(cell, InlinedClosureVariable): 502 # InlinedClosureVariable's are created from LOAD_CLOSURE's from 503 # InliningInstructionTranslators when the variable name is not found in closure_cells. 504 # They should remain outside of closure_cells, so that our callee (the 505 # InliningInstructionTranslator that traces `func`) handles 506 # the cell correctly - that is, the cell's contents are treated as if they 507 # are local variables, like in UserFunctionVariable's bind_args for freevars. 508 cand = parent 509 while cand and name not in cand.symbolic_locals: 510 cand = cand.parent 511 if cand is None: 512 raise RuntimeError( 513 f"Couldn't find {name} in the symbolic_locals of the inline interpreter stack" 514 ) 515 result[name] = cand.symbolic_locals[name] 516 else: 517 closure_cells[name] = self.closure.items[idx] 518 519 return result, closure_cells 520 521 def export_freevars(self, parent, child): 522 code = self.get_code() 523 for var in code.co_freevars: 524 if var in child.symbolic_locals: 525 parent.symbolic_locals[var] = child.symbolic_locals[var] 526 527 def reconstruct(self, codegen): 528 codegen.load_import_from(__name__, "_create_nested_fn") 529 codegen(self.code) 530 codegen.extend_output([codegen._create_load_const(self.f_globals)]) 531 codegen(ConstantVariable.create(self.code.value.co_name)) 532 533 if self.defaults: 534 codegen(self.defaults) 535 else: 536 codegen.extend_output([codegen.create_load_const(None)]) 537 538 if self.closure: 539 codegen(self.closure) 540 else: 541 codegen.extend_output([codegen.create_load_const(None)]) 542 543 if self.kwdefaults: 544 codegen(self.kwdefaults) 545 else: 546 codegen.extend_output([codegen.create_load_const(None)]) 547 548 if self.annotations: 549 try: 550 annotations = self.annotations.as_python_constant() 551 codegen.extend_output([codegen._create_load_const(annotations)]) 552 except NotImplementedError: 553 codegen(self.annotations) 554 else: 555 codegen.extend_output([codegen.create_load_const(None)]) 556 557 codegen.extend_output(create_call_function(7, push_null=True)) 558 559 if self.wrapped_reconstructible: 560 codegen.load_import_from("functools", "wraps") 561 codegen(self.wrapped_reconstructible) 562 codegen.extend_output(create_call_function(1, True)) 563 codegen.extend_output(create_rot_n(2)) 564 codegen.extend_output(create_call_function(1, True)) 565 566 567class SkipFunctionVariable(VariableTracker): 568 _nonvar_fields = { 569 "value", 570 "reason", 571 *VariableTracker._nonvar_fields, 572 } 573 574 def __init__(self, value, reason=None, **kwargs): 575 super().__init__(**kwargs) 576 self.value = value 577 self.reason = reason 578 579 def python_type(self): 580 return type(self.value) 581 582 def as_python_constant(self): 583 return self.value 584 585 @classmethod 586 def create_with_source(cls, value, source): 587 install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) 588 return cls( 589 value, 590 source=source, 591 ) 592 593 @staticmethod 594 @functools.lru_cache(None) 595 def fold_through_function_to_wrapper(): 596 return { 597 collections.namedtuple: variables.UserDefinedClassVariable, 598 } 599 600 def call_function( 601 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 602 ) -> "VariableTracker": 603 if inspect.getattr_static(self.value, "_torchdynamo_disable", False): 604 unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}") 605 # Fold through the functions(e.g, collections.namedtuple) 606 # that inputs & outputs are all python constants 607 elif ( 608 self.value in self.fold_through_function_to_wrapper().keys() 609 and check_constant_args(args, kwargs) 610 ): 611 value = self.value( 612 *[x.as_python_constant() for x in args], 613 **{k: v.as_python_constant() for k, v in kwargs.items()}, 614 ) 615 return self.fold_through_function_to_wrapper().get(self.value)( 616 value, mutable_local=MutableLocal() 617 ) 618 elif ( 619 self.value is functools.wraps 620 and not kwargs 621 and len(args) == 1 622 and ( 623 args[0].source is not None or args[0].can_reconstruct(tx.output.root_tx) 624 ) 625 ): 626 627 def wraps(fn): 628 if isinstance(fn, variables.NestedUserFunctionVariable): 629 if args[0].source: 630 reconstructible = args[0].source 631 else: 632 reconstructible = args[0] 633 return fn.clone(wrapped_reconstructible=reconstructible) 634 unimplemented(f"functools.wraps({fn})") 635 636 return variables.LambdaVariable(wraps) 637 else: 638 try: 639 path = inspect.getfile(self.value) 640 msg = f"'skip function {self.value.__qualname__} in file {path}'" 641 except TypeError: 642 known_python_builtin_modules = {"_abc", "_warnings"} 643 if self.value.__module__ in known_python_builtin_modules: 644 msg = ( 645 f"Graph break due to unsupported Python builtin {self.value.__module__}.{self.value.__qualname__}. " 646 f"Please file an issue on GitHub " 647 f"so the PyTorch team can add support for it. " 648 ) 649 else: 650 msg = ( 651 f"Graph break due to unsupported builtin {self.value.__module__}.{self.value.__qualname__}. " 652 f"This function is either a Python builtin (e.g. _warnings.warn) " 653 f"or a third-party C/C++ Python extension (perhaps created with pybind). " 654 f"If it is a Python builtin, please file an issue on GitHub " 655 f"so the PyTorch team can add support for it and see the next case for a workaround. " 656 f"If it is a third-party C/C++ Python extension, please " 657 f"either wrap it into a PyTorch-understood custom operator " 658 f"(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html " 659 f"for more details) or, if it is traceable, use " 660 f"torch.compiler.allow_in_graph." 661 ) 662 # also warn on it because most users won't see the graph break message 663 torch._dynamo.utils.warn_once(msg) 664 msg += f"', {self.reason}'" if self.reason else "" 665 unimplemented(msg) 666 667 668def _traceable_collective_remaps(): 669 # We can't rely on importing from distributed, since it's not always built 670 if torch.distributed.is_available(): 671 from torch.distributed._functional_collectives import ( 672 traceable_collective_remaps, 673 ) 674 675 return traceable_collective_remaps 676 return {} 677 678 679def _traceable_collectives_source(tx, fn): 680 assert torch.distributed.is_available(), "Illegal invocation." 681 assert fn in _traceable_collective_remaps().values() 682 683 inner_name = fn.__name__ 684 path_source = tx.import_source("torch.distributed._functional_collectives") 685 return AttrSource(path_source, inner_name) 686 687 688class CollectiveFunctionRewriteVariable(UserFunctionVariable): 689 """ 690 Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives. 691 692 This class provides both a way to check if a function is remappable, and perform the remapping. 693 694 In the case that a function is 'remappable' but only for some combinations of call-time arguments, 695 we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse 696 than status-quo as we currently graph-break on all distributed.* collectives. 697 """ 698 699 def __init__(self, fn, *, replacement_var, **kwargs): 700 super().__init__(fn, **kwargs) 701 assert isinstance(replacement_var, UserFunctionVariable) 702 self.replacement_var = replacement_var 703 704 @staticmethod 705 def create(tx, old_fn, source, **options): 706 new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn) 707 return CollectiveFunctionRewriteVariable( 708 old_fn, 709 replacement_var=UserFunctionVariable(new_fn, source=new_source, **options), 710 source=source, 711 **options, 712 ) 713 714 @staticmethod 715 def can_rewrite(variable): 716 return ( 717 inspect.isfunction(variable) and variable in _traceable_collective_remaps() 718 ) 719 720 @staticmethod 721 def rewrite(tx, fn): 722 new_fn = _traceable_collective_remaps()[fn] 723 return new_fn, _traceable_collectives_source(tx, new_fn) 724 725 def call_function( 726 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 727 ) -> "VariableTracker": 728 # call_function must check any unsupported arguments and graph-break. 729 # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn, 730 # since that's the contract for putting a mapping in `traceable_collective_remaps` 731 import torch.distributed as dist 732 from torch.distributed._functional_collectives import REDUCE_OP_TO_STR 733 734 # Merge args into kwargs so positional and keyword args 735 # can be processed the same way. 736 signature = inspect.signature(self.fn) 737 kwargs = dict(signature.bind(*args, **kwargs).arguments) 738 args = () 739 740 if "async_op" in kwargs and kwargs["async_op"].as_python_constant(): 741 unimplemented( 742 f"CollectiveFunctionRewriteVariable can't support async_op=True for {self.fn}" 743 ) 744 745 if self.fn in ( 746 dist.all_reduce, 747 dist.reduce_scatter_tensor, 748 dist._reduce_scatter_base, 749 ): 750 reduce_op_var = kwargs.get("op") 751 reduce_op = ( 752 reduce_op_var.value 753 if reduce_op_var is not None 754 else signature.parameters["op"].default 755 ) 756 if reduce_op not in REDUCE_OP_TO_STR: 757 raise ValueError(f"Unsupported all_reduce op: {reduce_op}") 758 kwargs["op"] = variables.ConstantVariable.create( 759 REDUCE_OP_TO_STR[reduce_op] 760 ) 761 return self.replacement_var.call_function(tx, args, kwargs) 762 763 764class FunctoolsPartialVariable(VariableTracker): 765 def __init__(self, func: VariableTracker, args, keywords, **kwargs): 766 super().__init__(**kwargs) 767 self.func = func 768 assert isinstance(args, list) 769 self.args = args 770 assert isinstance(keywords, dict) 771 self.keywords = keywords 772 773 def reconstruct(self, codegen): 774 codegen.load_import_from("functools", "partial") 775 codegen(self.func) 776 if self.args: 777 codegen.foreach(self.args) 778 if not self.keywords: 779 codegen.extend_output(create_call_function(len(self.args) + 1, True)) 780 return 781 782 codegen.foreach(self.keywords.values()) 783 keys = tuple(self.keywords.keys()) 784 codegen.extend_output( 785 codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, True) 786 ) 787 788 def get_function(self): 789 return self.as_python_constant() 790 791 def call_function( 792 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 793 ) -> "VariableTracker": 794 merged_args = self.args + args 795 merged_kwargs = {**self.keywords, **kwargs} 796 return self.func.call_function(tx, merged_args, merged_kwargs) 797 798 def call_hasattr(self, tx, name: str) -> VariableTracker: 799 # functools.partial uses slots, so attributes are constant 800 return variables.ConstantVariable.create( 801 hasattr(functools.partial(identity), name) 802 ) 803 804 def as_python_constant(self): 805 return functools.partial( 806 self.func.as_python_constant(), 807 *[arg.as_python_constant() for arg in self.args], 808 **{k: v.as_python_constant() for k, v in self.keywords.items()}, 809 ) 810 811 def guard_as_python_constant(self): 812 """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" 813 return functools.partial( 814 self.func.guard_as_python_constant(), 815 *[v.guard_as_python_constant() for v in self.args], 816 **{k: v.guard_as_python_constant() for k, v in self.keywords.items()}, 817 ) 818 819 820class TritonKernelVariable(VariableTracker): 821 def __init__(self, kernel, kernel_idx, grid, **kwargs): 822 from triton.runtime.autotuner import Autotuner 823 824 from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table 825 826 super().__init__(**kwargs) 827 828 assert kernel is not None 829 830 self.kernel = kernel 831 self.kernel_idx = kernel_side_table.add_kernel(kernel) 832 833 assert kernel_idx is None or self.kernel_idx == kernel_idx 834 835 self.grid = grid 836 837 if isinstance(kernel, Autotuner): 838 # We only support configs and keys arguments of triton.autotune 839 # Make sure other arguments are defaulted 840 defaults = inspect.signature(Autotuner.__init__).parameters 841 842 # Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep. 843 # The call to get_first_attr is to maintain backward-compatibility. 844 if ( 845 ( 846 "warmup" in defaults 847 and defaults["warmup"].default 848 != get_first_attr(kernel, "num_warmups", "warmup") 849 ) 850 or ( 851 "rep" in defaults 852 and defaults["rep"].default 853 != get_first_attr(kernel, "num_reps", "rep") 854 ) 855 or ( 856 "prune_configs_by" in defaults 857 and defaults["prune_configs_by"].default 858 != kernel.early_config_prune 859 ) 860 # Set via reset_to_zero argument 861 or len(kernel.reset_idx) != 0 862 or len(kernel.restore_idx) != 0 863 ): 864 raise Unsupported( 865 "Only configs and keys are supported for triton.autotune" 866 ) 867 868 def call_function( 869 self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" 870 ) -> "VariableTracker": 871 from triton.runtime.autotuner import autotune, Autotuner, Config 872 873 from .constant import ConstantVariable 874 from .dicts import ConstDictVariable 875 from .lists import BaseListVariable 876 877 if "num_ctas" in kwargs: 878 raise Unsupported( 879 "Passing num_ctas directly to the Triton kernel is not supported. " 880 "Please use a Config in @triton.autotune instead." 881 ) 882 883 special_kwargs = {} 884 for name in ("num_warps", "num_stages"): 885 if name in kwargs: 886 # remove special kwargs from `kwargs` 887 val = kwargs.pop(name) 888 assert isinstance(val, ConstantVariable) 889 special_kwargs[name] = val.value 890 891 if special_kwargs: 892 if isinstance(self.kernel, Autotuner): 893 # if there is Autotuner already, set 894 # special kwargs to each of its configs 895 new_configs = copy.deepcopy(self.kernel.configs) 896 for config in new_configs: 897 config.__dict__.update(special_kwargs) 898 new_kernel = autotune(configs=new_configs, key=[])(self.kernel.fn) 899 else: 900 # if there is no Autotuner, wrap the kernel into a 901 # new one with a single config with special kwargs 902 new_config = Config(kwargs={}, **special_kwargs) 903 new_kernel = autotune(configs=[new_config], key=[])(self.kernel) 904 905 # create a new variable to contain the new (wrapped) kernel; 906 # skip kernel_idx to get a new record in the kernel side table 907 new_var = TritonKernelVariable(new_kernel, None, self.grid) 908 return new_var.call_function(tx, args, kwargs) 909 910 if self.grid is None: 911 raise Unsupported("Triton kernels should always be called with a grid") 912 913 # Both for grid's meta as well as for the kernel, we need combined 914 # args and kwargs combined and normalized 915 combined_args_raw = {**dict(zip(self.kernel.arg_names, args)), **kwargs} 916 combined_args = { 917 variables.ConstantVariable.create(k): v 918 for k, v in combined_args_raw.items() 919 } 920 921 configs = ( 922 [config.kwargs for config in self.kernel.configs] 923 if isinstance(self.kernel, Autotuner) 924 else [{}] 925 ) 926 grids = [] 927 for config_args in configs: 928 # If the grid is a function, then lets execute it and convert it to 929 # a list 930 grid = self.grid 931 if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)): 932 # Populate the special "meta" argument to call the grid function 933 config_args = { 934 ConstantVariable.create(k): ConstantVariable.create(v) 935 for k, v in config_args.items() 936 } 937 meta = ConstDictVariable({**combined_args, **config_args}, dict) 938 grid = grid.call_function(tx, [meta], {}) 939 940 # Now, the grid must be a list either originally or through above 941 # modification 942 if isinstance(grid, BaseListVariable): 943 grids.append(grid.as_proxy()) 944 else: 945 unimplemented(f"grid for the triton kernel is {type(grid)}") 946 947 for i in range(len(grids)): 948 if not isinstance(grids[i], tuple): 949 raise Unsupported("Only tuple grids are supported") 950 # inductor expects all grids to be 3-tuple so lets make it 951 if len(grids[i]) == 1: 952 grids[i] = (grids[i][0], 1, 1) 953 elif len(grids[i]) == 2: 954 grids[i] = (grids[i][0], grids[i][1], 1) 955 elif len(grids[i]) > 3: 956 raise Unsupported("Grid can have at most rank 3") 957 958 assert len(grids) != 0 959 if len(set(grids)) == 1: 960 # If there's only one unique grid, lets simplify 961 grids = [grids[0]] 962 963 from torch._higher_order_ops.triton_kernel_wrap import ( 964 kernel_side_table, 965 triton_kernel_wrapper_mutation, 966 ) 967 968 # Combine args and kwargs and pass as a dict so that if user defined triton 969 # kernel uses variables as 'grid' or 'kernel', it does not conflict with 970 # parameters of the wrapper function 971 constant_args = { 972 k: v.as_python_constant() 973 for k, v in combined_args_raw.items() 974 if isinstance(v, ConstantVariable) 975 } 976 non_constant_args = { 977 k: v 978 for k, v in combined_args.items() 979 if not isinstance(v, ConstantVariable) 980 } 981 982 constant_args_idx = kernel_side_table.add_constant_args(constant_args) 983 meta = ConstDictVariable(non_constant_args, dict) 984 tx.output.create_proxy( 985 "call_function", 986 triton_kernel_wrapper_mutation, 987 (), 988 { 989 "kernel_idx": self.kernel_idx, 990 "constant_args_idx": constant_args_idx, 991 "grid": grids, 992 "kwargs": meta.as_proxy(), 993 }, 994 ) 995 996 return variables.ConstantVariable( 997 None, 998 ) 999 1000 def call_method( 1001 self, 1002 tx, 1003 name, 1004 args: "List[VariableTracker]", 1005 kwargs: "Dict[str, VariableTracker]", 1006 ) -> "VariableTracker": 1007 if name == "__getitem__": 1008 # __getitem__ should only be called if we don't already have a grid 1009 # Only grid needs to be passed 1010 if self.grid is not None or len(args) != 1: 1011 raise Unsupported( 1012 "Triton kernels should be called with only a single grid" 1013 ) 1014 1015 return TritonKernelVariable( 1016 kernel=self.kernel, 1017 kernel_idx=self.kernel_idx, 1018 grid=args[0], 1019 ) 1020 elif name == "run": 1021 if "grid" not in kwargs: 1022 raise Unsupported("Triton kernel requires to be called with a grid") 1023 grid = kwargs.pop("grid") 1024 kwargs.pop("warmup", None) 1025 # rewrite kernel.run(*args, grid=grid) to kernel[grid](*args) 1026 return TritonKernelVariable( 1027 kernel=self.kernel, kernel_idx=self.kernel_idx, grid=grid 1028 ).call_function(tx, args, kwargs) 1029 1030 # Bail out to parent's implementation 1031 return super().call_method(tx, name, args, kwargs) 1032