xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/Normalization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <tuple>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/_batch_norm_with_update_native.h>
10 #include <ATen/ops/batch_norm_backward_native.h>
11 #include <ATen/ops/_native_batch_norm_legit_native.h>
12 #include <ATen/ops/_to_dense_native.h>
13 #include <ATen/ops/empty_native.h>
14 #include <ATen/ops/native_batch_norm_backward_native.h>
15 #include <ATen/ops/native_batch_norm_native.h>
16 #endif
17 #include <ATen/native/mkldnn/Utils.h>
18 
19 #if !AT_MKLDNN_ENABLED()
20 
21 namespace at {
22 namespace native {
23 
mkldnn_batch_norm(const Tensor & self,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool train,double momentum,double eps)24 std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(
25     const Tensor& self, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
26     bool train,
27     double momentum,
28     double eps) {
29   TORCH_CHECK(false, "mkldnn_batch_norm: ATen not compiled with MKLDNN support");
30 }
31 
mkldnn_batch_norm_backward(const Tensor & grad_output,const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_invstd_opt,bool train,double eps,std::array<bool,3> grad_input_mask)32 std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(
33     const Tensor& grad_output,
34     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_invstd_opt,
35     bool train,
36     double eps,
37     std::array<bool,3> grad_input_mask) {
38   TORCH_CHECK(false, "mkldnn_batch_norm_backward: ATen not compiled with MKLDNN support");
39 }
40 
mkldnn_layer_norm_last_index_weight_bias_f32(const Tensor & input,IntArrayRef normalized_shape,const Tensor & weight,const Tensor & bias,double eps,bool inplace)41 std::tuple<Tensor, Tensor, Tensor> mkldnn_layer_norm_last_index_weight_bias_f32(
42     const Tensor& input,
43     IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias,
44     double eps, bool inplace) {
45   TORCH_CHECK(false, "mkldnn_layer_norm_last_index_weight_bias_f32: ATen not compiled with MKLDNN support");
46 }
47 
_mkldnn_batch_norm_legit(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,bool train,double momentum,double eps)48 std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit(
49     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var,
50     bool train,
51     double momentum,
52     double eps) {
53   TORCH_CHECK(false, "_mkldnn_batch_norm_legit: ATen not compiled with MKLDNN support");
54 }
55 
56 
_mkldnn_batch_norm_legit_no_stats(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,bool train,double momentum,double eps)57 std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit_no_stats(
58     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
59     bool train,
60     double momentum,
61     double eps) {
62   TORCH_CHECK(false, "_mkldnn_batch_norm_legit_no_stats: ATen not compiled with MKLDNN support");
63 }
64 
_batch_norm_with_update_mkldnn(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,double momentum,double eps)65 std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
66     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
67     Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
68   TORCH_CHECK(false, "_batch_norm_with_update_mkldnn: ATen not compiled with MKLDNN support");
69 }
70 
_new_batch_norm_backward_mkldnn(const Tensor & grad_output,const Tensor & input,const Tensor & weight,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_var_opt,bool update,double eps,std::array<bool,3> grad_input_mask,const Tensor & reserve)71 std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
72     const Tensor& grad_output, const Tensor& input, const Tensor& weight,
73     const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
74     const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
75     bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
76   TORCH_CHECK(false, "_new_batch_norm_backward_mkldnn: ATen not compiled with MKLDNN support");
77 }
78 
79 } // namespace native
80 } // namespace at
81 
82 #else // AT_MKLDNN_ENABLED
83 
84 #include <ATen/native/mkldnn/MKLDNNCommon.h>
85 #include <ATen/native/layer_norm.h>
86 #include <ideep/abstract_types.hpp>
87 
88 namespace at {
89 namespace native {
90 
mkldnn_layer_norm_last_index_weight_bias_f32(const Tensor & input,IntArrayRef normalized_shape,const Tensor & weight,const Tensor & bias,double eps,bool inplace)91 std::tuple<Tensor, Tensor, Tensor> mkldnn_layer_norm_last_index_weight_bias_f32(
92     const Tensor& input,
93     IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias,
94     double eps, bool inplace) {
95 
96   TORCH_INTERNAL_ASSERT(normalized_shape.size() == 1, "only accept shapes with the last dimension");
97   TORCH_INTERNAL_ASSERT(input.scalar_type() == at::kFloat);
98   auto M_N = at::native::_check_layer_norm_inputs(input, normalized_shape, weight, bias);
99   auto M = M_N.first;
100 
101   auto mean = empty_mkldnn(
102         {M},
103         input.scalar_type(),
104         input.options().layout_opt(),
105         input.options().device_opt(),
106         input.options().pinned_memory_opt());
107   auto rstd = empty_mkldnn(
108         {M},
109         input.scalar_type(),
110         input.options().layout_opt(),
111         input.options().device_opt(),
112         input.options().pinned_memory_opt());
113 
114   auto mean_it = at::native::itensor_from_mkldnn(mean);
115   auto rstd_it = at::native::itensor_from_mkldnn(rstd);
116 
117   auto input_it = at::native::itensor_from_mkldnn(input);
118   auto weight_it = at::native::itensor_from_mkldnn(weight);
119   auto bias_it = at::native::itensor_from_mkldnn(bias);
120 
121   auto out_it = inplace ? input_it : ideep::tensor(input_it.get_desc());
122   ideep::layer_normalization_forward::compute(input_it, weight_it, bias_it, out_it, mean_it, rstd_it, static_cast<float>(eps));
123 
124   auto dst = at::native::new_with_itensor_mkldnn(
125       std::move(out_it),
126       optTypeMetaToScalarType(input.options().dtype_opt()),
127       input.options().device_opt());
128 
129   return std::make_tuple(dst, mean, rstd);
130 }
131 
132 
mkldnn_batch_norm(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool train,double momentum,double eps)133 std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(
134     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
135     bool train,
136     double momentum,
137     double eps) {
138   // See [Note: hacky wrapper removal for optional tensor]
139   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
140   const Tensor& weight = *weight_maybe_owned;
141   const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
142   const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
143   const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();});
144 
145   if (input.scalar_type() == ScalarType::BFloat16) {
146     TORCH_CHECK(mkldnn_bf16_device_check(),
147         "mkldnn_batch_norm: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
148   }
149   TORCH_CHECK(weight.defined() && bias.defined(),
150              "mkldnn_batch_norm: currently mkldnn only support affine model");
151 
152   ideep::tensor& x = itensor_from_mkldnn(input);
153   ideep::tensor w = itensor_from_tensor(weight);
154   ideep::tensor b = itensor_from_tensor(bias);
155   bool use_running_stat = (running_mean.defined() && running_var.defined());
156 
157   ideep::tensor y;
158 
159   if (train) {
160     // TODO: enable 3d batchnorm.
161     TORCH_CHECK(input.dim() == 4,
162         "mkldnn_batch_norm: currently mkldnn training only support 2d batchnorm");
163     ideep::tensor saved_mean;
164     ideep::tensor saved_var;
165     ideep::batch_normalization_forward_training::compute(
166         // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
167         x, w, b, y, saved_mean, saved_var, momentum, eps);
168     if (use_running_stat) {
169       auto len = x.get_nelems() / w.get_nelems(); // n*h*w
170       ideep::tensor m = itensor_from_tensor(running_mean);
171       ideep::tensor v = itensor_from_tensor(running_var);
172       const std::vector<float> scales_mean{static_cast<float>(1 - momentum),
173                                            static_cast<float>(momentum)};
174       const std::vector<float> scales_var{static_cast<float>(1 - momentum),
175                                           // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
176                                           static_cast<float>(momentum * len / (len - 1))};
177       ideep::sum::compute(scales_mean, {m, saved_mean}, m);
178       ideep::sum::compute(scales_var, {v, saved_var}, v);
179     }
180     return std::make_tuple(
181          new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()),
182                                  input.options().device_opt()),
183          new_with_itensor_mkldnn(std::move(saved_mean), optTypeMetaToScalarType(weight.options().dtype_opt()),
184                                  weight.options().device_opt()),
185          new_with_itensor_mkldnn(std::move(saved_var), optTypeMetaToScalarType(weight.options().dtype_opt()),
186                                  weight.options().device_opt()));
187   } else {
188     TORCH_CHECK(input.dim() == 4 || input.dim() == 5,
189         "mkldnn_batch_norm: currently mkldnn inference only support 2d and 3d batchnorm");
190     if (use_running_stat) {
191       ideep::tensor m = itensor_from_tensor(running_mean);
192       ideep::tensor v = itensor_from_tensor(running_var);
193       ideep::batch_normalization_forward_inference::compute(
194           // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
195           x, m, v, w, b, y, eps);
196     } else {
197       // TODO: keep running estimates.
198       TORCH_CHECK(false, "mkldnn_batch_norm: mkldnn inference is not keep running estimates.");
199     }
200     return std::make_tuple(
201         new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()),
202                                 input.options().device_opt()),
203         new_with_itensor_mkldnn(ideep::tensor{}, optTypeMetaToScalarType(weight.options().dtype_opt()),
204                                 weight.options().device_opt()),
205         new_with_itensor_mkldnn(ideep::tensor{}, optTypeMetaToScalarType(weight.options().dtype_opt()),
206                                 weight.options().device_opt()));
207   }
208 }
209 
210 
_batch_norm_with_update_mkldnn(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,double momentum,double eps)211 std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mkldnn(
212     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
213     Tensor& running_mean, Tensor& running_var, double momentum, double eps) {
214   auto [output, save_mean, save_var] =
215     mkldnn_batch_norm(input, weight_opt, bias_opt, running_mean, running_var, /*train*/true, momentum, eps);
216   Tensor reserve = empty_mkldnn({0}, input.scalar_type());
217   return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);
218 }
219 
220 
_mkldnn_batch_norm_legit(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,Tensor & running_mean,Tensor & running_var,bool train,double momentum,double eps)221 std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit(
222     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt, Tensor& running_mean, Tensor& running_var,
223     bool train,
224     double momentum,
225     double eps) {
226   return mkldnn_batch_norm(input, weight_opt, bias_opt, running_mean, running_var, train, momentum, eps);
227 }
228 
229 
_mkldnn_batch_norm_legit_no_stats(const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & bias_opt,bool train,double momentum,double eps)230 std::tuple<Tensor, Tensor, Tensor> _mkldnn_batch_norm_legit_no_stats(
231     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
232     bool train,
233     double momentum,
234     double eps) {
235   return mkldnn_batch_norm(input, weight_opt, bias_opt, Tensor(), Tensor(), train, momentum, eps);
236 }
237 
238 
_new_batch_norm_backward_mkldnn(const Tensor & grad_output,const Tensor & input,const Tensor & weight,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_var_opt,bool update,double eps,std::array<bool,3> grad_input_mask,const Tensor & reserve)239 std::tuple<Tensor, Tensor, Tensor> _new_batch_norm_backward_mkldnn(
240     const Tensor& grad_output, const Tensor& input, const Tensor& weight,
241     const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
242     const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_var_opt,
243     bool update, double eps, std::array<bool,3> grad_input_mask, const Tensor& reserve) {
244   return mkldnn_batch_norm_backward(grad_output, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_var_opt, update, eps, grad_input_mask);
245 }
246 
247 
mkldnn_batch_norm_backward(const Tensor & grad_output,const Tensor & input,const std::optional<Tensor> & weight_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_invstd_opt,bool train,double eps,std::array<bool,3> grad_input_mask)248 std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm_backward(const Tensor& grad_output,
249     const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt, const std::optional<Tensor>& save_mean_opt, const std::optional<Tensor>& save_invstd_opt,
250     bool train,
251     double eps,
252     std::array<bool,3> grad_input_mask) {
253   // See [Note: hacky wrapper removal for optional tensor]
254   c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
255   const Tensor& weight = *weight_maybe_owned;
256   const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();});
257   const Tensor& save_invstd = c10::value_or_else(save_invstd_opt, [] {return Tensor();});
258 
259   TORCH_CHECK(train, "mkldnn_batch_norm_backward: currently mkldnn only support train model");
260   ideep::tensor& grady = itensor_from_mkldnn(grad_output);
261   ideep::tensor& x = itensor_from_mkldnn(input);
262   ideep::tensor w = itensor_from_tensor(weight);
263   ideep::tensor& m = itensor_from_mkldnn(save_mean);
264   ideep::tensor& v = itensor_from_mkldnn(save_invstd);
265 
266   ideep::tensor gradx, gradw, gradb;
267   ideep::batch_normalization_backward::compute(
268       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
269       x, m, v, grady, w, gradx, gradw, gradb, eps);
270 
271   return std::make_tuple(
272       new_with_itensor_mkldnn(std::move(gradx), optTypeMetaToScalarType(input.options().dtype_opt()),
273                               input.options().device_opt()),
274       mkldnn_to_dense(new_with_itensor_mkldnn(std::move(gradw),
275                                               optTypeMetaToScalarType(weight.options().dtype_opt()),
276                                               weight.options().device_opt())),
277       mkldnn_to_dense(new_with_itensor_mkldnn(std::move(gradb),
278                                               optTypeMetaToScalarType(weight.options().dtype_opt()),
279                                               weight.options().device_opt())));
280 }
281 
282 } // namespace native
283 } // namespace at
284 
285 #endif // AT_MKLDNN_ENABLED
286