1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <tuple>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/_batch_norm_with_update_native.h>
10 #include <ATen/ops/batch_norm_backward_native.h>
11 #include <ATen/ops/_native_batch_norm_legit_native.h>
12 #include <ATen/ops/_to_dense_native.h>
13 #include <ATen/ops/empty_native.h>
14 #include <ATen/ops/native_batch_norm_backward_native.h>
15 #include <ATen/ops/native_batch_norm_native.h>
16 #endif
17 #include <ATen/native/mkldnn/Utils.h>
18
19 #if !AT_MKLDNN_ENABLED()
20
21 namespace at {
22 namespace native {
23
mkldnn_batch_norm(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool train,double momentum,double eps)24 std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(
25 const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
26 bool train,
27 double momentum,
28 double eps) {
29 TORCH_CHECK(false, "mkldnn_batch_norm: ATen not compiled with MKLDNN support");
30 }
31
mkldnn_batch_norm_backward(const Tensor & grad_output,const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_invstd_opt,bool train,double eps,std::array<bool,3> grad_input_mask)32 std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(
33 const Tensor& grad_output,
34 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_invstd_opt,
35 bool train,
36 double eps,
37 std::array<bool,3> grad_input_mask) {
38 TORCH_CHECK(false, "mkldnn_batch_norm_backward: ATen not compiled with MKLDNN support");
39 }
40
mkldnn_layer_norm_last_index_weight_bias_f32(const Tensor & input,IntArrayRef normalized_shape,const Tensor & weight,const Tensor & bias,double eps,bool inplace)41 std::tuple<Tensor, Tensor, Tensor> mkldnn_layer_norm_last_index_weight_bias_f32(
42 const Tensor& input,
43 IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias,
44 double eps, bool inplace) {
45 TORCH_CHECK(false, "mkldnn_layer_norm_last_index_weight_bias_f32: ATen not compiled with MKLDNN support");
46 }
47
_mkldnn_batch_norm_legit(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,bool train,double momentum,double eps)48 std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit(
49 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var,
50 bool train,
51 double momentum,
52 double eps) {
53 TORCH_CHECK(false, "_mkldnn_batch_norm_legit: ATen not compiled with MKLDNN support");
54 }
55
56
_mkldnn_batch_norm_legit_no_stats(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,bool train,double momentum,double eps)57 std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit_no_stats(
58 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
59 bool train,
60 double momentum,
61 double eps) {
62 TORCH_CHECK(false, "_mkldnn_batch_norm_legit_no_stats: ATen not compiled with MKLDNN support");
63 }
64
_batch_norm_with_update_mkldnn(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,double momentum,double eps)65 std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
66 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
67 Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
68 TORCH_CHECK(false, "_batch_norm_with_update_mkldnn: ATen not compiled with MKLDNN support");
69 }
70
_new_batch_norm_backward_mkldnn(const Tensor & grad_output,const Tensor & input,const Tensor & weight,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_var_opt,bool update,double eps,std::array<bool,3> grad_input_mask,const Tensor & reserve)71 std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
72 const Tensor& grad_output, const Tensor& input, const Tensor& weight,
73 const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
74 const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
75 bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
76 TORCH_CHECK(false, "_new_batch_norm_backward_mkldnn: ATen not compiled with MKLDNN support");
77 }
78
79 } // namespace native
80 } // namespace at
81
82 #else // AT_MKLDNN_ENABLED
83
84 #include <ATen/native/mkldnn/MKLDNNCommon.h>
85 #include <ATen/native/layer_norm.h>
86 #include <ideep/abstract_types.hpp>
87
88 namespace at {
89 namespace native {
90
mkldnn_layer_norm_last_index_weight_bias_f32(const Tensor & input,IntArrayRef normalized_shape,const Tensor & weight,const Tensor & bias,double eps,bool inplace)91 std::tuple<Tensor, Tensor, Tensor> mkldnn_layer_norm_last_index_weight_bias_f32(
92 const Tensor& input,
93 IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias,
94 double eps, bool inplace) {
95
96 TORCH_INTERNAL_ASSERT(normalized_shape.size() == 1, "only accept shapes with the last dimension");
97 TORCH_INTERNAL_ASSERT(input.scalar_type() == at::kFloat);
98 auto M_N = at::native::_check_layer_norm_inputs(input, normalized_shape, weight, bias);
99 auto M = M_N.first;
100
101 auto mean = empty_mkldnn(
102 {M},
103 input.scalar_type(),
104 input.options().layout_opt(),
105 input.options().device_opt(),
106 input.options().pinned_memory_opt());
107 auto rstd = empty_mkldnn(
108 {M},
109 input.scalar_type(),
110 input.options().layout_opt(),
111 input.options().device_opt(),
112 input.options().pinned_memory_opt());
113
114 auto mean_it = at::native::itensor_from_mkldnn(mean);
115 auto rstd_it = at::native::itensor_from_mkldnn(rstd);
116
117 auto input_it = at::native::itensor_from_mkldnn(input);
118 auto weight_it = at::native::itensor_from_mkldnn(weight);
119 auto bias_it = at::native::itensor_from_mkldnn(bias);
120
121 auto out_it = inplace ? input_it : ideep::tensor(input_it.get_desc());
122 ideep::layer_normalization_forward::compute(input_it, weight_it, bias_it, out_it, mean_it, rstd_it, static_cast<float>(eps));
123
124 auto dst = at::native::new_with_itensor_mkldnn(
125 std::move(out_it),
126 optTypeMetaToScalarType(input.options().dtype_opt()),
127 input.options().device_opt());
128
129 return std::make_tuple(dst, mean, rstd);
130 }
131
132
mkldnn_batch_norm(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool train,double momentum,double eps)133 std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(
134 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
135 bool train,
136 double momentum,
137 double eps) {
138 // See [Note: hacky wrapper removal for optional tensor]
139 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
140 const Tensor& weight = *weight_maybe_owned;
141 const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
142 const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
143 const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
144
145 if (input.scalar_type() == ScalarType::BFloat16) {
146 TORCH_CHECK(mkldnn_bf16_device_check(),
147 "mkldnn_batch_norm: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
148 }
149 TORCH_CHECK(weight.defined() && bias.defined(),
150 "mkldnn_batch_norm: currently mkldnn only support affine model");
151
152 ideep::tensor& x = itensor_from_mkldnn(input);
153 ideep::tensor w = itensor_from_tensor(weight);
154 ideep::tensor b = itensor_from_tensor(bias);
155 bool use_running_stat = (running_mean.defined() && running_var.defined());
156
157 ideep::tensor y;
158
159 if (train) {
160 // TODO: enable 3d batchnorm.
161 TORCH_CHECK(input.dim() == 4,
162 "mkldnn_batch_norm: currently mkldnn training only support 2d batchnorm");
163 ideep::tensor saved_mean;
164 ideep::tensor saved_var;
165 ideep::batch_normalization_forward_training::compute(
166 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
167 x, w, b, y, saved_mean, saved_var, momentum, eps);
168 if (use_running_stat) {
169 auto len = x.get_nelems() / w.get_nelems(); // n*h*w
170 ideep::tensor m = itensor_from_tensor(running_mean);
171 ideep::tensor v = itensor_from_tensor(running_var);
172 const std::vector<float> scales_mean{static_cast<float>(1 - momentum),
173 static_cast<float>(momentum)};
174 const std::vector<float> scales_var{static_cast<float>(1 - momentum),
175 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
176 static_cast<float>(momentum * len / (len - 1))};
177 ideep::sum::compute(scales_mean, {m, saved_mean}, m);
178 ideep::sum::compute(scales_var, {v, saved_var}, v);
179 }
180 return std::make_tuple(
181 new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()),
182 input.options().device_opt()),
183 new_with_itensor_mkldnn(std::move(saved_mean), optTypeMetaToScalarType(weight.options().dtype_opt()),
184 weight.options().device_opt()),
185 new_with_itensor_mkldnn(std::move(saved_var), optTypeMetaToScalarType(weight.options().dtype_opt()),
186 weight.options().device_opt()));
187 } else {
188 TORCH_CHECK(input.dim() == 4 || input.dim() == 5,
189 "mkldnn_batch_norm: currently mkldnn inference only support 2d and 3d batchnorm");
190 if (use_running_stat) {
191 ideep::tensor m = itensor_from_tensor(running_mean);
192 ideep::tensor v = itensor_from_tensor(running_var);
193 ideep::batch_normalization_forward_inference::compute(
194 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
195 x, m, v, w, b, y, eps);
196 } else {
197 // TODO: keep running estimates.
198 TORCH_CHECK(false, "mkldnn_batch_norm: mkldnn inference is not keep running estimates.");
199 }
200 return std::make_tuple(
201 new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()),
202 input.options().device_opt()),
203 new_with_itensor_mkldnn(ideep::tensor{}, optTypeMetaToScalarType(weight.options().dtype_opt()),
204 weight.options().device_opt()),
205 new_with_itensor_mkldnn(ideep::tensor{}, optTypeMetaToScalarType(weight.options().dtype_opt()),
206 weight.options().device_opt()));
207 }
208 }
209
210
_batch_norm_with_update_mkldnn(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,double momentum,double eps)211 std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
212 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
213 Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
214 auto [output, save_mean, save_var] =
215 mkldnn_batch_norm(input, weight_opt, bias_opt, running_mean, running_var, /*train*/true, momentum, eps);
216 Tensor reserve = empty_mkldnn({0}, input.scalar_type());
217 return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
218 }
219
220
_mkldnn_batch_norm_legit(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,bool train,double momentum,double eps)221 std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit(
222 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var,
223 bool train,
224 double momentum,
225 double eps) {
226 return mkldnn_batch_norm(input, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps);
227 }
228
229
_mkldnn_batch_norm_legit_no_stats(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,bool train,double momentum,double eps)230 std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit_no_stats(
231 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
232 bool train,
233 double momentum,
234 double eps) {
235 return mkldnn_batch_norm(input, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps);
236 }
237
238
_new_batch_norm_backward_mkldnn(const Tensor & grad_output,const Tensor & input,const Tensor & weight,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_var_opt,bool update,double eps,std::array<bool,3> grad_input_mask,const Tensor & reserve)239 std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
240 const Tensor& grad_output, const Tensor& input, const Tensor& weight,
241 const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
242 const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
243 bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
244 return mkldnn_batch_norm_backward(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
245 }
246
247
mkldnn_batch_norm_backward(const Tensor & grad_output,const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_invstd_opt,bool train,double eps,std::array<bool,3> grad_input_mask)248 std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(const Tensor& grad_output,
249 const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_invstd_opt,
250 bool train,
251 double eps,
252 std::array<bool,3> grad_input_mask) {
253 // See [Note: hacky wrapper removal for optional tensor]
254 c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
255 const Tensor& weight = *weight_maybe_owned;
256 const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
257 const Tensor& save_invstd = c10::value_or_else(save_invstd_opt, [] {return Tensor();});
258
259 TORCH_CHECK(train, "mkldnn_batch_norm_backward: currently mkldnn only support train model");
260 ideep::tensor& grady = itensor_from_mkldnn(grad_output);
261 ideep::tensor& x = itensor_from_mkldnn(input);
262 ideep::tensor w = itensor_from_tensor(weight);
263 ideep::tensor& m = itensor_from_mkldnn(save_mean);
264 ideep::tensor& v = itensor_from_mkldnn(save_invstd);
265
266 ideep::tensor gradx, gradw, gradb;
267 ideep::batch_normalization_backward::compute(
268 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
269 x, m, v, grady, w, gradx, gradw, gradb, eps);
270
271 return std::make_tuple(
272 new_with_itensor_mkldnn(std::move(gradx), optTypeMetaToScalarType(input.options().dtype_opt()),
273 input.options().device_opt()),
274 mkldnn_to_dense(new_with_itensor_mkldnn(std::move(gradw),
275 optTypeMetaToScalarType(weight.options().dtype_opt()),
276 weight.options().device_opt())),
277 mkldnn_to_dense(new_with_itensor_mkldnn(std::move(gradb),
278 optTypeMetaToScalarType(weight.options().dtype_opt()),
279 weight.options().device_opt())));
280 }
281
282 } // namespace native
283 } // namespace at
284
285 #endif // AT_MKLDNN_ENABLED
286