xref: /aosp_15_r20/external/pytorch/torchgen/executorch/model.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Represents all kernels used by an Executorch model.
2# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.
3
4from __future__ import annotations
5
6import itertools
7from collections import defaultdict, namedtuple
8from dataclasses import dataclass
9from enum import IntEnum
10
11from torchgen.model import (
12    BackendIndex,
13    BackendMetadata,
14    DispatchKey,
15    NativeFunction,
16    NativeFunctionsGroup,
17    OperatorName,
18)
19from torchgen.utils import assert_never
20
21
22KERNEL_KEY_VERSION = 1
23
24
25# TODO: Duplicated Subset from codegen.tool.gen_oplist, remove declaration in codegen
26class ScalarType(IntEnum):
27    Byte = 0
28    Char = 1
29    Short = 2
30    Int = 3
31    Long = 4
32    Float = 6
33    Double = 7
34    Bool = 11
35
36
37ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "kernel_index"])
38
39
40@dataclass(frozen=True)
41class ETKernelKeyOpArgMeta:
42    arg_name: str
43    dtype: str
44    # The order of the dimensions if entry is a Tensor
45    dim_order: tuple[int, ...]
46
47    def to_native_string(self) -> str:
48        dtype_str = ScalarType[self.dtype].value
49        dim_str = str(self.dim_order)[1:-1].replace(" ", "")
50        return f"{dtype_str};{dim_str}"
51
52
53@dataclass(frozen=True)
54class ETKernelKey:
55    # Field undefined is default = True
56    arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = ()
57
58    # Indicator for this kernel being used as a catch all
59    default: bool = False
60
61    version: int = KERNEL_KEY_VERSION
62
63    @staticmethod
64    def gen_from_yaml(
65        args: dict[str, tuple[str, str]],
66        type_alias_map: dict[str, list[str]],  # TODO: Support unwrapped str val
67        dim_order_alias_map: dict[str, list[int]],
68    ) -> list[ETKernelKey]:
69        """Generate ETKernelKeys from arg kernel specs
70        Multiple ETKernelKeys are returned due to dtype permutations from utilizing
71        type_alias_map (actualizing each potential type permutation as a KernelKey)
72
73        Args:
74            args: Mapping from argument name to kernel specs
75                Kernel specs are a tuple of (dtype, dim_order).
76                Currently tuple entries must be aliased via the alias map arguments
77            type_alias_map: Mapping from type alias to potential type enums
78                i.e { T0 : [Double, Int] } means T0 can be either Double or Int
79                Used for lookup by args
80            dim_order_alias_map: Mapping from alias to a list of dimension orders
81                Used for lookup by args
82        """
83        # Cast to dim order to int
84        dim_order_alias_map = {
85            k: [int(alias) for alias in v] for k, v in dim_order_alias_map.items()
86        }
87        kernel_keys = []
88
89        # Get all used Dtype Alias
90        dtype_alias_used = set()
91        for type_alias, dim_order in args.values():
92            # Enforce usage of alias initially
93            # TODO: Support inlined arguments
94            assert type_alias in type_alias_map, "Undefined type alias: " + str(
95                type_alias
96            )
97            assert (
98                dim_order in dim_order_alias_map
99            ), "Undefined dim_order alias: " + str(dim_order)
100            dtype_alias_used.add(type_alias)
101
102        # Generate all permutations of dtype alias values
103        alias_dtypes = [
104            [(alias, dtype) for dtype in type_alias_map[alias]]
105            for alias in dtype_alias_used
106        ]
107        alias_permutations = [
108            dict(permutation) for permutation in list(itertools.product(*alias_dtypes))
109        ]
110
111        # Using each alias value permutation, generate kernel keys
112        op_arg_cache = {}
113        for permutation in alias_permutations:
114            arg_list = []
115            for arg_name, arg_spec in args.items():
116                dtype = permutation[arg_spec[0]]
117                dim_order = dim_order_alias_map[arg_spec[1]]  # type: ignore[assignment]
118                if (
119                    cache_key := (arg_name, dtype, tuple(dim_order))
120                ) not in op_arg_cache:
121                    op_arg_cache[cache_key] = ETKernelKeyOpArgMeta(*cache_key)  # type: ignore[arg-type]
122
123                arg_list.append(op_arg_cache[cache_key])
124            kernel_keys.append(ETKernelKey(tuple(arg_list)))
125
126        return kernel_keys
127
128    def to_native_string(self) -> str:
129        if self.default:
130            return "default"
131        return (
132            "v"
133            + str(KERNEL_KEY_VERSION)
134            + "/"
135            + "|".join([arg.to_native_string() for arg in self.arg_meta])
136        )
137
138
139@dataclass(frozen=True)
140class ETKernelIndex:
141    index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]]
142
143    def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
144        m = self.get_kernels(g)
145        return m is not None
146
147    def get_kernels(
148        self, g: NativeFunction | NativeFunctionsGroup
149    ) -> dict[ETKernelKey, BackendMetadata]:
150        if isinstance(g, NativeFunction):
151            f = g
152        elif isinstance(g, NativeFunctionsGroup):
153            f = g.functional
154        else:
155            assert_never(g)
156        if f.func.name not in self.index:
157            return {}
158        return self.index[f.func.name]
159
160    @staticmethod
161    def grow_from_backend_indices(
162        kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]],
163        backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
164    ) -> None:
165        for dk in backend_indices:
166            index = backend_indices[dk]
167            for op, backend_metadata in index.items():
168                if op in kernel_index:
169                    kernel_index[op][ETKernelKey(default=True)] = backend_metadata
170                else:
171                    kernel_index[op] = {ETKernelKey(default=True): backend_metadata}
172
173    @staticmethod
174    def from_backend_indices(
175        backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
176    ) -> ETKernelIndex:
177        kernel_index: dict[
178            OperatorName, dict[ETKernelKey, BackendMetadata]
179        ] = defaultdict(dict)
180        ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
181        return ETKernelIndex(kernel_index)
182
183    def grow(
184        self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
185    ) -> ETKernelIndex:
186        ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
187        return self
188
189    def _to_backend_index(self) -> BackendIndex:
190        """
191        WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
192        """
193        index: dict[OperatorName, BackendMetadata] = {}
194        for op in self.index:
195            kernel_dict = self.index[op]
196            assert (
197                len(kernel_dict.values()) == 1
198            ), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}"
199            index[op] = kernel_dict.get(
200                ETKernelKey(default=True),
201                BackendMetadata(kernel="", structured=False, cpp_namespace=""),
202            )
203        return BackendIndex(
204            dispatch_key=DispatchKey.CPU,
205            use_out_as_primary=False,
206            device_guard=False,
207            external=False,
208            index=index,
209        )
210
211    # Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
212    @staticmethod
213    def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex:
214        combined = defaultdict(dict, index_a.index.copy())
215
216        for op, entry in index_b.index.items():
217            for key, metadata in entry.items():
218                combined[op][key] = metadata
219
220        return ETKernelIndex(combined)
221