1#!/usr/bin/env python3 2 3import importlib 4import logging 5import os 6import re 7import subprocess 8import sys 9import warnings 10 11 12try: 13 from .common import ( 14 BenchmarkRunner, 15 download_retry_decorator, 16 load_yaml_file, 17 main, 18 reset_rng_state, 19 ) 20except ImportError: 21 from common import ( 22 BenchmarkRunner, 23 download_retry_decorator, 24 load_yaml_file, 25 main, 26 reset_rng_state, 27 ) 28 29import torch 30from torch._dynamo.testing import collect_results 31from torch._dynamo.utils import clone_inputs 32 33 34log = logging.getLogger(__name__) 35 36# Enable FX graph caching 37if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: 38 torch._inductor.config.fx_graph_cache = True 39 40 41def pip_install(package): 42 subprocess.check_call([sys.executable, "-m", "pip", "install", package]) 43 44 45# Disable the flake warnings for the imports. Flake8 does not provide a way to 46# disable just warning for the entire file. Disabling flake8 entirely. 47# flake8: noqa 48imports = [ 49 "AlbertForPreTraining", 50 "AutoConfig", 51 "AutoModelForCausalLM", 52 "AutoModelForMaskedLM", 53 "AutoModelForSeq2SeqLM", 54 "BigBirdConfig", 55 "BlenderbotForConditionalGeneration", 56 "BlenderbotModel", 57 "BlenderbotSmallForConditionalGeneration", 58 "BlenderbotSmallModel", 59 "CLIPModel", 60 "CLIPVisionModel", 61 "ElectraForPreTraining", 62 "GPT2ForSequenceClassification", 63 "GPTJForSequenceClassification", 64 "GPTNeoForSequenceClassification", 65 "HubertForSequenceClassification", 66 "LxmertForPreTraining", 67 "LxmertForQuestionAnswering", 68 "MarianForCausalLM", 69 "MarianModel", 70 "MarianMTModel", 71 "PegasusForConditionalGeneration", 72 "PegasusModel", 73 "ReformerConfig", 74 "ViTForImageClassification", 75 "ViTForMaskedImageModeling", 76 "ViTModel", 77] 78 79 80def process_hf_reformer_output(out): 81 assert isinstance(out, list) 82 # second output is unstable 83 return [elem for i, elem in enumerate(out) if i != 1] 84 85 86try: 87 mod = importlib.import_module("transformers") 88 for cls in imports: 89 if not hasattr(mod, cls): 90 raise ModuleNotFoundError 91except ModuleNotFoundError: 92 print("Installing HuggingFace Transformers...") 93 pip_install("git+https://github.com/huggingface/transformers.git#egg=transformers") 94finally: 95 for cls in imports: 96 exec(f"from transformers import {cls}") 97 98 99# These models contain the models present in huggingface_models_list. It is a 100# combination of models supported by HF Fx parser and some manually supplied 101# models. For these models, we already know the largest batch size that can fit 102# on A100 GPUs - 40 GB. 103BATCH_SIZE_KNOWN_MODELS = {} 104 105 106# TODO(sdym): use batch-size-file parameter of common.main, like torchbench.py 107# Get the list of models and their batch sizes 108MODELS_FILENAME = os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt") 109assert os.path.exists(MODELS_FILENAME) 110with open(MODELS_FILENAME, "r") as fh: 111 lines = fh.readlines() 112 lines = [line.rstrip() for line in lines] 113 for line in lines: 114 model_name, batch_size = line.split(",") 115 batch_size = int(batch_size) 116 BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size 117assert len(BATCH_SIZE_KNOWN_MODELS) 118 119 120def get_module_cls_by_model_name(model_cls_name): 121 _module_by_model_name = { 122 "Speech2Text2Decoder": "transformers.models.speech_to_text_2.modeling_speech_to_text_2", 123 "TrOCRDecoder": "transformers.models.trocr.modeling_trocr", 124 } 125 module_name = _module_by_model_name.get(model_cls_name, "transformers") 126 module = importlib.import_module(module_name) 127 return getattr(module, model_cls_name) 128 129 130def get_sequence_length(model_cls, model_name): 131 if model_name.startswith(("Blenderbot",)): 132 seq_length = 128 133 elif model_name.startswith(("GPT2", "Bart", "T5", "PLBart", "MBart")): 134 seq_length = 1024 135 elif model_name in ("AllenaiLongformerBase", "BigBird"): 136 seq_length = 1024 137 elif model_name.startswith("OPT"): 138 seq_length = 2048 139 elif "Reformer" in model_name: 140 seq_length = 4096 141 elif model_name.startswith( 142 ( 143 "Albert", 144 "Deberta", 145 "Layout", 146 "Electra", 147 "XLNet", 148 "MegatronBert", 149 "Bert", 150 "Roberta", 151 ) 152 ) or model_name in ("DistillGPT2", "GoogleFnet", "YituTechConvBert", "CamemBert"): 153 seq_length = 512 154 elif model_name in ("TrOCRForCausalLM"): 155 seq_length = 256 156 elif model_name.startswith("MobileBert"): 157 seq_length = 128 158 elif model_name.startswith("Wav2Vec2"): 159 # If too short, will fail with something like 160 # ValueError: `mask_length` has to be smaller than `sequence_length`, 161 # but got `mask_length`: 10 and `sequence_length`: 9` 162 seq_length = 10000 # NB: a more realistic size is 155136 163 else: 164 log.info( 165 f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily" 166 ) 167 seq_length = 128 168 return seq_length 169 170 171def generate_inputs_for_model( 172 model_cls, model, model_name, bs, device, include_loss_args=False 173): 174 # TODO - Check if following values are representative 175 num_choices = 3 176 num_visual_features = 42 177 seq_length = get_sequence_length(model_cls, model_name) 178 vocab_size = model.config.vocab_size 179 180 if model_name.startswith("Wav2Vec2"): 181 # TODO: If we add more input_values style models, try to work this 182 # into the overall control flow 183 target_length = 100 184 return { 185 "input_values": torch.randn((bs, seq_length), device=device), 186 # Added because that's what the example training script has 187 "attention_mask": rand_int_tensor(device, 0, 2, (bs, seq_length)), 188 "labels": rand_int_tensor(device, 0, vocab_size, (bs, target_length)), 189 } 190 191 if model_name.endswith("MultipleChoice"): 192 input = rand_int_tensor(device, 0, vocab_size, (bs, num_choices, seq_length)) 193 elif model_name.startswith("Roberta"): 194 input = rand_int_tensor(device, 0, 1, (bs, seq_length)) 195 else: 196 input = rand_int_tensor(device, 0, vocab_size, (bs, seq_length)) 197 198 if "Bart" in model_name: 199 input[:, -1] = model.config.eos_token_id 200 201 input_dict = {"input_ids": input} 202 203 if ( 204 model_name.startswith("T5") 205 or model_name.startswith("M2M100") 206 or model_name.startswith("MT5") 207 or model_cls 208 in [ 209 BlenderbotModel, 210 BlenderbotSmallModel, 211 BlenderbotForConditionalGeneration, 212 BlenderbotSmallForConditionalGeneration, 213 PegasusModel, 214 PegasusForConditionalGeneration, 215 MarianModel, 216 MarianMTModel, 217 ] 218 ): 219 input_dict["decoder_input_ids"] = input 220 221 if model_name.startswith("Lxmert"): 222 visual_feat_dim, visual_pos_dim = ( 223 model.config.visual_feat_dim, 224 model.config.visual_pos_dim, 225 ) 226 input_dict["visual_feats"] = torch.randn( 227 bs, num_visual_features, visual_feat_dim 228 ) 229 input_dict["visual_pos"] = torch.randn(bs, num_visual_features, visual_pos_dim) 230 231 if include_loss_args: 232 if model_name.endswith("PreTraining"): 233 if model_cls in [ElectraForPreTraining, LxmertForPreTraining]: 234 input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs, seq_length)) 235 else: 236 label_name = ( 237 "sentence_order_label" 238 if model_cls in [AlbertForPreTraining] 239 else "next_sentence_label" 240 ) 241 input_dict["labels"] = ( 242 rand_int_tensor(device, 0, vocab_size, (bs, seq_length)), 243 ) 244 input_dict[label_name] = rand_int_tensor(device, 0, 1, (bs,)) 245 elif model_name.endswith("QuestionAnswering"): 246 input_dict["start_positions"] = rand_int_tensor( 247 device, 0, seq_length, (bs,) 248 ) 249 input_dict["end_positions"] = rand_int_tensor(device, 0, seq_length, (bs,)) 250 elif ( 251 model_name.endswith("MaskedLM") 252 or model_name.endswith("HeadModel") 253 or model_name.endswith("CausalLM") 254 or model_name.endswith("DoubleHeadsModel") 255 ): 256 input_dict["labels"] = rand_int_tensor( 257 device, 0, vocab_size, (bs, seq_length) 258 ) 259 elif model_name.endswith("TokenClassification"): 260 input_dict["labels"] = rand_int_tensor( 261 device, 0, model.config.num_labels - 1, (bs, seq_length) 262 ) 263 elif model_name.endswith("MultipleChoice"): 264 input_dict["labels"] = rand_int_tensor(device, 0, num_choices, (bs,)) 265 elif model_name.endswith("SequenceClassification"): 266 input_dict["labels"] = rand_int_tensor( 267 device, 0, model.config.num_labels - 1, (bs,) 268 ) 269 elif model_name.endswith("NextSentencePrediction"): 270 input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs,)) 271 elif model_name.endswith("ForConditionalGeneration"): 272 input_dict["labels"] = rand_int_tensor( 273 device, 0, vocab_size - 1, (bs, seq_length) 274 ) 275 elif model_name in EXTRA_MODELS: 276 input_dict["labels"] = rand_int_tensor( 277 device, 0, vocab_size, (bs, seq_length) 278 ) 279 else: 280 raise NotImplementedError( 281 f"Class {model_name} unsupported for training test " 282 ) 283 284 return input_dict 285 286 287def rand_int_tensor(device, low, high, shape): 288 return torch.randint( 289 low, 290 high, 291 shape, 292 device=device, 293 dtype=torch.int64, 294 requires_grad=False, 295 ) 296 297 298EXTRA_MODELS = { 299 "AllenaiLongformerBase": ( 300 AutoConfig.from_pretrained("allenai/longformer-base-4096"), 301 AutoModelForMaskedLM, 302 ), 303 "Reformer": ( 304 ReformerConfig(), 305 AutoModelForMaskedLM, 306 ), 307 "T5Small": ( 308 AutoConfig.from_pretrained("t5-small"), 309 AutoModelForSeq2SeqLM, 310 ), 311 # "BigBird": ( 312 # BigBirdConfig(attention_type="block_sparse"), 313 # AutoModelForMaskedLM, 314 # ), 315 "DistillGPT2": ( 316 AutoConfig.from_pretrained("distilgpt2"), 317 AutoModelForCausalLM, 318 ), 319 "GoogleFnet": ( 320 AutoConfig.from_pretrained("google/fnet-base"), 321 AutoModelForMaskedLM, 322 ), 323 "YituTechConvBert": ( 324 AutoConfig.from_pretrained("YituTech/conv-bert-base"), 325 AutoModelForMaskedLM, 326 ), 327 "CamemBert": ( 328 AutoConfig.from_pretrained("camembert-base"), 329 AutoModelForMaskedLM, 330 ), 331} 332 333 334class HuggingfaceRunner(BenchmarkRunner): 335 def __init__(self): 336 super().__init__() 337 self.suite_name = "huggingface" 338 339 @property 340 def _config(self): 341 return load_yaml_file("huggingface.yaml") 342 343 @property 344 def _skip(self): 345 return self._config["skip"] 346 347 @property 348 def _accuracy(self): 349 return self._config["accuracy"] 350 351 @property 352 def skip_models(self): 353 return self._skip["all"] 354 355 @property 356 def skip_models_for_cpu(self): 357 return self._skip["device"]["cpu"] 358 359 @property 360 def fp32_only_models(self): 361 return self._config["only_fp32"] 362 363 @property 364 def skip_models_due_to_control_flow(self): 365 return self._skip["control_flow"] 366 367 def _get_model_cls_and_config(self, model_name): 368 if model_name not in EXTRA_MODELS: 369 model_cls = get_module_cls_by_model_name(model_name) 370 config_cls = model_cls.config_class 371 config = config_cls() 372 373 # NB: some models need a pad token defined to handle BS > 1 374 if ( 375 model_cls 376 in [ 377 GPT2ForSequenceClassification, 378 GPTNeoForSequenceClassification, 379 GPTJForSequenceClassification, 380 ] 381 or model_cls.__name__.startswith("Roberta") 382 or model_cls.__name__.startswith("Marian") 383 ): 384 config.pad_token_id = 0 385 386 else: 387 config, model_cls = EXTRA_MODELS[model_name] 388 389 return model_cls, config 390 391 @download_retry_decorator 392 def _download_model(self, model_name): 393 model_cls, config = self._get_model_cls_and_config(model_name) 394 if "auto" in model_cls.__module__: 395 # Handle auto classes 396 model = model_cls.from_config(config) 397 else: 398 model = model_cls(config) 399 return model 400 401 def load_model( 402 self, 403 device, 404 model_name, 405 batch_size=None, 406 extra_args=None, 407 ): 408 is_training = self.args.training 409 use_eval_mode = self.args.use_eval_mode 410 dtype = torch.float32 411 reset_rng_state() 412 model_cls, config = self._get_model_cls_and_config(model_name) 413 model = self._download_model(model_name) 414 model = model.to(device, dtype=dtype) 415 if self.args.enable_activation_checkpointing: 416 model.gradient_checkpointing_enable() 417 if model_name in BATCH_SIZE_KNOWN_MODELS: 418 batch_size_default = BATCH_SIZE_KNOWN_MODELS[model_name] 419 elif batch_size is None: 420 batch_size_default = 16 421 log.info( 422 f"Batch size not specified for {model_name}. Setting batch_size=16" 423 ) 424 425 if batch_size is None: 426 batch_size = batch_size_default 427 batch_size_divisors = self._config["batch_size"]["divisors"] 428 if model_name in batch_size_divisors: 429 batch_size = max(int(batch_size / batch_size_divisors[model_name]), 1) 430 log.info( 431 f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}" 432 ) 433 434 example_inputs = generate_inputs_for_model( 435 model_cls, model, model_name, batch_size, device, include_loss_args=True 436 ) 437 438 # So we can check for correct gradients without eliminating the dropout computation 439 for attr in dir(config): 440 if "drop" in attr and isinstance(getattr(config, attr), float): 441 setattr(config, attr, 1e-30) 442 443 if ( 444 is_training 445 and not use_eval_mode 446 and not ( 447 self.args.accuracy and model_name in self._config["only_inference"] 448 ) 449 ): 450 model.train() 451 else: 452 model.eval() 453 454 self.validate_model(model, example_inputs) 455 return device, model_name, model, example_inputs, batch_size 456 457 def iter_model_names(self, args): 458 model_names = list(BATCH_SIZE_KNOWN_MODELS.keys()) + list(EXTRA_MODELS.keys()) 459 model_names = set(model_names) 460 model_names = sorted(model_names) 461 462 start, end = self.get_benchmark_indices(len(model_names)) 463 for index, model_name in enumerate(model_names): 464 if index < start or index >= end: 465 continue 466 if ( 467 not re.search("|".join(args.filter), model_name, re.I) 468 or re.search("|".join(args.exclude), model_name, re.I) 469 or model_name in args.exclude_exact 470 or model_name in self.skip_models 471 ): 472 continue 473 yield model_name 474 475 @property 476 def skip_accuracy_checks_large_models_dashboard(self): 477 if self.args.dashboard or self.args.accuracy: 478 return self._accuracy["skip"]["large_models"] 479 return set() 480 481 @property 482 def get_output_amp_train_process_func(self): 483 return {} 484 485 def pick_grad(self, name, is_training): 486 if is_training: 487 return torch.enable_grad() 488 else: 489 return torch.no_grad() 490 491 def get_tolerance_and_cosine_flag(self, is_training, current_device, name): 492 cosine = self.args.cosine 493 if is_training: 494 from torch._inductor import config as inductor_config 495 496 if (name in self._config["tolerance"]["higher_training"]) or ( 497 inductor_config.max_autotune 498 and name in self._config["tolerance"]["higher_max_autotune_training"] 499 ): 500 return 2e-2, cosine 501 else: 502 return 1e-2, cosine 503 else: 504 if name in self._config["tolerance"]["higher_inference"]: 505 return 4e-3, cosine 506 if ( 507 current_device == "cpu" 508 and name in self._config["tolerance"]["higher_inference_cpu"] 509 ): 510 return 4e-3, cosine 511 return 1e-3, cosine 512 513 def compute_loss(self, pred): 514 return pred[0] 515 516 def forward_pass(self, mod, inputs, collect_outputs=True): 517 with self.autocast(**self.autocast_arg): 518 return mod(**inputs) 519 520 def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): 521 cloned_inputs = clone_inputs(inputs) 522 self.optimizer_zero_grad(mod) 523 with self.autocast(**self.autocast_arg): 524 pred = mod(**cloned_inputs) 525 loss = self.compute_loss(pred) 526 self.grad_scaler.scale(loss).backward() 527 self.optimizer_step() 528 if collect_outputs: 529 return collect_results(mod, pred, loss, cloned_inputs) 530 return None 531 532 533def refresh_model_names_and_batch_sizes(): 534 """ 535 This function reads the HF Fx tracer supported models and finds the largest 536 batch size that could fit on the GPU with PyTorch eager. 537 538 The resulting data is written in huggingface_models_list.txt. 539 540 Note - We only need to run this function if we believe that HF Fx tracer now 541 supports more models. 542 """ 543 import transformers.utils.fx as hf_fx 544 545 family = {} 546 lm_seen = set() 547 family_seen = set() 548 for cls_name in hf_fx._SUPPORTED_MODELS: 549 if "For" not in cls_name: 550 continue 551 552 model_cls = get_module_cls_by_model_name(cls_name) 553 554 # TODO: AttributeError: '*Config' object has no attribute 'vocab_size' 555 if model_cls in [ 556 CLIPModel, 557 CLIPVisionModel, 558 # SwinForImageClassification, 559 # SwinForImageClassification, 560 # SwinForMaskedImageModeling, 561 # SwinModel, 562 ViTForImageClassification, 563 ViTForMaskedImageModeling, 564 ViTModel, 565 ]: 566 continue 567 568 # TODO: AssertionError: Padding_idx must be within num_embeddings 569 if model_cls in [MarianForCausalLM, MarianMTModel, MarianModel]: 570 continue 571 572 # TODO: "model is not supported yet" from HFTracer 573 if model_cls in [HubertForSequenceClassification]: 574 continue 575 576 # TODO: shape mismatch in loss calculation 577 if model_cls in [LxmertForQuestionAnswering]: 578 continue 579 580 family_name = cls_name.split("For")[0] 581 if family_name not in family: 582 family[family_name] = [] 583 if cls_name.endswith(("MaskedLM", "CausalLM")) and family_name not in lm_seen: 584 family[family_name].append(cls_name) 585 lm_seen.add(family_name) 586 elif ( 587 cls_name.endswith( 588 ("SequenceClassification", "ConditionalGeneration", "QuestionAnswering") 589 ) 590 and family_name not in family_seen 591 ): 592 family[family_name].append(cls_name) 593 family_seen.add(family_name) 594 elif cls_name.endswith("ImageClassification"): 595 family[family_name].append(cls_name) 596 597 chosen_models = set() 598 for members in family.values(): 599 chosen_models.update(set(members)) 600 601 # Add the EXTRA_MODELS 602 chosen_models.update(set(EXTRA_MODELS.keys())) 603 604 for model_name in sorted(chosen_models): 605 try: 606 subprocess.check_call( 607 [sys.executable] 608 + sys.argv 609 + ["--find-batch-sizes"] 610 + [f"--only={model_name}"] 611 + [f"--output={MODELS_FILENAME}"] 612 ) 613 except subprocess.SubprocessError: 614 log.warning(f"Failed to find suitable batch size for {model_name}") 615 616 617def huggingface_main(): 618 # Code to refresh model names and batch sizes 619 # if "--find-batch-sizes" not in sys.argv: 620 # refresh_model_names_and_batch_sizes() 621 logging.basicConfig(level=logging.WARNING) 622 warnings.filterwarnings("ignore") 623 main(HuggingfaceRunner()) 624 625 626if __name__ == "__main__": 627 huggingface_main() 628