xref: /aosp_15_r20/external/executorch/extension/llm/export/partitioner_lib.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Optional
8
9
10def get_xnnpack_partitioner(dynamic_quant_only_partitioner: bool = True):
11    """
12    Returns the XNNPACK partitioner.
13
14    @arg dynamic_quant_only_partitioner:
15        This is enabled by default to keep BC.
16        If dynamic_quant_only_partitioner is True, then only dynamically quantized
17        linear layers will be partitioned.
18        Else, anything which can be will be partitioned greedily.
19    """
20    from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
21        XnnpackDynamicallyQuantizedPartitioner,
22        XnnpackPartitioner,
23    )
24
25    if dynamic_quant_only_partitioner:
26        # Following changes due to.
27        # 1. We need dynamically quantized partitioner for both pt2e_quantize options
28        #    as well as "qmode 8da4w" which is also dynamic quantizes linear layers.
29        # 2. XNNPACK partitioner seems to result in seg fault for non dqlinear ops.
30        return XnnpackDynamicallyQuantizedPartitioner()
31    return XnnpackPartitioner()
32
33
34def get_vulkan_partitioner(
35    dtype_override: Optional[str] = None, enable_dynamic_shape: bool = False
36):
37    assert (
38        dtype_override == "fp32" or dtype_override is None
39    ), "Vulkan backend does not support non fp32 dtypes at the moment"
40    from executorch.backends.vulkan.partitioner.vulkan_partitioner import (
41        VulkanPartitioner,
42    )
43
44    return VulkanPartitioner({"require_dynamic_shapes": enable_dynamic_shape})
45
46
47def get_mps_partitioner(use_kv_cache: bool = False):
48    from executorch.exir.backend.backend_details import CompileSpec
49
50    assert (
51        use_kv_cache is True
52    ), "MPS backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment"
53    try:
54        # pyre-ignore Undefined import [21]: Could not find a module corresponding to import `executorch.backends.apple.mps.partition.mps_partitioner`.
55        from executorch.backends.apple.mps.partition.mps_partitioner import (
56            MPSPartitioner,
57        )
58    except ImportError:
59        raise ImportError(
60            "Please install the MPS backend follwing https://pytorch.org/executorch/main/build-run-mps.html"
61        )
62
63    compile_specs = [CompileSpec("use_fp16", bytes([True]))]
64    return MPSPartitioner(compile_specs)  # pyre-fixme[16]
65
66
67def get_coreml_partitioner(
68    ios: int = 15,
69    embedding_quantize: Optional[str] = None,
70    pt2e_quantize: Optional[str] = None,
71    coreml_quantize: Optional[str] = None,
72):
73    try:
74        import coremltools as ct
75        from executorch.backends.apple.coreml.compiler import (  # pyre-ignore
76            CoreMLBackend,
77        )
78        from executorch.backends.apple.coreml.partition import (  # pyre-ignore
79            CoreMLPartitioner,
80        )
81    except ImportError:
82        raise ImportError(
83            "Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html"
84        )
85
86    def _validate_ios_version() -> None:
87        assert ios in (15, 16, 17, 18)
88
89        if embedding_quantize is not None and ios < 18:
90            raise ValueError(
91                "In Core ML, per-block quantization is introduced in iOS 18"
92            )
93
94        use_quantization = pt2e_quantize is not None or coreml_quantize is not None
95        if use_quantization and ios < 16:
96            raise ValueError("In Core ML, quantization is introduced in iOS 16")
97
98        use_8a = (pt2e_quantize is not None and "8a" in pt2e_quantize) or (
99            coreml_quantize is not None and "8a" in coreml_quantize
100        )
101        if use_8a and ios < 17:
102            raise ValueError(
103                "In Core ML, 8-bit activation quantization is introduced in iOS 17"
104            )
105
106        use_4w = (pt2e_quantize is not None and "4w" in pt2e_quantize) or (
107            coreml_quantize is not None and "4w" in coreml_quantize
108        )
109        if use_4w and ios < 18:
110            raise ValueError(
111                "In Core ML, 4-bit weight compression is introduced in iOS 18"
112            )
113
114    _validate_ios_version()
115
116    minimum_deployment_target = {
117        15: ct.target.iOS15,
118        16: ct.target.iOS16,
119        17: ct.target.iOS17,
120        18: ct.target.iOS18,
121    }[ios]
122    op_linear_quantizer_config = None
123    if coreml_quantize == "b4w":
124        op_linear_quantizer_config = {
125            "mode": "linear_symmetric",
126            "dtype": "int4",
127            "granularity": "per_block",
128            "block_size": 32,
129            "weight_threshold": 512,
130        }
131    compile_specs = CoreMLBackend.generate_compile_specs(  # pyre-fixme[16]
132        minimum_deployment_target=minimum_deployment_target,
133        compute_precision=ct.precision(ct.precision.FLOAT16.value),
134        # using `ComputeUnit.ALL` can increase the model load time, default to `ComputeUnit.CPU_AND_GPU`
135        compute_unit=ct.ComputeUnit[ct.ComputeUnit.CPU_AND_GPU.name.upper()],
136        model_type=CoreMLBackend.MODEL_TYPE.MODEL,  # pyre-fixme[16]
137        op_linear_quantizer_config=op_linear_quantizer_config,
138    )
139
140    take_over_mutable_buffer = minimum_deployment_target >= ct.target.iOS18
141
142    return CoreMLPartitioner(  # pyre-fixme[16]
143        compile_specs=compile_specs,
144        take_over_mutable_buffer=take_over_mutable_buffer,
145    )
146
147
148def get_qnn_partitioner(
149    use_kv_cache: bool = False,
150    pt2e_quantize: Optional[str] = None,
151    num_sharding: int = 0,
152    soc_model: str = "SM8650",  # default to SM8650
153):
154    assert (
155        use_kv_cache is True
156    ), "Qualcomm backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment"
157    try:
158        # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.partition.qnn_partitioner`
159        from executorch.backends.qualcomm.partition.qnn_partitioner import (
160            QnnPartitioner,
161        )
162
163        # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.serialization.qc_schema`
164        from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
165
166        # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
167        from executorch.backends.qualcomm.utils.utils import (
168            generate_htp_compiler_spec,
169            generate_qnn_executorch_compiler_spec,
170        )
171    except ImportError:
172        raise ImportError(
173            "Please install the Qualcomm backend following https://pytorch.org/executorch/main/build-run-qualcomm-ai-engine-direct-backend.html"
174        )
175
176    use_fp16 = True
177    skip_node_op_set = {"llama.fallback.default"}
178    if pt2e_quantize is not None:
179        use_fp16 = False
180
181    return QnnPartitioner(  # pyre-fixme[16]
182        generate_qnn_executorch_compiler_spec(  # pyre-fixme[16]
183            soc_model=getattr(QcomChipset, soc_model),  # pyre-fixme[16]
184            # pyre-fixme[16]
185            backend_options=generate_htp_compiler_spec(
186                use_fp16=use_fp16,
187                use_multi_contexts=num_sharding > 0,
188            ),
189            debug=False,
190            saver=False,
191        ),
192        skip_node_id_set={},
193        skip_node_op_set=skip_node_op_set,
194    )
195