xref: /aosp_15_r20/external/pytorch/torchgen/api/dispatcher.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import itertools
4from typing import Sequence
5
6from torchgen.api import cpp
7from torchgen.api.types import ArgName, Binding, CType, NamedCType
8from torchgen.model import (
9    Argument,
10    FunctionSchema,
11    Return,
12    SelfArgument,
13    TensorOptionsArguments,
14    Type,
15)
16from torchgen.utils import assert_never, concatMap
17
18
19# This file describes the translation of JIT schema to the dispatcher
20# API, the *unboxed* calling convention by which invocations through
21# the dispatcher are made.  Historically, the dispatcher API matched
22# the C++ API, but with the establishment of the boxed API, we've
23# made changes to the dispatcher API to so that the unboxed API
24# better aligns with the boxed API.  The dispatcher API hooks heavily
25# into our template based boxing/unboxing machinery, so changes
26# to this convention will usually need template updates too.
27#
28# Prominent characteristics of the dispatcher API:
29#
30#   - dtype, layout, device and pin_memory are represented as separate
31#     arguments.
32#
33
34
35def name(func: FunctionSchema) -> str:
36    return cpp.name(func)
37
38
39def argumenttype_type(
40    t: Type,
41    *,
42    mutable: bool,
43    binds: ArgName,
44    remove_non_owning_ref_types: bool = False,
45    symint: bool = True,
46) -> NamedCType:
47    # This is a faux amis.  If it makes sense in the future to add
48    # more special cases here, or invert things so cpp.argument_type
49    # calls this, or just completely inline the function, please do
50    # it.
51    return cpp.argumenttype_type(
52        t,
53        mutable=mutable,
54        binds=binds,
55        symint=symint,
56        remove_non_owning_ref_types=remove_non_owning_ref_types,
57    )
58
59
60def argument_type(
61    a: Argument,
62    *,
63    binds: ArgName,
64    remove_non_owning_ref_types: bool = False,
65    symint: bool = True,
66) -> NamedCType:
67    return argumenttype_type(
68        a.type,
69        mutable=a.is_write,
70        binds=binds,
71        remove_non_owning_ref_types=remove_non_owning_ref_types,
72        symint=symint,
73    )
74
75
76def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
77    # At present, there is no difference. But there could be!
78    return cpp.returns_type(rs, symint=symint)
79
80
81def jit_arguments(func: FunctionSchema) -> list[Argument]:
82    def to_argument(
83        a: Argument | TensorOptionsArguments | SelfArgument,
84    ) -> list[Argument]:
85        if isinstance(a, Argument):
86            return [a]
87        elif isinstance(a, SelfArgument):
88            return [a.argument]
89        elif isinstance(a, TensorOptionsArguments):
90            return [a.dtype, a.layout, a.device, a.pin_memory]
91        else:
92            assert_never(a)
93
94    return list(
95        concatMap(
96            to_argument,
97            itertools.chain(
98                func.arguments.positional, func.arguments.kwarg_only, func.arguments.out
99            ),
100        )
101    )
102
103
104def argument(
105    a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True
106) -> Binding:
107    return Binding(
108        nctype=argument_type(
109            a,
110            binds=a.name,
111            remove_non_owning_ref_types=remove_non_owning_ref_types,
112            symint=symint,
113        ),
114        name=a.name,
115        argument=a,
116    )
117
118
119def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]:
120    return [argument(a, symint=symint) for a in jit_arguments(func)]
121