xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport abc
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport argparse
7*da0073e9SAndroid Build Coastguard Workerimport collections
8*da0073e9SAndroid Build Coastguard Workerimport contextlib
9*da0073e9SAndroid Build Coastguard Workerimport copy
10*da0073e9SAndroid Build Coastguard Workerimport csv
11*da0073e9SAndroid Build Coastguard Workerimport dataclasses
12*da0073e9SAndroid Build Coastguard Workerimport functools
13*da0073e9SAndroid Build Coastguard Workerimport importlib
14*da0073e9SAndroid Build Coastguard Workerimport itertools
15*da0073e9SAndroid Build Coastguard Workerimport logging
16*da0073e9SAndroid Build Coastguard Workerimport os
17*da0073e9SAndroid Build Coastguard Workerimport pathlib
18*da0073e9SAndroid Build Coastguard Workerimport shutil
19*da0073e9SAndroid Build Coastguard Workerimport signal
20*da0073e9SAndroid Build Coastguard Workerimport subprocess
21*da0073e9SAndroid Build Coastguard Workerimport sys
22*da0073e9SAndroid Build Coastguard Workerimport time
23*da0073e9SAndroid Build Coastguard Workerimport weakref
24*da0073e9SAndroid Build Coastguard Workerfrom contextlib import contextmanager
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Workerfrom typing import (
27*da0073e9SAndroid Build Coastguard Worker    Any,
28*da0073e9SAndroid Build Coastguard Worker    Callable,
29*da0073e9SAndroid Build Coastguard Worker    Generator,
30*da0073e9SAndroid Build Coastguard Worker    List,
31*da0073e9SAndroid Build Coastguard Worker    Mapping,
32*da0073e9SAndroid Build Coastguard Worker    NamedTuple,
33*da0073e9SAndroid Build Coastguard Worker    Optional,
34*da0073e9SAndroid Build Coastguard Worker    Sequence,
35*da0073e9SAndroid Build Coastguard Worker    Tuple,
36*da0073e9SAndroid Build Coastguard Worker    Type,
37*da0073e9SAndroid Build Coastguard Worker    TYPE_CHECKING,
38*da0073e9SAndroid Build Coastguard Worker)
39*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import Self
40*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import MagicMock
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Workerimport numpy as np
43*da0073e9SAndroid Build Coastguard Workerimport pandas as pd
44*da0073e9SAndroid Build Coastguard Workerimport psutil
45*da0073e9SAndroid Build Coastguard Workerfrom scipy.stats import gmean, ttest_ind
46*da0073e9SAndroid Build Coastguard Workerfrom tqdm.auto import tqdm, trange
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Workerimport torch
49*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo
50*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.utils
51*da0073e9SAndroid Build Coastguard Workerimport torch._export
52*da0073e9SAndroid Build Coastguard Workerimport torch.distributed
53*da0073e9SAndroid Build Coastguard Workerimport torch.multiprocessing as mp
54*da0073e9SAndroid Build Coastguard Workerfrom torch._C import _has_cuda as HAS_CUDA, _has_xpu as HAS_XPU
55*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.profiler import fx_insert_profiling, Profiler
56*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import (
57*da0073e9SAndroid Build Coastguard Worker    dummy_fx_compile,
58*da0073e9SAndroid Build Coastguard Worker    format_speedup,
59*da0073e9SAndroid Build Coastguard Worker    reset_rng_state,
60*da0073e9SAndroid Build Coastguard Worker    same,
61*da0073e9SAndroid Build Coastguard Worker)
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Workertry:
64*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.utils import (
65*da0073e9SAndroid Build Coastguard Worker        clone_inputs,
66*da0073e9SAndroid Build Coastguard Worker        graph_break_reasons,
67*da0073e9SAndroid Build Coastguard Worker        maybe_enable_compiled_autograd,
68*da0073e9SAndroid Build Coastguard Worker    )
69*da0073e9SAndroid Build Coastguard Worker    from torch._inductor.utils import fresh_inductor_cache
70*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
71*da0073e9SAndroid Build Coastguard Worker    from _dynamo.utils import (
72*da0073e9SAndroid Build Coastguard Worker        clone_inputs,
73*da0073e9SAndroid Build Coastguard Worker        graph_break_reasons,
74*da0073e9SAndroid Build Coastguard Worker        maybe_enable_compiled_autograd,
75*da0073e9SAndroid Build Coastguard Worker    )
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Workerimport torch._functorch.config
78*da0073e9SAndroid Build Coastguard Workerfrom torch._functorch.aot_autograd import set_model_name
79*da0073e9SAndroid Build Coastguard Workerfrom torch._inductor import config as inductor_config, metrics
80*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensorMode
81*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
82*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map, tree_map_only
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Workertry:
85*da0073e9SAndroid Build Coastguard Worker    import torch_xla
86*da0073e9SAndroid Build Coastguard Worker    import torch_xla.core.xla_model as xm
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    # This is to woraround the backward issue https://github.com/pytorch/xla/issues/4174
89*da0073e9SAndroid Build Coastguard Worker    torch_xla._XLAC._init_computation_client()
90*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
91*da0073e9SAndroid Build Coastguard Worker    # ignore the error if torch_xla is not installed
92*da0073e9SAndroid Build Coastguard Worker    pass
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING:
96*da0073e9SAndroid Build Coastguard Worker    from torch.onnx._internal.fx import diagnostics
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Workerlog = logging.getLogger(__name__)
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker# We are primarily interested in TF32
102*da0073e9SAndroid Build Coastguard Workertorch.backends.cuda.matmul.allow_tf32 = True
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker# Suppress torch.profiler spam
105*da0073e9SAndroid Build Coastguard Workeros.environ["KINETO_LOG_LEVEL"] = "5"
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Workercurrent_name = ""
108*da0073e9SAndroid Build Coastguard Workercurrent_device = ""
109*da0073e9SAndroid Build Coastguard Workercurrent_onnx_compiler = ""
110*da0073e9SAndroid Build Coastguard Workercurrent_batch_size = None
111*da0073e9SAndroid Build Coastguard Workeroutput_filename = None
112*da0073e9SAndroid Build Coastguard Workerdisable_output = False
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard WorkerMAX_DOWNLOAD_ATTEMPTS = 5
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Workerclass CI(NamedTuple):
118*da0073e9SAndroid Build Coastguard Worker    backend: str  # aot_eager or inductor
119*da0073e9SAndroid Build Coastguard Worker    training: bool
120*da0073e9SAndroid Build Coastguard Worker    dynamic: bool = False
121*da0073e9SAndroid Build Coastguard Worker    device: str = "cuda"
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard WorkerCI_SKIP_OPTIMIZER = {
125*da0073e9SAndroid Build Coastguard Worker    # TIMM
126*da0073e9SAndroid Build Coastguard Worker    "convmixer_768_32",  # accuracy
127*da0073e9SAndroid Build Coastguard Worker    "hrnet_w18",  # Stack issue in fx
128*da0073e9SAndroid Build Coastguard Worker    # HF
129*da0073e9SAndroid Build Coastguard Worker    "pnasnet5large",  # Stack issue in fx
130*da0073e9SAndroid Build Coastguard Worker    "MobileBertForMaskedLM",  # Stack issue in fx
131*da0073e9SAndroid Build Coastguard Worker    "MobileBertForQuestionAnswering",  # Stack issue in fx
132*da0073e9SAndroid Build Coastguard Worker    "PegasusForConditionalGeneration",  # OOM
133*da0073e9SAndroid Build Coastguard Worker}
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard WorkerCI_SKIP_DYNAMIC_BATCH_ONLY = {
136*da0073e9SAndroid Build Coastguard Worker    "sam",
137*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/mindee/doctr/blob/f2114758d529ed8d3d0030581638f0520b6b98d8/doctr/models/detection/core.py#L89
138*da0073e9SAndroid Build Coastguard Worker    # It iterates over the batch, which is dynamic, and dynamo chokes
139*da0073e9SAndroid Build Coastguard Worker    # We should be able to graphbreak there.
140*da0073e9SAndroid Build Coastguard Worker    "doctr_det_predictor",
141*da0073e9SAndroid Build Coastguard Worker    "dlrm",
142*da0073e9SAndroid Build Coastguard Worker    "pyhpc_isoneutral_mixing",
143*da0073e9SAndroid Build Coastguard Worker    "pyhpc_equation_of_state",
144*da0073e9SAndroid Build Coastguard Worker    "pyhpc_turbulent_kinetic_energy",
145*da0073e9SAndroid Build Coastguard Worker    "detectron2_fcos_r_50_fpn",
146*da0073e9SAndroid Build Coastguard Worker    "hf_T5_generate",
147*da0073e9SAndroid Build Coastguard Worker}
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker# These models currently fail accuracy with eager Adam optimizer
150*da0073e9SAndroid Build Coastguard Worker# so we use SGD when running the full benchmarks
151*da0073e9SAndroid Build Coastguard Worker# https://github.com/pytorch/pytorch/issues/115966
152*da0073e9SAndroid Build Coastguard WorkerBENCHMARK_USE_SGD = {
153*da0073e9SAndroid Build Coastguard Worker    # TorchBench
154*da0073e9SAndroid Build Coastguard Worker    "BERT_pytorch",
155*da0073e9SAndroid Build Coastguard Worker    "LearningToPaint",
156*da0073e9SAndroid Build Coastguard Worker    "alexnet",
157*da0073e9SAndroid Build Coastguard Worker    "dcgan",
158*da0073e9SAndroid Build Coastguard Worker    "demucs",
159*da0073e9SAndroid Build Coastguard Worker    "densenet121",
160*da0073e9SAndroid Build Coastguard Worker    "dlrm",
161*da0073e9SAndroid Build Coastguard Worker    "fastNLP_Bert",
162*da0073e9SAndroid Build Coastguard Worker    "mobilenet_v2",
163*da0073e9SAndroid Build Coastguard Worker    "phlippe_densenet",
164*da0073e9SAndroid Build Coastguard Worker    "phlippe_resnet",
165*da0073e9SAndroid Build Coastguard Worker    "pytorch_stargan",
166*da0073e9SAndroid Build Coastguard Worker    "resnet18",
167*da0073e9SAndroid Build Coastguard Worker    "shufflenet_v2_x1_0",
168*da0073e9SAndroid Build Coastguard Worker    "speech_transformer",
169*da0073e9SAndroid Build Coastguard Worker    "squeezenet1_1",
170*da0073e9SAndroid Build Coastguard Worker    "stable_diffusion_text_encoder",
171*da0073e9SAndroid Build Coastguard Worker    "timm_efficientdet",
172*da0073e9SAndroid Build Coastguard Worker    "timm_nfnet",
173*da0073e9SAndroid Build Coastguard Worker    "timm_regnet",
174*da0073e9SAndroid Build Coastguard Worker    "timm_vision_transformer",
175*da0073e9SAndroid Build Coastguard Worker    "timm_vovnet",
176*da0073e9SAndroid Build Coastguard Worker    "vgg16",
177*da0073e9SAndroid Build Coastguard Worker    "hf_T5",  # Fails dynamic https://github.com/pytorch/pytorch/issues/115968
178*da0073e9SAndroid Build Coastguard Worker    # HF
179*da0073e9SAndroid Build Coastguard Worker    "AlbertForMaskedLM",
180*da0073e9SAndroid Build Coastguard Worker    "BartForCausalLM",
181*da0073e9SAndroid Build Coastguard Worker    "BartForConditionalGeneration",
182*da0073e9SAndroid Build Coastguard Worker    "BlenderbotSmallForCausalLM",
183*da0073e9SAndroid Build Coastguard Worker    "BlenderbotSmallForConditionalGeneration",
184*da0073e9SAndroid Build Coastguard Worker    "DebertaV2ForQuestionAnswering",  # eager OOM
185*da0073e9SAndroid Build Coastguard Worker    "ElectraForCausalLM",
186*da0073e9SAndroid Build Coastguard Worker    "M2M100ForConditionalGeneration",
187*da0073e9SAndroid Build Coastguard Worker    "MBartForCausalLM",
188*da0073e9SAndroid Build Coastguard Worker    "MBartForConditionalGeneration",
189*da0073e9SAndroid Build Coastguard Worker    "OPTForCausalLM",
190*da0073e9SAndroid Build Coastguard Worker    "PLBartForCausalLM",
191*da0073e9SAndroid Build Coastguard Worker    "PLBartForConditionalGeneration",
192*da0073e9SAndroid Build Coastguard Worker    "PegasusForCausalLM",
193*da0073e9SAndroid Build Coastguard Worker    "Speech2Text2ForCausalLM",
194*da0073e9SAndroid Build Coastguard Worker    "TrOCRForCausalLM",
195*da0073e9SAndroid Build Coastguard Worker    "XGLMForCausalLM",
196*da0073e9SAndroid Build Coastguard Worker    # TIMM
197*da0073e9SAndroid Build Coastguard Worker    "adv_inception_v3",
198*da0073e9SAndroid Build Coastguard Worker    "botnet26t_256",
199*da0073e9SAndroid Build Coastguard Worker    "cait_m36_384",  # OOM
200*da0073e9SAndroid Build Coastguard Worker    "coat_lite_mini",
201*da0073e9SAndroid Build Coastguard Worker    "convit_base",
202*da0073e9SAndroid Build Coastguard Worker    "dpn107",
203*da0073e9SAndroid Build Coastguard Worker    "fbnetv3_b",
204*da0073e9SAndroid Build Coastguard Worker    "gernet_l",
205*da0073e9SAndroid Build Coastguard Worker    "lcnet_050",
206*da0073e9SAndroid Build Coastguard Worker    "mixnet_l",
207*da0073e9SAndroid Build Coastguard Worker    "res2net101_26w_4s",
208*da0073e9SAndroid Build Coastguard Worker    "res2net50_14w_8s",
209*da0073e9SAndroid Build Coastguard Worker    "res2next50",
210*da0073e9SAndroid Build Coastguard Worker    "resnest101e",
211*da0073e9SAndroid Build Coastguard Worker    "sebotnet33ts_256",
212*da0073e9SAndroid Build Coastguard Worker    "swsl_resnext101_32x16d",
213*da0073e9SAndroid Build Coastguard Worker    "tf_efficientnet_b0",
214*da0073e9SAndroid Build Coastguard Worker    "ghostnet_100",
215*da0073e9SAndroid Build Coastguard Worker    "gmixer_24_224",
216*da0073e9SAndroid Build Coastguard Worker    "tinynet_a",
217*da0073e9SAndroid Build Coastguard Worker}
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker# These models OOM in CI
220*da0073e9SAndroid Build Coastguard Worker# due to the extra memory of Adam optimizer states,
221*da0073e9SAndroid Build Coastguard Worker# so we fall back to SGD in CI
222*da0073e9SAndroid Build Coastguard WorkerCI_USE_SGD = {
223*da0073e9SAndroid Build Coastguard Worker    "torchrec_dlrm",
224*da0073e9SAndroid Build Coastguard Worker    "demucs",
225*da0073e9SAndroid Build Coastguard Worker    "detectron2_fasterrcnn_r_101_c4",
226*da0073e9SAndroid Build Coastguard Worker    "detectron2_fasterrcnn_r_101_dc5",
227*da0073e9SAndroid Build Coastguard Worker    "detectron2_fasterrcnn_r_101_fpn",
228*da0073e9SAndroid Build Coastguard Worker    "detectron2_fasterrcnn_r_50_c4",
229*da0073e9SAndroid Build Coastguard Worker    "detectron2_fasterrcnn_r_50_dc5",
230*da0073e9SAndroid Build Coastguard Worker    "detectron2_fasterrcnn_r_50_fpn",
231*da0073e9SAndroid Build Coastguard Worker    "detectron2_maskrcnn_r_101_c4",
232*da0073e9SAndroid Build Coastguard Worker    "detectron2_maskrcnn_r_101_fpn",
233*da0073e9SAndroid Build Coastguard Worker    "detectron2_maskrcnn_r_50_c4",
234*da0073e9SAndroid Build Coastguard Worker    "detectron2_maskrcnn_r_50_fpn",
235*da0073e9SAndroid Build Coastguard Worker    "hf_T5_base",
236*da0073e9SAndroid Build Coastguard Worker    "hf_clip",
237*da0073e9SAndroid Build Coastguard Worker    "llama_v2_7b_16h",
238*da0073e9SAndroid Build Coastguard Worker    "mobilenet_v2_quantized_qat",
239*da0073e9SAndroid Build Coastguard Worker    "phi_1_5 resnet50_quantized_qat",
240*da0073e9SAndroid Build Coastguard Worker    "BlenderbotForCausalLM",
241*da0073e9SAndroid Build Coastguard Worker    "cait_m36_384",
242*da0073e9SAndroid Build Coastguard Worker    "DALLE2_pytorch",
243*da0073e9SAndroid Build Coastguard Worker    "moco",
244*da0073e9SAndroid Build Coastguard Worker    "timm_efficientdet",
245*da0073e9SAndroid Build Coastguard Worker    "ghostnet_100",
246*da0073e9SAndroid Build Coastguard Worker    "regnety_002",
247*da0073e9SAndroid Build Coastguard Worker    "poolformer_m36",
248*da0073e9SAndroid Build Coastguard Worker    "inception_v3",
249*da0073e9SAndroid Build Coastguard Worker    "tinynet_a",
250*da0073e9SAndroid Build Coastguard Worker    "selecsls42b",
251*da0073e9SAndroid Build Coastguard Worker    "mobilevit_s",
252*da0073e9SAndroid Build Coastguard Worker    "pytorch_CycleGAN_and_pix2pix",
253*da0073e9SAndroid Build Coastguard Worker    "vision_maskrcnn",
254*da0073e9SAndroid Build Coastguard Worker    "resmlp_12_224",
255*da0073e9SAndroid Build Coastguard Worker    "dlrm",
256*da0073e9SAndroid Build Coastguard Worker    "resnet50",
257*da0073e9SAndroid Build Coastguard Worker    "dm_nfnet_f0",
258*da0073e9SAndroid Build Coastguard Worker    "pit_b_224",
259*da0073e9SAndroid Build Coastguard Worker    "tf_mixnet_l",
260*da0073e9SAndroid Build Coastguard Worker}
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard WorkerDO_NOT_CAST_INPUTS = {"stable_diffusion"}
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker# Maps a benchmark model name to a list of status codes. For any listed entry, we'll
267*da0073e9SAndroid Build Coastguard Worker# capture TORCH_COMPILE_DEBUG logs in CI runs and preseve them (i.e., for upload) if
268*da0073e9SAndroid Build Coastguard Worker# the result status matches one listed.
269*da0073e9SAndroid Build Coastguard WorkerCI_PRESERVE_COMPILE_DEBUG = {
270*da0073e9SAndroid Build Coastguard Worker    # For example:
271*da0073e9SAndroid Build Coastguard Worker    # "mnasnet1_0": ["fail_accuracy"],
272*da0073e9SAndroid Build Coastguard Worker}
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Workerdef model_specified_by_path(path_and_class_str):
276*da0073e9SAndroid Build Coastguard Worker    return ":" in path_and_class_str
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Workerdef load_model_from_path(path_and_class_str):
280*da0073e9SAndroid Build Coastguard Worker    configs = {}
281*da0073e9SAndroid Build Coastguard Worker    for kvstr in path_and_class_str.split(","):
282*da0073e9SAndroid Build Coastguard Worker        k, v = kvstr.split(":")
283*da0073e9SAndroid Build Coastguard Worker        configs[k] = v
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker    for name in ["path", "class"]:
286*da0073e9SAndroid Build Coastguard Worker        if name not in configs:
287*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
288*da0073e9SAndroid Build Coastguard Worker                "Invalid --only arguments. Check help message for the correct format"
289*da0073e9SAndroid Build Coastguard Worker            )
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker    path = configs["path"]
292*da0073e9SAndroid Build Coastguard Worker    class_name = configs["class"]
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker    if path[:1] != "/":
295*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
296*da0073e9SAndroid Build Coastguard Worker            "Use absolute path since dynamo may change the current working directory which makes using relative path tricky"
297*da0073e9SAndroid Build Coastguard Worker        )
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker    spec = importlib.util.spec_from_file_location("module_name", path)
300*da0073e9SAndroid Build Coastguard Worker    module = importlib.util.module_from_spec(spec)
301*da0073e9SAndroid Build Coastguard Worker    spec.loader.exec_module(module)
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker    model_class = getattr(module, class_name)
304*da0073e9SAndroid Build Coastguard Worker    assert issubclass(model_class, torch.nn.Module)
305*da0073e9SAndroid Build Coastguard Worker    model = model_class()
306*da0073e9SAndroid Build Coastguard Worker    assert hasattr(model, "get_example_inputs")
307*da0073e9SAndroid Build Coastguard Worker    inputs = model.get_example_inputs()
308*da0073e9SAndroid Build Coastguard Worker    return model, inputs
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Workerdef output_csv(filename, headers, row):
312*da0073e9SAndroid Build Coastguard Worker    global disable_output
313*da0073e9SAndroid Build Coastguard Worker    if disable_output:
314*da0073e9SAndroid Build Coastguard Worker        return
315*da0073e9SAndroid Build Coastguard Worker    if os.path.exists(filename):
316*da0073e9SAndroid Build Coastguard Worker        with open(filename) as fd:
317*da0073e9SAndroid Build Coastguard Worker            lines = list(csv.reader(fd)) or [[]]
318*da0073e9SAndroid Build Coastguard Worker            if headers and len(headers) > len(lines[0]):
319*da0073e9SAndroid Build Coastguard Worker                # if prior results failed the header might not be filled in yet
320*da0073e9SAndroid Build Coastguard Worker                lines[0] = headers
321*da0073e9SAndroid Build Coastguard Worker            else:
322*da0073e9SAndroid Build Coastguard Worker                headers = lines[0]
323*da0073e9SAndroid Build Coastguard Worker    else:
324*da0073e9SAndroid Build Coastguard Worker        lines = [headers]
325*da0073e9SAndroid Build Coastguard Worker    lines.append([(f"{x:.6f}" if isinstance(x, float) else x) for x in row])
326*da0073e9SAndroid Build Coastguard Worker    with open(filename, "w") as fd:
327*da0073e9SAndroid Build Coastguard Worker        writer = csv.writer(fd, lineterminator="\n")
328*da0073e9SAndroid Build Coastguard Worker        for line in lines:
329*da0073e9SAndroid Build Coastguard Worker            writer.writerow(list(line) + ["0"] * (len(headers) - len(line)))
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Workerdef nothing(f):
333*da0073e9SAndroid Build Coastguard Worker    return f
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(None)
337*da0073e9SAndroid Build Coastguard Workerdef patch_torch_manual_seed():
338*da0073e9SAndroid Build Coastguard Worker    """Make torch manual seed deterministic. Helps with accuracy testing."""
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker    def deterministic_torch_manual_seed(*args, **kwargs):
341*da0073e9SAndroid Build Coastguard Worker        from torch._C import default_generator
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker        seed = 1337
344*da0073e9SAndroid Build Coastguard Worker        if HAS_CUDA:
345*da0073e9SAndroid Build Coastguard Worker            import torch.cuda
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker            if not torch.cuda._is_in_bad_fork():
348*da0073e9SAndroid Build Coastguard Worker                torch.cuda.manual_seed_all(seed)
349*da0073e9SAndroid Build Coastguard Worker        if HAS_XPU:
350*da0073e9SAndroid Build Coastguard Worker            import torch.xpu
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker            if not torch.xpu._is_in_bad_fork():
353*da0073e9SAndroid Build Coastguard Worker                torch.xpu.manual_seed_all(seed)
354*da0073e9SAndroid Build Coastguard Worker        return default_generator.manual_seed(seed)
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker    torch.manual_seed = deterministic_torch_manual_seed
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Workerdef empty_gpu_cache(device):
360*da0073e9SAndroid Build Coastguard Worker    """
361*da0073e9SAndroid Build Coastguard Worker    Explicitly empty gpu cache to avoid OOM in subsequent run.
362*da0073e9SAndroid Build Coastguard Worker    """
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker    if device not in ["cuda", "xpu"]:
365*da0073e9SAndroid Build Coastguard Worker        log.warning(
366*da0073e9SAndroid Build Coastguard Worker            "Trying to call the empty_gpu_cache for device: %s, which is not in list [cuda, xpu]",
367*da0073e9SAndroid Build Coastguard Worker            device,
368*da0073e9SAndroid Build Coastguard Worker        )
369*da0073e9SAndroid Build Coastguard Worker        return
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker    if device == "cuda":
372*da0073e9SAndroid Build Coastguard Worker        torch.cuda.empty_cache()
373*da0073e9SAndroid Build Coastguard Worker    elif device == "xpu":
374*da0073e9SAndroid Build Coastguard Worker        torch.xpu.empty_cache()
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Workerdef synchronize():
378*da0073e9SAndroid Build Coastguard Worker    pass
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Workerdef summarize_graph_break(filename):
382*da0073e9SAndroid Build Coastguard Worker    """
383*da0073e9SAndroid Build Coastguard Worker    Sorts and de-dupes the graphs breaks on the reason string. Note that this
384*da0073e9SAndroid Build Coastguard Worker    function is just a best effort to reduce the logging information. We could
385*da0073e9SAndroid Build Coastguard Worker    miss some graph breaks because of de-duping. We can further refine this
386*da0073e9SAndroid Build Coastguard Worker    function as need arises.
387*da0073e9SAndroid Build Coastguard Worker    """
388*da0073e9SAndroid Build Coastguard Worker    log_file = f"{filename.rstrip('.csv')}_graph_breaks.csv"
389*da0073e9SAndroid Build Coastguard Worker    if os.path.exists(log_file):
390*da0073e9SAndroid Build Coastguard Worker        df = pd.read_csv(log_file)
391*da0073e9SAndroid Build Coastguard Worker        df = df.sort_values("reason").drop_duplicates(subset="reason")
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker        # Specialize for multi tensor sgd as reason is not identical
394*da0073e9SAndroid Build Coastguard Worker        multi_tensor_sgd_row = df.loc[df["reason"].str.contains("_multi_tensor_sgd")]
395*da0073e9SAndroid Build Coastguard Worker        if len(multi_tensor_sgd_row):
396*da0073e9SAndroid Build Coastguard Worker            df = df[
397*da0073e9SAndroid Build Coastguard Worker                ~df["reason"].str.contains("_multi_tensor_sgd")
398*da0073e9SAndroid Build Coastguard Worker            ]  # Drop all sgd rows
399*da0073e9SAndroid Build Coastguard Worker            df = pd.concat(
400*da0073e9SAndroid Build Coastguard Worker                [df, pd.DataFrame([multi_tensor_sgd_row.iloc[0]])], axis=0
401*da0073e9SAndroid Build Coastguard Worker            )  # Add back a single row
402*da0073e9SAndroid Build Coastguard Worker        df.to_csv(f"{log_file.rstrip('.csv')}_deduped.csv", index=False)
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Workerdef print_summary(filename, print_dataframe=False):
406*da0073e9SAndroid Build Coastguard Worker    if not (filename and os.path.exists(filename)):
407*da0073e9SAndroid Build Coastguard Worker        return
408*da0073e9SAndroid Build Coastguard Worker    data = pd.read_csv(filename)
409*da0073e9SAndroid Build Coastguard Worker    if "tag" in data.columns:
410*da0073e9SAndroid Build Coastguard Worker        for tag in data.tag.unique():
411*da0073e9SAndroid Build Coastguard Worker            if tag == "0.0000":
412*da0073e9SAndroid Build Coastguard Worker                continue  # This happens for failed runs
413*da0073e9SAndroid Build Coastguard Worker            print(f"\nSummary for tag={tag}:")
414*da0073e9SAndroid Build Coastguard Worker            print_summary_table(data[data.tag == tag], print_dataframe=print_dataframe)
415*da0073e9SAndroid Build Coastguard Worker    else:
416*da0073e9SAndroid Build Coastguard Worker        print_summary_table(data, print_dataframe=print_dataframe)
417*da0073e9SAndroid Build Coastguard Worker    summarize_graph_break(filename)
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Workerdef print_summary_table(data, print_dataframe=False):
421*da0073e9SAndroid Build Coastguard Worker    if print_dataframe:
422*da0073e9SAndroid Build Coastguard Worker        pd.options.display.max_rows = 1000
423*da0073e9SAndroid Build Coastguard Worker        pd.options.display.max_columns = 1000
424*da0073e9SAndroid Build Coastguard Worker        pd.options.display.width = 2000
425*da0073e9SAndroid Build Coastguard Worker        print(data)
426*da0073e9SAndroid Build Coastguard Worker    width = max(map(len, data.columns))
427*da0073e9SAndroid Build Coastguard Worker    for col in data.columns:
428*da0073e9SAndroid Build Coastguard Worker        try:
429*da0073e9SAndroid Build Coastguard Worker            if col in ("dev", "name", "batch_size", "tag"):
430*da0073e9SAndroid Build Coastguard Worker                continue
431*da0073e9SAndroid Build Coastguard Worker            elif col in ("pct_ops", "pct_time"):
432*da0073e9SAndroid Build Coastguard Worker                print(col.ljust(width), f"{data[col].mean():.3%}")
433*da0073e9SAndroid Build Coastguard Worker            elif col in ("graphs", "graph_calls", "captured_ops", "total_ops"):
434*da0073e9SAndroid Build Coastguard Worker                print(col.ljust(width), f"{data[col].mean():.3f}")
435*da0073e9SAndroid Build Coastguard Worker            elif col in ("compilation_latency"):
436*da0073e9SAndroid Build Coastguard Worker                print(col.ljust(width), f"mean={data[col].mean():.3f} seconds")
437*da0073e9SAndroid Build Coastguard Worker            elif col in ("compression_ratio"):
438*da0073e9SAndroid Build Coastguard Worker                print(col.ljust(width), f"mean={data[col].mean():.3f}x")
439*da0073e9SAndroid Build Coastguard Worker            elif col in ("accuracy"):
440*da0073e9SAndroid Build Coastguard Worker                pass_rate = (data[col] == "pass").mean()
441*da0073e9SAndroid Build Coastguard Worker                print(col.ljust(width), f"pass_rate={100*pass_rate:.2f}%")
442*da0073e9SAndroid Build Coastguard Worker            else:
443*da0073e9SAndroid Build Coastguard Worker                cdata = data[col]
444*da0073e9SAndroid Build Coastguard Worker                print(
445*da0073e9SAndroid Build Coastguard Worker                    col.ljust(width),
446*da0073e9SAndroid Build Coastguard Worker                    f"gmean={gmean(cdata):.2f}x mean={cdata.mean():.3f}x",
447*da0073e9SAndroid Build Coastguard Worker                )
448*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
449*da0073e9SAndroid Build Coastguard Worker            pass
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Workerdef tensor_is_on_xla(tensors):
453*da0073e9SAndroid Build Coastguard Worker    def visit(x: torch.Tensor):
454*da0073e9SAndroid Build Coastguard Worker        nonlocal result
455*da0073e9SAndroid Build Coastguard Worker        if x.device.type == "xla":
456*da0073e9SAndroid Build Coastguard Worker            result = True
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker    result = False
459*da0073e9SAndroid Build Coastguard Worker    tree_map_only(torch.Tensor, visit, tensors)
460*da0073e9SAndroid Build Coastguard Worker    return result
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Workerdef timed(
464*da0073e9SAndroid Build Coastguard Worker    model,
465*da0073e9SAndroid Build Coastguard Worker    model_iter_fn,
466*da0073e9SAndroid Build Coastguard Worker    example_inputs,
467*da0073e9SAndroid Build Coastguard Worker    times=1,
468*da0073e9SAndroid Build Coastguard Worker    return_result=False,
469*da0073e9SAndroid Build Coastguard Worker    collect_outputs=False,
470*da0073e9SAndroid Build Coastguard Worker):
471*da0073e9SAndroid Build Coastguard Worker    use_xla = tensor_is_on_xla(example_inputs)
472*da0073e9SAndroid Build Coastguard Worker    synchronize()
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker    if use_xla:
475*da0073e9SAndroid Build Coastguard Worker        xm.mark_step()
476*da0073e9SAndroid Build Coastguard Worker        xm.wait_device_ops()
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker    time_total = 0
479*da0073e9SAndroid Build Coastguard Worker    # Dont collect outputs to correctly measure timing
480*da0073e9SAndroid Build Coastguard Worker    for _ in range(times):
481*da0073e9SAndroid Build Coastguard Worker        # Put this call inside the loop to reset the seed for each iteration.
482*da0073e9SAndroid Build Coastguard Worker        # Don't include reset_rng_state() to correctly measure timing
483*da0073e9SAndroid Build Coastguard Worker        reset_rng_state(use_xla)
484*da0073e9SAndroid Build Coastguard Worker        t_iter_begin = time.perf_counter()
485*da0073e9SAndroid Build Coastguard Worker        result = model_iter_fn(model, example_inputs, collect_outputs=collect_outputs)
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker        # instead of calling sync on result_list, we should call mark_step.
488*da0073e9SAndroid Build Coastguard Worker        # In training case, result_list may be empty, but we want to
489*da0073e9SAndroid Build Coastguard Worker        # send all the pending graphs for compilation.
490*da0073e9SAndroid Build Coastguard Worker        if use_xla:
491*da0073e9SAndroid Build Coastguard Worker            # For the model running on regular torchxla (baseline), we need the
492*da0073e9SAndroid Build Coastguard Worker            # mark step to send the accumulated graph for compilation.
493*da0073e9SAndroid Build Coastguard Worker            #
494*da0073e9SAndroid Build Coastguard Worker            # For the model running with dynamo/torchxla bridge, in training case,
495*da0073e9SAndroid Build Coastguard Worker            # we need the mark step to send the optimizer graph out for
496*da0073e9SAndroid Build Coastguard Worker            # compilation.
497*da0073e9SAndroid Build Coastguard Worker            xm.mark_step()
498*da0073e9SAndroid Build Coastguard Worker        t_iter_end = time.perf_counter()
499*da0073e9SAndroid Build Coastguard Worker        time_total += t_iter_end - t_iter_begin
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker    t_0 = time.perf_counter()
502*da0073e9SAndroid Build Coastguard Worker    if use_xla:
503*da0073e9SAndroid Build Coastguard Worker        xm.wait_device_ops()
504*da0073e9SAndroid Build Coastguard Worker    synchronize()
505*da0073e9SAndroid Build Coastguard Worker    t_1 = time.perf_counter()
506*da0073e9SAndroid Build Coastguard Worker    time_total += t_1 - t_0
507*da0073e9SAndroid Build Coastguard Worker    return (time_total, result) if return_result else time_total
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Workerdef _normalize_bench_inputs(example_inputs) -> Tuple[Tuple[Any], Mapping[str, Any]]:
511*da0073e9SAndroid Build Coastguard Worker    # NOTE(bowbao): For huggingface benchmark, example_inputs are formatted as dictionary,
512*da0073e9SAndroid Build Coastguard Worker    # and consumed like `model(**example_inputs)`.
513*da0073e9SAndroid Build Coastguard Worker    # For other benchmarks, example_inputs are formatted as tuple and consumed
514*da0073e9SAndroid Build Coastguard Worker    # like `model(*example_inputs)`.
515*da0073e9SAndroid Build Coastguard Worker    if isinstance(example_inputs, dict):
516*da0073e9SAndroid Build Coastguard Worker        return (), example_inputs
517*da0073e9SAndroid Build Coastguard Worker    else:
518*da0073e9SAndroid Build Coastguard Worker        return tuple(example_inputs), {}
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Workerdef _register_dataclass_output_as_pytree(example_outputs) -> None:
522*da0073e9SAndroid Build Coastguard Worker    # NOTE(angelayi): For huggingface benchmark, some example outputs are
523*da0073e9SAndroid Build Coastguard Worker    # formatted as a dataclass which pytree cannot consume. So we want
524*da0073e9SAndroid Build Coastguard Worker    # to register the pytree implementation here
525*da0073e9SAndroid Build Coastguard Worker    example_outputs_flat = pytree.tree_leaves(example_outputs)
526*da0073e9SAndroid Build Coastguard Worker    output_dataclass_types = [
527*da0073e9SAndroid Build Coastguard Worker        type(out) for out in example_outputs_flat if dataclasses.is_dataclass(type(out))
528*da0073e9SAndroid Build Coastguard Worker    ]
529*da0073e9SAndroid Build Coastguard Worker    for output_type in output_dataclass_types:
530*da0073e9SAndroid Build Coastguard Worker        from torch._export.utils import register_dataclass_as_pytree_node
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker        register_dataclass_as_pytree_node(
533*da0073e9SAndroid Build Coastguard Worker            output_type,
534*da0073e9SAndroid Build Coastguard Worker            serialized_type_name=f"{output_type.__module__}.{output_type.__name__}",
535*da0073e9SAndroid Build Coastguard Worker        )
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Workerclass Stats:
539*da0073e9SAndroid Build Coastguard Worker    totals = collections.defaultdict(collections.Counter)
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker    @classmethod
542*da0073e9SAndroid Build Coastguard Worker    def reset_counters(cls):
543*da0073e9SAndroid Build Coastguard Worker        for k, v in torch._dynamo.utils.counters.items():
544*da0073e9SAndroid Build Coastguard Worker            cls.totals[k].update(v)
545*da0073e9SAndroid Build Coastguard Worker        ok = torch._dynamo.utils.counters["frames"]["ok"]
546*da0073e9SAndroid Build Coastguard Worker        total = torch._dynamo.utils.counters["frames"]["total"]
547*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.utils.counters.clear()
548*da0073e9SAndroid Build Coastguard Worker        return ok, total
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker    @classmethod
551*da0073e9SAndroid Build Coastguard Worker    def print_summary(cls):
552*da0073e9SAndroid Build Coastguard Worker        for k, v in sorted(cls.totals.items()):
553*da0073e9SAndroid Build Coastguard Worker            lines = "\n  ".join(map(str, v.most_common(50)))
554*da0073e9SAndroid Build Coastguard Worker            print(f"STATS {k}\n  {lines}")
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker    @classmethod
557*da0073e9SAndroid Build Coastguard Worker    def aot_summary(cls):
558*da0073e9SAndroid Build Coastguard Worker        return [cls.totals["aot_autograd"]["total"], cls.totals["aot_autograd"]["ok"]]
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Workerdef coverage_experiment(args, model_iter_fn, model, example_inputs):
562*da0073e9SAndroid Build Coastguard Worker    """
563*da0073e9SAndroid Build Coastguard Worker    Test operator/model coverage of TorchDynamo and record statistics
564*da0073e9SAndroid Build Coastguard Worker    taken from a profiler.  This target is mainly intended to check
565*da0073e9SAndroid Build Coastguard Worker    correctness.
566*da0073e9SAndroid Build Coastguard Worker
567*da0073e9SAndroid Build Coastguard Worker    Writes to ./coverage.csv
568*da0073e9SAndroid Build Coastguard Worker    """
569*da0073e9SAndroid Build Coastguard Worker    profiler = Profiler()
570*da0073e9SAndroid Build Coastguard Worker    frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
571*da0073e9SAndroid Build Coastguard Worker    with profiler.prof:
572*da0073e9SAndroid Build Coastguard Worker        frozen_model_iter_fn(model, example_inputs)
573*da0073e9SAndroid Build Coastguard Worker    coverage_result = profiler.results()
574*da0073e9SAndroid Build Coastguard Worker    output_csv(
575*da0073e9SAndroid Build Coastguard Worker        output_filename,
576*da0073e9SAndroid Build Coastguard Worker        (
577*da0073e9SAndroid Build Coastguard Worker            "dev",
578*da0073e9SAndroid Build Coastguard Worker            "name",
579*da0073e9SAndroid Build Coastguard Worker            "batch_size",
580*da0073e9SAndroid Build Coastguard Worker            "graphs",
581*da0073e9SAndroid Build Coastguard Worker            "graph_calls",
582*da0073e9SAndroid Build Coastguard Worker            "captured_ops",
583*da0073e9SAndroid Build Coastguard Worker            "total_ops",
584*da0073e9SAndroid Build Coastguard Worker            "pct_ops",
585*da0073e9SAndroid Build Coastguard Worker            "pct_time",
586*da0073e9SAndroid Build Coastguard Worker        ),
587*da0073e9SAndroid Build Coastguard Worker        [
588*da0073e9SAndroid Build Coastguard Worker            current_device,
589*da0073e9SAndroid Build Coastguard Worker            current_name,
590*da0073e9SAndroid Build Coastguard Worker            current_batch_size,
591*da0073e9SAndroid Build Coastguard Worker        ]
592*da0073e9SAndroid Build Coastguard Worker        + coverage_result.tocsv(),
593*da0073e9SAndroid Build Coastguard Worker    )
594*da0073e9SAndroid Build Coastguard Worker    return coverage_result
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Workerdef speedup_experiment_fx2trt(args, model_iter_fn, model, example_inputs):
598*da0073e9SAndroid Build Coastguard Worker    """
599*da0073e9SAndroid Build Coastguard Worker    Measure speedups over eager using the trt inference backend. TRT backend is based fx graph
600*da0073e9SAndroid Build Coastguard Worker    generated by torch._dynamo.
601*da0073e9SAndroid Build Coastguard Worker    Writes to ./speedups_fx2trt.csv
602*da0073e9SAndroid Build Coastguard Worker    """
603*da0073e9SAndroid Build Coastguard Worker    return speedup_experiment(args, model_iter_fn, model, example_inputs)
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Workerdef recompile_profiler_experiment(args, model_iter_fn, model, example_inputs):
607*da0073e9SAndroid Build Coastguard Worker    prof = torch._dynamo.utils.CompilerProfiler()
608*da0073e9SAndroid Build Coastguard Worker    opt_model_iter_fn = torch._dynamo.optimize(prof, nopython=args.nopython)(
609*da0073e9SAndroid Build Coastguard Worker        model_iter_fn
610*da0073e9SAndroid Build Coastguard Worker    )
611*da0073e9SAndroid Build Coastguard Worker    opt_model_iter_fn(model, example_inputs)
612*da0073e9SAndroid Build Coastguard Worker    output_csv(
613*da0073e9SAndroid Build Coastguard Worker        output_filename, ["model", "profiler report"], [current_name, prof.report()]
614*da0073e9SAndroid Build Coastguard Worker    )
615*da0073e9SAndroid Build Coastguard Worker    met = prof.get_metrics()
616*da0073e9SAndroid Build Coastguard Worker    guard_failures = len(met["guard_failures"])
617*da0073e9SAndroid Build Coastguard Worker    return [guard_failures]
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Workerdef randomize_input(inputs):
621*da0073e9SAndroid Build Coastguard Worker    if isinstance(inputs, (list, tuple)):
622*da0073e9SAndroid Build Coastguard Worker        return type(inputs)([randomize_input(x) for x in inputs])
623*da0073e9SAndroid Build Coastguard Worker    elif isinstance(inputs, torch.Tensor):
624*da0073e9SAndroid Build Coastguard Worker        if inputs.dtype in (torch.float32, torch.float64):
625*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.utils.counters["randomize_input"]["times"] += 1
626*da0073e9SAndroid Build Coastguard Worker            return torch.randn_like(inputs)
627*da0073e9SAndroid Build Coastguard Worker        elif inputs.dtype == torch.int64:
628*da0073e9SAndroid Build Coastguard Worker            # Note: we can not simply tune integer tensors as follows
629*da0073e9SAndroid Build Coastguard Worker            #   `return torch.randint_like(inputs, high=inputs.max().item())`
630*da0073e9SAndroid Build Coastguard Worker            # This may break some invariants between tensors.
631*da0073e9SAndroid Build Coastguard Worker            # E.g. in embedding lookup case, one tensor is the length
632*da0073e9SAndroid Build Coastguard Worker            # and another is an indices tensor.
633*da0073e9SAndroid Build Coastguard Worker            return inputs
634*da0073e9SAndroid Build Coastguard Worker        else:
635*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
636*da0073e9SAndroid Build Coastguard Worker                f"randomize_input need support tensor of type {inputs.dtype}"
637*da0073e9SAndroid Build Coastguard Worker            )
638*da0073e9SAndroid Build Coastguard Worker    else:
639*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
640*da0073e9SAndroid Build Coastguard Worker            f"randomize_input can not handle input of type {type(inputs)}"
641*da0073e9SAndroid Build Coastguard Worker        )
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker
644*da0073e9SAndroid Build Coastguard Workerdef maybe_mark_step(args):
645*da0073e9SAndroid Build Coastguard Worker    if args.trace_on_xla:
646*da0073e9SAndroid Build Coastguard Worker        xm.mark_step()
647*da0073e9SAndroid Build Coastguard Worker
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Workerdef speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
650*da0073e9SAndroid Build Coastguard Worker    """
651*da0073e9SAndroid Build Coastguard Worker    Measure speedups over eager.
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker    Writes to ./speedups.csv
654*da0073e9SAndroid Build Coastguard Worker    """
655*da0073e9SAndroid Build Coastguard Worker    # if args.dynamic_shapes:
656*da0073e9SAndroid Build Coastguard Worker    #     return speedup_experiment_ds(args, model_iter_fn, model, example_inputs)
657*da0073e9SAndroid Build Coastguard Worker
658*da0073e9SAndroid Build Coastguard Worker    timings = np.zeros((args.repeat, 2), np.float64)
659*da0073e9SAndroid Build Coastguard Worker    # if we randomize the input, we should also check the result is correct
660*da0073e9SAndroid Build Coastguard Worker    should_randomize_input = args.randomize_input
661*da0073e9SAndroid Build Coastguard Worker
662*da0073e9SAndroid Build Coastguard Worker    import contextlib
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker    from torch._inductor.utils import maybe_profile
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker    @contextlib.contextmanager
667*da0073e9SAndroid Build Coastguard Worker    def maybe_mark_profile(*args, **kwargs):
668*da0073e9SAndroid Build Coastguard Worker        prof: torch.profiler.profile = kwargs.pop("p", None)
669*da0073e9SAndroid Build Coastguard Worker        mark = kwargs.pop("mark", None)
670*da0073e9SAndroid Build Coastguard Worker        if prof:
671*da0073e9SAndroid Build Coastguard Worker            with torch.profiler.record_function(mark):
672*da0073e9SAndroid Build Coastguard Worker                yield
673*da0073e9SAndroid Build Coastguard Worker        else:
674*da0073e9SAndroid Build Coastguard Worker            yield
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker    times = args.iterations_per_run
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker    # Use higher tolerance for XLA since XLA cause numerical unstability when
679*da0073e9SAndroid Build Coastguard Worker    # graph size changes
680*da0073e9SAndroid Build Coastguard Worker    tolerance = args.xla_tolerance if args.trace_on_xla else 1e-4
681*da0073e9SAndroid Build Coastguard Worker    torch._dynamo.config.repro_tolerance = tolerance
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker    with maybe_profile(args.export_profiler_trace) as p:
684*da0073e9SAndroid Build Coastguard Worker        if args.export_aot_inductor:
685*da0073e9SAndroid Build Coastguard Worker            frozen_model_iter_fn = export_aot_inductor(
686*da0073e9SAndroid Build Coastguard Worker                model, example_inputs, args.devices[0]
687*da0073e9SAndroid Build Coastguard Worker            )
688*da0073e9SAndroid Build Coastguard Worker        else:
689*da0073e9SAndroid Build Coastguard Worker            frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
690*da0073e9SAndroid Build Coastguard Worker
691*da0073e9SAndroid Build Coastguard Worker        for rep in trange(args.repeat, desc="running benchmark"):
692*da0073e9SAndroid Build Coastguard Worker            inputs = (
693*da0073e9SAndroid Build Coastguard Worker                randomize_input(copy.deepcopy(example_inputs))
694*da0073e9SAndroid Build Coastguard Worker                if should_randomize_input
695*da0073e9SAndroid Build Coastguard Worker                else example_inputs
696*da0073e9SAndroid Build Coastguard Worker            )
697*da0073e9SAndroid Build Coastguard Worker            # need call mark_step to perform the computation
698*da0073e9SAndroid Build Coastguard Worker            # on randomize_input. Otherwise the first call using the
699*da0073e9SAndroid Build Coastguard Worker            # inputs will incur high penalty then the next one.
700*da0073e9SAndroid Build Coastguard Worker            maybe_mark_step(args)
701*da0073e9SAndroid Build Coastguard Worker
702*da0073e9SAndroid Build Coastguard Worker            # interleave the runs to handle frequency scaling and load changes
703*da0073e9SAndroid Build Coastguard Worker            with maybe_mark_profile(p=p, mark="expected"):
704*da0073e9SAndroid Build Coastguard Worker                timings[rep, 0], expected_output = timed(
705*da0073e9SAndroid Build Coastguard Worker                    model,
706*da0073e9SAndroid Build Coastguard Worker                    model_iter_fn,
707*da0073e9SAndroid Build Coastguard Worker                    inputs,
708*da0073e9SAndroid Build Coastguard Worker                    return_result=True,
709*da0073e9SAndroid Build Coastguard Worker                    times=times,
710*da0073e9SAndroid Build Coastguard Worker                    collect_outputs=args.collect_outputs,
711*da0073e9SAndroid Build Coastguard Worker                )
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Worker            # call mark_step between the 2 calls to make the comparison fair.
714*da0073e9SAndroid Build Coastguard Worker            maybe_mark_step(args)
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker            with maybe_mark_profile(p=p, mark="actual"), maybe_enable_compiled_autograd(
717*da0073e9SAndroid Build Coastguard Worker                args.compiled_autograd
718*da0073e9SAndroid Build Coastguard Worker            ):
719*da0073e9SAndroid Build Coastguard Worker                timings[rep, 1], actual_output = timed(
720*da0073e9SAndroid Build Coastguard Worker                    model,
721*da0073e9SAndroid Build Coastguard Worker                    frozen_model_iter_fn,
722*da0073e9SAndroid Build Coastguard Worker                    inputs,
723*da0073e9SAndroid Build Coastguard Worker                    return_result=True,
724*da0073e9SAndroid Build Coastguard Worker                    times=times,
725*da0073e9SAndroid Build Coastguard Worker                    collect_outputs=args.collect_outputs,
726*da0073e9SAndroid Build Coastguard Worker                )
727*da0073e9SAndroid Build Coastguard Worker
728*da0073e9SAndroid Build Coastguard Worker    if args.export_profiler_trace:
729*da0073e9SAndroid Build Coastguard Worker        name = args.profiler_trace_name + "_" + model.name
730*da0073e9SAndroid Build Coastguard Worker        if hasattr(args, "rank"):
731*da0073e9SAndroid Build Coastguard Worker            name += f"_rank_{args.rank}"
732*da0073e9SAndroid Build Coastguard Worker        name += ".json"
733*da0073e9SAndroid Build Coastguard Worker        name = os.path.join(torch._dynamo.config.base_dir, name)
734*da0073e9SAndroid Build Coastguard Worker        p.export_chrome_trace(name)
735*da0073e9SAndroid Build Coastguard Worker    median = np.median(timings, axis=0)
736*da0073e9SAndroid Build Coastguard Worker    speedup = median[0] / median[1]
737*da0073e9SAndroid Build Coastguard Worker    if args.dump_raw_metrics:
738*da0073e9SAndroid Build Coastguard Worker        np.save(
739*da0073e9SAndroid Build Coastguard Worker            f"{output_filename[:-4]}-raw_timings-{current_name}-{current_device}.npy",
740*da0073e9SAndroid Build Coastguard Worker            timings,
741*da0073e9SAndroid Build Coastguard Worker        )
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker    first_headers = ["dev", "name", "batch_size"]
744*da0073e9SAndroid Build Coastguard Worker    first_fields = [current_device, current_name, current_batch_size]
745*da0073e9SAndroid Build Coastguard Worker    if "tag" in kwargs:
746*da0073e9SAndroid Build Coastguard Worker        first_headers.append("tag")
747*da0073e9SAndroid Build Coastguard Worker        first_fields.append(kwargs["tag"])
748*da0073e9SAndroid Build Coastguard Worker    headers = first_headers + ["speedup", "abs_latency"]
749*da0073e9SAndroid Build Coastguard Worker    row = first_fields + [float(speedup), median[1] * 1000]
750*da0073e9SAndroid Build Coastguard Worker    msg = f"{speedup:.3f}x"
751*da0073e9SAndroid Build Coastguard Worker    if args.baseline:
752*da0073e9SAndroid Build Coastguard Worker        headers.extend(
753*da0073e9SAndroid Build Coastguard Worker            [
754*da0073e9SAndroid Build Coastguard Worker                "baseline",
755*da0073e9SAndroid Build Coastguard Worker                "speedup_vs_baseline",
756*da0073e9SAndroid Build Coastguard Worker            ]
757*da0073e9SAndroid Build Coastguard Worker        )
758*da0073e9SAndroid Build Coastguard Worker        df = pd.read_csv(args.baseline)
759*da0073e9SAndroid Build Coastguard Worker        try:
760*da0073e9SAndroid Build Coastguard Worker            baseline_speedup = df[df["name"] == current_name]["speedup"].item()
761*da0073e9SAndroid Build Coastguard Worker            row.extend([baseline_speedup, speedup / baseline_speedup])
762*da0073e9SAndroid Build Coastguard Worker            msg = f"{baseline_speedup:.3f}x -> {speedup:.3f}x [{speedup / baseline_speedup:.3f}x]"
763*da0073e9SAndroid Build Coastguard Worker        except (KeyError, ZeroDivisionError):
764*da0073e9SAndroid Build Coastguard Worker            row.extend(
765*da0073e9SAndroid Build Coastguard Worker                [
766*da0073e9SAndroid Build Coastguard Worker                    0.0,
767*da0073e9SAndroid Build Coastguard Worker                    0.0,
768*da0073e9SAndroid Build Coastguard Worker                ]
769*da0073e9SAndroid Build Coastguard Worker            )
770*da0073e9SAndroid Build Coastguard Worker    if "compilation_latency" in kwargs:
771*da0073e9SAndroid Build Coastguard Worker        headers += [
772*da0073e9SAndroid Build Coastguard Worker            "compilation_latency",
773*da0073e9SAndroid Build Coastguard Worker            "compression_ratio",
774*da0073e9SAndroid Build Coastguard Worker            "eager_peak_mem",
775*da0073e9SAndroid Build Coastguard Worker            "dynamo_peak_mem",
776*da0073e9SAndroid Build Coastguard Worker        ]
777*da0073e9SAndroid Build Coastguard Worker        row.append(kwargs["compilation_latency"])
778*da0073e9SAndroid Build Coastguard Worker        row.append(kwargs["compression_ratio"])
779*da0073e9SAndroid Build Coastguard Worker        row.append(kwargs["eager_peak_mem"])
780*da0073e9SAndroid Build Coastguard Worker        row.append(kwargs["dynamo_peak_mem"])
781*da0073e9SAndroid Build Coastguard Worker
782*da0073e9SAndroid Build Coastguard Worker    if "cache_lookup_latency" in kwargs:
783*da0073e9SAndroid Build Coastguard Worker        headers.append("cache_lookup_latency")
784*da0073e9SAndroid Build Coastguard Worker        row.append(kwargs["cache_lookup_latency"])
785*da0073e9SAndroid Build Coastguard Worker
786*da0073e9SAndroid Build Coastguard Worker    if "dynamo_stats" in kwargs:
787*da0073e9SAndroid Build Coastguard Worker        for k, v in kwargs["dynamo_stats"].items():
788*da0073e9SAndroid Build Coastguard Worker            headers.append(k)
789*da0073e9SAndroid Build Coastguard Worker            row.append(v)
790*da0073e9SAndroid Build Coastguard Worker    output_csv(
791*da0073e9SAndroid Build Coastguard Worker        output_filename,
792*da0073e9SAndroid Build Coastguard Worker        headers,
793*da0073e9SAndroid Build Coastguard Worker        row,
794*da0073e9SAndroid Build Coastguard Worker    )
795*da0073e9SAndroid Build Coastguard Worker    headers, data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True)
796*da0073e9SAndroid Build Coastguard Worker    assert (
797*da0073e9SAndroid Build Coastguard Worker        output_filename.find(".csv") > 0
798*da0073e9SAndroid Build Coastguard Worker    ), f"expected output_filename to be a .csv, but got {output_filename}"
799*da0073e9SAndroid Build Coastguard Worker    output_csv(
800*da0073e9SAndroid Build Coastguard Worker        output_filename[:-4] + "_compilation_metrics.csv",
801*da0073e9SAndroid Build Coastguard Worker        first_headers + headers,
802*da0073e9SAndroid Build Coastguard Worker        first_fields + data,
803*da0073e9SAndroid Build Coastguard Worker    )
804*da0073e9SAndroid Build Coastguard Worker    return msg
805*da0073e9SAndroid Build Coastguard Worker
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Workerdef speedup_experiment_ds(args, model_iter_fn, model, example_inputs):
808*da0073e9SAndroid Build Coastguard Worker    """
809*da0073e9SAndroid Build Coastguard Worker    Run dynamic shapes benchmarks.
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker    Requires dynamic shape compatible models, which provide a list of example inputs.
812*da0073e9SAndroid Build Coastguard Worker
813*da0073e9SAndroid Build Coastguard Worker    Warms up using the first input example and then iterates the inputs,
814*da0073e9SAndroid Build Coastguard Worker    measuring (and expecting minimal) variance between the runtime for different examples.
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker    """
817*da0073e9SAndroid Build Coastguard Worker    timings = np.zeros((args.repeat, len(example_inputs), 2), np.float64)
818*da0073e9SAndroid Build Coastguard Worker
819*da0073e9SAndroid Build Coastguard Worker    if args.repeat > 5:
820*da0073e9SAndroid Build Coastguard Worker        print(
821*da0073e9SAndroid Build Coastguard Worker            f"\ndynamic shapes experiments are slow, consider setting --repeat less than {args.repeat}\n"
822*da0073e9SAndroid Build Coastguard Worker        )
823*da0073e9SAndroid Build Coastguard Worker
824*da0073e9SAndroid Build Coastguard Worker    nwarmup = 4
825*da0073e9SAndroid Build Coastguard Worker    for rep in range(args.repeat):
826*da0073e9SAndroid Build Coastguard Worker        # Start each rep fresh, e.g. only warmup on example 0
827*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
828*da0073e9SAndroid Build Coastguard Worker        optimized_model_iter_fn = optimize_ctx(model_iter_fn)
829*da0073e9SAndroid Build Coastguard Worker        for _ in range(nwarmup):
830*da0073e9SAndroid Build Coastguard Worker            optimized_model_iter_fn(model, example_inputs[0])
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker        for input_idx, inputs in enumerate(example_inputs):
833*da0073e9SAndroid Build Coastguard Worker            # interleave the runs to handle frequency scaling and load changes
834*da0073e9SAndroid Build Coastguard Worker            timings[rep, input_idx, 0] = timed(
835*da0073e9SAndroid Build Coastguard Worker                model, model_iter_fn, inputs, return_result=False
836*da0073e9SAndroid Build Coastguard Worker            )
837*da0073e9SAndroid Build Coastguard Worker            # different from regular speedup_experiment, we _DO_ want to allow recompilation
838*da0073e9SAndroid Build Coastguard Worker            timings[rep, input_idx, 1] = timed(
839*da0073e9SAndroid Build Coastguard Worker                model, optimized_model_iter_fn, inputs, return_result=False
840*da0073e9SAndroid Build Coastguard Worker            )
841*da0073e9SAndroid Build Coastguard Worker    medians = np.median(timings, axis=0)
842*da0073e9SAndroid Build Coastguard Worker    speedups = list(medians[:, 0] / medians[:, 1])
843*da0073e9SAndroid Build Coastguard Worker    speedups_mean = np.mean(speedups)
844*da0073e9SAndroid Build Coastguard Worker    speedups_median = np.median(speedups)
845*da0073e9SAndroid Build Coastguard Worker    speedups_var = np.var(speedups)
846*da0073e9SAndroid Build Coastguard Worker
847*da0073e9SAndroid Build Coastguard Worker    # TODO this x[0] is not going to work in general but bert only has 1 input
848*da0073e9SAndroid Build Coastguard Worker    shapes = [x[0].shape for x in example_inputs]
849*da0073e9SAndroid Build Coastguard Worker    shape_keys = sorted(set(shapes))
850*da0073e9SAndroid Build Coastguard Worker    shape_speedups = {
851*da0073e9SAndroid Build Coastguard Worker        shape: [
852*da0073e9SAndroid Build Coastguard Worker            it[1] for it in filter(lambda it: it[0] == shape, zip(shapes, speedups))
853*da0073e9SAndroid Build Coastguard Worker        ]
854*da0073e9SAndroid Build Coastguard Worker        for shape in shape_keys
855*da0073e9SAndroid Build Coastguard Worker    }
856*da0073e9SAndroid Build Coastguard Worker    output_str = (
857*da0073e9SAndroid Build Coastguard Worker        f"mean: {speedups_mean:.3f}, median: {speedups_median:.3f}, var: {speedups_var:.3f}"
858*da0073e9SAndroid Build Coastguard Worker        + "\nSpeedups by shape: "
859*da0073e9SAndroid Build Coastguard Worker        + "\n".join(
860*da0073e9SAndroid Build Coastguard Worker            [
861*da0073e9SAndroid Build Coastguard Worker                f"{shape}: "
862*da0073e9SAndroid Build Coastguard Worker                + ", ".join([f"{speedup: .3g}" for speedup in shape_speedups[shape]])
863*da0073e9SAndroid Build Coastguard Worker                for shape in shape_keys
864*da0073e9SAndroid Build Coastguard Worker            ]
865*da0073e9SAndroid Build Coastguard Worker        )
866*da0073e9SAndroid Build Coastguard Worker    )
867*da0073e9SAndroid Build Coastguard Worker    output_csv(
868*da0073e9SAndroid Build Coastguard Worker        output_filename,
869*da0073e9SAndroid Build Coastguard Worker        ("dev", "name", "batch_size", "speedup mean", "speedup median", "speedup var"),
870*da0073e9SAndroid Build Coastguard Worker        [
871*da0073e9SAndroid Build Coastguard Worker            current_device,
872*da0073e9SAndroid Build Coastguard Worker            current_name,
873*da0073e9SAndroid Build Coastguard Worker            current_batch_size,
874*da0073e9SAndroid Build Coastguard Worker            speedups_mean,
875*da0073e9SAndroid Build Coastguard Worker            speedups_median,
876*da0073e9SAndroid Build Coastguard Worker            speedups_var,
877*da0073e9SAndroid Build Coastguard Worker        ],
878*da0073e9SAndroid Build Coastguard Worker    )
879*da0073e9SAndroid Build Coastguard Worker    return output_str
880*da0073e9SAndroid Build Coastguard Worker
881*da0073e9SAndroid Build Coastguard Worker
882*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
883*da0073e9SAndroid Build Coastguard Workerdef override_synchronize_with_onnx_iobinding(iobinding):
884*da0073e9SAndroid Build Coastguard Worker    global synchronize
885*da0073e9SAndroid Build Coastguard Worker    prev_synchrnoize = synchronize
886*da0073e9SAndroid Build Coastguard Worker    try:
887*da0073e9SAndroid Build Coastguard Worker        if iobinding is not None:
888*da0073e9SAndroid Build Coastguard Worker
889*da0073e9SAndroid Build Coastguard Worker            def new_synchronize():
890*da0073e9SAndroid Build Coastguard Worker                iobinding.synchronize_inputs()
891*da0073e9SAndroid Build Coastguard Worker                iobinding.synchronize_outputs()
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker            synchronize = new_synchronize
894*da0073e9SAndroid Build Coastguard Worker        yield
895*da0073e9SAndroid Build Coastguard Worker    finally:
896*da0073e9SAndroid Build Coastguard Worker        synchronize = prev_synchrnoize
897*da0073e9SAndroid Build Coastguard Worker
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Workerdef speedup_experiment_onnx(
900*da0073e9SAndroid Build Coastguard Worker    args,
901*da0073e9SAndroid Build Coastguard Worker    model_iter_fn,
902*da0073e9SAndroid Build Coastguard Worker    onnx_model: OnnxModel,
903*da0073e9SAndroid Build Coastguard Worker    model,
904*da0073e9SAndroid Build Coastguard Worker    example_inputs,
905*da0073e9SAndroid Build Coastguard Worker    **kwargs,
906*da0073e9SAndroid Build Coastguard Worker):
907*da0073e9SAndroid Build Coastguard Worker    """
908*da0073e9SAndroid Build Coastguard Worker    Measure speedups over eager.
909*da0073e9SAndroid Build Coastguard Worker
910*da0073e9SAndroid Build Coastguard Worker    This function is responsible for the following:
911*da0073e9SAndroid Build Coastguard Worker        1. Creating iobinding with OnnxModel if device is CUDA, which is essential for perf measurement.
912*da0073e9SAndroid Build Coastguard Worker        2. Running ORT with OnnxModel.
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker    Writes to ./{output_filename}, which should be
915*da0073e9SAndroid Build Coastguard Worker        `pathlib.Path(self.output_dir) / f"{self.compiler}_{suite}_{self.dtype}_{self.mode}_{self.device}_{self.testing}.csv".
916*da0073e9SAndroid Build Coastguard Worker
917*da0073e9SAndroid Build Coastguard Worker    TODO(bowbao): Record export time and export peak memory usage.
918*da0073e9SAndroid Build Coastguard Worker    """
919*da0073e9SAndroid Build Coastguard Worker    timings = np.zeros((args.repeat, 2), np.float64)
920*da0073e9SAndroid Build Coastguard Worker    is_correct = True
921*da0073e9SAndroid Build Coastguard Worker    should_randomize_input = args.randomize_input
922*da0073e9SAndroid Build Coastguard Worker    times = args.iterations_per_run
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Worker    def create_onnx_input_binded_fn(onnx_model: OnnxModel, pt_inputs, example_outputs):
925*da0073e9SAndroid Build Coastguard Worker        # Goal is to move the iobinding creation outside of the timer function.
926*da0073e9SAndroid Build Coastguard Worker        iobinding, outputs = onnx_model.create_iobinding(pt_inputs, example_outputs)
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker        def onnxrt_model_iter_fn(model, inputs, collect_outputs=True):
929*da0073e9SAndroid Build Coastguard Worker            onnx_model.run_with_iobinding(iobinding, outputs)
930*da0073e9SAndroid Build Coastguard Worker            if collect_outputs:
931*da0073e9SAndroid Build Coastguard Worker                return outputs
932*da0073e9SAndroid Build Coastguard Worker
933*da0073e9SAndroid Build Coastguard Worker        return onnxrt_model_iter_fn, iobinding
934*da0073e9SAndroid Build Coastguard Worker
935*da0073e9SAndroid Build Coastguard Worker    def create_onnx_fn(onnx_model: OnnxModel, pt_inputs):
936*da0073e9SAndroid Build Coastguard Worker        # NOTE: Making perf comparison fair by moving out the i/o adapting part.
937*da0073e9SAndroid Build Coastguard Worker        # 1. Pre-adapt `pt_inputs` to `onnx_inputs` here.
938*da0073e9SAndroid Build Coastguard Worker        # 2. Drop `onnx_outputs` to `pt_outputs` adapting. Output comparison is not part of perf measurement.
939*da0073e9SAndroid Build Coastguard Worker        onnx_inputs = onnx_model.adapt_pt_inputs_to_onnx(pt_inputs)
940*da0073e9SAndroid Build Coastguard Worker
941*da0073e9SAndroid Build Coastguard Worker        def onnxrt_model_iter_fn(model, inputs, collect_outputs=True):
942*da0073e9SAndroid Build Coastguard Worker            return onnx_model.run_with_onnx_inputs(onnx_inputs)
943*da0073e9SAndroid Build Coastguard Worker
944*da0073e9SAndroid Build Coastguard Worker        return onnxrt_model_iter_fn
945*da0073e9SAndroid Build Coastguard Worker
946*da0073e9SAndroid Build Coastguard Worker    def timed_onnx(model, onnx_model: OnnxModel, inputs):
947*da0073e9SAndroid Build Coastguard Worker        if current_device == "cpu" or onnx_model.is_cpu():
948*da0073e9SAndroid Build Coastguard Worker            onnxrt_model_iter_fn = create_onnx_fn(onnx_model, inputs)
949*da0073e9SAndroid Build Coastguard Worker            iobinding = None
950*da0073e9SAndroid Build Coastguard Worker        else:
951*da0073e9SAndroid Build Coastguard Worker            onnxrt_model_iter_fn, iobinding = create_onnx_input_binded_fn(
952*da0073e9SAndroid Build Coastguard Worker                onnx_model, inputs, expected_output
953*da0073e9SAndroid Build Coastguard Worker            )
954*da0073e9SAndroid Build Coastguard Worker        with override_synchronize_with_onnx_iobinding(iobinding):
955*da0073e9SAndroid Build Coastguard Worker            return timed(
956*da0073e9SAndroid Build Coastguard Worker                model,
957*da0073e9SAndroid Build Coastguard Worker                onnxrt_model_iter_fn,
958*da0073e9SAndroid Build Coastguard Worker                inputs,
959*da0073e9SAndroid Build Coastguard Worker                return_result=True,
960*da0073e9SAndroid Build Coastguard Worker                times=times,
961*da0073e9SAndroid Build Coastguard Worker                collect_outputs=args.collect_outputs,
962*da0073e9SAndroid Build Coastguard Worker            )
963*da0073e9SAndroid Build Coastguard Worker
964*da0073e9SAndroid Build Coastguard Worker    # Insert ONNX warm-up
965*da0073e9SAndroid Build Coastguard Worker    inputs = (
966*da0073e9SAndroid Build Coastguard Worker        randomize_input(copy.deepcopy(example_inputs))
967*da0073e9SAndroid Build Coastguard Worker        if should_randomize_input
968*da0073e9SAndroid Build Coastguard Worker        else example_inputs
969*da0073e9SAndroid Build Coastguard Worker    )
970*da0073e9SAndroid Build Coastguard Worker    _, expected_output = timed(
971*da0073e9SAndroid Build Coastguard Worker        model,
972*da0073e9SAndroid Build Coastguard Worker        model_iter_fn,
973*da0073e9SAndroid Build Coastguard Worker        inputs,
974*da0073e9SAndroid Build Coastguard Worker        return_result=True,
975*da0073e9SAndroid Build Coastguard Worker        times=times,
976*da0073e9SAndroid Build Coastguard Worker        collect_outputs=args.collect_outputs,
977*da0073e9SAndroid Build Coastguard Worker    )
978*da0073e9SAndroid Build Coastguard Worker    for _ in range(2):
979*da0073e9SAndroid Build Coastguard Worker        timed_onnx(model, onnx_model, inputs)
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker    for rep in range(args.repeat):
982*da0073e9SAndroid Build Coastguard Worker        inputs = (
983*da0073e9SAndroid Build Coastguard Worker            randomize_input(copy.deepcopy(example_inputs))
984*da0073e9SAndroid Build Coastguard Worker            if should_randomize_input
985*da0073e9SAndroid Build Coastguard Worker            else example_inputs
986*da0073e9SAndroid Build Coastguard Worker        )
987*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.device_count() > 1:
988*da0073e9SAndroid Build Coastguard Worker            # Manually set correct torch.cuda.current_device to ensure torch.cuda.synchronize() works as intended.
989*da0073e9SAndroid Build Coastguard Worker            # When there are more than 1 cuda devices, the first one is used for pytorch eager.
990*da0073e9SAndroid Build Coastguard Worker            # The second one is used for onnx ort.
991*da0073e9SAndroid Build Coastguard Worker            torch.cuda.set_device(0)
992*da0073e9SAndroid Build Coastguard Worker        timings[rep, 0], expected_output = timed(
993*da0073e9SAndroid Build Coastguard Worker            model,
994*da0073e9SAndroid Build Coastguard Worker            model_iter_fn,
995*da0073e9SAndroid Build Coastguard Worker            inputs,
996*da0073e9SAndroid Build Coastguard Worker            return_result=True,
997*da0073e9SAndroid Build Coastguard Worker            times=times,
998*da0073e9SAndroid Build Coastguard Worker            collect_outputs=args.collect_outputs,
999*da0073e9SAndroid Build Coastguard Worker        )
1000*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.device_count() > 1:
1001*da0073e9SAndroid Build Coastguard Worker            # Manually set correct torch.cuda.current_device to ensure torch.cuda.synchronize() works as intended.
1002*da0073e9SAndroid Build Coastguard Worker            # When there are more than 1 cuda devices, the first one is used for pytorch eager.
1003*da0073e9SAndroid Build Coastguard Worker            # The second one is used for onnx ort.
1004*da0073e9SAndroid Build Coastguard Worker            torch.cuda.set_device(1)
1005*da0073e9SAndroid Build Coastguard Worker        timings[rep, 1], actual_output = timed_onnx(model, onnx_model, inputs)
1006*da0073e9SAndroid Build Coastguard Worker
1007*da0073e9SAndroid Build Coastguard Worker    pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue
1008*da0073e9SAndroid Build Coastguard Worker    median = np.median(timings, axis=0)
1009*da0073e9SAndroid Build Coastguard Worker    speedup = median[0] / median[1]
1010*da0073e9SAndroid Build Coastguard Worker    if args.dump_raw_metrics:
1011*da0073e9SAndroid Build Coastguard Worker        np.save(
1012*da0073e9SAndroid Build Coastguard Worker            f"{output_filename[:-4]}-raw_timings-{current_name}-{current_device}.npy",
1013*da0073e9SAndroid Build Coastguard Worker            timings,
1014*da0073e9SAndroid Build Coastguard Worker        )
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker    headers = ["dev", "name", "batch_size", "speedup", "abs_latency"]
1017*da0073e9SAndroid Build Coastguard Worker    row = [
1018*da0073e9SAndroid Build Coastguard Worker        current_device,
1019*da0073e9SAndroid Build Coastguard Worker        current_name,
1020*da0073e9SAndroid Build Coastguard Worker        current_batch_size,
1021*da0073e9SAndroid Build Coastguard Worker        float(speedup),
1022*da0073e9SAndroid Build Coastguard Worker        median[1] * 1000,
1023*da0073e9SAndroid Build Coastguard Worker    ]
1024*da0073e9SAndroid Build Coastguard Worker    if "compilation_latency" in kwargs:
1025*da0073e9SAndroid Build Coastguard Worker        headers = headers + ["compilation_latency", "compression_ratio"]
1026*da0073e9SAndroid Build Coastguard Worker        row.append(kwargs["compilation_latency"])
1027*da0073e9SAndroid Build Coastguard Worker        row.append(kwargs["compression_ratio"])
1028*da0073e9SAndroid Build Coastguard Worker
1029*da0073e9SAndroid Build Coastguard Worker    output_csv(
1030*da0073e9SAndroid Build Coastguard Worker        output_filename,
1031*da0073e9SAndroid Build Coastguard Worker        headers,
1032*da0073e9SAndroid Build Coastguard Worker        row,
1033*da0073e9SAndroid Build Coastguard Worker    )
1034*da0073e9SAndroid Build Coastguard Worker    headers, data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True)
1035*da0073e9SAndroid Build Coastguard Worker    assert (
1036*da0073e9SAndroid Build Coastguard Worker        output_filename.find(".csv") > 0
1037*da0073e9SAndroid Build Coastguard Worker    ), f"expected output_filename to be a .csv, but got {output_filename}"
1038*da0073e9SAndroid Build Coastguard Worker    output_csv(
1039*da0073e9SAndroid Build Coastguard Worker        output_filename[:-4] + "_compilation_metrics.csv",
1040*da0073e9SAndroid Build Coastguard Worker        ["dev", "name", "batch_size"] + headers,
1041*da0073e9SAndroid Build Coastguard Worker        [current_device, current_name, current_batch_size] + data,
1042*da0073e9SAndroid Build Coastguard Worker    )
1043*da0073e9SAndroid Build Coastguard Worker    return format_speedup(speedup, pvalue, is_correct=is_correct)
1044*da0073e9SAndroid Build Coastguard Worker
1045*da0073e9SAndroid Build Coastguard Worker
1046*da0073e9SAndroid Build Coastguard Workerdef overhead_experiment(*args, model_iter_fn):
1047*da0073e9SAndroid Build Coastguard Worker    """
1048*da0073e9SAndroid Build Coastguard Worker    Measure overheads of TorchDynamo by running with no backend (only
1049*da0073e9SAndroid Build Coastguard Worker    eager+FX), and reporting speedup/slowdown over eager.
1050*da0073e9SAndroid Build Coastguard Worker
1051*da0073e9SAndroid Build Coastguard Worker    Writes to ./overheads.csv
1052*da0073e9SAndroid Build Coastguard Worker    """
1053*da0073e9SAndroid Build Coastguard Worker    return speedup_experiment(*args, model_iter_fn)
1054*da0073e9SAndroid Build Coastguard Worker
1055*da0073e9SAndroid Build Coastguard Worker
1056*da0073e9SAndroid Build Coastguard Workerdef print_fx(gm, example_inputs):
1057*da0073e9SAndroid Build Coastguard Worker    print(gm.graph)
1058*da0073e9SAndroid Build Coastguard Worker    return gm
1059*da0073e9SAndroid Build Coastguard Worker
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Workerdef print_aten_ops(gm, example_inputs):
1062*da0073e9SAndroid Build Coastguard Worker    from functorch.compile import aot_module
1063*da0073e9SAndroid Build Coastguard Worker
1064*da0073e9SAndroid Build Coastguard Worker    def trace_printer(gm, _):
1065*da0073e9SAndroid Build Coastguard Worker        print(gm.graph)
1066*da0073e9SAndroid Build Coastguard Worker        return gm
1067*da0073e9SAndroid Build Coastguard Worker
1068*da0073e9SAndroid Build Coastguard Worker    return aot_module(gm, fw_compiler=trace_printer, bw_compiler=trace_printer)
1069*da0073e9SAndroid Build Coastguard Worker
1070*da0073e9SAndroid Build Coastguard Worker
1071*da0073e9SAndroid Build Coastguard Workerdef baselines(models, model_iter_fn, example_inputs, args):
1072*da0073e9SAndroid Build Coastguard Worker    """
1073*da0073e9SAndroid Build Coastguard Worker    Common measurement code across all baseline experiments.
1074*da0073e9SAndroid Build Coastguard Worker    """
1075*da0073e9SAndroid Build Coastguard Worker    models = list(models)
1076*da0073e9SAndroid Build Coastguard Worker    for idx, (name, model) in enumerate(models):
1077*da0073e9SAndroid Build Coastguard Worker        if idx == 0:
1078*da0073e9SAndroid Build Coastguard Worker            result0 = model_iter_fn(model, example_inputs)
1079*da0073e9SAndroid Build Coastguard Worker        elif model is not None:
1080*da0073e9SAndroid Build Coastguard Worker            try:
1081*da0073e9SAndroid Build Coastguard Worker                result = model_iter_fn(model, example_inputs)
1082*da0073e9SAndroid Build Coastguard Worker                if same(result0, result):
1083*da0073e9SAndroid Build Coastguard Worker                    continue
1084*da0073e9SAndroid Build Coastguard Worker                print(name, "is INCORRECT")
1085*da0073e9SAndroid Build Coastguard Worker            except Exception:
1086*da0073e9SAndroid Build Coastguard Worker                log.exception("error checking %s", name)
1087*da0073e9SAndroid Build Coastguard Worker            models[idx] = (name, None)
1088*da0073e9SAndroid Build Coastguard Worker    timings = np.zeros((args.repeat, len(models)), np.float64)
1089*da0073e9SAndroid Build Coastguard Worker    timings.fill(1.0e10)
1090*da0073e9SAndroid Build Coastguard Worker    for rep in range(args.repeat):
1091*da0073e9SAndroid Build Coastguard Worker        for idx, (name, model) in enumerate(models):
1092*da0073e9SAndroid Build Coastguard Worker            if model is not None:
1093*da0073e9SAndroid Build Coastguard Worker                try:
1094*da0073e9SAndroid Build Coastguard Worker                    timings[rep, idx] = timed(model, model_iter_fn, example_inputs)
1095*da0073e9SAndroid Build Coastguard Worker                except Exception:
1096*da0073e9SAndroid Build Coastguard Worker                    pass
1097*da0073e9SAndroid Build Coastguard Worker    pvalue = [
1098*da0073e9SAndroid Build Coastguard Worker        ttest_ind(timings[:, 0], timings[:, i]).pvalue
1099*da0073e9SAndroid Build Coastguard Worker        for i in range(1, timings.shape[1])
1100*da0073e9SAndroid Build Coastguard Worker    ]
1101*da0073e9SAndroid Build Coastguard Worker    median = np.median(timings, axis=0)
1102*da0073e9SAndroid Build Coastguard Worker    speedup = median[0] / median[1:]
1103*da0073e9SAndroid Build Coastguard Worker    for idx, (name, model) in enumerate(models[1:]):
1104*da0073e9SAndroid Build Coastguard Worker        if model is None:
1105*da0073e9SAndroid Build Coastguard Worker            speedup[idx] = 0.0
1106*da0073e9SAndroid Build Coastguard Worker    result = " ".join(
1107*da0073e9SAndroid Build Coastguard Worker        [
1108*da0073e9SAndroid Build Coastguard Worker            format_speedup(s, p, m is not None)
1109*da0073e9SAndroid Build Coastguard Worker            for s, p, m in zip(speedup, pvalue, [m for n, m in models[1:]])
1110*da0073e9SAndroid Build Coastguard Worker        ]
1111*da0073e9SAndroid Build Coastguard Worker    )
1112*da0073e9SAndroid Build Coastguard Worker    output_csv(
1113*da0073e9SAndroid Build Coastguard Worker        output_filename,
1114*da0073e9SAndroid Build Coastguard Worker        ("dev", "name", "batch_size") + tuple(n for n, m in models[1:]),
1115*da0073e9SAndroid Build Coastguard Worker        [current_device, current_name, current_batch_size]
1116*da0073e9SAndroid Build Coastguard Worker        + [f"{x:.4f}" for x in speedup],
1117*da0073e9SAndroid Build Coastguard Worker    )
1118*da0073e9SAndroid Build Coastguard Worker    return result
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard Worker
1121*da0073e9SAndroid Build Coastguard Workerdef xla(args, model_iter_fn, model, example_inputs):
1122*da0073e9SAndroid Build Coastguard Worker    xla_dev = xm.xla_device(devkind=current_device)
1123*da0073e9SAndroid Build Coastguard Worker    model_xla = copy.deepcopy(model).to("cpu").to(device=xla_dev)
1124*da0073e9SAndroid Build Coastguard Worker    example_inputs_xla = tree_map_only(
1125*da0073e9SAndroid Build Coastguard Worker        torch.Tensor, lambda x: x.to("cpu").to(device=xla_dev), example_inputs
1126*da0073e9SAndroid Build Coastguard Worker    )
1127*da0073e9SAndroid Build Coastguard Worker    for _ in range(3):  # warmup
1128*da0073e9SAndroid Build Coastguard Worker        timed(model, model_iter_fn, example_inputs)
1129*da0073e9SAndroid Build Coastguard Worker        timed(model_xla, model_iter_fn, example_inputs_xla)
1130*da0073e9SAndroid Build Coastguard Worker    timings = np.zeros((args.repeat, 2), np.float64)
1131*da0073e9SAndroid Build Coastguard Worker    timings.fill(1.0e10)
1132*da0073e9SAndroid Build Coastguard Worker    for rep in range(args.repeat):
1133*da0073e9SAndroid Build Coastguard Worker        timings[rep, 0] = timed(model, model_iter_fn, example_inputs)
1134*da0073e9SAndroid Build Coastguard Worker        timings[rep, 1] = timed(model_xla, model_iter_fn, example_inputs_xla)
1135*da0073e9SAndroid Build Coastguard Worker
1136*da0073e9SAndroid Build Coastguard Worker    pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue
1137*da0073e9SAndroid Build Coastguard Worker    time_baseline, time_xla = np.median(timings, axis=0)
1138*da0073e9SAndroid Build Coastguard Worker    speedup = time_baseline / time_xla
1139*da0073e9SAndroid Build Coastguard Worker    output_csv(
1140*da0073e9SAndroid Build Coastguard Worker        output_filename,
1141*da0073e9SAndroid Build Coastguard Worker        ("dev", "name", "batch_size", "speedup", "time_baseline", "time_xla"),
1142*da0073e9SAndroid Build Coastguard Worker        [
1143*da0073e9SAndroid Build Coastguard Worker            current_device,
1144*da0073e9SAndroid Build Coastguard Worker            current_name,
1145*da0073e9SAndroid Build Coastguard Worker            current_batch_size,
1146*da0073e9SAndroid Build Coastguard Worker            speedup,
1147*da0073e9SAndroid Build Coastguard Worker            time_baseline,
1148*da0073e9SAndroid Build Coastguard Worker            time_xla,
1149*da0073e9SAndroid Build Coastguard Worker        ],
1150*da0073e9SAndroid Build Coastguard Worker    )
1151*da0073e9SAndroid Build Coastguard Worker    return format_speedup(speedup, pvalue)
1152*da0073e9SAndroid Build Coastguard Worker
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Workerdef try_script(model, example_inputs):
1155*da0073e9SAndroid Build Coastguard Worker    try:
1156*da0073e9SAndroid Build Coastguard Worker        return torch.jit.script(model)
1157*da0073e9SAndroid Build Coastguard Worker    except Exception:
1158*da0073e9SAndroid Build Coastguard Worker        return None
1159*da0073e9SAndroid Build Coastguard Worker
1160*da0073e9SAndroid Build Coastguard Worker
1161*da0073e9SAndroid Build Coastguard Workerclass AOTInductorModelCache:
1162*da0073e9SAndroid Build Coastguard Worker    cache = dict()
1163*da0073e9SAndroid Build Coastguard Worker
1164*da0073e9SAndroid Build Coastguard Worker    @classmethod
1165*da0073e9SAndroid Build Coastguard Worker    def load(cls, model, example_inputs, device):
1166*da0073e9SAndroid Build Coastguard Worker        import torch._inductor
1167*da0073e9SAndroid Build Coastguard Worker        import torch.export._trace
1168*da0073e9SAndroid Build Coastguard Worker
1169*da0073e9SAndroid Build Coastguard Worker        key = weakref.ref(model)
1170*da0073e9SAndroid Build Coastguard Worker        if key not in cls.cache:
1171*da0073e9SAndroid Build Coastguard Worker            # Register the output dataclass to pytree
1172*da0073e9SAndroid Build Coastguard Worker            example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1173*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
1174*da0073e9SAndroid Build Coastguard Worker                # copy.deepcopy is required to prevent any surprising side-effect,
1175*da0073e9SAndroid Build Coastguard Worker                # see https://github.com/pytorch/pytorch/issues/113029
1176*da0073e9SAndroid Build Coastguard Worker                example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
1177*da0073e9SAndroid Build Coastguard Worker
1178*da0073e9SAndroid Build Coastguard Worker            if pytree._is_namedtuple_instance(example_outputs):
1179*da0073e9SAndroid Build Coastguard Worker                typ = type(example_outputs)
1180*da0073e9SAndroid Build Coastguard Worker                pytree._register_namedtuple(
1181*da0073e9SAndroid Build Coastguard Worker                    typ,
1182*da0073e9SAndroid Build Coastguard Worker                    serialized_type_name=f"{typ.__module__}.{typ.__name__}",
1183*da0073e9SAndroid Build Coastguard Worker                )
1184*da0073e9SAndroid Build Coastguard Worker            else:
1185*da0073e9SAndroid Build Coastguard Worker                _register_dataclass_output_as_pytree(example_outputs)
1186*da0073e9SAndroid Build Coastguard Worker
1187*da0073e9SAndroid Build Coastguard Worker            # TODO(angelayi): change this to predispatch
1188*da0073e9SAndroid Build Coastguard Worker            # https://github.com/pytorch/pytorch/issues/127513 needs to be fixed before changing
1189*da0073e9SAndroid Build Coastguard Worker            # to predispatch to avoid performance regressions
1190*da0073e9SAndroid Build Coastguard Worker            gm = torch.export._trace._export_to_torch_ir(
1191*da0073e9SAndroid Build Coastguard Worker                model,
1192*da0073e9SAndroid Build Coastguard Worker                example_args,
1193*da0073e9SAndroid Build Coastguard Worker                example_kwargs,
1194*da0073e9SAndroid Build Coastguard Worker            )
1195*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
1196*da0073e9SAndroid Build Coastguard Worker                so_path = torch._inductor.aot_compile(
1197*da0073e9SAndroid Build Coastguard Worker                    gm, example_args, example_kwargs
1198*da0073e9SAndroid Build Coastguard Worker                )  # type: ignore[arg-type]
1199*da0073e9SAndroid Build Coastguard Worker
1200*da0073e9SAndroid Build Coastguard Worker            cls.cache[key] = torch._export.aot_load(so_path, device)
1201*da0073e9SAndroid Build Coastguard Worker
1202*da0073e9SAndroid Build Coastguard Worker        return cls.cache[key]
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker
1205*da0073e9SAndroid Build Coastguard Workerdef export(model, example_inputs):
1206*da0073e9SAndroid Build Coastguard Worker    example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1207*da0073e9SAndroid Build Coastguard Worker    example_outputs = model(*example_args, **example_kwargs)
1208*da0073e9SAndroid Build Coastguard Worker    _register_dataclass_output_as_pytree(example_outputs)
1209*da0073e9SAndroid Build Coastguard Worker
1210*da0073e9SAndroid Build Coastguard Worker    ep = torch.export.export(model, example_args, example_kwargs)
1211*da0073e9SAndroid Build Coastguard Worker
1212*da0073e9SAndroid Build Coastguard Worker    def opt_export(_, example_inputs):
1213*da0073e9SAndroid Build Coastguard Worker        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1214*da0073e9SAndroid Build Coastguard Worker        return ep(*example_args, **example_kwargs)
1215*da0073e9SAndroid Build Coastguard Worker
1216*da0073e9SAndroid Build Coastguard Worker    return opt_export
1217*da0073e9SAndroid Build Coastguard Worker
1218*da0073e9SAndroid Build Coastguard Worker
1219*da0073e9SAndroid Build Coastguard Workerdef export_aot_inductor(model, example_inputs, device):
1220*da0073e9SAndroid Build Coastguard Worker    optimized = AOTInductorModelCache.load(model, example_inputs, device)
1221*da0073e9SAndroid Build Coastguard Worker
1222*da0073e9SAndroid Build Coastguard Worker    def opt_aot_inductor(_, example_inputs, collect_outputs=False):
1223*da0073e9SAndroid Build Coastguard Worker        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1224*da0073e9SAndroid Build Coastguard Worker        return optimized(*example_args, **example_kwargs)
1225*da0073e9SAndroid Build Coastguard Worker
1226*da0073e9SAndroid Build Coastguard Worker    return opt_aot_inductor
1227*da0073e9SAndroid Build Coastguard Worker
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Workerdef download_retry_decorator(download_fn):
1230*da0073e9SAndroid Build Coastguard Worker    """
1231*da0073e9SAndroid Build Coastguard Worker    Decorator function for applying retry logic to a download function.
1232*da0073e9SAndroid Build Coastguard Worker
1233*da0073e9SAndroid Build Coastguard Worker    The wrapped function will be called up to 5 times and raises an exception if the function fails each time.
1234*da0073e9SAndroid Build Coastguard Worker    After each unsuccessful attempt, there is a delay before the next attempt, which is increased linearly with the number of tries.
1235*da0073e9SAndroid Build Coastguard Worker
1236*da0073e9SAndroid Build Coastguard Worker    Usage:
1237*da0073e9SAndroid Build Coastguard Worker    @download_retry_decorator
1238*da0073e9SAndroid Build Coastguard Worker    def download_function(model_name: str):
1239*da0073e9SAndroid Build Coastguard Worker        # download logic goes here
1240*da0073e9SAndroid Build Coastguard Worker    """
1241*da0073e9SAndroid Build Coastguard Worker
1242*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(download_fn)
1243*da0073e9SAndroid Build Coastguard Worker    def wrapper(self, *args, **kwargs) -> Any:
1244*da0073e9SAndroid Build Coastguard Worker        tries = 0
1245*da0073e9SAndroid Build Coastguard Worker        total_allowed_tries = MAX_DOWNLOAD_ATTEMPTS
1246*da0073e9SAndroid Build Coastguard Worker        while tries <= total_allowed_tries:
1247*da0073e9SAndroid Build Coastguard Worker            try:
1248*da0073e9SAndroid Build Coastguard Worker                model = download_fn(self, *args, **kwargs)
1249*da0073e9SAndroid Build Coastguard Worker                return model
1250*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
1251*da0073e9SAndroid Build Coastguard Worker                tries += 1
1252*da0073e9SAndroid Build Coastguard Worker                if tries <= total_allowed_tries:
1253*da0073e9SAndroid Build Coastguard Worker                    wait = tries * 30
1254*da0073e9SAndroid Build Coastguard Worker                    print(
1255*da0073e9SAndroid Build Coastguard Worker                        f"Failed to load model: {e}. Trying again ({tries}/{total_allowed_tries}) after {wait}s"
1256*da0073e9SAndroid Build Coastguard Worker                    )
1257*da0073e9SAndroid Build Coastguard Worker                    time.sleep(wait)
1258*da0073e9SAndroid Build Coastguard Worker                else:
1259*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(  # noqa: B904
1260*da0073e9SAndroid Build Coastguard Worker                        f"Failed to load model '{args}' with following error(s): {str(e)}."
1261*da0073e9SAndroid Build Coastguard Worker                    )
1262*da0073e9SAndroid Build Coastguard Worker
1263*da0073e9SAndroid Build Coastguard Worker    return wrapper
1264*da0073e9SAndroid Build Coastguard Worker
1265*da0073e9SAndroid Build Coastguard Worker
1266*da0073e9SAndroid Build Coastguard Workerclass OnnxModel(abc.ABC):
1267*da0073e9SAndroid Build Coastguard Worker    TORCH_TO_NUMPY_DTYPE = {
1268*da0073e9SAndroid Build Coastguard Worker        torch.float16: np.float16,
1269*da0073e9SAndroid Build Coastguard Worker        torch.float32: np.float32,
1270*da0073e9SAndroid Build Coastguard Worker        torch.float64: np.float64,
1271*da0073e9SAndroid Build Coastguard Worker        torch.uint8: np.uint8,
1272*da0073e9SAndroid Build Coastguard Worker        torch.int8: np.int8,
1273*da0073e9SAndroid Build Coastguard Worker        torch.int16: np.int16,
1274*da0073e9SAndroid Build Coastguard Worker        torch.int32: np.int32,
1275*da0073e9SAndroid Build Coastguard Worker        torch.int64: np.longlong,
1276*da0073e9SAndroid Build Coastguard Worker        torch.bool: np.bool_,
1277*da0073e9SAndroid Build Coastguard Worker    }
1278*da0073e9SAndroid Build Coastguard Worker
1279*da0073e9SAndroid Build Coastguard Worker    _COMPILER_NAME: str
1280*da0073e9SAndroid Build Coastguard Worker
1281*da0073e9SAndroid Build Coastguard Worker    def __init__(
1282*da0073e9SAndroid Build Coastguard Worker        self,
1283*da0073e9SAndroid Build Coastguard Worker        output_directory,
1284*da0073e9SAndroid Build Coastguard Worker        model,
1285*da0073e9SAndroid Build Coastguard Worker        example_inputs,
1286*da0073e9SAndroid Build Coastguard Worker        dynamic_shapes: bool,
1287*da0073e9SAndroid Build Coastguard Worker        copy_before_export: bool = False,
1288*da0073e9SAndroid Build Coastguard Worker    ):
1289*da0073e9SAndroid Build Coastguard Worker        model_name = current_name
1290*da0073e9SAndroid Build Coastguard Worker        self.copy_before_export = copy_before_export
1291*da0073e9SAndroid Build Coastguard Worker        self.model_dir = self._generate_onnx_model_directory(
1292*da0073e9SAndroid Build Coastguard Worker            output_directory, self._COMPILER_NAME, model_name
1293*da0073e9SAndroid Build Coastguard Worker        )
1294*da0073e9SAndroid Build Coastguard Worker        self.model_path = str(
1295*da0073e9SAndroid Build Coastguard Worker            self.model_dir / f"{model_name}_{self._COMPILER_NAME}.onnx"
1296*da0073e9SAndroid Build Coastguard Worker        )
1297*da0073e9SAndroid Build Coastguard Worker
1298*da0073e9SAndroid Build Coastguard Worker    def _determine_deepcopy_target_device(self):
1299*da0073e9SAndroid Build Coastguard Worker        if current_device == "cpu":
1300*da0073e9SAndroid Build Coastguard Worker            target_device = "cpu"
1301*da0073e9SAndroid Build Coastguard Worker        else:
1302*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.device_count() > 1:
1303*da0073e9SAndroid Build Coastguard Worker                # Copy to another cuda device to avoid OOM.
1304*da0073e9SAndroid Build Coastguard Worker                target_device = "cuda:1"
1305*da0073e9SAndroid Build Coastguard Worker            else:
1306*da0073e9SAndroid Build Coastguard Worker                target_device = "cuda"
1307*da0073e9SAndroid Build Coastguard Worker        return target_device
1308*da0073e9SAndroid Build Coastguard Worker
1309*da0073e9SAndroid Build Coastguard Worker    def deepcopy_model_and_inputs_to_device(self, model, example_inputs, target_device):
1310*da0073e9SAndroid Build Coastguard Worker        # Deepcopy model before export to avoid modification to baseline model.
1311*da0073e9SAndroid Build Coastguard Worker        # To avoid OOM, the model is first moved to CPU. Both models are then moved to device.
1312*da0073e9SAndroid Build Coastguard Worker        model_device = next(model.parameters()).device
1313*da0073e9SAndroid Build Coastguard Worker        model.to("cpu")
1314*da0073e9SAndroid Build Coastguard Worker        model_copy = copy.deepcopy(model).to(target_device)
1315*da0073e9SAndroid Build Coastguard Worker        model.to(model_device)
1316*da0073e9SAndroid Build Coastguard Worker
1317*da0073e9SAndroid Build Coastguard Worker        target_device_example_inputs = tree_map_only(
1318*da0073e9SAndroid Build Coastguard Worker            torch.Tensor, lambda x: x.to(device=target_device), example_inputs
1319*da0073e9SAndroid Build Coastguard Worker        )
1320*da0073e9SAndroid Build Coastguard Worker
1321*da0073e9SAndroid Build Coastguard Worker        return model_copy, target_device_example_inputs
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker    @classmethod
1324*da0073e9SAndroid Build Coastguard Worker    def _generate_onnx_model_directory(
1325*da0073e9SAndroid Build Coastguard Worker        cls, output_directory: str, compiler_name: str, model_name: str
1326*da0073e9SAndroid Build Coastguard Worker    ) -> pathlib.Path:
1327*da0073e9SAndroid Build Coastguard Worker        model_path = pathlib.Path(
1328*da0073e9SAndroid Build Coastguard Worker            output_directory,
1329*da0073e9SAndroid Build Coastguard Worker            ".onnx_models",
1330*da0073e9SAndroid Build Coastguard Worker            model_name,
1331*da0073e9SAndroid Build Coastguard Worker            compiler_name,
1332*da0073e9SAndroid Build Coastguard Worker        )
1333*da0073e9SAndroid Build Coastguard Worker        if model_path.exists() and model_path.is_dir():
1334*da0073e9SAndroid Build Coastguard Worker            shutil.rmtree(model_path)
1335*da0073e9SAndroid Build Coastguard Worker        model_path.mkdir(parents=True, exist_ok=True)
1336*da0073e9SAndroid Build Coastguard Worker        return model_path
1337*da0073e9SAndroid Build Coastguard Worker
1338*da0073e9SAndroid Build Coastguard Worker    @abc.abstractmethod
1339*da0073e9SAndroid Build Coastguard Worker    def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]:
1340*da0073e9SAndroid Build Coastguard Worker        ...
1341*da0073e9SAndroid Build Coastguard Worker
1342*da0073e9SAndroid Build Coastguard Worker    @abc.abstractmethod
1343*da0073e9SAndroid Build Coastguard Worker    def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]:
1344*da0073e9SAndroid Build Coastguard Worker        ...
1345*da0073e9SAndroid Build Coastguard Worker
1346*da0073e9SAndroid Build Coastguard Worker    def adapt_pt_inputs_to_onnx(self, pt_inputs) -> Mapping[str, np.ndarray]:
1347*da0073e9SAndroid Build Coastguard Worker        pt_inputs = self.format_pt_inputs(pt_inputs)
1348*da0073e9SAndroid Build Coastguard Worker        return {
1349*da0073e9SAndroid Build Coastguard Worker            ort_input.name: pt_input.cpu().numpy()
1350*da0073e9SAndroid Build Coastguard Worker            for ort_input, pt_input in zip(self.onnx_session.get_inputs(), pt_inputs)
1351*da0073e9SAndroid Build Coastguard Worker        }
1352*da0073e9SAndroid Build Coastguard Worker
1353*da0073e9SAndroid Build Coastguard Worker    def adapt_onnx_outputs_to_pt(self, onnx_outputs: List[np.ndarray]) -> Any:
1354*da0073e9SAndroid Build Coastguard Worker        pt_outputs = [
1355*da0073e9SAndroid Build Coastguard Worker            torch.from_numpy(onnx_output).to(current_device)
1356*da0073e9SAndroid Build Coastguard Worker            for onnx_output in onnx_outputs
1357*da0073e9SAndroid Build Coastguard Worker        ]
1358*da0073e9SAndroid Build Coastguard Worker        if len(pt_outputs) == 1:
1359*da0073e9SAndroid Build Coastguard Worker            return pt_outputs[0]
1360*da0073e9SAndroid Build Coastguard Worker        return pt_outputs
1361*da0073e9SAndroid Build Coastguard Worker
1362*da0073e9SAndroid Build Coastguard Worker    def _init_ort_session(self, model_path: str):
1363*da0073e9SAndroid Build Coastguard Worker        import onnxruntime
1364*da0073e9SAndroid Build Coastguard Worker
1365*da0073e9SAndroid Build Coastguard Worker        if current_device == "cpu":
1366*da0073e9SAndroid Build Coastguard Worker            ort_providers = ["CPUExecutionProvider"]
1367*da0073e9SAndroid Build Coastguard Worker        else:
1368*da0073e9SAndroid Build Coastguard Worker            # NOTE(bowbao): Reduce OOM by running ORT on another gpu.
1369*da0073e9SAndroid Build Coastguard Worker            # TODO(bowbao): This works to avoid OOM, but performance is surprisingly very bad.
1370*da0073e9SAndroid Build Coastguard Worker            cuda_provider_options = {
1371*da0073e9SAndroid Build Coastguard Worker                "device_id": 1 if torch.cuda.device_count() > 1 else 0,
1372*da0073e9SAndroid Build Coastguard Worker            }
1373*da0073e9SAndroid Build Coastguard Worker            ort_providers = [("CUDAExecutionProvider", cuda_provider_options)]
1374*da0073e9SAndroid Build Coastguard Worker        session_options = onnxruntime.SessionOptions()
1375*da0073e9SAndroid Build Coastguard Worker        session_options.log_severity_level = 3  # Error
1376*da0073e9SAndroid Build Coastguard Worker
1377*da0073e9SAndroid Build Coastguard Worker        ort_session = onnxruntime.InferenceSession(
1378*da0073e9SAndroid Build Coastguard Worker            self.model_path,
1379*da0073e9SAndroid Build Coastguard Worker            providers=ort_providers,
1380*da0073e9SAndroid Build Coastguard Worker            sess_options=session_options,
1381*da0073e9SAndroid Build Coastguard Worker        )
1382*da0073e9SAndroid Build Coastguard Worker        return ort_session
1383*da0073e9SAndroid Build Coastguard Worker
1384*da0073e9SAndroid Build Coastguard Worker    def is_cpu(self) -> bool:
1385*da0073e9SAndroid Build Coastguard Worker        return self.onnx_session.get_providers()[0] == "CPUExecutionProvider"
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker    def cpu(self) -> Self:
1388*da0073e9SAndroid Build Coastguard Worker        self.onnx_session.set_providers(["CPUExecutionProvider"])
1389*da0073e9SAndroid Build Coastguard Worker        return self
1390*da0073e9SAndroid Build Coastguard Worker
1391*da0073e9SAndroid Build Coastguard Worker    def create_outputs(self, *example_outputs):
1392*da0073e9SAndroid Build Coastguard Worker        return tuple(torch.empty_like(x) for x in example_outputs)
1393*da0073e9SAndroid Build Coastguard Worker
1394*da0073e9SAndroid Build Coastguard Worker    def create_iobinding(self, pt_inputs, example_outputs):
1395*da0073e9SAndroid Build Coastguard Worker        pt_inputs = self.format_pt_inputs(pt_inputs)
1396*da0073e9SAndroid Build Coastguard Worker        example_outputs = self.format_pt_outputs(example_outputs)
1397*da0073e9SAndroid Build Coastguard Worker
1398*da0073e9SAndroid Build Coastguard Worker        iobinding = self.onnx_session.io_binding()
1399*da0073e9SAndroid Build Coastguard Worker        args = [arg.contiguous() for arg in pt_inputs]
1400*da0073e9SAndroid Build Coastguard Worker        for ort_input, arg in zip(self.onnx_session.get_inputs(), args):
1401*da0073e9SAndroid Build Coastguard Worker            # NOTE: Run ORT on another cuda device to reduce OOM.
1402*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.device_count() > 1:
1403*da0073e9SAndroid Build Coastguard Worker                arg = arg.detach().to("cuda:1")
1404*da0073e9SAndroid Build Coastguard Worker            device = arg.device
1405*da0073e9SAndroid Build Coastguard Worker            iobinding.bind_input(
1406*da0073e9SAndroid Build Coastguard Worker                ort_input.name,
1407*da0073e9SAndroid Build Coastguard Worker                device.type,
1408*da0073e9SAndroid Build Coastguard Worker                device.index or 0,
1409*da0073e9SAndroid Build Coastguard Worker                self.TORCH_TO_NUMPY_DTYPE[arg.dtype],
1410*da0073e9SAndroid Build Coastguard Worker                arg.size(),
1411*da0073e9SAndroid Build Coastguard Worker                arg.data_ptr(),
1412*da0073e9SAndroid Build Coastguard Worker            )
1413*da0073e9SAndroid Build Coastguard Worker
1414*da0073e9SAndroid Build Coastguard Worker        outputs = self.create_outputs(*example_outputs)
1415*da0073e9SAndroid Build Coastguard Worker        for ort_output, output in zip(self.onnx_session.get_outputs(), outputs):
1416*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.device_count() > 1:
1417*da0073e9SAndroid Build Coastguard Worker                output = output.detach().to("cuda:1")
1418*da0073e9SAndroid Build Coastguard Worker            device = output.device
1419*da0073e9SAndroid Build Coastguard Worker            iobinding.bind_output(
1420*da0073e9SAndroid Build Coastguard Worker                ort_output.name,
1421*da0073e9SAndroid Build Coastguard Worker                device.type,
1422*da0073e9SAndroid Build Coastguard Worker                device.index or 0,
1423*da0073e9SAndroid Build Coastguard Worker                self.TORCH_TO_NUMPY_DTYPE[output.dtype],
1424*da0073e9SAndroid Build Coastguard Worker                output.size(),
1425*da0073e9SAndroid Build Coastguard Worker                output.data_ptr(),
1426*da0073e9SAndroid Build Coastguard Worker            )
1427*da0073e9SAndroid Build Coastguard Worker        return iobinding, outputs
1428*da0073e9SAndroid Build Coastguard Worker
1429*da0073e9SAndroid Build Coastguard Worker    def run_with_iobinding(self, iobinding, outputs):
1430*da0073e9SAndroid Build Coastguard Worker        # 'outputs' are torch empty tensors binded to 'iobinding'.
1431*da0073e9SAndroid Build Coastguard Worker        self.onnx_session.run_with_iobinding(iobinding)
1432*da0073e9SAndroid Build Coastguard Worker        return outputs
1433*da0073e9SAndroid Build Coastguard Worker
1434*da0073e9SAndroid Build Coastguard Worker    def run_with_onnx_inputs(self, onnx_inputs):
1435*da0073e9SAndroid Build Coastguard Worker        return self.onnx_session.run(None, onnx_inputs)
1436*da0073e9SAndroid Build Coastguard Worker
1437*da0073e9SAndroid Build Coastguard Worker    @classmethod
1438*da0073e9SAndroid Build Coastguard Worker    def save_tensor_data(cls, numpy_tensor, output_path):
1439*da0073e9SAndroid Build Coastguard Worker        from onnx import numpy_helper
1440*da0073e9SAndroid Build Coastguard Worker
1441*da0073e9SAndroid Build Coastguard Worker        proto_tensor = numpy_helper.from_array(numpy_tensor)
1442*da0073e9SAndroid Build Coastguard Worker        with open(output_path, "wb") as f:
1443*da0073e9SAndroid Build Coastguard Worker            f.write(proto_tensor.SerializeToString())
1444*da0073e9SAndroid Build Coastguard Worker
1445*da0073e9SAndroid Build Coastguard Worker    def run_and_serialize_inputs_outputs(self, pt_inputs):
1446*da0073e9SAndroid Build Coastguard Worker        test_data_dir = self.model_dir / "test_data_set_0"
1447*da0073e9SAndroid Build Coastguard Worker        test_data_dir.mkdir(parents=True, exist_ok=True)
1448*da0073e9SAndroid Build Coastguard Worker
1449*da0073e9SAndroid Build Coastguard Worker        onnx_inputs = self.adapt_pt_inputs_to_onnx(pt_inputs)
1450*da0073e9SAndroid Build Coastguard Worker        for i, onnx_input in enumerate(onnx_inputs.values()):
1451*da0073e9SAndroid Build Coastguard Worker            self.save_tensor_data(onnx_input, str(test_data_dir / f"input_{i}.pb"))
1452*da0073e9SAndroid Build Coastguard Worker
1453*da0073e9SAndroid Build Coastguard Worker        onnx_outputs = self.run_with_onnx_inputs(onnx_inputs)
1454*da0073e9SAndroid Build Coastguard Worker
1455*da0073e9SAndroid Build Coastguard Worker        for i, onnx_output in enumerate(onnx_outputs):
1456*da0073e9SAndroid Build Coastguard Worker            self.save_tensor_data(onnx_output, str(test_data_dir / f"output_{i}.pb"))
1457*da0073e9SAndroid Build Coastguard Worker
1458*da0073e9SAndroid Build Coastguard Worker        return self.adapt_onnx_outputs_to_pt(onnx_outputs)
1459*da0073e9SAndroid Build Coastguard Worker
1460*da0073e9SAndroid Build Coastguard Worker    def run(self, pt_inputs):
1461*da0073e9SAndroid Build Coastguard Worker        # NOTE: For CUDA performance testing, use `run_with_iobinding` to exclude memory
1462*da0073e9SAndroid Build Coastguard Worker        # copying overhead for inputs/outputs between cpu and gpu.
1463*da0073e9SAndroid Build Coastguard Worker        # Otherwise perf number is inaccurate.
1464*da0073e9SAndroid Build Coastguard Worker        onnx_inputs = self.adapt_pt_inputs_to_onnx(pt_inputs)
1465*da0073e9SAndroid Build Coastguard Worker        onnx_outputs = self.run_with_onnx_inputs(onnx_inputs)
1466*da0073e9SAndroid Build Coastguard Worker        return self.adapt_onnx_outputs_to_pt(onnx_outputs)
1467*da0073e9SAndroid Build Coastguard Worker
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Workerclass OnnxModelFromTorchScript(OnnxModel):
1470*da0073e9SAndroid Build Coastguard Worker    """TorchScript based onnx export. `torch.onnx.export`
1471*da0073e9SAndroid Build Coastguard Worker
1472*da0073e9SAndroid Build Coastguard Worker    TODO(bowbao):
1473*da0073e9SAndroid Build Coastguard Worker    * large model export failed.
1474*da0073e9SAndroid Build Coastguard Worker          Onnx Model is larger than 2GB, but exporter makes decision based pt model size, which is
1475*da0073e9SAndroid Build Coastguard Worker          smaller than 2GB.
1476*da0073e9SAndroid Build Coastguard Worker    * OOM on slightly larger model.
1477*da0073e9SAndroid Build Coastguard Worker          Both pt model and ort inference session are on gpu. Attempt has been made to move ORT to
1478*da0073e9SAndroid Build Coastguard Worker          cuda:1, however ORT perf drop significantly.
1479*da0073e9SAndroid Build Coastguard Worker          For now running everything with batch_size 1 set in launch script.
1480*da0073e9SAndroid Build Coastguard Worker    """
1481*da0073e9SAndroid Build Coastguard Worker
1482*da0073e9SAndroid Build Coastguard Worker    _COMPILER_NAME = "torchscript"
1483*da0073e9SAndroid Build Coastguard Worker
1484*da0073e9SAndroid Build Coastguard Worker    def __init__(
1485*da0073e9SAndroid Build Coastguard Worker        self, output_directory, model, example_inputs, dynamic_shapes: bool, **kwargs
1486*da0073e9SAndroid Build Coastguard Worker    ):
1487*da0073e9SAndroid Build Coastguard Worker        if dynamic_shapes:
1488*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError("NYI dynamic shapes for OnnxModelFromTorchScript")
1489*da0073e9SAndroid Build Coastguard Worker        super().__init__(
1490*da0073e9SAndroid Build Coastguard Worker            output_directory, model, example_inputs, dynamic_shapes, **kwargs
1491*da0073e9SAndroid Build Coastguard Worker        )
1492*da0073e9SAndroid Build Coastguard Worker        self._export(
1493*da0073e9SAndroid Build Coastguard Worker            model,
1494*da0073e9SAndroid Build Coastguard Worker            example_inputs,
1495*da0073e9SAndroid Build Coastguard Worker            self.model_path,
1496*da0073e9SAndroid Build Coastguard Worker            opset_version=17,
1497*da0073e9SAndroid Build Coastguard Worker            do_constant_folding=False,
1498*da0073e9SAndroid Build Coastguard Worker            verbose=False,
1499*da0073e9SAndroid Build Coastguard Worker        )
1500*da0073e9SAndroid Build Coastguard Worker        self.onnx_session = self._init_ort_session(self.model_path)
1501*da0073e9SAndroid Build Coastguard Worker
1502*da0073e9SAndroid Build Coastguard Worker    def _export(self, model, example_inputs, output_path: str, /, **kwargs) -> None:
1503*da0073e9SAndroid Build Coastguard Worker        if self.copy_before_export:
1504*da0073e9SAndroid Build Coastguard Worker            # Deepcopy model before export to avoid modification to baseline model.
1505*da0073e9SAndroid Build Coastguard Worker            model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1506*da0073e9SAndroid Build Coastguard Worker                model, example_inputs, self._determine_deepcopy_target_device()
1507*da0073e9SAndroid Build Coastguard Worker            )
1508*da0073e9SAndroid Build Coastguard Worker
1509*da0073e9SAndroid Build Coastguard Worker        # Hack for huggingface models (kwargs only).
1510*da0073e9SAndroid Build Coastguard Worker        if isinstance(example_inputs, dict):
1511*da0073e9SAndroid Build Coastguard Worker
1512*da0073e9SAndroid Build Coastguard Worker            class WrapperModel(torch.nn.Module):
1513*da0073e9SAndroid Build Coastguard Worker                def __init__(self, model, keys):
1514*da0073e9SAndroid Build Coastguard Worker                    super().__init__()
1515*da0073e9SAndroid Build Coastguard Worker                    self.model = model
1516*da0073e9SAndroid Build Coastguard Worker                    self.keys = keys
1517*da0073e9SAndroid Build Coastguard Worker
1518*da0073e9SAndroid Build Coastguard Worker                def forward(self, *args):
1519*da0073e9SAndroid Build Coastguard Worker                    return self.model(**dict(zip(self.keys, args)))
1520*da0073e9SAndroid Build Coastguard Worker
1521*da0073e9SAndroid Build Coastguard Worker            model = WrapperModel(model, list(example_inputs.keys()))
1522*da0073e9SAndroid Build Coastguard Worker
1523*da0073e9SAndroid Build Coastguard Worker        torch.onnx.export(
1524*da0073e9SAndroid Build Coastguard Worker            model,
1525*da0073e9SAndroid Build Coastguard Worker            self.format_pt_inputs(example_inputs),
1526*da0073e9SAndroid Build Coastguard Worker            output_path,
1527*da0073e9SAndroid Build Coastguard Worker            **kwargs,
1528*da0073e9SAndroid Build Coastguard Worker        )
1529*da0073e9SAndroid Build Coastguard Worker
1530*da0073e9SAndroid Build Coastguard Worker    def format_pt_inputs(self, pt_inputs):
1531*da0073e9SAndroid Build Coastguard Worker        # NOTE(bowbao): For huggingface benchmark, pt_inputs are formatted as dictionary,
1532*da0073e9SAndroid Build Coastguard Worker        # and consumed like `model(**pt_inputs)`.
1533*da0073e9SAndroid Build Coastguard Worker        # For other benchmarks, pt_inputs are formatted as tuple and consumed
1534*da0073e9SAndroid Build Coastguard Worker        # like `model(*pt_inputs)`.
1535*da0073e9SAndroid Build Coastguard Worker        if isinstance(pt_inputs, dict):
1536*da0073e9SAndroid Build Coastguard Worker            pt_inputs = list(pt_inputs.values())
1537*da0073e9SAndroid Build Coastguard Worker        if isinstance(pt_inputs, torch.Tensor):
1538*da0073e9SAndroid Build Coastguard Worker            pt_inputs = (pt_inputs,)
1539*da0073e9SAndroid Build Coastguard Worker        return tuple(arg.contiguous() for arg in pt_inputs)
1540*da0073e9SAndroid Build Coastguard Worker
1541*da0073e9SAndroid Build Coastguard Worker    def format_pt_outputs(self, pt_outputs):
1542*da0073e9SAndroid Build Coastguard Worker        if isinstance(pt_outputs, torch.Tensor):
1543*da0073e9SAndroid Build Coastguard Worker            pt_outputs = (pt_outputs,)
1544*da0073e9SAndroid Build Coastguard Worker
1545*da0073e9SAndroid Build Coastguard Worker        pt_outputs = pytree.tree_leaves(pt_outputs)
1546*da0073e9SAndroid Build Coastguard Worker
1547*da0073e9SAndroid Build Coastguard Worker        # Hack for huggingface model outputs
1548*da0073e9SAndroid Build Coastguard Worker        try:
1549*da0073e9SAndroid Build Coastguard Worker            from transformers import modeling_outputs
1550*da0073e9SAndroid Build Coastguard Worker        except ImportError:
1551*da0073e9SAndroid Build Coastguard Worker            pass
1552*da0073e9SAndroid Build Coastguard Worker        else:
1553*da0073e9SAndroid Build Coastguard Worker
1554*da0073e9SAndroid Build Coastguard Worker            def _to_tuple(x):
1555*da0073e9SAndroid Build Coastguard Worker                if isinstance(x, modeling_outputs.ModelOutput):
1556*da0073e9SAndroid Build Coastguard Worker                    return x.to_tuple()
1557*da0073e9SAndroid Build Coastguard Worker                return x
1558*da0073e9SAndroid Build Coastguard Worker
1559*da0073e9SAndroid Build Coastguard Worker            pt_outputs = pytree.tree_map(_to_tuple, pt_outputs)
1560*da0073e9SAndroid Build Coastguard Worker            pt_outputs = pytree.tree_leaves(pt_outputs)
1561*da0073e9SAndroid Build Coastguard Worker
1562*da0073e9SAndroid Build Coastguard Worker        return pt_outputs
1563*da0073e9SAndroid Build Coastguard Worker
1564*da0073e9SAndroid Build Coastguard Worker
1565*da0073e9SAndroid Build Coastguard Workerclass OnnxModelFromDynamo(OnnxModel):
1566*da0073e9SAndroid Build Coastguard Worker    """Dynamo and Fx based export. `torch.onnx.dynamo_export`."""
1567*da0073e9SAndroid Build Coastguard Worker
1568*da0073e9SAndroid Build Coastguard Worker    _COMPILER_NAME = "dynamo"
1569*da0073e9SAndroid Build Coastguard Worker
1570*da0073e9SAndroid Build Coastguard Worker    def __init__(
1571*da0073e9SAndroid Build Coastguard Worker        self, output_directory, model, example_inputs, dynamic_shapes: bool, **kwargs
1572*da0073e9SAndroid Build Coastguard Worker    ):
1573*da0073e9SAndroid Build Coastguard Worker        super().__init__(
1574*da0073e9SAndroid Build Coastguard Worker            output_directory, model, example_inputs, dynamic_shapes, **kwargs
1575*da0073e9SAndroid Build Coastguard Worker        )
1576*da0073e9SAndroid Build Coastguard Worker        self._dynamic_shapes = dynamic_shapes
1577*da0073e9SAndroid Build Coastguard Worker        self._onnx_program = self._export(model, example_inputs, self.model_path)
1578*da0073e9SAndroid Build Coastguard Worker        # Clear the model proto to save memory.
1579*da0073e9SAndroid Build Coastguard Worker        # The model proto is saved to disk and no longer needed from `onnx_program`.
1580*da0073e9SAndroid Build Coastguard Worker        # `onnx_program` is kept for i/o adapter usage.
1581*da0073e9SAndroid Build Coastguard Worker        self._onnx_program.model_proto.Clear()
1582*da0073e9SAndroid Build Coastguard Worker        self.onnx_session = self._init_ort_session(self.model_path)
1583*da0073e9SAndroid Build Coastguard Worker
1584*da0073e9SAndroid Build Coastguard Worker    def _export(
1585*da0073e9SAndroid Build Coastguard Worker        self, model, example_inputs, output_path: str
1586*da0073e9SAndroid Build Coastguard Worker    ) -> torch.onnx.ONNXProgram:
1587*da0073e9SAndroid Build Coastguard Worker        if self.copy_before_export:
1588*da0073e9SAndroid Build Coastguard Worker            # Deepcopy model before export to avoid modification to baseline model.
1589*da0073e9SAndroid Build Coastguard Worker            model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1590*da0073e9SAndroid Build Coastguard Worker                model, example_inputs, self._determine_deepcopy_target_device()
1591*da0073e9SAndroid Build Coastguard Worker            )
1592*da0073e9SAndroid Build Coastguard Worker
1593*da0073e9SAndroid Build Coastguard Worker        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1594*da0073e9SAndroid Build Coastguard Worker        options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
1595*da0073e9SAndroid Build Coastguard Worker        onnx_program = torch.onnx.dynamo_export(
1596*da0073e9SAndroid Build Coastguard Worker            model, *example_args, **example_kwargs, export_options=options
1597*da0073e9SAndroid Build Coastguard Worker        )
1598*da0073e9SAndroid Build Coastguard Worker
1599*da0073e9SAndroid Build Coastguard Worker        onnx_program.save(output_path)
1600*da0073e9SAndroid Build Coastguard Worker        return onnx_program
1601*da0073e9SAndroid Build Coastguard Worker
1602*da0073e9SAndroid Build Coastguard Worker    def format_pt_inputs(self, pt_inputs):
1603*da0073e9SAndroid Build Coastguard Worker        pt_args, pt_kwargs = _normalize_bench_inputs(pt_inputs)
1604*da0073e9SAndroid Build Coastguard Worker        return self._onnx_program.adapt_torch_inputs_to_onnx(*pt_args, **pt_kwargs)
1605*da0073e9SAndroid Build Coastguard Worker
1606*da0073e9SAndroid Build Coastguard Worker    def format_pt_outputs(self, pt_outputs):
1607*da0073e9SAndroid Build Coastguard Worker        return self._onnx_program.adapt_torch_outputs_to_onnx(pt_outputs)
1608*da0073e9SAndroid Build Coastguard Worker
1609*da0073e9SAndroid Build Coastguard Worker
1610*da0073e9SAndroid Build Coastguard Workerclass OnnxModelFromDynamoAotInline(OnnxModelFromDynamo):
1611*da0073e9SAndroid Build Coastguard Worker    """Dynamo and Fx based export, with AOT inline post export. `torch.onnx.dynamo_export`."""
1612*da0073e9SAndroid Build Coastguard Worker
1613*da0073e9SAndroid Build Coastguard Worker    _COMPILER_NAME = "dynamo_aot_inline"
1614*da0073e9SAndroid Build Coastguard Worker
1615*da0073e9SAndroid Build Coastguard Worker    def _export(
1616*da0073e9SAndroid Build Coastguard Worker        self, model, example_inputs, output_path: str
1617*da0073e9SAndroid Build Coastguard Worker    ) -> torch.onnx.ONNXProgram:
1618*da0073e9SAndroid Build Coastguard Worker        if self.copy_before_export:
1619*da0073e9SAndroid Build Coastguard Worker            # Deepcopy model before export to avoid modification to baseline model.
1620*da0073e9SAndroid Build Coastguard Worker            model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1621*da0073e9SAndroid Build Coastguard Worker                model, example_inputs, self._determine_deepcopy_target_device()
1622*da0073e9SAndroid Build Coastguard Worker            )
1623*da0073e9SAndroid Build Coastguard Worker
1624*da0073e9SAndroid Build Coastguard Worker        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1625*da0073e9SAndroid Build Coastguard Worker        options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
1626*da0073e9SAndroid Build Coastguard Worker        onnx_program = torch.onnx.dynamo_export(
1627*da0073e9SAndroid Build Coastguard Worker            model, *example_args, **example_kwargs, export_options=options
1628*da0073e9SAndroid Build Coastguard Worker        )
1629*da0073e9SAndroid Build Coastguard Worker        # Apply AOT inline post export.
1630*da0073e9SAndroid Build Coastguard Worker        # Requires onnx >= 1.15
1631*da0073e9SAndroid Build Coastguard Worker        import onnx
1632*da0073e9SAndroid Build Coastguard Worker        import onnx.inliner
1633*da0073e9SAndroid Build Coastguard Worker
1634*da0073e9SAndroid Build Coastguard Worker        # Workaround for inliner not supporting with models larger than 2GB.
1635*da0073e9SAndroid Build Coastguard Worker        # Save model to disk first separating out external data,
1636*da0073e9SAndroid Build Coastguard Worker        # and load back without external data for inliner to work on.
1637*da0073e9SAndroid Build Coastguard Worker        model_proto = onnx_program.model_proto
1638*da0073e9SAndroid Build Coastguard Worker        onnx.save_model(model_proto, output_path, save_as_external_data=True)
1639*da0073e9SAndroid Build Coastguard Worker        model_proto = onnx.load(output_path, load_external_data=False)
1640*da0073e9SAndroid Build Coastguard Worker        model_proto = onnx.inliner.inline_local_functions(model_proto)
1641*da0073e9SAndroid Build Coastguard Worker        onnx.save_model(model_proto, output_path)
1642*da0073e9SAndroid Build Coastguard Worker        return onnx_program
1643*da0073e9SAndroid Build Coastguard Worker
1644*da0073e9SAndroid Build Coastguard Worker
1645*da0073e9SAndroid Build Coastguard Workerclass OnnxModelFromDynamoAotOptimize(OnnxModelFromDynamo):
1646*da0073e9SAndroid Build Coastguard Worker    """Dynamo and Fx based export, with AOT optimize post export. `torch.onnx.dynamo_export`."""
1647*da0073e9SAndroid Build Coastguard Worker
1648*da0073e9SAndroid Build Coastguard Worker    _COMPILER_NAME = "dynamo_aot_optimize"
1649*da0073e9SAndroid Build Coastguard Worker
1650*da0073e9SAndroid Build Coastguard Worker    def _export(
1651*da0073e9SAndroid Build Coastguard Worker        self, model, example_inputs, output_path: str
1652*da0073e9SAndroid Build Coastguard Worker    ) -> torch.onnx.ONNXProgram:
1653*da0073e9SAndroid Build Coastguard Worker        if self.copy_before_export:
1654*da0073e9SAndroid Build Coastguard Worker            # Deepcopy model before export to avoid modification to baseline model.
1655*da0073e9SAndroid Build Coastguard Worker            model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1656*da0073e9SAndroid Build Coastguard Worker                model, example_inputs, self._determine_deepcopy_target_device()
1657*da0073e9SAndroid Build Coastguard Worker            )
1658*da0073e9SAndroid Build Coastguard Worker
1659*da0073e9SAndroid Build Coastguard Worker        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1660*da0073e9SAndroid Build Coastguard Worker        options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
1661*da0073e9SAndroid Build Coastguard Worker        export_output = torch.onnx.dynamo_export(
1662*da0073e9SAndroid Build Coastguard Worker            model, *example_args, **example_kwargs, export_options=options
1663*da0073e9SAndroid Build Coastguard Worker        )
1664*da0073e9SAndroid Build Coastguard Worker
1665*da0073e9SAndroid Build Coastguard Worker        import onnx
1666*da0073e9SAndroid Build Coastguard Worker        from onnxscript.rewriter.onnxruntime import rewrite
1667*da0073e9SAndroid Build Coastguard Worker
1668*da0073e9SAndroid Build Coastguard Worker        model_proto = rewrite(export_output.model_proto)
1669*da0073e9SAndroid Build Coastguard Worker        onnx.save_model(
1670*da0073e9SAndroid Build Coastguard Worker            model_proto,
1671*da0073e9SAndroid Build Coastguard Worker            output_path,
1672*da0073e9SAndroid Build Coastguard Worker            save_as_external_data=True,
1673*da0073e9SAndroid Build Coastguard Worker            all_tensors_to_one_file=True,
1674*da0073e9SAndroid Build Coastguard Worker        )
1675*da0073e9SAndroid Build Coastguard Worker
1676*da0073e9SAndroid Build Coastguard Worker        return export_output
1677*da0073e9SAndroid Build Coastguard Worker
1678*da0073e9SAndroid Build Coastguard Worker
1679*da0073e9SAndroid Build Coastguard Workerclass _OnnxPatch:
1680*da0073e9SAndroid Build Coastguard Worker    @classmethod
1681*da0073e9SAndroid Build Coastguard Worker    def patch_non_tensor_outputs(cls, correct_result, new_result, fp64_outputs):
1682*da0073e9SAndroid Build Coastguard Worker        """Patch non-tensor outputs to make them comparable with the correct result.
1683*da0073e9SAndroid Build Coastguard Worker
1684*da0073e9SAndroid Build Coastguard Worker        ONNX model always returns a flat tuple of tensors, but the PyTorch model outputs
1685*da0073e9SAndroid Build Coastguard Worker        `correct_result` and `fp64_outputs` can be arbitrary types. This function normalizes
1686*da0073e9SAndroid Build Coastguard Worker        the outputs to make them comparable with the ONNX model output.
1687*da0073e9SAndroid Build Coastguard Worker        """
1688*da0073e9SAndroid Build Coastguard Worker        try:
1689*da0073e9SAndroid Build Coastguard Worker            from transformers import modeling_outputs
1690*da0073e9SAndroid Build Coastguard Worker        except ImportError:
1691*da0073e9SAndroid Build Coastguard Worker            has_transformers = False
1692*da0073e9SAndroid Build Coastguard Worker        else:
1693*da0073e9SAndroid Build Coastguard Worker            has_transformers = True
1694*da0073e9SAndroid Build Coastguard Worker
1695*da0073e9SAndroid Build Coastguard Worker        if has_transformers and isinstance(
1696*da0073e9SAndroid Build Coastguard Worker            correct_result, modeling_outputs.ModelOutput
1697*da0073e9SAndroid Build Coastguard Worker        ):
1698*da0073e9SAndroid Build Coastguard Worker            correct_result = correct_result.to_tuple()
1699*da0073e9SAndroid Build Coastguard Worker            fp64_outputs = fp64_outputs.to_tuple() if fp64_outputs is not None else None
1700*da0073e9SAndroid Build Coastguard Worker        elif type(correct_result).__name__ in (
1701*da0073e9SAndroid Build Coastguard Worker            "MaskedLMOutput",
1702*da0073e9SAndroid Build Coastguard Worker            "Seq2SeqLMOutput",
1703*da0073e9SAndroid Build Coastguard Worker            "CausalLMOutputWithCrossAttentions",
1704*da0073e9SAndroid Build Coastguard Worker            "LongformerMaskedLMOutput",
1705*da0073e9SAndroid Build Coastguard Worker            "Instances",
1706*da0073e9SAndroid Build Coastguard Worker            "SquashedNormal",
1707*da0073e9SAndroid Build Coastguard Worker            "Boxes",
1708*da0073e9SAndroid Build Coastguard Worker            "Normal",
1709*da0073e9SAndroid Build Coastguard Worker            "TanhTransform",
1710*da0073e9SAndroid Build Coastguard Worker            "Foo",
1711*da0073e9SAndroid Build Coastguard Worker            "Variable",
1712*da0073e9SAndroid Build Coastguard Worker        ):
1713*da0073e9SAndroid Build Coastguard Worker            # Copied from `same` function in `torch._dynamo.utils`
1714*da0073e9SAndroid Build Coastguard Worker            correct_result = [
1715*da0073e9SAndroid Build Coastguard Worker                value
1716*da0073e9SAndroid Build Coastguard Worker                for key in correct_result.__dict__.keys()
1717*da0073e9SAndroid Build Coastguard Worker                if (value := getattr(correct_result, key)) is not None
1718*da0073e9SAndroid Build Coastguard Worker            ]
1719*da0073e9SAndroid Build Coastguard Worker            fp64_outputs = (
1720*da0073e9SAndroid Build Coastguard Worker                [
1721*da0073e9SAndroid Build Coastguard Worker                    value
1722*da0073e9SAndroid Build Coastguard Worker                    for key in fp64_outputs.__dict__.keys()
1723*da0073e9SAndroid Build Coastguard Worker                    if (value := getattr(fp64_outputs, key)) is not None
1724*da0073e9SAndroid Build Coastguard Worker                ]
1725*da0073e9SAndroid Build Coastguard Worker                if fp64_outputs is not None
1726*da0073e9SAndroid Build Coastguard Worker                else None
1727*da0073e9SAndroid Build Coastguard Worker            )
1728*da0073e9SAndroid Build Coastguard Worker
1729*da0073e9SAndroid Build Coastguard Worker        # Flatten nested tuple of tensors, i.e. past_key_values
1730*da0073e9SAndroid Build Coastguard Worker        correct_result = pytree.tree_leaves(correct_result)
1731*da0073e9SAndroid Build Coastguard Worker        # Hack to put results from different runs on same device.
1732*da0073e9SAndroid Build Coastguard Worker        # This is needed for ONNX CPU fallback benchmark, where PyTorch eager is run on GPU.
1733*da0073e9SAndroid Build Coastguard Worker        # Assuming outputs from a single run are always on same device!
1734*da0073e9SAndroid Build Coastguard Worker        devices = [x.device for x in correct_result if isinstance(x, torch.Tensor)]
1735*da0073e9SAndroid Build Coastguard Worker        assert devices and all(
1736*da0073e9SAndroid Build Coastguard Worker            x == devices[0] for x in devices
1737*da0073e9SAndroid Build Coastguard Worker        ), "All tensors must be on same device!"
1738*da0073e9SAndroid Build Coastguard Worker        device = devices[0]
1739*da0073e9SAndroid Build Coastguard Worker        new_result = pytree.tree_leaves(new_result)
1740*da0073e9SAndroid Build Coastguard Worker        new_result = pytree.tree_map(
1741*da0073e9SAndroid Build Coastguard Worker            lambda x: x.to(device=device) if isinstance(x, torch.Tensor) else x,
1742*da0073e9SAndroid Build Coastguard Worker            new_result,
1743*da0073e9SAndroid Build Coastguard Worker        )
1744*da0073e9SAndroid Build Coastguard Worker        fp64_outputs = pytree.tree_leaves(fp64_outputs)
1745*da0073e9SAndroid Build Coastguard Worker
1746*da0073e9SAndroid Build Coastguard Worker        return correct_result, new_result, fp64_outputs
1747*da0073e9SAndroid Build Coastguard Worker
1748*da0073e9SAndroid Build Coastguard Worker
1749*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass
1750*da0073e9SAndroid Build Coastguard Workerclass OnnxExportErrorRow:
1751*da0073e9SAndroid Build Coastguard Worker    device: str
1752*da0073e9SAndroid Build Coastguard Worker    model_name: str
1753*da0073e9SAndroid Build Coastguard Worker    batch_size: int
1754*da0073e9SAndroid Build Coastguard Worker    rule_id: Optional[str] = None
1755*da0073e9SAndroid Build Coastguard Worker    rule_name: Optional[str] = None
1756*da0073e9SAndroid Build Coastguard Worker    diagnostic_level: Optional[str] = None
1757*da0073e9SAndroid Build Coastguard Worker    diagnostic_message: Optional[str] = None
1758*da0073e9SAndroid Build Coastguard Worker    exception_type_name: Optional[str] = None
1759*da0073e9SAndroid Build Coastguard Worker    exception_message: Optional[str] = None
1760*da0073e9SAndroid Build Coastguard Worker
1761*da0073e9SAndroid Build Coastguard Worker    def __post_init__(self):
1762*da0073e9SAndroid Build Coastguard Worker        assert (
1763*da0073e9SAndroid Build Coastguard Worker            self.rule_id is not None
1764*da0073e9SAndroid Build Coastguard Worker            and self.rule_name is not None
1765*da0073e9SAndroid Build Coastguard Worker            and self.diagnostic_level is not None
1766*da0073e9SAndroid Build Coastguard Worker            and self.diagnostic_message is not None
1767*da0073e9SAndroid Build Coastguard Worker        ) or self.exception_type_name, (
1768*da0073e9SAndroid Build Coastguard Worker            "Either rule_id, rule_name, diagnostic_level and diagnostic_message "
1769*da0073e9SAndroid Build Coastguard Worker            "must be set or exception_type_name must be set"
1770*da0073e9SAndroid Build Coastguard Worker        )
1771*da0073e9SAndroid Build Coastguard Worker
1772*da0073e9SAndroid Build Coastguard Worker    @property
1773*da0073e9SAndroid Build Coastguard Worker    def headers(self) -> List[str]:
1774*da0073e9SAndroid Build Coastguard Worker        return [field.name for field in dataclasses.fields(self)]
1775*da0073e9SAndroid Build Coastguard Worker
1776*da0073e9SAndroid Build Coastguard Worker    @property
1777*da0073e9SAndroid Build Coastguard Worker    def row(self) -> List[str]:
1778*da0073e9SAndroid Build Coastguard Worker        return [getattr(self, field.name) for field in dataclasses.fields(self)]
1779*da0073e9SAndroid Build Coastguard Worker
1780*da0073e9SAndroid Build Coastguard Worker
1781*da0073e9SAndroid Build Coastguard Workerclass OnnxExportErrorParser:
1782*da0073e9SAndroid Build Coastguard Worker    def __init__(self, device: str, model_name: str, batch_size: int):
1783*da0073e9SAndroid Build Coastguard Worker        self.device = device
1784*da0073e9SAndroid Build Coastguard Worker        self.model_name = model_name
1785*da0073e9SAndroid Build Coastguard Worker        self.batch_size = batch_size
1786*da0073e9SAndroid Build Coastguard Worker
1787*da0073e9SAndroid Build Coastguard Worker    def _qualified_exception_class_name(self, exception: Exception) -> str:
1788*da0073e9SAndroid Build Coastguard Worker        if exception.__class__.__module__ == "builtins":
1789*da0073e9SAndroid Build Coastguard Worker            return exception.__class__.__name__
1790*da0073e9SAndroid Build Coastguard Worker        return f"{exception.__class__.__module__}.{exception.__class__.__name__}"
1791*da0073e9SAndroid Build Coastguard Worker
1792*da0073e9SAndroid Build Coastguard Worker    def parse_diagnostic_context(
1793*da0073e9SAndroid Build Coastguard Worker        self,
1794*da0073e9SAndroid Build Coastguard Worker        diagnostic_context: diagnostics.DiagnosticContext,
1795*da0073e9SAndroid Build Coastguard Worker    ) -> Generator[OnnxExportErrorRow, Any, Any]:
1796*da0073e9SAndroid Build Coastguard Worker        from torch.onnx._internal.fx import diagnostics
1797*da0073e9SAndroid Build Coastguard Worker
1798*da0073e9SAndroid Build Coastguard Worker        for diagnostic in diagnostic_context.diagnostics:
1799*da0073e9SAndroid Build Coastguard Worker            if diagnostic.level >= diagnostics.levels.ERROR:
1800*da0073e9SAndroid Build Coastguard Worker                yield OnnxExportErrorRow(
1801*da0073e9SAndroid Build Coastguard Worker                    device=self.device,
1802*da0073e9SAndroid Build Coastguard Worker                    model_name=self.model_name,
1803*da0073e9SAndroid Build Coastguard Worker                    batch_size=self.batch_size,
1804*da0073e9SAndroid Build Coastguard Worker                    rule_id=diagnostic.rule.id,
1805*da0073e9SAndroid Build Coastguard Worker                    rule_name=diagnostic.rule.name,
1806*da0073e9SAndroid Build Coastguard Worker                    diagnostic_level=diagnostic.level.name,
1807*da0073e9SAndroid Build Coastguard Worker                    diagnostic_message=diagnostic.message,
1808*da0073e9SAndroid Build Coastguard Worker                )
1809*da0073e9SAndroid Build Coastguard Worker
1810*da0073e9SAndroid Build Coastguard Worker    def parse_exception(self, exception: Exception) -> OnnxExportErrorRow:
1811*da0073e9SAndroid Build Coastguard Worker        return OnnxExportErrorRow(
1812*da0073e9SAndroid Build Coastguard Worker            device=self.device,
1813*da0073e9SAndroid Build Coastguard Worker            model_name=self.model_name,
1814*da0073e9SAndroid Build Coastguard Worker            batch_size=self.batch_size,
1815*da0073e9SAndroid Build Coastguard Worker            exception_type_name=self._qualified_exception_class_name(exception),
1816*da0073e9SAndroid Build Coastguard Worker            exception_message=str(exception),
1817*da0073e9SAndroid Build Coastguard Worker        )
1818*da0073e9SAndroid Build Coastguard Worker
1819*da0073e9SAndroid Build Coastguard Worker
1820*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass
1821*da0073e9SAndroid Build Coastguard Workerclass OnnxContext:
1822*da0073e9SAndroid Build Coastguard Worker    onnx_model: Optional[OnnxModel] = None
1823*da0073e9SAndroid Build Coastguard Worker
1824*da0073e9SAndroid Build Coastguard Worker
1825*da0073e9SAndroid Build Coastguard Workerdef optimize_onnx_ctx(
1826*da0073e9SAndroid Build Coastguard Worker    output_directory: str,
1827*da0073e9SAndroid Build Coastguard Worker    onnx_model_cls: Type[OnnxModel],
1828*da0073e9SAndroid Build Coastguard Worker    run_n_iterations: Callable,
1829*da0073e9SAndroid Build Coastguard Worker    dynamic_shapes: bool = False,
1830*da0073e9SAndroid Build Coastguard Worker    copy_before_export: bool = False,
1831*da0073e9SAndroid Build Coastguard Worker) -> Callable:
1832*da0073e9SAndroid Build Coastguard Worker    # NOTE(bowbao): This function creates and returns the onnx version of 'run_n_iterations',
1833*da0073e9SAndroid Build Coastguard Worker    # which does the following:
1834*da0073e9SAndroid Build Coastguard Worker    #   1. Export and cache model.
1835*da0073e9SAndroid Build Coastguard Worker    #   2. Create iobinding for ORT.
1836*da0073e9SAndroid Build Coastguard Worker    #   3. Run ORT for n iterations.
1837*da0073e9SAndroid Build Coastguard Worker    # The cached model is stored in 'context' under the returned callable.
1838*da0073e9SAndroid Build Coastguard Worker    context = OnnxContext()
1839*da0073e9SAndroid Build Coastguard Worker    test_data_dumped = False
1840*da0073e9SAndroid Build Coastguard Worker
1841*da0073e9SAndroid Build Coastguard Worker    def run_n_iterations_onnx(model, inputs, n=2):
1842*da0073e9SAndroid Build Coastguard Worker        from torch.onnx._internal import exporter
1843*da0073e9SAndroid Build Coastguard Worker        from torch.onnx._internal.fx import diagnostics
1844*da0073e9SAndroid Build Coastguard Worker
1845*da0073e9SAndroid Build Coastguard Worker        # NOTE(bowbao): Capture all export & ort errors and diagnostics.
1846*da0073e9SAndroid Build Coastguard Worker        # Serialize to csv, to be parsed and summarized later by '._onnx/reporter.py'.
1847*da0073e9SAndroid Build Coastguard Worker        # TODO: Accuracy mismatch is not reported here in csv.
1848*da0073e9SAndroid Build Coastguard Worker        assert (
1849*da0073e9SAndroid Build Coastguard Worker            output_filename.find(".csv") > 0
1850*da0073e9SAndroid Build Coastguard Worker        ), f"expected output_filename to be a .csv, but got {output_filename}"
1851*da0073e9SAndroid Build Coastguard Worker        output_error_filename = output_filename[:-4] + "_export_error.csv"
1852*da0073e9SAndroid Build Coastguard Worker        parser = OnnxExportErrorParser(current_device, current_name, current_batch_size)
1853*da0073e9SAndroid Build Coastguard Worker        try:
1854*da0073e9SAndroid Build Coastguard Worker            nonlocal context
1855*da0073e9SAndroid Build Coastguard Worker            if context.onnx_model is None:
1856*da0073e9SAndroid Build Coastguard Worker                context.onnx_model = onnx_model_cls(
1857*da0073e9SAndroid Build Coastguard Worker                    output_directory,
1858*da0073e9SAndroid Build Coastguard Worker                    model,
1859*da0073e9SAndroid Build Coastguard Worker                    copy.deepcopy(inputs),
1860*da0073e9SAndroid Build Coastguard Worker                    dynamic_shapes=dynamic_shapes,
1861*da0073e9SAndroid Build Coastguard Worker                    copy_before_export=copy_before_export,
1862*da0073e9SAndroid Build Coastguard Worker                )
1863*da0073e9SAndroid Build Coastguard Worker            onnx_model = context.onnx_model
1864*da0073e9SAndroid Build Coastguard Worker
1865*da0073e9SAndroid Build Coastguard Worker            for _ in range(n):
1866*da0073e9SAndroid Build Coastguard Worker                nonlocal test_data_dumped
1867*da0073e9SAndroid Build Coastguard Worker                if not test_data_dumped:
1868*da0073e9SAndroid Build Coastguard Worker                    # Serializes inputs and outputs to .pb files for further offline analysis.
1869*da0073e9SAndroid Build Coastguard Worker                    # Due to this, this function is not and should not be used for perf measurement.
1870*da0073e9SAndroid Build Coastguard Worker                    outputs = onnx_model.run_and_serialize_inputs_outputs(inputs)
1871*da0073e9SAndroid Build Coastguard Worker                    test_data_dumped = True
1872*da0073e9SAndroid Build Coastguard Worker                else:
1873*da0073e9SAndroid Build Coastguard Worker                    outputs = onnx_model.run(inputs)
1874*da0073e9SAndroid Build Coastguard Worker            return outputs
1875*da0073e9SAndroid Build Coastguard Worker        except exporter.OnnxExporterError as e:
1876*da0073e9SAndroid Build Coastguard Worker            # `torch.onnx.dynamo_export` raises error that encloses diagnostics.
1877*da0073e9SAndroid Build Coastguard Worker            diagnostic_context = e.onnx_program.diagnostic_context
1878*da0073e9SAndroid Build Coastguard Worker            for parsed_error in parser.parse_diagnostic_context(diagnostic_context):
1879*da0073e9SAndroid Build Coastguard Worker                output_csv(
1880*da0073e9SAndroid Build Coastguard Worker                    output_error_filename, parsed_error.headers, parsed_error.row
1881*da0073e9SAndroid Build Coastguard Worker                )
1882*da0073e9SAndroid Build Coastguard Worker            if context.onnx_model is not None:
1883*da0073e9SAndroid Build Coastguard Worker                e.onnx_program.save_diagnostics(
1884*da0073e9SAndroid Build Coastguard Worker                    f"{context.onnx_model.model_dir}/"
1885*da0073e9SAndroid Build Coastguard Worker                    f"{current_onnx_compiler}_{current_name}_{current_device}.sarif"
1886*da0073e9SAndroid Build Coastguard Worker                )
1887*da0073e9SAndroid Build Coastguard Worker
1888*da0073e9SAndroid Build Coastguard Worker            # Check also the raw exception that caused export failure.
1889*da0073e9SAndroid Build Coastguard Worker            # Skip if it is already analyzed by diagnostics.
1890*da0073e9SAndroid Build Coastguard Worker            cause_of_exception = e.__cause__
1891*da0073e9SAndroid Build Coastguard Worker            if not isinstance(
1892*da0073e9SAndroid Build Coastguard Worker                cause_of_exception, diagnostics.RuntimeErrorWithDiagnostic
1893*da0073e9SAndroid Build Coastguard Worker            ):
1894*da0073e9SAndroid Build Coastguard Worker                parsed_error = parser.parse_exception(cause_of_exception)
1895*da0073e9SAndroid Build Coastguard Worker                output_csv(
1896*da0073e9SAndroid Build Coastguard Worker                    output_error_filename, parsed_error.headers, parsed_error.row
1897*da0073e9SAndroid Build Coastguard Worker                )
1898*da0073e9SAndroid Build Coastguard Worker            raise
1899*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
1900*da0073e9SAndroid Build Coastguard Worker            # `torch.onnx.export` errors.
1901*da0073e9SAndroid Build Coastguard Worker            # ORT errors.
1902*da0073e9SAndroid Build Coastguard Worker            parsed_error = parser.parse_exception(e)
1903*da0073e9SAndroid Build Coastguard Worker            output_csv(output_error_filename, parsed_error.headers, parsed_error.row)
1904*da0073e9SAndroid Build Coastguard Worker            raise
1905*da0073e9SAndroid Build Coastguard Worker
1906*da0073e9SAndroid Build Coastguard Worker    run_n_iterations_onnx.context = context
1907*da0073e9SAndroid Build Coastguard Worker
1908*da0073e9SAndroid Build Coastguard Worker    return run_n_iterations_onnx
1909*da0073e9SAndroid Build Coastguard Worker
1910*da0073e9SAndroid Build Coastguard Worker
1911*da0073e9SAndroid Build Coastguard Workerdef read_batch_size_from_file(args, filename, model_name):
1912*da0073e9SAndroid Build Coastguard Worker    batch_size = None
1913*da0073e9SAndroid Build Coastguard Worker    if os.path.exists("benchmarks"):
1914*da0073e9SAndroid Build Coastguard Worker        filename = os.path.join("benchmarks", filename)
1915*da0073e9SAndroid Build Coastguard Worker    assert os.path.exists(filename), filename
1916*da0073e9SAndroid Build Coastguard Worker    with open(filename) as f:
1917*da0073e9SAndroid Build Coastguard Worker        lines = f.readlines()
1918*da0073e9SAndroid Build Coastguard Worker        lines = [i.split(",") for i in lines if len(i.strip()) > 0]
1919*da0073e9SAndroid Build Coastguard Worker        for val in lines:
1920*da0073e9SAndroid Build Coastguard Worker            cur_name, b = val
1921*da0073e9SAndroid Build Coastguard Worker            if model_name == cur_name:
1922*da0073e9SAndroid Build Coastguard Worker                batch_size = int(b)
1923*da0073e9SAndroid Build Coastguard Worker    if batch_size is None:
1924*da0073e9SAndroid Build Coastguard Worker        log.warning("Could not find batch size for %s", model_name)
1925*da0073e9SAndroid Build Coastguard Worker    elif batch_size == -1:
1926*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
1927*da0073e9SAndroid Build Coastguard Worker            f"Batch size is unset for {model_name} in {args.batch_size_file}"
1928*da0073e9SAndroid Build Coastguard Worker        )
1929*da0073e9SAndroid Build Coastguard Worker    print(f"batch size: {batch_size}")
1930*da0073e9SAndroid Build Coastguard Worker    return batch_size
1931*da0073e9SAndroid Build Coastguard Worker
1932*da0073e9SAndroid Build Coastguard Worker
1933*da0073e9SAndroid Build Coastguard Workerclass TimeOutException(Exception):
1934*da0073e9SAndroid Build Coastguard Worker    pass
1935*da0073e9SAndroid Build Coastguard Worker
1936*da0073e9SAndroid Build Coastguard Worker
1937*da0073e9SAndroid Build Coastguard Workerdef alarm_handler(signum, frame):
1938*da0073e9SAndroid Build Coastguard Worker    raise TimeOutException
1939*da0073e9SAndroid Build Coastguard Worker
1940*da0073e9SAndroid Build Coastguard Worker
1941*da0073e9SAndroid Build Coastguard Workerdef exit_after(s):
1942*da0073e9SAndroid Build Coastguard Worker    """
1943*da0073e9SAndroid Build Coastguard Worker    Decorator to raise TimeoutException if the fn is taking more than s seconds
1944*da0073e9SAndroid Build Coastguard Worker    to run.
1945*da0073e9SAndroid Build Coastguard Worker    """
1946*da0073e9SAndroid Build Coastguard Worker
1947*da0073e9SAndroid Build Coastguard Worker    def outer(fn):
1948*da0073e9SAndroid Build Coastguard Worker        def inner(*args, **kwargs):
1949*da0073e9SAndroid Build Coastguard Worker            signal.signal(signal.SIGALRM, alarm_handler)
1950*da0073e9SAndroid Build Coastguard Worker            signal.alarm(s)
1951*da0073e9SAndroid Build Coastguard Worker            try:
1952*da0073e9SAndroid Build Coastguard Worker                result = fn(*args, **kwargs)
1953*da0073e9SAndroid Build Coastguard Worker            finally:
1954*da0073e9SAndroid Build Coastguard Worker                signal.alarm(0)
1955*da0073e9SAndroid Build Coastguard Worker            return result
1956*da0073e9SAndroid Build Coastguard Worker
1957*da0073e9SAndroid Build Coastguard Worker        return inner
1958*da0073e9SAndroid Build Coastguard Worker
1959*da0073e9SAndroid Build Coastguard Worker    return outer
1960*da0073e9SAndroid Build Coastguard Worker
1961*da0073e9SAndroid Build Coastguard Worker
1962*da0073e9SAndroid Build Coastguard Workerdef get_peak_memory():
1963*da0073e9SAndroid Build Coastguard Worker    return torch.cuda.max_memory_allocated() / 10**9
1964*da0073e9SAndroid Build Coastguard Worker
1965*da0073e9SAndroid Build Coastguard Worker
1966*da0073e9SAndroid Build Coastguard Workerdef null_experiment(args, model_iter_fn, model, example_inputs):
1967*da0073e9SAndroid Build Coastguard Worker    """
1968*da0073e9SAndroid Build Coastguard Worker    A no-op experiment useful for making sure TorchBenchark alone works properly.
1969*da0073e9SAndroid Build Coastguard Worker    """
1970*da0073e9SAndroid Build Coastguard Worker
1971*da0073e9SAndroid Build Coastguard Worker    return []
1972*da0073e9SAndroid Build Coastguard Worker
1973*da0073e9SAndroid Build Coastguard Worker
1974*da0073e9SAndroid Build Coastguard Workerdef cast_to(dtype, model, inputs):
1975*da0073e9SAndroid Build Coastguard Worker    # cast model and inputs to fp16
1976*da0073e9SAndroid Build Coastguard Worker    if dtype == torch.float16:
1977*da0073e9SAndroid Build Coastguard Worker        model = model.half()
1978*da0073e9SAndroid Build Coastguard Worker    else:
1979*da0073e9SAndroid Build Coastguard Worker        model = model.to(dtype)
1980*da0073e9SAndroid Build Coastguard Worker
1981*da0073e9SAndroid Build Coastguard Worker    inputs = tree_map(
1982*da0073e9SAndroid Build Coastguard Worker        lambda x: x.to(dtype)
1983*da0073e9SAndroid Build Coastguard Worker        if isinstance(x, torch.Tensor) and x.is_floating_point()
1984*da0073e9SAndroid Build Coastguard Worker        else x,
1985*da0073e9SAndroid Build Coastguard Worker        inputs,
1986*da0073e9SAndroid Build Coastguard Worker    )
1987*da0073e9SAndroid Build Coastguard Worker    return model, inputs
1988*da0073e9SAndroid Build Coastguard Worker
1989*da0073e9SAndroid Build Coastguard Worker
1990*da0073e9SAndroid Build Coastguard Workerdef cast_to_bf16(model, inputs):
1991*da0073e9SAndroid Build Coastguard Worker    return cast_to(torch.bfloat16, model, inputs)
1992*da0073e9SAndroid Build Coastguard Worker
1993*da0073e9SAndroid Build Coastguard Worker
1994*da0073e9SAndroid Build Coastguard Workerdef cast_to_fp16(model, inputs):
1995*da0073e9SAndroid Build Coastguard Worker    return cast_to(torch.float16, model, inputs)
1996*da0073e9SAndroid Build Coastguard Worker
1997*da0073e9SAndroid Build Coastguard Worker
1998*da0073e9SAndroid Build Coastguard Workerdef cast_to_fp64(model, inputs):
1999*da0073e9SAndroid Build Coastguard Worker    return cast_to(torch.float64, model, inputs)
2000*da0073e9SAndroid Build Coastguard Worker
2001*da0073e9SAndroid Build Coastguard Worker
2002*da0073e9SAndroid Build Coastguard Workerdef cast_to_fp32(model, inputs):
2003*da0073e9SAndroid Build Coastguard Worker    return cast_to(torch.float32, model, inputs)
2004*da0073e9SAndroid Build Coastguard Worker
2005*da0073e9SAndroid Build Coastguard Worker
2006*da0073e9SAndroid Build Coastguard Workerclass DummyGradScaler:
2007*da0073e9SAndroid Build Coastguard Worker    def scale(self, loss):
2008*da0073e9SAndroid Build Coastguard Worker        return loss
2009*da0073e9SAndroid Build Coastguard Worker
2010*da0073e9SAndroid Build Coastguard Worker
2011*da0073e9SAndroid Build Coastguard Workerdef get_dynamo_stats():
2012*da0073e9SAndroid Build Coastguard Worker    # TODO: consider deepcopy'ing the entire counters struct and
2013*da0073e9SAndroid Build Coastguard Worker    # adding a helper to do subtraction on it
2014*da0073e9SAndroid Build Coastguard Worker    return collections.Counter(
2015*da0073e9SAndroid Build Coastguard Worker        {
2016*da0073e9SAndroid Build Coastguard Worker            "calls_captured": torch._dynamo.utils.counters["stats"]["calls_captured"],
2017*da0073e9SAndroid Build Coastguard Worker            "unique_graphs": torch._dynamo.utils.counters["stats"]["unique_graphs"],
2018*da0073e9SAndroid Build Coastguard Worker            "graph_breaks": sum(torch._dynamo.utils.counters["graph_break"].values()),
2019*da0073e9SAndroid Build Coastguard Worker            # NB: The plus removes zero counts
2020*da0073e9SAndroid Build Coastguard Worker            "unique_graph_breaks": len(+torch._dynamo.utils.counters["graph_break"]),
2021*da0073e9SAndroid Build Coastguard Worker            "autograd_captures": torch._dynamo.utils.counters["compiled_autograd"][
2022*da0073e9SAndroid Build Coastguard Worker                "captures"
2023*da0073e9SAndroid Build Coastguard Worker            ],
2024*da0073e9SAndroid Build Coastguard Worker            "autograd_compiles": torch._dynamo.utils.counters["compiled_autograd"][
2025*da0073e9SAndroid Build Coastguard Worker                "compiles"
2026*da0073e9SAndroid Build Coastguard Worker            ],
2027*da0073e9SAndroid Build Coastguard Worker            "cudagraph_skips": torch._dynamo.utils.counters["inductor"][
2028*da0073e9SAndroid Build Coastguard Worker                "cudagraph_skips"
2029*da0073e9SAndroid Build Coastguard Worker            ],
2030*da0073e9SAndroid Build Coastguard Worker        }
2031*da0073e9SAndroid Build Coastguard Worker    )
2032*da0073e9SAndroid Build Coastguard Worker
2033*da0073e9SAndroid Build Coastguard Worker
2034*da0073e9SAndroid Build Coastguard Worker@contextmanager
2035*da0073e9SAndroid Build Coastguard Workerdef maybe_init_distributed(should_init_distributed, rank, world_size, port="6789"):
2036*da0073e9SAndroid Build Coastguard Worker    try:
2037*da0073e9SAndroid Build Coastguard Worker        if should_init_distributed:
2038*da0073e9SAndroid Build Coastguard Worker            torch.cuda.set_device(rank)
2039*da0073e9SAndroid Build Coastguard Worker            os.environ["MASTER_ADDR"] = "localhost"
2040*da0073e9SAndroid Build Coastguard Worker            os.environ["MASTER_PORT"] = port
2041*da0073e9SAndroid Build Coastguard Worker            torch.distributed.init_process_group(
2042*da0073e9SAndroid Build Coastguard Worker                "nccl", rank=rank, world_size=world_size
2043*da0073e9SAndroid Build Coastguard Worker            )
2044*da0073e9SAndroid Build Coastguard Worker        yield
2045*da0073e9SAndroid Build Coastguard Worker    finally:
2046*da0073e9SAndroid Build Coastguard Worker        if should_init_distributed:
2047*da0073e9SAndroid Build Coastguard Worker            torch.distributed.destroy_process_group()
2048*da0073e9SAndroid Build Coastguard Worker
2049*da0073e9SAndroid Build Coastguard Worker
2050*da0073e9SAndroid Build Coastguard Worker@contextmanager
2051*da0073e9SAndroid Build Coastguard Workerdef maybe_snapshot_memory(should_snapshot_memory, suffix):
2052*da0073e9SAndroid Build Coastguard Worker    # Enables Memory Snapshot tool for memory deep dives:
2053*da0073e9SAndroid Build Coastguard Worker    # https://pytorch.org/blog/understanding-gpu-memory-1/
2054*da0073e9SAndroid Build Coastguard Worker    try:
2055*da0073e9SAndroid Build Coastguard Worker        if should_snapshot_memory:
2056*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._record_memory_history(max_entries=100000)
2057*da0073e9SAndroid Build Coastguard Worker        yield
2058*da0073e9SAndroid Build Coastguard Worker    finally:
2059*da0073e9SAndroid Build Coastguard Worker        if should_snapshot_memory:
2060*da0073e9SAndroid Build Coastguard Worker            try:
2061*da0073e9SAndroid Build Coastguard Worker                torch.cuda.memory._dump_snapshot(
2062*da0073e9SAndroid Build Coastguard Worker                    os.path.join(
2063*da0073e9SAndroid Build Coastguard Worker                        torch._dynamo.config.base_dir,
2064*da0073e9SAndroid Build Coastguard Worker                        f"{output_filename.rstrip('.csv')}_{suffix}.pickle",
2065*da0073e9SAndroid Build Coastguard Worker                    )
2066*da0073e9SAndroid Build Coastguard Worker                )
2067*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
2068*da0073e9SAndroid Build Coastguard Worker                logging.error("Failed to save memory snapshot, %s", e)
2069*da0073e9SAndroid Build Coastguard Worker
2070*da0073e9SAndroid Build Coastguard Worker            torch.cuda.memory._record_memory_history(enabled=None)
2071*da0073e9SAndroid Build Coastguard Worker
2072*da0073e9SAndroid Build Coastguard Worker
2073*da0073e9SAndroid Build Coastguard Workerclass BenchmarkRunner:
2074*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
2075*da0073e9SAndroid Build Coastguard Worker        self.model_iter_fn = None
2076*da0073e9SAndroid Build Coastguard Worker        self.grad_scaler = DummyGradScaler()
2077*da0073e9SAndroid Build Coastguard Worker        self.autocast = contextlib.nullcontext
2078*da0073e9SAndroid Build Coastguard Worker        self.autocast_arg = {}
2079*da0073e9SAndroid Build Coastguard Worker        self.optimizer = None
2080*da0073e9SAndroid Build Coastguard Worker        self._args = None
2081*da0073e9SAndroid Build Coastguard Worker
2082*da0073e9SAndroid Build Coastguard Worker    def setup_amp(self, current_device=None):
2083*da0073e9SAndroid Build Coastguard Worker        if self.args.only in self.fp32_only_models:
2084*da0073e9SAndroid Build Coastguard Worker            return
2085*da0073e9SAndroid Build Coastguard Worker
2086*da0073e9SAndroid Build Coastguard Worker        devices = [current_device] if current_device else self.args.devices
2087*da0073e9SAndroid Build Coastguard Worker        if self.args.amp:
2088*da0073e9SAndroid Build Coastguard Worker            # AMP training can lead to small loss values which can undeflow
2089*da0073e9SAndroid Build Coastguard Worker            # gradient values returning in zero gradients. To solve this
2090*da0073e9SAndroid Build Coastguard Worker            # problem, PyTorch introduces GradScaler. GradScaler is a stateful
2091*da0073e9SAndroid Build Coastguard Worker            # structure, that scales the loss values to prevent underflow. Loss
2092*da0073e9SAndroid Build Coastguard Worker            # values are big at the beginning of training (therefore not
2093*da0073e9SAndroid Build Coastguard Worker            # requiring scaling), while loss value tends to be small as network
2094*da0073e9SAndroid Build Coastguard Worker            # starts getting better (requiring scaling). GradScaler manages all
2095*da0073e9SAndroid Build Coastguard Worker            # of this fine tuning, checking the gradients are turning to inf,
2096*da0073e9SAndroid Build Coastguard Worker            # discarding such batches.
2097*da0073e9SAndroid Build Coastguard Worker
2098*da0073e9SAndroid Build Coastguard Worker            # Since we are not running a long iteration, default value of
2099*da0073e9SAndroid Build Coastguard Worker            # init_scale 65536 is going to turn all gradients to inf. Therefore,
2100*da0073e9SAndroid Build Coastguard Worker            # we just use a init_scale of 2.0 for benchmarking purpose.
2101*da0073e9SAndroid Build Coastguard Worker
2102*da0073e9SAndroid Build Coastguard Worker            # Disabling Gradscaler because
2103*da0073e9SAndroid Build Coastguard Worker            #  1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
2104*da0073e9SAndroid Build Coastguard Worker            #  2) Current setup shares grad_scaler for eager and dynamo model,
2105*da0073e9SAndroid Build Coastguard Worker            #  which is bad as Gradscaler has state and can adjust the scaling
2106*da0073e9SAndroid Build Coastguard Worker            #  factor between eager and dynamo run, making accuracy check
2107*da0073e9SAndroid Build Coastguard Worker            #  harder.
2108*da0073e9SAndroid Build Coastguard Worker            # self.grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
2109*da0073e9SAndroid Build Coastguard Worker            self.autocast = functools.partial(
2110*da0073e9SAndroid Build Coastguard Worker                torch.amp.autocast, device_type=devices[0]
2111*da0073e9SAndroid Build Coastguard Worker            )
2112*da0073e9SAndroid Build Coastguard Worker            if self.args.amp_dtype:
2113*da0073e9SAndroid Build Coastguard Worker                amp_dtype = (
2114*da0073e9SAndroid Build Coastguard Worker                    torch.float16
2115*da0073e9SAndroid Build Coastguard Worker                    if self.args.amp_dtype == "float16"
2116*da0073e9SAndroid Build Coastguard Worker                    else torch.bfloat16
2117*da0073e9SAndroid Build Coastguard Worker                )
2118*da0073e9SAndroid Build Coastguard Worker                self.autocast_arg["dtype"] = amp_dtype
2119*da0073e9SAndroid Build Coastguard Worker
2120*da0073e9SAndroid Build Coastguard Worker    def init_optimizer(self, name, device, params):
2121*da0073e9SAndroid Build Coastguard Worker        if device == "cuda" and self.args.training and name not in CI_SKIP_OPTIMIZER:
2122*da0073e9SAndroid Build Coastguard Worker            if (name in CI_USE_SGD and self.args.ci) or name in BENCHMARK_USE_SGD:
2123*da0073e9SAndroid Build Coastguard Worker                self.optimizer = torch.optim.SGD(params, lr=0.01, foreach=True)
2124*da0073e9SAndroid Build Coastguard Worker                # Disable multi_tensor_sgd for benchmarking, there isn't a large performance benefit (~1%) to compiling
2125*da0073e9SAndroid Build Coastguard Worker                # this optimizer because it is a single foreach add, and increases compile time.
2126*da0073e9SAndroid Build Coastguard Worker                # After autotuning and fake tensor caching lands, we can enable, becuase the compile time impact will be lower.
2127*da0073e9SAndroid Build Coastguard Worker                # Fake Tensor caching: https://github.com/pytorch/pytorch/pull/113873
2128*da0073e9SAndroid Build Coastguard Worker                # Autotuning: https://github.com/pytorch/pytorch/issues/117447
2129*da0073e9SAndroid Build Coastguard Worker                self.optimizer.step = torch._dynamo.disable(self.optimizer.step)
2130*da0073e9SAndroid Build Coastguard Worker            else:
2131*da0073e9SAndroid Build Coastguard Worker                self.optimizer = torch.optim.Adam(
2132*da0073e9SAndroid Build Coastguard Worker                    params, lr=0.01, capturable=True, foreach=True
2133*da0073e9SAndroid Build Coastguard Worker                )
2134*da0073e9SAndroid Build Coastguard Worker        else:
2135*da0073e9SAndroid Build Coastguard Worker            self.optimizer = None
2136*da0073e9SAndroid Build Coastguard Worker
2137*da0073e9SAndroid Build Coastguard Worker    @property
2138*da0073e9SAndroid Build Coastguard Worker    def args(self):
2139*da0073e9SAndroid Build Coastguard Worker        return self._args
2140*da0073e9SAndroid Build Coastguard Worker
2141*da0073e9SAndroid Build Coastguard Worker    @args.setter
2142*da0073e9SAndroid Build Coastguard Worker    def args(self, args):
2143*da0073e9SAndroid Build Coastguard Worker        self._args = args
2144*da0073e9SAndroid Build Coastguard Worker
2145*da0073e9SAndroid Build Coastguard Worker    @property
2146*da0073e9SAndroid Build Coastguard Worker    def skip_models(self):
2147*da0073e9SAndroid Build Coastguard Worker        return set()
2148*da0073e9SAndroid Build Coastguard Worker
2149*da0073e9SAndroid Build Coastguard Worker    @property
2150*da0073e9SAndroid Build Coastguard Worker    def skip_models_for_cuda(self):
2151*da0073e9SAndroid Build Coastguard Worker        return set()
2152*da0073e9SAndroid Build Coastguard Worker
2153*da0073e9SAndroid Build Coastguard Worker    @property
2154*da0073e9SAndroid Build Coastguard Worker    def skip_models_for_cpu(self):
2155*da0073e9SAndroid Build Coastguard Worker        return set()
2156*da0073e9SAndroid Build Coastguard Worker
2157*da0073e9SAndroid Build Coastguard Worker    @property
2158*da0073e9SAndroid Build Coastguard Worker    def skip_models_for_freezing(self):
2159*da0073e9SAndroid Build Coastguard Worker        return set()
2160*da0073e9SAndroid Build Coastguard Worker
2161*da0073e9SAndroid Build Coastguard Worker    @property
2162*da0073e9SAndroid Build Coastguard Worker    def slow_models(self):
2163*da0073e9SAndroid Build Coastguard Worker        return set()
2164*da0073e9SAndroid Build Coastguard Worker
2165*da0073e9SAndroid Build Coastguard Worker    @property
2166*da0073e9SAndroid Build Coastguard Worker    def very_slow_models(self):
2167*da0073e9SAndroid Build Coastguard Worker        return set()
2168*da0073e9SAndroid Build Coastguard Worker
2169*da0073e9SAndroid Build Coastguard Worker    @property
2170*da0073e9SAndroid Build Coastguard Worker    def non_deterministic_models(self):
2171*da0073e9SAndroid Build Coastguard Worker        return set()
2172*da0073e9SAndroid Build Coastguard Worker
2173*da0073e9SAndroid Build Coastguard Worker    @property
2174*da0073e9SAndroid Build Coastguard Worker    def fp32_only_models(self):
2175*da0073e9SAndroid Build Coastguard Worker        return set()
2176*da0073e9SAndroid Build Coastguard Worker
2177*da0073e9SAndroid Build Coastguard Worker    @property
2178*da0073e9SAndroid Build Coastguard Worker    def force_amp_for_fp16_bf16_models(self):
2179*da0073e9SAndroid Build Coastguard Worker        return set()
2180*da0073e9SAndroid Build Coastguard Worker
2181*da0073e9SAndroid Build Coastguard Worker    @property
2182*da0073e9SAndroid Build Coastguard Worker    def force_fp16_for_bf16_models(self):
2183*da0073e9SAndroid Build Coastguard Worker        return set()
2184*da0073e9SAndroid Build Coastguard Worker
2185*da0073e9SAndroid Build Coastguard Worker    @property
2186*da0073e9SAndroid Build Coastguard Worker    def skip_not_suitable_for_training_models(self):
2187*da0073e9SAndroid Build Coastguard Worker        return set()
2188*da0073e9SAndroid Build Coastguard Worker
2189*da0073e9SAndroid Build Coastguard Worker    @property
2190*da0073e9SAndroid Build Coastguard Worker    def failing_torchinductor_models(self):
2191*da0073e9SAndroid Build Coastguard Worker        return set()
2192*da0073e9SAndroid Build Coastguard Worker
2193*da0073e9SAndroid Build Coastguard Worker    @property
2194*da0073e9SAndroid Build Coastguard Worker    def failing_fx2trt_models(self):
2195*da0073e9SAndroid Build Coastguard Worker        return set()
2196*da0073e9SAndroid Build Coastguard Worker
2197*da0073e9SAndroid Build Coastguard Worker    @property
2198*da0073e9SAndroid Build Coastguard Worker    def skip_accuracy_checks_large_models_dashboard(self):
2199*da0073e9SAndroid Build Coastguard Worker        return set()
2200*da0073e9SAndroid Build Coastguard Worker
2201*da0073e9SAndroid Build Coastguard Worker    @property
2202*da0073e9SAndroid Build Coastguard Worker    def skip_accuracy_check_as_eager_non_deterministic(self):
2203*da0073e9SAndroid Build Coastguard Worker        return set()
2204*da0073e9SAndroid Build Coastguard Worker
2205*da0073e9SAndroid Build Coastguard Worker    @property
2206*da0073e9SAndroid Build Coastguard Worker    def skip_multiprocess_models(self):
2207*da0073e9SAndroid Build Coastguard Worker        return set()
2208*da0073e9SAndroid Build Coastguard Worker
2209*da0073e9SAndroid Build Coastguard Worker    @property
2210*da0073e9SAndroid Build Coastguard Worker    def skip_models_due_to_control_flow(self):
2211*da0073e9SAndroid Build Coastguard Worker        return set()
2212*da0073e9SAndroid Build Coastguard Worker
2213*da0073e9SAndroid Build Coastguard Worker    @property
2214*da0073e9SAndroid Build Coastguard Worker    def guard_on_nn_module_models(self):
2215*da0073e9SAndroid Build Coastguard Worker        return set()
2216*da0073e9SAndroid Build Coastguard Worker
2217*da0073e9SAndroid Build Coastguard Worker    def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
2218*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
2219*da0073e9SAndroid Build Coastguard Worker
2220*da0073e9SAndroid Build Coastguard Worker    @property
2221*da0073e9SAndroid Build Coastguard Worker    def equal_nan(self):
2222*da0073e9SAndroid Build Coastguard Worker        equal_nan = True
2223*da0073e9SAndroid Build Coastguard Worker        if self.args.float32:
2224*da0073e9SAndroid Build Coastguard Worker            equal_nan = False
2225*da0073e9SAndroid Build Coastguard Worker        return equal_nan
2226*da0073e9SAndroid Build Coastguard Worker
2227*da0073e9SAndroid Build Coastguard Worker    def iter_models(self, args):
2228*da0073e9SAndroid Build Coastguard Worker        for model_name in self.iter_model_names(args):
2229*da0073e9SAndroid Build Coastguard Worker            for device in args.devices:
2230*da0073e9SAndroid Build Coastguard Worker                try:
2231*da0073e9SAndroid Build Coastguard Worker                    yield self.load_model(
2232*da0073e9SAndroid Build Coastguard Worker                        device,
2233*da0073e9SAndroid Build Coastguard Worker                        model_name,
2234*da0073e9SAndroid Build Coastguard Worker                        batch_size=args.batch_size,
2235*da0073e9SAndroid Build Coastguard Worker                    )
2236*da0073e9SAndroid Build Coastguard Worker                except NotImplementedError:
2237*da0073e9SAndroid Build Coastguard Worker                    continue  # bad benchmark implementation
2238*da0073e9SAndroid Build Coastguard Worker
2239*da0073e9SAndroid Build Coastguard Worker    def deepcopy_model(self, model):
2240*da0073e9SAndroid Build Coastguard Worker        return copy.deepcopy(model)
2241*da0073e9SAndroid Build Coastguard Worker
2242*da0073e9SAndroid Build Coastguard Worker    def cast_based_on_args(self, model, example_inputs):
2243*da0073e9SAndroid Build Coastguard Worker        if self.args.float32 or self.args.only in self.fp32_only_models:
2244*da0073e9SAndroid Build Coastguard Worker            if not self.args.float32:
2245*da0073e9SAndroid Build Coastguard Worker                log.warning("Model %s supports float32 only", self.args.only)
2246*da0073e9SAndroid Build Coastguard Worker            model, example_inputs = cast_to_fp32(model, example_inputs)
2247*da0073e9SAndroid Build Coastguard Worker        elif self.args.float16:
2248*da0073e9SAndroid Build Coastguard Worker            if self.args.only in self.force_amp_for_fp16_bf16_models:
2249*da0073e9SAndroid Build Coastguard Worker                log.warning(
2250*da0073e9SAndroid Build Coastguard Worker                    "Model %s does not support float16, running with amp instead",
2251*da0073e9SAndroid Build Coastguard Worker                    self.args.only,
2252*da0073e9SAndroid Build Coastguard Worker                )
2253*da0073e9SAndroid Build Coastguard Worker                self.args.amp = True
2254*da0073e9SAndroid Build Coastguard Worker                self.setup_amp()
2255*da0073e9SAndroid Build Coastguard Worker            else:
2256*da0073e9SAndroid Build Coastguard Worker                model, example_inputs = cast_to_fp16(model, example_inputs)
2257*da0073e9SAndroid Build Coastguard Worker        elif self.args.bfloat16:
2258*da0073e9SAndroid Build Coastguard Worker            if self.args.only in self.force_amp_for_fp16_bf16_models:
2259*da0073e9SAndroid Build Coastguard Worker                log.warning(
2260*da0073e9SAndroid Build Coastguard Worker                    "Model %s does not support bfloat16, running with amp instead",
2261*da0073e9SAndroid Build Coastguard Worker                    self.args.only,
2262*da0073e9SAndroid Build Coastguard Worker                )
2263*da0073e9SAndroid Build Coastguard Worker                self.args.amp = True
2264*da0073e9SAndroid Build Coastguard Worker                self.setup_amp()
2265*da0073e9SAndroid Build Coastguard Worker            elif self.args.only in self.force_fp16_for_bf16_models:
2266*da0073e9SAndroid Build Coastguard Worker                log.warning(
2267*da0073e9SAndroid Build Coastguard Worker                    "Model %s does not support bfloat16, running with float16 instead",
2268*da0073e9SAndroid Build Coastguard Worker                    self.args.only,
2269*da0073e9SAndroid Build Coastguard Worker                )
2270*da0073e9SAndroid Build Coastguard Worker                model, example_inputs = cast_to_fp16(model, example_inputs)
2271*da0073e9SAndroid Build Coastguard Worker            else:
2272*da0073e9SAndroid Build Coastguard Worker                model, example_inputs = cast_to_bf16(model, example_inputs)
2273*da0073e9SAndroid Build Coastguard Worker
2274*da0073e9SAndroid Build Coastguard Worker        return model, example_inputs
2275*da0073e9SAndroid Build Coastguard Worker
2276*da0073e9SAndroid Build Coastguard Worker    def validate_model(self, model, example_inputs):
2277*da0073e9SAndroid Build Coastguard Worker        """
2278*da0073e9SAndroid Build Coastguard Worker        Runs the eager model with example inputs to ensure that eager passes.
2279*da0073e9SAndroid Build Coastguard Worker        """
2280*da0073e9SAndroid Build Coastguard Worker        model = self.deepcopy_model(model)
2281*da0073e9SAndroid Build Coastguard Worker        example_inputs = clone_inputs(example_inputs)
2282*da0073e9SAndroid Build Coastguard Worker        model, example_inputs = self.cast_based_on_args(model, example_inputs)
2283*da0073e9SAndroid Build Coastguard Worker        try:
2284*da0073e9SAndroid Build Coastguard Worker            self.model_iter_fn(model, example_inputs)
2285*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
2286*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Eager run failed") from e
2287*da0073e9SAndroid Build Coastguard Worker
2288*da0073e9SAndroid Build Coastguard Worker    def maybe_cast(self, model, example_inputs):
2289*da0073e9SAndroid Build Coastguard Worker        model, example_inputs = self.cast_based_on_args(model, example_inputs)
2290*da0073e9SAndroid Build Coastguard Worker        return model, example_inputs
2291*da0073e9SAndroid Build Coastguard Worker
2292*da0073e9SAndroid Build Coastguard Worker    def decay_batch_exp(self, batch_size, factor=0.5, divisor=2):
2293*da0073e9SAndroid Build Coastguard Worker        out_batch_size = batch_size * factor
2294*da0073e9SAndroid Build Coastguard Worker        if out_batch_size > divisor:
2295*da0073e9SAndroid Build Coastguard Worker            out_batch_size = (out_batch_size + 1) // divisor * divisor
2296*da0073e9SAndroid Build Coastguard Worker        else:
2297*da0073e9SAndroid Build Coastguard Worker            out_batch_size = batch_size - 1
2298*da0073e9SAndroid Build Coastguard Worker        return max(0, int(out_batch_size))
2299*da0073e9SAndroid Build Coastguard Worker
2300*da0073e9SAndroid Build Coastguard Worker    def batch_size_finder(self, device, model_name, initial_batch_size=1024):
2301*da0073e9SAndroid Build Coastguard Worker        batch_size = initial_batch_size
2302*da0073e9SAndroid Build Coastguard Worker        while batch_size >= 1:
2303*da0073e9SAndroid Build Coastguard Worker            empty_gpu_cache(current_device)
2304*da0073e9SAndroid Build Coastguard Worker            try:
2305*da0073e9SAndroid Build Coastguard Worker                device, name, model, example_inputs, _ = self.load_model(
2306*da0073e9SAndroid Build Coastguard Worker                    device,
2307*da0073e9SAndroid Build Coastguard Worker                    model_name,
2308*da0073e9SAndroid Build Coastguard Worker                    batch_size,
2309*da0073e9SAndroid Build Coastguard Worker                )
2310*da0073e9SAndroid Build Coastguard Worker                self.model_iter_fn(model, example_inputs)
2311*da0073e9SAndroid Build Coastguard Worker                return batch_size
2312*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as e:
2313*da0073e9SAndroid Build Coastguard Worker                error_str = str(e)
2314*da0073e9SAndroid Build Coastguard Worker                if "channels_last" in error_str:
2315*da0073e9SAndroid Build Coastguard Worker                    break
2316*da0073e9SAndroid Build Coastguard Worker            batch_size = self.decay_batch_exp(batch_size)
2317*da0073e9SAndroid Build Coastguard Worker        return 1
2318*da0073e9SAndroid Build Coastguard Worker
2319*da0073e9SAndroid Build Coastguard Worker    def run_n_iterations(self, mod, inputs):
2320*da0073e9SAndroid Build Coastguard Worker        n = self.args.iterations
2321*da0073e9SAndroid Build Coastguard Worker        for _ in range(n - 1):
2322*da0073e9SAndroid Build Coastguard Worker            self.model_iter_fn(mod, inputs, collect_outputs=False)
2323*da0073e9SAndroid Build Coastguard Worker        return self.model_iter_fn(mod, inputs, collect_outputs=True)
2324*da0073e9SAndroid Build Coastguard Worker
2325*da0073e9SAndroid Build Coastguard Worker    @torch._disable_dynamo(recursive=True)
2326*da0073e9SAndroid Build Coastguard Worker    def optimizer_zero_grad(self, mod):
2327*da0073e9SAndroid Build Coastguard Worker        if self.optimizer is not None:
2328*da0073e9SAndroid Build Coastguard Worker            self.optimizer.zero_grad(True)
2329*da0073e9SAndroid Build Coastguard Worker        else:
2330*da0073e9SAndroid Build Coastguard Worker            mod.zero_grad(True)
2331*da0073e9SAndroid Build Coastguard Worker
2332*da0073e9SAndroid Build Coastguard Worker    def optimizer_step(self):
2333*da0073e9SAndroid Build Coastguard Worker        if self.optimizer is not None:
2334*da0073e9SAndroid Build Coastguard Worker            self.optimizer.step()
2335*da0073e9SAndroid Build Coastguard Worker
2336*da0073e9SAndroid Build Coastguard Worker    def get_benchmark_indices(self, length):
2337*da0073e9SAndroid Build Coastguard Worker        start = self._args.partition_id * (length // self._args.total_partitions)
2338*da0073e9SAndroid Build Coastguard Worker        end = (
2339*da0073e9SAndroid Build Coastguard Worker            (self._args.partition_id + 1) * (length // self._args.total_partitions)
2340*da0073e9SAndroid Build Coastguard Worker            if self._args.partition_id < self._args.total_partitions - 1
2341*da0073e9SAndroid Build Coastguard Worker            else length
2342*da0073e9SAndroid Build Coastguard Worker        )
2343*da0073e9SAndroid Build Coastguard Worker        return start, end
2344*da0073e9SAndroid Build Coastguard Worker
2345*da0073e9SAndroid Build Coastguard Worker    def get_fsdp_auto_wrap_policy(self, model_name: str):
2346*da0073e9SAndroid Build Coastguard Worker        from diffusers.models.transformer_2d import Transformer2DModel
2347*da0073e9SAndroid Build Coastguard Worker        from torchbenchmark.models.nanogpt.model import Block
2348*da0073e9SAndroid Build Coastguard Worker        from transformers.models.llama.modeling_llama import LlamaDecoderLayer
2349*da0073e9SAndroid Build Coastguard Worker
2350*da0073e9SAndroid Build Coastguard Worker        from transformers.models.t5.modeling_t5 import T5Block
2351*da0073e9SAndroid Build Coastguard Worker        from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
2352*da0073e9SAndroid Build Coastguard Worker
2353*da0073e9SAndroid Build Coastguard Worker        from torch.distributed.fsdp.wrap import (
2354*da0073e9SAndroid Build Coastguard Worker            ModuleWrapPolicy,
2355*da0073e9SAndroid Build Coastguard Worker            size_based_auto_wrap_policy,
2356*da0073e9SAndroid Build Coastguard Worker        )
2357*da0073e9SAndroid Build Coastguard Worker
2358*da0073e9SAndroid Build Coastguard Worker        # handcrafted wrap policy
2359*da0073e9SAndroid Build Coastguard Worker        MODEL_FSDP_WRAP = {
2360*da0073e9SAndroid Build Coastguard Worker            "stable_diffusion_unet": (Transformer2DModel,),
2361*da0073e9SAndroid Build Coastguard Worker            "hf_T5": (T5Block,),
2362*da0073e9SAndroid Build Coastguard Worker            "hf_T5_base": (T5Block,),
2363*da0073e9SAndroid Build Coastguard Worker            "hf_T5_large": (T5Block,),
2364*da0073e9SAndroid Build Coastguard Worker            "hf_Whisper": (WhisperEncoderLayer,),
2365*da0073e9SAndroid Build Coastguard Worker            "llama_v2_7b_16h": (LlamaDecoderLayer,),
2366*da0073e9SAndroid Build Coastguard Worker            "nanogpt": (Block,),
2367*da0073e9SAndroid Build Coastguard Worker        }
2368*da0073e9SAndroid Build Coastguard Worker
2369*da0073e9SAndroid Build Coastguard Worker        if model_name not in MODEL_FSDP_WRAP:
2370*da0073e9SAndroid Build Coastguard Worker            # default to using wrap policy based on module size
2371*da0073e9SAndroid Build Coastguard Worker            return functools.partial(
2372*da0073e9SAndroid Build Coastguard Worker                size_based_auto_wrap_policy, recurse=True, min_num_params=int(1e5)
2373*da0073e9SAndroid Build Coastguard Worker            )
2374*da0073e9SAndroid Build Coastguard Worker
2375*da0073e9SAndroid Build Coastguard Worker        return ModuleWrapPolicy(MODEL_FSDP_WRAP[model_name])
2376*da0073e9SAndroid Build Coastguard Worker
2377*da0073e9SAndroid Build Coastguard Worker    def deepcopy_and_maybe_parallelize(self, model):
2378*da0073e9SAndroid Build Coastguard Worker        model = self.deepcopy_model(model)
2379*da0073e9SAndroid Build Coastguard Worker        if self.args.ddp:
2380*da0073e9SAndroid Build Coastguard Worker            assert (
2381*da0073e9SAndroid Build Coastguard Worker                torch.distributed.is_available()
2382*da0073e9SAndroid Build Coastguard Worker            ), "Can't use DDP without a distributed enabled build"
2383*da0073e9SAndroid Build Coastguard Worker            from torch.nn.parallel import DistributedDataParallel as DDP
2384*da0073e9SAndroid Build Coastguard Worker
2385*da0073e9SAndroid Build Coastguard Worker            model = DDP(model, find_unused_parameters=True)
2386*da0073e9SAndroid Build Coastguard Worker        elif self.args.fsdp:
2387*da0073e9SAndroid Build Coastguard Worker            assert (
2388*da0073e9SAndroid Build Coastguard Worker                torch.distributed.is_available()
2389*da0073e9SAndroid Build Coastguard Worker            ), "Can't use FSDP without a distributed enabled build"
2390*da0073e9SAndroid Build Coastguard Worker            from torch.distributed.fsdp import (
2391*da0073e9SAndroid Build Coastguard Worker                FullyShardedDataParallel as FSDP,
2392*da0073e9SAndroid Build Coastguard Worker                MixedPrecision,
2393*da0073e9SAndroid Build Coastguard Worker            )
2394*da0073e9SAndroid Build Coastguard Worker
2395*da0073e9SAndroid Build Coastguard Worker            if self.args.float16:
2396*da0073e9SAndroid Build Coastguard Worker                dtype = torch.float16
2397*da0073e9SAndroid Build Coastguard Worker            elif self.args.bfloat16:
2398*da0073e9SAndroid Build Coastguard Worker                dtype = torch.bfloat16
2399*da0073e9SAndroid Build Coastguard Worker            else:
2400*da0073e9SAndroid Build Coastguard Worker                dtype = torch.float32
2401*da0073e9SAndroid Build Coastguard Worker
2402*da0073e9SAndroid Build Coastguard Worker            mp_policy = MixedPrecision(
2403*da0073e9SAndroid Build Coastguard Worker                param_dtype=dtype,
2404*da0073e9SAndroid Build Coastguard Worker                # Gradient communication precision.
2405*da0073e9SAndroid Build Coastguard Worker                reduce_dtype=dtype,
2406*da0073e9SAndroid Build Coastguard Worker                # Buffer precision.
2407*da0073e9SAndroid Build Coastguard Worker                buffer_dtype=dtype,
2408*da0073e9SAndroid Build Coastguard Worker            )
2409*da0073e9SAndroid Build Coastguard Worker
2410*da0073e9SAndroid Build Coastguard Worker            model = FSDP(
2411*da0073e9SAndroid Build Coastguard Worker                model,
2412*da0073e9SAndroid Build Coastguard Worker                use_orig_params=True,
2413*da0073e9SAndroid Build Coastguard Worker                device_id=torch.cuda.current_device()
2414*da0073e9SAndroid Build Coastguard Worker                if self.args.devices[-1] == "cuda"
2415*da0073e9SAndroid Build Coastguard Worker                else None,
2416*da0073e9SAndroid Build Coastguard Worker                mixed_precision=mp_policy,
2417*da0073e9SAndroid Build Coastguard Worker                limit_all_gathers=True,
2418*da0073e9SAndroid Build Coastguard Worker                auto_wrap_policy=self.get_fsdp_auto_wrap_policy(self.args.only),
2419*da0073e9SAndroid Build Coastguard Worker            )
2420*da0073e9SAndroid Build Coastguard Worker        return model
2421*da0073e9SAndroid Build Coastguard Worker
2422*da0073e9SAndroid Build Coastguard Worker    def check_accuracy(
2423*da0073e9SAndroid Build Coastguard Worker        self, name, model, example_inputs, optimize_ctx, experiment, tag
2424*da0073e9SAndroid Build Coastguard Worker    ):
2425*da0073e9SAndroid Build Coastguard Worker        """
2426*da0073e9SAndroid Build Coastguard Worker        Checks accuracy.
2427*da0073e9SAndroid Build Coastguard Worker        1) Collect the outputs with fp64 datatype. This is useful for error checking.
2428*da0073e9SAndroid Build Coastguard Worker        2) Checks if eager itself has variations.
2429*da0073e9SAndroid Build Coastguard Worker        """
2430*da0073e9SAndroid Build Coastguard Worker        start_stats = get_dynamo_stats()
2431*da0073e9SAndroid Build Coastguard Worker
2432*da0073e9SAndroid Build Coastguard Worker        def record_status(accuracy_status, dynamo_start_stats):
2433*da0073e9SAndroid Build Coastguard Worker            """
2434*da0073e9SAndroid Build Coastguard Worker            Records the status in the csv file
2435*da0073e9SAndroid Build Coastguard Worker            """
2436*da0073e9SAndroid Build Coastguard Worker            if current_name in self.non_deterministic_models:
2437*da0073e9SAndroid Build Coastguard Worker                if accuracy_status in (
2438*da0073e9SAndroid Build Coastguard Worker                    "pass",
2439*da0073e9SAndroid Build Coastguard Worker                    "eager_two_runs_differ",
2440*da0073e9SAndroid Build Coastguard Worker                    "fail_accuracy",
2441*da0073e9SAndroid Build Coastguard Worker                ):
2442*da0073e9SAndroid Build Coastguard Worker                    accuracy_status = "pass"
2443*da0073e9SAndroid Build Coastguard Worker
2444*da0073e9SAndroid Build Coastguard Worker            headers = ["dev", "name", "batch_size", "accuracy"]
2445*da0073e9SAndroid Build Coastguard Worker            fields = [current_device, current_name, current_batch_size, accuracy_status]
2446*da0073e9SAndroid Build Coastguard Worker
2447*da0073e9SAndroid Build Coastguard Worker            if tag is not None:
2448*da0073e9SAndroid Build Coastguard Worker                headers.insert(3, "tag")
2449*da0073e9SAndroid Build Coastguard Worker                fields.insert(3, tag)
2450*da0073e9SAndroid Build Coastguard Worker
2451*da0073e9SAndroid Build Coastguard Worker            dynamo_stats = get_dynamo_stats()
2452*da0073e9SAndroid Build Coastguard Worker            dynamo_stats.subtract(dynamo_start_stats)
2453*da0073e9SAndroid Build Coastguard Worker            for k, v in dynamo_stats.items():
2454*da0073e9SAndroid Build Coastguard Worker                headers.append(k)
2455*da0073e9SAndroid Build Coastguard Worker                fields.append(v)
2456*da0073e9SAndroid Build Coastguard Worker
2457*da0073e9SAndroid Build Coastguard Worker            output_csv(output_filename, headers, fields)
2458*da0073e9SAndroid Build Coastguard Worker            return accuracy_status
2459*da0073e9SAndroid Build Coastguard Worker
2460*da0073e9SAndroid Build Coastguard Worker        if name in self.skip_accuracy_checks_large_models_dashboard:
2461*da0073e9SAndroid Build Coastguard Worker            return record_status("pass_due_to_skip", dynamo_start_stats=start_stats)
2462*da0073e9SAndroid Build Coastguard Worker
2463*da0073e9SAndroid Build Coastguard Worker        with self.pick_grad(name, self.args.training):
2464*da0073e9SAndroid Build Coastguard Worker            # Collect the fp64 reference outputs to be used later for accuracy checking.
2465*da0073e9SAndroid Build Coastguard Worker            fp64_outputs = None
2466*da0073e9SAndroid Build Coastguard Worker            model_fp64 = None
2467*da0073e9SAndroid Build Coastguard Worker            inputs_fp64 = None
2468*da0073e9SAndroid Build Coastguard Worker            try:
2469*da0073e9SAndroid Build Coastguard Worker                model_fp64, inputs_fp64 = cast_to_fp64(
2470*da0073e9SAndroid Build Coastguard Worker                    self.deepcopy_and_maybe_parallelize(model),
2471*da0073e9SAndroid Build Coastguard Worker                    clone_inputs(example_inputs),
2472*da0073e9SAndroid Build Coastguard Worker                )
2473*da0073e9SAndroid Build Coastguard Worker                self.init_optimizer(name, current_device, model_fp64.parameters())
2474*da0073e9SAndroid Build Coastguard Worker                fp64_outputs = self.run_n_iterations(model_fp64, inputs_fp64)
2475*da0073e9SAndroid Build Coastguard Worker                fp64_outputs = tree_map(
2476*da0073e9SAndroid Build Coastguard Worker                    lambda x: x.to(torch.float64)
2477*da0073e9SAndroid Build Coastguard Worker                    if isinstance(x, torch.Tensor) and x.is_floating_point()
2478*da0073e9SAndroid Build Coastguard Worker                    else x,
2479*da0073e9SAndroid Build Coastguard Worker                    fp64_outputs,
2480*da0073e9SAndroid Build Coastguard Worker                )
2481*da0073e9SAndroid Build Coastguard Worker            except Exception:
2482*da0073e9SAndroid Build Coastguard Worker                log.warning(
2483*da0073e9SAndroid Build Coastguard Worker                    "fp64 golden ref were not generated for %s. Setting accuracy check to cosine",
2484*da0073e9SAndroid Build Coastguard Worker                    name,
2485*da0073e9SAndroid Build Coastguard Worker                )
2486*da0073e9SAndroid Build Coastguard Worker                self.args.cosine = True
2487*da0073e9SAndroid Build Coastguard Worker                fp64_outputs = None
2488*da0073e9SAndroid Build Coastguard Worker            finally:
2489*da0073e9SAndroid Build Coastguard Worker                del model_fp64, inputs_fp64
2490*da0073e9SAndroid Build Coastguard Worker                empty_gpu_cache(current_device)
2491*da0073e9SAndroid Build Coastguard Worker
2492*da0073e9SAndroid Build Coastguard Worker            tolerance, cos_similarity = self.get_tolerance_and_cosine_flag(
2493*da0073e9SAndroid Build Coastguard Worker                self.args.training, current_device, name
2494*da0073e9SAndroid Build Coastguard Worker            )
2495*da0073e9SAndroid Build Coastguard Worker
2496*da0073e9SAndroid Build Coastguard Worker            # Cast the model to float16/float32 as necessary
2497*da0073e9SAndroid Build Coastguard Worker            model, example_inputs = self.maybe_cast(model, example_inputs)
2498*da0073e9SAndroid Build Coastguard Worker            accuracy_status = "pass"
2499*da0073e9SAndroid Build Coastguard Worker
2500*da0073e9SAndroid Build Coastguard Worker            # Get results of native pytorch
2501*da0073e9SAndroid Build Coastguard Worker            reset_rng_state()
2502*da0073e9SAndroid Build Coastguard Worker            model_copy = None
2503*da0073e9SAndroid Build Coastguard Worker            try:
2504*da0073e9SAndroid Build Coastguard Worker                model_copy = self.deepcopy_and_maybe_parallelize(model)
2505*da0073e9SAndroid Build Coastguard Worker                self.init_optimizer(name, current_device, model_copy.parameters())
2506*da0073e9SAndroid Build Coastguard Worker                correct_result = self.run_n_iterations(
2507*da0073e9SAndroid Build Coastguard Worker                    model_copy, clone_inputs(example_inputs)
2508*da0073e9SAndroid Build Coastguard Worker                )
2509*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
2510*da0073e9SAndroid Build Coastguard Worker                accuracy_status = (
2511*da0073e9SAndroid Build Coastguard Worker                    "eager_1st_run_OOM"
2512*da0073e9SAndroid Build Coastguard Worker                    if isinstance(e, torch.cuda.OutOfMemoryError)
2513*da0073e9SAndroid Build Coastguard Worker                    else "eager_1st_run_fail"
2514*da0073e9SAndroid Build Coastguard Worker                )
2515*da0073e9SAndroid Build Coastguard Worker                log.exception("")
2516*da0073e9SAndroid Build Coastguard Worker                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2517*da0073e9SAndroid Build Coastguard Worker            finally:
2518*da0073e9SAndroid Build Coastguard Worker                del model_copy
2519*da0073e9SAndroid Build Coastguard Worker                empty_gpu_cache(current_device)
2520*da0073e9SAndroid Build Coastguard Worker
2521*da0073e9SAndroid Build Coastguard Worker            # Rerun native pytorch
2522*da0073e9SAndroid Build Coastguard Worker            reset_rng_state()
2523*da0073e9SAndroid Build Coastguard Worker            model_copy = None
2524*da0073e9SAndroid Build Coastguard Worker            try:
2525*da0073e9SAndroid Build Coastguard Worker                model_copy = self.deepcopy_and_maybe_parallelize(model)
2526*da0073e9SAndroid Build Coastguard Worker                self.init_optimizer(name, current_device, model_copy.parameters())
2527*da0073e9SAndroid Build Coastguard Worker                correct_rerun_result = self.run_n_iterations(
2528*da0073e9SAndroid Build Coastguard Worker                    model_copy, clone_inputs(example_inputs)
2529*da0073e9SAndroid Build Coastguard Worker                )
2530*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
2531*da0073e9SAndroid Build Coastguard Worker                accuracy_status = (
2532*da0073e9SAndroid Build Coastguard Worker                    "eager_2nd_run_OOM"
2533*da0073e9SAndroid Build Coastguard Worker                    if isinstance(e, torch.cuda.OutOfMemoryError)
2534*da0073e9SAndroid Build Coastguard Worker                    else "eager_2nd_run_fail"
2535*da0073e9SAndroid Build Coastguard Worker                )
2536*da0073e9SAndroid Build Coastguard Worker                log.exception("")
2537*da0073e9SAndroid Build Coastguard Worker                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2538*da0073e9SAndroid Build Coastguard Worker            finally:
2539*da0073e9SAndroid Build Coastguard Worker                del model_copy
2540*da0073e9SAndroid Build Coastguard Worker                empty_gpu_cache(current_device)
2541*da0073e9SAndroid Build Coastguard Worker
2542*da0073e9SAndroid Build Coastguard Worker            # Two eager runs should have exactly same result
2543*da0073e9SAndroid Build Coastguard Worker            is_same = True
2544*da0073e9SAndroid Build Coastguard Worker            try:
2545*da0073e9SAndroid Build Coastguard Worker                if (
2546*da0073e9SAndroid Build Coastguard Worker                    name not in self.skip_accuracy_check_as_eager_non_deterministic
2547*da0073e9SAndroid Build Coastguard Worker                    and not same(
2548*da0073e9SAndroid Build Coastguard Worker                        correct_result,
2549*da0073e9SAndroid Build Coastguard Worker                        correct_rerun_result,
2550*da0073e9SAndroid Build Coastguard Worker                        fp64_ref=None,
2551*da0073e9SAndroid Build Coastguard Worker                        cos_similarity=False,
2552*da0073e9SAndroid Build Coastguard Worker                        tol=0,
2553*da0073e9SAndroid Build Coastguard Worker                        equal_nan=self.equal_nan,
2554*da0073e9SAndroid Build Coastguard Worker                    )
2555*da0073e9SAndroid Build Coastguard Worker                ):
2556*da0073e9SAndroid Build Coastguard Worker                    is_same = False
2557*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
2558*da0073e9SAndroid Build Coastguard Worker                # Sometimes torch.allclose may throw RuntimeError
2559*da0073e9SAndroid Build Coastguard Worker                is_same = False
2560*da0073e9SAndroid Build Coastguard Worker
2561*da0073e9SAndroid Build Coastguard Worker            if not is_same:
2562*da0073e9SAndroid Build Coastguard Worker                accuracy_status = "eager_two_runs_differ"
2563*da0073e9SAndroid Build Coastguard Worker                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2564*da0073e9SAndroid Build Coastguard Worker
2565*da0073e9SAndroid Build Coastguard Worker            correct_rerun_result = None
2566*da0073e9SAndroid Build Coastguard Worker
2567*da0073e9SAndroid Build Coastguard Worker            # Run with Dynamo
2568*da0073e9SAndroid Build Coastguard Worker            reset_rng_state()
2569*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.reset()
2570*da0073e9SAndroid Build Coastguard Worker            model_copy = None
2571*da0073e9SAndroid Build Coastguard Worker            try:
2572*da0073e9SAndroid Build Coastguard Worker                model_copy = self.deepcopy_and_maybe_parallelize(model)
2573*da0073e9SAndroid Build Coastguard Worker                self.init_optimizer(name, current_device, model_copy.parameters())
2574*da0073e9SAndroid Build Coastguard Worker                if self.args.export or self.args.export_aot_inductor:
2575*da0073e9SAndroid Build Coastguard Worker                    # apply export on module directly
2576*da0073e9SAndroid Build Coastguard Worker                    # no need for n iterations
2577*da0073e9SAndroid Build Coastguard Worker                    # the logic should be the same to self.model_iter_fn (forward_pass)
2578*da0073e9SAndroid Build Coastguard Worker                    with self.autocast(**self.autocast_arg):
2579*da0073e9SAndroid Build Coastguard Worker                        optimized_model_iter_fn = optimize_ctx(
2580*da0073e9SAndroid Build Coastguard Worker                            model_copy, example_inputs
2581*da0073e9SAndroid Build Coastguard Worker                        )
2582*da0073e9SAndroid Build Coastguard Worker                        new_result = optimized_model_iter_fn(model_copy, example_inputs)
2583*da0073e9SAndroid Build Coastguard Worker                else:
2584*da0073e9SAndroid Build Coastguard Worker                    optimized_model_iter_fn = optimize_ctx(self.run_n_iterations)
2585*da0073e9SAndroid Build Coastguard Worker                    with maybe_enable_compiled_autograd(self.args.compiled_autograd):
2586*da0073e9SAndroid Build Coastguard Worker                        new_result = optimized_model_iter_fn(model_copy, example_inputs)
2587*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
2588*da0073e9SAndroid Build Coastguard Worker                log.exception("")
2589*da0073e9SAndroid Build Coastguard Worker                print(
2590*da0073e9SAndroid Build Coastguard Worker                    "TorchDynamo optimized model failed to run because of following error"
2591*da0073e9SAndroid Build Coastguard Worker                )
2592*da0073e9SAndroid Build Coastguard Worker                accuracy_status = (
2593*da0073e9SAndroid Build Coastguard Worker                    "OOM"
2594*da0073e9SAndroid Build Coastguard Worker                    if isinstance(e, torch.cuda.OutOfMemoryError)
2595*da0073e9SAndroid Build Coastguard Worker                    else "fail_to_run"
2596*da0073e9SAndroid Build Coastguard Worker                )
2597*da0073e9SAndroid Build Coastguard Worker                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2598*da0073e9SAndroid Build Coastguard Worker            finally:
2599*da0073e9SAndroid Build Coastguard Worker                del model_copy
2600*da0073e9SAndroid Build Coastguard Worker
2601*da0073e9SAndroid Build Coastguard Worker            if name in self.skip_accuracy_check_as_eager_non_deterministic:
2602*da0073e9SAndroid Build Coastguard Worker                return record_status("pass_due_to_skip", dynamo_start_stats=start_stats)
2603*da0073e9SAndroid Build Coastguard Worker
2604*da0073e9SAndroid Build Coastguard Worker            if (
2605*da0073e9SAndroid Build Coastguard Worker                current_onnx_compiler == "torchscript"
2606*da0073e9SAndroid Build Coastguard Worker                or current_onnx_compiler == "dynamo"
2607*da0073e9SAndroid Build Coastguard Worker            ):
2608*da0073e9SAndroid Build Coastguard Worker                # Workaround for ONNX for non-tensor outputs
2609*da0073e9SAndroid Build Coastguard Worker                (
2610*da0073e9SAndroid Build Coastguard Worker                    correct_result,
2611*da0073e9SAndroid Build Coastguard Worker                    new_result,
2612*da0073e9SAndroid Build Coastguard Worker                    fp64_outputs,
2613*da0073e9SAndroid Build Coastguard Worker                ) = _OnnxPatch.patch_non_tensor_outputs(
2614*da0073e9SAndroid Build Coastguard Worker                    correct_result, new_result, fp64_outputs
2615*da0073e9SAndroid Build Coastguard Worker                )
2616*da0073e9SAndroid Build Coastguard Worker                # Relax tolerance for ONNX cuda
2617*da0073e9SAndroid Build Coastguard Worker                if current_device == "cuda":
2618*da0073e9SAndroid Build Coastguard Worker                    tolerance = 1e-2
2619*da0073e9SAndroid Build Coastguard Worker
2620*da0073e9SAndroid Build Coastguard Worker                # TODO: store correct_result into the dumped file for offline onnx model validation.
2621*da0073e9SAndroid Build Coastguard Worker                # The downside and potential problem, is that the output formats may be different.
2622*da0073e9SAndroid Build Coastguard Worker                # E.g., the output order might not match, None might be part of output, etc.
2623*da0073e9SAndroid Build Coastguard Worker
2624*da0073e9SAndroid Build Coastguard Worker            try:
2625*da0073e9SAndroid Build Coastguard Worker                if self.args.training and self.args.amp:
2626*da0073e9SAndroid Build Coastguard Worker                    if process_fn := self.get_output_amp_train_process_func.get(
2627*da0073e9SAndroid Build Coastguard Worker                        name, None
2628*da0073e9SAndroid Build Coastguard Worker                    ):
2629*da0073e9SAndroid Build Coastguard Worker                        correct_result = process_fn(correct_result)
2630*da0073e9SAndroid Build Coastguard Worker                        new_result = process_fn(new_result)
2631*da0073e9SAndroid Build Coastguard Worker                        fp64_outputs = process_fn(fp64_outputs)
2632*da0073e9SAndroid Build Coastguard Worker
2633*da0073e9SAndroid Build Coastguard Worker                if not same(
2634*da0073e9SAndroid Build Coastguard Worker                    correct_result,
2635*da0073e9SAndroid Build Coastguard Worker                    new_result,
2636*da0073e9SAndroid Build Coastguard Worker                    fp64_outputs,
2637*da0073e9SAndroid Build Coastguard Worker                    equal_nan=self.equal_nan,
2638*da0073e9SAndroid Build Coastguard Worker                    cos_similarity=cos_similarity,
2639*da0073e9SAndroid Build Coastguard Worker                    tol=tolerance,
2640*da0073e9SAndroid Build Coastguard Worker                ):
2641*da0073e9SAndroid Build Coastguard Worker                    is_same = False
2642*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
2643*da0073e9SAndroid Build Coastguard Worker                # Sometimes torch.allclose may throw RuntimeError
2644*da0073e9SAndroid Build Coastguard Worker                is_same = False
2645*da0073e9SAndroid Build Coastguard Worker
2646*da0073e9SAndroid Build Coastguard Worker            if not is_same:
2647*da0073e9SAndroid Build Coastguard Worker                if self.args.skip_accuracy_check:
2648*da0073e9SAndroid Build Coastguard Worker                    accuracy_status = "pass_due_to_skip"
2649*da0073e9SAndroid Build Coastguard Worker                else:
2650*da0073e9SAndroid Build Coastguard Worker                    accuracy_status = "fail_accuracy"
2651*da0073e9SAndroid Build Coastguard Worker                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2652*da0073e9SAndroid Build Coastguard Worker
2653*da0073e9SAndroid Build Coastguard Worker        return record_status(accuracy_status, dynamo_start_stats=start_stats)
2654*da0073e9SAndroid Build Coastguard Worker
2655*da0073e9SAndroid Build Coastguard Worker    def check_tolerance(
2656*da0073e9SAndroid Build Coastguard Worker        self, name, model, example_inputs, optimize_ctx, base_device="cpu"
2657*da0073e9SAndroid Build Coastguard Worker    ):
2658*da0073e9SAndroid Build Coastguard Worker        """
2659*da0073e9SAndroid Build Coastguard Worker        Checks tolerance based on https://pytorch.org/docs/stable/generated/torch.allclose.html.
2660*da0073e9SAndroid Build Coastguard Worker        """
2661*da0073e9SAndroid Build Coastguard Worker        tolerance_status = "pass"
2662*da0073e9SAndroid Build Coastguard Worker        if name in self.skip_accuracy_checks_large_models_dashboard:
2663*da0073e9SAndroid Build Coastguard Worker            tolerance_status = "pass_due_to_skip"
2664*da0073e9SAndroid Build Coastguard Worker            return tolerance_status
2665*da0073e9SAndroid Build Coastguard Worker        # Cast the model to float16/float32 as necessary
2666*da0073e9SAndroid Build Coastguard Worker        model, example_inputs = self.maybe_cast(model, example_inputs)
2667*da0073e9SAndroid Build Coastguard Worker
2668*da0073e9SAndroid Build Coastguard Worker        with self.pick_grad(name, self.args.training):
2669*da0073e9SAndroid Build Coastguard Worker            # Get results of native pytorch
2670*da0073e9SAndroid Build Coastguard Worker            reset_rng_state()
2671*da0073e9SAndroid Build Coastguard Worker            model_copy = copy.deepcopy(model)
2672*da0073e9SAndroid Build Coastguard Worker            model_copy = model_copy.to(base_device)
2673*da0073e9SAndroid Build Coastguard Worker            example_inputs_copy = copy.deepcopy(example_inputs)
2674*da0073e9SAndroid Build Coastguard Worker            example_inputs_copy = tree_map(
2675*da0073e9SAndroid Build Coastguard Worker                lambda x: x.to(base_device), example_inputs_copy
2676*da0073e9SAndroid Build Coastguard Worker            )
2677*da0073e9SAndroid Build Coastguard Worker            self.init_optimizer(name, base_device, model_copy.parameters())
2678*da0073e9SAndroid Build Coastguard Worker            correct_result = self.run_n_iterations(model_copy, example_inputs_copy)
2679*da0073e9SAndroid Build Coastguard Worker
2680*da0073e9SAndroid Build Coastguard Worker            # Run with Dynamo
2681*da0073e9SAndroid Build Coastguard Worker            # Sometime CI fails with random triton compilation failure which will be skipped for now
2682*da0073e9SAndroid Build Coastguard Worker            # TODO: revisit this after switching to new Triton runtime
2683*da0073e9SAndroid Build Coastguard Worker            reset_rng_state()
2684*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.reset()
2685*da0073e9SAndroid Build Coastguard Worker            try:
2686*da0073e9SAndroid Build Coastguard Worker                self.init_optimizer(name, current_device, model.parameters())
2687*da0073e9SAndroid Build Coastguard Worker                optimized_model_iter_fn = optimize_ctx(self.run_n_iterations)
2688*da0073e9SAndroid Build Coastguard Worker                new_result = optimized_model_iter_fn(model, example_inputs)
2689*da0073e9SAndroid Build Coastguard Worker            except Exception as e:
2690*da0073e9SAndroid Build Coastguard Worker                log.exception("")
2691*da0073e9SAndroid Build Coastguard Worker                print(
2692*da0073e9SAndroid Build Coastguard Worker                    "TorchDynamo optimized model failed to run because of following error"
2693*da0073e9SAndroid Build Coastguard Worker                )
2694*da0073e9SAndroid Build Coastguard Worker                return "fail_to_run"
2695*da0073e9SAndroid Build Coastguard Worker
2696*da0073e9SAndroid Build Coastguard Worker            def dump_max_mean_values(tol, ref, res):
2697*da0073e9SAndroid Build Coastguard Worker                if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)):
2698*da0073e9SAndroid Build Coastguard Worker                    for refi, resi in zip(ref, res):
2699*da0073e9SAndroid Build Coastguard Worker                        dump_max_mean_values(tol, refi, resi)
2700*da0073e9SAndroid Build Coastguard Worker                elif isinstance(ref, dict):
2701*da0073e9SAndroid Build Coastguard Worker                    for k in ref.keys():
2702*da0073e9SAndroid Build Coastguard Worker                        dump_max_mean_values(tol, ref[k], res[k])
2703*da0073e9SAndroid Build Coastguard Worker                elif isinstance(ref, torch.Tensor):
2704*da0073e9SAndroid Build Coastguard Worker                    res = res.to(base_device)
2705*da0073e9SAndroid Build Coastguard Worker                    t = torch.abs(ref - res) / (1 + torch.abs(ref))
2706*da0073e9SAndroid Build Coastguard Worker                    tol.append(t.flatten().to(torch.float32))
2707*da0073e9SAndroid Build Coastguard Worker                return tol
2708*da0073e9SAndroid Build Coastguard Worker
2709*da0073e9SAndroid Build Coastguard Worker            tol = []
2710*da0073e9SAndroid Build Coastguard Worker            dump_max_mean_values(tol, correct_result, new_result)
2711*da0073e9SAndroid Build Coastguard Worker            tol = torch.cat(tol)
2712*da0073e9SAndroid Build Coastguard Worker            tol = torch.tensor(tol)
2713*da0073e9SAndroid Build Coastguard Worker            max = torch.max(tol)
2714*da0073e9SAndroid Build Coastguard Worker            mean = torch.mean(tol)
2715*da0073e9SAndroid Build Coastguard Worker            div = torch.std(tol)
2716*da0073e9SAndroid Build Coastguard Worker            headers = ["dev", "name", "batch_size", "max", "mean", "std"]
2717*da0073e9SAndroid Build Coastguard Worker            fields = [
2718*da0073e9SAndroid Build Coastguard Worker                current_device,
2719*da0073e9SAndroid Build Coastguard Worker                current_name,
2720*da0073e9SAndroid Build Coastguard Worker                current_batch_size,
2721*da0073e9SAndroid Build Coastguard Worker                max.item(),
2722*da0073e9SAndroid Build Coastguard Worker                mean.item(),
2723*da0073e9SAndroid Build Coastguard Worker                div.item(),
2724*da0073e9SAndroid Build Coastguard Worker            ]
2725*da0073e9SAndroid Build Coastguard Worker            output_csv(output_filename, headers, fields)
2726*da0073e9SAndroid Build Coastguard Worker        return tolerance_status
2727*da0073e9SAndroid Build Coastguard Worker
2728*da0073e9SAndroid Build Coastguard Worker    def run_performance_test(
2729*da0073e9SAndroid Build Coastguard Worker        self, name, model, example_inputs, optimize_ctx, experiment, tag=None
2730*da0073e9SAndroid Build Coastguard Worker    ):
2731*da0073e9SAndroid Build Coastguard Worker        if self.args.xla:
2732*da0073e9SAndroid Build Coastguard Worker            with self.pick_grad(name, self.args.training):
2733*da0073e9SAndroid Build Coastguard Worker                return experiment(*self.maybe_cast(model, example_inputs))
2734*da0073e9SAndroid Build Coastguard Worker
2735*da0073e9SAndroid Build Coastguard Worker        def warmup(fn, model, example_inputs, mode, niters=5):
2736*da0073e9SAndroid Build Coastguard Worker            peak_mem = 0
2737*da0073e9SAndroid Build Coastguard Worker            start_stats = get_dynamo_stats()
2738*da0073e9SAndroid Build Coastguard Worker            try:
2739*da0073e9SAndroid Build Coastguard Worker                if current_device == "cuda":
2740*da0073e9SAndroid Build Coastguard Worker                    torch.cuda.reset_peak_memory_stats()
2741*da0073e9SAndroid Build Coastguard Worker                    empty_gpu_cache(current_device)
2742*da0073e9SAndroid Build Coastguard Worker                t0 = time.perf_counter()
2743*da0073e9SAndroid Build Coastguard Worker                for _ in range(niters):
2744*da0073e9SAndroid Build Coastguard Worker                    fn(model, example_inputs)
2745*da0073e9SAndroid Build Coastguard Worker                t1 = time.perf_counter()
2746*da0073e9SAndroid Build Coastguard Worker                latency = t1 - t0
2747*da0073e9SAndroid Build Coastguard Worker                if current_device == "cuda":
2748*da0073e9SAndroid Build Coastguard Worker                    peak_mem = get_peak_memory()
2749*da0073e9SAndroid Build Coastguard Worker                elif current_device == "cpu":
2750*da0073e9SAndroid Build Coastguard Worker                    total = psutil.virtual_memory().total
2751*da0073e9SAndroid Build Coastguard Worker                    percentage = psutil.Process(os.getpid()).memory_percent()
2752*da0073e9SAndroid Build Coastguard Worker                    peak_mem = percentage * total / 10**9
2753*da0073e9SAndroid Build Coastguard Worker            except Exception:
2754*da0073e9SAndroid Build Coastguard Worker                log.exception("Backend %s failed in warmup()", mode)
2755*da0073e9SAndroid Build Coastguard Worker                return sys.exit(-1)
2756*da0073e9SAndroid Build Coastguard Worker            dynamo_stats = get_dynamo_stats()
2757*da0073e9SAndroid Build Coastguard Worker            dynamo_stats.subtract(start_stats)
2758*da0073e9SAndroid Build Coastguard Worker            return latency, peak_mem, dynamo_stats
2759*da0073e9SAndroid Build Coastguard Worker
2760*da0073e9SAndroid Build Coastguard Worker        # Cast the model to float16/float32 as necessary
2761*da0073e9SAndroid Build Coastguard Worker        model, example_inputs = self.maybe_cast(model, example_inputs)
2762*da0073e9SAndroid Build Coastguard Worker
2763*da0073e9SAndroid Build Coastguard Worker        # Use distributed wrapping as necessary
2764*da0073e9SAndroid Build Coastguard Worker        model = self.deepcopy_and_maybe_parallelize(model)
2765*da0073e9SAndroid Build Coastguard Worker
2766*da0073e9SAndroid Build Coastguard Worker        self.init_optimizer(name, current_device, model.parameters())
2767*da0073e9SAndroid Build Coastguard Worker
2768*da0073e9SAndroid Build Coastguard Worker        # The self.autocast context is needed for the model we export with aot_compile,
2769*da0073e9SAndroid Build Coastguard Worker        # similar to what we do in the check_accuracy function
2770*da0073e9SAndroid Build Coastguard Worker        ctx = (
2771*da0073e9SAndroid Build Coastguard Worker            self.autocast(**self.autocast_arg)
2772*da0073e9SAndroid Build Coastguard Worker            if self.args.export_aot_inductor
2773*da0073e9SAndroid Build Coastguard Worker            else contextlib.nullcontext()
2774*da0073e9SAndroid Build Coastguard Worker        )
2775*da0073e9SAndroid Build Coastguard Worker
2776*da0073e9SAndroid Build Coastguard Worker        with self.pick_grad(name, self.args.training), ctx:
2777*da0073e9SAndroid Build Coastguard Worker            ok, total = Stats.reset_counters()
2778*da0073e9SAndroid Build Coastguard Worker            experiment_kwargs = {}
2779*da0073e9SAndroid Build Coastguard Worker            if tag is not None:
2780*da0073e9SAndroid Build Coastguard Worker                experiment_kwargs["tag"] = tag
2781*da0073e9SAndroid Build Coastguard Worker            results = []
2782*da0073e9SAndroid Build Coastguard Worker            with maybe_snapshot_memory(
2783*da0073e9SAndroid Build Coastguard Worker                self.args.snapshot_memory, f"eager_{self.args.only}"
2784*da0073e9SAndroid Build Coastguard Worker            ):
2785*da0073e9SAndroid Build Coastguard Worker                eager_latency, eager_peak_mem, _ = warmup(
2786*da0073e9SAndroid Build Coastguard Worker                    self.model_iter_fn, model, example_inputs, "eager"
2787*da0073e9SAndroid Build Coastguard Worker                )
2788*da0073e9SAndroid Build Coastguard Worker                if self.args.use_warm_peak_memory:
2789*da0073e9SAndroid Build Coastguard Worker                    _, eager_peak_mem, _ = warmup(
2790*da0073e9SAndroid Build Coastguard Worker                        self.model_iter_fn, model, example_inputs, "eager", niters=1
2791*da0073e9SAndroid Build Coastguard Worker                    )
2792*da0073e9SAndroid Build Coastguard Worker
2793*da0073e9SAndroid Build Coastguard Worker            if self.args.export_aot_inductor:
2794*da0073e9SAndroid Build Coastguard Worker                t_0 = time.perf_counter()
2795*da0073e9SAndroid Build Coastguard Worker                optimized_model_iter_fn = optimize_ctx
2796*da0073e9SAndroid Build Coastguard Worker                t_1 = time.perf_counter()
2797*da0073e9SAndroid Build Coastguard Worker                aot_compilation_time = t_1 - t_0
2798*da0073e9SAndroid Build Coastguard Worker            else:
2799*da0073e9SAndroid Build Coastguard Worker                optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
2800*da0073e9SAndroid Build Coastguard Worker                aot_compilation_time = 0
2801*da0073e9SAndroid Build Coastguard Worker
2802*da0073e9SAndroid Build Coastguard Worker            with maybe_enable_compiled_autograd(
2803*da0073e9SAndroid Build Coastguard Worker                self.args.compiled_autograd
2804*da0073e9SAndroid Build Coastguard Worker            ), maybe_snapshot_memory(
2805*da0073e9SAndroid Build Coastguard Worker                self.args.snapshot_memory, f"compiled_{self.args.only}"
2806*da0073e9SAndroid Build Coastguard Worker            ):
2807*da0073e9SAndroid Build Coastguard Worker                dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup(
2808*da0073e9SAndroid Build Coastguard Worker                    optimized_model_iter_fn, model, example_inputs, "dynamo"
2809*da0073e9SAndroid Build Coastguard Worker                )
2810*da0073e9SAndroid Build Coastguard Worker                if self.args.use_warm_peak_memory:
2811*da0073e9SAndroid Build Coastguard Worker                    _, dynamo_peak_mem, _ = warmup(
2812*da0073e9SAndroid Build Coastguard Worker                        optimized_model_iter_fn,
2813*da0073e9SAndroid Build Coastguard Worker                        model,
2814*da0073e9SAndroid Build Coastguard Worker                        example_inputs,
2815*da0073e9SAndroid Build Coastguard Worker                        "dynamo",
2816*da0073e9SAndroid Build Coastguard Worker                        niters=1,
2817*da0073e9SAndroid Build Coastguard Worker                    )
2818*da0073e9SAndroid Build Coastguard Worker
2819*da0073e9SAndroid Build Coastguard Worker            if self.args.profile_dynamo_cache_lookup:
2820*da0073e9SAndroid Build Coastguard Worker                with torch.profiler.profile(
2821*da0073e9SAndroid Build Coastguard Worker                    activities=[torch.profiler.ProfilerActivity.CPU]
2822*da0073e9SAndroid Build Coastguard Worker                ) as prof:
2823*da0073e9SAndroid Build Coastguard Worker                    with maybe_enable_compiled_autograd(self.args.compiled_autograd):
2824*da0073e9SAndroid Build Coastguard Worker                        warmup(optimized_model_iter_fn, model, example_inputs, "dynamo")
2825*da0073e9SAndroid Build Coastguard Worker
2826*da0073e9SAndroid Build Coastguard Worker                events = list(
2827*da0073e9SAndroid Build Coastguard Worker                    filter(
2828*da0073e9SAndroid Build Coastguard Worker                        lambda event: "TorchDynamo Cache Lookup" in event.key,
2829*da0073e9SAndroid Build Coastguard Worker                        prof.key_averages(),
2830*da0073e9SAndroid Build Coastguard Worker                    )
2831*da0073e9SAndroid Build Coastguard Worker                )
2832*da0073e9SAndroid Build Coastguard Worker                dynamo_cache_lookup_latency = events[0].self_cpu_time_total
2833*da0073e9SAndroid Build Coastguard Worker
2834*da0073e9SAndroid Build Coastguard Worker            compilation_time = dynamo_latency - eager_latency + aot_compilation_time
2835*da0073e9SAndroid Build Coastguard Worker            compression_ratio = (
2836*da0073e9SAndroid Build Coastguard Worker                eager_peak_mem / dynamo_peak_mem if dynamo_peak_mem else 0.0
2837*da0073e9SAndroid Build Coastguard Worker            )
2838*da0073e9SAndroid Build Coastguard Worker            if self.args.print_memory:
2839*da0073e9SAndroid Build Coastguard Worker                print(
2840*da0073e9SAndroid Build Coastguard Worker                    f"memory: eager: {eager_peak_mem:.2f} GB, "
2841*da0073e9SAndroid Build Coastguard Worker                    f"dynamo: {dynamo_peak_mem:.2f} GB, "
2842*da0073e9SAndroid Build Coastguard Worker                    f"ratio: {compression_ratio:.2f}"
2843*da0073e9SAndroid Build Coastguard Worker                )
2844*da0073e9SAndroid Build Coastguard Worker
2845*da0073e9SAndroid Build Coastguard Worker            if self.args.print_compilation_time:
2846*da0073e9SAndroid Build Coastguard Worker                print(f"Compilation time: {compilation_time:.2f}")
2847*da0073e9SAndroid Build Coastguard Worker
2848*da0073e9SAndroid Build Coastguard Worker            if experiment.func is speedup_experiment:
2849*da0073e9SAndroid Build Coastguard Worker                experiment_kwargs["compilation_latency"] = compilation_time
2850*da0073e9SAndroid Build Coastguard Worker                experiment_kwargs["compression_ratio"] = compression_ratio
2851*da0073e9SAndroid Build Coastguard Worker                experiment_kwargs["eager_peak_mem"] = eager_peak_mem
2852*da0073e9SAndroid Build Coastguard Worker                experiment_kwargs["dynamo_peak_mem"] = dynamo_peak_mem
2853*da0073e9SAndroid Build Coastguard Worker                experiment_kwargs["dynamo_stats"] = dynamo_stats
2854*da0073e9SAndroid Build Coastguard Worker                if self.args.profile_dynamo_cache_lookup:
2855*da0073e9SAndroid Build Coastguard Worker                    experiment_kwargs[
2856*da0073e9SAndroid Build Coastguard Worker                        "cache_lookup_latency"
2857*da0073e9SAndroid Build Coastguard Worker                    ] = dynamo_cache_lookup_latency
2858*da0073e9SAndroid Build Coastguard Worker
2859*da0073e9SAndroid Build Coastguard Worker            if experiment.func is coverage_experiment:
2860*da0073e9SAndroid Build Coastguard Worker                ok, total = Stats.reset_counters()
2861*da0073e9SAndroid Build Coastguard Worker                results = []
2862*da0073e9SAndroid Build Coastguard Worker                # run with torch._dynamo few times to populate the cache
2863*da0073e9SAndroid Build Coastguard Worker                for _ in range(3):
2864*da0073e9SAndroid Build Coastguard Worker                    optimized_model_iter_fn(model, example_inputs)
2865*da0073e9SAndroid Build Coastguard Worker                _, frames_second_pass = Stats.reset_counters()  # should be 0
2866*da0073e9SAndroid Build Coastguard Worker                if frames_second_pass > 0:
2867*da0073e9SAndroid Build Coastguard Worker                    optimized_model_iter_fn(model, example_inputs)
2868*da0073e9SAndroid Build Coastguard Worker                    _, frames_third_pass = Stats.reset_counters()  # should be 0
2869*da0073e9SAndroid Build Coastguard Worker                else:
2870*da0073e9SAndroid Build Coastguard Worker                    frames_third_pass = 0
2871*da0073e9SAndroid Build Coastguard Worker
2872*da0073e9SAndroid Build Coastguard Worker                results.append(
2873*da0073e9SAndroid Build Coastguard Worker                    f"{ok:3}/{total:3} +{frames_third_pass} frames {compilation_time:3.0f}s"
2874*da0073e9SAndroid Build Coastguard Worker                )
2875*da0073e9SAndroid Build Coastguard Worker
2876*da0073e9SAndroid Build Coastguard Worker            if experiment.func is speedup_experiment_onnx:
2877*da0073e9SAndroid Build Coastguard Worker                experiment = functools.partial(
2878*da0073e9SAndroid Build Coastguard Worker                    experiment, optimized_model_iter_fn.context.onnx_model
2879*da0073e9SAndroid Build Coastguard Worker                )
2880*da0073e9SAndroid Build Coastguard Worker
2881*da0073e9SAndroid Build Coastguard Worker            if not hasattr(model, name):
2882*da0073e9SAndroid Build Coastguard Worker                model.name = name
2883*da0073e9SAndroid Build Coastguard Worker            results.append(experiment(model, example_inputs, **experiment_kwargs))
2884*da0073e9SAndroid Build Coastguard Worker            return " ".join(map(str, results))
2885*da0073e9SAndroid Build Coastguard Worker
2886*da0073e9SAndroid Build Coastguard Worker    def minify_model(
2887*da0073e9SAndroid Build Coastguard Worker        self,
2888*da0073e9SAndroid Build Coastguard Worker        name,
2889*da0073e9SAndroid Build Coastguard Worker        model,
2890*da0073e9SAndroid Build Coastguard Worker        example_inputs,
2891*da0073e9SAndroid Build Coastguard Worker        optimize_ctx,
2892*da0073e9SAndroid Build Coastguard Worker        experiment,
2893*da0073e9SAndroid Build Coastguard Worker        tag,
2894*da0073e9SAndroid Build Coastguard Worker    ):
2895*da0073e9SAndroid Build Coastguard Worker        logging.info("Minifying %s...", name)
2896*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCH_COMPILE_DEBUG"] = "1"
2897*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCHDYNAMO_REPRO_AFTER"] = "dynamo"
2898*da0073e9SAndroid Build Coastguard Worker        os.environ["TORCHDYNAMO_REPRO_LEVEL"] = "4"
2899*da0073e9SAndroid Build Coastguard Worker
2900*da0073e9SAndroid Build Coastguard Worker        self.check_accuracy(name, model, example_inputs, optimize_ctx, experiment, tag)
2901*da0073e9SAndroid Build Coastguard Worker
2902*da0073e9SAndroid Build Coastguard Worker        if self.args.output_directory:
2903*da0073e9SAndroid Build Coastguard Worker            repro_dir = self.args.output_directory
2904*da0073e9SAndroid Build Coastguard Worker        else:
2905*da0073e9SAndroid Build Coastguard Worker            repro_dir = torch._dynamo.config.base_dir
2906*da0073e9SAndroid Build Coastguard Worker
2907*da0073e9SAndroid Build Coastguard Worker        try:
2908*da0073e9SAndroid Build Coastguard Worker            shutil.move("repro.py", f"{repro_dir}/{name}_repro.py")
2909*da0073e9SAndroid Build Coastguard Worker        except OSError as e:
2910*da0073e9SAndroid Build Coastguard Worker            logging.error("Could not find repro script for model %s", name)
2911*da0073e9SAndroid Build Coastguard Worker        else:
2912*da0073e9SAndroid Build Coastguard Worker            logging.info(
2913*da0073e9SAndroid Build Coastguard Worker                "Repro script for model %s with minified graph saved to %s",
2914*da0073e9SAndroid Build Coastguard Worker                name,
2915*da0073e9SAndroid Build Coastguard Worker                repro_dir,
2916*da0073e9SAndroid Build Coastguard Worker            )
2917*da0073e9SAndroid Build Coastguard Worker
2918*da0073e9SAndroid Build Coastguard Worker    def maybe_preserve_compile_debug(self, name, status):
2919*da0073e9SAndroid Build Coastguard Worker        if (
2920*da0073e9SAndroid Build Coastguard Worker            name in CI_PRESERVE_COMPILE_DEBUG
2921*da0073e9SAndroid Build Coastguard Worker            and status in CI_PRESERVE_COMPILE_DEBUG[name]
2922*da0073e9SAndroid Build Coastguard Worker        ):
2923*da0073e9SAndroid Build Coastguard Worker            src_dir = torch._dynamo.utils.get_debug_dir()
2924*da0073e9SAndroid Build Coastguard Worker            if os.path.isdir(src_dir):
2925*da0073e9SAndroid Build Coastguard Worker                dbg_dir = os.path.join(
2926*da0073e9SAndroid Build Coastguard Worker                    os.getcwd(), "test", "debug", "torch_compile_debug"
2927*da0073e9SAndroid Build Coastguard Worker                )
2928*da0073e9SAndroid Build Coastguard Worker                dst_dir = os.path.join(dbg_dir, os.path.basename(src_dir))
2929*da0073e9SAndroid Build Coastguard Worker                try:
2930*da0073e9SAndroid Build Coastguard Worker                    os.makedirs(dbg_dir, exist_ok=True)
2931*da0073e9SAndroid Build Coastguard Worker                    os.rename(src_dir, dst_dir)
2932*da0073e9SAndroid Build Coastguard Worker                    log.warning("Moved %s to %s", src_dir, dst_dir)
2933*da0073e9SAndroid Build Coastguard Worker                except OSError:
2934*da0073e9SAndroid Build Coastguard Worker                    log.exception("Failed to preserve %s", src_dir)
2935*da0073e9SAndroid Build Coastguard Worker
2936*da0073e9SAndroid Build Coastguard Worker    def run_one_model(
2937*da0073e9SAndroid Build Coastguard Worker        self,
2938*da0073e9SAndroid Build Coastguard Worker        name,
2939*da0073e9SAndroid Build Coastguard Worker        model,
2940*da0073e9SAndroid Build Coastguard Worker        example_inputs,
2941*da0073e9SAndroid Build Coastguard Worker        optimize_ctx,
2942*da0073e9SAndroid Build Coastguard Worker        experiment,
2943*da0073e9SAndroid Build Coastguard Worker        explain=False,
2944*da0073e9SAndroid Build Coastguard Worker        tag=None,
2945*da0073e9SAndroid Build Coastguard Worker    ):
2946*da0073e9SAndroid Build Coastguard Worker        mode = "train" if self.args.training else "eval"
2947*da0073e9SAndroid Build Coastguard Worker        msg = f"{current_device:4} {mode:5} {current_name:34} "
2948*da0073e9SAndroid Build Coastguard Worker        if tag:
2949*da0073e9SAndroid Build Coastguard Worker            msg += f" {tag:26}"
2950*da0073e9SAndroid Build Coastguard Worker        print(msg, flush=True)
2951*da0073e9SAndroid Build Coastguard Worker
2952*da0073e9SAndroid Build Coastguard Worker        start_stats = get_dynamo_stats()
2953*da0073e9SAndroid Build Coastguard Worker
2954*da0073e9SAndroid Build Coastguard Worker        if self.args.accuracy:
2955*da0073e9SAndroid Build Coastguard Worker            status = self.check_accuracy(
2956*da0073e9SAndroid Build Coastguard Worker                name, model, example_inputs, optimize_ctx, experiment, tag
2957*da0073e9SAndroid Build Coastguard Worker            )
2958*da0073e9SAndroid Build Coastguard Worker            print(status)
2959*da0073e9SAndroid Build Coastguard Worker            if status == "fail_accuracy" and self.args.minify:
2960*da0073e9SAndroid Build Coastguard Worker                self.minify_model(
2961*da0073e9SAndroid Build Coastguard Worker                    name, model, example_inputs, optimize_ctx, experiment, tag
2962*da0073e9SAndroid Build Coastguard Worker                )
2963*da0073e9SAndroid Build Coastguard Worker        elif self.args.tolerance:
2964*da0073e9SAndroid Build Coastguard Worker            status = self.check_tolerance(name, model, example_inputs, optimize_ctx)
2965*da0073e9SAndroid Build Coastguard Worker            print(status)
2966*da0073e9SAndroid Build Coastguard Worker        elif self.args.performance:
2967*da0073e9SAndroid Build Coastguard Worker            status = self.run_performance_test(
2968*da0073e9SAndroid Build Coastguard Worker                name, model, example_inputs, optimize_ctx, experiment, tag
2969*da0073e9SAndroid Build Coastguard Worker            )
2970*da0073e9SAndroid Build Coastguard Worker            print(status)
2971*da0073e9SAndroid Build Coastguard Worker        empty_gpu_cache(current_device)
2972*da0073e9SAndroid Build Coastguard Worker
2973*da0073e9SAndroid Build Coastguard Worker        self.maybe_preserve_compile_debug(name, status)
2974*da0073e9SAndroid Build Coastguard Worker
2975*da0073e9SAndroid Build Coastguard Worker        if self.args.timing:
2976*da0073e9SAndroid Build Coastguard Worker            from torch._dynamo.utils import op_count, print_time_report
2977*da0073e9SAndroid Build Coastguard Worker            from torch.utils._stats import simple_call_counter
2978*da0073e9SAndroid Build Coastguard Worker
2979*da0073e9SAndroid Build Coastguard Worker            print_time_report()
2980*da0073e9SAndroid Build Coastguard Worker            stats = "STATS: "
2981*da0073e9SAndroid Build Coastguard Worker            stats = stats + " | ".join(
2982*da0073e9SAndroid Build Coastguard Worker                itertools.chain(
2983*da0073e9SAndroid Build Coastguard Worker                    [f"call_* op count: {op_count}"],
2984*da0073e9SAndroid Build Coastguard Worker                    (f"{key}:{value}" for key, value in simple_call_counter.items()),
2985*da0073e9SAndroid Build Coastguard Worker                )
2986*da0073e9SAndroid Build Coastguard Worker            )
2987*da0073e9SAndroid Build Coastguard Worker            print(stats)
2988*da0073e9SAndroid Build Coastguard Worker        stats = get_dynamo_stats()
2989*da0073e9SAndroid Build Coastguard Worker        stats.subtract(start_stats)
2990*da0073e9SAndroid Build Coastguard Worker
2991*da0073e9SAndroid Build Coastguard Worker        if explain:
2992*da0073e9SAndroid Build Coastguard Worker            print(
2993*da0073e9SAndroid Build Coastguard Worker                f"Dynamo produced {stats['unique_graphs']} graphs "
2994*da0073e9SAndroid Build Coastguard Worker                f"covering {stats['calls_captured']} ops with "
2995*da0073e9SAndroid Build Coastguard Worker                f"{stats['graph_breaks']} graph breaks ({stats['unique_graph_breaks']} unique)"
2996*da0073e9SAndroid Build Coastguard Worker            )
2997*da0073e9SAndroid Build Coastguard Worker
2998*da0073e9SAndroid Build Coastguard Worker        if explain or self.args.log_graph_breaks or self.args.print_graph_breaks:
2999*da0073e9SAndroid Build Coastguard Worker            filename = f"{output_filename.rstrip('.csv')}_graph_breaks.csv"
3000*da0073e9SAndroid Build Coastguard Worker
3001*da0073e9SAndroid Build Coastguard Worker            def add_double_quotes(x):
3002*da0073e9SAndroid Build Coastguard Worker                # Delimiter because reason could have comma
3003*da0073e9SAndroid Build Coastguard Worker                return f'"{x}"'
3004*da0073e9SAndroid Build Coastguard Worker
3005*da0073e9SAndroid Build Coastguard Worker            for graph_break in graph_break_reasons:
3006*da0073e9SAndroid Build Coastguard Worker                reason = add_double_quotes(graph_break.reason)
3007*da0073e9SAndroid Build Coastguard Worker                user_stack = add_double_quotes(
3008*da0073e9SAndroid Build Coastguard Worker                    ", ".join([str(x) for x in graph_break.user_stack])
3009*da0073e9SAndroid Build Coastguard Worker                )
3010*da0073e9SAndroid Build Coastguard Worker                output_csv(
3011*da0073e9SAndroid Build Coastguard Worker                    filename,
3012*da0073e9SAndroid Build Coastguard Worker                    ["model", "reason", "user_stack"],
3013*da0073e9SAndroid Build Coastguard Worker                    [current_name, reason, user_stack],
3014*da0073e9SAndroid Build Coastguard Worker                )
3015*da0073e9SAndroid Build Coastguard Worker
3016*da0073e9SAndroid Build Coastguard Worker        if self.args.stats:
3017*da0073e9SAndroid Build Coastguard Worker            Stats.print_summary()
3018*da0073e9SAndroid Build Coastguard Worker
3019*da0073e9SAndroid Build Coastguard Worker
3020*da0073e9SAndroid Build Coastguard Workerdef help(fn):
3021*da0073e9SAndroid Build Coastguard Worker    return fn.__doc__
3022*da0073e9SAndroid Build Coastguard Worker
3023*da0073e9SAndroid Build Coastguard Worker
3024*da0073e9SAndroid Build Coastguard Workerdiff_branch_default = "DIFF-BRANCH-DEFAULT"
3025*da0073e9SAndroid Build Coastguard Worker
3026*da0073e9SAndroid Build Coastguard Worker
3027*da0073e9SAndroid Build Coastguard Workerdef should_diff_branch(args):
3028*da0073e9SAndroid Build Coastguard Worker    return args.diff_branch != diff_branch_default
3029*da0073e9SAndroid Build Coastguard Worker
3030*da0073e9SAndroid Build Coastguard Worker
3031*da0073e9SAndroid Build Coastguard Workerdef parse_args(args=None):
3032*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser()
3033*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3034*da0073e9SAndroid Build Coastguard Worker        "--filter", "-k", action="append", help="filter benchmarks with regexp"
3035*da0073e9SAndroid Build Coastguard Worker    )
3036*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3037*da0073e9SAndroid Build Coastguard Worker        "--exclude", "-x", action="append", help="filter benchmarks with regexp"
3038*da0073e9SAndroid Build Coastguard Worker    )
3039*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3040*da0073e9SAndroid Build Coastguard Worker        "--exclude-exact", action="append", help="filter benchmarks with exact match"
3041*da0073e9SAndroid Build Coastguard Worker    )
3042*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3043*da0073e9SAndroid Build Coastguard Worker        "--total-partitions",
3044*da0073e9SAndroid Build Coastguard Worker        type=int,
3045*da0073e9SAndroid Build Coastguard Worker        default=1,
3046*da0073e9SAndroid Build Coastguard Worker        choices=range(1, 10),
3047*da0073e9SAndroid Build Coastguard Worker        help="Total number of partitions we want to divide the benchmark suite into",
3048*da0073e9SAndroid Build Coastguard Worker    )
3049*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3050*da0073e9SAndroid Build Coastguard Worker        "--partition-id",
3051*da0073e9SAndroid Build Coastguard Worker        type=int,
3052*da0073e9SAndroid Build Coastguard Worker        default=0,
3053*da0073e9SAndroid Build Coastguard Worker        help="ID of the benchmark suite partition to be run. Used to divide CI tasks",
3054*da0073e9SAndroid Build Coastguard Worker    )
3055*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3056*da0073e9SAndroid Build Coastguard Worker        "--devices", "--device", "-d", action="append", help="cpu or cuda"
3057*da0073e9SAndroid Build Coastguard Worker    )
3058*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--device-index", help="CUDA device index")
3059*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3060*da0073e9SAndroid Build Coastguard Worker        "--repeat", "-n", type=int, default=30, help="number of timing runs"
3061*da0073e9SAndroid Build Coastguard Worker    )
3062*da0073e9SAndroid Build Coastguard Worker    iterations_per_run_help = """
3063*da0073e9SAndroid Build Coastguard Worker        Run this may iterations for each time measurement. This is mainly used for
3064*da0073e9SAndroid Build Coastguard Worker        XLA training. We want to run multiple iterations per measurement so the
3065*da0073e9SAndroid Build Coastguard Worker        tracing and computation for different iteartions can overlap with each
3066*da0073e9SAndroid Build Coastguard Worker        other. This makes sure we have an accurate xla baseline.
3067*da0073e9SAndroid Build Coastguard Worker    """
3068*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3069*da0073e9SAndroid Build Coastguard Worker        "--iterations-per-run", type=int, default=1, help=iterations_per_run_help
3070*da0073e9SAndroid Build Coastguard Worker    )
3071*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3072*da0073e9SAndroid Build Coastguard Worker        "--randomize-input",
3073*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3074*da0073e9SAndroid Build Coastguard Worker        help="Whether to randomize the input values. Dimensions will be kept the same.",
3075*da0073e9SAndroid Build Coastguard Worker    )
3076*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3077*da0073e9SAndroid Build Coastguard Worker        "--threads",
3078*da0073e9SAndroid Build Coastguard Worker        "-t",
3079*da0073e9SAndroid Build Coastguard Worker        type=int,
3080*da0073e9SAndroid Build Coastguard Worker        help="number of threads to use for eager and inductor",
3081*da0073e9SAndroid Build Coastguard Worker    )
3082*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3083*da0073e9SAndroid Build Coastguard Worker        "--nopython", action="store_true", help="Turn graph breaks into errors"
3084*da0073e9SAndroid Build Coastguard Worker    )
3085*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3086*da0073e9SAndroid Build Coastguard Worker        "--no-skip",
3087*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3088*da0073e9SAndroid Build Coastguard Worker        help="run models that are in the global SKIP list",
3089*da0073e9SAndroid Build Coastguard Worker    )
3090*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3091*da0073e9SAndroid Build Coastguard Worker        "--prims-nvfuser", action="store_true", help="user prims + nvfuser backend"
3092*da0073e9SAndroid Build Coastguard Worker    )
3093*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3094*da0073e9SAndroid Build Coastguard Worker        "--dump-raw-metrics",
3095*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3096*da0073e9SAndroid Build Coastguard Worker        help="dump raw timing metrics from speedup experiment",
3097*da0073e9SAndroid Build Coastguard Worker    )
3098*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3099*da0073e9SAndroid Build Coastguard Worker        "--log-operator-inputs",
3100*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3101*da0073e9SAndroid Build Coastguard Worker        default=False,
3102*da0073e9SAndroid Build Coastguard Worker    )
3103*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3104*da0073e9SAndroid Build Coastguard Worker        "--channels-last",
3105*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3106*da0073e9SAndroid Build Coastguard Worker        default=False,
3107*da0073e9SAndroid Build Coastguard Worker        help="use channels last format",
3108*da0073e9SAndroid Build Coastguard Worker    )
3109*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3110*da0073e9SAndroid Build Coastguard Worker        "--batch-size", "--batch_size", type=int, help="batch size for benchmarking"
3111*da0073e9SAndroid Build Coastguard Worker    )
3112*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3113*da0073e9SAndroid Build Coastguard Worker        "--iterations", type=int, default=2, help="how many iterations to run"
3114*da0073e9SAndroid Build Coastguard Worker    )
3115*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3116*da0073e9SAndroid Build Coastguard Worker        "--batch-size-file", type=str, help="String to load batch size from"
3117*da0073e9SAndroid Build Coastguard Worker    )
3118*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--cosine", action="store_true", help="use cosine similarity")
3119*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3120*da0073e9SAndroid Build Coastguard Worker        "--freezing", action="store_true", help="turn on freezing", default=False
3121*da0073e9SAndroid Build Coastguard Worker    )
3122*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3123*da0073e9SAndroid Build Coastguard Worker        "--ci", action="store_true", help="Flag to tell that its a CI run"
3124*da0073e9SAndroid Build Coastguard Worker    )
3125*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3126*da0073e9SAndroid Build Coastguard Worker        "--dashboard", action="store_true", help="Flag to tell that its a Dashboard run"
3127*da0073e9SAndroid Build Coastguard Worker    )
3128*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3129*da0073e9SAndroid Build Coastguard Worker        "--skip-fp64-check", action="store_true", help="skip accuracy check using fp64"
3130*da0073e9SAndroid Build Coastguard Worker    )
3131*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3132*da0073e9SAndroid Build Coastguard Worker        "--fast", "-f", action="store_true", help="skip slow benchmarks"
3133*da0073e9SAndroid Build Coastguard Worker    )
3134*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3135*da0073e9SAndroid Build Coastguard Worker        "--only",
3136*da0073e9SAndroid Build Coastguard Worker        help="""Run just one model from torchbench. Or
3137*da0073e9SAndroid Build Coastguard Worker        specify the path and class name of the model in format like:
3138*da0073e9SAndroid Build Coastguard Worker        --only=path:<MODEL_FILE_PATH>,class:<CLASS_NAME>
3139*da0073e9SAndroid Build Coastguard Worker
3140*da0073e9SAndroid Build Coastguard Worker        Due to the fact that dynamo changes current working directory,
3141*da0073e9SAndroid Build Coastguard Worker        the path should be an absolute path.
3142*da0073e9SAndroid Build Coastguard Worker
3143*da0073e9SAndroid Build Coastguard Worker        The class should have a method get_example_inputs to return the inputs
3144*da0073e9SAndroid Build Coastguard Worker        for the model. An example looks like
3145*da0073e9SAndroid Build Coastguard Worker        ```
3146*da0073e9SAndroid Build Coastguard Worker        class LinearModel(nn.Module):
3147*da0073e9SAndroid Build Coastguard Worker            def __init__(self):
3148*da0073e9SAndroid Build Coastguard Worker                super().__init__()
3149*da0073e9SAndroid Build Coastguard Worker                self.linear = nn.Linear(10, 10)
3150*da0073e9SAndroid Build Coastguard Worker
3151*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3152*da0073e9SAndroid Build Coastguard Worker                return self.linear(x)
3153*da0073e9SAndroid Build Coastguard Worker
3154*da0073e9SAndroid Build Coastguard Worker            def get_example_inputs(self):
3155*da0073e9SAndroid Build Coastguard Worker                return (torch.randn(2, 10),)
3156*da0073e9SAndroid Build Coastguard Worker        ```
3157*da0073e9SAndroid Build Coastguard Worker    """,
3158*da0073e9SAndroid Build Coastguard Worker    )
3159*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3160*da0073e9SAndroid Build Coastguard Worker        "--multiprocess",
3161*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3162*da0073e9SAndroid Build Coastguard Worker        help="Create n processes based on the number of devices (distributed use case).",
3163*da0073e9SAndroid Build Coastguard Worker    )
3164*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3165*da0073e9SAndroid Build Coastguard Worker        "--ddp",
3166*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3167*da0073e9SAndroid Build Coastguard Worker        help="Wraps model in DDP before running it, and uses dynamo DDPOptmizer (graph breaks) by default.",
3168*da0073e9SAndroid Build Coastguard Worker    )
3169*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3170*da0073e9SAndroid Build Coastguard Worker        "--fsdp",
3171*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3172*da0073e9SAndroid Build Coastguard Worker        help="""Wraps model in FSDP before running it.
3173*da0073e9SAndroid Build Coastguard Worker        Doesn't recursively wrap, mainly useful for checking dynamo UnspecNNModule compatibility
3174*da0073e9SAndroid Build Coastguard Worker    """,
3175*da0073e9SAndroid Build Coastguard Worker    )
3176*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3177*da0073e9SAndroid Build Coastguard Worker        "--optimize-ddp-mode",
3178*da0073e9SAndroid Build Coastguard Worker        type=str,
3179*da0073e9SAndroid Build Coastguard Worker        default="ddp_optimizer",
3180*da0073e9SAndroid Build Coastguard Worker        help="Specify the DDP optimization mode -- the value of torch._dynamo.config.optimize_ddp.",
3181*da0073e9SAndroid Build Coastguard Worker    )
3182*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3183*da0073e9SAndroid Build Coastguard Worker        "--distributed-master-port",
3184*da0073e9SAndroid Build Coastguard Worker        default="6789",
3185*da0073e9SAndroid Build Coastguard Worker        help="Port to bind for for torch.distributed.  Use the default unless it's conflicting with another user",
3186*da0073e9SAndroid Build Coastguard Worker    )
3187*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3188*da0073e9SAndroid Build Coastguard Worker        "--dynamic-shapes",
3189*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3190*da0073e9SAndroid Build Coastguard Worker        help="Runs a dynamic shapes version of the benchmark, if available.",
3191*da0073e9SAndroid Build Coastguard Worker    )
3192*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3193*da0073e9SAndroid Build Coastguard Worker        "--propagate-real-tensors",
3194*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3195*da0073e9SAndroid Build Coastguard Worker        help="Capture as much data dependent as you can by unsoundly propagating real tensors",
3196*da0073e9SAndroid Build Coastguard Worker    )
3197*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3198*da0073e9SAndroid Build Coastguard Worker        "--dynamic-batch-only",
3199*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3200*da0073e9SAndroid Build Coastguard Worker        help="Only assume batch dimension is dynamic.  Implies --dynamic-shapes",
3201*da0073e9SAndroid Build Coastguard Worker    )
3202*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3203*da0073e9SAndroid Build Coastguard Worker        "--specialize-int", action="store_true", help="Run with specialize_int=True."
3204*da0073e9SAndroid Build Coastguard Worker    )
3205*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3206*da0073e9SAndroid Build Coastguard Worker        "--use-eval-mode",
3207*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3208*da0073e9SAndroid Build Coastguard Worker        help="sets model.eval() to reduce randomness",
3209*da0073e9SAndroid Build Coastguard Worker    )
3210*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3211*da0073e9SAndroid Build Coastguard Worker        "--skip-accuracy-check",
3212*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3213*da0073e9SAndroid Build Coastguard Worker        help="keeps running even when accuracy fails",
3214*da0073e9SAndroid Build Coastguard Worker    )
3215*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3216*da0073e9SAndroid Build Coastguard Worker        "--generate-aot-autograd-stats",
3217*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3218*da0073e9SAndroid Build Coastguard Worker        help="Generates AOT Autograd stats like how mnay graphs are sent to AOT",
3219*da0073e9SAndroid Build Coastguard Worker    )
3220*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3221*da0073e9SAndroid Build Coastguard Worker        "--inductor-settings",
3222*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3223*da0073e9SAndroid Build Coastguard Worker        help="Use same settings as --inductor for baseline comparisons",
3224*da0073e9SAndroid Build Coastguard Worker    )
3225*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3226*da0073e9SAndroid Build Coastguard Worker        "--suppress-errors",
3227*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3228*da0073e9SAndroid Build Coastguard Worker        help="Suppress errors instead of raising them",
3229*da0073e9SAndroid Build Coastguard Worker    )
3230*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3231*da0073e9SAndroid Build Coastguard Worker        "--output",
3232*da0073e9SAndroid Build Coastguard Worker        help="Overrides the output filename",
3233*da0073e9SAndroid Build Coastguard Worker    )
3234*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3235*da0073e9SAndroid Build Coastguard Worker        "--output-directory",
3236*da0073e9SAndroid Build Coastguard Worker        help="Overrides the directory to place output files.",
3237*da0073e9SAndroid Build Coastguard Worker    )
3238*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3239*da0073e9SAndroid Build Coastguard Worker        "--disable-output",
3240*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3241*da0073e9SAndroid Build Coastguard Worker        help="Disable writing of output files, e.g., for warm-up runs",
3242*da0073e9SAndroid Build Coastguard Worker    )
3243*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3244*da0073e9SAndroid Build Coastguard Worker        "--baseline",
3245*da0073e9SAndroid Build Coastguard Worker        help="Compare with a prior --output",
3246*da0073e9SAndroid Build Coastguard Worker    )
3247*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3248*da0073e9SAndroid Build Coastguard Worker        "--part",
3249*da0073e9SAndroid Build Coastguard Worker        default=None,
3250*da0073e9SAndroid Build Coastguard Worker        help="Specify the part of the model to run.",
3251*da0073e9SAndroid Build Coastguard Worker    )
3252*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3253*da0073e9SAndroid Build Coastguard Worker        "--export-profiler-trace",
3254*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3255*da0073e9SAndroid Build Coastguard Worker        help="exports trace of kineto profiler",
3256*da0073e9SAndroid Build Coastguard Worker    )
3257*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3258*da0073e9SAndroid Build Coastguard Worker        "--profiler-trace-name",
3259*da0073e9SAndroid Build Coastguard Worker        "--profiler_trace_name",
3260*da0073e9SAndroid Build Coastguard Worker        help="Overwrites exported trace name",
3261*da0073e9SAndroid Build Coastguard Worker    )
3262*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3263*da0073e9SAndroid Build Coastguard Worker        "--diff-branch",
3264*da0073e9SAndroid Build Coastguard Worker        default=diff_branch_default,
3265*da0073e9SAndroid Build Coastguard Worker        help="delta current branch against given branch.",
3266*da0073e9SAndroid Build Coastguard Worker    )
3267*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3268*da0073e9SAndroid Build Coastguard Worker        "--tag", default=None, help="Specify a tag to be included in csv files."
3269*da0073e9SAndroid Build Coastguard Worker    )
3270*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3271*da0073e9SAndroid Build Coastguard Worker        "--explain",
3272*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3273*da0073e9SAndroid Build Coastguard Worker        help="print some graph/op statistics during the run, similar to .explain()",
3274*da0073e9SAndroid Build Coastguard Worker    )
3275*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3276*da0073e9SAndroid Build Coastguard Worker        "--stats",
3277*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3278*da0073e9SAndroid Build Coastguard Worker        help="print graph counter stats",
3279*da0073e9SAndroid Build Coastguard Worker    )
3280*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3281*da0073e9SAndroid Build Coastguard Worker        "--use-warm-peak-memory",
3282*da0073e9SAndroid Build Coastguard Worker        "--use_warm_peak_memory",
3283*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3284*da0073e9SAndroid Build Coastguard Worker        help="Measure peak memory using a warm run to reduce autotuning noise",
3285*da0073e9SAndroid Build Coastguard Worker    )
3286*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3287*da0073e9SAndroid Build Coastguard Worker        "--print-memory",
3288*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3289*da0073e9SAndroid Build Coastguard Worker        help="print extra memory statistics",
3290*da0073e9SAndroid Build Coastguard Worker    )
3291*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3292*da0073e9SAndroid Build Coastguard Worker        "--print-compilation-time",
3293*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3294*da0073e9SAndroid Build Coastguard Worker        help="print compilation latency",
3295*da0073e9SAndroid Build Coastguard Worker    )
3296*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3297*da0073e9SAndroid Build Coastguard Worker        "--print-dataframe-summary",
3298*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3299*da0073e9SAndroid Build Coastguard Worker        help="print dataframe result used for calculating accuracy",
3300*da0073e9SAndroid Build Coastguard Worker    )
3301*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3302*da0073e9SAndroid Build Coastguard Worker        "--disable-cudagraphs",
3303*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3304*da0073e9SAndroid Build Coastguard Worker        help="Disables cudagraphs for Inductor",
3305*da0073e9SAndroid Build Coastguard Worker    )
3306*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3307*da0073e9SAndroid Build Coastguard Worker        "--disable-split-reductions",
3308*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3309*da0073e9SAndroid Build Coastguard Worker        help="Disables split reductions for Inductor",
3310*da0073e9SAndroid Build Coastguard Worker    )
3311*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3312*da0073e9SAndroid Build Coastguard Worker        "--disable-persistent-reductions",
3313*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3314*da0073e9SAndroid Build Coastguard Worker        help="Disables split reductions for Inductor",
3315*da0073e9SAndroid Build Coastguard Worker    )
3316*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3317*da0073e9SAndroid Build Coastguard Worker        "--disable-divisible-by-16",
3318*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3319*da0073e9SAndroid Build Coastguard Worker        help="Disables divisible by 16 hint to Triton for Inductor",
3320*da0073e9SAndroid Build Coastguard Worker    )
3321*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3322*da0073e9SAndroid Build Coastguard Worker        "--inductor-compile-mode",
3323*da0073e9SAndroid Build Coastguard Worker        default=None,
3324*da0073e9SAndroid Build Coastguard Worker        help="torch.compile mode argument for inductor runs.",
3325*da0073e9SAndroid Build Coastguard Worker    )
3326*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3327*da0073e9SAndroid Build Coastguard Worker        "--print-graph-breaks",
3328*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3329*da0073e9SAndroid Build Coastguard Worker        help="Show a warning whenever graph break",
3330*da0073e9SAndroid Build Coastguard Worker    )
3331*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3332*da0073e9SAndroid Build Coastguard Worker        "--log-graph-breaks",
3333*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3334*da0073e9SAndroid Build Coastguard Worker        help="log graph breaks in a file",
3335*da0073e9SAndroid Build Coastguard Worker    )
3336*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3337*da0073e9SAndroid Build Coastguard Worker        "--trace-on-xla",
3338*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3339*da0073e9SAndroid Build Coastguard Worker        help="Whether to trace the model on XLA or on eager device",
3340*da0073e9SAndroid Build Coastguard Worker    )
3341*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3342*da0073e9SAndroid Build Coastguard Worker        "--xla-tolerance",
3343*da0073e9SAndroid Build Coastguard Worker        type=float,
3344*da0073e9SAndroid Build Coastguard Worker        default=1e-2,
3345*da0073e9SAndroid Build Coastguard Worker        help="XLA needs a loose tolerance to pass the correctness check",
3346*da0073e9SAndroid Build Coastguard Worker    )
3347*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3348*da0073e9SAndroid Build Coastguard Worker        "--collect-outputs",
3349*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3350*da0073e9SAndroid Build Coastguard Worker        help="""Whether to collect outputs for training. Set this to true if we
3351*da0073e9SAndroid Build Coastguard Worker        want to verify the numerical correctness of graidents. But that may
3352*da0073e9SAndroid Build Coastguard Worker        cause time measurement not accurate""",
3353*da0073e9SAndroid Build Coastguard Worker    )
3354*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3355*da0073e9SAndroid Build Coastguard Worker        "--enable-activation-checkpointing",
3356*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3357*da0073e9SAndroid Build Coastguard Worker        help="Enables activation checkpointing for HF models",
3358*da0073e9SAndroid Build Coastguard Worker    )
3359*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--timing", action="store_true", help="Emits phase timing")
3360*da0073e9SAndroid Build Coastguard Worker
3361*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3362*da0073e9SAndroid Build Coastguard Worker        "--progress",
3363*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3364*da0073e9SAndroid Build Coastguard Worker        help="Print n/k models message between each model run.",
3365*da0073e9SAndroid Build Coastguard Worker    )
3366*da0073e9SAndroid Build Coastguard Worker
3367*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3368*da0073e9SAndroid Build Coastguard Worker        "--timeout",
3369*da0073e9SAndroid Build Coastguard Worker        type=int,
3370*da0073e9SAndroid Build Coastguard Worker        default=2000,
3371*da0073e9SAndroid Build Coastguard Worker        help="timeout (second) for benchmarking.",
3372*da0073e9SAndroid Build Coastguard Worker    )
3373*da0073e9SAndroid Build Coastguard Worker
3374*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3375*da0073e9SAndroid Build Coastguard Worker        "--per_process_memory_fraction",
3376*da0073e9SAndroid Build Coastguard Worker        type=float,
3377*da0073e9SAndroid Build Coastguard Worker        default=1,
3378*da0073e9SAndroid Build Coastguard Worker        help="Set per-process GPU memory fraction (limit) for reducing usable size and reproducing OOMs",
3379*da0073e9SAndroid Build Coastguard Worker    )
3380*da0073e9SAndroid Build Coastguard Worker
3381*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3382*da0073e9SAndroid Build Coastguard Worker        "--no-translation-validation",
3383*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3384*da0073e9SAndroid Build Coastguard Worker        help="Disable translation validation for accuracy builds.",
3385*da0073e9SAndroid Build Coastguard Worker    )
3386*da0073e9SAndroid Build Coastguard Worker
3387*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3388*da0073e9SAndroid Build Coastguard Worker        "--minify",
3389*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3390*da0073e9SAndroid Build Coastguard Worker        help="Enable minification when failure is below tolerance. Save repro script for each model.",
3391*da0073e9SAndroid Build Coastguard Worker    )
3392*da0073e9SAndroid Build Coastguard Worker
3393*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3394*da0073e9SAndroid Build Coastguard Worker        "--compiled-autograd",
3395*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3396*da0073e9SAndroid Build Coastguard Worker        help="Enables compiled autograd on compiled benchmark",
3397*da0073e9SAndroid Build Coastguard Worker    )
3398*da0073e9SAndroid Build Coastguard Worker
3399*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3400*da0073e9SAndroid Build Coastguard Worker        "--profile_dynamo_cache_lookup",
3401*da0073e9SAndroid Build Coastguard Worker        "--profile-dynamo-cache-lookup",
3402*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3403*da0073e9SAndroid Build Coastguard Worker        help="profiles TorchDynamo cache lookup",
3404*da0073e9SAndroid Build Coastguard Worker    )
3405*da0073e9SAndroid Build Coastguard Worker
3406*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3407*da0073e9SAndroid Build Coastguard Worker        "--snapshot-memory",
3408*da0073e9SAndroid Build Coastguard Worker        "--snapshot_memory",
3409*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3410*da0073e9SAndroid Build Coastguard Worker        help="Enables Memory Snapshot tool for memory deep dives: https://pytorch.org/blog/understanding-gpu-memory-1/",
3411*da0073e9SAndroid Build Coastguard Worker    )
3412*da0073e9SAndroid Build Coastguard Worker
3413*da0073e9SAndroid Build Coastguard Worker    group_latency = parser.add_mutually_exclusive_group()
3414*da0073e9SAndroid Build Coastguard Worker    group_latency.add_argument(
3415*da0073e9SAndroid Build Coastguard Worker        "--cold-start-latency",
3416*da0073e9SAndroid Build Coastguard Worker        "--cold_start_latency",
3417*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3418*da0073e9SAndroid Build Coastguard Worker        help="Use a fresh triton cachedir when running each model, to force cold-start compile.",
3419*da0073e9SAndroid Build Coastguard Worker    )
3420*da0073e9SAndroid Build Coastguard Worker    group_latency.add_argument(
3421*da0073e9SAndroid Build Coastguard Worker        "--warm-start-latency",
3422*da0073e9SAndroid Build Coastguard Worker        "--warm_start_latency",
3423*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3424*da0073e9SAndroid Build Coastguard Worker        help="Run model(s) twice and preseve caches in between to enable a 'warm start' on the 2nd run",
3425*da0073e9SAndroid Build Coastguard Worker    )
3426*da0073e9SAndroid Build Coastguard Worker
3427*da0073e9SAndroid Build Coastguard Worker    group_fuser = parser.add_mutually_exclusive_group()
3428*da0073e9SAndroid Build Coastguard Worker    # --nvfuser is now the default, keep the option to not break scripts
3429*da0073e9SAndroid Build Coastguard Worker    group_fuser.add_argument("--nvfuser", action="store_true", help=argparse.SUPPRESS)
3430*da0073e9SAndroid Build Coastguard Worker    group_fuser.add_argument("--nnc", action="store_true", help="enable NNC for GPUs")
3431*da0073e9SAndroid Build Coastguard Worker
3432*da0073e9SAndroid Build Coastguard Worker    group_prec = parser.add_mutually_exclusive_group()
3433*da0073e9SAndroid Build Coastguard Worker    group_prec.add_argument("--float16", action="store_true", help="cast model to fp16")
3434*da0073e9SAndroid Build Coastguard Worker    group_prec.add_argument(
3435*da0073e9SAndroid Build Coastguard Worker        "--bfloat16", action="store_true", help="cast model to bf16"
3436*da0073e9SAndroid Build Coastguard Worker    )
3437*da0073e9SAndroid Build Coastguard Worker    group_prec.add_argument("--float32", action="store_true", help="cast model to fp32")
3438*da0073e9SAndroid Build Coastguard Worker    group_prec.add_argument(
3439*da0073e9SAndroid Build Coastguard Worker        "--amp", action="store_true", help="use automatic mixed precision"
3440*da0073e9SAndroid Build Coastguard Worker    )
3441*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
3442*da0073e9SAndroid Build Coastguard Worker        "--amp-dtype",
3443*da0073e9SAndroid Build Coastguard Worker        choices=("bfloat16", "float16"),
3444*da0073e9SAndroid Build Coastguard Worker        help="the data type used with automatic mixed precision",
3445*da0073e9SAndroid Build Coastguard Worker    )
3446*da0073e9SAndroid Build Coastguard Worker    group_printout = parser.add_mutually_exclusive_group()
3447*da0073e9SAndroid Build Coastguard Worker    group_printout.add_argument(
3448*da0073e9SAndroid Build Coastguard Worker        "--verbose", "-v", action="store_true", help="enable verbose debug printouts"
3449*da0073e9SAndroid Build Coastguard Worker    )
3450*da0073e9SAndroid Build Coastguard Worker    group_printout.add_argument(
3451*da0073e9SAndroid Build Coastguard Worker        "--quiet", "-q", action="store_true", help="suppress debug printouts"
3452*da0073e9SAndroid Build Coastguard Worker    )
3453*da0073e9SAndroid Build Coastguard Worker
3454*da0073e9SAndroid Build Coastguard Worker    group = parser.add_mutually_exclusive_group()
3455*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3456*da0073e9SAndroid Build Coastguard Worker        "--coverage", action="store_true", help="(default) " + help(coverage_experiment)
3457*da0073e9SAndroid Build Coastguard Worker    )
3458*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3459*da0073e9SAndroid Build Coastguard Worker        "--overhead", action="store_true", help=help(overhead_experiment)
3460*da0073e9SAndroid Build Coastguard Worker    )
3461*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3462*da0073e9SAndroid Build Coastguard Worker        "--speedup-dynamo-ts",
3463*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3464*da0073e9SAndroid Build Coastguard Worker        help="TorchDynamo frontend with torchscript backend",
3465*da0073e9SAndroid Build Coastguard Worker    )
3466*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3467*da0073e9SAndroid Build Coastguard Worker        "--speedup-fx2trt", action="store_true", help=help(speedup_experiment_fx2trt)
3468*da0073e9SAndroid Build Coastguard Worker    )
3469*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3470*da0073e9SAndroid Build Coastguard Worker        "--speedup-fx2trt-fp16",
3471*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3472*da0073e9SAndroid Build Coastguard Worker        help=help(speedup_experiment_fx2trt),
3473*da0073e9SAndroid Build Coastguard Worker    )
3474*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3475*da0073e9SAndroid Build Coastguard Worker        "--print-fx",
3476*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3477*da0073e9SAndroid Build Coastguard Worker        help="Print fx traces captured from model",
3478*da0073e9SAndroid Build Coastguard Worker    )
3479*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3480*da0073e9SAndroid Build Coastguard Worker        "--print-aten-ops",
3481*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3482*da0073e9SAndroid Build Coastguard Worker        help="Print traces of aten ops captured by AOT autograd",
3483*da0073e9SAndroid Build Coastguard Worker    )
3484*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3485*da0073e9SAndroid Build Coastguard Worker        "--inductor",
3486*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3487*da0073e9SAndroid Build Coastguard Worker        help="Measure speedup with TorchInductor",
3488*da0073e9SAndroid Build Coastguard Worker    )
3489*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3490*da0073e9SAndroid Build Coastguard Worker        "--quantization",
3491*da0073e9SAndroid Build Coastguard Worker        choices=[
3492*da0073e9SAndroid Build Coastguard Worker            "int8dynamic",
3493*da0073e9SAndroid Build Coastguard Worker            "int8weightonly",
3494*da0073e9SAndroid Build Coastguard Worker            "int4weightonly",
3495*da0073e9SAndroid Build Coastguard Worker            "autoquant",
3496*da0073e9SAndroid Build Coastguard Worker            "noquant",
3497*da0073e9SAndroid Build Coastguard Worker        ],
3498*da0073e9SAndroid Build Coastguard Worker        default=None,
3499*da0073e9SAndroid Build Coastguard Worker        help="Measure speedup of torchao quantization with TorchInductor baseline",
3500*da0073e9SAndroid Build Coastguard Worker    )
3501*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3502*da0073e9SAndroid Build Coastguard Worker        "--export",
3503*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3504*da0073e9SAndroid Build Coastguard Worker        help="Measure pass rate with export",
3505*da0073e9SAndroid Build Coastguard Worker    )
3506*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3507*da0073e9SAndroid Build Coastguard Worker        "--export-aot-inductor",
3508*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3509*da0073e9SAndroid Build Coastguard Worker        help="Measure pass rate with Export+AOTInductor",
3510*da0073e9SAndroid Build Coastguard Worker    )
3511*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3512*da0073e9SAndroid Build Coastguard Worker        "--xla", action="store_true", help="Compare TorchXLA to eager PyTorch"
3513*da0073e9SAndroid Build Coastguard Worker    )
3514*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3515*da0073e9SAndroid Build Coastguard Worker        "--torchscript-onnx",
3516*da0073e9SAndroid Build Coastguard Worker        "--torchscript_onnx",
3517*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3518*da0073e9SAndroid Build Coastguard Worker        help="Measure speedup with TorchScript ONNX, i.e. `torch.onnx.export`",
3519*da0073e9SAndroid Build Coastguard Worker    )
3520*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3521*da0073e9SAndroid Build Coastguard Worker        "--dynamo-onnx",
3522*da0073e9SAndroid Build Coastguard Worker        "--dynamo_onnx",
3523*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3524*da0073e9SAndroid Build Coastguard Worker        help="Measure speedup with Dynamo ONNX, i.e. `torch.onnx.dynamo_export`",
3525*da0073e9SAndroid Build Coastguard Worker    )
3526*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3527*da0073e9SAndroid Build Coastguard Worker        "--dynamo-onnx-aot-inline",
3528*da0073e9SAndroid Build Coastguard Worker        "--dynamo_onnx_aot_inline",
3529*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3530*da0073e9SAndroid Build Coastguard Worker        help="Measure speedup with Dynamo ONNX AOT Inline, i.e. `torch.onnx.dynamo_export`",
3531*da0073e9SAndroid Build Coastguard Worker    )
3532*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3533*da0073e9SAndroid Build Coastguard Worker        "--dynamo-onnx-aot-optimize",
3534*da0073e9SAndroid Build Coastguard Worker        "--dynamo_onnx_aot_optimize",
3535*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3536*da0073e9SAndroid Build Coastguard Worker        help="Measure speedup with Dynamo ONNX w/ ort fusions, i.e. `torch.onnx.dynamo_export`",
3537*da0073e9SAndroid Build Coastguard Worker    )
3538*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3539*da0073e9SAndroid Build Coastguard Worker        "--backend",
3540*da0073e9SAndroid Build Coastguard Worker        choices=torch._dynamo.list_backends(exclude_tags=None),
3541*da0073e9SAndroid Build Coastguard Worker        help="measure speedup with a given backend",
3542*da0073e9SAndroid Build Coastguard Worker    )
3543*da0073e9SAndroid Build Coastguard Worker    group.add_argument("--nothing", action="store_true", help=help(null_experiment))
3544*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3545*da0073e9SAndroid Build Coastguard Worker        "--log-conv-args",
3546*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3547*da0073e9SAndroid Build Coastguard Worker        help="Dump convolution input/weight/bias's shape/stride/dtype and other options to json",
3548*da0073e9SAndroid Build Coastguard Worker    )
3549*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3550*da0073e9SAndroid Build Coastguard Worker        "--recompile-profiler",
3551*da0073e9SAndroid Build Coastguard Worker        "--recompile_profiler",
3552*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3553*da0073e9SAndroid Build Coastguard Worker        help="Run the dynamo recompilation profiler on each model.",
3554*da0073e9SAndroid Build Coastguard Worker    )
3555*da0073e9SAndroid Build Coastguard Worker    group.add_argument(
3556*da0073e9SAndroid Build Coastguard Worker        "--find-batch-sizes",
3557*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3558*da0073e9SAndroid Build Coastguard Worker        help="finds the largest batch size that could fit on GPUs",
3559*da0073e9SAndroid Build Coastguard Worker    )
3560*da0073e9SAndroid Build Coastguard Worker
3561*da0073e9SAndroid Build Coastguard Worker    mode_group = parser.add_mutually_exclusive_group(required=True)
3562*da0073e9SAndroid Build Coastguard Worker    mode_group.add_argument(
3563*da0073e9SAndroid Build Coastguard Worker        "--accuracy",
3564*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3565*da0073e9SAndroid Build Coastguard Worker        help="Checks accuracy with small batch size and eval mode",
3566*da0073e9SAndroid Build Coastguard Worker    )
3567*da0073e9SAndroid Build Coastguard Worker    mode_group.add_argument(
3568*da0073e9SAndroid Build Coastguard Worker        "--performance", action="store_true", help="Measures performance speedup"
3569*da0073e9SAndroid Build Coastguard Worker    )
3570*da0073e9SAndroid Build Coastguard Worker    mode_group.add_argument(
3571*da0073e9SAndroid Build Coastguard Worker        "--tolerance",
3572*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3573*da0073e9SAndroid Build Coastguard Worker        help="extracts the tolerance for each model with small batch size and eval mode",
3574*da0073e9SAndroid Build Coastguard Worker    )
3575*da0073e9SAndroid Build Coastguard Worker    run_mode_group = parser.add_mutually_exclusive_group(required=True)
3576*da0073e9SAndroid Build Coastguard Worker    run_mode_group.add_argument(
3577*da0073e9SAndroid Build Coastguard Worker        "--training",
3578*da0073e9SAndroid Build Coastguard Worker        action="store_true",
3579*da0073e9SAndroid Build Coastguard Worker        help="Performs training",
3580*da0073e9SAndroid Build Coastguard Worker    )
3581*da0073e9SAndroid Build Coastguard Worker    run_mode_group.add_argument(
3582*da0073e9SAndroid Build Coastguard Worker        "--inference", action="store_true", help="Performs inference"
3583*da0073e9SAndroid Build Coastguard Worker    )
3584*da0073e9SAndroid Build Coastguard Worker    return parser.parse_args(args)
3585*da0073e9SAndroid Build Coastguard Worker
3586*da0073e9SAndroid Build Coastguard Worker
3587*da0073e9SAndroid Build Coastguard Workerdef process_entry(rank, runner, original_dir, args):
3588*da0073e9SAndroid Build Coastguard Worker    args.rank = rank
3589*da0073e9SAndroid Build Coastguard Worker    with maybe_init_distributed(
3590*da0073e9SAndroid Build Coastguard Worker        args.init_distributed,
3591*da0073e9SAndroid Build Coastguard Worker        rank=rank,
3592*da0073e9SAndroid Build Coastguard Worker        world_size=args.world_size,
3593*da0073e9SAndroid Build Coastguard Worker        port=args.distributed_master_port,
3594*da0073e9SAndroid Build Coastguard Worker    ):
3595*da0073e9SAndroid Build Coastguard Worker        return run(runner, args, original_dir)
3596*da0073e9SAndroid Build Coastguard Worker
3597*da0073e9SAndroid Build Coastguard Worker
3598*da0073e9SAndroid Build Coastguard Workerdef maybe_fresh_cache(args):
3599*da0073e9SAndroid Build Coastguard Worker    cache_dir_assigned = "TORCHINDUCTOR_CACHE_DIR" in os.environ
3600*da0073e9SAndroid Build Coastguard Worker    if not cache_dir_assigned and (
3601*da0073e9SAndroid Build Coastguard Worker        args.cold_start_latency or args.warm_start_latency or args.ci
3602*da0073e9SAndroid Build Coastguard Worker    ):
3603*da0073e9SAndroid Build Coastguard Worker        return fresh_inductor_cache()
3604*da0073e9SAndroid Build Coastguard Worker    else:
3605*da0073e9SAndroid Build Coastguard Worker        return contextlib.nullcontext()
3606*da0073e9SAndroid Build Coastguard Worker
3607*da0073e9SAndroid Build Coastguard Worker
3608*da0073e9SAndroid Build Coastguard Workerdef main(runner, original_dir=None, args=None):
3609*da0073e9SAndroid Build Coastguard Worker    if original_dir:
3610*da0073e9SAndroid Build Coastguard Worker        os.chdir(original_dir)
3611*da0073e9SAndroid Build Coastguard Worker    args = parse_args() if not args else parse_args(args)
3612*da0073e9SAndroid Build Coastguard Worker    if args.baseline:
3613*da0073e9SAndroid Build Coastguard Worker        args.baseline = os.path.abspath(args.baseline)
3614*da0073e9SAndroid Build Coastguard Worker
3615*da0073e9SAndroid Build Coastguard Worker    if should_diff_branch(args):
3616*da0073e9SAndroid Build Coastguard Worker        import git
3617*da0073e9SAndroid Build Coastguard Worker
3618*da0073e9SAndroid Build Coastguard Worker        # We do this here so we error out earlier if there's an issue
3619*da0073e9SAndroid Build Coastguard Worker        repo = git.Repo()
3620*da0073e9SAndroid Build Coastguard Worker        if repo.is_dirty():
3621*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
3622*da0073e9SAndroid Build Coastguard Worker                "--diff-branch called on dirty branch. Commit, stash, or reset."
3623*da0073e9SAndroid Build Coastguard Worker            )
3624*da0073e9SAndroid Build Coastguard Worker        main_branch = repo.active_branch.name
3625*da0073e9SAndroid Build Coastguard Worker        if main_branch == args.diff_branch:
3626*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
3627*da0073e9SAndroid Build Coastguard Worker                f"--diff-branch: current branch is same as {args.diff_branch} branch, what are you diffing?"
3628*da0073e9SAndroid Build Coastguard Worker            )
3629*da0073e9SAndroid Build Coastguard Worker
3630*da0073e9SAndroid Build Coastguard Worker    with maybe_fresh_cache(args):
3631*da0073e9SAndroid Build Coastguard Worker        args.init_distributed = args.only and args.multiprocess
3632*da0073e9SAndroid Build Coastguard Worker        if args.init_distributed:
3633*da0073e9SAndroid Build Coastguard Worker            # NB: Do NOT query device count before CUDA initialization; we're
3634*da0073e9SAndroid Build Coastguard Worker            # going to overwrite CUDA_VISIBLE_DEVICES and this will result in
3635*da0073e9SAndroid Build Coastguard Worker            # https://github.com/pytorch/pytorch/issues/107300
3636*da0073e9SAndroid Build Coastguard Worker            device_count = torch.cuda.device_count()
3637*da0073e9SAndroid Build Coastguard Worker            if device_count <= 1:
3638*da0073e9SAndroid Build Coastguard Worker                log.warning(
3639*da0073e9SAndroid Build Coastguard Worker                    "The use multiprocess flag is set but there are <= 1 devices available."
3640*da0073e9SAndroid Build Coastguard Worker                )
3641*da0073e9SAndroid Build Coastguard Worker            # multiprocess path
3642*da0073e9SAndroid Build Coastguard Worker            args.world_size = device_count
3643*da0073e9SAndroid Build Coastguard Worker            mp.spawn(
3644*da0073e9SAndroid Build Coastguard Worker                process_entry, args=(runner, original_dir, args), nprocs=device_count
3645*da0073e9SAndroid Build Coastguard Worker            )
3646*da0073e9SAndroid Build Coastguard Worker        elif args.only and args.warm_start_latency:
3647*da0073e9SAndroid Build Coastguard Worker            # Warm start mode. Enable FX graph caching and perform back-to-back runs in
3648*da0073e9SAndroid Build Coastguard Worker            # separate processes (but ensure the inductor cache is preserved across runs).
3649*da0073e9SAndroid Build Coastguard Worker            env = os.environ.copy()
3650*da0073e9SAndroid Build Coastguard Worker            env["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
3651*da0073e9SAndroid Build Coastguard Worker            cmd = [sys.executable] + sys.argv
3652*da0073e9SAndroid Build Coastguard Worker            cmd.remove("--warm-start-latency")
3653*da0073e9SAndroid Build Coastguard Worker
3654*da0073e9SAndroid Build Coastguard Worker            print(f"Performing cold-start run for {args.only}")
3655*da0073e9SAndroid Build Coastguard Worker            warmup_cmd = cmd + ["--repeat=1", "--disable-output"]
3656*da0073e9SAndroid Build Coastguard Worker            subprocess.check_call(warmup_cmd, timeout=args.timeout, env=env)
3657*da0073e9SAndroid Build Coastguard Worker
3658*da0073e9SAndroid Build Coastguard Worker            print(f"Performing warm-start run for {args.only}")
3659*da0073e9SAndroid Build Coastguard Worker            subprocess.check_call(cmd, timeout=args.timeout, env=env)
3660*da0073e9SAndroid Build Coastguard Worker        else:
3661*da0073e9SAndroid Build Coastguard Worker            # single process path just uses the main process
3662*da0073e9SAndroid Build Coastguard Worker            args.world_size = 1
3663*da0073e9SAndroid Build Coastguard Worker            process_entry(0, runner, original_dir, args)
3664*da0073e9SAndroid Build Coastguard Worker
3665*da0073e9SAndroid Build Coastguard Worker
3666*da0073e9SAndroid Build Coastguard Workerdef write_csv_when_exception(args, name: str, status: str, device=None):
3667*da0073e9SAndroid Build Coastguard Worker    print(status)
3668*da0073e9SAndroid Build Coastguard Worker    placeholder_batch_size = 0
3669*da0073e9SAndroid Build Coastguard Worker    devices = [device] if device is not None else args.devices
3670*da0073e9SAndroid Build Coastguard Worker    if args.accuracy:
3671*da0073e9SAndroid Build Coastguard Worker        headers = ["dev", "name", "batch_size", "accuracy"]
3672*da0073e9SAndroid Build Coastguard Worker        rows = [[device, name, placeholder_batch_size, status] for device in devices]
3673*da0073e9SAndroid Build Coastguard Worker    elif args.performance:
3674*da0073e9SAndroid Build Coastguard Worker        headers = ["dev", "name", "batch_size", "speedup", "abs_latency"]
3675*da0073e9SAndroid Build Coastguard Worker        rows = [[device, name, placeholder_batch_size, 0.0, 0.0] for device in devices]
3676*da0073e9SAndroid Build Coastguard Worker    else:
3677*da0073e9SAndroid Build Coastguard Worker        headers = []
3678*da0073e9SAndroid Build Coastguard Worker        rows = [[device, name, placeholder_batch_size, 0.0] for device in devices]
3679*da0073e9SAndroid Build Coastguard Worker
3680*da0073e9SAndroid Build Coastguard Worker    for row in rows:
3681*da0073e9SAndroid Build Coastguard Worker        output_csv(output_filename, headers, row)
3682*da0073e9SAndroid Build Coastguard Worker
3683*da0073e9SAndroid Build Coastguard Worker
3684*da0073e9SAndroid Build Coastguard Workerdef run(runner, args, original_dir=None):
3685*da0073e9SAndroid Build Coastguard Worker    # Pass the parsed args object to benchmark runner object
3686*da0073e9SAndroid Build Coastguard Worker    runner.args = args
3687*da0073e9SAndroid Build Coastguard Worker
3688*da0073e9SAndroid Build Coastguard Worker    args.filter = args.filter or [r"."]
3689*da0073e9SAndroid Build Coastguard Worker    args.exclude = args.exclude or [r"^$"]
3690*da0073e9SAndroid Build Coastguard Worker    args.exclude_exact = args.exclude_exact or []
3691*da0073e9SAndroid Build Coastguard Worker
3692*da0073e9SAndroid Build Coastguard Worker    if args.inductor:
3693*da0073e9SAndroid Build Coastguard Worker        assert args.backend is None
3694*da0073e9SAndroid Build Coastguard Worker        args.backend = "inductor"
3695*da0073e9SAndroid Build Coastguard Worker    if args.quantization:
3696*da0073e9SAndroid Build Coastguard Worker        assert args.backend is None
3697*da0073e9SAndroid Build Coastguard Worker        args.backend = "torchao"
3698*da0073e9SAndroid Build Coastguard Worker    if args.dynamic_batch_only:
3699*da0073e9SAndroid Build Coastguard Worker        args.dynamic_shapes = True
3700*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.config.assume_static_by_default = True
3701*da0073e9SAndroid Build Coastguard Worker    if args.dynamic_shapes:
3702*da0073e9SAndroid Build Coastguard Worker        if not args.dynamic_batch_only:
3703*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.config.assume_static_by_default = False
3704*da0073e9SAndroid Build Coastguard Worker    if args.propagate_real_tensors:
3705*da0073e9SAndroid Build Coastguard Worker        # TODO: Separate flag for data dependent
3706*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.config.capture_scalar_outputs = True
3707*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.config.capture_dynamic_output_shape_ops = True
3708*da0073e9SAndroid Build Coastguard Worker        torch._functorch.config.fake_tensor_propagate_real_tensors = True
3709*da0073e9SAndroid Build Coastguard Worker    if args.specialize_int:
3710*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.config.specialize_int = True
3711*da0073e9SAndroid Build Coastguard Worker    if args.ci:
3712*da0073e9SAndroid Build Coastguard Worker        if args.accuracy:
3713*da0073e9SAndroid Build Coastguard Worker            # Run fewer iterations when checking accuracy
3714*da0073e9SAndroid Build Coastguard Worker            args.repeat = min(args.repeat, 2)
3715*da0073e9SAndroid Build Coastguard Worker
3716*da0073e9SAndroid Build Coastguard Worker            # Set translation validation on by default on CI accuracy runs.
3717*da0073e9SAndroid Build Coastguard Worker            torch.fx.experimental._config.translation_validation = True
3718*da0073e9SAndroid Build Coastguard Worker
3719*da0073e9SAndroid Build Coastguard Worker        ci = functools.partial(
3720*da0073e9SAndroid Build Coastguard Worker            CI, args.backend, training=args.training, dynamic=args.dynamic_shapes
3721*da0073e9SAndroid Build Coastguard Worker        )
3722*da0073e9SAndroid Build Coastguard Worker    if args.ddp:
3723*da0073e9SAndroid Build Coastguard Worker        assert args.training, "DDP benchmark requires --training mode"
3724*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.config.optimize_ddp = args.optimize_ddp_mode
3725*da0073e9SAndroid Build Coastguard Worker        if args.only == "dlrm":
3726*da0073e9SAndroid Build Coastguard Worker            log.error(
3727*da0073e9SAndroid Build Coastguard Worker                "DLRM+DDP is unsupported as it requires sharding the embedding layer separately from DDP"
3728*da0073e9SAndroid Build Coastguard Worker            )
3729*da0073e9SAndroid Build Coastguard Worker            return sys.exit(-1)
3730*da0073e9SAndroid Build Coastguard Worker    if args.accuracy:
3731*da0073e9SAndroid Build Coastguard Worker        # Use small batch size. We use >1 batch size to ensure we test
3732*da0073e9SAndroid Build Coastguard Worker        # batch_norm type of operators that work on batch dims.
3733*da0073e9SAndroid Build Coastguard Worker        # TODO - Go through the failures for batch size = 2
3734*da0073e9SAndroid Build Coastguard Worker        if args.batch_size is None:
3735*da0073e9SAndroid Build Coastguard Worker            if runner.suite_name == "huggingface":
3736*da0073e9SAndroid Build Coastguard Worker                args.batch_size = 1
3737*da0073e9SAndroid Build Coastguard Worker            elif runner.suite_name == "torchbench":
3738*da0073e9SAndroid Build Coastguard Worker                args.batch_size = 4
3739*da0073e9SAndroid Build Coastguard Worker            else:
3740*da0073e9SAndroid Build Coastguard Worker                # Larger batch size of TIMM models to have stable batch_norm
3741*da0073e9SAndroid Build Coastguard Worker                assert runner.suite_name == "timm_models"
3742*da0073e9SAndroid Build Coastguard Worker                args.batch_size = 8
3743*da0073e9SAndroid Build Coastguard Worker
3744*da0073e9SAndroid Build Coastguard Worker        # Remove sources of randomness
3745*da0073e9SAndroid Build Coastguard Worker        if runner.suite_name not in ("timm_models", "huggingface"):
3746*da0073e9SAndroid Build Coastguard Worker            # TODO - Using train mode for timm_models and HF models. Move to train mode for Torchbench as well.
3747*da0073e9SAndroid Build Coastguard Worker            args.use_eval_mode = True
3748*da0073e9SAndroid Build Coastguard Worker        inductor_config.fallback_random = True
3749*da0073e9SAndroid Build Coastguard Worker        if args.only is not None and args.only not in {
3750*da0073e9SAndroid Build Coastguard Worker            "alexnet",
3751*da0073e9SAndroid Build Coastguard Worker            "Background_Matting",
3752*da0073e9SAndroid Build Coastguard Worker            "pytorch_CycleGAN_and_pix2pix",
3753*da0073e9SAndroid Build Coastguard Worker            "pytorch_unet",
3754*da0073e9SAndroid Build Coastguard Worker            "Super_SloMo",
3755*da0073e9SAndroid Build Coastguard Worker            "vgg16",
3756*da0073e9SAndroid Build Coastguard Worker            # https://github.com/pytorch/pytorch/issues/96724
3757*da0073e9SAndroid Build Coastguard Worker            "Wav2Vec2ForCTC",
3758*da0073e9SAndroid Build Coastguard Worker            "Wav2Vec2ForPreTraining",
3759*da0073e9SAndroid Build Coastguard Worker            "sam",
3760*da0073e9SAndroid Build Coastguard Worker            "sam_fast",
3761*da0073e9SAndroid Build Coastguard Worker            "resnet50_quantized_qat",
3762*da0073e9SAndroid Build Coastguard Worker            "mobilenet_v2_quantized_qat",
3763*da0073e9SAndroid Build Coastguard Worker        }:
3764*da0073e9SAndroid Build Coastguard Worker            # some of the models do not support use_deterministic_algorithms
3765*da0073e9SAndroid Build Coastguard Worker            torch.use_deterministic_algorithms(True)
3766*da0073e9SAndroid Build Coastguard Worker        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
3767*da0073e9SAndroid Build Coastguard Worker        torch.backends.cudnn.deterministic = True
3768*da0073e9SAndroid Build Coastguard Worker        torch.backends.cudnn.allow_tf32 = False
3769*da0073e9SAndroid Build Coastguard Worker        torch.backends.cudnn.benchmark = False
3770*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_tf32 = False
3771*da0073e9SAndroid Build Coastguard Worker
3772*da0073e9SAndroid Build Coastguard Worker        # Remove randomeness when torch manual seed is called
3773*da0073e9SAndroid Build Coastguard Worker        patch_torch_manual_seed()
3774*da0073e9SAndroid Build Coastguard Worker
3775*da0073e9SAndroid Build Coastguard Worker        # Some models e.g. yolov3 assert batch size on n_gpus
3776*da0073e9SAndroid Build Coastguard Worker        if "CUDA_VISIBLE_DEVICES" not in os.environ and not args.multiprocess:
3777*da0073e9SAndroid Build Coastguard Worker            args.device_index = "0"
3778*da0073e9SAndroid Build Coastguard Worker
3779*da0073e9SAndroid Build Coastguard Worker        # Stricter check to disable fallbacks
3780*da0073e9SAndroid Build Coastguard Worker        args.suppress_errors = False
3781*da0073e9SAndroid Build Coastguard Worker
3782*da0073e9SAndroid Build Coastguard Worker    if args.device_index is not None:
3783*da0073e9SAndroid Build Coastguard Worker        if args.multiprocess:
3784*da0073e9SAndroid Build Coastguard Worker            print("Cannot specify both --device_index and --multiprocess")
3785*da0073e9SAndroid Build Coastguard Worker            return sys.exit(-1)
3786*da0073e9SAndroid Build Coastguard Worker        os.environ["CUDA_VISIBLE_DEVICES"] = args.device_index
3787*da0073e9SAndroid Build Coastguard Worker
3788*da0073e9SAndroid Build Coastguard Worker    elif args.performance:
3789*da0073e9SAndroid Build Coastguard Worker        # Ensure that we test on real scenarios
3790*da0073e9SAndroid Build Coastguard Worker        args.use_eval_mode = False
3791*da0073e9SAndroid Build Coastguard Worker
3792*da0073e9SAndroid Build Coastguard Worker    if args.partition_id > args.total_partitions or args.partition_id < 0:
3793*da0073e9SAndroid Build Coastguard Worker        print("Invalid partition id")
3794*da0073e9SAndroid Build Coastguard Worker        return sys.exit(-1)
3795*da0073e9SAndroid Build Coastguard Worker
3796*da0073e9SAndroid Build Coastguard Worker    if not args.devices:
3797*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
3798*da0073e9SAndroid Build Coastguard Worker            args.devices = ["cuda"]
3799*da0073e9SAndroid Build Coastguard Worker        else:
3800*da0073e9SAndroid Build Coastguard Worker            log.warning("torch.cuda.is_available() == False, using CPU")
3801*da0073e9SAndroid Build Coastguard Worker            args.devices = ["cpu"]
3802*da0073e9SAndroid Build Coastguard Worker
3803*da0073e9SAndroid Build Coastguard Worker    if args.devices != ["cpu"] and (HAS_CUDA or HAS_XPU):
3804*da0073e9SAndroid Build Coastguard Worker        global synchronize
3805*da0073e9SAndroid Build Coastguard Worker        synchronize = torch.cuda.synchronize if HAS_CUDA else torch.xpu.synchronize
3806*da0073e9SAndroid Build Coastguard Worker
3807*da0073e9SAndroid Build Coastguard Worker    if (
3808*da0073e9SAndroid Build Coastguard Worker        args.devices == ["cuda"]
3809*da0073e9SAndroid Build Coastguard Worker        and torch.cuda.get_device_properties(0).total_memory < 25 * 2**30
3810*da0073e9SAndroid Build Coastguard Worker    ):
3811*da0073e9SAndroid Build Coastguard Worker        # OOM errors on an RTX 3090 with 24gb RAM
3812*da0073e9SAndroid Build Coastguard Worker        runner.skip_models.update(
3813*da0073e9SAndroid Build Coastguard Worker            {
3814*da0073e9SAndroid Build Coastguard Worker                # torchbench
3815*da0073e9SAndroid Build Coastguard Worker                "hf_Longformer",
3816*da0073e9SAndroid Build Coastguard Worker                "timm_nfnet",
3817*da0073e9SAndroid Build Coastguard Worker                "timm_efficientdet",
3818*da0073e9SAndroid Build Coastguard Worker            }
3819*da0073e9SAndroid Build Coastguard Worker        )
3820*da0073e9SAndroid Build Coastguard Worker        if args.training:
3821*da0073e9SAndroid Build Coastguard Worker            runner.skip_models.add("hf_T5")
3822*da0073e9SAndroid Build Coastguard Worker
3823*da0073e9SAndroid Build Coastguard Worker    if args.nnc:
3824*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_override_can_fuse_on_cpu(True)
3825*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_override_can_fuse_on_gpu(True)
3826*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_texpr_fuser_enabled(True)
3827*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_nvfuser_enabled(False)
3828*da0073e9SAndroid Build Coastguard Worker
3829*da0073e9SAndroid Build Coastguard Worker    if args.threads:
3830*da0073e9SAndroid Build Coastguard Worker        torch.set_num_threads(args.threads)
3831*da0073e9SAndroid Build Coastguard Worker
3832*da0073e9SAndroid Build Coastguard Worker    if args.verbose:
3833*da0073e9SAndroid Build Coastguard Worker        torch._logging.set_logs(dynamo=logging.DEBUG)
3834*da0073e9SAndroid Build Coastguard Worker
3835*da0073e9SAndroid Build Coastguard Worker    if args.print_graph_breaks:
3836*da0073e9SAndroid Build Coastguard Worker        torch._logging.set_logs(graph_breaks=True)
3837*da0073e9SAndroid Build Coastguard Worker
3838*da0073e9SAndroid Build Coastguard Worker    if args.quiet:
3839*da0073e9SAndroid Build Coastguard Worker        torch._logging.set_logs(dynamo=logging.ERROR)
3840*da0073e9SAndroid Build Coastguard Worker
3841*da0073e9SAndroid Build Coastguard Worker    torch._dynamo.config.suppress_errors = args.suppress_errors
3842*da0073e9SAndroid Build Coastguard Worker
3843*da0073e9SAndroid Build Coastguard Worker    if args.training:
3844*da0073e9SAndroid Build Coastguard Worker        runner.model_iter_fn = runner.forward_and_backward_pass
3845*da0073e9SAndroid Build Coastguard Worker        runner.skip_models.update(runner.skip_not_suitable_for_training_models)
3846*da0073e9SAndroid Build Coastguard Worker    else:
3847*da0073e9SAndroid Build Coastguard Worker        runner.model_iter_fn = runner.forward_pass
3848*da0073e9SAndroid Build Coastguard Worker
3849*da0073e9SAndroid Build Coastguard Worker    if args.fast:
3850*da0073e9SAndroid Build Coastguard Worker        runner.skip_models.update(runner.slow_models)
3851*da0073e9SAndroid Build Coastguard Worker
3852*da0073e9SAndroid Build Coastguard Worker    if args.devices == ["cpu"]:
3853*da0073e9SAndroid Build Coastguard Worker        runner.skip_models.update(runner.very_slow_models)
3854*da0073e9SAndroid Build Coastguard Worker        runner.skip_models.update(runner.skip_models_for_cpu)
3855*da0073e9SAndroid Build Coastguard Worker    elif args.devices == ["cuda"]:
3856*da0073e9SAndroid Build Coastguard Worker        runner.skip_models.update(runner.skip_models_for_cuda)
3857*da0073e9SAndroid Build Coastguard Worker
3858*da0073e9SAndroid Build Coastguard Worker    if not args.multiprocess:
3859*da0073e9SAndroid Build Coastguard Worker        runner.skip_models.update(runner.skip_multiprocess_models)
3860*da0073e9SAndroid Build Coastguard Worker
3861*da0073e9SAndroid Build Coastguard Worker    if args.freezing:
3862*da0073e9SAndroid Build Coastguard Worker        runner.skip_models.update(runner.skip_models_for_freezing)
3863*da0073e9SAndroid Build Coastguard Worker
3864*da0073e9SAndroid Build Coastguard Worker    if args.no_skip:
3865*da0073e9SAndroid Build Coastguard Worker        runner.skip_models.clear()
3866*da0073e9SAndroid Build Coastguard Worker
3867*da0073e9SAndroid Build Coastguard Worker    experiment = null_experiment
3868*da0073e9SAndroid Build Coastguard Worker    global current_name, current_device, current_batch_size, output_filename, disable_output, optimize_ctx, current_onnx_compiler
3869*da0073e9SAndroid Build Coastguard Worker    optimize_ctx = contextlib.nullcontext()
3870*da0073e9SAndroid Build Coastguard Worker
3871*da0073e9SAndroid Build Coastguard Worker    if args.disable_output:
3872*da0073e9SAndroid Build Coastguard Worker        disable_output = True
3873*da0073e9SAndroid Build Coastguard Worker
3874*da0073e9SAndroid Build Coastguard Worker    if args.overhead:
3875*da0073e9SAndroid Build Coastguard Worker        optimize_ctx = torch._dynamo.optimize(dummy_fx_compile, nopython=args.nopython)
3876*da0073e9SAndroid Build Coastguard Worker        experiment = speedup_experiment
3877*da0073e9SAndroid Build Coastguard Worker        output_filename = "overheads.csv"
3878*da0073e9SAndroid Build Coastguard Worker    elif args.inductor:
3879*da0073e9SAndroid Build Coastguard Worker        inductor_config.debug = args.verbose
3880*da0073e9SAndroid Build Coastguard Worker        if args.threads:
3881*da0073e9SAndroid Build Coastguard Worker            inductor_config.cpp.threads = args.threads
3882*da0073e9SAndroid Build Coastguard Worker
3883*da0073e9SAndroid Build Coastguard Worker        optimize_ctx = functools.partial(
3884*da0073e9SAndroid Build Coastguard Worker            torch.compile,
3885*da0073e9SAndroid Build Coastguard Worker            backend="inductor",
3886*da0073e9SAndroid Build Coastguard Worker            fullgraph=args.nopython,
3887*da0073e9SAndroid Build Coastguard Worker            mode=args.inductor_compile_mode,
3888*da0073e9SAndroid Build Coastguard Worker        )
3889*da0073e9SAndroid Build Coastguard Worker        experiment = speedup_experiment
3890*da0073e9SAndroid Build Coastguard Worker        output_filename = "inductor.csv"
3891*da0073e9SAndroid Build Coastguard Worker    elif args.export:
3892*da0073e9SAndroid Build Coastguard Worker        optimize_ctx = export
3893*da0073e9SAndroid Build Coastguard Worker        experiment = speedup_experiment
3894*da0073e9SAndroid Build Coastguard Worker        output_filename = "export.csv"
3895*da0073e9SAndroid Build Coastguard Worker    elif args.xla:
3896*da0073e9SAndroid Build Coastguard Worker        (dev,) = args.devices
3897*da0073e9SAndroid Build Coastguard Worker        os.environ["PJRT_DEVICE"] = {"cuda": "GPU", "cpu": "CPU"}[dev]
3898*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.mark_dynamic = MagicMock()
3899*da0073e9SAndroid Build Coastguard Worker        experiment = xla
3900*da0073e9SAndroid Build Coastguard Worker        output_filename = "xla.csv"
3901*da0073e9SAndroid Build Coastguard Worker    elif args.torchscript_onnx:
3902*da0073e9SAndroid Build Coastguard Worker        optimize_ctx = functools.partial(
3903*da0073e9SAndroid Build Coastguard Worker            optimize_onnx_ctx,
3904*da0073e9SAndroid Build Coastguard Worker            args.output_directory or ".",
3905*da0073e9SAndroid Build Coastguard Worker            OnnxModelFromTorchScript,
3906*da0073e9SAndroid Build Coastguard Worker            copy_before_export=args.performance,  # Accuarcy bench already did deepcopy
3907*da0073e9SAndroid Build Coastguard Worker        )
3908*da0073e9SAndroid Build Coastguard Worker        experiment = speedup_experiment_onnx
3909*da0073e9SAndroid Build Coastguard Worker        output_filename = "torchscript_onnx.csv"
3910*da0073e9SAndroid Build Coastguard Worker        current_onnx_compiler = "torchscript"
3911*da0073e9SAndroid Build Coastguard Worker    elif args.dynamo_onnx:
3912*da0073e9SAndroid Build Coastguard Worker        optimize_ctx = functools.partial(
3913*da0073e9SAndroid Build Coastguard Worker            optimize_onnx_ctx,
3914*da0073e9SAndroid Build Coastguard Worker            args.output_directory or ".",
3915*da0073e9SAndroid Build Coastguard Worker            OnnxModelFromDynamo,
3916*da0073e9SAndroid Build Coastguard Worker            dynamic_shapes=args.dynamic_shapes,
3917*da0073e9SAndroid Build Coastguard Worker            copy_before_export=args.performance,
3918*da0073e9SAndroid Build Coastguard Worker        )
3919*da0073e9SAndroid Build Coastguard Worker        experiment = speedup_experiment_onnx
3920*da0073e9SAndroid Build Coastguard Worker        output_filename = "dynamo_onnx.csv"
3921*da0073e9SAndroid Build Coastguard Worker        current_onnx_compiler = "dynamo"
3922*da0073e9SAndroid Build Coastguard Worker    elif args.dynamo_onnx_aot_inline:
3923*da0073e9SAndroid Build Coastguard Worker        optimize_ctx = functools.partial(
3924*da0073e9SAndroid Build Coastguard Worker            optimize_onnx_ctx,
3925*da0073e9SAndroid Build Coastguard Worker            args.output_directory or ".",
3926*da0073e9SAndroid Build Coastguard Worker            OnnxModelFromDynamoAotInline,
3927*da0073e9SAndroid Build Coastguard Worker            dynamic_shapes=args.dynamic_shapes,
3928*da0073e9SAndroid Build Coastguard Worker            copy_before_export=args.performance,
3929*da0073e9SAndroid Build Coastguard Worker        )
3930*da0073e9SAndroid Build Coastguard Worker        experiment = speedup_experiment_onnx
3931*da0073e9SAndroid Build Coastguard Worker        output_filename = "dynamo_onnx_aot_inline.csv"
3932*da0073e9SAndroid Build Coastguard Worker        current_onnx_compiler = "dynamo"
3933*da0073e9SAndroid Build Coastguard Worker    elif args.dynamo_onnx_aot_optimize:
3934*da0073e9SAndroid Build Coastguard Worker        optimize_ctx = functools.partial(
3935*da0073e9SAndroid Build Coastguard Worker            optimize_onnx_ctx,
3936*da0073e9SAndroid Build Coastguard Worker            args.output_directory or ".",
3937*da0073e9SAndroid Build Coastguard Worker            OnnxModelFromDynamoAotOptimize,
3938*da0073e9SAndroid Build Coastguard Worker            dynamic_shapes=args.dynamic_shapes,
3939*da0073e9SAndroid Build Coastguard Worker            copy_before_export=args.performance,
3940*da0073e9SAndroid Build Coastguard Worker        )
3941*da0073e9SAndroid Build Coastguard Worker        experiment = speedup_experiment_onnx
3942*da0073e9SAndroid Build Coastguard Worker        output_filename = "dynamo_onnx_aot_optimize.csv"
3943*da0073e9SAndroid Build Coastguard Worker        current_onnx_compiler = "dynamo"
3944*da0073e9SAndroid Build Coastguard Worker    elif args.speedup_dynamo_ts:
3945*da0073e9SAndroid Build Coastguard Worker        optimize_ctx = torch._dynamo.optimize("ts", nopython=args.nopython)
3946*da0073e9SAndroid Build Coastguard Worker        experiment = speedup_experiment
3947*da0073e9SAndroid Build Coastguard Worker        output_filename = "speedup_dynamo_ts.csv"
3948*da0073e9SAndroid Build Coastguard Worker    elif args.prims_nvfuser:
3949*da0073e9SAndroid Build Coastguard Worker        optimize_ctx = torch._dynamo.optimize("prims_nvfuser", nopython=args.nopython)
3950*da0073e9SAndroid Build Coastguard Worker        experiment = speedup_experiment
3951*da0073e9SAndroid Build Coastguard Worker        backend_str = "prims_nvfuser"
3952*da0073e9SAndroid Build Coastguard Worker        output_filename = f"accuracy_aot_{backend_str}.csv"
3953*da0073e9SAndroid Build Coastguard Worker    elif args.print_fx:
3954*da0073e9SAndroid Build Coastguard Worker        optimize_ctx = torch._dynamo.optimize(
3955*da0073e9SAndroid Build Coastguard Worker            print_fx,
3956*da0073e9SAndroid Build Coastguard Worker            nopython=args.nopython,
3957*da0073e9SAndroid Build Coastguard Worker        )
3958*da0073e9SAndroid Build Coastguard Worker    elif args.print_aten_ops:
3959*da0073e9SAndroid Build Coastguard Worker        optimize_ctx = torch._dynamo.optimize(
3960*da0073e9SAndroid Build Coastguard Worker            print_aten_ops,
3961*da0073e9SAndroid Build Coastguard Worker            nopython=args.nopython,
3962*da0073e9SAndroid Build Coastguard Worker        )
3963*da0073e9SAndroid Build Coastguard Worker    elif args.nothing:
3964*da0073e9SAndroid Build Coastguard Worker        optimize_ctx = nothing
3965*da0073e9SAndroid Build Coastguard Worker        experiment = speedup_experiment
3966*da0073e9SAndroid Build Coastguard Worker        output_filename = "nothing.csv"
3967*da0073e9SAndroid Build Coastguard Worker    elif args.backend or args.export_aot_inductor:
3968*da0073e9SAndroid Build Coastguard Worker        if args.export_aot_inductor:
3969*da0073e9SAndroid Build Coastguard Worker            assert not args.training, "AOTInductor only supports inference"
3970*da0073e9SAndroid Build Coastguard Worker            optimize_ctx = functools.partial(
3971*da0073e9SAndroid Build Coastguard Worker                export_aot_inductor, device=args.devices[0]
3972*da0073e9SAndroid Build Coastguard Worker            )
3973*da0073e9SAndroid Build Coastguard Worker
3974*da0073e9SAndroid Build Coastguard Worker            # AOTInductor doesn't support control flow yet
3975*da0073e9SAndroid Build Coastguard Worker            runner.skip_models.update(runner.skip_models_due_to_control_flow)
3976*da0073e9SAndroid Build Coastguard Worker        elif args.backend == "torchao":
3977*da0073e9SAndroid Build Coastguard Worker            assert "cuda" in args.devices, "Quantization requires CUDA device."
3978*da0073e9SAndroid Build Coastguard Worker            assert args.bfloat16, "Quantization requires dtype bfloat16."
3979*da0073e9SAndroid Build Coastguard Worker            try:
3980*da0073e9SAndroid Build Coastguard Worker                from torchao_backend import setup_baseline, torchao_optimize_ctx
3981*da0073e9SAndroid Build Coastguard Worker            except ImportError:
3982*da0073e9SAndroid Build Coastguard Worker                from userbenchmark.dynamo.dynamobench.torchao_backend import (
3983*da0073e9SAndroid Build Coastguard Worker                    setup_baseline,
3984*da0073e9SAndroid Build Coastguard Worker                    torchao_optimize_ctx,
3985*da0073e9SAndroid Build Coastguard Worker                )
3986*da0073e9SAndroid Build Coastguard Worker
3987*da0073e9SAndroid Build Coastguard Worker            setup_baseline()
3988*da0073e9SAndroid Build Coastguard Worker            baseline_ctx = functools.partial(
3989*da0073e9SAndroid Build Coastguard Worker                torch.compile,
3990*da0073e9SAndroid Build Coastguard Worker                backend="inductor",
3991*da0073e9SAndroid Build Coastguard Worker                fullgraph=args.nopython,
3992*da0073e9SAndroid Build Coastguard Worker                mode=args.inductor_compile_mode,
3993*da0073e9SAndroid Build Coastguard Worker            )
3994*da0073e9SAndroid Build Coastguard Worker            runner.model_iter_fn = baseline_ctx(runner.model_iter_fn)
3995*da0073e9SAndroid Build Coastguard Worker            optimize_ctx = torchao_optimize_ctx(args.quantization)
3996*da0073e9SAndroid Build Coastguard Worker        else:
3997*da0073e9SAndroid Build Coastguard Worker            optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
3998*da0073e9SAndroid Build Coastguard Worker        experiment = speedup_experiment
3999*da0073e9SAndroid Build Coastguard Worker        if args.accuracy:
4000*da0073e9SAndroid Build Coastguard Worker            output_filename = f"accuracy_{args.backend}.csv"
4001*da0073e9SAndroid Build Coastguard Worker        elif args.tolerance:
4002*da0073e9SAndroid Build Coastguard Worker            output_filename = f"tolerance_{args.backend}.csv"
4003*da0073e9SAndroid Build Coastguard Worker        else:
4004*da0073e9SAndroid Build Coastguard Worker            output_filename = f"speedup_{args.backend}.csv"
4005*da0073e9SAndroid Build Coastguard Worker    elif args.recompile_profiler:
4006*da0073e9SAndroid Build Coastguard Worker        output_filename = "recompile_profiler_log.csv"
4007*da0073e9SAndroid Build Coastguard Worker        experiment = recompile_profiler_experiment
4008*da0073e9SAndroid Build Coastguard Worker    else:
4009*da0073e9SAndroid Build Coastguard Worker        optimize_ctx = torch._dynamo.optimize(
4010*da0073e9SAndroid Build Coastguard Worker            fx_insert_profiling, nopython=args.nopython
4011*da0073e9SAndroid Build Coastguard Worker        )
4012*da0073e9SAndroid Build Coastguard Worker        experiment = coverage_experiment
4013*da0073e9SAndroid Build Coastguard Worker        output_filename = "coverage.csv"
4014*da0073e9SAndroid Build Coastguard Worker
4015*da0073e9SAndroid Build Coastguard Worker    if args.inductor or args.backend == "inductor" or args.export_aot_inductor:
4016*da0073e9SAndroid Build Coastguard Worker        inductor_config.triton.cudagraphs = not args.disable_cudagraphs
4017*da0073e9SAndroid Build Coastguard Worker        inductor_config.triton.persistent_reductions = (
4018*da0073e9SAndroid Build Coastguard Worker            not args.disable_persistent_reductions
4019*da0073e9SAndroid Build Coastguard Worker        )
4020*da0073e9SAndroid Build Coastguard Worker        inductor_config.split_reductions = not args.disable_split_reductions
4021*da0073e9SAndroid Build Coastguard Worker        inductor_config.triton.divisible_by_16 = not args.disable_divisible_by_16
4022*da0073e9SAndroid Build Coastguard Worker        if args.inference:
4023*da0073e9SAndroid Build Coastguard Worker            inductor_config.freezing = args.freezing
4024*da0073e9SAndroid Build Coastguard Worker
4025*da0073e9SAndroid Build Coastguard Worker    runner.setup_amp()
4026*da0073e9SAndroid Build Coastguard Worker
4027*da0073e9SAndroid Build Coastguard Worker    if args.output:
4028*da0073e9SAndroid Build Coastguard Worker        output_filename = args.output
4029*da0073e9SAndroid Build Coastguard Worker
4030*da0073e9SAndroid Build Coastguard Worker    if output_filename:
4031*da0073e9SAndroid Build Coastguard Worker        if args.output_directory:
4032*da0073e9SAndroid Build Coastguard Worker            output_filename = os.path.join(args.output_directory, output_filename)
4033*da0073e9SAndroid Build Coastguard Worker        else:
4034*da0073e9SAndroid Build Coastguard Worker            output_filename = os.path.join(
4035*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.config.base_dir, output_filename
4036*da0073e9SAndroid Build Coastguard Worker            )
4037*da0073e9SAndroid Build Coastguard Worker
4038*da0073e9SAndroid Build Coastguard Worker    if args.find_batch_sizes and args.only:
4039*da0073e9SAndroid Build Coastguard Worker        for device in args.devices:
4040*da0073e9SAndroid Build Coastguard Worker            batch_size = runner.batch_size_finder(device, args.only)
4041*da0073e9SAndroid Build Coastguard Worker            print(args.only, batch_size)
4042*da0073e9SAndroid Build Coastguard Worker            output_csv(output_filename, [], [args.only, batch_size])
4043*da0073e9SAndroid Build Coastguard Worker        return
4044*da0073e9SAndroid Build Coastguard Worker
4045*da0073e9SAndroid Build Coastguard Worker    if args.export_profiler_trace:
4046*da0073e9SAndroid Build Coastguard Worker        if args.profiler_trace_name is None:
4047*da0073e9SAndroid Build Coastguard Worker            if args.backend:
4048*da0073e9SAndroid Build Coastguard Worker                args.profiler_trace_name = args.backend
4049*da0073e9SAndroid Build Coastguard Worker            elif args.inductor:
4050*da0073e9SAndroid Build Coastguard Worker                args.profiler_trace_name = "inductor"
4051*da0073e9SAndroid Build Coastguard Worker            else:
4052*da0073e9SAndroid Build Coastguard Worker                args.profiler_trace_name = "profile"
4053*da0073e9SAndroid Build Coastguard Worker        else:
4054*da0073e9SAndroid Build Coastguard Worker            args.profiler_trace_name = args.profiler_trace_name
4055*da0073e9SAndroid Build Coastguard Worker
4056*da0073e9SAndroid Build Coastguard Worker    if args.no_translation_validation:
4057*da0073e9SAndroid Build Coastguard Worker        # Overwrite 'translation_validation' config, if specified.
4058*da0073e9SAndroid Build Coastguard Worker        torch.fx.experimental._config.translation_validation = False
4059*da0073e9SAndroid Build Coastguard Worker
4060*da0073e9SAndroid Build Coastguard Worker    experiment = functools.partial(experiment, args, runner.model_iter_fn)
4061*da0073e9SAndroid Build Coastguard Worker
4062*da0073e9SAndroid Build Coastguard Worker    if args.only and should_diff_branch(args):
4063*da0073e9SAndroid Build Coastguard Worker        import git
4064*da0073e9SAndroid Build Coastguard Worker
4065*da0073e9SAndroid Build Coastguard Worker        repo = git.Repo()
4066*da0073e9SAndroid Build Coastguard Worker        main_branch = repo.active_branch.name
4067*da0073e9SAndroid Build Coastguard Worker        try:
4068*da0073e9SAndroid Build Coastguard Worker            # Adding diff-branch again to the args will override previous value
4069*da0073e9SAndroid Build Coastguard Worker            call_args = (
4070*da0073e9SAndroid Build Coastguard Worker                [sys.executable] + sys.argv + [f"--diff-branch={diff_branch_default}"]
4071*da0073e9SAndroid Build Coastguard Worker            )
4072*da0073e9SAndroid Build Coastguard Worker            # Run for main branch
4073*da0073e9SAndroid Build Coastguard Worker            subprocess.check_call(call_args + [f"--tag={main_branch}"])
4074*da0073e9SAndroid Build Coastguard Worker            # Run for comparison branch
4075*da0073e9SAndroid Build Coastguard Worker            repo.git.checkout(args.diff_branch)
4076*da0073e9SAndroid Build Coastguard Worker            subprocess.check_call(call_args + [f"--tag={args.diff_branch}"])
4077*da0073e9SAndroid Build Coastguard Worker        finally:
4078*da0073e9SAndroid Build Coastguard Worker            # Go back to main branch
4079*da0073e9SAndroid Build Coastguard Worker            repo.git.checkout(main_branch)
4080*da0073e9SAndroid Build Coastguard Worker    elif args.only:
4081*da0073e9SAndroid Build Coastguard Worker        model_name = args.only
4082*da0073e9SAndroid Build Coastguard Worker        for device in args.devices:
4083*da0073e9SAndroid Build Coastguard Worker            batch_size = args.batch_size
4084*da0073e9SAndroid Build Coastguard Worker            if args.batch_size_file:
4085*da0073e9SAndroid Build Coastguard Worker                batch_size = read_batch_size_from_file(
4086*da0073e9SAndroid Build Coastguard Worker                    args, args.batch_size_file, model_name
4087*da0073e9SAndroid Build Coastguard Worker                )
4088*da0073e9SAndroid Build Coastguard Worker            if model_specified_by_path(args.only):
4089*da0073e9SAndroid Build Coastguard Worker                model, example_inputs = load_model_from_path(args.only)
4090*da0073e9SAndroid Build Coastguard Worker                name = model.__class__.__name__
4091*da0073e9SAndroid Build Coastguard Worker                model = model.to(device=device)
4092*da0073e9SAndroid Build Coastguard Worker                example_inputs = tree_map_only(
4093*da0073e9SAndroid Build Coastguard Worker                    torch.Tensor, lambda x: x.to(device=device), example_inputs
4094*da0073e9SAndroid Build Coastguard Worker                )
4095*da0073e9SAndroid Build Coastguard Worker            else:
4096*da0073e9SAndroid Build Coastguard Worker                name = model_name
4097*da0073e9SAndroid Build Coastguard Worker                try:
4098*da0073e9SAndroid Build Coastguard Worker                    with tqdm(desc="loading model"):
4099*da0073e9SAndroid Build Coastguard Worker                        extra_args = []
4100*da0073e9SAndroid Build Coastguard Worker                        if hasattr(args, "rank") and hasattr(args, "world_size"):
4101*da0073e9SAndroid Build Coastguard Worker                            extra_args += [
4102*da0073e9SAndroid Build Coastguard Worker                                "--rank",
4103*da0073e9SAndroid Build Coastguard Worker                                str(args.rank),
4104*da0073e9SAndroid Build Coastguard Worker                                "--world_size",
4105*da0073e9SAndroid Build Coastguard Worker                                str(args.world_size),
4106*da0073e9SAndroid Build Coastguard Worker                            ]
4107*da0073e9SAndroid Build Coastguard Worker
4108*da0073e9SAndroid Build Coastguard Worker                        if args.part:
4109*da0073e9SAndroid Build Coastguard Worker                            (
4110*da0073e9SAndroid Build Coastguard Worker                                device,
4111*da0073e9SAndroid Build Coastguard Worker                                name,
4112*da0073e9SAndroid Build Coastguard Worker                                model,
4113*da0073e9SAndroid Build Coastguard Worker                                example_inputs,
4114*da0073e9SAndroid Build Coastguard Worker                                batch_size,
4115*da0073e9SAndroid Build Coastguard Worker                            ) = runner.load_model(
4116*da0073e9SAndroid Build Coastguard Worker                                device,
4117*da0073e9SAndroid Build Coastguard Worker                                model_name,
4118*da0073e9SAndroid Build Coastguard Worker                                batch_size=batch_size,
4119*da0073e9SAndroid Build Coastguard Worker                                part=args.part,
4120*da0073e9SAndroid Build Coastguard Worker                                extra_args=extra_args,
4121*da0073e9SAndroid Build Coastguard Worker                            )
4122*da0073e9SAndroid Build Coastguard Worker                        else:
4123*da0073e9SAndroid Build Coastguard Worker                            if args.fsdp:
4124*da0073e9SAndroid Build Coastguard Worker                                # Always load model on cpu for fsdp
4125*da0073e9SAndroid Build Coastguard Worker                                # When initializing FSDP, we will use the cuda device if args.cuda is set
4126*da0073e9SAndroid Build Coastguard Worker                                (
4127*da0073e9SAndroid Build Coastguard Worker                                    _,
4128*da0073e9SAndroid Build Coastguard Worker                                    name,
4129*da0073e9SAndroid Build Coastguard Worker                                    model,
4130*da0073e9SAndroid Build Coastguard Worker                                    example_inputs,
4131*da0073e9SAndroid Build Coastguard Worker                                    batch_size,
4132*da0073e9SAndroid Build Coastguard Worker                                ) = runner.load_model(
4133*da0073e9SAndroid Build Coastguard Worker                                    "cpu",
4134*da0073e9SAndroid Build Coastguard Worker                                    model_name,
4135*da0073e9SAndroid Build Coastguard Worker                                    batch_size=batch_size,
4136*da0073e9SAndroid Build Coastguard Worker                                    extra_args=extra_args,
4137*da0073e9SAndroid Build Coastguard Worker                                )
4138*da0073e9SAndroid Build Coastguard Worker                            else:
4139*da0073e9SAndroid Build Coastguard Worker                                (
4140*da0073e9SAndroid Build Coastguard Worker                                    device,
4141*da0073e9SAndroid Build Coastguard Worker                                    name,
4142*da0073e9SAndroid Build Coastguard Worker                                    model,
4143*da0073e9SAndroid Build Coastguard Worker                                    example_inputs,
4144*da0073e9SAndroid Build Coastguard Worker                                    batch_size,
4145*da0073e9SAndroid Build Coastguard Worker                                ) = runner.load_model(
4146*da0073e9SAndroid Build Coastguard Worker                                    device,
4147*da0073e9SAndroid Build Coastguard Worker                                    model_name,
4148*da0073e9SAndroid Build Coastguard Worker                                    batch_size=batch_size,
4149*da0073e9SAndroid Build Coastguard Worker                                    extra_args=extra_args,
4150*da0073e9SAndroid Build Coastguard Worker                                )
4151*da0073e9SAndroid Build Coastguard Worker                except Exception as e:
4152*da0073e9SAndroid Build Coastguard Worker                    import traceback
4153*da0073e9SAndroid Build Coastguard Worker
4154*da0073e9SAndroid Build Coastguard Worker                    mode = "train" if args.training else "eval"
4155*da0073e9SAndroid Build Coastguard Worker                    print(f"{device:4} {mode:5} {name:34} ")
4156*da0073e9SAndroid Build Coastguard Worker                    print(traceback.format_exc())
4157*da0073e9SAndroid Build Coastguard Worker                    status = (
4158*da0073e9SAndroid Build Coastguard Worker                        "model_fail_to_load"
4159*da0073e9SAndroid Build Coastguard Worker                        if isinstance(e, NotImplementedError)
4160*da0073e9SAndroid Build Coastguard Worker                        else "eager_fail_to_run"
4161*da0073e9SAndroid Build Coastguard Worker                    )
4162*da0073e9SAndroid Build Coastguard Worker                    write_csv_when_exception(args, name, status, device)
4163*da0073e9SAndroid Build Coastguard Worker                    continue  # bad benchmark implementation
4164*da0073e9SAndroid Build Coastguard Worker
4165*da0073e9SAndroid Build Coastguard Worker            if args.trace_on_xla:
4166*da0073e9SAndroid Build Coastguard Worker                xla_dev = xm.xla_device()
4167*da0073e9SAndroid Build Coastguard Worker                model = model.to(device=xla_dev)
4168*da0073e9SAndroid Build Coastguard Worker                example_inputs = tree_map_only(
4169*da0073e9SAndroid Build Coastguard Worker                    torch.Tensor, lambda x: x.to(device=xla_dev), example_inputs
4170*da0073e9SAndroid Build Coastguard Worker                )
4171*da0073e9SAndroid Build Coastguard Worker
4172*da0073e9SAndroid Build Coastguard Worker            current_name = name
4173*da0073e9SAndroid Build Coastguard Worker            current_device = device
4174*da0073e9SAndroid Build Coastguard Worker            current_batch_size = batch_size
4175*da0073e9SAndroid Build Coastguard Worker            set_model_name(name)
4176*da0073e9SAndroid Build Coastguard Worker
4177*da0073e9SAndroid Build Coastguard Worker            # Look for stuff that looks like batch size, and mark it dynamic.
4178*da0073e9SAndroid Build Coastguard Worker            # Better integration would integrate directly with benchmark suite
4179*da0073e9SAndroid Build Coastguard Worker            # but cannot conveniently do this
4180*da0073e9SAndroid Build Coastguard Worker            # NB: This must be done late enough so that we don't do more
4181*da0073e9SAndroid Build Coastguard Worker            # conversions on the inputs
4182*da0073e9SAndroid Build Coastguard Worker            # NB: Assumes only the first batch-y like dimension is the batch
4183*da0073e9SAndroid Build Coastguard Worker            marked = False
4184*da0073e9SAndroid Build Coastguard Worker
4185*da0073e9SAndroid Build Coastguard Worker            def detect_and_mark_batch(t):
4186*da0073e9SAndroid Build Coastguard Worker                nonlocal marked
4187*da0073e9SAndroid Build Coastguard Worker                for i, s in enumerate(t.size()):
4188*da0073e9SAndroid Build Coastguard Worker                    if s == batch_size:
4189*da0073e9SAndroid Build Coastguard Worker                        torch._dynamo.mark_dynamic(t, i)
4190*da0073e9SAndroid Build Coastguard Worker                        marked = True
4191*da0073e9SAndroid Build Coastguard Worker                        break
4192*da0073e9SAndroid Build Coastguard Worker
4193*da0073e9SAndroid Build Coastguard Worker            if (
4194*da0073e9SAndroid Build Coastguard Worker                args.dynamic_batch_only
4195*da0073e9SAndroid Build Coastguard Worker                and batch_size > 1
4196*da0073e9SAndroid Build Coastguard Worker                and model_name not in CI_SKIP_DYNAMIC_BATCH_ONLY
4197*da0073e9SAndroid Build Coastguard Worker            ):
4198*da0073e9SAndroid Build Coastguard Worker                tree_map_only(torch.Tensor, detect_and_mark_batch, example_inputs)
4199*da0073e9SAndroid Build Coastguard Worker                assert marked, f"nothing in example_inputs had a dim with {batch_size}"
4200*da0073e9SAndroid Build Coastguard Worker
4201*da0073e9SAndroid Build Coastguard Worker            if args.log_operator_inputs:
4202*da0073e9SAndroid Build Coastguard Worker                log_operator_inputs(
4203*da0073e9SAndroid Build Coastguard Worker                    model, example_inputs, runner.model_iter_fn, name, args
4204*da0073e9SAndroid Build Coastguard Worker                )
4205*da0073e9SAndroid Build Coastguard Worker                continue
4206*da0073e9SAndroid Build Coastguard Worker
4207*da0073e9SAndroid Build Coastguard Worker            if args.per_process_memory_fraction != 1:
4208*da0073e9SAndroid Build Coastguard Worker                torch.cuda.set_per_process_memory_fraction(
4209*da0073e9SAndroid Build Coastguard Worker                    args.per_process_memory_fraction
4210*da0073e9SAndroid Build Coastguard Worker                )
4211*da0073e9SAndroid Build Coastguard Worker            if model_name in DO_NOT_CAST_INPUTS:
4212*da0073e9SAndroid Build Coastguard Worker                model, _ = runner.cast_based_on_args(model, example_inputs)
4213*da0073e9SAndroid Build Coastguard Worker
4214*da0073e9SAndroid Build Coastguard Worker            else:
4215*da0073e9SAndroid Build Coastguard Worker                model, example_inputs = runner.cast_based_on_args(model, example_inputs)
4216*da0073e9SAndroid Build Coastguard Worker            runner.setup_amp(current_device)
4217*da0073e9SAndroid Build Coastguard Worker            guard_ctx = contextlib.nullcontext()
4218*da0073e9SAndroid Build Coastguard Worker            if name in runner.guard_on_nn_module_models:
4219*da0073e9SAndroid Build Coastguard Worker                guard_ctx = torch._dynamo.config.patch(guard_nn_modules=True)
4220*da0073e9SAndroid Build Coastguard Worker
4221*da0073e9SAndroid Build Coastguard Worker            with guard_ctx:
4222*da0073e9SAndroid Build Coastguard Worker                runner.run_one_model(
4223*da0073e9SAndroid Build Coastguard Worker                    name,
4224*da0073e9SAndroid Build Coastguard Worker                    model,
4225*da0073e9SAndroid Build Coastguard Worker                    example_inputs,
4226*da0073e9SAndroid Build Coastguard Worker                    optimize_ctx,
4227*da0073e9SAndroid Build Coastguard Worker                    experiment,
4228*da0073e9SAndroid Build Coastguard Worker                    explain=args.explain,
4229*da0073e9SAndroid Build Coastguard Worker                    tag=args.tag,
4230*da0073e9SAndroid Build Coastguard Worker                )
4231*da0073e9SAndroid Build Coastguard Worker        if args.generate_aot_autograd_stats:
4232*da0073e9SAndroid Build Coastguard Worker            stats_file = output_filename.split(".csv")[0] + "_stats.csv"
4233*da0073e9SAndroid Build Coastguard Worker            output_csv(
4234*da0073e9SAndroid Build Coastguard Worker                stats_file,
4235*da0073e9SAndroid Build Coastguard Worker                ("dev", "name", "batch_size", "total_aot_graphs", "ok_aot_graphs"),
4236*da0073e9SAndroid Build Coastguard Worker                [
4237*da0073e9SAndroid Build Coastguard Worker                    current_device,
4238*da0073e9SAndroid Build Coastguard Worker                    current_name,
4239*da0073e9SAndroid Build Coastguard Worker                    current_batch_size,
4240*da0073e9SAndroid Build Coastguard Worker                    *Stats.aot_summary(),
4241*da0073e9SAndroid Build Coastguard Worker                ],
4242*da0073e9SAndroid Build Coastguard Worker            )
4243*da0073e9SAndroid Build Coastguard Worker    else:
4244*da0073e9SAndroid Build Coastguard Worker        metrics.purge_old_log_files()
4245*da0073e9SAndroid Build Coastguard Worker        if output_filename and os.path.exists(output_filename):
4246*da0073e9SAndroid Build Coastguard Worker            os.unlink(output_filename)
4247*da0073e9SAndroid Build Coastguard Worker        if original_dir:
4248*da0073e9SAndroid Build Coastguard Worker            os.chdir(original_dir)
4249*da0073e9SAndroid Build Coastguard Worker        model_names = list(runner.iter_model_names(args))
4250*da0073e9SAndroid Build Coastguard Worker        nmodels = len(model_names)
4251*da0073e9SAndroid Build Coastguard Worker        for i, name in enumerate(model_names):
4252*da0073e9SAndroid Build Coastguard Worker            current_name = name
4253*da0073e9SAndroid Build Coastguard Worker            if args.progress:
4254*da0073e9SAndroid Build Coastguard Worker                print(f"Running model {i+1}/{nmodels}", flush=True)
4255*da0073e9SAndroid Build Coastguard Worker
4256*da0073e9SAndroid Build Coastguard Worker            try:
4257*da0073e9SAndroid Build Coastguard Worker                timeout = args.timeout
4258*da0073e9SAndroid Build Coastguard Worker                if should_diff_branch(args):
4259*da0073e9SAndroid Build Coastguard Worker                    timeout *= 2
4260*da0073e9SAndroid Build Coastguard Worker                env = os.environ.copy()
4261*da0073e9SAndroid Build Coastguard Worker                if args.ci and name in CI_PRESERVE_COMPILE_DEBUG:
4262*da0073e9SAndroid Build Coastguard Worker                    env["TORCH_COMPILE_DEBUG"] = "1"
4263*da0073e9SAndroid Build Coastguard Worker                subprocess.check_call(
4264*da0073e9SAndroid Build Coastguard Worker                    [sys.executable] + sys.argv + [f"--only={name}"],
4265*da0073e9SAndroid Build Coastguard Worker                    timeout=timeout,
4266*da0073e9SAndroid Build Coastguard Worker                    env=env,
4267*da0073e9SAndroid Build Coastguard Worker                )
4268*da0073e9SAndroid Build Coastguard Worker            except subprocess.TimeoutExpired:
4269*da0073e9SAndroid Build Coastguard Worker                write_csv_when_exception(args, name, "timeout")
4270*da0073e9SAndroid Build Coastguard Worker            except subprocess.CalledProcessError as e:
4271*da0073e9SAndroid Build Coastguard Worker                print("Run failed with return code: ", e.returncode, file=sys.stderr)
4272*da0073e9SAndroid Build Coastguard Worker                print("Output: ", e.output, file=sys.stderr)
4273*da0073e9SAndroid Build Coastguard Worker                print("Error: ", e.stderr, file=sys.stderr)
4274*da0073e9SAndroid Build Coastguard Worker        print_summary(output_filename, print_dataframe=args.print_dataframe_summary)
4275*da0073e9SAndroid Build Coastguard Worker
4276*da0073e9SAndroid Build Coastguard Worker
4277*da0073e9SAndroid Build Coastguard Workerdef log_operator_inputs(model, example_inputs, model_iter_fn, name, args):
4278*da0073e9SAndroid Build Coastguard Worker    mode = "training" if args.training else "eval"
4279*da0073e9SAndroid Build Coastguard Worker    output = os.path.join(os.path.dirname(args.output), f"{name}_{mode}.txt")
4280*da0073e9SAndroid Build Coastguard Worker
4281*da0073e9SAndroid Build Coastguard Worker    # TODO - add option for coalescing inputs over multiple runs
4282*da0073e9SAndroid Build Coastguard Worker    if os.path.exists(output):
4283*da0073e9SAndroid Build Coastguard Worker        print(f"Skipping {name}, {output} already exists")
4284*da0073e9SAndroid Build Coastguard Worker        return
4285*da0073e9SAndroid Build Coastguard Worker
4286*da0073e9SAndroid Build Coastguard Worker    print(f"Running {name}")
4287*da0073e9SAndroid Build Coastguard Worker    try:
4288*da0073e9SAndroid Build Coastguard Worker        from .microbenchmarks.operator_inp_utils import OperatorInputsMode
4289*da0073e9SAndroid Build Coastguard Worker    except ImportError:
4290*da0073e9SAndroid Build Coastguard Worker        from microbenchmarks.operator_inp_utils import OperatorInputsMode
4291*da0073e9SAndroid Build Coastguard Worker
4292*da0073e9SAndroid Build Coastguard Worker    operator_mode = OperatorInputsMode()
4293*da0073e9SAndroid Build Coastguard Worker    fake_tensor_mode = FakeTensorMode()
4294*da0073e9SAndroid Build Coastguard Worker
4295*da0073e9SAndroid Build Coastguard Worker    with torch._subclasses.fake_tensor.FakeCopyMode(fake_tensor_mode):
4296*da0073e9SAndroid Build Coastguard Worker        model_fake = copy.deepcopy(model)
4297*da0073e9SAndroid Build Coastguard Worker        example_inputs_fake = copy.deepcopy(example_inputs)
4298*da0073e9SAndroid Build Coastguard Worker    try:
4299*da0073e9SAndroid Build Coastguard Worker        with fake_tensor_mode, operator_mode:
4300*da0073e9SAndroid Build Coastguard Worker            model_iter_fn(model_fake, example_inputs_fake, collect_outputs=False)
4301*da0073e9SAndroid Build Coastguard Worker    except Exception as e:
4302*da0073e9SAndroid Build Coastguard Worker        print(f"{name} failed to run with fake tensors, trying real. Exception: {e}")
4303*da0073e9SAndroid Build Coastguard Worker        operator_mode = OperatorInputsMode()
4304*da0073e9SAndroid Build Coastguard Worker        try:
4305*da0073e9SAndroid Build Coastguard Worker            with operator_mode:
4306*da0073e9SAndroid Build Coastguard Worker                model_iter_fn(model, example_inputs, collect_outputs=False)
4307*da0073e9SAndroid Build Coastguard Worker        except Exception as e2:
4308*da0073e9SAndroid Build Coastguard Worker            print(f"{name} failed to run with real. Exception: {e2}")
4309*da0073e9SAndroid Build Coastguard Worker            raise
4310*da0073e9SAndroid Build Coastguard Worker
4311*da0073e9SAndroid Build Coastguard Worker    print(f"Writing output to {output}")
4312*da0073e9SAndroid Build Coastguard Worker    operator_mode.log_to_file(output)
4313*da0073e9SAndroid Build Coastguard Worker
4314*da0073e9SAndroid Build Coastguard Worker
4315*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
4316*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
4317*da0073e9SAndroid Build Coastguard Worker        f"You shouldn't run {sys.argv[0]} directly, instead try timm_model.py, torchbench.py or huggingface.py"
4318*da0073e9SAndroid Build Coastguard Worker    )
4319