xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/runner.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple
2*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torchvision.models as cnn
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerfrom .factory import (
9*da0073e9SAndroid Build Coastguard Worker    dropoutlstm_creator,
10*da0073e9SAndroid Build Coastguard Worker    imagenet_cnn_creator,
11*da0073e9SAndroid Build Coastguard Worker    layernorm_pytorch_lstm_creator,
12*da0073e9SAndroid Build Coastguard Worker    lnlstm_creator,
13*da0073e9SAndroid Build Coastguard Worker    lstm_creator,
14*da0073e9SAndroid Build Coastguard Worker    lstm_multilayer_creator,
15*da0073e9SAndroid Build Coastguard Worker    lstm_premul_bias_creator,
16*da0073e9SAndroid Build Coastguard Worker    lstm_premul_creator,
17*da0073e9SAndroid Build Coastguard Worker    lstm_simple_creator,
18*da0073e9SAndroid Build Coastguard Worker    pytorch_lstm_creator,
19*da0073e9SAndroid Build Coastguard Worker    varlen_lstm_creator,
20*da0073e9SAndroid Build Coastguard Worker    varlen_pytorch_lstm_creator,
21*da0073e9SAndroid Build Coastguard Worker)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerclass DisableCuDNN:
25*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
26*da0073e9SAndroid Build Coastguard Worker        self.saved = torch.backends.cudnn.enabled
27*da0073e9SAndroid Build Coastguard Worker        torch.backends.cudnn.enabled = False
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, *args, **kwargs):
30*da0073e9SAndroid Build Coastguard Worker        torch.backends.cudnn.enabled = self.saved
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Workerclass DummyContext:
34*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
35*da0073e9SAndroid Build Coastguard Worker        pass
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, *args, **kwargs):
38*da0073e9SAndroid Build Coastguard Worker        pass
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerclass AssertNoJIT:
42*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
43*da0073e9SAndroid Build Coastguard Worker        import os
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        enabled = os.environ.get("PYTORCH_JIT", 1)
46*da0073e9SAndroid Build Coastguard Worker        assert not enabled
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, *args, **kwargs):
49*da0073e9SAndroid Build Coastguard Worker        pass
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard WorkerRNNRunner = namedtuple(
53*da0073e9SAndroid Build Coastguard Worker    "RNNRunner",
54*da0073e9SAndroid Build Coastguard Worker    [
55*da0073e9SAndroid Build Coastguard Worker        "name",
56*da0073e9SAndroid Build Coastguard Worker        "creator",
57*da0073e9SAndroid Build Coastguard Worker        "context",
58*da0073e9SAndroid Build Coastguard Worker    ],
59*da0073e9SAndroid Build Coastguard Worker)
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Workerdef get_nn_runners(*names):
63*da0073e9SAndroid Build Coastguard Worker    return [nn_runners[name] for name in names]
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Workernn_runners = {
67*da0073e9SAndroid Build Coastguard Worker    "cudnn": RNNRunner("cudnn", pytorch_lstm_creator, DummyContext),
68*da0073e9SAndroid Build Coastguard Worker    "cudnn_dropout": RNNRunner(
69*da0073e9SAndroid Build Coastguard Worker        "cudnn_dropout", partial(pytorch_lstm_creator, dropout=0.4), DummyContext
70*da0073e9SAndroid Build Coastguard Worker    ),
71*da0073e9SAndroid Build Coastguard Worker    "cudnn_layernorm": RNNRunner(
72*da0073e9SAndroid Build Coastguard Worker        "cudnn_layernorm", layernorm_pytorch_lstm_creator, DummyContext
73*da0073e9SAndroid Build Coastguard Worker    ),
74*da0073e9SAndroid Build Coastguard Worker    "vl_cudnn": RNNRunner("vl_cudnn", varlen_pytorch_lstm_creator, DummyContext),
75*da0073e9SAndroid Build Coastguard Worker    "vl_jit": RNNRunner(
76*da0073e9SAndroid Build Coastguard Worker        "vl_jit", partial(varlen_lstm_creator, script=True), DummyContext
77*da0073e9SAndroid Build Coastguard Worker    ),
78*da0073e9SAndroid Build Coastguard Worker    "vl_py": RNNRunner("vl_py", varlen_lstm_creator, DummyContext),
79*da0073e9SAndroid Build Coastguard Worker    "aten": RNNRunner("aten", pytorch_lstm_creator, DisableCuDNN),
80*da0073e9SAndroid Build Coastguard Worker    "jit": RNNRunner("jit", lstm_creator, DummyContext),
81*da0073e9SAndroid Build Coastguard Worker    "jit_premul": RNNRunner("jit_premul", lstm_premul_creator, DummyContext),
82*da0073e9SAndroid Build Coastguard Worker    "jit_premul_bias": RNNRunner(
83*da0073e9SAndroid Build Coastguard Worker        "jit_premul_bias", lstm_premul_bias_creator, DummyContext
84*da0073e9SAndroid Build Coastguard Worker    ),
85*da0073e9SAndroid Build Coastguard Worker    "jit_simple": RNNRunner("jit_simple", lstm_simple_creator, DummyContext),
86*da0073e9SAndroid Build Coastguard Worker    "jit_multilayer": RNNRunner(
87*da0073e9SAndroid Build Coastguard Worker        "jit_multilayer", lstm_multilayer_creator, DummyContext
88*da0073e9SAndroid Build Coastguard Worker    ),
89*da0073e9SAndroid Build Coastguard Worker    "jit_layernorm": RNNRunner("jit_layernorm", lnlstm_creator, DummyContext),
90*da0073e9SAndroid Build Coastguard Worker    "jit_layernorm_decom": RNNRunner(
91*da0073e9SAndroid Build Coastguard Worker        "jit_layernorm_decom",
92*da0073e9SAndroid Build Coastguard Worker        partial(lnlstm_creator, decompose_layernorm=True),
93*da0073e9SAndroid Build Coastguard Worker        DummyContext,
94*da0073e9SAndroid Build Coastguard Worker    ),
95*da0073e9SAndroid Build Coastguard Worker    "jit_dropout": RNNRunner("jit_dropout", dropoutlstm_creator, DummyContext),
96*da0073e9SAndroid Build Coastguard Worker    "py": RNNRunner("py", partial(lstm_creator, script=False), DummyContext),
97*da0073e9SAndroid Build Coastguard Worker    "resnet18": RNNRunner(
98*da0073e9SAndroid Build Coastguard Worker        "resnet18", imagenet_cnn_creator(cnn.resnet18, jit=False), DummyContext
99*da0073e9SAndroid Build Coastguard Worker    ),
100*da0073e9SAndroid Build Coastguard Worker    "resnet18_jit": RNNRunner(
101*da0073e9SAndroid Build Coastguard Worker        "resnet18_jit", imagenet_cnn_creator(cnn.resnet18), DummyContext
102*da0073e9SAndroid Build Coastguard Worker    ),
103*da0073e9SAndroid Build Coastguard Worker    "resnet50": RNNRunner(
104*da0073e9SAndroid Build Coastguard Worker        "resnet50", imagenet_cnn_creator(cnn.resnet50, jit=False), DummyContext
105*da0073e9SAndroid Build Coastguard Worker    ),
106*da0073e9SAndroid Build Coastguard Worker    "resnet50_jit": RNNRunner(
107*da0073e9SAndroid Build Coastguard Worker        "resnet50_jit", imagenet_cnn_creator(cnn.resnet50), DummyContext
108*da0073e9SAndroid Build Coastguard Worker    ),
109*da0073e9SAndroid Build Coastguard Worker}
110