xref: /aosp_15_r20/external/executorch/backends/vulkan/op_registry.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport operator
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Dict, Optional, Set, Union
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerimport executorch.backends.vulkan.custom_ops_lib  # noqa
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Workerimport torch
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.vulkan.serialization.vulkan_graph_schema import (
18*523fa7a6SAndroid Build Coastguard Worker    VkMemoryLayout,
19*523fa7a6SAndroid Build Coastguard Worker    VkStorageType,
20*523fa7a6SAndroid Build Coastguard Worker)
21*523fa7a6SAndroid Build Coastguard Worker
22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.vulkan.utils import (
23*523fa7a6SAndroid Build Coastguard Worker    all_memory_layouts,
24*523fa7a6SAndroid Build Coastguard Worker    all_packed_dims,
25*523fa7a6SAndroid Build Coastguard Worker    PackedDim,
26*523fa7a6SAndroid Build Coastguard Worker)
27*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverload
30*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensor
31*523fa7a6SAndroid Build Coastguard Worker
32*523fa7a6SAndroid Build Coastguard Worker######################
33*523fa7a6SAndroid Build Coastguard Worker## OpFeatures class ##
34*523fa7a6SAndroid Build Coastguard Worker######################
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Workerdef allow_node(node: torch.fx.Node) -> bool:
38*523fa7a6SAndroid Build Coastguard Worker    return True
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Workerclass TextureImplFeatures:
42*523fa7a6SAndroid Build Coastguard Worker    __slots__ = [
43*523fa7a6SAndroid Build Coastguard Worker        "valid_packed_dims",
44*523fa7a6SAndroid Build Coastguard Worker        "uses_axis_map",
45*523fa7a6SAndroid Build Coastguard Worker    ]
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Worker    def __init__(
48*523fa7a6SAndroid Build Coastguard Worker        self,
49*523fa7a6SAndroid Build Coastguard Worker        uses_axis_map: bool = False,
50*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims: Optional[Set[PackedDim]] = None,
51*523fa7a6SAndroid Build Coastguard Worker    ):
52*523fa7a6SAndroid Build Coastguard Worker        self.uses_axis_map: bool = uses_axis_map
53*523fa7a6SAndroid Build Coastguard Worker        self.valid_packed_dims = set()
54*523fa7a6SAndroid Build Coastguard Worker        if valid_packed_dims is not None:
55*523fa7a6SAndroid Build Coastguard Worker            self.valid_packed_dims = valid_packed_dims
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker    def valid_memory_layouts(self) -> Set[VkMemoryLayout]:
58*523fa7a6SAndroid Build Coastguard Worker        """
59*523fa7a6SAndroid Build Coastguard Worker        Derive the set of memory layouts supported by the texture implementation based
60*523fa7a6SAndroid Build Coastguard Worker        on the valid packed dimensions.
61*523fa7a6SAndroid Build Coastguard Worker        """
62*523fa7a6SAndroid Build Coastguard Worker        layouts = set()
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker        if PackedDim.WIDTH in self.valid_packed_dims:
65*523fa7a6SAndroid Build Coastguard Worker            layouts.add(VkMemoryLayout.TENSOR_WIDTH_PACKED)
66*523fa7a6SAndroid Build Coastguard Worker
67*523fa7a6SAndroid Build Coastguard Worker        if PackedDim.HEIGHT in self.valid_packed_dims:
68*523fa7a6SAndroid Build Coastguard Worker            layouts.add(VkMemoryLayout.TENSOR_HEIGHT_PACKED)
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker        if PackedDim.CHANNELS in self.valid_packed_dims:
71*523fa7a6SAndroid Build Coastguard Worker            layouts.add(VkMemoryLayout.TENSOR_CHANNELS_PACKED)
72*523fa7a6SAndroid Build Coastguard Worker
73*523fa7a6SAndroid Build Coastguard Worker        return layouts
74*523fa7a6SAndroid Build Coastguard Worker
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Workerclass OpFeatures:
77*523fa7a6SAndroid Build Coastguard Worker    __slots__ = [
78*523fa7a6SAndroid Build Coastguard Worker        # None or TextureImplFeatures to specify implementation details of the texture
79*523fa7a6SAndroid Build Coastguard Worker        # based operator implementation.
80*523fa7a6SAndroid Build Coastguard Worker        "texture_impl",
81*523fa7a6SAndroid Build Coastguard Worker        # bool indicating if the operator has a buffer based implementation.
82*523fa7a6SAndroid Build Coastguard Worker        "buffer_impl",
83*523fa7a6SAndroid Build Coastguard Worker        # bool indicating if the operator has a resize function, which allows it to
84*523fa7a6SAndroid Build Coastguard Worker        # support dynamic shape tensors.
85*523fa7a6SAndroid Build Coastguard Worker        "resize_fn",
86*523fa7a6SAndroid Build Coastguard Worker        # Optimal
87*523fa7a6SAndroid Build Coastguard Worker        "optimal_storage",
88*523fa7a6SAndroid Build Coastguard Worker        "optimal_layout",
89*523fa7a6SAndroid Build Coastguard Worker        # bool indicating if the operator handles its own prepacking. If this is True,
90*523fa7a6SAndroid Build Coastguard Worker        # then the insert_prepack_nodes pass will not insert prepack nodes for the args
91*523fa7a6SAndroid Build Coastguard Worker        # of the op.
92*523fa7a6SAndroid Build Coastguard Worker        "handles_own_prepacking",
93*523fa7a6SAndroid Build Coastguard Worker        # Optional dictionary to specify a custom function to calculate the required
94*523fa7a6SAndroid Build Coastguard Worker        # image extents for a particular argument index.
95*523fa7a6SAndroid Build Coastguard Worker        "skip_limits_check",
96*523fa7a6SAndroid Build Coastguard Worker        # Optional check function used during partitioning to determine if a node's
97*523fa7a6SAndroid Build Coastguard Worker        # inputs are supported by the operator implementation.
98*523fa7a6SAndroid Build Coastguard Worker        "check_node_fn",
99*523fa7a6SAndroid Build Coastguard Worker    ]
100*523fa7a6SAndroid Build Coastguard Worker
101*523fa7a6SAndroid Build Coastguard Worker    def __init__(
102*523fa7a6SAndroid Build Coastguard Worker        self,
103*523fa7a6SAndroid Build Coastguard Worker        texture_impl: Optional[TextureImplFeatures] = None,
104*523fa7a6SAndroid Build Coastguard Worker        buffer_impl: bool = False,
105*523fa7a6SAndroid Build Coastguard Worker        resize_fn: bool = False,
106*523fa7a6SAndroid Build Coastguard Worker        optimal_storage: Optional[VkStorageType] = None,
107*523fa7a6SAndroid Build Coastguard Worker        optimal_layout: Optional[VkMemoryLayout] = None,
108*523fa7a6SAndroid Build Coastguard Worker        handles_own_prepacking: bool = False,
109*523fa7a6SAndroid Build Coastguard Worker        skip_limits_check: Optional[Set[int]] = None,
110*523fa7a6SAndroid Build Coastguard Worker        check_node_fn: Optional[Callable] = None,
111*523fa7a6SAndroid Build Coastguard Worker    ):
112*523fa7a6SAndroid Build Coastguard Worker        self.texture_impl: Optional[TextureImplFeatures] = texture_impl
113*523fa7a6SAndroid Build Coastguard Worker        self.buffer_impl: bool = buffer_impl
114*523fa7a6SAndroid Build Coastguard Worker        self.resize_fn: bool = resize_fn
115*523fa7a6SAndroid Build Coastguard Worker        self.optimal_storage: Optional[VkStorageType] = optimal_storage
116*523fa7a6SAndroid Build Coastguard Worker        self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout
117*523fa7a6SAndroid Build Coastguard Worker        self.handles_own_prepacking: bool = handles_own_prepacking
118*523fa7a6SAndroid Build Coastguard Worker
119*523fa7a6SAndroid Build Coastguard Worker        self.skip_limits_check: Set[int] = set()
120*523fa7a6SAndroid Build Coastguard Worker        if skip_limits_check is not None:
121*523fa7a6SAndroid Build Coastguard Worker            self.skip_limits_check = skip_limits_check
122*523fa7a6SAndroid Build Coastguard Worker
123*523fa7a6SAndroid Build Coastguard Worker        self.check_node_fn: Callable = allow_node
124*523fa7a6SAndroid Build Coastguard Worker        if check_node_fn is not None:
125*523fa7a6SAndroid Build Coastguard Worker            self.check_node_fn = check_node_fn
126*523fa7a6SAndroid Build Coastguard Worker
127*523fa7a6SAndroid Build Coastguard Worker    def propose_storage_type(self) -> Optional[VkStorageType]:
128*523fa7a6SAndroid Build Coastguard Worker        """
129*523fa7a6SAndroid Build Coastguard Worker        Propose a storage type that should be used for this operator. A proposal can be
130*523fa7a6SAndroid Build Coastguard Worker        made if one of the following is true:
131*523fa7a6SAndroid Build Coastguard Worker        1. The operator specifies an optimal storage type
132*523fa7a6SAndroid Build Coastguard Worker        2. Only one storage type is supported.
133*523fa7a6SAndroid Build Coastguard Worker
134*523fa7a6SAndroid Build Coastguard Worker        If both storage types are supported and no optimal storage type is specified,
135*523fa7a6SAndroid Build Coastguard Worker        then None is returned to indicate that there is no preference in storage type.
136*523fa7a6SAndroid Build Coastguard Worker        """
137*523fa7a6SAndroid Build Coastguard Worker        if self.optimal_storage is not None:
138*523fa7a6SAndroid Build Coastguard Worker            return self.optimal_storage
139*523fa7a6SAndroid Build Coastguard Worker
140*523fa7a6SAndroid Build Coastguard Worker        if self.texture_impl is not None and not self.buffer_impl:
141*523fa7a6SAndroid Build Coastguard Worker            return VkStorageType.TEXTURE_3D
142*523fa7a6SAndroid Build Coastguard Worker        elif self.buffer_impl and self.texture_impl is None:
143*523fa7a6SAndroid Build Coastguard Worker            return VkStorageType.BUFFER
144*523fa7a6SAndroid Build Coastguard Worker
145*523fa7a6SAndroid Build Coastguard Worker        return None
146*523fa7a6SAndroid Build Coastguard Worker
147*523fa7a6SAndroid Build Coastguard Worker    def supported_storage_types(self) -> Set[VkStorageType]:
148*523fa7a6SAndroid Build Coastguard Worker        """
149*523fa7a6SAndroid Build Coastguard Worker        Return the set of storage types supported by this operator.
150*523fa7a6SAndroid Build Coastguard Worker        """
151*523fa7a6SAndroid Build Coastguard Worker        storage_types = set()
152*523fa7a6SAndroid Build Coastguard Worker        if self.texture_impl is not None:
153*523fa7a6SAndroid Build Coastguard Worker            storage_types.add(VkStorageType.TEXTURE_3D)
154*523fa7a6SAndroid Build Coastguard Worker        if self.buffer_impl:
155*523fa7a6SAndroid Build Coastguard Worker            storage_types.add(VkStorageType.BUFFER)
156*523fa7a6SAndroid Build Coastguard Worker
157*523fa7a6SAndroid Build Coastguard Worker        return storage_types
158*523fa7a6SAndroid Build Coastguard Worker
159*523fa7a6SAndroid Build Coastguard Worker    def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]:
160*523fa7a6SAndroid Build Coastguard Worker        """
161*523fa7a6SAndroid Build Coastguard Worker        Given a storage type as a precondition, propose a memory layout that should be
162*523fa7a6SAndroid Build Coastguard Worker        used for this operator. A proposal can be made if one of the following is true:
163*523fa7a6SAndroid Build Coastguard Worker        1. The operator specifies an optimal memory layout
164*523fa7a6SAndroid Build Coastguard Worker        2. Only one memory layout is supported.
165*523fa7a6SAndroid Build Coastguard Worker
166*523fa7a6SAndroid Build Coastguard Worker        If multiple memory layouts are supported and no optimal memory layout is
167*523fa7a6SAndroid Build Coastguard Worker        specified then return None to indicate that the "best" memory layout for the
168*523fa7a6SAndroid Build Coastguard Worker        operator is ambiguous.
169*523fa7a6SAndroid Build Coastguard Worker        """
170*523fa7a6SAndroid Build Coastguard Worker        if self.optimal_layout is not None:
171*523fa7a6SAndroid Build Coastguard Worker            return self.optimal_layout
172*523fa7a6SAndroid Build Coastguard Worker
173*523fa7a6SAndroid Build Coastguard Worker        if storage == VkStorageType.TEXTURE_3D:
174*523fa7a6SAndroid Build Coastguard Worker            assert self.texture_impl is not None
175*523fa7a6SAndroid Build Coastguard Worker            possible_layouts = self.texture_impl.valid_memory_layouts()
176*523fa7a6SAndroid Build Coastguard Worker            if len(possible_layouts) == 1:
177*523fa7a6SAndroid Build Coastguard Worker                return next(iter(possible_layouts))
178*523fa7a6SAndroid Build Coastguard Worker
179*523fa7a6SAndroid Build Coastguard Worker        return None
180*523fa7a6SAndroid Build Coastguard Worker
181*523fa7a6SAndroid Build Coastguard Worker    def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout]:
182*523fa7a6SAndroid Build Coastguard Worker        """
183*523fa7a6SAndroid Build Coastguard Worker        Return the set of memory layouts supported by this operator for a given storage
184*523fa7a6SAndroid Build Coastguard Worker        type.
185*523fa7a6SAndroid Build Coastguard Worker        """
186*523fa7a6SAndroid Build Coastguard Worker        if storage == VkStorageType.TEXTURE_3D:
187*523fa7a6SAndroid Build Coastguard Worker            assert self.texture_impl is not None
188*523fa7a6SAndroid Build Coastguard Worker            return self.texture_impl.valid_memory_layouts()
189*523fa7a6SAndroid Build Coastguard Worker        else:
190*523fa7a6SAndroid Build Coastguard Worker            return all_memory_layouts
191*523fa7a6SAndroid Build Coastguard Worker
192*523fa7a6SAndroid Build Coastguard Worker
193*523fa7a6SAndroid Build Coastguard Worker#######################
194*523fa7a6SAndroid Build Coastguard Worker## Operator Registry ##
195*523fa7a6SAndroid Build Coastguard Worker#######################
196*523fa7a6SAndroid Build Coastguard Worker
197*523fa7a6SAndroid Build Coastguard WorkerOpKey = Union[str, torch._ops.OpOverload, EdgeOpOverload]
198*523fa7a6SAndroid Build Coastguard Worker
199*523fa7a6SAndroid Build Coastguard Workervulkan_supported_ops: Dict[OpKey, OpFeatures] = {}
200*523fa7a6SAndroid Build Coastguard Worker
201*523fa7a6SAndroid Build Coastguard Worker
202*523fa7a6SAndroid Build Coastguard Workerdef update_features(aten_op):
203*523fa7a6SAndroid Build Coastguard Worker    def features_decorator(fn: Callable):
204*523fa7a6SAndroid Build Coastguard Worker        def update_features_impl(op: OpKey):
205*523fa7a6SAndroid Build Coastguard Worker            if op in vulkan_supported_ops:
206*523fa7a6SAndroid Build Coastguard Worker                raise RuntimeError(f"[Vulkan delegate] duplicate registration of {op}!")
207*523fa7a6SAndroid Build Coastguard Worker            vulkan_supported_ops[op] = OpFeatures()
208*523fa7a6SAndroid Build Coastguard Worker            vulkan_supported_ops[op] = fn(vulkan_supported_ops[op])
209*523fa7a6SAndroid Build Coastguard Worker
210*523fa7a6SAndroid Build Coastguard Worker        if isinstance(aten_op, list):
211*523fa7a6SAndroid Build Coastguard Worker            for op in aten_op:
212*523fa7a6SAndroid Build Coastguard Worker                update_features_impl(op)
213*523fa7a6SAndroid Build Coastguard Worker        else:
214*523fa7a6SAndroid Build Coastguard Worker            update_features_impl(aten_op)
215*523fa7a6SAndroid Build Coastguard Worker
216*523fa7a6SAndroid Build Coastguard Worker        return fn
217*523fa7a6SAndroid Build Coastguard Worker
218*523fa7a6SAndroid Build Coastguard Worker    return features_decorator
219*523fa7a6SAndroid Build Coastguard Worker
220*523fa7a6SAndroid Build Coastguard Worker
221*523fa7a6SAndroid Build Coastguard Worker@update_features(
222*523fa7a6SAndroid Build Coastguard Worker    [
223*523fa7a6SAndroid Build Coastguard Worker        operator.getitem,
224*523fa7a6SAndroid Build Coastguard Worker        # Quantization related ops will be fused via graph passes
225*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
226*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
227*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
228*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
229*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
230*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
231*523fa7a6SAndroid Build Coastguard Worker    ]
232*523fa7a6SAndroid Build Coastguard Worker)
233*523fa7a6SAndroid Build Coastguard Workerdef register_ephemeral_op(features: OpFeatures):
234*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
235*523fa7a6SAndroid Build Coastguard Worker        uses_axis_map=True,
236*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims=all_packed_dims,
237*523fa7a6SAndroid Build Coastguard Worker    )
238*523fa7a6SAndroid Build Coastguard Worker    features.buffer_impl = True
239*523fa7a6SAndroid Build Coastguard Worker    features.resize_fn = True
240*523fa7a6SAndroid Build Coastguard Worker    return features
241*523fa7a6SAndroid Build Coastguard Worker
242*523fa7a6SAndroid Build Coastguard Worker
243*523fa7a6SAndroid Build Coastguard Worker@update_features(
244*523fa7a6SAndroid Build Coastguard Worker    [
245*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.add.Tensor,
246*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.sub.Tensor,
247*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.minimum.default,
248*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.mul.Tensor,
249*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.div.Tensor,
250*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.div.Tensor_mode,
251*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.pow.Tensor_Tensor,
252*523fa7a6SAndroid Build Coastguard Worker    ]
253*523fa7a6SAndroid Build Coastguard Worker)
254*523fa7a6SAndroid Build Coastguard Workerdef register_binary_op(features: OpFeatures):
255*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
256*523fa7a6SAndroid Build Coastguard Worker        uses_axis_map=True,
257*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims=all_packed_dims,
258*523fa7a6SAndroid Build Coastguard Worker    )
259*523fa7a6SAndroid Build Coastguard Worker    features.resize_fn = True
260*523fa7a6SAndroid Build Coastguard Worker    return features
261*523fa7a6SAndroid Build Coastguard Worker
262*523fa7a6SAndroid Build Coastguard Worker
263*523fa7a6SAndroid Build Coastguard Worker@update_features(
264*523fa7a6SAndroid Build Coastguard Worker    [
265*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.abs.default,
266*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.clamp.default,
267*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.cos.default,
268*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.exp.default,
269*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.gelu.default,
270*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.hardshrink.default,
271*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.hardtanh.default,
272*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.neg.default,
273*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.relu.default,
274*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.sigmoid.default,
275*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.sin.default,
276*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.sqrt.default,
277*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.rsqrt.default,
278*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.tanh.default,
279*523fa7a6SAndroid Build Coastguard Worker    ]
280*523fa7a6SAndroid Build Coastguard Worker)
281*523fa7a6SAndroid Build Coastguard Workerdef register_unary_op(features: OpFeatures):
282*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
283*523fa7a6SAndroid Build Coastguard Worker        uses_axis_map=True,
284*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims=all_packed_dims,
285*523fa7a6SAndroid Build Coastguard Worker    )
286*523fa7a6SAndroid Build Coastguard Worker    features.buffer_impl = True
287*523fa7a6SAndroid Build Coastguard Worker    features.resize_fn = True
288*523fa7a6SAndroid Build Coastguard Worker    return features
289*523fa7a6SAndroid Build Coastguard Worker
290*523fa7a6SAndroid Build Coastguard Worker
291*523fa7a6SAndroid Build Coastguard Worker@update_features(exir_ops.edge.aten._to_copy.default)
292*523fa7a6SAndroid Build Coastguard Workerdef register_to_copy_op(features: OpFeatures):
293*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
294*523fa7a6SAndroid Build Coastguard Worker        uses_axis_map=True,
295*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims=all_packed_dims,
296*523fa7a6SAndroid Build Coastguard Worker    )
297*523fa7a6SAndroid Build Coastguard Worker    features.resize_fn = True
298*523fa7a6SAndroid Build Coastguard Worker
299*523fa7a6SAndroid Build Coastguard Worker    def check_to_copy_node(node: torch.fx.Node) -> bool:
300*523fa7a6SAndroid Build Coastguard Worker        float_dtypes = [torch.float16, torch.float32]
301*523fa7a6SAndroid Build Coastguard Worker
302*523fa7a6SAndroid Build Coastguard Worker        if len(node.args) != 1:
303*523fa7a6SAndroid Build Coastguard Worker            return False
304*523fa7a6SAndroid Build Coastguard Worker
305*523fa7a6SAndroid Build Coastguard Worker        in_arg = node.args[0]
306*523fa7a6SAndroid Build Coastguard Worker        if not isinstance(in_arg, torch.fx.Node):
307*523fa7a6SAndroid Build Coastguard Worker            return False
308*523fa7a6SAndroid Build Coastguard Worker
309*523fa7a6SAndroid Build Coastguard Worker        in_tensor = in_arg.meta.get("val", None)
310*523fa7a6SAndroid Build Coastguard Worker        out_tensor = node.meta.get("val", None)
311*523fa7a6SAndroid Build Coastguard Worker
312*523fa7a6SAndroid Build Coastguard Worker        if isinstance(in_tensor, FakeTensor) and isinstance(out_tensor, FakeTensor):
313*523fa7a6SAndroid Build Coastguard Worker            if out_tensor.dtype in float_dtypes and in_tensor.dtype in float_dtypes:
314*523fa7a6SAndroid Build Coastguard Worker                return True
315*523fa7a6SAndroid Build Coastguard Worker
316*523fa7a6SAndroid Build Coastguard Worker        return False
317*523fa7a6SAndroid Build Coastguard Worker
318*523fa7a6SAndroid Build Coastguard Worker    features.check_node_fn = check_to_copy_node
319*523fa7a6SAndroid Build Coastguard Worker
320*523fa7a6SAndroid Build Coastguard Worker    return features
321*523fa7a6SAndroid Build Coastguard Worker
322*523fa7a6SAndroid Build Coastguard Worker
323*523fa7a6SAndroid Build Coastguard Worker@update_features(
324*523fa7a6SAndroid Build Coastguard Worker    [
325*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.bmm.default,
326*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.mm.default,
327*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.addmm.default,
328*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.linear.default,
329*523fa7a6SAndroid Build Coastguard Worker    ]
330*523fa7a6SAndroid Build Coastguard Worker)
331*523fa7a6SAndroid Build Coastguard Workerdef register_mm_op(features: OpFeatures):
332*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
333*523fa7a6SAndroid Build Coastguard Worker        uses_axis_map=True,
334*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims={
335*523fa7a6SAndroid Build Coastguard Worker            PackedDim.WIDTH,
336*523fa7a6SAndroid Build Coastguard Worker            PackedDim.CHANNELS,
337*523fa7a6SAndroid Build Coastguard Worker        },
338*523fa7a6SAndroid Build Coastguard Worker    )
339*523fa7a6SAndroid Build Coastguard Worker    features.buffer_impl = True
340*523fa7a6SAndroid Build Coastguard Worker    features.resize_fn = True
341*523fa7a6SAndroid Build Coastguard Worker    features.optimal_storage = VkStorageType.TEXTURE_3D
342*523fa7a6SAndroid Build Coastguard Worker    features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
343*523fa7a6SAndroid Build Coastguard Worker    features.handles_own_prepacking = True
344*523fa7a6SAndroid Build Coastguard Worker    return features
345*523fa7a6SAndroid Build Coastguard Worker
346*523fa7a6SAndroid Build Coastguard Worker
347*523fa7a6SAndroid Build Coastguard Worker@update_features(exir_ops.edge.aten._weight_int8pack_mm.default)
348*523fa7a6SAndroid Build Coastguard Workerdef register_int8_mm_op(features: OpFeatures):
349*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
350*523fa7a6SAndroid Build Coastguard Worker        uses_axis_map=False,
351*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims={PackedDim.WIDTH},
352*523fa7a6SAndroid Build Coastguard Worker    )
353*523fa7a6SAndroid Build Coastguard Worker    features.buffer_impl = True
354*523fa7a6SAndroid Build Coastguard Worker    features.resize_fn = True
355*523fa7a6SAndroid Build Coastguard Worker    features.optimal_storage = VkStorageType.TEXTURE_3D
356*523fa7a6SAndroid Build Coastguard Worker    features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
357*523fa7a6SAndroid Build Coastguard Worker    features.handles_own_prepacking = True
358*523fa7a6SAndroid Build Coastguard Worker    return features
359*523fa7a6SAndroid Build Coastguard Worker
360*523fa7a6SAndroid Build Coastguard Worker
361*523fa7a6SAndroid Build Coastguard Worker@update_features(exir_ops.edge.et_vk.linear_weight_int4.default)
362*523fa7a6SAndroid Build Coastguard Workerdef register_int4_mm_op(features: OpFeatures):
363*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
364*523fa7a6SAndroid Build Coastguard Worker        uses_axis_map=False,
365*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims={PackedDim.WIDTH},
366*523fa7a6SAndroid Build Coastguard Worker    )
367*523fa7a6SAndroid Build Coastguard Worker    features.resize_fn = True
368*523fa7a6SAndroid Build Coastguard Worker    features.optimal_storage = VkStorageType.TEXTURE_3D
369*523fa7a6SAndroid Build Coastguard Worker    features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
370*523fa7a6SAndroid Build Coastguard Worker    features.handles_own_prepacking = True
371*523fa7a6SAndroid Build Coastguard Worker    return features
372*523fa7a6SAndroid Build Coastguard Worker
373*523fa7a6SAndroid Build Coastguard Worker
374*523fa7a6SAndroid Build Coastguard Worker@update_features(
375*523fa7a6SAndroid Build Coastguard Worker    [
376*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten._log_softmax.default,
377*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten._softmax.default,
378*523fa7a6SAndroid Build Coastguard Worker    ]
379*523fa7a6SAndroid Build Coastguard Worker)
380*523fa7a6SAndroid Build Coastguard Workerdef register_softmax_op(features: OpFeatures):
381*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
382*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims=all_packed_dims,
383*523fa7a6SAndroid Build Coastguard Worker    )
384*523fa7a6SAndroid Build Coastguard Worker    features.resize_fn = True
385*523fa7a6SAndroid Build Coastguard Worker    return features
386*523fa7a6SAndroid Build Coastguard Worker
387*523fa7a6SAndroid Build Coastguard Worker
388*523fa7a6SAndroid Build Coastguard Worker@update_features(
389*523fa7a6SAndroid Build Coastguard Worker    [
390*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.mean.dim,
391*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.sum.dim_IntList,
392*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.amax.default,
393*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.amin.default,
394*523fa7a6SAndroid Build Coastguard Worker    ]
395*523fa7a6SAndroid Build Coastguard Worker)
396*523fa7a6SAndroid Build Coastguard Workerdef register_reduce_op(features: OpFeatures):
397*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
398*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims=all_packed_dims,
399*523fa7a6SAndroid Build Coastguard Worker    )
400*523fa7a6SAndroid Build Coastguard Worker    features.resize_fn = True
401*523fa7a6SAndroid Build Coastguard Worker
402*523fa7a6SAndroid Build Coastguard Worker    def check_reduce_node(node: torch.fx.Node) -> bool:
403*523fa7a6SAndroid Build Coastguard Worker        dim_list = node.args[1]
404*523fa7a6SAndroid Build Coastguard Worker        if isinstance(dim_list, list) and len(dim_list) != 1:
405*523fa7a6SAndroid Build Coastguard Worker            return False
406*523fa7a6SAndroid Build Coastguard Worker
407*523fa7a6SAndroid Build Coastguard Worker        keepdim = node.args[2]
408*523fa7a6SAndroid Build Coastguard Worker        if isinstance(keepdim, bool) and not keepdim:
409*523fa7a6SAndroid Build Coastguard Worker            return False
410*523fa7a6SAndroid Build Coastguard Worker
411*523fa7a6SAndroid Build Coastguard Worker        return True
412*523fa7a6SAndroid Build Coastguard Worker
413*523fa7a6SAndroid Build Coastguard Worker    features.check_node_fn = check_reduce_node
414*523fa7a6SAndroid Build Coastguard Worker    return features
415*523fa7a6SAndroid Build Coastguard Worker
416*523fa7a6SAndroid Build Coastguard Worker
417*523fa7a6SAndroid Build Coastguard Worker@update_features(
418*523fa7a6SAndroid Build Coastguard Worker    [
419*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.avg_pool2d.default,
420*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.max_pool2d_with_indices.default,
421*523fa7a6SAndroid Build Coastguard Worker    ]
422*523fa7a6SAndroid Build Coastguard Worker)
423*523fa7a6SAndroid Build Coastguard Workerdef register_2d_pool_op(features: OpFeatures):
424*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
425*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims={PackedDim.CHANNELS},
426*523fa7a6SAndroid Build Coastguard Worker    )
427*523fa7a6SAndroid Build Coastguard Worker    features.resize_fn = True
428*523fa7a6SAndroid Build Coastguard Worker    return features
429*523fa7a6SAndroid Build Coastguard Worker
430*523fa7a6SAndroid Build Coastguard Worker
431*523fa7a6SAndroid Build Coastguard Worker@update_features(
432*523fa7a6SAndroid Build Coastguard Worker    [
433*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.convolution.default,
434*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.et_vk.conv_with_clamp.default,
435*523fa7a6SAndroid Build Coastguard Worker    ]
436*523fa7a6SAndroid Build Coastguard Worker)
437*523fa7a6SAndroid Build Coastguard Workerdef register_convolution_op(features: OpFeatures):
438*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
439*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims={PackedDim.CHANNELS},
440*523fa7a6SAndroid Build Coastguard Worker    )
441*523fa7a6SAndroid Build Coastguard Worker    features.resize_fn = True
442*523fa7a6SAndroid Build Coastguard Worker    features.optimal_storage = VkStorageType.TEXTURE_3D
443*523fa7a6SAndroid Build Coastguard Worker    features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED
444*523fa7a6SAndroid Build Coastguard Worker    features.handles_own_prepacking = True
445*523fa7a6SAndroid Build Coastguard Worker    features.skip_limits_check = {1, 2}
446*523fa7a6SAndroid Build Coastguard Worker    return features
447*523fa7a6SAndroid Build Coastguard Worker
448*523fa7a6SAndroid Build Coastguard Worker
449*523fa7a6SAndroid Build Coastguard Worker@update_features("llama::sdpa_with_kv_cache")
450*523fa7a6SAndroid Build Coastguard Workerdef register_sdpa_op(features: OpFeatures):
451*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
452*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims={PackedDim.WIDTH},
453*523fa7a6SAndroid Build Coastguard Worker    )
454*523fa7a6SAndroid Build Coastguard Worker    features.resize_fn = True
455*523fa7a6SAndroid Build Coastguard Worker    features.optimal_storage = VkStorageType.TEXTURE_3D
456*523fa7a6SAndroid Build Coastguard Worker    features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
457*523fa7a6SAndroid Build Coastguard Worker    features.handles_own_prepacking = True
458*523fa7a6SAndroid Build Coastguard Worker    return features
459*523fa7a6SAndroid Build Coastguard Worker
460*523fa7a6SAndroid Build Coastguard Worker
461*523fa7a6SAndroid Build Coastguard Worker@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
462*523fa7a6SAndroid Build Coastguard Workerdef register_rotary_emb_op(features: OpFeatures):
463*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
464*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims={PackedDim.WIDTH},
465*523fa7a6SAndroid Build Coastguard Worker    )
466*523fa7a6SAndroid Build Coastguard Worker    features.resize_fn = True
467*523fa7a6SAndroid Build Coastguard Worker    return features
468*523fa7a6SAndroid Build Coastguard Worker
469*523fa7a6SAndroid Build Coastguard Worker
470*523fa7a6SAndroid Build Coastguard Worker@update_features(exir_ops.edge.aten.view_copy.default)
471*523fa7a6SAndroid Build Coastguard Workerdef register_view_op(features: OpFeatures):
472*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
473*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims=all_packed_dims,
474*523fa7a6SAndroid Build Coastguard Worker    )
475*523fa7a6SAndroid Build Coastguard Worker    features.resize_fn = True
476*523fa7a6SAndroid Build Coastguard Worker    return features
477*523fa7a6SAndroid Build Coastguard Worker
478*523fa7a6SAndroid Build Coastguard Worker
479*523fa7a6SAndroid Build Coastguard Worker# Ops ported from PyTorch Vulkan backend. These ops commonly support channels
480*523fa7a6SAndroid Build Coastguard Worker# packed tensors only and do not have a resize function.
481*523fa7a6SAndroid Build Coastguard Worker@update_features(
482*523fa7a6SAndroid Build Coastguard Worker    [
483*523fa7a6SAndroid Build Coastguard Worker        # Shape Manipulation
484*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.squeeze_copy.dims,
485*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.unsqueeze_copy.default,
486*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.permute_copy.default,
487*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.t_copy.default,
488*523fa7a6SAndroid Build Coastguard Worker        # Indexing and lookup
489*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.flip.default,
490*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.index_select.default,
491*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.select_copy.int,
492*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.slice_copy.Tensor,
493*523fa7a6SAndroid Build Coastguard Worker        # Tensor combination
494*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.cat.default,
495*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.split_with_sizes_copy.default,
496*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.split.Tensor,
497*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.repeat.default,
498*523fa7a6SAndroid Build Coastguard Worker        # Tensor creation
499*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.arange.start_step,
500*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.clone.default,
501*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.constant_pad_nd.default,
502*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.full.default,
503*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.full_like.default,
504*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.ones.default,
505*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.ones_like.default,
506*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.upsample_nearest2d.vec,
507*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.zeros.default,
508*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.zeros_like.default,
509*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.et_vk.grid_priors.default,
510*523fa7a6SAndroid Build Coastguard Worker    ]
511*523fa7a6SAndroid Build Coastguard Worker)
512*523fa7a6SAndroid Build Coastguard Workerdef register_ported_op(features: OpFeatures):
513*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
514*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims={PackedDim.CHANNELS},
515*523fa7a6SAndroid Build Coastguard Worker    )
516*523fa7a6SAndroid Build Coastguard Worker    return features
517*523fa7a6SAndroid Build Coastguard Worker
518*523fa7a6SAndroid Build Coastguard Worker
519*523fa7a6SAndroid Build Coastguard Worker# Ported ops that support their own prepacking.
520*523fa7a6SAndroid Build Coastguard Worker@update_features(
521*523fa7a6SAndroid Build Coastguard Worker    [
522*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.embedding.default,
523*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
524*523fa7a6SAndroid Build Coastguard Worker        exir_ops.edge.aten.native_layer_norm.default,
525*523fa7a6SAndroid Build Coastguard Worker    ]
526*523fa7a6SAndroid Build Coastguard Worker)
527*523fa7a6SAndroid Build Coastguard Workerdef register_ported_ops_with_prepacking(features: OpFeatures):
528*523fa7a6SAndroid Build Coastguard Worker    features.texture_impl = TextureImplFeatures(
529*523fa7a6SAndroid Build Coastguard Worker        valid_packed_dims={PackedDim.CHANNELS},
530*523fa7a6SAndroid Build Coastguard Worker    )
531*523fa7a6SAndroid Build Coastguard Worker    features.handles_own_prepacking = True
532*523fa7a6SAndroid Build Coastguard Worker    return features
533*523fa7a6SAndroid Build Coastguard Worker
534*523fa7a6SAndroid Build Coastguard Worker
535*523fa7a6SAndroid Build Coastguard Worker#######################
536*523fa7a6SAndroid Build Coastguard Worker## Utility functions ##
537*523fa7a6SAndroid Build Coastguard Worker#######################
538*523fa7a6SAndroid Build Coastguard Worker
539*523fa7a6SAndroid Build Coastguard Worker
540*523fa7a6SAndroid Build Coastguard Workerdef has_impl(target: OpKey) -> bool:
541*523fa7a6SAndroid Build Coastguard Worker    if not isinstance(target, str):
542*523fa7a6SAndroid Build Coastguard Worker        if target not in vulkan_supported_ops:
543*523fa7a6SAndroid Build Coastguard Worker            return target.name() in vulkan_supported_ops
544*523fa7a6SAndroid Build Coastguard Worker        return target in vulkan_supported_ops
545*523fa7a6SAndroid Build Coastguard Worker    else:
546*523fa7a6SAndroid Build Coastguard Worker        return target in vulkan_supported_ops
547*523fa7a6SAndroid Build Coastguard Worker
548*523fa7a6SAndroid Build Coastguard Worker
549*523fa7a6SAndroid Build Coastguard Workerdef get_op_features(target: OpKey) -> OpFeatures:
550*523fa7a6SAndroid Build Coastguard Worker    if not isinstance(target, str):
551*523fa7a6SAndroid Build Coastguard Worker        if target not in vulkan_supported_ops:
552*523fa7a6SAndroid Build Coastguard Worker            # Try the op's name
553*523fa7a6SAndroid Build Coastguard Worker            return vulkan_supported_ops[target.name()]
554*523fa7a6SAndroid Build Coastguard Worker
555*523fa7a6SAndroid Build Coastguard Worker        return vulkan_supported_ops[target]
556*523fa7a6SAndroid Build Coastguard Worker    else:
557*523fa7a6SAndroid Build Coastguard Worker        return vulkan_supported_ops[target]
558*523fa7a6SAndroid Build Coastguard Worker
559*523fa7a6SAndroid Build Coastguard Worker
560*523fa7a6SAndroid Build Coastguard Workerdef handles_own_prepacking(target: OpKey) -> bool:
561*523fa7a6SAndroid Build Coastguard Worker    return get_op_features(target).handles_own_prepacking
562