1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/nn_ops.cc.
17
18 #ifndef TENSORFLOW_CORE_KERNELS_RELU_OP_H_
19 #define TENSORFLOW_CORE_KERNELS_RELU_OP_H_
20
21 #define EIGEN_USE_THREADS
22
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "tensorflow/core/framework/numeric_op.h"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/framework/register_types.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/kernels/relu_op_functor.h"
29 #include "tensorflow/core/lib/core/errors.h"
30
31 namespace tensorflow {
32
33 template <typename Device, typename T>
34 class ReluOp : public UnaryElementWiseOp<T, ReluOp<Device, T>> {
35 public:
36 using UnaryElementWiseOp<T, ReluOp<Device, T>>::UnaryElementWiseOp;
37
Operate(OpKernelContext * context,const Tensor & input,Tensor * output)38 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
39 functor::Relu<Device, T> functor;
40 functor(context->eigen_device<Device>(), input.flat<T>(),
41 output->flat<T>());
42 }
43 };
44
45 // Out of line check to save code space (we have this code once, rather
46 // than once for every NDIMS * NumTypes * Num_different_relu_variants
47 // functions.
48 struct ReluHelpers {
ValidateSameSizeHelperReluHelpers49 static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g,
50 const Tensor& a) {
51 OP_REQUIRES(context, a.IsSameSize(g),
52 errors::InvalidArgument("g and a must be the same size"));
53 }
ValidateSameSizeReluHelpers54 static bool ValidateSameSize(OpKernelContext* context, const Tensor& g,
55 const Tensor& a) {
56 ValidateSameSizeHelper(context, g, a);
57 return context->status().ok();
58 }
59 };
60
61 template <typename Device, typename T>
62 class ReluGradOp : public BinaryElementWiseOp<T, ReluGradOp<Device, T>> {
63 public:
64 using BinaryElementWiseOp<T, ReluGradOp<Device, T>>::BinaryElementWiseOp;
65
66 void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
67 const Tensor& a, Tensor* output);
68
69 // INPUTS:
70 // g (gradients): backpropagated gradients
71 // a (inputs): either the inputs that were passed to ReluOp(), or its
72 // outputs (using either one yields the same result here).
73 // OUTPUT:
74 // gradients to backprop
75 template <int NDIMS>
Operate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)76 void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
77 Tensor* output) {
78 OperateNoTemplate(context, g, a, output);
79 }
80 };
81
82 template <typename Device, typename T>
OperateNoTemplate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)83 void ReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
84 const Tensor& g, const Tensor& a,
85 Tensor* output) {
86 if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
87 functor::ReluGrad<Device, T> functor;
88 functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
89 output->flat<T>());
90 }
91
92 template <typename Device, typename T>
93 class Relu6Op : public UnaryElementWiseOp<T, Relu6Op<Device, T>> {
94 public:
95 using UnaryElementWiseOp<T, Relu6Op<Device, T>>::UnaryElementWiseOp;
96
Operate(OpKernelContext * context,const Tensor & input,Tensor * output)97 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
98 functor::Relu6<Device, T> functor;
99 functor(context->eigen_device<Device>(), input.flat<T>(),
100 output->flat<T>());
101 }
102 };
103
104 template <typename Device, typename T>
105 class Relu6GradOp : public BinaryElementWiseOp<T, Relu6GradOp<Device, T>> {
106 public:
107 using BinaryElementWiseOp<T, Relu6GradOp<Device, T>>::BinaryElementWiseOp;
108
109 void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
110 const Tensor& a, Tensor* output);
111
112 // INPUTS:
113 // g (gradients): backpropagated gradients
114 // a (inputs): inputs that were passed to Relu6Op()
115 // OUTPUT:
116 // gradients to backprop
117 template <int NDIMS>
Operate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)118 void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
119 Tensor* output) {
120 OperateNoTemplate(context, g, a, output);
121 }
122 };
123
124 template <typename Device, typename T>
OperateNoTemplate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)125 void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
126 const Tensor& g, const Tensor& a,
127 Tensor* output) {
128 if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
129 functor::Relu6Grad<Device, T> functor;
130 functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
131 output->flat<T>());
132 }
133
134 template <typename Device, typename T>
135 class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> {
136 public:
LeakyReluOp(OpKernelConstruction * context)137 explicit LeakyReluOp(OpKernelConstruction* context)
138 : UnaryElementWiseOp<T, LeakyReluOp<Device, T>>(context) {
139 float alpha_tmp;
140 OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
141 alpha_ = T(alpha_tmp);
142 }
143
Operate(OpKernelContext * context,const Tensor & input,Tensor * output)144 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
145 functor::LeakyRelu<Device, T> functor;
146 functor({context->eigen_device<Device>(), input.flat<T>(), alpha_,
147 output->flat<T>()});
148 }
149
150 private:
151 T alpha_;
152 };
153
154 template <typename Device, typename T>
155 class LeakyReluGradOp
156 : public BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>> {
157 public:
LeakyReluGradOp(OpKernelConstruction * context)158 explicit LeakyReluGradOp(OpKernelConstruction* context)
159 : BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>>(context) {
160 float alpha_tmp;
161 OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
162 alpha_ = T(alpha_tmp);
163 }
164
165 void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
166 const Tensor& a, T alpha, Tensor* output);
167
168 // INPUTS:
169 // g (gradients): backpropagated gradients
170 // a (inputs): either the inputs that were passed to LeakyReluOp(), or its
171 // outputs (using either one yields the same result here).
172 // OUTPUT:
173 // gradients to backprop
174 template <int NDIMS>
Operate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)175 void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
176 Tensor* output) {
177 OperateNoTemplate(context, g, a, alpha_, output);
178 }
179
180 private:
181 T alpha_;
182 };
183
184 template <typename Device, typename T>
OperateNoTemplate(OpKernelContext * context,const Tensor & g,const Tensor & a,T alpha,Tensor * output)185 void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
186 const Tensor& g,
187 const Tensor& a, T alpha,
188 Tensor* output) {
189 if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
190 functor::LeakyReluGrad<Device, T> functor;
191 functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha,
192 output->flat<T>());
193 };
194
195 template <typename Device, typename T>
196 class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
197 public:
198 using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp;
199
Operate(OpKernelContext * context,const Tensor & input,Tensor * output)200 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
201 functor::Elu<Device, T> functor;
202 functor(context->eigen_device<Device>(), input.flat<T>(),
203 output->flat<T>());
204 }
205 };
206
207 template <typename Device, typename T>
208 class EluGradOp : public BinaryElementWiseOp<T, EluGradOp<Device, T>> {
209 public:
210 using BinaryElementWiseOp<T, EluGradOp<Device, T>>::BinaryElementWiseOp;
211
212 void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
213 const Tensor& a, Tensor* output);
214
215 // INPUTS:
216 // g (gradients): backpropagated gradients
217 // a (outputs): outputs of the EluOp()
218 // OUTPUT:
219 // gradients to backprop
220 template <int NDIMS>
Operate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)221 void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
222 Tensor* output) {
223 OperateNoTemplate(context, g, a, output);
224 }
225 };
226
227 template <typename Device, typename T>
OperateNoTemplate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)228 void EluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
229 const Tensor& g, const Tensor& a,
230 Tensor* output) {
231 if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
232 functor::EluGrad<Device, T> functor;
233 functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
234 output->flat<T>());
235 }
236
237 template <typename Device, typename T>
238 class SeluOp : public UnaryElementWiseOp<T, SeluOp<Device, T>> {
239 public:
240 using UnaryElementWiseOp<T, SeluOp<Device, T>>::UnaryElementWiseOp;
241
Operate(OpKernelContext * context,const Tensor & input,Tensor * output)242 void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
243 functor::Selu<Device, T> functor;
244 functor(context->eigen_device<Device>(), input.flat<T>(),
245 output->flat<T>());
246 }
247 };
248
249 template <typename Device, typename T>
250 class SeluGradOp : public BinaryElementWiseOp<T, SeluGradOp<Device, T>> {
251 public:
252 using BinaryElementWiseOp<T, SeluGradOp<Device, T>>::BinaryElementWiseOp;
253
254 void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
255 const Tensor& a, Tensor* output);
256
257 // INPUTS:
258 // g (gradients): backpropagated gradients
259 // a (outputs): outputs of the SeluOp()
260 // OUTPUT:
261 // gradients to backprop
262 template <int NDIMS>
Operate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)263 void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
264 Tensor* output) {
265 OperateNoTemplate(context, g, a, output);
266 }
267 };
268
269 template <typename Device, typename T>
OperateNoTemplate(OpKernelContext * context,const Tensor & g,const Tensor & a,Tensor * output)270 void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
271 const Tensor& g, const Tensor& a,
272 Tensor* output) {
273 if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
274 functor::SeluGrad<Device, T> functor;
275 functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
276 output->flat<T>());
277 }
278
279 } // namespace tensorflow
280
281 #undef EIGEN_USE_THREADS
282
283 #endif // TENSORFLOW_CORE_KERNELS_RELU_OP_H_
284