1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_ 17 #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_ 18 19 #define EIGEN_USE_THREADS 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "tensorflow/core/kernels/cwise_ops.h" 22 23 namespace Eigen { 24 namespace internal { 25 26 // Gradient for the tanh function 27 template <typename T> 28 struct scalar_tanh_gradient_op { 29 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operatorscalar_tanh_gradient_op30 operator()(const T& output, const T& output_gradient) const { 31 return output_gradient * (T(1) - output * output); 32 } 33 template <typename Packet> 34 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOpscalar_tanh_gradient_op35 packetOp(const Packet& output, const Packet& output_gradient) const { 36 return pmul(output_gradient, 37 psub(pset1<Packet>(T(1)), pmul(output, output))); 38 } 39 }; 40 template <typename T> 41 struct functor_traits<scalar_tanh_gradient_op<T>> { 42 enum { 43 Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost, 44 PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul, 45 }; 46 }; 47 48 // Gradient for the sigmoid function 49 template <typename T> 50 struct scalar_sigmoid_gradient_op { 51 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T 52 operator()(const T& output, const T& output_gradient) const { 53 return output_gradient * output * (T(1) - output); 54 } 55 template <typename Packet> 56 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 57 packetOp(const Packet& output, const Packet& output_gradient) const { 58 return pmul(output_gradient, 59 pmul(output, psub(pset1<Packet>(T(1)), output))); 60 } 61 }; 62 template <typename T> 63 struct functor_traits<scalar_sigmoid_gradient_op<T>> { 64 enum { 65 Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost, 66 PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul, 67 }; 68 }; 69 70 // Gradient for the inverse function 71 template <typename T> 72 struct scalar_inverse_gradient_op { 73 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T 74 operator()(const T& output, const T& output_gradient) const { 75 if (output_gradient == T(0)) { 76 return T(0); 77 } else { 78 const T out_conj = numext::conj(output); 79 return -out_conj * out_conj * output_gradient; 80 } 81 } 82 template <typename Packet> 83 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 84 packetOp(const Packet& output, const Packet& output_gradient) const { 85 const Packet out_conj = pconj(output); 86 return mul_no_nan_op<T>().packetOp(pnegate(pmul(out_conj, out_conj)), 87 output_gradient); 88 } 89 }; 90 template <typename T> 91 struct functor_traits<scalar_inverse_gradient_op<T>> { 92 enum { 93 Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost, 94 PacketAccess = packet_traits<T>::HasMul, 95 }; 96 }; 97 98 // Gradient for the sqrt function 99 template <typename T> 100 struct scalar_sqrt_gradient_op { 101 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T 102 operator()(const T& output, const T& output_gradient) const { 103 if (output_gradient == T(0)) { 104 return T(0); 105 } else { 106 const T out_conj = numext::conj(output); 107 return (static_cast<T>(0.5) * output_gradient) / out_conj; 108 } 109 } 110 template <typename Packet> 111 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 112 packetOp(const Packet& output, const Packet& output_gradient) const { 113 const Packet const_half = pset1<Packet>(static_cast<T>(0.5)); 114 const Packet out_conj = pconj(output); 115 return mul_no_nan_op<T>().packetOp(pdiv(const_half, out_conj), 116 output_gradient); 117 } 118 }; 119 template <typename T> 120 struct functor_traits<scalar_sqrt_gradient_op<T>> { 121 enum { 122 PacketAccess = packet_traits<T>::HasMul & packet_traits<T>::HasDiv, 123 Cost = NumTraits<T>::MulCost + scalar_div_cost<T, PacketAccess>::value, 124 }; 125 }; 126 127 // Gradient for the rsqrt function 128 template <typename T> 129 struct scalar_rsqrt_gradient_op { 130 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T 131 operator()(const T& output, const T& output_gradient) const { 132 if (output_gradient == T(0)) { 133 return T(0); 134 } else { 135 const T out_conj = numext::conj(output); 136 return static_cast<T>(-0.5) * (output_gradient * out_conj) * 137 (out_conj * out_conj); 138 } 139 } 140 template <typename Packet> 141 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet 142 packetOp(const Packet& output, const Packet& output_gradient) const { 143 const Packet const_half = pset1<Packet>(static_cast<T>(-0.5)); 144 const Packet out_conj = pconj(output); 145 auto safe_pmul = [](const Packet& a, const Packet& b) { 146 return mul_no_nan_op<T>().packetOp(a, b); 147 }; 148 return safe_pmul(pmul(const_half, pmul(out_conj, out_conj)), 149 safe_pmul(out_conj, output_gradient)); 150 } 151 }; 152 template <typename T> 153 struct functor_traits<scalar_rsqrt_gradient_op<T>> { 154 enum { 155 Cost = 4 * NumTraits<T>::MulCost, 156 PacketAccess = packet_traits<T>::HasMul, 157 }; 158 }; 159 160 } // end namespace internal 161 } // end namespace Eigen 162 163 namespace tensorflow { 164 165 namespace functor { 166 167 template <typename Device, typename Functor> 168 struct SimpleBinaryFunctor { 169 void operator()(const Device& d, typename Functor::tout_type out, 170 typename Functor::tin_type in0, 171 typename Functor::tin_type in1); 172 }; 173 174 // Partial specialization of BinaryFunctor for CPU devices 175 typedef Eigen::ThreadPoolDevice CPUDevice; 176 177 template <typename Functor> 178 struct SimpleBinaryFunctor<CPUDevice, Functor> { 179 void operator()(const CPUDevice& d, typename Functor::tout_type out, 180 typename Functor::tin_type in0, 181 typename Functor::tin_type in1) { 182 out.device(d) = in0.binaryExpr(in1, typename Functor::func()); 183 } 184 }; 185 186 187 template <typename T> 188 struct tanh_grad : base<T, Eigen::internal::scalar_tanh_gradient_op<T>> {}; 189 190 template <typename T> 191 struct sigmoid_grad : base<T, Eigen::internal::scalar_sigmoid_gradient_op<T>> { 192 }; 193 194 template <typename T> 195 struct inverse_grad : base<T, Eigen::internal::scalar_inverse_gradient_op<T>> { 196 }; 197 198 template <typename T> 199 struct sqrt_grad : base<T, Eigen::internal::scalar_sqrt_gradient_op<T>> {}; 200 201 template <typename T> 202 struct rsqrt_grad : base<T, Eigen::internal::scalar_rsqrt_gradient_op<T>> {}; 203 204 template <typename T> 205 struct igamma_grad_a : base<T, Eigen::internal::scalar_igamma_der_a_op<T>> {}; 206 207 } // end namespace functor 208 209 } // end namespace tensorflow 210 #endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_ 211