1*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple 2*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torchvision.models as cnn 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerfrom .factory import ( 9*da0073e9SAndroid Build Coastguard Worker dropoutlstm_creator, 10*da0073e9SAndroid Build Coastguard Worker imagenet_cnn_creator, 11*da0073e9SAndroid Build Coastguard Worker layernorm_pytorch_lstm_creator, 12*da0073e9SAndroid Build Coastguard Worker lnlstm_creator, 13*da0073e9SAndroid Build Coastguard Worker lstm_creator, 14*da0073e9SAndroid Build Coastguard Worker lstm_multilayer_creator, 15*da0073e9SAndroid Build Coastguard Worker lstm_premul_bias_creator, 16*da0073e9SAndroid Build Coastguard Worker lstm_premul_creator, 17*da0073e9SAndroid Build Coastguard Worker lstm_simple_creator, 18*da0073e9SAndroid Build Coastguard Worker pytorch_lstm_creator, 19*da0073e9SAndroid Build Coastguard Worker varlen_lstm_creator, 20*da0073e9SAndroid Build Coastguard Worker varlen_pytorch_lstm_creator, 21*da0073e9SAndroid Build Coastguard Worker) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerclass DisableCuDNN: 25*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 26*da0073e9SAndroid Build Coastguard Worker self.saved = torch.backends.cudnn.enabled 27*da0073e9SAndroid Build Coastguard Worker torch.backends.cudnn.enabled = False 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker def __exit__(self, *args, **kwargs): 30*da0073e9SAndroid Build Coastguard Worker torch.backends.cudnn.enabled = self.saved 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Workerclass DummyContext: 34*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 35*da0073e9SAndroid Build Coastguard Worker pass 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker def __exit__(self, *args, **kwargs): 38*da0073e9SAndroid Build Coastguard Worker pass 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Workerclass AssertNoJIT: 42*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 43*da0073e9SAndroid Build Coastguard Worker import os 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker enabled = os.environ.get("PYTORCH_JIT", 1) 46*da0073e9SAndroid Build Coastguard Worker assert not enabled 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker def __exit__(self, *args, **kwargs): 49*da0073e9SAndroid Build Coastguard Worker pass 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard WorkerRNNRunner = namedtuple( 53*da0073e9SAndroid Build Coastguard Worker "RNNRunner", 54*da0073e9SAndroid Build Coastguard Worker [ 55*da0073e9SAndroid Build Coastguard Worker "name", 56*da0073e9SAndroid Build Coastguard Worker "creator", 57*da0073e9SAndroid Build Coastguard Worker "context", 58*da0073e9SAndroid Build Coastguard Worker ], 59*da0073e9SAndroid Build Coastguard Worker) 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Workerdef get_nn_runners(*names): 63*da0073e9SAndroid Build Coastguard Worker return [nn_runners[name] for name in names] 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Workernn_runners = { 67*da0073e9SAndroid Build Coastguard Worker "cudnn": RNNRunner("cudnn", pytorch_lstm_creator, DummyContext), 68*da0073e9SAndroid Build Coastguard Worker "cudnn_dropout": RNNRunner( 69*da0073e9SAndroid Build Coastguard Worker "cudnn_dropout", partial(pytorch_lstm_creator, dropout=0.4), DummyContext 70*da0073e9SAndroid Build Coastguard Worker ), 71*da0073e9SAndroid Build Coastguard Worker "cudnn_layernorm": RNNRunner( 72*da0073e9SAndroid Build Coastguard Worker "cudnn_layernorm", layernorm_pytorch_lstm_creator, DummyContext 73*da0073e9SAndroid Build Coastguard Worker ), 74*da0073e9SAndroid Build Coastguard Worker "vl_cudnn": RNNRunner("vl_cudnn", varlen_pytorch_lstm_creator, DummyContext), 75*da0073e9SAndroid Build Coastguard Worker "vl_jit": RNNRunner( 76*da0073e9SAndroid Build Coastguard Worker "vl_jit", partial(varlen_lstm_creator, script=True), DummyContext 77*da0073e9SAndroid Build Coastguard Worker ), 78*da0073e9SAndroid Build Coastguard Worker "vl_py": RNNRunner("vl_py", varlen_lstm_creator, DummyContext), 79*da0073e9SAndroid Build Coastguard Worker "aten": RNNRunner("aten", pytorch_lstm_creator, DisableCuDNN), 80*da0073e9SAndroid Build Coastguard Worker "jit": RNNRunner("jit", lstm_creator, DummyContext), 81*da0073e9SAndroid Build Coastguard Worker "jit_premul": RNNRunner("jit_premul", lstm_premul_creator, DummyContext), 82*da0073e9SAndroid Build Coastguard Worker "jit_premul_bias": RNNRunner( 83*da0073e9SAndroid Build Coastguard Worker "jit_premul_bias", lstm_premul_bias_creator, DummyContext 84*da0073e9SAndroid Build Coastguard Worker ), 85*da0073e9SAndroid Build Coastguard Worker "jit_simple": RNNRunner("jit_simple", lstm_simple_creator, DummyContext), 86*da0073e9SAndroid Build Coastguard Worker "jit_multilayer": RNNRunner( 87*da0073e9SAndroid Build Coastguard Worker "jit_multilayer", lstm_multilayer_creator, DummyContext 88*da0073e9SAndroid Build Coastguard Worker ), 89*da0073e9SAndroid Build Coastguard Worker "jit_layernorm": RNNRunner("jit_layernorm", lnlstm_creator, DummyContext), 90*da0073e9SAndroid Build Coastguard Worker "jit_layernorm_decom": RNNRunner( 91*da0073e9SAndroid Build Coastguard Worker "jit_layernorm_decom", 92*da0073e9SAndroid Build Coastguard Worker partial(lnlstm_creator, decompose_layernorm=True), 93*da0073e9SAndroid Build Coastguard Worker DummyContext, 94*da0073e9SAndroid Build Coastguard Worker ), 95*da0073e9SAndroid Build Coastguard Worker "jit_dropout": RNNRunner("jit_dropout", dropoutlstm_creator, DummyContext), 96*da0073e9SAndroid Build Coastguard Worker "py": RNNRunner("py", partial(lstm_creator, script=False), DummyContext), 97*da0073e9SAndroid Build Coastguard Worker "resnet18": RNNRunner( 98*da0073e9SAndroid Build Coastguard Worker "resnet18", imagenet_cnn_creator(cnn.resnet18, jit=False), DummyContext 99*da0073e9SAndroid Build Coastguard Worker ), 100*da0073e9SAndroid Build Coastguard Worker "resnet18_jit": RNNRunner( 101*da0073e9SAndroid Build Coastguard Worker "resnet18_jit", imagenet_cnn_creator(cnn.resnet18), DummyContext 102*da0073e9SAndroid Build Coastguard Worker ), 103*da0073e9SAndroid Build Coastguard Worker "resnet50": RNNRunner( 104*da0073e9SAndroid Build Coastguard Worker "resnet50", imagenet_cnn_creator(cnn.resnet50, jit=False), DummyContext 105*da0073e9SAndroid Build Coastguard Worker ), 106*da0073e9SAndroid Build Coastguard Worker "resnet50_jit": RNNRunner( 107*da0073e9SAndroid Build Coastguard Worker "resnet50_jit", imagenet_cnn_creator(cnn.resnet50), DummyContext 108*da0073e9SAndroid Build Coastguard Worker ), 109*da0073e9SAndroid Build Coastguard Worker} 110