xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/torchbench.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport gc
4*da0073e9SAndroid Build Coastguard Workerimport importlib
5*da0073e9SAndroid Build Coastguard Workerimport logging
6*da0073e9SAndroid Build Coastguard Workerimport os
7*da0073e9SAndroid Build Coastguard Workerimport re
8*da0073e9SAndroid Build Coastguard Workerimport sys
9*da0073e9SAndroid Build Coastguard Workerimport warnings
10*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple
11*da0073e9SAndroid Build Coastguard Workerfrom os.path import abspath, exists
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerimport torch
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workertry:
17*da0073e9SAndroid Build Coastguard Worker    from .common import BenchmarkRunner, load_yaml_file, main
18*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
19*da0073e9SAndroid Build Coastguard Worker    from common import BenchmarkRunner, load_yaml_file, main
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import collect_results, reduce_to_scalar_loss
22*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.utils import clone_inputs
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker# We are primarily interested in tf32 datatype
26*da0073e9SAndroid Build Coastguard Workertorch.backends.cuda.matmul.allow_tf32 = True
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker# Enable FX graph caching
29*da0073e9SAndroid Build Coastguard Workerif "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
30*da0073e9SAndroid Build Coastguard Worker    torch._inductor.config.fx_graph_cache = True
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Workerdef _reassign_parameters(model):
34*da0073e9SAndroid Build Coastguard Worker    # torch_geometric models register parameter as tensors due to
35*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/dense/linear.py#L158-L168
36*da0073e9SAndroid Build Coastguard Worker    # Since it is unusual thing to do, we just reassign them to parameters
37*da0073e9SAndroid Build Coastguard Worker    def state_dict_hook(module, destination, prefix, local_metadata):
38*da0073e9SAndroid Build Coastguard Worker        for name, param in module.named_parameters():
39*da0073e9SAndroid Build Coastguard Worker            if isinstance(destination[name], torch.Tensor) and not isinstance(
40*da0073e9SAndroid Build Coastguard Worker                destination[name], torch.nn.Parameter
41*da0073e9SAndroid Build Coastguard Worker            ):
42*da0073e9SAndroid Build Coastguard Worker                destination[name] = torch.nn.Parameter(destination[name])
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker    model._register_state_dict_hook(state_dict_hook)
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Workerdef setup_torchbench_cwd():
48*da0073e9SAndroid Build Coastguard Worker    original_dir = abspath(os.getcwd())
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    os.environ["KALDI_ROOT"] = "/tmp"  # avoids some spam
51*da0073e9SAndroid Build Coastguard Worker    for torchbench_dir in (
52*da0073e9SAndroid Build Coastguard Worker        "./torchbenchmark",
53*da0073e9SAndroid Build Coastguard Worker        "../torchbenchmark",
54*da0073e9SAndroid Build Coastguard Worker        "../torchbench",
55*da0073e9SAndroid Build Coastguard Worker        "../benchmark",
56*da0073e9SAndroid Build Coastguard Worker        "../../torchbenchmark",
57*da0073e9SAndroid Build Coastguard Worker        "../../torchbench",
58*da0073e9SAndroid Build Coastguard Worker        "../../benchmark",
59*da0073e9SAndroid Build Coastguard Worker        "../../../torchbenchmark",
60*da0073e9SAndroid Build Coastguard Worker        "../../../torchbench",
61*da0073e9SAndroid Build Coastguard Worker        "../../../benchmark",
62*da0073e9SAndroid Build Coastguard Worker    ):
63*da0073e9SAndroid Build Coastguard Worker        if exists(torchbench_dir):
64*da0073e9SAndroid Build Coastguard Worker            break
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    if exists(torchbench_dir):
67*da0073e9SAndroid Build Coastguard Worker        torchbench_dir = abspath(torchbench_dir)
68*da0073e9SAndroid Build Coastguard Worker        os.chdir(torchbench_dir)
69*da0073e9SAndroid Build Coastguard Worker        sys.path.append(torchbench_dir)
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    return original_dir
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Workerdef process_hf_reformer_output(out):
75*da0073e9SAndroid Build Coastguard Worker    assert isinstance(out, list)
76*da0073e9SAndroid Build Coastguard Worker    # second output is unstable
77*da0073e9SAndroid Build Coastguard Worker    return [elem for i, elem in enumerate(out) if i != 1]
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Workerdef process_hf_whisper_output(out):
81*da0073e9SAndroid Build Coastguard Worker    out_ret = []
82*da0073e9SAndroid Build Coastguard Worker    for i, elem in enumerate(out):
83*da0073e9SAndroid Build Coastguard Worker        if i == 0:
84*da0073e9SAndroid Build Coastguard Worker            assert isinstance(elem, dict)
85*da0073e9SAndroid Build Coastguard Worker            out_ret.append({k: v for k, v in elem.items() if k != "logits"})
86*da0073e9SAndroid Build Coastguard Worker        elif i != 1:
87*da0073e9SAndroid Build Coastguard Worker            out_ret.append(elem)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    return out_ret
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Workerprocess_train_model_output = {
93*da0073e9SAndroid Build Coastguard Worker    "hf_Reformer": process_hf_reformer_output,
94*da0073e9SAndroid Build Coastguard Worker    "hf_Whisper": process_hf_whisper_output,
95*da0073e9SAndroid Build Coastguard Worker}
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Workerclass TorchBenchmarkRunner(BenchmarkRunner):
99*da0073e9SAndroid Build Coastguard Worker    def __init__(self):
100*da0073e9SAndroid Build Coastguard Worker        super().__init__()
101*da0073e9SAndroid Build Coastguard Worker        self.suite_name = "torchbench"
102*da0073e9SAndroid Build Coastguard Worker        self.optimizer = None
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker    @property
105*da0073e9SAndroid Build Coastguard Worker    def _config(self):
106*da0073e9SAndroid Build Coastguard Worker        return load_yaml_file("torchbench.yaml")
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker    @property
109*da0073e9SAndroid Build Coastguard Worker    def _skip(self):
110*da0073e9SAndroid Build Coastguard Worker        return self._config["skip"]
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker    @property
113*da0073e9SAndroid Build Coastguard Worker    def _batch_size(self):
114*da0073e9SAndroid Build Coastguard Worker        return self._config["batch_size"]
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    @property
117*da0073e9SAndroid Build Coastguard Worker    def _tolerance(self):
118*da0073e9SAndroid Build Coastguard Worker        return self._config["tolerance"]
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    @property
121*da0073e9SAndroid Build Coastguard Worker    def _require_larger_multiplier_for_smaller_tensor(self):
122*da0073e9SAndroid Build Coastguard Worker        return self._config["require_larger_multiplier_for_smaller_tensor"]
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    @property
125*da0073e9SAndroid Build Coastguard Worker    def _accuracy(self):
126*da0073e9SAndroid Build Coastguard Worker        return self._config["accuracy"]
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    @property
129*da0073e9SAndroid Build Coastguard Worker    def skip_models(self):
130*da0073e9SAndroid Build Coastguard Worker        return self._skip["all"]
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker    @property
133*da0073e9SAndroid Build Coastguard Worker    def skip_models_for_cpu(self):
134*da0073e9SAndroid Build Coastguard Worker        return self._skip["device"]["cpu"]
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker    @property
137*da0073e9SAndroid Build Coastguard Worker    def skip_models_for_cuda(self):
138*da0073e9SAndroid Build Coastguard Worker        return self._skip["device"]["cuda"]
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker    @property
141*da0073e9SAndroid Build Coastguard Worker    def skip_models_for_freezing_cuda(self):
142*da0073e9SAndroid Build Coastguard Worker        return self._skip["freezing"]["cuda"]
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    @property
145*da0073e9SAndroid Build Coastguard Worker    def skip_models_for_freezing_cpu(self):
146*da0073e9SAndroid Build Coastguard Worker        return self._skip["freezing"]["cpu"]
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker    @property
149*da0073e9SAndroid Build Coastguard Worker    def slow_models(self):
150*da0073e9SAndroid Build Coastguard Worker        return self._config["slow"]
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    @property
153*da0073e9SAndroid Build Coastguard Worker    def very_slow_models(self):
154*da0073e9SAndroid Build Coastguard Worker        return self._config["very_slow"]
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker    @property
157*da0073e9SAndroid Build Coastguard Worker    def non_deterministic_models(self):
158*da0073e9SAndroid Build Coastguard Worker        return self._config["non_deterministic"]
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker    @property
161*da0073e9SAndroid Build Coastguard Worker    def get_output_amp_train_process_func(self):
162*da0073e9SAndroid Build Coastguard Worker        return process_train_model_output
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker    @property
165*da0073e9SAndroid Build Coastguard Worker    def skip_not_suitable_for_training_models(self):
166*da0073e9SAndroid Build Coastguard Worker        return self._skip["test"]["training"]
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    @property
169*da0073e9SAndroid Build Coastguard Worker    def failing_fx2trt_models(self):
170*da0073e9SAndroid Build Coastguard Worker        return self._config["trt_not_yet_working"]
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    @property
173*da0073e9SAndroid Build Coastguard Worker    def force_amp_for_fp16_bf16_models(self):
174*da0073e9SAndroid Build Coastguard Worker        return self._config["dtype"]["force_amp_for_fp16_bf16_models"]
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker    @property
177*da0073e9SAndroid Build Coastguard Worker    def force_fp16_for_bf16_models(self):
178*da0073e9SAndroid Build Coastguard Worker        return self._config["dtype"]["force_fp16_for_bf16_models"]
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    @property
181*da0073e9SAndroid Build Coastguard Worker    def skip_accuracy_checks_large_models_dashboard(self):
182*da0073e9SAndroid Build Coastguard Worker        if self.args.dashboard or self.args.accuracy:
183*da0073e9SAndroid Build Coastguard Worker            return self._accuracy["skip"]["large_models"]
184*da0073e9SAndroid Build Coastguard Worker        return set()
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker    @property
187*da0073e9SAndroid Build Coastguard Worker    def skip_accuracy_check_as_eager_non_deterministic(self):
188*da0073e9SAndroid Build Coastguard Worker        if self.args.accuracy and self.args.training:
189*da0073e9SAndroid Build Coastguard Worker            return self._accuracy["skip"]["eager_not_deterministic"]
190*da0073e9SAndroid Build Coastguard Worker        return set()
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker    @property
193*da0073e9SAndroid Build Coastguard Worker    def skip_multiprocess_models(self):
194*da0073e9SAndroid Build Coastguard Worker        return self._skip["multiprocess"]
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker    @property
197*da0073e9SAndroid Build Coastguard Worker    def skip_models_due_to_control_flow(self):
198*da0073e9SAndroid Build Coastguard Worker        return self._skip["control_flow"]
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker    @property
201*da0073e9SAndroid Build Coastguard Worker    def guard_on_nn_module_models(self):
202*da0073e9SAndroid Build Coastguard Worker        return {
203*da0073e9SAndroid Build Coastguard Worker            "vision_maskrcnn",
204*da0073e9SAndroid Build Coastguard Worker        }
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker    @property
207*da0073e9SAndroid Build Coastguard Worker    def inline_inbuilt_nn_modules_models(self):
208*da0073e9SAndroid Build Coastguard Worker        return {
209*da0073e9SAndroid Build Coastguard Worker            "basic_gnn_edgecnn",
210*da0073e9SAndroid Build Coastguard Worker            "drq",
211*da0073e9SAndroid Build Coastguard Worker            "hf_Reformer",
212*da0073e9SAndroid Build Coastguard Worker            "DALLE2_pytorch",
213*da0073e9SAndroid Build Coastguard Worker            "hf_BigBird",
214*da0073e9SAndroid Build Coastguard Worker            "detectron2_maskrcnn_r_50_fpn",
215*da0073e9SAndroid Build Coastguard Worker            "detectron2_maskrcnn_r_101_fpn",
216*da0073e9SAndroid Build Coastguard Worker            "vision_maskrcnn",
217*da0073e9SAndroid Build Coastguard Worker            "doctr_reco_predictor",
218*da0073e9SAndroid Build Coastguard Worker            "hf_T5_generate",
219*da0073e9SAndroid Build Coastguard Worker        }
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker    def load_model(
222*da0073e9SAndroid Build Coastguard Worker        self,
223*da0073e9SAndroid Build Coastguard Worker        device,
224*da0073e9SAndroid Build Coastguard Worker        model_name,
225*da0073e9SAndroid Build Coastguard Worker        batch_size=None,
226*da0073e9SAndroid Build Coastguard Worker        part=None,
227*da0073e9SAndroid Build Coastguard Worker        extra_args=None,
228*da0073e9SAndroid Build Coastguard Worker    ):
229*da0073e9SAndroid Build Coastguard Worker        if self.args.enable_activation_checkpointing:
230*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError(
231*da0073e9SAndroid Build Coastguard Worker                "Activation checkpointing not implemented for Torchbench models"
232*da0073e9SAndroid Build Coastguard Worker            )
233*da0073e9SAndroid Build Coastguard Worker        is_training = self.args.training
234*da0073e9SAndroid Build Coastguard Worker        use_eval_mode = self.args.use_eval_mode
235*da0073e9SAndroid Build Coastguard Worker        dynamic_shapes = self.args.dynamic_shapes
236*da0073e9SAndroid Build Coastguard Worker        candidates = [
237*da0073e9SAndroid Build Coastguard Worker            f"torchbenchmark.models.{model_name}",
238*da0073e9SAndroid Build Coastguard Worker            f"torchbenchmark.canary_models.{model_name}",
239*da0073e9SAndroid Build Coastguard Worker            f"torchbenchmark.models.fb.{model_name}",
240*da0073e9SAndroid Build Coastguard Worker        ]
241*da0073e9SAndroid Build Coastguard Worker        for c in candidates:
242*da0073e9SAndroid Build Coastguard Worker            try:
243*da0073e9SAndroid Build Coastguard Worker                module = importlib.import_module(c)
244*da0073e9SAndroid Build Coastguard Worker                break
245*da0073e9SAndroid Build Coastguard Worker            except ModuleNotFoundError as e:
246*da0073e9SAndroid Build Coastguard Worker                if e.name != c:
247*da0073e9SAndroid Build Coastguard Worker                    raise
248*da0073e9SAndroid Build Coastguard Worker        else:
249*da0073e9SAndroid Build Coastguard Worker            raise ImportError(f"could not import any of {candidates}")
250*da0073e9SAndroid Build Coastguard Worker        benchmark_cls = getattr(module, "Model", None)
251*da0073e9SAndroid Build Coastguard Worker        if benchmark_cls is None:
252*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError(f"{model_name}.Model is None")
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        if not hasattr(benchmark_cls, "name"):
255*da0073e9SAndroid Build Coastguard Worker            benchmark_cls.name = model_name
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker        cant_change_batch_size = (
258*da0073e9SAndroid Build Coastguard Worker            not getattr(benchmark_cls, "ALLOW_CUSTOMIZE_BSIZE", True)
259*da0073e9SAndroid Build Coastguard Worker            or model_name in self._config["dont_change_batch_size"]
260*da0073e9SAndroid Build Coastguard Worker        )
261*da0073e9SAndroid Build Coastguard Worker        if cant_change_batch_size:
262*da0073e9SAndroid Build Coastguard Worker            batch_size = None
263*da0073e9SAndroid Build Coastguard Worker        if (
264*da0073e9SAndroid Build Coastguard Worker            batch_size is None
265*da0073e9SAndroid Build Coastguard Worker            and is_training
266*da0073e9SAndroid Build Coastguard Worker            and model_name in self._batch_size["training"]
267*da0073e9SAndroid Build Coastguard Worker        ):
268*da0073e9SAndroid Build Coastguard Worker            batch_size = self._batch_size["training"][model_name]
269*da0073e9SAndroid Build Coastguard Worker        elif (
270*da0073e9SAndroid Build Coastguard Worker            batch_size is None
271*da0073e9SAndroid Build Coastguard Worker            and not is_training
272*da0073e9SAndroid Build Coastguard Worker            and model_name in self._batch_size["inference"]
273*da0073e9SAndroid Build Coastguard Worker        ):
274*da0073e9SAndroid Build Coastguard Worker            batch_size = self._batch_size["inference"][model_name]
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker        # Control the memory footprint for few models
277*da0073e9SAndroid Build Coastguard Worker        if self.args.accuracy and model_name in self._accuracy["max_batch_size"]:
278*da0073e9SAndroid Build Coastguard Worker            batch_size = min(batch_size, self._accuracy["max_batch_size"][model_name])
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker        # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags"
281*da0073e9SAndroid Build Coastguard Worker        torch.backends.__allow_nonbracketed_mutation_flag = True
282*da0073e9SAndroid Build Coastguard Worker        if extra_args is None:
283*da0073e9SAndroid Build Coastguard Worker            extra_args = []
284*da0073e9SAndroid Build Coastguard Worker        if part:
285*da0073e9SAndroid Build Coastguard Worker            extra_args += ["--part", part]
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker        # sam_fast only runs with amp
288*da0073e9SAndroid Build Coastguard Worker        if model_name == "sam_fast":
289*da0073e9SAndroid Build Coastguard Worker            self.args.amp = True
290*da0073e9SAndroid Build Coastguard Worker            self.setup_amp()
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        if model_name == "vision_maskrcnn" and is_training:
293*da0073e9SAndroid Build Coastguard Worker            # Output of vision_maskrcnn model is a list of bounding boxes,
294*da0073e9SAndroid Build Coastguard Worker            # sorted on the basis of their scores. This makes accuracy
295*da0073e9SAndroid Build Coastguard Worker            # comparison hard with torch.compile. torch.compile can cause minor
296*da0073e9SAndroid Build Coastguard Worker            # divergences in the output because of how fusion works for amp in
297*da0073e9SAndroid Build Coastguard Worker            # TorchInductor compared to eager.  Therefore, instead of looking at
298*da0073e9SAndroid Build Coastguard Worker            # all the bounding boxes, we compare only top 4.
299*da0073e9SAndroid Build Coastguard Worker            model_kwargs = {"box_detections_per_img": 4}
300*da0073e9SAndroid Build Coastguard Worker            benchmark = benchmark_cls(
301*da0073e9SAndroid Build Coastguard Worker                test="train",
302*da0073e9SAndroid Build Coastguard Worker                device=device,
303*da0073e9SAndroid Build Coastguard Worker                batch_size=batch_size,
304*da0073e9SAndroid Build Coastguard Worker                extra_args=extra_args,
305*da0073e9SAndroid Build Coastguard Worker                model_kwargs=model_kwargs,
306*da0073e9SAndroid Build Coastguard Worker            )
307*da0073e9SAndroid Build Coastguard Worker            use_eval_mode = True
308*da0073e9SAndroid Build Coastguard Worker        elif is_training:
309*da0073e9SAndroid Build Coastguard Worker            benchmark = benchmark_cls(
310*da0073e9SAndroid Build Coastguard Worker                test="train",
311*da0073e9SAndroid Build Coastguard Worker                device=device,
312*da0073e9SAndroid Build Coastguard Worker                batch_size=batch_size,
313*da0073e9SAndroid Build Coastguard Worker                extra_args=extra_args,
314*da0073e9SAndroid Build Coastguard Worker            )
315*da0073e9SAndroid Build Coastguard Worker        else:
316*da0073e9SAndroid Build Coastguard Worker            benchmark = benchmark_cls(
317*da0073e9SAndroid Build Coastguard Worker                test="eval",
318*da0073e9SAndroid Build Coastguard Worker                device=device,
319*da0073e9SAndroid Build Coastguard Worker                batch_size=batch_size,
320*da0073e9SAndroid Build Coastguard Worker                extra_args=extra_args,
321*da0073e9SAndroid Build Coastguard Worker            )
322*da0073e9SAndroid Build Coastguard Worker        model, example_inputs = benchmark.get_module()
323*da0073e9SAndroid Build Coastguard Worker        if model_name in [
324*da0073e9SAndroid Build Coastguard Worker            "basic_gnn_edgecnn",
325*da0073e9SAndroid Build Coastguard Worker            "basic_gnn_gcn",
326*da0073e9SAndroid Build Coastguard Worker            "basic_gnn_sage",
327*da0073e9SAndroid Build Coastguard Worker            "basic_gnn_gin",
328*da0073e9SAndroid Build Coastguard Worker        ]:
329*da0073e9SAndroid Build Coastguard Worker            _reassign_parameters(model)
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker        # Models that must be in train mode while training
332*da0073e9SAndroid Build Coastguard Worker        if is_training and (
333*da0073e9SAndroid Build Coastguard Worker            not use_eval_mode or model_name in self._config["only_training"]
334*da0073e9SAndroid Build Coastguard Worker        ):
335*da0073e9SAndroid Build Coastguard Worker            model.train()
336*da0073e9SAndroid Build Coastguard Worker        else:
337*da0073e9SAndroid Build Coastguard Worker            model.eval()
338*da0073e9SAndroid Build Coastguard Worker        gc.collect()
339*da0073e9SAndroid Build Coastguard Worker        batch_size = benchmark.batch_size
340*da0073e9SAndroid Build Coastguard Worker        if model_name == "torchrec_dlrm":
341*da0073e9SAndroid Build Coastguard Worker            batch_namedtuple = namedtuple(
342*da0073e9SAndroid Build Coastguard Worker                "Batch", "dense_features sparse_features labels"
343*da0073e9SAndroid Build Coastguard Worker            )
344*da0073e9SAndroid Build Coastguard Worker            example_inputs = tuple(
345*da0073e9SAndroid Build Coastguard Worker                batch_namedtuple(
346*da0073e9SAndroid Build Coastguard Worker                    dense_features=batch.dense_features,
347*da0073e9SAndroid Build Coastguard Worker                    sparse_features=batch.sparse_features,
348*da0073e9SAndroid Build Coastguard Worker                    labels=batch.labels,
349*da0073e9SAndroid Build Coastguard Worker                )
350*da0073e9SAndroid Build Coastguard Worker                for batch in example_inputs
351*da0073e9SAndroid Build Coastguard Worker            )
352*da0073e9SAndroid Build Coastguard Worker        # Torchbench has quite different setup for yolov3, so directly passing
353*da0073e9SAndroid Build Coastguard Worker        # the right example_inputs
354*da0073e9SAndroid Build Coastguard Worker        if model_name == "yolov3":
355*da0073e9SAndroid Build Coastguard Worker            example_inputs = (torch.rand(batch_size, 3, 384, 512).to(device),)
356*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/benchmark/issues/1561
357*da0073e9SAndroid Build Coastguard Worker        if model_name == "maml_omniglot":
358*da0073e9SAndroid Build Coastguard Worker            batch_size = 5
359*da0073e9SAndroid Build Coastguard Worker            assert example_inputs[0].shape[0] == batch_size
360*da0073e9SAndroid Build Coastguard Worker        if model_name == "vision_maskrcnn":
361*da0073e9SAndroid Build Coastguard Worker            batch_size = 1
362*da0073e9SAndroid Build Coastguard Worker        # global current_name, current_device
363*da0073e9SAndroid Build Coastguard Worker        # current_device = device
364*da0073e9SAndroid Build Coastguard Worker        # current_name = benchmark.name
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        if self.args.trace_on_xla:
367*da0073e9SAndroid Build Coastguard Worker            # work around for: https://github.com/pytorch/xla/issues/4174
368*da0073e9SAndroid Build Coastguard Worker            import torch_xla  # noqa: F401
369*da0073e9SAndroid Build Coastguard Worker        self.validate_model(model, example_inputs)
370*da0073e9SAndroid Build Coastguard Worker        return device, benchmark.name, model, example_inputs, batch_size
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker    def iter_model_names(self, args):
373*da0073e9SAndroid Build Coastguard Worker        from torchbenchmark import _list_canary_model_paths, _list_model_paths
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker        models = _list_model_paths()
376*da0073e9SAndroid Build Coastguard Worker        models += [
377*da0073e9SAndroid Build Coastguard Worker            f
378*da0073e9SAndroid Build Coastguard Worker            for f in _list_canary_model_paths()
379*da0073e9SAndroid Build Coastguard Worker            if os.path.basename(f) in self._config["canary_models"]
380*da0073e9SAndroid Build Coastguard Worker        ]
381*da0073e9SAndroid Build Coastguard Worker        models.sort()
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker        start, end = self.get_benchmark_indices(len(models))
384*da0073e9SAndroid Build Coastguard Worker        for index, model_path in enumerate(models):
385*da0073e9SAndroid Build Coastguard Worker            if index < start or index >= end:
386*da0073e9SAndroid Build Coastguard Worker                continue
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker            model_name = os.path.basename(model_path)
389*da0073e9SAndroid Build Coastguard Worker            if (
390*da0073e9SAndroid Build Coastguard Worker                not re.search("|".join(args.filter), model_name, re.IGNORECASE)
391*da0073e9SAndroid Build Coastguard Worker                or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
392*da0073e9SAndroid Build Coastguard Worker                or model_name in args.exclude_exact
393*da0073e9SAndroid Build Coastguard Worker                or model_name in self.skip_models
394*da0073e9SAndroid Build Coastguard Worker            ):
395*da0073e9SAndroid Build Coastguard Worker                continue
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker            yield model_name
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker    def pick_grad(self, name, is_training):
400*da0073e9SAndroid Build Coastguard Worker        if is_training or name in ("maml",):
401*da0073e9SAndroid Build Coastguard Worker            return torch.enable_grad()
402*da0073e9SAndroid Build Coastguard Worker        else:
403*da0073e9SAndroid Build Coastguard Worker            return torch.no_grad()
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker    def use_larger_multiplier_for_smaller_tensor(self, name):
406*da0073e9SAndroid Build Coastguard Worker        return name in self._require_larger_multiplier_for_smaller_tensor
407*da0073e9SAndroid Build Coastguard Worker
408*da0073e9SAndroid Build Coastguard Worker    def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
409*da0073e9SAndroid Build Coastguard Worker        tolerance = 1e-4
410*da0073e9SAndroid Build Coastguard Worker        cosine = self.args.cosine
411*da0073e9SAndroid Build Coastguard Worker        # Increase the tolerance for torch allclose
412*da0073e9SAndroid Build Coastguard Worker        if self.args.float16 or self.args.amp:
413*da0073e9SAndroid Build Coastguard Worker            if name in self._tolerance["higher_fp16"]:
414*da0073e9SAndroid Build Coastguard Worker                return 1e-2, cosine
415*da0073e9SAndroid Build Coastguard Worker            elif name in self._tolerance["even_higher"]:
416*da0073e9SAndroid Build Coastguard Worker                return 8 * 1e-2, cosine
417*da0073e9SAndroid Build Coastguard Worker            return 1e-3, cosine
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker        if self.args.bfloat16:
420*da0073e9SAndroid Build Coastguard Worker            if name in self._tolerance["higher_bf16"]:
421*da0073e9SAndroid Build Coastguard Worker                return 1e-2, cosine
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker        if is_training and (current_device == "cuda" or current_device == "xpu"):
424*da0073e9SAndroid Build Coastguard Worker            tolerance = 1e-3
425*da0073e9SAndroid Build Coastguard Worker            if name in self._tolerance["cosine"]:
426*da0073e9SAndroid Build Coastguard Worker                cosine = True
427*da0073e9SAndroid Build Coastguard Worker            elif name in self._tolerance["higher"]:
428*da0073e9SAndroid Build Coastguard Worker                tolerance = 1e-3
429*da0073e9SAndroid Build Coastguard Worker            elif name in self._tolerance["even_higher"]:
430*da0073e9SAndroid Build Coastguard Worker                tolerance = 8 * 1e-2
431*da0073e9SAndroid Build Coastguard Worker        return tolerance, cosine
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker    def compute_loss(self, pred):
434*da0073e9SAndroid Build Coastguard Worker        return reduce_to_scalar_loss(pred)
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker    def forward_pass(self, mod, inputs, collect_outputs=True):
437*da0073e9SAndroid Build Coastguard Worker        with self.autocast(**self.autocast_arg):
438*da0073e9SAndroid Build Coastguard Worker            if isinstance(inputs, dict):
439*da0073e9SAndroid Build Coastguard Worker                return mod(**inputs)
440*da0073e9SAndroid Build Coastguard Worker            else:
441*da0073e9SAndroid Build Coastguard Worker                return mod(*inputs)
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker    def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
444*da0073e9SAndroid Build Coastguard Worker        cloned_inputs = clone_inputs(inputs)
445*da0073e9SAndroid Build Coastguard Worker        self.optimizer_zero_grad(mod)
446*da0073e9SAndroid Build Coastguard Worker        with self.autocast(**self.autocast_arg):
447*da0073e9SAndroid Build Coastguard Worker            if isinstance(cloned_inputs, dict):
448*da0073e9SAndroid Build Coastguard Worker                pred = mod(**cloned_inputs)
449*da0073e9SAndroid Build Coastguard Worker            else:
450*da0073e9SAndroid Build Coastguard Worker                pred = mod(*cloned_inputs)
451*da0073e9SAndroid Build Coastguard Worker            loss = self.compute_loss(pred)
452*da0073e9SAndroid Build Coastguard Worker        self.grad_scaler.scale(loss).backward()
453*da0073e9SAndroid Build Coastguard Worker        self.optimizer_step()
454*da0073e9SAndroid Build Coastguard Worker        if collect_outputs:
455*da0073e9SAndroid Build Coastguard Worker            return collect_results(mod, pred, loss, cloned_inputs)
456*da0073e9SAndroid Build Coastguard Worker        return None
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Workerdef torchbench_main():
460*da0073e9SAndroid Build Coastguard Worker    original_dir = setup_torchbench_cwd()
461*da0073e9SAndroid Build Coastguard Worker    logging.basicConfig(level=logging.WARNING)
462*da0073e9SAndroid Build Coastguard Worker    warnings.filterwarnings("ignore")
463*da0073e9SAndroid Build Coastguard Worker    main(TorchBenchmarkRunner(), original_dir)
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
467*da0073e9SAndroid Build Coastguard Worker    torchbench_main()
468