1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport gc 4*da0073e9SAndroid Build Coastguard Workerimport importlib 5*da0073e9SAndroid Build Coastguard Workerimport logging 6*da0073e9SAndroid Build Coastguard Workerimport os 7*da0073e9SAndroid Build Coastguard Workerimport re 8*da0073e9SAndroid Build Coastguard Workerimport sys 9*da0073e9SAndroid Build Coastguard Workerimport warnings 10*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple 11*da0073e9SAndroid Build Coastguard Workerfrom os.path import abspath, exists 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerimport torch 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workertry: 17*da0073e9SAndroid Build Coastguard Worker from .common import BenchmarkRunner, load_yaml_file, main 18*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 19*da0073e9SAndroid Build Coastguard Worker from common import BenchmarkRunner, load_yaml_file, main 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import collect_results, reduce_to_scalar_loss 22*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import clone_inputs 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker# We are primarily interested in tf32 datatype 26*da0073e9SAndroid Build Coastguard Workertorch.backends.cuda.matmul.allow_tf32 = True 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker# Enable FX graph caching 29*da0073e9SAndroid Build Coastguard Workerif "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ: 30*da0073e9SAndroid Build Coastguard Worker torch._inductor.config.fx_graph_cache = True 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Workerdef _reassign_parameters(model): 34*da0073e9SAndroid Build Coastguard Worker # torch_geometric models register parameter as tensors due to 35*da0073e9SAndroid Build Coastguard Worker # https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/dense/linear.py#L158-L168 36*da0073e9SAndroid Build Coastguard Worker # Since it is unusual thing to do, we just reassign them to parameters 37*da0073e9SAndroid Build Coastguard Worker def state_dict_hook(module, destination, prefix, local_metadata): 38*da0073e9SAndroid Build Coastguard Worker for name, param in module.named_parameters(): 39*da0073e9SAndroid Build Coastguard Worker if isinstance(destination[name], torch.Tensor) and not isinstance( 40*da0073e9SAndroid Build Coastguard Worker destination[name], torch.nn.Parameter 41*da0073e9SAndroid Build Coastguard Worker ): 42*da0073e9SAndroid Build Coastguard Worker destination[name] = torch.nn.Parameter(destination[name]) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker model._register_state_dict_hook(state_dict_hook) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Workerdef setup_torchbench_cwd(): 48*da0073e9SAndroid Build Coastguard Worker original_dir = abspath(os.getcwd()) 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker os.environ["KALDI_ROOT"] = "/tmp" # avoids some spam 51*da0073e9SAndroid Build Coastguard Worker for torchbench_dir in ( 52*da0073e9SAndroid Build Coastguard Worker "./torchbenchmark", 53*da0073e9SAndroid Build Coastguard Worker "../torchbenchmark", 54*da0073e9SAndroid Build Coastguard Worker "../torchbench", 55*da0073e9SAndroid Build Coastguard Worker "../benchmark", 56*da0073e9SAndroid Build Coastguard Worker "../../torchbenchmark", 57*da0073e9SAndroid Build Coastguard Worker "../../torchbench", 58*da0073e9SAndroid Build Coastguard Worker "../../benchmark", 59*da0073e9SAndroid Build Coastguard Worker "../../../torchbenchmark", 60*da0073e9SAndroid Build Coastguard Worker "../../../torchbench", 61*da0073e9SAndroid Build Coastguard Worker "../../../benchmark", 62*da0073e9SAndroid Build Coastguard Worker ): 63*da0073e9SAndroid Build Coastguard Worker if exists(torchbench_dir): 64*da0073e9SAndroid Build Coastguard Worker break 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker if exists(torchbench_dir): 67*da0073e9SAndroid Build Coastguard Worker torchbench_dir = abspath(torchbench_dir) 68*da0073e9SAndroid Build Coastguard Worker os.chdir(torchbench_dir) 69*da0073e9SAndroid Build Coastguard Worker sys.path.append(torchbench_dir) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker return original_dir 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Workerdef process_hf_reformer_output(out): 75*da0073e9SAndroid Build Coastguard Worker assert isinstance(out, list) 76*da0073e9SAndroid Build Coastguard Worker # second output is unstable 77*da0073e9SAndroid Build Coastguard Worker return [elem for i, elem in enumerate(out) if i != 1] 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Workerdef process_hf_whisper_output(out): 81*da0073e9SAndroid Build Coastguard Worker out_ret = [] 82*da0073e9SAndroid Build Coastguard Worker for i, elem in enumerate(out): 83*da0073e9SAndroid Build Coastguard Worker if i == 0: 84*da0073e9SAndroid Build Coastguard Worker assert isinstance(elem, dict) 85*da0073e9SAndroid Build Coastguard Worker out_ret.append({k: v for k, v in elem.items() if k != "logits"}) 86*da0073e9SAndroid Build Coastguard Worker elif i != 1: 87*da0073e9SAndroid Build Coastguard Worker out_ret.append(elem) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker return out_ret 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Workerprocess_train_model_output = { 93*da0073e9SAndroid Build Coastguard Worker "hf_Reformer": process_hf_reformer_output, 94*da0073e9SAndroid Build Coastguard Worker "hf_Whisper": process_hf_whisper_output, 95*da0073e9SAndroid Build Coastguard Worker} 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Workerclass TorchBenchmarkRunner(BenchmarkRunner): 99*da0073e9SAndroid Build Coastguard Worker def __init__(self): 100*da0073e9SAndroid Build Coastguard Worker super().__init__() 101*da0073e9SAndroid Build Coastguard Worker self.suite_name = "torchbench" 102*da0073e9SAndroid Build Coastguard Worker self.optimizer = None 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker @property 105*da0073e9SAndroid Build Coastguard Worker def _config(self): 106*da0073e9SAndroid Build Coastguard Worker return load_yaml_file("torchbench.yaml") 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker @property 109*da0073e9SAndroid Build Coastguard Worker def _skip(self): 110*da0073e9SAndroid Build Coastguard Worker return self._config["skip"] 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker @property 113*da0073e9SAndroid Build Coastguard Worker def _batch_size(self): 114*da0073e9SAndroid Build Coastguard Worker return self._config["batch_size"] 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker @property 117*da0073e9SAndroid Build Coastguard Worker def _tolerance(self): 118*da0073e9SAndroid Build Coastguard Worker return self._config["tolerance"] 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker @property 121*da0073e9SAndroid Build Coastguard Worker def _require_larger_multiplier_for_smaller_tensor(self): 122*da0073e9SAndroid Build Coastguard Worker return self._config["require_larger_multiplier_for_smaller_tensor"] 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker @property 125*da0073e9SAndroid Build Coastguard Worker def _accuracy(self): 126*da0073e9SAndroid Build Coastguard Worker return self._config["accuracy"] 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker @property 129*da0073e9SAndroid Build Coastguard Worker def skip_models(self): 130*da0073e9SAndroid Build Coastguard Worker return self._skip["all"] 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker @property 133*da0073e9SAndroid Build Coastguard Worker def skip_models_for_cpu(self): 134*da0073e9SAndroid Build Coastguard Worker return self._skip["device"]["cpu"] 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker @property 137*da0073e9SAndroid Build Coastguard Worker def skip_models_for_cuda(self): 138*da0073e9SAndroid Build Coastguard Worker return self._skip["device"]["cuda"] 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker @property 141*da0073e9SAndroid Build Coastguard Worker def skip_models_for_freezing_cuda(self): 142*da0073e9SAndroid Build Coastguard Worker return self._skip["freezing"]["cuda"] 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker @property 145*da0073e9SAndroid Build Coastguard Worker def skip_models_for_freezing_cpu(self): 146*da0073e9SAndroid Build Coastguard Worker return self._skip["freezing"]["cpu"] 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker @property 149*da0073e9SAndroid Build Coastguard Worker def slow_models(self): 150*da0073e9SAndroid Build Coastguard Worker return self._config["slow"] 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker @property 153*da0073e9SAndroid Build Coastguard Worker def very_slow_models(self): 154*da0073e9SAndroid Build Coastguard Worker return self._config["very_slow"] 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker @property 157*da0073e9SAndroid Build Coastguard Worker def non_deterministic_models(self): 158*da0073e9SAndroid Build Coastguard Worker return self._config["non_deterministic"] 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker @property 161*da0073e9SAndroid Build Coastguard Worker def get_output_amp_train_process_func(self): 162*da0073e9SAndroid Build Coastguard Worker return process_train_model_output 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker @property 165*da0073e9SAndroid Build Coastguard Worker def skip_not_suitable_for_training_models(self): 166*da0073e9SAndroid Build Coastguard Worker return self._skip["test"]["training"] 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker @property 169*da0073e9SAndroid Build Coastguard Worker def failing_fx2trt_models(self): 170*da0073e9SAndroid Build Coastguard Worker return self._config["trt_not_yet_working"] 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker @property 173*da0073e9SAndroid Build Coastguard Worker def force_amp_for_fp16_bf16_models(self): 174*da0073e9SAndroid Build Coastguard Worker return self._config["dtype"]["force_amp_for_fp16_bf16_models"] 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker @property 177*da0073e9SAndroid Build Coastguard Worker def force_fp16_for_bf16_models(self): 178*da0073e9SAndroid Build Coastguard Worker return self._config["dtype"]["force_fp16_for_bf16_models"] 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker @property 181*da0073e9SAndroid Build Coastguard Worker def skip_accuracy_checks_large_models_dashboard(self): 182*da0073e9SAndroid Build Coastguard Worker if self.args.dashboard or self.args.accuracy: 183*da0073e9SAndroid Build Coastguard Worker return self._accuracy["skip"]["large_models"] 184*da0073e9SAndroid Build Coastguard Worker return set() 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker @property 187*da0073e9SAndroid Build Coastguard Worker def skip_accuracy_check_as_eager_non_deterministic(self): 188*da0073e9SAndroid Build Coastguard Worker if self.args.accuracy and self.args.training: 189*da0073e9SAndroid Build Coastguard Worker return self._accuracy["skip"]["eager_not_deterministic"] 190*da0073e9SAndroid Build Coastguard Worker return set() 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker @property 193*da0073e9SAndroid Build Coastguard Worker def skip_multiprocess_models(self): 194*da0073e9SAndroid Build Coastguard Worker return self._skip["multiprocess"] 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker @property 197*da0073e9SAndroid Build Coastguard Worker def skip_models_due_to_control_flow(self): 198*da0073e9SAndroid Build Coastguard Worker return self._skip["control_flow"] 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker @property 201*da0073e9SAndroid Build Coastguard Worker def guard_on_nn_module_models(self): 202*da0073e9SAndroid Build Coastguard Worker return { 203*da0073e9SAndroid Build Coastguard Worker "vision_maskrcnn", 204*da0073e9SAndroid Build Coastguard Worker } 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker @property 207*da0073e9SAndroid Build Coastguard Worker def inline_inbuilt_nn_modules_models(self): 208*da0073e9SAndroid Build Coastguard Worker return { 209*da0073e9SAndroid Build Coastguard Worker "basic_gnn_edgecnn", 210*da0073e9SAndroid Build Coastguard Worker "drq", 211*da0073e9SAndroid Build Coastguard Worker "hf_Reformer", 212*da0073e9SAndroid Build Coastguard Worker "DALLE2_pytorch", 213*da0073e9SAndroid Build Coastguard Worker "hf_BigBird", 214*da0073e9SAndroid Build Coastguard Worker "detectron2_maskrcnn_r_50_fpn", 215*da0073e9SAndroid Build Coastguard Worker "detectron2_maskrcnn_r_101_fpn", 216*da0073e9SAndroid Build Coastguard Worker "vision_maskrcnn", 217*da0073e9SAndroid Build Coastguard Worker "doctr_reco_predictor", 218*da0073e9SAndroid Build Coastguard Worker "hf_T5_generate", 219*da0073e9SAndroid Build Coastguard Worker } 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker def load_model( 222*da0073e9SAndroid Build Coastguard Worker self, 223*da0073e9SAndroid Build Coastguard Worker device, 224*da0073e9SAndroid Build Coastguard Worker model_name, 225*da0073e9SAndroid Build Coastguard Worker batch_size=None, 226*da0073e9SAndroid Build Coastguard Worker part=None, 227*da0073e9SAndroid Build Coastguard Worker extra_args=None, 228*da0073e9SAndroid Build Coastguard Worker ): 229*da0073e9SAndroid Build Coastguard Worker if self.args.enable_activation_checkpointing: 230*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 231*da0073e9SAndroid Build Coastguard Worker "Activation checkpointing not implemented for Torchbench models" 232*da0073e9SAndroid Build Coastguard Worker ) 233*da0073e9SAndroid Build Coastguard Worker is_training = self.args.training 234*da0073e9SAndroid Build Coastguard Worker use_eval_mode = self.args.use_eval_mode 235*da0073e9SAndroid Build Coastguard Worker dynamic_shapes = self.args.dynamic_shapes 236*da0073e9SAndroid Build Coastguard Worker candidates = [ 237*da0073e9SAndroid Build Coastguard Worker f"torchbenchmark.models.{model_name}", 238*da0073e9SAndroid Build Coastguard Worker f"torchbenchmark.canary_models.{model_name}", 239*da0073e9SAndroid Build Coastguard Worker f"torchbenchmark.models.fb.{model_name}", 240*da0073e9SAndroid Build Coastguard Worker ] 241*da0073e9SAndroid Build Coastguard Worker for c in candidates: 242*da0073e9SAndroid Build Coastguard Worker try: 243*da0073e9SAndroid Build Coastguard Worker module = importlib.import_module(c) 244*da0073e9SAndroid Build Coastguard Worker break 245*da0073e9SAndroid Build Coastguard Worker except ModuleNotFoundError as e: 246*da0073e9SAndroid Build Coastguard Worker if e.name != c: 247*da0073e9SAndroid Build Coastguard Worker raise 248*da0073e9SAndroid Build Coastguard Worker else: 249*da0073e9SAndroid Build Coastguard Worker raise ImportError(f"could not import any of {candidates}") 250*da0073e9SAndroid Build Coastguard Worker benchmark_cls = getattr(module, "Model", None) 251*da0073e9SAndroid Build Coastguard Worker if benchmark_cls is None: 252*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError(f"{model_name}.Model is None") 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker if not hasattr(benchmark_cls, "name"): 255*da0073e9SAndroid Build Coastguard Worker benchmark_cls.name = model_name 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker cant_change_batch_size = ( 258*da0073e9SAndroid Build Coastguard Worker not getattr(benchmark_cls, "ALLOW_CUSTOMIZE_BSIZE", True) 259*da0073e9SAndroid Build Coastguard Worker or model_name in self._config["dont_change_batch_size"] 260*da0073e9SAndroid Build Coastguard Worker ) 261*da0073e9SAndroid Build Coastguard Worker if cant_change_batch_size: 262*da0073e9SAndroid Build Coastguard Worker batch_size = None 263*da0073e9SAndroid Build Coastguard Worker if ( 264*da0073e9SAndroid Build Coastguard Worker batch_size is None 265*da0073e9SAndroid Build Coastguard Worker and is_training 266*da0073e9SAndroid Build Coastguard Worker and model_name in self._batch_size["training"] 267*da0073e9SAndroid Build Coastguard Worker ): 268*da0073e9SAndroid Build Coastguard Worker batch_size = self._batch_size["training"][model_name] 269*da0073e9SAndroid Build Coastguard Worker elif ( 270*da0073e9SAndroid Build Coastguard Worker batch_size is None 271*da0073e9SAndroid Build Coastguard Worker and not is_training 272*da0073e9SAndroid Build Coastguard Worker and model_name in self._batch_size["inference"] 273*da0073e9SAndroid Build Coastguard Worker ): 274*da0073e9SAndroid Build Coastguard Worker batch_size = self._batch_size["inference"][model_name] 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker # Control the memory footprint for few models 277*da0073e9SAndroid Build Coastguard Worker if self.args.accuracy and model_name in self._accuracy["max_batch_size"]: 278*da0073e9SAndroid Build Coastguard Worker batch_size = min(batch_size, self._accuracy["max_batch_size"][model_name]) 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags" 281*da0073e9SAndroid Build Coastguard Worker torch.backends.__allow_nonbracketed_mutation_flag = True 282*da0073e9SAndroid Build Coastguard Worker if extra_args is None: 283*da0073e9SAndroid Build Coastguard Worker extra_args = [] 284*da0073e9SAndroid Build Coastguard Worker if part: 285*da0073e9SAndroid Build Coastguard Worker extra_args += ["--part", part] 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker # sam_fast only runs with amp 288*da0073e9SAndroid Build Coastguard Worker if model_name == "sam_fast": 289*da0073e9SAndroid Build Coastguard Worker self.args.amp = True 290*da0073e9SAndroid Build Coastguard Worker self.setup_amp() 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker if model_name == "vision_maskrcnn" and is_training: 293*da0073e9SAndroid Build Coastguard Worker # Output of vision_maskrcnn model is a list of bounding boxes, 294*da0073e9SAndroid Build Coastguard Worker # sorted on the basis of their scores. This makes accuracy 295*da0073e9SAndroid Build Coastguard Worker # comparison hard with torch.compile. torch.compile can cause minor 296*da0073e9SAndroid Build Coastguard Worker # divergences in the output because of how fusion works for amp in 297*da0073e9SAndroid Build Coastguard Worker # TorchInductor compared to eager. Therefore, instead of looking at 298*da0073e9SAndroid Build Coastguard Worker # all the bounding boxes, we compare only top 4. 299*da0073e9SAndroid Build Coastguard Worker model_kwargs = {"box_detections_per_img": 4} 300*da0073e9SAndroid Build Coastguard Worker benchmark = benchmark_cls( 301*da0073e9SAndroid Build Coastguard Worker test="train", 302*da0073e9SAndroid Build Coastguard Worker device=device, 303*da0073e9SAndroid Build Coastguard Worker batch_size=batch_size, 304*da0073e9SAndroid Build Coastguard Worker extra_args=extra_args, 305*da0073e9SAndroid Build Coastguard Worker model_kwargs=model_kwargs, 306*da0073e9SAndroid Build Coastguard Worker ) 307*da0073e9SAndroid Build Coastguard Worker use_eval_mode = True 308*da0073e9SAndroid Build Coastguard Worker elif is_training: 309*da0073e9SAndroid Build Coastguard Worker benchmark = benchmark_cls( 310*da0073e9SAndroid Build Coastguard Worker test="train", 311*da0073e9SAndroid Build Coastguard Worker device=device, 312*da0073e9SAndroid Build Coastguard Worker batch_size=batch_size, 313*da0073e9SAndroid Build Coastguard Worker extra_args=extra_args, 314*da0073e9SAndroid Build Coastguard Worker ) 315*da0073e9SAndroid Build Coastguard Worker else: 316*da0073e9SAndroid Build Coastguard Worker benchmark = benchmark_cls( 317*da0073e9SAndroid Build Coastguard Worker test="eval", 318*da0073e9SAndroid Build Coastguard Worker device=device, 319*da0073e9SAndroid Build Coastguard Worker batch_size=batch_size, 320*da0073e9SAndroid Build Coastguard Worker extra_args=extra_args, 321*da0073e9SAndroid Build Coastguard Worker ) 322*da0073e9SAndroid Build Coastguard Worker model, example_inputs = benchmark.get_module() 323*da0073e9SAndroid Build Coastguard Worker if model_name in [ 324*da0073e9SAndroid Build Coastguard Worker "basic_gnn_edgecnn", 325*da0073e9SAndroid Build Coastguard Worker "basic_gnn_gcn", 326*da0073e9SAndroid Build Coastguard Worker "basic_gnn_sage", 327*da0073e9SAndroid Build Coastguard Worker "basic_gnn_gin", 328*da0073e9SAndroid Build Coastguard Worker ]: 329*da0073e9SAndroid Build Coastguard Worker _reassign_parameters(model) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker # Models that must be in train mode while training 332*da0073e9SAndroid Build Coastguard Worker if is_training and ( 333*da0073e9SAndroid Build Coastguard Worker not use_eval_mode or model_name in self._config["only_training"] 334*da0073e9SAndroid Build Coastguard Worker ): 335*da0073e9SAndroid Build Coastguard Worker model.train() 336*da0073e9SAndroid Build Coastguard Worker else: 337*da0073e9SAndroid Build Coastguard Worker model.eval() 338*da0073e9SAndroid Build Coastguard Worker gc.collect() 339*da0073e9SAndroid Build Coastguard Worker batch_size = benchmark.batch_size 340*da0073e9SAndroid Build Coastguard Worker if model_name == "torchrec_dlrm": 341*da0073e9SAndroid Build Coastguard Worker batch_namedtuple = namedtuple( 342*da0073e9SAndroid Build Coastguard Worker "Batch", "dense_features sparse_features labels" 343*da0073e9SAndroid Build Coastguard Worker ) 344*da0073e9SAndroid Build Coastguard Worker example_inputs = tuple( 345*da0073e9SAndroid Build Coastguard Worker batch_namedtuple( 346*da0073e9SAndroid Build Coastguard Worker dense_features=batch.dense_features, 347*da0073e9SAndroid Build Coastguard Worker sparse_features=batch.sparse_features, 348*da0073e9SAndroid Build Coastguard Worker labels=batch.labels, 349*da0073e9SAndroid Build Coastguard Worker ) 350*da0073e9SAndroid Build Coastguard Worker for batch in example_inputs 351*da0073e9SAndroid Build Coastguard Worker ) 352*da0073e9SAndroid Build Coastguard Worker # Torchbench has quite different setup for yolov3, so directly passing 353*da0073e9SAndroid Build Coastguard Worker # the right example_inputs 354*da0073e9SAndroid Build Coastguard Worker if model_name == "yolov3": 355*da0073e9SAndroid Build Coastguard Worker example_inputs = (torch.rand(batch_size, 3, 384, 512).to(device),) 356*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/benchmark/issues/1561 357*da0073e9SAndroid Build Coastguard Worker if model_name == "maml_omniglot": 358*da0073e9SAndroid Build Coastguard Worker batch_size = 5 359*da0073e9SAndroid Build Coastguard Worker assert example_inputs[0].shape[0] == batch_size 360*da0073e9SAndroid Build Coastguard Worker if model_name == "vision_maskrcnn": 361*da0073e9SAndroid Build Coastguard Worker batch_size = 1 362*da0073e9SAndroid Build Coastguard Worker # global current_name, current_device 363*da0073e9SAndroid Build Coastguard Worker # current_device = device 364*da0073e9SAndroid Build Coastguard Worker # current_name = benchmark.name 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker if self.args.trace_on_xla: 367*da0073e9SAndroid Build Coastguard Worker # work around for: https://github.com/pytorch/xla/issues/4174 368*da0073e9SAndroid Build Coastguard Worker import torch_xla # noqa: F401 369*da0073e9SAndroid Build Coastguard Worker self.validate_model(model, example_inputs) 370*da0073e9SAndroid Build Coastguard Worker return device, benchmark.name, model, example_inputs, batch_size 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker def iter_model_names(self, args): 373*da0073e9SAndroid Build Coastguard Worker from torchbenchmark import _list_canary_model_paths, _list_model_paths 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker models = _list_model_paths() 376*da0073e9SAndroid Build Coastguard Worker models += [ 377*da0073e9SAndroid Build Coastguard Worker f 378*da0073e9SAndroid Build Coastguard Worker for f in _list_canary_model_paths() 379*da0073e9SAndroid Build Coastguard Worker if os.path.basename(f) in self._config["canary_models"] 380*da0073e9SAndroid Build Coastguard Worker ] 381*da0073e9SAndroid Build Coastguard Worker models.sort() 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker start, end = self.get_benchmark_indices(len(models)) 384*da0073e9SAndroid Build Coastguard Worker for index, model_path in enumerate(models): 385*da0073e9SAndroid Build Coastguard Worker if index < start or index >= end: 386*da0073e9SAndroid Build Coastguard Worker continue 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker model_name = os.path.basename(model_path) 389*da0073e9SAndroid Build Coastguard Worker if ( 390*da0073e9SAndroid Build Coastguard Worker not re.search("|".join(args.filter), model_name, re.IGNORECASE) 391*da0073e9SAndroid Build Coastguard Worker or re.search("|".join(args.exclude), model_name, re.IGNORECASE) 392*da0073e9SAndroid Build Coastguard Worker or model_name in args.exclude_exact 393*da0073e9SAndroid Build Coastguard Worker or model_name in self.skip_models 394*da0073e9SAndroid Build Coastguard Worker ): 395*da0073e9SAndroid Build Coastguard Worker continue 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker yield model_name 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker def pick_grad(self, name, is_training): 400*da0073e9SAndroid Build Coastguard Worker if is_training or name in ("maml",): 401*da0073e9SAndroid Build Coastguard Worker return torch.enable_grad() 402*da0073e9SAndroid Build Coastguard Worker else: 403*da0073e9SAndroid Build Coastguard Worker return torch.no_grad() 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker def use_larger_multiplier_for_smaller_tensor(self, name): 406*da0073e9SAndroid Build Coastguard Worker return name in self._require_larger_multiplier_for_smaller_tensor 407*da0073e9SAndroid Build Coastguard Worker 408*da0073e9SAndroid Build Coastguard Worker def get_tolerance_and_cosine_flag(self, is_training, current_device, name): 409*da0073e9SAndroid Build Coastguard Worker tolerance = 1e-4 410*da0073e9SAndroid Build Coastguard Worker cosine = self.args.cosine 411*da0073e9SAndroid Build Coastguard Worker # Increase the tolerance for torch allclose 412*da0073e9SAndroid Build Coastguard Worker if self.args.float16 or self.args.amp: 413*da0073e9SAndroid Build Coastguard Worker if name in self._tolerance["higher_fp16"]: 414*da0073e9SAndroid Build Coastguard Worker return 1e-2, cosine 415*da0073e9SAndroid Build Coastguard Worker elif name in self._tolerance["even_higher"]: 416*da0073e9SAndroid Build Coastguard Worker return 8 * 1e-2, cosine 417*da0073e9SAndroid Build Coastguard Worker return 1e-3, cosine 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker if self.args.bfloat16: 420*da0073e9SAndroid Build Coastguard Worker if name in self._tolerance["higher_bf16"]: 421*da0073e9SAndroid Build Coastguard Worker return 1e-2, cosine 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker if is_training and (current_device == "cuda" or current_device == "xpu"): 424*da0073e9SAndroid Build Coastguard Worker tolerance = 1e-3 425*da0073e9SAndroid Build Coastguard Worker if name in self._tolerance["cosine"]: 426*da0073e9SAndroid Build Coastguard Worker cosine = True 427*da0073e9SAndroid Build Coastguard Worker elif name in self._tolerance["higher"]: 428*da0073e9SAndroid Build Coastguard Worker tolerance = 1e-3 429*da0073e9SAndroid Build Coastguard Worker elif name in self._tolerance["even_higher"]: 430*da0073e9SAndroid Build Coastguard Worker tolerance = 8 * 1e-2 431*da0073e9SAndroid Build Coastguard Worker return tolerance, cosine 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker def compute_loss(self, pred): 434*da0073e9SAndroid Build Coastguard Worker return reduce_to_scalar_loss(pred) 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker def forward_pass(self, mod, inputs, collect_outputs=True): 437*da0073e9SAndroid Build Coastguard Worker with self.autocast(**self.autocast_arg): 438*da0073e9SAndroid Build Coastguard Worker if isinstance(inputs, dict): 439*da0073e9SAndroid Build Coastguard Worker return mod(**inputs) 440*da0073e9SAndroid Build Coastguard Worker else: 441*da0073e9SAndroid Build Coastguard Worker return mod(*inputs) 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): 444*da0073e9SAndroid Build Coastguard Worker cloned_inputs = clone_inputs(inputs) 445*da0073e9SAndroid Build Coastguard Worker self.optimizer_zero_grad(mod) 446*da0073e9SAndroid Build Coastguard Worker with self.autocast(**self.autocast_arg): 447*da0073e9SAndroid Build Coastguard Worker if isinstance(cloned_inputs, dict): 448*da0073e9SAndroid Build Coastguard Worker pred = mod(**cloned_inputs) 449*da0073e9SAndroid Build Coastguard Worker else: 450*da0073e9SAndroid Build Coastguard Worker pred = mod(*cloned_inputs) 451*da0073e9SAndroid Build Coastguard Worker loss = self.compute_loss(pred) 452*da0073e9SAndroid Build Coastguard Worker self.grad_scaler.scale(loss).backward() 453*da0073e9SAndroid Build Coastguard Worker self.optimizer_step() 454*da0073e9SAndroid Build Coastguard Worker if collect_outputs: 455*da0073e9SAndroid Build Coastguard Worker return collect_results(mod, pred, loss, cloned_inputs) 456*da0073e9SAndroid Build Coastguard Worker return None 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Workerdef torchbench_main(): 460*da0073e9SAndroid Build Coastguard Worker original_dir = setup_torchbench_cwd() 461*da0073e9SAndroid Build Coastguard Worker logging.basicConfig(level=logging.WARNING) 462*da0073e9SAndroid Build Coastguard Worker warnings.filterwarnings("ignore") 463*da0073e9SAndroid Build Coastguard Worker main(TorchBenchmarkRunner(), original_dir) 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 467*da0073e9SAndroid Build Coastguard Worker torchbench_main() 468