1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport torch 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerclass MkldnnLinear(torch.jit.ScriptModule): 6*da0073e9SAndroid Build Coastguard Worker def __init__(self, dense_module, dtype): 7*da0073e9SAndroid Build Coastguard Worker super().__init__() 8*da0073e9SAndroid Build Coastguard Worker self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype)) 9*da0073e9SAndroid Build Coastguard Worker if dense_module.bias is not None: 10*da0073e9SAndroid Build Coastguard Worker # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy, 11*da0073e9SAndroid Build Coastguard Worker # we use fp32 dtype. 12*da0073e9SAndroid Build Coastguard Worker self.register_buffer('bias', dense_module.bias.to_mkldnn()) 13*da0073e9SAndroid Build Coastguard Worker else: 14*da0073e9SAndroid Build Coastguard Worker # TODO: Remove this once ScriptModule supports registering None buffer 15*da0073e9SAndroid Build Coastguard Worker self.register_buffer( 16*da0073e9SAndroid Build Coastguard Worker 'bias', 17*da0073e9SAndroid Build Coastguard Worker torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn()) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 20*da0073e9SAndroid Build Coastguard Worker def __getstate__(self): 21*da0073e9SAndroid Build Coastguard Worker return (self.weight.to_dense(), self.bias.to_dense(), self.training) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 24*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 25*da0073e9SAndroid Build Coastguard Worker self.weight = state[0].to_mkldnn() 26*da0073e9SAndroid Build Coastguard Worker self.bias = state[1].to_mkldnn() 27*da0073e9SAndroid Build Coastguard Worker self.training = state[2] 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 30*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 31*da0073e9SAndroid Build Coastguard Worker x_mkldnn = x if x.is_mkldnn else x.to_mkldnn() 32*da0073e9SAndroid Build Coastguard Worker y_mkldnn = torch._C._nn.mkldnn_linear(x_mkldnn, self.weight, self.bias) 33*da0073e9SAndroid Build Coastguard Worker y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense() 34*da0073e9SAndroid Build Coastguard Worker return y 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Workerclass _MkldnnConvNd(torch.jit.ScriptModule): 38*da0073e9SAndroid Build Coastguard Worker """Common base of MkldnnConv1d and MkldnnConv2d.""" 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker __constants__ = ['stride', 'padding', 'dilation', 'groups'] 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker def __init__(self, dense_module): 43*da0073e9SAndroid Build Coastguard Worker super().__init__() 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker self.stride = dense_module.stride 46*da0073e9SAndroid Build Coastguard Worker self.padding = dense_module.padding 47*da0073e9SAndroid Build Coastguard Worker self.dilation = dense_module.dilation 48*da0073e9SAndroid Build Coastguard Worker self.groups = dense_module.groups 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker if dense_module.bias is not None: 51*da0073e9SAndroid Build Coastguard Worker self.register_buffer('bias', dense_module.bias.to_mkldnn()) 52*da0073e9SAndroid Build Coastguard Worker else: 53*da0073e9SAndroid Build Coastguard Worker # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy, 54*da0073e9SAndroid Build Coastguard Worker # we use fp32 dtype. 55*da0073e9SAndroid Build Coastguard Worker # TODO: Remove this once ScriptModule supports registering None buffer 56*da0073e9SAndroid Build Coastguard Worker self.register_buffer( 57*da0073e9SAndroid Build Coastguard Worker 'bias', 58*da0073e9SAndroid Build Coastguard Worker torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn()) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 61*da0073e9SAndroid Build Coastguard Worker def __getstate__(self): 62*da0073e9SAndroid Build Coastguard Worker return (self.weight.to_dense(), self.bias.to_dense(), self.training) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 65*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 66*da0073e9SAndroid Build Coastguard Worker return torch.mkldnn_convolution( 67*da0073e9SAndroid Build Coastguard Worker x, 68*da0073e9SAndroid Build Coastguard Worker self.weight, 69*da0073e9SAndroid Build Coastguard Worker self.bias, 70*da0073e9SAndroid Build Coastguard Worker self.padding, 71*da0073e9SAndroid Build Coastguard Worker self.stride, 72*da0073e9SAndroid Build Coastguard Worker self.dilation, 73*da0073e9SAndroid Build Coastguard Worker self.groups) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Workerclass MkldnnConv1d(_MkldnnConvNd): 77*da0073e9SAndroid Build Coastguard Worker def __init__(self, dense_module, dtype): 78*da0073e9SAndroid Build Coastguard Worker super().__init__(dense_module) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype)) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 83*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 84*da0073e9SAndroid Build Coastguard Worker self.weight = state[0].to_mkldnn() 85*da0073e9SAndroid Build Coastguard Worker self.bias = state[1].to_mkldnn() 86*da0073e9SAndroid Build Coastguard Worker self.training = state[2] 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Workerclass MkldnnConv2d(_MkldnnConvNd): 90*da0073e9SAndroid Build Coastguard Worker def __init__(self, dense_module, dtype): 91*da0073e9SAndroid Build Coastguard Worker super().__init__(dense_module) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv2d_weight( 94*da0073e9SAndroid Build Coastguard Worker dense_module.weight.to_mkldnn(dtype), 95*da0073e9SAndroid Build Coastguard Worker self.padding, 96*da0073e9SAndroid Build Coastguard Worker self.stride, 97*da0073e9SAndroid Build Coastguard Worker self.dilation, 98*da0073e9SAndroid Build Coastguard Worker self.groups)) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 101*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 102*da0073e9SAndroid Build Coastguard Worker self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight( 103*da0073e9SAndroid Build Coastguard Worker state[0].to_mkldnn(), 104*da0073e9SAndroid Build Coastguard Worker self.padding, 105*da0073e9SAndroid Build Coastguard Worker self.stride, 106*da0073e9SAndroid Build Coastguard Worker self.dilation, 107*da0073e9SAndroid Build Coastguard Worker self.groups) 108*da0073e9SAndroid Build Coastguard Worker self.bias = state[1].to_mkldnn() 109*da0073e9SAndroid Build Coastguard Worker self.training = state[2] 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Workerclass MkldnnConv3d(_MkldnnConvNd): 112*da0073e9SAndroid Build Coastguard Worker def __init__(self, dense_module, dtype): 113*da0073e9SAndroid Build Coastguard Worker super().__init__(dense_module) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv3d_weight( 116*da0073e9SAndroid Build Coastguard Worker dense_module.weight.to_mkldnn(dtype), 117*da0073e9SAndroid Build Coastguard Worker self.padding, 118*da0073e9SAndroid Build Coastguard Worker self.stride, 119*da0073e9SAndroid Build Coastguard Worker self.dilation, 120*da0073e9SAndroid Build Coastguard Worker self.groups)) 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 123*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 124*da0073e9SAndroid Build Coastguard Worker self.weight = torch._C._nn.mkldnn_reorder_conv3d_weight( 125*da0073e9SAndroid Build Coastguard Worker state[0].to_mkldnn(), 126*da0073e9SAndroid Build Coastguard Worker self.padding, 127*da0073e9SAndroid Build Coastguard Worker self.stride, 128*da0073e9SAndroid Build Coastguard Worker self.dilation, 129*da0073e9SAndroid Build Coastguard Worker self.groups) 130*da0073e9SAndroid Build Coastguard Worker self.bias = state[1].to_mkldnn() 131*da0073e9SAndroid Build Coastguard Worker self.training = state[2] 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Workerclass MkldnnBatchNorm(torch.jit.ScriptModule): 135*da0073e9SAndroid Build Coastguard Worker __constants__ = ['exponential_average_factor', 'eps'] 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker def __init__(self, dense_module): 138*da0073e9SAndroid Build Coastguard Worker super().__init__() 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker assert not dense_module.training 141*da0073e9SAndroid Build Coastguard Worker assert dense_module.track_running_stats 142*da0073e9SAndroid Build Coastguard Worker assert dense_module.affine 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker if dense_module.momentum is None: 145*da0073e9SAndroid Build Coastguard Worker self.exponential_average_factor = 0.0 146*da0073e9SAndroid Build Coastguard Worker else: 147*da0073e9SAndroid Build Coastguard Worker self.exponential_average_factor = dense_module.momentum 148*da0073e9SAndroid Build Coastguard Worker self.eps = dense_module.eps 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker self.register_buffer('weight', dense_module.weight.to_mkldnn()) 151*da0073e9SAndroid Build Coastguard Worker self.register_buffer('bias', dense_module.bias.to_mkldnn()) 152*da0073e9SAndroid Build Coastguard Worker self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn()) 153*da0073e9SAndroid Build Coastguard Worker self.register_buffer('running_var', dense_module.running_var.to_mkldnn()) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 156*da0073e9SAndroid Build Coastguard Worker def __getstate__(self): 157*da0073e9SAndroid Build Coastguard Worker weight = self.weight.to_dense() 158*da0073e9SAndroid Build Coastguard Worker bias = self.bias.to_dense() 159*da0073e9SAndroid Build Coastguard Worker running_mean = self.running_mean.to_dense() 160*da0073e9SAndroid Build Coastguard Worker running_var = self.running_var.to_dense() 161*da0073e9SAndroid Build Coastguard Worker return (weight, bias, running_mean, running_var, self.training) 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 164*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 165*da0073e9SAndroid Build Coastguard Worker self.weight = state[0].to_mkldnn() 166*da0073e9SAndroid Build Coastguard Worker self.bias = state[1].to_mkldnn() 167*da0073e9SAndroid Build Coastguard Worker self.running_mean = state[2].to_mkldnn() 168*da0073e9SAndroid Build Coastguard Worker self.running_var = state[3].to_mkldnn() 169*da0073e9SAndroid Build Coastguard Worker self.training = state[4] 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 172*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 173*da0073e9SAndroid Build Coastguard Worker return torch.batch_norm( 174*da0073e9SAndroid Build Coastguard Worker x, 175*da0073e9SAndroid Build Coastguard Worker self.weight, 176*da0073e9SAndroid Build Coastguard Worker self.bias, 177*da0073e9SAndroid Build Coastguard Worker self.running_mean, 178*da0073e9SAndroid Build Coastguard Worker self.running_var, 179*da0073e9SAndroid Build Coastguard Worker False, # training 180*da0073e9SAndroid Build Coastguard Worker self.exponential_average_factor, 181*da0073e9SAndroid Build Coastguard Worker self.eps, 182*da0073e9SAndroid Build Coastguard Worker False, # cuda_enabled 183*da0073e9SAndroid Build Coastguard Worker ) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Workerclass MkldnnPrelu(torch.jit.ScriptModule): 186*da0073e9SAndroid Build Coastguard Worker def __init__(self, dense_module, dtype): 187*da0073e9SAndroid Build Coastguard Worker super().__init__() 188*da0073e9SAndroid Build Coastguard Worker self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype)) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 191*da0073e9SAndroid Build Coastguard Worker def __getstate__(self): 192*da0073e9SAndroid Build Coastguard Worker return (self.weight.to_dense(), self.training) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 195*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 196*da0073e9SAndroid Build Coastguard Worker self.weight = state[0].to_mkldnn() 197*da0073e9SAndroid Build Coastguard Worker self.training = state[1] 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 200*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 201*da0073e9SAndroid Build Coastguard Worker x_mkldnn = x if x.is_mkldnn else x.to_mkldnn() 202*da0073e9SAndroid Build Coastguard Worker y_mkldnn = torch.prelu(x_mkldnn, self.weight) 203*da0073e9SAndroid Build Coastguard Worker y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense() 204*da0073e9SAndroid Build Coastguard Worker return y 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Workerdef to_mkldnn(module, dtype=torch.float): 207*da0073e9SAndroid Build Coastguard Worker assert dtype in [torch.float, torch.bfloat16, torch.half], \ 208*da0073e9SAndroid Build Coastguard Worker "MKLDNN only support float, bfloat16, and half path now" 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker def m_fn(m, d): 211*da0073e9SAndroid Build Coastguard Worker if isinstance(m, torch.nn.Linear): 212*da0073e9SAndroid Build Coastguard Worker return MkldnnLinear(m, d) 213*da0073e9SAndroid Build Coastguard Worker elif isinstance(m, torch.nn.Conv1d): 214*da0073e9SAndroid Build Coastguard Worker return MkldnnConv1d(m, d) 215*da0073e9SAndroid Build Coastguard Worker elif isinstance(m, torch.nn.Conv2d): 216*da0073e9SAndroid Build Coastguard Worker return MkldnnConv2d(m, d) 217*da0073e9SAndroid Build Coastguard Worker elif isinstance(m, torch.nn.Conv3d): 218*da0073e9SAndroid Build Coastguard Worker return MkldnnConv3d(m, d) 219*da0073e9SAndroid Build Coastguard Worker elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)): 220*da0073e9SAndroid Build Coastguard Worker # For batchnorm bf16 path, OneDNN requires weight and bias need fp32 dtype. 221*da0073e9SAndroid Build Coastguard Worker # so it doesn't need dtype argument. 222*da0073e9SAndroid Build Coastguard Worker return MkldnnBatchNorm(m) 223*da0073e9SAndroid Build Coastguard Worker elif isinstance(m, torch.nn.PReLU): 224*da0073e9SAndroid Build Coastguard Worker return MkldnnPrelu(m, d) 225*da0073e9SAndroid Build Coastguard Worker else: 226*da0073e9SAndroid Build Coastguard Worker return m 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker def m_fn_rec(m, d): 229*da0073e9SAndroid Build Coastguard Worker new_m = m_fn(m, d) 230*da0073e9SAndroid Build Coastguard Worker for name, sub_m in m.named_children(): 231*da0073e9SAndroid Build Coastguard Worker setattr(new_m, name, m_fn_rec(sub_m, d)) 232*da0073e9SAndroid Build Coastguard Worker return new_m 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker return m_fn_rec(module, dtype) 235