xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_
17 #define TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_
18 
19 #include "tensorflow/core/kernels/eigen_convolution_helpers.h"
20 
21 // Note this header is used in both TF and TFLite.
22 namespace Eigen {
23 
24 namespace internal {
25 
26 #if !EIGEN_ALTIVEC_USE_CUSTOM_PACK
27 // WARNING: Most of the code here implicitly assumes that the matrix is in
28 // ColMajor layout. This is guaranteed by the tensor contraction (see
29 // TensorContraction.h).
30 //
31 // Inside Eigen a tensor contraction is represented by a matrix multiplication.
32 // We don't want to actually extract image patches and reshape the result into
33 // a matrix (this involves allocating huge extra memory), so the patch
34 // extraction and reshape operations are implicit.
35 //
36 // TensorContractionInputMapper takes a matrix index and returns the coefficient
37 // (or the packet) of the "virtual tensor", that would be at that index if we
38 // were to actually reshape the result of patch extraction.
39 //
40 // TensorContractionSubMapper provides a similar view into the "virtual matrix"
41 // at the given vertical and horizontal offsets.
42 //
43 // "Virtual matrix" dimensions:
44 //   *0: kernelChannels * kernelRows * kernelCols;
45 //    1: out_height * out_width; * OTHERS (e.g batches, etc...)
46 //
47 // *) extracted patches are continuous in memory (innermost dimension assuming
48 //    col major layout)
49 //
50 // With this dimensions:
51 //   row - offset within a single patch (in code: patchId)
52 //   col - index of the extracted patch (in code: patchIndex)
53 //         patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
54 //
55 // TODO(ezhulenev): Consolidate this part of the code with the image patch
56 // extraction code since they are both very similar.
57 
58 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
59           typename Device, typename Scalar_, typename Index,
60           typename nocontract_t, typename contract_t, int Side, int packet_size,
61           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
62 class TensorContractionInputMapper<
63     Scalar_, Index, Side,
64     TensorEvaluator<
65         const TensorReshapingOp<NewDimension,
66                                 const TensorImagePatchOp<Rows, Cols, ArgType> >,
67         Device>,
68     nocontract_t, contract_t, packet_size, inner_dim_contiguous,
69     inner_dim_reordered, Alignment> {
70  public:
71   typedef Scalar_ Scalar;
72 
73   typedef TensorContractionInputMapper<
74       Scalar, Index, Side,
75       TensorEvaluator<
76           const TensorReshapingOp<
77               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
78           Device>,
79       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
80       inner_dim_reordered, Alignment>
81       Self;
82 
83   typedef TensorContractionSubMapper<
84       Scalar, Index, Side,
85       TensorEvaluator<
86           const TensorReshapingOp<
87               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
88           Device>,
89       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
90       inner_dim_reordered, Alignment>
91       SubMapper;
92 
93   typedef SubMapper VectorMapper;
94   typedef SubMapper LinearMapper;
95   typedef typename packet_traits<Scalar>::type Packet;
96 
97   typedef TensorEvaluator<ArgType, Device> TensorEvaluatorT;
98 
99   EIGEN_DEVICE_FUNC
TensorContractionInputMapper(const TensorEvaluator<const TensorReshapingOp<NewDimension,const TensorImagePatchOp<Rows,Cols,ArgType>>,Device> & tensor,const nocontract_t &,const nocontract_t &,const contract_t &,const contract_t &)100   TensorContractionInputMapper(
101       const TensorEvaluator<
102           const TensorReshapingOp<
103               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
104           Device>& tensor,
105       const nocontract_t&, const nocontract_t&, const contract_t&,
106       const contract_t&)
107       : m_impl(tensor.impl().impl()) {
108     Index patch_rows;
109     Index patch_depth;
110     if (internal::traits<ArgType>::Layout == ColMajor) {
111       patch_depth = tensor.impl().dimensions()[0];
112       patch_rows = tensor.impl().dimensions()[1];
113       m_patch_cols = tensor.impl().dimensions()[2];
114       m_num_patches = tensor.impl().dimensions()[3];
115     } else {
116       const size_t NumDims = tensor.impl().dimensions().size();
117       patch_depth = tensor.impl().dimensions()[NumDims - 1];
118       patch_rows = tensor.impl().dimensions()[NumDims - 2];
119       m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
120       m_num_patches = tensor.impl().dimensions()[NumDims - 4];
121     }
122 
123     // Strides for navigating through the single patch.
124     m_patch_row_stride = patch_depth;
125     m_patch_col_stride = patch_rows * m_patch_row_stride;
126 
127     m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
128     m_patch_col_inflate_strides = tensor.impl().colInflateStride();
129 
130     m_colStride = patch_rows;
131 
132     m_outputRows = tensor.impl().outputRows();
133     m_outputCols = tensor.impl().outputCols();
134     m_row_strides = tensor.impl().userRowStride();
135     m_col_strides = tensor.impl().userColStride();
136 
137     m_in_row_strides = tensor.impl().userInRowStride();
138     m_in_col_strides = tensor.impl().userInColStride();
139 
140     if (internal::traits<ArgType>::Layout == ColMajor) {
141       m_inputRows = tensor.impl().impl().dimensions()[1];
142       m_inputCols = tensor.impl().impl().dimensions()[2];
143     } else {
144       const int NumDims = tensor.impl().impl().dimensions().size();
145       m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2];
146       m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3];
147     }
148 
149     m_rowInputStride = patch_depth;
150     m_colInputStride = patch_depth * m_inputRows;
151     m_patchInputStride = patch_depth * m_inputRows * m_inputCols;
152 
153     m_rowPaddingTop = tensor.impl().rowPaddingTop();
154     m_colPaddingLeft = tensor.impl().colPaddingLeft();
155 
156     m_fastPatchRowStride =
157         internal::TensorIntDivisor<Index>(m_patch_row_stride);
158     m_fastPatchColStride =
159         internal::TensorIntDivisor<Index>(m_patch_col_stride);
160     m_fastInputRowStride =
161         internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
162     m_fastInputColStride =
163         internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
164     m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
165     m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
166     m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
167     m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth);
168   }
169 
170   EIGEN_DEVICE_FUNC
TensorContractionInputMapper(const TensorContractionInputMapper & base_mapper)171   TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
172       : m_impl(base_mapper.m_impl) {
173     m_patch_cols = base_mapper.m_patch_cols;
174     m_num_patches = base_mapper.m_num_patches;
175 
176     m_patch_row_stride = base_mapper.m_patch_row_stride;
177     m_patch_col_stride = base_mapper.m_patch_col_stride;
178 
179     m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
180     m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
181 
182     m_colStride = base_mapper.m_colStride;
183 
184     m_rowInputStride = base_mapper.m_rowInputStride;
185     m_colInputStride = base_mapper.m_colInputStride;
186     m_patchInputStride = base_mapper.m_patchInputStride;
187 
188     m_inputRows = base_mapper.m_inputRows;
189     m_inputCols = base_mapper.m_inputCols;
190 
191     m_outputRows = base_mapper.m_outputRows;
192     m_outputCols = base_mapper.m_outputCols;
193     m_row_strides = base_mapper.m_row_strides;
194     m_col_strides = base_mapper.m_col_strides;
195 
196     m_in_row_strides = base_mapper.m_in_row_strides;
197     m_in_col_strides = base_mapper.m_in_col_strides;
198 
199     m_rowPaddingTop = base_mapper.m_rowPaddingTop;
200     m_colPaddingLeft = base_mapper.m_colPaddingLeft;
201 
202     m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
203     m_fastPatchColStride = base_mapper.m_fastPatchColStride;
204     m_fastInputRowStride = base_mapper.m_fastInputRowStride;
205     m_fastInputColStride = base_mapper.m_fastInputColStride;
206     m_fastNumPatches = base_mapper.m_fastNumPatches;
207     m_fastColStride = base_mapper.m_fastColStride;
208     m_fastOutputRows = base_mapper.m_fastOutputRows;
209     m_fastDimZero = base_mapper.m_fastDimZero;
210   }
211 
212   // If true, turns off some optimizations for loading packets since the image
213   // patches are "non-standard" such as there are non-trivial strides or
214   // inflations in the input.
215   EIGEN_DEVICE_FUNC
nonStandardPatches()216   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
217     return m_in_row_strides != 1 || m_in_col_strides != 1 ||
218            m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
219   }
220 
221   EIGEN_DEVICE_FUNC
getSubMapper(Index i,Index j)222   EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
223     return SubMapper(*this, i, j);
224   }
225 
226   EIGEN_DEVICE_FUNC
getLinearMapper(Index i,Index j)227   EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
228     return LinearMapper(*this, i, j);
229   }
230 
231   EIGEN_DEVICE_FUNC
operator()232   EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
233     Index rowIndex, colIndex, otherIndex;
234     computeBaseIndices(0, rowIndex, colIndex, otherIndex);
235     return loadCoeff(row, rowIndex, colIndex, otherIndex);
236   }
237 
238   // Load the coefficient at the patchIndex location instead of the usual
239   // m_rowIndex,
240   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
241   // EIGEN_DEVICE_FUNC
242   EIGEN_DEVICE_FUNC
operator()243   EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
244     Index rowIndex, colIndex, otherIndex;
245     computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
246     return loadCoeff(row, rowIndex, colIndex, otherIndex);
247   }
248 
249   EIGEN_DEVICE_FUNC
loadPacket(Index row)250   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
251     Index rowIndex, colIndex, otherIndex;
252     computeBaseIndices(0, rowIndex, colIndex, otherIndex);
253     return loadPacket(row, rowIndex, colIndex, otherIndex);
254   }
255 
256   // Load the packet at the patchIndex location instead of the usual m_rowIndex,
257   // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
258   EIGEN_DEVICE_FUNC
loadPacket(Index row,Index patchIndex)259   EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
260     Index rowIndex, colIndex, otherIndex;
261     computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex);
262     return loadPacket(row, rowIndex, colIndex, otherIndex);
263   }
264 
265   EIGEN_DEVICE_FUNC
impl()266   EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
267     return m_impl;
268   }
269 
270   EIGEN_DEVICE_FUNC
patchDepth()271   EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; }
272   EIGEN_DEVICE_FUNC
patchRows()273   EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; }
274   EIGEN_DEVICE_FUNC
patchCols()275   EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
276 
277  private:
278   friend class TensorContractionSubMapper<
279       Scalar, Index, Side,
280       TensorEvaluator<
281           const TensorReshapingOp<
282               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
283           Device>,
284       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
285       inner_dim_reordered, Alignment>;
286 
287   // Load coefficient from a patch specified by the "within patch offset"
288   // (patchId) and the precomputed indices of the first element of the patch.
289   EIGEN_DEVICE_FUNC
loadCoeff(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)290   EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex,
291                                        Index colIndex, Index otherIndex) const {
292     // Find the offset of the element wrt the location of the first element.
293     const Index patchOffset = patchId / m_fastDimZero;
294 
295     const Index colOffset = patchOffset / m_fastColStride;
296     const Index inputCol = colIndex + colOffset * m_in_col_strides;
297     const Index origInputCol =
298         (m_patch_col_inflate_strides == 1)
299             ? inputCol
300             : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
301 
302     const Index rowOffset = patchOffset - colOffset * m_colStride;
303     const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
304     const Index origInputRow =
305         (m_patch_row_inflate_strides == 1)
306             ? inputRow
307             : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
308     if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols ||
309         origInputRow >= m_inputRows ||
310         (inputCol != origInputCol * m_patch_col_inflate_strides) ||
311         (inputRow != origInputRow * m_patch_row_inflate_strides)) {
312       return Scalar(0);
313     }
314     const Index depth = patchId - patchOffset * patchDepth();
315     const Index inputIndex = depth + origInputRow * m_rowInputStride +
316                              origInputCol * m_colInputStride + otherIndex;
317     return m_impl.coeff(inputIndex);
318   }
319 
320   // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
321   // and `in_strides` equal to 1 (template specialization without templates).
322   EIGEN_DEVICE_FUNC
loadCoeffStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)323   EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex,
324                                                Index colIndex,
325                                                Index otherIndex) const {
326     eigen_assert(!nonStandardPatches());
327 
328     // Find the offset of the element wrt the location of the first element.
329     const Index patchOffset = patchId / m_fastDimZero;
330     const Index colOffset = patchOffset / m_fastColStride;
331     const Index rowOffset = patchOffset - colOffset * m_colStride;
332     const Index inputCol = colIndex + colOffset;
333     const Index inputRow = rowIndex + rowOffset;
334     if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
335         inputRow >= m_inputRows) {
336       return Scalar(0);
337     }
338     const Index depth = patchId - patchOffset * patchDepth();
339     const Index inputIndex = depth + inputRow * m_rowInputStride +
340                              inputCol * m_colInputStride + otherIndex;
341     return m_impl.coeff(inputIndex);
342   }
343 
344   // Load packet from a patch specified by the "within patch offset"
345   // (patchId) and the precomputed indices of the first element of the patch.
346   EIGEN_DEVICE_FUNC
loadPacket(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)347   EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex,
348                                         Index colIndex,
349                                         Index otherIndex) const {
350     const Index packetSize = internal::unpacket_traits<Packet>::size;
351     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
352     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
353 
354     if (nonStandardPatches()) {
355       return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
356     }
357     typedef decltype(m_impl) TensorEvaluatorT;
358     return loadPacketStandard<Packet, TensorEvaluatorT>(patchId, rowIndex,
359                                                         colIndex, otherIndex);
360   }
361 
362   // Helper function to load a 'partial' packet - this is the single column
363   // part of a packet that is split across two columns. In the 'partial' packet,
364   // the elements corresponding to the column (specified through colOffset) are
365   // loaded and the rest of the elements are zero-filled into the 'partial'
366   // packet. This function is called from loadPacketStandardFromTwoColumns().
367   // This code path is exercised only when the packet type supports masked load
368   // and when the partial packet load is available in the TensorEvaluator.
369   EIGEN_DEVICE_FUNC
loadPartialPacketStandard(Index rowIndex,Index colIndex,Index otherIndex,Index patchId,const Index span[],const Index patchOffsets[],Index colOffset)370   EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard(
371       Index rowIndex, Index colIndex, Index otherIndex, Index patchId,
372       const Index span[], const Index patchOffsets[], Index colOffset) const {
373     const Index inputCol = colIndex + colOffset;
374     const Index rowOffsets[2] = {patchOffsets[0] - colOffset * m_colStride,
375                                  patchOffsets[1] - colOffset * m_colStride};
376     const Index inputRows[2] = {rowIndex + rowOffsets[0],
377                                 rowIndex + rowOffsets[1]};
378 
379     if (inputRows[0] >= m_inputRows || inputRows[1] < 0 ||
380         inputCol >= m_inputCols || inputCol < 0) {
381       // Partial packet is all zeros
382       return internal::pset1<Packet>(Scalar(0));
383     } else if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) {
384       // From inputIndex-span[0], we need to load elements starting from index
385       // span[0] all the way upto (and including) span[1].
386       const Index depth = patchId - patchOffsets[0] * patchDepth();
387       const Index inputIndex = depth + inputRows[0] * m_rowInputStride +
388                                inputCol * m_colInputStride + otherIndex;
389       return m_impl.template partialPacket<Packet>(
390           inputIndex - span[0], mask<Packet>(span[0], span[1] + 1));
391     } else {
392       // Using slow path for this partial packet.
393       // We need to load elements starting from index span[0] all the way upto
394       // (and including) span[1]. We split this load into 3 parts:
395       // 0 : span[0]-1 - Zeros will be loaded for these indices
396       // span[0] : span[1] - Elements will be loaded here for these indices
397       // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices
398       const Index packetSize = internal::unpacket_traits<Packet>::size;
399       EIGEN_ALIGN_MAX
400       std::remove_const_t<Scalar> values[packetSize];
401       for (int i = 0; i < span[0]; ++i) values[i] = Scalar(0);
402       for (int i = span[0]; i < span[1] + 1; ++i)
403         values[i] =
404             loadCoeff(patchId - span[0] + i, rowIndex, colIndex, otherIndex);
405       for (int i = span[1] + 1; i < packetSize; ++i) values[i] = Scalar(0);
406       return internal::pload<Packet>(values);
407     }
408   }
409 
410   // Helper function to load a packet that is split across two columns.
411   // If required, this function is called from loadPacketStandard() when the
412   // packet type supports masked load and when the partial packet load is
413   // available in the TensorEvaluator.
414   EIGEN_DEVICE_FUNC
loadPacketStandardFromTwoColumns(Index patchId,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[])415   EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromTwoColumns(
416       Index patchId, Index rowIndex, Index colIndex, Index otherIndex,
417       const Index patchOffsets[], const Index colOffsets[]) const {
418     eigen_assert(colOffsets[1] == colOffsets[0] + 1);
419     const Index packetSize = internal::unpacket_traits<Packet>::size;
420 
421     // Packet to load will be split into 2 parts where each part spans a single
422     // column. First determine where to split.
423     const Index patchIdSplit =
424         ((colOffsets[1] * m_colStride) * m_rowInputStride) - 1;
425     const Index patchOffsetSplit = patchIdSplit / m_fastDimZero;
426 
427     // patchIds[i]:          patchId corresponding to partial packet i
428     // spans[i]:             Start and end indices corresponding to the elements
429     //                       to be loaded for partial packet i
430     // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i
431     const Index patchIds[2] = {patchId, patchIdSplit + 1};
432     const Index spans[2][2] = {{0, patchIdSplit - patchId},
433                                {patchIdSplit - patchId + 1, packetSize - 1}};
434     const Index patchOffsets2Cols[2][2] = {
435         {patchOffsets[0], patchOffsetSplit},
436         {patchOffsetSplit + 1, patchOffsets[1]}};
437 
438     // Load partial packets and do bit-wise OR to generate required packet
439     return internal::por<Packet>(
440         loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[0],
441                                   spans[0], patchOffsets2Cols[0],
442                                   colOffsets[0]),
443         loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[1],
444                                   spans[1], patchOffsets2Cols[1],
445                                   colOffsets[1]));
446   }
447 
448   // Helper function to load a packet that is present in a single columns.
449   // If required, this function is called from loadPacketStandard().
450   EIGEN_DEVICE_FUNC
loadPacketStandardFromSingleColumn(Index patchId,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[],const Index inputCols[])451   EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumn(
452       Index patchId, Index rowIndex, Index colIndex, Index otherIndex,
453       const Index patchOffsets[], const Index colOffsets[],
454       const Index inputCols[]) const {
455     eigen_assert(colOffsets[0] == colOffsets[1]);
456     const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0] * m_colStride,
457                                  patchOffsets[1] - colOffsets[1] * m_colStride};
458     eigen_assert(rowOffsets[0] <= rowOffsets[1]);
459     const Index inputRows[2] = {rowIndex + rowOffsets[0],
460                                 rowIndex + rowOffsets[1]};
461 
462     if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
463       // all zeros
464       return internal::pset1<Packet>(Scalar(0));  // all zeros
465     }
466 
467     if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) {
468       // no padding
469       const Index depth = patchId - patchOffsets[0] * patchDepth();
470       const Index inputIndex = depth + inputRows[0] * m_rowInputStride +
471                                inputCols[0] * m_colInputStride + otherIndex;
472       return m_impl.template packet<Unaligned>(inputIndex);
473     }
474     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
475   }
476 
477   // Load standard packet from a patch specified by the "within patch offset"
478   // (patchId) and the precomputed indices of the first element of the patch.
479   // This function will be called if partial packet loading is not available
480   // for the TensorEvaluator or if the packet type does not support masked
481   // load.
482   template <typename PacketT, typename TensorEvaluatorT>
483   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
484       !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
485       PacketT>::type
loadPacketStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)486   loadPacketStandard(Index patchId, Index rowIndex, Index colIndex,
487                      Index otherIndex) const {
488     const Index packetSize = internal::unpacket_traits<Packet>::size;
489     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
490     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
491 
492     eigen_assert(!nonStandardPatches());
493 
494     if ((patchDepth() % packetSize) == 0) {
495       return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
496     }
497 
498     // Offsets and input calculation here are identical to
499     // loadCoeffStandard(...), but repeated twice.
500     const Index patchOffsets[2] = {patchId / m_fastDimZero,
501                                    (patchId + packetSize - 1) / m_fastDimZero};
502     const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
503                                  patchOffsets[1] / m_fastColStride};
504     const Index inputCols[2] = {colIndex + colOffsets[0],
505                                 colIndex + colOffsets[1]};
506 
507     if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
508       // all zeros
509       return internal::pset1<Packet>(Scalar(0));
510     }
511     if (inputCols[0] == inputCols[1]) {
512       return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex,
513                                                 otherIndex, patchOffsets,
514                                                 colOffsets, inputCols);
515     }
516     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
517   }
518 
519   // Load standard packet from a patch specified by the "within patch offset"
520   // (patchId) and the precomputed indices of the first element of the patch.
521   // This function will be called if partial packet loading is available for
522   // the TensorEvaluator and if the packet type supports masked load.
523   // The only difference between this and the other case is that if the packet
524   // to load is split across two columns, then in this case instead of going to
525   // the slow (element-by-element) load, we load two packets - each containing
526   // elements from one of the columns (rest of the elements of the packets are
527   // zeroes), and then combine these two packets to generate the required
528   // packet. The idea is to enable fast load (if possible) of these 'partial'
529   // packets.
530   template <typename PacketT, typename TensorEvaluatorT>
531   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
532       TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
533       PacketT>::type
loadPacketStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)534   loadPacketStandard(Index patchId, Index rowIndex, Index colIndex,
535                      Index otherIndex) const {
536     const Index packetSize = internal::unpacket_traits<PacketT>::size;
537     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
538     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
539 
540     eigen_assert(!nonStandardPatches());
541 
542     if ((patchDepth() % packetSize) == 0) {
543       return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
544     }
545 
546     // Offsets and input calculation here are identical to
547     // loadCoeffStandard(...), but repeated twice.
548     const Index patchOffsets[2] = {patchId / m_fastDimZero,
549                                    (patchId + packetSize - 1) / m_fastDimZero};
550     const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
551                                  patchOffsets[1] / m_fastColStride};
552     const Index inputCols[2] = {colIndex + colOffsets[0],
553                                 colIndex + colOffsets[1]};
554 
555     if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
556       // all zeros
557       return internal::pset1<PacketT>(Scalar(0));
558     }
559     if (inputCols[0] == inputCols[1]) {
560       return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex,
561                                                 otherIndex, patchOffsets,
562                                                 colOffsets, inputCols);
563     }
564     if (inputCols[1] == inputCols[0] + 1) {
565       return loadPacketStandardFromTwoColumns(
566           patchId, rowIndex, colIndex, otherIndex, patchOffsets, colOffsets);
567     }
568     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
569   }
570 
571   EIGEN_DEVICE_FUNC
loadPacketFast(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)572   EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex,
573                                             Index colIndex,
574                                             Index otherIndex) const {
575     const Index packetSize = internal::unpacket_traits<Packet>::size;
576     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
577     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
578 
579     eigen_assert(!nonStandardPatches());
580     eigen_assert((patchDepth() % packetSize) == 0);
581     // Find the offset of the element wrt the location of the first element.
582     const Index patchOffset = patchId / m_fastDimZero;
583     eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
584 
585     const Index colOffset = patchOffset / m_fastColStride;
586     const Index rowOffset = patchOffset - colOffset * m_colStride;
587     const Index inputCol = colIndex + colOffset;
588     const Index inputRow = rowIndex + rowOffset;
589     if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols ||
590         inputRow >= m_inputRows) {
591       // all zeros
592       return internal::pset1<Packet>(Scalar(0));
593     }
594     // no padding
595     const Index depth = patchId - patchOffset * patchDepth();
596     const Index inputIndex = depth + inputRow * m_rowInputStride +
597                              inputCol * m_colInputStride + otherIndex;
598     return m_impl.template packet<Unaligned>(inputIndex);
599   }
600 
packetWithPossibleZero(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)601   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(
602       Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const {
603     const int packetSize = internal::unpacket_traits<Packet>::size;
604     EIGEN_ALIGN_MAX
605     std::remove_const_t<Scalar> values[packetSize];
606     for (int i = 0; i < packetSize; ++i) {
607       values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex);
608     }
609     Packet rslt = internal::pload<Packet>(values);
610     return rslt;
611   }
612 
computeBaseIndices(Index patchIndex,Index & rowIndex,Index & colIndex,Index & otherIndex)613   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
614       Index patchIndex, Index& rowIndex, Index& colIndex,
615       Index& otherIndex) const {
616     const size_t NumInputDims = array_size<
617         typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
618     otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
619     const Index patch2DIndex = (NumInputDims == 3)
620                                    ? patchIndex
621                                    : (patchIndex - otherIndex * m_num_patches);
622     otherIndex *= m_patchInputStride;
623     colIndex = patch2DIndex / m_fastOutputRows;
624     rowIndex = patch2DIndex - colIndex * m_outputRows;
625     colIndex = colIndex * m_col_strides - m_colPaddingLeft;
626     rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
627   }
628 
629   Index m_patch_cols;   // number of columns in the patch
630   Index m_num_patches;  // number of patches to extract.
631 
632   // Strides for navigating through the single patch.
633   Index m_patch_row_stride;
634   Index m_patch_col_stride;
635   internal::TensorIntDivisor<Index> m_fastPatchRowStride;
636   internal::TensorIntDivisor<Index> m_fastPatchColStride;
637 
638   Index m_patch_row_inflate_strides;  // the strides for row inflation in the
639                                       // image patch
640   Index m_patch_col_inflate_strides;  // the strides for col inflation in the
641                                       // image patch
642   // Fast representation of inflation strides.
643   internal::TensorIntDivisor<Index> m_fastInputRowStride;
644   internal::TensorIntDivisor<Index> m_fastInputColStride;
645 
646   Index m_otherStride;
647   Index m_colStride;
648   internal::TensorIntDivisor<Index> m_fastNumPatches;
649   internal::TensorIntDivisor<Index> m_fastColStride;
650 
651   Index m_rowInputStride;    // row stride in the input tensor
652   Index m_colInputStride;    // col stride in the input tensor
653   Index m_patchInputStride;  // patch stride in the input tensor
654 
655   Index m_inputRows;  // Number of rows in the input tensor
656   Index m_inputCols;  // Number of cols in the input tensor
657 
658   Index m_outputRows;  // Number of convolution output rows
659   Index m_outputCols;  // Number of convolution output column
660 
661   Index m_row_strides;  // User specified row stride
662   Index m_col_strides;  // User specified col stride
663 
664   Index m_in_row_strides;  // User specified input row stride
665   Index m_in_col_strides;  // User specified input col stride
666 
667   Index m_rowPaddingTop;   // Row padding
668   Index m_colPaddingLeft;  // Column padding
669 
670   internal::TensorIntDivisor<Index> m_fastOutputRows;
671   internal::TensorIntDivisor<Index> m_fastDimZero;
672 
673   const TensorEvaluator<ArgType, Device> m_impl;
674 };
675 
676 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
677           typename Device, typename Scalar, typename Index,
678           typename nocontract_t, typename contract_t, int Side, int packet_size,
679           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
680 class TensorContractionSubMapper<
681     Scalar, Index, Side,
682     TensorEvaluator<
683         const TensorReshapingOp<NewDimension,
684                                 const TensorImagePatchOp<Rows, Cols, ArgType> >,
685         Device>,
686     nocontract_t, contract_t, packet_size, inner_dim_contiguous,
687     inner_dim_reordered, Alignment> {
688  public:
689   typedef typename packet_traits<Scalar>::type Packet;
690   typedef typename packet_traits<Scalar>::half HalfPacket;
691 
692   typedef TensorContractionInputMapper<
693       Scalar, Index, Side,
694       TensorEvaluator<
695           const TensorReshapingOp<
696               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
697           Device>,
698       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
699       inner_dim_reordered, Alignment>
700       ParentMapper;
701 
702   typedef TensorContractionSubMapper<
703       Scalar, Index, Side,
704       TensorEvaluator<
705           const TensorReshapingOp<
706               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
707           Device>,
708       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
709       inner_dim_reordered, Alignment>
710       Self;
711 
712   typedef Self LinearMapper;
713 
714   typedef typename ParentMapper::TensorEvaluatorT TensorEvaluatorT;
715 
TensorContractionSubMapper(const ParentMapper & base_mapper,Index vert_offset,Index horiz_offset)716   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
717       const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
718       : m_depth_offset(vert_offset),
719         m_col_offset(horiz_offset),
720         m_base_mapper(base_mapper) {
721     m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex,
722                                      m_otherIndex);
723   }
TensorContractionSubMapper(const Self & base_mapper,Index vert_offset,Index horiz_offset)724   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
725       const Self& base_mapper, Index vert_offset, Index horiz_offset)
726       : m_depth_offset(vert_offset + base_mapper.m_depth_offset),
727         m_col_offset(horiz_offset + base_mapper.m_col_offset),
728         m_base_mapper(base_mapper.m_base_mapper) {
729     m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex,
730                                      m_otherIndex);
731   }
operator()732   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
733     return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex,
734                                    m_otherIndex);
735   }
operator()736   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
737                                                           Index j) const {
738     return m_base_mapper(i + m_depth_offset, j + m_col_offset);
739   }
740 
loadPacket(Index i)741   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
742     return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex,
743                                     m_otherIndex);
744   }
loadPacket(Index i,Index j)745   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
746                                                           Index j) const {
747     return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
748                                                         j + m_col_offset);
749   }
750   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
loadCoeffStandard(Index i)751   loadCoeffStandard(Index i) const {
752     return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex,
753                                            m_colIndex, m_otherIndex);
754   }
755 
loadPacketFast(Index i)756   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
757     return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex,
758                                         m_colIndex, m_otherIndex);
759   }
760   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
loadPacketStandard(Index i)761   loadPacketStandard(Index i) const {
762     typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT;
763     return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>(
764         i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
765   }
766   template <typename Packet>
aligned(Index)767   EIGEN_DEVICE_FUNC bool aligned(Index) const {
768     return false;
769   }
770 
771   EIGEN_DEVICE_FUNC
nonStandardPatches()772   EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
773     return m_base_mapper.nonStandardPatches();
774   }
775 
776   // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
777   // index respectively that fits into the peeled_k elements starting at
778   // m_depth_offset.
779 
780   EIGEN_DEVICE_FUNC
maxCol(const Index peeled_k)781   EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
782     const Index max_col =
783         (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) /
784         fastPatchColStride();
785     return std::min<Index>(1 + max_col, patchCols());
786   }
787 
788   EIGEN_DEVICE_FUNC
maxRow(const Index peeled_k,const Index col)789   EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
790                                    const Index col) const {
791     const Index max_row = (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) -
792                            col * patchColStride()) /
793                           fastPatchRowStride();
794     return std::min<Index>(1 + max_row, patchRows());
795   }
796 
797   EIGEN_DEVICE_FUNC
maxDepth(const Index peeled_k,const Index col,Index row)798   EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col,
799                                      Index row) const {
800     const Index max_depth = m_depth_offset + peeled_k -  //
801                             col * patchColStride() -     //
802                             row * patchRowStride();
803     return std::min<Index>(max_depth, patchDepth());
804   }
805 
806   // MaxDepth uses only the remaining number of elements in the peeled_k.
807   EIGEN_DEVICE_FUNC
maxDepth(const Index num_elements,const Index start_depth)808   EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
809                                      const Index start_depth) const {
810     return std::min<Index>(start_depth + num_elements, patchDepth());
811   }
812 
813   // Every register matters in this code, so sometimes to prevent register
814   // spilling, instead of the variable that you would expect to see, we use
815   // another one, that is guaranteed to have the same value. E.g. patch depth is
816   // always the same as input depth, and it's also the same as input row stride.
817   // Bunch of other parameters have similar relations.
818 
819   typedef internal::TensorIntDivisor<Index> IndexDivisor;
820 
821   EIGEN_DEVICE_FUNC
patchDepth()822   EIGEN_ALWAYS_INLINE Index patchDepth() const {
823     return m_base_mapper.m_rowInputStride;
824   }
825   EIGEN_DEVICE_FUNC
patchRows()826   EIGEN_ALWAYS_INLINE Index patchRows() const {
827     return m_base_mapper.m_colStride;
828   }
829   EIGEN_DEVICE_FUNC
patchCols()830   EIGEN_ALWAYS_INLINE Index patchCols() const {
831     return m_base_mapper.m_patch_cols;
832   }
833 
834   EIGEN_DEVICE_FUNC
patchRowStride()835   EIGEN_ALWAYS_INLINE Index patchRowStride() const {
836     eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
837                  "Patch depth must be equal to patch row stride.");
838     return patchDepth();
839   }
840   EIGEN_DEVICE_FUNC
patchColStride()841   EIGEN_ALWAYS_INLINE Index patchColStride() const {
842     return m_base_mapper.m_patch_col_stride;
843   }
844 
845   EIGEN_DEVICE_FUNC
fastPatchRowStride()846   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
847     eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
848                  "Patch depth must be equal to patch row stride.");
849     return m_base_mapper.m_fastDimZero;  // patch_depth
850   }
851   EIGEN_DEVICE_FUNC
fastPatchColStride()852   EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
853     return m_base_mapper.m_fastPatchColStride;
854   }
855 
856   EIGEN_DEVICE_FUNC
packetNoPadding(const Index depth,const Index baseIndex)857   EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
858                                              const Index baseIndex) const {
859     const Index inputIndex = depth + baseIndex;
860     return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
861   }
862   EIGEN_DEVICE_FUNC
coeffNoPadding(const Index depth,const Index baseIndex)863   EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth,
864                                             const Index baseIndex) const {
865     const Index inputIndex = depth + baseIndex;
866     return m_base_mapper.m_impl.coeff(inputIndex);
867   }
868   template <typename PacketT = Packet>
869   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if<
870       TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value,
871       PacketT>::type
partialPacketNoPadding(const Index depth,const Index baseIndex,Index num_coeffs)872   partialPacketNoPadding(const Index depth, const Index baseIndex,
873                          Index num_coeffs) const {
874     const Index inputIndex = depth + baseIndex;
875     return m_base_mapper.m_impl.template partialPacket<PacketT>(
876         inputIndex, mask<PacketT>(0, num_coeffs));
877   }
878   EIGEN_DEVICE_FUNC
hasPadding()879   EIGEN_ALWAYS_INLINE bool hasPadding() const {
880     // TODO(ezhulenev): It does seems that for inflated filter it's still
881     // possible to guarantee "no padding or skipping" for non-standard packing.
882     if (nonStandardPatches()) return true;
883 
884     // Non zero padding before.
885     if (m_base_mapper.m_rowPaddingTop > 0) return true;
886     if (m_base_mapper.m_colPaddingLeft > 0) return true;
887 
888     // Non zero padding after in rows.
889     const Index last_row =
890         (m_base_mapper.m_outputRows - 1) * m_base_mapper.m_row_strides;
891     if (last_row + (patchRows() - 1) >= m_base_mapper.m_inputRows) return true;
892 
893     // Non zero padding after in cols.
894     const Index last_col =
895         (m_base_mapper.m_outputCols - 1) * m_base_mapper.m_col_strides;
896     if (last_col + (patchCols() - 1) >= m_base_mapper.m_inputCols) return true;
897 
898     return false;
899   }
900   EIGEN_DEVICE_FUNC
padRow(const Index row)901   EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
902     const Index r = m_rowIndex + row;
903     return r < 0 || r >= m_base_mapper.m_inputRows;
904   }
905   EIGEN_DEVICE_FUNC
padAnyRow(const Index first_row,const Index last_row)906   EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row,
907                                      const Index last_row) const {
908     return m_rowIndex + first_row < 0 ||
909            m_rowIndex + last_row >= m_base_mapper.m_inputRows;
910   }
911   EIGEN_DEVICE_FUNC
padOrSkipRow(const Index row,Index * orig_row)912   EIGEN_ALWAYS_INLINE bool padOrSkipRow(const Index row,
913                                         Index* orig_row) const {
914     eigen_assert(nonStandardPatches());
915 
916     const Index input_row = m_rowIndex + row * m_base_mapper.m_in_row_strides;
917     *orig_row = (m_base_mapper.m_patch_row_inflate_strides == 1)
918                     ? input_row
919                     : ((input_row >= 0)
920                            ? (input_row / m_base_mapper.m_fastInputRowStride)
921                            : 0);
922 
923     return (*orig_row < 0 || *orig_row >= m_base_mapper.m_inputRows) ||
924            (input_row != *orig_row * m_base_mapper.m_patch_row_inflate_strides);
925   }
926   EIGEN_DEVICE_FUNC
padCol(const Index col)927   EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
928     const Index c = m_colIndex + col;
929     return c < 0 || c >= m_base_mapper.m_inputCols;
930   }
931   EIGEN_DEVICE_FUNC
padOrSkipCol(const Index col,Index * orig_col)932   EIGEN_ALWAYS_INLINE bool padOrSkipCol(const Index col,
933                                         Index* orig_col) const {
934     eigen_assert(nonStandardPatches());
935 
936     const Index input_col = m_colIndex + col * m_base_mapper.m_in_col_strides;
937     *orig_col = (m_base_mapper.m_patch_col_inflate_strides == 1)
938                     ? input_col
939                     : ((input_col >= 0)
940                            ? (input_col / m_base_mapper.m_fastInputColStride)
941                            : 0);
942 
943     return (*orig_col < 0 || *orig_col >= m_base_mapper.m_inputCols) ||
944            (input_col != *orig_col * m_base_mapper.m_patch_col_inflate_strides);
945   }
946   EIGEN_DEVICE_FUNC
baseIndex(const Index row,const Index col)947   EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const {
948     const Index r = m_rowIndex + row;
949     const Index c = m_colIndex + col;
950     return r * m_base_mapper.m_rowInputStride +
951            c * m_base_mapper.m_colInputStride + m_otherIndex;
952   }
953   // Compute a base index when original input row and column were precomputed
954   // using padOrSkipRow and padOrSkipCol. Used only for non standard patches.
955   EIGEN_DEVICE_FUNC
origBaseIndex(const Index orig_row,const Index orig_col)956   EIGEN_ALWAYS_INLINE Index origBaseIndex(const Index orig_row,
957                                           const Index orig_col) const {
958     return orig_row * m_base_mapper.m_rowInputStride +
959            orig_col * m_base_mapper.m_colInputStride + m_otherIndex;
960   }
961 
962   EIGEN_DEVICE_FUNC
rowStride()963   EIGEN_ALWAYS_INLINE Index rowStride() const {
964     return m_base_mapper.m_row_strides;
965   }
966   EIGEN_DEVICE_FUNC
colStride()967   EIGEN_ALWAYS_INLINE Index colStride() const {
968     return m_base_mapper.m_col_strides;
969   }
970 
971   EIGEN_DEVICE_FUNC
rowOffset()972   EIGEN_ALWAYS_INLINE Index rowOffset() const {
973     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
974     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
975     return patchOffset - colOffset * m_base_mapper.m_colStride;
976   }
977 
978   EIGEN_DEVICE_FUNC
colOffset()979   EIGEN_ALWAYS_INLINE Index colOffset() const {
980     const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
981     const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
982     return colOffset;
983   }
984 
985   EIGEN_DEVICE_FUNC
depthOffset()986   EIGEN_ALWAYS_INLINE Index depthOffset() const {
987     return m_depth_offset % patchDepth();
988   }
989 
990   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
getLinearMapper(Index i,Index j)991   getLinearMapper(Index i, Index j) const {
992     return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
993   }
994 
995  private:
996   Index m_depth_offset;  // First row in the input matrix
997   Index m_col_offset;    // First col in the input matrix
998 
999   // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
1000   // indices for the first element in a patch specified by col_offset
1001   // (see computeBaseIndices(...) for details).
1002   Index m_rowIndex;
1003   Index m_colIndex;
1004   Index m_otherIndex;
1005 
1006   const ParentMapper m_base_mapper;  // Keeping a copy instead of a reference
1007                                      // performs better in benchmarks.
1008 };
1009 
1010 // Arrange a block of the right input matrix (in our case it's always a "virtual
1011 // matrix" constructed from extracted image patches) in contiguous memory.
1012 //
1013 // Given column major input (A0 beside A1 in memory):
1014 // A0 B0 C0 D0  E0 F0 G0 H0 ... Z0
1015 // A1 B1 C1 D1  E1 F1 G1 H1 ... Z1
1016 // A2 B2 C2 D2  E2 F2 G2 H2 ... Z2
1017 // A3 B3 C3 D3  E3 F3 G3 H3 ... Z3
1018 // A4 B4 C4 D4  E4 F4 G4 H4 ... Z4
1019 // A5 B5 C5 D5  E5 F5 G5 H5 ... Z5
1020 // A6 B6 C6 D6  E6 F6 G6 H6 ... Z6
1021 // A7 B7 C7 D7  E7 F7 G7 H7 ... Z7
1022 // A8 ...
1023 // ...
1024 //
1025 // *) A, B, C, ... - patches extracted from the original input.
1026 // *) A0, A1, A2 ... - values from the same patch at different offsets.
1027 //
1028 // The traversal (packed rhs memory) order (B0 besides A0 in memory):
1029 // A0 B0 C0 D0 A1 B1 C1 D1 ...
1030 // E0 F0 G0 H0 E1 F1 G1 H1 ...
1031 // ...
1032 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
1033 //
1034 // This traversal order must be the same as in default gemm_pack_rhs defined in
1035 // GeneralBlockPanelKernel.h.
1036 //
1037 // *) nr - number of registers along the 'n' dimension.
1038 //    See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
1039 //    Multiplication" paper.
1040 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
1041           typename Device, typename Scalar, typename Index,
1042           typename nocontract_t, typename contract_t, int packet_size,
1043           bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
1044           int nr>
1045 struct gemm_pack_rhs<
1046     Scalar, Index,
1047     TensorContractionSubMapper<
1048         Scalar, Index, Rhs,
1049         TensorEvaluator<
1050             const TensorReshapingOp<
1051                 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1052             Device>,
1053         nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1054         inner_dim_reordered, Alignment>,
1055     nr, ColMajor, false, false> {
1056   typedef TensorContractionSubMapper<
1057       Scalar, Index, Rhs,
1058       TensorEvaluator<
1059           const TensorReshapingOp<
1060               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1061           Device>,
1062       nocontract_t, contract_t, packet_size, inner_dim_contiguous,
1063       inner_dim_reordered, Alignment>
1064       SubMapper;
1065   typedef SubMapper DataMapper;
1066   typedef typename packet_traits<Scalar>::type Packet;
1067 
1068   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1069 
1070   EIGEN_DEVICE_FUNC
1071   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1072                                     Index depth, Index cols, Index stride = 0,
1073                                     Index offset = 0) const {
1074     eigen_assert(stride == 0);
1075     eigen_assert(offset == 0);
1076 
1077     const Index packet_cols4 = (cols / 4) * 4;
1078     const Index peeled_k = (depth / packet_size) * packet_size;
1079     const bool non_standard_patches = rhs.nonStandardPatches();
1080 
1081     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1082       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1083       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1084       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1085       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1086 
1087       Index k = 0;
1088       if ((packet_size % 4) == 0 && !non_standard_patches) {
1089         // FAST PATH:
1090         // Iterate over patch columns and rows, if we know that a single
1091         // packet do not span across multiple rows or columns.
1092         if ((rhs.patchDepth() % packet_size) == 0) {
1093           const Index start_col = rhs.colOffset();
1094           const Index max_col = rhs.maxCol(peeled_k);
1095 
1096           for (Index c = start_col; c < max_col; ++c) {
1097             eigen_assert(k <= peeled_k);
1098 
1099             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1100             const Index max_row = rhs.maxRow(peeled_k, c);
1101 
1102             const bool pad_col0 = dm0.padCol(c);
1103             const bool pad_col1 = dm1.padCol(c);
1104             const bool pad_col2 = dm2.padCol(c);
1105             const bool pad_col3 = dm3.padCol(c);
1106 
1107             // Check if we can squeeze reads along the `row` and `depth`
1108             // dimensions (two innermost dimensions).
1109             if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 &&    //
1110                 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) &&  //
1111                 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) &&  //
1112                 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) &&  //
1113                 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) {
1114               // Compute how many elements we can squeeze read.
1115               const Index start_depth =
1116                   (c == start_col) ? rhs.depthOffset() : 0;
1117 
1118               // Upper bound for the number of elements in the depth dimension
1119               // that we can squeeze read.
1120               const Index squeeze_length =
1121                   (max_row - start_row) * rhs.patchDepth() - start_depth;
1122 
1123               // Do not overshoot beyond the block size.
1124               const Index max_depth =
1125                   start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1126               eigen_assert((max_depth - start_depth) % packet_size == 0);
1127 
1128               const Index idx0 = dm0.baseIndex(start_row, c);
1129               const Index idx1 = dm1.baseIndex(start_row, c);
1130               const Index idx2 = dm2.baseIndex(start_row, c);
1131               const Index idx3 = dm3.baseIndex(start_row, c);
1132 
1133               for (Index d = start_depth; d < max_depth; d += packet_size) {
1134                 eigen_assert(k < peeled_k);
1135                 PacketBlock<Packet, 4> kernel;
1136                 kernel.packet[0] = rhs.packetNoPadding(d, idx0);
1137                 kernel.packet[1] = rhs.packetNoPadding(d, idx1);
1138                 kernel.packet[2] = rhs.packetNoPadding(d, idx2);
1139                 kernel.packet[3] = rhs.packetNoPadding(d, idx3);
1140                 ptranspose(kernel);
1141                 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1142                 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1143                 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1144                 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1145                 block += 4 * packet_size;
1146                 k += packet_size;
1147               }
1148 
1149               // Go to the next column.
1150               continue;
1151             }
1152 
1153             // If we can't squeeze reads, process rows one by one.
1154             for (Index r = start_row; r < max_row; ++r) {
1155               eigen_assert(k <= peeled_k);
1156 
1157               const bool pad0 = pad_col0 || dm0.padRow(r);
1158               const bool pad1 = pad_col1 || dm1.padRow(r);
1159               const bool pad2 = pad_col2 || dm2.padRow(r);
1160               const bool pad3 = pad_col3 || dm3.padRow(r);
1161 
1162               const Index idx0 = dm0.baseIndex(r, c);
1163               const Index idx1 = dm1.baseIndex(r, c);
1164               const Index idx2 = dm2.baseIndex(r, c);
1165               const Index idx3 = dm3.baseIndex(r, c);
1166 
1167               const Index start_depth = ((c == start_col) && (r == start_row))
1168                                             ? rhs.depthOffset()
1169                                             : 0;
1170               const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1171               eigen_assert((max_depth - start_depth) % packet_size == 0);
1172 
1173               for (Index d = start_depth; d < max_depth; d += packet_size) {
1174                 eigen_assert(k < peeled_k);
1175                 PacketBlock<Packet, 4> kernel;
1176                 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1177                                         : rhs.packetNoPadding(d, idx0);
1178                 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1179                                         : rhs.packetNoPadding(d, idx1);
1180                 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
1181                                         : rhs.packetNoPadding(d, idx2);
1182                 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
1183                                         : rhs.packetNoPadding(d, idx3);
1184                 ptranspose(kernel);
1185                 pstoreu(block + 0 * packet_size, kernel.packet[0]);
1186                 pstoreu(block + 1 * packet_size, kernel.packet[1]);
1187                 pstoreu(block + 2 * packet_size, kernel.packet[2]);
1188                 pstoreu(block + 3 * packet_size, kernel.packet[3]);
1189                 block += 4 * packet_size;
1190                 k += packet_size;
1191               }
1192             }
1193           }
1194 
1195           // The loop above should fill peeled_k elements.
1196           eigen_assert(peeled_k == k);
1197 
1198         } else {
1199           for (; k < peeled_k; k += packet_size) {
1200             PacketBlock<Packet, 4> kernel;
1201             kernel.packet[0] = dm0.loadPacketStandard(k);
1202             kernel.packet[1] = dm1.loadPacketStandard(k);
1203             kernel.packet[2] = dm2.loadPacketStandard(k);
1204             kernel.packet[3] = dm3.loadPacketStandard(k);
1205             ptranspose(kernel);
1206             pstoreu(block + 0 * packet_size, kernel.packet[0]);
1207             pstoreu(block + 1 * packet_size, kernel.packet[1]);
1208             pstoreu(block + 2 * packet_size, kernel.packet[2]);
1209             pstoreu(block + 3 * packet_size, kernel.packet[3]);
1210             block += 4 * packet_size;
1211           }
1212         }
1213       }
1214 
1215       // Copy the remaining coefficients of the column block after the peeled_k.
1216       if (!rhs.nonStandardPatches()) {
1217         for (; k < depth; k++) {
1218           block[0] = dm0.loadCoeffStandard(k);
1219           block[1] = dm1.loadCoeffStandard(k);
1220           block[2] = dm2.loadCoeffStandard(k);
1221           block[3] = dm3.loadCoeffStandard(k);
1222           block += 4;
1223         }
1224       } else {
1225         for (; k < depth; k++) {
1226           block[0] = dm0(k);
1227           block[1] = dm1(k);
1228           block[2] = dm2(k);
1229           block[3] = dm3(k);
1230           block += 4;
1231         }
1232       }
1233     }
1234 
1235     // copy the remaining columns one at a time (nr==1)
1236 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
1237     // remaining columns are handled different for PPC
1238     for (Index k = 0; k < depth; k++) {
1239       for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1240         *block = rhs(k, j2);
1241         block += 1;
1242       }
1243     }
1244 #else
1245     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1246       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1247       for (Index k = 0; k < depth; k++) {
1248         *block = dm0(k);
1249         block += 1;
1250       }
1251     }
1252 #endif
1253   }
1254 };
1255 
1256 // Template specialization for packet_size = 2. We must special-case packet
1257 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
1258 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
1259           typename Device, typename Scalar, typename Index,
1260           typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1261           bool inner_dim_reordered, int Alignment, int nr>
1262 struct gemm_pack_rhs<
1263     Scalar, Index,
1264     TensorContractionSubMapper<
1265         Scalar, Index, Rhs,
1266         TensorEvaluator<
1267             const TensorReshapingOp<
1268                 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1269             Device>,
1270         nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered,
1271         Alignment>,
1272     nr, ColMajor, false, false> {
1273   typedef TensorContractionSubMapper<
1274       Scalar, Index, Rhs,
1275       TensorEvaluator<
1276           const TensorReshapingOp<
1277               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1278           Device>,
1279       nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered,
1280       Alignment>
1281       SubMapper;
1282   typedef SubMapper DataMapper;
1283   typedef typename packet_traits<Scalar>::type Packet;
1284 
1285   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1286 
1287   EIGEN_DEVICE_FUNC
1288   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1289                                     Index depth, Index cols, Index stride = 0,
1290                                     Index offset = 0) const {
1291     eigen_assert(stride == 0);
1292     eigen_assert(offset == 0);
1293 
1294     const int packet_size = 2;
1295     const Index packet_cols4 = (cols / 4) * 4;
1296     const Index peeled_k = (depth / packet_size) * packet_size;
1297     const bool non_standard_patches = rhs.nonStandardPatches();
1298 
1299     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1300       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1301       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1302       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1303       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1304 
1305       Index k = 0;
1306       if (!non_standard_patches) {
1307         // FAST PATH:
1308         // Iterate over patch columns and rows if we know that a single
1309         // packet do not span across multiple rows or columns.
1310         if ((rhs.patchDepth() % packet_size) == 0) {
1311           const Index start_col = rhs.colOffset();
1312           const Index max_col = rhs.maxCol(peeled_k);
1313 
1314           for (Index c = start_col; c < max_col; ++c) {
1315             eigen_assert(k <= peeled_k);
1316 
1317             const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
1318             const Index max_row = rhs.maxRow(peeled_k, c);
1319 
1320             const bool pad_col0 = dm0.padCol(c);
1321             const bool pad_col1 = dm1.padCol(c);
1322             const bool pad_col2 = dm2.padCol(c);
1323             const bool pad_col3 = dm3.padCol(c);
1324 
1325             // We can squeeze reads along the `row` and `depth` dimensions if
1326             // the row stride is `1`, which means that `row` and `depth`
1327             // dimensions are contiguous (two innermost dimensions).
1328             if (rhs.rowStride() == 1 &&                                //
1329                 !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 &&    //
1330                 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) &&  //
1331                 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) &&  //
1332                 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) &&  //
1333                 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) {
1334               // Compute how many elements we can squeeze read.
1335               const Index start_depth =
1336                   (c == start_col) ? rhs.depthOffset() : 0;
1337 
1338               // Upper bound for the number of elements in the depth dimension
1339               // that we can squeeze read.
1340               const Index squeeze_length =
1341                   (max_row - start_row) * rhs.patchDepth() - start_depth;
1342 
1343               // Do not overshoot beyond the block size.
1344               const Index max_depth =
1345                   start_depth + std::min<Index>(peeled_k - k, squeeze_length);
1346               eigen_assert((max_depth - start_depth) % packet_size == 0);
1347 
1348               const Index idx0 = dm0.baseIndex(start_row, c);
1349               const Index idx1 = dm1.baseIndex(start_row, c);
1350               const Index idx2 = dm2.baseIndex(start_row, c);
1351               const Index idx3 = dm3.baseIndex(start_row, c);
1352 
1353               for (Index d = start_depth; d < max_depth; d += packet_size) {
1354                 PacketBlock<Packet, 2> kernel0;
1355                 PacketBlock<Packet, 2> kernel1;
1356                 kernel0.packet[0] = rhs.packetNoPadding(d, idx0);
1357                 kernel0.packet[1] = rhs.packetNoPadding(d, idx1);
1358                 kernel1.packet[0] = rhs.packetNoPadding(d, idx2);
1359                 kernel1.packet[1] = rhs.packetNoPadding(d, idx3);
1360                 ptranspose(kernel0);
1361                 ptranspose(kernel1);
1362                 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1363                 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1364                 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1365                 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1366                 block += 4 * packet_size;
1367                 k += packet_size;
1368               }
1369 
1370               // Go to the next column.
1371               continue;
1372             }
1373 
1374             // If we can't squeeze reads, process rows one by one.
1375             for (Index r = start_row; r < max_row; ++r) {
1376               eigen_assert(k <= peeled_k);
1377 
1378               const bool pad0 = pad_col0 || dm0.padRow(r);
1379               const bool pad1 = pad_col1 || dm1.padRow(r);
1380               const bool pad2 = pad_col2 || dm2.padRow(r);
1381               const bool pad3 = pad_col3 || dm3.padRow(r);
1382 
1383               const Index idx0 = dm0.baseIndex(r, c);
1384               const Index idx1 = dm1.baseIndex(r, c);
1385               const Index idx2 = dm2.baseIndex(r, c);
1386               const Index idx3 = dm3.baseIndex(r, c);
1387 
1388               const Index start_depth = ((c == start_col) && (r == start_row))
1389                                             ? rhs.depthOffset()
1390                                             : 0;
1391               const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
1392               eigen_assert((max_depth - start_depth) % packet_size == 0);
1393 
1394               for (Index d = start_depth; d < max_depth; d += packet_size) {
1395                 eigen_assert(k < peeled_k);
1396                 PacketBlock<Packet, 2> kernel0;
1397                 PacketBlock<Packet, 2> kernel1;
1398                 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
1399                                          : rhs.packetNoPadding(d, idx0);
1400                 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
1401                                          : rhs.packetNoPadding(d, idx1);
1402                 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
1403                                          : rhs.packetNoPadding(d, idx2);
1404                 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
1405                                          : rhs.packetNoPadding(d, idx3);
1406                 ptranspose(kernel0);
1407                 ptranspose(kernel1);
1408                 pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1409                 pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1410                 pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1411                 pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1412                 block += 4 * packet_size;
1413                 k += packet_size;
1414               }
1415             }
1416           }
1417 
1418           // The loop above should fill peeled_k elements.
1419           eigen_assert(peeled_k == k);
1420 
1421         } else {
1422           // Packet can span multiple rows or columns, so we have to go
1423           // though the slower "standard" path.
1424           for (; k < peeled_k; k += packet_size) {
1425             PacketBlock<Packet, 2> kernel0;
1426             PacketBlock<Packet, 2> kernel1;
1427             kernel0.packet[0] = dm0.loadPacketStandard(k);
1428             kernel0.packet[1] = dm1.loadPacketStandard(k);
1429             kernel1.packet[0] = dm2.loadPacketStandard(k);
1430             kernel1.packet[1] = dm3.loadPacketStandard(k);
1431             ptranspose(kernel0);
1432             ptranspose(kernel1);
1433             pstoreu(block + 0 * packet_size, kernel0.packet[0]);
1434             pstoreu(block + 1 * packet_size, kernel1.packet[0]);
1435             pstoreu(block + 2 * packet_size, kernel0.packet[1]);
1436             pstoreu(block + 3 * packet_size, kernel1.packet[1]);
1437             block += 4 * packet_size;
1438           }
1439         }
1440       }
1441 
1442       // Copy the remaining coefficients of the column block after the peeled_k.
1443       if (!non_standard_patches) {
1444         for (; k < depth; k++) {
1445           block[0] = dm0.loadCoeffStandard(k);
1446           block[1] = dm1.loadCoeffStandard(k);
1447           block[2] = dm2.loadCoeffStandard(k);
1448           block[3] = dm3.loadCoeffStandard(k);
1449           block += 4;
1450         }
1451       } else {
1452         for (; k < depth; k++) {
1453           block[0] = dm0(k);
1454           block[1] = dm1(k);
1455           block[2] = dm2(k);
1456           block[3] = dm3(k);
1457           block += 4;
1458         }
1459       }
1460     }
1461 
1462     // Copy the remaining columns one at a time (nr==1).
1463     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1464       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1465       for (Index k = 0; k < depth; k++) {
1466         *block = dm0(k);
1467         block += 1;
1468       }
1469     }
1470   }
1471 };
1472 
1473 // Special case for non-vectorized types such as float16.
1474 template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
1475           typename Device, typename Scalar, typename Index,
1476           typename nocontract_t, typename contract_t, bool inner_dim_contiguous,
1477           bool inner_dim_reordered, int Alignment, int nr>
1478 struct gemm_pack_rhs<
1479     Scalar, Index,
1480     TensorContractionSubMapper<
1481         Scalar, Index, Rhs,
1482         TensorEvaluator<
1483             const TensorReshapingOp<
1484                 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1485             Device>,
1486         nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
1487         Alignment>,
1488     nr, ColMajor, false, false> {
1489   typedef TensorContractionSubMapper<
1490       Scalar, Index, Rhs,
1491       TensorEvaluator<
1492           const TensorReshapingOp<
1493               NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >,
1494           Device>,
1495       nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
1496       Alignment>
1497       SubMapper;
1498   typedef SubMapper DataMapper;
1499 
1500   EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE)
1501 
1502   EIGEN_DEVICE_FUNC
1503   EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
1504                                     Index depth, Index cols, Index stride = 0,
1505                                     Index offset = 0) const {
1506     eigen_assert(stride == 0);
1507     eigen_assert(offset == 0);
1508 
1509     const Index packet_cols4 = (cols / 4) * 4;
1510 
1511     for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
1512       const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
1513       const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
1514       const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
1515       const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
1516 
1517       if (!rhs.nonStandardPatches()) {
1518         for (Index k = 0; k < depth; k++) {
1519           block[0] = dm0.loadCoeffStandard(k);
1520           block[1] = dm1.loadCoeffStandard(k);
1521           block[2] = dm2.loadCoeffStandard(k);
1522           block[3] = dm3.loadCoeffStandard(k);
1523           block += 4;
1524         }
1525       } else {
1526         for (Index k = 0; k < depth; k++) {
1527           block[0] = dm0(k);
1528           block[1] = dm1(k);
1529           block[2] = dm2(k);
1530           block[3] = dm3(k);
1531           block += 4;
1532         }
1533       }
1534     }
1535 
1536     // Copy the remaining columns one at a time (nr==1).
1537     for (Index j2 = packet_cols4; j2 < cols; ++j2) {
1538       const SubMapper dm0 = rhs.getLinearMapper(0, j2);
1539       for (Index k = 0; k < depth; k++) {
1540         *block = dm0(k);
1541         block += 1;
1542       }
1543     }
1544   }
1545 };
1546 #endif
1547 }  // end namespace internal
1548 
1549 /** SpatialConvolution
1550  * \ingroup CXX11_NeuralNetworks_Module
1551  *
1552  * \brief Applies a 2D convolution over a multichannel input image.
1553  *
1554  * The input parameter is expected to be a tensor with a rank of 3 or more
1555  * (channels, height, width, and optionally others)
1556  * The kernel parameter is expected to be a 4D tensor (filters, channels,
1557  * kernel_height, kernel_width)
1558  * The input and the kernel must both be in col-major layout. The result will
1559  * also be in col-major layout.
1560  *
1561  * If col_in_stride, row_in_stride > 1, then applies convolution with holes
1562  * (aka atrous convolution), sampling every col_in_stride, row_in_stride input
1563  * pixels.
1564  *
1565  * If padding_top, padding_bottom, padding_left, or padding_right is specified,
1566  * then those paddings will be used to pad the input, and padding_type must be
1567  * PADDING_VALID.
1568  *
1569  * The result can be assigned to a tensor of rank equal to the rank of the
1570  * input. The dimensions of the result will be filters, height, width (and
1571  * others if applicable).
1572  *
1573  * It is possible to swap the order of the width and height dimensions provided
1574  * that the same order is used in the input, the kernel, and the output.
1575  *
1576  * It is also possible to add an output kernel to the contraction, output
1577  * kernel is called by Eigen when it "finalizes" the block of an output tensor.
1578  *
1579  */
1580 template <typename Input, typename Kernel,
1581           typename OutputKernel = const NoOpOutputKernel>
1582 EIGEN_ALWAYS_INLINE static const std::conditional_t<
1583     internal::traits<Input>::Layout == ColMajor,
1584     TensorReshapingOp<
1585         const DSizes<typename internal::traits<Input>::Index,
1586                      internal::traits<Input>::NumDimensions>,
1587         const TensorContractionOp<
1588             const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1589             const TensorReshapingOp<
1590                 const DSizes<typename internal::traits<Input>::Index, 2>,
1591                 const Kernel>,
1592             const TensorReshapingOp<
1593                 const DSizes<typename internal::traits<Input>::Index, 2>,
1594                 const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
1595             const OutputKernel> >,
1596     TensorReshapingOp<
1597         const DSizes<typename internal::traits<Input>::Index,
1598                      internal::traits<Input>::NumDimensions>,
1599         const TensorContractionOp<
1600             const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
1601             const TensorReshapingOp<
1602                 const DSizes<typename internal::traits<Input>::Index, 2>,
1603                 const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
1604             const TensorReshapingOp<
1605                 const DSizes<typename internal::traits<Input>::Index, 2>,
1606                 const Kernel>,
1607             const OutputKernel> > >
1608 SpatialConvolution(const Input& input, const Kernel& kernel,
1609                    const Index row_stride = 1, const Index col_stride = 1,
1610                    const PaddingType padding_type = PADDING_SAME,
1611                    const Index row_in_stride = 1, const Index col_in_stride = 1,
1612                    const OutputKernel& output_kernel = OutputKernel(),
1613                    Index padding_top = 0, Index padding_bottom = 0,
1614                    Index padding_left = 0, Index padding_right = 0) {
1615   typedef typename internal::traits<Input>::Index TensorIndex;
1616   typedef typename internal::traits<Input>::Scalar InputScalar;
1617   TensorRef<Tensor<InputScalar, internal::traits<Input>::NumDimensions,
1618                    internal::traits<Input>::Layout, TensorIndex> >
1619       in(input);
1620   TensorRef<Tensor<typename internal::traits<Kernel>::Scalar,
1621                    internal::traits<Kernel>::NumDimensions,
1622                    internal::traits<Kernel>::Layout, TensorIndex> >
1623       kern(kernel);
1624 
1625   EIGEN_STATIC_ASSERT(
1626       internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
1627       YOU_MADE_A_PROGRAMMING_MISTAKE)
1628   const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
1629 
1630   const int NumDims = internal::traits<Input>::NumDimensions;
1631 
1632   // Number of filters to apply. This is the same as the output depth of the
1633   // result
1634   const TensorIndex kernelFilters =
1635       isColMajor ? kern.dimensions()[0] : kern.dimensions()[3];
1636   // Number of channels. This is the same as the input depth.
1637   const TensorIndex kernelChannels =
1638       isColMajor ? kern.dimensions()[1] : kern.dimensions()[2];
1639   const TensorIndex kernelRows =
1640       isColMajor ? kern.dimensions()[2] : kern.dimensions()[1];
1641   const TensorIndex kernelCols =
1642       isColMajor ? kern.dimensions()[3] : kern.dimensions()[0];
1643 
1644   const Index kernelRowsEff =
1645       kernelRows + (kernelRows - 1) * (row_in_stride - 1);
1646   const Index kernelColsEff =
1647       kernelCols + (kernelCols - 1) * (col_in_stride - 1);
1648 
1649   array<IndexPair<TensorIndex>, 1> contract_dims;
1650   contract_dims[0] = IndexPair<TensorIndex>(1, 0);
1651 
1652   const TensorIndex InputRows =
1653       isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
1654   const TensorIndex InputCols =
1655       isColMajor ? in.dimension(2) : in.dimension(NumDims - 3);
1656   const bool padding_explicit =
1657       (padding_top || padding_bottom || padding_left || padding_right);
1658 
1659   TensorIndex out_height;
1660   TensorIndex out_width;
1661   switch (padding_type) {
1662     case PADDING_VALID: {
1663       const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom;
1664       const TensorIndex InputColsEff = InputCols + padding_left + padding_right;
1665       out_height = divup(InputRowsEff - kernelRowsEff + 1, row_stride);
1666       out_width = divup(InputColsEff - kernelColsEff + 1, col_stride);
1667       break;
1668     }
1669     case PADDING_SAME: {
1670       eigen_assert(!padding_explicit);
1671       out_height = divup(InputRows, row_stride);
1672       out_width = divup(InputCols, col_stride);
1673       break;
1674     }
1675     default: {
1676       // Initialize unused variables to avoid a compiler warning
1677       out_height = 0;
1678       out_width = 0;
1679       eigen_assert(false && "unexpected padding");
1680     }
1681   }
1682 
1683   // Molds the output of the patch extraction code into a 2d tensor:
1684   // - the first dimension (dims[0]): the patch values to be multiplied with the
1685   // kernels
1686   // - the second dimension (dims[1]): everything else
1687   DSizes<TensorIndex, 2> pre_contract_dims;
1688   if (isColMajor) {
1689     pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
1690     pre_contract_dims[1] = out_height * out_width;
1691     for (int i = 3; i < NumDims; ++i) {
1692       pre_contract_dims[1] *= in.dimension(i);
1693     }
1694   } else {
1695     pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols;
1696     pre_contract_dims[0] = out_height * out_width;
1697     for (int i = 0; i < NumDims - 3; ++i) {
1698       pre_contract_dims[0] *= in.dimension(i);
1699     }
1700   }
1701 
1702   // Molds the output of the contraction into the shape expected by the used
1703   // (assuming this is ColMajor):
1704   // - 1st dim: kernel filters
1705   // - 2nd dim: output height
1706   // - 3rd dim: output width
1707   // - 4th dim and beyond: everything else including batch size
1708   DSizes<TensorIndex, NumDims> post_contract_dims;
1709   if (isColMajor) {
1710     post_contract_dims[0] = kernelFilters;
1711     post_contract_dims[1] = out_height;
1712     post_contract_dims[2] = out_width;
1713     for (int i = 3; i < NumDims; ++i) {
1714       post_contract_dims[i] = in.dimension(i);
1715     }
1716   } else {
1717     post_contract_dims[NumDims - 1] = kernelFilters;
1718     post_contract_dims[NumDims - 2] = out_height;
1719     post_contract_dims[NumDims - 3] = out_width;
1720     for (int i = 0; i < NumDims - 3; ++i) {
1721       post_contract_dims[i] = in.dimension(i);
1722     }
1723   }
1724 
1725   DSizes<TensorIndex, 2> kernel_dims;
1726   if (isColMajor) {
1727     kernel_dims[0] = kernelFilters;
1728     kernel_dims[1] = kernelChannels * kernelRows * kernelCols;
1729   } else {
1730     kernel_dims[0] = kernelChannels * kernelRows * kernelCols;
1731     kernel_dims[1] = kernelFilters;
1732   }
1733   if (padding_explicit) {
1734     return choose(
1735         Cond<internal::traits<Input>::Layout == ColMajor>(),
1736         kernel.reshape(kernel_dims)
1737             .contract(input
1738                           .extract_image_patches(
1739                               kernelRows, kernelCols, row_stride, col_stride,
1740                               row_in_stride, col_in_stride,
1741                               /*row_inflate_stride=*/1,
1742                               /*col_inflate_stride=*/1, padding_top,
1743                               padding_bottom, padding_left, padding_right,
1744                               /*padding_value=*/static_cast<InputScalar>(0))
1745                           .reshape(pre_contract_dims),
1746                       contract_dims, output_kernel)
1747             .reshape(post_contract_dims),
1748         input
1749             .extract_image_patches(
1750                 kernelRows, kernelCols, row_stride, col_stride, row_in_stride,
1751                 col_in_stride,
1752                 /*row_inflate_stride=*/1,
1753                 /*col_inflate_stride=*/1, padding_top, padding_bottom,
1754                 padding_left, padding_right,
1755                 /*padding_value=*/static_cast<InputScalar>(0))
1756             .reshape(pre_contract_dims)
1757             .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1758             .reshape(post_contract_dims));
1759   } else {
1760     return choose(
1761         Cond<internal::traits<Input>::Layout == ColMajor>(),
1762         kernel.reshape(kernel_dims)
1763             .contract(input
1764                           .extract_image_patches(
1765                               kernelRows, kernelCols, row_stride, col_stride,
1766                               row_in_stride, col_in_stride, padding_type)
1767                           .reshape(pre_contract_dims),
1768                       contract_dims, output_kernel)
1769             .reshape(post_contract_dims),
1770         input
1771             .extract_image_patches(kernelRows, kernelCols, row_stride,
1772                                    col_stride, row_in_stride, col_in_stride,
1773                                    padding_type)
1774             .reshape(pre_contract_dims)
1775             .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel)
1776             .reshape(post_contract_dims));
1777   }
1778 }
1779 
1780 }  // end namespace Eigen
1781 
1782 #endif  // TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_
1783