xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/Prelu.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/NativeFunctions.h>
3 #include <ATen/Config.h>
4 
5 
6 #if !AT_MKLDNN_ENABLED()
7 
8 namespace at { namespace native {
9 
mkldnn_prelu(const Tensor & input,const Tensor & weight)10 Tensor mkldnn_prelu(const Tensor& input, const Tensor& weight) {
11   TORCH_CHECK(false, "mkldnn_prelu: ATen not compiled with MKLDNN support");
12 }
13 
mkldnn_prelu_backward(const Tensor & grad_output,const Tensor & input,const Tensor & weight)14 std::tuple<Tensor, Tensor> mkldnn_prelu_backward(const Tensor& grad_output, const Tensor& input, const Tensor& weight) {
15   TORCH_CHECK(false, "mkldnn_prelu_backward: ATen not compiled with MKLDNN support");
16 }
17 
18 }}
19 
20 #else // AT_MKLDNN_ENABLED
21 
22 #include <ATen/native/mkldnn/MKLDNNCommon.h>
23 #include <ATen/native/mkldnn/Utils.h>
24 
25 namespace at { namespace native {
26 
mkldnn_prelu(const Tensor & input,const Tensor & weight)27 Tensor mkldnn_prelu(const Tensor& input, const Tensor& weight) {
28   if (input.scalar_type() == ScalarType::BFloat16) {
29     TORCH_CHECK(mkldnn_bf16_device_check(),
30         "mkldnn_relu: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
31   }
32 
33   const ideep::tensor& x = itensor_from_mkldnn(input);
34   const ideep::tensor& w = itensor_from_tensor(weight);
35 
36   ideep::tensor y;
37   ideep::prelu_forward::compute(
38       x, w, y, ideep::prop_kind::forward_training);
39   return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()),
40                                  input.options().device_opt());
41 }
42 
mkldnn_prelu_backward(const Tensor & grad_output,const Tensor & input,const Tensor & weight)43 std::tuple<Tensor, Tensor> mkldnn_prelu_backward(const Tensor& grad_output, const Tensor& input, const Tensor& weight) {
44   const ideep::tensor& x = itensor_from_mkldnn(input);
45   const ideep::tensor& w = itensor_from_tensor(weight);
46   const ideep::tensor grady = itensor_from_mkldnn(grad_output);
47   ideep::tensor gradx;
48   ideep::tensor gradw;
49 
50   ideep::prelu_backward::compute(
51       x, w, grady, gradx, gradw, ideep::prop_kind::backward);
52   if (weight.is_mkldnn()) {
53     return std::make_tuple(
54         new_with_itensor_mkldnn(std::move(gradx),
55                                 optTypeMetaToScalarType(grad_output.options().dtype_opt()),
56                                 grad_output.options().device_opt()),
57         new_with_itensor_mkldnn(std::move(gradw),
58                                 optTypeMetaToScalarType(weight.options().dtype_opt()),
59                                 weight.options().device_opt()));
60   } else {
61     return std::make_tuple(
62         new_with_itensor_mkldnn(std::move(gradx),
63                                 optTypeMetaToScalarType(grad_output.options().dtype_opt()),
64                                 grad_output.options().device_opt()),
65         mkldnn_to_dense(new_with_itensor_mkldnn(std::move(gradw),
66                                                 optTypeMetaToScalarType(weight.options().dtype_opt()),
67                                                 weight.options().device_opt())));
68   }
69 }
70 }}
71 
72 #endif // AT_MKLDNN_ENABLED
73