xref: /aosp_15_r20/external/pytorch/test/run_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3import argparse
4import copy
5import glob
6import json
7import os
8import re
9import shutil
10import signal
11import subprocess
12import sys
13import tempfile
14import time
15from collections import defaultdict
16from contextlib import ExitStack
17from datetime import datetime
18from pathlib import Path
19from typing import Any, cast, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
20
21import pkg_resources
22
23import torch
24import torch.distributed as dist
25from torch.multiprocessing import current_process, get_context
26from torch.testing._internal.common_utils import (
27    get_report_path,
28    IS_CI,
29    IS_MACOS,
30    IS_WINDOWS,
31    retry_shell,
32    set_cwd,
33    shell,
34    TEST_CUDA,
35    TEST_WITH_ASAN,
36    TEST_WITH_CROSSREF,
37    TEST_WITH_ROCM,
38    TEST_WITH_SLOW_GRADCHECK,
39)
40
41
42# using tools/ to optimize test run.
43REPO_ROOT = Path(__file__).resolve().parent.parent
44sys.path.insert(0, str(REPO_ROOT))
45
46from tools.stats.import_test_stats import (
47    ADDITIONAL_CI_FILES_FOLDER,
48    TEST_CLASS_TIMES_FILE,
49    TEST_TIMES_FILE,
50)
51from tools.stats.upload_metrics import add_global_metric, emit_metric
52from tools.testing.discover_tests import (
53    CPP_TEST_PATH,
54    CPP_TEST_PREFIX,
55    CPP_TESTS_DIR,
56    parse_test_module,
57    TESTS,
58)
59from tools.testing.do_target_determination_for_s3 import import_results
60from tools.testing.target_determination.gen_artifact import gen_ci_artifact
61from tools.testing.target_determination.heuristics.previously_failed_in_pr import (
62    gen_additional_test_failures_file,
63)
64from tools.testing.target_determination.heuristics.utils import get_pr_number
65from tools.testing.test_run import TestRun
66from tools.testing.test_selections import (
67    calculate_shards,
68    get_test_case_configs,
69    NUM_PROCS,
70    ShardedTest,
71    THRESHOLD,
72)
73
74
75# Make sure to remove REPO_ROOT after import is done
76sys.path.remove(str(REPO_ROOT))
77
78
79HAVE_TEST_SELECTION_TOOLS = True
80TEST_CONFIG = os.getenv("TEST_CONFIG", "")
81BUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "")
82RERUN_DISABLED_TESTS = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1"
83DISTRIBUTED_TEST_PREFIX = "distributed"
84INDUCTOR_TEST_PREFIX = "inductor"
85IS_SLOW = "slow" in TEST_CONFIG or "slow" in BUILD_ENVIRONMENT
86
87
88# Note [ROCm parallel CI testing]
89# https://github.com/pytorch/pytorch/pull/85770 added file-granularity parallel testing.
90# In .ci/pytorch/test.sh, TEST_CONFIG == "default", CUDA and HIP_VISIBLE_DEVICES is set to 0.
91# This results in multiple test files sharing the same GPU.
92# This should be a supported use case for ROCm, but it exposed issues in the kernel driver resulting in hangs.
93# See https://github.com/pytorch/pytorch/issues/90940.
94#
95# Further, ROCm self-hosted runners have up to 4 GPUs.
96# Device visibility was set to 0 to match CUDA test behavior, but this was wasting available GPU resources.
97# Assigning each Pool worker their own dedicated GPU avoids the ROCm oversubscription issues.
98# This should also result in better overall wall clock time since all GPUs can be utilized.
99def maybe_set_hip_visible_devies():
100    # Special handling of ROCm GHA runners for parallel (file granularity) tests.
101    if torch.version.hip:
102        p = current_process()
103        if p.name != "MainProcess":
104            # this is a Process from a parallel Pool, not the MainProcess
105            os.environ["HIP_VISIBLE_DEVICES"] = str(p._identity[0] % NUM_PROCS)
106
107
108def strtobool(s):
109    return s.lower() not in {"", "0", "false", "off"}
110
111
112class TestChoices(list):
113    def __init__(self, *args, **kwargs):
114        super().__init__(args[0])
115
116    def __contains__(self, item):
117        return list.__contains__(self, parse_test_module(item))
118
119
120FSDP_TEST = [test for test in TESTS if test.startswith("distributed/fsdp")]
121
122WINDOWS_BLOCKLIST = [
123    "distributed/nn/jit/test_instantiator",
124    "distributed/rpc/test_faulty_agent",
125    "distributed/rpc/test_tensorpipe_agent",
126    "distributed/rpc/test_share_memory",
127    "distributed/rpc/cuda/test_tensorpipe_agent",
128    "distributed/pipeline/sync/skip/test_api",
129    "distributed/pipeline/sync/skip/test_gpipe",
130    "distributed/pipeline/sync/skip/test_inspect_skip_layout",
131    "distributed/pipeline/sync/skip/test_leak",
132    "distributed/pipeline/sync/skip/test_portal",
133    "distributed/pipeline/sync/skip/test_stash_pop",
134    "distributed/pipeline/sync/skip/test_tracker",
135    "distributed/pipeline/sync/skip/test_verify_skippables",
136    "distributed/pipeline/sync/test_balance",
137    "distributed/pipeline/sync/test_bugs",
138    "distributed/pipeline/sync/test_checkpoint",
139    "distributed/pipeline/sync/test_copy",
140    "distributed/pipeline/sync/test_deferred_batch_norm",
141    "distributed/pipeline/sync/test_dependency",
142    "distributed/pipeline/sync/test_inplace",
143    "distributed/pipeline/sync/test_microbatch",
144    "distributed/pipeline/sync/test_phony",
145    "distributed/pipeline/sync/test_pipe",
146    "distributed/pipeline/sync/test_pipeline",
147    "distributed/pipeline/sync/test_stream",
148    "distributed/pipeline/sync/test_transparency",
149    "distributed/pipeline/sync/test_worker",
150    "distributed/elastic/agent/server/test/api_test",
151    "distributed/elastic/multiprocessing/api_test",
152    "distributed/_shard/checkpoint/test_checkpoint"
153    "distributed/_shard/checkpoint/test_file_system_checkpoint"
154    "distributed/_shard/sharding_spec/test_sharding_spec",
155    "distributed/_shard/sharding_plan/test_sharding_plan",
156    "distributed/_shard/sharded_tensor/test_sharded_tensor",
157    "distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
158    "distributed/_shard/sharded_tensor/ops/test_embedding",
159    "distributed/_shard/sharded_tensor/ops/test_embedding_bag",
160    "distributed/_shard/sharded_tensor/ops/test_binary_cmp",
161    "distributed/_shard/sharded_tensor/ops/test_init",
162    "distributed/_shard/sharded_optim/test_sharded_optim",
163] + FSDP_TEST
164
165ROCM_BLOCKLIST = [
166    "distributed/rpc/test_faulty_agent",
167    "distributed/rpc/test_tensorpipe_agent",
168    "distributed/rpc/test_share_memory",
169    "distributed/rpc/cuda/test_tensorpipe_agent",
170    "distributed/_shard/checkpoint/test_checkpoint"
171    "distributed/_shard/checkpoint/test_file_system_checkpoint"
172    "distributed/_shard/sharding_spec/test_sharding_spec",
173    "distributed/_shard/sharding_plan/test_sharding_plan",
174    "distributed/_shard/sharded_tensor/test_sharded_tensor",
175    "distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
176    "distributed/_shard/sharded_tensor/ops/test_embedding",
177    "distributed/_shard/sharded_tensor/ops/test_embedding_bag",
178    "distributed/_shard/sharded_tensor/ops/test_binary_cmp",
179    "distributed/_shard/sharded_tensor/ops/test_init",
180    "distributed/_shard/sharded_optim/test_sharded_optim",
181    "test_determination",
182    "test_jit_legacy",
183    "test_cuda_nvml_based_avail",
184    "test_jit_cuda_fuser",
185    "distributed/_tensor/test_attention",
186]
187
188XPU_BLOCKLIST = [
189    "test_autograd",
190    "profiler/test_cpp_thread",
191    "profiler/test_execution_trace",
192    "profiler/test_memory_profiler",
193    "profiler/test_profiler",
194    "profiler/test_profiler_tree",
195    "profiler/test_record_function",
196    "profiler/test_torch_tidy",
197]
198
199XPU_TEST = [
200    "test_xpu",
201]
202
203# The tests inside these files should never be run in parallel with each other
204RUN_PARALLEL_BLOCKLIST = [
205    "test_cpp_extensions_jit",
206    "test_cpp_extensions_open_device_registration",
207    "test_cpp_extensions_stream_and_event",
208    "test_cpp_extensions_mtia_backend",
209    "test_jit_disabled",
210    "test_mobile_optimizer",
211    "test_multiprocessing",
212    "test_multiprocessing_spawn",
213    "test_namedtuple_return_api",
214    "test_overrides",
215    "test_show_pickle",
216    "test_tensorexpr",
217    "test_cuda_primary_ctx",
218    "test_cuda_trace",
219    "inductor/test_benchmark_fusion",
220    "test_cuda_nvml_based_avail",
221    # temporarily sets a global config
222    "test_autograd_fallback",
223] + FSDP_TEST
224
225# Test files that should always be run serially with other test files,
226# but it's okay if the tests inside them are run in parallel with each other.
227CI_SERIAL_LIST = [
228    "test_nn",
229    "test_fake_tensor",
230    "test_cpp_api_parity",
231    "test_reductions",
232    "test_fx_backends",
233    "test_cpp_extensions_jit",
234    "test_torch",
235    "test_tensor_creation_ops",
236    "test_dispatch",
237    "test_python_dispatch",  # torch.library creation and deletion must be serialized
238    "test_spectral_ops",  # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/88916
239    "nn/test_pooling",
240    "nn/test_convolution",  # Doesn't respect set_per_process_memory_fraction, results in OOM for other tests in slow gradcheck
241    "distributions/test_distributions",
242    "test_fx",  # gets SIGKILL
243    "functorch/test_memory_efficient_fusion",  # Cause CUDA OOM on ROCm
244    "test_utils",  # OOM
245    "test_sort_and_select",  # OOM
246    "test_backward_compatible_arguments",  # OOM
247    "test_autocast",  # OOM
248    "test_native_mha",  # OOM
249    "test_module_hooks",  # OOM
250    "inductor/test_max_autotune",
251    "inductor/test_cutlass_backend",  # slow due to many nvcc compilation steps,
252    "inductor/test_flex_attention",  # OOM
253]
254# A subset of onnx tests that cannot run in parallel due to high memory usage.
255ONNX_SERIAL_LIST = [
256    "onnx/test_models",
257    "onnx/test_models_quantized_onnxruntime",
258    "onnx/test_models_onnxruntime",
259    "onnx/test_custom_ops",
260    "onnx/test_utility_funs",
261]
262
263# A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected
264CORE_TEST_LIST = [
265    "test_autograd",
266    "test_autograd_fallback",
267    "test_modules",
268    "test_nn",
269    "test_ops",
270    "test_ops_gradients",
271    "test_ops_fwd_gradients",
272    "test_ops_jit",
273    "test_torch",
274]
275
276
277# if a test file takes longer than 5 min, we add it to TARGET_DET_LIST
278SLOW_TEST_THRESHOLD = 300
279
280DISTRIBUTED_TESTS_CONFIG = {}
281
282
283if dist.is_available():
284    DISTRIBUTED_TESTS_CONFIG["test"] = {"WORLD_SIZE": "1"}
285    if not TEST_WITH_ROCM and dist.is_mpi_available():
286        DISTRIBUTED_TESTS_CONFIG["mpi"] = {
287            "WORLD_SIZE": "3",
288            "TEST_REPORT_SOURCE_OVERRIDE": "dist-mpi",
289        }
290    if dist.is_nccl_available():
291        DISTRIBUTED_TESTS_CONFIG["nccl"] = {
292            "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
293            "TEST_REPORT_SOURCE_OVERRIDE": "dist-nccl",
294        }
295    if dist.is_gloo_available():
296        DISTRIBUTED_TESTS_CONFIG["gloo"] = {
297            "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
298            "TEST_REPORT_SOURCE_OVERRIDE": "dist-gloo",
299        }
300    if dist.is_ucc_available():
301        DISTRIBUTED_TESTS_CONFIG["ucc"] = {
302            "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
303            "TEST_REPORT_SOURCE_OVERRIDE": "dist-ucc",
304            "UCX_TLS": "tcp,cuda",
305            "UCC_TLS": "nccl,ucp,cuda",
306            "UCC_TL_UCP_TUNE": "cuda:0",  # don't use UCP TL on CUDA as it is not well supported
307            "UCC_EC_CUDA_USE_COOPERATIVE_LAUNCH": "n",  # CI nodes (M60) fail if it is on
308        }
309
310# https://stackoverflow.com/questions/2549939/get-signal-names-from-numbers-in-python
311SIGNALS_TO_NAMES_DICT = {
312    getattr(signal, n): n for n in dir(signal) if n.startswith("SIG") and "_" not in n
313}
314
315CPP_EXTENSIONS_ERROR = """
316Ninja (https://ninja-build.org) is required for some of the C++ extensions
317tests, but it could not be found. Install ninja with `pip install ninja`
318or `conda install ninja`. Alternatively, disable said tests with
319`run_test.py --exclude test_cpp_extensions_aot_ninja test_cpp_extensions_jit`.
320"""
321
322PYTORCH_COLLECT_COVERAGE = bool(os.environ.get("PYTORCH_COLLECT_COVERAGE"))
323
324JIT_EXECUTOR_TESTS = [
325    "test_jit_profiling",
326    "test_jit_legacy",
327    "test_jit_fuser_legacy",
328]
329
330INDUCTOR_TESTS = [test for test in TESTS if test.startswith(INDUCTOR_TEST_PREFIX)]
331DISTRIBUTED_TESTS = [test for test in TESTS if test.startswith(DISTRIBUTED_TEST_PREFIX)]
332TORCH_EXPORT_TESTS = [test for test in TESTS if test.startswith("export")]
333FUNCTORCH_TESTS = [test for test in TESTS if test.startswith("functorch")]
334ONNX_TESTS = [test for test in TESTS if test.startswith("onnx")]
335CPP_TESTS = [test for test in TESTS if test.startswith(CPP_TEST_PREFIX)]
336
337TESTS_REQUIRING_LAPACK = [
338    "distributions/test_constraints",
339    "distributions/test_distributions",
340]
341
342# These are just the slowest ones, this isn't an exhaustive list.
343TESTS_NOT_USING_GRADCHECK = [
344    # Note that you should use skipIfSlowGradcheckEnv if you do not wish to
345    # skip all the tests in that file, e.g. test_mps
346    "doctests",
347    "test_meta",
348    "test_hub",
349    "test_fx",
350    "test_decomp",
351    "test_cpp_extensions_jit",
352    "test_jit",
353    "test_ops",
354    "test_ops_jit",
355    "dynamo/test_recompile_ux",
356    "inductor/test_smoke",
357    "test_quantization",
358]
359
360
361def print_to_stderr(message):
362    print(message, file=sys.stderr)
363
364
365def get_executable_command(options, disable_coverage=False, is_cpp_test=False):
366    if options.coverage and not disable_coverage:
367        if not is_cpp_test:
368            executable = ["coverage", "run", "--parallel-mode", "--source=torch"]
369        else:
370            # TODO: C++ with coverage is not yet supported
371            executable = []
372    else:
373        if not is_cpp_test:
374            executable = [sys.executable, "-bb"]
375        else:
376            executable = ["pytest"]
377
378    return executable
379
380
381def run_test(
382    test_module: ShardedTest,
383    test_directory,
384    options,
385    launcher_cmd=None,
386    extra_unittest_args=None,
387    env=None,
388    print_log=True,
389) -> int:
390    scribe_token = os.getenv("SCRIBE_GRAPHQL_ACCESS_TOKEN", "")
391    if scribe_token:
392        print_to_stderr("SCRIBE_GRAPHQL_ACCESS_TOKEN is set")
393    else:
394        print_to_stderr("SCRIBE_GRAPHQL_ACCESS_TOKEN is NOT set")
395
396    env = env or os.environ.copy()
397    maybe_set_hip_visible_devies()
398    unittest_args = options.additional_args.copy()
399    test_file = test_module.name
400    stepcurrent_key = test_file
401
402    is_distributed_test = test_file.startswith(DISTRIBUTED_TEST_PREFIX)
403    is_cpp_test = test_file.startswith(CPP_TEST_PREFIX)
404    # NB: Rerun disabled tests depends on pytest-flakefinder and it doesn't work with
405    # pytest-cpp atm. We also don't have support to disable C++ test yet, so it's ok
406    # to just return successfully here
407    if is_cpp_test and RERUN_DISABLED_TESTS:
408        print_to_stderr(
409            "Skipping C++ tests when running under RERUN_DISABLED_TESTS mode"
410        )
411        return 0
412
413    if is_cpp_test:
414        stepcurrent_key = f"{test_file}_{os.urandom(8).hex()}"
415    else:
416        unittest_args.extend(
417            [
418                f"--shard-id={test_module.shard}",
419                f"--num-shards={test_module.num_shards}",
420            ]
421        )
422        stepcurrent_key = f"{test_file}_{test_module.shard}_{os.urandom(8).hex()}"
423
424    if options.verbose:
425        unittest_args.append(f'-{"v" * options.verbose}')  # in case of pytest
426
427    if test_file in RUN_PARALLEL_BLOCKLIST:
428        unittest_args = [
429            arg for arg in unittest_args if not arg.startswith("--run-parallel")
430        ]
431
432    if extra_unittest_args:
433        assert isinstance(extra_unittest_args, list)
434        unittest_args.extend(extra_unittest_args)
435
436    # If using pytest, replace -f with equivalent -x
437    if options.pytest:
438        unittest_args.extend(
439            get_pytest_args(
440                options,
441                is_cpp_test=is_cpp_test,
442                is_distributed_test=is_distributed_test,
443            )
444        )
445        unittest_args.extend(test_module.get_pytest_args())
446        replacement = {"-f": "-x"}
447        unittest_args = [replacement.get(arg, arg) for arg in unittest_args]
448
449    if options.showlocals:
450        if options.pytest:
451            unittest_args.extend(["--showlocals", "--tb=long", "--color=yes"])
452        else:
453            unittest_args.append("--locals")
454
455    # NB: These features are not available for C++ tests, but there is little incentive
456    # to implement it because we have never seen a flaky C++ test before.
457    if IS_CI and not is_cpp_test:
458        ci_args = ["--import-slow-tests", "--import-disabled-tests"]
459        if RERUN_DISABLED_TESTS:
460            ci_args.append("--rerun-disabled-tests")
461        # use the downloaded test cases configuration, not supported in pytest
462        unittest_args.extend(ci_args)
463
464    if test_file in PYTEST_SKIP_RETRIES:
465        if not options.pytest:
466            raise RuntimeError(
467                "A test running without pytest cannot skip retries using "
468                "the PYTEST_SKIP_RETRIES set."
469            )
470        unittest_args = [arg for arg in unittest_args if "--reruns" not in arg]
471
472    # Extra arguments are not supported with pytest
473    executable = get_executable_command(options, is_cpp_test=is_cpp_test)
474    if not executable:
475        # If there is no eligible executable returning here, it means an unsupported
476        # case such as coverage for C++ test. So just returning ok makes sense
477        return 0
478
479    if test_file.startswith(CPP_TEST_PREFIX):
480        # C++ tests are not the regular test directory
481        if CPP_TESTS_DIR:
482            cpp_test = os.path.join(
483                CPP_TESTS_DIR,
484                test_file.replace(f"{CPP_TEST_PREFIX}/", ""),
485            )
486        else:
487            cpp_test = os.path.join(
488                Path(test_directory).parent,
489                CPP_TEST_PATH,
490                test_file.replace(f"{CPP_TEST_PREFIX}/", ""),
491            )
492
493        argv = [
494            cpp_test if sys.platform != "win32" else cpp_test + ".exe"
495        ] + unittest_args
496    else:
497        # Can't call `python -m unittest test_*` here because it doesn't run code
498        # in `if __name__ == '__main__': `. So call `python test_*.py` instead.
499        argv = [test_file + ".py"] + unittest_args
500
501    os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
502    if options.pipe_logs:
503        log_fd, log_path = tempfile.mkstemp(
504            dir=REPO_ROOT / "test" / "test-reports",
505            prefix=f"{sanitize_file_name(str(test_module))}_",
506            suffix="_toprint.log",
507        )
508        os.close(log_fd)
509
510    command = (launcher_cmd or []) + executable + argv
511    should_retry = (
512        "--subprocess" not in command
513        and not RERUN_DISABLED_TESTS
514        and not is_cpp_test
515        and "-n" not in command
516    )
517    timeout = (
518        None
519        if not options.enable_timeout
520        else THRESHOLD * 6
521        if IS_SLOW
522        else THRESHOLD * 3
523        if should_retry
524        and isinstance(test_module, ShardedTest)
525        and test_module.time is not None
526        else THRESHOLD * 3
527        if is_cpp_test
528        else None
529    )
530    print_to_stderr(f"Executing {command} ... [{datetime.now()}]")
531
532    with ExitStack() as stack:
533        output = None
534        if options.pipe_logs:
535            output = stack.enter_context(open(log_path, "w"))
536
537        if should_retry:
538            ret_code, was_rerun = run_test_retries(
539                command,
540                test_directory,
541                env,
542                timeout,
543                stepcurrent_key,
544                output,
545                options.continue_through_error,
546            )
547        else:
548            command.extend([f"--sc={stepcurrent_key}", "--print-items"])
549            ret_code, was_rerun = retry_shell(
550                command,
551                test_directory,
552                stdout=output,
553                stderr=output,
554                env=env,
555                timeout=timeout,
556                retries=0,
557            )
558
559            # Pytest return code 5 means no test is collected. Exit code 4 is
560            # returned when the binary is not a C++ test executable, but 4 can
561            # also be returned if the file fails before running any tests. All
562            # binary files under build/bin that are not C++ test at the time of
563            # this writing have been excluded and new ones should be added to
564            # the list of exclusions in tools/testing/discover_tests.py
565            ret_code = 0 if ret_code == 5 else ret_code
566
567    if options.pipe_logs and print_log:
568        handle_log_file(
569            test_module, log_path, failed=(ret_code != 0), was_rerun=was_rerun
570        )
571    return ret_code
572
573
574def try_set_cpp_stack_traces(env, command, set=True):
575    # Print full c++ stack traces during retries
576    env = env or {}
577    env["TORCH_SHOW_CPP_STACKTRACES"] = "1" if set else "0"
578    return env
579
580
581def run_test_retries(
582    command,
583    test_directory,
584    env,
585    timeout,
586    stepcurrent_key,
587    output,
588    continue_through_error,
589):
590    # Run the test with -x to stop at first failure.  Rerun the test by itself.
591    # If it succeeds, move on to the rest of the tests in a new process.  If it
592    # still fails, see below
593    #
594    # If continue through error is not set, then we fail fast.
595    #
596    # If continue through error is set, then we skip that test, and keep going.
597    # Basically if the same test fails 3 times in a row, skip the test on the
598    # next run, but still fail in the end. I take advantage of the value saved
599    # in stepcurrent to keep track of the most recently run test (which is the
600    # one that failed if there was a failure).
601
602    def print_to_file(s):
603        print(s, file=output, flush=True)
604
605    num_failures = defaultdict(int)
606
607    print_items = ["--print-items"]
608    sc_command = f"--sc={stepcurrent_key}"
609    while True:
610        ret_code, _ = retry_shell(
611            command + [sc_command] + print_items,
612            test_directory,
613            stdout=output,
614            stderr=output,
615            env=env,
616            timeout=timeout,
617            retries=0,  # no retries here, we do it ourselves, this is because it handles timeout exceptions well
618        )
619        ret_code = 0 if ret_code == 5 else ret_code
620        if ret_code == 0 and not sc_command.startswith("--rs="):
621            break  # Got to the end of the test suite successfully
622        signal_name = f" ({SIGNALS_TO_NAMES_DICT[-ret_code]})" if ret_code < 0 else ""
623        print_to_file(f"Got exit code {ret_code}{signal_name}")
624
625        # Read what just failed/ran
626        try:
627            with open(
628                REPO_ROOT / ".pytest_cache/v/cache/stepcurrent" / stepcurrent_key
629            ) as f:
630                current_failure = f.read()
631        except FileNotFoundError:
632            print_to_file(
633                "No stepcurrent file found. Either pytest didn't get to run (e.g. import error)"
634                + " or file got deleted (contact dev infra)"
635            )
636            break
637
638        env = try_set_cpp_stack_traces(env, command, set=False)
639        if ret_code != 0:
640            num_failures[current_failure] += 1
641
642        if ret_code == 0:
643            # Rerunning the previously failing test succeeded, so now we can
644            # skip it and move on
645            sc_command = f"--scs={stepcurrent_key}"
646            print_to_file(
647                "Test succeeeded in new process, continuing with the rest of the tests"
648            )
649        elif num_failures[current_failure] >= 3:
650            if not continue_through_error:
651                print_to_file("Stopping at first consistent failure")
652                break
653            sc_command = f"--scs={stepcurrent_key}"
654            print_to_file(
655                "Test failed consistently, "
656                "continuing with the rest of the tests due to continue-through-error being set"
657            )
658        else:
659            env = try_set_cpp_stack_traces(env, command, set=True)
660            sc_command = f"--rs={stepcurrent_key}"
661            print_to_file("Retrying single test...")
662        print_items = []  # do not continue printing them, massive waste of space
663
664    consistent_failures = [x[1:-1] for x in num_failures.keys() if num_failures[x] >= 3]
665    flaky_failures = [x[1:-1] for x in num_failures.keys() if 0 < num_failures[x] < 3]
666    if len(flaky_failures) > 0:
667        print_to_file(
668            "The following tests failed and then succeeded when run in a new process"
669            + f"{flaky_failures}",
670        )
671    if len(consistent_failures) > 0:
672        print_to_file(f"The following tests failed consistently: {consistent_failures}")
673        return 1, True
674    return ret_code, any(x > 0 for x in num_failures.values())
675
676
677def run_test_with_subprocess(test_module, test_directory, options):
678    return run_test(
679        test_module, test_directory, options, extra_unittest_args=["--subprocess"]
680    )
681
682
683def _test_cpp_extensions_aot(test_directory, options, use_ninja):
684    if use_ninja:
685        try:
686            from torch.utils import cpp_extension
687
688            cpp_extension.verify_ninja_availability()
689        except RuntimeError:
690            print_to_stderr(CPP_EXTENSIONS_ERROR)
691            return 1
692
693    # Wipe the build folder, if it exists already
694    cpp_extensions_test_dir = os.path.join(test_directory, "cpp_extensions")
695    cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build")
696    if os.path.exists(cpp_extensions_test_build_dir):
697        shutil.rmtree(cpp_extensions_test_build_dir)
698
699    # Build the test cpp extensions modules
700    shell_env = os.environ.copy()
701    shell_env["USE_NINJA"] = str(1 if use_ninja else 0)
702    cmd = [sys.executable, "setup.py", "install", "--root", "./install"]
703    return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=shell_env)
704    if return_code != 0:
705        return return_code
706    if sys.platform != "win32":
707        return_code = shell(
708            cmd,
709            cwd=os.path.join(cpp_extensions_test_dir, "no_python_abi_suffix_test"),
710            env=shell_env,
711        )
712        if return_code != 0:
713            return return_code
714
715    # "install" the test modules and run tests
716    python_path = os.environ.get("PYTHONPATH", "")
717    from shutil import copyfile
718
719    os.environ["USE_NINJA"] = shell_env["USE_NINJA"]
720    test_module = "test_cpp_extensions_aot" + ("_ninja" if use_ninja else "_no_ninja")
721    copyfile(
722        test_directory + "/test_cpp_extensions_aot.py",
723        test_directory + "/" + test_module + ".py",
724    )
725    try:
726        cpp_extensions = os.path.join(test_directory, "cpp_extensions")
727        install_directory = ""
728        # install directory is the one that is named site-packages
729        for root, directories, _ in os.walk(os.path.join(cpp_extensions, "install")):
730            for directory in directories:
731                if "-packages" in directory:
732                    install_directory = os.path.join(root, directory)
733
734        assert install_directory, "install_directory must not be empty"
735        os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
736        return run_test(ShardedTest(test_module, 1, 1), test_directory, options)
737    finally:
738        os.environ["PYTHONPATH"] = python_path
739        if os.path.exists(test_directory + "/" + test_module + ".py"):
740            os.remove(test_directory + "/" + test_module + ".py")
741        os.environ.pop("USE_NINJA")
742
743
744def test_cpp_extensions_aot_ninja(test_module, test_directory, options):
745    return _test_cpp_extensions_aot(test_directory, options, use_ninja=True)
746
747
748def test_cpp_extensions_aot_no_ninja(test_module, test_directory, options):
749    return _test_cpp_extensions_aot(test_directory, options, use_ninja=False)
750
751
752def test_autoload_enable(test_module, test_directory, options):
753    return _test_autoload(test_directory, options, enable=True)
754
755
756def test_autoload_disable(test_module, test_directory, options):
757    return _test_autoload(test_directory, options, enable=False)
758
759
760def _test_autoload(test_directory, options, enable=True):
761    # Wipe the build folder, if it exists already
762    cpp_extensions_test_dir = os.path.join(test_directory, "cpp_extensions")
763    cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build")
764    if os.path.exists(cpp_extensions_test_build_dir):
765        shutil.rmtree(cpp_extensions_test_build_dir)
766
767    # Build the test cpp extensions modules
768    cmd = [sys.executable, "setup.py", "install", "--root", "./install"]
769    return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=os.environ)
770    if return_code != 0:
771        return return_code
772
773    # "install" the test modules and run tests
774    python_path = os.environ.get("PYTHONPATH", "")
775
776    try:
777        cpp_extensions = os.path.join(test_directory, "cpp_extensions")
778        install_directory = ""
779        # install directory is the one that is named site-packages
780        for root, directories, _ in os.walk(os.path.join(cpp_extensions, "install")):
781            for directory in directories:
782                if "-packages" in directory:
783                    install_directory = os.path.join(root, directory)
784
785        assert install_directory, "install_directory must not be empty"
786        os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
787        os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = str(int(enable))
788
789        cmd = [sys.executable, "test_autoload.py"]
790        return_code = shell(cmd, cwd=test_directory, env=os.environ)
791        return return_code
792    finally:
793        os.environ["PYTHONPATH"] = python_path
794        os.environ.pop("TORCH_DEVICE_BACKEND_AUTOLOAD")
795
796
797def test_distributed(test_module, test_directory, options):
798    # MPI tests are broken with Python-3.9
799    mpi_available = subprocess.call(
800        "command -v mpiexec", shell=True
801    ) == 0 and sys.version_info < (3, 9)
802    if options.verbose and not mpi_available:
803        print_to_stderr("MPI not available -- MPI backend tests will be skipped")
804
805    config = DISTRIBUTED_TESTS_CONFIG
806    for backend, env_vars in config.items():
807        if sys.platform == "win32" and backend != "gloo":
808            continue
809        if backend == "mpi" and not mpi_available:
810            continue
811        for with_init_file in {True, False}:
812            if sys.platform == "win32" and not with_init_file:
813                continue
814            tmp_dir = tempfile.mkdtemp()
815            if options.verbose:
816                init_str = "with {} init_method"
817                with_init = init_str.format("file" if with_init_file else "env")
818                print_to_stderr(
819                    f"Running distributed tests for the {backend} backend {with_init}"
820                )
821            old_environ = dict(os.environ)
822            os.environ["TEMP_DIR"] = tmp_dir
823            os.environ["BACKEND"] = backend
824            os.environ.update(env_vars)
825            try:
826                os.mkdir(os.path.join(tmp_dir, "barrier"))
827                os.mkdir(os.path.join(tmp_dir, "test_dir"))
828                if backend == "mpi":
829                    # test mpiexec for --noprefix option
830                    with open(os.devnull, "w") as devnull:
831                        allowrunasroot_opt = (
832                            "--allow-run-as-root"
833                            if subprocess.call(
834                                'mpiexec --allow-run-as-root -n 1 bash -c ""',
835                                shell=True,
836                                stdout=devnull,
837                                stderr=subprocess.STDOUT,
838                            )
839                            == 0
840                            else ""
841                        )
842                        noprefix_opt = (
843                            "--noprefix"
844                            if subprocess.call(
845                                f'mpiexec {allowrunasroot_opt} -n 1 --noprefix bash -c ""',
846                                shell=True,
847                                stdout=devnull,
848                                stderr=subprocess.STDOUT,
849                            )
850                            == 0
851                            else ""
852                        )
853
854                    mpiexec = ["mpiexec", "-n", "3", noprefix_opt, allowrunasroot_opt]
855
856                    return_code = run_test(
857                        test_module, test_directory, options, launcher_cmd=mpiexec
858                    )
859                else:
860                    return_code = run_test(
861                        test_module,
862                        test_directory,
863                        options,
864                        extra_unittest_args=["--subprocess"],
865                    )
866                if return_code != 0:
867                    return return_code
868            finally:
869                shutil.rmtree(tmp_dir)
870                os.environ.clear()
871                os.environ.update(old_environ)
872    return 0
873
874
875def run_doctests(test_module, test_directory, options):
876    """
877    Assumes the incoming test module is called doctest, and simply executes the
878    xdoctest runner on the torch library itself.
879    """
880    import xdoctest
881
882    pkgpath = Path(torch.__file__).parent
883
884    exclude_module_list = ["torch._vendor.*"]
885    enabled = {
886        # TODO: expose these options to the user
887        # For now disable all feature-conditional tests
888        # 'lapack': 'auto',
889        # 'cuda': 'auto',
890        # 'cuda1': 'auto',
891        # 'qengine': 'auto',
892        "lapack": 0,
893        "cuda": 0,
894        "cuda1": 0,
895        "qengine": 0,
896        "autograd_profiler": 0,
897        "cpp_ext": 0,
898        "monitor": 0,
899        "onnx": "auto",
900    }
901
902    # Resolve "auto" based on a test to determine if the feature is available.
903    if enabled["cuda"] == "auto" and torch.cuda.is_available():
904        enabled["cuda"] = True
905
906    if (
907        enabled["cuda1"] == "auto"
908        and torch.cuda.is_available()
909        and torch.cuda.device_count() > 1
910    ):
911        enabled["cuda1"] = True
912
913    if enabled["lapack"] == "auto" and torch._C.has_lapack:
914        enabled["lapack"] = True
915
916    if enabled["qengine"] == "auto":
917        try:
918            # Is there a better check if quantization is enabled?
919            import torch.ao.nn.quantized as nnq  # NOQA: F401
920
921            torch.backends.quantized.engine = "qnnpack"
922            torch.backends.quantized.engine = "fbgemm"
923        except (ImportError, RuntimeError):
924            ...
925        else:
926            enabled["qengine"] = True
927
928    if enabled["onnx"] == "auto":
929        try:
930            import onnx  # NOQA: F401
931            import onnxruntime  # NOQA: F401
932            import onnxscript  # NOQA: F401
933        except ImportError:
934            exclude_module_list.append("torch.onnx.*")
935            enabled["onnx"] = False
936        else:
937            enabled["onnx"] = True
938
939    # Set doctest environment variables
940    if enabled["cuda"]:
941        os.environ["TORCH_DOCTEST_CUDA"] = "1"
942
943    if enabled["cuda1"]:
944        os.environ["TORCH_DOCTEST_CUDA1"] = "1"
945
946    if enabled["lapack"]:
947        os.environ["TORCH_DOCTEST_LAPACK"] = "1"
948
949    if enabled["qengine"]:
950        os.environ["TORCH_DOCTEST_QENGINE"] = "1"
951
952    if enabled["autograd_profiler"]:
953        os.environ["TORCH_DOCTEST_AUTOGRAD_PROFILER"] = "1"
954
955    if enabled["cpp_ext"]:
956        os.environ["TORCH_DOCTEST_CPP_EXT"] = "1"
957
958    if enabled["monitor"]:
959        os.environ["TORCH_DOCTEST_MONITOR"] = "1"
960
961    if enabled["onnx"]:
962        os.environ["TORCH_DOCTEST_ONNX"] = "1"
963
964    if 0:
965        # TODO: could try to enable some of these
966        os.environ["TORCH_DOCTEST_QUANTIZED_DYNAMIC"] = "1"
967        os.environ["TORCH_DOCTEST_ANOMALY"] = "1"
968        os.environ["TORCH_DOCTEST_AUTOGRAD"] = "1"
969        os.environ["TORCH_DOCTEST_HUB"] = "1"
970        os.environ["TORCH_DOCTEST_DATALOADER"] = "1"
971        os.environ["TORCH_DOCTEST_FUTURES"] = "1"
972
973    pkgpath = os.path.dirname(torch.__file__)
974
975    xdoctest_config = {
976        "global_exec": r"\n".join(
977            [
978                "from torch import nn",
979                "import torch.nn.functional as F",
980                "import torch",
981            ]
982        ),
983        "analysis": "static",  # set to "auto" to test doctests in compiled modules
984        "style": "google",
985        "options": "+IGNORE_WHITESPACE",
986    }
987    xdoctest_verbose = max(1, options.verbose)
988    run_summary = xdoctest.runner.doctest_module(
989        os.fspath(pkgpath),
990        config=xdoctest_config,
991        verbose=xdoctest_verbose,
992        command=options.xdoctest_command,
993        argv=[],
994        exclude=exclude_module_list,
995    )
996    result = 1 if run_summary.get("n_failed", 0) else 0
997    return result
998
999
1000def sanitize_file_name(file: str):
1001    return file.replace("\\", ".").replace("/", ".").replace(" ", "_")
1002
1003
1004def handle_log_file(
1005    test: ShardedTest, file_path: str, failed: bool, was_rerun: bool
1006) -> None:
1007    test = str(test)
1008    with open(file_path, errors="ignore") as f:
1009        full_text = f.read()
1010
1011    new_file = "test/test-reports/" + sanitize_file_name(
1012        f"{test}_{os.urandom(8).hex()}_.log"
1013    )
1014    os.rename(file_path, REPO_ROOT / new_file)
1015
1016    if not failed and not was_rerun and "=== RERUNS ===" not in full_text:
1017        # If success + no retries (idk how else to check for test level retries
1018        # other than reparse xml), print only what tests ran
1019        print_to_stderr(
1020            f"\n{test} was successful, full logs can be found in artifacts with path {new_file}"
1021        )
1022        for line in full_text.splitlines():
1023            if re.search("Running .* items in this shard:", line):
1024                print_to_stderr(line.rstrip())
1025        print_to_stderr("")
1026        return
1027
1028    # otherwise: print entire file
1029    print_to_stderr(f"\nPRINTING LOG FILE of {test} ({new_file})")
1030    print_to_stderr(full_text)
1031    print_to_stderr(f"FINISHED PRINTING LOG FILE of {test} ({new_file})\n")
1032
1033
1034def get_pytest_args(options, is_cpp_test=False, is_distributed_test=False):
1035    if RERUN_DISABLED_TESTS:
1036        # Distributed tests are too slow, so running them x50 will cause the jobs to timeout after
1037        # 3+ hours. So, let's opt for less number of reruns. We need at least 150 instances of the
1038        # test every 2 weeks to satisfy the Rockset query (15 x 14 = 210). The same logic applies
1039        # to ASAN, which is also slow
1040        count = 15 if is_distributed_test or TEST_WITH_ASAN else 50
1041        # When under rerun-disabled-tests mode, run the same tests multiple times to determine their
1042        # flakiness status. Default to 50 re-runs
1043        rerun_options = ["--flake-finder", f"--flake-runs={count}"]
1044    else:
1045        # When under the normal mode, retry a failed test 2 more times. -x means stop at the first
1046        # failure
1047        rerun_options = ["-x", "--reruns=2"]
1048
1049    pytest_args = [
1050        "-vv",
1051        "-rfEX",
1052    ]
1053    if not is_cpp_test:
1054        # C++ tests need to be run with pytest directly, not via python
1055        # We have a custom pytest shard that conflicts with the normal plugin
1056        pytest_args.extend(["-p", "no:xdist", "--use-pytest"])
1057    else:
1058        # Use pytext-dist to run C++ tests in parallel as running them sequentially using run_test
1059        # is much slower than running them directly
1060        pytest_args.extend(["-n", str(NUM_PROCS)])
1061
1062        if IS_CI:
1063            # Add the option to generate XML test report here as C++ tests
1064            # won't go into common_utils
1065            test_report_path = get_report_path(pytest=True)
1066            pytest_args.extend(["--junit-xml-reruns", test_report_path])
1067
1068    if options.pytest_k_expr:
1069        pytest_args.extend(["-k", options.pytest_k_expr])
1070
1071    pytest_args.extend(rerun_options)
1072    return pytest_args
1073
1074
1075def run_ci_sanity_check(test: ShardedTest, test_directory, options):
1076    assert (
1077        test.name == "test_ci_sanity_check_fail"
1078    ), f"This handler only works for test_ci_sanity_check_fail, got {test.name}"
1079    ret_code = run_test(test, test_directory, options, print_log=False)
1080    # This test should fail
1081    if ret_code != 1:
1082        return 1
1083    test_reports_dir = str(REPO_ROOT / "test/test-reports")
1084    # Delete the log files and xmls generated by the test
1085    for file in glob.glob(f"{test_reports_dir}/{test.name}*.log"):
1086        os.remove(file)
1087    for dirname in glob.glob(f"{test_reports_dir}/**/{test.name}"):
1088        shutil.rmtree(dirname)
1089    return 0
1090
1091
1092CUSTOM_HANDLERS = {
1093    "test_cuda_primary_ctx": run_test_with_subprocess,
1094    "test_cuda_nvml_based_avail": run_test_with_subprocess,
1095    "test_cuda_trace": run_test_with_subprocess,
1096    "test_cpp_extensions_aot_no_ninja": test_cpp_extensions_aot_no_ninja,
1097    "test_cpp_extensions_aot_ninja": test_cpp_extensions_aot_ninja,
1098    "distributed/test_distributed_spawn": test_distributed,
1099    "distributed/algorithms/quantization/test_quantization": test_distributed,
1100    "distributed/test_c10d_nccl": run_test_with_subprocess,
1101    "distributed/test_c10d_gloo": run_test_with_subprocess,
1102    "distributed/test_c10d_ucc": run_test_with_subprocess,
1103    "distributed/test_c10d_common": run_test_with_subprocess,
1104    "distributed/test_c10d_spawn_gloo": run_test_with_subprocess,
1105    "distributed/test_c10d_spawn_nccl": run_test_with_subprocess,
1106    "distributed/test_c10d_spawn_ucc": run_test_with_subprocess,
1107    "distributed/test_store": run_test_with_subprocess,
1108    "distributed/test_pg_wrapper": run_test_with_subprocess,
1109    "distributed/rpc/test_faulty_agent": run_test_with_subprocess,
1110    "distributed/rpc/test_tensorpipe_agent": run_test_with_subprocess,
1111    "distributed/rpc/test_share_memory": run_test_with_subprocess,
1112    "distributed/rpc/cuda/test_tensorpipe_agent": run_test_with_subprocess,
1113    "doctests": run_doctests,
1114    "test_ci_sanity_check_fail": run_ci_sanity_check,
1115    "test_autoload_enable": test_autoload_enable,
1116    "test_autoload_disable": test_autoload_disable,
1117}
1118
1119
1120PYTEST_SKIP_RETRIES = {"test_public_bindings"}
1121
1122
1123def parse_args():
1124    parser = argparse.ArgumentParser(
1125        description="Run the PyTorch unit test suite",
1126        epilog="where TESTS is any of: {}".format(", ".join(TESTS)),
1127        formatter_class=argparse.RawTextHelpFormatter,
1128    )
1129    parser.add_argument(
1130        "-v",
1131        "--verbose",
1132        action="count",
1133        default=0,
1134        help="Print verbose information and test-by-test results",
1135    )
1136    if sys.version_info >= (3, 9):
1137        parser.add_argument(
1138            "--showlocals",
1139            action=argparse.BooleanOptionalAction,
1140            default=strtobool(os.environ.get("TEST_SHOWLOCALS", "False")),
1141            help="Show local variables in tracebacks (default: True)",
1142        )
1143    else:
1144        parser.add_argument(
1145            "--showlocals",
1146            action="store_true",
1147            default=strtobool(os.environ.get("TEST_SHOWLOCALS", "False")),
1148            help="Show local variables in tracebacks (default: True)",
1149        )
1150        parser.add_argument("--no-showlocals", dest="showlocals", action="store_false")
1151    parser.add_argument("--jit", "--jit", action="store_true", help="run all jit tests")
1152    parser.add_argument(
1153        "--distributed-tests",
1154        "--distributed-tests",
1155        action="store_true",
1156        help="Run all distributed tests",
1157    )
1158    parser.add_argument(
1159        "--functorch",
1160        "--functorch",
1161        action="store_true",
1162        help=(
1163            "If this flag is present, we will only run functorch tests. "
1164            "If this flag is not present, we will run all tests "
1165            "(including functorch tests)."
1166        ),
1167    )
1168    parser.add_argument(
1169        "--mps",
1170        "--mps",
1171        action="store_true",
1172        help=("If this flag is present, we will only run test_mps and test_metal"),
1173    )
1174    parser.add_argument(
1175        "--xpu",
1176        "--xpu",
1177        action="store_true",
1178        help=("If this flag is present, we will run xpu tests except XPU_BLOCK_LIST"),
1179    )
1180    parser.add_argument(
1181        "--cpp",
1182        "--cpp",
1183        action="store_true",
1184        help=("If this flag is present, we will only run C++ tests"),
1185    )
1186    parser.add_argument(
1187        "-core",
1188        "--core",
1189        action="store_true",
1190        help="Only run core tests, or tests that validate PyTorch's ops, modules,"
1191        "and autograd. They are defined by CORE_TEST_LIST.",
1192    )
1193    parser.add_argument(
1194        "--onnx",
1195        "--onnx",
1196        action="store_true",
1197        help=(
1198            "Only run ONNX tests, or tests that validate PyTorch's ONNX export. "
1199            "If this flag is not present, we will exclude ONNX tests."
1200        ),
1201    )
1202    parser.add_argument(
1203        "-k",
1204        "--pytest-k-expr",
1205        default="",
1206        help="Pass to pytest as its -k expr argument",
1207    )
1208    parser.add_argument(
1209        "-c",
1210        "--coverage",
1211        action="store_true",
1212        help="enable coverage",
1213        default=PYTORCH_COLLECT_COVERAGE,
1214    )
1215    parser.add_argument(
1216        "-i",
1217        "--include",
1218        nargs="+",
1219        choices=TestChoices(TESTS),
1220        default=TESTS,
1221        metavar="TESTS",
1222        help="select a set of tests to include (defaults to ALL tests)."
1223        " tests must be a part of the TESTS list defined in run_test.py",
1224    )
1225    parser.add_argument(
1226        "-x",
1227        "--exclude",
1228        nargs="+",
1229        choices=TESTS,
1230        metavar="TESTS",
1231        default=[],
1232        help="select a set of tests to exclude",
1233    )
1234    parser.add_argument(
1235        "--ignore-win-blocklist",
1236        action="store_true",
1237        help="always run blocklisted windows tests",
1238    )
1239    # NS: Disable target determination until it can be made more reliable
1240    # parser.add_argument(
1241    #     "--determine-from",
1242    #     help="File of affected source filenames to determine which tests to run.",
1243    # )
1244    parser.add_argument(
1245        "--continue-through-error",
1246        "--keep-going",
1247        action="store_true",
1248        help="Runs the full test suite despite one of the tests failing",
1249        default=strtobool(os.environ.get("CONTINUE_THROUGH_ERROR", "False")),
1250    )
1251    parser.add_argument(
1252        "--pipe-logs",
1253        action="store_true",
1254        help="Print logs to output file while running tests.  True if in CI and env var is not set",
1255        default=IS_CI and not strtobool(os.environ.get("VERBOSE_TEST_LOGS", "False")),
1256    )
1257    parser.add_argument(
1258        "--enable-timeout",
1259        action="store_true",
1260        help="Set a timeout based on the test times json file.  Only works if there are test times available",
1261        default=IS_CI and not strtobool(os.environ.get("NO_TEST_TIMEOUT", "False")),
1262    )
1263    parser.add_argument(
1264        "--enable-td",
1265        action="store_true",
1266        help="Enables removing tests based on TD",
1267        default=IS_CI
1268        and (
1269            TEST_WITH_CROSSREF
1270            or TEST_WITH_ASAN
1271            or (TEST_CONFIG == "distributed" and TEST_CUDA)
1272            or (IS_WINDOWS and not TEST_CUDA)
1273            or TEST_CONFIG == "nogpu_AVX512"
1274            or TEST_CONFIG == "nogpu_NO_AVX2"
1275            or TEST_CONFIG == "default"
1276        )
1277        and get_pr_number() is not None
1278        and not strtobool(os.environ.get("NO_TD", "False"))
1279        and not TEST_WITH_ROCM
1280        and not IS_MACOS
1281        and "xpu" not in BUILD_ENVIRONMENT
1282        and "onnx" not in BUILD_ENVIRONMENT
1283        and os.environ.get("GITHUB_WORKFLOW", "slow") in ("trunk", "pull"),
1284    )
1285    parser.add_argument(
1286        "--shard",
1287        nargs=2,
1288        type=int,
1289        help="runs a shard of the tests (taking into account other selections), e.g., "
1290        "--shard 2 3 will break up the selected tests into 3 shards and run the tests "
1291        "in the 2nd shard (the first number should not exceed the second)",
1292    )
1293    parser.add_argument(
1294        "--exclude-jit-executor",
1295        action="store_true",
1296        help="exclude tests that are run for a specific jit config",
1297    )
1298    parser.add_argument(
1299        "--exclude-torch-export-tests",
1300        action="store_true",
1301        help="exclude torch export tests",
1302    )
1303    parser.add_argument(
1304        "--exclude-distributed-tests",
1305        action="store_true",
1306        help="exclude distributed tests",
1307    )
1308    parser.add_argument(
1309        "--exclude-inductor-tests",
1310        action="store_true",
1311        help="exclude inductor tests",
1312    )
1313    parser.add_argument(
1314        "--dry-run",
1315        action="store_true",
1316        help="Only list the test that will run.",
1317    )
1318    parser.add_argument(
1319        "--xdoctest-command",
1320        default="all",
1321        help=(
1322            "Control the specific doctest action. "
1323            "Use 'list' to simply parse doctests and check syntax. "
1324            "Use 'all' to execute all doctests or specify a specific "
1325            "doctest to run"
1326        ),
1327    )
1328    parser.add_argument(
1329        "--no-translation-validation",
1330        action="store_false",
1331        help="Run tests without translation validation.",
1332    )
1333
1334    group = parser.add_mutually_exclusive_group()
1335    group.add_argument(
1336        "--dynamo",
1337        action="store_true",
1338        help="Run tests with TorchDynamo+EagerBackend turned on",
1339    )
1340    group.add_argument(
1341        "--inductor",
1342        action="store_true",
1343        help="Run tests with TorchInductor turned on",
1344    )
1345
1346    args, extra = parser.parse_known_args()
1347    if "--" in extra:
1348        extra.remove("--")
1349    args.additional_args = extra
1350    return args
1351
1352
1353def exclude_tests(
1354    exclude_list, selected_tests, exclude_message=None, exact_match=False
1355):
1356    for exclude_test in exclude_list:
1357        tests_copy = selected_tests[:]
1358        for test in tests_copy:
1359            if (
1360                not exact_match and test.startswith(exclude_test)
1361            ) or test == exclude_test:
1362                if exclude_message is not None:
1363                    print_to_stderr(f"Excluding {test} {exclude_message}")
1364                selected_tests.remove(test)
1365    return selected_tests
1366
1367
1368def must_serial(file: Union[str, ShardedTest]) -> bool:
1369    if isinstance(file, ShardedTest):
1370        file = file.name
1371    return (
1372        os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1"
1373        or DISTRIBUTED_TEST_PREFIX in os.getenv("TEST_CONFIG", "")
1374        or DISTRIBUTED_TEST_PREFIX in file
1375        or file in CUSTOM_HANDLERS
1376        or file in RUN_PARALLEL_BLOCKLIST
1377        or file in CI_SERIAL_LIST
1378        or file in JIT_EXECUTOR_TESTS
1379        or file in ONNX_SERIAL_LIST
1380        or NUM_PROCS == 1
1381    )
1382
1383
1384def can_run_in_pytest(test):
1385    return os.getenv("PYTORCH_TEST_DO_NOT_USE_PYTEST", "0") == "0"
1386
1387
1388def get_selected_tests(options) -> List[str]:
1389    selected_tests = options.include
1390
1391    # filter if there's JIT only and distributed only test options
1392    if options.jit:
1393        selected_tests = list(
1394            filter(lambda test_name: "jit" in test_name, selected_tests)
1395        )
1396
1397    if options.distributed_tests:
1398        selected_tests = list(
1399            filter(lambda test_name: test_name in DISTRIBUTED_TESTS, selected_tests)
1400        )
1401
1402    # Filter to only run core tests when --core option is specified
1403    if options.core:
1404        selected_tests = list(
1405            filter(lambda test_name: test_name in CORE_TEST_LIST, selected_tests)
1406        )
1407
1408    # Filter to only run functorch tests when --functorch option is specified
1409    if options.functorch:
1410        selected_tests = [tname for tname in selected_tests if tname in FUNCTORCH_TESTS]
1411
1412    if options.cpp:
1413        selected_tests = [tname for tname in selected_tests if tname in CPP_TESTS]
1414    else:
1415        # Exclude all C++ tests otherwise as they are still handled differently
1416        # than Python test at the moment
1417        options.exclude.extend(CPP_TESTS)
1418
1419    if options.mps:
1420        selected_tests = ["test_mps", "test_metal", "test_modules", "test_nn"]
1421    else:
1422        # Exclude all mps tests otherwise
1423        options.exclude.extend(["test_mps", "test_metal"])
1424
1425    if options.xpu:
1426        selected_tests = exclude_tests(XPU_BLOCKLIST, selected_tests, "on XPU")
1427    else:
1428        # Exclude all xpu specifc tests otherwise
1429        options.exclude.extend(XPU_TEST)
1430
1431    # Filter to only run onnx tests when --onnx option is specified
1432    onnx_tests = [tname for tname in selected_tests if tname in ONNX_TESTS]
1433    if options.onnx:
1434        selected_tests = onnx_tests
1435    else:
1436        # Exclude all onnx tests otherwise
1437        options.exclude.extend(onnx_tests)
1438
1439    # process exclusion
1440    if options.exclude_jit_executor:
1441        options.exclude.extend(JIT_EXECUTOR_TESTS)
1442
1443    if options.exclude_distributed_tests:
1444        options.exclude.extend(DISTRIBUTED_TESTS)
1445
1446    if options.exclude_inductor_tests:
1447        options.exclude.extend(INDUCTOR_TESTS)
1448
1449    if options.exclude_torch_export_tests:
1450        options.exclude.extend(TORCH_EXPORT_TESTS)
1451
1452    # these tests failing in CUDA 11.6 temporary disabling. issue https://github.com/pytorch/pytorch/issues/75375
1453    if torch.version.cuda is not None:
1454        options.exclude.extend(["distributions/test_constraints"])
1455
1456    # these tests failing in Python 3.12 temporarily disabling
1457    if sys.version_info >= (3, 12):
1458        options.exclude.extend(
1459            [
1460                "functorch/test_dims",
1461                "functorch/test_rearrange",
1462                "functorch/test_parsing",
1463                "functorch/test_memory_efficient_fusion",
1464                "torch_np/numpy_tests/core/test_multiarray",
1465            ]
1466        )
1467
1468    selected_tests = exclude_tests(options.exclude, selected_tests)
1469
1470    if sys.platform == "win32" and not options.ignore_win_blocklist:
1471        target_arch = os.environ.get("VSCMD_ARG_TGT_ARCH")
1472        if target_arch != "x64":
1473            WINDOWS_BLOCKLIST.append("cpp_extensions_aot_no_ninja")
1474            WINDOWS_BLOCKLIST.append("cpp_extensions_aot_ninja")
1475            WINDOWS_BLOCKLIST.append("cpp_extensions_jit")
1476            WINDOWS_BLOCKLIST.append("jit")
1477            WINDOWS_BLOCKLIST.append("jit_fuser")
1478
1479        selected_tests = exclude_tests(WINDOWS_BLOCKLIST, selected_tests, "on Windows")
1480
1481    elif TEST_WITH_ROCM:
1482        selected_tests = exclude_tests(ROCM_BLOCKLIST, selected_tests, "on ROCm")
1483
1484    # skip all distributed tests if distributed package is not available.
1485    if not dist.is_available():
1486        selected_tests = exclude_tests(
1487            DISTRIBUTED_TESTS,
1488            selected_tests,
1489            "PyTorch is built without distributed support.",
1490        )
1491
1492    # skip tests that require LAPACK when it's not available
1493    if not torch._C.has_lapack:
1494        selected_tests = exclude_tests(
1495            TESTS_REQUIRING_LAPACK,
1496            selected_tests,
1497            "PyTorch is built without LAPACK support.",
1498        )
1499
1500    if TEST_WITH_SLOW_GRADCHECK:
1501        selected_tests = exclude_tests(
1502            TESTS_NOT_USING_GRADCHECK,
1503            selected_tests,
1504            "Running in slow gradcheck mode, skipping tests "
1505            "that don't use gradcheck.",
1506            exact_match=True,
1507        )
1508
1509    selected_tests = [parse_test_module(x) for x in selected_tests]
1510    return selected_tests
1511
1512
1513def load_test_times_from_file(file: str) -> Dict[str, Any]:
1514    # Load previous test times to make sharding decisions
1515    path = os.path.join(str(REPO_ROOT), file)
1516    if not os.path.exists(path):
1517        print_to_stderr(
1518            f"::warning:: Failed to find test times file `{path}`. Using round robin sharding."
1519        )
1520        return {}
1521
1522    with open(path) as f:
1523        test_times_file = cast(Dict[str, Any], json.load(f))
1524    build_environment = os.environ.get("BUILD_ENVIRONMENT")
1525    test_config = os.environ.get("TEST_CONFIG")
1526    if test_config in test_times_file.get(build_environment, {}):
1527        print_to_stderr("Found test times from artifacts")
1528        return test_times_file[build_environment][test_config]
1529    elif test_config in test_times_file["default"]:
1530        print_to_stderr(
1531            f"::warning:: Gathered no stats from artifacts for {build_environment} build env"
1532            f" and {test_config} test config. Using default build env and {test_config} test config instead."
1533        )
1534        return test_times_file["default"][test_config]
1535    else:
1536        print_to_stderr(
1537            f"::warning:: Gathered no stats from artifacts for build env {build_environment} build env"
1538            f" and {test_config} test config. Using default build env and default test config instead."
1539        )
1540        return test_times_file["default"]["default"]
1541
1542
1543def load_test_file_times(
1544    file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_TIMES_FILE,
1545) -> Dict[str, float]:
1546    return cast(Dict[str, float], load_test_times_from_file(file))
1547
1548
1549def load_test_class_times(
1550    file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_CLASS_TIMES_FILE,
1551) -> Dict[str, Dict[str, float]]:
1552    return cast(Dict[str, Dict[str, float]], load_test_times_from_file(file))
1553
1554
1555def get_sharding_opts(options) -> Tuple[int, int]:
1556    which_shard, num_shards = 1, 1
1557    if options.shard:
1558        assert len(options.shard) == 2, "Unexpected shard format"
1559        assert min(options.shard) > 0, "Shards must be positive numbers"
1560        which_shard, num_shards = options.shard
1561        assert (
1562            which_shard <= num_shards
1563        ), "Selected shard must be less than or equal to total number of shards"
1564
1565    return (which_shard, num_shards)
1566
1567
1568def do_sharding(
1569    options,
1570    selected_tests: Sequence[TestRun],
1571    test_file_times: Dict[str, float],
1572    test_class_times: Dict[str, Dict[str, float]],
1573    sort_by_time: bool = True,
1574) -> Tuple[float, List[ShardedTest]]:
1575    which_shard, num_shards = get_sharding_opts(options)
1576
1577    # Do sharding
1578    shards = calculate_shards(
1579        num_shards,
1580        selected_tests,
1581        test_file_times,
1582        test_class_times=test_class_times,
1583        must_serial=must_serial,
1584        sort_by_time=sort_by_time,
1585    )
1586    return shards[which_shard - 1]
1587
1588
1589class TestFailure(NamedTuple):
1590    test: TestRun
1591    message: str
1592
1593
1594def run_test_module(
1595    test: ShardedTest, test_directory: str, options
1596) -> Optional[TestFailure]:
1597    try:
1598        maybe_set_hip_visible_devies()
1599
1600        test_name = test.name
1601
1602        # Printing the date here can help diagnose which tests are slow
1603        print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
1604        handler = CUSTOM_HANDLERS.get(test_name, run_test)
1605        return_code = handler(test, test_directory, options)
1606        assert isinstance(return_code, int) and not isinstance(
1607            return_code, bool
1608        ), f"While running {str(test)} got non integer return code {return_code}"
1609        if return_code == 0:
1610            return None
1611
1612        message = f"{str(test)} failed!"
1613        if return_code < 0:
1614            # subprocess.Popen returns the child process' exit signal as
1615            # return code -N, where N is the signal number.
1616            signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
1617            message += f" Received signal: {signal_name}"
1618        return TestFailure(test.test, message)
1619    except Exception as e:
1620        return TestFailure(test.test, f"{str(test)} failed! {e}")
1621
1622
1623def run_tests(
1624    selected_tests: List[ShardedTest],
1625    test_directory: str,
1626    options,
1627    failures: List[TestFailure],
1628) -> None:
1629    if len(selected_tests) == 0:
1630        return
1631
1632    # parallel = in parallel with other files
1633    # serial = this file on it's own.  The file might still be run in parallel with itself (ex test_ops)
1634    selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
1635    selected_tests_serial = [
1636        x for x in selected_tests if x not in selected_tests_parallel
1637    ]
1638
1639    # See Note [ROCm parallel CI testing]
1640    pool = get_context("spawn").Pool(
1641        NUM_PROCS, maxtasksperchild=None if torch.version.hip else 1
1642    )
1643
1644    # NB: This is a hack to make conftest.py and files it depends on available
1645    # on CPP_TESTS_DIR. We should see if the file could be turned into a
1646    # full-fledge ptest plugin instead
1647    conftest_files = [
1648        "conftest.py",
1649        "pytest_shard_custom.py",
1650    ]
1651    for conftest_file in conftest_files:
1652        cpp_file = os.path.join(CPP_TESTS_DIR, conftest_file)
1653        if (
1654            options.cpp
1655            and os.path.exists(CPP_TESTS_DIR)
1656            and os.path.isdir(CPP_TESTS_DIR)
1657            and not os.path.exists(cpp_file)
1658        ):
1659            shutil.copy(os.path.join(test_directory, conftest_file), cpp_file)
1660
1661    def handle_error_messages(failure: Optional[TestFailure]):
1662        if failure is None:
1663            return False
1664        failures.append(failure)
1665        print_to_stderr(failure.message)
1666        return True
1667
1668    def parallel_test_completion_callback(failure):
1669        test_failed = handle_error_messages(failure)
1670        if (
1671            test_failed
1672            and not options.continue_through_error
1673            and not RERUN_DISABLED_TESTS
1674        ):
1675            pool.terminate()
1676
1677    keep_going_message = (
1678        "\n\nTip: You can keep running tests even on failure by passing --keep-going to run_test.py.\n"
1679        "If running on CI, add the 'keep-going' label to your PR and rerun your jobs."
1680    )
1681
1682    try:
1683        for test in selected_tests_serial:
1684            options_clone = copy.deepcopy(options)
1685            if can_run_in_pytest(test):
1686                options_clone.pytest = True
1687            failure = run_test_module(test, test_directory, options_clone)
1688            test_failed = handle_error_messages(failure)
1689            if (
1690                test_failed
1691                and not options.continue_through_error
1692                and not RERUN_DISABLED_TESTS
1693            ):
1694                raise RuntimeError(failure.message + keep_going_message)
1695
1696        # Run tests marked as serial first
1697        for test in selected_tests_parallel:
1698            options_clone = copy.deepcopy(options)
1699            if can_run_in_pytest(test):
1700                options_clone.pytest = True
1701            options_clone.additional_args.extend(["-m", "serial"])
1702            failure = run_test_module(test, test_directory, options_clone)
1703            test_failed = handle_error_messages(failure)
1704            if (
1705                test_failed
1706                and not options.continue_through_error
1707                and not RERUN_DISABLED_TESTS
1708            ):
1709                raise RuntimeError(failure.message + keep_going_message)
1710
1711        os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS)
1712        for test in selected_tests_parallel:
1713            options_clone = copy.deepcopy(options)
1714            if can_run_in_pytest(test):
1715                options_clone.pytest = True
1716            options_clone.additional_args.extend(["-m", "not serial"])
1717            pool.apply_async(
1718                run_test_module,
1719                args=(test, test_directory, options_clone),
1720                callback=parallel_test_completion_callback,
1721            )
1722        pool.close()
1723        pool.join()
1724        del os.environ["NUM_PARALLEL_PROCS"]
1725
1726    finally:
1727        pool.terminate()
1728        pool.join()
1729
1730    return
1731
1732
1733def check_pip_packages() -> None:
1734    packages = [
1735        "pytest-rerunfailures",
1736        "pytest-flakefinder",
1737        "pytest-xdist",
1738    ]
1739    installed_packages = [i.key for i in pkg_resources.working_set]
1740    for package in packages:
1741        if package not in installed_packages:
1742            print_to_stderr(
1743                f"Missing pip dependency: {package}, please run `pip install -r .ci/docker/requirements-ci.txt`"
1744            )
1745            sys.exit(1)
1746
1747
1748def main():
1749    check_pip_packages()
1750
1751    options = parse_args()
1752
1753    # Include sharding info in all metrics
1754    which_shard, num_shards = get_sharding_opts(options)
1755    add_global_metric("shard", which_shard)
1756    add_global_metric("num_shards", num_shards)
1757
1758    test_directory = str(REPO_ROOT / "test")
1759    selected_tests = get_selected_tests(options)
1760
1761    test_prioritizations = import_results()
1762    test_prioritizations.amend_tests(selected_tests)
1763
1764    os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
1765
1766    if options.coverage and not PYTORCH_COLLECT_COVERAGE:
1767        shell(["coverage", "erase"])
1768
1769    if IS_CI:
1770        # downloading test cases configuration to local environment
1771        get_test_case_configs(dirpath=test_directory)
1772
1773    test_file_times_dict = load_test_file_times()
1774    test_class_times_dict = load_test_class_times()
1775
1776    class TestBatch:
1777        """Defines a set of tests with similar priority that should be run together on the current shard"""
1778
1779        name: str
1780        sharded_tests: List[ShardedTest]
1781        failures: List[TestFailure]
1782
1783        def __init__(
1784            self, name: str, raw_tests: Sequence[TestRun], should_sort_shard: bool
1785        ):
1786            self.name = name
1787            self.failures = []
1788            self.time, self.sharded_tests = do_sharding(
1789                options,
1790                raw_tests,
1791                test_file_times_dict,
1792                test_class_times_dict,
1793                sort_by_time=should_sort_shard,
1794            )
1795
1796        def __str__(self):
1797            s = f"Name: {self.name} (est. time: {round(self.time / 60, 2)}min)\n"
1798            serial = [test for test in self.sharded_tests if must_serial(test)]
1799            parallel = [test for test in self.sharded_tests if not must_serial(test)]
1800            s += f"  Serial tests ({len(serial)}):\n"
1801            s += "".join(f"    {test}\n" for test in serial)
1802            s += f"  Parallel tests ({len(parallel)}):\n"
1803            s += "".join(f"    {test}\n" for test in parallel)
1804            return s.strip()
1805
1806    percent_to_run = 25 if options.enable_td else 100
1807    print_to_stderr(
1808        f"Running {percent_to_run}% of tests based on TD"
1809        if options.enable_td
1810        else "Running all tests"
1811    )
1812    include, exclude = test_prioritizations.get_top_per_tests(percent_to_run)
1813
1814    test_batch = TestBatch("tests to run", include, False)
1815    test_batch_exclude = TestBatch("excluded", exclude, True)
1816    if IS_CI:
1817        gen_ci_artifact([x.to_json() for x in include], [x.to_json() for x in exclude])
1818
1819    print_to_stderr(f"Running parallel tests on {NUM_PROCS} processes")
1820    print_to_stderr(test_batch)
1821    print_to_stderr(test_batch_exclude)
1822
1823    if options.dry_run:
1824        return
1825
1826    if options.dynamo:
1827        os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1"
1828
1829    elif options.inductor:
1830        os.environ["PYTORCH_TEST_WITH_INDUCTOR"] = "1"
1831
1832    if not options.no_translation_validation:
1833        os.environ["PYTORCH_TEST_WITH_TV"] = "1"
1834
1835    try:
1836        # Actually run the tests
1837        start_time = time.time()
1838        run_tests(
1839            test_batch.sharded_tests, test_directory, options, test_batch.failures
1840        )
1841        elapsed_time = time.time() - start_time
1842        print_to_stderr(
1843            f"Running test batch '{test_batch.name}' cost {round(elapsed_time, 2)} seconds"
1844        )
1845
1846    finally:
1847        if options.coverage:
1848            from coverage import Coverage
1849
1850            with set_cwd(test_directory):
1851                cov = Coverage()
1852                if PYTORCH_COLLECT_COVERAGE:
1853                    cov.load()
1854                cov.combine(strict=False)
1855                cov.save()
1856                if not PYTORCH_COLLECT_COVERAGE:
1857                    cov.html_report()
1858
1859        all_failures = test_batch.failures
1860
1861        if IS_CI:
1862            for test, _ in all_failures:
1863                test_stats = test_prioritizations.get_test_stats(test)
1864                print_to_stderr("Emiting td_test_failure_stats_v2")
1865                emit_metric(
1866                    "td_test_failure_stats_v2",
1867                    {
1868                        "selected_tests": selected_tests,
1869                        "failure": str(test),
1870                        **test_stats,
1871                    },
1872                )
1873            gen_additional_test_failures_file(
1874                [test.test_file for test, _ in all_failures]
1875            )
1876
1877    if len(all_failures):
1878        for _, err in all_failures:
1879            print_to_stderr(err)
1880
1881        # A disabled test is expected to fail, so there is no need to report a failure here
1882        if not RERUN_DISABLED_TESTS:
1883            sys.exit(1)
1884
1885
1886if __name__ == "__main__":
1887    main()
1888