xref: /aosp_15_r20/external/executorch/backends/xnnpack/partition/configs.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport torch
8*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Worker"""
11*523fa7a6SAndroid Build Coastguard Worker** How to incorporate a new op into the XNNPACK Partitioner? **
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Worker[1] When the new edge op being added is direct descendent of a core-aten op,
14*523fa7a6SAndroid Build Coastguard Workerand is also supported* by XNNPACK, prefer partitioning it via SUPPORTED_OPS
15*523fa7a6SAndroid Build Coastguard Workermechanism e.g. torch.add
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Worker[2] When the new op being added is not a core-aten op,
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Worker[2.1] If the original torch op is supported* by XNNPACK, prefer partitioning it
20*523fa7a6SAndroid Build Coastguard Workervia SUPPORTED_MODULES. This will require "recomposing" the op before lowering
21*523fa7a6SAndroid Build Coastguard Workerit to XNNPACK e.g. torch.nn.Linear. Make sure to include all variants of the
22*523fa7a6SAndroid Build Coastguard Workermodules in the SUPPORTED_MODULES list.
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker[2.2] If the original torch op is not supported by XNNPACK, then it is assumed
25*523fa7a6SAndroid Build Coastguard Workerthat out of all the decomposed core-aten ops, SUPPORTED_OPS will be lowered to
26*523fa7a6SAndroid Build Coastguard WorkerXNNPACK.
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard Worker* - Supported fully or partially. The partial support does not mean only few
29*523fa7a6SAndroid Build Coastguard Workerops from the decomposition but means only some variants of the op "modes"
30*523fa7a6SAndroid Build Coastguard Workerpossible with the arg combinations.
31*523fa7a6SAndroid Build Coastguard Worker"""
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_OPS = [
34*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.div.Tensor,
35*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.add.Tensor,
36*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.clamp.default,
37*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.sub.Tensor,
38*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.floor.default,
39*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.maximum.default,
40*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.minimum.default,
41*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.mul.Tensor,
42*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.constant_pad_nd.default,
43*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.upsample_bilinear2d.default,
44*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.mean.dim,
45*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.max.dim,
46*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.max_pool2d_with_indices.default,
47*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.hardtanh.default,
48*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.sqrt.default,
49*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.ceil.default,
50*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.hardswish.default,
51*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.neg.default,
52*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.pow.Tensor_Scalar,
53*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.abs.default,
54*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten._prelu_kernel.default,
55*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.slice_copy.Tensor,
56*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.relu.default,
57*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.hardtanh.default,
58*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.permute_copy.default,
59*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.sigmoid.default,
60*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten._softmax.default,
61*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.cat.default,
62*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.elu.default,
63*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.avg_pool2d.default,
64*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.leaky_relu.default,
65*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.addmm.default,  # TODO(T163877189) add constraint for addmm
66*523fa7a6SAndroid Build Coastguard Worker]
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_MODULES = [
69*523fa7a6SAndroid Build Coastguard Worker    torch.nn.Conv1d,
70*523fa7a6SAndroid Build Coastguard Worker    # TODO(T161981984) recomposed hardswish into a single node
71*523fa7a6SAndroid Build Coastguard Worker    torch.nn.Hardswish,  # we need to recompose
72*523fa7a6SAndroid Build Coastguard Worker    torch.nn.Hardsigmoid,  # we can handle decomposition
73*523fa7a6SAndroid Build Coastguard Worker    torch.nn.BatchNorm2d,
74*523fa7a6SAndroid Build Coastguard Worker    torch.nn.BatchNorm1d,
75*523fa7a6SAndroid Build Coastguard Worker    torch.nn.Conv2d,
76*523fa7a6SAndroid Build Coastguard Worker    torch.nn.Linear,
77*523fa7a6SAndroid Build Coastguard Worker    torch.nn.functional.linear,
78*523fa7a6SAndroid Build Coastguard Worker    torch.nn.PReLU,  # Without this, the PReLU weight becomes not a get_attr
79*523fa7a6SAndroid Build Coastguard Worker]
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Worker# TODO delete this and should use SUPPORTED_OPS instead once we align fp32 and quant support
82*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_QUANT_OPS = [
83*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.add.Tensor,
84*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.clamp.default,
85*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.relu.default,
86*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.sub.Tensor,
87*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.mul.Tensor,
88*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.mean.dim,
89*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.hardtanh.default,
90*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.slice_copy.Tensor,
91*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.permute_copy.default,
92*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.hardtanh.default,
93*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.mean.dim,
94*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.cat.default,
95*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.max_pool2d_with_indices.default,
96*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.max_pool2d.default,
97*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.constant_pad_nd.default,
98*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.elu.default,
99*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.t_copy.default,
100*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.leaky_relu.default,
101*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.addmm.default,  # TODO(T163877189) add constraint for addmm
102*523fa7a6SAndroid Build Coastguard Worker]
103*523fa7a6SAndroid Build Coastguard Worker
104*523fa7a6SAndroid Build Coastguard Worker# This set is used to determine if an op is a supported Quantized Op. This is
105*523fa7a6SAndroid Build Coastguard Worker# used to determine whether a quantization op is implicit or explicit.
106*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_IMPLICIT_Q_DQ_OP_NAMES_SET = {
107*523fa7a6SAndroid Build Coastguard Worker    op.name()
108*523fa7a6SAndroid Build Coastguard Worker    for op in (
109*523fa7a6SAndroid Build Coastguard Worker        SUPPORTED_QUANT_OPS
110*523fa7a6SAndroid Build Coastguard Worker        + [
111*523fa7a6SAndroid Build Coastguard Worker            exir_ops.edge.aten._to_copy.default,
112*523fa7a6SAndroid Build Coastguard Worker            exir_ops.edge.aten.linear.default,
113*523fa7a6SAndroid Build Coastguard Worker            exir_ops.edge.aten.convolution.default,
114*523fa7a6SAndroid Build Coastguard Worker        ]
115*523fa7a6SAndroid Build Coastguard Worker    )
116*523fa7a6SAndroid Build Coastguard Worker}
117*523fa7a6SAndroid Build Coastguard Worker
118*523fa7a6SAndroid Build Coastguard WorkerUNSUPPORTED_QUANT_MODULES = [
119*523fa7a6SAndroid Build Coastguard Worker    torch.nn.Hardswish,
120*523fa7a6SAndroid Build Coastguard Worker    torch.nn.Hardsigmoid,
121*523fa7a6SAndroid Build Coastguard Worker]
122*523fa7a6SAndroid Build Coastguard Worker
123*523fa7a6SAndroid Build Coastguard Worker# TODO delete this and should use SUPPORTED_MODULES instead once we align fp32 and quant support
124*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_QUANT_MODULES = [
125*523fa7a6SAndroid Build Coastguard Worker    torch.nn.Linear,
126*523fa7a6SAndroid Build Coastguard Worker    torch.nn.functional.linear,
127*523fa7a6SAndroid Build Coastguard Worker    # TODO - T158982884
128*523fa7a6SAndroid Build Coastguard Worker    # torch.ao.nn.quantized.reference.modules.linear.Linear,
129*523fa7a6SAndroid Build Coastguard Worker    torch.nn.Conv1d,
130*523fa7a6SAndroid Build Coastguard Worker    torch.nn.functional.conv1d,
131*523fa7a6SAndroid Build Coastguard Worker    torch.ao.nn.quantized.reference.modules.conv.Conv1d,
132*523fa7a6SAndroid Build Coastguard Worker    torch.nn.Conv2d,
133*523fa7a6SAndroid Build Coastguard Worker    torch.nn.functional.conv2d,
134*523fa7a6SAndroid Build Coastguard Worker    torch.ao.nn.quantized.reference.modules.conv.Conv2d,
135*523fa7a6SAndroid Build Coastguard Worker    torch.nn.BatchNorm1d,
136*523fa7a6SAndroid Build Coastguard Worker    torch.nn.BatchNorm2d,
137*523fa7a6SAndroid Build Coastguard Worker]
138*523fa7a6SAndroid Build Coastguard Worker
139*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_IMPLICIT_Q_DQ_MODULES_SET = set(SUPPORTED_QUANT_MODULES)
140*523fa7a6SAndroid Build Coastguard Worker
141*523fa7a6SAndroid Build Coastguard Worker# Modules which support dynamic quantization
142*523fa7a6SAndroid Build Coastguard Worker# These already support dynamic shape.
143*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_DYN_QUANT_LINEAR_MODULES = [
144*523fa7a6SAndroid Build Coastguard Worker    torch.nn.Linear,
145*523fa7a6SAndroid Build Coastguard Worker    torch.nn.functional.linear,
146*523fa7a6SAndroid Build Coastguard Worker]
147*523fa7a6SAndroid Build Coastguard Worker
148*523fa7a6SAndroid Build Coastguard WorkerSUPPORTED_DYN_QUANT_MODULES = SUPPORTED_DYN_QUANT_LINEAR_MODULES
149*523fa7a6SAndroid Build Coastguard Worker
150*523fa7a6SAndroid Build Coastguard Worker# XNNPACK supports majority of shape dynamism, however some ops are
151*523fa7a6SAndroid Build Coastguard Worker# explicitly static, so we maintain a set here to exclude them from
152*523fa7a6SAndroid Build Coastguard Worker# dynamic shape support.
153*523fa7a6SAndroid Build Coastguard WorkerSTATIC_OPS = [
154*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.cat.default,
155*523fa7a6SAndroid Build Coastguard Worker    exir_ops.edge.aten.slice_copy.Tensor,
156*523fa7a6SAndroid Build Coastguard Worker]
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard WorkerSTATIC_MODULES = []
159