xref: /aosp_15_r20/external/executorch/backends/arm/util/arm_model_evaluator.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import os
9import random
10import tempfile
11import zipfile
12
13from collections import defaultdict
14from pathlib import Path
15from typing import Any, Optional, Tuple
16
17import torch
18from torch.nn.modules import Module
19from torch.utils.data import DataLoader
20from torchvision import datasets, transforms
21
22
23# Logger for outputting progress for longer running evaluation
24logger = logging.getLogger(__name__)
25logger.setLevel(logging.INFO)
26
27
28def flatten_args(args) -> tuple | list:
29    flattened_args: list = []
30    if isinstance(args, torch.Tensor):
31        return [args]
32
33    for arg in args:
34        if isinstance(arg, (tuple, list)):
35            flattened_args.extend(arg)
36        else:
37            flattened_args.append(arg)
38
39    return tuple(flattened_args)
40
41
42class GenericModelEvaluator:
43    REQUIRES_CONFIG = False
44
45    def __init__(
46        self,
47        model_name: str,
48        fp32_model: torch.nn.Module,
49        int8_model: torch.nn.Module,
50        example_input: Tuple[torch.Tensor],
51        tosa_output_path: Optional[str],
52    ) -> None:
53        self.model_name = model_name
54
55        self.fp32_model = fp32_model
56        self.int8_model = int8_model
57        self.example_input = example_input
58
59        if tosa_output_path:
60            self.tosa_output_path = tosa_output_path
61        else:
62            self.tosa_output_path = None
63
64    def get_model_error(self) -> defaultdict:
65        """
66        Returns a dict containing the following metrics between the outputs of the FP32 and INT8 model:
67        - Maximum error
68        - Maximum absolute error
69        - Maximum percentage error
70        - Mean absolute error
71        """
72        fp32_outputs = flatten_args(self.fp32_model(*self.example_input))
73        int8_outputs = flatten_args(self.int8_model(*self.example_input))
74
75        model_error_dict = defaultdict(list)
76
77        for fp32_output, int8_output in zip(fp32_outputs, int8_outputs):
78            difference = fp32_output - int8_output
79            percentage_error = torch.div(difference, fp32_output) * 100
80            model_error_dict["max_error"].append(torch.max(difference).item())
81            model_error_dict["max_absolute_error"].append(
82                torch.max(torch.abs(difference)).item()
83            )
84            model_error_dict["max_percentage_error"].append(
85                torch.max(percentage_error).item()
86            )
87            model_error_dict["mean_absolute_error"].append(
88                torch.mean(torch.abs(difference).float()).item()
89            )
90
91        return model_error_dict
92
93    def get_compression_ratio(self) -> float:
94        """Compute the compression ratio of the outputted TOSA flatbuffer."""
95        with tempfile.NamedTemporaryFile(delete=True, suffix=".zip") as temp_zip:
96            with zipfile.ZipFile(
97                temp_zip.name, "w", compression=zipfile.ZIP_DEFLATED
98            ) as f:
99                f.write(self.tosa_output_path)
100
101            compression_ratio = os.path.getsize(
102                self.tosa_output_path
103            ) / os.path.getsize(temp_zip.name)
104
105        return compression_ratio
106
107    def evaluate(self) -> dict[Any]:
108        model_error_dict = self.get_model_error()
109
110        output_metrics = {"name": self.model_name, "metrics": dict(model_error_dict)}
111
112        if self.tosa_output_path:
113            # We know output_metrics["metrics"] is list since we just defined it, safe to ignore.
114            # pyre-ignore[16]
115            output_metrics["metrics"][
116                "compression_ratio"
117            ] = self.get_compression_ratio()
118
119        return output_metrics
120
121
122class MobileNetV2Evaluator(GenericModelEvaluator):
123    REQUIRES_CONFIG = True
124
125    def __init__(
126        self,
127        model_name: str,
128        fp32_model: Module,
129        int8_model: Module,
130        example_input: Tuple[torch.Tensor],
131        tosa_output_path: str | None,
132        batch_size: int,
133        validation_dataset_path: str,
134    ) -> None:
135        super().__init__(
136            model_name, fp32_model, int8_model, example_input, tosa_output_path
137        )
138
139        self.__batch_size = batch_size
140        self.__validation_set_path = validation_dataset_path
141
142    @staticmethod
143    def __load_dataset(directory: str) -> datasets.ImageFolder:
144        directory_path = Path(directory)
145        if not directory_path.exists():
146            raise FileNotFoundError(f"Directory: {directory} does not exist.")
147
148        transform = transforms.Compose(
149            [
150                transforms.Resize(256),
151                transforms.CenterCrop(224),
152                transforms.ToTensor(),
153                transforms.Normalize(
154                    mean=[0.484, 0.454, 0.403], std=[0.225, 0.220, 0.220]
155                ),
156            ]
157        )
158        return datasets.ImageFolder(directory_path, transform=transform)
159
160    @staticmethod
161    def get_calibrator(training_dataset_path: str) -> DataLoader:
162        dataset = MobileNetV2Evaluator.__load_dataset(training_dataset_path)
163        rand_indices = random.sample(range(len(dataset)), k=1000)
164
165        # Return a subset of the dataset to be used for calibration
166        return torch.utils.data.DataLoader(
167            torch.utils.data.Subset(dataset, rand_indices),
168            batch_size=1,
169            shuffle=False,
170        )
171
172    def __evaluate_mobilenet(self) -> Tuple[float, float]:
173        dataset = MobileNetV2Evaluator.__load_dataset(self.__validation_set_path)
174        loaded_dataset = DataLoader(
175            dataset,
176            batch_size=self.__batch_size,
177            shuffle=False,
178        )
179
180        top1_correct = 0
181        top5_correct = 0
182
183        for i, (image, target) in enumerate(loaded_dataset):
184            prediction = self.int8_model(image)
185            top1_prediction = torch.topk(prediction, k=1, dim=1).indices
186            top5_prediction = torch.topk(prediction, k=5, dim=1).indices
187
188            top1_correct += (top1_prediction == target.view(-1, 1)).sum().item()
189            top5_correct += (top5_prediction == target.view(-1, 1)).sum().item()
190
191            logger.info("Iteration: {}".format((i + 1) * self.__batch_size))
192            logger.info(
193                "Top 1: {}".format(top1_correct / ((i + 1) * self.__batch_size))
194            )
195            logger.info(
196                "Top 5: {}".format(top5_correct / ((i + 1) * self.__batch_size))
197            )
198
199        top1_accuracy = top1_correct / len(dataset)
200        top5_accuracy = top5_correct / len(dataset)
201
202        return top1_accuracy, top5_accuracy
203
204    def evaluate(self) -> dict[str, Any]:
205        top1_correct, top5_correct = self.__evaluate_mobilenet()
206        output = super().evaluate()
207
208        output["metrics"]["accuracy"] = {"top-1": top1_correct, "top-5": top5_correct}
209        return output
210