xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/relu_op.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #ifndef TENSORFLOW_CORE_KERNELS_RELU_OP_H_
19 #define TENSORFLOW_CORE_KERNELS_RELU_OP_H_
20 
21 #define EIGEN_USE_THREADS
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/kernels/relu_op_functor.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 
31 namespace tensorflow {
32 
33 template <typename Device, typename T>
34 class ReluOp : public UnaryElementWiseOp<T, ReluOp<Device, T>> {
35  public:
36   using UnaryElementWiseOp<T, ReluOp<Device, T>>::UnaryElementWiseOp;
37 
Operate(OpKernelContext * context,const Tensor & input,Tensor * output)38   void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
39     functor::Relu<Device, T> functor;
40     functor(context->eigen_device<Device>(), input.flat<T>(),
41             output->flat<T>());
42   }
43 };
44 
45 // Out of line check to save code space (we have this code once, rather
46 // than once for every NDIMS * NumTypes * Num_different_relu_variants
47 // functions.
48 struct ReluHelpers {
ValidateSameSizeHelperReluHelpers49   static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g,
50                                      const Tensor& a) {
51     OP_REQUIRES(context, a.IsSameSize(g),
52                 errors::InvalidArgument("g and a must be the same size"));
53   }
ValidateSameSizeReluHelpers54   static bool ValidateSameSize(OpKernelContext* context, const Tensor& g,
55                                const Tensor& a) {
56     ValidateSameSizeHelper(context, g, a);
57     return context->status().ok();
58   }
59 };
60 
61 template <typename Device, typename T>
62 class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
63  public:
64   using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp;
65 
66   void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
67                          const Tensor& a, Tensor* output);
68 
69   // INPUTS:
70   //   g (gradients): backpropagated gradients
71   //   a (inputs): either the inputs that were passed to ReluOp(), or its
72   //               outputs (using either one yields the same result here).
73   // OUTPUT:
74   //   gradients to backprop
75   template <int NDIMS>
Operate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)76   void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
77                Tensor* output) {
78     OperateNoTemplate(context, g, a, output);
79   }
80 };
81 
82 template <typename Device, typename T>
OperateNoTemplate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)83 void ReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
84                                               const Tensor& g, const Tensor& a,
85                                               Tensor* output) {
86   if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
87   functor::ReluGrad<Device, T> functor;
88   functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
89           output->flat<T>());
90 }
91 
92 template <typename Device, typename T>
93 class Relu6Op : public UnaryElementWiseOp<T, Relu6Op<Device, T>> {
94  public:
95   using UnaryElementWiseOp<T, Relu6Op<Device, T>>::UnaryElementWiseOp;
96 
Operate(OpKernelContext * context,const Tensor & input,Tensor * output)97   void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
98     functor::Relu6<Device, T> functor;
99     functor(context->eigen_device<Device>(), input.flat<T>(),
100             output->flat<T>());
101   }
102 };
103 
104 template <typename Device, typename T>
105 class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> {
106  public:
107   using BinaryElementWiseOp<T, Relu6GradOp<Device, T>>::BinaryElementWiseOp;
108 
109   void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
110                          const Tensor& a, Tensor* output);
111 
112   // INPUTS:
113   //   g (gradients): backpropagated gradients
114   //   a (inputs): inputs that were passed to Relu6Op()
115   // OUTPUT:
116   //   gradients to backprop
117   template <int NDIMS>
Operate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)118   void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
119                Tensor* output) {
120     OperateNoTemplate(context, g, a, output);
121   }
122 };
123 
124 template <typename Device, typename T>
OperateNoTemplate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)125 void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
126                                                const Tensor& g, const Tensor& a,
127                                                Tensor* output) {
128   if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
129   functor::Relu6Grad<Device, T> functor;
130   functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
131           output->flat<T>());
132 }
133 
134 template <typename Device, typename T>
135 class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> {
136  public:
LeakyReluOp(OpKernelConstruction * context)137   explicit LeakyReluOp(OpKernelConstruction* context)
138       : UnaryElementWiseOp<T, LeakyReluOp<Device, T>>(context) {
139     float alpha_tmp;
140     OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
141     alpha_ = T(alpha_tmp);
142   }
143 
Operate(OpKernelContext * context,const Tensor & input,Tensor * output)144   void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
145     functor::LeakyRelu<Device, T> functor;
146     functor({context->eigen_device<Device>(), input.flat<T>(), alpha_,
147              output->flat<T>()});
148   }
149 
150  private:
151   T alpha_;
152 };
153 
154 template <typename Device, typename T>
155 class LeakyReluGradOp
156     : public BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>> {
157  public:
LeakyReluGradOp(OpKernelConstruction * context)158   explicit LeakyReluGradOp(OpKernelConstruction* context)
159       : BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>>(context) {
160     float alpha_tmp;
161     OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
162     alpha_ = T(alpha_tmp);
163   }
164 
165   void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
166                          const Tensor& a, T alpha, Tensor* output);
167 
168   // INPUTS:
169   //   g (gradients): backpropagated gradients
170   //   a (inputs): either the inputs that were passed to LeakyReluOp(), or its
171   //               outputs (using either one yields the same result here).
172   // OUTPUT:
173   //   gradients to backprop
174   template <int NDIMS>
Operate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)175   void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
176                Tensor* output) {
177     OperateNoTemplate(context, g, a, alpha_, output);
178   }
179 
180  private:
181   T alpha_;
182 };
183 
184 template <typename Device, typename T>
OperateNoTemplate(OpKernelContext * context,const Tensor & g,const Tensor & a,T alpha,Tensor * output)185 void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
186                                                    const Tensor& g,
187                                                    const Tensor& a, T alpha,
188                                                    Tensor* output) {
189   if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
190   functor::LeakyReluGrad<Device, T> functor;
191   functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha,
192           output->flat<T>());
193 };
194 
195 template <typename Device, typename T>
196 class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
197  public:
198   using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp;
199 
Operate(OpKernelContext * context,const Tensor & input,Tensor * output)200   void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
201     functor::Elu<Device, T> functor;
202     functor(context->eigen_device<Device>(), input.flat<T>(),
203             output->flat<T>());
204   }
205 };
206 
207 template <typename Device, typename T>
208 class EluGradOp : public BinaryElementWiseOp<T, EluGradOp<Device, T>> {
209  public:
210   using BinaryElementWiseOp<T, EluGradOp<Device, T>>::BinaryElementWiseOp;
211 
212   void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
213                          const Tensor& a, Tensor* output);
214 
215   // INPUTS:
216   //   g (gradients): backpropagated gradients
217   //   a (outputs): outputs of the EluOp()
218   // OUTPUT:
219   //   gradients to backprop
220   template <int NDIMS>
Operate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)221   void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
222                Tensor* output) {
223     OperateNoTemplate(context, g, a, output);
224   }
225 };
226 
227 template <typename Device, typename T>
OperateNoTemplate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)228 void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
229                                              const Tensor& g, const Tensor& a,
230                                              Tensor* output) {
231   if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
232   functor::EluGrad<Device, T> functor;
233   functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
234           output->flat<T>());
235 }
236 
237 template <typename Device, typename T>
238 class SeluOp : public UnaryElementWiseOp<T, SeluOp<Device, T>> {
239  public:
240   using UnaryElementWiseOp<T, SeluOp<Device, T>>::UnaryElementWiseOp;
241 
Operate(OpKernelContext * context,const Tensor & input,Tensor * output)242   void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
243     functor::Selu<Device, T> functor;
244     functor(context->eigen_device<Device>(), input.flat<T>(),
245             output->flat<T>());
246   }
247 };
248 
249 template <typename Device, typename T>
250 class SeluGradOp : public BinaryElementWiseOp<T, SeluGradOp<Device, T>> {
251  public:
252   using BinaryElementWiseOp<T, SeluGradOp<Device, T>>::BinaryElementWiseOp;
253 
254   void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
255                          const Tensor& a, Tensor* output);
256 
257   // INPUTS:
258   //   g (gradients): backpropagated gradients
259   //   a (outputs): outputs of the SeluOp()
260   // OUTPUT:
261   //   gradients to backprop
262   template <int NDIMS>
Operate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)263   void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
264                Tensor* output) {
265     OperateNoTemplate(context, g, a, output);
266   }
267 };
268 
269 template <typename Device, typename T>
OperateNoTemplate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)270 void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
271                                               const Tensor& g, const Tensor& a,
272                                               Tensor* output) {
273   if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
274   functor::SeluGrad<Device, T> functor;
275   functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
276           output->flat<T>());
277 }
278 
279 }  // namespace tensorflow
280 
281 #undef EIGEN_USE_THREADS
282 
283 #endif  // TENSORFLOW_CORE_KERNELS_RELU_OP_H_
284