1#!/usr/bin/env python3 2 3import importlib 4import logging 5import os 6import re 7import subprocess 8import sys 9import warnings 10 11 12try: 13 from .common import BenchmarkRunner, download_retry_decorator, main 14except ImportError: 15 from common import BenchmarkRunner, download_retry_decorator, main 16 17import torch 18from torch._dynamo.testing import collect_results, reduce_to_scalar_loss 19from torch._dynamo.utils import clone_inputs 20 21 22# Enable FX graph caching 23if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: 24 torch._inductor.config.fx_graph_cache = True 25 26 27def pip_install(package): 28 subprocess.check_call([sys.executable, "-m", "pip", "install", package]) 29 30 31try: 32 importlib.import_module("timm") 33except ModuleNotFoundError: 34 print("Installing PyTorch Image Models...") 35 pip_install("git+https://github.com/rwightman/pytorch-image-models") 36finally: 37 from timm import __version__ as timmversion 38 from timm.data import resolve_data_config 39 from timm.models import create_model 40 41TIMM_MODELS = {} 42filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt") 43 44with open(filename) as fh: 45 lines = fh.readlines() 46 lines = [line.rstrip() for line in lines] 47 for line in lines: 48 model_name, batch_size = line.split(" ") 49 TIMM_MODELS[model_name] = int(batch_size) 50 51 52# TODO - Figure out the reason of cold start memory spike 53 54BATCH_SIZE_DIVISORS = { 55 "beit_base_patch16_224": 2, 56 "convit_base": 2, 57 "convmixer_768_32": 2, 58 "convnext_base": 2, 59 "cspdarknet53": 2, 60 "deit_base_distilled_patch16_224": 2, 61 "gluon_xception65": 2, 62 "mobilevit_s": 2, 63 "pnasnet5large": 2, 64 "poolformer_m36": 2, 65 "resnest101e": 2, 66 "swin_base_patch4_window7_224": 2, 67 "swsl_resnext101_32x16d": 2, 68 "vit_base_patch16_224": 2, 69 "volo_d1_224": 2, 70 "jx_nest_base": 4, 71} 72 73REQUIRE_HIGHER_TOLERANCE = { 74 "fbnetv3_b", 75 "gmixer_24_224", 76 "hrnet_w18", 77 "inception_v3", 78 "mixer_b16_224", 79 "mobilenetv3_large_100", 80 "sebotnet33ts_256", 81 "selecsls42b", 82} 83 84REQUIRE_EVEN_HIGHER_TOLERANCE = { 85 "levit_128", 86 "sebotnet33ts_256", 87 "beit_base_patch16_224", 88 "cspdarknet53", 89} 90 91# These models need higher tolerance in MaxAutotune mode 92REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE = { 93 "gluon_inception_v3", 94} 95 96REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = { 97 "adv_inception_v3", 98 "botnet26t_256", 99 "gluon_inception_v3", 100 "selecsls42b", 101 "swsl_resnext101_32x16d", 102} 103 104SCALED_COMPUTE_LOSS = { 105 "ese_vovnet19b_dw", 106 "fbnetc_100", 107 "mnasnet_100", 108 "mobilevit_s", 109 "sebotnet33ts_256", 110} 111 112FORCE_AMP_FOR_FP16_BF16_MODELS = { 113 "convit_base", 114 "xcit_large_24_p8_224", 115} 116 117SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = { 118 "xcit_large_24_p8_224", 119} 120 121REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR = { 122 "inception_v3", 123 "mobilenetv3_large_100", 124 "cspdarknet53", 125} 126 127 128def refresh_model_names(): 129 import glob 130 131 from timm.models import list_models 132 133 def read_models_from_docs(): 134 models = set() 135 # TODO - set the path to pytorch-image-models repo 136 for fn in glob.glob("../pytorch-image-models/docs/models/*.md"): 137 with open(fn) as f: 138 while True: 139 line = f.readline() 140 if not line: 141 break 142 if not line.startswith("model = timm.create_model("): 143 continue 144 145 model = line.split("'")[1] 146 # print(model) 147 models.add(model) 148 return models 149 150 def get_family_name(name): 151 known_families = [ 152 "darknet", 153 "densenet", 154 "dla", 155 "dpn", 156 "ecaresnet", 157 "halo", 158 "regnet", 159 "efficientnet", 160 "deit", 161 "mobilevit", 162 "mnasnet", 163 "convnext", 164 "resnet", 165 "resnest", 166 "resnext", 167 "selecsls", 168 "vgg", 169 "xception", 170 ] 171 172 for known_family in known_families: 173 if known_family in name: 174 return known_family 175 176 if name.startswith("gluon_"): 177 return "gluon_" + name.split("_")[1] 178 return name.split("_")[0] 179 180 def populate_family(models): 181 family = {} 182 for model_name in models: 183 family_name = get_family_name(model_name) 184 if family_name not in family: 185 family[family_name] = [] 186 family[family_name].append(model_name) 187 return family 188 189 docs_models = read_models_from_docs() 190 all_models = list_models(pretrained=True, exclude_filters=["*in21k"]) 191 192 all_models_family = populate_family(all_models) 193 docs_models_family = populate_family(docs_models) 194 195 for key in docs_models_family: 196 del all_models_family[key] 197 198 chosen_models = set() 199 chosen_models.update(value[0] for value in docs_models_family.values()) 200 201 chosen_models.update(value[0] for key, value in all_models_family.items()) 202 203 filename = "timm_models_list.txt" 204 if os.path.exists("benchmarks"): 205 filename = "benchmarks/" + filename 206 with open(filename, "w") as fw: 207 for model_name in sorted(chosen_models): 208 fw.write(model_name + "\n") 209 210 211class TimmRunner(BenchmarkRunner): 212 def __init__(self): 213 super().__init__() 214 self.suite_name = "timm_models" 215 216 @property 217 def force_amp_for_fp16_bf16_models(self): 218 return FORCE_AMP_FOR_FP16_BF16_MODELS 219 220 @property 221 def force_fp16_for_bf16_models(self): 222 return set() 223 224 @property 225 def get_output_amp_train_process_func(self): 226 return {} 227 228 @property 229 def skip_accuracy_check_as_eager_non_deterministic(self): 230 if self.args.accuracy and self.args.training: 231 return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS 232 return set() 233 234 @property 235 def guard_on_nn_module_models(self): 236 return { 237 "convit_base", 238 } 239 240 @property 241 def inline_inbuilt_nn_modules_models(self): 242 return { 243 "lcnet_050", 244 } 245 246 @download_retry_decorator 247 def _download_model(self, model_name): 248 model = create_model( 249 model_name, 250 in_chans=3, 251 scriptable=False, 252 num_classes=None, 253 drop_rate=0.0, 254 drop_path_rate=None, 255 drop_block_rate=None, 256 pretrained=True, 257 ) 258 return model 259 260 def load_model( 261 self, 262 device, 263 model_name, 264 batch_size=None, 265 extra_args=None, 266 ): 267 if self.args.enable_activation_checkpointing: 268 raise NotImplementedError( 269 "Activation checkpointing not implemented for Timm models" 270 ) 271 272 is_training = self.args.training 273 use_eval_mode = self.args.use_eval_mode 274 275 channels_last = self._args.channels_last 276 model = self._download_model(model_name) 277 278 if model is None: 279 raise RuntimeError(f"Failed to load model '{model_name}'") 280 model.to( 281 device=device, 282 memory_format=torch.channels_last if channels_last else None, 283 ) 284 285 self.num_classes = model.num_classes 286 287 data_config = resolve_data_config( 288 vars(self._args) if timmversion >= "0.8.0" else self._args, 289 model=model, 290 use_test_size=not is_training, 291 ) 292 input_size = data_config["input_size"] 293 recorded_batch_size = TIMM_MODELS[model_name] 294 295 if model_name in BATCH_SIZE_DIVISORS: 296 recorded_batch_size = max( 297 int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1 298 ) 299 batch_size = batch_size or recorded_batch_size 300 301 torch.manual_seed(1337) 302 input_tensor = torch.randint( 303 256, size=(batch_size,) + input_size, device=device 304 ).to(dtype=torch.float32) 305 mean = torch.mean(input_tensor) 306 std_dev = torch.std(input_tensor) 307 example_inputs = (input_tensor - mean) / std_dev 308 309 if channels_last: 310 example_inputs = example_inputs.contiguous( 311 memory_format=torch.channels_last 312 ) 313 example_inputs = [ 314 example_inputs, 315 ] 316 self.target = self._gen_target(batch_size, device) 317 318 self.loss = torch.nn.CrossEntropyLoss().to(device) 319 320 if model_name in SCALED_COMPUTE_LOSS: 321 self.compute_loss = self.scaled_compute_loss 322 323 if is_training and not use_eval_mode: 324 model.train() 325 else: 326 model.eval() 327 328 self.validate_model(model, example_inputs) 329 330 return device, model_name, model, example_inputs, batch_size 331 332 def iter_model_names(self, args): 333 # for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]): 334 model_names = sorted(TIMM_MODELS.keys()) 335 start, end = self.get_benchmark_indices(len(model_names)) 336 for index, model_name in enumerate(model_names): 337 if index < start or index >= end: 338 continue 339 if ( 340 not re.search("|".join(args.filter), model_name, re.IGNORECASE) 341 or re.search("|".join(args.exclude), model_name, re.IGNORECASE) 342 or model_name in args.exclude_exact 343 or model_name in self.skip_models 344 ): 345 continue 346 347 yield model_name 348 349 def pick_grad(self, name, is_training): 350 if is_training: 351 return torch.enable_grad() 352 else: 353 return torch.no_grad() 354 355 def use_larger_multiplier_for_smaller_tensor(self, name): 356 return name in REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR 357 358 def get_tolerance_and_cosine_flag(self, is_training, current_device, name): 359 cosine = self.args.cosine 360 tolerance = 1e-3 361 362 if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING: 363 # the conv-batchnorm fusion used under freezing may cause relatively 364 # large numerical difference. We need are larger tolerance. 365 # Check https://github.com/pytorch/pytorch/issues/120545 for context 366 tolerance = 8 * 1e-2 367 368 if is_training: 369 from torch._inductor import config as inductor_config 370 371 if name in REQUIRE_EVEN_HIGHER_TOLERANCE or ( 372 inductor_config.max_autotune 373 and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE 374 ): 375 tolerance = 8 * 1e-2 376 elif name in REQUIRE_HIGHER_TOLERANCE: 377 tolerance = 4 * 1e-2 378 else: 379 tolerance = 1e-2 380 return tolerance, cosine 381 382 def _gen_target(self, batch_size, device): 383 return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_( 384 self.num_classes 385 ) 386 387 def compute_loss(self, pred): 388 # High loss values make gradient checking harder, as small changes in 389 # accumulation order upsets accuracy checks. 390 return reduce_to_scalar_loss(pred) 391 392 def scaled_compute_loss(self, pred): 393 # Loss values need zoom out further. 394 return reduce_to_scalar_loss(pred) / 1000.0 395 396 def forward_pass(self, mod, inputs, collect_outputs=True): 397 with self.autocast(**self.autocast_arg): 398 return mod(*inputs) 399 400 def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): 401 cloned_inputs = clone_inputs(inputs) 402 self.optimizer_zero_grad(mod) 403 with self.autocast(**self.autocast_arg): 404 pred = mod(*cloned_inputs) 405 if isinstance(pred, tuple): 406 pred = pred[0] 407 loss = self.compute_loss(pred) 408 self.grad_scaler.scale(loss).backward() 409 self.optimizer_step() 410 if collect_outputs: 411 return collect_results(mod, pred, loss, cloned_inputs) 412 return None 413 414 415def timm_main(): 416 logging.basicConfig(level=logging.WARNING) 417 warnings.filterwarnings("ignore") 418 main(TimmRunner()) 419 420 421if __name__ == "__main__": 422 timm_main() 423