xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/reduction_gpu_kernels.cu.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
17 #define TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 
21 #define EIGEN_USE_GPU
22 
23 #include <sstream>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26 #include "tensorflow/core/kernels/gpu_prim.h"
27 #include "tensorflow/core/kernels/reduction_ops.h"
28 #include "tensorflow/core/lib/core/bits.h"
29 #include "tensorflow/core/util/gpu_device_functions.h"
30 #include "tensorflow/core/util/gpu_kernel_helper.h"
31 #include "tensorflow/core/util/permutation_input_iterator.h"
32 #include "tensorflow/core/util/transform_output_iterator.h"
33 
34 namespace tensorflow {
35 namespace functor {
36 
37 typedef Eigen::GpuDevice GPUDevice;
38 
39 template <typename T>
40 struct SqrtOfReal {
operatorSqrtOfReal41   __host__ __device__ T operator()(const T& a) const {
42     return T(Eigen::numext::sqrt(Eigen::numext::real(a)));
43   }
44 };
45 
46 template <typename T>
47 struct Sum {
operatorSum48   __host__ __device__ T operator()(const T& a, const T& b) const {
49     return a + b;
50   }
51 };
52 
53 template <typename T>
54 struct Prod {
operatorProd55   __host__ __device__ T operator()(const T& a, const T& b) const {
56     return a * b;
57   }
58 };
59 
60 template <typename T>
61 struct Square {
operatorSquare62   __host__ __device__ T operator()(const T& a) const {
63     return Prod<T>()(a, Eigen::numext::conj(a));
64   }
65 };
66 
67 template <typename T, typename OUT_T = T>
68 struct DividesBy {
69   T divisor;
70 
DividesByDividesBy71   __host__ __device__ explicit DividesBy(T divisor) : divisor(divisor) {}
72 
operatorDividesBy73   __host__ __device__ OUT_T operator()(const T& x) const { return x / divisor; }
74 };
75 
76 struct MaxPropagateNaN {
77   template <typename T>
operatorMaxPropagateNaN78   __host__ __device__ inline T operator()(const T& a, const T& b) const {
79     return (a != a ? a : (a > b ? a : b));
80   }
81 };
82 
83 struct MinPropagateNaN {
84   template <typename T>
operatorMinPropagateNaN85   __host__ __device__ inline T operator()(const T& a, const T& b) const {
86     return (a != a ? a : (a < b ? a : b));
87   }
88 };
89 
90 #if GOOGLE_CUDA
91 // TODO(rocm) : enable this once ROCm platform has support for complex datatypes
92 //
93 // needed to work around a compiler bug in nvcc - it doesn't seem to like
94 // the overloaded ops for std::complex
95 template <>
96 struct DividesBy<std::complex<float>> {
97   cuFloatComplex divisor;
98 
99   __host__ __device__ explicit DividesBy(std::complex<float> divisor)
100       : divisor(make_cuComplex(divisor.real(), divisor.imag())) {}
101 
102   // implements
103   __host__ __device__ std::complex<float> operator()(
104       const std::complex<float>& x) const {
105     auto result = cuCdivf(make_cuComplex(x.real(), x.imag()), divisor);
106     return std::complex<float>(result.x, result.y);
107   }
108 };
109 
110 template <>
111 struct DividesBy<std::complex<double>> {
112   cuDoubleComplex divisor;
113 
114   __host__ __device__ explicit DividesBy(std::complex<double> divisor)
115       : divisor(make_cuDoubleComplex(divisor.real(), divisor.imag())) {}
116 
117   // implements
118   __host__ __device__ std::complex<double> operator()(
119       const std::complex<double>& x) const {
120     auto result = cuCdiv(make_cuDoubleComplex(x.real(), x.imag()), divisor);
121     return std::complex<double>(result.x, result.y);
122   }
123 };
124 #endif  // GOOGLE_CUDA
125 
126 template <>
127 struct DividesBy<float, Eigen::half> {
128   float divisor;
129 
130   __host__ __device__ explicit DividesBy(float divisor) : divisor(divisor) {}
131 
132   __host__ __device__ Eigen::half operator()(const float& x) const {
133     return Eigen::half(x / divisor);
134   }
135 };
136 
137 struct HalfToFloat {
138   __host__ __device__ float operator()(const Eigen::half& x) const {
139     return static_cast<float>(x);
140   }
141 };
142 
143 struct FloatToHalf {
144   __host__ __device__ Eigen::half operator()(const float& x) const {
145     return static_cast<Eigen::half>(x);
146   }
147 };
148 
149 struct And {
150   __host__ __device__ bool operator()(const bool& a, const bool& b) const {
151     return a && b;
152   }
153 };
154 
155 struct Or {
156   __host__ __device__ bool operator()(const bool& a, const bool& b) const {
157     return a || b;
158   }
159 };
160 
161 // each block does a grid strided loop and reduces its values locally
162 // the case of one block is used for low latency small reductions to scalars
163 template <typename T, typename OUT_T, int num_threads, typename Op>
164 __global__ __launch_bounds__(1024) void BlockReduceKernel(
165     T in, OUT_T out, int num_elems, Op op,
166     typename std::iterator_traits<T>::value_type initVal) {
167   const int bid = blockIdx.x;
168   const int tid = threadIdx.x;
169 
170   const int gid = bid * blockDim.x + tid;
171   const int stride = blockDim.x * gridDim.x;
172 
173   typedef typename std::iterator_traits<T>::value_type value_type;
174 
175   value_type sum = initVal;
176   if (gid < num_elems) {
177     sum = in[gid];
178     for (int pos = gid + stride; pos < num_elems; pos += stride) {
179       sum = op(sum, in[pos]);
180     }
181   }
182 
183   typedef gpuprim::BlockReduce<value_type, num_threads> BlockReduce;
184 
185   __shared__ typename BlockReduce::TempStorage temp_storage;
186 
187   // only include input values in the reduction
188   //
189   // elements: -----------------
190   // grid:     |====|====|====|====|====|
191   const int num_elements_to_reduce =
192       max(min(num_elems - bid * blockDim.x, num_threads), 0);
193 
194   sum = BlockReduce(temp_storage).Reduce(sum, op, num_elements_to_reduce);
195 
196   if (tid == 0) out[bid] = sum;
197 }
198 
199 // maps a warp to each row
200 template <typename T, typename OUT_T, typename Op>
201 __global__ __launch_bounds__(1024) void RowReduceKernel(
202     T in, OUT_T out, int num_rows, int num_cols, Op op,
203     typename std::iterator_traits<T>::value_type initVal) {
204   typedef typename std::iterator_traits<T>::value_type value_type;
205   // Defensive index computation to avoid integer overflow.
206   assert(blockDim.x % TF_RED_WARPSIZE == 0);
207   int warps_per_block = blockDim.x / TF_RED_WARPSIZE;
208   int warp_index = threadIdx.x / TF_RED_WARPSIZE;
209   const int row = blockIdx.x * warps_per_block + warp_index;
210   const int lane = threadIdx.x % TF_RED_WARPSIZE;
211 
212   if (num_cols == 1) {
213     int gid = threadIdx.x + blockIdx.x * blockDim.x;
214     if (gid < num_rows) out[gid] = in[gid];
215     return;
216   }
217 
218   value_type sum = initVal;
219   int col = lane;
220 
221   if (row < num_rows && col < num_cols) {
222     sum = in[row * num_cols + col];
223     col += TF_RED_WARPSIZE;
224     for (; col < num_cols; col += TF_RED_WARPSIZE) {
225       sum = op(sum, in[row * num_cols + col]);
226     }
227   }
228 
229   typedef gpuprim::WarpReduce<value_type> WarpReduce;
230 
231   __shared__ typename WarpReduce::TempStorage temp_storage;
232 
233   sum =
234       WarpReduce(temp_storage).Reduce(sum, op, min(num_cols, TF_RED_WARPSIZE));
235 
236   if (row < num_rows && lane == 0) out[row] = sum;
237 }
238 
239 template <typename T1>
240 struct storage_type {
241   T1 val;
242   __host__ __device__ storage_type() {}
243   __host__ __device__ operator T1() { return val; }
244   __host__ __device__ storage_type<T1>& operator=(const T1& in) {
245     val = in;
246     return *this;
247   }
248 };
249 
250 template <typename T2>
251 struct storage_type<std::complex<T2>> {
252   T2 real;
253   T2 imag;
254   __host__ __device__ storage_type() {}
255   __host__ __device__ operator std::complex<T2>() {
256     return std::complex<T2>(real, imag);
257   }
258   __host__ __device__ storage_type<std::complex<T2>>& operator=(
259       const std::complex<T2>& in) {
260     real = in.real();
261     imag = in.imag();
262     return *this;
263   }
264 };
265 
266 // Works only if there are <= 16 columns
267 // each warps sums over multiple rows at once
268 template <typename T, typename OUT_T, typename Op>
269 __global__ __launch_bounds__(1024) void ColumnReduceMax16ColumnsKernel(
270     T in, OUT_T out, int num_rows, int num_cols, Op op,
271     typename std::iterator_traits<T>::value_type initVal) {
272   typedef typename std::iterator_traits<T>::value_type value_type;
273   int rows_per_warp = TF_RED_WARPSIZE / num_cols;
274 
275   const int lane = threadIdx.x % TF_RED_WARPSIZE;
276   const int lane_row = lane / num_cols;
277 
278   const int start_row_warp =
279       rows_per_warp * (blockIdx.y * blockDim.y + threadIdx.y);
280   const int start_row_lane = start_row_warp + lane_row;
281   int row = start_row_lane;
282   int col = lane % num_cols;
283 
284   value_type sum = initVal;
285   if (row * num_cols + col < num_rows * num_cols)
286     sum = in[row * num_cols + col];
287 
288     // 1D array necessary due to bug in CUDA 9 compiler.
289     // TODO(nluehr) revert to 2D array when compiler is ready.
290     // This is to mimic the following, but without any constructors:
291     //   __shared__ storage_type<value_type> partial_sums[TF_RED_WARPSIZE *
292     //   (TF_RED_WARPSIZE+1)];
293 #if GOOGLE_CUDA
294   __shared__ __align__(alignof(value_type)) char
295       partial_sums_raw[TF_RED_WARPSIZE * (TF_RED_WARPSIZE + 1) *
296                        sizeof(value_type)];
297   value_type* partial_sums = reinterpret_cast<value_type*>(partial_sums_raw);
298 #elif TENSORFLOW_USE_ROCM
299   __shared__ storage_type<value_type>
300       partial_sums[TF_RED_WARPSIZE * (TF_RED_WARPSIZE + 1)];
301 #endif
302 
303   row += rows_per_warp * gridDim.y * blockDim.y;
304   for (; row < num_rows; row += rows_per_warp * gridDim.y * blockDim.y) {
305     int global_pos = row * num_cols + col;
306     if (global_pos < (num_rows * num_cols))
307       sum = op(sum, in[row * num_cols + col]);
308   }
309 
310   const int rows_in_this_warp = min(rows_per_warp, num_rows - start_row_warp);
311   // not the most efficient way to do this sum
312   for (int i = 1; i < rows_in_this_warp; ++i) {
313     value_type tmp = gpuprim::ShuffleIndex<TF_RED_WARPSIZE, value_type>(
314         sum, static_cast<int>(threadIdx.x + i * num_cols), 0xffffffff);
315     if (lane < num_cols) sum = op(sum, tmp);
316   }
317 
318   if (lane < num_cols)
319     partial_sums[lane * (TF_RED_WARPSIZE + 1) + threadIdx.y] = sum;
320 
321   __syncthreads();
322 
323   if (threadIdx.y == 0 && threadIdx.x < num_cols) {
324     value_type s = partial_sums[threadIdx.x * (TF_RED_WARPSIZE + 1)];
325 
326     if (blockDim.y > 1) {
327       for (int row = 1; row < blockDim.y; ++row) {
328         value_type t = partial_sums[threadIdx.x * (TF_RED_WARPSIZE + 1) + row];
329         s = op(s, t);
330       }
331     }
332 
333     out[col * gridDim.y + blockIdx.y] = s;
334   }
335 }
336 
337 // Maps each block to a column range TF_RED_WARPSIZE wide
338 template <typename T, typename OUT_T, typename Op>
339 __global__ __launch_bounds__(1024) void ColumnReduceKernel(
340     T in, OUT_T out, int num_rows, int num_cols, Op op,
341     typename std::iterator_traits<T>::value_type initVal) {
342   typedef typename std::iterator_traits<T>::value_type value_type;
343   int row = blockIdx.y * blockDim.y + threadIdx.y;
344   int col = blockIdx.x * TF_RED_WARPSIZE + threadIdx.x;
345 
346   value_type sum = initVal;
347   if (row < num_rows && col < num_cols) sum = in[row * num_cols + col];
348 
349     // 1D array necessary due to bug in CUDA 9 compiler.
350     // TODO(nluehr) revert to 2D array when compiler is ready.
351     // This is to mimic the following, but without constructors:
352     //     __shared__ storage_type<value_type> partial_sums[TF_RED_WARPSIZE *
353     //     (TF_RED_WARPSIZE + 1)];
354 #if GOOGLE_CUDA
355   __shared__ __align__(alignof(value_type)) char
356       partial_sums_raw[TF_RED_WARPSIZE * (TF_RED_WARPSIZE + 1) *
357                        sizeof(value_type)];
358   value_type* partial_sums = reinterpret_cast<value_type*>(partial_sums_raw);
359 #elif TENSORFLOW_USE_ROCM
360   __shared__ storage_type<value_type>
361       partial_sums[TF_RED_WARPSIZE * (TF_RED_WARPSIZE + 1)];
362 #endif
363 
364   row += gridDim.y * blockDim.y;
365 
366   if (col < num_cols) {
367     for (; row < num_rows; row += gridDim.y * blockDim.y) {
368       sum = op(sum, in[row * num_cols + col]);
369     }
370   }
371 
372   partial_sums[threadIdx.x * (TF_RED_WARPSIZE + 1) + threadIdx.y] = sum;
373 
374   __syncthreads();
375 
376   if (threadIdx.y == 0 && col < num_cols) {
377     value_type s = partial_sums[threadIdx.x * (TF_RED_WARPSIZE + 1)];
378 
379     // only include input values in the reduction
380     // elem   block_rows
381     //  -         =
382     //  -         =
383     //  #         #  block boundary
384     //  -         =
385     //  -         =
386     //  #         #  block boundary
387     //  -         =
388     //            =
389     const int numRowsThisBlock =
390         min(static_cast<int>(blockDim.y), num_rows - blockIdx.y * blockDim.y);
391 
392     for (int row = 1; row < numRowsThisBlock; ++row) {
393       value_type t = partial_sums[threadIdx.x * (TF_RED_WARPSIZE + 1) + row];
394       s = op(s, t);
395     }
396 
397     out[col * gridDim.y + blockIdx.y] = s;
398   }
399 }
400 
401 // does multiple warp size segmented reductions in parallel
402 // segments cannot cross warp boundaries (mainly used for reducing the segments
403 // that come from the Max16Columns column reduction kernel)
404 template <typename T, typename OUT_T, typename Op>
405 __global__ __launch_bounds__(1024) void CleanupSegments(
406     T partial_sums, OUT_T out, int num_rows, int num_cols, int segment_size,
407     Op op, typename std::iterator_traits<T>::value_type initVal) {
408   typedef typename std::iterator_traits<T>::value_type value_type;
409   const int tid = threadIdx.x + blockIdx.x * blockDim.x;
410 
411   value_type val = initVal;
412   if (tid < segment_size * num_cols) val = partial_sums[tid];
413 
414   typedef gpuprim::WarpReduce<value_type> WarpReduce;
415 
416   __shared__ typename WarpReduce::TempStorage temp_storage;
417 
418   const bool head_flag = (threadIdx.x % segment_size) == 0;
419   value_type sum =
420       WarpReduce(temp_storage).HeadSegmentedReduce(val, head_flag, op);
421 
422   if (head_flag && tid < segment_size * num_cols) {
423     out[tid / segment_size] = sum;
424   }
425 }
426 
427 // assigns one thread to a column
428 template <typename T, typename OUT_T, typename Op>
429 __global__ __launch_bounds__(1024) void ColumnReduceSimpleKernel(
430     T in, OUT_T out, int num_planes, int num_rows, int num_cols, Op op) {
431   typedef typename std::iterator_traits<T>::value_type value_type;
432   const int gid = threadIdx.x + blockIdx.x * blockDim.x;
433   const int elems_per_plane = num_rows * num_cols;
434 
435   const int plane = gid / num_cols;
436   const int col = gid % num_cols;
437 
438   if (plane >= num_planes) return;
439 
440   if (num_rows == 1) {
441     out[plane * elems_per_plane + col] = in[plane * elems_per_plane + col];
442     return;
443   }
444 
445   value_type sum = op(in[plane * elems_per_plane + col],
446                       in[plane * elems_per_plane + num_cols + col]);
447   for (int row = 2; row < num_rows; ++row) {
448     sum = op(sum, in[plane * elems_per_plane + row * num_cols + col]);
449   }
450 
451   out[plane * num_cols + col] = sum;
452 }
453 
454 namespace {
455 constexpr int kUnroll = 8;
456 }
457 
458 template <typename T, typename IN_T, typename Op>
459 __device__ __inline__ T ComputeSum(IN_T in_, const int plane,
460                                    const int num_out_rows, int num_rows,
461                                    int num_cols, const int col, Op op) {
462   const int out_rows = num_rows / (2 * kUnroll);
463   const int num_rem_rows = num_rows % (2 * kUnroll);
464   const int elems_per_plane = num_rows * num_cols;
465   T reg[2 * kUnroll];
466   T sum;
467   int offset = 0;
468   if (out_rows != 0) {
469     for (int i = 0; i < 2 * kUnroll; i++) {
470       reg[i] =
471           in_[plane * elems_per_plane + i * (num_out_rows * num_cols) + col];
472     }
473     sum = reg[0];
474     for (int i = 1; i < 2 * kUnroll; i++) {
475       sum = op(sum, reg[i]);
476     }
477     offset = 2 * kUnroll * (num_out_rows * num_cols);
478   }
479 
480   if (col < num_cols && num_rem_rows > 0) {
481     reg[0] = in_[plane * elems_per_plane + offset + 0 * num_cols + col];
482     if (out_rows != 0) {
483       sum = op(sum, reg[0]);
484     } else {
485       sum = reg[0];
486     }
487     for (int i = 1; i < num_rem_rows; i++) {
488       reg[0] = in_[plane * elems_per_plane + offset + i * num_cols + col];
489       sum = op(sum, reg[0]);
490     }
491   }
492   return sum;
493 }
494 
495 template <typename IN_T, typename Op>
496 __global__ __launch_bounds__(1024) void ColumnReduceInToTempKernel(
497     void* __restrict__ temp, int temp_in_offset, int temp_out_offset, IN_T in,
498     int num_planes, int num_rows, int num_cols, Op op) {
499   typedef typename std::iterator_traits<IN_T>::value_type value_type;
500 
501   value_type* t = (value_type*)temp;
502   value_type* out_ = t + temp_out_offset;
503 
504   const int gid = threadIdx.x + blockIdx.x * blockDim.x;
505   const int num_out_rows = max(1, num_rows / (2 * kUnroll));
506   const int plane = gid / (num_out_rows * num_cols);
507   const int col = gid % (num_out_rows * num_cols);
508 
509   if (plane >= num_planes) return;
510 
511   value_type sum;
512   if (temp_in_offset == -1) {
513     auto in_ = in;
514     sum = ComputeSum<value_type, IN_T, Op>(in_, plane, num_out_rows, num_rows,
515                                            num_cols, col, op);
516   } else {
517     auto in_ = t + temp_in_offset;
518     sum = ComputeSum<value_type, value_type*, Op>(in_, plane, num_out_rows,
519                                                   num_rows, num_cols, col, op);
520   }
521   out_[plane * num_out_rows * num_cols + col] = sum;
522 }
523 
524 template <typename T, typename OUT_T, typename Op>
525 __global__ __launch_bounds__(1024) void ColumnReduceTempToOutKernel(
526     void* __restrict__ temp, int temp_in_offset, T in, OUT_T out,
527     int num_planes, int num_rows, int num_cols, Op op) {
528   typedef typename std::iterator_traits<T>::value_type value_type;
529   value_type* t = (value_type*)temp;
530   const int tid = threadIdx.x;
531   const int gid = threadIdx.x + blockIdx.x * blockDim.x;
532   int elems_per_plane = num_rows * num_cols;
533 
534   if (num_rows == 1) {
535     if (gid >= num_planes * num_cols) return;
536     if (temp_in_offset == -1) {
537       auto in_ = in;
538       out[gid] = in_[gid];
539     } else {
540       auto in_ = t + temp_in_offset;
541       out[gid] = in_[gid];
542     }
543     return;
544   }
545 
546   const int planes_per_block = 1;
547   const int plane = blockIdx.x * planes_per_block + tid / elems_per_plane;
548   // A thread block contains one or multiple plane(s),
549   // i.e. num_rows * num_cols <= blockDim.x
550   const int col = tid % elems_per_plane;
551   const int local_plane = plane % planes_per_block;
552 
553   if (tid >= planes_per_block * elems_per_plane || plane >= num_planes) return;
554 
555   GPU_DYNAMIC_SHARED_MEM_DECL(8, char, ss);
556   value_type* const smem = reinterpret_cast<value_type*>(ss);
557 
558   if (temp_in_offset == -1) {
559     auto in_ = in;
560     smem[local_plane * elems_per_plane + col] =
561         in_[plane * elems_per_plane + col];
562   } else {
563     auto in_ = t + temp_in_offset;
564     smem[local_plane * elems_per_plane + col] =
565         in_[plane * elems_per_plane + col];
566   }
567   __syncthreads();
568 
569   int num_in_rows = num_rows;
570   int num_out_rows;
571   int num_rem_rows;
572 
573   int in_offset = 0;
574   int out_offset = blockDim.x;
575 
576   int in_elems_per_plane = elems_per_plane;
577   int out_elems_per_plane;
578 
579   while (num_in_rows > 1) {
580     num_out_rows = num_in_rows / 2;
581     num_rem_rows = num_in_rows % 2;
582     out_elems_per_plane = num_out_rows * num_cols;
583 
584     if (col < out_elems_per_plane) {
585       value_type sum;
586       sum = op(smem[in_offset + local_plane * in_elems_per_plane + col],
587                smem[in_offset + local_plane * in_elems_per_plane +
588                     out_elems_per_plane + col]);
589       if (num_rem_rows == 1 && col < num_cols) {
590         sum = op(sum, smem[in_offset + local_plane * in_elems_per_plane +
591                            2 * out_elems_per_plane + col]);
592       }
593       smem[out_offset + local_plane * out_elems_per_plane + col] = sum;
594     }
595 
596     num_in_rows = num_out_rows;
597     in_elems_per_plane = out_elems_per_plane;
598     int t_offset = in_offset;
599     in_offset = out_offset;
600     out_offset = t_offset;
601     __syncthreads();
602   }
603 
604   if (col < num_cols) {
605     out[plane * num_cols + col] =
606         smem[in_offset + local_plane * out_elems_per_plane + col];
607   }
608 }
609 
610 struct RowOffset {
611   __host__ __device__ explicit RowOffset(const int& cols) : cols_(cols) {}
612 
613   __host__ __device__ int operator()(const int& x) const { return cols_ * x; }
614 
615   int cols_;
616 };
617 
618 struct GatherOp {
619   __host__ __device__ GatherOp(const int& extent_x, const int& extent_y,
620                                const int& extent_z, bool kOne)
621       : extent_x_(extent_x),
622         extent_y_(extent_y),
623         extent_z_(extent_z),
624         kOne_(kOne) {
625     if (kOne_)
626       group_size_ = extent_y_;
627     else
628       group_size_ = extent_x_ * extent_z_;
629   }
630 
631   __host__ __device__ int operator()(const int& ind) const {
632     const int group = kOne_ ? ind / group_size_ : ind % group_size_;
633     const int offset = kOne_ ? ind % group_size_ : ind / group_size_;
634 
635     const int x = group / extent_z_;
636     const int z = group % extent_z_;
637 
638     return x * extent_y_ * extent_z_ + z + offset * extent_z_;
639   }
640 
641   int extent_x_;
642   int extent_y_;
643   int extent_z_;
644   bool kOne_;
645   int group_size_;
646 };
647 
648 template <typename T, typename Op, typename OUT_T, typename IN_T>
649 void LaunchScalarReduction(OpKernelContext* ctx, OUT_T out, IN_T in,
650                            int in_size, Op op, T init,
651                            const gpuStream_t& cu_stream) {
652   // handle situations where low latency is important better than CUB
653   if (in_size <= 4096) {
654     const int num_blocks = 1;
655     const int num_threads = 256;
656     TF_CHECK_OK(GpuLaunchKernel(BlockReduceKernel<IN_T, OUT_T, num_threads, Op>,
657                                 num_blocks, num_threads, 0, cu_stream, in, out,
658                                 in_size, op, init));
659     return;
660   } else if (in_size <= 1 << 18) {
661     const int num_threads = 256;
662     const int num_blocks =
663         std::min(TF_RED_WARPSIZE, Eigen::divup(in_size, num_threads));
664     // it seems like tailoring this to the GPU
665     // would be more effective, but all attempts
666     // at making this a multiple of the number of
667     // multiprocessors have lead to lower perf
668     // in general
669     // TODO(eriche) investigate this more
670 
671     Tensor temp_storage;
672     OP_REQUIRES_OK(
673         ctx, ctx->allocate_temp(
674                  DT_INT8,
675                  TensorShape({static_cast<int64_t>(num_blocks * sizeof(T))}),
676                  &temp_storage));
677 
678     TF_CHECK_OK(GpuLaunchKernel(BlockReduceKernel<IN_T, T*, num_threads, Op>,
679                                 num_blocks, num_threads, 0, cu_stream, in,
680                                 (T*)temp_storage.flat<int8_t>().data(), in_size,
681                                 op, init));
682 
683     // take care that we only reduce blocks that had some valid elements in them
684     // TODO(eriche): CUB currently has a bug in HeadSegmentedReduce that
685     // requires it to be used with a full warp.  Can reduce TF_RED_WARPSIZE ->
686     // num_blocks when this is fixed.
687     TF_CHECK_OK(GpuLaunchKernel(CleanupSegments<T*, OUT_T, Op>, 1,
688                                 TF_RED_WARPSIZE, 0, cu_stream,
689                                 (T*)temp_storage.flat<int8_t>().data(), out, 1,
690                                 1, num_blocks, op, init));
691     return;
692   }
693 
694   size_t temp_storage_bytes = 0;
695   auto reduce = [&](void* temp_storage_ptr) {
696     auto success =
697         gpuprim::DeviceReduce::Reduce(temp_storage_ptr, temp_storage_bytes, in,
698                                       out, in_size, op, init, cu_stream);
699 
700     OP_REQUIRES(
701         ctx, success == 0,
702         errors::Internal("CUB reduce error ", GpuGetErrorString(success)));
703   };
704 
705   reduce(nullptr);  // Get required amount of temp storage.
706 
707   Tensor temp_storage;
708   OP_REQUIRES_OK(
709       ctx, ctx->allocate_temp(
710                DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
711                &temp_storage));
712 
713   reduce(temp_storage.flat<int8_t>().data());  // Do reduction.
714 }
715 
716 template <typename T, typename Op, typename OUT_T, typename IN_T>
717 void LaunchRowReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int num_rows,
718                         int num_cols, Op op, T init,
719                         const gpuStream_t& cu_stream) {
720   if (num_cols < 1024) {
721     const int threads_per_block = 128;
722     const int warps_per_block = threads_per_block / TF_RED_WARPSIZE;
723     int num_blocks = (num_rows + warps_per_block - 1) / warps_per_block;
724 
725     TF_CHECK_OK(GpuLaunchKernel(RowReduceKernel<IN_T, OUT_T, Op>, num_blocks,
726                                 threads_per_block, 0, cu_stream, in, out,
727                                 num_rows, num_cols, op, init));
728     return;
729   }
730 
731   // setup segment offsets with counting and transform iterator
732   RowOffset row_offset_op(num_cols);
733   gpuprim::CountingInputIterator<int> counting_iter(0);
734   gpuprim::TransformInputIterator<int, RowOffset,
735                                   gpuprim::CountingInputIterator<int>>
736       transform_iter(counting_iter, row_offset_op);
737 
738   size_t temp_storage_bytes = 0;
739   auto reduce = [&](void* temp_storage_ptr) {
740     auto success = gpuprim::DeviceSegmentedReduce::Reduce(
741         temp_storage_ptr, temp_storage_bytes, in, out, num_rows, transform_iter,
742         transform_iter + 1, op, init, cu_stream);
743 
744     OP_REQUIRES(ctx, success == 0,
745                 errors::Internal("CUB segmented reduce error",
746                                  GpuGetErrorString(success)));
747   };
748 
749   reduce(nullptr);  // Get required amount of temp storage.
750 
751   Tensor temp_storage;
752   OP_REQUIRES_OK(
753       ctx, ctx->allocate_temp(
754                DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
755                &temp_storage));
756 
757   reduce(temp_storage.flat<int8_t>().data());  // Do reduction.
758 }
759 
760 template <typename T, typename Op, typename OUT_T, typename IN_T>
761 void LaunchColumnReduction_LTE16Cols(OpKernelContext* ctx, OUT_T out, IN_T in,
762                                      int extent_x, int extent_y, Op op, T init,
763                                      const gpuStream_t& cu_stream) {
764   int rows_per_warp = TF_RED_WARPSIZE / extent_y;
765   dim3 block_dim(
766       TF_RED_WARPSIZE,
767       std::min(Eigen::divup(extent_x, rows_per_warp), (1024 / TF_RED_WARPSIZE)),
768       1);
769   dim3 grid_dim(1,
770                 Eigen::divup(static_cast<unsigned int>(extent_x),
771                              rows_per_warp * block_dim.y),
772                 1);
773 
774   grid_dim.y = std::min((int)grid_dim.y, TF_RED_WARPSIZE);
775 
776   if (grid_dim.y > 2 && grid_dim.y < TF_RED_WARPSIZE) {
777     int log2 = Log2Floor(grid_dim.y);
778     grid_dim.y = 1 << log2;
779   }
780 
781   if (grid_dim.y == 1) {
782     TF_CHECK_OK(GpuLaunchKernel(ColumnReduceMax16ColumnsKernel<IN_T, OUT_T, Op>,
783                                 grid_dim, block_dim, 0, cu_stream, in, out,
784                                 extent_x, extent_y, op, init));
785   } else {
786     Tensor temp_storage;
787     OP_REQUIRES_OK(ctx,
788                    ctx->allocate_temp(DT_INT8,
789                                       TensorShape({static_cast<int64_t>(
790                                           sizeof(T) * extent_y * grid_dim.y)}),
791                                       &temp_storage));
792     TF_CHECK_OK(GpuLaunchKernel(ColumnReduceMax16ColumnsKernel<IN_T, T*, Op>,
793                                 grid_dim, block_dim, 0, cu_stream, in,
794                                 (T*)temp_storage.flat<int8_t>().data(),
795                                 extent_x, extent_y, op, init));
796 
797     dim3 new_grid_dim(
798         (grid_dim.y * extent_y + (TF_RED_WARPSIZE - 1)) / TF_RED_WARPSIZE, 1,
799         1);
800     dim3 num_threads(128, 1, 1);
801     TF_CHECK_OK(GpuLaunchKernel(CleanupSegments<T*, OUT_T, Op>, new_grid_dim,
802                                 num_threads, 0, cu_stream,
803                                 (T*)temp_storage.flat<int8_t>().data(), out,
804                                 extent_x, extent_y, grid_dim.y, op, init));
805   }
806 }
807 
808 template <typename T, typename Op, typename OUT_T, typename IN_T>
809 void LaunchColumnReduction_LTE4096Cols(OpKernelContext* ctx, OUT_T out, IN_T in,
810                                        int extent_x, int extent_y, Op op,
811                                        T init, const gpuStream_t& cu_stream) {
812   dim3 block_dim(TF_RED_WARPSIZE, std::min(extent_x, (1024 / TF_RED_WARPSIZE)),
813                  1);
814   dim3 grid_dim((extent_y + (TF_RED_WARPSIZE - 1)) / TF_RED_WARPSIZE, 1, 1);
815 
816   if (grid_dim.x < 16)
817     grid_dim.y = std::min((extent_x + (TF_RED_WARPSIZE - 1)) / TF_RED_WARPSIZE,
818                           TF_RED_WARPSIZE);
819 
820   if (grid_dim.y > 2 && grid_dim.y < TF_RED_WARPSIZE) {
821     int log2 = Log2Floor(grid_dim.y);
822     grid_dim.y = 1 << log2;
823   }
824 
825   if (grid_dim.y == 1) {
826     TF_CHECK_OK(GpuLaunchKernel(ColumnReduceKernel<IN_T, OUT_T, Op>, grid_dim,
827                                 block_dim, 0, cu_stream, in, out, extent_x,
828                                 extent_y, op, init));
829   } else {
830     Tensor temp_storage;
831     OP_REQUIRES_OK(ctx,
832                    ctx->allocate_temp(DT_INT8,
833                                       TensorShape({static_cast<int64_t>(
834                                           sizeof(T) * extent_y * grid_dim.y)}),
835                                       &temp_storage));
836 
837     TF_CHECK_OK(GpuLaunchKernel(
838         ColumnReduceKernel<IN_T, T*, Op>, grid_dim, block_dim, 0, cu_stream, in,
839         (T*)temp_storage.flat<int8_t>().data(), extent_x, extent_y, op, init));
840 
841     dim3 new_grid_dim(
842         (grid_dim.y * extent_y + (TF_RED_WARPSIZE - 1)) / TF_RED_WARPSIZE, 1,
843         1);
844     TF_CHECK_OK(GpuLaunchKernel(CleanupSegments<T*, OUT_T, Op>, new_grid_dim,
845                                 block_dim, 0, cu_stream,
846                                 (T*)temp_storage.flat<int8_t>().data(), out,
847                                 extent_x, extent_y, grid_dim.y, op, init));
848   }
849 }
850 
851 template <typename T, typename Op, typename OUT_T, typename IN_T>
852 void LaunchColumnReduction(OpKernelContext* ctx, OUT_T out, IN_T in,
853                            int extent_x, int extent_y, Op op, T init,
854                            const gpuStream_t& cu_stream) {
855   if (extent_y <= 16) {
856     LaunchColumnReduction_LTE16Cols(ctx, out, in, extent_x, extent_y, op, init,
857                                     cu_stream);
858   } else if (extent_y <= 4096) {
859     LaunchColumnReduction_LTE4096Cols(ctx, out, in, extent_x, extent_y, op,
860                                       init, cu_stream);
861   } else {
862     int threads_per_block = 128;
863     int num_blocks = Eigen::divup(extent_y, threads_per_block);
864 
865     TF_CHECK_OK(GpuLaunchKernel(ColumnReduceSimpleKernel<IN_T, OUT_T, Op>,
866                                 num_blocks, threads_per_block, 0, cu_stream, in,
867                                 out, 1, extent_x, extent_y, op));
868   }
869 }
870 
871 template <typename T, typename Op, typename OUT_T, typename IN_T>
872 void Launch3DYReductionSimple(OpKernelContext* ctx, OUT_T out, IN_T in,
873                               int extent_x, int extent_y, int extent_z, Op op,
874                               T init, const gpuStream_t& cu_stream) {
875   int threads_per_block = 128;
876   int num_blocks =
877       (extent_x * extent_z + threads_per_block - 1) / threads_per_block;
878 
879   // TODO(eriche): this won't be very good in the case of small x
880   //                small z and large y.
881   TF_CHECK_OK(GpuLaunchKernel(ColumnReduceSimpleKernel<IN_T, OUT_T, Op>,
882                               num_blocks, threads_per_block, 0, cu_stream, in,
883                               out, extent_x, extent_y, extent_z, op));
884 }
885 
886 template <typename T, typename Op, typename OUT_T, typename IN_T>
887 void Launch3DYReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x,
888                         int extent_y, int extent_z, Op op, T init,
889                         const gpuStream_t& cu_stream) {
890   int threads_per_block = 128;
891 
892   int n_group_in = extent_y;
893   int n_size = extent_z;
894 
895   // Calculate and allocate temporary space
896   std::size_t temp_storage_bytes = 0;
897   // A plane's size is n_group_in * n_size. We make sure no single plane crosses
898   // more than one thread block, meaning a thread block will handle one whole
899   // plane or multiple planes in the second stage. Also, It may handle a partial
900   // plane when n_size is too large and the while-loop will stop at
901   // n_group_in = 1, where we directly copy the temp to output in the next
902   // stage.
903   while (n_group_in >= 2 && n_group_in * n_size > threads_per_block) {
904     int n_group_out = std::max(1, n_group_in / (2 * kUnroll));
905     temp_storage_bytes += n_group_out * n_size;
906     n_group_in = n_group_out;
907   }
908   temp_storage_bytes *= extent_x * sizeof(T);
909   Tensor temp_storage;
910   OP_REQUIRES_OK(
911       ctx, ctx->allocate_temp(
912                DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
913                &temp_storage));
914 
915   // Reduction
916   n_group_in = extent_y;
917   int temp_in_offset = -1;
918   int temp_out_offset = 0;
919   int num_blocks;
920   while (n_group_in >= 2 && n_group_in * n_size > threads_per_block) {
921     int n_group_out = std::max(1, n_group_in / (2 * kUnroll));
922     num_blocks =
923         Eigen::divup(extent_x * n_group_out * n_size, threads_per_block);
924     TF_CHECK_OK(GpuLaunchKernel(
925         ColumnReduceInToTempKernel<IN_T, Op>, num_blocks, threads_per_block, 0,
926         cu_stream, (void*)(temp_storage.flat<int8_t>().data()), temp_in_offset,
927         temp_out_offset, in, extent_x, n_group_in, extent_z, op));
928 
929     n_group_in = n_group_out;
930     temp_in_offset = temp_out_offset;
931     temp_out_offset = temp_in_offset + extent_x * n_group_out * n_size;
932   }
933 
934   if (n_group_in * n_size <= threads_per_block) {
935     num_blocks = extent_x;
936   } else {
937     DCHECK_EQ(1, n_group_in);
938     num_blocks = Eigen::divup(extent_x * n_size, threads_per_block);
939   }
940 
941   TF_CHECK_OK(GpuLaunchKernel(
942       ColumnReduceTempToOutKernel<IN_T, OUT_T, Op>, num_blocks,
943       threads_per_block, 2 * sizeof(T) * threads_per_block, cu_stream,
944       (void*)(temp_storage.flat<int8_t>().data()), temp_in_offset, in, out,
945       extent_x, n_group_in, extent_z, op));
946 }
947 
948 template <typename T, typename Op, typename OUT_T, typename IN_T>
949 void Launch3DXZReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x,
950                          int extent_y, int extent_z, Op op, T init,
951                          const gpuStream_t& cu_stream) {
952   // setup segment offsets with counting and transform iterator
953   RowOffset row_offset_op(extent_x * extent_z);
954   gpuprim::CountingInputIterator<int> counting_iter(0);
955   gpuprim::TransformInputIterator<int, RowOffset,
956                                   gpuprim::CountingInputIterator<int>>
957       transform_iter(counting_iter, row_offset_op);
958 
959   GatherOp gather_op(extent_x, extent_y, extent_z, false);
960   typedef gpuprim::TransformInputIterator<int, GatherOp,
961                                           gpuprim::CountingInputIterator<int>>
962       gatherIterType;
963   gatherIterType gather_iter(counting_iter, gather_op);
964 
965   PermutationInputIterator<T, IN_T, gatherIterType> permute_iter(in,
966                                                                  gather_iter);
967 
968   std::size_t temp_storage_bytes = 0;
969   auto reduce = [&](void* temp_storage_ptr) {
970     auto success = gpuprim::DeviceSegmentedReduce::Reduce(
971         temp_storage_ptr, temp_storage_bytes, permute_iter, out, extent_y,
972         transform_iter, transform_iter + 1, op, init, cu_stream);
973 
974     OP_REQUIRES(ctx, success == 0,
975                 errors::Internal("CUB segmented reduce error",
976                                  GpuGetErrorString(success)));
977   };
978 
979   reduce(nullptr);  // Get required amount of temp storage.
980 
981   Tensor temp_storage;
982   OP_REQUIRES_OK(
983       ctx, ctx->allocate_temp(
984                DT_INT8, TensorShape({static_cast<int64_t>(temp_storage_bytes)}),
985                &temp_storage));
986 
987   reduce(temp_storage.flat<int8_t>().data());  // Do reduction.
988 }
989 
990 namespace reduction_op_helper {
991 
992 template <typename T, typename Op>
993 struct IsSum {
994   constexpr static bool value =
995       (std::is_same<Op, gpuprim::Sum>::value ||
996        std::is_same<Op, Eigen::internal::SumReducer<T>>::value ||
997        std::is_same<Op, Sum<T>>::value);
998 };
999 
1000 template <typename T, typename Op>
1001 struct IsMax {
1002   constexpr static bool value =
1003       (std::is_same<Op, MaxPropagateNaN>::value ||
1004        std::is_same<Op, gpuprim::Max>::value ||
1005        std::is_same<
1006            Op, Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>>::value);
1007 };
1008 
1009 template <typename T, typename Op>
1010 struct IsMin {
1011   constexpr static bool value =
1012       (std::is_same<Op, MinPropagateNaN>::value ||
1013        std::is_same<Op, gpuprim::Min>::value ||
1014        std::is_same<
1015            Op, Eigen::internal::MinReducer<T, Eigen::PropagateNaN>>::value);
1016 };
1017 
1018 template <typename T, typename Op>
1019 struct IsProd {
1020   constexpr static bool value =
1021       (std::is_same<Op, Prod<T>>::value ||
1022        std::is_same<Op, Eigen::internal::ProdReducer<T>>::value);
1023 };
1024 
1025 template <typename T, typename Op>
1026 struct IdentityValue {
1027   static_assert(IsSum<T, Op>::value || IsMax<T, Op>::value ||
1028                     IsMin<T, Op>::value || IsProd<T, Op>::value ||
1029                     std::is_same<Op, And>::value || std::is_same<Op, Or>::value,
1030                 "IdentityValue not yet defined for this type");
1031 
1032   template <typename U = T, typename OpCopy = Op>
1033   U operator()(
1034       typename std::enable_if<IsSum<U, OpCopy>::value, U>::type t = U(0)) {
1035     return t;
1036   }
1037 
1038   template <typename U = T, typename OpCopy = Op>
1039   U operator()(typename std::enable_if<IsMax<U, OpCopy>::value, U>::type t =
1040                    Eigen::NumTraits<U>::lowest()) {
1041     return t;
1042   }
1043 
1044   template <typename U = T, typename OpCopy = Op>
1045   U operator()(typename std::enable_if<IsMin<U, OpCopy>::value, U>::type t =
1046                    Eigen::NumTraits<U>::highest()) {
1047     return t;
1048   }
1049 
1050   template <typename U = T, typename OpCopy = Op>
1051   U operator()(
1052       typename std::enable_if<IsProd<U, OpCopy>::value, U>::type t = U(1)) {
1053     return t;
1054   }
1055 
1056   template <typename U = T, typename OpCopy = Op>
1057   U operator()(typename std::enable_if<std::is_same<OpCopy, And>::value,
1058                                        bool>::type t = true) {
1059     return t;
1060   }
1061 
1062   template <typename U = T, typename OpCopy = Op>
1063   U operator()(typename std::enable_if<std::is_same<OpCopy, Or>::value,
1064                                        bool>::type t = false) {
1065     return t;
1066   }
1067 };
1068 
1069 }  // namespace reduction_op_helper
1070 
1071 template <typename T, typename Op, typename OUT_T, typename IN_T,
1072           typename ReductionAxes>
1073 void ReduceImpl(OpKernelContext* ctx, OUT_T out, IN_T in, int in_rank,
1074                 int in_dim0, int in_dim1, int in_dim2, int out_rank,
1075                 const ReductionAxes& reduction_axes, Op op) {
1076   T init = reduction_op_helper::IdentityValue<T, Op>()();
1077   const gpuStream_t& cu_stream = GetGpuStream(ctx);
1078   if (out_rank == 0) {
1079     const int in_size = in_dim0 * in_dim1 * in_dim2;
1080     LaunchScalarReduction(ctx, out, in, in_size, op, init, cu_stream);
1081   } else if (in_rank == 2 && out_rank == 1 &&
1082              reduction_axes[0] == 1) {  // row reduction
1083     LaunchRowReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream);
1084   } else if (in_rank == 2 && out_rank == 1 &&
1085              reduction_axes[0] == 0) {  // column reduction
1086     LaunchColumnReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream);
1087   } else if (in_rank == 3 && out_rank == 2 && reduction_axes[0] == 1) {
1088     int elems_per_thread = in_dim1 / (in_dim0 * in_dim2);
1089     if (elems_per_thread >= 16) {
1090       Launch3DYReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init,
1091                          cu_stream);
1092     } else {
1093       Launch3DYReductionSimple(ctx, out, in, in_dim0, in_dim1, in_dim2, op,
1094                                init, cu_stream);
1095     }
1096   } else if (in_rank == 3 && out_rank == 1 && reduction_axes[0] == 0 &&
1097              reduction_axes[1] == 2) {
1098     Launch3DXZReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init,
1099                         cu_stream);
1100   } else {
1101     std::stringstream ss;
1102     ss << "Invalid reduction requested: in_rank, out_rank, axes " << in_rank
1103        << " " << out_rank;
1104     if (out_rank == 1) ss << " " << reduction_axes[0];
1105     if (out_rank == 2) ss << " " << reduction_axes[1];
1106     LOG(FATAL) << ss.str();
1107   }
1108 }
1109 
1110 template <typename Reducer>
1111 struct ReduceFunctor<GPUDevice, Reducer> {
1112   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1113   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
1114                      const ReductionAxes& reduction_axes,
1115                      const Reducer& reducer);
1116 };
1117 
1118 template <typename T>
1119 struct ReduceFunctor<GPUDevice, Eigen::internal::SumReducer<T>> {
1120   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1121   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
1122                      const ReductionAxes& reduction_axes,
1123                      const Eigen::internal::SumReducer<T>& reducer) {
1124     ReduceImpl<T, Sum<T>, T*, T*, ReductionAxes>(
1125         ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
1126         in.rank() >= 2 ? in.dimension(1) : 1,
1127         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
1128         Sum<T>());
1129   }
1130 
1131   template <typename OUT_T>
1132   static void FillIdentity(const GPUDevice& d, OUT_T out,
1133                            const Eigen::internal::SumReducer<T>& reducer) {
1134     FillIdentityEigenImpl(d, out, reducer);
1135   }
1136 };
1137 
1138 // TODO(rmlarsen): Specialize for float16.
1139 template <typename T>
1140 struct ReduceFunctor<GPUDevice, functor::EuclideanNormReducer<T>> {
1141   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1142   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
1143                      const ReductionAxes& reduction_axes,
1144                      const functor::EuclideanNormReducer<T>& reducer) {
1145     typedef gpuprim::TransformInputIterator<T, Square<T>, T*> inputIterType;
1146     inputIterType input_itr((T*)in.data(), Square<T>());
1147     typedef TransformOutputIterator<T, T, SqrtOfReal<T>> outputIterType;
1148     outputIterType output_itr((T*)out.data(), SqrtOfReal<T>());
1149     ReduceImpl<T, Sum<T>, outputIterType, inputIterType, ReductionAxes>(
1150         ctx, output_itr, input_itr, in.rank(), in.dimension(0),
1151         in.rank() >= 2 ? in.dimension(1) : 1,
1152         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
1153         Sum<T>());
1154   }
1155 
1156   template <typename OUT_T>
1157   static void FillIdentity(const GPUDevice& d, OUT_T out,
1158                            const functor::EuclideanNormReducer<T>& reducer) {
1159     FillIdentityEigenImpl(d, out, reducer);
1160   }
1161 };
1162 
1163 template <typename T>
1164 struct ReduceFunctor<GPUDevice, functor::MeanReducer<T>> {
1165   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1166   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
1167                      const ReductionAxes& reduction_axes,
1168                      const functor::MeanReducer<T>& reducer) {
1169     int divisor = 1;
1170     if (out.rank() == 0)
1171       divisor = in.size();
1172     else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0)
1173       divisor = in.dimension(0);
1174     else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1)
1175       divisor = in.dimension(1);
1176     else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 &&
1177              reduction_axes[1] == 2)
1178       divisor = in.dimension(0) * in.dimension(2);
1179     else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1)
1180       divisor = in.dimension(1);
1181 
1182     DividesBy<T> div_op(static_cast<T>(divisor));
1183     TransformOutputIterator<T, T, DividesBy<T>> itr((T*)out.data(), div_op);
1184     ReduceImpl<T, Sum<T>, TransformOutputIterator<T, T, DividesBy<T>>, T*,
1185                ReductionAxes>(ctx, itr, (T*)in.data(), in.rank(),
1186                               in.dimension(0),
1187                               in.rank() >= 2 ? in.dimension(1) : 1,
1188                               in.rank() >= 3 ? in.dimension(2) : 1, out.rank(),
1189                               reduction_axes, Sum<T>());
1190   }
1191 
1192   template <typename OUT_T>
1193   static void FillIdentity(const GPUDevice& d, OUT_T out,
1194                            const functor::MeanReducer<T>& reducer) {
1195     FillIdentityEigenImpl(d, out, reducer);
1196   }
1197 };
1198 
1199 template <>
1200 struct ReduceFunctor<GPUDevice, functor::MeanReducer<Eigen::half>> {
1201   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1202   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
1203                      const ReductionAxes& reduction_axes,
1204                      const functor::MeanReducer<Eigen::half>& reducer) {
1205     float divisor = 1.f;
1206     if (out.rank() == 0)
1207       divisor = in.size();
1208     else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0)
1209       divisor = in.dimension(0);
1210     else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1)
1211       divisor = in.dimension(1);
1212     else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 &&
1213              reduction_axes[1] == 2)
1214       divisor = in.dimension(0) * in.dimension(2);
1215     else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1)
1216       divisor = in.dimension(1);
1217     DividesBy<float, Eigen::half> div_op(divisor);
1218 
1219     typedef gpuprim::TransformInputIterator<float, HalfToFloat, Eigen::half*>
1220         inputIterType;
1221     inputIterType input_itr((Eigen::half*)in.data(), HalfToFloat());
1222 
1223     typedef TransformOutputIterator<Eigen::half, float,
1224                                     DividesBy<float, Eigen::half>>
1225         outputIterType;
1226     outputIterType itr((Eigen::half*)out.data(), div_op);
1227 
1228     ReduceImpl<float, gpuprim::Sum, outputIterType, inputIterType,
1229                ReductionAxes>(ctx, itr, input_itr, in.rank(), in.dimension(0),
1230                               in.rank() >= 2 ? in.dimension(1) : 1,
1231                               in.rank() >= 3 ? in.dimension(2) : 1, out.rank(),
1232                               reduction_axes, gpuprim::Sum());
1233   }
1234 
1235   template <typename OUT_T>
1236   static void FillIdentity(const GPUDevice& d, OUT_T out,
1237                            const functor::MeanReducer<Eigen::half>& reducer) {
1238     FillIdentityEigenImpl(d, out, reducer);
1239   }
1240 };
1241 
1242 template <typename T>
1243 struct ReduceFunctor<GPUDevice,
1244                      Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>> {
1245   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1246   static void Reduce(
1247       OpKernelContext* ctx, OUT_T out, IN_T in,
1248       const ReductionAxes& reduction_axes,
1249       const Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>& reducer) {
1250     ReduceImpl<T, MaxPropagateNaN, T*, T*, ReductionAxes>(
1251         ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
1252         in.rank() >= 2 ? in.dimension(1) : 1,
1253         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
1254         MaxPropagateNaN());
1255   }
1256 
1257   template <typename OUT_T>
1258   static void FillIdentity(
1259       const GPUDevice& d, OUT_T out,
1260       const Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>& reducer) {
1261     FillIdentityEigenImpl(d, out, reducer);
1262   }
1263 };
1264 
1265 template <typename T>
1266 struct ReduceFunctor<GPUDevice,
1267                      Eigen::internal::MinReducer<T, Eigen::PropagateNaN>> {
1268   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1269   static void Reduce(
1270       OpKernelContext* ctx, OUT_T out, IN_T in,
1271       const ReductionAxes& reduction_axes,
1272       const Eigen::internal::MinReducer<T, Eigen::PropagateNaN>& reducer) {
1273     ReduceImpl<T, MinPropagateNaN, T*, T*, ReductionAxes>(
1274         ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
1275         in.rank() >= 2 ? in.dimension(1) : 1,
1276         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
1277         MinPropagateNaN());
1278   }
1279 
1280   template <typename OUT_T>
1281   static void FillIdentity(
1282       const GPUDevice& d, OUT_T out,
1283       const Eigen::internal::MinReducer<T, Eigen::PropagateNaN>& reducer) {
1284     FillIdentityEigenImpl(d, out, reducer);
1285   }
1286 };
1287 
1288 template <typename T>
1289 struct ReduceFunctor<GPUDevice, Eigen::internal::ProdReducer<T>> {
1290   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1291   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
1292                      const ReductionAxes& reduction_axes,
1293                      const Eigen::internal::ProdReducer<T>& reducer) {
1294     ReduceImpl<T, Prod<T>, T*, T*, ReductionAxes>(
1295         ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
1296         in.rank() >= 2 ? in.dimension(1) : 1,
1297         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
1298         Prod<T>());
1299   }
1300 
1301   template <typename OUT_T>
1302   static void FillIdentity(const GPUDevice& d, OUT_T out,
1303                            const Eigen::internal::ProdReducer<T>& reducer) {
1304     FillIdentityEigenImpl(d, out, reducer);
1305   }
1306 };
1307 
1308 template <>
1309 struct ReduceFunctor<GPUDevice, Eigen::internal::AndReducer> {
1310   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1311   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
1312                      const ReductionAxes& reduction_axes,
1313                      const Eigen::internal::AndReducer& reducer) {
1314     ReduceImpl<bool, And, bool*, bool*, ReductionAxes>(
1315         ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0),
1316         in.rank() >= 2 ? in.dimension(1) : 1,
1317         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
1318         And());
1319   }
1320 
1321   template <typename OUT_T>
1322   static void FillIdentity(const GPUDevice& d, OUT_T out,
1323                            const Eigen::internal::AndReducer& reducer) {
1324     FillIdentityEigenImpl(d, out, reducer);
1325   }
1326 };
1327 
1328 template <>
1329 struct ReduceFunctor<GPUDevice, Eigen::internal::OrReducer> {
1330   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1331   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
1332                      const ReductionAxes& reduction_axes,
1333                      const Eigen::internal::OrReducer& reducer) {
1334     ReduceImpl<bool, Or, bool*, bool*, ReductionAxes>(
1335         ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0),
1336         in.rank() >= 2 ? in.dimension(1) : 1,
1337         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, Or());
1338   }
1339 
1340   template <typename OUT_T>
1341   static void FillIdentity(const GPUDevice& d, OUT_T out,
1342                            const Eigen::internal::OrReducer& reducer) {
1343     FillIdentityEigenImpl(d, out, reducer);
1344   }
1345 };
1346 
1347 }  // namespace functor
1348 }  // namespace tensorflow
1349 
1350 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1351 
1352 #endif  // TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
1353