1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# Providing builders for LLM models. These builders help user to build LLM 8# eager models, apply source transformations and quantization and export them to 9# ExecuTorch. 10 11# pyre-unsafe 12 13import logging 14from enum import Enum 15from typing import Any, Callable, Dict, List, Optional 16 17import torch 18from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( 19 DuplicateDynamicQuantChainPass, 20) 21from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass 22from executorch.exir import EdgeProgramManager 23from executorch.exir.backend.partitioner import Partitioner 24 25from executorch.exir.backend.utils import format_delegated_graph 26from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig 27 28from executorch.exir.passes import MemoryPlanningPass 29from executorch.exir.passes.quant_fusion_pass import QuantFusionPass 30from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass 31 32from executorch.extension.export_util.utils import export_to_edge, save_pte_program 33from executorch.extension.llm.tokenizer.utils import get_tokenizer 34from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 35from torch.ao.quantization.quantizer import Quantizer 36from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer 37from torch.export import export_for_training 38from torch.nn.attention import SDPBackend 39 40FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 41logging.basicConfig(level=logging.INFO, format=FORMAT) 42 43 44class DType(Enum): 45 fp32 = "fp32" 46 fp16 = "fp16" 47 bf16 = "bf16" 48 49 def to_torch_dtype(self) -> torch.dtype: 50 mapping = { 51 DType.fp32: torch.float32, 52 DType.fp16: torch.float16, 53 DType.bf16: torch.bfloat16, 54 } 55 if self not in mapping: 56 raise ValueError(f"Unsupported dtype {self}") 57 return mapping[self] 58 59 60class LLMEdgeManager: 61 """ 62 Host a torch.nn.Module for LLM model and facilitates exporting to ExecuTorch. 63 """ 64 65 def __init__( 66 self, 67 model, 68 modelname, 69 max_seq_len, 70 dtype, 71 use_kv_cache, 72 example_inputs, 73 example_kwarg_inputs: Optional[Dict] = None, 74 args: Optional[Any] = None, 75 enable_dynamic_shape: bool = False, 76 generate_full_logits: bool = False, 77 calibration_tasks: Optional[List[str]] = None, 78 calibration_limit: Optional[int] = None, 79 calibration_seq_length: Optional[int] = None, 80 calibration_data: Optional[str] = None, 81 tokenizer_path: Optional[str] = None, 82 verbose: bool = False, 83 metadata: Optional[dict] = None, 84 dynamic_shapes: Optional[Any] = None, 85 ): 86 self.model = model 87 # graph module returned from export() 88 self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None 89 self.modelname = modelname 90 self.max_seq_len = max_seq_len 91 self.dtype = dtype 92 self.example_inputs = example_inputs 93 self.example_kwarg_inputs = example_kwarg_inputs 94 self.use_kv_cache = use_kv_cache 95 self.generate_full_logits = generate_full_logits 96 self.enable_dynamic_shape = enable_dynamic_shape 97 self.verbose = verbose 98 self.metadata = metadata 99 self.applied_source_transforms = [] 100 self.edge_manager: Optional[EdgeProgramManager] = None 101 self.export_program = None 102 self.output_dir = "." 103 self.dynamic_shapes = dynamic_shapes 104 self._saved_pte_filename = None 105 self.args = args 106 self.calibration_tasks = calibration_tasks 107 self.calibration_limit = calibration_limit 108 self.calibration_seq_length = calibration_seq_length 109 self.calibration_data = calibration_data 110 self.tokenizer_path = tokenizer_path 111 112 def set_output_dir(self, output_dir: str) -> "LLMEdgeManager": 113 """ 114 Set the directory where the .pte file will be saved. 115 Args: 116 output_dir (str): The directory to store the .pte file. 117 """ 118 self.output_dir = output_dir 119 return self 120 121 def to_dtype(self, dtype_override: Optional[DType]) -> "LLMEdgeManager": 122 """ 123 Convert the model to the specified dtype. 124 Args: 125 dtype_override (Optional[DType]): Override the dtype of the model. 126 """ 127 assert not dtype_override or isinstance( 128 dtype_override, DType 129 ), "Override dtype needs to be of type <DType>" 130 if dtype_override is not None and dtype_override != self.dtype: 131 torch_dtype = dtype_override.to_torch_dtype() 132 logging.info(f"model.to {torch_dtype}") 133 self.model = self.model.to(dtype=torch_dtype) 134 self.dtype = dtype_override 135 return self 136 137 def source_transform( 138 self, transforms: List[Callable[[torch.nn.Module], torch.nn.Module]] 139 ) -> "LLMEdgeManager": 140 """ 141 Apply source transforms to the model. The transforms are callables that 142 takes nn.Module as input and returns nn.Module. 143 Args: 144 transforms (List[Callable[[torch.nn.Module], torch.nn.Module]]): A 145 list of source transforms. 146 """ 147 for transform in transforms: 148 self.model = transform(self.model) 149 self.applied_source_transforms.extend(transforms) 150 151 if self.verbose: 152 logging.info(f"Applied source transforms: {self.applied_source_transforms}") 153 logging.info(f"Model after source transforms: {self.model}") 154 return self 155 156 def _get_dynamic_shape(self) -> Any: 157 if self.dynamic_shapes: 158 return self.dynamic_shapes 159 160 dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1) 161 162 if not self.use_kv_cache: 163 # Only one input argument: tokens 164 self.dynamic_shapes = ({1: dim},) 165 elif self.enable_dynamic_shape: 166 # Two input arguments: tokens and input_pos but input_pos is static shape 167 self.dynamic_shapes = ({1: dim}, {0: 1}) 168 else: 169 # Two input arguments: tokens and input_pos but both are of static shape 170 self.dynamic_shapes = None 171 return self.dynamic_shapes 172 173 def _get_edge_config(self) -> EdgeCompileConfig: 174 edge_config = EdgeCompileConfig( 175 _check_ir_validity=False, 176 _skip_type_promotion=bool(self.dtype == DType.fp16), 177 _skip_dim_order=True, 178 ) 179 return edge_config 180 181 def export(self) -> "LLMEdgeManager": 182 dynamic_shape = self._get_dynamic_shape() 183 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing 184 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) 185 with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): 186 if hasattr(self.args, "qnn") and self.args.qnn: 187 # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a 188 # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details 189 exported_module = torch.export.export( 190 self.model, 191 self.example_inputs, 192 self.example_kwarg_inputs, 193 dynamic_shapes=dynamic_shape, 194 strict=True, 195 ) 196 else: 197 logging.info("Exporting with:") 198 logging.info(f"inputs: {self.example_inputs}") 199 logging.info(f"kwargs: {self.example_kwarg_inputs}") 200 logging.info(f"dynamic shapes: {dynamic_shape}") 201 exported_module = export_for_training( 202 self.model, 203 self.example_inputs, 204 kwargs=self.example_kwarg_inputs, 205 dynamic_shapes=dynamic_shape, 206 ) 207 # pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as 208 # `Module`. 209 self.pre_autograd_graph_module = exported_module.module() 210 if hasattr(self.args, "export_only") and self.args.export_only: 211 torch.export.save(exported_module, self.args.output_name) 212 213 return self 214 215 def pt2e_calibrate( 216 self, 217 prepared_module, 218 calibration_tasks, 219 calibration_limit, 220 calibration_seq_length, 221 calibration_data, 222 tokenizer_path, 223 ): 224 logging.info("Run calibration...") 225 try: 226 from executorch.examples.models.llama.eval_llama_lib import ( 227 GraphModuleEvalWrapper, 228 ) 229 from lm_eval.evaluator import simple_evaluate 230 except ImportError: 231 raise ImportError( 232 "Please install the llm eval dependency via examples/models/llama/install_requirements.sh" 233 ) 234 235 tokenizer = get_tokenizer(tokenizer_path) 236 237 def calibrate_template( 238 module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int 239 ): 240 # TODO: change criteria & support batch inputs if necessary 241 pos = torch.tensor(0, dtype=torch.int64) 242 token_list = tokenizer.encode(prompts, bos=True, eos=False) 243 244 with torch.no_grad(): 245 while token_list[-1] != tokenizer.eos_id and pos < max_len: 246 logits = module( 247 torch.full((1, 1), token_list[pos]), 248 torch.tensor((pos,)), 249 ) 250 pos += 1 251 if pos >= len(token_list): 252 if self.generate_full_logits: 253 token_list.append( 254 torch.argmax(logits[:, -1], dim=-1).item() 255 ) 256 else: 257 token_list.append(torch.argmax(logits[:], dim=-1).item()) 258 259 calibrate_template( 260 module=prepared_module, 261 tokenizer=tokenizer, 262 prompts=calibration_data, 263 max_len=calibration_seq_length, 264 ) 265 266 eval_wrapper = GraphModuleEvalWrapper( 267 model=prepared_module, 268 tokenizer=tokenizer, 269 max_seq_length=calibration_seq_length, 270 use_kv_cache=self.use_kv_cache, 271 generate_full_logits=self.generate_full_logits, 272 enable_dynamic_shape=self.enable_dynamic_shape, 273 ) 274 275 # Evaluate the model 276 with torch.no_grad(): 277 eval_results = simple_evaluate( 278 model=eval_wrapper, 279 tasks=calibration_tasks, 280 limit=calibration_limit, 281 ) 282 283 for task, res in eval_results["results"].items(): 284 print(f"{task}: {res}") 285 logging.info("Calibration finish...") 286 287 def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager": 288 """ 289 Quantize the model via pt2e flow and retrieve LLMEdgeManager including the quantized model. 290 Args: 291 quantizers (Optional[List[Quantizer]]): A list of quantizers. 292 """ 293 assert ( 294 self.edge_manager is None 295 ), "export_to_edge is already called, please call pt2e_quantize before export_to_edge" 296 logging.info(f"Using pt2e {quantizers} to quantizing the model...") 297 298 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing 299 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) 300 if quantizers: 301 with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): 302 if self.verbose: 303 logging.info(f"Applied quantizers: {quantizers}") 304 composed_quantizer = ComposableQuantizer(quantizers) 305 assert ( 306 self.pre_autograd_graph_module is not None 307 ), "Please run export() first" 308 m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) 309 logging.info( 310 f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" 311 ) 312 # Calibrate 313 if ( 314 self.calibration_tasks is not None 315 and self.calibration_limit is not None 316 and self.calibration_seq_length is not None 317 and self.calibration_data is not None 318 and self.tokenizer_path is not None 319 ): 320 logging.info( 321 f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}" 322 ) 323 self.pt2e_calibrate( 324 prepared_module=m, 325 calibration_tasks=self.calibration_tasks, 326 calibration_limit=self.calibration_limit, 327 calibration_seq_length=self.calibration_seq_length, 328 calibration_data=self.calibration_data, 329 tokenizer_path=self.tokenizer_path, 330 ) 331 else: 332 logging.info( 333 "No calibration provided, using dummy input to calibrate..." 334 ) 335 m(*self.example_inputs) 336 m = convert_pt2e(m) 337 DuplicateDynamicQuantChainPass()(m) 338 self.pre_autograd_graph_module = m 339 return self 340 else: 341 logging.info("No quantizer provided, passing...") 342 return self 343 344 def export_to_edge(self) -> "LLMEdgeManager": 345 """ 346 Export the model to Edge dialect and retrieve a LLMEdgeManager. 347 """ 348 dynamic_shape = self._get_dynamic_shape() 349 edge_config = self._get_edge_config() 350 351 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing 352 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) 353 with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): 354 if self.pre_autograd_graph_module is None: 355 # Run export() if it didn't run 356 self.export() 357 self.edge_manager = export_to_edge( 358 self.pre_autograd_graph_module, # pyre-fixme[6] 359 self.example_inputs, 360 example_kwarg_inputs=self.example_kwarg_inputs, 361 dynamic_shapes=dynamic_shape, 362 edge_constant_methods=self.metadata, 363 edge_compile_config=edge_config, 364 verbose=self.verbose, 365 ) 366 return self 367 368 def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager": 369 """ 370 Partition the model and lower to different backends. The signature is 371 aligned with the signature of `to_backend` method of EdgeManager. 372 Args: 373 partitioners (Optional[List[Partitioner]]): One or more 374 partitioner to be sent to EdgeManager.to_backend(). 375 """ 376 if partitioners is None: 377 logging.info("No partitioner provided, passing...") 378 else: 379 for partitioner in partitioners: 380 if partitioner is not None: 381 assert ( 382 self.edge_manager is not None 383 ), "Need to run export_to_edge() first" 384 self.edge_manager = self.edge_manager.to_backend(partitioner) 385 if self.verbose: 386 logging.info( 387 format_delegated_graph( 388 self.edge_manager.exported_program().graph_module 389 ) 390 ) 391 logging.info(f"Applied partitioners: {partitioner}") 392 else: 393 logging.info("No partitioner provided, passing...") 394 continue 395 396 return self 397 398 def to_executorch(self) -> "LLMEdgeManager": 399 """ 400 Lower the model to executorch and get an ExecutorchProgram. 401 """ 402 assert self.edge_manager, "Need to run export_to_edge() first" 403 self.export_program = self.edge_manager.to_executorch( 404 ExecutorchBackendConfig( 405 extract_delegate_segments=True, 406 passes=[ 407 # If there are Linear operations left in the graph, let's execute 408 # them with the optimized op_linear rather than materializing a 409 # transpose followed by a regular op_mm. 410 ConvertToLinearPass(), 411 QuantFusionPass(), 412 ], 413 memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), 414 sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), 415 ) 416 ) 417 logging.info( 418 "Required memory for activation in bytes: {}".format( 419 self.export_program._emitter_output.program.execution_plan[ 420 0 421 ].non_const_buffer_sizes 422 ), 423 ) 424 return self 425 426 def save_to_pte(self, output_name: str) -> None: 427 """ 428 Save the model to a .pte file. 429 Args: 430 output_name (Optional[str]): The name of the .pte file. 431 """ 432 assert output_name, "Need a valid output name" 433 filename = save_pte_program(self.export_program, output_name, self.output_dir) 434 self._saved_pte_filename = filename 435 436 def get_saved_pte_filename(self) -> Optional[str]: 437 """ 438 Return the filename of the most recenet saved .pte file. Return None if the model is not saved. 439 """ 440 return self._saved_pte_filename 441