xref: /aosp_15_r20/external/pytorch/torch/utils/mkldnn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4
5class MkldnnLinear(torch.jit.ScriptModule):
6    def __init__(self, dense_module, dtype):
7        super().__init__()
8        self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
9        if dense_module.bias is not None:
10            # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy,
11            # we use fp32 dtype.
12            self.register_buffer('bias', dense_module.bias.to_mkldnn())
13        else:
14            # TODO: Remove this once ScriptModule supports registering None buffer
15            self.register_buffer(
16                'bias',
17                torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
18
19    @torch.jit.script_method
20    def __getstate__(self):
21        return (self.weight.to_dense(), self.bias.to_dense(), self.training)
22
23    @torch.jit.script_method
24    def __setstate__(self, state):
25        self.weight = state[0].to_mkldnn()
26        self.bias = state[1].to_mkldnn()
27        self.training = state[2]
28
29    @torch.jit.script_method
30    def forward(self, x):
31        x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
32        y_mkldnn = torch._C._nn.mkldnn_linear(x_mkldnn, self.weight, self.bias)
33        y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
34        return y
35
36
37class _MkldnnConvNd(torch.jit.ScriptModule):
38    """Common base of MkldnnConv1d and MkldnnConv2d."""
39
40    __constants__ = ['stride', 'padding', 'dilation', 'groups']
41
42    def __init__(self, dense_module):
43        super().__init__()
44
45        self.stride = dense_module.stride
46        self.padding = dense_module.padding
47        self.dilation = dense_module.dilation
48        self.groups = dense_module.groups
49
50        if dense_module.bias is not None:
51            self.register_buffer('bias', dense_module.bias.to_mkldnn())
52        else:
53            # Bias can be fp32 or bf16 for OneDNN bf16 path, but for good accuracy,
54            # we use fp32 dtype.
55            # TODO: Remove this once ScriptModule supports registering None buffer
56            self.register_buffer(
57                'bias',
58                torch.zeros([dense_module.weight.size(0)], dtype=torch.float).to_mkldnn())
59
60    @torch.jit.script_method
61    def __getstate__(self):
62        return (self.weight.to_dense(), self.bias.to_dense(), self.training)
63
64    @torch.jit.script_method
65    def forward(self, x):
66        return torch.mkldnn_convolution(
67            x,
68            self.weight,
69            self.bias,
70            self.padding,
71            self.stride,
72            self.dilation,
73            self.groups)
74
75
76class MkldnnConv1d(_MkldnnConvNd):
77    def __init__(self, dense_module, dtype):
78        super().__init__(dense_module)
79
80        self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
81
82    @torch.jit.script_method
83    def __setstate__(self, state):
84        self.weight = state[0].to_mkldnn()
85        self.bias = state[1].to_mkldnn()
86        self.training = state[2]
87
88
89class MkldnnConv2d(_MkldnnConvNd):
90    def __init__(self, dense_module, dtype):
91        super().__init__(dense_module)
92
93        self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv2d_weight(
94            dense_module.weight.to_mkldnn(dtype),
95            self.padding,
96            self.stride,
97            self.dilation,
98            self.groups))
99
100    @torch.jit.script_method
101    def __setstate__(self, state):
102        self.weight = torch._C._nn.mkldnn_reorder_conv2d_weight(
103            state[0].to_mkldnn(),
104            self.padding,
105            self.stride,
106            self.dilation,
107            self.groups)
108        self.bias = state[1].to_mkldnn()
109        self.training = state[2]
110
111class MkldnnConv3d(_MkldnnConvNd):
112    def __init__(self, dense_module, dtype):
113        super().__init__(dense_module)
114
115        self.register_buffer('weight', torch._C._nn.mkldnn_reorder_conv3d_weight(
116            dense_module.weight.to_mkldnn(dtype),
117            self.padding,
118            self.stride,
119            self.dilation,
120            self.groups))
121
122    @torch.jit.script_method
123    def __setstate__(self, state):
124        self.weight = torch._C._nn.mkldnn_reorder_conv3d_weight(
125            state[0].to_mkldnn(),
126            self.padding,
127            self.stride,
128            self.dilation,
129            self.groups)
130        self.bias = state[1].to_mkldnn()
131        self.training = state[2]
132
133
134class MkldnnBatchNorm(torch.jit.ScriptModule):
135    __constants__ = ['exponential_average_factor', 'eps']
136
137    def __init__(self, dense_module):
138        super().__init__()
139
140        assert not dense_module.training
141        assert dense_module.track_running_stats
142        assert dense_module.affine
143
144        if dense_module.momentum is None:
145            self.exponential_average_factor = 0.0
146        else:
147            self.exponential_average_factor = dense_module.momentum
148        self.eps = dense_module.eps
149
150        self.register_buffer('weight', dense_module.weight.to_mkldnn())
151        self.register_buffer('bias', dense_module.bias.to_mkldnn())
152        self.register_buffer('running_mean', dense_module.running_mean.to_mkldnn())
153        self.register_buffer('running_var', dense_module.running_var.to_mkldnn())
154
155    @torch.jit.script_method
156    def __getstate__(self):
157        weight = self.weight.to_dense()
158        bias = self.bias.to_dense()
159        running_mean = self.running_mean.to_dense()
160        running_var = self.running_var.to_dense()
161        return (weight, bias, running_mean, running_var, self.training)
162
163    @torch.jit.script_method
164    def __setstate__(self, state):
165        self.weight = state[0].to_mkldnn()
166        self.bias = state[1].to_mkldnn()
167        self.running_mean = state[2].to_mkldnn()
168        self.running_var = state[3].to_mkldnn()
169        self.training = state[4]
170
171    @torch.jit.script_method
172    def forward(self, x):
173        return torch.batch_norm(
174            x,
175            self.weight,
176            self.bias,
177            self.running_mean,
178            self.running_var,
179            False,  # training
180            self.exponential_average_factor,
181            self.eps,
182            False,  # cuda_enabled
183        )
184
185class MkldnnPrelu(torch.jit.ScriptModule):
186    def __init__(self, dense_module, dtype):
187        super().__init__()
188        self.register_buffer('weight', dense_module.weight.to_mkldnn(dtype))
189
190    @torch.jit.script_method
191    def __getstate__(self):
192        return (self.weight.to_dense(), self.training)
193
194    @torch.jit.script_method
195    def __setstate__(self, state):
196        self.weight = state[0].to_mkldnn()
197        self.training = state[1]
198
199    @torch.jit.script_method
200    def forward(self, x):
201        x_mkldnn = x if x.is_mkldnn else x.to_mkldnn()
202        y_mkldnn = torch.prelu(x_mkldnn, self.weight)
203        y = y_mkldnn if x.is_mkldnn else y_mkldnn.to_dense()
204        return y
205
206def to_mkldnn(module, dtype=torch.float):
207    assert dtype in [torch.float, torch.bfloat16, torch.half], \
208        "MKLDNN only support float, bfloat16, and half path now"
209
210    def m_fn(m, d):
211        if isinstance(m, torch.nn.Linear):
212            return MkldnnLinear(m, d)
213        elif isinstance(m, torch.nn.Conv1d):
214            return MkldnnConv1d(m, d)
215        elif isinstance(m, torch.nn.Conv2d):
216            return MkldnnConv2d(m, d)
217        elif isinstance(m, torch.nn.Conv3d):
218            return MkldnnConv3d(m, d)
219        elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
220            # For batchnorm bf16 path, OneDNN requires weight and bias need fp32 dtype.
221            # so it doesn't need dtype argument.
222            return MkldnnBatchNorm(m)
223        elif isinstance(m, torch.nn.PReLU):
224            return MkldnnPrelu(m, d)
225        else:
226            return m
227
228    def m_fn_rec(m, d):
229        new_m = m_fn(m, d)
230        for name, sub_m in m.named_children():
231            setattr(new_m, name, m_fn_rec(sub_m, d))
232        return new_m
233
234    return m_fn_rec(module, dtype)
235