1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ 17 #define TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 21 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) 22 #include "tensorflow/core/kernels/eigen_contraction_kernel.h" 23 #endif 24 25 #include "tensorflow/core/kernels/eigen_convolution_helpers.h" 26 27 namespace Eigen { 28 29 namespace internal { 30 31 #if !EIGEN_ALTIVEC_USE_CUSTOM_PACK 32 // WARNING: Most of the code here implicitly assumes that the matrix is in 33 // ColMajor layout. This is guaranteed by the tensor contraction (see 34 // TensorContraction.h). 35 // 36 // Inside Eigen a tensor contraction is represented by a matrix multiplication. 37 // We don't want to actually extract volume patches and reshape the result into 38 // a matrix (this involves allocating huge extra memory), so the patch 39 // extraction and reshape operations are implicit. 40 // 41 // TensorContractionInputMapper takes a matrix index and returns the coefficient 42 // (or the packet) of the "virtual tensor", that would be at that index if we 43 // were to actually reshape the result of patch extraction. 44 // 45 // TensorContractionSubMapper provides a similar view into the "virtual matrix" 46 // at the given vertical and horizontal offsets. 47 // 48 // "Virtual matrix" dimensions: 49 // *0: kernelChannels * kernelPlanes * kernelRows * kernelCols 50 // 1: out_planes * out_height * out_width * OTHERS (e.g batches, etc...) 51 // 52 // *) extracted patches are continuous in memory (innermost dimension assuming 53 // col major layout) 54 // 55 // With this dimensions: 56 // row - offset within a single patch (in code: patchId) 57 // col - index of the extracted patch (in code: patchIndex) 58 // patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) 59 // 60 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 61 typename ArgType, typename Device, typename Scalar_, typename Index, 62 typename nocontract_t, typename contract_t, int Side, int packet_size, 63 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 64 class TensorContractionInputMapper< 65 Scalar_, Index, Side, 66 TensorEvaluator<const TensorReshapingOp<NewDimension, 67 const TensorVolumePatchOp< 68 Planes, Rows, Cols, ArgType> >, 69 Device>, 70 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 71 inner_dim_reordered, Alignment> { 72 public: 73 typedef Scalar_ Scalar; 74 typedef TensorContractionInputMapper< 75 Scalar, Index, Side, 76 TensorEvaluator<const TensorReshapingOp< 77 NewDimension, const TensorVolumePatchOp< 78 Planes, Rows, Cols, ArgType> >, 79 Device>, 80 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 81 inner_dim_reordered, Alignment> 82 Self; 83 typedef TensorContractionSubMapper< 84 Scalar, Index, Side, 85 TensorEvaluator<const TensorReshapingOp< 86 NewDimension, const TensorVolumePatchOp< 87 Planes, Rows, Cols, ArgType> >, 88 Device>, 89 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 90 inner_dim_reordered, Alignment> 91 SubMapper; 92 typedef SubMapper VectorMapper; 93 typedef SubMapper LinearMapper; 94 typedef typename packet_traits<Scalar>::type Packet; 95 96 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorEvaluator<const TensorReshapingOp<NewDimension,const TensorVolumePatchOp<Planes,Rows,Cols,ArgType>>,Device> & tensor,const nocontract_t &,const nocontract_t &,const contract_t &,const contract_t &)97 TensorContractionInputMapper( 98 const TensorEvaluator< 99 const TensorReshapingOp< 100 NewDimension, 101 const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >, 102 Device>& tensor, 103 const nocontract_t&, const nocontract_t&, const contract_t&, 104 const contract_t&) 105 : m_impl(tensor.impl().impl()) { 106 if (internal::traits<ArgType>::Layout == ColMajor) { 107 m_patch_depth = tensor.impl().dimensions()[0]; 108 m_patch_planes = tensor.impl().dimensions()[1]; 109 m_patch_rows = tensor.impl().dimensions()[2]; 110 m_patch_cols = tensor.impl().dimensions()[3]; 111 m_num_patches = tensor.impl().dimensions()[4]; 112 } else { 113 const int NumDims = tensor.impl().dimensions().size(); 114 m_patch_depth = tensor.impl().dimensions()[NumDims - 1]; 115 m_patch_planes = tensor.impl().dimensions()[NumDims - 2]; 116 m_patch_rows = tensor.impl().dimensions()[NumDims - 3]; 117 m_patch_cols = tensor.impl().dimensions()[NumDims - 4]; 118 m_num_patches = tensor.impl().dimensions()[NumDims - 5]; 119 } 120 121 // Strides for navigating through the single patch. 122 m_patch_plane_stride = m_patch_depth; 123 m_patch_row_stride = m_patch_planes * m_patch_plane_stride; 124 m_patch_col_stride = m_patch_rows * m_patch_row_stride; 125 126 // Strides for the output tensor. 127 // IMPORTANT: These strides are used to locate an element in a patch at a 128 // depth zero (channel), which is not quite the same as "traditional" 129 // stride. 130 m_rowStride = m_patch_planes; 131 m_colStride = m_patch_rows * m_rowStride; 132 m_patchStride = m_colStride * m_patch_cols * m_patch_depth; 133 m_otherStride = m_patchStride * m_num_patches; 134 135 m_outputPlanes = tensor.impl().outputPlanes(); 136 m_outputRows = tensor.impl().outputRows(); 137 m_outputCols = tensor.impl().outputCols(); 138 139 m_outputPlanesRows = m_outputPlanes * m_outputRows; 140 141 m_plane_strides = tensor.impl().userPlaneStride(); 142 m_row_strides = tensor.impl().userRowStride(); 143 m_col_strides = tensor.impl().userColStride(); 144 145 m_in_plane_strides = tensor.impl().userInPlaneStride(); 146 m_in_row_strides = tensor.impl().userInRowStride(); 147 m_in_col_strides = tensor.impl().userInColStride(); 148 149 m_patch_plane_inflate_strides = tensor.impl().planeInflateStride(); 150 m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); 151 m_patch_col_inflate_strides = tensor.impl().colInflateStride(); 152 153 if (internal::traits<ArgType>::Layout == ColMajor) { 154 m_inputDepth = tensor.impl().impl().dimensions()[0]; 155 m_inputPlanes = tensor.impl().impl().dimensions()[1]; 156 m_inputRows = tensor.impl().impl().dimensions()[2]; 157 m_inputCols = tensor.impl().impl().dimensions()[3]; 158 } else { 159 const int NumDims = tensor.impl().impl().dimensions().size(); 160 m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1]; 161 m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2]; 162 m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3]; 163 m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4]; 164 } 165 166 // Strides for navigating through the input tensor. 167 m_planeInputStride = m_inputDepth; 168 m_rowInputStride = m_inputDepth * m_inputPlanes; 169 m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes; 170 m_patchInputStride = 171 m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes; 172 173 m_planePaddingTop = tensor.impl().planePaddingTop(); 174 m_rowPaddingTop = tensor.impl().rowPaddingTop(); 175 m_colPaddingLeft = tensor.impl().colPaddingLeft(); 176 177 m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches); 178 179 m_fastPatchPlaneStride = 180 internal::TensorIntDivisor<Index>(m_patch_plane_stride); 181 m_fastPatchRowStride = 182 internal::TensorIntDivisor<Index>(m_patch_row_stride); 183 m_fastPatchColStride = 184 internal::TensorIntDivisor<Index>(m_patch_col_stride); 185 186 m_fastInputPlaneStride = 187 internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides); 188 m_fastInputRowStride = 189 internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides); 190 m_fastInputColStride = 191 internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides); 192 193 m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride); 194 m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride); 195 196 m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth); 197 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); 198 m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes); 199 m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); 200 m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols); 201 202 m_fastOutputPlanesRows = 203 internal::TensorIntDivisor<Index>(m_outputPlanesRows); 204 } 205 206 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const TensorContractionInputMapper & base_mapper)207 TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) 208 : m_impl(base_mapper.m_impl) { 209 m_patch_depth = base_mapper.m_patch_depth; 210 m_patch_planes = base_mapper.m_patch_planes; 211 m_patch_rows = base_mapper.m_patch_rows; 212 m_patch_cols = base_mapper.m_patch_cols; 213 m_num_patches = base_mapper.m_num_patches; 214 215 m_patch_plane_stride = base_mapper.m_patch_plane_stride; 216 m_patch_row_stride = base_mapper.m_patch_row_stride; 217 m_patch_col_stride = base_mapper.m_patch_col_stride; 218 219 m_rowStride = base_mapper.m_rowStride; 220 m_colStride = base_mapper.m_colStride; 221 m_patchStride = base_mapper.m_patchStride; 222 m_otherStride = base_mapper.m_otherStride; 223 224 m_planeInputStride = base_mapper.m_planeInputStride; 225 m_rowInputStride = base_mapper.m_rowInputStride; 226 m_colInputStride = base_mapper.m_colInputStride; 227 m_patchInputStride = base_mapper.m_patchInputStride; 228 m_otherInputStride = base_mapper.m_otherInputStride; 229 230 m_inputDepth = base_mapper.m_inputDepth; 231 m_inputPlanes = base_mapper.m_inputPlanes; 232 m_inputRows = base_mapper.m_inputRows; 233 m_inputCols = base_mapper.m_inputCols; 234 235 m_outputPlanes = base_mapper.m_outputPlanes; 236 m_outputRows = base_mapper.m_outputRows; 237 m_outputCols = base_mapper.m_outputCols; 238 239 m_plane_strides = base_mapper.m_plane_strides; 240 m_row_strides = base_mapper.m_row_strides; 241 m_col_strides = base_mapper.m_col_strides; 242 243 m_in_plane_strides = base_mapper.m_in_plane_strides; 244 m_in_row_strides = base_mapper.m_in_row_strides; 245 m_in_col_strides = base_mapper.m_in_col_strides; 246 247 m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides; 248 m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; 249 m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; 250 251 m_planePaddingTop = base_mapper.m_planePaddingTop; 252 m_rowPaddingTop = base_mapper.m_rowPaddingTop; 253 m_colPaddingLeft = base_mapper.m_colPaddingLeft; 254 255 m_outputPlanesRows = base_mapper.m_outputPlanesRows; 256 257 m_fastNumPatches = base_mapper.m_fastNumPatches; 258 m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride; 259 m_fastPatchRowStride = base_mapper.m_fastPatchRowStride; 260 m_fastPatchColStride = base_mapper.m_fastPatchColStride; 261 m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride; 262 m_fastInputRowStride = base_mapper.m_fastInputRowStride; 263 m_fastInputColStride = base_mapper.m_fastInputColStride; 264 m_fastRowStride = base_mapper.m_fastRowStride; 265 m_fastColStride = base_mapper.m_fastColStride; 266 m_fastOutputPlanes = base_mapper.m_fastOutputPlanes; 267 m_fastOutputRows = base_mapper.m_fastOutputRows; 268 m_fastOutputCols = base_mapper.m_fastOutputCols; 269 m_fastDimZero = base_mapper.m_fastDimZero; 270 m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows; 271 } 272 273 // If true, turns off some optimizations for loading packets since the image 274 // patches are "non-standard" such as there are non-trivial strides or 275 // inflations in the input. 276 EIGEN_DEVICE_FUNC nonStandardPatches()277 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { 278 return m_in_plane_strides != 1 || m_in_row_strides != 1 || 279 m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 || 280 m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; 281 } 282 283 EIGEN_DEVICE_FUNC getSubMapper(Index i,Index j)284 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { 285 return SubMapper(*this, i, j); 286 } 287 288 EIGEN_DEVICE_FUNC getLinearMapper(Index i,Index j)289 EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { 290 return LinearMapper(*this, i, j); 291 } 292 293 EIGEN_DEVICE_FUNC operator()294 EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { 295 Index planeIndex, rowIndex, colIndex, otherIndex; 296 computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex); 297 return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex); 298 } 299 300 // Load the coefficient at the patchIndex location instead of the usual 301 // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the 302 // gpu code. 303 EIGEN_DEVICE_FUNC operator()304 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { 305 Index planeIndex, rowIndex, colIndex, otherIndex; 306 computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex); 307 return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex); 308 } 309 310 EIGEN_DEVICE_FUNC loadPacket(Index row)311 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { 312 Index planeIndex, rowIndex, colIndex, otherIndex; 313 computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex); 314 return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex); 315 } 316 317 // Load the packet at the patchIndex location instead of the usual m_rowIndex, 318 // m_colIndex, m_otherIndex. This is currently only used by the gpu code. 319 EIGEN_DEVICE_FUNC loadPacket(Index row,Index patchIndex)320 EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { 321 Index planeIndex, rowIndex, colIndex, otherIndex; 322 computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex); 323 return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex); 324 } 325 326 EIGEN_DEVICE_FUNC impl()327 EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const { 328 return m_impl; 329 } 330 331 EIGEN_DEVICE_FUNC patchDepth()332 EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; } 333 EIGEN_DEVICE_FUNC patchPlanes()334 EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; } 335 EIGEN_DEVICE_FUNC patchRows()336 EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; } 337 EIGEN_DEVICE_FUNC patchCols()338 EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } 339 340 private: 341 friend class TensorContractionSubMapper< 342 Scalar, Index, Side, 343 TensorEvaluator<const TensorReshapingOp< 344 NewDimension, const TensorVolumePatchOp< 345 Planes, Rows, Cols, ArgType> >, 346 Device>, 347 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 348 inner_dim_reordered, Alignment>; 349 350 // Load coefficient from a patch specified by the "within patch offset" 351 // (patchId) and the precomputed indices of the first element of the patch. 352 EIGEN_DEVICE_FUNC loadCoeff(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)353 EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex, 354 Index rowIndex, Index colIndex, 355 Index otherIndex) const { 356 // Find the offset of the element wrt the location of the first element. 357 const Index patchOffset = patchId / m_fastDimZero; 358 359 const Index colOffset = patchOffset / m_fastColStride; 360 const Index inputCol = colIndex + colOffset * m_in_col_strides; 361 const Index origInputCol = 362 (m_patch_col_inflate_strides == 1) 363 ? inputCol 364 : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); 365 366 const Index rowOffset = 367 (patchOffset - colOffset * m_colStride) / m_fastRowStride; 368 const Index inputRow = rowIndex + rowOffset * m_in_row_strides; 369 const Index origInputRow = 370 (m_patch_row_inflate_strides == 1) 371 ? inputRow 372 : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); 373 374 const Index planeOffset = 375 patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; 376 const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides; 377 const Index origInputPlane = 378 (m_patch_plane_inflate_strides == 1) 379 ? inputPlane 380 : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0); 381 382 if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 || 383 origInputCol >= m_inputCols || origInputRow >= m_inputRows || 384 origInputPlane >= m_inputPlanes || 385 (inputCol != origInputCol * m_patch_col_inflate_strides) || 386 (inputRow != origInputRow * m_patch_row_inflate_strides) || 387 (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) { 388 return Scalar(0); 389 } 390 391 const Index depth = patchId - patchOffset * patchDepth(); 392 const Index inputIndex = depth + origInputPlane * m_planeInputStride + 393 origInputRow * m_rowInputStride + 394 origInputCol * m_colInputStride + otherIndex; 395 396 return m_impl.coeff(inputIndex); 397 } 398 399 // This is the same as loadCoeff(...), but optimized for all `inflate_strides` 400 // and `in_strides` equal to 1 (template specialization without templates). 401 EIGEN_DEVICE_FUNC loadCoeffStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)402 EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex, 403 Index rowIndex, Index colIndex, 404 Index otherIndex) const { 405 eigen_assert(!nonStandardPatches()); 406 407 // Find the offset of the element wrt the location of the first element. 408 const Index patchOffset = patchId / m_fastDimZero; 409 410 const Index colOffset = patchOffset / m_fastColStride; 411 const Index rowOffset = 412 (patchOffset - colOffset * m_colStride) / m_fastRowStride; 413 const Index planeOffset = 414 patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; 415 416 const Index inputCol = colIndex + colOffset; 417 const Index inputRow = rowIndex + rowOffset; 418 const Index inputPlane = planeIndex + planeOffset; 419 420 if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || 421 inputRow >= m_inputRows || inputPlane < 0 || 422 inputPlane >= m_inputPlanes) { 423 return Scalar(0); 424 } 425 426 const Index depth = patchId - patchOffset * patchDepth(); 427 const Index inputIndex = depth + inputPlane * m_planeInputStride + 428 inputRow * m_rowInputStride + 429 inputCol * m_colInputStride + otherIndex; 430 431 return m_impl.coeff(inputIndex); 432 } 433 434 // Load packet from a patch specified by the "within patch offset" 435 // (patchId) and the precomputed indices of the first element of the patch. 436 EIGEN_DEVICE_FUNC loadPacket(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)437 EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex, 438 Index rowIndex, Index colIndex, 439 Index otherIndex) const { 440 const Index packetSize = internal::unpacket_traits<Packet>::size; 441 442 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 443 eigen_assert(patchId < 444 patchDepth() * patchPlanes() * patchRows() * patchCols()); 445 446 if (nonStandardPatches()) { 447 return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, 448 otherIndex); 449 } 450 typedef decltype(m_impl) TensorEvaluatorT; 451 return loadPacketStandard<Packet, TensorEvaluatorT>( 452 patchId, planeIndex, rowIndex, colIndex, otherIndex); 453 } 454 455 // Helper function to load a 'partial' packet - this is the single row part of 456 // a packet that is split across two rows (but single column). In the 457 // 'partial' packet, the elements corresponding to the row (specified through 458 // rowOffset) are loaded and the rest of the elements are zero-filled into the 459 // 'partial' packet. This function is called from 460 // loadPacketStandardFromSingleColumnTwoRows(). This code path is exercised 461 // only when the packet type supports masked load and when the partial packet 462 // load is available in the TensorEvaluator. 463 EIGEN_DEVICE_FUNC loadPartialPacketStandard(Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex,Index patchId,const Index span[],const Index patchOffsets[],Index colOffset,Index rowOffset)464 EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard( 465 Index planeIndex, Index rowIndex, Index colIndex, Index otherIndex, 466 Index patchId, const Index span[], const Index patchOffsets[], 467 Index colOffset, Index rowOffset) const { 468 const Index inputCol = colIndex + colOffset; 469 const Index inputRow = rowIndex + rowOffset; 470 const Index planeOffsets[2] = { 471 patchOffsets[0] - colOffset * m_colStride - rowOffset * m_rowStride, 472 patchOffsets[1] - colOffset * m_colStride - rowOffset * m_rowStride}; 473 const Index inputPlanes[2] = {planeIndex + planeOffsets[0], 474 planeIndex + planeOffsets[1]}; 475 476 if (inputRow >= m_inputRows || inputRow < 0 || inputCol >= m_inputCols || 477 inputCol < 0 || inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) { 478 // Partial packet is all zeros 479 return internal::pset1<Packet>(Scalar(0)); 480 } else if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) { 481 // From inputIndex-span[0], we need to load elements starting from index 482 // span[0] all the way upto (and including) span[1]. 483 const Index depth = patchId - patchOffsets[0] * patchDepth(); 484 const Index inputIndex = depth + inputPlanes[0] * m_planeInputStride + 485 inputRow * m_rowInputStride + 486 inputCol * m_colInputStride + otherIndex; 487 return m_impl.template partialPacket<Packet>( 488 inputIndex - span[0], mask<Packet>(span[0], span[1] + 1)); 489 } else { 490 // Using slow path for this partial packet. 491 // We need to load elements starting from index span[0] all the way upto 492 // (and including) span[1]. We split this load into 3 parts: 493 // 0 : span[0]-1 - Zeros will be loaded for these indices 494 // span[0] : span[1] - Elements will be loaded here for these indices 495 // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices 496 const Index packetSize = internal::unpacket_traits<Packet>::size; 497 EIGEN_ALIGN_MAX 498 std::remove_const_t<Scalar> values[packetSize]; 499 for (int i = 0; i < span[0]; ++i) values[i] = Scalar(0); 500 for (int i = span[0]; i < span[1] + 1; ++i) 501 values[i] = loadCoeff(patchId - span[0] + i, planeIndex, rowIndex, 502 colIndex, otherIndex); 503 for (int i = span[1] + 1; i < packetSize; ++i) values[i] = Scalar(0); 504 return internal::pload<Packet>(values); 505 } 506 } 507 508 // Helper function to load a packet that is split across two rows (but single 509 // column). If required, this function is called from loadPacketStandard() 510 // when the packet type supports masked load and when the partial packet load 511 // is available in the TensorEvaluator. 512 EIGEN_DEVICE_FUNC loadPacketStandardFromSingleColumnTwoRows(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[],const Index rowOffsets[])513 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumnTwoRows( 514 Index patchId, Index planeIndex, Index rowIndex, Index colIndex, 515 Index otherIndex, const Index patchOffsets[], const Index colOffsets[], 516 const Index rowOffsets[]) const { 517 eigen_assert(colOffsets[1] == colOffsets[0] && 518 rowOffsets[1] == rowOffsets[0] + 1); 519 const Index packetSize = internal::unpacket_traits<Packet>::size; 520 521 // Packet to load will be split into 2 parts where each part spans a single 522 // row and both the parts span the same column. 523 // First determine where to split. 524 const Index patchIdSplit = 525 (((rowOffsets[1] * m_rowStride) + (colOffsets[0] * m_colStride)) * 526 m_patch_depth) - 527 1; 528 const Index patchOffsetSplit = patchIdSplit / m_fastDimZero; 529 530 // patchIds[i]: patchId corresponding to partial packet i 531 // spans[i]: Start and end indices corresponding to the elements 532 // to be loaded for partial packet i 533 // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i 534 const Index patchIds[2] = {patchId, patchIdSplit + 1}; 535 const Index spans[2][2] = {{0, patchIdSplit - patchId}, 536 {patchIdSplit - patchId + 1, packetSize - 1}}; 537 const Index patchOffsets2Cols[2][2] = { 538 {patchOffsets[0], patchOffsetSplit}, 539 {patchOffsetSplit + 1, patchOffsets[1]}}; 540 541 // Load partial packets and do bit-wise OR to generate required packet 542 return internal::por<Packet>( 543 loadPartialPacketStandard(planeIndex, rowIndex, colIndex, otherIndex, 544 patchIds[0], spans[0], patchOffsets2Cols[0], 545 colOffsets[0], rowOffsets[0]), 546 loadPartialPacketStandard(planeIndex, rowIndex, colIndex, otherIndex, 547 patchIds[1], spans[1], patchOffsets2Cols[1], 548 colOffsets[1], rowOffsets[1])); 549 } 550 551 // Helper function to load a packet that is present in a single column and 552 // row. If required, this function is called from loadPacketStandard(). 553 EIGEN_DEVICE_FUNC loadPacketStandardFromSingleColumnSingleRow(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex,const Index patchOffsets[],const Index colOffsets[],const Index rowOffsets[],const Index inputCols[],const Index inputRows[])554 EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumnSingleRow( 555 Index patchId, Index planeIndex, Index rowIndex, Index colIndex, 556 Index otherIndex, const Index patchOffsets[], const Index colOffsets[], 557 const Index rowOffsets[], const Index inputCols[], 558 const Index inputRows[]) const { 559 eigen_assert(colOffsets[1] == colOffsets[0] && 560 rowOffsets[1] == rowOffsets[0]); 561 const Index planeOffsets[2] = { 562 patchOffsets[0] - colOffsets[0] * m_colStride - 563 rowOffsets[0] * m_rowStride, 564 patchOffsets[1] - colOffsets[1] * m_colStride - 565 rowOffsets[1] * m_rowStride}; 566 eigen_assert(planeOffsets[0] <= planeOffsets[1]); 567 const Index inputPlanes[2] = {planeIndex + planeOffsets[0], 568 planeIndex + planeOffsets[1]}; 569 570 if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) { 571 return internal::pset1<Packet>(Scalar(0)); 572 } 573 if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) { 574 const Index depth = patchId - patchOffsets[0] * patchDepth(); 575 const Index inputIndex = depth + inputPlanes[0] * m_planeInputStride + 576 inputRows[0] * m_rowInputStride + 577 inputCols[0] * m_colInputStride + otherIndex; 578 return m_impl.template packet<Unaligned>(inputIndex); 579 } 580 return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, 581 otherIndex); 582 } 583 584 // Load standard packet from a patch specified by the "within patch offset" 585 // (patchId) and the precomputed indices of the first element of the patch. 586 // This function will be called if partial packet loading is not available 587 // for the TensorEvaluator or if the packet type does not support masked 588 // load. 589 template <typename PacketT, typename TensorEvaluatorT> 590 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< 591 !TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, 592 PacketT>::type loadPacketStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)593 loadPacketStandard(Index patchId, Index planeIndex, Index rowIndex, 594 Index colIndex, Index otherIndex) const { 595 const Index packetSize = internal::unpacket_traits<Packet>::size; 596 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 597 eigen_assert(patchId < 598 patchDepth() * patchPlanes() * patchRows() * patchCols()); 599 eigen_assert(!nonStandardPatches()); 600 601 if ((patchDepth() % packetSize) == 0) { 602 return loadPacketFast(patchId, planeIndex, rowIndex, colIndex, 603 otherIndex); 604 } else { 605 // Offsets and input calculation here are identical to 606 // loadCoeffStandard(...), but repeated twice. 607 608 const Index patchOffsets[2] = { 609 patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; 610 611 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, 612 patchOffsets[1] / m_fastColStride}; 613 eigen_assert(colOffsets[0] <= colOffsets[1]); 614 615 const Index inputCols[2] = {colIndex + colOffsets[0], 616 colIndex + colOffsets[1]}; 617 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { 618 return internal::pset1<Packet>(Scalar(0)); 619 } 620 621 if (inputCols[0] == inputCols[1]) { 622 const Index rowOffsets[2] = { 623 (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride, 624 (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride}; 625 eigen_assert(rowOffsets[0] <= rowOffsets[1]); 626 const Index inputRows[2] = {rowIndex + rowOffsets[0], 627 rowIndex + rowOffsets[1]}; 628 629 if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { 630 return internal::pset1<Packet>(Scalar(0)); 631 } 632 633 if (inputRows[0] == inputRows[1]) { 634 return loadPacketStandardFromSingleColumnSingleRow( 635 patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets, 636 colOffsets, rowOffsets, inputCols, inputRows); 637 } 638 } 639 } 640 641 return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, 642 otherIndex); 643 } 644 645 // Load standard packet from a patch specified by the "within patch offset" 646 // (patchId) and the precomputed indices of the first element of the patch. 647 // This function will be called if partial packet loading is available for 648 // the TensorEvaluator and if the packet type supports masked load. 649 // The only difference between this and the other case is that if the packet 650 // to load is split across two rows (but in same column), then in this case 651 // instead of going to the slow (element-by-element) load, we load two packets 652 // - each containing elements from one of the rows (rest of the elements of 653 // the packets are zeroes), and then combine these two packets to generate the 654 // required packet. The idea is to enable fast load (if possible) of these 655 // 'partial' packets. 656 template <typename PacketT, typename TensorEvaluatorT> 657 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< 658 TensorEvaluatorHasPartialPacket<TensorEvaluatorT, PacketT, Index>::value, 659 PacketT>::type loadPacketStandard(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)660 loadPacketStandard(Index patchId, Index planeIndex, Index rowIndex, 661 Index colIndex, Index otherIndex) const { 662 const Index packetSize = internal::unpacket_traits<Packet>::size; 663 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 664 eigen_assert(patchId < 665 patchDepth() * patchPlanes() * patchRows() * patchCols()); 666 eigen_assert(!nonStandardPatches()); 667 668 if ((patchDepth() % packetSize) == 0) { 669 return loadPacketFast(patchId, planeIndex, rowIndex, colIndex, 670 otherIndex); 671 } else { 672 // Offsets and input calculation here are identical to 673 // loadCoeffStandard(...), but repeated twice. 674 675 const Index patchOffsets[2] = { 676 patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; 677 678 const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, 679 patchOffsets[1] / m_fastColStride}; 680 eigen_assert(colOffsets[0] <= colOffsets[1]); 681 682 const Index inputCols[2] = {colIndex + colOffsets[0], 683 colIndex + colOffsets[1]}; 684 if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { 685 return internal::pset1<Packet>(Scalar(0)); 686 } 687 688 if (inputCols[0] == inputCols[1]) { 689 const Index rowOffsets[2] = { 690 (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride, 691 (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride}; 692 eigen_assert(rowOffsets[0] <= rowOffsets[1]); 693 const Index inputRows[2] = {rowIndex + rowOffsets[0], 694 rowIndex + rowOffsets[1]}; 695 696 if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { 697 return internal::pset1<Packet>(Scalar(0)); 698 } 699 700 if (inputRows[0] == inputRows[1]) { 701 return loadPacketStandardFromSingleColumnSingleRow( 702 patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets, 703 colOffsets, rowOffsets, inputCols, inputRows); 704 } 705 if (inputRows[0] + 1 == inputRows[1]) { 706 return loadPacketStandardFromSingleColumnTwoRows( 707 patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets, 708 colOffsets, rowOffsets); 709 } 710 } 711 } 712 713 return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, 714 otherIndex); 715 } 716 717 EIGEN_DEVICE_FUNC loadPacketFast(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)718 EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex, 719 Index rowIndex, Index colIndex, 720 Index otherIndex) const { 721 const Index packetSize = internal::unpacket_traits<Packet>::size; 722 EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) 723 eigen_assert(patchId < 724 patchDepth() * patchPlanes() * patchRows() * patchCols()); 725 726 eigen_assert(!nonStandardPatches()); 727 eigen_assert((patchDepth() % packetSize) == 0); 728 729 // Find the offset of the element wrt the location of the first element. 730 const Index patchOffset = patchId / m_fastDimZero; 731 eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); 732 733 const Index colOffset = patchOffset / m_fastColStride; 734 const Index rowOffset = 735 (patchOffset - colOffset * m_colStride) / m_fastRowStride; 736 const Index planeOffset = 737 patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; 738 739 const Index inputCol = colIndex + colOffset; 740 const Index inputRow = rowIndex + rowOffset; 741 const Index inputPlane = planeIndex + planeOffset; 742 743 if (inputCol < 0 || inputRow < 0 || inputPlane < 0 || 744 inputCol >= m_inputCols || inputRow >= m_inputRows || 745 inputPlane >= m_inputPlanes) { 746 return internal::pset1<Packet>(Scalar(0)); 747 } 748 749 const Index depth = patchId - patchOffset * patchDepth(); 750 const Index inputIndex = depth + inputPlane * m_planeInputStride + 751 inputRow * m_rowInputStride + 752 inputCol * m_colInputStride + otherIndex; 753 return m_impl.template packet<Unaligned>(inputIndex); 754 } 755 756 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetWithPossibleZero(Index patchId,Index planeIndex,Index rowIndex,Index colIndex,Index otherIndex)757 packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex, 758 Index colIndex, Index otherIndex) const { 759 const int packetSize = internal::unpacket_traits<Packet>::size; 760 EIGEN_ALIGN_MAX 761 std::remove_const_t<Scalar> values[packetSize]; 762 for (int i = 0; i < packetSize; ++i) { 763 values[i] = 764 loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex); 765 } 766 Packet rslt = internal::pload<Packet>(values); 767 return rslt; 768 } 769 770 // Precompute the indices (plane, row, col, other) of the first element of 771 // the given patch index, within the output tensor of the TensorVolumePatchOp. computeBaseIndices(Index patchIndex,Index & planeIndex,Index & rowIndex,Index & colIndex,Index & otherIndex)772 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( 773 Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex, 774 Index& otherIndex) const { 775 const size_t NumInputDims = array_size< 776 typename TensorEvaluator<ArgType, Device>::Dimensions>::value; 777 778 // Check if patchIndex might contain batch and other dimensions. 779 otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches; 780 781 // Compute index of the patch within the batch (and other dimensions). 782 const Index patch3DIndex = (NumInputDims == 4) 783 ? patchIndex 784 : (patchIndex - otherIndex * m_num_patches); 785 786 otherIndex *= m_patchInputStride; 787 788 colIndex = patch3DIndex / m_fastOutputPlanesRows; 789 rowIndex = 790 (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; 791 planeIndex = 792 patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes; 793 794 colIndex = colIndex * m_col_strides - m_colPaddingLeft; 795 rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; 796 planeIndex = planeIndex * m_plane_strides - m_planePaddingTop; 797 } 798 799 Index m_patch_depth; // number of channels in the patch 800 Index m_patch_planes; // number of planes in the patch 801 Index m_patch_rows; // number of rows in the patch 802 Index m_patch_cols; // number of columns in the patch 803 Index m_num_patches; // number of patches to extract 804 805 // Strides for navigating through the single patch. 806 Index m_patch_plane_stride; 807 Index m_patch_row_stride; 808 Index m_patch_col_stride; 809 810 // Strides for the output tensor (depth is not the part of the stride). 811 Index m_rowStride; 812 Index m_colStride; 813 Index m_patchStride; 814 Index m_otherStride; 815 816 Index m_planeInputStride; // Plane stride in the input tensor 817 Index m_rowInputStride; // Row stride in the input tensor 818 Index m_colInputStride; // Col stride in the input tensor 819 Index m_patchInputStride; // Patch stride in the input tensor 820 Index m_otherInputStride; 821 822 Index m_inputDepth; // Depth of the input tensor 823 Index m_inputPlanes; // Number of planes in the input tensor 824 Index m_inputRows; // Number of rows in the input tensor 825 Index m_inputCols; // Number of cols in the input tensor 826 827 Index m_outputPlanes; // Number of output planes 828 Index m_outputRows; // Number of output rows 829 Index m_outputCols; // Number of output cols 830 Index m_outputPlanesRows; // Cached outputPlanes * outputRows. 831 832 Index m_plane_strides; // User specified plane stride 833 Index m_row_strides; // User specified row stride 834 Index m_col_strides; // User specified col stride 835 836 // User specified plane/row/col atrous convolution strides. 837 Index m_in_plane_strides; 838 Index m_in_row_strides; 839 Index m_in_col_strides; 840 841 // User specified plane/row/col inflation strides in the image patch. 842 Index m_patch_plane_inflate_strides; 843 Index m_patch_row_inflate_strides; 844 Index m_patch_col_inflate_strides; 845 846 Index m_planePaddingTop; // Plane padding 847 Index m_rowPaddingTop; // Row padding 848 Index m_colPaddingLeft; // Column padding 849 850 // Fast representation of various divisors. 851 internal::TensorIntDivisor<Index> m_fastNumPatches; 852 853 internal::TensorIntDivisor<Index> m_fastPatchPlaneStride; 854 internal::TensorIntDivisor<Index> m_fastPatchRowStride; 855 internal::TensorIntDivisor<Index> m_fastPatchColStride; 856 857 internal::TensorIntDivisor<Index> m_fastInputPlaneStride; 858 internal::TensorIntDivisor<Index> m_fastInputRowStride; 859 internal::TensorIntDivisor<Index> m_fastInputColStride; 860 861 internal::TensorIntDivisor<Index> m_fastRowStride; 862 internal::TensorIntDivisor<Index> m_fastColStride; 863 864 internal::TensorIntDivisor<Index> m_fastDimZero; // aka output depth 865 internal::TensorIntDivisor<Index> m_fastOutputPlanes; 866 internal::TensorIntDivisor<Index> m_fastOutputRows; 867 internal::TensorIntDivisor<Index> m_fastOutputCols; 868 internal::TensorIntDivisor<Index> m_fastOutputPlanesRows; 869 870 const TensorEvaluator<ArgType, Device> m_impl; 871 }; 872 873 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 874 typename ArgType, typename Device, typename Scalar, typename Index, 875 typename nocontract_t, typename contract_t, int Side, int packet_size, 876 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 877 class TensorContractionSubMapper< 878 Scalar, Index, Side, 879 TensorEvaluator<const TensorReshapingOp<NewDimension, 880 const TensorVolumePatchOp< 881 Planes, Rows, Cols, ArgType> >, 882 Device>, 883 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 884 inner_dim_reordered, Alignment> { 885 public: 886 typedef typename packet_traits<Scalar>::type Packet; 887 typedef typename packet_traits<Scalar>::half HalfPacket; 888 889 typedef TensorContractionInputMapper< 890 Scalar, Index, Side, 891 TensorEvaluator<const TensorReshapingOp< 892 NewDimension, const TensorVolumePatchOp< 893 Planes, Rows, Cols, ArgType> >, 894 Device>, 895 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 896 inner_dim_reordered, Alignment> 897 ParentMapper; 898 typedef TensorContractionSubMapper< 899 Scalar, Index, Side, 900 TensorEvaluator<const TensorReshapingOp< 901 NewDimension, const TensorVolumePatchOp< 902 Planes, Rows, Cols, ArgType> >, 903 Device>, 904 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 905 inner_dim_reordered, Alignment> 906 Self; 907 typedef Self LinearMapper; 908 TensorContractionSubMapper(const ParentMapper & base_mapper,Index vert_offset,Index horiz_offset)909 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( 910 const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) 911 : m_base_mapper(base_mapper), 912 m_depth_offset(vert_offset), 913 m_col_offset(horiz_offset) { 914 m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex, 915 m_colIndex, m_otherIndex); 916 } TensorContractionSubMapper(const Self & base_mapper,Index vert_offset,Index horiz_offset)917 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( 918 const Self& base_mapper, Index vert_offset, Index horiz_offset) 919 : m_base_mapper(base_mapper.m_base_mapper), 920 m_depth_offset(vert_offset + base_mapper.m_depth_offset), 921 m_col_offset(horiz_offset + base_mapper.m_col_offset) { 922 m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex, 923 m_colIndex, m_otherIndex); 924 } operator()925 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { 926 return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex, 927 m_colIndex, m_otherIndex); 928 } operator()929 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, 930 Index j) const { 931 return m_base_mapper(i + m_depth_offset, j + m_col_offset); 932 } 933 loadPacket(Index i)934 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { 935 return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex, 936 m_rowIndex, m_colIndex, m_otherIndex); 937 } loadPacket(Index i,Index j)938 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, 939 Index j) const { 940 return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, 941 j + m_col_offset); 942 } 943 944 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar loadCoeffStandard(Index i)945 loadCoeffStandard(Index i) const { 946 return m_base_mapper.loadCoeffStandard( 947 i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex); 948 } 949 loadPacketFast(Index i)950 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const { 951 return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex, 952 m_rowIndex, m_colIndex, m_otherIndex); 953 } 954 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index i)955 loadPacketStandard(Index i) const { 956 typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT; 957 return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>( 958 i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex); 959 } 960 template <typename Packet> aligned(Index)961 EIGEN_DEVICE_FUNC bool aligned(Index) const { 962 return false; 963 } 964 965 EIGEN_DEVICE_FUNC nonStandardPatches()966 EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { 967 return m_base_mapper.nonStandardPatches(); 968 } 969 970 // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row, 971 // plane and depth index respectively that fits into the peeled_k elements 972 // starting at m_depth_offset. 973 974 EIGEN_DEVICE_FUNC maxCol(const Index peeled_k)975 EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const { 976 const Index max_col = 977 fastPatchColStride().divide(m_depth_offset + peeled_k); 978 return std::min<Index>(1 + max_col, patchCols()); 979 } 980 981 EIGEN_DEVICE_FUNC maxRow(const Index peeled_k,const Index col)982 EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, 983 const Index col) const { 984 const Index max_row = fastPatchRowStride().divide( 985 m_depth_offset + peeled_k - col * patchColStride()); 986 return std::min<Index>(1 + max_row, patchRows()); 987 } 988 989 EIGEN_DEVICE_FUNC maxPlane(const Index peeled_k,const Index col,const Index row)990 EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col, 991 const Index row) const { 992 const Index max_plane = fastPatchPlaneStride().divide( 993 m_depth_offset + peeled_k - col * patchColStride() - 994 row * patchRowStride()); 995 return std::min<Index>(1 + max_plane, patchPlanes()); 996 } 997 998 // MaxDepth uses only the remaining number of elements in the peeled_k. 999 EIGEN_DEVICE_FUNC maxDepth(const Index num_elements,const Index start_depth)1000 EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, 1001 const Index start_depth) const { 1002 return std::min<Index>(start_depth + num_elements, patchDepth()); 1003 } 1004 1005 // Every register matters in this code, so sometimes to prevent register 1006 // spilling, instead of the variable that you would expect to see, we use 1007 // another one, that is guaranteed to have the same value. E.g. patch depth is 1008 // always the same as input depth, and it's also the same as input plane 1009 // stride. Bunch of other parameters have similar relations. 1010 1011 typedef internal::TensorIntDivisor<Index> IndexDivisor; 1012 1013 EIGEN_DEVICE_FUNC patchDepth()1014 EIGEN_ALWAYS_INLINE Index patchDepth() const { 1015 eigen_assert(m_base_mapper.m_patch_depth == 1016 m_base_mapper.m_planeInputStride && 1017 "Patch depth must be equal to plane input stride."); 1018 return m_base_mapper.m_planeInputStride; 1019 } 1020 1021 EIGEN_DEVICE_FUNC patchPlanes()1022 EIGEN_ALWAYS_INLINE Index patchPlanes() const { 1023 eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride && 1024 "Patch planes must be equal to row stride."); 1025 return m_base_mapper.m_rowStride; 1026 } 1027 EIGEN_DEVICE_FUNC patchRows()1028 EIGEN_ALWAYS_INLINE Index patchRows() const { 1029 return m_base_mapper.m_patch_rows; 1030 } 1031 EIGEN_DEVICE_FUNC patchCols()1032 EIGEN_ALWAYS_INLINE Index patchCols() const { 1033 return m_base_mapper.m_patch_cols; 1034 } 1035 1036 EIGEN_DEVICE_FUNC patchPlaneStride()1037 EIGEN_ALWAYS_INLINE Index patchPlaneStride() const { 1038 eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride && 1039 "Patch depth must be equal to patch plane stride."); 1040 return patchDepth(); 1041 } 1042 EIGEN_DEVICE_FUNC patchRowStride()1043 EIGEN_ALWAYS_INLINE Index patchRowStride() const { 1044 return m_base_mapper.m_patch_row_stride; 1045 } 1046 EIGEN_DEVICE_FUNC patchColStride()1047 EIGEN_ALWAYS_INLINE Index patchColStride() const { 1048 return m_base_mapper.m_patch_col_stride; 1049 } 1050 1051 EIGEN_DEVICE_FUNC fastPatchPlaneStride()1052 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const { 1053 eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride && 1054 "Patch depth must be equal to patch plane stride."); 1055 return m_base_mapper.m_fastDimZero; // patch_depth 1056 } 1057 EIGEN_DEVICE_FUNC fastPatchRowStride()1058 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const { 1059 return m_base_mapper.m_fastPatchRowStride; 1060 } 1061 EIGEN_DEVICE_FUNC fastPatchColStride()1062 EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const { 1063 return m_base_mapper.m_fastPatchColStride; 1064 } 1065 1066 EIGEN_DEVICE_FUNC packetNoPadding(const Index depth,const Index baseIndex)1067 EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, 1068 const Index baseIndex) const { 1069 const Index inputIndex = depth + baseIndex; 1070 return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex); 1071 } 1072 EIGEN_DEVICE_FUNC coeffNoPadding(const Index depth,const Index baseIndex)1073 EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, 1074 const Index baseIndex) const { 1075 const Index inputIndex = depth + baseIndex; 1076 return m_base_mapper.m_impl.coeff(inputIndex); 1077 } 1078 1079 EIGEN_DEVICE_FUNC padPlane(const Index plane)1080 EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const { 1081 const Index p = m_planeIndex + plane; 1082 return p < 0 || p >= m_base_mapper.m_inputPlanes; 1083 } 1084 EIGEN_DEVICE_FUNC padRow(const Index row)1085 EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { 1086 const Index r = m_rowIndex + row; 1087 return r < 0 || r >= m_base_mapper.m_inputRows; 1088 } 1089 EIGEN_DEVICE_FUNC padCol(const Index col)1090 EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { 1091 const Index c = m_colIndex + col; 1092 return c < 0 || c >= m_base_mapper.m_inputCols; 1093 } 1094 EIGEN_DEVICE_FUNC baseIndex(const Index plane,const Index row,const Index col)1095 EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row, 1096 const Index col) const { 1097 const Index p = m_planeIndex + plane; 1098 const Index r = m_rowIndex + row; 1099 const Index c = m_colIndex + col; 1100 return p * m_base_mapper.m_planeInputStride + 1101 r * m_base_mapper.m_rowInputStride + 1102 c * m_base_mapper.m_colInputStride + m_otherIndex; 1103 } 1104 1105 EIGEN_DEVICE_FUNC planeOffset()1106 EIGEN_ALWAYS_INLINE Index planeOffset() const { 1107 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 1108 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 1109 const Index rowOffset = 1110 (patchOffset - colOffset * m_base_mapper.m_colStride) / 1111 m_base_mapper.m_fastRowStride; 1112 const Index planeOffset = patchOffset - 1113 colOffset * m_base_mapper.m_colStride - 1114 rowOffset * m_base_mapper.m_rowStride; 1115 return planeOffset; 1116 } 1117 1118 EIGEN_DEVICE_FUNC rowOffset()1119 EIGEN_ALWAYS_INLINE Index rowOffset() const { 1120 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 1121 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 1122 const Index rowOffset = 1123 (patchOffset - colOffset * m_base_mapper.m_colStride) / 1124 m_base_mapper.m_fastRowStride; 1125 return rowOffset; 1126 } 1127 1128 EIGEN_DEVICE_FUNC colOffset()1129 EIGEN_ALWAYS_INLINE Index colOffset() const { 1130 const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; 1131 const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; 1132 return colOffset; 1133 } 1134 1135 EIGEN_DEVICE_FUNC depthOffset()1136 EIGEN_ALWAYS_INLINE Index depthOffset() const { 1137 return m_depth_offset % patchDepth(); 1138 } 1139 1140 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i,Index j)1141 getLinearMapper(Index i, Index j) const { 1142 return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset); 1143 } 1144 1145 private: 1146 const ParentMapper m_base_mapper; // Keeping a copy instead of a reference 1147 // performs better in benchmarks. 1148 1149 Index m_depth_offset; // First row in the input matrix 1150 Index m_col_offset; // First col in the input matrix 1151 1152 // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base 1153 // indices for the first element in a patch specified by col_offset 1154 // (see computeBaseIndices(...) for details). 1155 Index m_planeIndex; 1156 Index m_rowIndex; 1157 Index m_colIndex; 1158 Index m_otherIndex; 1159 }; 1160 1161 // Arrange a block of the right input matrix (in our case it's always a "virtual 1162 // matrix" constructed from extracted volume patches) in contiguous memory. 1163 // 1164 // Given column major input (A0 beside A1 in memory): 1165 // A0 B0 C0 D0 E0 F0 G0 H0 ... Z0 1166 // A1 B1 C1 D1 E1 F1 G1 H1 ... Z1 1167 // A2 B2 C2 D2 E2 F2 G2 H2 ... Z2 1168 // A3 B3 C3 D3 E3 F3 G3 H3 ... Z3 1169 // A4 B4 C4 D4 E4 F4 G4 H4 ... Z4 1170 // A5 B5 C5 D5 E5 F5 G5 H5 ... Z5 1171 // A6 B6 C6 D6 E6 F6 G6 H6 ... Z6 1172 // A7 B7 C7 D7 E7 F7 G7 H7 ... Z7 1173 // A8 ... 1174 // ... 1175 // 1176 // *) A, B, C, ... - patches extracted from the original input. 1177 // *) A0, A1, A2 ... - values from the same patch at different offsets. 1178 // 1179 // The traversal (packed rhs memory) order (B0 besides A0 in memory): 1180 // A0 B0 C0 D0 A1 B1 C1 D1 ... 1181 // E0 F0 G0 H0 E1 F1 G1 H1 ... 1182 // ... 1183 // Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4) 1184 // 1185 // This traversal order must be the same as in default gemm_pack_rhs defined in 1186 // GeneralBlockPanelKernel.h. 1187 // 1188 // *) nr - number of registers along the 'n' dimension. 1189 // See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix 1190 // Multiplication" paper. 1191 // 1192 // TODO(ezhulenev): Add support for squeezing reads along two innermost 1193 // dimensions (see eigen_spatial_convolutions). 1194 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 1195 typename ArgType, typename Device, typename Scalar, typename Index, 1196 typename nocontract_t, typename contract_t, int packet_size, 1197 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, 1198 int nr> 1199 struct gemm_pack_rhs< 1200 Scalar, Index, 1201 TensorContractionSubMapper< 1202 Scalar, Index, Rhs, 1203 TensorEvaluator<const TensorReshapingOp< 1204 NewDimension, const TensorVolumePatchOp< 1205 Planes, Rows, Cols, ArgType> >, 1206 Device>, 1207 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1208 inner_dim_reordered, Alignment>, 1209 nr, ColMajor, false, false> { 1210 typedef TensorContractionSubMapper< 1211 Scalar, Index, Rhs, 1212 TensorEvaluator<const TensorReshapingOp< 1213 NewDimension, const TensorVolumePatchOp< 1214 Planes, Rows, Cols, ArgType> >, 1215 Device>, 1216 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1217 inner_dim_reordered, Alignment> 1218 SubMapper; 1219 1220 typedef SubMapper DataMapper; 1221 typedef typename packet_traits<Scalar>::type Packet; 1222 1223 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); 1224 1225 EIGEN_DEVICE_FUNC 1226 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1227 Index depth, Index cols, Index stride = 0, 1228 Index offset = 0) const { 1229 eigen_assert(stride == 0); 1230 eigen_assert(offset == 0); 1231 1232 const Index packet_cols4 = (cols / 4) * 4; 1233 const Index peeled_k = (depth / packet_size) * packet_size; 1234 const bool non_standard_patches = rhs.nonStandardPatches(); 1235 1236 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1237 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1238 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1239 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1240 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1241 1242 Index k = 0; 1243 if ((packet_size % 4) == 0 && !non_standard_patches) { 1244 // FAST PATH: 1245 // Iterate over patch columns, rows and planes if we know that a single 1246 // packet do not span across multiple planes, rows or columns. 1247 if ((rhs.patchDepth() % packet_size) == 0) { 1248 const Index start_col = rhs.colOffset(); 1249 const Index max_col = rhs.maxCol(peeled_k); 1250 1251 for (Index c = start_col; c < max_col; ++c) { 1252 eigen_assert(k <= peeled_k); 1253 1254 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 1255 const Index max_row = rhs.maxRow(peeled_k, c); 1256 1257 const bool pad_col0 = dm0.padCol(c); 1258 const bool pad_col1 = dm1.padCol(c); 1259 const bool pad_col2 = dm2.padCol(c); 1260 const bool pad_col3 = dm3.padCol(c); 1261 1262 for (Index r = start_row; r < max_row; ++r) { 1263 eigen_assert(k <= peeled_k); 1264 1265 const Index start_plane = ((c == start_col) && (r == start_row)) 1266 ? rhs.planeOffset() 1267 : 0; 1268 const Index max_plane = rhs.maxPlane(peeled_k, c, r); 1269 1270 const bool pad_row0 = pad_col0 || dm0.padRow(r); 1271 const bool pad_row1 = pad_col1 || dm1.padRow(r); 1272 const bool pad_row2 = pad_col2 || dm2.padRow(r); 1273 const bool pad_row3 = pad_col3 || dm3.padRow(r); 1274 1275 for (Index p = start_plane; p < max_plane; ++p) { 1276 eigen_assert(k <= peeled_k); 1277 1278 const bool pad0 = pad_row0 || dm0.padPlane(p); 1279 const bool pad1 = pad_row1 || dm1.padPlane(p); 1280 const bool pad2 = pad_row2 || dm2.padPlane(p); 1281 const bool pad3 = pad_row3 || dm3.padPlane(p); 1282 1283 const Index idx0 = dm0.baseIndex(p, r, c); 1284 const Index idx1 = dm1.baseIndex(p, r, c); 1285 const Index idx2 = dm2.baseIndex(p, r, c); 1286 const Index idx3 = dm3.baseIndex(p, r, c); 1287 1288 const Index start_depth = 1289 ((c == start_col) && (r == start_row) && (p == start_plane)) 1290 ? rhs.depthOffset() 1291 : 0; 1292 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 1293 eigen_assert((max_depth - start_depth) % packet_size == 0); 1294 1295 for (Index d = start_depth; d < max_depth; d += packet_size) { 1296 eigen_assert(k < peeled_k); 1297 PacketBlock<Packet, 4> kernel; 1298 kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) 1299 : rhs.packetNoPadding(d, idx0); 1300 kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) 1301 : rhs.packetNoPadding(d, idx1); 1302 kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) 1303 : rhs.packetNoPadding(d, idx2); 1304 kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) 1305 : rhs.packetNoPadding(d, idx3); 1306 ptranspose(kernel); 1307 pstoreu(block + 0 * packet_size, kernel.packet[0]); 1308 pstoreu(block + 1 * packet_size, kernel.packet[1]); 1309 pstoreu(block + 2 * packet_size, kernel.packet[2]); 1310 pstoreu(block + 3 * packet_size, kernel.packet[3]); 1311 block += 4 * packet_size; 1312 k += packet_size; 1313 } 1314 } 1315 } 1316 } 1317 1318 // The loop above should fill peeled_k elements. 1319 eigen_assert(peeled_k == k); 1320 1321 } else { 1322 // Packet can span multiple planes, rows or columns, so we have to go 1323 // though the slower "standard" path. 1324 for (; k < peeled_k; k += packet_size) { 1325 PacketBlock<Packet, 4> kernel; 1326 kernel.packet[0] = dm0.loadPacketStandard(k); 1327 kernel.packet[1] = dm1.loadPacketStandard(k); 1328 kernel.packet[2] = dm2.loadPacketStandard(k); 1329 kernel.packet[3] = dm3.loadPacketStandard(k); 1330 ptranspose(kernel); 1331 pstoreu(block + 0 * packet_size, kernel.packet[0]); 1332 pstoreu(block + 1 * packet_size, kernel.packet[1]); 1333 pstoreu(block + 2 * packet_size, kernel.packet[2]); 1334 pstoreu(block + 3 * packet_size, kernel.packet[3]); 1335 block += 4 * packet_size; 1336 } 1337 } 1338 } 1339 1340 // Copy the remaining coefficients of the column block after the peeled_k. 1341 if (!non_standard_patches) { 1342 for (; k < depth; k++) { 1343 block[0] = dm0.loadCoeffStandard(k); 1344 block[1] = dm1.loadCoeffStandard(k); 1345 block[2] = dm2.loadCoeffStandard(k); 1346 block[3] = dm3.loadCoeffStandard(k); 1347 block += 4; 1348 } 1349 } else { 1350 for (; k < depth; k++) { 1351 block[0] = dm0(k); 1352 block[1] = dm1(k); 1353 block[2] = dm2(k); 1354 block[3] = dm3(k); 1355 block += 4; 1356 } 1357 } 1358 } 1359 1360 // Copy the remaining columns one at a time (nr==1). 1361 #if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) 1362 // remaining columns are handled different for PPC 1363 for (Index k = 0; k < depth; k++) { 1364 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1365 *block = rhs(k, j2); 1366 block += 1; 1367 } 1368 } 1369 #else 1370 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1371 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1372 for (Index k = 0; k < depth; k++) { 1373 *block = dm0(k); 1374 block += 1; 1375 } 1376 } 1377 #endif 1378 } 1379 }; 1380 1381 // Template specialization for packet_size = 2. We must special-case packet 1382 // blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>. 1383 // 1384 // TODO(ezhulenev): Add support for squeezing reads along two innermost 1385 // dimensions (see eigen_spatial_convolutions). 1386 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 1387 typename ArgType, typename Device, typename Scalar, typename Index, 1388 typename nocontract_t, typename contract_t, bool inner_dim_contiguous, 1389 bool inner_dim_reordered, int Alignment, int nr> 1390 struct gemm_pack_rhs< 1391 Scalar, Index, 1392 TensorContractionSubMapper< 1393 Scalar, Index, Rhs, 1394 TensorEvaluator<const TensorReshapingOp< 1395 NewDimension, const TensorVolumePatchOp< 1396 Planes, Rows, Cols, ArgType> >, 1397 Device>, 1398 nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous, 1399 inner_dim_reordered, Alignment>, 1400 nr, ColMajor, false, false> { 1401 typedef TensorContractionSubMapper< 1402 Scalar, Index, Rhs, 1403 TensorEvaluator<const TensorReshapingOp< 1404 NewDimension, const TensorVolumePatchOp< 1405 Planes, Rows, Cols, ArgType> >, 1406 Device>, 1407 nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous, 1408 inner_dim_reordered, Alignment> 1409 SubMapper; 1410 typedef SubMapper DataMapper; 1411 typedef typename packet_traits<Scalar>::type Packet; 1412 1413 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); 1414 1415 EIGEN_DEVICE_FUNC 1416 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1417 Index depth, Index cols, Index stride = 0, 1418 Index offset = 0) const { 1419 eigen_assert(stride == 0); 1420 eigen_assert(offset == 0); 1421 1422 const int packet_size = 2; 1423 1424 const Index packet_cols4 = (cols / 4) * 4; 1425 const Index peeled_k = (depth / packet_size) * packet_size; 1426 const bool non_standard_patches = rhs.nonStandardPatches(); 1427 1428 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1429 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1430 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1431 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1432 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1433 1434 Index k = 0; 1435 if (!non_standard_patches) { 1436 // FAST PATH: 1437 // Iterate over patch columns, rows and planes if we know that a single 1438 // packet do not span across multiple planes, rows or columns. 1439 if ((rhs.patchDepth() % packet_size) == 0) { 1440 const Index start_col = rhs.colOffset(); 1441 const Index max_col = rhs.maxCol(peeled_k); 1442 1443 for (Index c = start_col; c < max_col; ++c) { 1444 eigen_assert(k <= peeled_k); 1445 1446 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 1447 const Index max_row = rhs.maxRow(peeled_k, c); 1448 1449 const bool pad_col0 = dm0.padCol(c); 1450 const bool pad_col1 = dm1.padCol(c); 1451 const bool pad_col2 = dm2.padCol(c); 1452 const bool pad_col3 = dm3.padCol(c); 1453 1454 for (Index r = start_row; r < max_row; ++r) { 1455 eigen_assert(k <= peeled_k); 1456 1457 const Index start_plane = ((c == start_col) && (r == start_row)) 1458 ? rhs.planeOffset() 1459 : 0; 1460 const Index max_plane = rhs.maxPlane(peeled_k, c, r); 1461 1462 const bool pad_row0 = dm0.padRow(r); 1463 const bool pad_row1 = dm1.padRow(r); 1464 const bool pad_row2 = dm2.padRow(r); 1465 const bool pad_row3 = dm3.padRow(r); 1466 1467 for (Index p = start_plane; p < max_plane; ++p) { 1468 eigen_assert(k <= peeled_k); 1469 1470 const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p); 1471 const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p); 1472 const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p); 1473 const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p); 1474 1475 const Index idx0 = dm0.baseIndex(p, r, c); 1476 const Index idx1 = dm1.baseIndex(p, r, c); 1477 const Index idx2 = dm2.baseIndex(p, r, c); 1478 const Index idx3 = dm3.baseIndex(p, r, c); 1479 1480 const Index start_depth = 1481 ((c == start_col) && (r == start_row) && (p == start_plane)) 1482 ? rhs.depthOffset() 1483 : 0; 1484 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 1485 eigen_assert((max_depth - start_depth) % packet_size == 0); 1486 1487 for (Index d = start_depth; d < max_depth; d += packet_size) { 1488 eigen_assert(k < peeled_k); 1489 PacketBlock<Packet, 2> kernel0; 1490 PacketBlock<Packet, 2> kernel1; 1491 kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) 1492 : rhs.packetNoPadding(d, idx0); 1493 kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) 1494 : rhs.packetNoPadding(d, idx1); 1495 kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) 1496 : rhs.packetNoPadding(d, idx2); 1497 kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) 1498 : rhs.packetNoPadding(d, idx3); 1499 ptranspose(kernel0); 1500 ptranspose(kernel1); 1501 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1502 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1503 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1504 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1505 block += 4 * packet_size; 1506 k += packet_size; 1507 } 1508 } 1509 } 1510 } 1511 1512 // The loop above should fill peeled_k elements. 1513 eigen_assert(peeled_k == k); 1514 1515 } else { 1516 for (; k < peeled_k; k += packet_size) { 1517 PacketBlock<Packet, 2> kernel0; 1518 PacketBlock<Packet, 2> kernel1; 1519 kernel0.packet[0] = dm0.loadPacketStandard(k); 1520 kernel0.packet[1] = dm1.loadPacketStandard(k); 1521 kernel1.packet[0] = dm2.loadPacketStandard(k); 1522 kernel1.packet[1] = dm3.loadPacketStandard(k); 1523 ptranspose(kernel0); 1524 ptranspose(kernel1); 1525 pstoreu(block + 0 * packet_size, kernel0.packet[0]); 1526 pstoreu(block + 1 * packet_size, kernel1.packet[0]); 1527 pstoreu(block + 2 * packet_size, kernel0.packet[1]); 1528 pstoreu(block + 3 * packet_size, kernel1.packet[1]); 1529 block += 4 * packet_size; 1530 } 1531 } 1532 } 1533 1534 // Copy the remaining coefficients of the column block after the peeled_k. 1535 if (!rhs.nonStandardPatches()) { 1536 for (; k < depth; k++) { 1537 block[0] = dm0.loadCoeffStandard(k); 1538 block[1] = dm1.loadCoeffStandard(k); 1539 block[2] = dm2.loadCoeffStandard(k); 1540 block[3] = dm3.loadCoeffStandard(k); 1541 block += 4; 1542 } 1543 } else { 1544 for (; k < depth; k++) { 1545 block[0] = dm0(k); 1546 block[1] = dm1(k); 1547 block[2] = dm2(k); 1548 block[3] = dm3(k); 1549 block += 4; 1550 } 1551 } 1552 } 1553 1554 // Copy the remaining columns one at a time (nr==1). 1555 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1556 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1557 for (Index k = 0; k < depth; k++) { 1558 *block = dm0(k); 1559 block += 1; 1560 } 1561 } 1562 } 1563 }; 1564 1565 // Special case for non-vectorized types such as float16 (packet_size = 1). 1566 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 1567 typename ArgType, typename Device, typename Scalar, typename Index, 1568 typename nocontract_t, typename contract_t, bool inner_dim_contiguous, 1569 bool inner_dim_reordered, int Alignment, int nr> 1570 struct gemm_pack_rhs< 1571 Scalar, Index, 1572 TensorContractionSubMapper< 1573 Scalar, Index, Rhs, 1574 TensorEvaluator<const TensorReshapingOp< 1575 NewDimension, const TensorVolumePatchOp< 1576 Planes, Rows, Cols, ArgType> >, 1577 Device>, 1578 nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous, 1579 inner_dim_reordered, Alignment>, 1580 nr, ColMajor, false, false> { 1581 typedef TensorContractionSubMapper< 1582 Scalar, Index, Rhs, 1583 TensorEvaluator<const TensorReshapingOp< 1584 NewDimension, const TensorVolumePatchOp< 1585 Planes, Rows, Cols, ArgType> >, 1586 Device>, 1587 nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, 1588 Alignment> 1589 SubMapper; 1590 typedef SubMapper DataMapper; 1591 1592 EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); 1593 1594 EIGEN_DEVICE_FUNC 1595 EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, 1596 Index depth, Index cols, Index stride = 0, 1597 Index offset = 0) const { 1598 eigen_assert(stride == 0); 1599 eigen_assert(offset == 0); 1600 1601 const Index packet_cols4 = (cols / 4) * 4; 1602 1603 for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { 1604 const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); 1605 const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); 1606 const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); 1607 const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); 1608 1609 if (!rhs.nonStandardPatches()) { 1610 for (Index k = 0; k < depth; k++) { 1611 block[0] = dm0.loadCoeffStandard(k); 1612 block[1] = dm1.loadCoeffStandard(k); 1613 block[2] = dm2.loadCoeffStandard(k); 1614 block[3] = dm3.loadCoeffStandard(k); 1615 block += 4; 1616 } 1617 } else { 1618 for (Index k = 0; k < depth; k++) { 1619 block[0] = dm0(k); 1620 block[1] = dm1(k); 1621 block[2] = dm2(k); 1622 block[3] = dm3(k); 1623 block += 4; 1624 } 1625 } 1626 } 1627 1628 // Copy the remaining columns one at a time (nr==1). 1629 for (Index j2 = packet_cols4; j2 < cols; ++j2) { 1630 const SubMapper dm0 = rhs.getLinearMapper(0, j2); 1631 for (Index k = 0; k < depth; k++) { 1632 *block = dm0(k); 1633 block += 1; 1634 } 1635 } 1636 } 1637 }; 1638 #endif 1639 1640 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) 1641 // Pack a block of the right input matrix (in our case it's always a "virtual 1642 // matrix" constructed from extracted image patches) in contiguous block in 1643 // column-major storage order. Knowing the properties of the original patch op 1644 // we can do it more efficient than the default gemm_pack_colmajor_block. 1645 // 1646 // TODO(ezhulenev): gemm_pack_colmajor_block for spatial convolutions supports 1647 // squeezing reads along the 2 innermost dimensions, add it here if needed. 1648 template <typename NewDimension, Index Planes, Index Rows, Index Cols, 1649 typename ArgType, typename Device, typename Scalar, 1650 typename StorageIndex, typename nocontract_t, typename contract_t, 1651 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, 1652 int Alignment> 1653 struct gemm_pack_colmajor_block< 1654 Scalar, StorageIndex, 1655 TensorContractionSubMapper< 1656 Scalar, StorageIndex, Rhs, 1657 TensorEvaluator<const TensorReshapingOp< 1658 NewDimension, const TensorVolumePatchOp< 1659 Planes, Rows, Cols, ArgType> >, 1660 Device>, 1661 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1662 inner_dim_reordered, Alignment>, 1663 ColMajor> { 1664 typedef TensorContractionSubMapper< 1665 Scalar, StorageIndex, Rhs, 1666 TensorEvaluator<const TensorReshapingOp< 1667 NewDimension, const TensorVolumePatchOp< 1668 Planes, Rows, Cols, ArgType> >, 1669 Device>, 1670 nocontract_t, contract_t, packet_size, inner_dim_contiguous, 1671 inner_dim_reordered, Alignment> 1672 SubMapper; 1673 1674 typedef SubMapper DataMapper; 1675 typedef typename packet_traits<Scalar>::type Packet; 1676 1677 EIGEN_DONT_INLINE 1678 void operator()(Scalar* block, const DataMapper& rhs, StorageIndex rows, 1679 StorageIndex cols) { 1680 const bool standard_patches = !rhs.nonStandardPatches(); 1681 1682 if (standard_patches && rhs.patchDepth() % packet_size == 0) { 1683 packStandardPatches<true>(block, rhs, rows, cols); 1684 1685 } else if (standard_patches) { 1686 packStandardPatches<false>(block, rhs, rows, cols); 1687 1688 } else { 1689 // With non-standard patches we don't do any vectorized loads. 1690 // TODO(ezhulenev): It doesn't look like that we should completely give up 1691 // on packets. Make this code path faster! 1692 for (StorageIndex col = 0; col < cols; ++col) { 1693 SubMapper lm = rhs.getLinearMapper(0, col); 1694 for (StorageIndex i = 0; i < rows; ++i) { 1695 *block = lm(i); 1696 ++block; 1697 } 1698 } 1699 } 1700 } 1701 1702 private: 1703 // Pack standard volume patches: 1704 // 1705 // - patch_depth_is_multiple_of_packet_size=true: We are guaranteed to have 1706 // depth dimension size to be a multiple of packet size, so we can skip all 1707 // non vectorized loads and checks. 1708 // 1709 template <bool patch_depth_is_multiple_of_packet_size> 1710 EIGEN_ALWAYS_INLINE void packStandardPatches(Scalar* block, 1711 const DataMapper& rhs, 1712 StorageIndex rows, 1713 StorageIndex cols) { 1714 eigen_assert(!rhs.nonStandardPatches()); 1715 1716 // Give vectorized_rows the name used in all other gemm_pack_rhs above. 1717 const Index peeled_k = (rows / packet_size) * packet_size; 1718 1719 const Index start_col = rhs.colOffset(); 1720 const Index max_col = rhs.maxCol(peeled_k); 1721 1722 for (StorageIndex col = 0; col < cols; ++col) { 1723 SubMapper lm = rhs.getLinearMapper(0, col); 1724 1725 Index k = 0; 1726 for (Index c = start_col; c < max_col; ++c) { 1727 eigen_assert(k <= peeled_k); 1728 1729 const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; 1730 const Index max_row = rhs.maxRow(peeled_k, c); 1731 const bool pad_col = lm.padCol(c); 1732 1733 for (Index r = start_row; r < max_row; ++r) { 1734 eigen_assert(k <= peeled_k); 1735 1736 const Index start_plane = 1737 ((c == start_col) && (r == start_row)) ? rhs.planeOffset() : 0; 1738 const Index max_plane = rhs.maxPlane(peeled_k, c, r); 1739 const bool pad_row = pad_col || lm.padRow(r); 1740 1741 for (Index p = start_plane; p < max_plane; ++p) { 1742 eigen_assert(k <= peeled_k); 1743 1744 const Index start_depth = 1745 ((c == start_col) && (r == start_row) && (p == start_plane)) 1746 ? rhs.depthOffset() 1747 : 0; 1748 const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); 1749 1750 const bool pad = pad_col || pad_row || lm.padPlane(p); 1751 const Index base_idx = lm.baseIndex(p, r, c); 1752 1753 if (patch_depth_is_multiple_of_packet_size) 1754 eigen_assert((max_depth - start_depth) % packet_size == 0); 1755 1756 // If patch depth is a multiple of packet size, it's guaranteed that 1757 // we can process all values in depth dimension with packets. 1758 const Index max_vectorized_depth = 1759 patch_depth_is_multiple_of_packet_size 1760 ? max_depth 1761 : max_depth - packet_size; 1762 1763 Index d = start_depth; 1764 1765 // 1. Process depth dimension with vectorized instructions. 1766 for (; d < max_vectorized_depth; d += packet_size) { 1767 eigen_assert(k < peeled_k); 1768 const Packet packet = pad ? pset1<Packet>(Scalar(0)) 1769 : rhs.packetNoPadding(d, base_idx); 1770 internal::pstoreu(block, packet); 1771 block += packet_size; 1772 k += packet_size; 1773 } 1774 1775 // 2. Finish with coefficients. 1776 if (!patch_depth_is_multiple_of_packet_size) { 1777 for (; d < max_depth; d++) { 1778 eigen_assert(k < peeled_k); 1779 *block = pad ? Scalar(0) : rhs.coeffNoPadding(d, base_idx); 1780 ++block; 1781 ++k; 1782 } 1783 } 1784 } 1785 } 1786 } 1787 1788 // The loop above should fill peeled_k elements. 1789 eigen_assert(peeled_k == k); 1790 1791 // Fill remaining elements using loadCoeffStandard. 1792 for (; k < rows; ++k) { 1793 *block = lm.loadCoeffStandard(k); 1794 ++block; 1795 } 1796 } 1797 } 1798 }; 1799 #endif // defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) 1800 1801 } // namespace internal 1802 1803 /** CuboidConvolution 1804 * \ingroup CXX11_NeuralNetworks_Module 1805 * 1806 * \brief Applies a 3D convolution over a multichannel input voxel block. 1807 * 1808 * The input parameter is expected to be a tensor with a rank of 4 or more 1809 * (channels, depth, height, width, and optionally others). 1810 * The kernel parameter is expected to be a 5D tensor (filters, channels, 1811 * kernel_depth, kernel_height, kernel_width). 1812 * The result can be assigned to a tensor of rank equal to the rank of the 1813 * input. The dimensions of the result will be filters, depth, height, width 1814 * (and others if applicable). 1815 * 1816 * The input and kernel have to be in the same layout, and both row-major and 1817 * col-major are supported. The shapes given above are for col-major layout. 1818 * For row-major, all dimensions should be reversed. 1819 * 1820 * It is possible to swap the order of the depth, width, and height dimensions 1821 * provided that the same order is used in the input, the kernel, and the 1822 * output. 1823 */ 1824 template <typename Input, typename Kernel> 1825 EIGEN_ALWAYS_INLINE static const std::conditional_t< 1826 internal::traits<Input>::Layout == ColMajor, 1827 TensorReshapingOp< 1828 const DSizes<typename internal::traits<Input>::Index, 1829 internal::traits<Input>::NumDimensions>, 1830 const TensorContractionOp< 1831 const array<IndexPair<typename internal::traits<Input>::Index>, 1>, 1832 const TensorReshapingOp< 1833 const DSizes<typename internal::traits<Input>::Index, 2>, 1834 const Kernel>, 1835 const TensorReshapingOp< 1836 const DSizes<typename internal::traits<Input>::Index, 2>, 1837 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, 1838 const Input> > > >, 1839 TensorReshapingOp< 1840 const DSizes<typename internal::traits<Input>::Index, 1841 internal::traits<Input>::NumDimensions>, 1842 const TensorContractionOp< 1843 const array<IndexPair<typename internal::traits<Input>::Index>, 1>, 1844 const TensorReshapingOp< 1845 const DSizes<typename internal::traits<Input>::Index, 2>, 1846 const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, 1847 const Input> >, 1848 const TensorReshapingOp< 1849 const DSizes<typename internal::traits<Input>::Index, 2>, 1850 const Kernel> > > > 1851 CuboidConvolution(const Input& input, const Kernel& kernel, 1852 const Index stridePlanes = 1, const Index strideRows = 1, 1853 const Index strideCols = 1, 1854 const PaddingType padding_type = PADDING_SAME) { 1855 typedef typename internal::traits<Input>::Index TensorIndex; 1856 TensorRef<Tensor<typename internal::traits<Input>::Scalar, 1857 internal::traits<Input>::NumDimensions, 1858 internal::traits<Input>::Layout, TensorIndex> > 1859 in(input); 1860 TensorRef<Tensor<typename internal::traits<Kernel>::Scalar, 1861 internal::traits<Kernel>::NumDimensions, 1862 internal::traits<Kernel>::Layout, TensorIndex> > 1863 kern(kernel); 1864 1865 EIGEN_STATIC_ASSERT( 1866 internal::traits<Input>::Layout == internal::traits<Kernel>::Layout, 1867 YOU_MADE_A_PROGRAMMING_MISTAKE); 1868 static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor); 1869 static const int NumDims = internal::traits<Input>::NumDimensions; 1870 1871 // Number of filters to apply. This is the same as the output depth of the 1872 // result. 1873 const TensorIndex kernelFilters = 1874 isColMajor ? kern.dimensions()[0] : kern.dimensions()[4]; 1875 const TensorIndex kernelChannels = 1876 isColMajor ? kern.dimensions()[1] : kern.dimensions()[3]; 1877 1878 // Spatial size of the kernel. 1879 const TensorIndex kernelPlanes = 1880 isColMajor ? kern.dimensions()[2] : kern.dimensions()[2]; 1881 const TensorIndex kernelRows = 1882 isColMajor ? kern.dimensions()[3] : kern.dimensions()[1]; 1883 const TensorIndex kernelCols = 1884 isColMajor ? kern.dimensions()[4] : kern.dimensions()[0]; 1885 1886 if (isColMajor) { 1887 eigen_assert(kernelChannels == in.dimension(0)); 1888 } else { 1889 eigen_assert(kernelChannels == in.dimension(NumDims - 1)); 1890 } 1891 1892 const TensorIndex inputPlanes = 1893 isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); 1894 const TensorIndex inputRows = 1895 isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); 1896 const TensorIndex inputCols = 1897 isColMajor ? in.dimension(3) : in.dimension(NumDims - 4); 1898 1899 TensorIndex out_planes; 1900 TensorIndex out_height; 1901 TensorIndex out_width; 1902 switch (padding_type) { 1903 case PADDING_VALID: 1904 out_planes = Eigen::divup(inputPlanes - kernelPlanes + 1, 1905 static_cast<TensorIndex>(stridePlanes)); 1906 out_height = Eigen::divup(inputRows - kernelRows + 1, 1907 static_cast<TensorIndex>(strideRows)); 1908 out_width = Eigen::divup(inputCols - kernelCols + 1, 1909 static_cast<TensorIndex>(strideCols)); 1910 break; 1911 case PADDING_SAME: 1912 out_planes = 1913 Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes)); 1914 out_height = 1915 Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows)); 1916 out_width = Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols)); 1917 break; 1918 default: 1919 out_planes = 0; 1920 out_height = 0; 1921 out_width = 0; 1922 eigen_assert(false && "unexpected padding"); 1923 } 1924 1925 DSizes<TensorIndex, 2> kernel_dims; 1926 if (isColMajor) { 1927 kernel_dims[0] = kernelFilters; 1928 kernel_dims[1] = kernelChannels * kernelPlanes * kernelRows * kernelCols; 1929 } else { 1930 kernel_dims[0] = kernelChannels * kernelPlanes * kernelRows * kernelCols; 1931 kernel_dims[1] = kernelFilters; 1932 } 1933 1934 // Molds the output of the patch extraction result into a 2D tensor: 1935 // - the first dimension (dims[0]): the patch values to be multiplied with the 1936 // kernels 1937 // - the second dimension (dims[1]): everything else 1938 DSizes<TensorIndex, 2> pre_contract_dims; 1939 if (isColMajor) { 1940 pre_contract_dims[0] = 1941 kernelChannels * kernelPlanes * kernelRows * kernelCols; 1942 pre_contract_dims[1] = out_planes * out_height * out_width; 1943 for (int i = 4; i < NumDims; ++i) { 1944 pre_contract_dims[1] *= in.dimension(i); 1945 } 1946 } else { 1947 pre_contract_dims[1] = 1948 kernelChannels * kernelPlanes * kernelRows * kernelCols; 1949 pre_contract_dims[0] = out_planes * out_height * out_width; 1950 for (int i = 0; i < NumDims - 4; ++i) { 1951 pre_contract_dims[0] *= in.dimension(i); 1952 } 1953 } 1954 1955 array<IndexPair<TensorIndex>, 1> contract_dims; 1956 contract_dims[0] = IndexPair<TensorIndex>(1, 0); 1957 1958 // Molds the output of the contraction into the shape expected by the user 1959 // (assuming ColMajor): 1960 // - 1st dim: kernel filters 1961 // - 2nd dim: output depth 1962 // - 3nd dim: output height 1963 // - 4rd dim: output width 1964 // - 5th dim and beyond: everything else including batch size 1965 DSizes<TensorIndex, NumDims> post_contract_dims; 1966 if (isColMajor) { 1967 post_contract_dims[0] = kernelFilters; 1968 post_contract_dims[1] = out_planes; 1969 post_contract_dims[2] = out_height; 1970 post_contract_dims[3] = out_width; 1971 for (int i = 4; i < NumDims; ++i) { 1972 post_contract_dims[i] = in.dimension(i); 1973 } 1974 } else { 1975 post_contract_dims[NumDims - 1] = kernelFilters; 1976 post_contract_dims[NumDims - 2] = out_planes; 1977 post_contract_dims[NumDims - 3] = out_height; 1978 post_contract_dims[NumDims - 4] = out_width; 1979 for (int i = 0; i < NumDims - 4; ++i) { 1980 post_contract_dims[i] = in.dimension(i); 1981 } 1982 } 1983 1984 return choose( 1985 Cond<internal::traits<Input>::Layout == ColMajor>(), 1986 kernel.reshape(kernel_dims) 1987 .contract(input 1988 .extract_volume_patches( 1989 kernelPlanes, kernelRows, kernelCols, stridePlanes, 1990 strideRows, strideCols, padding_type) 1991 .reshape(pre_contract_dims), 1992 contract_dims) 1993 .reshape(post_contract_dims), 1994 input 1995 .extract_volume_patches(kernelPlanes, kernelRows, kernelCols, 1996 stridePlanes, strideRows, strideCols, 1997 padding_type) 1998 .reshape(pre_contract_dims) 1999 .contract(kernel.reshape(kernel_dims), contract_dims) 2000 .reshape(post_contract_dims)); 2001 } 2002 2003 } // end namespace Eigen 2004 2005 #endif // TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ 2006