xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/timm_models.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport importlib
4*da0073e9SAndroid Build Coastguard Workerimport logging
5*da0073e9SAndroid Build Coastguard Workerimport os
6*da0073e9SAndroid Build Coastguard Workerimport re
7*da0073e9SAndroid Build Coastguard Workerimport subprocess
8*da0073e9SAndroid Build Coastguard Workerimport sys
9*da0073e9SAndroid Build Coastguard Workerimport warnings
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workertry:
13*da0073e9SAndroid Build Coastguard Worker    from .common import BenchmarkRunner, download_retry_decorator, main
14*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
15*da0073e9SAndroid Build Coastguard Worker    from common import BenchmarkRunner, download_retry_decorator, main
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerimport torch
18*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import collect_results, reduce_to_scalar_loss
19*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import clone_inputs
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker# Enable FX graph caching
23*da0073e9SAndroid Build Coastguard Workerif "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
24*da0073e9SAndroid Build Coastguard Worker    torch._inductor.config.fx_graph_cache = True
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerdef pip_install(package):
28*da0073e9SAndroid Build Coastguard Worker    subprocess.check_call([sys.executable, "-m", "pip", "install", package])
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Workertry:
32*da0073e9SAndroid Build Coastguard Worker    importlib.import_module("timm")
33*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError:
34*da0073e9SAndroid Build Coastguard Worker    print("Installing PyTorch Image Models...")
35*da0073e9SAndroid Build Coastguard Worker    pip_install("git+https://github.com/rwightman/pytorch-image-models")
36*da0073e9SAndroid Build Coastguard Workerfinally:
37*da0073e9SAndroid Build Coastguard Worker    from timm import __version__ as timmversion
38*da0073e9SAndroid Build Coastguard Worker    from timm.data import resolve_data_config
39*da0073e9SAndroid Build Coastguard Worker    from timm.models import create_model
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard WorkerTIMM_MODELS = {}
42*da0073e9SAndroid Build Coastguard Workerfilename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Workerwith open(filename) as fh:
45*da0073e9SAndroid Build Coastguard Worker    lines = fh.readlines()
46*da0073e9SAndroid Build Coastguard Worker    lines = [line.rstrip() for line in lines]
47*da0073e9SAndroid Build Coastguard Worker    for line in lines:
48*da0073e9SAndroid Build Coastguard Worker        model_name, batch_size = line.split(" ")
49*da0073e9SAndroid Build Coastguard Worker        TIMM_MODELS[model_name] = int(batch_size)
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker# TODO - Figure out the reason of cold start memory spike
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard WorkerBATCH_SIZE_DIVISORS = {
55*da0073e9SAndroid Build Coastguard Worker    "beit_base_patch16_224": 2,
56*da0073e9SAndroid Build Coastguard Worker    "convit_base": 2,
57*da0073e9SAndroid Build Coastguard Worker    "convmixer_768_32": 2,
58*da0073e9SAndroid Build Coastguard Worker    "convnext_base": 2,
59*da0073e9SAndroid Build Coastguard Worker    "cspdarknet53": 2,
60*da0073e9SAndroid Build Coastguard Worker    "deit_base_distilled_patch16_224": 2,
61*da0073e9SAndroid Build Coastguard Worker    "gluon_xception65": 2,
62*da0073e9SAndroid Build Coastguard Worker    "mobilevit_s": 2,
63*da0073e9SAndroid Build Coastguard Worker    "pnasnet5large": 2,
64*da0073e9SAndroid Build Coastguard Worker    "poolformer_m36": 2,
65*da0073e9SAndroid Build Coastguard Worker    "resnest101e": 2,
66*da0073e9SAndroid Build Coastguard Worker    "swin_base_patch4_window7_224": 2,
67*da0073e9SAndroid Build Coastguard Worker    "swsl_resnext101_32x16d": 2,
68*da0073e9SAndroid Build Coastguard Worker    "vit_base_patch16_224": 2,
69*da0073e9SAndroid Build Coastguard Worker    "volo_d1_224": 2,
70*da0073e9SAndroid Build Coastguard Worker    "jx_nest_base": 4,
71*da0073e9SAndroid Build Coastguard Worker}
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard WorkerREQUIRE_HIGHER_TOLERANCE = {
74*da0073e9SAndroid Build Coastguard Worker    "fbnetv3_b",
75*da0073e9SAndroid Build Coastguard Worker    "gmixer_24_224",
76*da0073e9SAndroid Build Coastguard Worker    "hrnet_w18",
77*da0073e9SAndroid Build Coastguard Worker    "inception_v3",
78*da0073e9SAndroid Build Coastguard Worker    "mixer_b16_224",
79*da0073e9SAndroid Build Coastguard Worker    "mobilenetv3_large_100",
80*da0073e9SAndroid Build Coastguard Worker    "sebotnet33ts_256",
81*da0073e9SAndroid Build Coastguard Worker    "selecsls42b",
82*da0073e9SAndroid Build Coastguard Worker}
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard WorkerREQUIRE_EVEN_HIGHER_TOLERANCE = {
85*da0073e9SAndroid Build Coastguard Worker    "levit_128",
86*da0073e9SAndroid Build Coastguard Worker    "sebotnet33ts_256",
87*da0073e9SAndroid Build Coastguard Worker    "beit_base_patch16_224",
88*da0073e9SAndroid Build Coastguard Worker    "cspdarknet53",
89*da0073e9SAndroid Build Coastguard Worker}
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker# These models need higher tolerance in MaxAutotune mode
92*da0073e9SAndroid Build Coastguard WorkerREQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE = {
93*da0073e9SAndroid Build Coastguard Worker    "gluon_inception_v3",
94*da0073e9SAndroid Build Coastguard Worker}
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard WorkerREQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = {
97*da0073e9SAndroid Build Coastguard Worker    "adv_inception_v3",
98*da0073e9SAndroid Build Coastguard Worker    "botnet26t_256",
99*da0073e9SAndroid Build Coastguard Worker    "gluon_inception_v3",
100*da0073e9SAndroid Build Coastguard Worker    "selecsls42b",
101*da0073e9SAndroid Build Coastguard Worker    "swsl_resnext101_32x16d",
102*da0073e9SAndroid Build Coastguard Worker}
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard WorkerSCALED_COMPUTE_LOSS = {
105*da0073e9SAndroid Build Coastguard Worker    "ese_vovnet19b_dw",
106*da0073e9SAndroid Build Coastguard Worker    "fbnetc_100",
107*da0073e9SAndroid Build Coastguard Worker    "mnasnet_100",
108*da0073e9SAndroid Build Coastguard Worker    "mobilevit_s",
109*da0073e9SAndroid Build Coastguard Worker    "sebotnet33ts_256",
110*da0073e9SAndroid Build Coastguard Worker}
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard WorkerFORCE_AMP_FOR_FP16_BF16_MODELS = {
113*da0073e9SAndroid Build Coastguard Worker    "convit_base",
114*da0073e9SAndroid Build Coastguard Worker    "xcit_large_24_p8_224",
115*da0073e9SAndroid Build Coastguard Worker}
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard WorkerSKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = {
118*da0073e9SAndroid Build Coastguard Worker    "xcit_large_24_p8_224",
119*da0073e9SAndroid Build Coastguard Worker}
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard WorkerREQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR = {
122*da0073e9SAndroid Build Coastguard Worker    "inception_v3",
123*da0073e9SAndroid Build Coastguard Worker    "mobilenetv3_large_100",
124*da0073e9SAndroid Build Coastguard Worker    "cspdarknet53",
125*da0073e9SAndroid Build Coastguard Worker}
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Workerdef refresh_model_names():
129*da0073e9SAndroid Build Coastguard Worker    import glob
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    from timm.models import list_models
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    def read_models_from_docs():
134*da0073e9SAndroid Build Coastguard Worker        models = set()
135*da0073e9SAndroid Build Coastguard Worker        # TODO - set the path to pytorch-image-models repo
136*da0073e9SAndroid Build Coastguard Worker        for fn in glob.glob("../pytorch-image-models/docs/models/*.md"):
137*da0073e9SAndroid Build Coastguard Worker            with open(fn) as f:
138*da0073e9SAndroid Build Coastguard Worker                while True:
139*da0073e9SAndroid Build Coastguard Worker                    line = f.readline()
140*da0073e9SAndroid Build Coastguard Worker                    if not line:
141*da0073e9SAndroid Build Coastguard Worker                        break
142*da0073e9SAndroid Build Coastguard Worker                    if not line.startswith("model = timm.create_model("):
143*da0073e9SAndroid Build Coastguard Worker                        continue
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker                    model = line.split("'")[1]
146*da0073e9SAndroid Build Coastguard Worker                    # print(model)
147*da0073e9SAndroid Build Coastguard Worker                    models.add(model)
148*da0073e9SAndroid Build Coastguard Worker        return models
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker    def get_family_name(name):
151*da0073e9SAndroid Build Coastguard Worker        known_families = [
152*da0073e9SAndroid Build Coastguard Worker            "darknet",
153*da0073e9SAndroid Build Coastguard Worker            "densenet",
154*da0073e9SAndroid Build Coastguard Worker            "dla",
155*da0073e9SAndroid Build Coastguard Worker            "dpn",
156*da0073e9SAndroid Build Coastguard Worker            "ecaresnet",
157*da0073e9SAndroid Build Coastguard Worker            "halo",
158*da0073e9SAndroid Build Coastguard Worker            "regnet",
159*da0073e9SAndroid Build Coastguard Worker            "efficientnet",
160*da0073e9SAndroid Build Coastguard Worker            "deit",
161*da0073e9SAndroid Build Coastguard Worker            "mobilevit",
162*da0073e9SAndroid Build Coastguard Worker            "mnasnet",
163*da0073e9SAndroid Build Coastguard Worker            "convnext",
164*da0073e9SAndroid Build Coastguard Worker            "resnet",
165*da0073e9SAndroid Build Coastguard Worker            "resnest",
166*da0073e9SAndroid Build Coastguard Worker            "resnext",
167*da0073e9SAndroid Build Coastguard Worker            "selecsls",
168*da0073e9SAndroid Build Coastguard Worker            "vgg",
169*da0073e9SAndroid Build Coastguard Worker            "xception",
170*da0073e9SAndroid Build Coastguard Worker        ]
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker        for known_family in known_families:
173*da0073e9SAndroid Build Coastguard Worker            if known_family in name:
174*da0073e9SAndroid Build Coastguard Worker                return known_family
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        if name.startswith("gluon_"):
177*da0073e9SAndroid Build Coastguard Worker            return "gluon_" + name.split("_")[1]
178*da0073e9SAndroid Build Coastguard Worker        return name.split("_")[0]
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    def populate_family(models):
181*da0073e9SAndroid Build Coastguard Worker        family = {}
182*da0073e9SAndroid Build Coastguard Worker        for model_name in models:
183*da0073e9SAndroid Build Coastguard Worker            family_name = get_family_name(model_name)
184*da0073e9SAndroid Build Coastguard Worker            if family_name not in family:
185*da0073e9SAndroid Build Coastguard Worker                family[family_name] = []
186*da0073e9SAndroid Build Coastguard Worker            family[family_name].append(model_name)
187*da0073e9SAndroid Build Coastguard Worker        return family
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker    docs_models = read_models_from_docs()
190*da0073e9SAndroid Build Coastguard Worker    all_models = list_models(pretrained=True, exclude_filters=["*in21k"])
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker    all_models_family = populate_family(all_models)
193*da0073e9SAndroid Build Coastguard Worker    docs_models_family = populate_family(docs_models)
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    for key in docs_models_family:
196*da0073e9SAndroid Build Coastguard Worker        del all_models_family[key]
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker    chosen_models = set()
199*da0073e9SAndroid Build Coastguard Worker    chosen_models.update(value[0] for value in docs_models_family.values())
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker    chosen_models.update(value[0] for key, value in all_models_family.items())
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker    filename = "timm_models_list.txt"
204*da0073e9SAndroid Build Coastguard Worker    if os.path.exists("benchmarks"):
205*da0073e9SAndroid Build Coastguard Worker        filename = "benchmarks/" + filename
206*da0073e9SAndroid Build Coastguard Worker    with open(filename, "w") as fw:
207*da0073e9SAndroid Build Coastguard Worker        for model_name in sorted(chosen_models):
208*da0073e9SAndroid Build Coastguard Worker            fw.write(model_name + "\n")
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Workerclass TimmRunner(BenchmarkRunner):
212*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
213*da0073e9SAndroid Build Coastguard Worker        super().__init__()
214*da0073e9SAndroid Build Coastguard Worker        self.suite_name = "timm_models"
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    @property
217*da0073e9SAndroid Build Coastguard Worker    def force_amp_for_fp16_bf16_models(self):
218*da0073e9SAndroid Build Coastguard Worker        return FORCE_AMP_FOR_FP16_BF16_MODELS
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker    @property
221*da0073e9SAndroid Build Coastguard Worker    def force_fp16_for_bf16_models(self):
222*da0073e9SAndroid Build Coastguard Worker        return set()
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker    @property
225*da0073e9SAndroid Build Coastguard Worker    def get_output_amp_train_process_func(self):
226*da0073e9SAndroid Build Coastguard Worker        return {}
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker    @property
229*da0073e9SAndroid Build Coastguard Worker    def skip_accuracy_check_as_eager_non_deterministic(self):
230*da0073e9SAndroid Build Coastguard Worker        if self.args.accuracy and self.args.training:
231*da0073e9SAndroid Build Coastguard Worker            return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS
232*da0073e9SAndroid Build Coastguard Worker        return set()
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker    @property
235*da0073e9SAndroid Build Coastguard Worker    def guard_on_nn_module_models(self):
236*da0073e9SAndroid Build Coastguard Worker        return {
237*da0073e9SAndroid Build Coastguard Worker            "convit_base",
238*da0073e9SAndroid Build Coastguard Worker        }
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker    @property
241*da0073e9SAndroid Build Coastguard Worker    def inline_inbuilt_nn_modules_models(self):
242*da0073e9SAndroid Build Coastguard Worker        return {
243*da0073e9SAndroid Build Coastguard Worker            "lcnet_050",
244*da0073e9SAndroid Build Coastguard Worker        }
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker    @download_retry_decorator
247*da0073e9SAndroid Build Coastguard Worker    def _download_model(self, model_name):
248*da0073e9SAndroid Build Coastguard Worker        model = create_model(
249*da0073e9SAndroid Build Coastguard Worker            model_name,
250*da0073e9SAndroid Build Coastguard Worker            in_chans=3,
251*da0073e9SAndroid Build Coastguard Worker            scriptable=False,
252*da0073e9SAndroid Build Coastguard Worker            num_classes=None,
253*da0073e9SAndroid Build Coastguard Worker            drop_rate=0.0,
254*da0073e9SAndroid Build Coastguard Worker            drop_path_rate=None,
255*da0073e9SAndroid Build Coastguard Worker            drop_block_rate=None,
256*da0073e9SAndroid Build Coastguard Worker            pretrained=True,
257*da0073e9SAndroid Build Coastguard Worker        )
258*da0073e9SAndroid Build Coastguard Worker        return model
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker    def load_model(
261*da0073e9SAndroid Build Coastguard Worker        self,
262*da0073e9SAndroid Build Coastguard Worker        device,
263*da0073e9SAndroid Build Coastguard Worker        model_name,
264*da0073e9SAndroid Build Coastguard Worker        batch_size=None,
265*da0073e9SAndroid Build Coastguard Worker        extra_args=None,
266*da0073e9SAndroid Build Coastguard Worker    ):
267*da0073e9SAndroid Build Coastguard Worker        if self.args.enable_activation_checkpointing:
268*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError(
269*da0073e9SAndroid Build Coastguard Worker                "Activation checkpointing not implemented for Timm models"
270*da0073e9SAndroid Build Coastguard Worker            )
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker        is_training = self.args.training
273*da0073e9SAndroid Build Coastguard Worker        use_eval_mode = self.args.use_eval_mode
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker        channels_last = self._args.channels_last
276*da0073e9SAndroid Build Coastguard Worker        model = self._download_model(model_name)
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        if model is None:
279*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"Failed to load model '{model_name}'")
280*da0073e9SAndroid Build Coastguard Worker        model.to(
281*da0073e9SAndroid Build Coastguard Worker            device=device,
282*da0073e9SAndroid Build Coastguard Worker            memory_format=torch.channels_last if channels_last else None,
283*da0073e9SAndroid Build Coastguard Worker        )
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        self.num_classes = model.num_classes
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker        data_config = resolve_data_config(
288*da0073e9SAndroid Build Coastguard Worker            vars(self._args) if timmversion >= "0.8.0" else self._args,
289*da0073e9SAndroid Build Coastguard Worker            model=model,
290*da0073e9SAndroid Build Coastguard Worker            use_test_size=not is_training,
291*da0073e9SAndroid Build Coastguard Worker        )
292*da0073e9SAndroid Build Coastguard Worker        input_size = data_config["input_size"]
293*da0073e9SAndroid Build Coastguard Worker        recorded_batch_size = TIMM_MODELS[model_name]
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker        if model_name in BATCH_SIZE_DIVISORS:
296*da0073e9SAndroid Build Coastguard Worker            recorded_batch_size = max(
297*da0073e9SAndroid Build Coastguard Worker                int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1
298*da0073e9SAndroid Build Coastguard Worker            )
299*da0073e9SAndroid Build Coastguard Worker        batch_size = batch_size or recorded_batch_size
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(1337)
302*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.randint(
303*da0073e9SAndroid Build Coastguard Worker            256, size=(batch_size,) + input_size, device=device
304*da0073e9SAndroid Build Coastguard Worker        ).to(dtype=torch.float32)
305*da0073e9SAndroid Build Coastguard Worker        mean = torch.mean(input_tensor)
306*da0073e9SAndroid Build Coastguard Worker        std_dev = torch.std(input_tensor)
307*da0073e9SAndroid Build Coastguard Worker        example_inputs = (input_tensor - mean) / std_dev
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        if channels_last:
310*da0073e9SAndroid Build Coastguard Worker            example_inputs = example_inputs.contiguous(
311*da0073e9SAndroid Build Coastguard Worker                memory_format=torch.channels_last
312*da0073e9SAndroid Build Coastguard Worker            )
313*da0073e9SAndroid Build Coastguard Worker        example_inputs = [
314*da0073e9SAndroid Build Coastguard Worker            example_inputs,
315*da0073e9SAndroid Build Coastguard Worker        ]
316*da0073e9SAndroid Build Coastguard Worker        self.target = self._gen_target(batch_size, device)
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker        self.loss = torch.nn.CrossEntropyLoss().to(device)
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker        if model_name in SCALED_COMPUTE_LOSS:
321*da0073e9SAndroid Build Coastguard Worker            self.compute_loss = self.scaled_compute_loss
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker        if is_training and not use_eval_mode:
324*da0073e9SAndroid Build Coastguard Worker            model.train()
325*da0073e9SAndroid Build Coastguard Worker        else:
326*da0073e9SAndroid Build Coastguard Worker            model.eval()
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker        self.validate_model(model, example_inputs)
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker        return device, model_name, model, example_inputs, batch_size
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker    def iter_model_names(self, args):
333*da0073e9SAndroid Build Coastguard Worker        # for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]):
334*da0073e9SAndroid Build Coastguard Worker        model_names = sorted(TIMM_MODELS.keys())
335*da0073e9SAndroid Build Coastguard Worker        start, end = self.get_benchmark_indices(len(model_names))
336*da0073e9SAndroid Build Coastguard Worker        for index, model_name in enumerate(model_names):
337*da0073e9SAndroid Build Coastguard Worker            if index < start or index >= end:
338*da0073e9SAndroid Build Coastguard Worker                continue
339*da0073e9SAndroid Build Coastguard Worker            if (
340*da0073e9SAndroid Build Coastguard Worker                not re.search("|".join(args.filter), model_name, re.IGNORECASE)
341*da0073e9SAndroid Build Coastguard Worker                or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
342*da0073e9SAndroid Build Coastguard Worker                or model_name in args.exclude_exact
343*da0073e9SAndroid Build Coastguard Worker                or model_name in self.skip_models
344*da0073e9SAndroid Build Coastguard Worker            ):
345*da0073e9SAndroid Build Coastguard Worker                continue
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker            yield model_name
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker    def pick_grad(self, name, is_training):
350*da0073e9SAndroid Build Coastguard Worker        if is_training:
351*da0073e9SAndroid Build Coastguard Worker            return torch.enable_grad()
352*da0073e9SAndroid Build Coastguard Worker        else:
353*da0073e9SAndroid Build Coastguard Worker            return torch.no_grad()
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker    def use_larger_multiplier_for_smaller_tensor(self, name):
356*da0073e9SAndroid Build Coastguard Worker        return name in REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker    def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
359*da0073e9SAndroid Build Coastguard Worker        cosine = self.args.cosine
360*da0073e9SAndroid Build Coastguard Worker        tolerance = 1e-3
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker        if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING:
363*da0073e9SAndroid Build Coastguard Worker            # the conv-batchnorm fusion used under freezing may cause relatively
364*da0073e9SAndroid Build Coastguard Worker            # large numerical difference. We need are larger tolerance.
365*da0073e9SAndroid Build Coastguard Worker            # Check https://github.com/pytorch/pytorch/issues/120545 for context
366*da0073e9SAndroid Build Coastguard Worker            tolerance = 8 * 1e-2
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker        if is_training:
369*da0073e9SAndroid Build Coastguard Worker            from torch._inductor import config as inductor_config
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker            if name in REQUIRE_EVEN_HIGHER_TOLERANCE or (
372*da0073e9SAndroid Build Coastguard Worker                inductor_config.max_autotune
373*da0073e9SAndroid Build Coastguard Worker                and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE
374*da0073e9SAndroid Build Coastguard Worker            ):
375*da0073e9SAndroid Build Coastguard Worker                tolerance = 8 * 1e-2
376*da0073e9SAndroid Build Coastguard Worker            elif name in REQUIRE_HIGHER_TOLERANCE:
377*da0073e9SAndroid Build Coastguard Worker                tolerance = 4 * 1e-2
378*da0073e9SAndroid Build Coastguard Worker            else:
379*da0073e9SAndroid Build Coastguard Worker                tolerance = 1e-2
380*da0073e9SAndroid Build Coastguard Worker        return tolerance, cosine
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker    def _gen_target(self, batch_size, device):
383*da0073e9SAndroid Build Coastguard Worker        return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
384*da0073e9SAndroid Build Coastguard Worker            self.num_classes
385*da0073e9SAndroid Build Coastguard Worker        )
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker    def compute_loss(self, pred):
388*da0073e9SAndroid Build Coastguard Worker        # High loss values make gradient checking harder, as small changes in
389*da0073e9SAndroid Build Coastguard Worker        # accumulation order upsets accuracy checks.
390*da0073e9SAndroid Build Coastguard Worker        return reduce_to_scalar_loss(pred)
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker    def scaled_compute_loss(self, pred):
393*da0073e9SAndroid Build Coastguard Worker        # Loss values need zoom out further.
394*da0073e9SAndroid Build Coastguard Worker        return reduce_to_scalar_loss(pred) / 1000.0
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker    def forward_pass(self, mod, inputs, collect_outputs=True):
397*da0073e9SAndroid Build Coastguard Worker        with self.autocast(**self.autocast_arg):
398*da0073e9SAndroid Build Coastguard Worker            return mod(*inputs)
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker    def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
401*da0073e9SAndroid Build Coastguard Worker        cloned_inputs = clone_inputs(inputs)
402*da0073e9SAndroid Build Coastguard Worker        self.optimizer_zero_grad(mod)
403*da0073e9SAndroid Build Coastguard Worker        with self.autocast(**self.autocast_arg):
404*da0073e9SAndroid Build Coastguard Worker            pred = mod(*cloned_inputs)
405*da0073e9SAndroid Build Coastguard Worker            if isinstance(pred, tuple):
406*da0073e9SAndroid Build Coastguard Worker                pred = pred[0]
407*da0073e9SAndroid Build Coastguard Worker            loss = self.compute_loss(pred)
408*da0073e9SAndroid Build Coastguard Worker        self.grad_scaler.scale(loss).backward()
409*da0073e9SAndroid Build Coastguard Worker        self.optimizer_step()
410*da0073e9SAndroid Build Coastguard Worker        if collect_outputs:
411*da0073e9SAndroid Build Coastguard Worker            return collect_results(mod, pred, loss, cloned_inputs)
412*da0073e9SAndroid Build Coastguard Worker        return None
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Workerdef timm_main():
416*da0073e9SAndroid Build Coastguard Worker    logging.basicConfig(level=logging.WARNING)
417*da0073e9SAndroid Build Coastguard Worker    warnings.filterwarnings("ignore")
418*da0073e9SAndroid Build Coastguard Worker    main(TimmRunner())
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
422*da0073e9SAndroid Build Coastguard Worker    timm_main()
423