xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/cwise_ops_gradients.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
18 
19 #define EIGEN_USE_THREADS
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/kernels/cwise_ops.h"
22 
23 namespace Eigen {
24 namespace internal {
25 
26 // Gradient for the tanh function
27 template <typename T>
28 struct scalar_tanh_gradient_op {
29   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
operatorscalar_tanh_gradient_op30   operator()(const T& output, const T& output_gradient) const {
31     return output_gradient * (T(1) - output * output);
32   }
33   template <typename Packet>
34   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
packetOpscalar_tanh_gradient_op35   packetOp(const Packet& output, const Packet& output_gradient) const {
36     return pmul(output_gradient,
37                 psub(pset1<Packet>(T(1)), pmul(output, output)));
38   }
39 };
40 template <typename T>
41 struct functor_traits<scalar_tanh_gradient_op<T>> {
42   enum {
43     Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
44     PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul,
45   };
46 };
47 
48 // Gradient for the sigmoid function
49 template <typename T>
50 struct scalar_sigmoid_gradient_op {
51   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
52   operator()(const T& output, const T& output_gradient) const {
53     return output_gradient * output * (T(1) - output);
54   }
55   template <typename Packet>
56   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
57   packetOp(const Packet& output, const Packet& output_gradient) const {
58     return pmul(output_gradient,
59                 pmul(output, psub(pset1<Packet>(T(1)), output)));
60   }
61 };
62 template <typename T>
63 struct functor_traits<scalar_sigmoid_gradient_op<T>> {
64   enum {
65     Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
66     PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul,
67   };
68 };
69 
70 // Gradient for the inverse function
71 template <typename T>
72 struct scalar_inverse_gradient_op {
73   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
74   operator()(const T& output, const T& output_gradient) const {
75     if (output_gradient == T(0)) {
76       return T(0);
77     } else {
78       const T out_conj = numext::conj(output);
79       return -out_conj * out_conj * output_gradient;
80     }
81   }
82   template <typename Packet>
83   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
84   packetOp(const Packet& output, const Packet& output_gradient) const {
85     const Packet out_conj = pconj(output);
86     return mul_no_nan_op<T>().packetOp(pnegate(pmul(out_conj, out_conj)),
87                                        output_gradient);
88   }
89 };
90 template <typename T>
91 struct functor_traits<scalar_inverse_gradient_op<T>> {
92   enum {
93     Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
94     PacketAccess = packet_traits<T>::HasMul,
95   };
96 };
97 
98 // Gradient for the sqrt function
99 template <typename T>
100 struct scalar_sqrt_gradient_op {
101   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
102   operator()(const T& output, const T& output_gradient) const {
103     if (output_gradient == T(0)) {
104       return T(0);
105     } else {
106       const T out_conj = numext::conj(output);
107       return (static_cast<T>(0.5) * output_gradient) / out_conj;
108     }
109   }
110   template <typename Packet>
111   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
112   packetOp(const Packet& output, const Packet& output_gradient) const {
113     const Packet const_half = pset1<Packet>(static_cast<T>(0.5));
114     const Packet out_conj = pconj(output);
115     return mul_no_nan_op<T>().packetOp(pdiv(const_half, out_conj),
116                                        output_gradient);
117   }
118 };
119 template <typename T>
120 struct functor_traits<scalar_sqrt_gradient_op<T>> {
121   enum {
122     PacketAccess = packet_traits<T>::HasMul & packet_traits<T>::HasDiv,
123     Cost = NumTraits<T>::MulCost + scalar_div_cost<T, PacketAccess>::value,
124   };
125 };
126 
127 // Gradient for the rsqrt function
128 template <typename T>
129 struct scalar_rsqrt_gradient_op {
130   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
131   operator()(const T& output, const T& output_gradient) const {
132     if (output_gradient == T(0)) {
133       return T(0);
134     } else {
135       const T out_conj = numext::conj(output);
136       return static_cast<T>(-0.5) * (output_gradient * out_conj) *
137              (out_conj * out_conj);
138     }
139   }
140   template <typename Packet>
141   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
142   packetOp(const Packet& output, const Packet& output_gradient) const {
143     const Packet const_half = pset1<Packet>(static_cast<T>(-0.5));
144     const Packet out_conj = pconj(output);
145     auto safe_pmul = [](const Packet& a, const Packet& b) {
146       return mul_no_nan_op<T>().packetOp(a, b);
147     };
148     return safe_pmul(pmul(const_half, pmul(out_conj, out_conj)),
149                      safe_pmul(out_conj, output_gradient));
150   }
151 };
152 template <typename T>
153 struct functor_traits<scalar_rsqrt_gradient_op<T>> {
154   enum {
155     Cost = 4 * NumTraits<T>::MulCost,
156     PacketAccess = packet_traits<T>::HasMul,
157   };
158 };
159 
160 }  // end namespace internal
161 }  // end namespace Eigen
162 
163 namespace tensorflow {
164 
165 namespace functor {
166 
167 template <typename Device, typename Functor>
168 struct SimpleBinaryFunctor {
169   void operator()(const Device& d, typename Functor::tout_type out,
170                   typename Functor::tin_type in0,
171                   typename Functor::tin_type in1);
172 };
173 
174 // Partial specialization of BinaryFunctor for CPU devices
175 typedef Eigen::ThreadPoolDevice CPUDevice;
176 
177 template <typename Functor>
178 struct SimpleBinaryFunctor<CPUDevice, Functor> {
179   void operator()(const CPUDevice& d, typename Functor::tout_type out,
180                   typename Functor::tin_type in0,
181                   typename Functor::tin_type in1) {
182     out.device(d) = in0.binaryExpr(in1, typename Functor::func());
183   }
184 };
185 
186 
187 template <typename T>
188 struct tanh_grad : base<T, Eigen::internal::scalar_tanh_gradient_op<T>> {};
189 
190 template <typename T>
191 struct sigmoid_grad : base<T, Eigen::internal::scalar_sigmoid_gradient_op<T>> {
192 };
193 
194 template <typename T>
195 struct inverse_grad : base<T, Eigen::internal::scalar_inverse_gradient_op<T>> {
196 };
197 
198 template <typename T>
199 struct sqrt_grad : base<T, Eigen::internal::scalar_sqrt_gradient_op<T>> {};
200 
201 template <typename T>
202 struct rsqrt_grad : base<T, Eigen::internal::scalar_rsqrt_gradient_op<T>> {};
203 
204 template <typename T>
205 struct igamma_grad_a : base<T, Eigen::internal::scalar_igamma_der_a_op<T>> {};
206 
207 }  // end namespace functor
208 
209 }  // end namespace tensorflow
210 #endif  // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
211