xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/eigen_cuboid_convolution.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
17 #define TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 
21 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
22 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
23 #endif
24 
25 #include "tensorflow/core/kernels/eigen_convolution_helpers.h"
26 
27 namespace Eigen {
28 
29 namespace internal {
30 
31 #if !EIGEN_ALTIVEC_USE_CUSTOM_PACK
32 // WARNING: Most of the code here implicitly assumes that the matrix is in
33 // ColMajor layout. This is guaranteed by the tensor contraction (see
34 // TensorContraction.h).
35 //
36 // Inside Eigen a tensor contraction is represented by a matrix multiplication.
37 // We don't want to actually extract volume patches and reshape the result into
38 // a matrix (this involves allocating huge extra memory), so the patch
39 // extraction and reshape operations are implicit.
40 //
41 // TensorContractionInputMapper takes a matrix index and returns the coefficient
42 // (or the packet) of the "virtual tensor", that would be at that index if we
43 // were to actually reshape the result of patch extraction.
44 //
45 // TensorContractionSubMapper provides a similar view into the "virtual matrix"
46 // at the given vertical and horizontal offsets.
47 //
48 // "Virtual matrix" dimensions:
49 //   *0: kernelChannels * kernelPlanes * kernelRows * kernelCols
50 //    1: out_planes * out_height * out_width * OTHERS (e.g batches, etc...)
51 //
52 // *) extracted patches are continuous in memory (innermost dimension assuming
53 //    col major layout)
54 //
55 // With this dimensions:
56 //   row - offset within a single patch (in code: patchId)
57 //   col - index of the extracted patch (in code: patchIndex)
58 //         patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
59 //
60 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
61           typename ArgType, typename Device, typename Scalar_, typename Index,
62           typename nocontract_t, typename contract_t, int Side, int packet_size,
63           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
64 class TensorContractionInputMapper<
65     Scalar_, Index, Side,
66     TensorEvaluator<const TensorReshapingOp<NewDimension,
67                                             const TensorVolumePatchOp<
68                                                 Planes, Rows, Cols, ArgType> >,
69                     Device>,
70     nocontract_t, contract_t, packet_size, inner_dim_contiguous,
71     inner_dim_reordered, Alignment> {
72  public:
73   typedef Scalar_ Scalar;
74   typedef TensorContractionInputMapper<
75       Scalar, Index, Side,
76       TensorEvaluator<const TensorReshapingOp<
77                           NewDimension, const TensorVolumePatchOp<
78                                             Planes, Rows, Cols, ArgType> >,
79                       Device>,
80       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
81       inner_dim_reordered, Alignment>
82       Self;
83   typedef TensorContractionSubMapper<
84       Scalar, Index, Side,
85       TensorEvaluator<const TensorReshapingOp<
86                           NewDimension, const TensorVolumePatchOp<
87                                             Planes, Rows, Cols, ArgType> >,
88                       Device>,
89       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
90       inner_dim_reordered, Alignment>
91       SubMapper;
92   typedef SubMapper VectorMapper;
93   typedef SubMapper LinearMapper;
94   typedef typename packet_traits<Scalar>::type Packet;
95 
96   EIGEN_DEVICE_FUNC
TensorContractionInputMapper(const TensorEvaluator<const TensorReshapingOp<NewDimension,const TensorVolumePatchOp<Planes,Rows,Cols,ArgType>>,Device> & tensor,const nocontract_t &,const nocontract_t &,const contract_t &,const contract_t &)97   TensorContractionInputMapper(
98       const TensorEvaluator<
99           const TensorReshapingOp<
100               NewDimension,
101               const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >,
102           Device>& tensor,
103       const nocontract_t&, const nocontract_t&, const contract_t&,
104       const contract_t&)
105       : m_impl(tensor.impl().impl()) {
106     if (internal::traits<ArgType>::Layout == ColMajor) {
107       m_patch_depth = tensor.impl().dimensions()[0];
108       m_patch_planes = tensor.impl().dimensions()[1];
109       m_patch_rows = tensor.impl().dimensions()[2];
110       m_patch_cols = tensor.impl().dimensions()[3];
111       m_num_patches = tensor.impl().dimensions()[4];
112     } else {
113       const int NumDims = tensor.impl().dimensions().size();
114       m_patch_depth = tensor.impl().dimensions()[NumDims - 1];
115       m_patch_planes = tensor.impl().dimensions()[NumDims - 2];
116       m_patch_rows = tensor.impl().dimensions()[NumDims - 3];
117       m_patch_cols = tensor.impl().dimensions()[NumDims - 4];
118       m_num_patches = tensor.impl().dimensions()[NumDims - 5];
119     }
120 
121     // Strides for navigating through the single patch.
122     m_patch_plane_stride = m_patch_depth;
123     m_patch_row_stride = m_patch_planes * m_patch_plane_stride;
124     m_patch_col_stride = m_patch_rows * m_patch_row_stride;
125 
126     // Strides for the output tensor.
127     // IMPORTANT: These strides are used to locate an element in a patch at a
128     // depth zero (channel), which is not quite the same as "traditional"
129     // stride.
130     m_rowStride = m_patch_planes;
131     m_colStride = m_patch_rows * m_rowStride;
132     m_patchStride = m_colStride * m_patch_cols * m_patch_depth;
133     m_otherStride = m_patchStride * m_num_patches;
134 
135     m_outputPlanes = tensor.impl().outputPlanes();
136     m_outputRows = tensor.impl().outputRows();
137     m_outputCols = tensor.impl().outputCols();
138 
139     m_outputPlanesRows = m_outputPlanes * m_outputRows;
140 
141     m_plane_strides = tensor.impl().userPlaneStride();
142     m_row_strides = tensor.impl().userRowStride();
143     m_col_strides = tensor.impl().userColStride();
144 
145     m_in_plane_strides = tensor.impl().userInPlaneStride();
146     m_in_row_strides = tensor.impl().userInRowStride();
147     m_in_col_strides = tensor.impl().userInColStride();
148 
149     m_patch_plane_inflate_strides = tensor.impl().planeInflateStride();
150     m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
151     m_patch_col_inflate_strides = tensor.impl().colInflateStride();
152 
153     if (internal::traits<ArgType>::Layout == ColMajor) {
154       m_inputDepth = tensor.impl().impl().dimensions()[0];
155       m_inputPlanes = tensor.impl().impl().dimensions()[1];
156       m_inputRows = tensor.impl().impl().dimensions()[2];
157       m_inputCols = tensor.impl().impl().dimensions()[3];
158     } else {
159       const int NumDims = tensor.impl().impl().dimensions().size();
160       m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1];
161       m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2];
162       m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3];
163       m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4];
164     }
165 
166     // Strides for navigating through the input tensor.
167     m_planeInputStride = m_inputDepth;
168     m_rowInputStride = m_inputDepth * m_inputPlanes;
169     m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes;
170     m_patchInputStride =
171         m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes;
172 
173     m_planePaddingTop = tensor.impl().planePaddingTop();
174     m_rowPaddingTop = tensor.impl().rowPaddingTop();
175     m_colPaddingLeft = tensor.impl().colPaddingLeft();
176 
177     m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
178 
179     m_fastPatchPlaneStride =
180         internal::TensorIntDivisor<Index>(m_patch_plane_stride);
181     m_fastPatchRowStride =
182         internal::TensorIntDivisor<Index>(m_patch_row_stride);
183     m_fastPatchColStride =
184         internal::TensorIntDivisor<Index>(m_patch_col_stride);
185 
186     m_fastInputPlaneStride =
187         internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides);
188     m_fastInputRowStride =
189         internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
190     m_fastInputColStride =
191         internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
192 
193     m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride);
194     m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
195 
196     m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth);
197     m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
198     m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes);
199     m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
200     m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols);
201 
202     m_fastOutputPlanesRows =
203         internal::TensorIntDivisor<Index>(m_outputPlanesRows);
204   }
205 
206   EIGEN_DEVICE_FUNC
TensorContractionInputMapper(const TensorContractionInputMapper & base_mapper)207   TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
208       : m_impl(base_mapper.m_impl) {
209     m_patch_depth = base_mapper.m_patch_depth;
210     m_patch_planes = base_mapper.m_patch_planes;
211     m_patch_rows = base_mapper.m_patch_rows;
212     m_patch_cols = base_mapper.m_patch_cols;
213     m_num_patches = base_mapper.m_num_patches;
214 
215     m_patch_plane_stride = base_mapper.m_patch_plane_stride;
216     m_patch_row_stride = base_mapper.m_patch_row_stride;
217     m_patch_col_stride = base_mapper.m_patch_col_stride;
218 
219     m_rowStride = base_mapper.m_rowStride;
220     m_colStride = base_mapper.m_colStride;
221     m_patchStride = base_mapper.m_patchStride;
222     m_otherStride = base_mapper.m_otherStride;
223 
224     m_planeInputStride = base_mapper.m_planeInputStride;
225     m_rowInputStride = base_mapper.m_rowInputStride;
226     m_colInputStride = base_mapper.m_colInputStride;
227     m_patchInputStride = base_mapper.m_patchInputStride;
228     m_otherInputStride = base_mapper.m_otherInputStride;
229 
230     m_inputDepth = base_mapper.m_inputDepth;
231     m_inputPlanes = base_mapper.m_inputPlanes;
232     m_inputRows = base_mapper.m_inputRows;
233     m_inputCols = base_mapper.m_inputCols;
234 
235     m_outputPlanes = base_mapper.m_outputPlanes;
236     m_outputRows = base_mapper.m_outputRows;
237     m_outputCols = base_mapper.m_outputCols;
238 
239     m_plane_strides = base_mapper.m_plane_strides;
240     m_row_strides = base_mapper.m_row_strides;
241     m_col_strides = base_mapper.m_col_strides;
242 
243     m_in_plane_strides = base_mapper.m_in_plane_strides;
244     m_in_row_strides = base_mapper.m_in_row_strides;
245     m_in_col_strides = base_mapper.m_in_col_strides;
246 
247     m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides;
248     m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
249     m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
250 
251     m_planePaddingTop = base_mapper.m_planePaddingTop;
252     m_rowPaddingTop = base_mapper.m_rowPaddingTop;
253     m_colPaddingLeft = base_mapper.m_colPaddingLeft;
254 
255     m_outputPlanesRows = base_mapper.m_outputPlanesRows;
256 
257     m_fastNumPatches = base_mapper.m_fastNumPatches;
258     m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride;
259     m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
260     m_fastPatchColStride = base_mapper.m_fastPatchColStride;
261     m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride;
262     m_fastInputRowStride = base_mapper.m_fastInputRowStride;
263     m_fastInputColStride = base_mapper.m_fastInputColStride;
264     m_fastRowStride = base_mapper.m_fastRowStride;
265     m_fastColStride = base_mapper.m_fastColStride;
266     m_fastOutputPlanes = base_mapper.m_fastOutputPlanes;
267     m_fastOutputRows = base_mapper.m_fastOutputRows;
268     m_fastOutputCols = base_mapper.m_fastOutputCols;
269     m_fastDimZero = base_mapper.m_fastDimZero;
270     m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows;
271   }
272 
273   // If true, turns off some optimizations for loading packets since the image
274   // patches are "non-standard" such as there are non-trivial strides or
275   // inflations in the input.
276   EIGEN_DEVICE_FUNC
nonStandardPatches()277   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
278     return m_in_plane_strides != 1 || m_in_row_strides != 1 ||
279            m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 ||
280            m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
281   }
282 
283   EIGEN_DEVICE_FUNC
getSubMapper(Index i,Index j)284   EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
285     return SubMapper(*this, i, j);
286   }
287 
288   EIGEN_DEVICE_FUNC
getLinearMapper(Index i,Index j)289   EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
290     return LinearMapper(*this, i, j);
291   }
292 
293   EIGEN_DEVICE_FUNC
operator()294   EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
295     Index planeIndex, rowIndex, colIndex, otherIndex;
296     computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
297     return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
298   }
299 
300   // Load the coefficient at the patchIndex location instead of the usual
301   // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the
302   // gpu code.
303   EIGEN_DEVICE_FUNC
operator()304   EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
305     Index planeIndex, rowIndex, colIndex, otherIndex;
306     computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
307     return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
308   }
309 
310   EIGEN_DEVICE_FUNC
loadPacket(Index row)311   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
312     Index planeIndex, rowIndex, colIndex, otherIndex;
313     computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
314     return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
315   }
316 
317   // Load the packet at the patchIndex location instead of the usual m_rowIndex,
318   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
319   EIGEN_DEVICE_FUNC
loadPacket(Index row,Index patchIndex)320   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
321     Index planeIndex, rowIndex, colIndex, otherIndex;
322     computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
323     return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
324   }
325 
326   EIGEN_DEVICE_FUNC
impl()327   EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
328     return m_impl;
329   }
330 
331   EIGEN_DEVICE_FUNC
patchDepth()332   EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; }
333   EIGEN_DEVICE_FUNC
patchPlanes()334   EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; }
335   EIGEN_DEVICE_FUNC
patchRows()336   EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; }
337   EIGEN_DEVICE_FUNC
patchCols()338   EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
339 
340  private:
341   friend class TensorContractionSubMapper<
342       Scalar, Index, Side,
343       TensorEvaluator<const TensorReshapingOp<
344                           NewDimension, const TensorVolumePatchOp<
345                                             Planes, Rows, Cols, ArgType> >,
346                       Device>,
347       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
348       inner_dim_reordered, Alignment>;
349 
350   // Load coefficient from a patch specified by the "within patch offset"
351   // (patchId) and the precomputed indices of the first element of the patch.
352   EIGEN_DEVICE_FUNC
loadCoeff(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)353   EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex,
354                                        Index rowIndex, Index colIndex,
355                                        Index otherIndex) const {
356     // Find the offset of the element wrt the location of the first element.
357     const Index patchOffset = patchId / m_fastDimZero;
358 
359     const Index colOffset = patchOffset / m_fastColStride;
360     const Index inputCol = colIndex + colOffset * m_in_col_strides;
361     const Index origInputCol =
362         (m_patch_col_inflate_strides == 1)
363             ? inputCol
364             : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
365 
366     const Index rowOffset =
367         (patchOffset - colOffset * m_colStride) / m_fastRowStride;
368     const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
369     const Index origInputRow =
370         (m_patch_row_inflate_strides == 1)
371             ? inputRow
372             : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
373 
374     const Index planeOffset =
375         patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
376     const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides;
377     const Index origInputPlane =
378         (m_patch_plane_inflate_strides == 1)
379             ? inputPlane
380             : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
381 
382     if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 ||
383         origInputCol >= m_inputCols || origInputRow >= m_inputRows ||
384         origInputPlane >= m_inputPlanes ||
385         (inputCol != origInputCol * m_patch_col_inflate_strides) ||
386         (inputRow != origInputRow * m_patch_row_inflate_strides) ||
387         (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) {
388       return Scalar(0);
389     }
390 
391     const Index depth = patchId - patchOffset * patchDepth();
392     const Index inputIndex = depth + origInputPlane * m_planeInputStride +
393                              origInputRow * m_rowInputStride +
394                              origInputCol * m_colInputStride + otherIndex;
395 
396     return m_impl.coeff(inputIndex);
397   }
398 
399   // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
400   // and `in_strides` equal to 1 (template specialization without templates).
401   EIGEN_DEVICE_FUNC
loadCoeffStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)402   EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex,
403                                                Index rowIndex, Index colIndex,
404                                                Index otherIndex) const {
405     eigen_assert(!nonStandardPatches());
406 
407     // Find the offset of the element wrt the location of the first element.
408     const Index patchOffset = patchId / m_fastDimZero;
409 
410     const Index colOffset = patchOffset / m_fastColStride;
411     const Index rowOffset =
412         (patchOffset - colOffset * m_colStride) / m_fastRowStride;
413     const Index planeOffset =
414         patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
415 
416     const Index inputCol = colIndex + colOffset;
417     const Index inputRow = rowIndex + rowOffset;
418     const Index inputPlane = planeIndex + planeOffset;
419 
420     if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
421         inputRow >= m_inputRows || inputPlane < 0 ||
422         inputPlane >= m_inputPlanes) {
423       return Scalar(0);
424     }
425 
426     const Index depth = patchId - patchOffset * patchDepth();
427     const Index inputIndex = depth + inputPlane * m_planeInputStride +
428                              inputRow * m_rowInputStride +
429                              inputCol * m_colInputStride + otherIndex;
430 
431     return m_impl.coeff(inputIndex);
432   }
433 
434   // Load packet from a patch specified by the "within patch offset"
435   // (patchId) and the precomputed indices of the first element of the patch.
436   EIGEN_DEVICE_FUNC
loadPacket(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)437   EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex,
438                                         Index rowIndex, Index colIndex,
439                                         Index otherIndex) const {
440     const Index packetSize = internal::unpacket_traits<Packet>::size;
441 
442     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
443     eigen_assert(patchId <
444                  patchDepth() * patchPlanes() * patchRows() * patchCols());
445 
446     if (nonStandardPatches()) {
447       return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
448                                     otherIndex);
449     }
450     typedef decltype(m_impl) TensorEvaluatorT;
451     return loadPacketStandard<Packet, TensorEvaluatorT>(
452         patchId, planeIndex, rowIndex, colIndex, otherIndex);
453   }
454 
455   // Helper function to load a 'partial' packet - this is the single row part of
456   // a packet that is split across two rows (but single column). In the
457   // 'partial' packet, the elements corresponding to the row (specified through
458   // rowOffset) are loaded and the rest of the elements are zero-filled into the
459   // 'partial' packet. This function is called from
460   // loadPacketStandardFromSingleColumnTwoRows(). This code path is exercised
461   // only when the packet type supports masked load and when the partial packet
462   // load is available in the TensorEvaluator.
463   EIGEN_DEVICE_FUNC
loadPartialPacketStandard(Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex,Index patchId,const Index span[],const Index patchOffsets[],Index colOffset,Index rowOffset)464   EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard(
465       Index planeIndex, Index rowIndex, Index colIndex, Index otherIndex,
466       Index patchId, const Index span[], const Index patchOffsets[],
467       Index colOffset, Index rowOffset) const {
468     const Index inputCol = colIndex + colOffset;
469     const Index inputRow = rowIndex + rowOffset;
470     const Index planeOffsets[2] = {
471         patchOffsets[0] - colOffset * m_colStride - rowOffset * m_rowStride,
472         patchOffsets[1] - colOffset * m_colStride - rowOffset * m_rowStride};
473     const Index inputPlanes[2] = {planeIndex + planeOffsets[0],
474                                   planeIndex + planeOffsets[1]};
475 
476     if (inputRow >= m_inputRows || inputRow < 0 || inputCol >= m_inputCols ||
477         inputCol < 0 || inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) {
478       // Partial packet is all zeros
479       return internal::pset1<Packet>(Scalar(0));
480     } else if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
481       // From inputIndex-span[0], we need to load elements starting from index
482       // span[0] all the way upto (and including) span[1].
483       const Index depth = patchId - patchOffsets[0] * patchDepth();
484       const Index inputIndex = depth + inputPlanes[0] * m_planeInputStride +
485                                inputRow * m_rowInputStride +
486                                inputCol * m_colInputStride + otherIndex;
487       return m_impl.template partialPacket<Packet>(
488           inputIndex - span[0], mask<Packet>(span[0], span[1] + 1));
489     } else {
490       // Using slow path for this partial packet.
491       // We need to load elements starting from index span[0] all the way upto
492       // (and including) span[1]. We split this load into 3 parts:
493       // 0 : span[0]-1 - Zeros will be loaded for these indices
494       // span[0] : span[1] - Elements will be loaded here for these indices
495       // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices
496       const Index packetSize = internal::unpacket_traits<Packet>::size;
497       EIGEN_ALIGN_MAX
498       std::remove_const_t<Scalar> values[packetSize];
499       for (int i = 0; i < span[0]; ++i) values[i] = Scalar(0);
500       for (int i = span[0]; i < span[1] + 1; ++i)
501         values[i] = loadCoeff(patchId - span[0] + i, planeIndex, rowIndex,
502                               colIndex, otherIndex);
503       for (int i = span[1] + 1; i < packetSize; ++i) values[i] = Scalar(0);
504       return internal::pload<Packet>(values);
505     }
506   }
507 
508   // Helper function to load a packet that is split across two rows (but single
509   // column). If required, this function is called from loadPacketStandard()
510   // when the packet type supports masked load and when the partial packet load
511   // is available in the TensorEvaluator.
512   EIGEN_DEVICE_FUNC
loadPacketStandardFromSingleColumnTwoRows(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[],const Index rowOffsets[])513   EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumnTwoRows(
514       Index patchId, Index planeIndex, Index rowIndex, Index colIndex,
515       Index otherIndex, const Index patchOffsets[], const Index colOffsets[],
516       const Index rowOffsets[]) const {
517     eigen_assert(colOffsets[1] == colOffsets[0] &&
518                  rowOffsets[1] == rowOffsets[0] + 1);
519     const Index packetSize = internal::unpacket_traits<Packet>::size;
520 
521     // Packet to load will be split into 2 parts where each part spans a single
522     // row and both the parts span the same column.
523     // First determine where to split.
524     const Index patchIdSplit =
525         (((rowOffsets[1] * m_rowStride) + (colOffsets[0] * m_colStride)) *
526          m_patch_depth) -
527         1;
528     const Index patchOffsetSplit = patchIdSplit / m_fastDimZero;
529 
530     // patchIds[i]:          patchId corresponding to partial packet i
531     // spans[i]:             Start and end indices corresponding to the elements
532     //                       to be loaded for partial packet i
533     // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i
534     const Index patchIds[2] = {patchId, patchIdSplit + 1};
535     const Index spans[2][2] = {{0, patchIdSplit - patchId},
536                                {patchIdSplit - patchId + 1, packetSize - 1}};
537     const Index patchOffsets2Cols[2][2] = {
538         {patchOffsets[0], patchOffsetSplit},
539         {patchOffsetSplit + 1, patchOffsets[1]}};
540 
541     // Load partial packets and do bit-wise OR to generate required packet
542     return internal::por<Packet>(
543         loadPartialPacketStandard(planeIndex, rowIndex, colIndex, otherIndex,
544                                   patchIds[0], spans[0], patchOffsets2Cols[0],
545                                   colOffsets[0], rowOffsets[0]),
546         loadPartialPacketStandard(planeIndex, rowIndex, colIndex, otherIndex,
547                                   patchIds[1], spans[1], patchOffsets2Cols[1],
548                                   colOffsets[1], rowOffsets[1]));
549   }
550 
551   // Helper function to load a packet that is present in a single column and
552   // row. If required, this function is called from loadPacketStandard().
553   EIGEN_DEVICE_FUNC
loadPacketStandardFromSingleColumnSingleRow(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[],const Index rowOffsets[],const Index inputCols[],const Index inputRows[])554   EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumnSingleRow(
555       Index patchId, Index planeIndex, Index rowIndex, Index colIndex,
556       Index otherIndex, const Index patchOffsets[], const Index colOffsets[],
557       const Index rowOffsets[], const Index inputCols[],
558       const Index inputRows[]) const {
559     eigen_assert(colOffsets[1] == colOffsets[0] &&
560                  rowOffsets[1] == rowOffsets[0]);
561     const Index planeOffsets[2] = {
562         patchOffsets[0] - colOffsets[0] * m_colStride -
563             rowOffsets[0] * m_rowStride,
564         patchOffsets[1] - colOffsets[1] * m_colStride -
565             rowOffsets[1] * m_rowStride};
566     eigen_assert(planeOffsets[0] <= planeOffsets[1]);
567     const Index inputPlanes[2] = {planeIndex + planeOffsets[0],
568                                   planeIndex + planeOffsets[1]};
569 
570     if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) {
571       return internal::pset1<Packet>(Scalar(0));
572     }
573     if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
574       const Index depth = patchId - patchOffsets[0] * patchDepth();
575       const Index inputIndex = depth + inputPlanes[0] * m_planeInputStride +
576                                inputRows[0] * m_rowInputStride +
577                                inputCols[0] * m_colInputStride + otherIndex;
578       return m_impl.template packet<Unaligned>(inputIndex);
579     }
580     return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
581                                   otherIndex);
582   }
583 
584   // Load standard packet from a patch specified by the "within patch offset"
585   // (patchId) and the precomputed indices of the first element of the patch.
586   // This function will be called if partial packet loading is not available
587   // for the TensorEvaluator or if the packet type does not support masked
588   // load.
589   template <typename PacketT, typename TensorEvaluatorT>
590   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
591       !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
592       PacketT>::type
loadPacketStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)593   loadPacketStandard(Index patchId, Index planeIndex, Index rowIndex,
594                      Index colIndex, Index otherIndex) const {
595     const Index packetSize = internal::unpacket_traits<Packet>::size;
596     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
597     eigen_assert(patchId <
598                  patchDepth() * patchPlanes() * patchRows() * patchCols());
599     eigen_assert(!nonStandardPatches());
600 
601     if ((patchDepth() % packetSize) == 0) {
602       return loadPacketFast(patchId, planeIndex, rowIndex, colIndex,
603                             otherIndex);
604     } else {
605       // Offsets and input calculation here are identical to
606       // loadCoeffStandard(...), but repeated twice.
607 
608       const Index patchOffsets[2] = {
609           patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
610 
611       const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
612                                    patchOffsets[1] / m_fastColStride};
613       eigen_assert(colOffsets[0] <= colOffsets[1]);
614 
615       const Index inputCols[2] = {colIndex + colOffsets[0],
616                                   colIndex + colOffsets[1]};
617       if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
618         return internal::pset1<Packet>(Scalar(0));
619       }
620 
621       if (inputCols[0] == inputCols[1]) {
622         const Index rowOffsets[2] = {
623             (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
624             (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
625         eigen_assert(rowOffsets[0] <= rowOffsets[1]);
626         const Index inputRows[2] = {rowIndex + rowOffsets[0],
627                                     rowIndex + rowOffsets[1]};
628 
629         if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
630           return internal::pset1<Packet>(Scalar(0));
631         }
632 
633         if (inputRows[0] == inputRows[1]) {
634           return loadPacketStandardFromSingleColumnSingleRow(
635               patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets,
636               colOffsets, rowOffsets, inputCols, inputRows);
637         }
638       }
639     }
640 
641     return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
642                                   otherIndex);
643   }
644 
645   // Load standard packet from a patch specified by the "within patch offset"
646   // (patchId) and the precomputed indices of the first element of the patch.
647   // This function will be called if partial packet loading is available for
648   // the TensorEvaluator and if the packet type supports masked load.
649   // The only difference between this and the other case is that if the packet
650   // to load is split across two rows (but in same column), then in this case
651   // instead of going to the slow (element-by-element) load, we load two packets
652   // - each containing elements from one of the rows (rest of the elements of
653   // the packets are zeroes), and then combine these two packets to generate the
654   // required packet. The idea is to enable fast load (if possible) of these
655   // 'partial' packets.
656   template <typename PacketT, typename TensorEvaluatorT>
657   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
658       TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
659       PacketT>::type
loadPacketStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)660   loadPacketStandard(Index patchId, Index planeIndex, Index rowIndex,
661                      Index colIndex, Index otherIndex) const {
662     const Index packetSize = internal::unpacket_traits<Packet>::size;
663     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
664     eigen_assert(patchId <
665                  patchDepth() * patchPlanes() * patchRows() * patchCols());
666     eigen_assert(!nonStandardPatches());
667 
668     if ((patchDepth() % packetSize) == 0) {
669       return loadPacketFast(patchId, planeIndex, rowIndex, colIndex,
670                             otherIndex);
671     } else {
672       // Offsets and input calculation here are identical to
673       // loadCoeffStandard(...), but repeated twice.
674 
675       const Index patchOffsets[2] = {
676           patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
677 
678       const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
679                                    patchOffsets[1] / m_fastColStride};
680       eigen_assert(colOffsets[0] <= colOffsets[1]);
681 
682       const Index inputCols[2] = {colIndex + colOffsets[0],
683                                   colIndex + colOffsets[1]};
684       if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
685         return internal::pset1<Packet>(Scalar(0));
686       }
687 
688       if (inputCols[0] == inputCols[1]) {
689         const Index rowOffsets[2] = {
690             (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
691             (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
692         eigen_assert(rowOffsets[0] <= rowOffsets[1]);
693         const Index inputRows[2] = {rowIndex + rowOffsets[0],
694                                     rowIndex + rowOffsets[1]};
695 
696         if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
697           return internal::pset1<Packet>(Scalar(0));
698         }
699 
700         if (inputRows[0] == inputRows[1]) {
701           return loadPacketStandardFromSingleColumnSingleRow(
702               patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets,
703               colOffsets, rowOffsets, inputCols, inputRows);
704         }
705         if (inputRows[0] + 1 == inputRows[1]) {
706           return loadPacketStandardFromSingleColumnTwoRows(
707               patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets,
708               colOffsets, rowOffsets);
709         }
710       }
711     }
712 
713     return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
714                                   otherIndex);
715   }
716 
717   EIGEN_DEVICE_FUNC
loadPacketFast(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)718   EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex,
719                                             Index rowIndex, Index colIndex,
720                                             Index otherIndex) const {
721     const Index packetSize = internal::unpacket_traits<Packet>::size;
722     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
723     eigen_assert(patchId <
724                  patchDepth() * patchPlanes() * patchRows() * patchCols());
725 
726     eigen_assert(!nonStandardPatches());
727     eigen_assert((patchDepth() % packetSize) == 0);
728 
729     // Find the offset of the element wrt the location of the first element.
730     const Index patchOffset = patchId / m_fastDimZero;
731     eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
732 
733     const Index colOffset = patchOffset / m_fastColStride;
734     const Index rowOffset =
735         (patchOffset - colOffset * m_colStride) / m_fastRowStride;
736     const Index planeOffset =
737         patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
738 
739     const Index inputCol = colIndex + colOffset;
740     const Index inputRow = rowIndex + rowOffset;
741     const Index inputPlane = planeIndex + planeOffset;
742 
743     if (inputCol < 0 || inputRow < 0 || inputPlane < 0 ||
744         inputCol >= m_inputCols || inputRow >= m_inputRows ||
745         inputPlane >= m_inputPlanes) {
746       return internal::pset1<Packet>(Scalar(0));
747     }
748 
749     const Index depth = patchId - patchOffset * patchDepth();
750     const Index inputIndex = depth + inputPlane * m_planeInputStride +
751                              inputRow * m_rowInputStride +
752                              inputCol * m_colInputStride + otherIndex;
753     return m_impl.template packet<Unaligned>(inputIndex);
754   }
755 
756   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
packetWithPossibleZero(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)757   packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex,
758                          Index colIndex, Index otherIndex) const {
759     const int packetSize = internal::unpacket_traits<Packet>::size;
760     EIGEN_ALIGN_MAX
761     std::remove_const_t<Scalar> values[packetSize];
762     for (int i = 0; i < packetSize; ++i) {
763       values[i] =
764           loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex);
765     }
766     Packet rslt = internal::pload<Packet>(values);
767     return rslt;
768   }
769 
770   // Precompute the indices (plane, row, col, other) of the first element of
771   // the given patch index, within the output tensor of the TensorVolumePatchOp.
computeBaseIndices(Index patchIndex,Index & planeIndex,Index & rowIndex,Index & colIndex,Index & otherIndex)772   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
773       Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex,
774       Index& otherIndex) const {
775     const size_t NumInputDims = array_size<
776         typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
777 
778     // Check if patchIndex might contain batch and other dimensions.
779     otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches;
780 
781     // Compute index of the patch within the batch (and other dimensions).
782     const Index patch3DIndex = (NumInputDims == 4)
783                                    ? patchIndex
784                                    : (patchIndex - otherIndex * m_num_patches);
785 
786     otherIndex *= m_patchInputStride;
787 
788     colIndex = patch3DIndex / m_fastOutputPlanesRows;
789     rowIndex =
790         (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
791     planeIndex =
792         patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes;
793 
794     colIndex = colIndex * m_col_strides - m_colPaddingLeft;
795     rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
796     planeIndex = planeIndex * m_plane_strides - m_planePaddingTop;
797   }
798 
799   Index m_patch_depth;   // number of channels in the patch
800   Index m_patch_planes;  // number of planes in the patch
801   Index m_patch_rows;    // number of rows in the patch
802   Index m_patch_cols;    // number of columns in the patch
803   Index m_num_patches;   // number of patches to extract
804 
805   // Strides for navigating through the single patch.
806   Index m_patch_plane_stride;
807   Index m_patch_row_stride;
808   Index m_patch_col_stride;
809 
810   // Strides for the output tensor (depth is not the part of the stride).
811   Index m_rowStride;
812   Index m_colStride;
813   Index m_patchStride;
814   Index m_otherStride;
815 
816   Index m_planeInputStride;  // Plane stride in the input tensor
817   Index m_rowInputStride;    // Row stride in the input tensor
818   Index m_colInputStride;    // Col stride in the input tensor
819   Index m_patchInputStride;  // Patch stride in the input tensor
820   Index m_otherInputStride;
821 
822   Index m_inputDepth;   // Depth of the input tensor
823   Index m_inputPlanes;  // Number of planes in the input tensor
824   Index m_inputRows;    // Number of rows in the input tensor
825   Index m_inputCols;    // Number of cols in the input tensor
826 
827   Index m_outputPlanes;      // Number of output planes
828   Index m_outputRows;        // Number of output rows
829   Index m_outputCols;        // Number of output cols
830   Index m_outputPlanesRows;  // Cached outputPlanes * outputRows.
831 
832   Index m_plane_strides;  // User specified plane stride
833   Index m_row_strides;    // User specified row stride
834   Index m_col_strides;    // User specified col stride
835 
836   // User specified plane/row/col atrous convolution strides.
837   Index m_in_plane_strides;
838   Index m_in_row_strides;
839   Index m_in_col_strides;
840 
841   // User specified plane/row/col inflation strides in the image patch.
842   Index m_patch_plane_inflate_strides;
843   Index m_patch_row_inflate_strides;
844   Index m_patch_col_inflate_strides;
845 
846   Index m_planePaddingTop;  // Plane padding
847   Index m_rowPaddingTop;    // Row padding
848   Index m_colPaddingLeft;   // Column padding
849 
850   // Fast representation of various divisors.
851   internal::TensorIntDivisor<Index> m_fastNumPatches;
852 
853   internal::TensorIntDivisor<Index> m_fastPatchPlaneStride;
854   internal::TensorIntDivisor<Index> m_fastPatchRowStride;
855   internal::TensorIntDivisor<Index> m_fastPatchColStride;
856 
857   internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
858   internal::TensorIntDivisor<Index> m_fastInputRowStride;
859   internal::TensorIntDivisor<Index> m_fastInputColStride;
860 
861   internal::TensorIntDivisor<Index> m_fastRowStride;
862   internal::TensorIntDivisor<Index> m_fastColStride;
863 
864   internal::TensorIntDivisor<Index> m_fastDimZero;  // aka output depth
865   internal::TensorIntDivisor<Index> m_fastOutputPlanes;
866   internal::TensorIntDivisor<Index> m_fastOutputRows;
867   internal::TensorIntDivisor<Index> m_fastOutputCols;
868   internal::TensorIntDivisor<Index> m_fastOutputPlanesRows;
869 
870   const TensorEvaluator<ArgType, Device> m_impl;
871 };
872 
873 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
874           typename ArgType, typename Device, typename Scalar, typename Index,
875           typename nocontract_t, typename contract_t, int Side, int packet_size,
876           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
877 class TensorContractionSubMapper<
878     Scalar, Index, Side,
879     TensorEvaluator<const TensorReshapingOp<NewDimension,
880                                             const TensorVolumePatchOp<
881                                                 Planes, Rows, Cols, ArgType> >,
882                     Device>,
883     nocontract_t, contract_t, packet_size, inner_dim_contiguous,
884     inner_dim_reordered, Alignment> {
885  public:
886   typedef typename packet_traits<Scalar>::type Packet;
887   typedef typename packet_traits<Scalar>::half HalfPacket;
888 
889   typedef TensorContractionInputMapper<
890       Scalar, Index, Side,
891       TensorEvaluator<const TensorReshapingOp<
892                           NewDimension, const TensorVolumePatchOp<
893                                             Planes, Rows, Cols, ArgType> >,
894                       Device>,
895       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
896       inner_dim_reordered, Alignment>
897       ParentMapper;
898   typedef TensorContractionSubMapper<
899       Scalar, Index, Side,
900       TensorEvaluator<const TensorReshapingOp<
901                           NewDimension, const TensorVolumePatchOp<
902                                             Planes, Rows, Cols, ArgType> >,
903                       Device>,
904       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
905       inner_dim_reordered, Alignment>
906       Self;
907   typedef Self LinearMapper;
908 
TensorContractionSubMapper(const ParentMapper & base_mapper,Index vert_offset,Index horiz_offset)909   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
910       const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
911       : m_base_mapper(base_mapper),
912         m_depth_offset(vert_offset),
913         m_col_offset(horiz_offset) {
914     m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
915                                      m_colIndex, m_otherIndex);
916   }
TensorContractionSubMapper(const Self & base_mapper,Index vert_offset,Index horiz_offset)917   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
918       const Self& base_mapper, Index vert_offset, Index horiz_offset)
919       : m_base_mapper(base_mapper.m_base_mapper),
920         m_depth_offset(vert_offset + base_mapper.m_depth_offset),
921         m_col_offset(horiz_offset + base_mapper.m_col_offset) {
922     m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
923                                      m_colIndex, m_otherIndex);
924   }
operator()925   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
926     return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex,
927                                    m_colIndex, m_otherIndex);
928   }
operator()929   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
930                                                           Index j) const {
931     return m_base_mapper(i + m_depth_offset, j + m_col_offset);
932   }
933 
loadPacket(Index i)934   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
935     return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex,
936                                     m_rowIndex, m_colIndex, m_otherIndex);
937   }
loadPacket(Index i,Index j)938   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
939                                                           Index j) const {
940     return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
941                                                         j + m_col_offset);
942   }
943 
944   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
loadCoeffStandard(Index i)945   loadCoeffStandard(Index i) const {
946     return m_base_mapper.loadCoeffStandard(
947         i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
948   }
949 
loadPacketFast(Index i)950   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
951     return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex,
952                                         m_rowIndex, m_colIndex, m_otherIndex);
953   }
954   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
loadPacketStandard(Index i)955   loadPacketStandard(Index i) const {
956     typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT;
957     return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>(
958         i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
959   }
960   template <typename Packet>
aligned(Index)961   EIGEN_DEVICE_FUNC bool aligned(Index) const {
962     return false;
963   }
964 
965   EIGEN_DEVICE_FUNC
nonStandardPatches()966   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
967     return m_base_mapper.nonStandardPatches();
968   }
969 
970   // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row,
971   // plane and depth index respectively that fits into the peeled_k elements
972   // starting at m_depth_offset.
973 
974   EIGEN_DEVICE_FUNC
maxCol(const Index peeled_k)975   EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
976     const Index max_col =
977         fastPatchColStride().divide(m_depth_offset + peeled_k);
978     return std::min<Index>(1 + max_col, patchCols());
979   }
980 
981   EIGEN_DEVICE_FUNC
maxRow(const Index peeled_k,const Index col)982   EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
983                                    const Index col) const {
984     const Index max_row = fastPatchRowStride().divide(
985         m_depth_offset + peeled_k - col * patchColStride());
986     return std::min<Index>(1 + max_row, patchRows());
987   }
988 
989   EIGEN_DEVICE_FUNC
maxPlane(const Index peeled_k,const Index col,const Index row)990   EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col,
991                                      const Index row) const {
992     const Index max_plane = fastPatchPlaneStride().divide(
993         m_depth_offset + peeled_k - col * patchColStride() -
994         row * patchRowStride());
995     return std::min<Index>(1 + max_plane, patchPlanes());
996   }
997 
998   // MaxDepth uses only the remaining number of elements in the peeled_k.
999   EIGEN_DEVICE_FUNC
maxDepth(const Index num_elements,const Index start_depth)1000   EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
1001                                      const Index start_depth) const {
1002     return std::min<Index>(start_depth + num_elements, patchDepth());
1003   }
1004 
1005   // Every register matters in this code, so sometimes to prevent register
1006   // spilling, instead of the variable that you would expect to see, we use
1007   // another one, that is guaranteed to have the same value. E.g. patch depth is
1008   // always the same as input depth, and it's also the same as input plane
1009   // stride. Bunch of other parameters have similar relations.
1010 
1011   typedef internal::TensorIntDivisor<Index> IndexDivisor;
1012 
1013   EIGEN_DEVICE_FUNC
patchDepth()1014   EIGEN_ALWAYS_INLINE Index patchDepth() const {
1015     eigen_assert(m_base_mapper.m_patch_depth ==
1016                      m_base_mapper.m_planeInputStride &&
1017                  "Patch depth must be equal to plane input stride.");
1018     return m_base_mapper.m_planeInputStride;
1019   }
1020 
1021   EIGEN_DEVICE_FUNC
patchPlanes()1022   EIGEN_ALWAYS_INLINE Index patchPlanes() const {
1023     eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride &&
1024                  "Patch planes must be equal to row stride.");
1025     return m_base_mapper.m_rowStride;
1026   }
1027   EIGEN_DEVICE_FUNC
patchRows()1028   EIGEN_ALWAYS_INLINE Index patchRows() const {
1029     return m_base_mapper.m_patch_rows;
1030   }
1031   EIGEN_DEVICE_FUNC
patchCols()1032   EIGEN_ALWAYS_INLINE Index patchCols() const {
1033     return m_base_mapper.m_patch_cols;
1034   }
1035 
1036   EIGEN_DEVICE_FUNC
patchPlaneStride()1037   EIGEN_ALWAYS_INLINE Index patchPlaneStride() const {
1038     eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
1039                  "Patch depth must be equal to patch plane stride.");
1040     return patchDepth();
1041   }
1042   EIGEN_DEVICE_FUNC
patchRowStride()1043   EIGEN_ALWAYS_INLINE Index patchRowStride() const {
1044     return m_base_mapper.m_patch_row_stride;
1045   }
1046   EIGEN_DEVICE_FUNC
patchColStride()1047   EIGEN_ALWAYS_INLINE Index patchColStride() const {
1048     return m_base_mapper.m_patch_col_stride;
1049   }
1050 
1051   EIGEN_DEVICE_FUNC
fastPatchPlaneStride()1052   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const {
1053     eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
1054                  "Patch depth must be equal to patch plane stride.");
1055     return m_base_mapper.m_fastDimZero;  // patch_depth
1056   }
1057   EIGEN_DEVICE_FUNC
fastPatchRowStride()1058   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
1059     return m_base_mapper.m_fastPatchRowStride;
1060   }
1061   EIGEN_DEVICE_FUNC
fastPatchColStride()1062   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
1063     return m_base_mapper.m_fastPatchColStride;
1064   }
1065 
1066   EIGEN_DEVICE_FUNC
packetNoPadding(const Index depth,const Index baseIndex)1067   EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
1068                                              const Index baseIndex) const {
1069     const Index inputIndex = depth + baseIndex;
1070     return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
1071   }
1072   EIGEN_DEVICE_FUNC
coeffNoPadding(const Index depth,const Index baseIndex)1073   EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth,
1074                                             const Index baseIndex) const {
1075     const Index inputIndex = depth + baseIndex;
1076     return m_base_mapper.m_impl.coeff(inputIndex);
1077   }
1078 
1079   EIGEN_DEVICE_FUNC
padPlane(const Index plane)1080   EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const {
1081     const Index p = m_planeIndex + plane;
1082     return p < 0 || p >= m_base_mapper.m_inputPlanes;
1083   }
1084   EIGEN_DEVICE_FUNC
padRow(const Index row)1085   EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
1086     const Index r = m_rowIndex + row;
1087     return r < 0 || r >= m_base_mapper.m_inputRows;
1088   }
1089   EIGEN_DEVICE_FUNC
padCol(const Index col)1090   EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
1091     const Index c = m_colIndex + col;
1092     return c < 0 || c >= m_base_mapper.m_inputCols;
1093   }
1094   EIGEN_DEVICE_FUNC
baseIndex(const Index plane,const Index row,const Index col)1095   EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row,
1096                                       const Index col) const {
1097     const Index p = m_planeIndex + plane;
1098     const Index r = m_rowIndex + row;
1099     const Index c = m_colIndex + col;
1100     return p * m_base_mapper.m_planeInputStride +
1101            r * m_base_mapper.m_rowInputStride +
1102            c * m_base_mapper.m_colInputStride + m_otherIndex;
1103   }
1104 
1105   EIGEN_DEVICE_FUNC
planeOffset()1106   EIGEN_ALWAYS_INLINE Index planeOffset() const {
1107     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
1108     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
1109     const Index rowOffset =
1110         (patchOffset - colOffset * m_base_mapper.m_colStride) /
1111         m_base_mapper.m_fastRowStride;
1112     const Index planeOffset = patchOffset -
1113                               colOffset * m_base_mapper.m_colStride -
1114                               rowOffset * m_base_mapper.m_rowStride;
1115     return planeOffset;
1116   }
1117 
1118   EIGEN_DEVICE_FUNC
rowOffset()1119   EIGEN_ALWAYS_INLINE Index rowOffset() const {
1120     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
1121     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
1122     const Index rowOffset =
1123         (patchOffset - colOffset * m_base_mapper.m_colStride) /
1124         m_base_mapper.m_fastRowStride;
1125     return rowOffset;
1126   }
1127 
1128   EIGEN_DEVICE_FUNC
colOffset()1129   EIGEN_ALWAYS_INLINE Index colOffset() const {
1130     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
1131     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
1132     return colOffset;
1133   }
1134 
1135   EIGEN_DEVICE_FUNC
depthOffset()1136   EIGEN_ALWAYS_INLINE Index depthOffset() const {
1137     return m_depth_offset % patchDepth();
1138   }
1139 
1140   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
getLinearMapper(Index i,Index j)1141   getLinearMapper(Index i, Index j) const {
1142     return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
1143   }
1144 
1145  private:
1146   const ParentMapper m_base_mapper;  // Keeping a copy instead of a reference
1147                                      // performs better in benchmarks.
1148 
1149   Index m_depth_offset;  // First row in the input matrix
1150   Index m_col_offset;    // First col in the input matrix
1151 
1152   // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
1153   // indices for the first element in a patch specified by col_offset
1154   // (see computeBaseIndices(...) for details).
1155   Index m_planeIndex;
1156   Index m_rowIndex;
1157   Index m_colIndex;
1158   Index m_otherIndex;
1159 };
1160 
1161 // Arrange a block of the right input matrix (in our case it's always a "virtual
1162 // matrix" constructed from extracted volume patches) in contiguous memory.
1163 //
1164 // Given column major input (A0 beside A1 in memory):
1165 // A0 B0 C0 D0  E0 F0 G0 H0 ... Z0
1166 // A1 B1 C1 D1  E1 F1 G1 H1 ... Z1
1167 // A2 B2 C2 D2  E2 F2 G2 H2 ... Z2
1168 // A3 B3 C3 D3  E3 F3 G3 H3 ... Z3
1169 // A4 B4 C4 D4  E4 F4 G4 H4 ... Z4
1170 // A5 B5 C5 D5  E5 F5 G5 H5 ... Z5
1171 // A6 B6 C6 D6  E6 F6 G6 H6 ... Z6
1172 // A7 B7 C7 D7  E7 F7 G7 H7 ... Z7
1173 // A8 ...
1174 // ...
1175 //
1176 // *) A, B, C, ... - patches extracted from the original input.
1177 // *) A0, A1, A2 ... - values from the same patch at different offsets.
1178 //
1179 // The traversal (packed rhs memory) order (B0 besides A0 in memory):
1180 // A0 B0 C0 D0 A1 B1 C1 D1 ...
1181 // E0 F0 G0 H0 E1 F1 G1 H1 ...
1182 // ...
1183 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
1184 //
1185 // This traversal order must be the same as in default gemm_pack_rhs defined in
1186 // GeneralBlockPanelKernel.h.
1187 //
1188 // *) nr - number of registers along the 'n' dimension.
1189 //    See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
1190 //    Multiplication" paper.
1191 //
1192 // TODO(ezhulenev): Add support for squeezing reads along two innermost
1193 // dimensions (see eigen_spatial_convolutions).
1194 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1195           typename ArgType, typename Device, typename Scalar, typename Index,
1196           typename nocontract_t, typename contract_t, int packet_size,
1197           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
1198           int nr>
1199 struct gemm_pack_rhs<
1200     Scalar, Index,
1201     TensorContractionSubMapper<
1202         Scalar, Index, Rhs,
1203         TensorEvaluator<const TensorReshapingOp<
1204                             NewDimension, const TensorVolumePatchOp<
1205                                               Planes, Rows, Cols, ArgType> >,
1206                         Device>,
1207         nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1208         inner_dim_reordered, Alignment>,
1209     nr, ColMajor, false, false> {
1210   typedef TensorContractionSubMapper<
1211       Scalar, Index, Rhs,
1212       TensorEvaluator<const TensorReshapingOp<
1213                           NewDimension, const TensorVolumePatchOp<
1214                                             Planes, Rows, Cols, ArgType> >,
1215                       Device>,
1216       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1217       inner_dim_reordered, Alignment>
1218       SubMapper;
1219 
1220   typedef SubMapper DataMapper;
1221   typedef typename packet_traits<Scalar>::type Packet;
1222 
1223   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
1224 
1225   EIGEN_DEVICE_FUNC
1226   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1227                                     Index depth, Index cols, Index stride = 0,
1228                                     Index offset = 0) const {
1229     eigen_assert(stride == 0);
1230     eigen_assert(offset == 0);
1231 
1232     const Index packet_cols4 = (cols / 4) * 4;
1233     const Index peeled_k = (depth / packet_size) * packet_size;
1234     const bool non_standard_patches = rhs.nonStandardPatches();
1235 
1236     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1237       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1238       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1239       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1240       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1241 
1242       Index k = 0;
1243       if ((packet_size % 4) == 0 && !non_standard_patches) {
1244         // FAST PATH:
1245         // Iterate over patch columns, rows and planes if we know that a single
1246         // packet do not span across multiple planes, rows or columns.
1247         if ((rhs.patchDepth() % packet_size) == 0) {
1248           const Index start_col = rhs.colOffset();
1249           const Index max_col = rhs.maxCol(peeled_k);
1250 
1251           for (Index c = start_col; c < max_col; ++c) {
1252             eigen_assert(k <= peeled_k);
1253 
1254             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1255             const Index max_row = rhs.maxRow(peeled_k, c);
1256 
1257             const bool pad_col0 = dm0.padCol(c);
1258             const bool pad_col1 = dm1.padCol(c);
1259             const bool pad_col2 = dm2.padCol(c);
1260             const bool pad_col3 = dm3.padCol(c);
1261 
1262             for (Index r = start_row; r < max_row; ++r) {
1263               eigen_assert(k <= peeled_k);
1264 
1265               const Index start_plane = ((c == start_col) && (r == start_row))
1266                                             ? rhs.planeOffset()
1267                                             : 0;
1268               const Index max_plane = rhs.maxPlane(peeled_k, c, r);
1269 
1270               const bool pad_row0 = pad_col0 || dm0.padRow(r);
1271               const bool pad_row1 = pad_col1 || dm1.padRow(r);
1272               const bool pad_row2 = pad_col2 || dm2.padRow(r);
1273               const bool pad_row3 = pad_col3 || dm3.padRow(r);
1274 
1275               for (Index p = start_plane; p < max_plane; ++p) {
1276                 eigen_assert(k <= peeled_k);
1277 
1278                 const bool pad0 = pad_row0 || dm0.padPlane(p);
1279                 const bool pad1 = pad_row1 || dm1.padPlane(p);
1280                 const bool pad2 = pad_row2 || dm2.padPlane(p);
1281                 const bool pad3 = pad_row3 || dm3.padPlane(p);
1282 
1283                 const Index idx0 = dm0.baseIndex(p, r, c);
1284                 const Index idx1 = dm1.baseIndex(p, r, c);
1285                 const Index idx2 = dm2.baseIndex(p, r, c);
1286                 const Index idx3 = dm3.baseIndex(p, r, c);
1287 
1288                 const Index start_depth =
1289                     ((c == start_col) && (r == start_row) && (p == start_plane))
1290                         ? rhs.depthOffset()
1291                         : 0;
1292                 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1293                 eigen_assert((max_depth - start_depth) % packet_size == 0);
1294 
1295                 for (Index d = start_depth; d < max_depth; d += packet_size) {
1296                   eigen_assert(k < peeled_k);
1297                   PacketBlock<Packet, 4> kernel;
1298                   kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1299                                           : rhs.packetNoPadding(d, idx0);
1300                   kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1301                                           : rhs.packetNoPadding(d, idx1);
1302                   kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
1303                                           : rhs.packetNoPadding(d, idx2);
1304                   kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
1305                                           : rhs.packetNoPadding(d, idx3);
1306                   ptranspose(kernel);
1307                   pstoreu(block + 0 * packet_size, kernel.packet[0]);
1308                   pstoreu(block + 1 * packet_size, kernel.packet[1]);
1309                   pstoreu(block + 2 * packet_size, kernel.packet[2]);
1310                   pstoreu(block + 3 * packet_size, kernel.packet[3]);
1311                   block += 4 * packet_size;
1312                   k += packet_size;
1313                 }
1314               }
1315             }
1316           }
1317 
1318           // The loop above should fill peeled_k elements.
1319           eigen_assert(peeled_k == k);
1320 
1321         } else {
1322           // Packet can span multiple planes, rows or columns, so we have to go
1323           // though the slower "standard" path.
1324           for (; k < peeled_k; k += packet_size) {
1325             PacketBlock<Packet, 4> kernel;
1326             kernel.packet[0] = dm0.loadPacketStandard(k);
1327             kernel.packet[1] = dm1.loadPacketStandard(k);
1328             kernel.packet[2] = dm2.loadPacketStandard(k);
1329             kernel.packet[3] = dm3.loadPacketStandard(k);
1330             ptranspose(kernel);
1331             pstoreu(block + 0 * packet_size, kernel.packet[0]);
1332             pstoreu(block + 1 * packet_size, kernel.packet[1]);
1333             pstoreu(block + 2 * packet_size, kernel.packet[2]);
1334             pstoreu(block + 3 * packet_size, kernel.packet[3]);
1335             block += 4 * packet_size;
1336           }
1337         }
1338       }
1339 
1340       // Copy the remaining coefficients of the column block after the peeled_k.
1341       if (!non_standard_patches) {
1342         for (; k < depth; k++) {
1343           block[0] = dm0.loadCoeffStandard(k);
1344           block[1] = dm1.loadCoeffStandard(k);
1345           block[2] = dm2.loadCoeffStandard(k);
1346           block[3] = dm3.loadCoeffStandard(k);
1347           block += 4;
1348         }
1349       } else {
1350         for (; k < depth; k++) {
1351           block[0] = dm0(k);
1352           block[1] = dm1(k);
1353           block[2] = dm2(k);
1354           block[3] = dm3(k);
1355           block += 4;
1356         }
1357       }
1358     }
1359 
1360     // Copy the remaining columns one at a time (nr==1).
1361 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
1362     // remaining columns are handled different for PPC
1363     for (Index k = 0; k < depth; k++) {
1364       for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1365         *block = rhs(k, j2);
1366         block += 1;
1367       }
1368     }
1369 #else
1370     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1371       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1372       for (Index k = 0; k < depth; k++) {
1373         *block = dm0(k);
1374         block += 1;
1375       }
1376     }
1377 #endif
1378   }
1379 };
1380 
1381 // Template specialization for packet_size = 2. We must special-case packet
1382 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
1383 //
1384 // TODO(ezhulenev): Add support for squeezing reads along two innermost
1385 // dimensions (see eigen_spatial_convolutions).
1386 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1387           typename ArgType, typename Device, typename Scalar, typename Index,
1388           typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1389           bool inner_dim_reordered, int Alignment, int nr>
1390 struct gemm_pack_rhs<
1391     Scalar, Index,
1392     TensorContractionSubMapper<
1393         Scalar, Index, Rhs,
1394         TensorEvaluator<const TensorReshapingOp<
1395                             NewDimension, const TensorVolumePatchOp<
1396                                               Planes, Rows, Cols, ArgType> >,
1397                         Device>,
1398         nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
1399         inner_dim_reordered, Alignment>,
1400     nr, ColMajor, false, false> {
1401   typedef TensorContractionSubMapper<
1402       Scalar, Index, Rhs,
1403       TensorEvaluator<const TensorReshapingOp<
1404                           NewDimension, const TensorVolumePatchOp<
1405                                             Planes, Rows, Cols, ArgType> >,
1406                       Device>,
1407       nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
1408       inner_dim_reordered, Alignment>
1409       SubMapper;
1410   typedef SubMapper DataMapper;
1411   typedef typename packet_traits<Scalar>::type Packet;
1412 
1413   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
1414 
1415   EIGEN_DEVICE_FUNC
1416   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1417                                     Index depth, Index cols, Index stride = 0,
1418                                     Index offset = 0) const {
1419     eigen_assert(stride == 0);
1420     eigen_assert(offset == 0);
1421 
1422     const int packet_size = 2;
1423 
1424     const Index packet_cols4 = (cols / 4) * 4;
1425     const Index peeled_k = (depth / packet_size) * packet_size;
1426     const bool non_standard_patches = rhs.nonStandardPatches();
1427 
1428     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1429       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1430       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1431       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1432       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1433 
1434       Index k = 0;
1435       if (!non_standard_patches) {
1436         // FAST PATH:
1437         // Iterate over patch columns, rows and planes if we know that a single
1438         // packet do not span across multiple planes, rows or columns.
1439         if ((rhs.patchDepth() % packet_size) == 0) {
1440           const Index start_col = rhs.colOffset();
1441           const Index max_col = rhs.maxCol(peeled_k);
1442 
1443           for (Index c = start_col; c < max_col; ++c) {
1444             eigen_assert(k <= peeled_k);
1445 
1446             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1447             const Index max_row = rhs.maxRow(peeled_k, c);
1448 
1449             const bool pad_col0 = dm0.padCol(c);
1450             const bool pad_col1 = dm1.padCol(c);
1451             const bool pad_col2 = dm2.padCol(c);
1452             const bool pad_col3 = dm3.padCol(c);
1453 
1454             for (Index r = start_row; r < max_row; ++r) {
1455               eigen_assert(k <= peeled_k);
1456 
1457               const Index start_plane = ((c == start_col) && (r == start_row))
1458                                             ? rhs.planeOffset()
1459                                             : 0;
1460               const Index max_plane = rhs.maxPlane(peeled_k, c, r);
1461 
1462               const bool pad_row0 = dm0.padRow(r);
1463               const bool pad_row1 = dm1.padRow(r);
1464               const bool pad_row2 = dm2.padRow(r);
1465               const bool pad_row3 = dm3.padRow(r);
1466 
1467               for (Index p = start_plane; p < max_plane; ++p) {
1468                 eigen_assert(k <= peeled_k);
1469 
1470                 const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
1471                 const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
1472                 const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
1473                 const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
1474 
1475                 const Index idx0 = dm0.baseIndex(p, r, c);
1476                 const Index idx1 = dm1.baseIndex(p, r, c);
1477                 const Index idx2 = dm2.baseIndex(p, r, c);
1478                 const Index idx3 = dm3.baseIndex(p, r, c);
1479 
1480                 const Index start_depth =
1481                     ((c == start_col) && (r == start_row) && (p == start_plane))
1482                         ? rhs.depthOffset()
1483                         : 0;
1484                 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1485                 eigen_assert((max_depth - start_depth) % packet_size == 0);
1486 
1487                 for (Index d = start_depth; d < max_depth; d += packet_size) {
1488                   eigen_assert(k < peeled_k);
1489                   PacketBlock<Packet, 2> kernel0;
1490                   PacketBlock<Packet, 2> kernel1;
1491                   kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1492                                            : rhs.packetNoPadding(d, idx0);
1493                   kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1494                                            : rhs.packetNoPadding(d, idx1);
1495                   kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
1496                                            : rhs.packetNoPadding(d, idx2);
1497                   kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
1498                                            : rhs.packetNoPadding(d, idx3);
1499                   ptranspose(kernel0);
1500                   ptranspose(kernel1);
1501                   pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1502                   pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1503                   pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1504                   pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1505                   block += 4 * packet_size;
1506                   k += packet_size;
1507                 }
1508               }
1509             }
1510           }
1511 
1512           // The loop above should fill peeled_k elements.
1513           eigen_assert(peeled_k == k);
1514 
1515         } else {
1516           for (; k < peeled_k; k += packet_size) {
1517             PacketBlock<Packet, 2> kernel0;
1518             PacketBlock<Packet, 2> kernel1;
1519             kernel0.packet[0] = dm0.loadPacketStandard(k);
1520             kernel0.packet[1] = dm1.loadPacketStandard(k);
1521             kernel1.packet[0] = dm2.loadPacketStandard(k);
1522             kernel1.packet[1] = dm3.loadPacketStandard(k);
1523             ptranspose(kernel0);
1524             ptranspose(kernel1);
1525             pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1526             pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1527             pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1528             pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1529             block += 4 * packet_size;
1530           }
1531         }
1532       }
1533 
1534       // Copy the remaining coefficients of the column block after the peeled_k.
1535       if (!rhs.nonStandardPatches()) {
1536         for (; k < depth; k++) {
1537           block[0] = dm0.loadCoeffStandard(k);
1538           block[1] = dm1.loadCoeffStandard(k);
1539           block[2] = dm2.loadCoeffStandard(k);
1540           block[3] = dm3.loadCoeffStandard(k);
1541           block += 4;
1542         }
1543       } else {
1544         for (; k < depth; k++) {
1545           block[0] = dm0(k);
1546           block[1] = dm1(k);
1547           block[2] = dm2(k);
1548           block[3] = dm3(k);
1549           block += 4;
1550         }
1551       }
1552     }
1553 
1554     // Copy the remaining columns one at a time (nr==1).
1555     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1556       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1557       for (Index k = 0; k < depth; k++) {
1558         *block = dm0(k);
1559         block += 1;
1560       }
1561     }
1562   }
1563 };
1564 
1565 // Special case for non-vectorized types such as float16 (packet_size = 1).
1566 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1567           typename ArgType, typename Device, typename Scalar, typename Index,
1568           typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1569           bool inner_dim_reordered, int Alignment, int nr>
1570 struct gemm_pack_rhs<
1571     Scalar, Index,
1572     TensorContractionSubMapper<
1573         Scalar, Index, Rhs,
1574         TensorEvaluator<const TensorReshapingOp<
1575                             NewDimension, const TensorVolumePatchOp<
1576                                               Planes, Rows, Cols, ArgType> >,
1577                         Device>,
1578         nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous,
1579         inner_dim_reordered, Alignment>,
1580     nr, ColMajor, false, false> {
1581   typedef TensorContractionSubMapper<
1582       Scalar, Index, Rhs,
1583       TensorEvaluator<const TensorReshapingOp<
1584                           NewDimension, const TensorVolumePatchOp<
1585                                             Planes, Rows, Cols, ArgType> >,
1586                       Device>,
1587       nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
1588       Alignment>
1589       SubMapper;
1590   typedef SubMapper DataMapper;
1591 
1592   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
1593 
1594   EIGEN_DEVICE_FUNC
1595   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1596                                     Index depth, Index cols, Index stride = 0,
1597                                     Index offset = 0) const {
1598     eigen_assert(stride == 0);
1599     eigen_assert(offset == 0);
1600 
1601     const Index packet_cols4 = (cols / 4) * 4;
1602 
1603     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1604       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1605       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1606       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1607       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1608 
1609       if (!rhs.nonStandardPatches()) {
1610         for (Index k = 0; k < depth; k++) {
1611           block[0] = dm0.loadCoeffStandard(k);
1612           block[1] = dm1.loadCoeffStandard(k);
1613           block[2] = dm2.loadCoeffStandard(k);
1614           block[3] = dm3.loadCoeffStandard(k);
1615           block += 4;
1616         }
1617       } else {
1618         for (Index k = 0; k < depth; k++) {
1619           block[0] = dm0(k);
1620           block[1] = dm1(k);
1621           block[2] = dm2(k);
1622           block[3] = dm3(k);
1623           block += 4;
1624         }
1625       }
1626     }
1627 
1628     // Copy the remaining columns one at a time (nr==1).
1629     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1630       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1631       for (Index k = 0; k < depth; k++) {
1632         *block = dm0(k);
1633         block += 1;
1634       }
1635     }
1636   }
1637 };
1638 #endif
1639 
1640 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
1641 // Pack a block of the right input matrix (in our case it's always a "virtual
1642 // matrix" constructed from extracted image patches) in contiguous block in
1643 // column-major storage order. Knowing the properties of the original patch op
1644 // we can do it more efficient than the default gemm_pack_colmajor_block.
1645 //
1646 // TODO(ezhulenev): gemm_pack_colmajor_block for spatial convolutions supports
1647 // squeezing reads along the 2 innermost dimensions, add it here if needed.
1648 template <typename NewDimension, Index Planes, Index Rows, Index Cols,
1649           typename ArgType, typename Device, typename Scalar,
1650           typename StorageIndex, typename nocontract_t, typename contract_t,
1651           int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
1652           int Alignment>
1653 struct gemm_pack_colmajor_block<
1654     Scalar, StorageIndex,
1655     TensorContractionSubMapper<
1656         Scalar, StorageIndex, Rhs,
1657         TensorEvaluator<const TensorReshapingOp<
1658                             NewDimension, const TensorVolumePatchOp<
1659                                               Planes, Rows, Cols, ArgType> >,
1660                         Device>,
1661         nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1662         inner_dim_reordered, Alignment>,
1663     ColMajor> {
1664   typedef TensorContractionSubMapper<
1665       Scalar, StorageIndex, Rhs,
1666       TensorEvaluator<const TensorReshapingOp<
1667                           NewDimension, const TensorVolumePatchOp<
1668                                             Planes, Rows, Cols, ArgType> >,
1669                       Device>,
1670       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1671       inner_dim_reordered, Alignment>
1672       SubMapper;
1673 
1674   typedef SubMapper DataMapper;
1675   typedef typename packet_traits<Scalar>::type Packet;
1676 
1677   EIGEN_DONT_INLINE
1678   void operator()(Scalar* block, const DataMapper& rhs, StorageIndex rows,
1679                   StorageIndex cols) {
1680     const bool standard_patches = !rhs.nonStandardPatches();
1681 
1682     if (standard_patches && rhs.patchDepth() % packet_size == 0) {
1683       packStandardPatches<true>(block, rhs, rows, cols);
1684 
1685     } else if (standard_patches) {
1686       packStandardPatches<false>(block, rhs, rows, cols);
1687 
1688     } else {
1689       // With non-standard patches we don't do any vectorized loads.
1690       // TODO(ezhulenev): It doesn't look like that we should completely give up
1691       // on packets. Make this code path faster!
1692       for (StorageIndex col = 0; col < cols; ++col) {
1693         SubMapper lm = rhs.getLinearMapper(0, col);
1694         for (StorageIndex i = 0; i < rows; ++i) {
1695           *block = lm(i);
1696           ++block;
1697         }
1698       }
1699     }
1700   }
1701 
1702  private:
1703   // Pack standard volume patches:
1704   //
1705   // - patch_depth_is_multiple_of_packet_size=true: We are guaranteed to have
1706   //   depth dimension size to be a multiple of packet size, so we can skip all
1707   //   non vectorized loads and checks.
1708   //
1709   template <bool patch_depth_is_multiple_of_packet_size>
1710   EIGEN_ALWAYS_INLINE void packStandardPatches(Scalar* block,
1711                                                const DataMapper& rhs,
1712                                                StorageIndex rows,
1713                                                StorageIndex cols) {
1714     eigen_assert(!rhs.nonStandardPatches());
1715 
1716     // Give vectorized_rows the name used in all other gemm_pack_rhs above.
1717     const Index peeled_k = (rows / packet_size) * packet_size;
1718 
1719     const Index start_col = rhs.colOffset();
1720     const Index max_col = rhs.maxCol(peeled_k);
1721 
1722     for (StorageIndex col = 0; col < cols; ++col) {
1723       SubMapper lm = rhs.getLinearMapper(0, col);
1724 
1725       Index k = 0;
1726       for (Index c = start_col; c < max_col; ++c) {
1727         eigen_assert(k <= peeled_k);
1728 
1729         const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1730         const Index max_row = rhs.maxRow(peeled_k, c);
1731         const bool pad_col = lm.padCol(c);
1732 
1733         for (Index r = start_row; r < max_row; ++r) {
1734           eigen_assert(k <= peeled_k);
1735 
1736           const Index start_plane =
1737               ((c == start_col) && (r == start_row)) ? rhs.planeOffset() : 0;
1738           const Index max_plane = rhs.maxPlane(peeled_k, c, r);
1739           const bool pad_row = pad_col || lm.padRow(r);
1740 
1741           for (Index p = start_plane; p < max_plane; ++p) {
1742             eigen_assert(k <= peeled_k);
1743 
1744             const Index start_depth =
1745                 ((c == start_col) && (r == start_row) && (p == start_plane))
1746                     ? rhs.depthOffset()
1747                     : 0;
1748             const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1749 
1750             const bool pad = pad_col || pad_row || lm.padPlane(p);
1751             const Index base_idx = lm.baseIndex(p, r, c);
1752 
1753             if (patch_depth_is_multiple_of_packet_size)
1754               eigen_assert((max_depth - start_depth) % packet_size == 0);
1755 
1756             // If patch depth is a multiple of packet size, it's guaranteed that
1757             // we can process all values in depth dimension with packets.
1758             const Index max_vectorized_depth =
1759                 patch_depth_is_multiple_of_packet_size
1760                     ? max_depth
1761                     : max_depth - packet_size;
1762 
1763             Index d = start_depth;
1764 
1765             // 1. Process depth dimension with vectorized instructions.
1766             for (; d < max_vectorized_depth; d += packet_size) {
1767               eigen_assert(k < peeled_k);
1768               const Packet packet = pad ? pset1<Packet>(Scalar(0))
1769                                         : rhs.packetNoPadding(d, base_idx);
1770               internal::pstoreu(block, packet);
1771               block += packet_size;
1772               k += packet_size;
1773             }
1774 
1775             // 2. Finish with coefficients.
1776             if (!patch_depth_is_multiple_of_packet_size) {
1777               for (; d < max_depth; d++) {
1778                 eigen_assert(k < peeled_k);
1779                 *block = pad ? Scalar(0) : rhs.coeffNoPadding(d, base_idx);
1780                 ++block;
1781                 ++k;
1782               }
1783             }
1784           }
1785         }
1786       }
1787 
1788       // The loop above should fill peeled_k elements.
1789       eigen_assert(peeled_k == k);
1790 
1791       // Fill remaining elements using loadCoeffStandard.
1792       for (; k < rows; ++k) {
1793         *block = lm.loadCoeffStandard(k);
1794         ++block;
1795       }
1796     }
1797   }
1798 };
1799 #endif  // defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
1800 
1801 }  // namespace internal
1802 
1803 /** CuboidConvolution
1804  * \ingroup CXX11_NeuralNetworks_Module
1805  *
1806  * \brief Applies a 3D convolution over a multichannel input voxel block.
1807  *
1808  * The input parameter is expected to be a tensor with a rank of 4 or more
1809  * (channels, depth, height, width, and optionally others).
1810  * The kernel parameter is expected to be a 5D tensor (filters, channels,
1811  * kernel_depth, kernel_height, kernel_width).
1812  * The result can be assigned to a tensor of rank equal to the rank of the
1813  * input. The dimensions of the result will be filters, depth, height, width
1814  * (and others if applicable).
1815  *
1816  * The input and kernel have to be in the same layout, and both row-major and
1817  * col-major are supported. The shapes given above are for col-major layout.
1818  * For row-major, all dimensions should be reversed.
1819  *
1820  * It is possible to swap the order of the depth, width, and height dimensions
1821  * provided that the same order is used in the input, the kernel, and the
1822  * output.
1823  */
1824 template <typename Input, typename Kernel>
1825 EIGEN_ALWAYS_INLINE static const std::conditional_t<
1826     internal::traits<Input>::Layout == ColMajor,
1827     TensorReshapingOp<
1828         const DSizes<typename internal::traits<Input>::Index,
1829                      internal::traits<Input>::NumDimensions>,
1830         const TensorContractionOp<
1831             const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1832             const TensorReshapingOp<
1833                 const DSizes<typename internal::traits<Input>::Index, 2>,
1834                 const Kernel>,
1835             const TensorReshapingOp<
1836                 const DSizes<typename internal::traits<Input>::Index, 2>,
1837                 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
1838                                           const Input> > > >,
1839     TensorReshapingOp<
1840         const DSizes<typename internal::traits<Input>::Index,
1841                      internal::traits<Input>::NumDimensions>,
1842         const TensorContractionOp<
1843             const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1844             const TensorReshapingOp<
1845                 const DSizes<typename internal::traits<Input>::Index, 2>,
1846                 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
1847                                           const Input> >,
1848             const TensorReshapingOp<
1849                 const DSizes<typename internal::traits<Input>::Index, 2>,
1850                 const Kernel> > > >
1851 CuboidConvolution(const Input& input, const Kernel& kernel,
1852                   const Index stridePlanes = 1, const Index strideRows = 1,
1853                   const Index strideCols = 1,
1854                   const PaddingType padding_type = PADDING_SAME) {
1855   typedef typename internal::traits<Input>::Index TensorIndex;
1856   TensorRef<Tensor<typename internal::traits<Input>::Scalar,
1857                    internal::traits<Input>::NumDimensions,
1858                    internal::traits<Input>::Layout, TensorIndex> >
1859       in(input);
1860   TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
1861                    internal::traits<Kernel>::NumDimensions,
1862                    internal::traits<Kernel>::Layout, TensorIndex> >
1863       kern(kernel);
1864 
1865   EIGEN_STATIC_ASSERT(
1866       internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
1867       YOU_MADE_A_PROGRAMMING_MISTAKE);
1868   static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
1869   static const int NumDims = internal::traits<Input>::NumDimensions;
1870 
1871   // Number of filters to apply. This is the same as the output depth of the
1872   // result.
1873   const TensorIndex kernelFilters =
1874       isColMajor ? kern.dimensions()[0] : kern.dimensions()[4];
1875   const TensorIndex kernelChannels =
1876       isColMajor ? kern.dimensions()[1] : kern.dimensions()[3];
1877 
1878   // Spatial size of the kernel.
1879   const TensorIndex kernelPlanes =
1880       isColMajor ? kern.dimensions()[2] : kern.dimensions()[2];
1881   const TensorIndex kernelRows =
1882       isColMajor ? kern.dimensions()[3] : kern.dimensions()[1];
1883   const TensorIndex kernelCols =
1884       isColMajor ? kern.dimensions()[4] : kern.dimensions()[0];
1885 
1886   if (isColMajor) {
1887     eigen_assert(kernelChannels == in.dimension(0));
1888   } else {
1889     eigen_assert(kernelChannels == in.dimension(NumDims - 1));
1890   }
1891 
1892   const TensorIndex inputPlanes =
1893       isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
1894   const TensorIndex inputRows =
1895       isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
1896   const TensorIndex inputCols =
1897       isColMajor ? in.dimension(3) : in.dimension(NumDims - 4);
1898 
1899   TensorIndex out_planes;
1900   TensorIndex out_height;
1901   TensorIndex out_width;
1902   switch (padding_type) {
1903     case PADDING_VALID:
1904       out_planes = Eigen::divup(inputPlanes - kernelPlanes + 1,
1905                                 static_cast<TensorIndex>(stridePlanes));
1906       out_height = Eigen::divup(inputRows - kernelRows + 1,
1907                                 static_cast<TensorIndex>(strideRows));
1908       out_width = Eigen::divup(inputCols - kernelCols + 1,
1909                                static_cast<TensorIndex>(strideCols));
1910       break;
1911     case PADDING_SAME:
1912       out_planes =
1913           Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
1914       out_height =
1915           Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
1916       out_width = Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
1917       break;
1918     default:
1919       out_planes = 0;
1920       out_height = 0;
1921       out_width = 0;
1922       eigen_assert(false && "unexpected padding");
1923   }
1924 
1925   DSizes<TensorIndex, 2> kernel_dims;
1926   if (isColMajor) {
1927     kernel_dims[0] = kernelFilters;
1928     kernel_dims[1] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
1929   } else {
1930     kernel_dims[0] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
1931     kernel_dims[1] = kernelFilters;
1932   }
1933 
1934   // Molds the output of the patch extraction result into a 2D tensor:
1935   // - the first dimension (dims[0]): the patch values to be multiplied with the
1936   // kernels
1937   // - the second dimension (dims[1]): everything else
1938   DSizes<TensorIndex, 2> pre_contract_dims;
1939   if (isColMajor) {
1940     pre_contract_dims[0] =
1941         kernelChannels * kernelPlanes * kernelRows * kernelCols;
1942     pre_contract_dims[1] = out_planes * out_height * out_width;
1943     for (int i = 4; i < NumDims; ++i) {
1944       pre_contract_dims[1] *= in.dimension(i);
1945     }
1946   } else {
1947     pre_contract_dims[1] =
1948         kernelChannels * kernelPlanes * kernelRows * kernelCols;
1949     pre_contract_dims[0] = out_planes * out_height * out_width;
1950     for (int i = 0; i < NumDims - 4; ++i) {
1951       pre_contract_dims[0] *= in.dimension(i);
1952     }
1953   }
1954 
1955   array<IndexPair<TensorIndex>, 1> contract_dims;
1956   contract_dims[0] = IndexPair<TensorIndex>(1, 0);
1957 
1958   // Molds the output of the contraction into the shape expected by the user
1959   // (assuming ColMajor):
1960   // - 1st dim: kernel filters
1961   // - 2nd dim: output depth
1962   // - 3nd dim: output height
1963   // - 4rd dim: output width
1964   // - 5th dim and beyond: everything else including batch size
1965   DSizes<TensorIndex, NumDims> post_contract_dims;
1966   if (isColMajor) {
1967     post_contract_dims[0] = kernelFilters;
1968     post_contract_dims[1] = out_planes;
1969     post_contract_dims[2] = out_height;
1970     post_contract_dims[3] = out_width;
1971     for (int i = 4; i < NumDims; ++i) {
1972       post_contract_dims[i] = in.dimension(i);
1973     }
1974   } else {
1975     post_contract_dims[NumDims - 1] = kernelFilters;
1976     post_contract_dims[NumDims - 2] = out_planes;
1977     post_contract_dims[NumDims - 3] = out_height;
1978     post_contract_dims[NumDims - 4] = out_width;
1979     for (int i = 0; i < NumDims - 4; ++i) {
1980       post_contract_dims[i] = in.dimension(i);
1981     }
1982   }
1983 
1984   return choose(
1985       Cond<internal::traits<Input>::Layout == ColMajor>(),
1986       kernel.reshape(kernel_dims)
1987           .contract(input
1988                         .extract_volume_patches(
1989                             kernelPlanes, kernelRows, kernelCols, stridePlanes,
1990                             strideRows, strideCols, padding_type)
1991                         .reshape(pre_contract_dims),
1992                     contract_dims)
1993           .reshape(post_contract_dims),
1994       input
1995           .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
1996                                   stridePlanes, strideRows, strideCols,
1997                                   padding_type)
1998           .reshape(pre_contract_dims)
1999           .contract(kernel.reshape(kernel_dims), contract_dims)
2000           .reshape(post_contract_dims));
2001 }
2002 
2003 }  // end namespace Eigen
2004 
2005 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_
2006