1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7 8import argparse 9 10from typing import Optional, Union 11 12import torch 13from executorch.examples.models.llama.export_llama_lib import ( 14 get_quantizer_and_quant_params, 15) 16from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken 17 18from executorch.extension.llm.export.builder import LLMEdgeManager 19from executorch.extension.llm.tokenizer.tokenizer import ( 20 Tokenizer as SentencePieceTokenizer, 21) 22from executorch.extension.llm.tokenizer.utils import get_tokenizer 23from lm_eval.evaluator import simple_evaluate 24 25from .evaluate.eager_eval import EagerEvalWrapper 26 27from .export_llama_lib import ( 28 _prepare_for_llama_export, 29 build_args_parser as _build_args_parser, 30) 31 32 33class GraphModuleEvalWrapper(EagerEvalWrapper): 34 """ 35 A wrapper class for ExecuTorch py-binded integration with the 36 lm-evaluation-harness library. 37 """ 38 39 def __init__( 40 self, 41 model: torch.fx.GraphModule, 42 tokenizer: Union[SentencePieceTokenizer, Tiktoken], 43 max_seq_length: Optional[int] = None, 44 use_kv_cache: bool = False, 45 generate_full_logits: bool = False, 46 enable_dynamic_shape: bool = True, 47 ): 48 super().__init__( 49 model=model, tokenizer=tokenizer, max_seq_length=max_seq_length 50 ) 51 self._model = model.to(self.device) 52 self._use_kv_cache = use_kv_cache 53 self._generate_full_logits = generate_full_logits 54 self._enable_dynamic_shape = enable_dynamic_shape 55 56 def _model_call(self, inps): 57 if self._use_kv_cache: 58 if not self._enable_dynamic_shape: 59 # graph module exported without dynamic shape won't work with a different shape. 60 # And we have to do single token prefill here. 61 result_logits = [] 62 for pos in range(inps.shape[-1]): 63 pos_tensor = torch.tensor([pos], dtype=torch.int64) 64 logits = self._model(inps[:, pos : pos + 1], pos_tensor) 65 result_logits.append(logits) 66 if self._generate_full_logits: 67 return torch.cat(result_logits, dim=1) 68 else: 69 return torch.stack(result_logits, dim=1) 70 else: 71 pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) 72 # Batch process the whole sequence. 73 logits = self._model(inps[:, : self._max_seq_length], pos_tensor) 74 return logits 75 76 else: 77 return self._model(inps) 78 79 def _model_generate(self, context, max_length, eos_token_id): 80 raise Exception("unimplemented") 81 82 83class ETPybindEvalWrapper(EagerEvalWrapper): 84 """ 85 A wrapper class for ExecuTorch py-binded integration with the 86 lm-evaluation-harness library. 87 """ 88 89 def __init__( 90 self, 91 model: str, 92 tokenizer: Union[SentencePieceTokenizer, Tiktoken], 93 max_seq_length: Optional[int] = None, 94 ): 95 super().__init__(None, tokenizer, max_seq_length) # pyre-ignore 96 self._model = model # Expects model to be path to a .pte file 97 98 from executorch.extension.pybindings.portable_lib import _load_for_executorch 99 100 # Load custom ops and quantized ops. 101 from executorch.extension.pybindings import portable_lib # noqa # usort: skip 102 103 # Note: import this after portable_lib 104 from executorch.extension.llm.custom_ops import ( # noqa 105 sdpa_with_kv_cache, # usort: skip 106 ) 107 from executorch.kernels import quantized # noqa 108 109 self._et_model = _load_for_executorch(self._model) 110 self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0] # pyre-ignore 111 112 def _model_call(self, inps): 113 # Given inps (tokens), return the logits from a single forward call 114 # inps: Tensor of shape (1, max_seq_len - 1) 115 # logits: Tensor of shape (1, max_seq_len - 1, vocab_size) 116 result = [] 117 if self._use_kv_cache: 118 pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) 119 result = self._et_model.forward( 120 (inps[:, : self._max_seq_length], pos_tensor) 121 ) 122 else: 123 result = self._et_model.forward((inps,)) 124 if result[0].dim() != 3: 125 raise ValueError( 126 f"Dim of logits must be 3 for evaluation. Got {result[0].dim()} here. Add --generate_full_logits in export_llama to generate a pte file with full logits." 127 ) 128 return result[0] 129 130 131class ETRunnerEvalWrapper(EagerEvalWrapper): 132 """ 133 A wrapper class for ExecuTorch Runtime integration with the 134 lm-evaluation-harness library. 135 """ 136 137 def __init__( 138 self, 139 model: str, 140 tokenizer: Union[SentencePieceTokenizer, Tiktoken], 141 tokenizer_bin: str, 142 max_seq_length: Optional[int] = None, 143 ): 144 super().__init__(None, tokenizer, max_seq_length) # pyre-ignore 145 self._model = model 146 self._tokenizer_bin = tokenizer_bin 147 148 def _model_call(self, inps): 149 # Given inps (tokens), return the logits from a single 150 # forward call 151 152 # Example: 153 # inps: Tensor of shape (1, N) 154 # logits: Tensor of shape (1, N, vocab_size) 155 pass 156 157 158def gen_eval_wrapper( 159 model_name: str, 160 args: argparse.ArgumentParser, 161): 162 """ 163 Generates a wrapper interface around the provided model and tokenizer for 164 the lm-evaluation-harness library. 165 166 Returns: 167 eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. 168 """ 169 tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore 170 171 # ExecuTorch Binary Evaluation 172 if (model := args.pte) is not None: # pyre-ignore 173 if (tokenizer_bin := args.tokenizer_bin) is not None: # pyre-ignore 174 # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime 175 return ETRunnerEvalWrapper( 176 model=model, 177 tokenizer=tokenizer, 178 tokenizer_bin=tokenizer_bin, 179 max_seq_length=args.max_seq_length, # pyre-ignore 180 ) 181 182 # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings 183 return ETPybindEvalWrapper( 184 model=model, 185 tokenizer=tokenizer, 186 # Exported model takes at most (max_seq_length - 1) tokens. 187 # Note that the eager model takes at most max_seq_length tokens. 188 max_seq_length=args.max_seq_length - 1, 189 ) 190 191 pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) 192 # GPTFastEvalWrapper: Create a wrapper around a pre-exported model 193 manager: LLMEdgeManager = _prepare_for_llama_export(args) 194 195 if len(quantizers) != 0: 196 manager = manager.export().pt2e_quantize(quantizers) 197 model = ( 198 manager.pre_autograd_graph_module.to(device="cuda") # pyre-ignore 199 if torch.cuda.is_available() 200 else manager.pre_autograd_graph_module.to(device="cpu") 201 ) 202 return GraphModuleEvalWrapper( 203 model=model, 204 tokenizer=tokenizer, 205 max_seq_length=args.max_seq_length, 206 use_kv_cache=args.use_kv_cache, # pyre-ignore 207 enable_dynamic_shape=args.enable_dynamic_shape, # pyre-ignore 208 ) 209 else: 210 # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch 211 # for quantizers. Currently export_for_training only works with --kv_cache, but 212 # fails without the kv_cache mode 213 model = ( 214 manager.model.eval().to(device="cuda") 215 if torch.cuda.is_available() 216 else manager.model.eval().to(device="cpu") 217 ) 218 219 # Save the checkpoint after the eager model preparation is done. 220 # The reason for this option is that the checkpoint can be used 221 # to do evaluations in other evaluation platforms, or with data 222 # that is not available in this eval_llama. We save the checkpoint 223 # here for consistency with eval_llama. The accuracy results we 224 # get from eval_llama can be used as a reference to other evaluations. 225 if args.output_eager_checkpoint_file is not None: # pyre-ignore 226 torch.save(model, args.output_eager_checkpoint_file) 227 228 return EagerEvalWrapper( 229 model=model, 230 tokenizer=tokenizer, 231 max_seq_length=args.max_seq_length, 232 use_kv_cache=args.use_kv_cache, 233 ) 234 235 236def build_args_parser() -> argparse.ArgumentParser: 237 # Start with arg parser from export_llama_lib 238 parser = _build_args_parser() 239 240 # Add additional args specific to eval 241 parser.add_argument( 242 "--tasks", 243 nargs="+", 244 type=str, 245 default=["wikitext"], 246 help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", 247 ) 248 parser.add_argument( 249 "--limit", 250 type=int, 251 default=None, 252 help="number of samples to evalulate. If not set, evaluate all samples", 253 ) 254 parser.add_argument( 255 "-f", 256 "--num_fewshot", 257 type=int, 258 default=None, 259 metavar="N", 260 help="Number of examples in few-shot context", 261 ) 262 # Add additional args specific to eval via an ET Runner 263 # Note: For initial integration, the tokenizer.model is also required 264 parser.add_argument( 265 "--pte", 266 type=str, 267 default=None, 268 help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow", 269 ) 270 parser.add_argument( 271 "--tokenizer_bin", 272 type=str, 273 default=None, 274 help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime", 275 ) 276 parser.add_argument( 277 "--output_eager_checkpoint_file", 278 type=str, 279 default=None, 280 help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.", 281 ) 282 283 return parser 284 285 286def eval_llama( 287 model_name: str, 288 args: argparse.ArgumentParser, 289) -> None: 290 # Generate the eval wrapper 291 eval_wrapper = gen_eval_wrapper(model_name, args) 292 293 # Needed for loading mmlu dataset. 294 # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files 295 # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks` 296 if args.tasks and "mmlu" in args.tasks: 297 import datasets 298 299 datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True 300 301 # Evaluate the model 302 with torch.no_grad(): 303 eval_results = simple_evaluate( 304 model=eval_wrapper, 305 tasks=args.tasks, 306 num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot` 307 limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit` 308 ) 309 310 for task, res in eval_results["results"].items(): 311 print(f"{task}: {res}") 312