xref: /aosp_15_r20/external/pytorch/torch/utils/mkldnn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport torch
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerclass MkldnnLinear(torch.jit.ScriptModule):
6*da0073e9SAndroid Build Coastguard Worker    def __init__(self, dense_module, dtype):
7*da0073e9SAndroid Build Coastguard Worker        super().__init__()
8*da0073e9SAndroid Build Coastguard Worker        self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
9*da0073e9SAndroid Build Coastguard Worker        if dense_module.bias is not None:
10*da0073e9SAndroid Build Coastguard Worker            # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy,
11*da0073e9SAndroid Build Coastguard Worker            # we use fp32 dtype.
12*da0073e9SAndroid Build Coastguard Worker            self.register_buffer('bias', dense_module.bias.to_mkldnn())
13*da0073e9SAndroid Build Coastguard Worker        else:
14*da0073e9SAndroid Build Coastguard Worker            # TODO: Remove this once ScriptModule supports registering None buffer
15*da0073e9SAndroid Build Coastguard Worker            self.register_buffer(
16*da0073e9SAndroid Build Coastguard Worker                'bias',
17*da0073e9SAndroid Build Coastguard Worker                torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
20*da0073e9SAndroid Build Coastguard Worker    def __getstate__(self):
21*da0073e9SAndroid Build Coastguard Worker        return (self.weight.to_dense(), self.bias.to_dense(), self.training)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
24*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):
25*da0073e9SAndroid Build Coastguard Worker        self.weight = state[0].to_mkldnn()
26*da0073e9SAndroid Build Coastguard Worker        self.bias = state[1].to_mkldnn()
27*da0073e9SAndroid Build Coastguard Worker        self.training = state[2]
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
30*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
31*da0073e9SAndroid Build Coastguard Worker        x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
32*da0073e9SAndroid Build Coastguard Worker        y_mkldnn = torch._C._nn.mkldnn_linear(x_mkldnn, self.weight, self.bias)
33*da0073e9SAndroid Build Coastguard Worker        y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
34*da0073e9SAndroid Build Coastguard Worker        return y
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Workerclass _MkldnnConvNd(torch.jit.ScriptModule):
38*da0073e9SAndroid Build Coastguard Worker    """Common base of MkldnnConv1d and MkldnnConv2d."""
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    __constants__ = ['stride', 'padding', 'dilation', 'groups']
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    def __init__(self, dense_module):
43*da0073e9SAndroid Build Coastguard Worker        super().__init__()
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        self.stride = dense_module.stride
46*da0073e9SAndroid Build Coastguard Worker        self.padding = dense_module.padding
47*da0073e9SAndroid Build Coastguard Worker        self.dilation = dense_module.dilation
48*da0073e9SAndroid Build Coastguard Worker        self.groups = dense_module.groups
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker        if dense_module.bias is not None:
51*da0073e9SAndroid Build Coastguard Worker            self.register_buffer('bias', dense_module.bias.to_mkldnn())
52*da0073e9SAndroid Build Coastguard Worker        else:
53*da0073e9SAndroid Build Coastguard Worker            # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy,
54*da0073e9SAndroid Build Coastguard Worker            # we use fp32 dtype.
55*da0073e9SAndroid Build Coastguard Worker            # TODO: Remove this once ScriptModule supports registering None buffer
56*da0073e9SAndroid Build Coastguard Worker            self.register_buffer(
57*da0073e9SAndroid Build Coastguard Worker                'bias',
58*da0073e9SAndroid Build Coastguard Worker                torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
61*da0073e9SAndroid Build Coastguard Worker    def __getstate__(self):
62*da0073e9SAndroid Build Coastguard Worker        return (self.weight.to_dense(), self.bias.to_dense(), self.training)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
65*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
66*da0073e9SAndroid Build Coastguard Worker        return torch.mkldnn_convolution(
67*da0073e9SAndroid Build Coastguard Worker            x,
68*da0073e9SAndroid Build Coastguard Worker            self.weight,
69*da0073e9SAndroid Build Coastguard Worker            self.bias,
70*da0073e9SAndroid Build Coastguard Worker            self.padding,
71*da0073e9SAndroid Build Coastguard Worker            self.stride,
72*da0073e9SAndroid Build Coastguard Worker            self.dilation,
73*da0073e9SAndroid Build Coastguard Worker            self.groups)
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Workerclass MkldnnConv1d(_MkldnnConvNd):
77*da0073e9SAndroid Build Coastguard Worker    def __init__(self, dense_module, dtype):
78*da0073e9SAndroid Build Coastguard Worker        super().__init__(dense_module)
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker        self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
83*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):
84*da0073e9SAndroid Build Coastguard Worker        self.weight = state[0].to_mkldnn()
85*da0073e9SAndroid Build Coastguard Worker        self.bias = state[1].to_mkldnn()
86*da0073e9SAndroid Build Coastguard Worker        self.training = state[2]
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Workerclass MkldnnConv2d(_MkldnnConvNd):
90*da0073e9SAndroid Build Coastguard Worker    def __init__(self, dense_module, dtype):
91*da0073e9SAndroid Build Coastguard Worker        super().__init__(dense_module)
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker        self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv2d_weight(
94*da0073e9SAndroid Build Coastguard Worker            dense_module.weight.to_mkldnn(dtype),
95*da0073e9SAndroid Build Coastguard Worker            self.padding,
96*da0073e9SAndroid Build Coastguard Worker            self.stride,
97*da0073e9SAndroid Build Coastguard Worker            self.dilation,
98*da0073e9SAndroid Build Coastguard Worker            self.groups))
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
101*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):
102*da0073e9SAndroid Build Coastguard Worker        self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight(
103*da0073e9SAndroid Build Coastguard Worker            state[0].to_mkldnn(),
104*da0073e9SAndroid Build Coastguard Worker            self.padding,
105*da0073e9SAndroid Build Coastguard Worker            self.stride,
106*da0073e9SAndroid Build Coastguard Worker            self.dilation,
107*da0073e9SAndroid Build Coastguard Worker            self.groups)
108*da0073e9SAndroid Build Coastguard Worker        self.bias = state[1].to_mkldnn()
109*da0073e9SAndroid Build Coastguard Worker        self.training = state[2]
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Workerclass MkldnnConv3d(_MkldnnConvNd):
112*da0073e9SAndroid Build Coastguard Worker    def __init__(self, dense_module, dtype):
113*da0073e9SAndroid Build Coastguard Worker        super().__init__(dense_module)
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker        self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv3d_weight(
116*da0073e9SAndroid Build Coastguard Worker            dense_module.weight.to_mkldnn(dtype),
117*da0073e9SAndroid Build Coastguard Worker            self.padding,
118*da0073e9SAndroid Build Coastguard Worker            self.stride,
119*da0073e9SAndroid Build Coastguard Worker            self.dilation,
120*da0073e9SAndroid Build Coastguard Worker            self.groups))
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
123*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):
124*da0073e9SAndroid Build Coastguard Worker        self.weight = torch._C._nn.mkldnn_reorder_conv3d_weight(
125*da0073e9SAndroid Build Coastguard Worker            state[0].to_mkldnn(),
126*da0073e9SAndroid Build Coastguard Worker            self.padding,
127*da0073e9SAndroid Build Coastguard Worker            self.stride,
128*da0073e9SAndroid Build Coastguard Worker            self.dilation,
129*da0073e9SAndroid Build Coastguard Worker            self.groups)
130*da0073e9SAndroid Build Coastguard Worker        self.bias = state[1].to_mkldnn()
131*da0073e9SAndroid Build Coastguard Worker        self.training = state[2]
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Workerclass MkldnnBatchNorm(torch.jit.ScriptModule):
135*da0073e9SAndroid Build Coastguard Worker    __constants__ = ['exponential_average_factor', 'eps']
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    def __init__(self, dense_module):
138*da0073e9SAndroid Build Coastguard Worker        super().__init__()
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        assert not dense_module.training
141*da0073e9SAndroid Build Coastguard Worker        assert dense_module.track_running_stats
142*da0073e9SAndroid Build Coastguard Worker        assert dense_module.affine
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        if dense_module.momentum is None:
145*da0073e9SAndroid Build Coastguard Worker            self.exponential_average_factor = 0.0
146*da0073e9SAndroid Build Coastguard Worker        else:
147*da0073e9SAndroid Build Coastguard Worker            self.exponential_average_factor = dense_module.momentum
148*da0073e9SAndroid Build Coastguard Worker        self.eps = dense_module.eps
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker        self.register_buffer('weight', dense_module.weight.to_mkldnn())
151*da0073e9SAndroid Build Coastguard Worker        self.register_buffer('bias', dense_module.bias.to_mkldnn())
152*da0073e9SAndroid Build Coastguard Worker        self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn())
153*da0073e9SAndroid Build Coastguard Worker        self.register_buffer('running_var', dense_module.running_var.to_mkldnn())
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
156*da0073e9SAndroid Build Coastguard Worker    def __getstate__(self):
157*da0073e9SAndroid Build Coastguard Worker        weight = self.weight.to_dense()
158*da0073e9SAndroid Build Coastguard Worker        bias = self.bias.to_dense()
159*da0073e9SAndroid Build Coastguard Worker        running_mean = self.running_mean.to_dense()
160*da0073e9SAndroid Build Coastguard Worker        running_var = self.running_var.to_dense()
161*da0073e9SAndroid Build Coastguard Worker        return (weight, bias, running_mean, running_var, self.training)
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
164*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):
165*da0073e9SAndroid Build Coastguard Worker        self.weight = state[0].to_mkldnn()
166*da0073e9SAndroid Build Coastguard Worker        self.bias = state[1].to_mkldnn()
167*da0073e9SAndroid Build Coastguard Worker        self.running_mean = state[2].to_mkldnn()
168*da0073e9SAndroid Build Coastguard Worker        self.running_var = state[3].to_mkldnn()
169*da0073e9SAndroid Build Coastguard Worker        self.training = state[4]
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
172*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
173*da0073e9SAndroid Build Coastguard Worker        return torch.batch_norm(
174*da0073e9SAndroid Build Coastguard Worker            x,
175*da0073e9SAndroid Build Coastguard Worker            self.weight,
176*da0073e9SAndroid Build Coastguard Worker            self.bias,
177*da0073e9SAndroid Build Coastguard Worker            self.running_mean,
178*da0073e9SAndroid Build Coastguard Worker            self.running_var,
179*da0073e9SAndroid Build Coastguard Worker            False,  # training
180*da0073e9SAndroid Build Coastguard Worker            self.exponential_average_factor,
181*da0073e9SAndroid Build Coastguard Worker            self.eps,
182*da0073e9SAndroid Build Coastguard Worker            False,  # cuda_enabled
183*da0073e9SAndroid Build Coastguard Worker        )
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Workerclass MkldnnPrelu(torch.jit.ScriptModule):
186*da0073e9SAndroid Build Coastguard Worker    def __init__(self, dense_module, dtype):
187*da0073e9SAndroid Build Coastguard Worker        super().__init__()
188*da0073e9SAndroid Build Coastguard Worker        self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
191*da0073e9SAndroid Build Coastguard Worker    def __getstate__(self):
192*da0073e9SAndroid Build Coastguard Worker        return (self.weight.to_dense(), self.training)
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
195*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):
196*da0073e9SAndroid Build Coastguard Worker        self.weight = state[0].to_mkldnn()
197*da0073e9SAndroid Build Coastguard Worker        self.training = state[1]
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker    @torch.jit.script_method
200*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
201*da0073e9SAndroid Build Coastguard Worker        x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
202*da0073e9SAndroid Build Coastguard Worker        y_mkldnn = torch.prelu(x_mkldnn, self.weight)
203*da0073e9SAndroid Build Coastguard Worker        y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
204*da0073e9SAndroid Build Coastguard Worker        return y
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Workerdef to_mkldnn(module, dtype=torch.float):
207*da0073e9SAndroid Build Coastguard Worker    assert dtype in [torch.float, torch.bfloat16, torch.half], \
208*da0073e9SAndroid Build Coastguard Worker        "MKLDNN only support float, bfloat16, and half path now"
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker    def m_fn(m, d):
211*da0073e9SAndroid Build Coastguard Worker        if isinstance(m, torch.nn.Linear):
212*da0073e9SAndroid Build Coastguard Worker            return MkldnnLinear(m, d)
213*da0073e9SAndroid Build Coastguard Worker        elif isinstance(m, torch.nn.Conv1d):
214*da0073e9SAndroid Build Coastguard Worker            return MkldnnConv1d(m, d)
215*da0073e9SAndroid Build Coastguard Worker        elif isinstance(m, torch.nn.Conv2d):
216*da0073e9SAndroid Build Coastguard Worker            return MkldnnConv2d(m, d)
217*da0073e9SAndroid Build Coastguard Worker        elif isinstance(m, torch.nn.Conv3d):
218*da0073e9SAndroid Build Coastguard Worker            return MkldnnConv3d(m, d)
219*da0073e9SAndroid Build Coastguard Worker        elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
220*da0073e9SAndroid Build Coastguard Worker            # For batchnorm bf16 path, OneDNN requires weight and bias need fp32 dtype.
221*da0073e9SAndroid Build Coastguard Worker            # so it doesn't need dtype argument.
222*da0073e9SAndroid Build Coastguard Worker            return MkldnnBatchNorm(m)
223*da0073e9SAndroid Build Coastguard Worker        elif isinstance(m, torch.nn.PReLU):
224*da0073e9SAndroid Build Coastguard Worker            return MkldnnPrelu(m, d)
225*da0073e9SAndroid Build Coastguard Worker        else:
226*da0073e9SAndroid Build Coastguard Worker            return m
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker    def m_fn_rec(m, d):
229*da0073e9SAndroid Build Coastguard Worker        new_m = m_fn(m, d)
230*da0073e9SAndroid Build Coastguard Worker        for name, sub_m in m.named_children():
231*da0073e9SAndroid Build Coastguard Worker            setattr(new_m, name, m_fn_rec(sub_m, d))
232*da0073e9SAndroid Build Coastguard Worker        return new_m
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker    return m_fn_rec(module, dtype)
235