1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ 17 #define TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ 18 19 #include "tensorflow/core/kernels/eigen_convolution_helpers.h" 20 21 // Note this header is used in both TF and TFLite. 22 namespace Eigen { 23 24 namespace internal { 25 26 #if !EIGEN_ALTIVEC_USE_CUSTOM_PACK 27 // WARNING: Most of the code here implicitly assumes that the matrix is in 28 // ColMajor layout. This is guaranteed by the tensor contraction (see 29 // TensorContraction.h). 30 // 31 // Inside Eigen a tensor contraction is represented by a matrix multiplication. 32 // We don't want to actually extract image patches and reshape the result into 33 // a matrix (this involves allocating huge extra memory), so the patch 34 // extraction and reshape operations are implicit. 35 // 36 // TensorContractionInputMapper takes a matrix index and returns the coefficient 37 // (or the packet) of the "virtual tensor", that would be at that index if we 38 // were to actually reshape the result of patch extraction. 39 // 40 // TensorContractionSubMapper provides a similar view into the "virtual matrix" 41 // at the given vertical and horizontal offsets. 42 // 43 // "Virtual matrix" dimensions: 44 // *0: kernelChannels * kernelRows * kernelCols; 45 // 1: out_height * out_width; * OTHERS (e.g batches, etc...) 46 // 47 // *) extracted patches are continuous in memory (innermost dimension assuming 48 // col major layout) 49 // 50 // With this dimensions: 51 // row - offset within a single patch (in code: patchId) 52 // col - index of the extracted patch (in code: patchIndex) 53 // patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) 54 // 55 // TODO(ezhulenev): Consolidate this part of the code with the image patch 56 // extraction code since they are both very similar. 57 58 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 59 typename Device, typename Scalar_, typename Index, 60 typename nocontract_t, typename contract_t, int Side, int packet_size, 61 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 62 class TensorContractionInputMapper< 63 Scalar_, Index, Side, 64 TensorEvaluator< 65 const TensorReshapingOp<NewDimension, 66 const TensorImagePatchOp<Rows, Cols, ArgType> >, 67 Device>, 68 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 69 inner_dim_reordered, Alignment> { 70 public: 71 typedef Scalar_ Scalar; 72 73 typedef TensorContractionInputMapper< 74 Scalar, Index, Side, 75 TensorEvaluator< 76 const TensorReshapingOp< 77 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 78 Device>, 79 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 80 inner_dim_reordered, Alignment> 81 Self; 82 83 typedef TensorContractionSubMapper< 84 Scalar, Index, Side, 85 TensorEvaluator< 86 const TensorReshapingOp< 87 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 88 Device>, 89 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 90 inner_dim_reordered, Alignment> 91 SubMapper; 92 93 typedef SubMapper VectorMapper; 94 typedef SubMapper LinearMapper; 95 typedef typename packet_traits<Scalar>::type Packet; 96 97 typedef TensorEvaluator<ArgType, Device> TensorEvaluatorT; 98 99 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorEvaluator<const TensorReshapingOp<NewDimension,const TensorImagePatchOp<Rows,Cols,ArgType>>,Device> & tensor,const nocontract_t &,const nocontract_t &,const contract_t &,const contract_t &)100 TensorContractionInputMapper( 101 const TensorEvaluator< 102 const TensorReshapingOp< 103 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 104 Device>& tensor, 105 const nocontract_t&, const nocontract_t&, const contract_t&, 106 const contract_t&) 107 : m_impl(tensor.impl().impl()) { 108 Index patch_rows; 109 Index patch_depth; 110 if (internal::traits<ArgType>::Layout == ColMajor) { 111 patch_depth = tensor.impl().dimensions()[0]; 112 patch_rows = tensor.impl().dimensions()[1]; 113 m_patch_cols = tensor.impl().dimensions()[2]; 114 m_num_patches = tensor.impl().dimensions()[3]; 115 } else { 116 const size_t NumDims = tensor.impl().dimensions().size(); 117 patch_depth = tensor.impl().dimensions()[NumDims - 1]; 118 patch_rows = tensor.impl().dimensions()[NumDims - 2]; 119 m_patch_cols = tensor.impl().dimensions()[NumDims - 3]; 120 m_num_patches = tensor.impl().dimensions()[NumDims - 4]; 121 } 122 123 // Strides for navigating through the single patch. 124 m_patch_row_stride = patch_depth; 125 m_patch_col_stride = patch_rows * m_patch_row_stride; 126 127 m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); 128 m_patch_col_inflate_strides = tensor.impl().colInflateStride(); 129 130 m_colStride = patch_rows; 131 132 m_outputRows = tensor.impl().outputRows(); 133 m_outputCols = tensor.impl().outputCols(); 134 m_row_strides = tensor.impl().userRowStride(); 135 m_col_strides = tensor.impl().userColStride(); 136 137 m_in_row_strides = tensor.impl().userInRowStride(); 138 m_in_col_strides = tensor.impl().userInColStride(); 139 140 if (internal::traits<ArgType>::Layout == ColMajor) { 141 m_inputRows = tensor.impl().impl().dimensions()[1]; 142 m_inputCols = tensor.impl().impl().dimensions()[2]; 143 } else { 144 const int NumDims = tensor.impl().impl().dimensions().size(); 145 m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2]; 146 m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3]; 147 } 148 149 m_rowInputStride = patch_depth; 150 m_colInputStride = patch_depth * m_inputRows; 151 m_patchInputStride = patch_depth * m_inputRows * m_inputCols; 152 153 m_rowPaddingTop = tensor.impl().rowPaddingTop(); 154 m_colPaddingLeft = tensor.impl().colPaddingLeft(); 155 156 m_fastPatchRowStride = 157 internal::TensorIntDivisor<Index>(m_patch_row_stride); 158 m_fastPatchColStride = 159 internal::TensorIntDivisor<Index>(m_patch_col_stride); 160 m_fastInputRowStride = 161 internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides); 162 m_fastInputColStride = 163 internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides); 164 m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches); 165 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride); 166 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); 167 m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth); 168 } 169 170 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorContractionInputMapper & base_mapper)171 TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) 172 : m_impl(base_mapper.m_impl) { 173 m_patch_cols = base_mapper.m_patch_cols; 174 m_num_patches = base_mapper.m_num_patches; 175 176 m_patch_row_stride = base_mapper.m_patch_row_stride; 177 m_patch_col_stride = base_mapper.m_patch_col_stride; 178 179 m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; 180 m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; 181 182 m_colStride = base_mapper.m_colStride; 183 184 m_rowInputStride = base_mapper.m_rowInputStride; 185 m_colInputStride = base_mapper.m_colInputStride; 186 m_patchInputStride = base_mapper.m_patchInputStride; 187 188 m_inputRows = base_mapper.m_inputRows; 189 m_inputCols = base_mapper.m_inputCols; 190 191 m_outputRows = base_mapper.m_outputRows; 192 m_outputCols = base_mapper.m_outputCols; 193 m_row_strides = base_mapper.m_row_strides; 194 m_col_strides = base_mapper.m_col_strides; 195 196 m_in_row_strides = base_mapper.m_in_row_strides; 197 m_in_col_strides = base_mapper.m_in_col_strides; 198 199 m_rowPaddingTop = base_mapper.m_rowPaddingTop; 200 m_colPaddingLeft = base_mapper.m_colPaddingLeft; 201 202 m_fastPatchRowStride = base_mapper.m_fastPatchRowStride; 203 m_fastPatchColStride = base_mapper.m_fastPatchColStride; 204 m_fastInputRowStride = base_mapper.m_fastInputRowStride; 205 m_fastInputColStride = base_mapper.m_fastInputColStride; 206 m_fastNumPatches = base_mapper.m_fastNumPatches; 207 m_fastColStride = base_mapper.m_fastColStride; 208 m_fastOutputRows = base_mapper.m_fastOutputRows; 209 m_fastDimZero = base_mapper.m_fastDimZero; 210 } 211 212 // If true, turns off some optimizations for loading packets since the image 213 // patches are "non-standard" such as there are non-trivial strides or 214 // inflations in the input. 215 EIGEN_DEVICE_FUNC nonStandardPatches()216 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { 217 return m_in_row_strides != 1 || m_in_col_strides != 1 || 218 m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; 219 } 220 221 EIGEN_DEVICE_FUNC getSubMapper(Index i,Index j)222 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { 223 return SubMapper(*this, i, j); 224 } 225 226 EIGEN_DEVICE_FUNC getLinearMapper(Index i,Index j)227 EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { 228 return LinearMapper(*this, i, j); 229 } 230 231 EIGEN_DEVICE_FUNC operator()232 EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { 233 Index rowIndex, colIndex, otherIndex; 234 computeBaseIndices(0, rowIndex, colIndex, otherIndex); 235 return loadCoeff(row, rowIndex, colIndex, otherIndex); 236 } 237 238 // Load the coefficient at the patchIndex location instead of the usual 239 // m_rowIndex, 240 // m_colIndex, m_otherIndex. This is currently only used by the gpu code. 241 // EIGEN_DEVICE_FUNC 242 EIGEN_DEVICE_FUNC operator()243 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { 244 Index rowIndex, colIndex, otherIndex; 245 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); 246 return loadCoeff(row, rowIndex, colIndex, otherIndex); 247 } 248 249 EIGEN_DEVICE_FUNC loadPacket(Index row)250 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { 251 Index rowIndex, colIndex, otherIndex; 252 computeBaseIndices(0, rowIndex, colIndex, otherIndex); 253 return loadPacket(row, rowIndex, colIndex, otherIndex); 254 } 255 256 // Load the packet at the patchIndex location instead of the usual m_rowIndex, 257 // m_colIndex, m_otherIndex. This is currently only used by the gpu code. 258 EIGEN_DEVICE_FUNC loadPacket(Index row,Index patchIndex)259 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { 260 Index rowIndex, colIndex, otherIndex; 261 computeBaseIndices(patchIndex, rowIndex, colIndex, otherIndex); 262 return loadPacket(row, rowIndex, colIndex, otherIndex); 263 } 264 265 EIGEN_DEVICE_FUNC impl()266 EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const { 267 return m_impl; 268 } 269 270 EIGEN_DEVICE_FUNC patchDepth()271 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_rowInputStride; } 272 EIGEN_DEVICE_FUNC patchRows()273 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_colStride; } 274 EIGEN_DEVICE_FUNC patchCols()275 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } 276 277 private: 278 friend class TensorContractionSubMapper< 279 Scalar, Index, Side, 280 TensorEvaluator< 281 const TensorReshapingOp< 282 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 283 Device>, 284 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 285 inner_dim_reordered, Alignment>; 286 287 // Load coefficient from a patch specified by the "within patch offset" 288 // (patchId) and the precomputed indices of the first element of the patch. 289 EIGEN_DEVICE_FUNC loadCoeff(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)290 EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, 291 Index colIndex, Index otherIndex) const { 292 // Find the offset of the element wrt the location of the first element. 293 const Index patchOffset = patchId / m_fastDimZero; 294 295 const Index colOffset = patchOffset / m_fastColStride; 296 const Index inputCol = colIndex + colOffset * m_in_col_strides; 297 const Index origInputCol = 298 (m_patch_col_inflate_strides == 1) 299 ? inputCol 300 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); 301 302 const Index rowOffset = patchOffset - colOffset * m_colStride; 303 const Index inputRow = rowIndex + rowOffset * m_in_row_strides; 304 const Index origInputRow = 305 (m_patch_row_inflate_strides == 1) 306 ? inputRow 307 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); 308 if (origInputCol < 0 || origInputRow < 0 || origInputCol >= m_inputCols || 309 origInputRow >= m_inputRows || 310 (inputCol != origInputCol * m_patch_col_inflate_strides) || 311 (inputRow != origInputRow * m_patch_row_inflate_strides)) { 312 return Scalar(0); 313 } 314 const Index depth = patchId - patchOffset * patchDepth(); 315 const Index inputIndex = depth + origInputRow * m_rowInputStride + 316 origInputCol * m_colInputStride + otherIndex; 317 return m_impl.coeff(inputIndex); 318 } 319 320 // This is the same as loadCoeff(...), but optimized for all `inflate_strides` 321 // and `in_strides` equal to 1 (template specialization without templates). 322 EIGEN_DEVICE_FUNC loadCoeffStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)323 EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, 324 Index colIndex, 325 Index otherIndex) const { 326 eigen_assert(!nonStandardPatches()); 327 328 // Find the offset of the element wrt the location of the first element. 329 const Index patchOffset = patchId / m_fastDimZero; 330 const Index colOffset = patchOffset / m_fastColStride; 331 const Index rowOffset = patchOffset - colOffset * m_colStride; 332 const Index inputCol = colIndex + colOffset; 333 const Index inputRow = rowIndex + rowOffset; 334 if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || 335 inputRow >= m_inputRows) { 336 return Scalar(0); 337 } 338 const Index depth = patchId - patchOffset * patchDepth(); 339 const Index inputIndex = depth + inputRow * m_rowInputStride + 340 inputCol * m_colInputStride + otherIndex; 341 return m_impl.coeff(inputIndex); 342 } 343 344 // Load packet from a patch specified by the "within patch offset" 345 // (patchId) and the precomputed indices of the first element of the patch. 346 EIGEN_DEVICE_FUNC loadPacket(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)347 EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, 348 Index colIndex, 349 Index otherIndex) const { 350 const Index packetSize = internal::unpacket_traits<Packet>::size; 351 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 352 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); 353 354 if (nonStandardPatches()) { 355 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); 356 } 357 typedef decltype(m_impl) TensorEvaluatorT; 358 return loadPacketStandard<Packet, TensorEvaluatorT>(patchId, rowIndex, 359 colIndex, otherIndex); 360 } 361 362 // Helper function to load a 'partial' packet - this is the single column 363 // part of a packet that is split across two columns. In the 'partial' packet, 364 // the elements corresponding to the column (specified through colOffset) are 365 // loaded and the rest of the elements are zero-filled into the 'partial' 366 // packet. This function is called from loadPacketStandardFromTwoColumns(). 367 // This code path is exercised only when the packet type supports masked load 368 // and when the partial packet load is available in the TensorEvaluator. 369 EIGEN_DEVICE_FUNC loadPartialPacketStandard(Index rowIndex,Index colIndex,Index otherIndex,Index patchId,const Index span[],const Index patchOffsets[],Index colOffset)370 EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard( 371 Index rowIndex, Index colIndex, Index otherIndex, Index patchId, 372 const Index span[], const Index patchOffsets[], Index colOffset) const { 373 const Index inputCol = colIndex + colOffset; 374 const Index rowOffsets[2] = {patchOffsets[0] - colOffset * m_colStride, 375 patchOffsets[1] - colOffset * m_colStride}; 376 const Index inputRows[2] = {rowIndex + rowOffsets[0], 377 rowIndex + rowOffsets[1]}; 378 379 if (inputRows[0] >= m_inputRows || inputRows[1] < 0 || 380 inputCol >= m_inputCols || inputCol < 0) { 381 // Partial packet is all zeros 382 return internal::pset1<Packet>(Scalar(0)); 383 } else if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) { 384 // From inputIndex-span[0], we need to load elements starting from index 385 // span[0] all the way upto (and including) span[1]. 386 const Index depth = patchId - patchOffsets[0] * patchDepth(); 387 const Index inputIndex = depth + inputRows[0] * m_rowInputStride + 388 inputCol * m_colInputStride + otherIndex; 389 return m_impl.template partialPacket<Packet>( 390 inputIndex - span[0], mask<Packet>(span[0], span[1] + 1)); 391 } else { 392 // Using slow path for this partial packet. 393 // We need to load elements starting from index span[0] all the way upto 394 // (and including) span[1]. We split this load into 3 parts: 395 // 0 : span[0]-1 - Zeros will be loaded for these indices 396 // span[0] : span[1] - Elements will be loaded here for these indices 397 // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices 398 const Index packetSize = internal::unpacket_traits<Packet>::size; 399 EIGEN_ALIGN_MAX 400 std::remove_const_t<Scalar> values[packetSize]; 401 for (int i = 0; i < span[0]; ++i) values[i] = Scalar(0); 402 for (int i = span[0]; i < span[1] + 1; ++i) 403 values[i] = 404 loadCoeff(patchId - span[0] + i, rowIndex, colIndex, otherIndex); 405 for (int i = span[1] + 1; i < packetSize; ++i) values[i] = Scalar(0); 406 return internal::pload<Packet>(values); 407 } 408 } 409 410 // Helper function to load a packet that is split across two columns. 411 // If required, this function is called from loadPacketStandard() when the 412 // packet type supports masked load and when the partial packet load is 413 // available in the TensorEvaluator. 414 EIGEN_DEVICE_FUNC loadPacketStandardFromTwoColumns(Index patchId,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[])415 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromTwoColumns( 416 Index patchId, Index rowIndex, Index colIndex, Index otherIndex, 417 const Index patchOffsets[], const Index colOffsets[]) const { 418 eigen_assert(colOffsets[1] == colOffsets[0] + 1); 419 const Index packetSize = internal::unpacket_traits<Packet>::size; 420 421 // Packet to load will be split into 2 parts where each part spans a single 422 // column. First determine where to split. 423 const Index patchIdSplit = 424 ((colOffsets[1] * m_colStride) * m_rowInputStride) - 1; 425 const Index patchOffsetSplit = patchIdSplit / m_fastDimZero; 426 427 // patchIds[i]: patchId corresponding to partial packet i 428 // spans[i]: Start and end indices corresponding to the elements 429 // to be loaded for partial packet i 430 // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i 431 const Index patchIds[2] = {patchId, patchIdSplit + 1}; 432 const Index spans[2][2] = {{0, patchIdSplit - patchId}, 433 {patchIdSplit - patchId + 1, packetSize - 1}}; 434 const Index patchOffsets2Cols[2][2] = { 435 {patchOffsets[0], patchOffsetSplit}, 436 {patchOffsetSplit + 1, patchOffsets[1]}}; 437 438 // Load partial packets and do bit-wise OR to generate required packet 439 return internal::por<Packet>( 440 loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[0], 441 spans[0], patchOffsets2Cols[0], 442 colOffsets[0]), 443 loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[1], 444 spans[1], patchOffsets2Cols[1], 445 colOffsets[1])); 446 } 447 448 // Helper function to load a packet that is present in a single columns. 449 // If required, this function is called from loadPacketStandard(). 450 EIGEN_DEVICE_FUNC loadPacketStandardFromSingleColumn(Index patchId,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[],const Index inputCols[])451 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumn( 452 Index patchId, Index rowIndex, Index colIndex, Index otherIndex, 453 const Index patchOffsets[], const Index colOffsets[], 454 const Index inputCols[]) const { 455 eigen_assert(colOffsets[0] == colOffsets[1]); 456 const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0] * m_colStride, 457 patchOffsets[1] - colOffsets[1] * m_colStride}; 458 eigen_assert(rowOffsets[0] <= rowOffsets[1]); 459 const Index inputRows[2] = {rowIndex + rowOffsets[0], 460 rowIndex + rowOffsets[1]}; 461 462 if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { 463 // all zeros 464 return internal::pset1<Packet>(Scalar(0)); // all zeros 465 } 466 467 if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) { 468 // no padding 469 const Index depth = patchId - patchOffsets[0] * patchDepth(); 470 const Index inputIndex = depth + inputRows[0] * m_rowInputStride + 471 inputCols[0] * m_colInputStride + otherIndex; 472 return m_impl.template packet<Unaligned>(inputIndex); 473 } 474 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); 475 } 476 477 // Load standard packet from a patch specified by the "within patch offset" 478 // (patchId) and the precomputed indices of the first element of the patch. 479 // This function will be called if partial packet loading is not available 480 // for the TensorEvaluator or if the packet type does not support masked 481 // load. 482 template <typename PacketT, typename TensorEvaluatorT> 483 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< 484 !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, 485 PacketT>::type loadPacketStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)486 loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, 487 Index otherIndex) const { 488 const Index packetSize = internal::unpacket_traits<Packet>::size; 489 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 490 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); 491 492 eigen_assert(!nonStandardPatches()); 493 494 if ((patchDepth() % packetSize) == 0) { 495 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex); 496 } 497 498 // Offsets and input calculation here are identical to 499 // loadCoeffStandard(...), but repeated twice. 500 const Index patchOffsets[2] = {patchId / m_fastDimZero, 501 (patchId + packetSize - 1) / m_fastDimZero}; 502 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, 503 patchOffsets[1] / m_fastColStride}; 504 const Index inputCols[2] = {colIndex + colOffsets[0], 505 colIndex + colOffsets[1]}; 506 507 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { 508 // all zeros 509 return internal::pset1<Packet>(Scalar(0)); 510 } 511 if (inputCols[0] == inputCols[1]) { 512 return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, 513 otherIndex, patchOffsets, 514 colOffsets, inputCols); 515 } 516 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); 517 } 518 519 // Load standard packet from a patch specified by the "within patch offset" 520 // (patchId) and the precomputed indices of the first element of the patch. 521 // This function will be called if partial packet loading is available for 522 // the TensorEvaluator and if the packet type supports masked load. 523 // The only difference between this and the other case is that if the packet 524 // to load is split across two columns, then in this case instead of going to 525 // the slow (element-by-element) load, we load two packets - each containing 526 // elements from one of the columns (rest of the elements of the packets are 527 // zeroes), and then combine these two packets to generate the required 528 // packet. The idea is to enable fast load (if possible) of these 'partial' 529 // packets. 530 template <typename PacketT, typename TensorEvaluatorT> 531 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< 532 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, 533 PacketT>::type loadPacketStandard(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)534 loadPacketStandard(Index patchId, Index rowIndex, Index colIndex, 535 Index otherIndex) const { 536 const Index packetSize = internal::unpacket_traits<PacketT>::size; 537 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 538 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); 539 540 eigen_assert(!nonStandardPatches()); 541 542 if ((patchDepth() % packetSize) == 0) { 543 return loadPacketFast(patchId, rowIndex, colIndex, otherIndex); 544 } 545 546 // Offsets and input calculation here are identical to 547 // loadCoeffStandard(...), but repeated twice. 548 const Index patchOffsets[2] = {patchId / m_fastDimZero, 549 (patchId + packetSize - 1) / m_fastDimZero}; 550 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, 551 patchOffsets[1] / m_fastColStride}; 552 const Index inputCols[2] = {colIndex + colOffsets[0], 553 colIndex + colOffsets[1]}; 554 555 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { 556 // all zeros 557 return internal::pset1<PacketT>(Scalar(0)); 558 } 559 if (inputCols[0] == inputCols[1]) { 560 return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex, 561 otherIndex, patchOffsets, 562 colOffsets, inputCols); 563 } 564 if (inputCols[1] == inputCols[0] + 1) { 565 return loadPacketStandardFromTwoColumns( 566 patchId, rowIndex, colIndex, otherIndex, patchOffsets, colOffsets); 567 } 568 return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex); 569 } 570 571 EIGEN_DEVICE_FUNC loadPacketFast(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)572 EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index rowIndex, 573 Index colIndex, 574 Index otherIndex) const { 575 const Index packetSize = internal::unpacket_traits<Packet>::size; 576 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 577 eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols); 578 579 eigen_assert(!nonStandardPatches()); 580 eigen_assert((patchDepth() % packetSize) == 0); 581 // Find the offset of the element wrt the location of the first element. 582 const Index patchOffset = patchId / m_fastDimZero; 583 eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); 584 585 const Index colOffset = patchOffset / m_fastColStride; 586 const Index rowOffset = patchOffset - colOffset * m_colStride; 587 const Index inputCol = colIndex + colOffset; 588 const Index inputRow = rowIndex + rowOffset; 589 if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols || 590 inputRow >= m_inputRows) { 591 // all zeros 592 return internal::pset1<Packet>(Scalar(0)); 593 } 594 // no padding 595 const Index depth = patchId - patchOffset * patchDepth(); 596 const Index inputIndex = depth + inputRow * m_rowInputStride + 597 inputCol * m_colInputStride + otherIndex; 598 return m_impl.template packet<Unaligned>(inputIndex); 599 } 600 packetWithPossibleZero(Index patchId,Index rowIndex,Index colIndex,Index otherIndex)601 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero( 602 Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const { 603 const int packetSize = internal::unpacket_traits<Packet>::size; 604 EIGEN_ALIGN_MAX 605 std::remove_const_t<Scalar> values[packetSize]; 606 for (int i = 0; i < packetSize; ++i) { 607 values[i] = loadCoeff(patchId + i, rowIndex, colIndex, otherIndex); 608 } 609 Packet rslt = internal::pload<Packet>(values); 610 return rslt; 611 } 612 computeBaseIndices(Index patchIndex,Index & rowIndex,Index & colIndex,Index & otherIndex)613 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( 614 Index patchIndex, Index& rowIndex, Index& colIndex, 615 Index& otherIndex) const { 616 const size_t NumInputDims = array_size< 617 typename TensorEvaluator<ArgType, Device>::Dimensions>::value; 618 otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches; 619 const Index patch2DIndex = (NumInputDims == 3) 620 ? patchIndex 621 : (patchIndex - otherIndex * m_num_patches); 622 otherIndex *= m_patchInputStride; 623 colIndex = patch2DIndex / m_fastOutputRows; 624 rowIndex = patch2DIndex - colIndex * m_outputRows; 625 colIndex = colIndex * m_col_strides - m_colPaddingLeft; 626 rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; 627 } 628 629 Index m_patch_cols; // number of columns in the patch 630 Index m_num_patches; // number of patches to extract. 631 632 // Strides for navigating through the single patch. 633 Index m_patch_row_stride; 634 Index m_patch_col_stride; 635 internal::TensorIntDivisor<Index> m_fastPatchRowStride; 636 internal::TensorIntDivisor<Index> m_fastPatchColStride; 637 638 Index m_patch_row_inflate_strides; // the strides for row inflation in the 639 // image patch 640 Index m_patch_col_inflate_strides; // the strides for col inflation in the 641 // image patch 642 // Fast representation of inflation strides. 643 internal::TensorIntDivisor<Index> m_fastInputRowStride; 644 internal::TensorIntDivisor<Index> m_fastInputColStride; 645 646 Index m_otherStride; 647 Index m_colStride; 648 internal::TensorIntDivisor<Index> m_fastNumPatches; 649 internal::TensorIntDivisor<Index> m_fastColStride; 650 651 Index m_rowInputStride; // row stride in the input tensor 652 Index m_colInputStride; // col stride in the input tensor 653 Index m_patchInputStride; // patch stride in the input tensor 654 655 Index m_inputRows; // Number of rows in the input tensor 656 Index m_inputCols; // Number of cols in the input tensor 657 658 Index m_outputRows; // Number of convolution output rows 659 Index m_outputCols; // Number of convolution output column 660 661 Index m_row_strides; // User specified row stride 662 Index m_col_strides; // User specified col stride 663 664 Index m_in_row_strides; // User specified input row stride 665 Index m_in_col_strides; // User specified input col stride 666 667 Index m_rowPaddingTop; // Row padding 668 Index m_colPaddingLeft; // Column padding 669 670 internal::TensorIntDivisor<Index> m_fastOutputRows; 671 internal::TensorIntDivisor<Index> m_fastDimZero; 672 673 const TensorEvaluator<ArgType, Device> m_impl; 674 }; 675 676 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 677 typename Device, typename Scalar, typename Index, 678 typename nocontract_t, typename contract_t, int Side, int packet_size, 679 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 680 class TensorContractionSubMapper< 681 Scalar, Index, Side, 682 TensorEvaluator< 683 const TensorReshapingOp<NewDimension, 684 const TensorImagePatchOp<Rows, Cols, ArgType> >, 685 Device>, 686 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 687 inner_dim_reordered, Alignment> { 688 public: 689 typedef typename packet_traits<Scalar>::type Packet; 690 typedef typename packet_traits<Scalar>::half HalfPacket; 691 692 typedef TensorContractionInputMapper< 693 Scalar, Index, Side, 694 TensorEvaluator< 695 const TensorReshapingOp< 696 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 697 Device>, 698 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 699 inner_dim_reordered, Alignment> 700 ParentMapper; 701 702 typedef TensorContractionSubMapper< 703 Scalar, Index, Side, 704 TensorEvaluator< 705 const TensorReshapingOp< 706 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 707 Device>, 708 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 709 inner_dim_reordered, Alignment> 710 Self; 711 712 typedef Self LinearMapper; 713 714 typedef typename ParentMapper::TensorEvaluatorT TensorEvaluatorT; 715 TensorContractionSubMapper(const ParentMapper & base_mapper,Index vert_offset,Index horiz_offset)716 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( 717 const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) 718 : m_depth_offset(vert_offset), 719 m_col_offset(horiz_offset), 720 m_base_mapper(base_mapper) { 721 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, 722 m_otherIndex); 723 } TensorContractionSubMapper(const Self & base_mapper,Index vert_offset,Index horiz_offset)724 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( 725 const Self& base_mapper, Index vert_offset, Index horiz_offset) 726 : m_depth_offset(vert_offset + base_mapper.m_depth_offset), 727 m_col_offset(horiz_offset + base_mapper.m_col_offset), 728 m_base_mapper(base_mapper.m_base_mapper) { 729 m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex, 730 m_otherIndex); 731 } operator()732 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { 733 return m_base_mapper.loadCoeff(i + m_depth_offset, m_rowIndex, m_colIndex, 734 m_otherIndex); 735 } operator()736 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, 737 Index j) const { 738 return m_base_mapper(i + m_depth_offset, j + m_col_offset); 739 } 740 loadPacket(Index i)741 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { 742 return m_base_mapper.loadPacket(i + m_depth_offset, m_rowIndex, m_colIndex, 743 m_otherIndex); 744 } loadPacket(Index i,Index j)745 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, 746 Index j) const { 747 return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, 748 j + m_col_offset); 749 } 750 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar loadCoeffStandard(Index i)751 loadCoeffStandard(Index i) const { 752 return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex, 753 m_colIndex, m_otherIndex); 754 } 755 loadPacketFast(Index i)756 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const { 757 return m_base_mapper.loadPacketFast(i + m_depth_offset, m_rowIndex, 758 m_colIndex, m_otherIndex); 759 } 760 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index i)761 loadPacketStandard(Index i) const { 762 typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT; 763 return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>( 764 i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex); 765 } 766 template <typename Packet> aligned(Index)767 EIGEN_DEVICE_FUNC bool aligned(Index) const { 768 return false; 769 } 770 771 EIGEN_DEVICE_FUNC nonStandardPatches()772 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { 773 return m_base_mapper.nonStandardPatches(); 774 } 775 776 // Max(Col|Row|Depth): compute the upper limit for the column, row and depth 777 // index respectively that fits into the peeled_k elements starting at 778 // m_depth_offset. 779 780 EIGEN_DEVICE_FUNC maxCol(const Index peeled_k)781 EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const { 782 const Index max_col = 783 (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) / 784 fastPatchColStride(); 785 return std::min<Index>(1 + max_col, patchCols()); 786 } 787 788 EIGEN_DEVICE_FUNC maxRow(const Index peeled_k,const Index col)789 EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, 790 const Index col) const { 791 const Index max_row = (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) - 792 col * patchColStride()) / 793 fastPatchRowStride(); 794 return std::min<Index>(1 + max_row, patchRows()); 795 } 796 797 EIGEN_DEVICE_FUNC maxDepth(const Index peeled_k,const Index col,Index row)798 EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col, 799 Index row) const { 800 const Index max_depth = m_depth_offset + peeled_k - // 801 col * patchColStride() - // 802 row * patchRowStride(); 803 return std::min<Index>(max_depth, patchDepth()); 804 } 805 806 // MaxDepth uses only the remaining number of elements in the peeled_k. 807 EIGEN_DEVICE_FUNC maxDepth(const Index num_elements,const Index start_depth)808 EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, 809 const Index start_depth) const { 810 return std::min<Index>(start_depth + num_elements, patchDepth()); 811 } 812 813 // Every register matters in this code, so sometimes to prevent register 814 // spilling, instead of the variable that you would expect to see, we use 815 // another one, that is guaranteed to have the same value. E.g. patch depth is 816 // always the same as input depth, and it's also the same as input row stride. 817 // Bunch of other parameters have similar relations. 818 819 typedef internal::TensorIntDivisor<Index> IndexDivisor; 820 821 EIGEN_DEVICE_FUNC patchDepth()822 EIGEN_ALWAYS_INLINE Index patchDepth() const { 823 return m_base_mapper.m_rowInputStride; 824 } 825 EIGEN_DEVICE_FUNC patchRows()826 EIGEN_ALWAYS_INLINE Index patchRows() const { 827 return m_base_mapper.m_colStride; 828 } 829 EIGEN_DEVICE_FUNC patchCols()830 EIGEN_ALWAYS_INLINE Index patchCols() const { 831 return m_base_mapper.m_patch_cols; 832 } 833 834 EIGEN_DEVICE_FUNC patchRowStride()835 EIGEN_ALWAYS_INLINE Index patchRowStride() const { 836 eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride && 837 "Patch depth must be equal to patch row stride."); 838 return patchDepth(); 839 } 840 EIGEN_DEVICE_FUNC patchColStride()841 EIGEN_ALWAYS_INLINE Index patchColStride() const { 842 return m_base_mapper.m_patch_col_stride; 843 } 844 845 EIGEN_DEVICE_FUNC fastPatchRowStride()846 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const { 847 eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride && 848 "Patch depth must be equal to patch row stride."); 849 return m_base_mapper.m_fastDimZero; // patch_depth 850 } 851 EIGEN_DEVICE_FUNC fastPatchColStride()852 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const { 853 return m_base_mapper.m_fastPatchColStride; 854 } 855 856 EIGEN_DEVICE_FUNC packetNoPadding(const Index depth,const Index baseIndex)857 EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, 858 const Index baseIndex) const { 859 const Index inputIndex = depth + baseIndex; 860 return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex); 861 } 862 EIGEN_DEVICE_FUNC coeffNoPadding(const Index depth,const Index baseIndex)863 EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, 864 const Index baseIndex) const { 865 const Index inputIndex = depth + baseIndex; 866 return m_base_mapper.m_impl.coeff(inputIndex); 867 } 868 template <typename PacketT = Packet> 869 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< 870 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, 871 PacketT>::type partialPacketNoPadding(const Index depth,const Index baseIndex,Index num_coeffs)872 partialPacketNoPadding(const Index depth, const Index baseIndex, 873 Index num_coeffs) const { 874 const Index inputIndex = depth + baseIndex; 875 return m_base_mapper.m_impl.template partialPacket<PacketT>( 876 inputIndex, mask<PacketT>(0, num_coeffs)); 877 } 878 EIGEN_DEVICE_FUNC hasPadding()879 EIGEN_ALWAYS_INLINE bool hasPadding() const { 880 // TODO(ezhulenev): It does seems that for inflated filter it's still 881 // possible to guarantee "no padding or skipping" for non-standard packing. 882 if (nonStandardPatches()) return true; 883 884 // Non zero padding before. 885 if (m_base_mapper.m_rowPaddingTop > 0) return true; 886 if (m_base_mapper.m_colPaddingLeft > 0) return true; 887 888 // Non zero padding after in rows. 889 const Index last_row = 890 (m_base_mapper.m_outputRows - 1) * m_base_mapper.m_row_strides; 891 if (last_row + (patchRows() - 1) >= m_base_mapper.m_inputRows) return true; 892 893 // Non zero padding after in cols. 894 const Index last_col = 895 (m_base_mapper.m_outputCols - 1) * m_base_mapper.m_col_strides; 896 if (last_col + (patchCols() - 1) >= m_base_mapper.m_inputCols) return true; 897 898 return false; 899 } 900 EIGEN_DEVICE_FUNC padRow(const Index row)901 EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { 902 const Index r = m_rowIndex + row; 903 return r < 0 || r >= m_base_mapper.m_inputRows; 904 } 905 EIGEN_DEVICE_FUNC padAnyRow(const Index first_row,const Index last_row)906 EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row, 907 const Index last_row) const { 908 return m_rowIndex + first_row < 0 || 909 m_rowIndex + last_row >= m_base_mapper.m_inputRows; 910 } 911 EIGEN_DEVICE_FUNC padOrSkipRow(const Index row,Index * orig_row)912 EIGEN_ALWAYS_INLINE bool padOrSkipRow(const Index row, 913 Index* orig_row) const { 914 eigen_assert(nonStandardPatches()); 915 916 const Index input_row = m_rowIndex + row * m_base_mapper.m_in_row_strides; 917 *orig_row = (m_base_mapper.m_patch_row_inflate_strides == 1) 918 ? input_row 919 : ((input_row >= 0) 920 ? (input_row / m_base_mapper.m_fastInputRowStride) 921 : 0); 922 923 return (*orig_row < 0 || *orig_row >= m_base_mapper.m_inputRows) || 924 (input_row != *orig_row * m_base_mapper.m_patch_row_inflate_strides); 925 } 926 EIGEN_DEVICE_FUNC padCol(const Index col)927 EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { 928 const Index c = m_colIndex + col; 929 return c < 0 || c >= m_base_mapper.m_inputCols; 930 } 931 EIGEN_DEVICE_FUNC padOrSkipCol(const Index col,Index * orig_col)932 EIGEN_ALWAYS_INLINE bool padOrSkipCol(const Index col, 933 Index* orig_col) const { 934 eigen_assert(nonStandardPatches()); 935 936 const Index input_col = m_colIndex + col * m_base_mapper.m_in_col_strides; 937 *orig_col = (m_base_mapper.m_patch_col_inflate_strides == 1) 938 ? input_col 939 : ((input_col >= 0) 940 ? (input_col / m_base_mapper.m_fastInputColStride) 941 : 0); 942 943 return (*orig_col < 0 || *orig_col >= m_base_mapper.m_inputCols) || 944 (input_col != *orig_col * m_base_mapper.m_patch_col_inflate_strides); 945 } 946 EIGEN_DEVICE_FUNC baseIndex(const Index row,const Index col)947 EIGEN_ALWAYS_INLINE Index baseIndex(const Index row, const Index col) const { 948 const Index r = m_rowIndex + row; 949 const Index c = m_colIndex + col; 950 return r * m_base_mapper.m_rowInputStride + 951 c * m_base_mapper.m_colInputStride + m_otherIndex; 952 } 953 // Compute a base index when original input row and column were precomputed 954 // using padOrSkipRow and padOrSkipCol. Used only for non standard patches. 955 EIGEN_DEVICE_FUNC origBaseIndex(const Index orig_row,const Index orig_col)956 EIGEN_ALWAYS_INLINE Index origBaseIndex(const Index orig_row, 957 const Index orig_col) const { 958 return orig_row * m_base_mapper.m_rowInputStride + 959 orig_col * m_base_mapper.m_colInputStride + m_otherIndex; 960 } 961 962 EIGEN_DEVICE_FUNC rowStride()963 EIGEN_ALWAYS_INLINE Index rowStride() const { 964 return m_base_mapper.m_row_strides; 965 } 966 EIGEN_DEVICE_FUNC colStride()967 EIGEN_ALWAYS_INLINE Index colStride() const { 968 return m_base_mapper.m_col_strides; 969 } 970 971 EIGEN_DEVICE_FUNC rowOffset()972 EIGEN_ALWAYS_INLINE Index rowOffset() const { 973 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 974 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 975 return patchOffset - colOffset * m_base_mapper.m_colStride; 976 } 977 978 EIGEN_DEVICE_FUNC colOffset()979 EIGEN_ALWAYS_INLINE Index colOffset() const { 980 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 981 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 982 return colOffset; 983 } 984 985 EIGEN_DEVICE_FUNC depthOffset()986 EIGEN_ALWAYS_INLINE Index depthOffset() const { 987 return m_depth_offset % patchDepth(); 988 } 989 990 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i,Index j)991 getLinearMapper(Index i, Index j) const { 992 return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset); 993 } 994 995 private: 996 Index m_depth_offset; // First row in the input matrix 997 Index m_col_offset; // First col in the input matrix 998 999 // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base 1000 // indices for the first element in a patch specified by col_offset 1001 // (see computeBaseIndices(...) for details). 1002 Index m_rowIndex; 1003 Index m_colIndex; 1004 Index m_otherIndex; 1005 1006 const ParentMapper m_base_mapper; // Keeping a copy instead of a reference 1007 // performs better in benchmarks. 1008 }; 1009 1010 // Arrange a block of the right input matrix (in our case it's always a "virtual 1011 // matrix" constructed from extracted image patches) in contiguous memory. 1012 // 1013 // Given column major input (A0 beside A1 in memory): 1014 // A0 B0 C0 D0 E0 F0 G0 H0 ... Z0 1015 // A1 B1 C1 D1 E1 F1 G1 H1 ... Z1 1016 // A2 B2 C2 D2 E2 F2 G2 H2 ... Z2 1017 // A3 B3 C3 D3 E3 F3 G3 H3 ... Z3 1018 // A4 B4 C4 D4 E4 F4 G4 H4 ... Z4 1019 // A5 B5 C5 D5 E5 F5 G5 H5 ... Z5 1020 // A6 B6 C6 D6 E6 F6 G6 H6 ... Z6 1021 // A7 B7 C7 D7 E7 F7 G7 H7 ... Z7 1022 // A8 ... 1023 // ... 1024 // 1025 // *) A, B, C, ... - patches extracted from the original input. 1026 // *) A0, A1, A2 ... - values from the same patch at different offsets. 1027 // 1028 // The traversal (packed rhs memory) order (B0 besides A0 in memory): 1029 // A0 B0 C0 D0 A1 B1 C1 D1 ... 1030 // E0 F0 G0 H0 E1 F1 G1 H1 ... 1031 // ... 1032 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4) 1033 // 1034 // This traversal order must be the same as in default gemm_pack_rhs defined in 1035 // GeneralBlockPanelKernel.h. 1036 // 1037 // *) nr - number of registers along the 'n' dimension. 1038 // See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix 1039 // Multiplication" paper. 1040 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 1041 typename Device, typename Scalar, typename Index, 1042 typename nocontract_t, typename contract_t, int packet_size, 1043 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, 1044 int nr> 1045 struct gemm_pack_rhs< 1046 Scalar, Index, 1047 TensorContractionSubMapper< 1048 Scalar, Index, Rhs, 1049 TensorEvaluator< 1050 const TensorReshapingOp< 1051 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1052 Device>, 1053 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1054 inner_dim_reordered, Alignment>, 1055 nr, ColMajor, false, false> { 1056 typedef TensorContractionSubMapper< 1057 Scalar, Index, Rhs, 1058 TensorEvaluator< 1059 const TensorReshapingOp< 1060 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1061 Device>, 1062 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1063 inner_dim_reordered, Alignment> 1064 SubMapper; 1065 typedef SubMapper DataMapper; 1066 typedef typename packet_traits<Scalar>::type Packet; 1067 1068 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) 1069 1070 EIGEN_DEVICE_FUNC 1071 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1072 Index depth, Index cols, Index stride = 0, 1073 Index offset = 0) const { 1074 eigen_assert(stride == 0); 1075 eigen_assert(offset == 0); 1076 1077 const Index packet_cols4 = (cols / 4) * 4; 1078 const Index peeled_k = (depth / packet_size) * packet_size; 1079 const bool non_standard_patches = rhs.nonStandardPatches(); 1080 1081 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1082 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1083 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1084 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1085 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1086 1087 Index k = 0; 1088 if ((packet_size % 4) == 0 && !non_standard_patches) { 1089 // FAST PATH: 1090 // Iterate over patch columns and rows, if we know that a single 1091 // packet do not span across multiple rows or columns. 1092 if ((rhs.patchDepth() % packet_size) == 0) { 1093 const Index start_col = rhs.colOffset(); 1094 const Index max_col = rhs.maxCol(peeled_k); 1095 1096 for (Index c = start_col; c < max_col; ++c) { 1097 eigen_assert(k <= peeled_k); 1098 1099 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 1100 const Index max_row = rhs.maxRow(peeled_k, c); 1101 1102 const bool pad_col0 = dm0.padCol(c); 1103 const bool pad_col1 = dm1.padCol(c); 1104 const bool pad_col2 = dm2.padCol(c); 1105 const bool pad_col3 = dm3.padCol(c); 1106 1107 // Check if we can squeeze reads along the `row` and `depth` 1108 // dimensions (two innermost dimensions). 1109 if (!pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && // 1110 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && // 1111 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && // 1112 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && // 1113 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) { 1114 // Compute how many elements we can squeeze read. 1115 const Index start_depth = 1116 (c == start_col) ? rhs.depthOffset() : 0; 1117 1118 // Upper bound for the number of elements in the depth dimension 1119 // that we can squeeze read. 1120 const Index squeeze_length = 1121 (max_row - start_row) * rhs.patchDepth() - start_depth; 1122 1123 // Do not overshoot beyond the block size. 1124 const Index max_depth = 1125 start_depth + std::min<Index>(peeled_k - k, squeeze_length); 1126 eigen_assert((max_depth - start_depth) % packet_size == 0); 1127 1128 const Index idx0 = dm0.baseIndex(start_row, c); 1129 const Index idx1 = dm1.baseIndex(start_row, c); 1130 const Index idx2 = dm2.baseIndex(start_row, c); 1131 const Index idx3 = dm3.baseIndex(start_row, c); 1132 1133 for (Index d = start_depth; d < max_depth; d += packet_size) { 1134 eigen_assert(k < peeled_k); 1135 PacketBlock<Packet, 4> kernel; 1136 kernel.packet[0] = rhs.packetNoPadding(d, idx0); 1137 kernel.packet[1] = rhs.packetNoPadding(d, idx1); 1138 kernel.packet[2] = rhs.packetNoPadding(d, idx2); 1139 kernel.packet[3] = rhs.packetNoPadding(d, idx3); 1140 ptranspose(kernel); 1141 pstoreu(block + 0 * packet_size, kernel.packet[0]); 1142 pstoreu(block + 1 * packet_size, kernel.packet[1]); 1143 pstoreu(block + 2 * packet_size, kernel.packet[2]); 1144 pstoreu(block + 3 * packet_size, kernel.packet[3]); 1145 block += 4 * packet_size; 1146 k += packet_size; 1147 } 1148 1149 // Go to the next column. 1150 continue; 1151 } 1152 1153 // If we can't squeeze reads, process rows one by one. 1154 for (Index r = start_row; r < max_row; ++r) { 1155 eigen_assert(k <= peeled_k); 1156 1157 const bool pad0 = pad_col0 || dm0.padRow(r); 1158 const bool pad1 = pad_col1 || dm1.padRow(r); 1159 const bool pad2 = pad_col2 || dm2.padRow(r); 1160 const bool pad3 = pad_col3 || dm3.padRow(r); 1161 1162 const Index idx0 = dm0.baseIndex(r, c); 1163 const Index idx1 = dm1.baseIndex(r, c); 1164 const Index idx2 = dm2.baseIndex(r, c); 1165 const Index idx3 = dm3.baseIndex(r, c); 1166 1167 const Index start_depth = ((c == start_col) && (r == start_row)) 1168 ? rhs.depthOffset() 1169 : 0; 1170 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 1171 eigen_assert((max_depth - start_depth) % packet_size == 0); 1172 1173 for (Index d = start_depth; d < max_depth; d += packet_size) { 1174 eigen_assert(k < peeled_k); 1175 PacketBlock<Packet, 4> kernel; 1176 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) 1177 : rhs.packetNoPadding(d, idx0); 1178 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) 1179 : rhs.packetNoPadding(d, idx1); 1180 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) 1181 : rhs.packetNoPadding(d, idx2); 1182 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) 1183 : rhs.packetNoPadding(d, idx3); 1184 ptranspose(kernel); 1185 pstoreu(block + 0 * packet_size, kernel.packet[0]); 1186 pstoreu(block + 1 * packet_size, kernel.packet[1]); 1187 pstoreu(block + 2 * packet_size, kernel.packet[2]); 1188 pstoreu(block + 3 * packet_size, kernel.packet[3]); 1189 block += 4 * packet_size; 1190 k += packet_size; 1191 } 1192 } 1193 } 1194 1195 // The loop above should fill peeled_k elements. 1196 eigen_assert(peeled_k == k); 1197 1198 } else { 1199 for (; k < peeled_k; k += packet_size) { 1200 PacketBlock<Packet, 4> kernel; 1201 kernel.packet[0] = dm0.loadPacketStandard(k); 1202 kernel.packet[1] = dm1.loadPacketStandard(k); 1203 kernel.packet[2] = dm2.loadPacketStandard(k); 1204 kernel.packet[3] = dm3.loadPacketStandard(k); 1205 ptranspose(kernel); 1206 pstoreu(block + 0 * packet_size, kernel.packet[0]); 1207 pstoreu(block + 1 * packet_size, kernel.packet[1]); 1208 pstoreu(block + 2 * packet_size, kernel.packet[2]); 1209 pstoreu(block + 3 * packet_size, kernel.packet[3]); 1210 block += 4 * packet_size; 1211 } 1212 } 1213 } 1214 1215 // Copy the remaining coefficients of the column block after the peeled_k. 1216 if (!rhs.nonStandardPatches()) { 1217 for (; k < depth; k++) { 1218 block[0] = dm0.loadCoeffStandard(k); 1219 block[1] = dm1.loadCoeffStandard(k); 1220 block[2] = dm2.loadCoeffStandard(k); 1221 block[3] = dm3.loadCoeffStandard(k); 1222 block += 4; 1223 } 1224 } else { 1225 for (; k < depth; k++) { 1226 block[0] = dm0(k); 1227 block[1] = dm1(k); 1228 block[2] = dm2(k); 1229 block[3] = dm3(k); 1230 block += 4; 1231 } 1232 } 1233 } 1234 1235 // copy the remaining columns one at a time (nr==1) 1236 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) 1237 // remaining columns are handled different for PPC 1238 for (Index k = 0; k < depth; k++) { 1239 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1240 *block = rhs(k, j2); 1241 block += 1; 1242 } 1243 } 1244 #else 1245 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1246 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1247 for (Index k = 0; k < depth; k++) { 1248 *block = dm0(k); 1249 block += 1; 1250 } 1251 } 1252 #endif 1253 } 1254 }; 1255 1256 // Template specialization for packet_size = 2. We must special-case packet 1257 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>. 1258 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 1259 typename Device, typename Scalar, typename Index, 1260 typename nocontract_t, typename contract_t, bool inner_dim_contiguous, 1261 bool inner_dim_reordered, int Alignment, int nr> 1262 struct gemm_pack_rhs< 1263 Scalar, Index, 1264 TensorContractionSubMapper< 1265 Scalar, Index, Rhs, 1266 TensorEvaluator< 1267 const TensorReshapingOp< 1268 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1269 Device>, 1270 nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, 1271 Alignment>, 1272 nr, ColMajor, false, false> { 1273 typedef TensorContractionSubMapper< 1274 Scalar, Index, Rhs, 1275 TensorEvaluator< 1276 const TensorReshapingOp< 1277 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1278 Device>, 1279 nocontract_t, contract_t, 2, inner_dim_contiguous, inner_dim_reordered, 1280 Alignment> 1281 SubMapper; 1282 typedef SubMapper DataMapper; 1283 typedef typename packet_traits<Scalar>::type Packet; 1284 1285 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) 1286 1287 EIGEN_DEVICE_FUNC 1288 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1289 Index depth, Index cols, Index stride = 0, 1290 Index offset = 0) const { 1291 eigen_assert(stride == 0); 1292 eigen_assert(offset == 0); 1293 1294 const int packet_size = 2; 1295 const Index packet_cols4 = (cols / 4) * 4; 1296 const Index peeled_k = (depth / packet_size) * packet_size; 1297 const bool non_standard_patches = rhs.nonStandardPatches(); 1298 1299 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1300 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1301 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1302 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1303 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1304 1305 Index k = 0; 1306 if (!non_standard_patches) { 1307 // FAST PATH: 1308 // Iterate over patch columns and rows if we know that a single 1309 // packet do not span across multiple rows or columns. 1310 if ((rhs.patchDepth() % packet_size) == 0) { 1311 const Index start_col = rhs.colOffset(); 1312 const Index max_col = rhs.maxCol(peeled_k); 1313 1314 for (Index c = start_col; c < max_col; ++c) { 1315 eigen_assert(k <= peeled_k); 1316 1317 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 1318 const Index max_row = rhs.maxRow(peeled_k, c); 1319 1320 const bool pad_col0 = dm0.padCol(c); 1321 const bool pad_col1 = dm1.padCol(c); 1322 const bool pad_col2 = dm2.padCol(c); 1323 const bool pad_col3 = dm3.padCol(c); 1324 1325 // We can squeeze reads along the `row` and `depth` dimensions if 1326 // the row stride is `1`, which means that `row` and `depth` 1327 // dimensions are contiguous (two innermost dimensions). 1328 if (rhs.rowStride() == 1 && // 1329 !pad_col0 && !pad_col1 && !pad_col2 && !pad_col3 && // 1330 !dm0.padRow(start_row) && !dm0.padRow(max_row - 1) && // 1331 !dm1.padRow(start_row) && !dm1.padRow(max_row - 1) && // 1332 !dm2.padRow(start_row) && !dm2.padRow(max_row - 1) && // 1333 !dm3.padRow(start_row) && !dm3.padRow(max_row - 1)) { 1334 // Compute how many elements we can squeeze read. 1335 const Index start_depth = 1336 (c == start_col) ? rhs.depthOffset() : 0; 1337 1338 // Upper bound for the number of elements in the depth dimension 1339 // that we can squeeze read. 1340 const Index squeeze_length = 1341 (max_row - start_row) * rhs.patchDepth() - start_depth; 1342 1343 // Do not overshoot beyond the block size. 1344 const Index max_depth = 1345 start_depth + std::min<Index>(peeled_k - k, squeeze_length); 1346 eigen_assert((max_depth - start_depth) % packet_size == 0); 1347 1348 const Index idx0 = dm0.baseIndex(start_row, c); 1349 const Index idx1 = dm1.baseIndex(start_row, c); 1350 const Index idx2 = dm2.baseIndex(start_row, c); 1351 const Index idx3 = dm3.baseIndex(start_row, c); 1352 1353 for (Index d = start_depth; d < max_depth; d += packet_size) { 1354 PacketBlock<Packet, 2> kernel0; 1355 PacketBlock<Packet, 2> kernel1; 1356 kernel0.packet[0] = rhs.packetNoPadding(d, idx0); 1357 kernel0.packet[1] = rhs.packetNoPadding(d, idx1); 1358 kernel1.packet[0] = rhs.packetNoPadding(d, idx2); 1359 kernel1.packet[1] = rhs.packetNoPadding(d, idx3); 1360 ptranspose(kernel0); 1361 ptranspose(kernel1); 1362 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1363 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1364 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1365 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1366 block += 4 * packet_size; 1367 k += packet_size; 1368 } 1369 1370 // Go to the next column. 1371 continue; 1372 } 1373 1374 // If we can't squeeze reads, process rows one by one. 1375 for (Index r = start_row; r < max_row; ++r) { 1376 eigen_assert(k <= peeled_k); 1377 1378 const bool pad0 = pad_col0 || dm0.padRow(r); 1379 const bool pad1 = pad_col1 || dm1.padRow(r); 1380 const bool pad2 = pad_col2 || dm2.padRow(r); 1381 const bool pad3 = pad_col3 || dm3.padRow(r); 1382 1383 const Index idx0 = dm0.baseIndex(r, c); 1384 const Index idx1 = dm1.baseIndex(r, c); 1385 const Index idx2 = dm2.baseIndex(r, c); 1386 const Index idx3 = dm3.baseIndex(r, c); 1387 1388 const Index start_depth = ((c == start_col) && (r == start_row)) 1389 ? rhs.depthOffset() 1390 : 0; 1391 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 1392 eigen_assert((max_depth - start_depth) % packet_size == 0); 1393 1394 for (Index d = start_depth; d < max_depth; d += packet_size) { 1395 eigen_assert(k < peeled_k); 1396 PacketBlock<Packet, 2> kernel0; 1397 PacketBlock<Packet, 2> kernel1; 1398 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) 1399 : rhs.packetNoPadding(d, idx0); 1400 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) 1401 : rhs.packetNoPadding(d, idx1); 1402 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) 1403 : rhs.packetNoPadding(d, idx2); 1404 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) 1405 : rhs.packetNoPadding(d, idx3); 1406 ptranspose(kernel0); 1407 ptranspose(kernel1); 1408 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1409 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1410 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1411 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1412 block += 4 * packet_size; 1413 k += packet_size; 1414 } 1415 } 1416 } 1417 1418 // The loop above should fill peeled_k elements. 1419 eigen_assert(peeled_k == k); 1420 1421 } else { 1422 // Packet can span multiple rows or columns, so we have to go 1423 // though the slower "standard" path. 1424 for (; k < peeled_k; k += packet_size) { 1425 PacketBlock<Packet, 2> kernel0; 1426 PacketBlock<Packet, 2> kernel1; 1427 kernel0.packet[0] = dm0.loadPacketStandard(k); 1428 kernel0.packet[1] = dm1.loadPacketStandard(k); 1429 kernel1.packet[0] = dm2.loadPacketStandard(k); 1430 kernel1.packet[1] = dm3.loadPacketStandard(k); 1431 ptranspose(kernel0); 1432 ptranspose(kernel1); 1433 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1434 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1435 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1436 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1437 block += 4 * packet_size; 1438 } 1439 } 1440 } 1441 1442 // Copy the remaining coefficients of the column block after the peeled_k. 1443 if (!non_standard_patches) { 1444 for (; k < depth; k++) { 1445 block[0] = dm0.loadCoeffStandard(k); 1446 block[1] = dm1.loadCoeffStandard(k); 1447 block[2] = dm2.loadCoeffStandard(k); 1448 block[3] = dm3.loadCoeffStandard(k); 1449 block += 4; 1450 } 1451 } else { 1452 for (; k < depth; k++) { 1453 block[0] = dm0(k); 1454 block[1] = dm1(k); 1455 block[2] = dm2(k); 1456 block[3] = dm3(k); 1457 block += 4; 1458 } 1459 } 1460 } 1461 1462 // Copy the remaining columns one at a time (nr==1). 1463 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1464 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1465 for (Index k = 0; k < depth; k++) { 1466 *block = dm0(k); 1467 block += 1; 1468 } 1469 } 1470 } 1471 }; 1472 1473 // Special case for non-vectorized types such as float16. 1474 template <typename NewDimension, Index Rows, Index Cols, typename ArgType, 1475 typename Device, typename Scalar, typename Index, 1476 typename nocontract_t, typename contract_t, bool inner_dim_contiguous, 1477 bool inner_dim_reordered, int Alignment, int nr> 1478 struct gemm_pack_rhs< 1479 Scalar, Index, 1480 TensorContractionSubMapper< 1481 Scalar, Index, Rhs, 1482 TensorEvaluator< 1483 const TensorReshapingOp< 1484 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1485 Device>, 1486 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, 1487 Alignment>, 1488 nr, ColMajor, false, false> { 1489 typedef TensorContractionSubMapper< 1490 Scalar, Index, Rhs, 1491 TensorEvaluator< 1492 const TensorReshapingOp< 1493 NewDimension, const TensorImagePatchOp<Rows, Cols, ArgType> >, 1494 Device>, 1495 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, 1496 Alignment> 1497 SubMapper; 1498 typedef SubMapper DataMapper; 1499 1500 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE) 1501 1502 EIGEN_DEVICE_FUNC 1503 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1504 Index depth, Index cols, Index stride = 0, 1505 Index offset = 0) const { 1506 eigen_assert(stride == 0); 1507 eigen_assert(offset == 0); 1508 1509 const Index packet_cols4 = (cols / 4) * 4; 1510 1511 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1512 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1513 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1514 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1515 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1516 1517 if (!rhs.nonStandardPatches()) { 1518 for (Index k = 0; k < depth; k++) { 1519 block[0] = dm0.loadCoeffStandard(k); 1520 block[1] = dm1.loadCoeffStandard(k); 1521 block[2] = dm2.loadCoeffStandard(k); 1522 block[3] = dm3.loadCoeffStandard(k); 1523 block += 4; 1524 } 1525 } else { 1526 for (Index k = 0; k < depth; k++) { 1527 block[0] = dm0(k); 1528 block[1] = dm1(k); 1529 block[2] = dm2(k); 1530 block[3] = dm3(k); 1531 block += 4; 1532 } 1533 } 1534 } 1535 1536 // Copy the remaining columns one at a time (nr==1). 1537 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1538 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1539 for (Index k = 0; k < depth; k++) { 1540 *block = dm0(k); 1541 block += 1; 1542 } 1543 } 1544 } 1545 }; 1546 #endif 1547 } // end namespace internal 1548 1549 /** SpatialConvolution 1550 * \ingroup CXX11_NeuralNetworks_Module 1551 * 1552 * \brief Applies a 2D convolution over a multichannel input image. 1553 * 1554 * The input parameter is expected to be a tensor with a rank of 3 or more 1555 * (channels, height, width, and optionally others) 1556 * The kernel parameter is expected to be a 4D tensor (filters, channels, 1557 * kernel_height, kernel_width) 1558 * The input and the kernel must both be in col-major layout. The result will 1559 * also be in col-major layout. 1560 * 1561 * If col_in_stride, row_in_stride > 1, then applies convolution with holes 1562 * (aka atrous convolution), sampling every col_in_stride, row_in_stride input 1563 * pixels. 1564 * 1565 * If padding_top, padding_bottom, padding_left, or padding_right is specified, 1566 * then those paddings will be used to pad the input, and padding_type must be 1567 * PADDING_VALID. 1568 * 1569 * The result can be assigned to a tensor of rank equal to the rank of the 1570 * input. The dimensions of the result will be filters, height, width (and 1571 * others if applicable). 1572 * 1573 * It is possible to swap the order of the width and height dimensions provided 1574 * that the same order is used in the input, the kernel, and the output. 1575 * 1576 * It is also possible to add an output kernel to the contraction, output 1577 * kernel is called by Eigen when it "finalizes" the block of an output tensor. 1578 * 1579 */ 1580 template <typename Input, typename Kernel, 1581 typename OutputKernel = const NoOpOutputKernel> 1582 EIGEN_ALWAYS_INLINE static const std::conditional_t< 1583 internal::traits<Input>::Layout == ColMajor, 1584 TensorReshapingOp< 1585 const DSizes<typename internal::traits<Input>::Index, 1586 internal::traits<Input>::NumDimensions>, 1587 const TensorContractionOp< 1588 const array<IndexPair<typename internal::traits<Input>::Index>, 1>, 1589 const TensorReshapingOp< 1590 const DSizes<typename internal::traits<Input>::Index, 2>, 1591 const Kernel>, 1592 const TensorReshapingOp< 1593 const DSizes<typename internal::traits<Input>::Index, 2>, 1594 const TensorImagePatchOp<Dynamic, Dynamic, const Input> >, 1595 const OutputKernel> >, 1596 TensorReshapingOp< 1597 const DSizes<typename internal::traits<Input>::Index, 1598 internal::traits<Input>::NumDimensions>, 1599 const TensorContractionOp< 1600 const array<IndexPair<typename internal::traits<Input>::Index>, 1>, 1601 const TensorReshapingOp< 1602 const DSizes<typename internal::traits<Input>::Index, 2>, 1603 const TensorImagePatchOp<Dynamic, Dynamic, const Input> >, 1604 const TensorReshapingOp< 1605 const DSizes<typename internal::traits<Input>::Index, 2>, 1606 const Kernel>, 1607 const OutputKernel> > > 1608 SpatialConvolution(const Input& input, const Kernel& kernel, 1609 const Index row_stride = 1, const Index col_stride = 1, 1610 const PaddingType padding_type = PADDING_SAME, 1611 const Index row_in_stride = 1, const Index col_in_stride = 1, 1612 const OutputKernel& output_kernel = OutputKernel(), 1613 Index padding_top = 0, Index padding_bottom = 0, 1614 Index padding_left = 0, Index padding_right = 0) { 1615 typedef typename internal::traits<Input>::Index TensorIndex; 1616 typedef typename internal::traits<Input>::Scalar InputScalar; 1617 TensorRef<Tensor<InputScalar, internal::traits<Input>::NumDimensions, 1618 internal::traits<Input>::Layout, TensorIndex> > 1619 in(input); 1620 TensorRef<Tensor<typename internal::traits<Kernel>::Scalar, 1621 internal::traits<Kernel>::NumDimensions, 1622 internal::traits<Kernel>::Layout, TensorIndex> > 1623 kern(kernel); 1624 1625 EIGEN_STATIC_ASSERT( 1626 internal::traits<Input>::Layout == internal::traits<Kernel>::Layout, 1627 YOU_MADE_A_PROGRAMMING_MISTAKE) 1628 const bool isColMajor = (internal::traits<Input>::Layout == ColMajor); 1629 1630 const int NumDims = internal::traits<Input>::NumDimensions; 1631 1632 // Number of filters to apply. This is the same as the output depth of the 1633 // result 1634 const TensorIndex kernelFilters = 1635 isColMajor ? kern.dimensions()[0] : kern.dimensions()[3]; 1636 // Number of channels. This is the same as the input depth. 1637 const TensorIndex kernelChannels = 1638 isColMajor ? kern.dimensions()[1] : kern.dimensions()[2]; 1639 const TensorIndex kernelRows = 1640 isColMajor ? kern.dimensions()[2] : kern.dimensions()[1]; 1641 const TensorIndex kernelCols = 1642 isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; 1643 1644 const Index kernelRowsEff = 1645 kernelRows + (kernelRows - 1) * (row_in_stride - 1); 1646 const Index kernelColsEff = 1647 kernelCols + (kernelCols - 1) * (col_in_stride - 1); 1648 1649 array<IndexPair<TensorIndex>, 1> contract_dims; 1650 contract_dims[0] = IndexPair<TensorIndex>(1, 0); 1651 1652 const TensorIndex InputRows = 1653 isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); 1654 const TensorIndex InputCols = 1655 isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); 1656 const bool padding_explicit = 1657 (padding_top || padding_bottom || padding_left || padding_right); 1658 1659 TensorIndex out_height; 1660 TensorIndex out_width; 1661 switch (padding_type) { 1662 case PADDING_VALID: { 1663 const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom; 1664 const TensorIndex InputColsEff = InputCols + padding_left + padding_right; 1665 out_height = divup(InputRowsEff - kernelRowsEff + 1, row_stride); 1666 out_width = divup(InputColsEff - kernelColsEff + 1, col_stride); 1667 break; 1668 } 1669 case PADDING_SAME: { 1670 eigen_assert(!padding_explicit); 1671 out_height = divup(InputRows, row_stride); 1672 out_width = divup(InputCols, col_stride); 1673 break; 1674 } 1675 default: { 1676 // Initialize unused variables to avoid a compiler warning 1677 out_height = 0; 1678 out_width = 0; 1679 eigen_assert(false && "unexpected padding"); 1680 } 1681 } 1682 1683 // Molds the output of the patch extraction code into a 2d tensor: 1684 // - the first dimension (dims[0]): the patch values to be multiplied with the 1685 // kernels 1686 // - the second dimension (dims[1]): everything else 1687 DSizes<TensorIndex, 2> pre_contract_dims; 1688 if (isColMajor) { 1689 pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols; 1690 pre_contract_dims[1] = out_height * out_width; 1691 for (int i = 3; i < NumDims; ++i) { 1692 pre_contract_dims[1] *= in.dimension(i); 1693 } 1694 } else { 1695 pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols; 1696 pre_contract_dims[0] = out_height * out_width; 1697 for (int i = 0; i < NumDims - 3; ++i) { 1698 pre_contract_dims[0] *= in.dimension(i); 1699 } 1700 } 1701 1702 // Molds the output of the contraction into the shape expected by the used 1703 // (assuming this is ColMajor): 1704 // - 1st dim: kernel filters 1705 // - 2nd dim: output height 1706 // - 3rd dim: output width 1707 // - 4th dim and beyond: everything else including batch size 1708 DSizes<TensorIndex, NumDims> post_contract_dims; 1709 if (isColMajor) { 1710 post_contract_dims[0] = kernelFilters; 1711 post_contract_dims[1] = out_height; 1712 post_contract_dims[2] = out_width; 1713 for (int i = 3; i < NumDims; ++i) { 1714 post_contract_dims[i] = in.dimension(i); 1715 } 1716 } else { 1717 post_contract_dims[NumDims - 1] = kernelFilters; 1718 post_contract_dims[NumDims - 2] = out_height; 1719 post_contract_dims[NumDims - 3] = out_width; 1720 for (int i = 0; i < NumDims - 3; ++i) { 1721 post_contract_dims[i] = in.dimension(i); 1722 } 1723 } 1724 1725 DSizes<TensorIndex, 2> kernel_dims; 1726 if (isColMajor) { 1727 kernel_dims[0] = kernelFilters; 1728 kernel_dims[1] = kernelChannels * kernelRows * kernelCols; 1729 } else { 1730 kernel_dims[0] = kernelChannels * kernelRows * kernelCols; 1731 kernel_dims[1] = kernelFilters; 1732 } 1733 if (padding_explicit) { 1734 return choose( 1735 Cond<internal::traits<Input>::Layout == ColMajor>(), 1736 kernel.reshape(kernel_dims) 1737 .contract(input 1738 .extract_image_patches( 1739 kernelRows, kernelCols, row_stride, col_stride, 1740 row_in_stride, col_in_stride, 1741 /*row_inflate_stride=*/1, 1742 /*col_inflate_stride=*/1, padding_top, 1743 padding_bottom, padding_left, padding_right, 1744 /*padding_value=*/static_cast<InputScalar>(0)) 1745 .reshape(pre_contract_dims), 1746 contract_dims, output_kernel) 1747 .reshape(post_contract_dims), 1748 input 1749 .extract_image_patches( 1750 kernelRows, kernelCols, row_stride, col_stride, row_in_stride, 1751 col_in_stride, 1752 /*row_inflate_stride=*/1, 1753 /*col_inflate_stride=*/1, padding_top, padding_bottom, 1754 padding_left, padding_right, 1755 /*padding_value=*/static_cast<InputScalar>(0)) 1756 .reshape(pre_contract_dims) 1757 .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel) 1758 .reshape(post_contract_dims)); 1759 } else { 1760 return choose( 1761 Cond<internal::traits<Input>::Layout == ColMajor>(), 1762 kernel.reshape(kernel_dims) 1763 .contract(input 1764 .extract_image_patches( 1765 kernelRows, kernelCols, row_stride, col_stride, 1766 row_in_stride, col_in_stride, padding_type) 1767 .reshape(pre_contract_dims), 1768 contract_dims, output_kernel) 1769 .reshape(post_contract_dims), 1770 input 1771 .extract_image_patches(kernelRows, kernelCols, row_stride, 1772 col_stride, row_in_stride, col_in_stride, 1773 padding_type) 1774 .reshape(pre_contract_dims) 1775 .contract(kernel.reshape(kernel_dims), contract_dims, output_kernel) 1776 .reshape(post_contract_dims)); 1777 } 1778 } 1779 1780 } // end namespace Eigen 1781 1782 #endif // TENSORFLOW_CORE_KERNELS_EIGEN_SPATIAL_CONVOLUTIONS_INL_H_ 1783