xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport os
2*da0073e9SAndroid Build Coastguard Workerimport unittest
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerfrom .common import parse_args, run
5*da0073e9SAndroid Build Coastguard Workerfrom .torchbench import setup_torchbench_cwd, TorchBenchmarkRunner
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workertry:
9*da0073e9SAndroid Build Coastguard Worker    # fbcode only
10*da0073e9SAndroid Build Coastguard Worker    from aiplatform.utils.sanitizer_status import is_asan_or_tsan
11*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker    def is_asan_or_tsan():
14*da0073e9SAndroid Build Coastguard Worker        return False
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerclass TestDynamoBenchmark(unittest.TestCase):
18*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(is_asan_or_tsan(), "ASAN/TSAN not supported")
19*da0073e9SAndroid Build Coastguard Worker    def test_benchmark_infra_runs(self) -> None:
20*da0073e9SAndroid Build Coastguard Worker        """
21*da0073e9SAndroid Build Coastguard Worker        Basic smoke test that TorchBench runs.
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker        This test is mainly meant to check that our setup in fbcode
24*da0073e9SAndroid Build Coastguard Worker        doesn't break.
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker        If you see a failure here related to missing CPP headers, then
27*da0073e9SAndroid Build Coastguard Worker        you likely need to update the resources list in:
28*da0073e9SAndroid Build Coastguard Worker            //caffe2:inductor
29*da0073e9SAndroid Build Coastguard Worker        """
30*da0073e9SAndroid Build Coastguard Worker        original_dir = setup_torchbench_cwd()
31*da0073e9SAndroid Build Coastguard Worker        try:
32*da0073e9SAndroid Build Coastguard Worker            args = parse_args(
33*da0073e9SAndroid Build Coastguard Worker                [
34*da0073e9SAndroid Build Coastguard Worker                    "-dcpu",
35*da0073e9SAndroid Build Coastguard Worker                    "--inductor",
36*da0073e9SAndroid Build Coastguard Worker                    "--training",
37*da0073e9SAndroid Build Coastguard Worker                    "--performance",
38*da0073e9SAndroid Build Coastguard Worker                    "--only=BERT_pytorch",
39*da0073e9SAndroid Build Coastguard Worker                    "-n1",
40*da0073e9SAndroid Build Coastguard Worker                    "--batch-size=1",
41*da0073e9SAndroid Build Coastguard Worker                ]
42*da0073e9SAndroid Build Coastguard Worker            )
43*da0073e9SAndroid Build Coastguard Worker            run(TorchBenchmarkRunner(), args, original_dir)
44*da0073e9SAndroid Build Coastguard Worker        finally:
45*da0073e9SAndroid Build Coastguard Worker            os.chdir(original_dir)
46