1# mypy: allow-untyped-defs 2import torch 3 4 5class MkldnnLinear(torch.jit.ScriptModule): 6 def __init__(self, dense_module, dtype): 7 super().__init__() 8 self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype)) 9 if dense_module.bias is not None: 10 # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy, 11 # we use fp32 dtype. 12 self.register_buffer('bias', dense_module.bias.to_mkldnn()) 13 else: 14 # TODO: Remove this once ScriptModule supports registering None buffer 15 self.register_buffer( 16 'bias', 17 torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn()) 18 19 @torch.jit.script_method 20 def __getstate__(self): 21 return (self.weight.to_dense(), self.bias.to_dense(), self.training) 22 23 @torch.jit.script_method 24 def __setstate__(self, state): 25 self.weight = state[0].to_mkldnn() 26 self.bias = state[1].to_mkldnn() 27 self.training = state[2] 28 29 @torch.jit.script_method 30 def forward(self, x): 31 x_mkldnn = x if x.is_mkldnn else x.to_mkldnn() 32 y_mkldnn = torch._C._nn.mkldnn_linear(x_mkldnn, self.weight, self.bias) 33 y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense() 34 return y 35 36 37class _MkldnnConvNd(torch.jit.ScriptModule): 38 """Common base of MkldnnConv1d and MkldnnConv2d.""" 39 40 __constants__ = ['stride', 'padding', 'dilation', 'groups'] 41 42 def __init__(self, dense_module): 43 super().__init__() 44 45 self.stride = dense_module.stride 46 self.padding = dense_module.padding 47 self.dilation = dense_module.dilation 48 self.groups = dense_module.groups 49 50 if dense_module.bias is not None: 51 self.register_buffer('bias', dense_module.bias.to_mkldnn()) 52 else: 53 # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy, 54 # we use fp32 dtype. 55 # TODO: Remove this once ScriptModule supports registering None buffer 56 self.register_buffer( 57 'bias', 58 torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn()) 59 60 @torch.jit.script_method 61 def __getstate__(self): 62 return (self.weight.to_dense(), self.bias.to_dense(), self.training) 63 64 @torch.jit.script_method 65 def forward(self, x): 66 return torch.mkldnn_convolution( 67 x, 68 self.weight, 69 self.bias, 70 self.padding, 71 self.stride, 72 self.dilation, 73 self.groups) 74 75 76class MkldnnConv1d(_MkldnnConvNd): 77 def __init__(self, dense_module, dtype): 78 super().__init__(dense_module) 79 80 self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype)) 81 82 @torch.jit.script_method 83 def __setstate__(self, state): 84 self.weight = state[0].to_mkldnn() 85 self.bias = state[1].to_mkldnn() 86 self.training = state[2] 87 88 89class MkldnnConv2d(_MkldnnConvNd): 90 def __init__(self, dense_module, dtype): 91 super().__init__(dense_module) 92 93 self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv2d_weight( 94 dense_module.weight.to_mkldnn(dtype), 95 self.padding, 96 self.stride, 97 self.dilation, 98 self.groups)) 99 100 @torch.jit.script_method 101 def __setstate__(self, state): 102 self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight( 103 state[0].to_mkldnn(), 104 self.padding, 105 self.stride, 106 self.dilation, 107 self.groups) 108 self.bias = state[1].to_mkldnn() 109 self.training = state[2] 110 111class MkldnnConv3d(_MkldnnConvNd): 112 def __init__(self, dense_module, dtype): 113 super().__init__(dense_module) 114 115 self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv3d_weight( 116 dense_module.weight.to_mkldnn(dtype), 117 self.padding, 118 self.stride, 119 self.dilation, 120 self.groups)) 121 122 @torch.jit.script_method 123 def __setstate__(self, state): 124 self.weight = torch._C._nn.mkldnn_reorder_conv3d_weight( 125 state[0].to_mkldnn(), 126 self.padding, 127 self.stride, 128 self.dilation, 129 self.groups) 130 self.bias = state[1].to_mkldnn() 131 self.training = state[2] 132 133 134class MkldnnBatchNorm(torch.jit.ScriptModule): 135 __constants__ = ['exponential_average_factor', 'eps'] 136 137 def __init__(self, dense_module): 138 super().__init__() 139 140 assert not dense_module.training 141 assert dense_module.track_running_stats 142 assert dense_module.affine 143 144 if dense_module.momentum is None: 145 self.exponential_average_factor = 0.0 146 else: 147 self.exponential_average_factor = dense_module.momentum 148 self.eps = dense_module.eps 149 150 self.register_buffer('weight', dense_module.weight.to_mkldnn()) 151 self.register_buffer('bias', dense_module.bias.to_mkldnn()) 152 self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn()) 153 self.register_buffer('running_var', dense_module.running_var.to_mkldnn()) 154 155 @torch.jit.script_method 156 def __getstate__(self): 157 weight = self.weight.to_dense() 158 bias = self.bias.to_dense() 159 running_mean = self.running_mean.to_dense() 160 running_var = self.running_var.to_dense() 161 return (weight, bias, running_mean, running_var, self.training) 162 163 @torch.jit.script_method 164 def __setstate__(self, state): 165 self.weight = state[0].to_mkldnn() 166 self.bias = state[1].to_mkldnn() 167 self.running_mean = state[2].to_mkldnn() 168 self.running_var = state[3].to_mkldnn() 169 self.training = state[4] 170 171 @torch.jit.script_method 172 def forward(self, x): 173 return torch.batch_norm( 174 x, 175 self.weight, 176 self.bias, 177 self.running_mean, 178 self.running_var, 179 False, # training 180 self.exponential_average_factor, 181 self.eps, 182 False, # cuda_enabled 183 ) 184 185class MkldnnPrelu(torch.jit.ScriptModule): 186 def __init__(self, dense_module, dtype): 187 super().__init__() 188 self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype)) 189 190 @torch.jit.script_method 191 def __getstate__(self): 192 return (self.weight.to_dense(), self.training) 193 194 @torch.jit.script_method 195 def __setstate__(self, state): 196 self.weight = state[0].to_mkldnn() 197 self.training = state[1] 198 199 @torch.jit.script_method 200 def forward(self, x): 201 x_mkldnn = x if x.is_mkldnn else x.to_mkldnn() 202 y_mkldnn = torch.prelu(x_mkldnn, self.weight) 203 y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense() 204 return y 205 206def to_mkldnn(module, dtype=torch.float): 207 assert dtype in [torch.float, torch.bfloat16, torch.half], \ 208 "MKLDNN only support float, bfloat16, and half path now" 209 210 def m_fn(m, d): 211 if isinstance(m, torch.nn.Linear): 212 return MkldnnLinear(m, d) 213 elif isinstance(m, torch.nn.Conv1d): 214 return MkldnnConv1d(m, d) 215 elif isinstance(m, torch.nn.Conv2d): 216 return MkldnnConv2d(m, d) 217 elif isinstance(m, torch.nn.Conv3d): 218 return MkldnnConv3d(m, d) 219 elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)): 220 # For batchnorm bf16 path, OneDNN requires weight and bias need fp32 dtype. 221 # so it doesn't need dtype argument. 222 return MkldnnBatchNorm(m) 223 elif isinstance(m, torch.nn.PReLU): 224 return MkldnnPrelu(m, d) 225 else: 226 return m 227 228 def m_fn_rec(m, d): 229 new_m = m_fn(m, d) 230 for name, sub_m in m.named_children(): 231 setattr(new_m, name, m_fn_rec(sub_m, d)) 232 return new_m 233 234 return m_fn_rec(module, dtype) 235