1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport operator 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Callable, Dict, Optional, Set, Union 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerimport executorch.backends.vulkan.custom_ops_lib # noqa 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Workerimport torch 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.vulkan.serialization.vulkan_graph_schema import ( 18*523fa7a6SAndroid Build Coastguard Worker VkMemoryLayout, 19*523fa7a6SAndroid Build Coastguard Worker VkStorageType, 20*523fa7a6SAndroid Build Coastguard Worker) 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.vulkan.utils import ( 23*523fa7a6SAndroid Build Coastguard Worker all_memory_layouts, 24*523fa7a6SAndroid Build Coastguard Worker all_packed_dims, 25*523fa7a6SAndroid Build Coastguard Worker PackedDim, 26*523fa7a6SAndroid Build Coastguard Worker) 27*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects.edge._ops import EdgeOpOverload 30*523fa7a6SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensor 31*523fa7a6SAndroid Build Coastguard Worker 32*523fa7a6SAndroid Build Coastguard Worker###################### 33*523fa7a6SAndroid Build Coastguard Worker## OpFeatures class ## 34*523fa7a6SAndroid Build Coastguard Worker###################### 35*523fa7a6SAndroid Build Coastguard Worker 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Workerdef allow_node(node: torch.fx.Node) -> bool: 38*523fa7a6SAndroid Build Coastguard Worker return True 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Workerclass TextureImplFeatures: 42*523fa7a6SAndroid Build Coastguard Worker __slots__ = [ 43*523fa7a6SAndroid Build Coastguard Worker "valid_packed_dims", 44*523fa7a6SAndroid Build Coastguard Worker "uses_axis_map", 45*523fa7a6SAndroid Build Coastguard Worker ] 46*523fa7a6SAndroid Build Coastguard Worker 47*523fa7a6SAndroid Build Coastguard Worker def __init__( 48*523fa7a6SAndroid Build Coastguard Worker self, 49*523fa7a6SAndroid Build Coastguard Worker uses_axis_map: bool = False, 50*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims: Optional[Set[PackedDim]] = None, 51*523fa7a6SAndroid Build Coastguard Worker ): 52*523fa7a6SAndroid Build Coastguard Worker self.uses_axis_map: bool = uses_axis_map 53*523fa7a6SAndroid Build Coastguard Worker self.valid_packed_dims = set() 54*523fa7a6SAndroid Build Coastguard Worker if valid_packed_dims is not None: 55*523fa7a6SAndroid Build Coastguard Worker self.valid_packed_dims = valid_packed_dims 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker def valid_memory_layouts(self) -> Set[VkMemoryLayout]: 58*523fa7a6SAndroid Build Coastguard Worker """ 59*523fa7a6SAndroid Build Coastguard Worker Derive the set of memory layouts supported by the texture implementation based 60*523fa7a6SAndroid Build Coastguard Worker on the valid packed dimensions. 61*523fa7a6SAndroid Build Coastguard Worker """ 62*523fa7a6SAndroid Build Coastguard Worker layouts = set() 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker if PackedDim.WIDTH in self.valid_packed_dims: 65*523fa7a6SAndroid Build Coastguard Worker layouts.add(VkMemoryLayout.TENSOR_WIDTH_PACKED) 66*523fa7a6SAndroid Build Coastguard Worker 67*523fa7a6SAndroid Build Coastguard Worker if PackedDim.HEIGHT in self.valid_packed_dims: 68*523fa7a6SAndroid Build Coastguard Worker layouts.add(VkMemoryLayout.TENSOR_HEIGHT_PACKED) 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker if PackedDim.CHANNELS in self.valid_packed_dims: 71*523fa7a6SAndroid Build Coastguard Worker layouts.add(VkMemoryLayout.TENSOR_CHANNELS_PACKED) 72*523fa7a6SAndroid Build Coastguard Worker 73*523fa7a6SAndroid Build Coastguard Worker return layouts 74*523fa7a6SAndroid Build Coastguard Worker 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Workerclass OpFeatures: 77*523fa7a6SAndroid Build Coastguard Worker __slots__ = [ 78*523fa7a6SAndroid Build Coastguard Worker # None or TextureImplFeatures to specify implementation details of the texture 79*523fa7a6SAndroid Build Coastguard Worker # based operator implementation. 80*523fa7a6SAndroid Build Coastguard Worker "texture_impl", 81*523fa7a6SAndroid Build Coastguard Worker # bool indicating if the operator has a buffer based implementation. 82*523fa7a6SAndroid Build Coastguard Worker "buffer_impl", 83*523fa7a6SAndroid Build Coastguard Worker # bool indicating if the operator has a resize function, which allows it to 84*523fa7a6SAndroid Build Coastguard Worker # support dynamic shape tensors. 85*523fa7a6SAndroid Build Coastguard Worker "resize_fn", 86*523fa7a6SAndroid Build Coastguard Worker # Optimal 87*523fa7a6SAndroid Build Coastguard Worker "optimal_storage", 88*523fa7a6SAndroid Build Coastguard Worker "optimal_layout", 89*523fa7a6SAndroid Build Coastguard Worker # bool indicating if the operator handles its own prepacking. If this is True, 90*523fa7a6SAndroid Build Coastguard Worker # then the insert_prepack_nodes pass will not insert prepack nodes for the args 91*523fa7a6SAndroid Build Coastguard Worker # of the op. 92*523fa7a6SAndroid Build Coastguard Worker "handles_own_prepacking", 93*523fa7a6SAndroid Build Coastguard Worker # Optional dictionary to specify a custom function to calculate the required 94*523fa7a6SAndroid Build Coastguard Worker # image extents for a particular argument index. 95*523fa7a6SAndroid Build Coastguard Worker "skip_limits_check", 96*523fa7a6SAndroid Build Coastguard Worker # Optional check function used during partitioning to determine if a node's 97*523fa7a6SAndroid Build Coastguard Worker # inputs are supported by the operator implementation. 98*523fa7a6SAndroid Build Coastguard Worker "check_node_fn", 99*523fa7a6SAndroid Build Coastguard Worker ] 100*523fa7a6SAndroid Build Coastguard Worker 101*523fa7a6SAndroid Build Coastguard Worker def __init__( 102*523fa7a6SAndroid Build Coastguard Worker self, 103*523fa7a6SAndroid Build Coastguard Worker texture_impl: Optional[TextureImplFeatures] = None, 104*523fa7a6SAndroid Build Coastguard Worker buffer_impl: bool = False, 105*523fa7a6SAndroid Build Coastguard Worker resize_fn: bool = False, 106*523fa7a6SAndroid Build Coastguard Worker optimal_storage: Optional[VkStorageType] = None, 107*523fa7a6SAndroid Build Coastguard Worker optimal_layout: Optional[VkMemoryLayout] = None, 108*523fa7a6SAndroid Build Coastguard Worker handles_own_prepacking: bool = False, 109*523fa7a6SAndroid Build Coastguard Worker skip_limits_check: Optional[Set[int]] = None, 110*523fa7a6SAndroid Build Coastguard Worker check_node_fn: Optional[Callable] = None, 111*523fa7a6SAndroid Build Coastguard Worker ): 112*523fa7a6SAndroid Build Coastguard Worker self.texture_impl: Optional[TextureImplFeatures] = texture_impl 113*523fa7a6SAndroid Build Coastguard Worker self.buffer_impl: bool = buffer_impl 114*523fa7a6SAndroid Build Coastguard Worker self.resize_fn: bool = resize_fn 115*523fa7a6SAndroid Build Coastguard Worker self.optimal_storage: Optional[VkStorageType] = optimal_storage 116*523fa7a6SAndroid Build Coastguard Worker self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout 117*523fa7a6SAndroid Build Coastguard Worker self.handles_own_prepacking: bool = handles_own_prepacking 118*523fa7a6SAndroid Build Coastguard Worker 119*523fa7a6SAndroid Build Coastguard Worker self.skip_limits_check: Set[int] = set() 120*523fa7a6SAndroid Build Coastguard Worker if skip_limits_check is not None: 121*523fa7a6SAndroid Build Coastguard Worker self.skip_limits_check = skip_limits_check 122*523fa7a6SAndroid Build Coastguard Worker 123*523fa7a6SAndroid Build Coastguard Worker self.check_node_fn: Callable = allow_node 124*523fa7a6SAndroid Build Coastguard Worker if check_node_fn is not None: 125*523fa7a6SAndroid Build Coastguard Worker self.check_node_fn = check_node_fn 126*523fa7a6SAndroid Build Coastguard Worker 127*523fa7a6SAndroid Build Coastguard Worker def propose_storage_type(self) -> Optional[VkStorageType]: 128*523fa7a6SAndroid Build Coastguard Worker """ 129*523fa7a6SAndroid Build Coastguard Worker Propose a storage type that should be used for this operator. A proposal can be 130*523fa7a6SAndroid Build Coastguard Worker made if one of the following is true: 131*523fa7a6SAndroid Build Coastguard Worker 1. The operator specifies an optimal storage type 132*523fa7a6SAndroid Build Coastguard Worker 2. Only one storage type is supported. 133*523fa7a6SAndroid Build Coastguard Worker 134*523fa7a6SAndroid Build Coastguard Worker If both storage types are supported and no optimal storage type is specified, 135*523fa7a6SAndroid Build Coastguard Worker then None is returned to indicate that there is no preference in storage type. 136*523fa7a6SAndroid Build Coastguard Worker """ 137*523fa7a6SAndroid Build Coastguard Worker if self.optimal_storage is not None: 138*523fa7a6SAndroid Build Coastguard Worker return self.optimal_storage 139*523fa7a6SAndroid Build Coastguard Worker 140*523fa7a6SAndroid Build Coastguard Worker if self.texture_impl is not None and not self.buffer_impl: 141*523fa7a6SAndroid Build Coastguard Worker return VkStorageType.TEXTURE_3D 142*523fa7a6SAndroid Build Coastguard Worker elif self.buffer_impl and self.texture_impl is None: 143*523fa7a6SAndroid Build Coastguard Worker return VkStorageType.BUFFER 144*523fa7a6SAndroid Build Coastguard Worker 145*523fa7a6SAndroid Build Coastguard Worker return None 146*523fa7a6SAndroid Build Coastguard Worker 147*523fa7a6SAndroid Build Coastguard Worker def supported_storage_types(self) -> Set[VkStorageType]: 148*523fa7a6SAndroid Build Coastguard Worker """ 149*523fa7a6SAndroid Build Coastguard Worker Return the set of storage types supported by this operator. 150*523fa7a6SAndroid Build Coastguard Worker """ 151*523fa7a6SAndroid Build Coastguard Worker storage_types = set() 152*523fa7a6SAndroid Build Coastguard Worker if self.texture_impl is not None: 153*523fa7a6SAndroid Build Coastguard Worker storage_types.add(VkStorageType.TEXTURE_3D) 154*523fa7a6SAndroid Build Coastguard Worker if self.buffer_impl: 155*523fa7a6SAndroid Build Coastguard Worker storage_types.add(VkStorageType.BUFFER) 156*523fa7a6SAndroid Build Coastguard Worker 157*523fa7a6SAndroid Build Coastguard Worker return storage_types 158*523fa7a6SAndroid Build Coastguard Worker 159*523fa7a6SAndroid Build Coastguard Worker def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]: 160*523fa7a6SAndroid Build Coastguard Worker """ 161*523fa7a6SAndroid Build Coastguard Worker Given a storage type as a precondition, propose a memory layout that should be 162*523fa7a6SAndroid Build Coastguard Worker used for this operator. A proposal can be made if one of the following is true: 163*523fa7a6SAndroid Build Coastguard Worker 1. The operator specifies an optimal memory layout 164*523fa7a6SAndroid Build Coastguard Worker 2. Only one memory layout is supported. 165*523fa7a6SAndroid Build Coastguard Worker 166*523fa7a6SAndroid Build Coastguard Worker If multiple memory layouts are supported and no optimal memory layout is 167*523fa7a6SAndroid Build Coastguard Worker specified then return None to indicate that the "best" memory layout for the 168*523fa7a6SAndroid Build Coastguard Worker operator is ambiguous. 169*523fa7a6SAndroid Build Coastguard Worker """ 170*523fa7a6SAndroid Build Coastguard Worker if self.optimal_layout is not None: 171*523fa7a6SAndroid Build Coastguard Worker return self.optimal_layout 172*523fa7a6SAndroid Build Coastguard Worker 173*523fa7a6SAndroid Build Coastguard Worker if storage == VkStorageType.TEXTURE_3D: 174*523fa7a6SAndroid Build Coastguard Worker assert self.texture_impl is not None 175*523fa7a6SAndroid Build Coastguard Worker possible_layouts = self.texture_impl.valid_memory_layouts() 176*523fa7a6SAndroid Build Coastguard Worker if len(possible_layouts) == 1: 177*523fa7a6SAndroid Build Coastguard Worker return next(iter(possible_layouts)) 178*523fa7a6SAndroid Build Coastguard Worker 179*523fa7a6SAndroid Build Coastguard Worker return None 180*523fa7a6SAndroid Build Coastguard Worker 181*523fa7a6SAndroid Build Coastguard Worker def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout]: 182*523fa7a6SAndroid Build Coastguard Worker """ 183*523fa7a6SAndroid Build Coastguard Worker Return the set of memory layouts supported by this operator for a given storage 184*523fa7a6SAndroid Build Coastguard Worker type. 185*523fa7a6SAndroid Build Coastguard Worker """ 186*523fa7a6SAndroid Build Coastguard Worker if storage == VkStorageType.TEXTURE_3D: 187*523fa7a6SAndroid Build Coastguard Worker assert self.texture_impl is not None 188*523fa7a6SAndroid Build Coastguard Worker return self.texture_impl.valid_memory_layouts() 189*523fa7a6SAndroid Build Coastguard Worker else: 190*523fa7a6SAndroid Build Coastguard Worker return all_memory_layouts 191*523fa7a6SAndroid Build Coastguard Worker 192*523fa7a6SAndroid Build Coastguard Worker 193*523fa7a6SAndroid Build Coastguard Worker####################### 194*523fa7a6SAndroid Build Coastguard Worker## Operator Registry ## 195*523fa7a6SAndroid Build Coastguard Worker####################### 196*523fa7a6SAndroid Build Coastguard Worker 197*523fa7a6SAndroid Build Coastguard WorkerOpKey = Union[str, torch._ops.OpOverload, EdgeOpOverload] 198*523fa7a6SAndroid Build Coastguard Worker 199*523fa7a6SAndroid Build Coastguard Workervulkan_supported_ops: Dict[OpKey, OpFeatures] = {} 200*523fa7a6SAndroid Build Coastguard Worker 201*523fa7a6SAndroid Build Coastguard Worker 202*523fa7a6SAndroid Build Coastguard Workerdef update_features(aten_op): 203*523fa7a6SAndroid Build Coastguard Worker def features_decorator(fn: Callable): 204*523fa7a6SAndroid Build Coastguard Worker def update_features_impl(op: OpKey): 205*523fa7a6SAndroid Build Coastguard Worker if op in vulkan_supported_ops: 206*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"[Vulkan delegate] duplicate registration of {op}!") 207*523fa7a6SAndroid Build Coastguard Worker vulkan_supported_ops[op] = OpFeatures() 208*523fa7a6SAndroid Build Coastguard Worker vulkan_supported_ops[op] = fn(vulkan_supported_ops[op]) 209*523fa7a6SAndroid Build Coastguard Worker 210*523fa7a6SAndroid Build Coastguard Worker if isinstance(aten_op, list): 211*523fa7a6SAndroid Build Coastguard Worker for op in aten_op: 212*523fa7a6SAndroid Build Coastguard Worker update_features_impl(op) 213*523fa7a6SAndroid Build Coastguard Worker else: 214*523fa7a6SAndroid Build Coastguard Worker update_features_impl(aten_op) 215*523fa7a6SAndroid Build Coastguard Worker 216*523fa7a6SAndroid Build Coastguard Worker return fn 217*523fa7a6SAndroid Build Coastguard Worker 218*523fa7a6SAndroid Build Coastguard Worker return features_decorator 219*523fa7a6SAndroid Build Coastguard Worker 220*523fa7a6SAndroid Build Coastguard Worker 221*523fa7a6SAndroid Build Coastguard Worker@update_features( 222*523fa7a6SAndroid Build Coastguard Worker [ 223*523fa7a6SAndroid Build Coastguard Worker operator.getitem, 224*523fa7a6SAndroid Build Coastguard Worker # Quantization related ops will be fused via graph passes 225*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.quantized_decomposed.quantize_per_channel.default, 226*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 227*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, 228*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 229*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, 230*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, 231*523fa7a6SAndroid Build Coastguard Worker ] 232*523fa7a6SAndroid Build Coastguard Worker) 233*523fa7a6SAndroid Build Coastguard Workerdef register_ephemeral_op(features: OpFeatures): 234*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 235*523fa7a6SAndroid Build Coastguard Worker uses_axis_map=True, 236*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims=all_packed_dims, 237*523fa7a6SAndroid Build Coastguard Worker ) 238*523fa7a6SAndroid Build Coastguard Worker features.buffer_impl = True 239*523fa7a6SAndroid Build Coastguard Worker features.resize_fn = True 240*523fa7a6SAndroid Build Coastguard Worker return features 241*523fa7a6SAndroid Build Coastguard Worker 242*523fa7a6SAndroid Build Coastguard Worker 243*523fa7a6SAndroid Build Coastguard Worker@update_features( 244*523fa7a6SAndroid Build Coastguard Worker [ 245*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.add.Tensor, 246*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.sub.Tensor, 247*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.minimum.default, 248*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.mul.Tensor, 249*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.div.Tensor, 250*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.div.Tensor_mode, 251*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.pow.Tensor_Tensor, 252*523fa7a6SAndroid Build Coastguard Worker ] 253*523fa7a6SAndroid Build Coastguard Worker) 254*523fa7a6SAndroid Build Coastguard Workerdef register_binary_op(features: OpFeatures): 255*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 256*523fa7a6SAndroid Build Coastguard Worker uses_axis_map=True, 257*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims=all_packed_dims, 258*523fa7a6SAndroid Build Coastguard Worker ) 259*523fa7a6SAndroid Build Coastguard Worker features.resize_fn = True 260*523fa7a6SAndroid Build Coastguard Worker return features 261*523fa7a6SAndroid Build Coastguard Worker 262*523fa7a6SAndroid Build Coastguard Worker 263*523fa7a6SAndroid Build Coastguard Worker@update_features( 264*523fa7a6SAndroid Build Coastguard Worker [ 265*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.abs.default, 266*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.clamp.default, 267*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.cos.default, 268*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.exp.default, 269*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.gelu.default, 270*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.hardshrink.default, 271*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.hardtanh.default, 272*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.neg.default, 273*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.relu.default, 274*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.sigmoid.default, 275*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.sin.default, 276*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.sqrt.default, 277*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.rsqrt.default, 278*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.tanh.default, 279*523fa7a6SAndroid Build Coastguard Worker ] 280*523fa7a6SAndroid Build Coastguard Worker) 281*523fa7a6SAndroid Build Coastguard Workerdef register_unary_op(features: OpFeatures): 282*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 283*523fa7a6SAndroid Build Coastguard Worker uses_axis_map=True, 284*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims=all_packed_dims, 285*523fa7a6SAndroid Build Coastguard Worker ) 286*523fa7a6SAndroid Build Coastguard Worker features.buffer_impl = True 287*523fa7a6SAndroid Build Coastguard Worker features.resize_fn = True 288*523fa7a6SAndroid Build Coastguard Worker return features 289*523fa7a6SAndroid Build Coastguard Worker 290*523fa7a6SAndroid Build Coastguard Worker 291*523fa7a6SAndroid Build Coastguard Worker@update_features(exir_ops.edge.aten._to_copy.default) 292*523fa7a6SAndroid Build Coastguard Workerdef register_to_copy_op(features: OpFeatures): 293*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 294*523fa7a6SAndroid Build Coastguard Worker uses_axis_map=True, 295*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims=all_packed_dims, 296*523fa7a6SAndroid Build Coastguard Worker ) 297*523fa7a6SAndroid Build Coastguard Worker features.resize_fn = True 298*523fa7a6SAndroid Build Coastguard Worker 299*523fa7a6SAndroid Build Coastguard Worker def check_to_copy_node(node: torch.fx.Node) -> bool: 300*523fa7a6SAndroid Build Coastguard Worker float_dtypes = [torch.float16, torch.float32] 301*523fa7a6SAndroid Build Coastguard Worker 302*523fa7a6SAndroid Build Coastguard Worker if len(node.args) != 1: 303*523fa7a6SAndroid Build Coastguard Worker return False 304*523fa7a6SAndroid Build Coastguard Worker 305*523fa7a6SAndroid Build Coastguard Worker in_arg = node.args[0] 306*523fa7a6SAndroid Build Coastguard Worker if not isinstance(in_arg, torch.fx.Node): 307*523fa7a6SAndroid Build Coastguard Worker return False 308*523fa7a6SAndroid Build Coastguard Worker 309*523fa7a6SAndroid Build Coastguard Worker in_tensor = in_arg.meta.get("val", None) 310*523fa7a6SAndroid Build Coastguard Worker out_tensor = node.meta.get("val", None) 311*523fa7a6SAndroid Build Coastguard Worker 312*523fa7a6SAndroid Build Coastguard Worker if isinstance(in_tensor, FakeTensor) and isinstance(out_tensor, FakeTensor): 313*523fa7a6SAndroid Build Coastguard Worker if out_tensor.dtype in float_dtypes and in_tensor.dtype in float_dtypes: 314*523fa7a6SAndroid Build Coastguard Worker return True 315*523fa7a6SAndroid Build Coastguard Worker 316*523fa7a6SAndroid Build Coastguard Worker return False 317*523fa7a6SAndroid Build Coastguard Worker 318*523fa7a6SAndroid Build Coastguard Worker features.check_node_fn = check_to_copy_node 319*523fa7a6SAndroid Build Coastguard Worker 320*523fa7a6SAndroid Build Coastguard Worker return features 321*523fa7a6SAndroid Build Coastguard Worker 322*523fa7a6SAndroid Build Coastguard Worker 323*523fa7a6SAndroid Build Coastguard Worker@update_features( 324*523fa7a6SAndroid Build Coastguard Worker [ 325*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.bmm.default, 326*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.mm.default, 327*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.addmm.default, 328*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.linear.default, 329*523fa7a6SAndroid Build Coastguard Worker ] 330*523fa7a6SAndroid Build Coastguard Worker) 331*523fa7a6SAndroid Build Coastguard Workerdef register_mm_op(features: OpFeatures): 332*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 333*523fa7a6SAndroid Build Coastguard Worker uses_axis_map=True, 334*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims={ 335*523fa7a6SAndroid Build Coastguard Worker PackedDim.WIDTH, 336*523fa7a6SAndroid Build Coastguard Worker PackedDim.CHANNELS, 337*523fa7a6SAndroid Build Coastguard Worker }, 338*523fa7a6SAndroid Build Coastguard Worker ) 339*523fa7a6SAndroid Build Coastguard Worker features.buffer_impl = True 340*523fa7a6SAndroid Build Coastguard Worker features.resize_fn = True 341*523fa7a6SAndroid Build Coastguard Worker features.optimal_storage = VkStorageType.TEXTURE_3D 342*523fa7a6SAndroid Build Coastguard Worker features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED 343*523fa7a6SAndroid Build Coastguard Worker features.handles_own_prepacking = True 344*523fa7a6SAndroid Build Coastguard Worker return features 345*523fa7a6SAndroid Build Coastguard Worker 346*523fa7a6SAndroid Build Coastguard Worker 347*523fa7a6SAndroid Build Coastguard Worker@update_features(exir_ops.edge.aten._weight_int8pack_mm.default) 348*523fa7a6SAndroid Build Coastguard Workerdef register_int8_mm_op(features: OpFeatures): 349*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 350*523fa7a6SAndroid Build Coastguard Worker uses_axis_map=False, 351*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims={PackedDim.WIDTH}, 352*523fa7a6SAndroid Build Coastguard Worker ) 353*523fa7a6SAndroid Build Coastguard Worker features.buffer_impl = True 354*523fa7a6SAndroid Build Coastguard Worker features.resize_fn = True 355*523fa7a6SAndroid Build Coastguard Worker features.optimal_storage = VkStorageType.TEXTURE_3D 356*523fa7a6SAndroid Build Coastguard Worker features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED 357*523fa7a6SAndroid Build Coastguard Worker features.handles_own_prepacking = True 358*523fa7a6SAndroid Build Coastguard Worker return features 359*523fa7a6SAndroid Build Coastguard Worker 360*523fa7a6SAndroid Build Coastguard Worker 361*523fa7a6SAndroid Build Coastguard Worker@update_features(exir_ops.edge.et_vk.linear_weight_int4.default) 362*523fa7a6SAndroid Build Coastguard Workerdef register_int4_mm_op(features: OpFeatures): 363*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 364*523fa7a6SAndroid Build Coastguard Worker uses_axis_map=False, 365*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims={PackedDim.WIDTH}, 366*523fa7a6SAndroid Build Coastguard Worker ) 367*523fa7a6SAndroid Build Coastguard Worker features.resize_fn = True 368*523fa7a6SAndroid Build Coastguard Worker features.optimal_storage = VkStorageType.TEXTURE_3D 369*523fa7a6SAndroid Build Coastguard Worker features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED 370*523fa7a6SAndroid Build Coastguard Worker features.handles_own_prepacking = True 371*523fa7a6SAndroid Build Coastguard Worker return features 372*523fa7a6SAndroid Build Coastguard Worker 373*523fa7a6SAndroid Build Coastguard Worker 374*523fa7a6SAndroid Build Coastguard Worker@update_features( 375*523fa7a6SAndroid Build Coastguard Worker [ 376*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten._log_softmax.default, 377*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten._softmax.default, 378*523fa7a6SAndroid Build Coastguard Worker ] 379*523fa7a6SAndroid Build Coastguard Worker) 380*523fa7a6SAndroid Build Coastguard Workerdef register_softmax_op(features: OpFeatures): 381*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 382*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims=all_packed_dims, 383*523fa7a6SAndroid Build Coastguard Worker ) 384*523fa7a6SAndroid Build Coastguard Worker features.resize_fn = True 385*523fa7a6SAndroid Build Coastguard Worker return features 386*523fa7a6SAndroid Build Coastguard Worker 387*523fa7a6SAndroid Build Coastguard Worker 388*523fa7a6SAndroid Build Coastguard Worker@update_features( 389*523fa7a6SAndroid Build Coastguard Worker [ 390*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.mean.dim, 391*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.sum.dim_IntList, 392*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.amax.default, 393*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.amin.default, 394*523fa7a6SAndroid Build Coastguard Worker ] 395*523fa7a6SAndroid Build Coastguard Worker) 396*523fa7a6SAndroid Build Coastguard Workerdef register_reduce_op(features: OpFeatures): 397*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 398*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims=all_packed_dims, 399*523fa7a6SAndroid Build Coastguard Worker ) 400*523fa7a6SAndroid Build Coastguard Worker features.resize_fn = True 401*523fa7a6SAndroid Build Coastguard Worker 402*523fa7a6SAndroid Build Coastguard Worker def check_reduce_node(node: torch.fx.Node) -> bool: 403*523fa7a6SAndroid Build Coastguard Worker dim_list = node.args[1] 404*523fa7a6SAndroid Build Coastguard Worker if isinstance(dim_list, list) and len(dim_list) != 1: 405*523fa7a6SAndroid Build Coastguard Worker return False 406*523fa7a6SAndroid Build Coastguard Worker 407*523fa7a6SAndroid Build Coastguard Worker keepdim = node.args[2] 408*523fa7a6SAndroid Build Coastguard Worker if isinstance(keepdim, bool) and not keepdim: 409*523fa7a6SAndroid Build Coastguard Worker return False 410*523fa7a6SAndroid Build Coastguard Worker 411*523fa7a6SAndroid Build Coastguard Worker return True 412*523fa7a6SAndroid Build Coastguard Worker 413*523fa7a6SAndroid Build Coastguard Worker features.check_node_fn = check_reduce_node 414*523fa7a6SAndroid Build Coastguard Worker return features 415*523fa7a6SAndroid Build Coastguard Worker 416*523fa7a6SAndroid Build Coastguard Worker 417*523fa7a6SAndroid Build Coastguard Worker@update_features( 418*523fa7a6SAndroid Build Coastguard Worker [ 419*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.avg_pool2d.default, 420*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.max_pool2d_with_indices.default, 421*523fa7a6SAndroid Build Coastguard Worker ] 422*523fa7a6SAndroid Build Coastguard Worker) 423*523fa7a6SAndroid Build Coastguard Workerdef register_2d_pool_op(features: OpFeatures): 424*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 425*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims={PackedDim.CHANNELS}, 426*523fa7a6SAndroid Build Coastguard Worker ) 427*523fa7a6SAndroid Build Coastguard Worker features.resize_fn = True 428*523fa7a6SAndroid Build Coastguard Worker return features 429*523fa7a6SAndroid Build Coastguard Worker 430*523fa7a6SAndroid Build Coastguard Worker 431*523fa7a6SAndroid Build Coastguard Worker@update_features( 432*523fa7a6SAndroid Build Coastguard Worker [ 433*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.convolution.default, 434*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.et_vk.conv_with_clamp.default, 435*523fa7a6SAndroid Build Coastguard Worker ] 436*523fa7a6SAndroid Build Coastguard Worker) 437*523fa7a6SAndroid Build Coastguard Workerdef register_convolution_op(features: OpFeatures): 438*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 439*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims={PackedDim.CHANNELS}, 440*523fa7a6SAndroid Build Coastguard Worker ) 441*523fa7a6SAndroid Build Coastguard Worker features.resize_fn = True 442*523fa7a6SAndroid Build Coastguard Worker features.optimal_storage = VkStorageType.TEXTURE_3D 443*523fa7a6SAndroid Build Coastguard Worker features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED 444*523fa7a6SAndroid Build Coastguard Worker features.handles_own_prepacking = True 445*523fa7a6SAndroid Build Coastguard Worker features.skip_limits_check = {1, 2} 446*523fa7a6SAndroid Build Coastguard Worker return features 447*523fa7a6SAndroid Build Coastguard Worker 448*523fa7a6SAndroid Build Coastguard Worker 449*523fa7a6SAndroid Build Coastguard Worker@update_features("llama::sdpa_with_kv_cache") 450*523fa7a6SAndroid Build Coastguard Workerdef register_sdpa_op(features: OpFeatures): 451*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 452*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims={PackedDim.WIDTH}, 453*523fa7a6SAndroid Build Coastguard Worker ) 454*523fa7a6SAndroid Build Coastguard Worker features.resize_fn = True 455*523fa7a6SAndroid Build Coastguard Worker features.optimal_storage = VkStorageType.TEXTURE_3D 456*523fa7a6SAndroid Build Coastguard Worker features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED 457*523fa7a6SAndroid Build Coastguard Worker features.handles_own_prepacking = True 458*523fa7a6SAndroid Build Coastguard Worker return features 459*523fa7a6SAndroid Build Coastguard Worker 460*523fa7a6SAndroid Build Coastguard Worker 461*523fa7a6SAndroid Build Coastguard Worker@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) 462*523fa7a6SAndroid Build Coastguard Workerdef register_rotary_emb_op(features: OpFeatures): 463*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 464*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims={PackedDim.WIDTH}, 465*523fa7a6SAndroid Build Coastguard Worker ) 466*523fa7a6SAndroid Build Coastguard Worker features.resize_fn = True 467*523fa7a6SAndroid Build Coastguard Worker return features 468*523fa7a6SAndroid Build Coastguard Worker 469*523fa7a6SAndroid Build Coastguard Worker 470*523fa7a6SAndroid Build Coastguard Worker@update_features(exir_ops.edge.aten.view_copy.default) 471*523fa7a6SAndroid Build Coastguard Workerdef register_view_op(features: OpFeatures): 472*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 473*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims=all_packed_dims, 474*523fa7a6SAndroid Build Coastguard Worker ) 475*523fa7a6SAndroid Build Coastguard Worker features.resize_fn = True 476*523fa7a6SAndroid Build Coastguard Worker return features 477*523fa7a6SAndroid Build Coastguard Worker 478*523fa7a6SAndroid Build Coastguard Worker 479*523fa7a6SAndroid Build Coastguard Worker# Ops ported from PyTorch Vulkan backend. These ops commonly support channels 480*523fa7a6SAndroid Build Coastguard Worker# packed tensors only and do not have a resize function. 481*523fa7a6SAndroid Build Coastguard Worker@update_features( 482*523fa7a6SAndroid Build Coastguard Worker [ 483*523fa7a6SAndroid Build Coastguard Worker # Shape Manipulation 484*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.squeeze_copy.dims, 485*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.unsqueeze_copy.default, 486*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.permute_copy.default, 487*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.t_copy.default, 488*523fa7a6SAndroid Build Coastguard Worker # Indexing and lookup 489*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.flip.default, 490*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.index_select.default, 491*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.select_copy.int, 492*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.slice_copy.Tensor, 493*523fa7a6SAndroid Build Coastguard Worker # Tensor combination 494*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.cat.default, 495*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.split_with_sizes_copy.default, 496*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.split.Tensor, 497*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.repeat.default, 498*523fa7a6SAndroid Build Coastguard Worker # Tensor creation 499*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.arange.start_step, 500*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.clone.default, 501*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.constant_pad_nd.default, 502*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.full.default, 503*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.full_like.default, 504*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.ones.default, 505*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.ones_like.default, 506*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.upsample_nearest2d.vec, 507*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.zeros.default, 508*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.zeros_like.default, 509*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.et_vk.grid_priors.default, 510*523fa7a6SAndroid Build Coastguard Worker ] 511*523fa7a6SAndroid Build Coastguard Worker) 512*523fa7a6SAndroid Build Coastguard Workerdef register_ported_op(features: OpFeatures): 513*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 514*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims={PackedDim.CHANNELS}, 515*523fa7a6SAndroid Build Coastguard Worker ) 516*523fa7a6SAndroid Build Coastguard Worker return features 517*523fa7a6SAndroid Build Coastguard Worker 518*523fa7a6SAndroid Build Coastguard Worker 519*523fa7a6SAndroid Build Coastguard Worker# Ported ops that support their own prepacking. 520*523fa7a6SAndroid Build Coastguard Worker@update_features( 521*523fa7a6SAndroid Build Coastguard Worker [ 522*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.embedding.default, 523*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten._native_batch_norm_legit_no_training.default, 524*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.native_layer_norm.default, 525*523fa7a6SAndroid Build Coastguard Worker ] 526*523fa7a6SAndroid Build Coastguard Worker) 527*523fa7a6SAndroid Build Coastguard Workerdef register_ported_ops_with_prepacking(features: OpFeatures): 528*523fa7a6SAndroid Build Coastguard Worker features.texture_impl = TextureImplFeatures( 529*523fa7a6SAndroid Build Coastguard Worker valid_packed_dims={PackedDim.CHANNELS}, 530*523fa7a6SAndroid Build Coastguard Worker ) 531*523fa7a6SAndroid Build Coastguard Worker features.handles_own_prepacking = True 532*523fa7a6SAndroid Build Coastguard Worker return features 533*523fa7a6SAndroid Build Coastguard Worker 534*523fa7a6SAndroid Build Coastguard Worker 535*523fa7a6SAndroid Build Coastguard Worker####################### 536*523fa7a6SAndroid Build Coastguard Worker## Utility functions ## 537*523fa7a6SAndroid Build Coastguard Worker####################### 538*523fa7a6SAndroid Build Coastguard Worker 539*523fa7a6SAndroid Build Coastguard Worker 540*523fa7a6SAndroid Build Coastguard Workerdef has_impl(target: OpKey) -> bool: 541*523fa7a6SAndroid Build Coastguard Worker if not isinstance(target, str): 542*523fa7a6SAndroid Build Coastguard Worker if target not in vulkan_supported_ops: 543*523fa7a6SAndroid Build Coastguard Worker return target.name() in vulkan_supported_ops 544*523fa7a6SAndroid Build Coastguard Worker return target in vulkan_supported_ops 545*523fa7a6SAndroid Build Coastguard Worker else: 546*523fa7a6SAndroid Build Coastguard Worker return target in vulkan_supported_ops 547*523fa7a6SAndroid Build Coastguard Worker 548*523fa7a6SAndroid Build Coastguard Worker 549*523fa7a6SAndroid Build Coastguard Workerdef get_op_features(target: OpKey) -> OpFeatures: 550*523fa7a6SAndroid Build Coastguard Worker if not isinstance(target, str): 551*523fa7a6SAndroid Build Coastguard Worker if target not in vulkan_supported_ops: 552*523fa7a6SAndroid Build Coastguard Worker # Try the op's name 553*523fa7a6SAndroid Build Coastguard Worker return vulkan_supported_ops[target.name()] 554*523fa7a6SAndroid Build Coastguard Worker 555*523fa7a6SAndroid Build Coastguard Worker return vulkan_supported_ops[target] 556*523fa7a6SAndroid Build Coastguard Worker else: 557*523fa7a6SAndroid Build Coastguard Worker return vulkan_supported_ops[target] 558*523fa7a6SAndroid Build Coastguard Worker 559*523fa7a6SAndroid Build Coastguard Worker 560*523fa7a6SAndroid Build Coastguard Workerdef handles_own_prepacking(target: OpKey) -> bool: 561*523fa7a6SAndroid Build Coastguard Worker return get_op_features(target).handles_own_prepacking 562