xref: /aosp_15_r20/external/executorch/examples/models/llama/eval_llama_lib.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8import argparse
9
10from typing import Optional, Union
11
12import torch
13from executorch.examples.models.llama.export_llama_lib import (
14    get_quantizer_and_quant_params,
15)
16from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken
17
18from executorch.extension.llm.export.builder import LLMEdgeManager
19from executorch.extension.llm.tokenizer.tokenizer import (
20    Tokenizer as SentencePieceTokenizer,
21)
22from executorch.extension.llm.tokenizer.utils import get_tokenizer
23from lm_eval.evaluator import simple_evaluate
24
25from .evaluate.eager_eval import EagerEvalWrapper
26
27from .export_llama_lib import (
28    _prepare_for_llama_export,
29    build_args_parser as _build_args_parser,
30)
31
32
33class GraphModuleEvalWrapper(EagerEvalWrapper):
34    """
35    A wrapper class for ExecuTorch py-binded integration with the
36    lm-evaluation-harness library.
37    """
38
39    def __init__(
40        self,
41        model: torch.fx.GraphModule,
42        tokenizer: Union[SentencePieceTokenizer, Tiktoken],
43        max_seq_length: Optional[int] = None,
44        use_kv_cache: bool = False,
45        generate_full_logits: bool = False,
46        enable_dynamic_shape: bool = True,
47    ):
48        super().__init__(
49            model=model, tokenizer=tokenizer, max_seq_length=max_seq_length
50        )
51        self._model = model.to(self.device)
52        self._use_kv_cache = use_kv_cache
53        self._generate_full_logits = generate_full_logits
54        self._enable_dynamic_shape = enable_dynamic_shape
55
56    def _model_call(self, inps):
57        if self._use_kv_cache:
58            if not self._enable_dynamic_shape:
59                # graph module exported without dynamic shape won't work with a different shape.
60                # And we have to do single token prefill here.
61                result_logits = []
62                for pos in range(inps.shape[-1]):
63                    pos_tensor = torch.tensor([pos], dtype=torch.int64)
64                    logits = self._model(inps[:, pos : pos + 1], pos_tensor)
65                    result_logits.append(logits)
66                if self._generate_full_logits:
67                    return torch.cat(result_logits, dim=1)
68                else:
69                    return torch.stack(result_logits, dim=1)
70            else:
71                pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
72                # Batch process the whole sequence.
73                logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
74                return logits
75
76        else:
77            return self._model(inps)
78
79    def _model_generate(self, context, max_length, eos_token_id):
80        raise Exception("unimplemented")
81
82
83class ETPybindEvalWrapper(EagerEvalWrapper):
84    """
85    A wrapper class for ExecuTorch py-binded integration with the
86    lm-evaluation-harness library.
87    """
88
89    def __init__(
90        self,
91        model: str,
92        tokenizer: Union[SentencePieceTokenizer, Tiktoken],
93        max_seq_length: Optional[int] = None,
94    ):
95        super().__init__(None, tokenizer, max_seq_length)  # pyre-ignore
96        self._model = model  # Expects model to be path to a .pte file
97
98        from executorch.extension.pybindings.portable_lib import _load_for_executorch
99
100        # Load custom ops and quantized ops.
101        from executorch.extension.pybindings import portable_lib  # noqa # usort: skip
102
103        # Note: import this after portable_lib
104        from executorch.extension.llm.custom_ops import (  # noqa
105            sdpa_with_kv_cache,  # usort: skip
106        )
107        from executorch.kernels import quantized  # noqa
108
109        self._et_model = _load_for_executorch(self._model)
110        self._use_kv_cache = self._et_model.run_method("use_kv_cache")[0]  # pyre-ignore
111
112    def _model_call(self, inps):
113        # Given inps (tokens), return the logits from a single forward call
114        # inps: Tensor of shape (1, max_seq_len - 1)
115        # logits: Tensor of shape (1, max_seq_len - 1, vocab_size)
116        result = []
117        if self._use_kv_cache:
118            pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
119            result = self._et_model.forward(
120                (inps[:, : self._max_seq_length], pos_tensor)
121            )
122        else:
123            result = self._et_model.forward((inps,))
124        if result[0].dim() != 3:
125            raise ValueError(
126                f"Dim of logits must be 3 for evaluation. Got {result[0].dim()} here. Add --generate_full_logits in export_llama to generate a pte file with full logits."
127            )
128        return result[0]
129
130
131class ETRunnerEvalWrapper(EagerEvalWrapper):
132    """
133    A wrapper class for ExecuTorch Runtime integration with the
134    lm-evaluation-harness library.
135    """
136
137    def __init__(
138        self,
139        model: str,
140        tokenizer: Union[SentencePieceTokenizer, Tiktoken],
141        tokenizer_bin: str,
142        max_seq_length: Optional[int] = None,
143    ):
144        super().__init__(None, tokenizer, max_seq_length)  # pyre-ignore
145        self._model = model
146        self._tokenizer_bin = tokenizer_bin
147
148    def _model_call(self, inps):
149        # Given inps (tokens), return the logits from a single
150        # forward call
151
152        # Example:
153        # inps: Tensor of shape (1, N)
154        # logits: Tensor of shape (1, N, vocab_size)
155        pass
156
157
158def gen_eval_wrapper(
159    model_name: str,
160    args: argparse.ArgumentParser,
161):
162    """
163    Generates a wrapper interface around the provided model and tokenizer for
164    the lm-evaluation-harness library.
165
166    Returns:
167        eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library.
168    """
169    tokenizer = get_tokenizer(args.tokenizer_path)  # pyre-ignore
170
171    # ExecuTorch Binary Evaluation
172    if (model := args.pte) is not None:  # pyre-ignore
173        if (tokenizer_bin := args.tokenizer_bin) is not None:  # pyre-ignore
174            # ETRunnerEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated at runtime
175            return ETRunnerEvalWrapper(
176                model=model,
177                tokenizer=tokenizer,
178                tokenizer_bin=tokenizer_bin,
179                max_seq_length=args.max_seq_length,  # pyre-ignore
180            )
181
182        # ETPybindEvalWrapper: Create a wrapper around an ExecuTorch model, evaluated with pybindings
183        return ETPybindEvalWrapper(
184            model=model,
185            tokenizer=tokenizer,
186            # Exported model takes at most (max_seq_length - 1) tokens.
187            # Note that the eager model takes at most max_seq_length tokens.
188            max_seq_length=args.max_seq_length - 1,
189        )
190
191    pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
192    # GPTFastEvalWrapper: Create a wrapper around a pre-exported model
193    manager: LLMEdgeManager = _prepare_for_llama_export(args)
194
195    if len(quantizers) != 0:
196        manager = manager.export().pt2e_quantize(quantizers)
197        model = (
198            manager.pre_autograd_graph_module.to(device="cuda")  # pyre-ignore
199            if torch.cuda.is_available()
200            else manager.pre_autograd_graph_module.to(device="cpu")
201        )
202        return GraphModuleEvalWrapper(
203            model=model,
204            tokenizer=tokenizer,
205            max_seq_length=args.max_seq_length,
206            use_kv_cache=args.use_kv_cache,  # pyre-ignore
207            enable_dynamic_shape=args.enable_dynamic_shape,  # pyre-ignore
208        )
209    else:
210        # TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
211        # for quantizers. Currently export_for_training only works with --kv_cache, but
212        # fails without the kv_cache mode
213        model = (
214            manager.model.eval().to(device="cuda")
215            if torch.cuda.is_available()
216            else manager.model.eval().to(device="cpu")
217        )
218
219        # Save the checkpoint after the eager model preparation is done.
220        # The reason for this option is that the checkpoint can be used
221        # to do evaluations in other evaluation platforms, or with data
222        # that is not available in this eval_llama. We save the checkpoint
223        # here for consistency with eval_llama. The accuracy results we
224        # get from eval_llama can be used as a reference to other evaluations.
225        if args.output_eager_checkpoint_file is not None:  # pyre-ignore
226            torch.save(model, args.output_eager_checkpoint_file)
227
228        return EagerEvalWrapper(
229            model=model,
230            tokenizer=tokenizer,
231            max_seq_length=args.max_seq_length,
232            use_kv_cache=args.use_kv_cache,
233        )
234
235
236def build_args_parser() -> argparse.ArgumentParser:
237    # Start with arg parser from export_llama_lib
238    parser = _build_args_parser()
239
240    # Add additional args specific to eval
241    parser.add_argument(
242        "--tasks",
243        nargs="+",
244        type=str,
245        default=["wikitext"],
246        help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2",
247    )
248    parser.add_argument(
249        "--limit",
250        type=int,
251        default=None,
252        help="number of samples to evalulate. If not set, evaluate all samples",
253    )
254    parser.add_argument(
255        "-f",
256        "--num_fewshot",
257        type=int,
258        default=None,
259        metavar="N",
260        help="Number of examples in few-shot context",
261    )
262    # Add additional args specific to eval via an ET Runner
263    # Note: For initial integration, the tokenizer.model is also required
264    parser.add_argument(
265        "--pte",
266        type=str,
267        default=None,
268        help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow",
269    )
270    parser.add_argument(
271        "--tokenizer_bin",
272        type=str,
273        default=None,
274        help="[For ExecuTorch] Path to the Tokenizer binary for evaluating ExecuTorch models via runtime",
275    )
276    parser.add_argument(
277        "--output_eager_checkpoint_file",
278        type=str,
279        default=None,
280        help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
281    )
282
283    return parser
284
285
286def eval_llama(
287    model_name: str,
288    args: argparse.ArgumentParser,
289) -> None:
290    # Generate the eval wrapper
291    eval_wrapper = gen_eval_wrapper(model_name, args)
292
293    # Needed for loading mmlu dataset.
294    # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
295    # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
296    if args.tasks and "mmlu" in args.tasks:
297        import datasets
298
299        datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
300
301    # Evaluate the model
302    with torch.no_grad():
303        eval_results = simple_evaluate(
304            model=eval_wrapper,
305            tasks=args.tasks,
306            num_fewshot=args.num_fewshot,  # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
307            limit=args.limit,  # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
308        )
309
310    for task, res in eval_results["results"].items():
311        print(f"{task}: {res}")
312