xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/timm_models.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3import importlib
4import logging
5import os
6import re
7import subprocess
8import sys
9import warnings
10
11
12try:
13    from .common import BenchmarkRunner, download_retry_decorator, main
14except ImportError:
15    from common import BenchmarkRunner, download_retry_decorator, main
16
17import torch
18from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
19from torch._dynamo.utils import clone_inputs
20
21
22# Enable FX graph caching
23if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
24    torch._inductor.config.fx_graph_cache = True
25
26
27def pip_install(package):
28    subprocess.check_call([sys.executable, "-m", "pip", "install", package])
29
30
31try:
32    importlib.import_module("timm")
33except ModuleNotFoundError:
34    print("Installing PyTorch Image Models...")
35    pip_install("git+https://github.com/rwightman/pytorch-image-models")
36finally:
37    from timm import __version__ as timmversion
38    from timm.data import resolve_data_config
39    from timm.models import create_model
40
41TIMM_MODELS = {}
42filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
43
44with open(filename) as fh:
45    lines = fh.readlines()
46    lines = [line.rstrip() for line in lines]
47    for line in lines:
48        model_name, batch_size = line.split(" ")
49        TIMM_MODELS[model_name] = int(batch_size)
50
51
52# TODO - Figure out the reason of cold start memory spike
53
54BATCH_SIZE_DIVISORS = {
55    "beit_base_patch16_224": 2,
56    "convit_base": 2,
57    "convmixer_768_32": 2,
58    "convnext_base": 2,
59    "cspdarknet53": 2,
60    "deit_base_distilled_patch16_224": 2,
61    "gluon_xception65": 2,
62    "mobilevit_s": 2,
63    "pnasnet5large": 2,
64    "poolformer_m36": 2,
65    "resnest101e": 2,
66    "swin_base_patch4_window7_224": 2,
67    "swsl_resnext101_32x16d": 2,
68    "vit_base_patch16_224": 2,
69    "volo_d1_224": 2,
70    "jx_nest_base": 4,
71}
72
73REQUIRE_HIGHER_TOLERANCE = {
74    "fbnetv3_b",
75    "gmixer_24_224",
76    "hrnet_w18",
77    "inception_v3",
78    "mixer_b16_224",
79    "mobilenetv3_large_100",
80    "sebotnet33ts_256",
81    "selecsls42b",
82}
83
84REQUIRE_EVEN_HIGHER_TOLERANCE = {
85    "levit_128",
86    "sebotnet33ts_256",
87    "beit_base_patch16_224",
88    "cspdarknet53",
89}
90
91# These models need higher tolerance in MaxAutotune mode
92REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE = {
93    "gluon_inception_v3",
94}
95
96REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = {
97    "adv_inception_v3",
98    "botnet26t_256",
99    "gluon_inception_v3",
100    "selecsls42b",
101    "swsl_resnext101_32x16d",
102}
103
104SCALED_COMPUTE_LOSS = {
105    "ese_vovnet19b_dw",
106    "fbnetc_100",
107    "mnasnet_100",
108    "mobilevit_s",
109    "sebotnet33ts_256",
110}
111
112FORCE_AMP_FOR_FP16_BF16_MODELS = {
113    "convit_base",
114    "xcit_large_24_p8_224",
115}
116
117SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = {
118    "xcit_large_24_p8_224",
119}
120
121REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR = {
122    "inception_v3",
123    "mobilenetv3_large_100",
124    "cspdarknet53",
125}
126
127
128def refresh_model_names():
129    import glob
130
131    from timm.models import list_models
132
133    def read_models_from_docs():
134        models = set()
135        # TODO - set the path to pytorch-image-models repo
136        for fn in glob.glob("../pytorch-image-models/docs/models/*.md"):
137            with open(fn) as f:
138                while True:
139                    line = f.readline()
140                    if not line:
141                        break
142                    if not line.startswith("model = timm.create_model("):
143                        continue
144
145                    model = line.split("'")[1]
146                    # print(model)
147                    models.add(model)
148        return models
149
150    def get_family_name(name):
151        known_families = [
152            "darknet",
153            "densenet",
154            "dla",
155            "dpn",
156            "ecaresnet",
157            "halo",
158            "regnet",
159            "efficientnet",
160            "deit",
161            "mobilevit",
162            "mnasnet",
163            "convnext",
164            "resnet",
165            "resnest",
166            "resnext",
167            "selecsls",
168            "vgg",
169            "xception",
170        ]
171
172        for known_family in known_families:
173            if known_family in name:
174                return known_family
175
176        if name.startswith("gluon_"):
177            return "gluon_" + name.split("_")[1]
178        return name.split("_")[0]
179
180    def populate_family(models):
181        family = {}
182        for model_name in models:
183            family_name = get_family_name(model_name)
184            if family_name not in family:
185                family[family_name] = []
186            family[family_name].append(model_name)
187        return family
188
189    docs_models = read_models_from_docs()
190    all_models = list_models(pretrained=True, exclude_filters=["*in21k"])
191
192    all_models_family = populate_family(all_models)
193    docs_models_family = populate_family(docs_models)
194
195    for key in docs_models_family:
196        del all_models_family[key]
197
198    chosen_models = set()
199    chosen_models.update(value[0] for value in docs_models_family.values())
200
201    chosen_models.update(value[0] for key, value in all_models_family.items())
202
203    filename = "timm_models_list.txt"
204    if os.path.exists("benchmarks"):
205        filename = "benchmarks/" + filename
206    with open(filename, "w") as fw:
207        for model_name in sorted(chosen_models):
208            fw.write(model_name + "\n")
209
210
211class TimmRunner(BenchmarkRunner):
212    def __init__(self):
213        super().__init__()
214        self.suite_name = "timm_models"
215
216    @property
217    def force_amp_for_fp16_bf16_models(self):
218        return FORCE_AMP_FOR_FP16_BF16_MODELS
219
220    @property
221    def force_fp16_for_bf16_models(self):
222        return set()
223
224    @property
225    def get_output_amp_train_process_func(self):
226        return {}
227
228    @property
229    def skip_accuracy_check_as_eager_non_deterministic(self):
230        if self.args.accuracy and self.args.training:
231            return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS
232        return set()
233
234    @property
235    def guard_on_nn_module_models(self):
236        return {
237            "convit_base",
238        }
239
240    @property
241    def inline_inbuilt_nn_modules_models(self):
242        return {
243            "lcnet_050",
244        }
245
246    @download_retry_decorator
247    def _download_model(self, model_name):
248        model = create_model(
249            model_name,
250            in_chans=3,
251            scriptable=False,
252            num_classes=None,
253            drop_rate=0.0,
254            drop_path_rate=None,
255            drop_block_rate=None,
256            pretrained=True,
257        )
258        return model
259
260    def load_model(
261        self,
262        device,
263        model_name,
264        batch_size=None,
265        extra_args=None,
266    ):
267        if self.args.enable_activation_checkpointing:
268            raise NotImplementedError(
269                "Activation checkpointing not implemented for Timm models"
270            )
271
272        is_training = self.args.training
273        use_eval_mode = self.args.use_eval_mode
274
275        channels_last = self._args.channels_last
276        model = self._download_model(model_name)
277
278        if model is None:
279            raise RuntimeError(f"Failed to load model '{model_name}'")
280        model.to(
281            device=device,
282            memory_format=torch.channels_last if channels_last else None,
283        )
284
285        self.num_classes = model.num_classes
286
287        data_config = resolve_data_config(
288            vars(self._args) if timmversion >= "0.8.0" else self._args,
289            model=model,
290            use_test_size=not is_training,
291        )
292        input_size = data_config["input_size"]
293        recorded_batch_size = TIMM_MODELS[model_name]
294
295        if model_name in BATCH_SIZE_DIVISORS:
296            recorded_batch_size = max(
297                int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1
298            )
299        batch_size = batch_size or recorded_batch_size
300
301        torch.manual_seed(1337)
302        input_tensor = torch.randint(
303            256, size=(batch_size,) + input_size, device=device
304        ).to(dtype=torch.float32)
305        mean = torch.mean(input_tensor)
306        std_dev = torch.std(input_tensor)
307        example_inputs = (input_tensor - mean) / std_dev
308
309        if channels_last:
310            example_inputs = example_inputs.contiguous(
311                memory_format=torch.channels_last
312            )
313        example_inputs = [
314            example_inputs,
315        ]
316        self.target = self._gen_target(batch_size, device)
317
318        self.loss = torch.nn.CrossEntropyLoss().to(device)
319
320        if model_name in SCALED_COMPUTE_LOSS:
321            self.compute_loss = self.scaled_compute_loss
322
323        if is_training and not use_eval_mode:
324            model.train()
325        else:
326            model.eval()
327
328        self.validate_model(model, example_inputs)
329
330        return device, model_name, model, example_inputs, batch_size
331
332    def iter_model_names(self, args):
333        # for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]):
334        model_names = sorted(TIMM_MODELS.keys())
335        start, end = self.get_benchmark_indices(len(model_names))
336        for index, model_name in enumerate(model_names):
337            if index < start or index >= end:
338                continue
339            if (
340                not re.search("|".join(args.filter), model_name, re.IGNORECASE)
341                or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
342                or model_name in args.exclude_exact
343                or model_name in self.skip_models
344            ):
345                continue
346
347            yield model_name
348
349    def pick_grad(self, name, is_training):
350        if is_training:
351            return torch.enable_grad()
352        else:
353            return torch.no_grad()
354
355    def use_larger_multiplier_for_smaller_tensor(self, name):
356        return name in REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR
357
358    def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
359        cosine = self.args.cosine
360        tolerance = 1e-3
361
362        if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING:
363            # the conv-batchnorm fusion used under freezing may cause relatively
364            # large numerical difference. We need are larger tolerance.
365            # Check https://github.com/pytorch/pytorch/issues/120545 for context
366            tolerance = 8 * 1e-2
367
368        if is_training:
369            from torch._inductor import config as inductor_config
370
371            if name in REQUIRE_EVEN_HIGHER_TOLERANCE or (
372                inductor_config.max_autotune
373                and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE
374            ):
375                tolerance = 8 * 1e-2
376            elif name in REQUIRE_HIGHER_TOLERANCE:
377                tolerance = 4 * 1e-2
378            else:
379                tolerance = 1e-2
380        return tolerance, cosine
381
382    def _gen_target(self, batch_size, device):
383        return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
384            self.num_classes
385        )
386
387    def compute_loss(self, pred):
388        # High loss values make gradient checking harder, as small changes in
389        # accumulation order upsets accuracy checks.
390        return reduce_to_scalar_loss(pred)
391
392    def scaled_compute_loss(self, pred):
393        # Loss values need zoom out further.
394        return reduce_to_scalar_loss(pred) / 1000.0
395
396    def forward_pass(self, mod, inputs, collect_outputs=True):
397        with self.autocast(**self.autocast_arg):
398            return mod(*inputs)
399
400    def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
401        cloned_inputs = clone_inputs(inputs)
402        self.optimizer_zero_grad(mod)
403        with self.autocast(**self.autocast_arg):
404            pred = mod(*cloned_inputs)
405            if isinstance(pred, tuple):
406                pred = pred[0]
407            loss = self.compute_loss(pred)
408        self.grad_scaler.scale(loss).backward()
409        self.optimizer_step()
410        if collect_outputs:
411            return collect_results(mod, pred, loss, cloned_inputs)
412        return None
413
414
415def timm_main():
416    logging.basicConfig(level=logging.WARNING)
417    warnings.filterwarnings("ignore")
418    main(TimmRunner())
419
420
421if __name__ == "__main__":
422    timm_main()
423