xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/conftest.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport pytest  # noqa: F401
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerdefault_rnns = [
5*da0073e9SAndroid Build Coastguard Worker    "cudnn",
6*da0073e9SAndroid Build Coastguard Worker    "aten",
7*da0073e9SAndroid Build Coastguard Worker    "jit",
8*da0073e9SAndroid Build Coastguard Worker    "jit_premul",
9*da0073e9SAndroid Build Coastguard Worker    "jit_premul_bias",
10*da0073e9SAndroid Build Coastguard Worker    "jit_simple",
11*da0073e9SAndroid Build Coastguard Worker    "jit_multilayer",
12*da0073e9SAndroid Build Coastguard Worker    "py",
13*da0073e9SAndroid Build Coastguard Worker]
14*da0073e9SAndroid Build Coastguard Workerdefault_cnns = ["resnet18", "resnet18_jit", "resnet50", "resnet50_jit"]
15*da0073e9SAndroid Build Coastguard Workerall_nets = default_rnns + default_cnns
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Workerdef pytest_generate_tests(metafunc):
19*da0073e9SAndroid Build Coastguard Worker    # This creates lists of tests to generate, can be customized
20*da0073e9SAndroid Build Coastguard Worker    if metafunc.cls.__name__ == "TestBenchNetwork":
21*da0073e9SAndroid Build Coastguard Worker        metafunc.parametrize("net_name", all_nets, scope="class")
22*da0073e9SAndroid Build Coastguard Worker        metafunc.parametrize(
23*da0073e9SAndroid Build Coastguard Worker            "executor", [metafunc.config.getoption("executor")], scope="class"
24*da0073e9SAndroid Build Coastguard Worker        )
25*da0073e9SAndroid Build Coastguard Worker        metafunc.parametrize(
26*da0073e9SAndroid Build Coastguard Worker            "fuser", [metafunc.config.getoption("fuser")], scope="class"
27*da0073e9SAndroid Build Coastguard Worker        )
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workerdef pytest_addoption(parser):
31*da0073e9SAndroid Build Coastguard Worker    parser.addoption("--fuser", default="old", help="fuser to use for benchmarks")
32*da0073e9SAndroid Build Coastguard Worker    parser.addoption(
33*da0073e9SAndroid Build Coastguard Worker        "--executor", default="legacy", help="executor to use for benchmarks"
34*da0073e9SAndroid Build Coastguard Worker    )
35