1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8import os 9import random 10import tempfile 11import zipfile 12 13from collections import defaultdict 14from pathlib import Path 15from typing import Any, Optional, Tuple 16 17import torch 18from torch.nn.modules import Module 19from torch.utils.data import DataLoader 20from torchvision import datasets, transforms 21 22 23# Logger for outputting progress for longer running evaluation 24logger = logging.getLogger(__name__) 25logger.setLevel(logging.INFO) 26 27 28def flatten_args(args) -> tuple | list: 29 flattened_args: list = [] 30 if isinstance(args, torch.Tensor): 31 return [args] 32 33 for arg in args: 34 if isinstance(arg, (tuple, list)): 35 flattened_args.extend(arg) 36 else: 37 flattened_args.append(arg) 38 39 return tuple(flattened_args) 40 41 42class GenericModelEvaluator: 43 REQUIRES_CONFIG = False 44 45 def __init__( 46 self, 47 model_name: str, 48 fp32_model: torch.nn.Module, 49 int8_model: torch.nn.Module, 50 example_input: Tuple[torch.Tensor], 51 tosa_output_path: Optional[str], 52 ) -> None: 53 self.model_name = model_name 54 55 self.fp32_model = fp32_model 56 self.int8_model = int8_model 57 self.example_input = example_input 58 59 if tosa_output_path: 60 self.tosa_output_path = tosa_output_path 61 else: 62 self.tosa_output_path = None 63 64 def get_model_error(self) -> defaultdict: 65 """ 66 Returns a dict containing the following metrics between the outputs of the FP32 and INT8 model: 67 - Maximum error 68 - Maximum absolute error 69 - Maximum percentage error 70 - Mean absolute error 71 """ 72 fp32_outputs = flatten_args(self.fp32_model(*self.example_input)) 73 int8_outputs = flatten_args(self.int8_model(*self.example_input)) 74 75 model_error_dict = defaultdict(list) 76 77 for fp32_output, int8_output in zip(fp32_outputs, int8_outputs): 78 difference = fp32_output - int8_output 79 percentage_error = torch.div(difference, fp32_output) * 100 80 model_error_dict["max_error"].append(torch.max(difference).item()) 81 model_error_dict["max_absolute_error"].append( 82 torch.max(torch.abs(difference)).item() 83 ) 84 model_error_dict["max_percentage_error"].append( 85 torch.max(percentage_error).item() 86 ) 87 model_error_dict["mean_absolute_error"].append( 88 torch.mean(torch.abs(difference).float()).item() 89 ) 90 91 return model_error_dict 92 93 def get_compression_ratio(self) -> float: 94 """Compute the compression ratio of the outputted TOSA flatbuffer.""" 95 with tempfile.NamedTemporaryFile(delete=True, suffix=".zip") as temp_zip: 96 with zipfile.ZipFile( 97 temp_zip.name, "w", compression=zipfile.ZIP_DEFLATED 98 ) as f: 99 f.write(self.tosa_output_path) 100 101 compression_ratio = os.path.getsize( 102 self.tosa_output_path 103 ) / os.path.getsize(temp_zip.name) 104 105 return compression_ratio 106 107 def evaluate(self) -> dict[Any]: 108 model_error_dict = self.get_model_error() 109 110 output_metrics = {"name": self.model_name, "metrics": dict(model_error_dict)} 111 112 if self.tosa_output_path: 113 # We know output_metrics["metrics"] is list since we just defined it, safe to ignore. 114 # pyre-ignore[16] 115 output_metrics["metrics"][ 116 "compression_ratio" 117 ] = self.get_compression_ratio() 118 119 return output_metrics 120 121 122class MobileNetV2Evaluator(GenericModelEvaluator): 123 REQUIRES_CONFIG = True 124 125 def __init__( 126 self, 127 model_name: str, 128 fp32_model: Module, 129 int8_model: Module, 130 example_input: Tuple[torch.Tensor], 131 tosa_output_path: str | None, 132 batch_size: int, 133 validation_dataset_path: str, 134 ) -> None: 135 super().__init__( 136 model_name, fp32_model, int8_model, example_input, tosa_output_path 137 ) 138 139 self.__batch_size = batch_size 140 self.__validation_set_path = validation_dataset_path 141 142 @staticmethod 143 def __load_dataset(directory: str) -> datasets.ImageFolder: 144 directory_path = Path(directory) 145 if not directory_path.exists(): 146 raise FileNotFoundError(f"Directory: {directory} does not exist.") 147 148 transform = transforms.Compose( 149 [ 150 transforms.Resize(256), 151 transforms.CenterCrop(224), 152 transforms.ToTensor(), 153 transforms.Normalize( 154 mean=[0.484, 0.454, 0.403], std=[0.225, 0.220, 0.220] 155 ), 156 ] 157 ) 158 return datasets.ImageFolder(directory_path, transform=transform) 159 160 @staticmethod 161 def get_calibrator(training_dataset_path: str) -> DataLoader: 162 dataset = MobileNetV2Evaluator.__load_dataset(training_dataset_path) 163 rand_indices = random.sample(range(len(dataset)), k=1000) 164 165 # Return a subset of the dataset to be used for calibration 166 return torch.utils.data.DataLoader( 167 torch.utils.data.Subset(dataset, rand_indices), 168 batch_size=1, 169 shuffle=False, 170 ) 171 172 def __evaluate_mobilenet(self) -> Tuple[float, float]: 173 dataset = MobileNetV2Evaluator.__load_dataset(self.__validation_set_path) 174 loaded_dataset = DataLoader( 175 dataset, 176 batch_size=self.__batch_size, 177 shuffle=False, 178 ) 179 180 top1_correct = 0 181 top5_correct = 0 182 183 for i, (image, target) in enumerate(loaded_dataset): 184 prediction = self.int8_model(image) 185 top1_prediction = torch.topk(prediction, k=1, dim=1).indices 186 top5_prediction = torch.topk(prediction, k=5, dim=1).indices 187 188 top1_correct += (top1_prediction == target.view(-1, 1)).sum().item() 189 top5_correct += (top5_prediction == target.view(-1, 1)).sum().item() 190 191 logger.info("Iteration: {}".format((i + 1) * self.__batch_size)) 192 logger.info( 193 "Top 1: {}".format(top1_correct / ((i + 1) * self.__batch_size)) 194 ) 195 logger.info( 196 "Top 5: {}".format(top5_correct / ((i + 1) * self.__batch_size)) 197 ) 198 199 top1_accuracy = top1_correct / len(dataset) 200 top5_accuracy = top5_correct / len(dataset) 201 202 return top1_accuracy, top5_accuracy 203 204 def evaluate(self) -> dict[str, Any]: 205 top1_correct, top5_correct = self.__evaluate_mobilenet() 206 output = super().evaluate() 207 208 output["metrics"]["accuracy"] = {"top-1": top1_correct, "top-5": top5_correct} 209 return output 210