xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/huggingface.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3import importlib
4import logging
5import os
6import re
7import subprocess
8import sys
9import warnings
10
11
12try:
13    from .common import (
14        BenchmarkRunner,
15        download_retry_decorator,
16        load_yaml_file,
17        main,
18        reset_rng_state,
19    )
20except ImportError:
21    from common import (
22        BenchmarkRunner,
23        download_retry_decorator,
24        load_yaml_file,
25        main,
26        reset_rng_state,
27    )
28
29import torch
30from torch._dynamo.testing import collect_results
31from torch._dynamo.utils import clone_inputs
32
33
34log = logging.getLogger(__name__)
35
36# Enable FX graph caching
37if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
38    torch._inductor.config.fx_graph_cache = True
39
40
41def pip_install(package):
42    subprocess.check_call([sys.executable, "-m", "pip", "install", package])
43
44
45# Disable the flake warnings for the imports. Flake8 does not provide a way to
46# disable just warning for the entire file. Disabling flake8 entirely.
47# flake8: noqa
48imports = [
49    "AlbertForPreTraining",
50    "AutoConfig",
51    "AutoModelForCausalLM",
52    "AutoModelForMaskedLM",
53    "AutoModelForSeq2SeqLM",
54    "BigBirdConfig",
55    "BlenderbotForConditionalGeneration",
56    "BlenderbotModel",
57    "BlenderbotSmallForConditionalGeneration",
58    "BlenderbotSmallModel",
59    "CLIPModel",
60    "CLIPVisionModel",
61    "ElectraForPreTraining",
62    "GPT2ForSequenceClassification",
63    "GPTJForSequenceClassification",
64    "GPTNeoForSequenceClassification",
65    "HubertForSequenceClassification",
66    "LxmertForPreTraining",
67    "LxmertForQuestionAnswering",
68    "MarianForCausalLM",
69    "MarianModel",
70    "MarianMTModel",
71    "PegasusForConditionalGeneration",
72    "PegasusModel",
73    "ReformerConfig",
74    "ViTForImageClassification",
75    "ViTForMaskedImageModeling",
76    "ViTModel",
77]
78
79
80def process_hf_reformer_output(out):
81    assert isinstance(out, list)
82    # second output is unstable
83    return [elem for i, elem in enumerate(out) if i != 1]
84
85
86try:
87    mod = importlib.import_module("transformers")
88    for cls in imports:
89        if not hasattr(mod, cls):
90            raise ModuleNotFoundError
91except ModuleNotFoundError:
92    print("Installing HuggingFace Transformers...")
93    pip_install("git+https://github.com/huggingface/transformers.git#egg=transformers")
94finally:
95    for cls in imports:
96        exec(f"from transformers import {cls}")
97
98
99# These models contain the models present in huggingface_models_list. It is a
100# combination of models supported by HF Fx parser and some manually supplied
101# models. For these models, we already know the largest batch size that can fit
102# on A100 GPUs - 40 GB.
103BATCH_SIZE_KNOWN_MODELS = {}
104
105
106# TODO(sdym): use batch-size-file parameter of common.main, like torchbench.py
107# Get the list of models and their batch sizes
108MODELS_FILENAME = os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt")
109assert os.path.exists(MODELS_FILENAME)
110with open(MODELS_FILENAME, "r") as fh:
111    lines = fh.readlines()
112    lines = [line.rstrip() for line in lines]
113    for line in lines:
114        model_name, batch_size = line.split(",")
115        batch_size = int(batch_size)
116        BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size
117assert len(BATCH_SIZE_KNOWN_MODELS)
118
119
120def get_module_cls_by_model_name(model_cls_name):
121    _module_by_model_name = {
122        "Speech2Text2Decoder": "transformers.models.speech_to_text_2.modeling_speech_to_text_2",
123        "TrOCRDecoder": "transformers.models.trocr.modeling_trocr",
124    }
125    module_name = _module_by_model_name.get(model_cls_name, "transformers")
126    module = importlib.import_module(module_name)
127    return getattr(module, model_cls_name)
128
129
130def get_sequence_length(model_cls, model_name):
131    if model_name.startswith(("Blenderbot",)):
132        seq_length = 128
133    elif model_name.startswith(("GPT2", "Bart", "T5", "PLBart", "MBart")):
134        seq_length = 1024
135    elif model_name in ("AllenaiLongformerBase", "BigBird"):
136        seq_length = 1024
137    elif model_name.startswith("OPT"):
138        seq_length = 2048
139    elif "Reformer" in model_name:
140        seq_length = 4096
141    elif model_name.startswith(
142        (
143            "Albert",
144            "Deberta",
145            "Layout",
146            "Electra",
147            "XLNet",
148            "MegatronBert",
149            "Bert",
150            "Roberta",
151        )
152    ) or model_name in ("DistillGPT2", "GoogleFnet", "YituTechConvBert", "CamemBert"):
153        seq_length = 512
154    elif model_name in ("TrOCRForCausalLM"):
155        seq_length = 256
156    elif model_name.startswith("MobileBert"):
157        seq_length = 128
158    elif model_name.startswith("Wav2Vec2"):
159        # If too short, will fail with something like
160        # ValueError: `mask_length` has to be smaller than `sequence_length`,
161        # but got `mask_length`: 10 and `sequence_length`: 9`
162        seq_length = 10000  # NB: a more realistic size is 155136
163    else:
164        log.info(
165            f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily"
166        )
167        seq_length = 128
168    return seq_length
169
170
171def generate_inputs_for_model(
172    model_cls, model, model_name, bs, device, include_loss_args=False
173):
174    # TODO - Check if following values are representative
175    num_choices = 3
176    num_visual_features = 42
177    seq_length = get_sequence_length(model_cls, model_name)
178    vocab_size = model.config.vocab_size
179
180    if model_name.startswith("Wav2Vec2"):
181        # TODO: If we add more input_values style models, try to work this
182        # into the overall control flow
183        target_length = 100
184        return {
185            "input_values": torch.randn((bs, seq_length), device=device),
186            # Added because that's what the example training script has
187            "attention_mask": rand_int_tensor(device, 0, 2, (bs, seq_length)),
188            "labels": rand_int_tensor(device, 0, vocab_size, (bs, target_length)),
189        }
190
191    if model_name.endswith("MultipleChoice"):
192        input = rand_int_tensor(device, 0, vocab_size, (bs, num_choices, seq_length))
193    elif model_name.startswith("Roberta"):
194        input = rand_int_tensor(device, 0, 1, (bs, seq_length))
195    else:
196        input = rand_int_tensor(device, 0, vocab_size, (bs, seq_length))
197
198    if "Bart" in model_name:
199        input[:, -1] = model.config.eos_token_id
200
201    input_dict = {"input_ids": input}
202
203    if (
204        model_name.startswith("T5")
205        or model_name.startswith("M2M100")
206        or model_name.startswith("MT5")
207        or model_cls
208        in [
209            BlenderbotModel,
210            BlenderbotSmallModel,
211            BlenderbotForConditionalGeneration,
212            BlenderbotSmallForConditionalGeneration,
213            PegasusModel,
214            PegasusForConditionalGeneration,
215            MarianModel,
216            MarianMTModel,
217        ]
218    ):
219        input_dict["decoder_input_ids"] = input
220
221    if model_name.startswith("Lxmert"):
222        visual_feat_dim, visual_pos_dim = (
223            model.config.visual_feat_dim,
224            model.config.visual_pos_dim,
225        )
226        input_dict["visual_feats"] = torch.randn(
227            bs, num_visual_features, visual_feat_dim
228        )
229        input_dict["visual_pos"] = torch.randn(bs, num_visual_features, visual_pos_dim)
230
231    if include_loss_args:
232        if model_name.endswith("PreTraining"):
233            if model_cls in [ElectraForPreTraining, LxmertForPreTraining]:
234                input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs, seq_length))
235            else:
236                label_name = (
237                    "sentence_order_label"
238                    if model_cls in [AlbertForPreTraining]
239                    else "next_sentence_label"
240                )
241                input_dict["labels"] = (
242                    rand_int_tensor(device, 0, vocab_size, (bs, seq_length)),
243                )
244                input_dict[label_name] = rand_int_tensor(device, 0, 1, (bs,))
245        elif model_name.endswith("QuestionAnswering"):
246            input_dict["start_positions"] = rand_int_tensor(
247                device, 0, seq_length, (bs,)
248            )
249            input_dict["end_positions"] = rand_int_tensor(device, 0, seq_length, (bs,))
250        elif (
251            model_name.endswith("MaskedLM")
252            or model_name.endswith("HeadModel")
253            or model_name.endswith("CausalLM")
254            or model_name.endswith("DoubleHeadsModel")
255        ):
256            input_dict["labels"] = rand_int_tensor(
257                device, 0, vocab_size, (bs, seq_length)
258            )
259        elif model_name.endswith("TokenClassification"):
260            input_dict["labels"] = rand_int_tensor(
261                device, 0, model.config.num_labels - 1, (bs, seq_length)
262            )
263        elif model_name.endswith("MultipleChoice"):
264            input_dict["labels"] = rand_int_tensor(device, 0, num_choices, (bs,))
265        elif model_name.endswith("SequenceClassification"):
266            input_dict["labels"] = rand_int_tensor(
267                device, 0, model.config.num_labels - 1, (bs,)
268            )
269        elif model_name.endswith("NextSentencePrediction"):
270            input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs,))
271        elif model_name.endswith("ForConditionalGeneration"):
272            input_dict["labels"] = rand_int_tensor(
273                device, 0, vocab_size - 1, (bs, seq_length)
274            )
275        elif model_name in EXTRA_MODELS:
276            input_dict["labels"] = rand_int_tensor(
277                device, 0, vocab_size, (bs, seq_length)
278            )
279        else:
280            raise NotImplementedError(
281                f"Class {model_name} unsupported for training test "
282            )
283
284    return input_dict
285
286
287def rand_int_tensor(device, low, high, shape):
288    return torch.randint(
289        low,
290        high,
291        shape,
292        device=device,
293        dtype=torch.int64,
294        requires_grad=False,
295    )
296
297
298EXTRA_MODELS = {
299    "AllenaiLongformerBase": (
300        AutoConfig.from_pretrained("allenai/longformer-base-4096"),
301        AutoModelForMaskedLM,
302    ),
303    "Reformer": (
304        ReformerConfig(),
305        AutoModelForMaskedLM,
306    ),
307    "T5Small": (
308        AutoConfig.from_pretrained("t5-small"),
309        AutoModelForSeq2SeqLM,
310    ),
311    # "BigBird": (
312    #     BigBirdConfig(attention_type="block_sparse"),
313    #     AutoModelForMaskedLM,
314    # ),
315    "DistillGPT2": (
316        AutoConfig.from_pretrained("distilgpt2"),
317        AutoModelForCausalLM,
318    ),
319    "GoogleFnet": (
320        AutoConfig.from_pretrained("google/fnet-base"),
321        AutoModelForMaskedLM,
322    ),
323    "YituTechConvBert": (
324        AutoConfig.from_pretrained("YituTech/conv-bert-base"),
325        AutoModelForMaskedLM,
326    ),
327    "CamemBert": (
328        AutoConfig.from_pretrained("camembert-base"),
329        AutoModelForMaskedLM,
330    ),
331}
332
333
334class HuggingfaceRunner(BenchmarkRunner):
335    def __init__(self):
336        super().__init__()
337        self.suite_name = "huggingface"
338
339    @property
340    def _config(self):
341        return load_yaml_file("huggingface.yaml")
342
343    @property
344    def _skip(self):
345        return self._config["skip"]
346
347    @property
348    def _accuracy(self):
349        return self._config["accuracy"]
350
351    @property
352    def skip_models(self):
353        return self._skip["all"]
354
355    @property
356    def skip_models_for_cpu(self):
357        return self._skip["device"]["cpu"]
358
359    @property
360    def fp32_only_models(self):
361        return self._config["only_fp32"]
362
363    @property
364    def skip_models_due_to_control_flow(self):
365        return self._skip["control_flow"]
366
367    def _get_model_cls_and_config(self, model_name):
368        if model_name not in EXTRA_MODELS:
369            model_cls = get_module_cls_by_model_name(model_name)
370            config_cls = model_cls.config_class
371            config = config_cls()
372
373            # NB: some models need a pad token defined to handle BS > 1
374            if (
375                model_cls
376                in [
377                    GPT2ForSequenceClassification,
378                    GPTNeoForSequenceClassification,
379                    GPTJForSequenceClassification,
380                ]
381                or model_cls.__name__.startswith("Roberta")
382                or model_cls.__name__.startswith("Marian")
383            ):
384                config.pad_token_id = 0
385
386        else:
387            config, model_cls = EXTRA_MODELS[model_name]
388
389        return model_cls, config
390
391    @download_retry_decorator
392    def _download_model(self, model_name):
393        model_cls, config = self._get_model_cls_and_config(model_name)
394        if "auto" in model_cls.__module__:
395            # Handle auto classes
396            model = model_cls.from_config(config)
397        else:
398            model = model_cls(config)
399        return model
400
401    def load_model(
402        self,
403        device,
404        model_name,
405        batch_size=None,
406        extra_args=None,
407    ):
408        is_training = self.args.training
409        use_eval_mode = self.args.use_eval_mode
410        dtype = torch.float32
411        reset_rng_state()
412        model_cls, config = self._get_model_cls_and_config(model_name)
413        model = self._download_model(model_name)
414        model = model.to(device, dtype=dtype)
415        if self.args.enable_activation_checkpointing:
416            model.gradient_checkpointing_enable()
417        if model_name in BATCH_SIZE_KNOWN_MODELS:
418            batch_size_default = BATCH_SIZE_KNOWN_MODELS[model_name]
419        elif batch_size is None:
420            batch_size_default = 16
421            log.info(
422                f"Batch size not specified for {model_name}. Setting batch_size=16"
423            )
424
425        if batch_size is None:
426            batch_size = batch_size_default
427            batch_size_divisors = self._config["batch_size"]["divisors"]
428            if model_name in batch_size_divisors:
429                batch_size = max(int(batch_size / batch_size_divisors[model_name]), 1)
430                log.info(
431                    f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}"
432                )
433
434        example_inputs = generate_inputs_for_model(
435            model_cls, model, model_name, batch_size, device, include_loss_args=True
436        )
437
438        # So we can check for correct gradients without eliminating the dropout computation
439        for attr in dir(config):
440            if "drop" in attr and isinstance(getattr(config, attr), float):
441                setattr(config, attr, 1e-30)
442
443        if (
444            is_training
445            and not use_eval_mode
446            and not (
447                self.args.accuracy and model_name in self._config["only_inference"]
448            )
449        ):
450            model.train()
451        else:
452            model.eval()
453
454        self.validate_model(model, example_inputs)
455        return device, model_name, model, example_inputs, batch_size
456
457    def iter_model_names(self, args):
458        model_names = list(BATCH_SIZE_KNOWN_MODELS.keys()) + list(EXTRA_MODELS.keys())
459        model_names = set(model_names)
460        model_names = sorted(model_names)
461
462        start, end = self.get_benchmark_indices(len(model_names))
463        for index, model_name in enumerate(model_names):
464            if index < start or index >= end:
465                continue
466            if (
467                not re.search("|".join(args.filter), model_name, re.I)
468                or re.search("|".join(args.exclude), model_name, re.I)
469                or model_name in args.exclude_exact
470                or model_name in self.skip_models
471            ):
472                continue
473            yield model_name
474
475    @property
476    def skip_accuracy_checks_large_models_dashboard(self):
477        if self.args.dashboard or self.args.accuracy:
478            return self._accuracy["skip"]["large_models"]
479        return set()
480
481    @property
482    def get_output_amp_train_process_func(self):
483        return {}
484
485    def pick_grad(self, name, is_training):
486        if is_training:
487            return torch.enable_grad()
488        else:
489            return torch.no_grad()
490
491    def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
492        cosine = self.args.cosine
493        if is_training:
494            from torch._inductor import config as inductor_config
495
496            if (name in self._config["tolerance"]["higher_training"]) or (
497                inductor_config.max_autotune
498                and name in self._config["tolerance"]["higher_max_autotune_training"]
499            ):
500                return 2e-2, cosine
501            else:
502                return 1e-2, cosine
503        else:
504            if name in self._config["tolerance"]["higher_inference"]:
505                return 4e-3, cosine
506            if (
507                current_device == "cpu"
508                and name in self._config["tolerance"]["higher_inference_cpu"]
509            ):
510                return 4e-3, cosine
511        return 1e-3, cosine
512
513    def compute_loss(self, pred):
514        return pred[0]
515
516    def forward_pass(self, mod, inputs, collect_outputs=True):
517        with self.autocast(**self.autocast_arg):
518            return mod(**inputs)
519
520    def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
521        cloned_inputs = clone_inputs(inputs)
522        self.optimizer_zero_grad(mod)
523        with self.autocast(**self.autocast_arg):
524            pred = mod(**cloned_inputs)
525            loss = self.compute_loss(pred)
526        self.grad_scaler.scale(loss).backward()
527        self.optimizer_step()
528        if collect_outputs:
529            return collect_results(mod, pred, loss, cloned_inputs)
530        return None
531
532
533def refresh_model_names_and_batch_sizes():
534    """
535    This function reads the HF Fx tracer supported models and finds the largest
536    batch size that could fit on the GPU with PyTorch eager.
537
538    The resulting data is written in huggingface_models_list.txt.
539
540    Note - We only need to run this function if we believe that HF Fx tracer now
541    supports more models.
542    """
543    import transformers.utils.fx as hf_fx
544
545    family = {}
546    lm_seen = set()
547    family_seen = set()
548    for cls_name in hf_fx._SUPPORTED_MODELS:
549        if "For" not in cls_name:
550            continue
551
552        model_cls = get_module_cls_by_model_name(cls_name)
553
554        # TODO: AttributeError: '*Config' object has no attribute 'vocab_size'
555        if model_cls in [
556            CLIPModel,
557            CLIPVisionModel,
558            # SwinForImageClassification,
559            # SwinForImageClassification,
560            # SwinForMaskedImageModeling,
561            # SwinModel,
562            ViTForImageClassification,
563            ViTForMaskedImageModeling,
564            ViTModel,
565        ]:
566            continue
567
568        # TODO: AssertionError: Padding_idx must be within num_embeddings
569        if model_cls in [MarianForCausalLM, MarianMTModel, MarianModel]:
570            continue
571
572        # TODO: "model is not supported yet" from HFTracer
573        if model_cls in [HubertForSequenceClassification]:
574            continue
575
576        # TODO: shape mismatch in loss calculation
577        if model_cls in [LxmertForQuestionAnswering]:
578            continue
579
580        family_name = cls_name.split("For")[0]
581        if family_name not in family:
582            family[family_name] = []
583        if cls_name.endswith(("MaskedLM", "CausalLM")) and family_name not in lm_seen:
584            family[family_name].append(cls_name)
585            lm_seen.add(family_name)
586        elif (
587            cls_name.endswith(
588                ("SequenceClassification", "ConditionalGeneration", "QuestionAnswering")
589            )
590            and family_name not in family_seen
591        ):
592            family[family_name].append(cls_name)
593            family_seen.add(family_name)
594        elif cls_name.endswith("ImageClassification"):
595            family[family_name].append(cls_name)
596
597    chosen_models = set()
598    for members in family.values():
599        chosen_models.update(set(members))
600
601    # Add the EXTRA_MODELS
602    chosen_models.update(set(EXTRA_MODELS.keys()))
603
604    for model_name in sorted(chosen_models):
605        try:
606            subprocess.check_call(
607                [sys.executable]
608                + sys.argv
609                + ["--find-batch-sizes"]
610                + [f"--only={model_name}"]
611                + [f"--output={MODELS_FILENAME}"]
612            )
613        except subprocess.SubprocessError:
614            log.warning(f"Failed to find suitable batch size for {model_name}")
615
616
617def huggingface_main():
618    # Code to refresh model names and batch sizes
619    # if "--find-batch-sizes" not in sys.argv:
620    #     refresh_model_names_and_batch_sizes()
621    logging.basicConfig(level=logging.WARNING)
622    warnings.filterwarnings("ignore")
623    main(HuggingfaceRunner())
624
625
626if __name__ == "__main__":
627    huggingface_main()
628