1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport torch 8*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Worker""" 11*523fa7a6SAndroid Build Coastguard Worker** How to incorporate a new op into the XNNPACK Partitioner? ** 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Worker[1] When the new edge op being added is direct descendent of a core-aten op, 14*523fa7a6SAndroid Build Coastguard Workerand is also supported* by XNNPACK, prefer partitioning it via SUPPORTED_OPS 15*523fa7a6SAndroid Build Coastguard Workermechanism e.g. torch.add 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Worker[2] When the new op being added is not a core-aten op, 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Worker[2.1] If the original torch op is supported* by XNNPACK, prefer partitioning it 20*523fa7a6SAndroid Build Coastguard Workervia SUPPORTED_MODULES. This will require "recomposing" the op before lowering 21*523fa7a6SAndroid Build Coastguard Workerit to XNNPACK e.g. torch.nn.Linear. Make sure to include all variants of the 22*523fa7a6SAndroid Build Coastguard Workermodules in the SUPPORTED_MODULES list. 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker[2.2] If the original torch op is not supported by XNNPACK, then it is assumed 25*523fa7a6SAndroid Build Coastguard Workerthat out of all the decomposed core-aten ops, SUPPORTED_OPS will be lowered to 26*523fa7a6SAndroid Build Coastguard WorkerXNNPACK. 27*523fa7a6SAndroid Build Coastguard Worker 28*523fa7a6SAndroid Build Coastguard Worker* - Supported fully or partially. The partial support does not mean only few 29*523fa7a6SAndroid Build Coastguard Workerops from the decomposition but means only some variants of the op "modes" 30*523fa7a6SAndroid Build Coastguard Workerpossible with the arg combinations. 31*523fa7a6SAndroid Build Coastguard Worker""" 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_OPS = [ 34*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.div.Tensor, 35*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.add.Tensor, 36*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.clamp.default, 37*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.sub.Tensor, 38*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.floor.default, 39*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.maximum.default, 40*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.minimum.default, 41*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.mul.Tensor, 42*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.constant_pad_nd.default, 43*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.upsample_bilinear2d.default, 44*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.mean.dim, 45*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.max.dim, 46*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.max_pool2d_with_indices.default, 47*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.hardtanh.default, 48*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.sqrt.default, 49*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.ceil.default, 50*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.hardswish.default, 51*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.neg.default, 52*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.pow.Tensor_Scalar, 53*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.abs.default, 54*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten._prelu_kernel.default, 55*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.slice_copy.Tensor, 56*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.relu.default, 57*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.hardtanh.default, 58*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.permute_copy.default, 59*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.sigmoid.default, 60*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten._softmax.default, 61*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.cat.default, 62*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.elu.default, 63*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.avg_pool2d.default, 64*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.leaky_relu.default, 65*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm 66*523fa7a6SAndroid Build Coastguard Worker] 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_MODULES = [ 69*523fa7a6SAndroid Build Coastguard Worker torch.nn.Conv1d, 70*523fa7a6SAndroid Build Coastguard Worker # TODO(T161981984) recomposed hardswish into a single node 71*523fa7a6SAndroid Build Coastguard Worker torch.nn.Hardswish, # we need to recompose 72*523fa7a6SAndroid Build Coastguard Worker torch.nn.Hardsigmoid, # we can handle decomposition 73*523fa7a6SAndroid Build Coastguard Worker torch.nn.BatchNorm2d, 74*523fa7a6SAndroid Build Coastguard Worker torch.nn.BatchNorm1d, 75*523fa7a6SAndroid Build Coastguard Worker torch.nn.Conv2d, 76*523fa7a6SAndroid Build Coastguard Worker torch.nn.Linear, 77*523fa7a6SAndroid Build Coastguard Worker torch.nn.functional.linear, 78*523fa7a6SAndroid Build Coastguard Worker torch.nn.PReLU, # Without this, the PReLU weight becomes not a get_attr 79*523fa7a6SAndroid Build Coastguard Worker] 80*523fa7a6SAndroid Build Coastguard Worker 81*523fa7a6SAndroid Build Coastguard Worker# TODO delete this and should use SUPPORTED_OPS instead once we align fp32 and quant support 82*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_QUANT_OPS = [ 83*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.add.Tensor, 84*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.clamp.default, 85*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.relu.default, 86*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.sub.Tensor, 87*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.mul.Tensor, 88*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.mean.dim, 89*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.hardtanh.default, 90*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.slice_copy.Tensor, 91*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.permute_copy.default, 92*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.hardtanh.default, 93*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.mean.dim, 94*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.cat.default, 95*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.max_pool2d_with_indices.default, 96*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.max_pool2d.default, 97*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.constant_pad_nd.default, 98*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.elu.default, 99*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.t_copy.default, 100*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.leaky_relu.default, 101*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.addmm.default, # TODO(T163877189) add constraint for addmm 102*523fa7a6SAndroid Build Coastguard Worker] 103*523fa7a6SAndroid Build Coastguard Worker 104*523fa7a6SAndroid Build Coastguard Worker# This set is used to determine if an op is a supported Quantized Op. This is 105*523fa7a6SAndroid Build Coastguard Worker# used to determine whether a quantization op is implicit or explicit. 106*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET = { 107*523fa7a6SAndroid Build Coastguard Worker op.name() 108*523fa7a6SAndroid Build Coastguard Worker for op in ( 109*523fa7a6SAndroid Build Coastguard Worker SUPPORTED_QUANT_OPS 110*523fa7a6SAndroid Build Coastguard Worker + [ 111*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten._to_copy.default, 112*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.linear.default, 113*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.convolution.default, 114*523fa7a6SAndroid Build Coastguard Worker ] 115*523fa7a6SAndroid Build Coastguard Worker ) 116*523fa7a6SAndroid Build Coastguard Worker} 117*523fa7a6SAndroid Build Coastguard Worker 118*523fa7a6SAndroid Build Coastguard WorkerUNSUPPORTED_QUANT_MODULES = [ 119*523fa7a6SAndroid Build Coastguard Worker torch.nn.Hardswish, 120*523fa7a6SAndroid Build Coastguard Worker torch.nn.Hardsigmoid, 121*523fa7a6SAndroid Build Coastguard Worker] 122*523fa7a6SAndroid Build Coastguard Worker 123*523fa7a6SAndroid Build Coastguard Worker# TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support 124*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_QUANT_MODULES = [ 125*523fa7a6SAndroid Build Coastguard Worker torch.nn.Linear, 126*523fa7a6SAndroid Build Coastguard Worker torch.nn.functional.linear, 127*523fa7a6SAndroid Build Coastguard Worker # TODO - T158982884 128*523fa7a6SAndroid Build Coastguard Worker # torch.ao.nn.quantized.reference.modules.linear.Linear, 129*523fa7a6SAndroid Build Coastguard Worker torch.nn.Conv1d, 130*523fa7a6SAndroid Build Coastguard Worker torch.nn.functional.conv1d, 131*523fa7a6SAndroid Build Coastguard Worker torch.ao.nn.quantized.reference.modules.conv.Conv1d, 132*523fa7a6SAndroid Build Coastguard Worker torch.nn.Conv2d, 133*523fa7a6SAndroid Build Coastguard Worker torch.nn.functional.conv2d, 134*523fa7a6SAndroid Build Coastguard Worker torch.ao.nn.quantized.reference.modules.conv.Conv2d, 135*523fa7a6SAndroid Build Coastguard Worker torch.nn.BatchNorm1d, 136*523fa7a6SAndroid Build Coastguard Worker torch.nn.BatchNorm2d, 137*523fa7a6SAndroid Build Coastguard Worker] 138*523fa7a6SAndroid Build Coastguard Worker 139*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_IMPLICIT_Q_DQ_MODULES_SET = set(SUPPORTED_QUANT_MODULES) 140*523fa7a6SAndroid Build Coastguard Worker 141*523fa7a6SAndroid Build Coastguard Worker# Modules which support dynamic quantization 142*523fa7a6SAndroid Build Coastguard Worker# These already support dynamic shape. 143*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_DYN_QUANT_LINEAR_MODULES = [ 144*523fa7a6SAndroid Build Coastguard Worker torch.nn.Linear, 145*523fa7a6SAndroid Build Coastguard Worker torch.nn.functional.linear, 146*523fa7a6SAndroid Build Coastguard Worker] 147*523fa7a6SAndroid Build Coastguard Worker 148*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_DYN_QUANT_MODULES = SUPPORTED_DYN_QUANT_LINEAR_MODULES 149*523fa7a6SAndroid Build Coastguard Worker 150*523fa7a6SAndroid Build Coastguard Worker# XNNPACK supports majority of shape dynamism, however some ops are 151*523fa7a6SAndroid Build Coastguard Worker# explicitly static, so we maintain a set here to exclude them from 152*523fa7a6SAndroid Build Coastguard Worker# dynamic shape support. 153*523fa7a6SAndroid Build Coastguard WorkerSTATIC_OPS = [ 154*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.cat.default, 155*523fa7a6SAndroid Build Coastguard Worker exir_ops.edge.aten.slice_copy.Tensor, 156*523fa7a6SAndroid Build Coastguard Worker] 157*523fa7a6SAndroid Build Coastguard Worker 158*523fa7a6SAndroid Build Coastguard WorkerSTATIC_MODULES = [] 159