xref: /aosp_15_r20/external/executorch/extension/llm/export/builder.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# Providing builders for LLM models. These builders help user to build LLM
8# eager models, apply source transformations and quantization and export them to
9# ExecuTorch.
10
11# pyre-unsafe
12
13import logging
14from enum import Enum
15from typing import Any, Callable, Dict, List, Optional
16
17import torch
18from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
19    DuplicateDynamicQuantChainPass,
20)
21from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
22from executorch.exir import EdgeProgramManager
23from executorch.exir.backend.partitioner import Partitioner
24
25from executorch.exir.backend.utils import format_delegated_graph
26from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
27
28from executorch.exir.passes import MemoryPlanningPass
29from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
30from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
31
32from executorch.extension.export_util.utils import export_to_edge, save_pte_program
33from executorch.extension.llm.tokenizer.utils import get_tokenizer
34from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
35from torch.ao.quantization.quantizer import Quantizer
36from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
37from torch.export import export_for_training
38from torch.nn.attention import SDPBackend
39
40FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
41logging.basicConfig(level=logging.INFO, format=FORMAT)
42
43
44class DType(Enum):
45    fp32 = "fp32"
46    fp16 = "fp16"
47    bf16 = "bf16"
48
49    def to_torch_dtype(self) -> torch.dtype:
50        mapping = {
51            DType.fp32: torch.float32,
52            DType.fp16: torch.float16,
53            DType.bf16: torch.bfloat16,
54        }
55        if self not in mapping:
56            raise ValueError(f"Unsupported dtype {self}")
57        return mapping[self]
58
59
60class LLMEdgeManager:
61    """
62    Host a torch.nn.Module for LLM model and facilitates exporting to ExecuTorch.
63    """
64
65    def __init__(
66        self,
67        model,
68        modelname,
69        max_seq_len,
70        dtype,
71        use_kv_cache,
72        example_inputs,
73        example_kwarg_inputs: Optional[Dict] = None,
74        args: Optional[Any] = None,
75        enable_dynamic_shape: bool = False,
76        generate_full_logits: bool = False,
77        calibration_tasks: Optional[List[str]] = None,
78        calibration_limit: Optional[int] = None,
79        calibration_seq_length: Optional[int] = None,
80        calibration_data: Optional[str] = None,
81        tokenizer_path: Optional[str] = None,
82        verbose: bool = False,
83        metadata: Optional[dict] = None,
84        dynamic_shapes: Optional[Any] = None,
85    ):
86        self.model = model
87        # graph module returned from export()
88        self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None
89        self.modelname = modelname
90        self.max_seq_len = max_seq_len
91        self.dtype = dtype
92        self.example_inputs = example_inputs
93        self.example_kwarg_inputs = example_kwarg_inputs
94        self.use_kv_cache = use_kv_cache
95        self.generate_full_logits = generate_full_logits
96        self.enable_dynamic_shape = enable_dynamic_shape
97        self.verbose = verbose
98        self.metadata = metadata
99        self.applied_source_transforms = []
100        self.edge_manager: Optional[EdgeProgramManager] = None
101        self.export_program = None
102        self.output_dir = "."
103        self.dynamic_shapes = dynamic_shapes
104        self._saved_pte_filename = None
105        self.args = args
106        self.calibration_tasks = calibration_tasks
107        self.calibration_limit = calibration_limit
108        self.calibration_seq_length = calibration_seq_length
109        self.calibration_data = calibration_data
110        self.tokenizer_path = tokenizer_path
111
112    def set_output_dir(self, output_dir: str) -> "LLMEdgeManager":
113        """
114        Set the directory where the .pte file will be saved.
115        Args:
116            output_dir (str): The directory to store the .pte file.
117        """
118        self.output_dir = output_dir
119        return self
120
121    def to_dtype(self, dtype_override: Optional[DType]) -> "LLMEdgeManager":
122        """
123        Convert the model to the specified dtype.
124        Args:
125            dtype_override (Optional[DType]): Override the dtype of the model.
126        """
127        assert not dtype_override or isinstance(
128            dtype_override, DType
129        ), "Override dtype needs to be of type <DType>"
130        if dtype_override is not None and dtype_override != self.dtype:
131            torch_dtype = dtype_override.to_torch_dtype()
132            logging.info(f"model.to {torch_dtype}")
133            self.model = self.model.to(dtype=torch_dtype)
134            self.dtype = dtype_override
135        return self
136
137    def source_transform(
138        self, transforms: List[Callable[[torch.nn.Module], torch.nn.Module]]
139    ) -> "LLMEdgeManager":
140        """
141        Apply source transforms to the model. The transforms are callables that
142        takes nn.Module as input and returns nn.Module.
143        Args:
144            transforms (List[Callable[[torch.nn.Module], torch.nn.Module]]): A
145                list of source transforms.
146        """
147        for transform in transforms:
148            self.model = transform(self.model)
149        self.applied_source_transforms.extend(transforms)
150
151        if self.verbose:
152            logging.info(f"Applied source transforms: {self.applied_source_transforms}")
153        logging.info(f"Model after source transforms: {self.model}")
154        return self
155
156    def _get_dynamic_shape(self) -> Any:
157        if self.dynamic_shapes:
158            return self.dynamic_shapes
159
160        dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)
161
162        if not self.use_kv_cache:
163            # Only one input argument: tokens
164            self.dynamic_shapes = ({1: dim},)
165        elif self.enable_dynamic_shape:
166            # Two input arguments: tokens and input_pos but input_pos is static shape
167            self.dynamic_shapes = ({1: dim}, {0: 1})
168        else:
169            # Two input arguments: tokens and input_pos but both are of static shape
170            self.dynamic_shapes = None
171        return self.dynamic_shapes
172
173    def _get_edge_config(self) -> EdgeCompileConfig:
174        edge_config = EdgeCompileConfig(
175            _check_ir_validity=False,
176            _skip_type_promotion=bool(self.dtype == DType.fp16),
177            _skip_dim_order=True,
178        )
179        return edge_config
180
181    def export(self) -> "LLMEdgeManager":
182        dynamic_shape = self._get_dynamic_shape()
183        # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
184        # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
185        with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
186            if hasattr(self.args, "qnn") and self.args.qnn:
187                # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
188                # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
189                exported_module = torch.export.export(
190                    self.model,
191                    self.example_inputs,
192                    self.example_kwarg_inputs,
193                    dynamic_shapes=dynamic_shape,
194                    strict=True,
195                )
196            else:
197                logging.info("Exporting with:")
198                logging.info(f"inputs: {self.example_inputs}")
199                logging.info(f"kwargs: {self.example_kwarg_inputs}")
200                logging.info(f"dynamic shapes: {dynamic_shape}")
201                exported_module = export_for_training(
202                    self.model,
203                    self.example_inputs,
204                    kwargs=self.example_kwarg_inputs,
205                    dynamic_shapes=dynamic_shape,
206                )
207            # pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
208            #  `Module`.
209            self.pre_autograd_graph_module = exported_module.module()
210            if hasattr(self.args, "export_only") and self.args.export_only:
211                torch.export.save(exported_module, self.args.output_name)
212
213        return self
214
215    def pt2e_calibrate(
216        self,
217        prepared_module,
218        calibration_tasks,
219        calibration_limit,
220        calibration_seq_length,
221        calibration_data,
222        tokenizer_path,
223    ):
224        logging.info("Run calibration...")
225        try:
226            from executorch.examples.models.llama.eval_llama_lib import (
227                GraphModuleEvalWrapper,
228            )
229            from lm_eval.evaluator import simple_evaluate
230        except ImportError:
231            raise ImportError(
232                "Please install the llm eval dependency via examples/models/llama/install_requirements.sh"
233            )
234
235        tokenizer = get_tokenizer(tokenizer_path)
236
237        def calibrate_template(
238            module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int
239        ):
240            # TODO: change criteria & support batch inputs if necessary
241            pos = torch.tensor(0, dtype=torch.int64)
242            token_list = tokenizer.encode(prompts, bos=True, eos=False)
243
244            with torch.no_grad():
245                while token_list[-1] != tokenizer.eos_id and pos < max_len:
246                    logits = module(
247                        torch.full((1, 1), token_list[pos]),
248                        torch.tensor((pos,)),
249                    )
250                    pos += 1
251                    if pos >= len(token_list):
252                        if self.generate_full_logits:
253                            token_list.append(
254                                torch.argmax(logits[:, -1], dim=-1).item()
255                            )
256                        else:
257                            token_list.append(torch.argmax(logits[:], dim=-1).item())
258
259        calibrate_template(
260            module=prepared_module,
261            tokenizer=tokenizer,
262            prompts=calibration_data,
263            max_len=calibration_seq_length,
264        )
265
266        eval_wrapper = GraphModuleEvalWrapper(
267            model=prepared_module,
268            tokenizer=tokenizer,
269            max_seq_length=calibration_seq_length,
270            use_kv_cache=self.use_kv_cache,
271            generate_full_logits=self.generate_full_logits,
272            enable_dynamic_shape=self.enable_dynamic_shape,
273        )
274
275        # Evaluate the model
276        with torch.no_grad():
277            eval_results = simple_evaluate(
278                model=eval_wrapper,
279                tasks=calibration_tasks,
280                limit=calibration_limit,
281            )
282
283        for task, res in eval_results["results"].items():
284            print(f"{task}: {res}")
285        logging.info("Calibration finish...")
286
287    def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
288        """
289        Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model.
290        Args:
291            quantizers (Optional[List[Quantizer]]): A list of quantizers.
292        """
293        assert (
294            self.edge_manager is None
295        ), "export_to_edge is already called, please call pt2e_quantize before export_to_edge"
296        logging.info(f"Using pt2e {quantizers} to quantizing the model...")
297
298        # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
299        # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
300        if quantizers:
301            with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
302                if self.verbose:
303                    logging.info(f"Applied quantizers: {quantizers}")
304                composed_quantizer = ComposableQuantizer(quantizers)
305                assert (
306                    self.pre_autograd_graph_module is not None
307                ), "Please run export() first"
308                m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
309                logging.info(
310                    f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
311                )
312                # Calibrate
313                if (
314                    self.calibration_tasks is not None
315                    and self.calibration_limit is not None
316                    and self.calibration_seq_length is not None
317                    and self.calibration_data is not None
318                    and self.tokenizer_path is not None
319                ):
320                    logging.info(
321                        f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
322                    )
323                    self.pt2e_calibrate(
324                        prepared_module=m,
325                        calibration_tasks=self.calibration_tasks,
326                        calibration_limit=self.calibration_limit,
327                        calibration_seq_length=self.calibration_seq_length,
328                        calibration_data=self.calibration_data,
329                        tokenizer_path=self.tokenizer_path,
330                    )
331                else:
332                    logging.info(
333                        "No calibration provided, using dummy input to calibrate..."
334                    )
335                    m(*self.example_inputs)
336                m = convert_pt2e(m)
337                DuplicateDynamicQuantChainPass()(m)
338                self.pre_autograd_graph_module = m
339            return self
340        else:
341            logging.info("No quantizer provided, passing...")
342            return self
343
344    def export_to_edge(self) -> "LLMEdgeManager":
345        """
346        Export the model to Edge dialect and retrieve a LLMEdgeManager.
347        """
348        dynamic_shape = self._get_dynamic_shape()
349        edge_config = self._get_edge_config()
350
351        # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
352        # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
353        with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
354            if self.pre_autograd_graph_module is None:
355                # Run export() if it didn't run
356                self.export()
357            self.edge_manager = export_to_edge(
358                self.pre_autograd_graph_module,  # pyre-fixme[6]
359                self.example_inputs,
360                example_kwarg_inputs=self.example_kwarg_inputs,
361                dynamic_shapes=dynamic_shape,
362                edge_constant_methods=self.metadata,
363                edge_compile_config=edge_config,
364                verbose=self.verbose,
365            )
366        return self
367
368    def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager":
369        """
370        Partition the model and lower to different backends. The signature is
371        aligned with the signature of `to_backend` method of EdgeManager.
372        Args:
373            partitioners (Optional[List[Partitioner]]): One or more
374                partitioner to be sent to EdgeManager.to_backend().
375        """
376        if partitioners is None:
377            logging.info("No partitioner provided, passing...")
378        else:
379            for partitioner in partitioners:
380                if partitioner is not None:
381                    assert (
382                        self.edge_manager is not None
383                    ), "Need to run export_to_edge() first"
384                    self.edge_manager = self.edge_manager.to_backend(partitioner)
385                    if self.verbose:
386                        logging.info(
387                            format_delegated_graph(
388                                self.edge_manager.exported_program().graph_module
389                            )
390                        )
391                        logging.info(f"Applied partitioners: {partitioner}")
392                else:
393                    logging.info("No partitioner provided, passing...")
394                    continue
395
396        return self
397
398    def to_executorch(self) -> "LLMEdgeManager":
399        """
400        Lower the model to executorch and get an ExecutorchProgram.
401        """
402        assert self.edge_manager, "Need to run export_to_edge() first"
403        self.export_program = self.edge_manager.to_executorch(
404            ExecutorchBackendConfig(
405                extract_delegate_segments=True,
406                passes=[
407                    # If there are Linear operations left in the graph, let's execute
408                    # them with the optimized op_linear rather than materializing a
409                    # transpose followed by a regular op_mm.
410                    ConvertToLinearPass(),
411                    QuantFusionPass(),
412                ],
413                memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
414                sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
415            )
416        )
417        logging.info(
418            "Required memory for activation in bytes: {}".format(
419                self.export_program._emitter_output.program.execution_plan[
420                    0
421                ].non_const_buffer_sizes
422            ),
423        )
424        return self
425
426    def save_to_pte(self, output_name: str) -> None:
427        """
428        Save the model to a .pte file.
429        Args:
430            output_name (Optional[str]): The name of the .pte file.
431        """
432        assert output_name, "Need a valid output name"
433        filename = save_pte_program(self.export_program, output_name, self.output_dir)
434        self._saved_pte_filename = filename
435
436    def get_saved_pte_filename(self) -> Optional[str]:
437        """
438        Return the filename of the most recenet saved .pte file. Return None if the model is not saved.
439        """
440        return self._saved_pte_filename
441