1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport abc 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport argparse 7*da0073e9SAndroid Build Coastguard Workerimport collections 8*da0073e9SAndroid Build Coastguard Workerimport contextlib 9*da0073e9SAndroid Build Coastguard Workerimport copy 10*da0073e9SAndroid Build Coastguard Workerimport csv 11*da0073e9SAndroid Build Coastguard Workerimport dataclasses 12*da0073e9SAndroid Build Coastguard Workerimport functools 13*da0073e9SAndroid Build Coastguard Workerimport importlib 14*da0073e9SAndroid Build Coastguard Workerimport itertools 15*da0073e9SAndroid Build Coastguard Workerimport logging 16*da0073e9SAndroid Build Coastguard Workerimport os 17*da0073e9SAndroid Build Coastguard Workerimport pathlib 18*da0073e9SAndroid Build Coastguard Workerimport shutil 19*da0073e9SAndroid Build Coastguard Workerimport signal 20*da0073e9SAndroid Build Coastguard Workerimport subprocess 21*da0073e9SAndroid Build Coastguard Workerimport sys 22*da0073e9SAndroid Build Coastguard Workerimport time 23*da0073e9SAndroid Build Coastguard Workerimport weakref 24*da0073e9SAndroid Build Coastguard Workerfrom contextlib import contextmanager 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Workerfrom typing import ( 27*da0073e9SAndroid Build Coastguard Worker Any, 28*da0073e9SAndroid Build Coastguard Worker Callable, 29*da0073e9SAndroid Build Coastguard Worker Generator, 30*da0073e9SAndroid Build Coastguard Worker List, 31*da0073e9SAndroid Build Coastguard Worker Mapping, 32*da0073e9SAndroid Build Coastguard Worker NamedTuple, 33*da0073e9SAndroid Build Coastguard Worker Optional, 34*da0073e9SAndroid Build Coastguard Worker Sequence, 35*da0073e9SAndroid Build Coastguard Worker Tuple, 36*da0073e9SAndroid Build Coastguard Worker Type, 37*da0073e9SAndroid Build Coastguard Worker TYPE_CHECKING, 38*da0073e9SAndroid Build Coastguard Worker) 39*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import Self 40*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import MagicMock 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Workerimport numpy as np 43*da0073e9SAndroid Build Coastguard Workerimport pandas as pd 44*da0073e9SAndroid Build Coastguard Workerimport psutil 45*da0073e9SAndroid Build Coastguard Workerfrom scipy.stats import gmean, ttest_ind 46*da0073e9SAndroid Build Coastguard Workerfrom tqdm.auto import tqdm, trange 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Workerimport torch 49*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo 50*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.utils 51*da0073e9SAndroid Build Coastguard Workerimport torch._export 52*da0073e9SAndroid Build Coastguard Workerimport torch.distributed 53*da0073e9SAndroid Build Coastguard Workerimport torch.multiprocessing as mp 54*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _has_cuda as HAS_CUDA, _has_xpu as HAS_XPU 55*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.profiler import fx_insert_profiling, Profiler 56*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import ( 57*da0073e9SAndroid Build Coastguard Worker dummy_fx_compile, 58*da0073e9SAndroid Build Coastguard Worker format_speedup, 59*da0073e9SAndroid Build Coastguard Worker reset_rng_state, 60*da0073e9SAndroid Build Coastguard Worker same, 61*da0073e9SAndroid Build Coastguard Worker) 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Workertry: 64*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.utils import ( 65*da0073e9SAndroid Build Coastguard Worker clone_inputs, 66*da0073e9SAndroid Build Coastguard Worker graph_break_reasons, 67*da0073e9SAndroid Build Coastguard Worker maybe_enable_compiled_autograd, 68*da0073e9SAndroid Build Coastguard Worker ) 69*da0073e9SAndroid Build Coastguard Worker from torch._inductor.utils import fresh_inductor_cache 70*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 71*da0073e9SAndroid Build Coastguard Worker from _dynamo.utils import ( 72*da0073e9SAndroid Build Coastguard Worker clone_inputs, 73*da0073e9SAndroid Build Coastguard Worker graph_break_reasons, 74*da0073e9SAndroid Build Coastguard Worker maybe_enable_compiled_autograd, 75*da0073e9SAndroid Build Coastguard Worker ) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Workerimport torch._functorch.config 78*da0073e9SAndroid Build Coastguard Workerfrom torch._functorch.aot_autograd import set_model_name 79*da0073e9SAndroid Build Coastguard Workerfrom torch._inductor import config as inductor_config, metrics 80*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensorMode 81*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 82*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map, tree_map_only 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Workertry: 85*da0073e9SAndroid Build Coastguard Worker import torch_xla 86*da0073e9SAndroid Build Coastguard Worker import torch_xla.core.xla_model as xm 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker # This is to woraround the backward issue https://github.com/pytorch/xla/issues/4174 89*da0073e9SAndroid Build Coastguard Worker torch_xla._XLAC._init_computation_client() 90*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 91*da0073e9SAndroid Build Coastguard Worker # ignore the error if torch_xla is not installed 92*da0073e9SAndroid Build Coastguard Worker pass 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING: 96*da0073e9SAndroid Build Coastguard Worker from torch.onnx._internal.fx import diagnostics 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Workerlog = logging.getLogger(__name__) 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker# We are primarily interested in TF32 102*da0073e9SAndroid Build Coastguard Workertorch.backends.cuda.matmul.allow_tf32 = True 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker# Suppress torch.profiler spam 105*da0073e9SAndroid Build Coastguard Workeros.environ["KINETO_LOG_LEVEL"] = "5" 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Workercurrent_name = "" 108*da0073e9SAndroid Build Coastguard Workercurrent_device = "" 109*da0073e9SAndroid Build Coastguard Workercurrent_onnx_compiler = "" 110*da0073e9SAndroid Build Coastguard Workercurrent_batch_size = None 111*da0073e9SAndroid Build Coastguard Workeroutput_filename = None 112*da0073e9SAndroid Build Coastguard Workerdisable_output = False 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard WorkerMAX_DOWNLOAD_ATTEMPTS = 5 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Workerclass CI(NamedTuple): 118*da0073e9SAndroid Build Coastguard Worker backend: str # aot_eager or inductor 119*da0073e9SAndroid Build Coastguard Worker training: bool 120*da0073e9SAndroid Build Coastguard Worker dynamic: bool = False 121*da0073e9SAndroid Build Coastguard Worker device: str = "cuda" 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard WorkerCI_SKIP_OPTIMIZER = { 125*da0073e9SAndroid Build Coastguard Worker # TIMM 126*da0073e9SAndroid Build Coastguard Worker "convmixer_768_32", # accuracy 127*da0073e9SAndroid Build Coastguard Worker "hrnet_w18", # Stack issue in fx 128*da0073e9SAndroid Build Coastguard Worker # HF 129*da0073e9SAndroid Build Coastguard Worker "pnasnet5large", # Stack issue in fx 130*da0073e9SAndroid Build Coastguard Worker "MobileBertForMaskedLM", # Stack issue in fx 131*da0073e9SAndroid Build Coastguard Worker "MobileBertForQuestionAnswering", # Stack issue in fx 132*da0073e9SAndroid Build Coastguard Worker "PegasusForConditionalGeneration", # OOM 133*da0073e9SAndroid Build Coastguard Worker} 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard WorkerCI_SKIP_DYNAMIC_BATCH_ONLY = { 136*da0073e9SAndroid Build Coastguard Worker "sam", 137*da0073e9SAndroid Build Coastguard Worker # See https://github.com/mindee/doctr/blob/f2114758d529ed8d3d0030581638f0520b6b98d8/doctr/models/detection/core.py#L89 138*da0073e9SAndroid Build Coastguard Worker # It iterates over the batch, which is dynamic, and dynamo chokes 139*da0073e9SAndroid Build Coastguard Worker # We should be able to graphbreak there. 140*da0073e9SAndroid Build Coastguard Worker "doctr_det_predictor", 141*da0073e9SAndroid Build Coastguard Worker "dlrm", 142*da0073e9SAndroid Build Coastguard Worker "pyhpc_isoneutral_mixing", 143*da0073e9SAndroid Build Coastguard Worker "pyhpc_equation_of_state", 144*da0073e9SAndroid Build Coastguard Worker "pyhpc_turbulent_kinetic_energy", 145*da0073e9SAndroid Build Coastguard Worker "detectron2_fcos_r_50_fpn", 146*da0073e9SAndroid Build Coastguard Worker "hf_T5_generate", 147*da0073e9SAndroid Build Coastguard Worker} 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker# These models currently fail accuracy with eager Adam optimizer 150*da0073e9SAndroid Build Coastguard Worker# so we use SGD when running the full benchmarks 151*da0073e9SAndroid Build Coastguard Worker# https://github.com/pytorch/pytorch/issues/115966 152*da0073e9SAndroid Build Coastguard WorkerBENCHMARK_USE_SGD = { 153*da0073e9SAndroid Build Coastguard Worker # TorchBench 154*da0073e9SAndroid Build Coastguard Worker "BERT_pytorch", 155*da0073e9SAndroid Build Coastguard Worker "LearningToPaint", 156*da0073e9SAndroid Build Coastguard Worker "alexnet", 157*da0073e9SAndroid Build Coastguard Worker "dcgan", 158*da0073e9SAndroid Build Coastguard Worker "demucs", 159*da0073e9SAndroid Build Coastguard Worker "densenet121", 160*da0073e9SAndroid Build Coastguard Worker "dlrm", 161*da0073e9SAndroid Build Coastguard Worker "fastNLP_Bert", 162*da0073e9SAndroid Build Coastguard Worker "mobilenet_v2", 163*da0073e9SAndroid Build Coastguard Worker "phlippe_densenet", 164*da0073e9SAndroid Build Coastguard Worker "phlippe_resnet", 165*da0073e9SAndroid Build Coastguard Worker "pytorch_stargan", 166*da0073e9SAndroid Build Coastguard Worker "resnet18", 167*da0073e9SAndroid Build Coastguard Worker "shufflenet_v2_x1_0", 168*da0073e9SAndroid Build Coastguard Worker "speech_transformer", 169*da0073e9SAndroid Build Coastguard Worker "squeezenet1_1", 170*da0073e9SAndroid Build Coastguard Worker "stable_diffusion_text_encoder", 171*da0073e9SAndroid Build Coastguard Worker "timm_efficientdet", 172*da0073e9SAndroid Build Coastguard Worker "timm_nfnet", 173*da0073e9SAndroid Build Coastguard Worker "timm_regnet", 174*da0073e9SAndroid Build Coastguard Worker "timm_vision_transformer", 175*da0073e9SAndroid Build Coastguard Worker "timm_vovnet", 176*da0073e9SAndroid Build Coastguard Worker "vgg16", 177*da0073e9SAndroid Build Coastguard Worker "hf_T5", # Fails dynamic https://github.com/pytorch/pytorch/issues/115968 178*da0073e9SAndroid Build Coastguard Worker # HF 179*da0073e9SAndroid Build Coastguard Worker "AlbertForMaskedLM", 180*da0073e9SAndroid Build Coastguard Worker "BartForCausalLM", 181*da0073e9SAndroid Build Coastguard Worker "BartForConditionalGeneration", 182*da0073e9SAndroid Build Coastguard Worker "BlenderbotSmallForCausalLM", 183*da0073e9SAndroid Build Coastguard Worker "BlenderbotSmallForConditionalGeneration", 184*da0073e9SAndroid Build Coastguard Worker "DebertaV2ForQuestionAnswering", # eager OOM 185*da0073e9SAndroid Build Coastguard Worker "ElectraForCausalLM", 186*da0073e9SAndroid Build Coastguard Worker "M2M100ForConditionalGeneration", 187*da0073e9SAndroid Build Coastguard Worker "MBartForCausalLM", 188*da0073e9SAndroid Build Coastguard Worker "MBartForConditionalGeneration", 189*da0073e9SAndroid Build Coastguard Worker "OPTForCausalLM", 190*da0073e9SAndroid Build Coastguard Worker "PLBartForCausalLM", 191*da0073e9SAndroid Build Coastguard Worker "PLBartForConditionalGeneration", 192*da0073e9SAndroid Build Coastguard Worker "PegasusForCausalLM", 193*da0073e9SAndroid Build Coastguard Worker "Speech2Text2ForCausalLM", 194*da0073e9SAndroid Build Coastguard Worker "TrOCRForCausalLM", 195*da0073e9SAndroid Build Coastguard Worker "XGLMForCausalLM", 196*da0073e9SAndroid Build Coastguard Worker # TIMM 197*da0073e9SAndroid Build Coastguard Worker "adv_inception_v3", 198*da0073e9SAndroid Build Coastguard Worker "botnet26t_256", 199*da0073e9SAndroid Build Coastguard Worker "cait_m36_384", # OOM 200*da0073e9SAndroid Build Coastguard Worker "coat_lite_mini", 201*da0073e9SAndroid Build Coastguard Worker "convit_base", 202*da0073e9SAndroid Build Coastguard Worker "dpn107", 203*da0073e9SAndroid Build Coastguard Worker "fbnetv3_b", 204*da0073e9SAndroid Build Coastguard Worker "gernet_l", 205*da0073e9SAndroid Build Coastguard Worker "lcnet_050", 206*da0073e9SAndroid Build Coastguard Worker "mixnet_l", 207*da0073e9SAndroid Build Coastguard Worker "res2net101_26w_4s", 208*da0073e9SAndroid Build Coastguard Worker "res2net50_14w_8s", 209*da0073e9SAndroid Build Coastguard Worker "res2next50", 210*da0073e9SAndroid Build Coastguard Worker "resnest101e", 211*da0073e9SAndroid Build Coastguard Worker "sebotnet33ts_256", 212*da0073e9SAndroid Build Coastguard Worker "swsl_resnext101_32x16d", 213*da0073e9SAndroid Build Coastguard Worker "tf_efficientnet_b0", 214*da0073e9SAndroid Build Coastguard Worker "ghostnet_100", 215*da0073e9SAndroid Build Coastguard Worker "gmixer_24_224", 216*da0073e9SAndroid Build Coastguard Worker "tinynet_a", 217*da0073e9SAndroid Build Coastguard Worker} 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker# These models OOM in CI 220*da0073e9SAndroid Build Coastguard Worker# due to the extra memory of Adam optimizer states, 221*da0073e9SAndroid Build Coastguard Worker# so we fall back to SGD in CI 222*da0073e9SAndroid Build Coastguard WorkerCI_USE_SGD = { 223*da0073e9SAndroid Build Coastguard Worker "torchrec_dlrm", 224*da0073e9SAndroid Build Coastguard Worker "demucs", 225*da0073e9SAndroid Build Coastguard Worker "detectron2_fasterrcnn_r_101_c4", 226*da0073e9SAndroid Build Coastguard Worker "detectron2_fasterrcnn_r_101_dc5", 227*da0073e9SAndroid Build Coastguard Worker "detectron2_fasterrcnn_r_101_fpn", 228*da0073e9SAndroid Build Coastguard Worker "detectron2_fasterrcnn_r_50_c4", 229*da0073e9SAndroid Build Coastguard Worker "detectron2_fasterrcnn_r_50_dc5", 230*da0073e9SAndroid Build Coastguard Worker "detectron2_fasterrcnn_r_50_fpn", 231*da0073e9SAndroid Build Coastguard Worker "detectron2_maskrcnn_r_101_c4", 232*da0073e9SAndroid Build Coastguard Worker "detectron2_maskrcnn_r_101_fpn", 233*da0073e9SAndroid Build Coastguard Worker "detectron2_maskrcnn_r_50_c4", 234*da0073e9SAndroid Build Coastguard Worker "detectron2_maskrcnn_r_50_fpn", 235*da0073e9SAndroid Build Coastguard Worker "hf_T5_base", 236*da0073e9SAndroid Build Coastguard Worker "hf_clip", 237*da0073e9SAndroid Build Coastguard Worker "llama_v2_7b_16h", 238*da0073e9SAndroid Build Coastguard Worker "mobilenet_v2_quantized_qat", 239*da0073e9SAndroid Build Coastguard Worker "phi_1_5 resnet50_quantized_qat", 240*da0073e9SAndroid Build Coastguard Worker "BlenderbotForCausalLM", 241*da0073e9SAndroid Build Coastguard Worker "cait_m36_384", 242*da0073e9SAndroid Build Coastguard Worker "DALLE2_pytorch", 243*da0073e9SAndroid Build Coastguard Worker "moco", 244*da0073e9SAndroid Build Coastguard Worker "timm_efficientdet", 245*da0073e9SAndroid Build Coastguard Worker "ghostnet_100", 246*da0073e9SAndroid Build Coastguard Worker "regnety_002", 247*da0073e9SAndroid Build Coastguard Worker "poolformer_m36", 248*da0073e9SAndroid Build Coastguard Worker "inception_v3", 249*da0073e9SAndroid Build Coastguard Worker "tinynet_a", 250*da0073e9SAndroid Build Coastguard Worker "selecsls42b", 251*da0073e9SAndroid Build Coastguard Worker "mobilevit_s", 252*da0073e9SAndroid Build Coastguard Worker "pytorch_CycleGAN_and_pix2pix", 253*da0073e9SAndroid Build Coastguard Worker "vision_maskrcnn", 254*da0073e9SAndroid Build Coastguard Worker "resmlp_12_224", 255*da0073e9SAndroid Build Coastguard Worker "dlrm", 256*da0073e9SAndroid Build Coastguard Worker "resnet50", 257*da0073e9SAndroid Build Coastguard Worker "dm_nfnet_f0", 258*da0073e9SAndroid Build Coastguard Worker "pit_b_224", 259*da0073e9SAndroid Build Coastguard Worker "tf_mixnet_l", 260*da0073e9SAndroid Build Coastguard Worker} 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard WorkerDO_NOT_CAST_INPUTS = {"stable_diffusion"} 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker# Maps a benchmark model name to a list of status codes. For any listed entry, we'll 267*da0073e9SAndroid Build Coastguard Worker# capture TORCH_COMPILE_DEBUG logs in CI runs and preseve them (i.e., for upload) if 268*da0073e9SAndroid Build Coastguard Worker# the result status matches one listed. 269*da0073e9SAndroid Build Coastguard WorkerCI_PRESERVE_COMPILE_DEBUG = { 270*da0073e9SAndroid Build Coastguard Worker # For example: 271*da0073e9SAndroid Build Coastguard Worker # "mnasnet1_0": ["fail_accuracy"], 272*da0073e9SAndroid Build Coastguard Worker} 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Workerdef model_specified_by_path(path_and_class_str): 276*da0073e9SAndroid Build Coastguard Worker return ":" in path_and_class_str 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Workerdef load_model_from_path(path_and_class_str): 280*da0073e9SAndroid Build Coastguard Worker configs = {} 281*da0073e9SAndroid Build Coastguard Worker for kvstr in path_and_class_str.split(","): 282*da0073e9SAndroid Build Coastguard Worker k, v = kvstr.split(":") 283*da0073e9SAndroid Build Coastguard Worker configs[k] = v 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker for name in ["path", "class"]: 286*da0073e9SAndroid Build Coastguard Worker if name not in configs: 287*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 288*da0073e9SAndroid Build Coastguard Worker "Invalid --only arguments. Check help message for the correct format" 289*da0073e9SAndroid Build Coastguard Worker ) 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker path = configs["path"] 292*da0073e9SAndroid Build Coastguard Worker class_name = configs["class"] 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker if path[:1] != "/": 295*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 296*da0073e9SAndroid Build Coastguard Worker "Use absolute path since dynamo may change the current working directory which makes using relative path tricky" 297*da0073e9SAndroid Build Coastguard Worker ) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker spec = importlib.util.spec_from_file_location("module_name", path) 300*da0073e9SAndroid Build Coastguard Worker module = importlib.util.module_from_spec(spec) 301*da0073e9SAndroid Build Coastguard Worker spec.loader.exec_module(module) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker model_class = getattr(module, class_name) 304*da0073e9SAndroid Build Coastguard Worker assert issubclass(model_class, torch.nn.Module) 305*da0073e9SAndroid Build Coastguard Worker model = model_class() 306*da0073e9SAndroid Build Coastguard Worker assert hasattr(model, "get_example_inputs") 307*da0073e9SAndroid Build Coastguard Worker inputs = model.get_example_inputs() 308*da0073e9SAndroid Build Coastguard Worker return model, inputs 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Workerdef output_csv(filename, headers, row): 312*da0073e9SAndroid Build Coastguard Worker global disable_output 313*da0073e9SAndroid Build Coastguard Worker if disable_output: 314*da0073e9SAndroid Build Coastguard Worker return 315*da0073e9SAndroid Build Coastguard Worker if os.path.exists(filename): 316*da0073e9SAndroid Build Coastguard Worker with open(filename) as fd: 317*da0073e9SAndroid Build Coastguard Worker lines = list(csv.reader(fd)) or [[]] 318*da0073e9SAndroid Build Coastguard Worker if headers and len(headers) > len(lines[0]): 319*da0073e9SAndroid Build Coastguard Worker # if prior results failed the header might not be filled in yet 320*da0073e9SAndroid Build Coastguard Worker lines[0] = headers 321*da0073e9SAndroid Build Coastguard Worker else: 322*da0073e9SAndroid Build Coastguard Worker headers = lines[0] 323*da0073e9SAndroid Build Coastguard Worker else: 324*da0073e9SAndroid Build Coastguard Worker lines = [headers] 325*da0073e9SAndroid Build Coastguard Worker lines.append([(f"{x:.6f}" if isinstance(x, float) else x) for x in row]) 326*da0073e9SAndroid Build Coastguard Worker with open(filename, "w") as fd: 327*da0073e9SAndroid Build Coastguard Worker writer = csv.writer(fd, lineterminator="\n") 328*da0073e9SAndroid Build Coastguard Worker for line in lines: 329*da0073e9SAndroid Build Coastguard Worker writer.writerow(list(line) + ["0"] * (len(headers) - len(line))) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Workerdef nothing(f): 333*da0073e9SAndroid Build Coastguard Worker return f 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(None) 337*da0073e9SAndroid Build Coastguard Workerdef patch_torch_manual_seed(): 338*da0073e9SAndroid Build Coastguard Worker """Make torch manual seed deterministic. Helps with accuracy testing.""" 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker def deterministic_torch_manual_seed(*args, **kwargs): 341*da0073e9SAndroid Build Coastguard Worker from torch._C import default_generator 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker seed = 1337 344*da0073e9SAndroid Build Coastguard Worker if HAS_CUDA: 345*da0073e9SAndroid Build Coastguard Worker import torch.cuda 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker if not torch.cuda._is_in_bad_fork(): 348*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed_all(seed) 349*da0073e9SAndroid Build Coastguard Worker if HAS_XPU: 350*da0073e9SAndroid Build Coastguard Worker import torch.xpu 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker if not torch.xpu._is_in_bad_fork(): 353*da0073e9SAndroid Build Coastguard Worker torch.xpu.manual_seed_all(seed) 354*da0073e9SAndroid Build Coastguard Worker return default_generator.manual_seed(seed) 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker torch.manual_seed = deterministic_torch_manual_seed 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Workerdef empty_gpu_cache(device): 360*da0073e9SAndroid Build Coastguard Worker """ 361*da0073e9SAndroid Build Coastguard Worker Explicitly empty gpu cache to avoid OOM in subsequent run. 362*da0073e9SAndroid Build Coastguard Worker """ 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker if device not in ["cuda", "xpu"]: 365*da0073e9SAndroid Build Coastguard Worker log.warning( 366*da0073e9SAndroid Build Coastguard Worker "Trying to call the empty_gpu_cache for device: %s, which is not in list [cuda, xpu]", 367*da0073e9SAndroid Build Coastguard Worker device, 368*da0073e9SAndroid Build Coastguard Worker ) 369*da0073e9SAndroid Build Coastguard Worker return 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker if device == "cuda": 372*da0073e9SAndroid Build Coastguard Worker torch.cuda.empty_cache() 373*da0073e9SAndroid Build Coastguard Worker elif device == "xpu": 374*da0073e9SAndroid Build Coastguard Worker torch.xpu.empty_cache() 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Workerdef synchronize(): 378*da0073e9SAndroid Build Coastguard Worker pass 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Workerdef summarize_graph_break(filename): 382*da0073e9SAndroid Build Coastguard Worker """ 383*da0073e9SAndroid Build Coastguard Worker Sorts and de-dupes the graphs breaks on the reason string. Note that this 384*da0073e9SAndroid Build Coastguard Worker function is just a best effort to reduce the logging information. We could 385*da0073e9SAndroid Build Coastguard Worker miss some graph breaks because of de-duping. We can further refine this 386*da0073e9SAndroid Build Coastguard Worker function as need arises. 387*da0073e9SAndroid Build Coastguard Worker """ 388*da0073e9SAndroid Build Coastguard Worker log_file = f"{filename.rstrip('.csv')}_graph_breaks.csv" 389*da0073e9SAndroid Build Coastguard Worker if os.path.exists(log_file): 390*da0073e9SAndroid Build Coastguard Worker df = pd.read_csv(log_file) 391*da0073e9SAndroid Build Coastguard Worker df = df.sort_values("reason").drop_duplicates(subset="reason") 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker # Specialize for multi tensor sgd as reason is not identical 394*da0073e9SAndroid Build Coastguard Worker multi_tensor_sgd_row = df.loc[df["reason"].str.contains("_multi_tensor_sgd")] 395*da0073e9SAndroid Build Coastguard Worker if len(multi_tensor_sgd_row): 396*da0073e9SAndroid Build Coastguard Worker df = df[ 397*da0073e9SAndroid Build Coastguard Worker ~df["reason"].str.contains("_multi_tensor_sgd") 398*da0073e9SAndroid Build Coastguard Worker ] # Drop all sgd rows 399*da0073e9SAndroid Build Coastguard Worker df = pd.concat( 400*da0073e9SAndroid Build Coastguard Worker [df, pd.DataFrame([multi_tensor_sgd_row.iloc[0]])], axis=0 401*da0073e9SAndroid Build Coastguard Worker ) # Add back a single row 402*da0073e9SAndroid Build Coastguard Worker df.to_csv(f"{log_file.rstrip('.csv')}_deduped.csv", index=False) 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Workerdef print_summary(filename, print_dataframe=False): 406*da0073e9SAndroid Build Coastguard Worker if not (filename and os.path.exists(filename)): 407*da0073e9SAndroid Build Coastguard Worker return 408*da0073e9SAndroid Build Coastguard Worker data = pd.read_csv(filename) 409*da0073e9SAndroid Build Coastguard Worker if "tag" in data.columns: 410*da0073e9SAndroid Build Coastguard Worker for tag in data.tag.unique(): 411*da0073e9SAndroid Build Coastguard Worker if tag == "0.0000": 412*da0073e9SAndroid Build Coastguard Worker continue # This happens for failed runs 413*da0073e9SAndroid Build Coastguard Worker print(f"\nSummary for tag={tag}:") 414*da0073e9SAndroid Build Coastguard Worker print_summary_table(data[data.tag == tag], print_dataframe=print_dataframe) 415*da0073e9SAndroid Build Coastguard Worker else: 416*da0073e9SAndroid Build Coastguard Worker print_summary_table(data, print_dataframe=print_dataframe) 417*da0073e9SAndroid Build Coastguard Worker summarize_graph_break(filename) 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Workerdef print_summary_table(data, print_dataframe=False): 421*da0073e9SAndroid Build Coastguard Worker if print_dataframe: 422*da0073e9SAndroid Build Coastguard Worker pd.options.display.max_rows = 1000 423*da0073e9SAndroid Build Coastguard Worker pd.options.display.max_columns = 1000 424*da0073e9SAndroid Build Coastguard Worker pd.options.display.width = 2000 425*da0073e9SAndroid Build Coastguard Worker print(data) 426*da0073e9SAndroid Build Coastguard Worker width = max(map(len, data.columns)) 427*da0073e9SAndroid Build Coastguard Worker for col in data.columns: 428*da0073e9SAndroid Build Coastguard Worker try: 429*da0073e9SAndroid Build Coastguard Worker if col in ("dev", "name", "batch_size", "tag"): 430*da0073e9SAndroid Build Coastguard Worker continue 431*da0073e9SAndroid Build Coastguard Worker elif col in ("pct_ops", "pct_time"): 432*da0073e9SAndroid Build Coastguard Worker print(col.ljust(width), f"{data[col].mean():.3%}") 433*da0073e9SAndroid Build Coastguard Worker elif col in ("graphs", "graph_calls", "captured_ops", "total_ops"): 434*da0073e9SAndroid Build Coastguard Worker print(col.ljust(width), f"{data[col].mean():.3f}") 435*da0073e9SAndroid Build Coastguard Worker elif col in ("compilation_latency"): 436*da0073e9SAndroid Build Coastguard Worker print(col.ljust(width), f"mean={data[col].mean():.3f} seconds") 437*da0073e9SAndroid Build Coastguard Worker elif col in ("compression_ratio"): 438*da0073e9SAndroid Build Coastguard Worker print(col.ljust(width), f"mean={data[col].mean():.3f}x") 439*da0073e9SAndroid Build Coastguard Worker elif col in ("accuracy"): 440*da0073e9SAndroid Build Coastguard Worker pass_rate = (data[col] == "pass").mean() 441*da0073e9SAndroid Build Coastguard Worker print(col.ljust(width), f"pass_rate={100*pass_rate:.2f}%") 442*da0073e9SAndroid Build Coastguard Worker else: 443*da0073e9SAndroid Build Coastguard Worker cdata = data[col] 444*da0073e9SAndroid Build Coastguard Worker print( 445*da0073e9SAndroid Build Coastguard Worker col.ljust(width), 446*da0073e9SAndroid Build Coastguard Worker f"gmean={gmean(cdata):.2f}x mean={cdata.mean():.3f}x", 447*da0073e9SAndroid Build Coastguard Worker ) 448*da0073e9SAndroid Build Coastguard Worker except Exception as e: 449*da0073e9SAndroid Build Coastguard Worker pass 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Workerdef tensor_is_on_xla(tensors): 453*da0073e9SAndroid Build Coastguard Worker def visit(x: torch.Tensor): 454*da0073e9SAndroid Build Coastguard Worker nonlocal result 455*da0073e9SAndroid Build Coastguard Worker if x.device.type == "xla": 456*da0073e9SAndroid Build Coastguard Worker result = True 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker result = False 459*da0073e9SAndroid Build Coastguard Worker tree_map_only(torch.Tensor, visit, tensors) 460*da0073e9SAndroid Build Coastguard Worker return result 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Workerdef timed( 464*da0073e9SAndroid Build Coastguard Worker model, 465*da0073e9SAndroid Build Coastguard Worker model_iter_fn, 466*da0073e9SAndroid Build Coastguard Worker example_inputs, 467*da0073e9SAndroid Build Coastguard Worker times=1, 468*da0073e9SAndroid Build Coastguard Worker return_result=False, 469*da0073e9SAndroid Build Coastguard Worker collect_outputs=False, 470*da0073e9SAndroid Build Coastguard Worker): 471*da0073e9SAndroid Build Coastguard Worker use_xla = tensor_is_on_xla(example_inputs) 472*da0073e9SAndroid Build Coastguard Worker synchronize() 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker if use_xla: 475*da0073e9SAndroid Build Coastguard Worker xm.mark_step() 476*da0073e9SAndroid Build Coastguard Worker xm.wait_device_ops() 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker time_total = 0 479*da0073e9SAndroid Build Coastguard Worker # Dont collect outputs to correctly measure timing 480*da0073e9SAndroid Build Coastguard Worker for _ in range(times): 481*da0073e9SAndroid Build Coastguard Worker # Put this call inside the loop to reset the seed for each iteration. 482*da0073e9SAndroid Build Coastguard Worker # Don't include reset_rng_state() to correctly measure timing 483*da0073e9SAndroid Build Coastguard Worker reset_rng_state(use_xla) 484*da0073e9SAndroid Build Coastguard Worker t_iter_begin = time.perf_counter() 485*da0073e9SAndroid Build Coastguard Worker result = model_iter_fn(model, example_inputs, collect_outputs=collect_outputs) 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker # instead of calling sync on result_list, we should call mark_step. 488*da0073e9SAndroid Build Coastguard Worker # In training case, result_list may be empty, but we want to 489*da0073e9SAndroid Build Coastguard Worker # send all the pending graphs for compilation. 490*da0073e9SAndroid Build Coastguard Worker if use_xla: 491*da0073e9SAndroid Build Coastguard Worker # For the model running on regular torchxla (baseline), we need the 492*da0073e9SAndroid Build Coastguard Worker # mark step to send the accumulated graph for compilation. 493*da0073e9SAndroid Build Coastguard Worker # 494*da0073e9SAndroid Build Coastguard Worker # For the model running with dynamo/torchxla bridge, in training case, 495*da0073e9SAndroid Build Coastguard Worker # we need the mark step to send the optimizer graph out for 496*da0073e9SAndroid Build Coastguard Worker # compilation. 497*da0073e9SAndroid Build Coastguard Worker xm.mark_step() 498*da0073e9SAndroid Build Coastguard Worker t_iter_end = time.perf_counter() 499*da0073e9SAndroid Build Coastguard Worker time_total += t_iter_end - t_iter_begin 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker t_0 = time.perf_counter() 502*da0073e9SAndroid Build Coastguard Worker if use_xla: 503*da0073e9SAndroid Build Coastguard Worker xm.wait_device_ops() 504*da0073e9SAndroid Build Coastguard Worker synchronize() 505*da0073e9SAndroid Build Coastguard Worker t_1 = time.perf_counter() 506*da0073e9SAndroid Build Coastguard Worker time_total += t_1 - t_0 507*da0073e9SAndroid Build Coastguard Worker return (time_total, result) if return_result else time_total 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Workerdef _normalize_bench_inputs(example_inputs) -> Tuple[Tuple[Any], Mapping[str, Any]]: 511*da0073e9SAndroid Build Coastguard Worker # NOTE(bowbao): For huggingface benchmark, example_inputs are formatted as dictionary, 512*da0073e9SAndroid Build Coastguard Worker # and consumed like `model(**example_inputs)`. 513*da0073e9SAndroid Build Coastguard Worker # For other benchmarks, example_inputs are formatted as tuple and consumed 514*da0073e9SAndroid Build Coastguard Worker # like `model(*example_inputs)`. 515*da0073e9SAndroid Build Coastguard Worker if isinstance(example_inputs, dict): 516*da0073e9SAndroid Build Coastguard Worker return (), example_inputs 517*da0073e9SAndroid Build Coastguard Worker else: 518*da0073e9SAndroid Build Coastguard Worker return tuple(example_inputs), {} 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Workerdef _register_dataclass_output_as_pytree(example_outputs) -> None: 522*da0073e9SAndroid Build Coastguard Worker # NOTE(angelayi): For huggingface benchmark, some example outputs are 523*da0073e9SAndroid Build Coastguard Worker # formatted as a dataclass which pytree cannot consume. So we want 524*da0073e9SAndroid Build Coastguard Worker # to register the pytree implementation here 525*da0073e9SAndroid Build Coastguard Worker example_outputs_flat = pytree.tree_leaves(example_outputs) 526*da0073e9SAndroid Build Coastguard Worker output_dataclass_types = [ 527*da0073e9SAndroid Build Coastguard Worker type(out) for out in example_outputs_flat if dataclasses.is_dataclass(type(out)) 528*da0073e9SAndroid Build Coastguard Worker ] 529*da0073e9SAndroid Build Coastguard Worker for output_type in output_dataclass_types: 530*da0073e9SAndroid Build Coastguard Worker from torch._export.utils import register_dataclass_as_pytree_node 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker register_dataclass_as_pytree_node( 533*da0073e9SAndroid Build Coastguard Worker output_type, 534*da0073e9SAndroid Build Coastguard Worker serialized_type_name=f"{output_type.__module__}.{output_type.__name__}", 535*da0073e9SAndroid Build Coastguard Worker ) 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Workerclass Stats: 539*da0073e9SAndroid Build Coastguard Worker totals = collections.defaultdict(collections.Counter) 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker @classmethod 542*da0073e9SAndroid Build Coastguard Worker def reset_counters(cls): 543*da0073e9SAndroid Build Coastguard Worker for k, v in torch._dynamo.utils.counters.items(): 544*da0073e9SAndroid Build Coastguard Worker cls.totals[k].update(v) 545*da0073e9SAndroid Build Coastguard Worker ok = torch._dynamo.utils.counters["frames"]["ok"] 546*da0073e9SAndroid Build Coastguard Worker total = torch._dynamo.utils.counters["frames"]["total"] 547*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.counters.clear() 548*da0073e9SAndroid Build Coastguard Worker return ok, total 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker @classmethod 551*da0073e9SAndroid Build Coastguard Worker def print_summary(cls): 552*da0073e9SAndroid Build Coastguard Worker for k, v in sorted(cls.totals.items()): 553*da0073e9SAndroid Build Coastguard Worker lines = "\n ".join(map(str, v.most_common(50))) 554*da0073e9SAndroid Build Coastguard Worker print(f"STATS {k}\n {lines}") 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker @classmethod 557*da0073e9SAndroid Build Coastguard Worker def aot_summary(cls): 558*da0073e9SAndroid Build Coastguard Worker return [cls.totals["aot_autograd"]["total"], cls.totals["aot_autograd"]["ok"]] 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Workerdef coverage_experiment(args, model_iter_fn, model, example_inputs): 562*da0073e9SAndroid Build Coastguard Worker """ 563*da0073e9SAndroid Build Coastguard Worker Test operator/model coverage of TorchDynamo and record statistics 564*da0073e9SAndroid Build Coastguard Worker taken from a profiler. This target is mainly intended to check 565*da0073e9SAndroid Build Coastguard Worker correctness. 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker Writes to ./coverage.csv 568*da0073e9SAndroid Build Coastguard Worker """ 569*da0073e9SAndroid Build Coastguard Worker profiler = Profiler() 570*da0073e9SAndroid Build Coastguard Worker frozen_model_iter_fn = torch._dynamo.run(model_iter_fn) 571*da0073e9SAndroid Build Coastguard Worker with profiler.prof: 572*da0073e9SAndroid Build Coastguard Worker frozen_model_iter_fn(model, example_inputs) 573*da0073e9SAndroid Build Coastguard Worker coverage_result = profiler.results() 574*da0073e9SAndroid Build Coastguard Worker output_csv( 575*da0073e9SAndroid Build Coastguard Worker output_filename, 576*da0073e9SAndroid Build Coastguard Worker ( 577*da0073e9SAndroid Build Coastguard Worker "dev", 578*da0073e9SAndroid Build Coastguard Worker "name", 579*da0073e9SAndroid Build Coastguard Worker "batch_size", 580*da0073e9SAndroid Build Coastguard Worker "graphs", 581*da0073e9SAndroid Build Coastguard Worker "graph_calls", 582*da0073e9SAndroid Build Coastguard Worker "captured_ops", 583*da0073e9SAndroid Build Coastguard Worker "total_ops", 584*da0073e9SAndroid Build Coastguard Worker "pct_ops", 585*da0073e9SAndroid Build Coastguard Worker "pct_time", 586*da0073e9SAndroid Build Coastguard Worker ), 587*da0073e9SAndroid Build Coastguard Worker [ 588*da0073e9SAndroid Build Coastguard Worker current_device, 589*da0073e9SAndroid Build Coastguard Worker current_name, 590*da0073e9SAndroid Build Coastguard Worker current_batch_size, 591*da0073e9SAndroid Build Coastguard Worker ] 592*da0073e9SAndroid Build Coastguard Worker + coverage_result.tocsv(), 593*da0073e9SAndroid Build Coastguard Worker ) 594*da0073e9SAndroid Build Coastguard Worker return coverage_result 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Workerdef speedup_experiment_fx2trt(args, model_iter_fn, model, example_inputs): 598*da0073e9SAndroid Build Coastguard Worker """ 599*da0073e9SAndroid Build Coastguard Worker Measure speedups over eager using the trt inference backend. TRT backend is based fx graph 600*da0073e9SAndroid Build Coastguard Worker generated by torch._dynamo. 601*da0073e9SAndroid Build Coastguard Worker Writes to ./speedups_fx2trt.csv 602*da0073e9SAndroid Build Coastguard Worker """ 603*da0073e9SAndroid Build Coastguard Worker return speedup_experiment(args, model_iter_fn, model, example_inputs) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Workerdef recompile_profiler_experiment(args, model_iter_fn, model, example_inputs): 607*da0073e9SAndroid Build Coastguard Worker prof = torch._dynamo.utils.CompilerProfiler() 608*da0073e9SAndroid Build Coastguard Worker opt_model_iter_fn = torch._dynamo.optimize(prof, nopython=args.nopython)( 609*da0073e9SAndroid Build Coastguard Worker model_iter_fn 610*da0073e9SAndroid Build Coastguard Worker ) 611*da0073e9SAndroid Build Coastguard Worker opt_model_iter_fn(model, example_inputs) 612*da0073e9SAndroid Build Coastguard Worker output_csv( 613*da0073e9SAndroid Build Coastguard Worker output_filename, ["model", "profiler report"], [current_name, prof.report()] 614*da0073e9SAndroid Build Coastguard Worker ) 615*da0073e9SAndroid Build Coastguard Worker met = prof.get_metrics() 616*da0073e9SAndroid Build Coastguard Worker guard_failures = len(met["guard_failures"]) 617*da0073e9SAndroid Build Coastguard Worker return [guard_failures] 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Workerdef randomize_input(inputs): 621*da0073e9SAndroid Build Coastguard Worker if isinstance(inputs, (list, tuple)): 622*da0073e9SAndroid Build Coastguard Worker return type(inputs)([randomize_input(x) for x in inputs]) 623*da0073e9SAndroid Build Coastguard Worker elif isinstance(inputs, torch.Tensor): 624*da0073e9SAndroid Build Coastguard Worker if inputs.dtype in (torch.float32, torch.float64): 625*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.counters["randomize_input"]["times"] += 1 626*da0073e9SAndroid Build Coastguard Worker return torch.randn_like(inputs) 627*da0073e9SAndroid Build Coastguard Worker elif inputs.dtype == torch.int64: 628*da0073e9SAndroid Build Coastguard Worker # Note: we can not simply tune integer tensors as follows 629*da0073e9SAndroid Build Coastguard Worker # `return torch.randint_like(inputs, high=inputs.max().item())` 630*da0073e9SAndroid Build Coastguard Worker # This may break some invariants between tensors. 631*da0073e9SAndroid Build Coastguard Worker # E.g. in embedding lookup case, one tensor is the length 632*da0073e9SAndroid Build Coastguard Worker # and another is an indices tensor. 633*da0073e9SAndroid Build Coastguard Worker return inputs 634*da0073e9SAndroid Build Coastguard Worker else: 635*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 636*da0073e9SAndroid Build Coastguard Worker f"randomize_input need support tensor of type {inputs.dtype}" 637*da0073e9SAndroid Build Coastguard Worker ) 638*da0073e9SAndroid Build Coastguard Worker else: 639*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 640*da0073e9SAndroid Build Coastguard Worker f"randomize_input can not handle input of type {type(inputs)}" 641*da0073e9SAndroid Build Coastguard Worker ) 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker 644*da0073e9SAndroid Build Coastguard Workerdef maybe_mark_step(args): 645*da0073e9SAndroid Build Coastguard Worker if args.trace_on_xla: 646*da0073e9SAndroid Build Coastguard Worker xm.mark_step() 647*da0073e9SAndroid Build Coastguard Worker 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Workerdef speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs): 650*da0073e9SAndroid Build Coastguard Worker """ 651*da0073e9SAndroid Build Coastguard Worker Measure speedups over eager. 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker Writes to ./speedups.csv 654*da0073e9SAndroid Build Coastguard Worker """ 655*da0073e9SAndroid Build Coastguard Worker # if args.dynamic_shapes: 656*da0073e9SAndroid Build Coastguard Worker # return speedup_experiment_ds(args, model_iter_fn, model, example_inputs) 657*da0073e9SAndroid Build Coastguard Worker 658*da0073e9SAndroid Build Coastguard Worker timings = np.zeros((args.repeat, 2), np.float64) 659*da0073e9SAndroid Build Coastguard Worker # if we randomize the input, we should also check the result is correct 660*da0073e9SAndroid Build Coastguard Worker should_randomize_input = args.randomize_input 661*da0073e9SAndroid Build Coastguard Worker 662*da0073e9SAndroid Build Coastguard Worker import contextlib 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker from torch._inductor.utils import maybe_profile 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 667*da0073e9SAndroid Build Coastguard Worker def maybe_mark_profile(*args, **kwargs): 668*da0073e9SAndroid Build Coastguard Worker prof: torch.profiler.profile = kwargs.pop("p", None) 669*da0073e9SAndroid Build Coastguard Worker mark = kwargs.pop("mark", None) 670*da0073e9SAndroid Build Coastguard Worker if prof: 671*da0073e9SAndroid Build Coastguard Worker with torch.profiler.record_function(mark): 672*da0073e9SAndroid Build Coastguard Worker yield 673*da0073e9SAndroid Build Coastguard Worker else: 674*da0073e9SAndroid Build Coastguard Worker yield 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker times = args.iterations_per_run 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker # Use higher tolerance for XLA since XLA cause numerical unstability when 679*da0073e9SAndroid Build Coastguard Worker # graph size changes 680*da0073e9SAndroid Build Coastguard Worker tolerance = args.xla_tolerance if args.trace_on_xla else 1e-4 681*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.repro_tolerance = tolerance 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker with maybe_profile(args.export_profiler_trace) as p: 684*da0073e9SAndroid Build Coastguard Worker if args.export_aot_inductor: 685*da0073e9SAndroid Build Coastguard Worker frozen_model_iter_fn = export_aot_inductor( 686*da0073e9SAndroid Build Coastguard Worker model, example_inputs, args.devices[0] 687*da0073e9SAndroid Build Coastguard Worker ) 688*da0073e9SAndroid Build Coastguard Worker else: 689*da0073e9SAndroid Build Coastguard Worker frozen_model_iter_fn = torch._dynamo.run(model_iter_fn) 690*da0073e9SAndroid Build Coastguard Worker 691*da0073e9SAndroid Build Coastguard Worker for rep in trange(args.repeat, desc="running benchmark"): 692*da0073e9SAndroid Build Coastguard Worker inputs = ( 693*da0073e9SAndroid Build Coastguard Worker randomize_input(copy.deepcopy(example_inputs)) 694*da0073e9SAndroid Build Coastguard Worker if should_randomize_input 695*da0073e9SAndroid Build Coastguard Worker else example_inputs 696*da0073e9SAndroid Build Coastguard Worker ) 697*da0073e9SAndroid Build Coastguard Worker # need call mark_step to perform the computation 698*da0073e9SAndroid Build Coastguard Worker # on randomize_input. Otherwise the first call using the 699*da0073e9SAndroid Build Coastguard Worker # inputs will incur high penalty then the next one. 700*da0073e9SAndroid Build Coastguard Worker maybe_mark_step(args) 701*da0073e9SAndroid Build Coastguard Worker 702*da0073e9SAndroid Build Coastguard Worker # interleave the runs to handle frequency scaling and load changes 703*da0073e9SAndroid Build Coastguard Worker with maybe_mark_profile(p=p, mark="expected"): 704*da0073e9SAndroid Build Coastguard Worker timings[rep, 0], expected_output = timed( 705*da0073e9SAndroid Build Coastguard Worker model, 706*da0073e9SAndroid Build Coastguard Worker model_iter_fn, 707*da0073e9SAndroid Build Coastguard Worker inputs, 708*da0073e9SAndroid Build Coastguard Worker return_result=True, 709*da0073e9SAndroid Build Coastguard Worker times=times, 710*da0073e9SAndroid Build Coastguard Worker collect_outputs=args.collect_outputs, 711*da0073e9SAndroid Build Coastguard Worker ) 712*da0073e9SAndroid Build Coastguard Worker 713*da0073e9SAndroid Build Coastguard Worker # call mark_step between the 2 calls to make the comparison fair. 714*da0073e9SAndroid Build Coastguard Worker maybe_mark_step(args) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker with maybe_mark_profile(p=p, mark="actual"), maybe_enable_compiled_autograd( 717*da0073e9SAndroid Build Coastguard Worker args.compiled_autograd 718*da0073e9SAndroid Build Coastguard Worker ): 719*da0073e9SAndroid Build Coastguard Worker timings[rep, 1], actual_output = timed( 720*da0073e9SAndroid Build Coastguard Worker model, 721*da0073e9SAndroid Build Coastguard Worker frozen_model_iter_fn, 722*da0073e9SAndroid Build Coastguard Worker inputs, 723*da0073e9SAndroid Build Coastguard Worker return_result=True, 724*da0073e9SAndroid Build Coastguard Worker times=times, 725*da0073e9SAndroid Build Coastguard Worker collect_outputs=args.collect_outputs, 726*da0073e9SAndroid Build Coastguard Worker ) 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker if args.export_profiler_trace: 729*da0073e9SAndroid Build Coastguard Worker name = args.profiler_trace_name + "_" + model.name 730*da0073e9SAndroid Build Coastguard Worker if hasattr(args, "rank"): 731*da0073e9SAndroid Build Coastguard Worker name += f"_rank_{args.rank}" 732*da0073e9SAndroid Build Coastguard Worker name += ".json" 733*da0073e9SAndroid Build Coastguard Worker name = os.path.join(torch._dynamo.config.base_dir, name) 734*da0073e9SAndroid Build Coastguard Worker p.export_chrome_trace(name) 735*da0073e9SAndroid Build Coastguard Worker median = np.median(timings, axis=0) 736*da0073e9SAndroid Build Coastguard Worker speedup = median[0] / median[1] 737*da0073e9SAndroid Build Coastguard Worker if args.dump_raw_metrics: 738*da0073e9SAndroid Build Coastguard Worker np.save( 739*da0073e9SAndroid Build Coastguard Worker f"{output_filename[:-4]}-raw_timings-{current_name}-{current_device}.npy", 740*da0073e9SAndroid Build Coastguard Worker timings, 741*da0073e9SAndroid Build Coastguard Worker ) 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker first_headers = ["dev", "name", "batch_size"] 744*da0073e9SAndroid Build Coastguard Worker first_fields = [current_device, current_name, current_batch_size] 745*da0073e9SAndroid Build Coastguard Worker if "tag" in kwargs: 746*da0073e9SAndroid Build Coastguard Worker first_headers.append("tag") 747*da0073e9SAndroid Build Coastguard Worker first_fields.append(kwargs["tag"]) 748*da0073e9SAndroid Build Coastguard Worker headers = first_headers + ["speedup", "abs_latency"] 749*da0073e9SAndroid Build Coastguard Worker row = first_fields + [float(speedup), median[1] * 1000] 750*da0073e9SAndroid Build Coastguard Worker msg = f"{speedup:.3f}x" 751*da0073e9SAndroid Build Coastguard Worker if args.baseline: 752*da0073e9SAndroid Build Coastguard Worker headers.extend( 753*da0073e9SAndroid Build Coastguard Worker [ 754*da0073e9SAndroid Build Coastguard Worker "baseline", 755*da0073e9SAndroid Build Coastguard Worker "speedup_vs_baseline", 756*da0073e9SAndroid Build Coastguard Worker ] 757*da0073e9SAndroid Build Coastguard Worker ) 758*da0073e9SAndroid Build Coastguard Worker df = pd.read_csv(args.baseline) 759*da0073e9SAndroid Build Coastguard Worker try: 760*da0073e9SAndroid Build Coastguard Worker baseline_speedup = df[df["name"] == current_name]["speedup"].item() 761*da0073e9SAndroid Build Coastguard Worker row.extend([baseline_speedup, speedup / baseline_speedup]) 762*da0073e9SAndroid Build Coastguard Worker msg = f"{baseline_speedup:.3f}x -> {speedup:.3f}x [{speedup / baseline_speedup:.3f}x]" 763*da0073e9SAndroid Build Coastguard Worker except (KeyError, ZeroDivisionError): 764*da0073e9SAndroid Build Coastguard Worker row.extend( 765*da0073e9SAndroid Build Coastguard Worker [ 766*da0073e9SAndroid Build Coastguard Worker 0.0, 767*da0073e9SAndroid Build Coastguard Worker 0.0, 768*da0073e9SAndroid Build Coastguard Worker ] 769*da0073e9SAndroid Build Coastguard Worker ) 770*da0073e9SAndroid Build Coastguard Worker if "compilation_latency" in kwargs: 771*da0073e9SAndroid Build Coastguard Worker headers += [ 772*da0073e9SAndroid Build Coastguard Worker "compilation_latency", 773*da0073e9SAndroid Build Coastguard Worker "compression_ratio", 774*da0073e9SAndroid Build Coastguard Worker "eager_peak_mem", 775*da0073e9SAndroid Build Coastguard Worker "dynamo_peak_mem", 776*da0073e9SAndroid Build Coastguard Worker ] 777*da0073e9SAndroid Build Coastguard Worker row.append(kwargs["compilation_latency"]) 778*da0073e9SAndroid Build Coastguard Worker row.append(kwargs["compression_ratio"]) 779*da0073e9SAndroid Build Coastguard Worker row.append(kwargs["eager_peak_mem"]) 780*da0073e9SAndroid Build Coastguard Worker row.append(kwargs["dynamo_peak_mem"]) 781*da0073e9SAndroid Build Coastguard Worker 782*da0073e9SAndroid Build Coastguard Worker if "cache_lookup_latency" in kwargs: 783*da0073e9SAndroid Build Coastguard Worker headers.append("cache_lookup_latency") 784*da0073e9SAndroid Build Coastguard Worker row.append(kwargs["cache_lookup_latency"]) 785*da0073e9SAndroid Build Coastguard Worker 786*da0073e9SAndroid Build Coastguard Worker if "dynamo_stats" in kwargs: 787*da0073e9SAndroid Build Coastguard Worker for k, v in kwargs["dynamo_stats"].items(): 788*da0073e9SAndroid Build Coastguard Worker headers.append(k) 789*da0073e9SAndroid Build Coastguard Worker row.append(v) 790*da0073e9SAndroid Build Coastguard Worker output_csv( 791*da0073e9SAndroid Build Coastguard Worker output_filename, 792*da0073e9SAndroid Build Coastguard Worker headers, 793*da0073e9SAndroid Build Coastguard Worker row, 794*da0073e9SAndroid Build Coastguard Worker ) 795*da0073e9SAndroid Build Coastguard Worker headers, data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True) 796*da0073e9SAndroid Build Coastguard Worker assert ( 797*da0073e9SAndroid Build Coastguard Worker output_filename.find(".csv") > 0 798*da0073e9SAndroid Build Coastguard Worker ), f"expected output_filename to be a .csv, but got {output_filename}" 799*da0073e9SAndroid Build Coastguard Worker output_csv( 800*da0073e9SAndroid Build Coastguard Worker output_filename[:-4] + "_compilation_metrics.csv", 801*da0073e9SAndroid Build Coastguard Worker first_headers + headers, 802*da0073e9SAndroid Build Coastguard Worker first_fields + data, 803*da0073e9SAndroid Build Coastguard Worker ) 804*da0073e9SAndroid Build Coastguard Worker return msg 805*da0073e9SAndroid Build Coastguard Worker 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Workerdef speedup_experiment_ds(args, model_iter_fn, model, example_inputs): 808*da0073e9SAndroid Build Coastguard Worker """ 809*da0073e9SAndroid Build Coastguard Worker Run dynamic shapes benchmarks. 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker Requires dynamic shape compatible models, which provide a list of example inputs. 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker Warms up using the first input example and then iterates the inputs, 814*da0073e9SAndroid Build Coastguard Worker measuring (and expecting minimal) variance between the runtime for different examples. 815*da0073e9SAndroid Build Coastguard Worker 816*da0073e9SAndroid Build Coastguard Worker """ 817*da0073e9SAndroid Build Coastguard Worker timings = np.zeros((args.repeat, len(example_inputs), 2), np.float64) 818*da0073e9SAndroid Build Coastguard Worker 819*da0073e9SAndroid Build Coastguard Worker if args.repeat > 5: 820*da0073e9SAndroid Build Coastguard Worker print( 821*da0073e9SAndroid Build Coastguard Worker f"\ndynamic shapes experiments are slow, consider setting --repeat less than {args.repeat}\n" 822*da0073e9SAndroid Build Coastguard Worker ) 823*da0073e9SAndroid Build Coastguard Worker 824*da0073e9SAndroid Build Coastguard Worker nwarmup = 4 825*da0073e9SAndroid Build Coastguard Worker for rep in range(args.repeat): 826*da0073e9SAndroid Build Coastguard Worker # Start each rep fresh, e.g. only warmup on example 0 827*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 828*da0073e9SAndroid Build Coastguard Worker optimized_model_iter_fn = optimize_ctx(model_iter_fn) 829*da0073e9SAndroid Build Coastguard Worker for _ in range(nwarmup): 830*da0073e9SAndroid Build Coastguard Worker optimized_model_iter_fn(model, example_inputs[0]) 831*da0073e9SAndroid Build Coastguard Worker 832*da0073e9SAndroid Build Coastguard Worker for input_idx, inputs in enumerate(example_inputs): 833*da0073e9SAndroid Build Coastguard Worker # interleave the runs to handle frequency scaling and load changes 834*da0073e9SAndroid Build Coastguard Worker timings[rep, input_idx, 0] = timed( 835*da0073e9SAndroid Build Coastguard Worker model, model_iter_fn, inputs, return_result=False 836*da0073e9SAndroid Build Coastguard Worker ) 837*da0073e9SAndroid Build Coastguard Worker # different from regular speedup_experiment, we _DO_ want to allow recompilation 838*da0073e9SAndroid Build Coastguard Worker timings[rep, input_idx, 1] = timed( 839*da0073e9SAndroid Build Coastguard Worker model, optimized_model_iter_fn, inputs, return_result=False 840*da0073e9SAndroid Build Coastguard Worker ) 841*da0073e9SAndroid Build Coastguard Worker medians = np.median(timings, axis=0) 842*da0073e9SAndroid Build Coastguard Worker speedups = list(medians[:, 0] / medians[:, 1]) 843*da0073e9SAndroid Build Coastguard Worker speedups_mean = np.mean(speedups) 844*da0073e9SAndroid Build Coastguard Worker speedups_median = np.median(speedups) 845*da0073e9SAndroid Build Coastguard Worker speedups_var = np.var(speedups) 846*da0073e9SAndroid Build Coastguard Worker 847*da0073e9SAndroid Build Coastguard Worker # TODO this x[0] is not going to work in general but bert only has 1 input 848*da0073e9SAndroid Build Coastguard Worker shapes = [x[0].shape for x in example_inputs] 849*da0073e9SAndroid Build Coastguard Worker shape_keys = sorted(set(shapes)) 850*da0073e9SAndroid Build Coastguard Worker shape_speedups = { 851*da0073e9SAndroid Build Coastguard Worker shape: [ 852*da0073e9SAndroid Build Coastguard Worker it[1] for it in filter(lambda it: it[0] == shape, zip(shapes, speedups)) 853*da0073e9SAndroid Build Coastguard Worker ] 854*da0073e9SAndroid Build Coastguard Worker for shape in shape_keys 855*da0073e9SAndroid Build Coastguard Worker } 856*da0073e9SAndroid Build Coastguard Worker output_str = ( 857*da0073e9SAndroid Build Coastguard Worker f"mean: {speedups_mean:.3f}, median: {speedups_median:.3f}, var: {speedups_var:.3f}" 858*da0073e9SAndroid Build Coastguard Worker + "\nSpeedups by shape: " 859*da0073e9SAndroid Build Coastguard Worker + "\n".join( 860*da0073e9SAndroid Build Coastguard Worker [ 861*da0073e9SAndroid Build Coastguard Worker f"{shape}: " 862*da0073e9SAndroid Build Coastguard Worker + ", ".join([f"{speedup: .3g}" for speedup in shape_speedups[shape]]) 863*da0073e9SAndroid Build Coastguard Worker for shape in shape_keys 864*da0073e9SAndroid Build Coastguard Worker ] 865*da0073e9SAndroid Build Coastguard Worker ) 866*da0073e9SAndroid Build Coastguard Worker ) 867*da0073e9SAndroid Build Coastguard Worker output_csv( 868*da0073e9SAndroid Build Coastguard Worker output_filename, 869*da0073e9SAndroid Build Coastguard Worker ("dev", "name", "batch_size", "speedup mean", "speedup median", "speedup var"), 870*da0073e9SAndroid Build Coastguard Worker [ 871*da0073e9SAndroid Build Coastguard Worker current_device, 872*da0073e9SAndroid Build Coastguard Worker current_name, 873*da0073e9SAndroid Build Coastguard Worker current_batch_size, 874*da0073e9SAndroid Build Coastguard Worker speedups_mean, 875*da0073e9SAndroid Build Coastguard Worker speedups_median, 876*da0073e9SAndroid Build Coastguard Worker speedups_var, 877*da0073e9SAndroid Build Coastguard Worker ], 878*da0073e9SAndroid Build Coastguard Worker ) 879*da0073e9SAndroid Build Coastguard Worker return output_str 880*da0073e9SAndroid Build Coastguard Worker 881*da0073e9SAndroid Build Coastguard Worker 882*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 883*da0073e9SAndroid Build Coastguard Workerdef override_synchronize_with_onnx_iobinding(iobinding): 884*da0073e9SAndroid Build Coastguard Worker global synchronize 885*da0073e9SAndroid Build Coastguard Worker prev_synchrnoize = synchronize 886*da0073e9SAndroid Build Coastguard Worker try: 887*da0073e9SAndroid Build Coastguard Worker if iobinding is not None: 888*da0073e9SAndroid Build Coastguard Worker 889*da0073e9SAndroid Build Coastguard Worker def new_synchronize(): 890*da0073e9SAndroid Build Coastguard Worker iobinding.synchronize_inputs() 891*da0073e9SAndroid Build Coastguard Worker iobinding.synchronize_outputs() 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker synchronize = new_synchronize 894*da0073e9SAndroid Build Coastguard Worker yield 895*da0073e9SAndroid Build Coastguard Worker finally: 896*da0073e9SAndroid Build Coastguard Worker synchronize = prev_synchrnoize 897*da0073e9SAndroid Build Coastguard Worker 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Workerdef speedup_experiment_onnx( 900*da0073e9SAndroid Build Coastguard Worker args, 901*da0073e9SAndroid Build Coastguard Worker model_iter_fn, 902*da0073e9SAndroid Build Coastguard Worker onnx_model: OnnxModel, 903*da0073e9SAndroid Build Coastguard Worker model, 904*da0073e9SAndroid Build Coastguard Worker example_inputs, 905*da0073e9SAndroid Build Coastguard Worker **kwargs, 906*da0073e9SAndroid Build Coastguard Worker): 907*da0073e9SAndroid Build Coastguard Worker """ 908*da0073e9SAndroid Build Coastguard Worker Measure speedups over eager. 909*da0073e9SAndroid Build Coastguard Worker 910*da0073e9SAndroid Build Coastguard Worker This function is responsible for the following: 911*da0073e9SAndroid Build Coastguard Worker 1. Creating iobinding with OnnxModel if device is CUDA, which is essential for perf measurement. 912*da0073e9SAndroid Build Coastguard Worker 2. Running ORT with OnnxModel. 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker Writes to ./{output_filename}, which should be 915*da0073e9SAndroid Build Coastguard Worker `pathlib.Path(self.output_dir) / f"{self.compiler}_{suite}_{self.dtype}_{self.mode}_{self.device}_{self.testing}.csv". 916*da0073e9SAndroid Build Coastguard Worker 917*da0073e9SAndroid Build Coastguard Worker TODO(bowbao): Record export time and export peak memory usage. 918*da0073e9SAndroid Build Coastguard Worker """ 919*da0073e9SAndroid Build Coastguard Worker timings = np.zeros((args.repeat, 2), np.float64) 920*da0073e9SAndroid Build Coastguard Worker is_correct = True 921*da0073e9SAndroid Build Coastguard Worker should_randomize_input = args.randomize_input 922*da0073e9SAndroid Build Coastguard Worker times = args.iterations_per_run 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Worker def create_onnx_input_binded_fn(onnx_model: OnnxModel, pt_inputs, example_outputs): 925*da0073e9SAndroid Build Coastguard Worker # Goal is to move the iobinding creation outside of the timer function. 926*da0073e9SAndroid Build Coastguard Worker iobinding, outputs = onnx_model.create_iobinding(pt_inputs, example_outputs) 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker def onnxrt_model_iter_fn(model, inputs, collect_outputs=True): 929*da0073e9SAndroid Build Coastguard Worker onnx_model.run_with_iobinding(iobinding, outputs) 930*da0073e9SAndroid Build Coastguard Worker if collect_outputs: 931*da0073e9SAndroid Build Coastguard Worker return outputs 932*da0073e9SAndroid Build Coastguard Worker 933*da0073e9SAndroid Build Coastguard Worker return onnxrt_model_iter_fn, iobinding 934*da0073e9SAndroid Build Coastguard Worker 935*da0073e9SAndroid Build Coastguard Worker def create_onnx_fn(onnx_model: OnnxModel, pt_inputs): 936*da0073e9SAndroid Build Coastguard Worker # NOTE: Making perf comparison fair by moving out the i/o adapting part. 937*da0073e9SAndroid Build Coastguard Worker # 1. Pre-adapt `pt_inputs` to `onnx_inputs` here. 938*da0073e9SAndroid Build Coastguard Worker # 2. Drop `onnx_outputs` to `pt_outputs` adapting. Output comparison is not part of perf measurement. 939*da0073e9SAndroid Build Coastguard Worker onnx_inputs = onnx_model.adapt_pt_inputs_to_onnx(pt_inputs) 940*da0073e9SAndroid Build Coastguard Worker 941*da0073e9SAndroid Build Coastguard Worker def onnxrt_model_iter_fn(model, inputs, collect_outputs=True): 942*da0073e9SAndroid Build Coastguard Worker return onnx_model.run_with_onnx_inputs(onnx_inputs) 943*da0073e9SAndroid Build Coastguard Worker 944*da0073e9SAndroid Build Coastguard Worker return onnxrt_model_iter_fn 945*da0073e9SAndroid Build Coastguard Worker 946*da0073e9SAndroid Build Coastguard Worker def timed_onnx(model, onnx_model: OnnxModel, inputs): 947*da0073e9SAndroid Build Coastguard Worker if current_device == "cpu" or onnx_model.is_cpu(): 948*da0073e9SAndroid Build Coastguard Worker onnxrt_model_iter_fn = create_onnx_fn(onnx_model, inputs) 949*da0073e9SAndroid Build Coastguard Worker iobinding = None 950*da0073e9SAndroid Build Coastguard Worker else: 951*da0073e9SAndroid Build Coastguard Worker onnxrt_model_iter_fn, iobinding = create_onnx_input_binded_fn( 952*da0073e9SAndroid Build Coastguard Worker onnx_model, inputs, expected_output 953*da0073e9SAndroid Build Coastguard Worker ) 954*da0073e9SAndroid Build Coastguard Worker with override_synchronize_with_onnx_iobinding(iobinding): 955*da0073e9SAndroid Build Coastguard Worker return timed( 956*da0073e9SAndroid Build Coastguard Worker model, 957*da0073e9SAndroid Build Coastguard Worker onnxrt_model_iter_fn, 958*da0073e9SAndroid Build Coastguard Worker inputs, 959*da0073e9SAndroid Build Coastguard Worker return_result=True, 960*da0073e9SAndroid Build Coastguard Worker times=times, 961*da0073e9SAndroid Build Coastguard Worker collect_outputs=args.collect_outputs, 962*da0073e9SAndroid Build Coastguard Worker ) 963*da0073e9SAndroid Build Coastguard Worker 964*da0073e9SAndroid Build Coastguard Worker # Insert ONNX warm-up 965*da0073e9SAndroid Build Coastguard Worker inputs = ( 966*da0073e9SAndroid Build Coastguard Worker randomize_input(copy.deepcopy(example_inputs)) 967*da0073e9SAndroid Build Coastguard Worker if should_randomize_input 968*da0073e9SAndroid Build Coastguard Worker else example_inputs 969*da0073e9SAndroid Build Coastguard Worker ) 970*da0073e9SAndroid Build Coastguard Worker _, expected_output = timed( 971*da0073e9SAndroid Build Coastguard Worker model, 972*da0073e9SAndroid Build Coastguard Worker model_iter_fn, 973*da0073e9SAndroid Build Coastguard Worker inputs, 974*da0073e9SAndroid Build Coastguard Worker return_result=True, 975*da0073e9SAndroid Build Coastguard Worker times=times, 976*da0073e9SAndroid Build Coastguard Worker collect_outputs=args.collect_outputs, 977*da0073e9SAndroid Build Coastguard Worker ) 978*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 979*da0073e9SAndroid Build Coastguard Worker timed_onnx(model, onnx_model, inputs) 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker for rep in range(args.repeat): 982*da0073e9SAndroid Build Coastguard Worker inputs = ( 983*da0073e9SAndroid Build Coastguard Worker randomize_input(copy.deepcopy(example_inputs)) 984*da0073e9SAndroid Build Coastguard Worker if should_randomize_input 985*da0073e9SAndroid Build Coastguard Worker else example_inputs 986*da0073e9SAndroid Build Coastguard Worker ) 987*da0073e9SAndroid Build Coastguard Worker if torch.cuda.device_count() > 1: 988*da0073e9SAndroid Build Coastguard Worker # Manually set correct torch.cuda.current_device to ensure torch.cuda.synchronize() works as intended. 989*da0073e9SAndroid Build Coastguard Worker # When there are more than 1 cuda devices, the first one is used for pytorch eager. 990*da0073e9SAndroid Build Coastguard Worker # The second one is used for onnx ort. 991*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_device(0) 992*da0073e9SAndroid Build Coastguard Worker timings[rep, 0], expected_output = timed( 993*da0073e9SAndroid Build Coastguard Worker model, 994*da0073e9SAndroid Build Coastguard Worker model_iter_fn, 995*da0073e9SAndroid Build Coastguard Worker inputs, 996*da0073e9SAndroid Build Coastguard Worker return_result=True, 997*da0073e9SAndroid Build Coastguard Worker times=times, 998*da0073e9SAndroid Build Coastguard Worker collect_outputs=args.collect_outputs, 999*da0073e9SAndroid Build Coastguard Worker ) 1000*da0073e9SAndroid Build Coastguard Worker if torch.cuda.device_count() > 1: 1001*da0073e9SAndroid Build Coastguard Worker # Manually set correct torch.cuda.current_device to ensure torch.cuda.synchronize() works as intended. 1002*da0073e9SAndroid Build Coastguard Worker # When there are more than 1 cuda devices, the first one is used for pytorch eager. 1003*da0073e9SAndroid Build Coastguard Worker # The second one is used for onnx ort. 1004*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_device(1) 1005*da0073e9SAndroid Build Coastguard Worker timings[rep, 1], actual_output = timed_onnx(model, onnx_model, inputs) 1006*da0073e9SAndroid Build Coastguard Worker 1007*da0073e9SAndroid Build Coastguard Worker pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue 1008*da0073e9SAndroid Build Coastguard Worker median = np.median(timings, axis=0) 1009*da0073e9SAndroid Build Coastguard Worker speedup = median[0] / median[1] 1010*da0073e9SAndroid Build Coastguard Worker if args.dump_raw_metrics: 1011*da0073e9SAndroid Build Coastguard Worker np.save( 1012*da0073e9SAndroid Build Coastguard Worker f"{output_filename[:-4]}-raw_timings-{current_name}-{current_device}.npy", 1013*da0073e9SAndroid Build Coastguard Worker timings, 1014*da0073e9SAndroid Build Coastguard Worker ) 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker headers = ["dev", "name", "batch_size", "speedup", "abs_latency"] 1017*da0073e9SAndroid Build Coastguard Worker row = [ 1018*da0073e9SAndroid Build Coastguard Worker current_device, 1019*da0073e9SAndroid Build Coastguard Worker current_name, 1020*da0073e9SAndroid Build Coastguard Worker current_batch_size, 1021*da0073e9SAndroid Build Coastguard Worker float(speedup), 1022*da0073e9SAndroid Build Coastguard Worker median[1] * 1000, 1023*da0073e9SAndroid Build Coastguard Worker ] 1024*da0073e9SAndroid Build Coastguard Worker if "compilation_latency" in kwargs: 1025*da0073e9SAndroid Build Coastguard Worker headers = headers + ["compilation_latency", "compression_ratio"] 1026*da0073e9SAndroid Build Coastguard Worker row.append(kwargs["compilation_latency"]) 1027*da0073e9SAndroid Build Coastguard Worker row.append(kwargs["compression_ratio"]) 1028*da0073e9SAndroid Build Coastguard Worker 1029*da0073e9SAndroid Build Coastguard Worker output_csv( 1030*da0073e9SAndroid Build Coastguard Worker output_filename, 1031*da0073e9SAndroid Build Coastguard Worker headers, 1032*da0073e9SAndroid Build Coastguard Worker row, 1033*da0073e9SAndroid Build Coastguard Worker ) 1034*da0073e9SAndroid Build Coastguard Worker headers, data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True) 1035*da0073e9SAndroid Build Coastguard Worker assert ( 1036*da0073e9SAndroid Build Coastguard Worker output_filename.find(".csv") > 0 1037*da0073e9SAndroid Build Coastguard Worker ), f"expected output_filename to be a .csv, but got {output_filename}" 1038*da0073e9SAndroid Build Coastguard Worker output_csv( 1039*da0073e9SAndroid Build Coastguard Worker output_filename[:-4] + "_compilation_metrics.csv", 1040*da0073e9SAndroid Build Coastguard Worker ["dev", "name", "batch_size"] + headers, 1041*da0073e9SAndroid Build Coastguard Worker [current_device, current_name, current_batch_size] + data, 1042*da0073e9SAndroid Build Coastguard Worker ) 1043*da0073e9SAndroid Build Coastguard Worker return format_speedup(speedup, pvalue, is_correct=is_correct) 1044*da0073e9SAndroid Build Coastguard Worker 1045*da0073e9SAndroid Build Coastguard Worker 1046*da0073e9SAndroid Build Coastguard Workerdef overhead_experiment(*args, model_iter_fn): 1047*da0073e9SAndroid Build Coastguard Worker """ 1048*da0073e9SAndroid Build Coastguard Worker Measure overheads of TorchDynamo by running with no backend (only 1049*da0073e9SAndroid Build Coastguard Worker eager+FX), and reporting speedup/slowdown over eager. 1050*da0073e9SAndroid Build Coastguard Worker 1051*da0073e9SAndroid Build Coastguard Worker Writes to ./overheads.csv 1052*da0073e9SAndroid Build Coastguard Worker """ 1053*da0073e9SAndroid Build Coastguard Worker return speedup_experiment(*args, model_iter_fn) 1054*da0073e9SAndroid Build Coastguard Worker 1055*da0073e9SAndroid Build Coastguard Worker 1056*da0073e9SAndroid Build Coastguard Workerdef print_fx(gm, example_inputs): 1057*da0073e9SAndroid Build Coastguard Worker print(gm.graph) 1058*da0073e9SAndroid Build Coastguard Worker return gm 1059*da0073e9SAndroid Build Coastguard Worker 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Workerdef print_aten_ops(gm, example_inputs): 1062*da0073e9SAndroid Build Coastguard Worker from functorch.compile import aot_module 1063*da0073e9SAndroid Build Coastguard Worker 1064*da0073e9SAndroid Build Coastguard Worker def trace_printer(gm, _): 1065*da0073e9SAndroid Build Coastguard Worker print(gm.graph) 1066*da0073e9SAndroid Build Coastguard Worker return gm 1067*da0073e9SAndroid Build Coastguard Worker 1068*da0073e9SAndroid Build Coastguard Worker return aot_module(gm, fw_compiler=trace_printer, bw_compiler=trace_printer) 1069*da0073e9SAndroid Build Coastguard Worker 1070*da0073e9SAndroid Build Coastguard Worker 1071*da0073e9SAndroid Build Coastguard Workerdef baselines(models, model_iter_fn, example_inputs, args): 1072*da0073e9SAndroid Build Coastguard Worker """ 1073*da0073e9SAndroid Build Coastguard Worker Common measurement code across all baseline experiments. 1074*da0073e9SAndroid Build Coastguard Worker """ 1075*da0073e9SAndroid Build Coastguard Worker models = list(models) 1076*da0073e9SAndroid Build Coastguard Worker for idx, (name, model) in enumerate(models): 1077*da0073e9SAndroid Build Coastguard Worker if idx == 0: 1078*da0073e9SAndroid Build Coastguard Worker result0 = model_iter_fn(model, example_inputs) 1079*da0073e9SAndroid Build Coastguard Worker elif model is not None: 1080*da0073e9SAndroid Build Coastguard Worker try: 1081*da0073e9SAndroid Build Coastguard Worker result = model_iter_fn(model, example_inputs) 1082*da0073e9SAndroid Build Coastguard Worker if same(result0, result): 1083*da0073e9SAndroid Build Coastguard Worker continue 1084*da0073e9SAndroid Build Coastguard Worker print(name, "is INCORRECT") 1085*da0073e9SAndroid Build Coastguard Worker except Exception: 1086*da0073e9SAndroid Build Coastguard Worker log.exception("error checking %s", name) 1087*da0073e9SAndroid Build Coastguard Worker models[idx] = (name, None) 1088*da0073e9SAndroid Build Coastguard Worker timings = np.zeros((args.repeat, len(models)), np.float64) 1089*da0073e9SAndroid Build Coastguard Worker timings.fill(1.0e10) 1090*da0073e9SAndroid Build Coastguard Worker for rep in range(args.repeat): 1091*da0073e9SAndroid Build Coastguard Worker for idx, (name, model) in enumerate(models): 1092*da0073e9SAndroid Build Coastguard Worker if model is not None: 1093*da0073e9SAndroid Build Coastguard Worker try: 1094*da0073e9SAndroid Build Coastguard Worker timings[rep, idx] = timed(model, model_iter_fn, example_inputs) 1095*da0073e9SAndroid Build Coastguard Worker except Exception: 1096*da0073e9SAndroid Build Coastguard Worker pass 1097*da0073e9SAndroid Build Coastguard Worker pvalue = [ 1098*da0073e9SAndroid Build Coastguard Worker ttest_ind(timings[:, 0], timings[:, i]).pvalue 1099*da0073e9SAndroid Build Coastguard Worker for i in range(1, timings.shape[1]) 1100*da0073e9SAndroid Build Coastguard Worker ] 1101*da0073e9SAndroid Build Coastguard Worker median = np.median(timings, axis=0) 1102*da0073e9SAndroid Build Coastguard Worker speedup = median[0] / median[1:] 1103*da0073e9SAndroid Build Coastguard Worker for idx, (name, model) in enumerate(models[1:]): 1104*da0073e9SAndroid Build Coastguard Worker if model is None: 1105*da0073e9SAndroid Build Coastguard Worker speedup[idx] = 0.0 1106*da0073e9SAndroid Build Coastguard Worker result = " ".join( 1107*da0073e9SAndroid Build Coastguard Worker [ 1108*da0073e9SAndroid Build Coastguard Worker format_speedup(s, p, m is not None) 1109*da0073e9SAndroid Build Coastguard Worker for s, p, m in zip(speedup, pvalue, [m for n, m in models[1:]]) 1110*da0073e9SAndroid Build Coastguard Worker ] 1111*da0073e9SAndroid Build Coastguard Worker ) 1112*da0073e9SAndroid Build Coastguard Worker output_csv( 1113*da0073e9SAndroid Build Coastguard Worker output_filename, 1114*da0073e9SAndroid Build Coastguard Worker ("dev", "name", "batch_size") + tuple(n for n, m in models[1:]), 1115*da0073e9SAndroid Build Coastguard Worker [current_device, current_name, current_batch_size] 1116*da0073e9SAndroid Build Coastguard Worker + [f"{x:.4f}" for x in speedup], 1117*da0073e9SAndroid Build Coastguard Worker ) 1118*da0073e9SAndroid Build Coastguard Worker return result 1119*da0073e9SAndroid Build Coastguard Worker 1120*da0073e9SAndroid Build Coastguard Worker 1121*da0073e9SAndroid Build Coastguard Workerdef xla(args, model_iter_fn, model, example_inputs): 1122*da0073e9SAndroid Build Coastguard Worker xla_dev = xm.xla_device(devkind=current_device) 1123*da0073e9SAndroid Build Coastguard Worker model_xla = copy.deepcopy(model).to("cpu").to(device=xla_dev) 1124*da0073e9SAndroid Build Coastguard Worker example_inputs_xla = tree_map_only( 1125*da0073e9SAndroid Build Coastguard Worker torch.Tensor, lambda x: x.to("cpu").to(device=xla_dev), example_inputs 1126*da0073e9SAndroid Build Coastguard Worker ) 1127*da0073e9SAndroid Build Coastguard Worker for _ in range(3): # warmup 1128*da0073e9SAndroid Build Coastguard Worker timed(model, model_iter_fn, example_inputs) 1129*da0073e9SAndroid Build Coastguard Worker timed(model_xla, model_iter_fn, example_inputs_xla) 1130*da0073e9SAndroid Build Coastguard Worker timings = np.zeros((args.repeat, 2), np.float64) 1131*da0073e9SAndroid Build Coastguard Worker timings.fill(1.0e10) 1132*da0073e9SAndroid Build Coastguard Worker for rep in range(args.repeat): 1133*da0073e9SAndroid Build Coastguard Worker timings[rep, 0] = timed(model, model_iter_fn, example_inputs) 1134*da0073e9SAndroid Build Coastguard Worker timings[rep, 1] = timed(model_xla, model_iter_fn, example_inputs_xla) 1135*da0073e9SAndroid Build Coastguard Worker 1136*da0073e9SAndroid Build Coastguard Worker pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue 1137*da0073e9SAndroid Build Coastguard Worker time_baseline, time_xla = np.median(timings, axis=0) 1138*da0073e9SAndroid Build Coastguard Worker speedup = time_baseline / time_xla 1139*da0073e9SAndroid Build Coastguard Worker output_csv( 1140*da0073e9SAndroid Build Coastguard Worker output_filename, 1141*da0073e9SAndroid Build Coastguard Worker ("dev", "name", "batch_size", "speedup", "time_baseline", "time_xla"), 1142*da0073e9SAndroid Build Coastguard Worker [ 1143*da0073e9SAndroid Build Coastguard Worker current_device, 1144*da0073e9SAndroid Build Coastguard Worker current_name, 1145*da0073e9SAndroid Build Coastguard Worker current_batch_size, 1146*da0073e9SAndroid Build Coastguard Worker speedup, 1147*da0073e9SAndroid Build Coastguard Worker time_baseline, 1148*da0073e9SAndroid Build Coastguard Worker time_xla, 1149*da0073e9SAndroid Build Coastguard Worker ], 1150*da0073e9SAndroid Build Coastguard Worker ) 1151*da0073e9SAndroid Build Coastguard Worker return format_speedup(speedup, pvalue) 1152*da0073e9SAndroid Build Coastguard Worker 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Workerdef try_script(model, example_inputs): 1155*da0073e9SAndroid Build Coastguard Worker try: 1156*da0073e9SAndroid Build Coastguard Worker return torch.jit.script(model) 1157*da0073e9SAndroid Build Coastguard Worker except Exception: 1158*da0073e9SAndroid Build Coastguard Worker return None 1159*da0073e9SAndroid Build Coastguard Worker 1160*da0073e9SAndroid Build Coastguard Worker 1161*da0073e9SAndroid Build Coastguard Workerclass AOTInductorModelCache: 1162*da0073e9SAndroid Build Coastguard Worker cache = dict() 1163*da0073e9SAndroid Build Coastguard Worker 1164*da0073e9SAndroid Build Coastguard Worker @classmethod 1165*da0073e9SAndroid Build Coastguard Worker def load(cls, model, example_inputs, device): 1166*da0073e9SAndroid Build Coastguard Worker import torch._inductor 1167*da0073e9SAndroid Build Coastguard Worker import torch.export._trace 1168*da0073e9SAndroid Build Coastguard Worker 1169*da0073e9SAndroid Build Coastguard Worker key = weakref.ref(model) 1170*da0073e9SAndroid Build Coastguard Worker if key not in cls.cache: 1171*da0073e9SAndroid Build Coastguard Worker # Register the output dataclass to pytree 1172*da0073e9SAndroid Build Coastguard Worker example_args, example_kwargs = _normalize_bench_inputs(example_inputs) 1173*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1174*da0073e9SAndroid Build Coastguard Worker # copy.deepcopy is required to prevent any surprising side-effect, 1175*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/113029 1176*da0073e9SAndroid Build Coastguard Worker example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs) 1177*da0073e9SAndroid Build Coastguard Worker 1178*da0073e9SAndroid Build Coastguard Worker if pytree._is_namedtuple_instance(example_outputs): 1179*da0073e9SAndroid Build Coastguard Worker typ = type(example_outputs) 1180*da0073e9SAndroid Build Coastguard Worker pytree._register_namedtuple( 1181*da0073e9SAndroid Build Coastguard Worker typ, 1182*da0073e9SAndroid Build Coastguard Worker serialized_type_name=f"{typ.__module__}.{typ.__name__}", 1183*da0073e9SAndroid Build Coastguard Worker ) 1184*da0073e9SAndroid Build Coastguard Worker else: 1185*da0073e9SAndroid Build Coastguard Worker _register_dataclass_output_as_pytree(example_outputs) 1186*da0073e9SAndroid Build Coastguard Worker 1187*da0073e9SAndroid Build Coastguard Worker # TODO(angelayi): change this to predispatch 1188*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/127513 needs to be fixed before changing 1189*da0073e9SAndroid Build Coastguard Worker # to predispatch to avoid performance regressions 1190*da0073e9SAndroid Build Coastguard Worker gm = torch.export._trace._export_to_torch_ir( 1191*da0073e9SAndroid Build Coastguard Worker model, 1192*da0073e9SAndroid Build Coastguard Worker example_args, 1193*da0073e9SAndroid Build Coastguard Worker example_kwargs, 1194*da0073e9SAndroid Build Coastguard Worker ) 1195*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1196*da0073e9SAndroid Build Coastguard Worker so_path = torch._inductor.aot_compile( 1197*da0073e9SAndroid Build Coastguard Worker gm, example_args, example_kwargs 1198*da0073e9SAndroid Build Coastguard Worker ) # type: ignore[arg-type] 1199*da0073e9SAndroid Build Coastguard Worker 1200*da0073e9SAndroid Build Coastguard Worker cls.cache[key] = torch._export.aot_load(so_path, device) 1201*da0073e9SAndroid Build Coastguard Worker 1202*da0073e9SAndroid Build Coastguard Worker return cls.cache[key] 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker 1205*da0073e9SAndroid Build Coastguard Workerdef export(model, example_inputs): 1206*da0073e9SAndroid Build Coastguard Worker example_args, example_kwargs = _normalize_bench_inputs(example_inputs) 1207*da0073e9SAndroid Build Coastguard Worker example_outputs = model(*example_args, **example_kwargs) 1208*da0073e9SAndroid Build Coastguard Worker _register_dataclass_output_as_pytree(example_outputs) 1209*da0073e9SAndroid Build Coastguard Worker 1210*da0073e9SAndroid Build Coastguard Worker ep = torch.export.export(model, example_args, example_kwargs) 1211*da0073e9SAndroid Build Coastguard Worker 1212*da0073e9SAndroid Build Coastguard Worker def opt_export(_, example_inputs): 1213*da0073e9SAndroid Build Coastguard Worker example_args, example_kwargs = _normalize_bench_inputs(example_inputs) 1214*da0073e9SAndroid Build Coastguard Worker return ep(*example_args, **example_kwargs) 1215*da0073e9SAndroid Build Coastguard Worker 1216*da0073e9SAndroid Build Coastguard Worker return opt_export 1217*da0073e9SAndroid Build Coastguard Worker 1218*da0073e9SAndroid Build Coastguard Worker 1219*da0073e9SAndroid Build Coastguard Workerdef export_aot_inductor(model, example_inputs, device): 1220*da0073e9SAndroid Build Coastguard Worker optimized = AOTInductorModelCache.load(model, example_inputs, device) 1221*da0073e9SAndroid Build Coastguard Worker 1222*da0073e9SAndroid Build Coastguard Worker def opt_aot_inductor(_, example_inputs, collect_outputs=False): 1223*da0073e9SAndroid Build Coastguard Worker example_args, example_kwargs = _normalize_bench_inputs(example_inputs) 1224*da0073e9SAndroid Build Coastguard Worker return optimized(*example_args, **example_kwargs) 1225*da0073e9SAndroid Build Coastguard Worker 1226*da0073e9SAndroid Build Coastguard Worker return opt_aot_inductor 1227*da0073e9SAndroid Build Coastguard Worker 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Workerdef download_retry_decorator(download_fn): 1230*da0073e9SAndroid Build Coastguard Worker """ 1231*da0073e9SAndroid Build Coastguard Worker Decorator function for applying retry logic to a download function. 1232*da0073e9SAndroid Build Coastguard Worker 1233*da0073e9SAndroid Build Coastguard Worker The wrapped function will be called up to 5 times and raises an exception if the function fails each time. 1234*da0073e9SAndroid Build Coastguard Worker After each unsuccessful attempt, there is a delay before the next attempt, which is increased linearly with the number of tries. 1235*da0073e9SAndroid Build Coastguard Worker 1236*da0073e9SAndroid Build Coastguard Worker Usage: 1237*da0073e9SAndroid Build Coastguard Worker @download_retry_decorator 1238*da0073e9SAndroid Build Coastguard Worker def download_function(model_name: str): 1239*da0073e9SAndroid Build Coastguard Worker # download logic goes here 1240*da0073e9SAndroid Build Coastguard Worker """ 1241*da0073e9SAndroid Build Coastguard Worker 1242*da0073e9SAndroid Build Coastguard Worker @functools.wraps(download_fn) 1243*da0073e9SAndroid Build Coastguard Worker def wrapper(self, *args, **kwargs) -> Any: 1244*da0073e9SAndroid Build Coastguard Worker tries = 0 1245*da0073e9SAndroid Build Coastguard Worker total_allowed_tries = MAX_DOWNLOAD_ATTEMPTS 1246*da0073e9SAndroid Build Coastguard Worker while tries <= total_allowed_tries: 1247*da0073e9SAndroid Build Coastguard Worker try: 1248*da0073e9SAndroid Build Coastguard Worker model = download_fn(self, *args, **kwargs) 1249*da0073e9SAndroid Build Coastguard Worker return model 1250*da0073e9SAndroid Build Coastguard Worker except Exception as e: 1251*da0073e9SAndroid Build Coastguard Worker tries += 1 1252*da0073e9SAndroid Build Coastguard Worker if tries <= total_allowed_tries: 1253*da0073e9SAndroid Build Coastguard Worker wait = tries * 30 1254*da0073e9SAndroid Build Coastguard Worker print( 1255*da0073e9SAndroid Build Coastguard Worker f"Failed to load model: {e}. Trying again ({tries}/{total_allowed_tries}) after {wait}s" 1256*da0073e9SAndroid Build Coastguard Worker ) 1257*da0073e9SAndroid Build Coastguard Worker time.sleep(wait) 1258*da0073e9SAndroid Build Coastguard Worker else: 1259*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( # noqa: B904 1260*da0073e9SAndroid Build Coastguard Worker f"Failed to load model '{args}' with following error(s): {str(e)}." 1261*da0073e9SAndroid Build Coastguard Worker ) 1262*da0073e9SAndroid Build Coastguard Worker 1263*da0073e9SAndroid Build Coastguard Worker return wrapper 1264*da0073e9SAndroid Build Coastguard Worker 1265*da0073e9SAndroid Build Coastguard Worker 1266*da0073e9SAndroid Build Coastguard Workerclass OnnxModel(abc.ABC): 1267*da0073e9SAndroid Build Coastguard Worker TORCH_TO_NUMPY_DTYPE = { 1268*da0073e9SAndroid Build Coastguard Worker torch.float16: np.float16, 1269*da0073e9SAndroid Build Coastguard Worker torch.float32: np.float32, 1270*da0073e9SAndroid Build Coastguard Worker torch.float64: np.float64, 1271*da0073e9SAndroid Build Coastguard Worker torch.uint8: np.uint8, 1272*da0073e9SAndroid Build Coastguard Worker torch.int8: np.int8, 1273*da0073e9SAndroid Build Coastguard Worker torch.int16: np.int16, 1274*da0073e9SAndroid Build Coastguard Worker torch.int32: np.int32, 1275*da0073e9SAndroid Build Coastguard Worker torch.int64: np.longlong, 1276*da0073e9SAndroid Build Coastguard Worker torch.bool: np.bool_, 1277*da0073e9SAndroid Build Coastguard Worker } 1278*da0073e9SAndroid Build Coastguard Worker 1279*da0073e9SAndroid Build Coastguard Worker _COMPILER_NAME: str 1280*da0073e9SAndroid Build Coastguard Worker 1281*da0073e9SAndroid Build Coastguard Worker def __init__( 1282*da0073e9SAndroid Build Coastguard Worker self, 1283*da0073e9SAndroid Build Coastguard Worker output_directory, 1284*da0073e9SAndroid Build Coastguard Worker model, 1285*da0073e9SAndroid Build Coastguard Worker example_inputs, 1286*da0073e9SAndroid Build Coastguard Worker dynamic_shapes: bool, 1287*da0073e9SAndroid Build Coastguard Worker copy_before_export: bool = False, 1288*da0073e9SAndroid Build Coastguard Worker ): 1289*da0073e9SAndroid Build Coastguard Worker model_name = current_name 1290*da0073e9SAndroid Build Coastguard Worker self.copy_before_export = copy_before_export 1291*da0073e9SAndroid Build Coastguard Worker self.model_dir = self._generate_onnx_model_directory( 1292*da0073e9SAndroid Build Coastguard Worker output_directory, self._COMPILER_NAME, model_name 1293*da0073e9SAndroid Build Coastguard Worker ) 1294*da0073e9SAndroid Build Coastguard Worker self.model_path = str( 1295*da0073e9SAndroid Build Coastguard Worker self.model_dir / f"{model_name}_{self._COMPILER_NAME}.onnx" 1296*da0073e9SAndroid Build Coastguard Worker ) 1297*da0073e9SAndroid Build Coastguard Worker 1298*da0073e9SAndroid Build Coastguard Worker def _determine_deepcopy_target_device(self): 1299*da0073e9SAndroid Build Coastguard Worker if current_device == "cpu": 1300*da0073e9SAndroid Build Coastguard Worker target_device = "cpu" 1301*da0073e9SAndroid Build Coastguard Worker else: 1302*da0073e9SAndroid Build Coastguard Worker if torch.cuda.device_count() > 1: 1303*da0073e9SAndroid Build Coastguard Worker # Copy to another cuda device to avoid OOM. 1304*da0073e9SAndroid Build Coastguard Worker target_device = "cuda:1" 1305*da0073e9SAndroid Build Coastguard Worker else: 1306*da0073e9SAndroid Build Coastguard Worker target_device = "cuda" 1307*da0073e9SAndroid Build Coastguard Worker return target_device 1308*da0073e9SAndroid Build Coastguard Worker 1309*da0073e9SAndroid Build Coastguard Worker def deepcopy_model_and_inputs_to_device(self, model, example_inputs, target_device): 1310*da0073e9SAndroid Build Coastguard Worker # Deepcopy model before export to avoid modification to baseline model. 1311*da0073e9SAndroid Build Coastguard Worker # To avoid OOM, the model is first moved to CPU. Both models are then moved to device. 1312*da0073e9SAndroid Build Coastguard Worker model_device = next(model.parameters()).device 1313*da0073e9SAndroid Build Coastguard Worker model.to("cpu") 1314*da0073e9SAndroid Build Coastguard Worker model_copy = copy.deepcopy(model).to(target_device) 1315*da0073e9SAndroid Build Coastguard Worker model.to(model_device) 1316*da0073e9SAndroid Build Coastguard Worker 1317*da0073e9SAndroid Build Coastguard Worker target_device_example_inputs = tree_map_only( 1318*da0073e9SAndroid Build Coastguard Worker torch.Tensor, lambda x: x.to(device=target_device), example_inputs 1319*da0073e9SAndroid Build Coastguard Worker ) 1320*da0073e9SAndroid Build Coastguard Worker 1321*da0073e9SAndroid Build Coastguard Worker return model_copy, target_device_example_inputs 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker @classmethod 1324*da0073e9SAndroid Build Coastguard Worker def _generate_onnx_model_directory( 1325*da0073e9SAndroid Build Coastguard Worker cls, output_directory: str, compiler_name: str, model_name: str 1326*da0073e9SAndroid Build Coastguard Worker ) -> pathlib.Path: 1327*da0073e9SAndroid Build Coastguard Worker model_path = pathlib.Path( 1328*da0073e9SAndroid Build Coastguard Worker output_directory, 1329*da0073e9SAndroid Build Coastguard Worker ".onnx_models", 1330*da0073e9SAndroid Build Coastguard Worker model_name, 1331*da0073e9SAndroid Build Coastguard Worker compiler_name, 1332*da0073e9SAndroid Build Coastguard Worker ) 1333*da0073e9SAndroid Build Coastguard Worker if model_path.exists() and model_path.is_dir(): 1334*da0073e9SAndroid Build Coastguard Worker shutil.rmtree(model_path) 1335*da0073e9SAndroid Build Coastguard Worker model_path.mkdir(parents=True, exist_ok=True) 1336*da0073e9SAndroid Build Coastguard Worker return model_path 1337*da0073e9SAndroid Build Coastguard Worker 1338*da0073e9SAndroid Build Coastguard Worker @abc.abstractmethod 1339*da0073e9SAndroid Build Coastguard Worker def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]: 1340*da0073e9SAndroid Build Coastguard Worker ... 1341*da0073e9SAndroid Build Coastguard Worker 1342*da0073e9SAndroid Build Coastguard Worker @abc.abstractmethod 1343*da0073e9SAndroid Build Coastguard Worker def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]: 1344*da0073e9SAndroid Build Coastguard Worker ... 1345*da0073e9SAndroid Build Coastguard Worker 1346*da0073e9SAndroid Build Coastguard Worker def adapt_pt_inputs_to_onnx(self, pt_inputs) -> Mapping[str, np.ndarray]: 1347*da0073e9SAndroid Build Coastguard Worker pt_inputs = self.format_pt_inputs(pt_inputs) 1348*da0073e9SAndroid Build Coastguard Worker return { 1349*da0073e9SAndroid Build Coastguard Worker ort_input.name: pt_input.cpu().numpy() 1350*da0073e9SAndroid Build Coastguard Worker for ort_input, pt_input in zip(self.onnx_session.get_inputs(), pt_inputs) 1351*da0073e9SAndroid Build Coastguard Worker } 1352*da0073e9SAndroid Build Coastguard Worker 1353*da0073e9SAndroid Build Coastguard Worker def adapt_onnx_outputs_to_pt(self, onnx_outputs: List[np.ndarray]) -> Any: 1354*da0073e9SAndroid Build Coastguard Worker pt_outputs = [ 1355*da0073e9SAndroid Build Coastguard Worker torch.from_numpy(onnx_output).to(current_device) 1356*da0073e9SAndroid Build Coastguard Worker for onnx_output in onnx_outputs 1357*da0073e9SAndroid Build Coastguard Worker ] 1358*da0073e9SAndroid Build Coastguard Worker if len(pt_outputs) == 1: 1359*da0073e9SAndroid Build Coastguard Worker return pt_outputs[0] 1360*da0073e9SAndroid Build Coastguard Worker return pt_outputs 1361*da0073e9SAndroid Build Coastguard Worker 1362*da0073e9SAndroid Build Coastguard Worker def _init_ort_session(self, model_path: str): 1363*da0073e9SAndroid Build Coastguard Worker import onnxruntime 1364*da0073e9SAndroid Build Coastguard Worker 1365*da0073e9SAndroid Build Coastguard Worker if current_device == "cpu": 1366*da0073e9SAndroid Build Coastguard Worker ort_providers = ["CPUExecutionProvider"] 1367*da0073e9SAndroid Build Coastguard Worker else: 1368*da0073e9SAndroid Build Coastguard Worker # NOTE(bowbao): Reduce OOM by running ORT on another gpu. 1369*da0073e9SAndroid Build Coastguard Worker # TODO(bowbao): This works to avoid OOM, but performance is surprisingly very bad. 1370*da0073e9SAndroid Build Coastguard Worker cuda_provider_options = { 1371*da0073e9SAndroid Build Coastguard Worker "device_id": 1 if torch.cuda.device_count() > 1 else 0, 1372*da0073e9SAndroid Build Coastguard Worker } 1373*da0073e9SAndroid Build Coastguard Worker ort_providers = [("CUDAExecutionProvider", cuda_provider_options)] 1374*da0073e9SAndroid Build Coastguard Worker session_options = onnxruntime.SessionOptions() 1375*da0073e9SAndroid Build Coastguard Worker session_options.log_severity_level = 3 # Error 1376*da0073e9SAndroid Build Coastguard Worker 1377*da0073e9SAndroid Build Coastguard Worker ort_session = onnxruntime.InferenceSession( 1378*da0073e9SAndroid Build Coastguard Worker self.model_path, 1379*da0073e9SAndroid Build Coastguard Worker providers=ort_providers, 1380*da0073e9SAndroid Build Coastguard Worker sess_options=session_options, 1381*da0073e9SAndroid Build Coastguard Worker ) 1382*da0073e9SAndroid Build Coastguard Worker return ort_session 1383*da0073e9SAndroid Build Coastguard Worker 1384*da0073e9SAndroid Build Coastguard Worker def is_cpu(self) -> bool: 1385*da0073e9SAndroid Build Coastguard Worker return self.onnx_session.get_providers()[0] == "CPUExecutionProvider" 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker def cpu(self) -> Self: 1388*da0073e9SAndroid Build Coastguard Worker self.onnx_session.set_providers(["CPUExecutionProvider"]) 1389*da0073e9SAndroid Build Coastguard Worker return self 1390*da0073e9SAndroid Build Coastguard Worker 1391*da0073e9SAndroid Build Coastguard Worker def create_outputs(self, *example_outputs): 1392*da0073e9SAndroid Build Coastguard Worker return tuple(torch.empty_like(x) for x in example_outputs) 1393*da0073e9SAndroid Build Coastguard Worker 1394*da0073e9SAndroid Build Coastguard Worker def create_iobinding(self, pt_inputs, example_outputs): 1395*da0073e9SAndroid Build Coastguard Worker pt_inputs = self.format_pt_inputs(pt_inputs) 1396*da0073e9SAndroid Build Coastguard Worker example_outputs = self.format_pt_outputs(example_outputs) 1397*da0073e9SAndroid Build Coastguard Worker 1398*da0073e9SAndroid Build Coastguard Worker iobinding = self.onnx_session.io_binding() 1399*da0073e9SAndroid Build Coastguard Worker args = [arg.contiguous() for arg in pt_inputs] 1400*da0073e9SAndroid Build Coastguard Worker for ort_input, arg in zip(self.onnx_session.get_inputs(), args): 1401*da0073e9SAndroid Build Coastguard Worker # NOTE: Run ORT on another cuda device to reduce OOM. 1402*da0073e9SAndroid Build Coastguard Worker if torch.cuda.device_count() > 1: 1403*da0073e9SAndroid Build Coastguard Worker arg = arg.detach().to("cuda:1") 1404*da0073e9SAndroid Build Coastguard Worker device = arg.device 1405*da0073e9SAndroid Build Coastguard Worker iobinding.bind_input( 1406*da0073e9SAndroid Build Coastguard Worker ort_input.name, 1407*da0073e9SAndroid Build Coastguard Worker device.type, 1408*da0073e9SAndroid Build Coastguard Worker device.index or 0, 1409*da0073e9SAndroid Build Coastguard Worker self.TORCH_TO_NUMPY_DTYPE[arg.dtype], 1410*da0073e9SAndroid Build Coastguard Worker arg.size(), 1411*da0073e9SAndroid Build Coastguard Worker arg.data_ptr(), 1412*da0073e9SAndroid Build Coastguard Worker ) 1413*da0073e9SAndroid Build Coastguard Worker 1414*da0073e9SAndroid Build Coastguard Worker outputs = self.create_outputs(*example_outputs) 1415*da0073e9SAndroid Build Coastguard Worker for ort_output, output in zip(self.onnx_session.get_outputs(), outputs): 1416*da0073e9SAndroid Build Coastguard Worker if torch.cuda.device_count() > 1: 1417*da0073e9SAndroid Build Coastguard Worker output = output.detach().to("cuda:1") 1418*da0073e9SAndroid Build Coastguard Worker device = output.device 1419*da0073e9SAndroid Build Coastguard Worker iobinding.bind_output( 1420*da0073e9SAndroid Build Coastguard Worker ort_output.name, 1421*da0073e9SAndroid Build Coastguard Worker device.type, 1422*da0073e9SAndroid Build Coastguard Worker device.index or 0, 1423*da0073e9SAndroid Build Coastguard Worker self.TORCH_TO_NUMPY_DTYPE[output.dtype], 1424*da0073e9SAndroid Build Coastguard Worker output.size(), 1425*da0073e9SAndroid Build Coastguard Worker output.data_ptr(), 1426*da0073e9SAndroid Build Coastguard Worker ) 1427*da0073e9SAndroid Build Coastguard Worker return iobinding, outputs 1428*da0073e9SAndroid Build Coastguard Worker 1429*da0073e9SAndroid Build Coastguard Worker def run_with_iobinding(self, iobinding, outputs): 1430*da0073e9SAndroid Build Coastguard Worker # 'outputs' are torch empty tensors binded to 'iobinding'. 1431*da0073e9SAndroid Build Coastguard Worker self.onnx_session.run_with_iobinding(iobinding) 1432*da0073e9SAndroid Build Coastguard Worker return outputs 1433*da0073e9SAndroid Build Coastguard Worker 1434*da0073e9SAndroid Build Coastguard Worker def run_with_onnx_inputs(self, onnx_inputs): 1435*da0073e9SAndroid Build Coastguard Worker return self.onnx_session.run(None, onnx_inputs) 1436*da0073e9SAndroid Build Coastguard Worker 1437*da0073e9SAndroid Build Coastguard Worker @classmethod 1438*da0073e9SAndroid Build Coastguard Worker def save_tensor_data(cls, numpy_tensor, output_path): 1439*da0073e9SAndroid Build Coastguard Worker from onnx import numpy_helper 1440*da0073e9SAndroid Build Coastguard Worker 1441*da0073e9SAndroid Build Coastguard Worker proto_tensor = numpy_helper.from_array(numpy_tensor) 1442*da0073e9SAndroid Build Coastguard Worker with open(output_path, "wb") as f: 1443*da0073e9SAndroid Build Coastguard Worker f.write(proto_tensor.SerializeToString()) 1444*da0073e9SAndroid Build Coastguard Worker 1445*da0073e9SAndroid Build Coastguard Worker def run_and_serialize_inputs_outputs(self, pt_inputs): 1446*da0073e9SAndroid Build Coastguard Worker test_data_dir = self.model_dir / "test_data_set_0" 1447*da0073e9SAndroid Build Coastguard Worker test_data_dir.mkdir(parents=True, exist_ok=True) 1448*da0073e9SAndroid Build Coastguard Worker 1449*da0073e9SAndroid Build Coastguard Worker onnx_inputs = self.adapt_pt_inputs_to_onnx(pt_inputs) 1450*da0073e9SAndroid Build Coastguard Worker for i, onnx_input in enumerate(onnx_inputs.values()): 1451*da0073e9SAndroid Build Coastguard Worker self.save_tensor_data(onnx_input, str(test_data_dir / f"input_{i}.pb")) 1452*da0073e9SAndroid Build Coastguard Worker 1453*da0073e9SAndroid Build Coastguard Worker onnx_outputs = self.run_with_onnx_inputs(onnx_inputs) 1454*da0073e9SAndroid Build Coastguard Worker 1455*da0073e9SAndroid Build Coastguard Worker for i, onnx_output in enumerate(onnx_outputs): 1456*da0073e9SAndroid Build Coastguard Worker self.save_tensor_data(onnx_output, str(test_data_dir / f"output_{i}.pb")) 1457*da0073e9SAndroid Build Coastguard Worker 1458*da0073e9SAndroid Build Coastguard Worker return self.adapt_onnx_outputs_to_pt(onnx_outputs) 1459*da0073e9SAndroid Build Coastguard Worker 1460*da0073e9SAndroid Build Coastguard Worker def run(self, pt_inputs): 1461*da0073e9SAndroid Build Coastguard Worker # NOTE: For CUDA performance testing, use `run_with_iobinding` to exclude memory 1462*da0073e9SAndroid Build Coastguard Worker # copying overhead for inputs/outputs between cpu and gpu. 1463*da0073e9SAndroid Build Coastguard Worker # Otherwise perf number is inaccurate. 1464*da0073e9SAndroid Build Coastguard Worker onnx_inputs = self.adapt_pt_inputs_to_onnx(pt_inputs) 1465*da0073e9SAndroid Build Coastguard Worker onnx_outputs = self.run_with_onnx_inputs(onnx_inputs) 1466*da0073e9SAndroid Build Coastguard Worker return self.adapt_onnx_outputs_to_pt(onnx_outputs) 1467*da0073e9SAndroid Build Coastguard Worker 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Workerclass OnnxModelFromTorchScript(OnnxModel): 1470*da0073e9SAndroid Build Coastguard Worker """TorchScript based onnx export. `torch.onnx.export` 1471*da0073e9SAndroid Build Coastguard Worker 1472*da0073e9SAndroid Build Coastguard Worker TODO(bowbao): 1473*da0073e9SAndroid Build Coastguard Worker * large model export failed. 1474*da0073e9SAndroid Build Coastguard Worker Onnx Model is larger than 2GB, but exporter makes decision based pt model size, which is 1475*da0073e9SAndroid Build Coastguard Worker smaller than 2GB. 1476*da0073e9SAndroid Build Coastguard Worker * OOM on slightly larger model. 1477*da0073e9SAndroid Build Coastguard Worker Both pt model and ort inference session are on gpu. Attempt has been made to move ORT to 1478*da0073e9SAndroid Build Coastguard Worker cuda:1, however ORT perf drop significantly. 1479*da0073e9SAndroid Build Coastguard Worker For now running everything with batch_size 1 set in launch script. 1480*da0073e9SAndroid Build Coastguard Worker """ 1481*da0073e9SAndroid Build Coastguard Worker 1482*da0073e9SAndroid Build Coastguard Worker _COMPILER_NAME = "torchscript" 1483*da0073e9SAndroid Build Coastguard Worker 1484*da0073e9SAndroid Build Coastguard Worker def __init__( 1485*da0073e9SAndroid Build Coastguard Worker self, output_directory, model, example_inputs, dynamic_shapes: bool, **kwargs 1486*da0073e9SAndroid Build Coastguard Worker ): 1487*da0073e9SAndroid Build Coastguard Worker if dynamic_shapes: 1488*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("NYI dynamic shapes for OnnxModelFromTorchScript") 1489*da0073e9SAndroid Build Coastguard Worker super().__init__( 1490*da0073e9SAndroid Build Coastguard Worker output_directory, model, example_inputs, dynamic_shapes, **kwargs 1491*da0073e9SAndroid Build Coastguard Worker ) 1492*da0073e9SAndroid Build Coastguard Worker self._export( 1493*da0073e9SAndroid Build Coastguard Worker model, 1494*da0073e9SAndroid Build Coastguard Worker example_inputs, 1495*da0073e9SAndroid Build Coastguard Worker self.model_path, 1496*da0073e9SAndroid Build Coastguard Worker opset_version=17, 1497*da0073e9SAndroid Build Coastguard Worker do_constant_folding=False, 1498*da0073e9SAndroid Build Coastguard Worker verbose=False, 1499*da0073e9SAndroid Build Coastguard Worker ) 1500*da0073e9SAndroid Build Coastguard Worker self.onnx_session = self._init_ort_session(self.model_path) 1501*da0073e9SAndroid Build Coastguard Worker 1502*da0073e9SAndroid Build Coastguard Worker def _export(self, model, example_inputs, output_path: str, /, **kwargs) -> None: 1503*da0073e9SAndroid Build Coastguard Worker if self.copy_before_export: 1504*da0073e9SAndroid Build Coastguard Worker # Deepcopy model before export to avoid modification to baseline model. 1505*da0073e9SAndroid Build Coastguard Worker model, example_inputs = self.deepcopy_model_and_inputs_to_device( 1506*da0073e9SAndroid Build Coastguard Worker model, example_inputs, self._determine_deepcopy_target_device() 1507*da0073e9SAndroid Build Coastguard Worker ) 1508*da0073e9SAndroid Build Coastguard Worker 1509*da0073e9SAndroid Build Coastguard Worker # Hack for huggingface models (kwargs only). 1510*da0073e9SAndroid Build Coastguard Worker if isinstance(example_inputs, dict): 1511*da0073e9SAndroid Build Coastguard Worker 1512*da0073e9SAndroid Build Coastguard Worker class WrapperModel(torch.nn.Module): 1513*da0073e9SAndroid Build Coastguard Worker def __init__(self, model, keys): 1514*da0073e9SAndroid Build Coastguard Worker super().__init__() 1515*da0073e9SAndroid Build Coastguard Worker self.model = model 1516*da0073e9SAndroid Build Coastguard Worker self.keys = keys 1517*da0073e9SAndroid Build Coastguard Worker 1518*da0073e9SAndroid Build Coastguard Worker def forward(self, *args): 1519*da0073e9SAndroid Build Coastguard Worker return self.model(**dict(zip(self.keys, args))) 1520*da0073e9SAndroid Build Coastguard Worker 1521*da0073e9SAndroid Build Coastguard Worker model = WrapperModel(model, list(example_inputs.keys())) 1522*da0073e9SAndroid Build Coastguard Worker 1523*da0073e9SAndroid Build Coastguard Worker torch.onnx.export( 1524*da0073e9SAndroid Build Coastguard Worker model, 1525*da0073e9SAndroid Build Coastguard Worker self.format_pt_inputs(example_inputs), 1526*da0073e9SAndroid Build Coastguard Worker output_path, 1527*da0073e9SAndroid Build Coastguard Worker **kwargs, 1528*da0073e9SAndroid Build Coastguard Worker ) 1529*da0073e9SAndroid Build Coastguard Worker 1530*da0073e9SAndroid Build Coastguard Worker def format_pt_inputs(self, pt_inputs): 1531*da0073e9SAndroid Build Coastguard Worker # NOTE(bowbao): For huggingface benchmark, pt_inputs are formatted as dictionary, 1532*da0073e9SAndroid Build Coastguard Worker # and consumed like `model(**pt_inputs)`. 1533*da0073e9SAndroid Build Coastguard Worker # For other benchmarks, pt_inputs are formatted as tuple and consumed 1534*da0073e9SAndroid Build Coastguard Worker # like `model(*pt_inputs)`. 1535*da0073e9SAndroid Build Coastguard Worker if isinstance(pt_inputs, dict): 1536*da0073e9SAndroid Build Coastguard Worker pt_inputs = list(pt_inputs.values()) 1537*da0073e9SAndroid Build Coastguard Worker if isinstance(pt_inputs, torch.Tensor): 1538*da0073e9SAndroid Build Coastguard Worker pt_inputs = (pt_inputs,) 1539*da0073e9SAndroid Build Coastguard Worker return tuple(arg.contiguous() for arg in pt_inputs) 1540*da0073e9SAndroid Build Coastguard Worker 1541*da0073e9SAndroid Build Coastguard Worker def format_pt_outputs(self, pt_outputs): 1542*da0073e9SAndroid Build Coastguard Worker if isinstance(pt_outputs, torch.Tensor): 1543*da0073e9SAndroid Build Coastguard Worker pt_outputs = (pt_outputs,) 1544*da0073e9SAndroid Build Coastguard Worker 1545*da0073e9SAndroid Build Coastguard Worker pt_outputs = pytree.tree_leaves(pt_outputs) 1546*da0073e9SAndroid Build Coastguard Worker 1547*da0073e9SAndroid Build Coastguard Worker # Hack for huggingface model outputs 1548*da0073e9SAndroid Build Coastguard Worker try: 1549*da0073e9SAndroid Build Coastguard Worker from transformers import modeling_outputs 1550*da0073e9SAndroid Build Coastguard Worker except ImportError: 1551*da0073e9SAndroid Build Coastguard Worker pass 1552*da0073e9SAndroid Build Coastguard Worker else: 1553*da0073e9SAndroid Build Coastguard Worker 1554*da0073e9SAndroid Build Coastguard Worker def _to_tuple(x): 1555*da0073e9SAndroid Build Coastguard Worker if isinstance(x, modeling_outputs.ModelOutput): 1556*da0073e9SAndroid Build Coastguard Worker return x.to_tuple() 1557*da0073e9SAndroid Build Coastguard Worker return x 1558*da0073e9SAndroid Build Coastguard Worker 1559*da0073e9SAndroid Build Coastguard Worker pt_outputs = pytree.tree_map(_to_tuple, pt_outputs) 1560*da0073e9SAndroid Build Coastguard Worker pt_outputs = pytree.tree_leaves(pt_outputs) 1561*da0073e9SAndroid Build Coastguard Worker 1562*da0073e9SAndroid Build Coastguard Worker return pt_outputs 1563*da0073e9SAndroid Build Coastguard Worker 1564*da0073e9SAndroid Build Coastguard Worker 1565*da0073e9SAndroid Build Coastguard Workerclass OnnxModelFromDynamo(OnnxModel): 1566*da0073e9SAndroid Build Coastguard Worker """Dynamo and Fx based export. `torch.onnx.dynamo_export`.""" 1567*da0073e9SAndroid Build Coastguard Worker 1568*da0073e9SAndroid Build Coastguard Worker _COMPILER_NAME = "dynamo" 1569*da0073e9SAndroid Build Coastguard Worker 1570*da0073e9SAndroid Build Coastguard Worker def __init__( 1571*da0073e9SAndroid Build Coastguard Worker self, output_directory, model, example_inputs, dynamic_shapes: bool, **kwargs 1572*da0073e9SAndroid Build Coastguard Worker ): 1573*da0073e9SAndroid Build Coastguard Worker super().__init__( 1574*da0073e9SAndroid Build Coastguard Worker output_directory, model, example_inputs, dynamic_shapes, **kwargs 1575*da0073e9SAndroid Build Coastguard Worker ) 1576*da0073e9SAndroid Build Coastguard Worker self._dynamic_shapes = dynamic_shapes 1577*da0073e9SAndroid Build Coastguard Worker self._onnx_program = self._export(model, example_inputs, self.model_path) 1578*da0073e9SAndroid Build Coastguard Worker # Clear the model proto to save memory. 1579*da0073e9SAndroid Build Coastguard Worker # The model proto is saved to disk and no longer needed from `onnx_program`. 1580*da0073e9SAndroid Build Coastguard Worker # `onnx_program` is kept for i/o adapter usage. 1581*da0073e9SAndroid Build Coastguard Worker self._onnx_program.model_proto.Clear() 1582*da0073e9SAndroid Build Coastguard Worker self.onnx_session = self._init_ort_session(self.model_path) 1583*da0073e9SAndroid Build Coastguard Worker 1584*da0073e9SAndroid Build Coastguard Worker def _export( 1585*da0073e9SAndroid Build Coastguard Worker self, model, example_inputs, output_path: str 1586*da0073e9SAndroid Build Coastguard Worker ) -> torch.onnx.ONNXProgram: 1587*da0073e9SAndroid Build Coastguard Worker if self.copy_before_export: 1588*da0073e9SAndroid Build Coastguard Worker # Deepcopy model before export to avoid modification to baseline model. 1589*da0073e9SAndroid Build Coastguard Worker model, example_inputs = self.deepcopy_model_and_inputs_to_device( 1590*da0073e9SAndroid Build Coastguard Worker model, example_inputs, self._determine_deepcopy_target_device() 1591*da0073e9SAndroid Build Coastguard Worker ) 1592*da0073e9SAndroid Build Coastguard Worker 1593*da0073e9SAndroid Build Coastguard Worker example_args, example_kwargs = _normalize_bench_inputs(example_inputs) 1594*da0073e9SAndroid Build Coastguard Worker options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes) 1595*da0073e9SAndroid Build Coastguard Worker onnx_program = torch.onnx.dynamo_export( 1596*da0073e9SAndroid Build Coastguard Worker model, *example_args, **example_kwargs, export_options=options 1597*da0073e9SAndroid Build Coastguard Worker ) 1598*da0073e9SAndroid Build Coastguard Worker 1599*da0073e9SAndroid Build Coastguard Worker onnx_program.save(output_path) 1600*da0073e9SAndroid Build Coastguard Worker return onnx_program 1601*da0073e9SAndroid Build Coastguard Worker 1602*da0073e9SAndroid Build Coastguard Worker def format_pt_inputs(self, pt_inputs): 1603*da0073e9SAndroid Build Coastguard Worker pt_args, pt_kwargs = _normalize_bench_inputs(pt_inputs) 1604*da0073e9SAndroid Build Coastguard Worker return self._onnx_program.adapt_torch_inputs_to_onnx(*pt_args, **pt_kwargs) 1605*da0073e9SAndroid Build Coastguard Worker 1606*da0073e9SAndroid Build Coastguard Worker def format_pt_outputs(self, pt_outputs): 1607*da0073e9SAndroid Build Coastguard Worker return self._onnx_program.adapt_torch_outputs_to_onnx(pt_outputs) 1608*da0073e9SAndroid Build Coastguard Worker 1609*da0073e9SAndroid Build Coastguard Worker 1610*da0073e9SAndroid Build Coastguard Workerclass OnnxModelFromDynamoAotInline(OnnxModelFromDynamo): 1611*da0073e9SAndroid Build Coastguard Worker """Dynamo and Fx based export, with AOT inline post export. `torch.onnx.dynamo_export`.""" 1612*da0073e9SAndroid Build Coastguard Worker 1613*da0073e9SAndroid Build Coastguard Worker _COMPILER_NAME = "dynamo_aot_inline" 1614*da0073e9SAndroid Build Coastguard Worker 1615*da0073e9SAndroid Build Coastguard Worker def _export( 1616*da0073e9SAndroid Build Coastguard Worker self, model, example_inputs, output_path: str 1617*da0073e9SAndroid Build Coastguard Worker ) -> torch.onnx.ONNXProgram: 1618*da0073e9SAndroid Build Coastguard Worker if self.copy_before_export: 1619*da0073e9SAndroid Build Coastguard Worker # Deepcopy model before export to avoid modification to baseline model. 1620*da0073e9SAndroid Build Coastguard Worker model, example_inputs = self.deepcopy_model_and_inputs_to_device( 1621*da0073e9SAndroid Build Coastguard Worker model, example_inputs, self._determine_deepcopy_target_device() 1622*da0073e9SAndroid Build Coastguard Worker ) 1623*da0073e9SAndroid Build Coastguard Worker 1624*da0073e9SAndroid Build Coastguard Worker example_args, example_kwargs = _normalize_bench_inputs(example_inputs) 1625*da0073e9SAndroid Build Coastguard Worker options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes) 1626*da0073e9SAndroid Build Coastguard Worker onnx_program = torch.onnx.dynamo_export( 1627*da0073e9SAndroid Build Coastguard Worker model, *example_args, **example_kwargs, export_options=options 1628*da0073e9SAndroid Build Coastguard Worker ) 1629*da0073e9SAndroid Build Coastguard Worker # Apply AOT inline post export. 1630*da0073e9SAndroid Build Coastguard Worker # Requires onnx >= 1.15 1631*da0073e9SAndroid Build Coastguard Worker import onnx 1632*da0073e9SAndroid Build Coastguard Worker import onnx.inliner 1633*da0073e9SAndroid Build Coastguard Worker 1634*da0073e9SAndroid Build Coastguard Worker # Workaround for inliner not supporting with models larger than 2GB. 1635*da0073e9SAndroid Build Coastguard Worker # Save model to disk first separating out external data, 1636*da0073e9SAndroid Build Coastguard Worker # and load back without external data for inliner to work on. 1637*da0073e9SAndroid Build Coastguard Worker model_proto = onnx_program.model_proto 1638*da0073e9SAndroid Build Coastguard Worker onnx.save_model(model_proto, output_path, save_as_external_data=True) 1639*da0073e9SAndroid Build Coastguard Worker model_proto = onnx.load(output_path, load_external_data=False) 1640*da0073e9SAndroid Build Coastguard Worker model_proto = onnx.inliner.inline_local_functions(model_proto) 1641*da0073e9SAndroid Build Coastguard Worker onnx.save_model(model_proto, output_path) 1642*da0073e9SAndroid Build Coastguard Worker return onnx_program 1643*da0073e9SAndroid Build Coastguard Worker 1644*da0073e9SAndroid Build Coastguard Worker 1645*da0073e9SAndroid Build Coastguard Workerclass OnnxModelFromDynamoAotOptimize(OnnxModelFromDynamo): 1646*da0073e9SAndroid Build Coastguard Worker """Dynamo and Fx based export, with AOT optimize post export. `torch.onnx.dynamo_export`.""" 1647*da0073e9SAndroid Build Coastguard Worker 1648*da0073e9SAndroid Build Coastguard Worker _COMPILER_NAME = "dynamo_aot_optimize" 1649*da0073e9SAndroid Build Coastguard Worker 1650*da0073e9SAndroid Build Coastguard Worker def _export( 1651*da0073e9SAndroid Build Coastguard Worker self, model, example_inputs, output_path: str 1652*da0073e9SAndroid Build Coastguard Worker ) -> torch.onnx.ONNXProgram: 1653*da0073e9SAndroid Build Coastguard Worker if self.copy_before_export: 1654*da0073e9SAndroid Build Coastguard Worker # Deepcopy model before export to avoid modification to baseline model. 1655*da0073e9SAndroid Build Coastguard Worker model, example_inputs = self.deepcopy_model_and_inputs_to_device( 1656*da0073e9SAndroid Build Coastguard Worker model, example_inputs, self._determine_deepcopy_target_device() 1657*da0073e9SAndroid Build Coastguard Worker ) 1658*da0073e9SAndroid Build Coastguard Worker 1659*da0073e9SAndroid Build Coastguard Worker example_args, example_kwargs = _normalize_bench_inputs(example_inputs) 1660*da0073e9SAndroid Build Coastguard Worker options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes) 1661*da0073e9SAndroid Build Coastguard Worker export_output = torch.onnx.dynamo_export( 1662*da0073e9SAndroid Build Coastguard Worker model, *example_args, **example_kwargs, export_options=options 1663*da0073e9SAndroid Build Coastguard Worker ) 1664*da0073e9SAndroid Build Coastguard Worker 1665*da0073e9SAndroid Build Coastguard Worker import onnx 1666*da0073e9SAndroid Build Coastguard Worker from onnxscript.rewriter.onnxruntime import rewrite 1667*da0073e9SAndroid Build Coastguard Worker 1668*da0073e9SAndroid Build Coastguard Worker model_proto = rewrite(export_output.model_proto) 1669*da0073e9SAndroid Build Coastguard Worker onnx.save_model( 1670*da0073e9SAndroid Build Coastguard Worker model_proto, 1671*da0073e9SAndroid Build Coastguard Worker output_path, 1672*da0073e9SAndroid Build Coastguard Worker save_as_external_data=True, 1673*da0073e9SAndroid Build Coastguard Worker all_tensors_to_one_file=True, 1674*da0073e9SAndroid Build Coastguard Worker ) 1675*da0073e9SAndroid Build Coastguard Worker 1676*da0073e9SAndroid Build Coastguard Worker return export_output 1677*da0073e9SAndroid Build Coastguard Worker 1678*da0073e9SAndroid Build Coastguard Worker 1679*da0073e9SAndroid Build Coastguard Workerclass _OnnxPatch: 1680*da0073e9SAndroid Build Coastguard Worker @classmethod 1681*da0073e9SAndroid Build Coastguard Worker def patch_non_tensor_outputs(cls, correct_result, new_result, fp64_outputs): 1682*da0073e9SAndroid Build Coastguard Worker """Patch non-tensor outputs to make them comparable with the correct result. 1683*da0073e9SAndroid Build Coastguard Worker 1684*da0073e9SAndroid Build Coastguard Worker ONNX model always returns a flat tuple of tensors, but the PyTorch model outputs 1685*da0073e9SAndroid Build Coastguard Worker `correct_result` and `fp64_outputs` can be arbitrary types. This function normalizes 1686*da0073e9SAndroid Build Coastguard Worker the outputs to make them comparable with the ONNX model output. 1687*da0073e9SAndroid Build Coastguard Worker """ 1688*da0073e9SAndroid Build Coastguard Worker try: 1689*da0073e9SAndroid Build Coastguard Worker from transformers import modeling_outputs 1690*da0073e9SAndroid Build Coastguard Worker except ImportError: 1691*da0073e9SAndroid Build Coastguard Worker has_transformers = False 1692*da0073e9SAndroid Build Coastguard Worker else: 1693*da0073e9SAndroid Build Coastguard Worker has_transformers = True 1694*da0073e9SAndroid Build Coastguard Worker 1695*da0073e9SAndroid Build Coastguard Worker if has_transformers and isinstance( 1696*da0073e9SAndroid Build Coastguard Worker correct_result, modeling_outputs.ModelOutput 1697*da0073e9SAndroid Build Coastguard Worker ): 1698*da0073e9SAndroid Build Coastguard Worker correct_result = correct_result.to_tuple() 1699*da0073e9SAndroid Build Coastguard Worker fp64_outputs = fp64_outputs.to_tuple() if fp64_outputs is not None else None 1700*da0073e9SAndroid Build Coastguard Worker elif type(correct_result).__name__ in ( 1701*da0073e9SAndroid Build Coastguard Worker "MaskedLMOutput", 1702*da0073e9SAndroid Build Coastguard Worker "Seq2SeqLMOutput", 1703*da0073e9SAndroid Build Coastguard Worker "CausalLMOutputWithCrossAttentions", 1704*da0073e9SAndroid Build Coastguard Worker "LongformerMaskedLMOutput", 1705*da0073e9SAndroid Build Coastguard Worker "Instances", 1706*da0073e9SAndroid Build Coastguard Worker "SquashedNormal", 1707*da0073e9SAndroid Build Coastguard Worker "Boxes", 1708*da0073e9SAndroid Build Coastguard Worker "Normal", 1709*da0073e9SAndroid Build Coastguard Worker "TanhTransform", 1710*da0073e9SAndroid Build Coastguard Worker "Foo", 1711*da0073e9SAndroid Build Coastguard Worker "Variable", 1712*da0073e9SAndroid Build Coastguard Worker ): 1713*da0073e9SAndroid Build Coastguard Worker # Copied from `same` function in `torch._dynamo.utils` 1714*da0073e9SAndroid Build Coastguard Worker correct_result = [ 1715*da0073e9SAndroid Build Coastguard Worker value 1716*da0073e9SAndroid Build Coastguard Worker for key in correct_result.__dict__.keys() 1717*da0073e9SAndroid Build Coastguard Worker if (value := getattr(correct_result, key)) is not None 1718*da0073e9SAndroid Build Coastguard Worker ] 1719*da0073e9SAndroid Build Coastguard Worker fp64_outputs = ( 1720*da0073e9SAndroid Build Coastguard Worker [ 1721*da0073e9SAndroid Build Coastguard Worker value 1722*da0073e9SAndroid Build Coastguard Worker for key in fp64_outputs.__dict__.keys() 1723*da0073e9SAndroid Build Coastguard Worker if (value := getattr(fp64_outputs, key)) is not None 1724*da0073e9SAndroid Build Coastguard Worker ] 1725*da0073e9SAndroid Build Coastguard Worker if fp64_outputs is not None 1726*da0073e9SAndroid Build Coastguard Worker else None 1727*da0073e9SAndroid Build Coastguard Worker ) 1728*da0073e9SAndroid Build Coastguard Worker 1729*da0073e9SAndroid Build Coastguard Worker # Flatten nested tuple of tensors, i.e. past_key_values 1730*da0073e9SAndroid Build Coastguard Worker correct_result = pytree.tree_leaves(correct_result) 1731*da0073e9SAndroid Build Coastguard Worker # Hack to put results from different runs on same device. 1732*da0073e9SAndroid Build Coastguard Worker # This is needed for ONNX CPU fallback benchmark, where PyTorch eager is run on GPU. 1733*da0073e9SAndroid Build Coastguard Worker # Assuming outputs from a single run are always on same device! 1734*da0073e9SAndroid Build Coastguard Worker devices = [x.device for x in correct_result if isinstance(x, torch.Tensor)] 1735*da0073e9SAndroid Build Coastguard Worker assert devices and all( 1736*da0073e9SAndroid Build Coastguard Worker x == devices[0] for x in devices 1737*da0073e9SAndroid Build Coastguard Worker ), "All tensors must be on same device!" 1738*da0073e9SAndroid Build Coastguard Worker device = devices[0] 1739*da0073e9SAndroid Build Coastguard Worker new_result = pytree.tree_leaves(new_result) 1740*da0073e9SAndroid Build Coastguard Worker new_result = pytree.tree_map( 1741*da0073e9SAndroid Build Coastguard Worker lambda x: x.to(device=device) if isinstance(x, torch.Tensor) else x, 1742*da0073e9SAndroid Build Coastguard Worker new_result, 1743*da0073e9SAndroid Build Coastguard Worker ) 1744*da0073e9SAndroid Build Coastguard Worker fp64_outputs = pytree.tree_leaves(fp64_outputs) 1745*da0073e9SAndroid Build Coastguard Worker 1746*da0073e9SAndroid Build Coastguard Worker return correct_result, new_result, fp64_outputs 1747*da0073e9SAndroid Build Coastguard Worker 1748*da0073e9SAndroid Build Coastguard Worker 1749*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass 1750*da0073e9SAndroid Build Coastguard Workerclass OnnxExportErrorRow: 1751*da0073e9SAndroid Build Coastguard Worker device: str 1752*da0073e9SAndroid Build Coastguard Worker model_name: str 1753*da0073e9SAndroid Build Coastguard Worker batch_size: int 1754*da0073e9SAndroid Build Coastguard Worker rule_id: Optional[str] = None 1755*da0073e9SAndroid Build Coastguard Worker rule_name: Optional[str] = None 1756*da0073e9SAndroid Build Coastguard Worker diagnostic_level: Optional[str] = None 1757*da0073e9SAndroid Build Coastguard Worker diagnostic_message: Optional[str] = None 1758*da0073e9SAndroid Build Coastguard Worker exception_type_name: Optional[str] = None 1759*da0073e9SAndroid Build Coastguard Worker exception_message: Optional[str] = None 1760*da0073e9SAndroid Build Coastguard Worker 1761*da0073e9SAndroid Build Coastguard Worker def __post_init__(self): 1762*da0073e9SAndroid Build Coastguard Worker assert ( 1763*da0073e9SAndroid Build Coastguard Worker self.rule_id is not None 1764*da0073e9SAndroid Build Coastguard Worker and self.rule_name is not None 1765*da0073e9SAndroid Build Coastguard Worker and self.diagnostic_level is not None 1766*da0073e9SAndroid Build Coastguard Worker and self.diagnostic_message is not None 1767*da0073e9SAndroid Build Coastguard Worker ) or self.exception_type_name, ( 1768*da0073e9SAndroid Build Coastguard Worker "Either rule_id, rule_name, diagnostic_level and diagnostic_message " 1769*da0073e9SAndroid Build Coastguard Worker "must be set or exception_type_name must be set" 1770*da0073e9SAndroid Build Coastguard Worker ) 1771*da0073e9SAndroid Build Coastguard Worker 1772*da0073e9SAndroid Build Coastguard Worker @property 1773*da0073e9SAndroid Build Coastguard Worker def headers(self) -> List[str]: 1774*da0073e9SAndroid Build Coastguard Worker return [field.name for field in dataclasses.fields(self)] 1775*da0073e9SAndroid Build Coastguard Worker 1776*da0073e9SAndroid Build Coastguard Worker @property 1777*da0073e9SAndroid Build Coastguard Worker def row(self) -> List[str]: 1778*da0073e9SAndroid Build Coastguard Worker return [getattr(self, field.name) for field in dataclasses.fields(self)] 1779*da0073e9SAndroid Build Coastguard Worker 1780*da0073e9SAndroid Build Coastguard Worker 1781*da0073e9SAndroid Build Coastguard Workerclass OnnxExportErrorParser: 1782*da0073e9SAndroid Build Coastguard Worker def __init__(self, device: str, model_name: str, batch_size: int): 1783*da0073e9SAndroid Build Coastguard Worker self.device = device 1784*da0073e9SAndroid Build Coastguard Worker self.model_name = model_name 1785*da0073e9SAndroid Build Coastguard Worker self.batch_size = batch_size 1786*da0073e9SAndroid Build Coastguard Worker 1787*da0073e9SAndroid Build Coastguard Worker def _qualified_exception_class_name(self, exception: Exception) -> str: 1788*da0073e9SAndroid Build Coastguard Worker if exception.__class__.__module__ == "builtins": 1789*da0073e9SAndroid Build Coastguard Worker return exception.__class__.__name__ 1790*da0073e9SAndroid Build Coastguard Worker return f"{exception.__class__.__module__}.{exception.__class__.__name__}" 1791*da0073e9SAndroid Build Coastguard Worker 1792*da0073e9SAndroid Build Coastguard Worker def parse_diagnostic_context( 1793*da0073e9SAndroid Build Coastguard Worker self, 1794*da0073e9SAndroid Build Coastguard Worker diagnostic_context: diagnostics.DiagnosticContext, 1795*da0073e9SAndroid Build Coastguard Worker ) -> Generator[OnnxExportErrorRow, Any, Any]: 1796*da0073e9SAndroid Build Coastguard Worker from torch.onnx._internal.fx import diagnostics 1797*da0073e9SAndroid Build Coastguard Worker 1798*da0073e9SAndroid Build Coastguard Worker for diagnostic in diagnostic_context.diagnostics: 1799*da0073e9SAndroid Build Coastguard Worker if diagnostic.level >= diagnostics.levels.ERROR: 1800*da0073e9SAndroid Build Coastguard Worker yield OnnxExportErrorRow( 1801*da0073e9SAndroid Build Coastguard Worker device=self.device, 1802*da0073e9SAndroid Build Coastguard Worker model_name=self.model_name, 1803*da0073e9SAndroid Build Coastguard Worker batch_size=self.batch_size, 1804*da0073e9SAndroid Build Coastguard Worker rule_id=diagnostic.rule.id, 1805*da0073e9SAndroid Build Coastguard Worker rule_name=diagnostic.rule.name, 1806*da0073e9SAndroid Build Coastguard Worker diagnostic_level=diagnostic.level.name, 1807*da0073e9SAndroid Build Coastguard Worker diagnostic_message=diagnostic.message, 1808*da0073e9SAndroid Build Coastguard Worker ) 1809*da0073e9SAndroid Build Coastguard Worker 1810*da0073e9SAndroid Build Coastguard Worker def parse_exception(self, exception: Exception) -> OnnxExportErrorRow: 1811*da0073e9SAndroid Build Coastguard Worker return OnnxExportErrorRow( 1812*da0073e9SAndroid Build Coastguard Worker device=self.device, 1813*da0073e9SAndroid Build Coastguard Worker model_name=self.model_name, 1814*da0073e9SAndroid Build Coastguard Worker batch_size=self.batch_size, 1815*da0073e9SAndroid Build Coastguard Worker exception_type_name=self._qualified_exception_class_name(exception), 1816*da0073e9SAndroid Build Coastguard Worker exception_message=str(exception), 1817*da0073e9SAndroid Build Coastguard Worker ) 1818*da0073e9SAndroid Build Coastguard Worker 1819*da0073e9SAndroid Build Coastguard Worker 1820*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass 1821*da0073e9SAndroid Build Coastguard Workerclass OnnxContext: 1822*da0073e9SAndroid Build Coastguard Worker onnx_model: Optional[OnnxModel] = None 1823*da0073e9SAndroid Build Coastguard Worker 1824*da0073e9SAndroid Build Coastguard Worker 1825*da0073e9SAndroid Build Coastguard Workerdef optimize_onnx_ctx( 1826*da0073e9SAndroid Build Coastguard Worker output_directory: str, 1827*da0073e9SAndroid Build Coastguard Worker onnx_model_cls: Type[OnnxModel], 1828*da0073e9SAndroid Build Coastguard Worker run_n_iterations: Callable, 1829*da0073e9SAndroid Build Coastguard Worker dynamic_shapes: bool = False, 1830*da0073e9SAndroid Build Coastguard Worker copy_before_export: bool = False, 1831*da0073e9SAndroid Build Coastguard Worker) -> Callable: 1832*da0073e9SAndroid Build Coastguard Worker # NOTE(bowbao): This function creates and returns the onnx version of 'run_n_iterations', 1833*da0073e9SAndroid Build Coastguard Worker # which does the following: 1834*da0073e9SAndroid Build Coastguard Worker # 1. Export and cache model. 1835*da0073e9SAndroid Build Coastguard Worker # 2. Create iobinding for ORT. 1836*da0073e9SAndroid Build Coastguard Worker # 3. Run ORT for n iterations. 1837*da0073e9SAndroid Build Coastguard Worker # The cached model is stored in 'context' under the returned callable. 1838*da0073e9SAndroid Build Coastguard Worker context = OnnxContext() 1839*da0073e9SAndroid Build Coastguard Worker test_data_dumped = False 1840*da0073e9SAndroid Build Coastguard Worker 1841*da0073e9SAndroid Build Coastguard Worker def run_n_iterations_onnx(model, inputs, n=2): 1842*da0073e9SAndroid Build Coastguard Worker from torch.onnx._internal import exporter 1843*da0073e9SAndroid Build Coastguard Worker from torch.onnx._internal.fx import diagnostics 1844*da0073e9SAndroid Build Coastguard Worker 1845*da0073e9SAndroid Build Coastguard Worker # NOTE(bowbao): Capture all export & ort errors and diagnostics. 1846*da0073e9SAndroid Build Coastguard Worker # Serialize to csv, to be parsed and summarized later by '._onnx/reporter.py'. 1847*da0073e9SAndroid Build Coastguard Worker # TODO: Accuracy mismatch is not reported here in csv. 1848*da0073e9SAndroid Build Coastguard Worker assert ( 1849*da0073e9SAndroid Build Coastguard Worker output_filename.find(".csv") > 0 1850*da0073e9SAndroid Build Coastguard Worker ), f"expected output_filename to be a .csv, but got {output_filename}" 1851*da0073e9SAndroid Build Coastguard Worker output_error_filename = output_filename[:-4] + "_export_error.csv" 1852*da0073e9SAndroid Build Coastguard Worker parser = OnnxExportErrorParser(current_device, current_name, current_batch_size) 1853*da0073e9SAndroid Build Coastguard Worker try: 1854*da0073e9SAndroid Build Coastguard Worker nonlocal context 1855*da0073e9SAndroid Build Coastguard Worker if context.onnx_model is None: 1856*da0073e9SAndroid Build Coastguard Worker context.onnx_model = onnx_model_cls( 1857*da0073e9SAndroid Build Coastguard Worker output_directory, 1858*da0073e9SAndroid Build Coastguard Worker model, 1859*da0073e9SAndroid Build Coastguard Worker copy.deepcopy(inputs), 1860*da0073e9SAndroid Build Coastguard Worker dynamic_shapes=dynamic_shapes, 1861*da0073e9SAndroid Build Coastguard Worker copy_before_export=copy_before_export, 1862*da0073e9SAndroid Build Coastguard Worker ) 1863*da0073e9SAndroid Build Coastguard Worker onnx_model = context.onnx_model 1864*da0073e9SAndroid Build Coastguard Worker 1865*da0073e9SAndroid Build Coastguard Worker for _ in range(n): 1866*da0073e9SAndroid Build Coastguard Worker nonlocal test_data_dumped 1867*da0073e9SAndroid Build Coastguard Worker if not test_data_dumped: 1868*da0073e9SAndroid Build Coastguard Worker # Serializes inputs and outputs to .pb files for further offline analysis. 1869*da0073e9SAndroid Build Coastguard Worker # Due to this, this function is not and should not be used for perf measurement. 1870*da0073e9SAndroid Build Coastguard Worker outputs = onnx_model.run_and_serialize_inputs_outputs(inputs) 1871*da0073e9SAndroid Build Coastguard Worker test_data_dumped = True 1872*da0073e9SAndroid Build Coastguard Worker else: 1873*da0073e9SAndroid Build Coastguard Worker outputs = onnx_model.run(inputs) 1874*da0073e9SAndroid Build Coastguard Worker return outputs 1875*da0073e9SAndroid Build Coastguard Worker except exporter.OnnxExporterError as e: 1876*da0073e9SAndroid Build Coastguard Worker # `torch.onnx.dynamo_export` raises error that encloses diagnostics. 1877*da0073e9SAndroid Build Coastguard Worker diagnostic_context = e.onnx_program.diagnostic_context 1878*da0073e9SAndroid Build Coastguard Worker for parsed_error in parser.parse_diagnostic_context(diagnostic_context): 1879*da0073e9SAndroid Build Coastguard Worker output_csv( 1880*da0073e9SAndroid Build Coastguard Worker output_error_filename, parsed_error.headers, parsed_error.row 1881*da0073e9SAndroid Build Coastguard Worker ) 1882*da0073e9SAndroid Build Coastguard Worker if context.onnx_model is not None: 1883*da0073e9SAndroid Build Coastguard Worker e.onnx_program.save_diagnostics( 1884*da0073e9SAndroid Build Coastguard Worker f"{context.onnx_model.model_dir}/" 1885*da0073e9SAndroid Build Coastguard Worker f"{current_onnx_compiler}_{current_name}_{current_device}.sarif" 1886*da0073e9SAndroid Build Coastguard Worker ) 1887*da0073e9SAndroid Build Coastguard Worker 1888*da0073e9SAndroid Build Coastguard Worker # Check also the raw exception that caused export failure. 1889*da0073e9SAndroid Build Coastguard Worker # Skip if it is already analyzed by diagnostics. 1890*da0073e9SAndroid Build Coastguard Worker cause_of_exception = e.__cause__ 1891*da0073e9SAndroid Build Coastguard Worker if not isinstance( 1892*da0073e9SAndroid Build Coastguard Worker cause_of_exception, diagnostics.RuntimeErrorWithDiagnostic 1893*da0073e9SAndroid Build Coastguard Worker ): 1894*da0073e9SAndroid Build Coastguard Worker parsed_error = parser.parse_exception(cause_of_exception) 1895*da0073e9SAndroid Build Coastguard Worker output_csv( 1896*da0073e9SAndroid Build Coastguard Worker output_error_filename, parsed_error.headers, parsed_error.row 1897*da0073e9SAndroid Build Coastguard Worker ) 1898*da0073e9SAndroid Build Coastguard Worker raise 1899*da0073e9SAndroid Build Coastguard Worker except Exception as e: 1900*da0073e9SAndroid Build Coastguard Worker # `torch.onnx.export` errors. 1901*da0073e9SAndroid Build Coastguard Worker # ORT errors. 1902*da0073e9SAndroid Build Coastguard Worker parsed_error = parser.parse_exception(e) 1903*da0073e9SAndroid Build Coastguard Worker output_csv(output_error_filename, parsed_error.headers, parsed_error.row) 1904*da0073e9SAndroid Build Coastguard Worker raise 1905*da0073e9SAndroid Build Coastguard Worker 1906*da0073e9SAndroid Build Coastguard Worker run_n_iterations_onnx.context = context 1907*da0073e9SAndroid Build Coastguard Worker 1908*da0073e9SAndroid Build Coastguard Worker return run_n_iterations_onnx 1909*da0073e9SAndroid Build Coastguard Worker 1910*da0073e9SAndroid Build Coastguard Worker 1911*da0073e9SAndroid Build Coastguard Workerdef read_batch_size_from_file(args, filename, model_name): 1912*da0073e9SAndroid Build Coastguard Worker batch_size = None 1913*da0073e9SAndroid Build Coastguard Worker if os.path.exists("benchmarks"): 1914*da0073e9SAndroid Build Coastguard Worker filename = os.path.join("benchmarks", filename) 1915*da0073e9SAndroid Build Coastguard Worker assert os.path.exists(filename), filename 1916*da0073e9SAndroid Build Coastguard Worker with open(filename) as f: 1917*da0073e9SAndroid Build Coastguard Worker lines = f.readlines() 1918*da0073e9SAndroid Build Coastguard Worker lines = [i.split(",") for i in lines if len(i.strip()) > 0] 1919*da0073e9SAndroid Build Coastguard Worker for val in lines: 1920*da0073e9SAndroid Build Coastguard Worker cur_name, b = val 1921*da0073e9SAndroid Build Coastguard Worker if model_name == cur_name: 1922*da0073e9SAndroid Build Coastguard Worker batch_size = int(b) 1923*da0073e9SAndroid Build Coastguard Worker if batch_size is None: 1924*da0073e9SAndroid Build Coastguard Worker log.warning("Could not find batch size for %s", model_name) 1925*da0073e9SAndroid Build Coastguard Worker elif batch_size == -1: 1926*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1927*da0073e9SAndroid Build Coastguard Worker f"Batch size is unset for {model_name} in {args.batch_size_file}" 1928*da0073e9SAndroid Build Coastguard Worker ) 1929*da0073e9SAndroid Build Coastguard Worker print(f"batch size: {batch_size}") 1930*da0073e9SAndroid Build Coastguard Worker return batch_size 1931*da0073e9SAndroid Build Coastguard Worker 1932*da0073e9SAndroid Build Coastguard Worker 1933*da0073e9SAndroid Build Coastguard Workerclass TimeOutException(Exception): 1934*da0073e9SAndroid Build Coastguard Worker pass 1935*da0073e9SAndroid Build Coastguard Worker 1936*da0073e9SAndroid Build Coastguard Worker 1937*da0073e9SAndroid Build Coastguard Workerdef alarm_handler(signum, frame): 1938*da0073e9SAndroid Build Coastguard Worker raise TimeOutException 1939*da0073e9SAndroid Build Coastguard Worker 1940*da0073e9SAndroid Build Coastguard Worker 1941*da0073e9SAndroid Build Coastguard Workerdef exit_after(s): 1942*da0073e9SAndroid Build Coastguard Worker """ 1943*da0073e9SAndroid Build Coastguard Worker Decorator to raise TimeoutException if the fn is taking more than s seconds 1944*da0073e9SAndroid Build Coastguard Worker to run. 1945*da0073e9SAndroid Build Coastguard Worker """ 1946*da0073e9SAndroid Build Coastguard Worker 1947*da0073e9SAndroid Build Coastguard Worker def outer(fn): 1948*da0073e9SAndroid Build Coastguard Worker def inner(*args, **kwargs): 1949*da0073e9SAndroid Build Coastguard Worker signal.signal(signal.SIGALRM, alarm_handler) 1950*da0073e9SAndroid Build Coastguard Worker signal.alarm(s) 1951*da0073e9SAndroid Build Coastguard Worker try: 1952*da0073e9SAndroid Build Coastguard Worker result = fn(*args, **kwargs) 1953*da0073e9SAndroid Build Coastguard Worker finally: 1954*da0073e9SAndroid Build Coastguard Worker signal.alarm(0) 1955*da0073e9SAndroid Build Coastguard Worker return result 1956*da0073e9SAndroid Build Coastguard Worker 1957*da0073e9SAndroid Build Coastguard Worker return inner 1958*da0073e9SAndroid Build Coastguard Worker 1959*da0073e9SAndroid Build Coastguard Worker return outer 1960*da0073e9SAndroid Build Coastguard Worker 1961*da0073e9SAndroid Build Coastguard Worker 1962*da0073e9SAndroid Build Coastguard Workerdef get_peak_memory(): 1963*da0073e9SAndroid Build Coastguard Worker return torch.cuda.max_memory_allocated() / 10**9 1964*da0073e9SAndroid Build Coastguard Worker 1965*da0073e9SAndroid Build Coastguard Worker 1966*da0073e9SAndroid Build Coastguard Workerdef null_experiment(args, model_iter_fn, model, example_inputs): 1967*da0073e9SAndroid Build Coastguard Worker """ 1968*da0073e9SAndroid Build Coastguard Worker A no-op experiment useful for making sure TorchBenchark alone works properly. 1969*da0073e9SAndroid Build Coastguard Worker """ 1970*da0073e9SAndroid Build Coastguard Worker 1971*da0073e9SAndroid Build Coastguard Worker return [] 1972*da0073e9SAndroid Build Coastguard Worker 1973*da0073e9SAndroid Build Coastguard Worker 1974*da0073e9SAndroid Build Coastguard Workerdef cast_to(dtype, model, inputs): 1975*da0073e9SAndroid Build Coastguard Worker # cast model and inputs to fp16 1976*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16: 1977*da0073e9SAndroid Build Coastguard Worker model = model.half() 1978*da0073e9SAndroid Build Coastguard Worker else: 1979*da0073e9SAndroid Build Coastguard Worker model = model.to(dtype) 1980*da0073e9SAndroid Build Coastguard Worker 1981*da0073e9SAndroid Build Coastguard Worker inputs = tree_map( 1982*da0073e9SAndroid Build Coastguard Worker lambda x: x.to(dtype) 1983*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor) and x.is_floating_point() 1984*da0073e9SAndroid Build Coastguard Worker else x, 1985*da0073e9SAndroid Build Coastguard Worker inputs, 1986*da0073e9SAndroid Build Coastguard Worker ) 1987*da0073e9SAndroid Build Coastguard Worker return model, inputs 1988*da0073e9SAndroid Build Coastguard Worker 1989*da0073e9SAndroid Build Coastguard Worker 1990*da0073e9SAndroid Build Coastguard Workerdef cast_to_bf16(model, inputs): 1991*da0073e9SAndroid Build Coastguard Worker return cast_to(torch.bfloat16, model, inputs) 1992*da0073e9SAndroid Build Coastguard Worker 1993*da0073e9SAndroid Build Coastguard Worker 1994*da0073e9SAndroid Build Coastguard Workerdef cast_to_fp16(model, inputs): 1995*da0073e9SAndroid Build Coastguard Worker return cast_to(torch.float16, model, inputs) 1996*da0073e9SAndroid Build Coastguard Worker 1997*da0073e9SAndroid Build Coastguard Worker 1998*da0073e9SAndroid Build Coastguard Workerdef cast_to_fp64(model, inputs): 1999*da0073e9SAndroid Build Coastguard Worker return cast_to(torch.float64, model, inputs) 2000*da0073e9SAndroid Build Coastguard Worker 2001*da0073e9SAndroid Build Coastguard Worker 2002*da0073e9SAndroid Build Coastguard Workerdef cast_to_fp32(model, inputs): 2003*da0073e9SAndroid Build Coastguard Worker return cast_to(torch.float32, model, inputs) 2004*da0073e9SAndroid Build Coastguard Worker 2005*da0073e9SAndroid Build Coastguard Worker 2006*da0073e9SAndroid Build Coastguard Workerclass DummyGradScaler: 2007*da0073e9SAndroid Build Coastguard Worker def scale(self, loss): 2008*da0073e9SAndroid Build Coastguard Worker return loss 2009*da0073e9SAndroid Build Coastguard Worker 2010*da0073e9SAndroid Build Coastguard Worker 2011*da0073e9SAndroid Build Coastguard Workerdef get_dynamo_stats(): 2012*da0073e9SAndroid Build Coastguard Worker # TODO: consider deepcopy'ing the entire counters struct and 2013*da0073e9SAndroid Build Coastguard Worker # adding a helper to do subtraction on it 2014*da0073e9SAndroid Build Coastguard Worker return collections.Counter( 2015*da0073e9SAndroid Build Coastguard Worker { 2016*da0073e9SAndroid Build Coastguard Worker "calls_captured": torch._dynamo.utils.counters["stats"]["calls_captured"], 2017*da0073e9SAndroid Build Coastguard Worker "unique_graphs": torch._dynamo.utils.counters["stats"]["unique_graphs"], 2018*da0073e9SAndroid Build Coastguard Worker "graph_breaks": sum(torch._dynamo.utils.counters["graph_break"].values()), 2019*da0073e9SAndroid Build Coastguard Worker # NB: The plus removes zero counts 2020*da0073e9SAndroid Build Coastguard Worker "unique_graph_breaks": len(+torch._dynamo.utils.counters["graph_break"]), 2021*da0073e9SAndroid Build Coastguard Worker "autograd_captures": torch._dynamo.utils.counters["compiled_autograd"][ 2022*da0073e9SAndroid Build Coastguard Worker "captures" 2023*da0073e9SAndroid Build Coastguard Worker ], 2024*da0073e9SAndroid Build Coastguard Worker "autograd_compiles": torch._dynamo.utils.counters["compiled_autograd"][ 2025*da0073e9SAndroid Build Coastguard Worker "compiles" 2026*da0073e9SAndroid Build Coastguard Worker ], 2027*da0073e9SAndroid Build Coastguard Worker "cudagraph_skips": torch._dynamo.utils.counters["inductor"][ 2028*da0073e9SAndroid Build Coastguard Worker "cudagraph_skips" 2029*da0073e9SAndroid Build Coastguard Worker ], 2030*da0073e9SAndroid Build Coastguard Worker } 2031*da0073e9SAndroid Build Coastguard Worker ) 2032*da0073e9SAndroid Build Coastguard Worker 2033*da0073e9SAndroid Build Coastguard Worker 2034*da0073e9SAndroid Build Coastguard Worker@contextmanager 2035*da0073e9SAndroid Build Coastguard Workerdef maybe_init_distributed(should_init_distributed, rank, world_size, port="6789"): 2036*da0073e9SAndroid Build Coastguard Worker try: 2037*da0073e9SAndroid Build Coastguard Worker if should_init_distributed: 2038*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_device(rank) 2039*da0073e9SAndroid Build Coastguard Worker os.environ["MASTER_ADDR"] = "localhost" 2040*da0073e9SAndroid Build Coastguard Worker os.environ["MASTER_PORT"] = port 2041*da0073e9SAndroid Build Coastguard Worker torch.distributed.init_process_group( 2042*da0073e9SAndroid Build Coastguard Worker "nccl", rank=rank, world_size=world_size 2043*da0073e9SAndroid Build Coastguard Worker ) 2044*da0073e9SAndroid Build Coastguard Worker yield 2045*da0073e9SAndroid Build Coastguard Worker finally: 2046*da0073e9SAndroid Build Coastguard Worker if should_init_distributed: 2047*da0073e9SAndroid Build Coastguard Worker torch.distributed.destroy_process_group() 2048*da0073e9SAndroid Build Coastguard Worker 2049*da0073e9SAndroid Build Coastguard Worker 2050*da0073e9SAndroid Build Coastguard Worker@contextmanager 2051*da0073e9SAndroid Build Coastguard Workerdef maybe_snapshot_memory(should_snapshot_memory, suffix): 2052*da0073e9SAndroid Build Coastguard Worker # Enables Memory Snapshot tool for memory deep dives: 2053*da0073e9SAndroid Build Coastguard Worker # https://pytorch.org/blog/understanding-gpu-memory-1/ 2054*da0073e9SAndroid Build Coastguard Worker try: 2055*da0073e9SAndroid Build Coastguard Worker if should_snapshot_memory: 2056*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history(max_entries=100000) 2057*da0073e9SAndroid Build Coastguard Worker yield 2058*da0073e9SAndroid Build Coastguard Worker finally: 2059*da0073e9SAndroid Build Coastguard Worker if should_snapshot_memory: 2060*da0073e9SAndroid Build Coastguard Worker try: 2061*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._dump_snapshot( 2062*da0073e9SAndroid Build Coastguard Worker os.path.join( 2063*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.base_dir, 2064*da0073e9SAndroid Build Coastguard Worker f"{output_filename.rstrip('.csv')}_{suffix}.pickle", 2065*da0073e9SAndroid Build Coastguard Worker ) 2066*da0073e9SAndroid Build Coastguard Worker ) 2067*da0073e9SAndroid Build Coastguard Worker except Exception as e: 2068*da0073e9SAndroid Build Coastguard Worker logging.error("Failed to save memory snapshot, %s", e) 2069*da0073e9SAndroid Build Coastguard Worker 2070*da0073e9SAndroid Build Coastguard Worker torch.cuda.memory._record_memory_history(enabled=None) 2071*da0073e9SAndroid Build Coastguard Worker 2072*da0073e9SAndroid Build Coastguard Worker 2073*da0073e9SAndroid Build Coastguard Workerclass BenchmarkRunner: 2074*da0073e9SAndroid Build Coastguard Worker def __init__(self): 2075*da0073e9SAndroid Build Coastguard Worker self.model_iter_fn = None 2076*da0073e9SAndroid Build Coastguard Worker self.grad_scaler = DummyGradScaler() 2077*da0073e9SAndroid Build Coastguard Worker self.autocast = contextlib.nullcontext 2078*da0073e9SAndroid Build Coastguard Worker self.autocast_arg = {} 2079*da0073e9SAndroid Build Coastguard Worker self.optimizer = None 2080*da0073e9SAndroid Build Coastguard Worker self._args = None 2081*da0073e9SAndroid Build Coastguard Worker 2082*da0073e9SAndroid Build Coastguard Worker def setup_amp(self, current_device=None): 2083*da0073e9SAndroid Build Coastguard Worker if self.args.only in self.fp32_only_models: 2084*da0073e9SAndroid Build Coastguard Worker return 2085*da0073e9SAndroid Build Coastguard Worker 2086*da0073e9SAndroid Build Coastguard Worker devices = [current_device] if current_device else self.args.devices 2087*da0073e9SAndroid Build Coastguard Worker if self.args.amp: 2088*da0073e9SAndroid Build Coastguard Worker # AMP training can lead to small loss values which can undeflow 2089*da0073e9SAndroid Build Coastguard Worker # gradient values returning in zero gradients. To solve this 2090*da0073e9SAndroid Build Coastguard Worker # problem, PyTorch introduces GradScaler. GradScaler is a stateful 2091*da0073e9SAndroid Build Coastguard Worker # structure, that scales the loss values to prevent underflow. Loss 2092*da0073e9SAndroid Build Coastguard Worker # values are big at the beginning of training (therefore not 2093*da0073e9SAndroid Build Coastguard Worker # requiring scaling), while loss value tends to be small as network 2094*da0073e9SAndroid Build Coastguard Worker # starts getting better (requiring scaling). GradScaler manages all 2095*da0073e9SAndroid Build Coastguard Worker # of this fine tuning, checking the gradients are turning to inf, 2096*da0073e9SAndroid Build Coastguard Worker # discarding such batches. 2097*da0073e9SAndroid Build Coastguard Worker 2098*da0073e9SAndroid Build Coastguard Worker # Since we are not running a long iteration, default value of 2099*da0073e9SAndroid Build Coastguard Worker # init_scale 65536 is going to turn all gradients to inf. Therefore, 2100*da0073e9SAndroid Build Coastguard Worker # we just use a init_scale of 2.0 for benchmarking purpose. 2101*da0073e9SAndroid Build Coastguard Worker 2102*da0073e9SAndroid Build Coastguard Worker # Disabling Gradscaler because 2103*da0073e9SAndroid Build Coastguard Worker # 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful. 2104*da0073e9SAndroid Build Coastguard Worker # 2) Current setup shares grad_scaler for eager and dynamo model, 2105*da0073e9SAndroid Build Coastguard Worker # which is bad as Gradscaler has state and can adjust the scaling 2106*da0073e9SAndroid Build Coastguard Worker # factor between eager and dynamo run, making accuracy check 2107*da0073e9SAndroid Build Coastguard Worker # harder. 2108*da0073e9SAndroid Build Coastguard Worker # self.grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0) 2109*da0073e9SAndroid Build Coastguard Worker self.autocast = functools.partial( 2110*da0073e9SAndroid Build Coastguard Worker torch.amp.autocast, device_type=devices[0] 2111*da0073e9SAndroid Build Coastguard Worker ) 2112*da0073e9SAndroid Build Coastguard Worker if self.args.amp_dtype: 2113*da0073e9SAndroid Build Coastguard Worker amp_dtype = ( 2114*da0073e9SAndroid Build Coastguard Worker torch.float16 2115*da0073e9SAndroid Build Coastguard Worker if self.args.amp_dtype == "float16" 2116*da0073e9SAndroid Build Coastguard Worker else torch.bfloat16 2117*da0073e9SAndroid Build Coastguard Worker ) 2118*da0073e9SAndroid Build Coastguard Worker self.autocast_arg["dtype"] = amp_dtype 2119*da0073e9SAndroid Build Coastguard Worker 2120*da0073e9SAndroid Build Coastguard Worker def init_optimizer(self, name, device, params): 2121*da0073e9SAndroid Build Coastguard Worker if device == "cuda" and self.args.training and name not in CI_SKIP_OPTIMIZER: 2122*da0073e9SAndroid Build Coastguard Worker if (name in CI_USE_SGD and self.args.ci) or name in BENCHMARK_USE_SGD: 2123*da0073e9SAndroid Build Coastguard Worker self.optimizer = torch.optim.SGD(params, lr=0.01, foreach=True) 2124*da0073e9SAndroid Build Coastguard Worker # Disable multi_tensor_sgd for benchmarking, there isn't a large performance benefit (~1%) to compiling 2125*da0073e9SAndroid Build Coastguard Worker # this optimizer because it is a single foreach add, and increases compile time. 2126*da0073e9SAndroid Build Coastguard Worker # After autotuning and fake tensor caching lands, we can enable, becuase the compile time impact will be lower. 2127*da0073e9SAndroid Build Coastguard Worker # Fake Tensor caching: https://github.com/pytorch/pytorch/pull/113873 2128*da0073e9SAndroid Build Coastguard Worker # Autotuning: https://github.com/pytorch/pytorch/issues/117447 2129*da0073e9SAndroid Build Coastguard Worker self.optimizer.step = torch._dynamo.disable(self.optimizer.step) 2130*da0073e9SAndroid Build Coastguard Worker else: 2131*da0073e9SAndroid Build Coastguard Worker self.optimizer = torch.optim.Adam( 2132*da0073e9SAndroid Build Coastguard Worker params, lr=0.01, capturable=True, foreach=True 2133*da0073e9SAndroid Build Coastguard Worker ) 2134*da0073e9SAndroid Build Coastguard Worker else: 2135*da0073e9SAndroid Build Coastguard Worker self.optimizer = None 2136*da0073e9SAndroid Build Coastguard Worker 2137*da0073e9SAndroid Build Coastguard Worker @property 2138*da0073e9SAndroid Build Coastguard Worker def args(self): 2139*da0073e9SAndroid Build Coastguard Worker return self._args 2140*da0073e9SAndroid Build Coastguard Worker 2141*da0073e9SAndroid Build Coastguard Worker @args.setter 2142*da0073e9SAndroid Build Coastguard Worker def args(self, args): 2143*da0073e9SAndroid Build Coastguard Worker self._args = args 2144*da0073e9SAndroid Build Coastguard Worker 2145*da0073e9SAndroid Build Coastguard Worker @property 2146*da0073e9SAndroid Build Coastguard Worker def skip_models(self): 2147*da0073e9SAndroid Build Coastguard Worker return set() 2148*da0073e9SAndroid Build Coastguard Worker 2149*da0073e9SAndroid Build Coastguard Worker @property 2150*da0073e9SAndroid Build Coastguard Worker def skip_models_for_cuda(self): 2151*da0073e9SAndroid Build Coastguard Worker return set() 2152*da0073e9SAndroid Build Coastguard Worker 2153*da0073e9SAndroid Build Coastguard Worker @property 2154*da0073e9SAndroid Build Coastguard Worker def skip_models_for_cpu(self): 2155*da0073e9SAndroid Build Coastguard Worker return set() 2156*da0073e9SAndroid Build Coastguard Worker 2157*da0073e9SAndroid Build Coastguard Worker @property 2158*da0073e9SAndroid Build Coastguard Worker def skip_models_for_freezing(self): 2159*da0073e9SAndroid Build Coastguard Worker return set() 2160*da0073e9SAndroid Build Coastguard Worker 2161*da0073e9SAndroid Build Coastguard Worker @property 2162*da0073e9SAndroid Build Coastguard Worker def slow_models(self): 2163*da0073e9SAndroid Build Coastguard Worker return set() 2164*da0073e9SAndroid Build Coastguard Worker 2165*da0073e9SAndroid Build Coastguard Worker @property 2166*da0073e9SAndroid Build Coastguard Worker def very_slow_models(self): 2167*da0073e9SAndroid Build Coastguard Worker return set() 2168*da0073e9SAndroid Build Coastguard Worker 2169*da0073e9SAndroid Build Coastguard Worker @property 2170*da0073e9SAndroid Build Coastguard Worker def non_deterministic_models(self): 2171*da0073e9SAndroid Build Coastguard Worker return set() 2172*da0073e9SAndroid Build Coastguard Worker 2173*da0073e9SAndroid Build Coastguard Worker @property 2174*da0073e9SAndroid Build Coastguard Worker def fp32_only_models(self): 2175*da0073e9SAndroid Build Coastguard Worker return set() 2176*da0073e9SAndroid Build Coastguard Worker 2177*da0073e9SAndroid Build Coastguard Worker @property 2178*da0073e9SAndroid Build Coastguard Worker def force_amp_for_fp16_bf16_models(self): 2179*da0073e9SAndroid Build Coastguard Worker return set() 2180*da0073e9SAndroid Build Coastguard Worker 2181*da0073e9SAndroid Build Coastguard Worker @property 2182*da0073e9SAndroid Build Coastguard Worker def force_fp16_for_bf16_models(self): 2183*da0073e9SAndroid Build Coastguard Worker return set() 2184*da0073e9SAndroid Build Coastguard Worker 2185*da0073e9SAndroid Build Coastguard Worker @property 2186*da0073e9SAndroid Build Coastguard Worker def skip_not_suitable_for_training_models(self): 2187*da0073e9SAndroid Build Coastguard Worker return set() 2188*da0073e9SAndroid Build Coastguard Worker 2189*da0073e9SAndroid Build Coastguard Worker @property 2190*da0073e9SAndroid Build Coastguard Worker def failing_torchinductor_models(self): 2191*da0073e9SAndroid Build Coastguard Worker return set() 2192*da0073e9SAndroid Build Coastguard Worker 2193*da0073e9SAndroid Build Coastguard Worker @property 2194*da0073e9SAndroid Build Coastguard Worker def failing_fx2trt_models(self): 2195*da0073e9SAndroid Build Coastguard Worker return set() 2196*da0073e9SAndroid Build Coastguard Worker 2197*da0073e9SAndroid Build Coastguard Worker @property 2198*da0073e9SAndroid Build Coastguard Worker def skip_accuracy_checks_large_models_dashboard(self): 2199*da0073e9SAndroid Build Coastguard Worker return set() 2200*da0073e9SAndroid Build Coastguard Worker 2201*da0073e9SAndroid Build Coastguard Worker @property 2202*da0073e9SAndroid Build Coastguard Worker def skip_accuracy_check_as_eager_non_deterministic(self): 2203*da0073e9SAndroid Build Coastguard Worker return set() 2204*da0073e9SAndroid Build Coastguard Worker 2205*da0073e9SAndroid Build Coastguard Worker @property 2206*da0073e9SAndroid Build Coastguard Worker def skip_multiprocess_models(self): 2207*da0073e9SAndroid Build Coastguard Worker return set() 2208*da0073e9SAndroid Build Coastguard Worker 2209*da0073e9SAndroid Build Coastguard Worker @property 2210*da0073e9SAndroid Build Coastguard Worker def skip_models_due_to_control_flow(self): 2211*da0073e9SAndroid Build Coastguard Worker return set() 2212*da0073e9SAndroid Build Coastguard Worker 2213*da0073e9SAndroid Build Coastguard Worker @property 2214*da0073e9SAndroid Build Coastguard Worker def guard_on_nn_module_models(self): 2215*da0073e9SAndroid Build Coastguard Worker return set() 2216*da0073e9SAndroid Build Coastguard Worker 2217*da0073e9SAndroid Build Coastguard Worker def get_tolerance_and_cosine_flag(self, is_training, current_device, name): 2218*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 2219*da0073e9SAndroid Build Coastguard Worker 2220*da0073e9SAndroid Build Coastguard Worker @property 2221*da0073e9SAndroid Build Coastguard Worker def equal_nan(self): 2222*da0073e9SAndroid Build Coastguard Worker equal_nan = True 2223*da0073e9SAndroid Build Coastguard Worker if self.args.float32: 2224*da0073e9SAndroid Build Coastguard Worker equal_nan = False 2225*da0073e9SAndroid Build Coastguard Worker return equal_nan 2226*da0073e9SAndroid Build Coastguard Worker 2227*da0073e9SAndroid Build Coastguard Worker def iter_models(self, args): 2228*da0073e9SAndroid Build Coastguard Worker for model_name in self.iter_model_names(args): 2229*da0073e9SAndroid Build Coastguard Worker for device in args.devices: 2230*da0073e9SAndroid Build Coastguard Worker try: 2231*da0073e9SAndroid Build Coastguard Worker yield self.load_model( 2232*da0073e9SAndroid Build Coastguard Worker device, 2233*da0073e9SAndroid Build Coastguard Worker model_name, 2234*da0073e9SAndroid Build Coastguard Worker batch_size=args.batch_size, 2235*da0073e9SAndroid Build Coastguard Worker ) 2236*da0073e9SAndroid Build Coastguard Worker except NotImplementedError: 2237*da0073e9SAndroid Build Coastguard Worker continue # bad benchmark implementation 2238*da0073e9SAndroid Build Coastguard Worker 2239*da0073e9SAndroid Build Coastguard Worker def deepcopy_model(self, model): 2240*da0073e9SAndroid Build Coastguard Worker return copy.deepcopy(model) 2241*da0073e9SAndroid Build Coastguard Worker 2242*da0073e9SAndroid Build Coastguard Worker def cast_based_on_args(self, model, example_inputs): 2243*da0073e9SAndroid Build Coastguard Worker if self.args.float32 or self.args.only in self.fp32_only_models: 2244*da0073e9SAndroid Build Coastguard Worker if not self.args.float32: 2245*da0073e9SAndroid Build Coastguard Worker log.warning("Model %s supports float32 only", self.args.only) 2246*da0073e9SAndroid Build Coastguard Worker model, example_inputs = cast_to_fp32(model, example_inputs) 2247*da0073e9SAndroid Build Coastguard Worker elif self.args.float16: 2248*da0073e9SAndroid Build Coastguard Worker if self.args.only in self.force_amp_for_fp16_bf16_models: 2249*da0073e9SAndroid Build Coastguard Worker log.warning( 2250*da0073e9SAndroid Build Coastguard Worker "Model %s does not support float16, running with amp instead", 2251*da0073e9SAndroid Build Coastguard Worker self.args.only, 2252*da0073e9SAndroid Build Coastguard Worker ) 2253*da0073e9SAndroid Build Coastguard Worker self.args.amp = True 2254*da0073e9SAndroid Build Coastguard Worker self.setup_amp() 2255*da0073e9SAndroid Build Coastguard Worker else: 2256*da0073e9SAndroid Build Coastguard Worker model, example_inputs = cast_to_fp16(model, example_inputs) 2257*da0073e9SAndroid Build Coastguard Worker elif self.args.bfloat16: 2258*da0073e9SAndroid Build Coastguard Worker if self.args.only in self.force_amp_for_fp16_bf16_models: 2259*da0073e9SAndroid Build Coastguard Worker log.warning( 2260*da0073e9SAndroid Build Coastguard Worker "Model %s does not support bfloat16, running with amp instead", 2261*da0073e9SAndroid Build Coastguard Worker self.args.only, 2262*da0073e9SAndroid Build Coastguard Worker ) 2263*da0073e9SAndroid Build Coastguard Worker self.args.amp = True 2264*da0073e9SAndroid Build Coastguard Worker self.setup_amp() 2265*da0073e9SAndroid Build Coastguard Worker elif self.args.only in self.force_fp16_for_bf16_models: 2266*da0073e9SAndroid Build Coastguard Worker log.warning( 2267*da0073e9SAndroid Build Coastguard Worker "Model %s does not support bfloat16, running with float16 instead", 2268*da0073e9SAndroid Build Coastguard Worker self.args.only, 2269*da0073e9SAndroid Build Coastguard Worker ) 2270*da0073e9SAndroid Build Coastguard Worker model, example_inputs = cast_to_fp16(model, example_inputs) 2271*da0073e9SAndroid Build Coastguard Worker else: 2272*da0073e9SAndroid Build Coastguard Worker model, example_inputs = cast_to_bf16(model, example_inputs) 2273*da0073e9SAndroid Build Coastguard Worker 2274*da0073e9SAndroid Build Coastguard Worker return model, example_inputs 2275*da0073e9SAndroid Build Coastguard Worker 2276*da0073e9SAndroid Build Coastguard Worker def validate_model(self, model, example_inputs): 2277*da0073e9SAndroid Build Coastguard Worker """ 2278*da0073e9SAndroid Build Coastguard Worker Runs the eager model with example inputs to ensure that eager passes. 2279*da0073e9SAndroid Build Coastguard Worker """ 2280*da0073e9SAndroid Build Coastguard Worker model = self.deepcopy_model(model) 2281*da0073e9SAndroid Build Coastguard Worker example_inputs = clone_inputs(example_inputs) 2282*da0073e9SAndroid Build Coastguard Worker model, example_inputs = self.cast_based_on_args(model, example_inputs) 2283*da0073e9SAndroid Build Coastguard Worker try: 2284*da0073e9SAndroid Build Coastguard Worker self.model_iter_fn(model, example_inputs) 2285*da0073e9SAndroid Build Coastguard Worker except Exception as e: 2286*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Eager run failed") from e 2287*da0073e9SAndroid Build Coastguard Worker 2288*da0073e9SAndroid Build Coastguard Worker def maybe_cast(self, model, example_inputs): 2289*da0073e9SAndroid Build Coastguard Worker model, example_inputs = self.cast_based_on_args(model, example_inputs) 2290*da0073e9SAndroid Build Coastguard Worker return model, example_inputs 2291*da0073e9SAndroid Build Coastguard Worker 2292*da0073e9SAndroid Build Coastguard Worker def decay_batch_exp(self, batch_size, factor=0.5, divisor=2): 2293*da0073e9SAndroid Build Coastguard Worker out_batch_size = batch_size * factor 2294*da0073e9SAndroid Build Coastguard Worker if out_batch_size > divisor: 2295*da0073e9SAndroid Build Coastguard Worker out_batch_size = (out_batch_size + 1) // divisor * divisor 2296*da0073e9SAndroid Build Coastguard Worker else: 2297*da0073e9SAndroid Build Coastguard Worker out_batch_size = batch_size - 1 2298*da0073e9SAndroid Build Coastguard Worker return max(0, int(out_batch_size)) 2299*da0073e9SAndroid Build Coastguard Worker 2300*da0073e9SAndroid Build Coastguard Worker def batch_size_finder(self, device, model_name, initial_batch_size=1024): 2301*da0073e9SAndroid Build Coastguard Worker batch_size = initial_batch_size 2302*da0073e9SAndroid Build Coastguard Worker while batch_size >= 1: 2303*da0073e9SAndroid Build Coastguard Worker empty_gpu_cache(current_device) 2304*da0073e9SAndroid Build Coastguard Worker try: 2305*da0073e9SAndroid Build Coastguard Worker device, name, model, example_inputs, _ = self.load_model( 2306*da0073e9SAndroid Build Coastguard Worker device, 2307*da0073e9SAndroid Build Coastguard Worker model_name, 2308*da0073e9SAndroid Build Coastguard Worker batch_size, 2309*da0073e9SAndroid Build Coastguard Worker ) 2310*da0073e9SAndroid Build Coastguard Worker self.model_iter_fn(model, example_inputs) 2311*da0073e9SAndroid Build Coastguard Worker return batch_size 2312*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 2313*da0073e9SAndroid Build Coastguard Worker error_str = str(e) 2314*da0073e9SAndroid Build Coastguard Worker if "channels_last" in error_str: 2315*da0073e9SAndroid Build Coastguard Worker break 2316*da0073e9SAndroid Build Coastguard Worker batch_size = self.decay_batch_exp(batch_size) 2317*da0073e9SAndroid Build Coastguard Worker return 1 2318*da0073e9SAndroid Build Coastguard Worker 2319*da0073e9SAndroid Build Coastguard Worker def run_n_iterations(self, mod, inputs): 2320*da0073e9SAndroid Build Coastguard Worker n = self.args.iterations 2321*da0073e9SAndroid Build Coastguard Worker for _ in range(n - 1): 2322*da0073e9SAndroid Build Coastguard Worker self.model_iter_fn(mod, inputs, collect_outputs=False) 2323*da0073e9SAndroid Build Coastguard Worker return self.model_iter_fn(mod, inputs, collect_outputs=True) 2324*da0073e9SAndroid Build Coastguard Worker 2325*da0073e9SAndroid Build Coastguard Worker @torch._disable_dynamo(recursive=True) 2326*da0073e9SAndroid Build Coastguard Worker def optimizer_zero_grad(self, mod): 2327*da0073e9SAndroid Build Coastguard Worker if self.optimizer is not None: 2328*da0073e9SAndroid Build Coastguard Worker self.optimizer.zero_grad(True) 2329*da0073e9SAndroid Build Coastguard Worker else: 2330*da0073e9SAndroid Build Coastguard Worker mod.zero_grad(True) 2331*da0073e9SAndroid Build Coastguard Worker 2332*da0073e9SAndroid Build Coastguard Worker def optimizer_step(self): 2333*da0073e9SAndroid Build Coastguard Worker if self.optimizer is not None: 2334*da0073e9SAndroid Build Coastguard Worker self.optimizer.step() 2335*da0073e9SAndroid Build Coastguard Worker 2336*da0073e9SAndroid Build Coastguard Worker def get_benchmark_indices(self, length): 2337*da0073e9SAndroid Build Coastguard Worker start = self._args.partition_id * (length // self._args.total_partitions) 2338*da0073e9SAndroid Build Coastguard Worker end = ( 2339*da0073e9SAndroid Build Coastguard Worker (self._args.partition_id + 1) * (length // self._args.total_partitions) 2340*da0073e9SAndroid Build Coastguard Worker if self._args.partition_id < self._args.total_partitions - 1 2341*da0073e9SAndroid Build Coastguard Worker else length 2342*da0073e9SAndroid Build Coastguard Worker ) 2343*da0073e9SAndroid Build Coastguard Worker return start, end 2344*da0073e9SAndroid Build Coastguard Worker 2345*da0073e9SAndroid Build Coastguard Worker def get_fsdp_auto_wrap_policy(self, model_name: str): 2346*da0073e9SAndroid Build Coastguard Worker from diffusers.models.transformer_2d import Transformer2DModel 2347*da0073e9SAndroid Build Coastguard Worker from torchbenchmark.models.nanogpt.model import Block 2348*da0073e9SAndroid Build Coastguard Worker from transformers.models.llama.modeling_llama import LlamaDecoderLayer 2349*da0073e9SAndroid Build Coastguard Worker 2350*da0073e9SAndroid Build Coastguard Worker from transformers.models.t5.modeling_t5 import T5Block 2351*da0073e9SAndroid Build Coastguard Worker from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer 2352*da0073e9SAndroid Build Coastguard Worker 2353*da0073e9SAndroid Build Coastguard Worker from torch.distributed.fsdp.wrap import ( 2354*da0073e9SAndroid Build Coastguard Worker ModuleWrapPolicy, 2355*da0073e9SAndroid Build Coastguard Worker size_based_auto_wrap_policy, 2356*da0073e9SAndroid Build Coastguard Worker ) 2357*da0073e9SAndroid Build Coastguard Worker 2358*da0073e9SAndroid Build Coastguard Worker # handcrafted wrap policy 2359*da0073e9SAndroid Build Coastguard Worker MODEL_FSDP_WRAP = { 2360*da0073e9SAndroid Build Coastguard Worker "stable_diffusion_unet": (Transformer2DModel,), 2361*da0073e9SAndroid Build Coastguard Worker "hf_T5": (T5Block,), 2362*da0073e9SAndroid Build Coastguard Worker "hf_T5_base": (T5Block,), 2363*da0073e9SAndroid Build Coastguard Worker "hf_T5_large": (T5Block,), 2364*da0073e9SAndroid Build Coastguard Worker "hf_Whisper": (WhisperEncoderLayer,), 2365*da0073e9SAndroid Build Coastguard Worker "llama_v2_7b_16h": (LlamaDecoderLayer,), 2366*da0073e9SAndroid Build Coastguard Worker "nanogpt": (Block,), 2367*da0073e9SAndroid Build Coastguard Worker } 2368*da0073e9SAndroid Build Coastguard Worker 2369*da0073e9SAndroid Build Coastguard Worker if model_name not in MODEL_FSDP_WRAP: 2370*da0073e9SAndroid Build Coastguard Worker # default to using wrap policy based on module size 2371*da0073e9SAndroid Build Coastguard Worker return functools.partial( 2372*da0073e9SAndroid Build Coastguard Worker size_based_auto_wrap_policy, recurse=True, min_num_params=int(1e5) 2373*da0073e9SAndroid Build Coastguard Worker ) 2374*da0073e9SAndroid Build Coastguard Worker 2375*da0073e9SAndroid Build Coastguard Worker return ModuleWrapPolicy(MODEL_FSDP_WRAP[model_name]) 2376*da0073e9SAndroid Build Coastguard Worker 2377*da0073e9SAndroid Build Coastguard Worker def deepcopy_and_maybe_parallelize(self, model): 2378*da0073e9SAndroid Build Coastguard Worker model = self.deepcopy_model(model) 2379*da0073e9SAndroid Build Coastguard Worker if self.args.ddp: 2380*da0073e9SAndroid Build Coastguard Worker assert ( 2381*da0073e9SAndroid Build Coastguard Worker torch.distributed.is_available() 2382*da0073e9SAndroid Build Coastguard Worker ), "Can't use DDP without a distributed enabled build" 2383*da0073e9SAndroid Build Coastguard Worker from torch.nn.parallel import DistributedDataParallel as DDP 2384*da0073e9SAndroid Build Coastguard Worker 2385*da0073e9SAndroid Build Coastguard Worker model = DDP(model, find_unused_parameters=True) 2386*da0073e9SAndroid Build Coastguard Worker elif self.args.fsdp: 2387*da0073e9SAndroid Build Coastguard Worker assert ( 2388*da0073e9SAndroid Build Coastguard Worker torch.distributed.is_available() 2389*da0073e9SAndroid Build Coastguard Worker ), "Can't use FSDP without a distributed enabled build" 2390*da0073e9SAndroid Build Coastguard Worker from torch.distributed.fsdp import ( 2391*da0073e9SAndroid Build Coastguard Worker FullyShardedDataParallel as FSDP, 2392*da0073e9SAndroid Build Coastguard Worker MixedPrecision, 2393*da0073e9SAndroid Build Coastguard Worker ) 2394*da0073e9SAndroid Build Coastguard Worker 2395*da0073e9SAndroid Build Coastguard Worker if self.args.float16: 2396*da0073e9SAndroid Build Coastguard Worker dtype = torch.float16 2397*da0073e9SAndroid Build Coastguard Worker elif self.args.bfloat16: 2398*da0073e9SAndroid Build Coastguard Worker dtype = torch.bfloat16 2399*da0073e9SAndroid Build Coastguard Worker else: 2400*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 2401*da0073e9SAndroid Build Coastguard Worker 2402*da0073e9SAndroid Build Coastguard Worker mp_policy = MixedPrecision( 2403*da0073e9SAndroid Build Coastguard Worker param_dtype=dtype, 2404*da0073e9SAndroid Build Coastguard Worker # Gradient communication precision. 2405*da0073e9SAndroid Build Coastguard Worker reduce_dtype=dtype, 2406*da0073e9SAndroid Build Coastguard Worker # Buffer precision. 2407*da0073e9SAndroid Build Coastguard Worker buffer_dtype=dtype, 2408*da0073e9SAndroid Build Coastguard Worker ) 2409*da0073e9SAndroid Build Coastguard Worker 2410*da0073e9SAndroid Build Coastguard Worker model = FSDP( 2411*da0073e9SAndroid Build Coastguard Worker model, 2412*da0073e9SAndroid Build Coastguard Worker use_orig_params=True, 2413*da0073e9SAndroid Build Coastguard Worker device_id=torch.cuda.current_device() 2414*da0073e9SAndroid Build Coastguard Worker if self.args.devices[-1] == "cuda" 2415*da0073e9SAndroid Build Coastguard Worker else None, 2416*da0073e9SAndroid Build Coastguard Worker mixed_precision=mp_policy, 2417*da0073e9SAndroid Build Coastguard Worker limit_all_gathers=True, 2418*da0073e9SAndroid Build Coastguard Worker auto_wrap_policy=self.get_fsdp_auto_wrap_policy(self.args.only), 2419*da0073e9SAndroid Build Coastguard Worker ) 2420*da0073e9SAndroid Build Coastguard Worker return model 2421*da0073e9SAndroid Build Coastguard Worker 2422*da0073e9SAndroid Build Coastguard Worker def check_accuracy( 2423*da0073e9SAndroid Build Coastguard Worker self, name, model, example_inputs, optimize_ctx, experiment, tag 2424*da0073e9SAndroid Build Coastguard Worker ): 2425*da0073e9SAndroid Build Coastguard Worker """ 2426*da0073e9SAndroid Build Coastguard Worker Checks accuracy. 2427*da0073e9SAndroid Build Coastguard Worker 1) Collect the outputs with fp64 datatype. This is useful for error checking. 2428*da0073e9SAndroid Build Coastguard Worker 2) Checks if eager itself has variations. 2429*da0073e9SAndroid Build Coastguard Worker """ 2430*da0073e9SAndroid Build Coastguard Worker start_stats = get_dynamo_stats() 2431*da0073e9SAndroid Build Coastguard Worker 2432*da0073e9SAndroid Build Coastguard Worker def record_status(accuracy_status, dynamo_start_stats): 2433*da0073e9SAndroid Build Coastguard Worker """ 2434*da0073e9SAndroid Build Coastguard Worker Records the status in the csv file 2435*da0073e9SAndroid Build Coastguard Worker """ 2436*da0073e9SAndroid Build Coastguard Worker if current_name in self.non_deterministic_models: 2437*da0073e9SAndroid Build Coastguard Worker if accuracy_status in ( 2438*da0073e9SAndroid Build Coastguard Worker "pass", 2439*da0073e9SAndroid Build Coastguard Worker "eager_two_runs_differ", 2440*da0073e9SAndroid Build Coastguard Worker "fail_accuracy", 2441*da0073e9SAndroid Build Coastguard Worker ): 2442*da0073e9SAndroid Build Coastguard Worker accuracy_status = "pass" 2443*da0073e9SAndroid Build Coastguard Worker 2444*da0073e9SAndroid Build Coastguard Worker headers = ["dev", "name", "batch_size", "accuracy"] 2445*da0073e9SAndroid Build Coastguard Worker fields = [current_device, current_name, current_batch_size, accuracy_status] 2446*da0073e9SAndroid Build Coastguard Worker 2447*da0073e9SAndroid Build Coastguard Worker if tag is not None: 2448*da0073e9SAndroid Build Coastguard Worker headers.insert(3, "tag") 2449*da0073e9SAndroid Build Coastguard Worker fields.insert(3, tag) 2450*da0073e9SAndroid Build Coastguard Worker 2451*da0073e9SAndroid Build Coastguard Worker dynamo_stats = get_dynamo_stats() 2452*da0073e9SAndroid Build Coastguard Worker dynamo_stats.subtract(dynamo_start_stats) 2453*da0073e9SAndroid Build Coastguard Worker for k, v in dynamo_stats.items(): 2454*da0073e9SAndroid Build Coastguard Worker headers.append(k) 2455*da0073e9SAndroid Build Coastguard Worker fields.append(v) 2456*da0073e9SAndroid Build Coastguard Worker 2457*da0073e9SAndroid Build Coastguard Worker output_csv(output_filename, headers, fields) 2458*da0073e9SAndroid Build Coastguard Worker return accuracy_status 2459*da0073e9SAndroid Build Coastguard Worker 2460*da0073e9SAndroid Build Coastguard Worker if name in self.skip_accuracy_checks_large_models_dashboard: 2461*da0073e9SAndroid Build Coastguard Worker return record_status("pass_due_to_skip", dynamo_start_stats=start_stats) 2462*da0073e9SAndroid Build Coastguard Worker 2463*da0073e9SAndroid Build Coastguard Worker with self.pick_grad(name, self.args.training): 2464*da0073e9SAndroid Build Coastguard Worker # Collect the fp64 reference outputs to be used later for accuracy checking. 2465*da0073e9SAndroid Build Coastguard Worker fp64_outputs = None 2466*da0073e9SAndroid Build Coastguard Worker model_fp64 = None 2467*da0073e9SAndroid Build Coastguard Worker inputs_fp64 = None 2468*da0073e9SAndroid Build Coastguard Worker try: 2469*da0073e9SAndroid Build Coastguard Worker model_fp64, inputs_fp64 = cast_to_fp64( 2470*da0073e9SAndroid Build Coastguard Worker self.deepcopy_and_maybe_parallelize(model), 2471*da0073e9SAndroid Build Coastguard Worker clone_inputs(example_inputs), 2472*da0073e9SAndroid Build Coastguard Worker ) 2473*da0073e9SAndroid Build Coastguard Worker self.init_optimizer(name, current_device, model_fp64.parameters()) 2474*da0073e9SAndroid Build Coastguard Worker fp64_outputs = self.run_n_iterations(model_fp64, inputs_fp64) 2475*da0073e9SAndroid Build Coastguard Worker fp64_outputs = tree_map( 2476*da0073e9SAndroid Build Coastguard Worker lambda x: x.to(torch.float64) 2477*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor) and x.is_floating_point() 2478*da0073e9SAndroid Build Coastguard Worker else x, 2479*da0073e9SAndroid Build Coastguard Worker fp64_outputs, 2480*da0073e9SAndroid Build Coastguard Worker ) 2481*da0073e9SAndroid Build Coastguard Worker except Exception: 2482*da0073e9SAndroid Build Coastguard Worker log.warning( 2483*da0073e9SAndroid Build Coastguard Worker "fp64 golden ref were not generated for %s. Setting accuracy check to cosine", 2484*da0073e9SAndroid Build Coastguard Worker name, 2485*da0073e9SAndroid Build Coastguard Worker ) 2486*da0073e9SAndroid Build Coastguard Worker self.args.cosine = True 2487*da0073e9SAndroid Build Coastguard Worker fp64_outputs = None 2488*da0073e9SAndroid Build Coastguard Worker finally: 2489*da0073e9SAndroid Build Coastguard Worker del model_fp64, inputs_fp64 2490*da0073e9SAndroid Build Coastguard Worker empty_gpu_cache(current_device) 2491*da0073e9SAndroid Build Coastguard Worker 2492*da0073e9SAndroid Build Coastguard Worker tolerance, cos_similarity = self.get_tolerance_and_cosine_flag( 2493*da0073e9SAndroid Build Coastguard Worker self.args.training, current_device, name 2494*da0073e9SAndroid Build Coastguard Worker ) 2495*da0073e9SAndroid Build Coastguard Worker 2496*da0073e9SAndroid Build Coastguard Worker # Cast the model to float16/float32 as necessary 2497*da0073e9SAndroid Build Coastguard Worker model, example_inputs = self.maybe_cast(model, example_inputs) 2498*da0073e9SAndroid Build Coastguard Worker accuracy_status = "pass" 2499*da0073e9SAndroid Build Coastguard Worker 2500*da0073e9SAndroid Build Coastguard Worker # Get results of native pytorch 2501*da0073e9SAndroid Build Coastguard Worker reset_rng_state() 2502*da0073e9SAndroid Build Coastguard Worker model_copy = None 2503*da0073e9SAndroid Build Coastguard Worker try: 2504*da0073e9SAndroid Build Coastguard Worker model_copy = self.deepcopy_and_maybe_parallelize(model) 2505*da0073e9SAndroid Build Coastguard Worker self.init_optimizer(name, current_device, model_copy.parameters()) 2506*da0073e9SAndroid Build Coastguard Worker correct_result = self.run_n_iterations( 2507*da0073e9SAndroid Build Coastguard Worker model_copy, clone_inputs(example_inputs) 2508*da0073e9SAndroid Build Coastguard Worker ) 2509*da0073e9SAndroid Build Coastguard Worker except Exception as e: 2510*da0073e9SAndroid Build Coastguard Worker accuracy_status = ( 2511*da0073e9SAndroid Build Coastguard Worker "eager_1st_run_OOM" 2512*da0073e9SAndroid Build Coastguard Worker if isinstance(e, torch.cuda.OutOfMemoryError) 2513*da0073e9SAndroid Build Coastguard Worker else "eager_1st_run_fail" 2514*da0073e9SAndroid Build Coastguard Worker ) 2515*da0073e9SAndroid Build Coastguard Worker log.exception("") 2516*da0073e9SAndroid Build Coastguard Worker return record_status(accuracy_status, dynamo_start_stats=start_stats) 2517*da0073e9SAndroid Build Coastguard Worker finally: 2518*da0073e9SAndroid Build Coastguard Worker del model_copy 2519*da0073e9SAndroid Build Coastguard Worker empty_gpu_cache(current_device) 2520*da0073e9SAndroid Build Coastguard Worker 2521*da0073e9SAndroid Build Coastguard Worker # Rerun native pytorch 2522*da0073e9SAndroid Build Coastguard Worker reset_rng_state() 2523*da0073e9SAndroid Build Coastguard Worker model_copy = None 2524*da0073e9SAndroid Build Coastguard Worker try: 2525*da0073e9SAndroid Build Coastguard Worker model_copy = self.deepcopy_and_maybe_parallelize(model) 2526*da0073e9SAndroid Build Coastguard Worker self.init_optimizer(name, current_device, model_copy.parameters()) 2527*da0073e9SAndroid Build Coastguard Worker correct_rerun_result = self.run_n_iterations( 2528*da0073e9SAndroid Build Coastguard Worker model_copy, clone_inputs(example_inputs) 2529*da0073e9SAndroid Build Coastguard Worker ) 2530*da0073e9SAndroid Build Coastguard Worker except Exception as e: 2531*da0073e9SAndroid Build Coastguard Worker accuracy_status = ( 2532*da0073e9SAndroid Build Coastguard Worker "eager_2nd_run_OOM" 2533*da0073e9SAndroid Build Coastguard Worker if isinstance(e, torch.cuda.OutOfMemoryError) 2534*da0073e9SAndroid Build Coastguard Worker else "eager_2nd_run_fail" 2535*da0073e9SAndroid Build Coastguard Worker ) 2536*da0073e9SAndroid Build Coastguard Worker log.exception("") 2537*da0073e9SAndroid Build Coastguard Worker return record_status(accuracy_status, dynamo_start_stats=start_stats) 2538*da0073e9SAndroid Build Coastguard Worker finally: 2539*da0073e9SAndroid Build Coastguard Worker del model_copy 2540*da0073e9SAndroid Build Coastguard Worker empty_gpu_cache(current_device) 2541*da0073e9SAndroid Build Coastguard Worker 2542*da0073e9SAndroid Build Coastguard Worker # Two eager runs should have exactly same result 2543*da0073e9SAndroid Build Coastguard Worker is_same = True 2544*da0073e9SAndroid Build Coastguard Worker try: 2545*da0073e9SAndroid Build Coastguard Worker if ( 2546*da0073e9SAndroid Build Coastguard Worker name not in self.skip_accuracy_check_as_eager_non_deterministic 2547*da0073e9SAndroid Build Coastguard Worker and not same( 2548*da0073e9SAndroid Build Coastguard Worker correct_result, 2549*da0073e9SAndroid Build Coastguard Worker correct_rerun_result, 2550*da0073e9SAndroid Build Coastguard Worker fp64_ref=None, 2551*da0073e9SAndroid Build Coastguard Worker cos_similarity=False, 2552*da0073e9SAndroid Build Coastguard Worker tol=0, 2553*da0073e9SAndroid Build Coastguard Worker equal_nan=self.equal_nan, 2554*da0073e9SAndroid Build Coastguard Worker ) 2555*da0073e9SAndroid Build Coastguard Worker ): 2556*da0073e9SAndroid Build Coastguard Worker is_same = False 2557*da0073e9SAndroid Build Coastguard Worker except Exception as e: 2558*da0073e9SAndroid Build Coastguard Worker # Sometimes torch.allclose may throw RuntimeError 2559*da0073e9SAndroid Build Coastguard Worker is_same = False 2560*da0073e9SAndroid Build Coastguard Worker 2561*da0073e9SAndroid Build Coastguard Worker if not is_same: 2562*da0073e9SAndroid Build Coastguard Worker accuracy_status = "eager_two_runs_differ" 2563*da0073e9SAndroid Build Coastguard Worker return record_status(accuracy_status, dynamo_start_stats=start_stats) 2564*da0073e9SAndroid Build Coastguard Worker 2565*da0073e9SAndroid Build Coastguard Worker correct_rerun_result = None 2566*da0073e9SAndroid Build Coastguard Worker 2567*da0073e9SAndroid Build Coastguard Worker # Run with Dynamo 2568*da0073e9SAndroid Build Coastguard Worker reset_rng_state() 2569*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 2570*da0073e9SAndroid Build Coastguard Worker model_copy = None 2571*da0073e9SAndroid Build Coastguard Worker try: 2572*da0073e9SAndroid Build Coastguard Worker model_copy = self.deepcopy_and_maybe_parallelize(model) 2573*da0073e9SAndroid Build Coastguard Worker self.init_optimizer(name, current_device, model_copy.parameters()) 2574*da0073e9SAndroid Build Coastguard Worker if self.args.export or self.args.export_aot_inductor: 2575*da0073e9SAndroid Build Coastguard Worker # apply export on module directly 2576*da0073e9SAndroid Build Coastguard Worker # no need for n iterations 2577*da0073e9SAndroid Build Coastguard Worker # the logic should be the same to self.model_iter_fn (forward_pass) 2578*da0073e9SAndroid Build Coastguard Worker with self.autocast(**self.autocast_arg): 2579*da0073e9SAndroid Build Coastguard Worker optimized_model_iter_fn = optimize_ctx( 2580*da0073e9SAndroid Build Coastguard Worker model_copy, example_inputs 2581*da0073e9SAndroid Build Coastguard Worker ) 2582*da0073e9SAndroid Build Coastguard Worker new_result = optimized_model_iter_fn(model_copy, example_inputs) 2583*da0073e9SAndroid Build Coastguard Worker else: 2584*da0073e9SAndroid Build Coastguard Worker optimized_model_iter_fn = optimize_ctx(self.run_n_iterations) 2585*da0073e9SAndroid Build Coastguard Worker with maybe_enable_compiled_autograd(self.args.compiled_autograd): 2586*da0073e9SAndroid Build Coastguard Worker new_result = optimized_model_iter_fn(model_copy, example_inputs) 2587*da0073e9SAndroid Build Coastguard Worker except Exception as e: 2588*da0073e9SAndroid Build Coastguard Worker log.exception("") 2589*da0073e9SAndroid Build Coastguard Worker print( 2590*da0073e9SAndroid Build Coastguard Worker "TorchDynamo optimized model failed to run because of following error" 2591*da0073e9SAndroid Build Coastguard Worker ) 2592*da0073e9SAndroid Build Coastguard Worker accuracy_status = ( 2593*da0073e9SAndroid Build Coastguard Worker "OOM" 2594*da0073e9SAndroid Build Coastguard Worker if isinstance(e, torch.cuda.OutOfMemoryError) 2595*da0073e9SAndroid Build Coastguard Worker else "fail_to_run" 2596*da0073e9SAndroid Build Coastguard Worker ) 2597*da0073e9SAndroid Build Coastguard Worker return record_status(accuracy_status, dynamo_start_stats=start_stats) 2598*da0073e9SAndroid Build Coastguard Worker finally: 2599*da0073e9SAndroid Build Coastguard Worker del model_copy 2600*da0073e9SAndroid Build Coastguard Worker 2601*da0073e9SAndroid Build Coastguard Worker if name in self.skip_accuracy_check_as_eager_non_deterministic: 2602*da0073e9SAndroid Build Coastguard Worker return record_status("pass_due_to_skip", dynamo_start_stats=start_stats) 2603*da0073e9SAndroid Build Coastguard Worker 2604*da0073e9SAndroid Build Coastguard Worker if ( 2605*da0073e9SAndroid Build Coastguard Worker current_onnx_compiler == "torchscript" 2606*da0073e9SAndroid Build Coastguard Worker or current_onnx_compiler == "dynamo" 2607*da0073e9SAndroid Build Coastguard Worker ): 2608*da0073e9SAndroid Build Coastguard Worker # Workaround for ONNX for non-tensor outputs 2609*da0073e9SAndroid Build Coastguard Worker ( 2610*da0073e9SAndroid Build Coastguard Worker correct_result, 2611*da0073e9SAndroid Build Coastguard Worker new_result, 2612*da0073e9SAndroid Build Coastguard Worker fp64_outputs, 2613*da0073e9SAndroid Build Coastguard Worker ) = _OnnxPatch.patch_non_tensor_outputs( 2614*da0073e9SAndroid Build Coastguard Worker correct_result, new_result, fp64_outputs 2615*da0073e9SAndroid Build Coastguard Worker ) 2616*da0073e9SAndroid Build Coastguard Worker # Relax tolerance for ONNX cuda 2617*da0073e9SAndroid Build Coastguard Worker if current_device == "cuda": 2618*da0073e9SAndroid Build Coastguard Worker tolerance = 1e-2 2619*da0073e9SAndroid Build Coastguard Worker 2620*da0073e9SAndroid Build Coastguard Worker # TODO: store correct_result into the dumped file for offline onnx model validation. 2621*da0073e9SAndroid Build Coastguard Worker # The downside and potential problem, is that the output formats may be different. 2622*da0073e9SAndroid Build Coastguard Worker # E.g., the output order might not match, None might be part of output, etc. 2623*da0073e9SAndroid Build Coastguard Worker 2624*da0073e9SAndroid Build Coastguard Worker try: 2625*da0073e9SAndroid Build Coastguard Worker if self.args.training and self.args.amp: 2626*da0073e9SAndroid Build Coastguard Worker if process_fn := self.get_output_amp_train_process_func.get( 2627*da0073e9SAndroid Build Coastguard Worker name, None 2628*da0073e9SAndroid Build Coastguard Worker ): 2629*da0073e9SAndroid Build Coastguard Worker correct_result = process_fn(correct_result) 2630*da0073e9SAndroid Build Coastguard Worker new_result = process_fn(new_result) 2631*da0073e9SAndroid Build Coastguard Worker fp64_outputs = process_fn(fp64_outputs) 2632*da0073e9SAndroid Build Coastguard Worker 2633*da0073e9SAndroid Build Coastguard Worker if not same( 2634*da0073e9SAndroid Build Coastguard Worker correct_result, 2635*da0073e9SAndroid Build Coastguard Worker new_result, 2636*da0073e9SAndroid Build Coastguard Worker fp64_outputs, 2637*da0073e9SAndroid Build Coastguard Worker equal_nan=self.equal_nan, 2638*da0073e9SAndroid Build Coastguard Worker cos_similarity=cos_similarity, 2639*da0073e9SAndroid Build Coastguard Worker tol=tolerance, 2640*da0073e9SAndroid Build Coastguard Worker ): 2641*da0073e9SAndroid Build Coastguard Worker is_same = False 2642*da0073e9SAndroid Build Coastguard Worker except Exception as e: 2643*da0073e9SAndroid Build Coastguard Worker # Sometimes torch.allclose may throw RuntimeError 2644*da0073e9SAndroid Build Coastguard Worker is_same = False 2645*da0073e9SAndroid Build Coastguard Worker 2646*da0073e9SAndroid Build Coastguard Worker if not is_same: 2647*da0073e9SAndroid Build Coastguard Worker if self.args.skip_accuracy_check: 2648*da0073e9SAndroid Build Coastguard Worker accuracy_status = "pass_due_to_skip" 2649*da0073e9SAndroid Build Coastguard Worker else: 2650*da0073e9SAndroid Build Coastguard Worker accuracy_status = "fail_accuracy" 2651*da0073e9SAndroid Build Coastguard Worker return record_status(accuracy_status, dynamo_start_stats=start_stats) 2652*da0073e9SAndroid Build Coastguard Worker 2653*da0073e9SAndroid Build Coastguard Worker return record_status(accuracy_status, dynamo_start_stats=start_stats) 2654*da0073e9SAndroid Build Coastguard Worker 2655*da0073e9SAndroid Build Coastguard Worker def check_tolerance( 2656*da0073e9SAndroid Build Coastguard Worker self, name, model, example_inputs, optimize_ctx, base_device="cpu" 2657*da0073e9SAndroid Build Coastguard Worker ): 2658*da0073e9SAndroid Build Coastguard Worker """ 2659*da0073e9SAndroid Build Coastguard Worker Checks tolerance based on https://pytorch.org/docs/stable/generated/torch.allclose.html. 2660*da0073e9SAndroid Build Coastguard Worker """ 2661*da0073e9SAndroid Build Coastguard Worker tolerance_status = "pass" 2662*da0073e9SAndroid Build Coastguard Worker if name in self.skip_accuracy_checks_large_models_dashboard: 2663*da0073e9SAndroid Build Coastguard Worker tolerance_status = "pass_due_to_skip" 2664*da0073e9SAndroid Build Coastguard Worker return tolerance_status 2665*da0073e9SAndroid Build Coastguard Worker # Cast the model to float16/float32 as necessary 2666*da0073e9SAndroid Build Coastguard Worker model, example_inputs = self.maybe_cast(model, example_inputs) 2667*da0073e9SAndroid Build Coastguard Worker 2668*da0073e9SAndroid Build Coastguard Worker with self.pick_grad(name, self.args.training): 2669*da0073e9SAndroid Build Coastguard Worker # Get results of native pytorch 2670*da0073e9SAndroid Build Coastguard Worker reset_rng_state() 2671*da0073e9SAndroid Build Coastguard Worker model_copy = copy.deepcopy(model) 2672*da0073e9SAndroid Build Coastguard Worker model_copy = model_copy.to(base_device) 2673*da0073e9SAndroid Build Coastguard Worker example_inputs_copy = copy.deepcopy(example_inputs) 2674*da0073e9SAndroid Build Coastguard Worker example_inputs_copy = tree_map( 2675*da0073e9SAndroid Build Coastguard Worker lambda x: x.to(base_device), example_inputs_copy 2676*da0073e9SAndroid Build Coastguard Worker ) 2677*da0073e9SAndroid Build Coastguard Worker self.init_optimizer(name, base_device, model_copy.parameters()) 2678*da0073e9SAndroid Build Coastguard Worker correct_result = self.run_n_iterations(model_copy, example_inputs_copy) 2679*da0073e9SAndroid Build Coastguard Worker 2680*da0073e9SAndroid Build Coastguard Worker # Run with Dynamo 2681*da0073e9SAndroid Build Coastguard Worker # Sometime CI fails with random triton compilation failure which will be skipped for now 2682*da0073e9SAndroid Build Coastguard Worker # TODO: revisit this after switching to new Triton runtime 2683*da0073e9SAndroid Build Coastguard Worker reset_rng_state() 2684*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 2685*da0073e9SAndroid Build Coastguard Worker try: 2686*da0073e9SAndroid Build Coastguard Worker self.init_optimizer(name, current_device, model.parameters()) 2687*da0073e9SAndroid Build Coastguard Worker optimized_model_iter_fn = optimize_ctx(self.run_n_iterations) 2688*da0073e9SAndroid Build Coastguard Worker new_result = optimized_model_iter_fn(model, example_inputs) 2689*da0073e9SAndroid Build Coastguard Worker except Exception as e: 2690*da0073e9SAndroid Build Coastguard Worker log.exception("") 2691*da0073e9SAndroid Build Coastguard Worker print( 2692*da0073e9SAndroid Build Coastguard Worker "TorchDynamo optimized model failed to run because of following error" 2693*da0073e9SAndroid Build Coastguard Worker ) 2694*da0073e9SAndroid Build Coastguard Worker return "fail_to_run" 2695*da0073e9SAndroid Build Coastguard Worker 2696*da0073e9SAndroid Build Coastguard Worker def dump_max_mean_values(tol, ref, res): 2697*da0073e9SAndroid Build Coastguard Worker if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)): 2698*da0073e9SAndroid Build Coastguard Worker for refi, resi in zip(ref, res): 2699*da0073e9SAndroid Build Coastguard Worker dump_max_mean_values(tol, refi, resi) 2700*da0073e9SAndroid Build Coastguard Worker elif isinstance(ref, dict): 2701*da0073e9SAndroid Build Coastguard Worker for k in ref.keys(): 2702*da0073e9SAndroid Build Coastguard Worker dump_max_mean_values(tol, ref[k], res[k]) 2703*da0073e9SAndroid Build Coastguard Worker elif isinstance(ref, torch.Tensor): 2704*da0073e9SAndroid Build Coastguard Worker res = res.to(base_device) 2705*da0073e9SAndroid Build Coastguard Worker t = torch.abs(ref - res) / (1 + torch.abs(ref)) 2706*da0073e9SAndroid Build Coastguard Worker tol.append(t.flatten().to(torch.float32)) 2707*da0073e9SAndroid Build Coastguard Worker return tol 2708*da0073e9SAndroid Build Coastguard Worker 2709*da0073e9SAndroid Build Coastguard Worker tol = [] 2710*da0073e9SAndroid Build Coastguard Worker dump_max_mean_values(tol, correct_result, new_result) 2711*da0073e9SAndroid Build Coastguard Worker tol = torch.cat(tol) 2712*da0073e9SAndroid Build Coastguard Worker tol = torch.tensor(tol) 2713*da0073e9SAndroid Build Coastguard Worker max = torch.max(tol) 2714*da0073e9SAndroid Build Coastguard Worker mean = torch.mean(tol) 2715*da0073e9SAndroid Build Coastguard Worker div = torch.std(tol) 2716*da0073e9SAndroid Build Coastguard Worker headers = ["dev", "name", "batch_size", "max", "mean", "std"] 2717*da0073e9SAndroid Build Coastguard Worker fields = [ 2718*da0073e9SAndroid Build Coastguard Worker current_device, 2719*da0073e9SAndroid Build Coastguard Worker current_name, 2720*da0073e9SAndroid Build Coastguard Worker current_batch_size, 2721*da0073e9SAndroid Build Coastguard Worker max.item(), 2722*da0073e9SAndroid Build Coastguard Worker mean.item(), 2723*da0073e9SAndroid Build Coastguard Worker div.item(), 2724*da0073e9SAndroid Build Coastguard Worker ] 2725*da0073e9SAndroid Build Coastguard Worker output_csv(output_filename, headers, fields) 2726*da0073e9SAndroid Build Coastguard Worker return tolerance_status 2727*da0073e9SAndroid Build Coastguard Worker 2728*da0073e9SAndroid Build Coastguard Worker def run_performance_test( 2729*da0073e9SAndroid Build Coastguard Worker self, name, model, example_inputs, optimize_ctx, experiment, tag=None 2730*da0073e9SAndroid Build Coastguard Worker ): 2731*da0073e9SAndroid Build Coastguard Worker if self.args.xla: 2732*da0073e9SAndroid Build Coastguard Worker with self.pick_grad(name, self.args.training): 2733*da0073e9SAndroid Build Coastguard Worker return experiment(*self.maybe_cast(model, example_inputs)) 2734*da0073e9SAndroid Build Coastguard Worker 2735*da0073e9SAndroid Build Coastguard Worker def warmup(fn, model, example_inputs, mode, niters=5): 2736*da0073e9SAndroid Build Coastguard Worker peak_mem = 0 2737*da0073e9SAndroid Build Coastguard Worker start_stats = get_dynamo_stats() 2738*da0073e9SAndroid Build Coastguard Worker try: 2739*da0073e9SAndroid Build Coastguard Worker if current_device == "cuda": 2740*da0073e9SAndroid Build Coastguard Worker torch.cuda.reset_peak_memory_stats() 2741*da0073e9SAndroid Build Coastguard Worker empty_gpu_cache(current_device) 2742*da0073e9SAndroid Build Coastguard Worker t0 = time.perf_counter() 2743*da0073e9SAndroid Build Coastguard Worker for _ in range(niters): 2744*da0073e9SAndroid Build Coastguard Worker fn(model, example_inputs) 2745*da0073e9SAndroid Build Coastguard Worker t1 = time.perf_counter() 2746*da0073e9SAndroid Build Coastguard Worker latency = t1 - t0 2747*da0073e9SAndroid Build Coastguard Worker if current_device == "cuda": 2748*da0073e9SAndroid Build Coastguard Worker peak_mem = get_peak_memory() 2749*da0073e9SAndroid Build Coastguard Worker elif current_device == "cpu": 2750*da0073e9SAndroid Build Coastguard Worker total = psutil.virtual_memory().total 2751*da0073e9SAndroid Build Coastguard Worker percentage = psutil.Process(os.getpid()).memory_percent() 2752*da0073e9SAndroid Build Coastguard Worker peak_mem = percentage * total / 10**9 2753*da0073e9SAndroid Build Coastguard Worker except Exception: 2754*da0073e9SAndroid Build Coastguard Worker log.exception("Backend %s failed in warmup()", mode) 2755*da0073e9SAndroid Build Coastguard Worker return sys.exit(-1) 2756*da0073e9SAndroid Build Coastguard Worker dynamo_stats = get_dynamo_stats() 2757*da0073e9SAndroid Build Coastguard Worker dynamo_stats.subtract(start_stats) 2758*da0073e9SAndroid Build Coastguard Worker return latency, peak_mem, dynamo_stats 2759*da0073e9SAndroid Build Coastguard Worker 2760*da0073e9SAndroid Build Coastguard Worker # Cast the model to float16/float32 as necessary 2761*da0073e9SAndroid Build Coastguard Worker model, example_inputs = self.maybe_cast(model, example_inputs) 2762*da0073e9SAndroid Build Coastguard Worker 2763*da0073e9SAndroid Build Coastguard Worker # Use distributed wrapping as necessary 2764*da0073e9SAndroid Build Coastguard Worker model = self.deepcopy_and_maybe_parallelize(model) 2765*da0073e9SAndroid Build Coastguard Worker 2766*da0073e9SAndroid Build Coastguard Worker self.init_optimizer(name, current_device, model.parameters()) 2767*da0073e9SAndroid Build Coastguard Worker 2768*da0073e9SAndroid Build Coastguard Worker # The self.autocast context is needed for the model we export with aot_compile, 2769*da0073e9SAndroid Build Coastguard Worker # similar to what we do in the check_accuracy function 2770*da0073e9SAndroid Build Coastguard Worker ctx = ( 2771*da0073e9SAndroid Build Coastguard Worker self.autocast(**self.autocast_arg) 2772*da0073e9SAndroid Build Coastguard Worker if self.args.export_aot_inductor 2773*da0073e9SAndroid Build Coastguard Worker else contextlib.nullcontext() 2774*da0073e9SAndroid Build Coastguard Worker ) 2775*da0073e9SAndroid Build Coastguard Worker 2776*da0073e9SAndroid Build Coastguard Worker with self.pick_grad(name, self.args.training), ctx: 2777*da0073e9SAndroid Build Coastguard Worker ok, total = Stats.reset_counters() 2778*da0073e9SAndroid Build Coastguard Worker experiment_kwargs = {} 2779*da0073e9SAndroid Build Coastguard Worker if tag is not None: 2780*da0073e9SAndroid Build Coastguard Worker experiment_kwargs["tag"] = tag 2781*da0073e9SAndroid Build Coastguard Worker results = [] 2782*da0073e9SAndroid Build Coastguard Worker with maybe_snapshot_memory( 2783*da0073e9SAndroid Build Coastguard Worker self.args.snapshot_memory, f"eager_{self.args.only}" 2784*da0073e9SAndroid Build Coastguard Worker ): 2785*da0073e9SAndroid Build Coastguard Worker eager_latency, eager_peak_mem, _ = warmup( 2786*da0073e9SAndroid Build Coastguard Worker self.model_iter_fn, model, example_inputs, "eager" 2787*da0073e9SAndroid Build Coastguard Worker ) 2788*da0073e9SAndroid Build Coastguard Worker if self.args.use_warm_peak_memory: 2789*da0073e9SAndroid Build Coastguard Worker _, eager_peak_mem, _ = warmup( 2790*da0073e9SAndroid Build Coastguard Worker self.model_iter_fn, model, example_inputs, "eager", niters=1 2791*da0073e9SAndroid Build Coastguard Worker ) 2792*da0073e9SAndroid Build Coastguard Worker 2793*da0073e9SAndroid Build Coastguard Worker if self.args.export_aot_inductor: 2794*da0073e9SAndroid Build Coastguard Worker t_0 = time.perf_counter() 2795*da0073e9SAndroid Build Coastguard Worker optimized_model_iter_fn = optimize_ctx 2796*da0073e9SAndroid Build Coastguard Worker t_1 = time.perf_counter() 2797*da0073e9SAndroid Build Coastguard Worker aot_compilation_time = t_1 - t_0 2798*da0073e9SAndroid Build Coastguard Worker else: 2799*da0073e9SAndroid Build Coastguard Worker optimized_model_iter_fn = optimize_ctx(self.model_iter_fn) 2800*da0073e9SAndroid Build Coastguard Worker aot_compilation_time = 0 2801*da0073e9SAndroid Build Coastguard Worker 2802*da0073e9SAndroid Build Coastguard Worker with maybe_enable_compiled_autograd( 2803*da0073e9SAndroid Build Coastguard Worker self.args.compiled_autograd 2804*da0073e9SAndroid Build Coastguard Worker ), maybe_snapshot_memory( 2805*da0073e9SAndroid Build Coastguard Worker self.args.snapshot_memory, f"compiled_{self.args.only}" 2806*da0073e9SAndroid Build Coastguard Worker ): 2807*da0073e9SAndroid Build Coastguard Worker dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup( 2808*da0073e9SAndroid Build Coastguard Worker optimized_model_iter_fn, model, example_inputs, "dynamo" 2809*da0073e9SAndroid Build Coastguard Worker ) 2810*da0073e9SAndroid Build Coastguard Worker if self.args.use_warm_peak_memory: 2811*da0073e9SAndroid Build Coastguard Worker _, dynamo_peak_mem, _ = warmup( 2812*da0073e9SAndroid Build Coastguard Worker optimized_model_iter_fn, 2813*da0073e9SAndroid Build Coastguard Worker model, 2814*da0073e9SAndroid Build Coastguard Worker example_inputs, 2815*da0073e9SAndroid Build Coastguard Worker "dynamo", 2816*da0073e9SAndroid Build Coastguard Worker niters=1, 2817*da0073e9SAndroid Build Coastguard Worker ) 2818*da0073e9SAndroid Build Coastguard Worker 2819*da0073e9SAndroid Build Coastguard Worker if self.args.profile_dynamo_cache_lookup: 2820*da0073e9SAndroid Build Coastguard Worker with torch.profiler.profile( 2821*da0073e9SAndroid Build Coastguard Worker activities=[torch.profiler.ProfilerActivity.CPU] 2822*da0073e9SAndroid Build Coastguard Worker ) as prof: 2823*da0073e9SAndroid Build Coastguard Worker with maybe_enable_compiled_autograd(self.args.compiled_autograd): 2824*da0073e9SAndroid Build Coastguard Worker warmup(optimized_model_iter_fn, model, example_inputs, "dynamo") 2825*da0073e9SAndroid Build Coastguard Worker 2826*da0073e9SAndroid Build Coastguard Worker events = list( 2827*da0073e9SAndroid Build Coastguard Worker filter( 2828*da0073e9SAndroid Build Coastguard Worker lambda event: "TorchDynamo Cache Lookup" in event.key, 2829*da0073e9SAndroid Build Coastguard Worker prof.key_averages(), 2830*da0073e9SAndroid Build Coastguard Worker ) 2831*da0073e9SAndroid Build Coastguard Worker ) 2832*da0073e9SAndroid Build Coastguard Worker dynamo_cache_lookup_latency = events[0].self_cpu_time_total 2833*da0073e9SAndroid Build Coastguard Worker 2834*da0073e9SAndroid Build Coastguard Worker compilation_time = dynamo_latency - eager_latency + aot_compilation_time 2835*da0073e9SAndroid Build Coastguard Worker compression_ratio = ( 2836*da0073e9SAndroid Build Coastguard Worker eager_peak_mem / dynamo_peak_mem if dynamo_peak_mem else 0.0 2837*da0073e9SAndroid Build Coastguard Worker ) 2838*da0073e9SAndroid Build Coastguard Worker if self.args.print_memory: 2839*da0073e9SAndroid Build Coastguard Worker print( 2840*da0073e9SAndroid Build Coastguard Worker f"memory: eager: {eager_peak_mem:.2f} GB, " 2841*da0073e9SAndroid Build Coastguard Worker f"dynamo: {dynamo_peak_mem:.2f} GB, " 2842*da0073e9SAndroid Build Coastguard Worker f"ratio: {compression_ratio:.2f}" 2843*da0073e9SAndroid Build Coastguard Worker ) 2844*da0073e9SAndroid Build Coastguard Worker 2845*da0073e9SAndroid Build Coastguard Worker if self.args.print_compilation_time: 2846*da0073e9SAndroid Build Coastguard Worker print(f"Compilation time: {compilation_time:.2f}") 2847*da0073e9SAndroid Build Coastguard Worker 2848*da0073e9SAndroid Build Coastguard Worker if experiment.func is speedup_experiment: 2849*da0073e9SAndroid Build Coastguard Worker experiment_kwargs["compilation_latency"] = compilation_time 2850*da0073e9SAndroid Build Coastguard Worker experiment_kwargs["compression_ratio"] = compression_ratio 2851*da0073e9SAndroid Build Coastguard Worker experiment_kwargs["eager_peak_mem"] = eager_peak_mem 2852*da0073e9SAndroid Build Coastguard Worker experiment_kwargs["dynamo_peak_mem"] = dynamo_peak_mem 2853*da0073e9SAndroid Build Coastguard Worker experiment_kwargs["dynamo_stats"] = dynamo_stats 2854*da0073e9SAndroid Build Coastguard Worker if self.args.profile_dynamo_cache_lookup: 2855*da0073e9SAndroid Build Coastguard Worker experiment_kwargs[ 2856*da0073e9SAndroid Build Coastguard Worker "cache_lookup_latency" 2857*da0073e9SAndroid Build Coastguard Worker ] = dynamo_cache_lookup_latency 2858*da0073e9SAndroid Build Coastguard Worker 2859*da0073e9SAndroid Build Coastguard Worker if experiment.func is coverage_experiment: 2860*da0073e9SAndroid Build Coastguard Worker ok, total = Stats.reset_counters() 2861*da0073e9SAndroid Build Coastguard Worker results = [] 2862*da0073e9SAndroid Build Coastguard Worker # run with torch._dynamo few times to populate the cache 2863*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 2864*da0073e9SAndroid Build Coastguard Worker optimized_model_iter_fn(model, example_inputs) 2865*da0073e9SAndroid Build Coastguard Worker _, frames_second_pass = Stats.reset_counters() # should be 0 2866*da0073e9SAndroid Build Coastguard Worker if frames_second_pass > 0: 2867*da0073e9SAndroid Build Coastguard Worker optimized_model_iter_fn(model, example_inputs) 2868*da0073e9SAndroid Build Coastguard Worker _, frames_third_pass = Stats.reset_counters() # should be 0 2869*da0073e9SAndroid Build Coastguard Worker else: 2870*da0073e9SAndroid Build Coastguard Worker frames_third_pass = 0 2871*da0073e9SAndroid Build Coastguard Worker 2872*da0073e9SAndroid Build Coastguard Worker results.append( 2873*da0073e9SAndroid Build Coastguard Worker f"{ok:3}/{total:3} +{frames_third_pass} frames {compilation_time:3.0f}s" 2874*da0073e9SAndroid Build Coastguard Worker ) 2875*da0073e9SAndroid Build Coastguard Worker 2876*da0073e9SAndroid Build Coastguard Worker if experiment.func is speedup_experiment_onnx: 2877*da0073e9SAndroid Build Coastguard Worker experiment = functools.partial( 2878*da0073e9SAndroid Build Coastguard Worker experiment, optimized_model_iter_fn.context.onnx_model 2879*da0073e9SAndroid Build Coastguard Worker ) 2880*da0073e9SAndroid Build Coastguard Worker 2881*da0073e9SAndroid Build Coastguard Worker if not hasattr(model, name): 2882*da0073e9SAndroid Build Coastguard Worker model.name = name 2883*da0073e9SAndroid Build Coastguard Worker results.append(experiment(model, example_inputs, **experiment_kwargs)) 2884*da0073e9SAndroid Build Coastguard Worker return " ".join(map(str, results)) 2885*da0073e9SAndroid Build Coastguard Worker 2886*da0073e9SAndroid Build Coastguard Worker def minify_model( 2887*da0073e9SAndroid Build Coastguard Worker self, 2888*da0073e9SAndroid Build Coastguard Worker name, 2889*da0073e9SAndroid Build Coastguard Worker model, 2890*da0073e9SAndroid Build Coastguard Worker example_inputs, 2891*da0073e9SAndroid Build Coastguard Worker optimize_ctx, 2892*da0073e9SAndroid Build Coastguard Worker experiment, 2893*da0073e9SAndroid Build Coastguard Worker tag, 2894*da0073e9SAndroid Build Coastguard Worker ): 2895*da0073e9SAndroid Build Coastguard Worker logging.info("Minifying %s...", name) 2896*da0073e9SAndroid Build Coastguard Worker os.environ["TORCH_COMPILE_DEBUG"] = "1" 2897*da0073e9SAndroid Build Coastguard Worker os.environ["TORCHDYNAMO_REPRO_AFTER"] = "dynamo" 2898*da0073e9SAndroid Build Coastguard Worker os.environ["TORCHDYNAMO_REPRO_LEVEL"] = "4" 2899*da0073e9SAndroid Build Coastguard Worker 2900*da0073e9SAndroid Build Coastguard Worker self.check_accuracy(name, model, example_inputs, optimize_ctx, experiment, tag) 2901*da0073e9SAndroid Build Coastguard Worker 2902*da0073e9SAndroid Build Coastguard Worker if self.args.output_directory: 2903*da0073e9SAndroid Build Coastguard Worker repro_dir = self.args.output_directory 2904*da0073e9SAndroid Build Coastguard Worker else: 2905*da0073e9SAndroid Build Coastguard Worker repro_dir = torch._dynamo.config.base_dir 2906*da0073e9SAndroid Build Coastguard Worker 2907*da0073e9SAndroid Build Coastguard Worker try: 2908*da0073e9SAndroid Build Coastguard Worker shutil.move("repro.py", f"{repro_dir}/{name}_repro.py") 2909*da0073e9SAndroid Build Coastguard Worker except OSError as e: 2910*da0073e9SAndroid Build Coastguard Worker logging.error("Could not find repro script for model %s", name) 2911*da0073e9SAndroid Build Coastguard Worker else: 2912*da0073e9SAndroid Build Coastguard Worker logging.info( 2913*da0073e9SAndroid Build Coastguard Worker "Repro script for model %s with minified graph saved to %s", 2914*da0073e9SAndroid Build Coastguard Worker name, 2915*da0073e9SAndroid Build Coastguard Worker repro_dir, 2916*da0073e9SAndroid Build Coastguard Worker ) 2917*da0073e9SAndroid Build Coastguard Worker 2918*da0073e9SAndroid Build Coastguard Worker def maybe_preserve_compile_debug(self, name, status): 2919*da0073e9SAndroid Build Coastguard Worker if ( 2920*da0073e9SAndroid Build Coastguard Worker name in CI_PRESERVE_COMPILE_DEBUG 2921*da0073e9SAndroid Build Coastguard Worker and status in CI_PRESERVE_COMPILE_DEBUG[name] 2922*da0073e9SAndroid Build Coastguard Worker ): 2923*da0073e9SAndroid Build Coastguard Worker src_dir = torch._dynamo.utils.get_debug_dir() 2924*da0073e9SAndroid Build Coastguard Worker if os.path.isdir(src_dir): 2925*da0073e9SAndroid Build Coastguard Worker dbg_dir = os.path.join( 2926*da0073e9SAndroid Build Coastguard Worker os.getcwd(), "test", "debug", "torch_compile_debug" 2927*da0073e9SAndroid Build Coastguard Worker ) 2928*da0073e9SAndroid Build Coastguard Worker dst_dir = os.path.join(dbg_dir, os.path.basename(src_dir)) 2929*da0073e9SAndroid Build Coastguard Worker try: 2930*da0073e9SAndroid Build Coastguard Worker os.makedirs(dbg_dir, exist_ok=True) 2931*da0073e9SAndroid Build Coastguard Worker os.rename(src_dir, dst_dir) 2932*da0073e9SAndroid Build Coastguard Worker log.warning("Moved %s to %s", src_dir, dst_dir) 2933*da0073e9SAndroid Build Coastguard Worker except OSError: 2934*da0073e9SAndroid Build Coastguard Worker log.exception("Failed to preserve %s", src_dir) 2935*da0073e9SAndroid Build Coastguard Worker 2936*da0073e9SAndroid Build Coastguard Worker def run_one_model( 2937*da0073e9SAndroid Build Coastguard Worker self, 2938*da0073e9SAndroid Build Coastguard Worker name, 2939*da0073e9SAndroid Build Coastguard Worker model, 2940*da0073e9SAndroid Build Coastguard Worker example_inputs, 2941*da0073e9SAndroid Build Coastguard Worker optimize_ctx, 2942*da0073e9SAndroid Build Coastguard Worker experiment, 2943*da0073e9SAndroid Build Coastguard Worker explain=False, 2944*da0073e9SAndroid Build Coastguard Worker tag=None, 2945*da0073e9SAndroid Build Coastguard Worker ): 2946*da0073e9SAndroid Build Coastguard Worker mode = "train" if self.args.training else "eval" 2947*da0073e9SAndroid Build Coastguard Worker msg = f"{current_device:4} {mode:5} {current_name:34} " 2948*da0073e9SAndroid Build Coastguard Worker if tag: 2949*da0073e9SAndroid Build Coastguard Worker msg += f" {tag:26}" 2950*da0073e9SAndroid Build Coastguard Worker print(msg, flush=True) 2951*da0073e9SAndroid Build Coastguard Worker 2952*da0073e9SAndroid Build Coastguard Worker start_stats = get_dynamo_stats() 2953*da0073e9SAndroid Build Coastguard Worker 2954*da0073e9SAndroid Build Coastguard Worker if self.args.accuracy: 2955*da0073e9SAndroid Build Coastguard Worker status = self.check_accuracy( 2956*da0073e9SAndroid Build Coastguard Worker name, model, example_inputs, optimize_ctx, experiment, tag 2957*da0073e9SAndroid Build Coastguard Worker ) 2958*da0073e9SAndroid Build Coastguard Worker print(status) 2959*da0073e9SAndroid Build Coastguard Worker if status == "fail_accuracy" and self.args.minify: 2960*da0073e9SAndroid Build Coastguard Worker self.minify_model( 2961*da0073e9SAndroid Build Coastguard Worker name, model, example_inputs, optimize_ctx, experiment, tag 2962*da0073e9SAndroid Build Coastguard Worker ) 2963*da0073e9SAndroid Build Coastguard Worker elif self.args.tolerance: 2964*da0073e9SAndroid Build Coastguard Worker status = self.check_tolerance(name, model, example_inputs, optimize_ctx) 2965*da0073e9SAndroid Build Coastguard Worker print(status) 2966*da0073e9SAndroid Build Coastguard Worker elif self.args.performance: 2967*da0073e9SAndroid Build Coastguard Worker status = self.run_performance_test( 2968*da0073e9SAndroid Build Coastguard Worker name, model, example_inputs, optimize_ctx, experiment, tag 2969*da0073e9SAndroid Build Coastguard Worker ) 2970*da0073e9SAndroid Build Coastguard Worker print(status) 2971*da0073e9SAndroid Build Coastguard Worker empty_gpu_cache(current_device) 2972*da0073e9SAndroid Build Coastguard Worker 2973*da0073e9SAndroid Build Coastguard Worker self.maybe_preserve_compile_debug(name, status) 2974*da0073e9SAndroid Build Coastguard Worker 2975*da0073e9SAndroid Build Coastguard Worker if self.args.timing: 2976*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.utils import op_count, print_time_report 2977*da0073e9SAndroid Build Coastguard Worker from torch.utils._stats import simple_call_counter 2978*da0073e9SAndroid Build Coastguard Worker 2979*da0073e9SAndroid Build Coastguard Worker print_time_report() 2980*da0073e9SAndroid Build Coastguard Worker stats = "STATS: " 2981*da0073e9SAndroid Build Coastguard Worker stats = stats + " | ".join( 2982*da0073e9SAndroid Build Coastguard Worker itertools.chain( 2983*da0073e9SAndroid Build Coastguard Worker [f"call_* op count: {op_count}"], 2984*da0073e9SAndroid Build Coastguard Worker (f"{key}:{value}" for key, value in simple_call_counter.items()), 2985*da0073e9SAndroid Build Coastguard Worker ) 2986*da0073e9SAndroid Build Coastguard Worker ) 2987*da0073e9SAndroid Build Coastguard Worker print(stats) 2988*da0073e9SAndroid Build Coastguard Worker stats = get_dynamo_stats() 2989*da0073e9SAndroid Build Coastguard Worker stats.subtract(start_stats) 2990*da0073e9SAndroid Build Coastguard Worker 2991*da0073e9SAndroid Build Coastguard Worker if explain: 2992*da0073e9SAndroid Build Coastguard Worker print( 2993*da0073e9SAndroid Build Coastguard Worker f"Dynamo produced {stats['unique_graphs']} graphs " 2994*da0073e9SAndroid Build Coastguard Worker f"covering {stats['calls_captured']} ops with " 2995*da0073e9SAndroid Build Coastguard Worker f"{stats['graph_breaks']} graph breaks ({stats['unique_graph_breaks']} unique)" 2996*da0073e9SAndroid Build Coastguard Worker ) 2997*da0073e9SAndroid Build Coastguard Worker 2998*da0073e9SAndroid Build Coastguard Worker if explain or self.args.log_graph_breaks or self.args.print_graph_breaks: 2999*da0073e9SAndroid Build Coastguard Worker filename = f"{output_filename.rstrip('.csv')}_graph_breaks.csv" 3000*da0073e9SAndroid Build Coastguard Worker 3001*da0073e9SAndroid Build Coastguard Worker def add_double_quotes(x): 3002*da0073e9SAndroid Build Coastguard Worker # Delimiter because reason could have comma 3003*da0073e9SAndroid Build Coastguard Worker return f'"{x}"' 3004*da0073e9SAndroid Build Coastguard Worker 3005*da0073e9SAndroid Build Coastguard Worker for graph_break in graph_break_reasons: 3006*da0073e9SAndroid Build Coastguard Worker reason = add_double_quotes(graph_break.reason) 3007*da0073e9SAndroid Build Coastguard Worker user_stack = add_double_quotes( 3008*da0073e9SAndroid Build Coastguard Worker ", ".join([str(x) for x in graph_break.user_stack]) 3009*da0073e9SAndroid Build Coastguard Worker ) 3010*da0073e9SAndroid Build Coastguard Worker output_csv( 3011*da0073e9SAndroid Build Coastguard Worker filename, 3012*da0073e9SAndroid Build Coastguard Worker ["model", "reason", "user_stack"], 3013*da0073e9SAndroid Build Coastguard Worker [current_name, reason, user_stack], 3014*da0073e9SAndroid Build Coastguard Worker ) 3015*da0073e9SAndroid Build Coastguard Worker 3016*da0073e9SAndroid Build Coastguard Worker if self.args.stats: 3017*da0073e9SAndroid Build Coastguard Worker Stats.print_summary() 3018*da0073e9SAndroid Build Coastguard Worker 3019*da0073e9SAndroid Build Coastguard Worker 3020*da0073e9SAndroid Build Coastguard Workerdef help(fn): 3021*da0073e9SAndroid Build Coastguard Worker return fn.__doc__ 3022*da0073e9SAndroid Build Coastguard Worker 3023*da0073e9SAndroid Build Coastguard Worker 3024*da0073e9SAndroid Build Coastguard Workerdiff_branch_default = "DIFF-BRANCH-DEFAULT" 3025*da0073e9SAndroid Build Coastguard Worker 3026*da0073e9SAndroid Build Coastguard Worker 3027*da0073e9SAndroid Build Coastguard Workerdef should_diff_branch(args): 3028*da0073e9SAndroid Build Coastguard Worker return args.diff_branch != diff_branch_default 3029*da0073e9SAndroid Build Coastguard Worker 3030*da0073e9SAndroid Build Coastguard Worker 3031*da0073e9SAndroid Build Coastguard Workerdef parse_args(args=None): 3032*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser() 3033*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3034*da0073e9SAndroid Build Coastguard Worker "--filter", "-k", action="append", help="filter benchmarks with regexp" 3035*da0073e9SAndroid Build Coastguard Worker ) 3036*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3037*da0073e9SAndroid Build Coastguard Worker "--exclude", "-x", action="append", help="filter benchmarks with regexp" 3038*da0073e9SAndroid Build Coastguard Worker ) 3039*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3040*da0073e9SAndroid Build Coastguard Worker "--exclude-exact", action="append", help="filter benchmarks with exact match" 3041*da0073e9SAndroid Build Coastguard Worker ) 3042*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3043*da0073e9SAndroid Build Coastguard Worker "--total-partitions", 3044*da0073e9SAndroid Build Coastguard Worker type=int, 3045*da0073e9SAndroid Build Coastguard Worker default=1, 3046*da0073e9SAndroid Build Coastguard Worker choices=range(1, 10), 3047*da0073e9SAndroid Build Coastguard Worker help="Total number of partitions we want to divide the benchmark suite into", 3048*da0073e9SAndroid Build Coastguard Worker ) 3049*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3050*da0073e9SAndroid Build Coastguard Worker "--partition-id", 3051*da0073e9SAndroid Build Coastguard Worker type=int, 3052*da0073e9SAndroid Build Coastguard Worker default=0, 3053*da0073e9SAndroid Build Coastguard Worker help="ID of the benchmark suite partition to be run. Used to divide CI tasks", 3054*da0073e9SAndroid Build Coastguard Worker ) 3055*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3056*da0073e9SAndroid Build Coastguard Worker "--devices", "--device", "-d", action="append", help="cpu or cuda" 3057*da0073e9SAndroid Build Coastguard Worker ) 3058*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--device-index", help="CUDA device index") 3059*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3060*da0073e9SAndroid Build Coastguard Worker "--repeat", "-n", type=int, default=30, help="number of timing runs" 3061*da0073e9SAndroid Build Coastguard Worker ) 3062*da0073e9SAndroid Build Coastguard Worker iterations_per_run_help = """ 3063*da0073e9SAndroid Build Coastguard Worker Run this may iterations for each time measurement. This is mainly used for 3064*da0073e9SAndroid Build Coastguard Worker XLA training. We want to run multiple iterations per measurement so the 3065*da0073e9SAndroid Build Coastguard Worker tracing and computation for different iteartions can overlap with each 3066*da0073e9SAndroid Build Coastguard Worker other. This makes sure we have an accurate xla baseline. 3067*da0073e9SAndroid Build Coastguard Worker """ 3068*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3069*da0073e9SAndroid Build Coastguard Worker "--iterations-per-run", type=int, default=1, help=iterations_per_run_help 3070*da0073e9SAndroid Build Coastguard Worker ) 3071*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3072*da0073e9SAndroid Build Coastguard Worker "--randomize-input", 3073*da0073e9SAndroid Build Coastguard Worker action="store_true", 3074*da0073e9SAndroid Build Coastguard Worker help="Whether to randomize the input values. Dimensions will be kept the same.", 3075*da0073e9SAndroid Build Coastguard Worker ) 3076*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3077*da0073e9SAndroid Build Coastguard Worker "--threads", 3078*da0073e9SAndroid Build Coastguard Worker "-t", 3079*da0073e9SAndroid Build Coastguard Worker type=int, 3080*da0073e9SAndroid Build Coastguard Worker help="number of threads to use for eager and inductor", 3081*da0073e9SAndroid Build Coastguard Worker ) 3082*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3083*da0073e9SAndroid Build Coastguard Worker "--nopython", action="store_true", help="Turn graph breaks into errors" 3084*da0073e9SAndroid Build Coastguard Worker ) 3085*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3086*da0073e9SAndroid Build Coastguard Worker "--no-skip", 3087*da0073e9SAndroid Build Coastguard Worker action="store_true", 3088*da0073e9SAndroid Build Coastguard Worker help="run models that are in the global SKIP list", 3089*da0073e9SAndroid Build Coastguard Worker ) 3090*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3091*da0073e9SAndroid Build Coastguard Worker "--prims-nvfuser", action="store_true", help="user prims + nvfuser backend" 3092*da0073e9SAndroid Build Coastguard Worker ) 3093*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3094*da0073e9SAndroid Build Coastguard Worker "--dump-raw-metrics", 3095*da0073e9SAndroid Build Coastguard Worker action="store_true", 3096*da0073e9SAndroid Build Coastguard Worker help="dump raw timing metrics from speedup experiment", 3097*da0073e9SAndroid Build Coastguard Worker ) 3098*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3099*da0073e9SAndroid Build Coastguard Worker "--log-operator-inputs", 3100*da0073e9SAndroid Build Coastguard Worker action="store_true", 3101*da0073e9SAndroid Build Coastguard Worker default=False, 3102*da0073e9SAndroid Build Coastguard Worker ) 3103*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3104*da0073e9SAndroid Build Coastguard Worker "--channels-last", 3105*da0073e9SAndroid Build Coastguard Worker action="store_true", 3106*da0073e9SAndroid Build Coastguard Worker default=False, 3107*da0073e9SAndroid Build Coastguard Worker help="use channels last format", 3108*da0073e9SAndroid Build Coastguard Worker ) 3109*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3110*da0073e9SAndroid Build Coastguard Worker "--batch-size", "--batch_size", type=int, help="batch size for benchmarking" 3111*da0073e9SAndroid Build Coastguard Worker ) 3112*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3113*da0073e9SAndroid Build Coastguard Worker "--iterations", type=int, default=2, help="how many iterations to run" 3114*da0073e9SAndroid Build Coastguard Worker ) 3115*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3116*da0073e9SAndroid Build Coastguard Worker "--batch-size-file", type=str, help="String to load batch size from" 3117*da0073e9SAndroid Build Coastguard Worker ) 3118*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--cosine", action="store_true", help="use cosine similarity") 3119*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3120*da0073e9SAndroid Build Coastguard Worker "--freezing", action="store_true", help="turn on freezing", default=False 3121*da0073e9SAndroid Build Coastguard Worker ) 3122*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3123*da0073e9SAndroid Build Coastguard Worker "--ci", action="store_true", help="Flag to tell that its a CI run" 3124*da0073e9SAndroid Build Coastguard Worker ) 3125*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3126*da0073e9SAndroid Build Coastguard Worker "--dashboard", action="store_true", help="Flag to tell that its a Dashboard run" 3127*da0073e9SAndroid Build Coastguard Worker ) 3128*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3129*da0073e9SAndroid Build Coastguard Worker "--skip-fp64-check", action="store_true", help="skip accuracy check using fp64" 3130*da0073e9SAndroid Build Coastguard Worker ) 3131*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3132*da0073e9SAndroid Build Coastguard Worker "--fast", "-f", action="store_true", help="skip slow benchmarks" 3133*da0073e9SAndroid Build Coastguard Worker ) 3134*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3135*da0073e9SAndroid Build Coastguard Worker "--only", 3136*da0073e9SAndroid Build Coastguard Worker help="""Run just one model from torchbench. Or 3137*da0073e9SAndroid Build Coastguard Worker specify the path and class name of the model in format like: 3138*da0073e9SAndroid Build Coastguard Worker --only=path:<MODEL_FILE_PATH>,class:<CLASS_NAME> 3139*da0073e9SAndroid Build Coastguard Worker 3140*da0073e9SAndroid Build Coastguard Worker Due to the fact that dynamo changes current working directory, 3141*da0073e9SAndroid Build Coastguard Worker the path should be an absolute path. 3142*da0073e9SAndroid Build Coastguard Worker 3143*da0073e9SAndroid Build Coastguard Worker The class should have a method get_example_inputs to return the inputs 3144*da0073e9SAndroid Build Coastguard Worker for the model. An example looks like 3145*da0073e9SAndroid Build Coastguard Worker ``` 3146*da0073e9SAndroid Build Coastguard Worker class LinearModel(nn.Module): 3147*da0073e9SAndroid Build Coastguard Worker def __init__(self): 3148*da0073e9SAndroid Build Coastguard Worker super().__init__() 3149*da0073e9SAndroid Build Coastguard Worker self.linear = nn.Linear(10, 10) 3150*da0073e9SAndroid Build Coastguard Worker 3151*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3152*da0073e9SAndroid Build Coastguard Worker return self.linear(x) 3153*da0073e9SAndroid Build Coastguard Worker 3154*da0073e9SAndroid Build Coastguard Worker def get_example_inputs(self): 3155*da0073e9SAndroid Build Coastguard Worker return (torch.randn(2, 10),) 3156*da0073e9SAndroid Build Coastguard Worker ``` 3157*da0073e9SAndroid Build Coastguard Worker """, 3158*da0073e9SAndroid Build Coastguard Worker ) 3159*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3160*da0073e9SAndroid Build Coastguard Worker "--multiprocess", 3161*da0073e9SAndroid Build Coastguard Worker action="store_true", 3162*da0073e9SAndroid Build Coastguard Worker help="Create n processes based on the number of devices (distributed use case).", 3163*da0073e9SAndroid Build Coastguard Worker ) 3164*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3165*da0073e9SAndroid Build Coastguard Worker "--ddp", 3166*da0073e9SAndroid Build Coastguard Worker action="store_true", 3167*da0073e9SAndroid Build Coastguard Worker help="Wraps model in DDP before running it, and uses dynamo DDPOptmizer (graph breaks) by default.", 3168*da0073e9SAndroid Build Coastguard Worker ) 3169*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3170*da0073e9SAndroid Build Coastguard Worker "--fsdp", 3171*da0073e9SAndroid Build Coastguard Worker action="store_true", 3172*da0073e9SAndroid Build Coastguard Worker help="""Wraps model in FSDP before running it. 3173*da0073e9SAndroid Build Coastguard Worker Doesn't recursively wrap, mainly useful for checking dynamo UnspecNNModule compatibility 3174*da0073e9SAndroid Build Coastguard Worker """, 3175*da0073e9SAndroid Build Coastguard Worker ) 3176*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3177*da0073e9SAndroid Build Coastguard Worker "--optimize-ddp-mode", 3178*da0073e9SAndroid Build Coastguard Worker type=str, 3179*da0073e9SAndroid Build Coastguard Worker default="ddp_optimizer", 3180*da0073e9SAndroid Build Coastguard Worker help="Specify the DDP optimization mode -- the value of torch._dynamo.config.optimize_ddp.", 3181*da0073e9SAndroid Build Coastguard Worker ) 3182*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3183*da0073e9SAndroid Build Coastguard Worker "--distributed-master-port", 3184*da0073e9SAndroid Build Coastguard Worker default="6789", 3185*da0073e9SAndroid Build Coastguard Worker help="Port to bind for for torch.distributed. Use the default unless it's conflicting with another user", 3186*da0073e9SAndroid Build Coastguard Worker ) 3187*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3188*da0073e9SAndroid Build Coastguard Worker "--dynamic-shapes", 3189*da0073e9SAndroid Build Coastguard Worker action="store_true", 3190*da0073e9SAndroid Build Coastguard Worker help="Runs a dynamic shapes version of the benchmark, if available.", 3191*da0073e9SAndroid Build Coastguard Worker ) 3192*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3193*da0073e9SAndroid Build Coastguard Worker "--propagate-real-tensors", 3194*da0073e9SAndroid Build Coastguard Worker action="store_true", 3195*da0073e9SAndroid Build Coastguard Worker help="Capture as much data dependent as you can by unsoundly propagating real tensors", 3196*da0073e9SAndroid Build Coastguard Worker ) 3197*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3198*da0073e9SAndroid Build Coastguard Worker "--dynamic-batch-only", 3199*da0073e9SAndroid Build Coastguard Worker action="store_true", 3200*da0073e9SAndroid Build Coastguard Worker help="Only assume batch dimension is dynamic. Implies --dynamic-shapes", 3201*da0073e9SAndroid Build Coastguard Worker ) 3202*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3203*da0073e9SAndroid Build Coastguard Worker "--specialize-int", action="store_true", help="Run with specialize_int=True." 3204*da0073e9SAndroid Build Coastguard Worker ) 3205*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3206*da0073e9SAndroid Build Coastguard Worker "--use-eval-mode", 3207*da0073e9SAndroid Build Coastguard Worker action="store_true", 3208*da0073e9SAndroid Build Coastguard Worker help="sets model.eval() to reduce randomness", 3209*da0073e9SAndroid Build Coastguard Worker ) 3210*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3211*da0073e9SAndroid Build Coastguard Worker "--skip-accuracy-check", 3212*da0073e9SAndroid Build Coastguard Worker action="store_true", 3213*da0073e9SAndroid Build Coastguard Worker help="keeps running even when accuracy fails", 3214*da0073e9SAndroid Build Coastguard Worker ) 3215*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3216*da0073e9SAndroid Build Coastguard Worker "--generate-aot-autograd-stats", 3217*da0073e9SAndroid Build Coastguard Worker action="store_true", 3218*da0073e9SAndroid Build Coastguard Worker help="Generates AOT Autograd stats like how mnay graphs are sent to AOT", 3219*da0073e9SAndroid Build Coastguard Worker ) 3220*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3221*da0073e9SAndroid Build Coastguard Worker "--inductor-settings", 3222*da0073e9SAndroid Build Coastguard Worker action="store_true", 3223*da0073e9SAndroid Build Coastguard Worker help="Use same settings as --inductor for baseline comparisons", 3224*da0073e9SAndroid Build Coastguard Worker ) 3225*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3226*da0073e9SAndroid Build Coastguard Worker "--suppress-errors", 3227*da0073e9SAndroid Build Coastguard Worker action="store_true", 3228*da0073e9SAndroid Build Coastguard Worker help="Suppress errors instead of raising them", 3229*da0073e9SAndroid Build Coastguard Worker ) 3230*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3231*da0073e9SAndroid Build Coastguard Worker "--output", 3232*da0073e9SAndroid Build Coastguard Worker help="Overrides the output filename", 3233*da0073e9SAndroid Build Coastguard Worker ) 3234*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3235*da0073e9SAndroid Build Coastguard Worker "--output-directory", 3236*da0073e9SAndroid Build Coastguard Worker help="Overrides the directory to place output files.", 3237*da0073e9SAndroid Build Coastguard Worker ) 3238*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3239*da0073e9SAndroid Build Coastguard Worker "--disable-output", 3240*da0073e9SAndroid Build Coastguard Worker action="store_true", 3241*da0073e9SAndroid Build Coastguard Worker help="Disable writing of output files, e.g., for warm-up runs", 3242*da0073e9SAndroid Build Coastguard Worker ) 3243*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3244*da0073e9SAndroid Build Coastguard Worker "--baseline", 3245*da0073e9SAndroid Build Coastguard Worker help="Compare with a prior --output", 3246*da0073e9SAndroid Build Coastguard Worker ) 3247*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3248*da0073e9SAndroid Build Coastguard Worker "--part", 3249*da0073e9SAndroid Build Coastguard Worker default=None, 3250*da0073e9SAndroid Build Coastguard Worker help="Specify the part of the model to run.", 3251*da0073e9SAndroid Build Coastguard Worker ) 3252*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3253*da0073e9SAndroid Build Coastguard Worker "--export-profiler-trace", 3254*da0073e9SAndroid Build Coastguard Worker action="store_true", 3255*da0073e9SAndroid Build Coastguard Worker help="exports trace of kineto profiler", 3256*da0073e9SAndroid Build Coastguard Worker ) 3257*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3258*da0073e9SAndroid Build Coastguard Worker "--profiler-trace-name", 3259*da0073e9SAndroid Build Coastguard Worker "--profiler_trace_name", 3260*da0073e9SAndroid Build Coastguard Worker help="Overwrites exported trace name", 3261*da0073e9SAndroid Build Coastguard Worker ) 3262*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3263*da0073e9SAndroid Build Coastguard Worker "--diff-branch", 3264*da0073e9SAndroid Build Coastguard Worker default=diff_branch_default, 3265*da0073e9SAndroid Build Coastguard Worker help="delta current branch against given branch.", 3266*da0073e9SAndroid Build Coastguard Worker ) 3267*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3268*da0073e9SAndroid Build Coastguard Worker "--tag", default=None, help="Specify a tag to be included in csv files." 3269*da0073e9SAndroid Build Coastguard Worker ) 3270*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3271*da0073e9SAndroid Build Coastguard Worker "--explain", 3272*da0073e9SAndroid Build Coastguard Worker action="store_true", 3273*da0073e9SAndroid Build Coastguard Worker help="print some graph/op statistics during the run, similar to .explain()", 3274*da0073e9SAndroid Build Coastguard Worker ) 3275*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3276*da0073e9SAndroid Build Coastguard Worker "--stats", 3277*da0073e9SAndroid Build Coastguard Worker action="store_true", 3278*da0073e9SAndroid Build Coastguard Worker help="print graph counter stats", 3279*da0073e9SAndroid Build Coastguard Worker ) 3280*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3281*da0073e9SAndroid Build Coastguard Worker "--use-warm-peak-memory", 3282*da0073e9SAndroid Build Coastguard Worker "--use_warm_peak_memory", 3283*da0073e9SAndroid Build Coastguard Worker action="store_true", 3284*da0073e9SAndroid Build Coastguard Worker help="Measure peak memory using a warm run to reduce autotuning noise", 3285*da0073e9SAndroid Build Coastguard Worker ) 3286*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3287*da0073e9SAndroid Build Coastguard Worker "--print-memory", 3288*da0073e9SAndroid Build Coastguard Worker action="store_true", 3289*da0073e9SAndroid Build Coastguard Worker help="print extra memory statistics", 3290*da0073e9SAndroid Build Coastguard Worker ) 3291*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3292*da0073e9SAndroid Build Coastguard Worker "--print-compilation-time", 3293*da0073e9SAndroid Build Coastguard Worker action="store_true", 3294*da0073e9SAndroid Build Coastguard Worker help="print compilation latency", 3295*da0073e9SAndroid Build Coastguard Worker ) 3296*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3297*da0073e9SAndroid Build Coastguard Worker "--print-dataframe-summary", 3298*da0073e9SAndroid Build Coastguard Worker action="store_true", 3299*da0073e9SAndroid Build Coastguard Worker help="print dataframe result used for calculating accuracy", 3300*da0073e9SAndroid Build Coastguard Worker ) 3301*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3302*da0073e9SAndroid Build Coastguard Worker "--disable-cudagraphs", 3303*da0073e9SAndroid Build Coastguard Worker action="store_true", 3304*da0073e9SAndroid Build Coastguard Worker help="Disables cudagraphs for Inductor", 3305*da0073e9SAndroid Build Coastguard Worker ) 3306*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3307*da0073e9SAndroid Build Coastguard Worker "--disable-split-reductions", 3308*da0073e9SAndroid Build Coastguard Worker action="store_true", 3309*da0073e9SAndroid Build Coastguard Worker help="Disables split reductions for Inductor", 3310*da0073e9SAndroid Build Coastguard Worker ) 3311*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3312*da0073e9SAndroid Build Coastguard Worker "--disable-persistent-reductions", 3313*da0073e9SAndroid Build Coastguard Worker action="store_true", 3314*da0073e9SAndroid Build Coastguard Worker help="Disables split reductions for Inductor", 3315*da0073e9SAndroid Build Coastguard Worker ) 3316*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3317*da0073e9SAndroid Build Coastguard Worker "--disable-divisible-by-16", 3318*da0073e9SAndroid Build Coastguard Worker action="store_true", 3319*da0073e9SAndroid Build Coastguard Worker help="Disables divisible by 16 hint to Triton for Inductor", 3320*da0073e9SAndroid Build Coastguard Worker ) 3321*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3322*da0073e9SAndroid Build Coastguard Worker "--inductor-compile-mode", 3323*da0073e9SAndroid Build Coastguard Worker default=None, 3324*da0073e9SAndroid Build Coastguard Worker help="torch.compile mode argument for inductor runs.", 3325*da0073e9SAndroid Build Coastguard Worker ) 3326*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3327*da0073e9SAndroid Build Coastguard Worker "--print-graph-breaks", 3328*da0073e9SAndroid Build Coastguard Worker action="store_true", 3329*da0073e9SAndroid Build Coastguard Worker help="Show a warning whenever graph break", 3330*da0073e9SAndroid Build Coastguard Worker ) 3331*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3332*da0073e9SAndroid Build Coastguard Worker "--log-graph-breaks", 3333*da0073e9SAndroid Build Coastguard Worker action="store_true", 3334*da0073e9SAndroid Build Coastguard Worker help="log graph breaks in a file", 3335*da0073e9SAndroid Build Coastguard Worker ) 3336*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3337*da0073e9SAndroid Build Coastguard Worker "--trace-on-xla", 3338*da0073e9SAndroid Build Coastguard Worker action="store_true", 3339*da0073e9SAndroid Build Coastguard Worker help="Whether to trace the model on XLA or on eager device", 3340*da0073e9SAndroid Build Coastguard Worker ) 3341*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3342*da0073e9SAndroid Build Coastguard Worker "--xla-tolerance", 3343*da0073e9SAndroid Build Coastguard Worker type=float, 3344*da0073e9SAndroid Build Coastguard Worker default=1e-2, 3345*da0073e9SAndroid Build Coastguard Worker help="XLA needs a loose tolerance to pass the correctness check", 3346*da0073e9SAndroid Build Coastguard Worker ) 3347*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3348*da0073e9SAndroid Build Coastguard Worker "--collect-outputs", 3349*da0073e9SAndroid Build Coastguard Worker action="store_true", 3350*da0073e9SAndroid Build Coastguard Worker help="""Whether to collect outputs for training. Set this to true if we 3351*da0073e9SAndroid Build Coastguard Worker want to verify the numerical correctness of graidents. But that may 3352*da0073e9SAndroid Build Coastguard Worker cause time measurement not accurate""", 3353*da0073e9SAndroid Build Coastguard Worker ) 3354*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3355*da0073e9SAndroid Build Coastguard Worker "--enable-activation-checkpointing", 3356*da0073e9SAndroid Build Coastguard Worker action="store_true", 3357*da0073e9SAndroid Build Coastguard Worker help="Enables activation checkpointing for HF models", 3358*da0073e9SAndroid Build Coastguard Worker ) 3359*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--timing", action="store_true", help="Emits phase timing") 3360*da0073e9SAndroid Build Coastguard Worker 3361*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3362*da0073e9SAndroid Build Coastguard Worker "--progress", 3363*da0073e9SAndroid Build Coastguard Worker action="store_true", 3364*da0073e9SAndroid Build Coastguard Worker help="Print n/k models message between each model run.", 3365*da0073e9SAndroid Build Coastguard Worker ) 3366*da0073e9SAndroid Build Coastguard Worker 3367*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3368*da0073e9SAndroid Build Coastguard Worker "--timeout", 3369*da0073e9SAndroid Build Coastguard Worker type=int, 3370*da0073e9SAndroid Build Coastguard Worker default=2000, 3371*da0073e9SAndroid Build Coastguard Worker help="timeout (second) for benchmarking.", 3372*da0073e9SAndroid Build Coastguard Worker ) 3373*da0073e9SAndroid Build Coastguard Worker 3374*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3375*da0073e9SAndroid Build Coastguard Worker "--per_process_memory_fraction", 3376*da0073e9SAndroid Build Coastguard Worker type=float, 3377*da0073e9SAndroid Build Coastguard Worker default=1, 3378*da0073e9SAndroid Build Coastguard Worker help="Set per-process GPU memory fraction (limit) for reducing usable size and reproducing OOMs", 3379*da0073e9SAndroid Build Coastguard Worker ) 3380*da0073e9SAndroid Build Coastguard Worker 3381*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3382*da0073e9SAndroid Build Coastguard Worker "--no-translation-validation", 3383*da0073e9SAndroid Build Coastguard Worker action="store_true", 3384*da0073e9SAndroid Build Coastguard Worker help="Disable translation validation for accuracy builds.", 3385*da0073e9SAndroid Build Coastguard Worker ) 3386*da0073e9SAndroid Build Coastguard Worker 3387*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3388*da0073e9SAndroid Build Coastguard Worker "--minify", 3389*da0073e9SAndroid Build Coastguard Worker action="store_true", 3390*da0073e9SAndroid Build Coastguard Worker help="Enable minification when failure is below tolerance. Save repro script for each model.", 3391*da0073e9SAndroid Build Coastguard Worker ) 3392*da0073e9SAndroid Build Coastguard Worker 3393*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3394*da0073e9SAndroid Build Coastguard Worker "--compiled-autograd", 3395*da0073e9SAndroid Build Coastguard Worker action="store_true", 3396*da0073e9SAndroid Build Coastguard Worker help="Enables compiled autograd on compiled benchmark", 3397*da0073e9SAndroid Build Coastguard Worker ) 3398*da0073e9SAndroid Build Coastguard Worker 3399*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3400*da0073e9SAndroid Build Coastguard Worker "--profile_dynamo_cache_lookup", 3401*da0073e9SAndroid Build Coastguard Worker "--profile-dynamo-cache-lookup", 3402*da0073e9SAndroid Build Coastguard Worker action="store_true", 3403*da0073e9SAndroid Build Coastguard Worker help="profiles TorchDynamo cache lookup", 3404*da0073e9SAndroid Build Coastguard Worker ) 3405*da0073e9SAndroid Build Coastguard Worker 3406*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3407*da0073e9SAndroid Build Coastguard Worker "--snapshot-memory", 3408*da0073e9SAndroid Build Coastguard Worker "--snapshot_memory", 3409*da0073e9SAndroid Build Coastguard Worker action="store_true", 3410*da0073e9SAndroid Build Coastguard Worker help="Enables Memory Snapshot tool for memory deep dives: https://pytorch.org/blog/understanding-gpu-memory-1/", 3411*da0073e9SAndroid Build Coastguard Worker ) 3412*da0073e9SAndroid Build Coastguard Worker 3413*da0073e9SAndroid Build Coastguard Worker group_latency = parser.add_mutually_exclusive_group() 3414*da0073e9SAndroid Build Coastguard Worker group_latency.add_argument( 3415*da0073e9SAndroid Build Coastguard Worker "--cold-start-latency", 3416*da0073e9SAndroid Build Coastguard Worker "--cold_start_latency", 3417*da0073e9SAndroid Build Coastguard Worker action="store_true", 3418*da0073e9SAndroid Build Coastguard Worker help="Use a fresh triton cachedir when running each model, to force cold-start compile.", 3419*da0073e9SAndroid Build Coastguard Worker ) 3420*da0073e9SAndroid Build Coastguard Worker group_latency.add_argument( 3421*da0073e9SAndroid Build Coastguard Worker "--warm-start-latency", 3422*da0073e9SAndroid Build Coastguard Worker "--warm_start_latency", 3423*da0073e9SAndroid Build Coastguard Worker action="store_true", 3424*da0073e9SAndroid Build Coastguard Worker help="Run model(s) twice and preseve caches in between to enable a 'warm start' on the 2nd run", 3425*da0073e9SAndroid Build Coastguard Worker ) 3426*da0073e9SAndroid Build Coastguard Worker 3427*da0073e9SAndroid Build Coastguard Worker group_fuser = parser.add_mutually_exclusive_group() 3428*da0073e9SAndroid Build Coastguard Worker # --nvfuser is now the default, keep the option to not break scripts 3429*da0073e9SAndroid Build Coastguard Worker group_fuser.add_argument("--nvfuser", action="store_true", help=argparse.SUPPRESS) 3430*da0073e9SAndroid Build Coastguard Worker group_fuser.add_argument("--nnc", action="store_true", help="enable NNC for GPUs") 3431*da0073e9SAndroid Build Coastguard Worker 3432*da0073e9SAndroid Build Coastguard Worker group_prec = parser.add_mutually_exclusive_group() 3433*da0073e9SAndroid Build Coastguard Worker group_prec.add_argument("--float16", action="store_true", help="cast model to fp16") 3434*da0073e9SAndroid Build Coastguard Worker group_prec.add_argument( 3435*da0073e9SAndroid Build Coastguard Worker "--bfloat16", action="store_true", help="cast model to bf16" 3436*da0073e9SAndroid Build Coastguard Worker ) 3437*da0073e9SAndroid Build Coastguard Worker group_prec.add_argument("--float32", action="store_true", help="cast model to fp32") 3438*da0073e9SAndroid Build Coastguard Worker group_prec.add_argument( 3439*da0073e9SAndroid Build Coastguard Worker "--amp", action="store_true", help="use automatic mixed precision" 3440*da0073e9SAndroid Build Coastguard Worker ) 3441*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 3442*da0073e9SAndroid Build Coastguard Worker "--amp-dtype", 3443*da0073e9SAndroid Build Coastguard Worker choices=("bfloat16", "float16"), 3444*da0073e9SAndroid Build Coastguard Worker help="the data type used with automatic mixed precision", 3445*da0073e9SAndroid Build Coastguard Worker ) 3446*da0073e9SAndroid Build Coastguard Worker group_printout = parser.add_mutually_exclusive_group() 3447*da0073e9SAndroid Build Coastguard Worker group_printout.add_argument( 3448*da0073e9SAndroid Build Coastguard Worker "--verbose", "-v", action="store_true", help="enable verbose debug printouts" 3449*da0073e9SAndroid Build Coastguard Worker ) 3450*da0073e9SAndroid Build Coastguard Worker group_printout.add_argument( 3451*da0073e9SAndroid Build Coastguard Worker "--quiet", "-q", action="store_true", help="suppress debug printouts" 3452*da0073e9SAndroid Build Coastguard Worker ) 3453*da0073e9SAndroid Build Coastguard Worker 3454*da0073e9SAndroid Build Coastguard Worker group = parser.add_mutually_exclusive_group() 3455*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3456*da0073e9SAndroid Build Coastguard Worker "--coverage", action="store_true", help="(default) " + help(coverage_experiment) 3457*da0073e9SAndroid Build Coastguard Worker ) 3458*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3459*da0073e9SAndroid Build Coastguard Worker "--overhead", action="store_true", help=help(overhead_experiment) 3460*da0073e9SAndroid Build Coastguard Worker ) 3461*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3462*da0073e9SAndroid Build Coastguard Worker "--speedup-dynamo-ts", 3463*da0073e9SAndroid Build Coastguard Worker action="store_true", 3464*da0073e9SAndroid Build Coastguard Worker help="TorchDynamo frontend with torchscript backend", 3465*da0073e9SAndroid Build Coastguard Worker ) 3466*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3467*da0073e9SAndroid Build Coastguard Worker "--speedup-fx2trt", action="store_true", help=help(speedup_experiment_fx2trt) 3468*da0073e9SAndroid Build Coastguard Worker ) 3469*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3470*da0073e9SAndroid Build Coastguard Worker "--speedup-fx2trt-fp16", 3471*da0073e9SAndroid Build Coastguard Worker action="store_true", 3472*da0073e9SAndroid Build Coastguard Worker help=help(speedup_experiment_fx2trt), 3473*da0073e9SAndroid Build Coastguard Worker ) 3474*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3475*da0073e9SAndroid Build Coastguard Worker "--print-fx", 3476*da0073e9SAndroid Build Coastguard Worker action="store_true", 3477*da0073e9SAndroid Build Coastguard Worker help="Print fx traces captured from model", 3478*da0073e9SAndroid Build Coastguard Worker ) 3479*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3480*da0073e9SAndroid Build Coastguard Worker "--print-aten-ops", 3481*da0073e9SAndroid Build Coastguard Worker action="store_true", 3482*da0073e9SAndroid Build Coastguard Worker help="Print traces of aten ops captured by AOT autograd", 3483*da0073e9SAndroid Build Coastguard Worker ) 3484*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3485*da0073e9SAndroid Build Coastguard Worker "--inductor", 3486*da0073e9SAndroid Build Coastguard Worker action="store_true", 3487*da0073e9SAndroid Build Coastguard Worker help="Measure speedup with TorchInductor", 3488*da0073e9SAndroid Build Coastguard Worker ) 3489*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3490*da0073e9SAndroid Build Coastguard Worker "--quantization", 3491*da0073e9SAndroid Build Coastguard Worker choices=[ 3492*da0073e9SAndroid Build Coastguard Worker "int8dynamic", 3493*da0073e9SAndroid Build Coastguard Worker "int8weightonly", 3494*da0073e9SAndroid Build Coastguard Worker "int4weightonly", 3495*da0073e9SAndroid Build Coastguard Worker "autoquant", 3496*da0073e9SAndroid Build Coastguard Worker "noquant", 3497*da0073e9SAndroid Build Coastguard Worker ], 3498*da0073e9SAndroid Build Coastguard Worker default=None, 3499*da0073e9SAndroid Build Coastguard Worker help="Measure speedup of torchao quantization with TorchInductor baseline", 3500*da0073e9SAndroid Build Coastguard Worker ) 3501*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3502*da0073e9SAndroid Build Coastguard Worker "--export", 3503*da0073e9SAndroid Build Coastguard Worker action="store_true", 3504*da0073e9SAndroid Build Coastguard Worker help="Measure pass rate with export", 3505*da0073e9SAndroid Build Coastguard Worker ) 3506*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3507*da0073e9SAndroid Build Coastguard Worker "--export-aot-inductor", 3508*da0073e9SAndroid Build Coastguard Worker action="store_true", 3509*da0073e9SAndroid Build Coastguard Worker help="Measure pass rate with Export+AOTInductor", 3510*da0073e9SAndroid Build Coastguard Worker ) 3511*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3512*da0073e9SAndroid Build Coastguard Worker "--xla", action="store_true", help="Compare TorchXLA to eager PyTorch" 3513*da0073e9SAndroid Build Coastguard Worker ) 3514*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3515*da0073e9SAndroid Build Coastguard Worker "--torchscript-onnx", 3516*da0073e9SAndroid Build Coastguard Worker "--torchscript_onnx", 3517*da0073e9SAndroid Build Coastguard Worker action="store_true", 3518*da0073e9SAndroid Build Coastguard Worker help="Measure speedup with TorchScript ONNX, i.e. `torch.onnx.export`", 3519*da0073e9SAndroid Build Coastguard Worker ) 3520*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3521*da0073e9SAndroid Build Coastguard Worker "--dynamo-onnx", 3522*da0073e9SAndroid Build Coastguard Worker "--dynamo_onnx", 3523*da0073e9SAndroid Build Coastguard Worker action="store_true", 3524*da0073e9SAndroid Build Coastguard Worker help="Measure speedup with Dynamo ONNX, i.e. `torch.onnx.dynamo_export`", 3525*da0073e9SAndroid Build Coastguard Worker ) 3526*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3527*da0073e9SAndroid Build Coastguard Worker "--dynamo-onnx-aot-inline", 3528*da0073e9SAndroid Build Coastguard Worker "--dynamo_onnx_aot_inline", 3529*da0073e9SAndroid Build Coastguard Worker action="store_true", 3530*da0073e9SAndroid Build Coastguard Worker help="Measure speedup with Dynamo ONNX AOT Inline, i.e. `torch.onnx.dynamo_export`", 3531*da0073e9SAndroid Build Coastguard Worker ) 3532*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3533*da0073e9SAndroid Build Coastguard Worker "--dynamo-onnx-aot-optimize", 3534*da0073e9SAndroid Build Coastguard Worker "--dynamo_onnx_aot_optimize", 3535*da0073e9SAndroid Build Coastguard Worker action="store_true", 3536*da0073e9SAndroid Build Coastguard Worker help="Measure speedup with Dynamo ONNX w/ ort fusions, i.e. `torch.onnx.dynamo_export`", 3537*da0073e9SAndroid Build Coastguard Worker ) 3538*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3539*da0073e9SAndroid Build Coastguard Worker "--backend", 3540*da0073e9SAndroid Build Coastguard Worker choices=torch._dynamo.list_backends(exclude_tags=None), 3541*da0073e9SAndroid Build Coastguard Worker help="measure speedup with a given backend", 3542*da0073e9SAndroid Build Coastguard Worker ) 3543*da0073e9SAndroid Build Coastguard Worker group.add_argument("--nothing", action="store_true", help=help(null_experiment)) 3544*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3545*da0073e9SAndroid Build Coastguard Worker "--log-conv-args", 3546*da0073e9SAndroid Build Coastguard Worker action="store_true", 3547*da0073e9SAndroid Build Coastguard Worker help="Dump convolution input/weight/bias's shape/stride/dtype and other options to json", 3548*da0073e9SAndroid Build Coastguard Worker ) 3549*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3550*da0073e9SAndroid Build Coastguard Worker "--recompile-profiler", 3551*da0073e9SAndroid Build Coastguard Worker "--recompile_profiler", 3552*da0073e9SAndroid Build Coastguard Worker action="store_true", 3553*da0073e9SAndroid Build Coastguard Worker help="Run the dynamo recompilation profiler on each model.", 3554*da0073e9SAndroid Build Coastguard Worker ) 3555*da0073e9SAndroid Build Coastguard Worker group.add_argument( 3556*da0073e9SAndroid Build Coastguard Worker "--find-batch-sizes", 3557*da0073e9SAndroid Build Coastguard Worker action="store_true", 3558*da0073e9SAndroid Build Coastguard Worker help="finds the largest batch size that could fit on GPUs", 3559*da0073e9SAndroid Build Coastguard Worker ) 3560*da0073e9SAndroid Build Coastguard Worker 3561*da0073e9SAndroid Build Coastguard Worker mode_group = parser.add_mutually_exclusive_group(required=True) 3562*da0073e9SAndroid Build Coastguard Worker mode_group.add_argument( 3563*da0073e9SAndroid Build Coastguard Worker "--accuracy", 3564*da0073e9SAndroid Build Coastguard Worker action="store_true", 3565*da0073e9SAndroid Build Coastguard Worker help="Checks accuracy with small batch size and eval mode", 3566*da0073e9SAndroid Build Coastguard Worker ) 3567*da0073e9SAndroid Build Coastguard Worker mode_group.add_argument( 3568*da0073e9SAndroid Build Coastguard Worker "--performance", action="store_true", help="Measures performance speedup" 3569*da0073e9SAndroid Build Coastguard Worker ) 3570*da0073e9SAndroid Build Coastguard Worker mode_group.add_argument( 3571*da0073e9SAndroid Build Coastguard Worker "--tolerance", 3572*da0073e9SAndroid Build Coastguard Worker action="store_true", 3573*da0073e9SAndroid Build Coastguard Worker help="extracts the tolerance for each model with small batch size and eval mode", 3574*da0073e9SAndroid Build Coastguard Worker ) 3575*da0073e9SAndroid Build Coastguard Worker run_mode_group = parser.add_mutually_exclusive_group(required=True) 3576*da0073e9SAndroid Build Coastguard Worker run_mode_group.add_argument( 3577*da0073e9SAndroid Build Coastguard Worker "--training", 3578*da0073e9SAndroid Build Coastguard Worker action="store_true", 3579*da0073e9SAndroid Build Coastguard Worker help="Performs training", 3580*da0073e9SAndroid Build Coastguard Worker ) 3581*da0073e9SAndroid Build Coastguard Worker run_mode_group.add_argument( 3582*da0073e9SAndroid Build Coastguard Worker "--inference", action="store_true", help="Performs inference" 3583*da0073e9SAndroid Build Coastguard Worker ) 3584*da0073e9SAndroid Build Coastguard Worker return parser.parse_args(args) 3585*da0073e9SAndroid Build Coastguard Worker 3586*da0073e9SAndroid Build Coastguard Worker 3587*da0073e9SAndroid Build Coastguard Workerdef process_entry(rank, runner, original_dir, args): 3588*da0073e9SAndroid Build Coastguard Worker args.rank = rank 3589*da0073e9SAndroid Build Coastguard Worker with maybe_init_distributed( 3590*da0073e9SAndroid Build Coastguard Worker args.init_distributed, 3591*da0073e9SAndroid Build Coastguard Worker rank=rank, 3592*da0073e9SAndroid Build Coastguard Worker world_size=args.world_size, 3593*da0073e9SAndroid Build Coastguard Worker port=args.distributed_master_port, 3594*da0073e9SAndroid Build Coastguard Worker ): 3595*da0073e9SAndroid Build Coastguard Worker return run(runner, args, original_dir) 3596*da0073e9SAndroid Build Coastguard Worker 3597*da0073e9SAndroid Build Coastguard Worker 3598*da0073e9SAndroid Build Coastguard Workerdef maybe_fresh_cache(args): 3599*da0073e9SAndroid Build Coastguard Worker cache_dir_assigned = "TORCHINDUCTOR_CACHE_DIR" in os.environ 3600*da0073e9SAndroid Build Coastguard Worker if not cache_dir_assigned and ( 3601*da0073e9SAndroid Build Coastguard Worker args.cold_start_latency or args.warm_start_latency or args.ci 3602*da0073e9SAndroid Build Coastguard Worker ): 3603*da0073e9SAndroid Build Coastguard Worker return fresh_inductor_cache() 3604*da0073e9SAndroid Build Coastguard Worker else: 3605*da0073e9SAndroid Build Coastguard Worker return contextlib.nullcontext() 3606*da0073e9SAndroid Build Coastguard Worker 3607*da0073e9SAndroid Build Coastguard Worker 3608*da0073e9SAndroid Build Coastguard Workerdef main(runner, original_dir=None, args=None): 3609*da0073e9SAndroid Build Coastguard Worker if original_dir: 3610*da0073e9SAndroid Build Coastguard Worker os.chdir(original_dir) 3611*da0073e9SAndroid Build Coastguard Worker args = parse_args() if not args else parse_args(args) 3612*da0073e9SAndroid Build Coastguard Worker if args.baseline: 3613*da0073e9SAndroid Build Coastguard Worker args.baseline = os.path.abspath(args.baseline) 3614*da0073e9SAndroid Build Coastguard Worker 3615*da0073e9SAndroid Build Coastguard Worker if should_diff_branch(args): 3616*da0073e9SAndroid Build Coastguard Worker import git 3617*da0073e9SAndroid Build Coastguard Worker 3618*da0073e9SAndroid Build Coastguard Worker # We do this here so we error out earlier if there's an issue 3619*da0073e9SAndroid Build Coastguard Worker repo = git.Repo() 3620*da0073e9SAndroid Build Coastguard Worker if repo.is_dirty(): 3621*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 3622*da0073e9SAndroid Build Coastguard Worker "--diff-branch called on dirty branch. Commit, stash, or reset." 3623*da0073e9SAndroid Build Coastguard Worker ) 3624*da0073e9SAndroid Build Coastguard Worker main_branch = repo.active_branch.name 3625*da0073e9SAndroid Build Coastguard Worker if main_branch == args.diff_branch: 3626*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 3627*da0073e9SAndroid Build Coastguard Worker f"--diff-branch: current branch is same as {args.diff_branch} branch, what are you diffing?" 3628*da0073e9SAndroid Build Coastguard Worker ) 3629*da0073e9SAndroid Build Coastguard Worker 3630*da0073e9SAndroid Build Coastguard Worker with maybe_fresh_cache(args): 3631*da0073e9SAndroid Build Coastguard Worker args.init_distributed = args.only and args.multiprocess 3632*da0073e9SAndroid Build Coastguard Worker if args.init_distributed: 3633*da0073e9SAndroid Build Coastguard Worker # NB: Do NOT query device count before CUDA initialization; we're 3634*da0073e9SAndroid Build Coastguard Worker # going to overwrite CUDA_VISIBLE_DEVICES and this will result in 3635*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/107300 3636*da0073e9SAndroid Build Coastguard Worker device_count = torch.cuda.device_count() 3637*da0073e9SAndroid Build Coastguard Worker if device_count <= 1: 3638*da0073e9SAndroid Build Coastguard Worker log.warning( 3639*da0073e9SAndroid Build Coastguard Worker "The use multiprocess flag is set but there are <= 1 devices available." 3640*da0073e9SAndroid Build Coastguard Worker ) 3641*da0073e9SAndroid Build Coastguard Worker # multiprocess path 3642*da0073e9SAndroid Build Coastguard Worker args.world_size = device_count 3643*da0073e9SAndroid Build Coastguard Worker mp.spawn( 3644*da0073e9SAndroid Build Coastguard Worker process_entry, args=(runner, original_dir, args), nprocs=device_count 3645*da0073e9SAndroid Build Coastguard Worker ) 3646*da0073e9SAndroid Build Coastguard Worker elif args.only and args.warm_start_latency: 3647*da0073e9SAndroid Build Coastguard Worker # Warm start mode. Enable FX graph caching and perform back-to-back runs in 3648*da0073e9SAndroid Build Coastguard Worker # separate processes (but ensure the inductor cache is preserved across runs). 3649*da0073e9SAndroid Build Coastguard Worker env = os.environ.copy() 3650*da0073e9SAndroid Build Coastguard Worker env["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" 3651*da0073e9SAndroid Build Coastguard Worker cmd = [sys.executable] + sys.argv 3652*da0073e9SAndroid Build Coastguard Worker cmd.remove("--warm-start-latency") 3653*da0073e9SAndroid Build Coastguard Worker 3654*da0073e9SAndroid Build Coastguard Worker print(f"Performing cold-start run for {args.only}") 3655*da0073e9SAndroid Build Coastguard Worker warmup_cmd = cmd + ["--repeat=1", "--disable-output"] 3656*da0073e9SAndroid Build Coastguard Worker subprocess.check_call(warmup_cmd, timeout=args.timeout, env=env) 3657*da0073e9SAndroid Build Coastguard Worker 3658*da0073e9SAndroid Build Coastguard Worker print(f"Performing warm-start run for {args.only}") 3659*da0073e9SAndroid Build Coastguard Worker subprocess.check_call(cmd, timeout=args.timeout, env=env) 3660*da0073e9SAndroid Build Coastguard Worker else: 3661*da0073e9SAndroid Build Coastguard Worker # single process path just uses the main process 3662*da0073e9SAndroid Build Coastguard Worker args.world_size = 1 3663*da0073e9SAndroid Build Coastguard Worker process_entry(0, runner, original_dir, args) 3664*da0073e9SAndroid Build Coastguard Worker 3665*da0073e9SAndroid Build Coastguard Worker 3666*da0073e9SAndroid Build Coastguard Workerdef write_csv_when_exception(args, name: str, status: str, device=None): 3667*da0073e9SAndroid Build Coastguard Worker print(status) 3668*da0073e9SAndroid Build Coastguard Worker placeholder_batch_size = 0 3669*da0073e9SAndroid Build Coastguard Worker devices = [device] if device is not None else args.devices 3670*da0073e9SAndroid Build Coastguard Worker if args.accuracy: 3671*da0073e9SAndroid Build Coastguard Worker headers = ["dev", "name", "batch_size", "accuracy"] 3672*da0073e9SAndroid Build Coastguard Worker rows = [[device, name, placeholder_batch_size, status] for device in devices] 3673*da0073e9SAndroid Build Coastguard Worker elif args.performance: 3674*da0073e9SAndroid Build Coastguard Worker headers = ["dev", "name", "batch_size", "speedup", "abs_latency"] 3675*da0073e9SAndroid Build Coastguard Worker rows = [[device, name, placeholder_batch_size, 0.0, 0.0] for device in devices] 3676*da0073e9SAndroid Build Coastguard Worker else: 3677*da0073e9SAndroid Build Coastguard Worker headers = [] 3678*da0073e9SAndroid Build Coastguard Worker rows = [[device, name, placeholder_batch_size, 0.0] for device in devices] 3679*da0073e9SAndroid Build Coastguard Worker 3680*da0073e9SAndroid Build Coastguard Worker for row in rows: 3681*da0073e9SAndroid Build Coastguard Worker output_csv(output_filename, headers, row) 3682*da0073e9SAndroid Build Coastguard Worker 3683*da0073e9SAndroid Build Coastguard Worker 3684*da0073e9SAndroid Build Coastguard Workerdef run(runner, args, original_dir=None): 3685*da0073e9SAndroid Build Coastguard Worker # Pass the parsed args object to benchmark runner object 3686*da0073e9SAndroid Build Coastguard Worker runner.args = args 3687*da0073e9SAndroid Build Coastguard Worker 3688*da0073e9SAndroid Build Coastguard Worker args.filter = args.filter or [r"."] 3689*da0073e9SAndroid Build Coastguard Worker args.exclude = args.exclude or [r"^$"] 3690*da0073e9SAndroid Build Coastguard Worker args.exclude_exact = args.exclude_exact or [] 3691*da0073e9SAndroid Build Coastguard Worker 3692*da0073e9SAndroid Build Coastguard Worker if args.inductor: 3693*da0073e9SAndroid Build Coastguard Worker assert args.backend is None 3694*da0073e9SAndroid Build Coastguard Worker args.backend = "inductor" 3695*da0073e9SAndroid Build Coastguard Worker if args.quantization: 3696*da0073e9SAndroid Build Coastguard Worker assert args.backend is None 3697*da0073e9SAndroid Build Coastguard Worker args.backend = "torchao" 3698*da0073e9SAndroid Build Coastguard Worker if args.dynamic_batch_only: 3699*da0073e9SAndroid Build Coastguard Worker args.dynamic_shapes = True 3700*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.assume_static_by_default = True 3701*da0073e9SAndroid Build Coastguard Worker if args.dynamic_shapes: 3702*da0073e9SAndroid Build Coastguard Worker if not args.dynamic_batch_only: 3703*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.assume_static_by_default = False 3704*da0073e9SAndroid Build Coastguard Worker if args.propagate_real_tensors: 3705*da0073e9SAndroid Build Coastguard Worker # TODO: Separate flag for data dependent 3706*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.capture_scalar_outputs = True 3707*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.capture_dynamic_output_shape_ops = True 3708*da0073e9SAndroid Build Coastguard Worker torch._functorch.config.fake_tensor_propagate_real_tensors = True 3709*da0073e9SAndroid Build Coastguard Worker if args.specialize_int: 3710*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.specialize_int = True 3711*da0073e9SAndroid Build Coastguard Worker if args.ci: 3712*da0073e9SAndroid Build Coastguard Worker if args.accuracy: 3713*da0073e9SAndroid Build Coastguard Worker # Run fewer iterations when checking accuracy 3714*da0073e9SAndroid Build Coastguard Worker args.repeat = min(args.repeat, 2) 3715*da0073e9SAndroid Build Coastguard Worker 3716*da0073e9SAndroid Build Coastguard Worker # Set translation validation on by default on CI accuracy runs. 3717*da0073e9SAndroid Build Coastguard Worker torch.fx.experimental._config.translation_validation = True 3718*da0073e9SAndroid Build Coastguard Worker 3719*da0073e9SAndroid Build Coastguard Worker ci = functools.partial( 3720*da0073e9SAndroid Build Coastguard Worker CI, args.backend, training=args.training, dynamic=args.dynamic_shapes 3721*da0073e9SAndroid Build Coastguard Worker ) 3722*da0073e9SAndroid Build Coastguard Worker if args.ddp: 3723*da0073e9SAndroid Build Coastguard Worker assert args.training, "DDP benchmark requires --training mode" 3724*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.optimize_ddp = args.optimize_ddp_mode 3725*da0073e9SAndroid Build Coastguard Worker if args.only == "dlrm": 3726*da0073e9SAndroid Build Coastguard Worker log.error( 3727*da0073e9SAndroid Build Coastguard Worker "DLRM+DDP is unsupported as it requires sharding the embedding layer separately from DDP" 3728*da0073e9SAndroid Build Coastguard Worker ) 3729*da0073e9SAndroid Build Coastguard Worker return sys.exit(-1) 3730*da0073e9SAndroid Build Coastguard Worker if args.accuracy: 3731*da0073e9SAndroid Build Coastguard Worker # Use small batch size. We use >1 batch size to ensure we test 3732*da0073e9SAndroid Build Coastguard Worker # batch_norm type of operators that work on batch dims. 3733*da0073e9SAndroid Build Coastguard Worker # TODO - Go through the failures for batch size = 2 3734*da0073e9SAndroid Build Coastguard Worker if args.batch_size is None: 3735*da0073e9SAndroid Build Coastguard Worker if runner.suite_name == "huggingface": 3736*da0073e9SAndroid Build Coastguard Worker args.batch_size = 1 3737*da0073e9SAndroid Build Coastguard Worker elif runner.suite_name == "torchbench": 3738*da0073e9SAndroid Build Coastguard Worker args.batch_size = 4 3739*da0073e9SAndroid Build Coastguard Worker else: 3740*da0073e9SAndroid Build Coastguard Worker # Larger batch size of TIMM models to have stable batch_norm 3741*da0073e9SAndroid Build Coastguard Worker assert runner.suite_name == "timm_models" 3742*da0073e9SAndroid Build Coastguard Worker args.batch_size = 8 3743*da0073e9SAndroid Build Coastguard Worker 3744*da0073e9SAndroid Build Coastguard Worker # Remove sources of randomness 3745*da0073e9SAndroid Build Coastguard Worker if runner.suite_name not in ("timm_models", "huggingface"): 3746*da0073e9SAndroid Build Coastguard Worker # TODO - Using train mode for timm_models and HF models. Move to train mode for Torchbench as well. 3747*da0073e9SAndroid Build Coastguard Worker args.use_eval_mode = True 3748*da0073e9SAndroid Build Coastguard Worker inductor_config.fallback_random = True 3749*da0073e9SAndroid Build Coastguard Worker if args.only is not None and args.only not in { 3750*da0073e9SAndroid Build Coastguard Worker "alexnet", 3751*da0073e9SAndroid Build Coastguard Worker "Background_Matting", 3752*da0073e9SAndroid Build Coastguard Worker "pytorch_CycleGAN_and_pix2pix", 3753*da0073e9SAndroid Build Coastguard Worker "pytorch_unet", 3754*da0073e9SAndroid Build Coastguard Worker "Super_SloMo", 3755*da0073e9SAndroid Build Coastguard Worker "vgg16", 3756*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/96724 3757*da0073e9SAndroid Build Coastguard Worker "Wav2Vec2ForCTC", 3758*da0073e9SAndroid Build Coastguard Worker "Wav2Vec2ForPreTraining", 3759*da0073e9SAndroid Build Coastguard Worker "sam", 3760*da0073e9SAndroid Build Coastguard Worker "sam_fast", 3761*da0073e9SAndroid Build Coastguard Worker "resnet50_quantized_qat", 3762*da0073e9SAndroid Build Coastguard Worker "mobilenet_v2_quantized_qat", 3763*da0073e9SAndroid Build Coastguard Worker }: 3764*da0073e9SAndroid Build Coastguard Worker # some of the models do not support use_deterministic_algorithms 3765*da0073e9SAndroid Build Coastguard Worker torch.use_deterministic_algorithms(True) 3766*da0073e9SAndroid Build Coastguard Worker os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 3767*da0073e9SAndroid Build Coastguard Worker torch.backends.cudnn.deterministic = True 3768*da0073e9SAndroid Build Coastguard Worker torch.backends.cudnn.allow_tf32 = False 3769*da0073e9SAndroid Build Coastguard Worker torch.backends.cudnn.benchmark = False 3770*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_tf32 = False 3771*da0073e9SAndroid Build Coastguard Worker 3772*da0073e9SAndroid Build Coastguard Worker # Remove randomeness when torch manual seed is called 3773*da0073e9SAndroid Build Coastguard Worker patch_torch_manual_seed() 3774*da0073e9SAndroid Build Coastguard Worker 3775*da0073e9SAndroid Build Coastguard Worker # Some models e.g. yolov3 assert batch size on n_gpus 3776*da0073e9SAndroid Build Coastguard Worker if "CUDA_VISIBLE_DEVICES" not in os.environ and not args.multiprocess: 3777*da0073e9SAndroid Build Coastguard Worker args.device_index = "0" 3778*da0073e9SAndroid Build Coastguard Worker 3779*da0073e9SAndroid Build Coastguard Worker # Stricter check to disable fallbacks 3780*da0073e9SAndroid Build Coastguard Worker args.suppress_errors = False 3781*da0073e9SAndroid Build Coastguard Worker 3782*da0073e9SAndroid Build Coastguard Worker if args.device_index is not None: 3783*da0073e9SAndroid Build Coastguard Worker if args.multiprocess: 3784*da0073e9SAndroid Build Coastguard Worker print("Cannot specify both --device_index and --multiprocess") 3785*da0073e9SAndroid Build Coastguard Worker return sys.exit(-1) 3786*da0073e9SAndroid Build Coastguard Worker os.environ["CUDA_VISIBLE_DEVICES"] = args.device_index 3787*da0073e9SAndroid Build Coastguard Worker 3788*da0073e9SAndroid Build Coastguard Worker elif args.performance: 3789*da0073e9SAndroid Build Coastguard Worker # Ensure that we test on real scenarios 3790*da0073e9SAndroid Build Coastguard Worker args.use_eval_mode = False 3791*da0073e9SAndroid Build Coastguard Worker 3792*da0073e9SAndroid Build Coastguard Worker if args.partition_id > args.total_partitions or args.partition_id < 0: 3793*da0073e9SAndroid Build Coastguard Worker print("Invalid partition id") 3794*da0073e9SAndroid Build Coastguard Worker return sys.exit(-1) 3795*da0073e9SAndroid Build Coastguard Worker 3796*da0073e9SAndroid Build Coastguard Worker if not args.devices: 3797*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 3798*da0073e9SAndroid Build Coastguard Worker args.devices = ["cuda"] 3799*da0073e9SAndroid Build Coastguard Worker else: 3800*da0073e9SAndroid Build Coastguard Worker log.warning("torch.cuda.is_available() == False, using CPU") 3801*da0073e9SAndroid Build Coastguard Worker args.devices = ["cpu"] 3802*da0073e9SAndroid Build Coastguard Worker 3803*da0073e9SAndroid Build Coastguard Worker if args.devices != ["cpu"] and (HAS_CUDA or HAS_XPU): 3804*da0073e9SAndroid Build Coastguard Worker global synchronize 3805*da0073e9SAndroid Build Coastguard Worker synchronize = torch.cuda.synchronize if HAS_CUDA else torch.xpu.synchronize 3806*da0073e9SAndroid Build Coastguard Worker 3807*da0073e9SAndroid Build Coastguard Worker if ( 3808*da0073e9SAndroid Build Coastguard Worker args.devices == ["cuda"] 3809*da0073e9SAndroid Build Coastguard Worker and torch.cuda.get_device_properties(0).total_memory < 25 * 2**30 3810*da0073e9SAndroid Build Coastguard Worker ): 3811*da0073e9SAndroid Build Coastguard Worker # OOM errors on an RTX 3090 with 24gb RAM 3812*da0073e9SAndroid Build Coastguard Worker runner.skip_models.update( 3813*da0073e9SAndroid Build Coastguard Worker { 3814*da0073e9SAndroid Build Coastguard Worker # torchbench 3815*da0073e9SAndroid Build Coastguard Worker "hf_Longformer", 3816*da0073e9SAndroid Build Coastguard Worker "timm_nfnet", 3817*da0073e9SAndroid Build Coastguard Worker "timm_efficientdet", 3818*da0073e9SAndroid Build Coastguard Worker } 3819*da0073e9SAndroid Build Coastguard Worker ) 3820*da0073e9SAndroid Build Coastguard Worker if args.training: 3821*da0073e9SAndroid Build Coastguard Worker runner.skip_models.add("hf_T5") 3822*da0073e9SAndroid Build Coastguard Worker 3823*da0073e9SAndroid Build Coastguard Worker if args.nnc: 3824*da0073e9SAndroid Build Coastguard Worker torch._C._jit_override_can_fuse_on_cpu(True) 3825*da0073e9SAndroid Build Coastguard Worker torch._C._jit_override_can_fuse_on_gpu(True) 3826*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_texpr_fuser_enabled(True) 3827*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_nvfuser_enabled(False) 3828*da0073e9SAndroid Build Coastguard Worker 3829*da0073e9SAndroid Build Coastguard Worker if args.threads: 3830*da0073e9SAndroid Build Coastguard Worker torch.set_num_threads(args.threads) 3831*da0073e9SAndroid Build Coastguard Worker 3832*da0073e9SAndroid Build Coastguard Worker if args.verbose: 3833*da0073e9SAndroid Build Coastguard Worker torch._logging.set_logs(dynamo=logging.DEBUG) 3834*da0073e9SAndroid Build Coastguard Worker 3835*da0073e9SAndroid Build Coastguard Worker if args.print_graph_breaks: 3836*da0073e9SAndroid Build Coastguard Worker torch._logging.set_logs(graph_breaks=True) 3837*da0073e9SAndroid Build Coastguard Worker 3838*da0073e9SAndroid Build Coastguard Worker if args.quiet: 3839*da0073e9SAndroid Build Coastguard Worker torch._logging.set_logs(dynamo=logging.ERROR) 3840*da0073e9SAndroid Build Coastguard Worker 3841*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.suppress_errors = args.suppress_errors 3842*da0073e9SAndroid Build Coastguard Worker 3843*da0073e9SAndroid Build Coastguard Worker if args.training: 3844*da0073e9SAndroid Build Coastguard Worker runner.model_iter_fn = runner.forward_and_backward_pass 3845*da0073e9SAndroid Build Coastguard Worker runner.skip_models.update(runner.skip_not_suitable_for_training_models) 3846*da0073e9SAndroid Build Coastguard Worker else: 3847*da0073e9SAndroid Build Coastguard Worker runner.model_iter_fn = runner.forward_pass 3848*da0073e9SAndroid Build Coastguard Worker 3849*da0073e9SAndroid Build Coastguard Worker if args.fast: 3850*da0073e9SAndroid Build Coastguard Worker runner.skip_models.update(runner.slow_models) 3851*da0073e9SAndroid Build Coastguard Worker 3852*da0073e9SAndroid Build Coastguard Worker if args.devices == ["cpu"]: 3853*da0073e9SAndroid Build Coastguard Worker runner.skip_models.update(runner.very_slow_models) 3854*da0073e9SAndroid Build Coastguard Worker runner.skip_models.update(runner.skip_models_for_cpu) 3855*da0073e9SAndroid Build Coastguard Worker elif args.devices == ["cuda"]: 3856*da0073e9SAndroid Build Coastguard Worker runner.skip_models.update(runner.skip_models_for_cuda) 3857*da0073e9SAndroid Build Coastguard Worker 3858*da0073e9SAndroid Build Coastguard Worker if not args.multiprocess: 3859*da0073e9SAndroid Build Coastguard Worker runner.skip_models.update(runner.skip_multiprocess_models) 3860*da0073e9SAndroid Build Coastguard Worker 3861*da0073e9SAndroid Build Coastguard Worker if args.freezing: 3862*da0073e9SAndroid Build Coastguard Worker runner.skip_models.update(runner.skip_models_for_freezing) 3863*da0073e9SAndroid Build Coastguard Worker 3864*da0073e9SAndroid Build Coastguard Worker if args.no_skip: 3865*da0073e9SAndroid Build Coastguard Worker runner.skip_models.clear() 3866*da0073e9SAndroid Build Coastguard Worker 3867*da0073e9SAndroid Build Coastguard Worker experiment = null_experiment 3868*da0073e9SAndroid Build Coastguard Worker global current_name, current_device, current_batch_size, output_filename, disable_output, optimize_ctx, current_onnx_compiler 3869*da0073e9SAndroid Build Coastguard Worker optimize_ctx = contextlib.nullcontext() 3870*da0073e9SAndroid Build Coastguard Worker 3871*da0073e9SAndroid Build Coastguard Worker if args.disable_output: 3872*da0073e9SAndroid Build Coastguard Worker disable_output = True 3873*da0073e9SAndroid Build Coastguard Worker 3874*da0073e9SAndroid Build Coastguard Worker if args.overhead: 3875*da0073e9SAndroid Build Coastguard Worker optimize_ctx = torch._dynamo.optimize(dummy_fx_compile, nopython=args.nopython) 3876*da0073e9SAndroid Build Coastguard Worker experiment = speedup_experiment 3877*da0073e9SAndroid Build Coastguard Worker output_filename = "overheads.csv" 3878*da0073e9SAndroid Build Coastguard Worker elif args.inductor: 3879*da0073e9SAndroid Build Coastguard Worker inductor_config.debug = args.verbose 3880*da0073e9SAndroid Build Coastguard Worker if args.threads: 3881*da0073e9SAndroid Build Coastguard Worker inductor_config.cpp.threads = args.threads 3882*da0073e9SAndroid Build Coastguard Worker 3883*da0073e9SAndroid Build Coastguard Worker optimize_ctx = functools.partial( 3884*da0073e9SAndroid Build Coastguard Worker torch.compile, 3885*da0073e9SAndroid Build Coastguard Worker backend="inductor", 3886*da0073e9SAndroid Build Coastguard Worker fullgraph=args.nopython, 3887*da0073e9SAndroid Build Coastguard Worker mode=args.inductor_compile_mode, 3888*da0073e9SAndroid Build Coastguard Worker ) 3889*da0073e9SAndroid Build Coastguard Worker experiment = speedup_experiment 3890*da0073e9SAndroid Build Coastguard Worker output_filename = "inductor.csv" 3891*da0073e9SAndroid Build Coastguard Worker elif args.export: 3892*da0073e9SAndroid Build Coastguard Worker optimize_ctx = export 3893*da0073e9SAndroid Build Coastguard Worker experiment = speedup_experiment 3894*da0073e9SAndroid Build Coastguard Worker output_filename = "export.csv" 3895*da0073e9SAndroid Build Coastguard Worker elif args.xla: 3896*da0073e9SAndroid Build Coastguard Worker (dev,) = args.devices 3897*da0073e9SAndroid Build Coastguard Worker os.environ["PJRT_DEVICE"] = {"cuda": "GPU", "cpu": "CPU"}[dev] 3898*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic = MagicMock() 3899*da0073e9SAndroid Build Coastguard Worker experiment = xla 3900*da0073e9SAndroid Build Coastguard Worker output_filename = "xla.csv" 3901*da0073e9SAndroid Build Coastguard Worker elif args.torchscript_onnx: 3902*da0073e9SAndroid Build Coastguard Worker optimize_ctx = functools.partial( 3903*da0073e9SAndroid Build Coastguard Worker optimize_onnx_ctx, 3904*da0073e9SAndroid Build Coastguard Worker args.output_directory or ".", 3905*da0073e9SAndroid Build Coastguard Worker OnnxModelFromTorchScript, 3906*da0073e9SAndroid Build Coastguard Worker copy_before_export=args.performance, # Accuarcy bench already did deepcopy 3907*da0073e9SAndroid Build Coastguard Worker ) 3908*da0073e9SAndroid Build Coastguard Worker experiment = speedup_experiment_onnx 3909*da0073e9SAndroid Build Coastguard Worker output_filename = "torchscript_onnx.csv" 3910*da0073e9SAndroid Build Coastguard Worker current_onnx_compiler = "torchscript" 3911*da0073e9SAndroid Build Coastguard Worker elif args.dynamo_onnx: 3912*da0073e9SAndroid Build Coastguard Worker optimize_ctx = functools.partial( 3913*da0073e9SAndroid Build Coastguard Worker optimize_onnx_ctx, 3914*da0073e9SAndroid Build Coastguard Worker args.output_directory or ".", 3915*da0073e9SAndroid Build Coastguard Worker OnnxModelFromDynamo, 3916*da0073e9SAndroid Build Coastguard Worker dynamic_shapes=args.dynamic_shapes, 3917*da0073e9SAndroid Build Coastguard Worker copy_before_export=args.performance, 3918*da0073e9SAndroid Build Coastguard Worker ) 3919*da0073e9SAndroid Build Coastguard Worker experiment = speedup_experiment_onnx 3920*da0073e9SAndroid Build Coastguard Worker output_filename = "dynamo_onnx.csv" 3921*da0073e9SAndroid Build Coastguard Worker current_onnx_compiler = "dynamo" 3922*da0073e9SAndroid Build Coastguard Worker elif args.dynamo_onnx_aot_inline: 3923*da0073e9SAndroid Build Coastguard Worker optimize_ctx = functools.partial( 3924*da0073e9SAndroid Build Coastguard Worker optimize_onnx_ctx, 3925*da0073e9SAndroid Build Coastguard Worker args.output_directory or ".", 3926*da0073e9SAndroid Build Coastguard Worker OnnxModelFromDynamoAotInline, 3927*da0073e9SAndroid Build Coastguard Worker dynamic_shapes=args.dynamic_shapes, 3928*da0073e9SAndroid Build Coastguard Worker copy_before_export=args.performance, 3929*da0073e9SAndroid Build Coastguard Worker ) 3930*da0073e9SAndroid Build Coastguard Worker experiment = speedup_experiment_onnx 3931*da0073e9SAndroid Build Coastguard Worker output_filename = "dynamo_onnx_aot_inline.csv" 3932*da0073e9SAndroid Build Coastguard Worker current_onnx_compiler = "dynamo" 3933*da0073e9SAndroid Build Coastguard Worker elif args.dynamo_onnx_aot_optimize: 3934*da0073e9SAndroid Build Coastguard Worker optimize_ctx = functools.partial( 3935*da0073e9SAndroid Build Coastguard Worker optimize_onnx_ctx, 3936*da0073e9SAndroid Build Coastguard Worker args.output_directory or ".", 3937*da0073e9SAndroid Build Coastguard Worker OnnxModelFromDynamoAotOptimize, 3938*da0073e9SAndroid Build Coastguard Worker dynamic_shapes=args.dynamic_shapes, 3939*da0073e9SAndroid Build Coastguard Worker copy_before_export=args.performance, 3940*da0073e9SAndroid Build Coastguard Worker ) 3941*da0073e9SAndroid Build Coastguard Worker experiment = speedup_experiment_onnx 3942*da0073e9SAndroid Build Coastguard Worker output_filename = "dynamo_onnx_aot_optimize.csv" 3943*da0073e9SAndroid Build Coastguard Worker current_onnx_compiler = "dynamo" 3944*da0073e9SAndroid Build Coastguard Worker elif args.speedup_dynamo_ts: 3945*da0073e9SAndroid Build Coastguard Worker optimize_ctx = torch._dynamo.optimize("ts", nopython=args.nopython) 3946*da0073e9SAndroid Build Coastguard Worker experiment = speedup_experiment 3947*da0073e9SAndroid Build Coastguard Worker output_filename = "speedup_dynamo_ts.csv" 3948*da0073e9SAndroid Build Coastguard Worker elif args.prims_nvfuser: 3949*da0073e9SAndroid Build Coastguard Worker optimize_ctx = torch._dynamo.optimize("prims_nvfuser", nopython=args.nopython) 3950*da0073e9SAndroid Build Coastguard Worker experiment = speedup_experiment 3951*da0073e9SAndroid Build Coastguard Worker backend_str = "prims_nvfuser" 3952*da0073e9SAndroid Build Coastguard Worker output_filename = f"accuracy_aot_{backend_str}.csv" 3953*da0073e9SAndroid Build Coastguard Worker elif args.print_fx: 3954*da0073e9SAndroid Build Coastguard Worker optimize_ctx = torch._dynamo.optimize( 3955*da0073e9SAndroid Build Coastguard Worker print_fx, 3956*da0073e9SAndroid Build Coastguard Worker nopython=args.nopython, 3957*da0073e9SAndroid Build Coastguard Worker ) 3958*da0073e9SAndroid Build Coastguard Worker elif args.print_aten_ops: 3959*da0073e9SAndroid Build Coastguard Worker optimize_ctx = torch._dynamo.optimize( 3960*da0073e9SAndroid Build Coastguard Worker print_aten_ops, 3961*da0073e9SAndroid Build Coastguard Worker nopython=args.nopython, 3962*da0073e9SAndroid Build Coastguard Worker ) 3963*da0073e9SAndroid Build Coastguard Worker elif args.nothing: 3964*da0073e9SAndroid Build Coastguard Worker optimize_ctx = nothing 3965*da0073e9SAndroid Build Coastguard Worker experiment = speedup_experiment 3966*da0073e9SAndroid Build Coastguard Worker output_filename = "nothing.csv" 3967*da0073e9SAndroid Build Coastguard Worker elif args.backend or args.export_aot_inductor: 3968*da0073e9SAndroid Build Coastguard Worker if args.export_aot_inductor: 3969*da0073e9SAndroid Build Coastguard Worker assert not args.training, "AOTInductor only supports inference" 3970*da0073e9SAndroid Build Coastguard Worker optimize_ctx = functools.partial( 3971*da0073e9SAndroid Build Coastguard Worker export_aot_inductor, device=args.devices[0] 3972*da0073e9SAndroid Build Coastguard Worker ) 3973*da0073e9SAndroid Build Coastguard Worker 3974*da0073e9SAndroid Build Coastguard Worker # AOTInductor doesn't support control flow yet 3975*da0073e9SAndroid Build Coastguard Worker runner.skip_models.update(runner.skip_models_due_to_control_flow) 3976*da0073e9SAndroid Build Coastguard Worker elif args.backend == "torchao": 3977*da0073e9SAndroid Build Coastguard Worker assert "cuda" in args.devices, "Quantization requires CUDA device." 3978*da0073e9SAndroid Build Coastguard Worker assert args.bfloat16, "Quantization requires dtype bfloat16." 3979*da0073e9SAndroid Build Coastguard Worker try: 3980*da0073e9SAndroid Build Coastguard Worker from torchao_backend import setup_baseline, torchao_optimize_ctx 3981*da0073e9SAndroid Build Coastguard Worker except ImportError: 3982*da0073e9SAndroid Build Coastguard Worker from userbenchmark.dynamo.dynamobench.torchao_backend import ( 3983*da0073e9SAndroid Build Coastguard Worker setup_baseline, 3984*da0073e9SAndroid Build Coastguard Worker torchao_optimize_ctx, 3985*da0073e9SAndroid Build Coastguard Worker ) 3986*da0073e9SAndroid Build Coastguard Worker 3987*da0073e9SAndroid Build Coastguard Worker setup_baseline() 3988*da0073e9SAndroid Build Coastguard Worker baseline_ctx = functools.partial( 3989*da0073e9SAndroid Build Coastguard Worker torch.compile, 3990*da0073e9SAndroid Build Coastguard Worker backend="inductor", 3991*da0073e9SAndroid Build Coastguard Worker fullgraph=args.nopython, 3992*da0073e9SAndroid Build Coastguard Worker mode=args.inductor_compile_mode, 3993*da0073e9SAndroid Build Coastguard Worker ) 3994*da0073e9SAndroid Build Coastguard Worker runner.model_iter_fn = baseline_ctx(runner.model_iter_fn) 3995*da0073e9SAndroid Build Coastguard Worker optimize_ctx = torchao_optimize_ctx(args.quantization) 3996*da0073e9SAndroid Build Coastguard Worker else: 3997*da0073e9SAndroid Build Coastguard Worker optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython) 3998*da0073e9SAndroid Build Coastguard Worker experiment = speedup_experiment 3999*da0073e9SAndroid Build Coastguard Worker if args.accuracy: 4000*da0073e9SAndroid Build Coastguard Worker output_filename = f"accuracy_{args.backend}.csv" 4001*da0073e9SAndroid Build Coastguard Worker elif args.tolerance: 4002*da0073e9SAndroid Build Coastguard Worker output_filename = f"tolerance_{args.backend}.csv" 4003*da0073e9SAndroid Build Coastguard Worker else: 4004*da0073e9SAndroid Build Coastguard Worker output_filename = f"speedup_{args.backend}.csv" 4005*da0073e9SAndroid Build Coastguard Worker elif args.recompile_profiler: 4006*da0073e9SAndroid Build Coastguard Worker output_filename = "recompile_profiler_log.csv" 4007*da0073e9SAndroid Build Coastguard Worker experiment = recompile_profiler_experiment 4008*da0073e9SAndroid Build Coastguard Worker else: 4009*da0073e9SAndroid Build Coastguard Worker optimize_ctx = torch._dynamo.optimize( 4010*da0073e9SAndroid Build Coastguard Worker fx_insert_profiling, nopython=args.nopython 4011*da0073e9SAndroid Build Coastguard Worker ) 4012*da0073e9SAndroid Build Coastguard Worker experiment = coverage_experiment 4013*da0073e9SAndroid Build Coastguard Worker output_filename = "coverage.csv" 4014*da0073e9SAndroid Build Coastguard Worker 4015*da0073e9SAndroid Build Coastguard Worker if args.inductor or args.backend == "inductor" or args.export_aot_inductor: 4016*da0073e9SAndroid Build Coastguard Worker inductor_config.triton.cudagraphs = not args.disable_cudagraphs 4017*da0073e9SAndroid Build Coastguard Worker inductor_config.triton.persistent_reductions = ( 4018*da0073e9SAndroid Build Coastguard Worker not args.disable_persistent_reductions 4019*da0073e9SAndroid Build Coastguard Worker ) 4020*da0073e9SAndroid Build Coastguard Worker inductor_config.split_reductions = not args.disable_split_reductions 4021*da0073e9SAndroid Build Coastguard Worker inductor_config.triton.divisible_by_16 = not args.disable_divisible_by_16 4022*da0073e9SAndroid Build Coastguard Worker if args.inference: 4023*da0073e9SAndroid Build Coastguard Worker inductor_config.freezing = args.freezing 4024*da0073e9SAndroid Build Coastguard Worker 4025*da0073e9SAndroid Build Coastguard Worker runner.setup_amp() 4026*da0073e9SAndroid Build Coastguard Worker 4027*da0073e9SAndroid Build Coastguard Worker if args.output: 4028*da0073e9SAndroid Build Coastguard Worker output_filename = args.output 4029*da0073e9SAndroid Build Coastguard Worker 4030*da0073e9SAndroid Build Coastguard Worker if output_filename: 4031*da0073e9SAndroid Build Coastguard Worker if args.output_directory: 4032*da0073e9SAndroid Build Coastguard Worker output_filename = os.path.join(args.output_directory, output_filename) 4033*da0073e9SAndroid Build Coastguard Worker else: 4034*da0073e9SAndroid Build Coastguard Worker output_filename = os.path.join( 4035*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.base_dir, output_filename 4036*da0073e9SAndroid Build Coastguard Worker ) 4037*da0073e9SAndroid Build Coastguard Worker 4038*da0073e9SAndroid Build Coastguard Worker if args.find_batch_sizes and args.only: 4039*da0073e9SAndroid Build Coastguard Worker for device in args.devices: 4040*da0073e9SAndroid Build Coastguard Worker batch_size = runner.batch_size_finder(device, args.only) 4041*da0073e9SAndroid Build Coastguard Worker print(args.only, batch_size) 4042*da0073e9SAndroid Build Coastguard Worker output_csv(output_filename, [], [args.only, batch_size]) 4043*da0073e9SAndroid Build Coastguard Worker return 4044*da0073e9SAndroid Build Coastguard Worker 4045*da0073e9SAndroid Build Coastguard Worker if args.export_profiler_trace: 4046*da0073e9SAndroid Build Coastguard Worker if args.profiler_trace_name is None: 4047*da0073e9SAndroid Build Coastguard Worker if args.backend: 4048*da0073e9SAndroid Build Coastguard Worker args.profiler_trace_name = args.backend 4049*da0073e9SAndroid Build Coastguard Worker elif args.inductor: 4050*da0073e9SAndroid Build Coastguard Worker args.profiler_trace_name = "inductor" 4051*da0073e9SAndroid Build Coastguard Worker else: 4052*da0073e9SAndroid Build Coastguard Worker args.profiler_trace_name = "profile" 4053*da0073e9SAndroid Build Coastguard Worker else: 4054*da0073e9SAndroid Build Coastguard Worker args.profiler_trace_name = args.profiler_trace_name 4055*da0073e9SAndroid Build Coastguard Worker 4056*da0073e9SAndroid Build Coastguard Worker if args.no_translation_validation: 4057*da0073e9SAndroid Build Coastguard Worker # Overwrite 'translation_validation' config, if specified. 4058*da0073e9SAndroid Build Coastguard Worker torch.fx.experimental._config.translation_validation = False 4059*da0073e9SAndroid Build Coastguard Worker 4060*da0073e9SAndroid Build Coastguard Worker experiment = functools.partial(experiment, args, runner.model_iter_fn) 4061*da0073e9SAndroid Build Coastguard Worker 4062*da0073e9SAndroid Build Coastguard Worker if args.only and should_diff_branch(args): 4063*da0073e9SAndroid Build Coastguard Worker import git 4064*da0073e9SAndroid Build Coastguard Worker 4065*da0073e9SAndroid Build Coastguard Worker repo = git.Repo() 4066*da0073e9SAndroid Build Coastguard Worker main_branch = repo.active_branch.name 4067*da0073e9SAndroid Build Coastguard Worker try: 4068*da0073e9SAndroid Build Coastguard Worker # Adding diff-branch again to the args will override previous value 4069*da0073e9SAndroid Build Coastguard Worker call_args = ( 4070*da0073e9SAndroid Build Coastguard Worker [sys.executable] + sys.argv + [f"--diff-branch={diff_branch_default}"] 4071*da0073e9SAndroid Build Coastguard Worker ) 4072*da0073e9SAndroid Build Coastguard Worker # Run for main branch 4073*da0073e9SAndroid Build Coastguard Worker subprocess.check_call(call_args + [f"--tag={main_branch}"]) 4074*da0073e9SAndroid Build Coastguard Worker # Run for comparison branch 4075*da0073e9SAndroid Build Coastguard Worker repo.git.checkout(args.diff_branch) 4076*da0073e9SAndroid Build Coastguard Worker subprocess.check_call(call_args + [f"--tag={args.diff_branch}"]) 4077*da0073e9SAndroid Build Coastguard Worker finally: 4078*da0073e9SAndroid Build Coastguard Worker # Go back to main branch 4079*da0073e9SAndroid Build Coastguard Worker repo.git.checkout(main_branch) 4080*da0073e9SAndroid Build Coastguard Worker elif args.only: 4081*da0073e9SAndroid Build Coastguard Worker model_name = args.only 4082*da0073e9SAndroid Build Coastguard Worker for device in args.devices: 4083*da0073e9SAndroid Build Coastguard Worker batch_size = args.batch_size 4084*da0073e9SAndroid Build Coastguard Worker if args.batch_size_file: 4085*da0073e9SAndroid Build Coastguard Worker batch_size = read_batch_size_from_file( 4086*da0073e9SAndroid Build Coastguard Worker args, args.batch_size_file, model_name 4087*da0073e9SAndroid Build Coastguard Worker ) 4088*da0073e9SAndroid Build Coastguard Worker if model_specified_by_path(args.only): 4089*da0073e9SAndroid Build Coastguard Worker model, example_inputs = load_model_from_path(args.only) 4090*da0073e9SAndroid Build Coastguard Worker name = model.__class__.__name__ 4091*da0073e9SAndroid Build Coastguard Worker model = model.to(device=device) 4092*da0073e9SAndroid Build Coastguard Worker example_inputs = tree_map_only( 4093*da0073e9SAndroid Build Coastguard Worker torch.Tensor, lambda x: x.to(device=device), example_inputs 4094*da0073e9SAndroid Build Coastguard Worker ) 4095*da0073e9SAndroid Build Coastguard Worker else: 4096*da0073e9SAndroid Build Coastguard Worker name = model_name 4097*da0073e9SAndroid Build Coastguard Worker try: 4098*da0073e9SAndroid Build Coastguard Worker with tqdm(desc="loading model"): 4099*da0073e9SAndroid Build Coastguard Worker extra_args = [] 4100*da0073e9SAndroid Build Coastguard Worker if hasattr(args, "rank") and hasattr(args, "world_size"): 4101*da0073e9SAndroid Build Coastguard Worker extra_args += [ 4102*da0073e9SAndroid Build Coastguard Worker "--rank", 4103*da0073e9SAndroid Build Coastguard Worker str(args.rank), 4104*da0073e9SAndroid Build Coastguard Worker "--world_size", 4105*da0073e9SAndroid Build Coastguard Worker str(args.world_size), 4106*da0073e9SAndroid Build Coastguard Worker ] 4107*da0073e9SAndroid Build Coastguard Worker 4108*da0073e9SAndroid Build Coastguard Worker if args.part: 4109*da0073e9SAndroid Build Coastguard Worker ( 4110*da0073e9SAndroid Build Coastguard Worker device, 4111*da0073e9SAndroid Build Coastguard Worker name, 4112*da0073e9SAndroid Build Coastguard Worker model, 4113*da0073e9SAndroid Build Coastguard Worker example_inputs, 4114*da0073e9SAndroid Build Coastguard Worker batch_size, 4115*da0073e9SAndroid Build Coastguard Worker ) = runner.load_model( 4116*da0073e9SAndroid Build Coastguard Worker device, 4117*da0073e9SAndroid Build Coastguard Worker model_name, 4118*da0073e9SAndroid Build Coastguard Worker batch_size=batch_size, 4119*da0073e9SAndroid Build Coastguard Worker part=args.part, 4120*da0073e9SAndroid Build Coastguard Worker extra_args=extra_args, 4121*da0073e9SAndroid Build Coastguard Worker ) 4122*da0073e9SAndroid Build Coastguard Worker else: 4123*da0073e9SAndroid Build Coastguard Worker if args.fsdp: 4124*da0073e9SAndroid Build Coastguard Worker # Always load model on cpu for fsdp 4125*da0073e9SAndroid Build Coastguard Worker # When initializing FSDP, we will use the cuda device if args.cuda is set 4126*da0073e9SAndroid Build Coastguard Worker ( 4127*da0073e9SAndroid Build Coastguard Worker _, 4128*da0073e9SAndroid Build Coastguard Worker name, 4129*da0073e9SAndroid Build Coastguard Worker model, 4130*da0073e9SAndroid Build Coastguard Worker example_inputs, 4131*da0073e9SAndroid Build Coastguard Worker batch_size, 4132*da0073e9SAndroid Build Coastguard Worker ) = runner.load_model( 4133*da0073e9SAndroid Build Coastguard Worker "cpu", 4134*da0073e9SAndroid Build Coastguard Worker model_name, 4135*da0073e9SAndroid Build Coastguard Worker batch_size=batch_size, 4136*da0073e9SAndroid Build Coastguard Worker extra_args=extra_args, 4137*da0073e9SAndroid Build Coastguard Worker ) 4138*da0073e9SAndroid Build Coastguard Worker else: 4139*da0073e9SAndroid Build Coastguard Worker ( 4140*da0073e9SAndroid Build Coastguard Worker device, 4141*da0073e9SAndroid Build Coastguard Worker name, 4142*da0073e9SAndroid Build Coastguard Worker model, 4143*da0073e9SAndroid Build Coastguard Worker example_inputs, 4144*da0073e9SAndroid Build Coastguard Worker batch_size, 4145*da0073e9SAndroid Build Coastguard Worker ) = runner.load_model( 4146*da0073e9SAndroid Build Coastguard Worker device, 4147*da0073e9SAndroid Build Coastguard Worker model_name, 4148*da0073e9SAndroid Build Coastguard Worker batch_size=batch_size, 4149*da0073e9SAndroid Build Coastguard Worker extra_args=extra_args, 4150*da0073e9SAndroid Build Coastguard Worker ) 4151*da0073e9SAndroid Build Coastguard Worker except Exception as e: 4152*da0073e9SAndroid Build Coastguard Worker import traceback 4153*da0073e9SAndroid Build Coastguard Worker 4154*da0073e9SAndroid Build Coastguard Worker mode = "train" if args.training else "eval" 4155*da0073e9SAndroid Build Coastguard Worker print(f"{device:4} {mode:5} {name:34} ") 4156*da0073e9SAndroid Build Coastguard Worker print(traceback.format_exc()) 4157*da0073e9SAndroid Build Coastguard Worker status = ( 4158*da0073e9SAndroid Build Coastguard Worker "model_fail_to_load" 4159*da0073e9SAndroid Build Coastguard Worker if isinstance(e, NotImplementedError) 4160*da0073e9SAndroid Build Coastguard Worker else "eager_fail_to_run" 4161*da0073e9SAndroid Build Coastguard Worker ) 4162*da0073e9SAndroid Build Coastguard Worker write_csv_when_exception(args, name, status, device) 4163*da0073e9SAndroid Build Coastguard Worker continue # bad benchmark implementation 4164*da0073e9SAndroid Build Coastguard Worker 4165*da0073e9SAndroid Build Coastguard Worker if args.trace_on_xla: 4166*da0073e9SAndroid Build Coastguard Worker xla_dev = xm.xla_device() 4167*da0073e9SAndroid Build Coastguard Worker model = model.to(device=xla_dev) 4168*da0073e9SAndroid Build Coastguard Worker example_inputs = tree_map_only( 4169*da0073e9SAndroid Build Coastguard Worker torch.Tensor, lambda x: x.to(device=xla_dev), example_inputs 4170*da0073e9SAndroid Build Coastguard Worker ) 4171*da0073e9SAndroid Build Coastguard Worker 4172*da0073e9SAndroid Build Coastguard Worker current_name = name 4173*da0073e9SAndroid Build Coastguard Worker current_device = device 4174*da0073e9SAndroid Build Coastguard Worker current_batch_size = batch_size 4175*da0073e9SAndroid Build Coastguard Worker set_model_name(name) 4176*da0073e9SAndroid Build Coastguard Worker 4177*da0073e9SAndroid Build Coastguard Worker # Look for stuff that looks like batch size, and mark it dynamic. 4178*da0073e9SAndroid Build Coastguard Worker # Better integration would integrate directly with benchmark suite 4179*da0073e9SAndroid Build Coastguard Worker # but cannot conveniently do this 4180*da0073e9SAndroid Build Coastguard Worker # NB: This must be done late enough so that we don't do more 4181*da0073e9SAndroid Build Coastguard Worker # conversions on the inputs 4182*da0073e9SAndroid Build Coastguard Worker # NB: Assumes only the first batch-y like dimension is the batch 4183*da0073e9SAndroid Build Coastguard Worker marked = False 4184*da0073e9SAndroid Build Coastguard Worker 4185*da0073e9SAndroid Build Coastguard Worker def detect_and_mark_batch(t): 4186*da0073e9SAndroid Build Coastguard Worker nonlocal marked 4187*da0073e9SAndroid Build Coastguard Worker for i, s in enumerate(t.size()): 4188*da0073e9SAndroid Build Coastguard Worker if s == batch_size: 4189*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(t, i) 4190*da0073e9SAndroid Build Coastguard Worker marked = True 4191*da0073e9SAndroid Build Coastguard Worker break 4192*da0073e9SAndroid Build Coastguard Worker 4193*da0073e9SAndroid Build Coastguard Worker if ( 4194*da0073e9SAndroid Build Coastguard Worker args.dynamic_batch_only 4195*da0073e9SAndroid Build Coastguard Worker and batch_size > 1 4196*da0073e9SAndroid Build Coastguard Worker and model_name not in CI_SKIP_DYNAMIC_BATCH_ONLY 4197*da0073e9SAndroid Build Coastguard Worker ): 4198*da0073e9SAndroid Build Coastguard Worker tree_map_only(torch.Tensor, detect_and_mark_batch, example_inputs) 4199*da0073e9SAndroid Build Coastguard Worker assert marked, f"nothing in example_inputs had a dim with {batch_size}" 4200*da0073e9SAndroid Build Coastguard Worker 4201*da0073e9SAndroid Build Coastguard Worker if args.log_operator_inputs: 4202*da0073e9SAndroid Build Coastguard Worker log_operator_inputs( 4203*da0073e9SAndroid Build Coastguard Worker model, example_inputs, runner.model_iter_fn, name, args 4204*da0073e9SAndroid Build Coastguard Worker ) 4205*da0073e9SAndroid Build Coastguard Worker continue 4206*da0073e9SAndroid Build Coastguard Worker 4207*da0073e9SAndroid Build Coastguard Worker if args.per_process_memory_fraction != 1: 4208*da0073e9SAndroid Build Coastguard Worker torch.cuda.set_per_process_memory_fraction( 4209*da0073e9SAndroid Build Coastguard Worker args.per_process_memory_fraction 4210*da0073e9SAndroid Build Coastguard Worker ) 4211*da0073e9SAndroid Build Coastguard Worker if model_name in DO_NOT_CAST_INPUTS: 4212*da0073e9SAndroid Build Coastguard Worker model, _ = runner.cast_based_on_args(model, example_inputs) 4213*da0073e9SAndroid Build Coastguard Worker 4214*da0073e9SAndroid Build Coastguard Worker else: 4215*da0073e9SAndroid Build Coastguard Worker model, example_inputs = runner.cast_based_on_args(model, example_inputs) 4216*da0073e9SAndroid Build Coastguard Worker runner.setup_amp(current_device) 4217*da0073e9SAndroid Build Coastguard Worker guard_ctx = contextlib.nullcontext() 4218*da0073e9SAndroid Build Coastguard Worker if name in runner.guard_on_nn_module_models: 4219*da0073e9SAndroid Build Coastguard Worker guard_ctx = torch._dynamo.config.patch(guard_nn_modules=True) 4220*da0073e9SAndroid Build Coastguard Worker 4221*da0073e9SAndroid Build Coastguard Worker with guard_ctx: 4222*da0073e9SAndroid Build Coastguard Worker runner.run_one_model( 4223*da0073e9SAndroid Build Coastguard Worker name, 4224*da0073e9SAndroid Build Coastguard Worker model, 4225*da0073e9SAndroid Build Coastguard Worker example_inputs, 4226*da0073e9SAndroid Build Coastguard Worker optimize_ctx, 4227*da0073e9SAndroid Build Coastguard Worker experiment, 4228*da0073e9SAndroid Build Coastguard Worker explain=args.explain, 4229*da0073e9SAndroid Build Coastguard Worker tag=args.tag, 4230*da0073e9SAndroid Build Coastguard Worker ) 4231*da0073e9SAndroid Build Coastguard Worker if args.generate_aot_autograd_stats: 4232*da0073e9SAndroid Build Coastguard Worker stats_file = output_filename.split(".csv")[0] + "_stats.csv" 4233*da0073e9SAndroid Build Coastguard Worker output_csv( 4234*da0073e9SAndroid Build Coastguard Worker stats_file, 4235*da0073e9SAndroid Build Coastguard Worker ("dev", "name", "batch_size", "total_aot_graphs", "ok_aot_graphs"), 4236*da0073e9SAndroid Build Coastguard Worker [ 4237*da0073e9SAndroid Build Coastguard Worker current_device, 4238*da0073e9SAndroid Build Coastguard Worker current_name, 4239*da0073e9SAndroid Build Coastguard Worker current_batch_size, 4240*da0073e9SAndroid Build Coastguard Worker *Stats.aot_summary(), 4241*da0073e9SAndroid Build Coastguard Worker ], 4242*da0073e9SAndroid Build Coastguard Worker ) 4243*da0073e9SAndroid Build Coastguard Worker else: 4244*da0073e9SAndroid Build Coastguard Worker metrics.purge_old_log_files() 4245*da0073e9SAndroid Build Coastguard Worker if output_filename and os.path.exists(output_filename): 4246*da0073e9SAndroid Build Coastguard Worker os.unlink(output_filename) 4247*da0073e9SAndroid Build Coastguard Worker if original_dir: 4248*da0073e9SAndroid Build Coastguard Worker os.chdir(original_dir) 4249*da0073e9SAndroid Build Coastguard Worker model_names = list(runner.iter_model_names(args)) 4250*da0073e9SAndroid Build Coastguard Worker nmodels = len(model_names) 4251*da0073e9SAndroid Build Coastguard Worker for i, name in enumerate(model_names): 4252*da0073e9SAndroid Build Coastguard Worker current_name = name 4253*da0073e9SAndroid Build Coastguard Worker if args.progress: 4254*da0073e9SAndroid Build Coastguard Worker print(f"Running model {i+1}/{nmodels}", flush=True) 4255*da0073e9SAndroid Build Coastguard Worker 4256*da0073e9SAndroid Build Coastguard Worker try: 4257*da0073e9SAndroid Build Coastguard Worker timeout = args.timeout 4258*da0073e9SAndroid Build Coastguard Worker if should_diff_branch(args): 4259*da0073e9SAndroid Build Coastguard Worker timeout *= 2 4260*da0073e9SAndroid Build Coastguard Worker env = os.environ.copy() 4261*da0073e9SAndroid Build Coastguard Worker if args.ci and name in CI_PRESERVE_COMPILE_DEBUG: 4262*da0073e9SAndroid Build Coastguard Worker env["TORCH_COMPILE_DEBUG"] = "1" 4263*da0073e9SAndroid Build Coastguard Worker subprocess.check_call( 4264*da0073e9SAndroid Build Coastguard Worker [sys.executable] + sys.argv + [f"--only={name}"], 4265*da0073e9SAndroid Build Coastguard Worker timeout=timeout, 4266*da0073e9SAndroid Build Coastguard Worker env=env, 4267*da0073e9SAndroid Build Coastguard Worker ) 4268*da0073e9SAndroid Build Coastguard Worker except subprocess.TimeoutExpired: 4269*da0073e9SAndroid Build Coastguard Worker write_csv_when_exception(args, name, "timeout") 4270*da0073e9SAndroid Build Coastguard Worker except subprocess.CalledProcessError as e: 4271*da0073e9SAndroid Build Coastguard Worker print("Run failed with return code: ", e.returncode, file=sys.stderr) 4272*da0073e9SAndroid Build Coastguard Worker print("Output: ", e.output, file=sys.stderr) 4273*da0073e9SAndroid Build Coastguard Worker print("Error: ", e.stderr, file=sys.stderr) 4274*da0073e9SAndroid Build Coastguard Worker print_summary(output_filename, print_dataframe=args.print_dataframe_summary) 4275*da0073e9SAndroid Build Coastguard Worker 4276*da0073e9SAndroid Build Coastguard Worker 4277*da0073e9SAndroid Build Coastguard Workerdef log_operator_inputs(model, example_inputs, model_iter_fn, name, args): 4278*da0073e9SAndroid Build Coastguard Worker mode = "training" if args.training else "eval" 4279*da0073e9SAndroid Build Coastguard Worker output = os.path.join(os.path.dirname(args.output), f"{name}_{mode}.txt") 4280*da0073e9SAndroid Build Coastguard Worker 4281*da0073e9SAndroid Build Coastguard Worker # TODO - add option for coalescing inputs over multiple runs 4282*da0073e9SAndroid Build Coastguard Worker if os.path.exists(output): 4283*da0073e9SAndroid Build Coastguard Worker print(f"Skipping {name}, {output} already exists") 4284*da0073e9SAndroid Build Coastguard Worker return 4285*da0073e9SAndroid Build Coastguard Worker 4286*da0073e9SAndroid Build Coastguard Worker print(f"Running {name}") 4287*da0073e9SAndroid Build Coastguard Worker try: 4288*da0073e9SAndroid Build Coastguard Worker from .microbenchmarks.operator_inp_utils import OperatorInputsMode 4289*da0073e9SAndroid Build Coastguard Worker except ImportError: 4290*da0073e9SAndroid Build Coastguard Worker from microbenchmarks.operator_inp_utils import OperatorInputsMode 4291*da0073e9SAndroid Build Coastguard Worker 4292*da0073e9SAndroid Build Coastguard Worker operator_mode = OperatorInputsMode() 4293*da0073e9SAndroid Build Coastguard Worker fake_tensor_mode = FakeTensorMode() 4294*da0073e9SAndroid Build Coastguard Worker 4295*da0073e9SAndroid Build Coastguard Worker with torch._subclasses.fake_tensor.FakeCopyMode(fake_tensor_mode): 4296*da0073e9SAndroid Build Coastguard Worker model_fake = copy.deepcopy(model) 4297*da0073e9SAndroid Build Coastguard Worker example_inputs_fake = copy.deepcopy(example_inputs) 4298*da0073e9SAndroid Build Coastguard Worker try: 4299*da0073e9SAndroid Build Coastguard Worker with fake_tensor_mode, operator_mode: 4300*da0073e9SAndroid Build Coastguard Worker model_iter_fn(model_fake, example_inputs_fake, collect_outputs=False) 4301*da0073e9SAndroid Build Coastguard Worker except Exception as e: 4302*da0073e9SAndroid Build Coastguard Worker print(f"{name} failed to run with fake tensors, trying real. Exception: {e}") 4303*da0073e9SAndroid Build Coastguard Worker operator_mode = OperatorInputsMode() 4304*da0073e9SAndroid Build Coastguard Worker try: 4305*da0073e9SAndroid Build Coastguard Worker with operator_mode: 4306*da0073e9SAndroid Build Coastguard Worker model_iter_fn(model, example_inputs, collect_outputs=False) 4307*da0073e9SAndroid Build Coastguard Worker except Exception as e2: 4308*da0073e9SAndroid Build Coastguard Worker print(f"{name} failed to run with real. Exception: {e2}") 4309*da0073e9SAndroid Build Coastguard Worker raise 4310*da0073e9SAndroid Build Coastguard Worker 4311*da0073e9SAndroid Build Coastguard Worker print(f"Writing output to {output}") 4312*da0073e9SAndroid Build Coastguard Worker operator_mode.log_to_file(output) 4313*da0073e9SAndroid Build Coastguard Worker 4314*da0073e9SAndroid Build Coastguard Worker 4315*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 4316*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 4317*da0073e9SAndroid Build Coastguard Worker f"You shouldn't run {sys.argv[0]} directly, instead try timm_model.py, torchbench.py or huggingface.py" 4318*da0073e9SAndroid Build Coastguard Worker ) 4319