1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport importlib 4*da0073e9SAndroid Build Coastguard Workerimport logging 5*da0073e9SAndroid Build Coastguard Workerimport os 6*da0073e9SAndroid Build Coastguard Workerimport re 7*da0073e9SAndroid Build Coastguard Workerimport subprocess 8*da0073e9SAndroid Build Coastguard Workerimport sys 9*da0073e9SAndroid Build Coastguard Workerimport warnings 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workertry: 13*da0073e9SAndroid Build Coastguard Worker from .common import BenchmarkRunner, download_retry_decorator, main 14*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 15*da0073e9SAndroid Build Coastguard Worker from common import BenchmarkRunner, download_retry_decorator, main 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerimport torch 18*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import collect_results, reduce_to_scalar_loss 19*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import clone_inputs 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker# Enable FX graph caching 23*da0073e9SAndroid Build Coastguard Workerif "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: 24*da0073e9SAndroid Build Coastguard Worker torch._inductor.config.fx_graph_cache = True 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerdef pip_install(package): 28*da0073e9SAndroid Build Coastguard Worker subprocess.check_call([sys.executable, "-m", "pip", "install", package]) 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Workertry: 32*da0073e9SAndroid Build Coastguard Worker importlib.import_module("timm") 33*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError: 34*da0073e9SAndroid Build Coastguard Worker print("Installing PyTorch Image Models...") 35*da0073e9SAndroid Build Coastguard Worker pip_install("git+https://github.com/rwightman/pytorch-image-models") 36*da0073e9SAndroid Build Coastguard Workerfinally: 37*da0073e9SAndroid Build Coastguard Worker from timm import __version__ as timmversion 38*da0073e9SAndroid Build Coastguard Worker from timm.data import resolve_data_config 39*da0073e9SAndroid Build Coastguard Worker from timm.models import create_model 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard WorkerTIMM_MODELS = {} 42*da0073e9SAndroid Build Coastguard Workerfilename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt") 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Workerwith open(filename) as fh: 45*da0073e9SAndroid Build Coastguard Worker lines = fh.readlines() 46*da0073e9SAndroid Build Coastguard Worker lines = [line.rstrip() for line in lines] 47*da0073e9SAndroid Build Coastguard Worker for line in lines: 48*da0073e9SAndroid Build Coastguard Worker model_name, batch_size = line.split(" ") 49*da0073e9SAndroid Build Coastguard Worker TIMM_MODELS[model_name] = int(batch_size) 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker# TODO - Figure out the reason of cold start memory spike 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard WorkerBATCH_SIZE_DIVISORS = { 55*da0073e9SAndroid Build Coastguard Worker "beit_base_patch16_224": 2, 56*da0073e9SAndroid Build Coastguard Worker "convit_base": 2, 57*da0073e9SAndroid Build Coastguard Worker "convmixer_768_32": 2, 58*da0073e9SAndroid Build Coastguard Worker "convnext_base": 2, 59*da0073e9SAndroid Build Coastguard Worker "cspdarknet53": 2, 60*da0073e9SAndroid Build Coastguard Worker "deit_base_distilled_patch16_224": 2, 61*da0073e9SAndroid Build Coastguard Worker "gluon_xception65": 2, 62*da0073e9SAndroid Build Coastguard Worker "mobilevit_s": 2, 63*da0073e9SAndroid Build Coastguard Worker "pnasnet5large": 2, 64*da0073e9SAndroid Build Coastguard Worker "poolformer_m36": 2, 65*da0073e9SAndroid Build Coastguard Worker "resnest101e": 2, 66*da0073e9SAndroid Build Coastguard Worker "swin_base_patch4_window7_224": 2, 67*da0073e9SAndroid Build Coastguard Worker "swsl_resnext101_32x16d": 2, 68*da0073e9SAndroid Build Coastguard Worker "vit_base_patch16_224": 2, 69*da0073e9SAndroid Build Coastguard Worker "volo_d1_224": 2, 70*da0073e9SAndroid Build Coastguard Worker "jx_nest_base": 4, 71*da0073e9SAndroid Build Coastguard Worker} 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard WorkerREQUIRE_HIGHER_TOLERANCE = { 74*da0073e9SAndroid Build Coastguard Worker "fbnetv3_b", 75*da0073e9SAndroid Build Coastguard Worker "gmixer_24_224", 76*da0073e9SAndroid Build Coastguard Worker "hrnet_w18", 77*da0073e9SAndroid Build Coastguard Worker "inception_v3", 78*da0073e9SAndroid Build Coastguard Worker "mixer_b16_224", 79*da0073e9SAndroid Build Coastguard Worker "mobilenetv3_large_100", 80*da0073e9SAndroid Build Coastguard Worker "sebotnet33ts_256", 81*da0073e9SAndroid Build Coastguard Worker "selecsls42b", 82*da0073e9SAndroid Build Coastguard Worker} 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard WorkerREQUIRE_EVEN_HIGHER_TOLERANCE = { 85*da0073e9SAndroid Build Coastguard Worker "levit_128", 86*da0073e9SAndroid Build Coastguard Worker "sebotnet33ts_256", 87*da0073e9SAndroid Build Coastguard Worker "beit_base_patch16_224", 88*da0073e9SAndroid Build Coastguard Worker "cspdarknet53", 89*da0073e9SAndroid Build Coastguard Worker} 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker# These models need higher tolerance in MaxAutotune mode 92*da0073e9SAndroid Build Coastguard WorkerREQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE = { 93*da0073e9SAndroid Build Coastguard Worker "gluon_inception_v3", 94*da0073e9SAndroid Build Coastguard Worker} 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard WorkerREQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = { 97*da0073e9SAndroid Build Coastguard Worker "adv_inception_v3", 98*da0073e9SAndroid Build Coastguard Worker "botnet26t_256", 99*da0073e9SAndroid Build Coastguard Worker "gluon_inception_v3", 100*da0073e9SAndroid Build Coastguard Worker "selecsls42b", 101*da0073e9SAndroid Build Coastguard Worker "swsl_resnext101_32x16d", 102*da0073e9SAndroid Build Coastguard Worker} 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard WorkerSCALED_COMPUTE_LOSS = { 105*da0073e9SAndroid Build Coastguard Worker "ese_vovnet19b_dw", 106*da0073e9SAndroid Build Coastguard Worker "fbnetc_100", 107*da0073e9SAndroid Build Coastguard Worker "mnasnet_100", 108*da0073e9SAndroid Build Coastguard Worker "mobilevit_s", 109*da0073e9SAndroid Build Coastguard Worker "sebotnet33ts_256", 110*da0073e9SAndroid Build Coastguard Worker} 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard WorkerFORCE_AMP_FOR_FP16_BF16_MODELS = { 113*da0073e9SAndroid Build Coastguard Worker "convit_base", 114*da0073e9SAndroid Build Coastguard Worker "xcit_large_24_p8_224", 115*da0073e9SAndroid Build Coastguard Worker} 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard WorkerSKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = { 118*da0073e9SAndroid Build Coastguard Worker "xcit_large_24_p8_224", 119*da0073e9SAndroid Build Coastguard Worker} 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard WorkerREQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR = { 122*da0073e9SAndroid Build Coastguard Worker "inception_v3", 123*da0073e9SAndroid Build Coastguard Worker "mobilenetv3_large_100", 124*da0073e9SAndroid Build Coastguard Worker "cspdarknet53", 125*da0073e9SAndroid Build Coastguard Worker} 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Workerdef refresh_model_names(): 129*da0073e9SAndroid Build Coastguard Worker import glob 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker from timm.models import list_models 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker def read_models_from_docs(): 134*da0073e9SAndroid Build Coastguard Worker models = set() 135*da0073e9SAndroid Build Coastguard Worker # TODO - set the path to pytorch-image-models repo 136*da0073e9SAndroid Build Coastguard Worker for fn in glob.glob("../pytorch-image-models/docs/models/*.md"): 137*da0073e9SAndroid Build Coastguard Worker with open(fn) as f: 138*da0073e9SAndroid Build Coastguard Worker while True: 139*da0073e9SAndroid Build Coastguard Worker line = f.readline() 140*da0073e9SAndroid Build Coastguard Worker if not line: 141*da0073e9SAndroid Build Coastguard Worker break 142*da0073e9SAndroid Build Coastguard Worker if not line.startswith("model = timm.create_model("): 143*da0073e9SAndroid Build Coastguard Worker continue 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker model = line.split("'")[1] 146*da0073e9SAndroid Build Coastguard Worker # print(model) 147*da0073e9SAndroid Build Coastguard Worker models.add(model) 148*da0073e9SAndroid Build Coastguard Worker return models 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker def get_family_name(name): 151*da0073e9SAndroid Build Coastguard Worker known_families = [ 152*da0073e9SAndroid Build Coastguard Worker "darknet", 153*da0073e9SAndroid Build Coastguard Worker "densenet", 154*da0073e9SAndroid Build Coastguard Worker "dla", 155*da0073e9SAndroid Build Coastguard Worker "dpn", 156*da0073e9SAndroid Build Coastguard Worker "ecaresnet", 157*da0073e9SAndroid Build Coastguard Worker "halo", 158*da0073e9SAndroid Build Coastguard Worker "regnet", 159*da0073e9SAndroid Build Coastguard Worker "efficientnet", 160*da0073e9SAndroid Build Coastguard Worker "deit", 161*da0073e9SAndroid Build Coastguard Worker "mobilevit", 162*da0073e9SAndroid Build Coastguard Worker "mnasnet", 163*da0073e9SAndroid Build Coastguard Worker "convnext", 164*da0073e9SAndroid Build Coastguard Worker "resnet", 165*da0073e9SAndroid Build Coastguard Worker "resnest", 166*da0073e9SAndroid Build Coastguard Worker "resnext", 167*da0073e9SAndroid Build Coastguard Worker "selecsls", 168*da0073e9SAndroid Build Coastguard Worker "vgg", 169*da0073e9SAndroid Build Coastguard Worker "xception", 170*da0073e9SAndroid Build Coastguard Worker ] 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker for known_family in known_families: 173*da0073e9SAndroid Build Coastguard Worker if known_family in name: 174*da0073e9SAndroid Build Coastguard Worker return known_family 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker if name.startswith("gluon_"): 177*da0073e9SAndroid Build Coastguard Worker return "gluon_" + name.split("_")[1] 178*da0073e9SAndroid Build Coastguard Worker return name.split("_")[0] 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker def populate_family(models): 181*da0073e9SAndroid Build Coastguard Worker family = {} 182*da0073e9SAndroid Build Coastguard Worker for model_name in models: 183*da0073e9SAndroid Build Coastguard Worker family_name = get_family_name(model_name) 184*da0073e9SAndroid Build Coastguard Worker if family_name not in family: 185*da0073e9SAndroid Build Coastguard Worker family[family_name] = [] 186*da0073e9SAndroid Build Coastguard Worker family[family_name].append(model_name) 187*da0073e9SAndroid Build Coastguard Worker return family 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker docs_models = read_models_from_docs() 190*da0073e9SAndroid Build Coastguard Worker all_models = list_models(pretrained=True, exclude_filters=["*in21k"]) 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker all_models_family = populate_family(all_models) 193*da0073e9SAndroid Build Coastguard Worker docs_models_family = populate_family(docs_models) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker for key in docs_models_family: 196*da0073e9SAndroid Build Coastguard Worker del all_models_family[key] 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker chosen_models = set() 199*da0073e9SAndroid Build Coastguard Worker chosen_models.update(value[0] for value in docs_models_family.values()) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker chosen_models.update(value[0] for key, value in all_models_family.items()) 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker filename = "timm_models_list.txt" 204*da0073e9SAndroid Build Coastguard Worker if os.path.exists("benchmarks"): 205*da0073e9SAndroid Build Coastguard Worker filename = "benchmarks/" + filename 206*da0073e9SAndroid Build Coastguard Worker with open(filename, "w") as fw: 207*da0073e9SAndroid Build Coastguard Worker for model_name in sorted(chosen_models): 208*da0073e9SAndroid Build Coastguard Worker fw.write(model_name + "\n") 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Workerclass TimmRunner(BenchmarkRunner): 212*da0073e9SAndroid Build Coastguard Worker def __init__(self): 213*da0073e9SAndroid Build Coastguard Worker super().__init__() 214*da0073e9SAndroid Build Coastguard Worker self.suite_name = "timm_models" 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker @property 217*da0073e9SAndroid Build Coastguard Worker def force_amp_for_fp16_bf16_models(self): 218*da0073e9SAndroid Build Coastguard Worker return FORCE_AMP_FOR_FP16_BF16_MODELS 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker @property 221*da0073e9SAndroid Build Coastguard Worker def force_fp16_for_bf16_models(self): 222*da0073e9SAndroid Build Coastguard Worker return set() 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker @property 225*da0073e9SAndroid Build Coastguard Worker def get_output_amp_train_process_func(self): 226*da0073e9SAndroid Build Coastguard Worker return {} 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker @property 229*da0073e9SAndroid Build Coastguard Worker def skip_accuracy_check_as_eager_non_deterministic(self): 230*da0073e9SAndroid Build Coastguard Worker if self.args.accuracy and self.args.training: 231*da0073e9SAndroid Build Coastguard Worker return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS 232*da0073e9SAndroid Build Coastguard Worker return set() 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker @property 235*da0073e9SAndroid Build Coastguard Worker def guard_on_nn_module_models(self): 236*da0073e9SAndroid Build Coastguard Worker return { 237*da0073e9SAndroid Build Coastguard Worker "convit_base", 238*da0073e9SAndroid Build Coastguard Worker } 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker @property 241*da0073e9SAndroid Build Coastguard Worker def inline_inbuilt_nn_modules_models(self): 242*da0073e9SAndroid Build Coastguard Worker return { 243*da0073e9SAndroid Build Coastguard Worker "lcnet_050", 244*da0073e9SAndroid Build Coastguard Worker } 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker @download_retry_decorator 247*da0073e9SAndroid Build Coastguard Worker def _download_model(self, model_name): 248*da0073e9SAndroid Build Coastguard Worker model = create_model( 249*da0073e9SAndroid Build Coastguard Worker model_name, 250*da0073e9SAndroid Build Coastguard Worker in_chans=3, 251*da0073e9SAndroid Build Coastguard Worker scriptable=False, 252*da0073e9SAndroid Build Coastguard Worker num_classes=None, 253*da0073e9SAndroid Build Coastguard Worker drop_rate=0.0, 254*da0073e9SAndroid Build Coastguard Worker drop_path_rate=None, 255*da0073e9SAndroid Build Coastguard Worker drop_block_rate=None, 256*da0073e9SAndroid Build Coastguard Worker pretrained=True, 257*da0073e9SAndroid Build Coastguard Worker ) 258*da0073e9SAndroid Build Coastguard Worker return model 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker def load_model( 261*da0073e9SAndroid Build Coastguard Worker self, 262*da0073e9SAndroid Build Coastguard Worker device, 263*da0073e9SAndroid Build Coastguard Worker model_name, 264*da0073e9SAndroid Build Coastguard Worker batch_size=None, 265*da0073e9SAndroid Build Coastguard Worker extra_args=None, 266*da0073e9SAndroid Build Coastguard Worker ): 267*da0073e9SAndroid Build Coastguard Worker if self.args.enable_activation_checkpointing: 268*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 269*da0073e9SAndroid Build Coastguard Worker "Activation checkpointing not implemented for Timm models" 270*da0073e9SAndroid Build Coastguard Worker ) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker is_training = self.args.training 273*da0073e9SAndroid Build Coastguard Worker use_eval_mode = self.args.use_eval_mode 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker channels_last = self._args.channels_last 276*da0073e9SAndroid Build Coastguard Worker model = self._download_model(model_name) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker if model is None: 279*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Failed to load model '{model_name}'") 280*da0073e9SAndroid Build Coastguard Worker model.to( 281*da0073e9SAndroid Build Coastguard Worker device=device, 282*da0073e9SAndroid Build Coastguard Worker memory_format=torch.channels_last if channels_last else None, 283*da0073e9SAndroid Build Coastguard Worker ) 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker self.num_classes = model.num_classes 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker data_config = resolve_data_config( 288*da0073e9SAndroid Build Coastguard Worker vars(self._args) if timmversion >= "0.8.0" else self._args, 289*da0073e9SAndroid Build Coastguard Worker model=model, 290*da0073e9SAndroid Build Coastguard Worker use_test_size=not is_training, 291*da0073e9SAndroid Build Coastguard Worker ) 292*da0073e9SAndroid Build Coastguard Worker input_size = data_config["input_size"] 293*da0073e9SAndroid Build Coastguard Worker recorded_batch_size = TIMM_MODELS[model_name] 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker if model_name in BATCH_SIZE_DIVISORS: 296*da0073e9SAndroid Build Coastguard Worker recorded_batch_size = max( 297*da0073e9SAndroid Build Coastguard Worker int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1 298*da0073e9SAndroid Build Coastguard Worker ) 299*da0073e9SAndroid Build Coastguard Worker batch_size = batch_size or recorded_batch_size 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1337) 302*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.randint( 303*da0073e9SAndroid Build Coastguard Worker 256, size=(batch_size,) + input_size, device=device 304*da0073e9SAndroid Build Coastguard Worker ).to(dtype=torch.float32) 305*da0073e9SAndroid Build Coastguard Worker mean = torch.mean(input_tensor) 306*da0073e9SAndroid Build Coastguard Worker std_dev = torch.std(input_tensor) 307*da0073e9SAndroid Build Coastguard Worker example_inputs = (input_tensor - mean) / std_dev 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker if channels_last: 310*da0073e9SAndroid Build Coastguard Worker example_inputs = example_inputs.contiguous( 311*da0073e9SAndroid Build Coastguard Worker memory_format=torch.channels_last 312*da0073e9SAndroid Build Coastguard Worker ) 313*da0073e9SAndroid Build Coastguard Worker example_inputs = [ 314*da0073e9SAndroid Build Coastguard Worker example_inputs, 315*da0073e9SAndroid Build Coastguard Worker ] 316*da0073e9SAndroid Build Coastguard Worker self.target = self._gen_target(batch_size, device) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker self.loss = torch.nn.CrossEntropyLoss().to(device) 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker if model_name in SCALED_COMPUTE_LOSS: 321*da0073e9SAndroid Build Coastguard Worker self.compute_loss = self.scaled_compute_loss 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker if is_training and not use_eval_mode: 324*da0073e9SAndroid Build Coastguard Worker model.train() 325*da0073e9SAndroid Build Coastguard Worker else: 326*da0073e9SAndroid Build Coastguard Worker model.eval() 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker self.validate_model(model, example_inputs) 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker return device, model_name, model, example_inputs, batch_size 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker def iter_model_names(self, args): 333*da0073e9SAndroid Build Coastguard Worker # for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]): 334*da0073e9SAndroid Build Coastguard Worker model_names = sorted(TIMM_MODELS.keys()) 335*da0073e9SAndroid Build Coastguard Worker start, end = self.get_benchmark_indices(len(model_names)) 336*da0073e9SAndroid Build Coastguard Worker for index, model_name in enumerate(model_names): 337*da0073e9SAndroid Build Coastguard Worker if index < start or index >= end: 338*da0073e9SAndroid Build Coastguard Worker continue 339*da0073e9SAndroid Build Coastguard Worker if ( 340*da0073e9SAndroid Build Coastguard Worker not re.search("|".join(args.filter), model_name, re.IGNORECASE) 341*da0073e9SAndroid Build Coastguard Worker or re.search("|".join(args.exclude), model_name, re.IGNORECASE) 342*da0073e9SAndroid Build Coastguard Worker or model_name in args.exclude_exact 343*da0073e9SAndroid Build Coastguard Worker or model_name in self.skip_models 344*da0073e9SAndroid Build Coastguard Worker ): 345*da0073e9SAndroid Build Coastguard Worker continue 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker yield model_name 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker def pick_grad(self, name, is_training): 350*da0073e9SAndroid Build Coastguard Worker if is_training: 351*da0073e9SAndroid Build Coastguard Worker return torch.enable_grad() 352*da0073e9SAndroid Build Coastguard Worker else: 353*da0073e9SAndroid Build Coastguard Worker return torch.no_grad() 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker def use_larger_multiplier_for_smaller_tensor(self, name): 356*da0073e9SAndroid Build Coastguard Worker return name in REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker def get_tolerance_and_cosine_flag(self, is_training, current_device, name): 359*da0073e9SAndroid Build Coastguard Worker cosine = self.args.cosine 360*da0073e9SAndroid Build Coastguard Worker tolerance = 1e-3 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING: 363*da0073e9SAndroid Build Coastguard Worker # the conv-batchnorm fusion used under freezing may cause relatively 364*da0073e9SAndroid Build Coastguard Worker # large numerical difference. We need are larger tolerance. 365*da0073e9SAndroid Build Coastguard Worker # Check https://github.com/pytorch/pytorch/issues/120545 for context 366*da0073e9SAndroid Build Coastguard Worker tolerance = 8 * 1e-2 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker if is_training: 369*da0073e9SAndroid Build Coastguard Worker from torch._inductor import config as inductor_config 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker if name in REQUIRE_EVEN_HIGHER_TOLERANCE or ( 372*da0073e9SAndroid Build Coastguard Worker inductor_config.max_autotune 373*da0073e9SAndroid Build Coastguard Worker and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE 374*da0073e9SAndroid Build Coastguard Worker ): 375*da0073e9SAndroid Build Coastguard Worker tolerance = 8 * 1e-2 376*da0073e9SAndroid Build Coastguard Worker elif name in REQUIRE_HIGHER_TOLERANCE: 377*da0073e9SAndroid Build Coastguard Worker tolerance = 4 * 1e-2 378*da0073e9SAndroid Build Coastguard Worker else: 379*da0073e9SAndroid Build Coastguard Worker tolerance = 1e-2 380*da0073e9SAndroid Build Coastguard Worker return tolerance, cosine 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker def _gen_target(self, batch_size, device): 383*da0073e9SAndroid Build Coastguard Worker return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_( 384*da0073e9SAndroid Build Coastguard Worker self.num_classes 385*da0073e9SAndroid Build Coastguard Worker ) 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker def compute_loss(self, pred): 388*da0073e9SAndroid Build Coastguard Worker # High loss values make gradient checking harder, as small changes in 389*da0073e9SAndroid Build Coastguard Worker # accumulation order upsets accuracy checks. 390*da0073e9SAndroid Build Coastguard Worker return reduce_to_scalar_loss(pred) 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker def scaled_compute_loss(self, pred): 393*da0073e9SAndroid Build Coastguard Worker # Loss values need zoom out further. 394*da0073e9SAndroid Build Coastguard Worker return reduce_to_scalar_loss(pred) / 1000.0 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker def forward_pass(self, mod, inputs, collect_outputs=True): 397*da0073e9SAndroid Build Coastguard Worker with self.autocast(**self.autocast_arg): 398*da0073e9SAndroid Build Coastguard Worker return mod(*inputs) 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): 401*da0073e9SAndroid Build Coastguard Worker cloned_inputs = clone_inputs(inputs) 402*da0073e9SAndroid Build Coastguard Worker self.optimizer_zero_grad(mod) 403*da0073e9SAndroid Build Coastguard Worker with self.autocast(**self.autocast_arg): 404*da0073e9SAndroid Build Coastguard Worker pred = mod(*cloned_inputs) 405*da0073e9SAndroid Build Coastguard Worker if isinstance(pred, tuple): 406*da0073e9SAndroid Build Coastguard Worker pred = pred[0] 407*da0073e9SAndroid Build Coastguard Worker loss = self.compute_loss(pred) 408*da0073e9SAndroid Build Coastguard Worker self.grad_scaler.scale(loss).backward() 409*da0073e9SAndroid Build Coastguard Worker self.optimizer_step() 410*da0073e9SAndroid Build Coastguard Worker if collect_outputs: 411*da0073e9SAndroid Build Coastguard Worker return collect_results(mod, pred, loss, cloned_inputs) 412*da0073e9SAndroid Build Coastguard Worker return None 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Workerdef timm_main(): 416*da0073e9SAndroid Build Coastguard Worker logging.basicConfig(level=logging.WARNING) 417*da0073e9SAndroid Build Coastguard Worker warnings.filterwarnings("ignore") 418*da0073e9SAndroid Build Coastguard Worker main(TimmRunner()) 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 422*da0073e9SAndroid Build Coastguard Worker timm_main() 423