xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/functions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import collections
4import copy
5import functools
6import inspect
7import itertools
8import types
9from typing import Dict, List, Optional, TYPE_CHECKING, Union
10
11import torch
12
13from .. import variables
14from ..bytecode_transformation import create_call_function, create_rot_n
15from ..exc import unimplemented, Unsupported
16from ..guards import GuardBuilder, install_guard
17from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
18from ..utils import check_constant_args, get_first_attr, identity, istype, make_cell
19from .base import MutableLocal, typestr, VariableTracker
20from .constant import ConstantVariable
21
22if TYPE_CHECKING:
23    from torch._guards import Source
24
25
26def wrap_bound_arg(tx, val, source=None):
27    # Source propagation is best effort since not every object we encounter has a source to begin with.
28    if isinstance(val, VariableTracker):
29        return val
30    elif not source:
31        from torch._dynamo.variables.builder import SourcelessBuilder
32
33        return SourcelessBuilder.create(tx, val)
34    else:
35        # Create a lazy variable to avoid guarding on __defaults__ unless really
36        # needed.
37        return variables.LazyVariableTracker.create(val, source)
38
39
40def wrap_args_kwargs(tx, result):
41    for k, v in list(result.items()):
42        if isinstance(v, (tuple, dict)):
43            # args/kwargs
44            result[k] = wrap_bound_arg(tx, v)
45
46
47def init_cellvars(parent, result, code):
48    closure_cells = dict()
49    side_effects = parent.output.side_effects
50
51    # for name in itertools.chain(code.co_cellvars, code.co_freevars):
52    for name in code.co_cellvars:
53        closure_cells[name] = side_effects.track_cell_new()
54        if name in result:
55            side_effects.store_cell(closure_cells[name], result.pop(name))
56
57    return closure_cells
58
59
60def _create_nested_fn(
61    code, f_globals, name, defaults, closure, kwdefaults, annotations
62):
63    from types import FunctionType
64
65    func = FunctionType(code, f_globals, name, defaults, closure)
66    func.__kwdefaults__ = kwdefaults
67
68    if isinstance(annotations, tuple):
69        from itertools import pairwise
70
71        annotations = dict(pairwise(annotations))
72
73    # TypeError: __annotations__ must be set to a dict object
74    assert annotations is None or isinstance(annotations, dict)
75    func.__annotations__ = annotations
76
77    return func
78
79
80class BaseUserFunctionVariable(VariableTracker):
81    def get_filename(self):
82        return self.get_code().co_filename
83
84    def get_name(self):
85        return self.get_code().co_name
86
87    def call_function(
88        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
89    ) -> "VariableTracker":
90        return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
91
92    def call_hasattr(self, tx, name: str) -> VariableTracker:
93        result = False
94
95        try:
96            result = hasattr(self.get_function(), name)
97        except NotImplementedError:
98            if name == "__name__" and isinstance(self, NestedUserFunctionVariable):
99                result = True
100        return variables.ConstantVariable.create(result)
101
102    def inspect_parameter_names(self):
103        return list(inspect.signature(self.get_function()).parameters)
104
105    def closure_vars(self, tx):
106        return {}
107
108
109class UserFunctionVariable(BaseUserFunctionVariable):
110    """Some unsupported user-defined global function"""
111
112    _nonvar_fields = {
113        "fn",
114        "is_constant",
115        *BaseUserFunctionVariable._nonvar_fields,
116    }
117
118    @classmethod
119    def create_with_source(cls, value, source):
120        install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
121        return cls(
122            value,
123            source=source,
124        )
125
126    def __init__(self, fn, is_constant=False, **kwargs):
127        super().__init__(**kwargs)
128        if getattr(fn, "_dynamo_marked_constant", False):
129            # This method should be treated as a constant for the purposes of compilation
130            self.is_constant = True
131        else:
132            self.is_constant = False
133
134        assert isinstance(
135            fn, (types.FunctionType, torch.jit.ScriptFunction)
136        ), f"expected FunctionType found {typestr(fn)} {fn}"
137        # unpack @torch._dynamo.optimize()(fn) wrapped function
138        fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
139        # unpack torch.jit.script_if_tracing
140        if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
141            fn = inspect.getattr_static(fn, "__original_fn", fn)
142        self.fn: types.FunctionType = fn
143
144    def as_python_constant(self):
145        if istype(self, UserFunctionVariable):
146            return self.fn
147        # subclasses (such as methods) usually aren't a constant
148        return super().as_python_constant()
149
150    def self_args(self):
151        return []
152
153    def get_function(self):
154        return self.fn
155
156    def get_code(self):
157        return self.fn.__code__
158
159    def python_type(self):
160        return types.FunctionType
161
162    def has_self(self):
163        return getattr(self.fn, "__self__", None) is not None
164
165    def get_globals(self):
166        return self.fn.__globals__
167
168    def bind_args(self, parent, args, kwargs):
169        assert not self.is_constant
170        tx = parent.output.root_tx
171        wrap = functools.partial(wrap_bound_arg, tx=tx)
172
173        fn: types.FunctionType = self.fn
174        defaults = fn.__defaults__ or []
175        defaults_sources = [
176            None if self.source is None else DefaultsSource(self.source, idx)
177            for idx, _ in enumerate(defaults)
178        ]
179        fake_func = types.FunctionType(
180            fn.__code__,
181            fn.__globals__,
182            fn.__name__,
183            tuple(
184                [
185                    wrap(val=arg, source=source)
186                    for arg, source in zip(defaults, defaults_sources)
187                ]
188            ),
189            fn.__closure__,
190        )
191        if fn.__kwdefaults__:
192            kwdefaults_sources = {
193                k: None
194                if self.source is None
195                else DefaultsSource(self.source, k, is_kw=True)
196                for k in fn.__kwdefaults__
197            }
198            fake_func.__kwdefaults__ = {
199                k: wrap(val=v, source=kwdefaults_sources[k])
200                for k, v in fn.__kwdefaults__.items()
201            }
202
203        bound = inspect.signature(fake_func).bind(*args, **kwargs)
204        bound.apply_defaults()
205        result = dict(bound.arguments.items())
206
207        wrap_args_kwargs(tx, result)
208        closure_cells = init_cellvars(parent, result, fn.__code__)
209        closure = self.fn.__closure__ or ()
210        assert len(closure) == len(self.fn.__code__.co_freevars)
211        for idx, name, cell in zip(
212            itertools.count(), self.fn.__code__.co_freevars, closure
213        ):
214            if name == "__class__":
215                source = AttrSource(self.source, "__class__") if self.source else None
216                result[name] = variables.UserDefinedClassVariable(
217                    cell.cell_contents,
218                    source=source,
219                )
220            else:
221                var = tx.match_nested_cell(name, cell)
222                if var is not None:
223                    # optimization for cleaner codegen
224                    result[name] = var
225                elif self.source:
226                    from .builder import VariableBuilder
227
228                    side_effects = parent.output.side_effects
229                    if cell in side_effects:
230                        out = side_effects[cell]
231                    else:
232                        closure_cell = GetItemSource(
233                            AttrSource(self.source, "__closure__"), idx
234                        )
235                        closure_cell_contents = AttrSource(
236                            closure_cell, "cell_contents"
237                        )
238                        try:
239                            contents_var = VariableBuilder(
240                                parent, closure_cell_contents
241                            )(cell.cell_contents)
242                        except ValueError:
243                            # Cell has not yet been assigned
244                            contents_var = variables.DeletedVariable()
245
246                        if (
247                            closure_cell_contents.name()
248                            not in tx.mutated_closure_cell_contents
249                        ):
250                            # Optimistically don't allocate the cell, to
251                            # reduce the number of side effects.  This is
252                            # important for cond, as without it, any accesses
253                            # to closures create side effects and cond doesn't
254                            # support side effects.  If we're wrong and this
255                            # closure cell gets written to, we will restart
256                            # the analysis with this cell's name in the
257                            # mutated list here
258                            result[name] = contents_var
259                            continue
260
261                        # cells are written to with "cell_contents",
262                        # so the source should just be the closure_cell, not its contents
263                        out = side_effects.track_cell_existing(closure_cell, cell)
264                        side_effects.store_cell(
265                            out,
266                            contents_var,
267                        )
268
269                    result[name] = out
270
271                else:
272                    from .builder import SourcelessBuilder
273
274                    result[name] = SourcelessBuilder.create(tx, cell.cell_contents)
275
276        return result, closure_cells
277
278    def export_freevars(self, parent, child):
279        pass
280
281    def call_hasattr(self, tx, name: str) -> VariableTracker:
282        result = hasattr(self.fn, name)
283        return variables.ConstantVariable.create(result)
284
285    def call_function(
286        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
287    ) -> "VariableTracker":
288        if self.is_constant:
289            return invoke_and_store_as_constant(
290                tx, self.fn, self.get_name(), args, kwargs
291            )
292
293        return super().call_function(tx, args, kwargs)
294
295
296class UserMethodVariable(UserFunctionVariable):
297    """Some unsupported user-defined method"""
298
299    def __init__(self, fn, obj, **kwargs):
300        super().__init__(fn=fn, **kwargs)
301        self.obj = obj
302
303    def __str__(self):
304        return f"{self.__class__.__name__}({self.fn}, {self.obj})"
305
306    def self_args(self):
307        return [self.obj]
308
309    def python_type(self):
310        return types.MethodType
311
312    def call_function(
313        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
314    ) -> "VariableTracker":
315        # For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution
316        # rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method
317        # since we ensure `forward` of allowed modules can be traced by AOT safely.
318        # Note this is not only for allowed modules, as user customized modules can extend from
319        # allowed modules but using parent's `forward` method, which is also covered by this branch.
320
321        # If we are tracing the higher order op, we want Dynamo to step inside
322        # the module call so that Dynamo can see the underlying parameters and
323        # buffers and raise them as inputs to the graph. The is_root_tracer
324        # check bypasses the if condition for non-root tracers and directly
325        # calls the super().call_function at the end, which is basically
326        # equivalent of inlining the method.
327        if tx.output.is_root_tracer() and isinstance(
328            self.obj, variables.NNModuleVariable
329        ):
330            module_attr = getattr(self.fn, "__module__", "")
331            # inline torch.nn.utils.parametrize
332            if (
333                module_attr is not None
334                and module_attr.startswith("torch.nn.")
335                and module_attr != "torch.nn.utils.parametrize"
336                or self.is_constant
337            ):
338                return self.obj.call_method(
339                    tx, self.fn.__name__, args, kwargs, constant=self.is_constant
340                )
341        if self.is_constant:
342            fn = getattr(self.obj.value, self.fn.__name__)
343            return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs)
344        return super().call_function(tx, args, kwargs)
345
346    def inspect_parameter_names(self):
347        return super().inspect_parameter_names()[1:]
348
349
350class WrappedUserMethodVariable(UserMethodVariable):
351    def __init__(self, wrapped, context, **kwargs):
352        kwargs.pop("fn", None)
353        kwargs.pop("obj", None)
354        super().__init__(wrapped.fn, wrapped.obj, **kwargs)
355        self.wrapped = wrapped
356        self.context = context
357
358    def call_function(
359        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
360    ) -> "VariableTracker":
361        self.context.enter(tx)
362        result = super().call_function(tx, args, kwargs)
363        self.context.exit(tx)
364        return result
365
366
367class WrappedUserFunctionVariable(UserFunctionVariable):
368    def __init__(self, wrapped, context, **kwargs):
369        kwargs.pop("fn", None)
370        kwargs.pop("obj", None)
371        super().__init__(wrapped.fn, **kwargs)
372        self.wrapped = wrapped
373        self.context = context
374
375    def call_function(
376        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
377    ) -> "VariableTracker":
378        self.context.enter(tx)
379        result = super().call_function(tx, args, kwargs)
380        self.context.exit(tx)
381        return result
382
383
384def invoke_and_store_as_constant(tx, fn, name, args, kwargs):
385    def convert(x):
386        if isinstance(x, variables.TensorVariable):
387            return x.get_real_value()
388        return x.as_python_constant()
389
390    args = [convert(x) for x in args]
391    kwargs = {k: convert(v) for k, v in kwargs.items()}
392    res = fn(*args, **kwargs)
393    return tx.output.register_attr_or_module(
394        res,
395        name,
396        source=ConstantSource(name),
397    )
398
399
400class NestedUserFunctionVariable(BaseUserFunctionVariable):
401    _nonvar_fields = {
402        "closure_scope",
403        "f_globals",
404        *BaseUserFunctionVariable._nonvar_fields,
405    }
406
407    def __init__(
408        self,
409        fn_name,
410        code,
411        f_globals,
412        defaults,
413        kwdefaults,
414        annotations,
415        closure,
416        closure_scope,
417        wrapped_reconstructible=None,
418        **kwargs,
419    ):
420        super().__init__(**kwargs)
421        assert isinstance(fn_name.as_python_constant(), str)
422        assert isinstance(code.as_python_constant(), types.CodeType)
423        assert isinstance(f_globals, dict)
424        self.fn_name = fn_name
425        self.code = code
426        self.f_globals = f_globals
427        self.defaults = defaults
428        self.kwdefaults = kwdefaults
429        self.annotations = annotations
430        self.closure = closure
431        if closure is None:
432            closure_scope = None
433        self.closure_scope = closure_scope
434        # Either a source or a VT with .can_reconstruct() == True
435        self.wrapped_reconstructible: Optional[
436            Union[Source, VariableTracker]
437        ] = wrapped_reconstructible
438
439    def self_args(self):
440        return []
441
442    def get_code(self):
443        return self.code.as_python_constant()
444
445    def get_function(self):
446        if self.closure:
447            raise NotImplementedError
448        func = types.FunctionType(
449            self.code.as_python_constant(),
450            self.f_globals,
451            self.fn_name.as_python_constant(),
452        )
453        if self.defaults:
454            func.__defaults__ = self.defaults.as_python_constant()
455        if self.kwdefaults:
456            func.__kwdefaults__ = self.kwdefaults.as_python_constant()
457        if self.annotations:
458            annotations = self.annotations.as_python_constant()
459            if isinstance(annotations, tuple):
460                from itertools import pairwise
461
462                annotations = dict(pairwise(annotations))
463
464            # TypeError: __annotations__ must be set to a dict object
465            assert isinstance(annotations, dict)
466            func.__annotations__ = annotations
467        return func
468
469    def has_closure(self):
470        return self.closure is not None
471
472    def has_self(self):
473        return False
474
475    def get_globals(self):
476        return self.f_globals
477
478    def bind_args(self, parent, args, kwargs):
479        from .misc import InlinedClosureVariable
480
481        code = self.get_code()
482        func = types.FunctionType(
483            code,
484            self.f_globals,
485            self.fn_name.as_python_constant(),
486            tuple(self.defaults.items) if self.defaults else None,
487            tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
488        )
489        if self.kwdefaults:
490            func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant()
491        bound = inspect.signature(func).bind(*args, **kwargs)
492        bound.apply_defaults()
493        result = dict(bound.arguments.items())
494        wrap_args_kwargs(parent.output.root_tx, result)
495        closure_cells = init_cellvars(parent, result, code)
496
497        for idx, name in enumerate(code.co_freevars):
498            cell = self.closure.items[idx]
499            assert getattr(cell, name, name) == name
500            assert name not in result
501            if isinstance(cell, InlinedClosureVariable):
502                # InlinedClosureVariable's are created from LOAD_CLOSURE's from
503                # InliningInstructionTranslators when the variable name is not found in closure_cells.
504                # They should remain outside of closure_cells, so that our callee (the
505                # InliningInstructionTranslator that traces `func`) handles
506                # the cell correctly - that is, the cell's contents are treated as if they
507                # are local variables, like in UserFunctionVariable's bind_args for freevars.
508                cand = parent
509                while cand and name not in cand.symbolic_locals:
510                    cand = cand.parent
511                if cand is None:
512                    raise RuntimeError(
513                        f"Couldn't find {name} in the symbolic_locals of the inline interpreter stack"
514                    )
515                result[name] = cand.symbolic_locals[name]
516            else:
517                closure_cells[name] = self.closure.items[idx]
518
519        return result, closure_cells
520
521    def export_freevars(self, parent, child):
522        code = self.get_code()
523        for var in code.co_freevars:
524            if var in child.symbolic_locals:
525                parent.symbolic_locals[var] = child.symbolic_locals[var]
526
527    def reconstruct(self, codegen):
528        codegen.load_import_from(__name__, "_create_nested_fn")
529        codegen(self.code)
530        codegen.extend_output([codegen._create_load_const(self.f_globals)])
531        codegen(ConstantVariable.create(self.code.value.co_name))
532
533        if self.defaults:
534            codegen(self.defaults)
535        else:
536            codegen.extend_output([codegen.create_load_const(None)])
537
538        if self.closure:
539            codegen(self.closure)
540        else:
541            codegen.extend_output([codegen.create_load_const(None)])
542
543        if self.kwdefaults:
544            codegen(self.kwdefaults)
545        else:
546            codegen.extend_output([codegen.create_load_const(None)])
547
548        if self.annotations:
549            try:
550                annotations = self.annotations.as_python_constant()
551                codegen.extend_output([codegen._create_load_const(annotations)])
552            except NotImplementedError:
553                codegen(self.annotations)
554        else:
555            codegen.extend_output([codegen.create_load_const(None)])
556
557        codegen.extend_output(create_call_function(7, push_null=True))
558
559        if self.wrapped_reconstructible:
560            codegen.load_import_from("functools", "wraps")
561            codegen(self.wrapped_reconstructible)
562            codegen.extend_output(create_call_function(1, True))
563            codegen.extend_output(create_rot_n(2))
564            codegen.extend_output(create_call_function(1, True))
565
566
567class SkipFunctionVariable(VariableTracker):
568    _nonvar_fields = {
569        "value",
570        "reason",
571        *VariableTracker._nonvar_fields,
572    }
573
574    def __init__(self, value, reason=None, **kwargs):
575        super().__init__(**kwargs)
576        self.value = value
577        self.reason = reason
578
579    def python_type(self):
580        return type(self.value)
581
582    def as_python_constant(self):
583        return self.value
584
585    @classmethod
586    def create_with_source(cls, value, source):
587        install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
588        return cls(
589            value,
590            source=source,
591        )
592
593    @staticmethod
594    @functools.lru_cache(None)
595    def fold_through_function_to_wrapper():
596        return {
597            collections.namedtuple: variables.UserDefinedClassVariable,
598        }
599
600    def call_function(
601        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
602    ) -> "VariableTracker":
603        if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
604            unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
605        # Fold through the functions(e.g, collections.namedtuple)
606        # that inputs & outputs are all python constants
607        elif (
608            self.value in self.fold_through_function_to_wrapper().keys()
609            and check_constant_args(args, kwargs)
610        ):
611            value = self.value(
612                *[x.as_python_constant() for x in args],
613                **{k: v.as_python_constant() for k, v in kwargs.items()},
614            )
615            return self.fold_through_function_to_wrapper().get(self.value)(
616                value, mutable_local=MutableLocal()
617            )
618        elif (
619            self.value is functools.wraps
620            and not kwargs
621            and len(args) == 1
622            and (
623                args[0].source is not None or args[0].can_reconstruct(tx.output.root_tx)
624            )
625        ):
626
627            def wraps(fn):
628                if isinstance(fn, variables.NestedUserFunctionVariable):
629                    if args[0].source:
630                        reconstructible = args[0].source
631                    else:
632                        reconstructible = args[0]
633                    return fn.clone(wrapped_reconstructible=reconstructible)
634                unimplemented(f"functools.wraps({fn})")
635
636            return variables.LambdaVariable(wraps)
637        else:
638            try:
639                path = inspect.getfile(self.value)
640                msg = f"'skip function {self.value.__qualname__} in file {path}'"
641            except TypeError:
642                known_python_builtin_modules = {"_abc", "_warnings"}
643                if self.value.__module__ in known_python_builtin_modules:
644                    msg = (
645                        f"Graph break due to unsupported Python builtin {self.value.__module__}.{self.value.__qualname__}. "
646                        f"Please file an issue on GitHub "
647                        f"so the PyTorch team can add support for it. "
648                    )
649                else:
650                    msg = (
651                        f"Graph break due to unsupported builtin {self.value.__module__}.{self.value.__qualname__}. "
652                        f"This function is either a Python builtin (e.g. _warnings.warn) "
653                        f"or a third-party C/C++ Python extension (perhaps created with pybind). "
654                        f"If it is a Python builtin, please file an issue on GitHub "
655                        f"so the PyTorch team can add support for it and see the next case for a workaround. "
656                        f"If it is a third-party C/C++ Python extension, please "
657                        f"either wrap it into a PyTorch-understood custom operator "
658                        f"(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html "
659                        f"for more details) or, if it is traceable, use "
660                        f"torch.compiler.allow_in_graph."
661                    )
662                    # also warn on it because most users won't see the graph break message
663                    torch._dynamo.utils.warn_once(msg)
664            msg += f"', {self.reason}'" if self.reason else ""
665            unimplemented(msg)
666
667
668def _traceable_collective_remaps():
669    # We can't rely on importing from distributed, since it's not always built
670    if torch.distributed.is_available():
671        from torch.distributed._functional_collectives import (
672            traceable_collective_remaps,
673        )
674
675        return traceable_collective_remaps
676    return {}
677
678
679def _traceable_collectives_source(tx, fn):
680    assert torch.distributed.is_available(), "Illegal invocation."
681    assert fn in _traceable_collective_remaps().values()
682
683    inner_name = fn.__name__
684    path_source = tx.import_source("torch.distributed._functional_collectives")
685    return AttrSource(path_source, inner_name)
686
687
688class CollectiveFunctionRewriteVariable(UserFunctionVariable):
689    """
690    Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives.
691
692    This class provides both a way to check if a function is remappable, and perform the remapping.
693
694    In the case that a function is 'remappable' but only for some combinations of call-time arguments,
695    we check the args at `call_function` time and fall back to graph-breaking if needed.  This is no worse
696    than status-quo as we currently graph-break on all distributed.* collectives.
697    """
698
699    def __init__(self, fn, *, replacement_var, **kwargs):
700        super().__init__(fn, **kwargs)
701        assert isinstance(replacement_var, UserFunctionVariable)
702        self.replacement_var = replacement_var
703
704    @staticmethod
705    def create(tx, old_fn, source, **options):
706        new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn)
707        return CollectiveFunctionRewriteVariable(
708            old_fn,
709            replacement_var=UserFunctionVariable(new_fn, source=new_source, **options),
710            source=source,
711            **options,
712        )
713
714    @staticmethod
715    def can_rewrite(variable):
716        return (
717            inspect.isfunction(variable) and variable in _traceable_collective_remaps()
718        )
719
720    @staticmethod
721    def rewrite(tx, fn):
722        new_fn = _traceable_collective_remaps()[fn]
723        return new_fn, _traceable_collectives_source(tx, new_fn)
724
725    def call_function(
726        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
727    ) -> "VariableTracker":
728        # call_function must check any unsupported arguments and graph-break.
729        # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
730        # since that's the contract for putting a mapping in `traceable_collective_remaps`
731        import torch.distributed as dist
732        from torch.distributed._functional_collectives import REDUCE_OP_TO_STR
733
734        # Merge args into kwargs so positional and keyword args
735        # can be processed the same way.
736        signature = inspect.signature(self.fn)
737        kwargs = dict(signature.bind(*args, **kwargs).arguments)
738        args = ()
739
740        if "async_op" in kwargs and kwargs["async_op"].as_python_constant():
741            unimplemented(
742                f"CollectiveFunctionRewriteVariable can't support async_op=True for {self.fn}"
743            )
744
745        if self.fn in (
746            dist.all_reduce,
747            dist.reduce_scatter_tensor,
748            dist._reduce_scatter_base,
749        ):
750            reduce_op_var = kwargs.get("op")
751            reduce_op = (
752                reduce_op_var.value
753                if reduce_op_var is not None
754                else signature.parameters["op"].default
755            )
756            if reduce_op not in REDUCE_OP_TO_STR:
757                raise ValueError(f"Unsupported all_reduce op: {reduce_op}")
758            kwargs["op"] = variables.ConstantVariable.create(
759                REDUCE_OP_TO_STR[reduce_op]
760            )
761        return self.replacement_var.call_function(tx, args, kwargs)
762
763
764class FunctoolsPartialVariable(VariableTracker):
765    def __init__(self, func: VariableTracker, args, keywords, **kwargs):
766        super().__init__(**kwargs)
767        self.func = func
768        assert isinstance(args, list)
769        self.args = args
770        assert isinstance(keywords, dict)
771        self.keywords = keywords
772
773    def reconstruct(self, codegen):
774        codegen.load_import_from("functools", "partial")
775        codegen(self.func)
776        if self.args:
777            codegen.foreach(self.args)
778        if not self.keywords:
779            codegen.extend_output(create_call_function(len(self.args) + 1, True))
780            return
781
782        codegen.foreach(self.keywords.values())
783        keys = tuple(self.keywords.keys())
784        codegen.extend_output(
785            codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, True)
786        )
787
788    def get_function(self):
789        return self.as_python_constant()
790
791    def call_function(
792        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
793    ) -> "VariableTracker":
794        merged_args = self.args + args
795        merged_kwargs = {**self.keywords, **kwargs}
796        return self.func.call_function(tx, merged_args, merged_kwargs)
797
798    def call_hasattr(self, tx, name: str) -> VariableTracker:
799        # functools.partial uses slots, so attributes are constant
800        return variables.ConstantVariable.create(
801            hasattr(functools.partial(identity), name)
802        )
803
804    def as_python_constant(self):
805        return functools.partial(
806            self.func.as_python_constant(),
807            *[arg.as_python_constant() for arg in self.args],
808            **{k: v.as_python_constant() for k, v in self.keywords.items()},
809        )
810
811    def guard_as_python_constant(self):
812        """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
813        return functools.partial(
814            self.func.guard_as_python_constant(),
815            *[v.guard_as_python_constant() for v in self.args],
816            **{k: v.guard_as_python_constant() for k, v in self.keywords.items()},
817        )
818
819
820class TritonKernelVariable(VariableTracker):
821    def __init__(self, kernel, kernel_idx, grid, **kwargs):
822        from triton.runtime.autotuner import Autotuner
823
824        from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
825
826        super().__init__(**kwargs)
827
828        assert kernel is not None
829
830        self.kernel = kernel
831        self.kernel_idx = kernel_side_table.add_kernel(kernel)
832
833        assert kernel_idx is None or self.kernel_idx == kernel_idx
834
835        self.grid = grid
836
837        if isinstance(kernel, Autotuner):
838            # We only support configs and keys arguments of triton.autotune
839            # Make sure other arguments are defaulted
840            defaults = inspect.signature(Autotuner.__init__).parameters
841
842            # Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep.
843            # The call to get_first_attr is to maintain backward-compatibility.
844            if (
845                (
846                    "warmup" in defaults
847                    and defaults["warmup"].default
848                    != get_first_attr(kernel, "num_warmups", "warmup")
849                )
850                or (
851                    "rep" in defaults
852                    and defaults["rep"].default
853                    != get_first_attr(kernel, "num_reps", "rep")
854                )
855                or (
856                    "prune_configs_by" in defaults
857                    and defaults["prune_configs_by"].default
858                    != kernel.early_config_prune
859                )
860                # Set via reset_to_zero argument
861                or len(kernel.reset_idx) != 0
862                or len(kernel.restore_idx) != 0
863            ):
864                raise Unsupported(
865                    "Only configs and keys are supported for triton.autotune"
866                )
867
868    def call_function(
869        self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
870    ) -> "VariableTracker":
871        from triton.runtime.autotuner import autotune, Autotuner, Config
872
873        from .constant import ConstantVariable
874        from .dicts import ConstDictVariable
875        from .lists import BaseListVariable
876
877        if "num_ctas" in kwargs:
878            raise Unsupported(
879                "Passing num_ctas directly to the Triton kernel is not supported. "
880                "Please use a Config in @triton.autotune instead."
881            )
882
883        special_kwargs = {}
884        for name in ("num_warps", "num_stages"):
885            if name in kwargs:
886                # remove special kwargs from `kwargs`
887                val = kwargs.pop(name)
888                assert isinstance(val, ConstantVariable)
889                special_kwargs[name] = val.value
890
891        if special_kwargs:
892            if isinstance(self.kernel, Autotuner):
893                # if there is Autotuner already, set
894                # special kwargs to each of its configs
895                new_configs = copy.deepcopy(self.kernel.configs)
896                for config in new_configs:
897                    config.__dict__.update(special_kwargs)
898                new_kernel = autotune(configs=new_configs, key=[])(self.kernel.fn)
899            else:
900                # if there is no Autotuner, wrap the kernel into a
901                # new one with a single config with special kwargs
902                new_config = Config(kwargs={}, **special_kwargs)
903                new_kernel = autotune(configs=[new_config], key=[])(self.kernel)
904
905            # create a new variable to contain the new (wrapped) kernel;
906            # skip kernel_idx to get a new record in the kernel side table
907            new_var = TritonKernelVariable(new_kernel, None, self.grid)
908            return new_var.call_function(tx, args, kwargs)
909
910        if self.grid is None:
911            raise Unsupported("Triton kernels should always be called with a grid")
912
913        # Both for grid's meta as well as for the kernel, we need combined
914        # args and kwargs combined and normalized
915        combined_args_raw = {**dict(zip(self.kernel.arg_names, args)), **kwargs}
916        combined_args = {
917            variables.ConstantVariable.create(k): v
918            for k, v in combined_args_raw.items()
919        }
920
921        configs = (
922            [config.kwargs for config in self.kernel.configs]
923            if isinstance(self.kernel, Autotuner)
924            else [{}]
925        )
926        grids = []
927        for config_args in configs:
928            # If the grid is a function, then lets execute it and convert it to
929            # a list
930            grid = self.grid
931            if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)):
932                # Populate the special "meta" argument to call the grid function
933                config_args = {
934                    ConstantVariable.create(k): ConstantVariable.create(v)
935                    for k, v in config_args.items()
936                }
937                meta = ConstDictVariable({**combined_args, **config_args}, dict)
938                grid = grid.call_function(tx, [meta], {})
939
940            # Now, the grid must be a list either originally or through above
941            # modification
942            if isinstance(grid, BaseListVariable):
943                grids.append(grid.as_proxy())
944            else:
945                unimplemented(f"grid for the triton kernel is {type(grid)}")
946
947        for i in range(len(grids)):
948            if not isinstance(grids[i], tuple):
949                raise Unsupported("Only tuple grids are supported")
950            # inductor expects all grids to be 3-tuple so lets make it
951            if len(grids[i]) == 1:
952                grids[i] = (grids[i][0], 1, 1)
953            elif len(grids[i]) == 2:
954                grids[i] = (grids[i][0], grids[i][1], 1)
955            elif len(grids[i]) > 3:
956                raise Unsupported("Grid can have at most rank 3")
957
958        assert len(grids) != 0
959        if len(set(grids)) == 1:
960            # If there's only one unique grid, lets simplify
961            grids = [grids[0]]
962
963        from torch._higher_order_ops.triton_kernel_wrap import (
964            kernel_side_table,
965            triton_kernel_wrapper_mutation,
966        )
967
968        # Combine args and kwargs and pass as a dict so that if user defined triton
969        # kernel uses variables as 'grid' or 'kernel', it does not conflict with
970        # parameters of the wrapper function
971        constant_args = {
972            k: v.as_python_constant()
973            for k, v in combined_args_raw.items()
974            if isinstance(v, ConstantVariable)
975        }
976        non_constant_args = {
977            k: v
978            for k, v in combined_args.items()
979            if not isinstance(v, ConstantVariable)
980        }
981
982        constant_args_idx = kernel_side_table.add_constant_args(constant_args)
983        meta = ConstDictVariable(non_constant_args, dict)
984        tx.output.create_proxy(
985            "call_function",
986            triton_kernel_wrapper_mutation,
987            (),
988            {
989                "kernel_idx": self.kernel_idx,
990                "constant_args_idx": constant_args_idx,
991                "grid": grids,
992                "kwargs": meta.as_proxy(),
993            },
994        )
995
996        return variables.ConstantVariable(
997            None,
998        )
999
1000    def call_method(
1001        self,
1002        tx,
1003        name,
1004        args: "List[VariableTracker]",
1005        kwargs: "Dict[str, VariableTracker]",
1006    ) -> "VariableTracker":
1007        if name == "__getitem__":
1008            # __getitem__ should only be called if we don't already have a grid
1009            # Only grid needs to be passed
1010            if self.grid is not None or len(args) != 1:
1011                raise Unsupported(
1012                    "Triton kernels should be called with only a single grid"
1013                )
1014
1015            return TritonKernelVariable(
1016                kernel=self.kernel,
1017                kernel_idx=self.kernel_idx,
1018                grid=args[0],
1019            )
1020        elif name == "run":
1021            if "grid" not in kwargs:
1022                raise Unsupported("Triton kernel requires to be called with a grid")
1023            grid = kwargs.pop("grid")
1024            kwargs.pop("warmup", None)
1025            # rewrite kernel.run(*args, grid=grid) to kernel[grid](*args)
1026            return TritonKernelVariable(
1027                kernel=self.kernel, kernel_idx=self.kernel_idx, grid=grid
1028            ).call_function(tx, args, kwargs)
1029
1030        # Bail out to parent's implementation
1031        return super().call_method(tx, name, args, kwargs)
1032