1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <[email protected]> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 12 13 namespace Eigen { 14 15 /** \class TensorContraction 16 * \ingroup CXX11_Tensor_Module 17 * 18 * \brief Tensor contraction class. 19 * 20 * 21 */ 22 namespace internal { 23 24 template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType> 25 struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> > 26 { 27 // Type promotion to handle the case where the types of the lhs and the rhs are different. 28 typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type, 29 typename remove_const<typename RhsXprType::Scalar>::type>::ResScalar Scalar; 30 31 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind, 32 typename traits<RhsXprType>::StorageKind>::ret StorageKind; 33 typedef typename promote_index_type<typename traits<LhsXprType>::Index, 34 typename traits<RhsXprType>::Index>::type Index; 35 typedef typename LhsXprType::Nested LhsNested; 36 typedef typename RhsXprType::Nested RhsNested; 37 typedef typename remove_reference<LhsNested>::type _LhsNested; 38 typedef typename remove_reference<RhsNested>::type _RhsNested; 39 40 // From NumDims below. 41 static const int NumDimensions = traits<LhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value; 42 static const int Layout = traits<LhsXprType>::Layout; 43 typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val, 44 typename traits<LhsXprType>::PointerType, 45 typename traits<RhsXprType>::PointerType>::type 46 PointerType; 47 48 enum { 49 Flags = 0 50 }; 51 }; 52 53 template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType> 54 struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, Eigen::Dense> 55 { 56 typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>& type; 57 }; 58 59 template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType> 60 struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> >::type> 61 { 62 typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> type; 63 }; 64 65 template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename OutputKernelType_, typename Device_> 66 struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_, OutputKernelType_>, Device_> > { 67 typedef Indices_ Indices; 68 typedef LeftArgType_ LeftArgType; 69 typedef RightArgType_ RightArgType; 70 typedef OutputKernelType_ OutputKernelType; 71 typedef Device_ Device; 72 73 // From NumDims below. 74 static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value; 75 }; 76 77 // Helper class to allocate and deallocate temporary memory for packed buffers. 78 template <typename LhsScalar, typename RhsScalar> 79 struct TensorContractionBlockMemAllocator { 80 typedef void* BlockMemHandle; 81 82 template <typename Device> 83 EIGEN_DEVICE_FUNC static BlockMemHandle allocate(Device& d, const Index bm, 84 const Index bk, 85 const Index bn, 86 LhsScalar** lhs_block, 87 RhsScalar** rhs_block) { 88 eigen_assert(lhs_block); 89 eigen_assert(rhs_block); 90 BlockSizes sz = ComputeLhsRhsBlockSizes(bm, bk, bn); 91 char* block_mem = static_cast<char*>(d.allocate(sz.lhs_size + sz.rhs_size)); 92 eigen_assert(block_mem); 93 *lhs_block = reinterpret_cast<LhsScalar*>(block_mem); 94 *rhs_block = reinterpret_cast<RhsScalar*>(block_mem + sz.lhs_size); 95 return block_mem; 96 } 97 98 template <typename Device> 99 EIGEN_DEVICE_FUNC static BlockMemHandle allocateSlices( 100 Device& d, const Index bm, const Index bk, const Index bn, 101 const Index num_lhs, const Index num_rhs, const Index num_slices, 102 std::vector<LhsScalar*>* lhs_blocks, 103 std::vector<RhsScalar*>* rhs_blocks) { 104 eigen_assert(num_slices > 0); 105 eigen_assert(num_lhs >= 0 && num_rhs >= 0); 106 eigen_assert(num_lhs == 0 || lhs_blocks); 107 eigen_assert(num_rhs == 0 || rhs_blocks); 108 BlockSizes sz = ComputeLhsRhsBlockSizes(bm, bk, bn); 109 void* block_mem = d.allocate( 110 (num_lhs * sz.lhs_size + num_rhs * sz.rhs_size) * num_slices); 111 eigen_assert(block_mem); 112 char* mem = static_cast<char*>(block_mem); 113 114 for (Index x = 0; x < num_slices; x++) { 115 if (num_lhs > 0) lhs_blocks[x].resize(num_lhs); 116 for (Index m = 0; m < num_lhs; m++) { 117 lhs_blocks[x][m] = reinterpret_cast<LhsScalar*>(mem); 118 mem += sz.lhs_size; 119 } 120 if (num_rhs > 0) rhs_blocks[x].resize(num_rhs); 121 for (Index n = 0; n < num_rhs; n++) { 122 rhs_blocks[x][n] = reinterpret_cast<RhsScalar*>(mem); 123 mem += sz.rhs_size; 124 } 125 } 126 127 return block_mem; 128 } 129 130 template <typename Device> 131 EIGEN_DEVICE_FUNC static void deallocate(Device& d, BlockMemHandle handle) { 132 d.deallocate(handle); 133 } 134 135 private: 136 struct BlockSizes { 137 Index lhs_size; 138 Index rhs_size; 139 }; 140 EIGEN_DEVICE_FUNC static BlockSizes ComputeLhsRhsBlockSizes(const Index bm, 141 const Index bk, 142 const Index bn) { 143 Index align = numext::maxi(EIGEN_MAX_ALIGN_BYTES, 1); 144 BlockSizes sz; 145 sz.lhs_size = divup<Index>(bm * bk * sizeof(LhsScalar), align) * align; 146 sz.rhs_size = divup<Index>(bn * bk * sizeof(RhsScalar), align) * align; 147 return sz; 148 } 149 }; 150 151 // WARNING: In this code we assume that Lhs and Rhs tensor expressions are in 152 // ColMajor storage order. This property is guaranteed by the 153 // TensorContractionOp evaluator. TensorContractionKernel specifies how we pack 154 // blocks of Lhs and Rhs tensor expressions, and how we invoke matrix 155 // multiplication for these blocks. Default tensor contraction uses 156 // gemm_pack_rhs, gemm_pack_lhs and gebp_kernel from Eigen Core (see 157 // GeneralBlocPanelKernel.h for details). 158 // 159 // By specializing contraction kernels we can use other low level libraries to 160 // perform matrix multiplication, and still rely on Eigen contraction evaluator. 161 // This also includes full support in TensorContractionThreadPool, assuming that 162 // underlying gemm do not use it's own threading. 163 // 164 // - ResScalar/LhsScalar/RhsScalar - scalar type for the result of 165 // multiplication, lhs tensor and rhs tensor respectively. 166 // 167 // - StorageIndex - index type for the tensor expressions. In practice almost 168 // always is Eigen::Index. 169 // 170 // - OutputMapper provides access to the memory of the output matrix. In 171 // practice it's always column major blas_data_mapper (it must be of ResScalar 172 // type). 173 // 174 // - LhsMapper/RhsMapper similarly to blas_data_mapper provide a two dimensional 175 // view into the Lhs/Rhs tensor expressions. In practice it's 176 // TensorContractionInputMapper, or some specialization of it based on the 177 // type of tensor expression (e.g. TensorImagePatchOp has optimized input 178 // mapper). 179 template <typename ResScalar, typename LhsScalar, typename RhsScalar, 180 typename StorageIndex, typename OutputMapper, typename LhsMapper, 181 typename RhsMapper> 182 struct TensorContractionKernel { 183 // True if `invoke()` supports `beta` in `C <- alpha * A * B + beta * C` 184 // (otherwise beta should be always equal to 1). 185 enum { HasBeta = false }; 186 187 EIGEN_DEVICE_FUNC 188 TensorContractionKernel(StorageIndex m_, StorageIndex k_, StorageIndex n_, 189 StorageIndex bm_, StorageIndex bk_, StorageIndex bn_) 190 : m(m_), k(k_), n(n_), bm(bm_), bk(bk_), bn(bn_) {} 191 192 // Pack blocks of Lhs and Rhs into contiguous blocks in memory. 193 typedef LhsScalar* LhsBlock; 194 typedef RhsScalar* RhsBlock; 195 196 // Packed Lhs/Rhs block memory allocator. 197 typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar> 198 BlockMemAllocator; 199 typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle; 200 201 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits; 202 203 typedef internal::gemm_pack_lhs< 204 LhsScalar, StorageIndex, typename LhsMapper::SubMapper, Traits::mr, 205 Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor> 206 LhsPacker; 207 208 typedef internal::gemm_pack_rhs<RhsScalar, StorageIndex, 209 typename RhsMapper::SubMapper, Traits::nr, 210 ColMajor> 211 RhsPacker; 212 213 typedef internal::gebp_kernel<LhsScalar, RhsScalar, StorageIndex, 214 OutputMapper, Traits::mr, Traits::nr, 215 /*ConjugateLhs*/ false, /*ConjugateRhs*/ false> 216 GebpKernel; 217 218 template <typename Device> 219 EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block, 220 RhsBlock* rhs_block) { 221 return BlockMemAllocator::allocate(d, bm, bk, bn, lhs_block, rhs_block); 222 } 223 224 template <typename Device> 225 EIGEN_DEVICE_FUNC BlockMemHandle allocateSlices( 226 Device& d, const StorageIndex num_lhs, const StorageIndex num_rhs, 227 const StorageIndex num_slices, std::vector<LhsBlock>* lhs_blocks, 228 std::vector<RhsBlock>* rhs_blocks) { 229 return BlockMemAllocator::allocateSlices( 230 d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_blocks, rhs_blocks); 231 } 232 233 template <typename Device> 234 EIGEN_DEVICE_FUNC static void deallocate(Device& d, BlockMemHandle handle) { 235 BlockMemAllocator::deallocate(d, handle); 236 } 237 238 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs( 239 LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper, 240 const StorageIndex depth, const StorageIndex rows) { 241 LhsPacker()(*lhsBlock, data_mapper, depth, rows, /*stride*/ 0, 242 /*offset*/ 0); 243 } 244 245 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs( 246 RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper, 247 const StorageIndex depth, const StorageIndex cols) { 248 RhsPacker()(*rhsBlock, data_mapper, depth, cols); 249 } 250 251 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke( 252 const OutputMapper& output_mapper, const LhsBlock& lhsBlock, 253 const RhsBlock& rhsBlock, const StorageIndex rows, 254 const StorageIndex depth, const StorageIndex cols, 255 const ResScalar alpha, const ResScalar beta) { 256 // Default GEBP kernel does not support beta. 257 eigen_assert(beta == ResScalar(1)); 258 static const int kComputeStrideFromBlockDimensions = -1; 259 GebpKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha, 260 /*strideA*/ kComputeStrideFromBlockDimensions, 261 /*strideB*/ kComputeStrideFromBlockDimensions, 262 /*offsetA*/ 0, /*offsetB*/ 0); 263 } 264 265 private: 266 // These are dimensions of the original Tensors, and selected block sizes. The 267 // actual block sizes passed to all function above might be smaller because of 268 // the partial blocks at the end. 269 const StorageIndex m; 270 const StorageIndex k; 271 const StorageIndex n; 272 const StorageIndex bm; 273 const StorageIndex bk; 274 const StorageIndex bn; 275 }; 276 277 } // end namespace internal 278 279 // Tensor contraction params that should enable to get from output matrix 280 // 2-dimensional coordinates to the output tensor dimensions. 281 struct TensorContractionParams { 282 // TensorContraction evaluator assumes that both tensors are in ColMajor 283 // layout, if tensors are in RowMajor evaluator swap lhs with rhs. 284 bool swapped_arguments; 285 }; 286 287 // Output kernel allows to fuse operations into the tensor contraction. 288 // 289 // Examples: 290 // 1. Elementwise Relu transformation following Conv2D. 291 // 2. AddBias to the Conv2D output channels dimension. 292 // 293 // The NoOpOutputKernel implements an output kernel that does absolutely nothing. 294 struct NoOpOutputKernel { 295 /** 296 * Tensor contraction evaluator calls this kernel after finishing each block 297 * of output matrix. Output blocks belong to the 2-dimensional output tensor. 298 * 299 * TensorContractionParams contains contraction dimensions information 300 * required to map output 2-d space into the expected output tensor space 301 * (potentially higher dimensional). 302 * 303 * \param[in] output_mapper Access to output tensor memory 304 * \param[in] params Tensor contraction parameters 305 * \param[in] i Index of a first row available through output_mapper 306 * \param[in] j Index of a first column available through output_mapper 307 * \param[in] num_rows Number of available rows 308 * \param[in] num_cols Number of available columns 309 */ 310 template <typename Index, typename Scalar> 311 EIGEN_ALWAYS_INLINE void operator()( 312 const internal::blas_data_mapper<Scalar, Index, ColMajor>& output_mapper, 313 const TensorContractionParams& params, Index i, 314 Index j, Index num_rows, Index num_cols) const { 315 EIGEN_UNUSED_VARIABLE(output_mapper); 316 EIGEN_UNUSED_VARIABLE(params); 317 EIGEN_UNUSED_VARIABLE(i); 318 EIGEN_UNUSED_VARIABLE(j); 319 EIGEN_UNUSED_VARIABLE(num_rows); 320 EIGEN_UNUSED_VARIABLE(num_cols); 321 } 322 }; 323 324 template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType = const NoOpOutputKernel> 325 class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>, ReadOnlyAccessors> 326 { 327 public: 328 typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar; 329 typedef typename internal::gebp_traits<typename LhsXprType::CoeffReturnType, 330 typename RhsXprType::CoeffReturnType>::ResScalar CoeffReturnType; 331 typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested; 332 typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind; 333 typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index; 334 335 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp( 336 const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims, 337 const OutputKernelType& output_kernel = OutputKernelType()) 338 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims), 339 m_output_kernel(output_kernel) {} 340 341 EIGEN_DEVICE_FUNC 342 const Indices& indices() const { return m_indices; } 343 344 /** \returns the nested expressions */ 345 EIGEN_DEVICE_FUNC 346 const typename internal::remove_all<typename LhsXprType::Nested>::type& 347 lhsExpression() const { return m_lhs_xpr; } 348 349 EIGEN_DEVICE_FUNC 350 const typename internal::remove_all<typename RhsXprType::Nested>::type& 351 rhsExpression() const { return m_rhs_xpr; } 352 353 EIGEN_DEVICE_FUNC 354 const OutputKernelType& outputKernel() const { return m_output_kernel; } 355 356 protected: 357 typename LhsXprType::Nested m_lhs_xpr; 358 typename RhsXprType::Nested m_rhs_xpr; 359 const Indices m_indices; 360 const OutputKernelType m_output_kernel; 361 }; 362 363 364 template<typename Derived> 365 struct TensorContractionEvaluatorBase : internal::no_assignment_operator 366 { 367 typedef typename internal::traits<Derived>::Indices Indices; 368 typedef typename internal::traits<Derived>::LeftArgType LeftArgType; 369 typedef typename internal::traits<Derived>::RightArgType RightArgType; 370 typedef typename internal::traits<Derived>::OutputKernelType OutputKernelType; 371 typedef typename internal::traits<Derived>::Device Device; 372 373 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType; 374 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; 375 typedef typename XprType::Index Index; 376 typedef typename XprType::CoeffReturnType CoeffReturnType; 377 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 378 typedef StorageMemory<Scalar, Device> Storage; 379 typedef typename Storage::Type EvaluatorPointerType; 380 381 enum { 382 IsAligned = true, 383 PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1), 384 BlockAccess = false, 385 PreferBlockAccess = false, 386 Layout = TensorEvaluator<LeftArgType, Device>::Layout, 387 CoordAccess = false, // to be implemented 388 RawAccess = true 389 }; 390 391 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// 392 typedef internal::TensorBlockNotImplemented TensorBlock; 393 //===--------------------------------------------------------------------===// 394 395 // Most of the code is assuming that both input tensors are ColMajor. If the 396 // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: 397 // If we want to compute A * B = C, where A is LHS and B is RHS, the code 398 // will pretend B is LHS and A is RHS. 399 typedef typename internal::conditional< 400 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; 401 typedef typename internal::conditional< 402 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; 403 404 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluatorType; 405 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluatorType; 406 407 static const int LDims = 408 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value; 409 static const int RDims = 410 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; 411 static const int ContractDims = internal::array_size<Indices>::value; 412 static const int NumDims = LDims + RDims - 2 * ContractDims; 413 414 typedef array<Index, ContractDims> contract_t; 415 typedef array<Index, LDims - ContractDims> left_nocontract_t; 416 typedef array<Index, RDims - ContractDims> right_nocontract_t; 417 418 typedef DSizes<Index, NumDims> Dimensions; 419 420 EIGEN_STRONG_INLINE 421 TensorContractionEvaluatorBase(const XprType& op, const Device& device) 422 : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), 423 op.lhsExpression(), op.rhsExpression()), device), 424 m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), 425 op.rhsExpression(), op.lhsExpression()), device), 426 m_device(device), 427 m_output_kernel(op.outputKernel()), 428 m_result(NULL) { 429 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == 430 static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)), 431 YOU_MADE_A_PROGRAMMING_MISTAKE); 432 433 434 DSizes<Index, LDims> eval_left_dims; 435 DSizes<Index, RDims> eval_right_dims; 436 array<IndexPair<Index>, ContractDims> eval_op_indices; 437 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { 438 // For ColMajor, we keep using the existing dimensions 439 for (int i = 0; i < LDims; i++) { 440 eval_left_dims[i] = m_leftImpl.dimensions()[i]; 441 } 442 for (int i = 0; i < RDims; i++) { 443 eval_right_dims[i] = m_rightImpl.dimensions()[i]; 444 } 445 // We keep the pairs of contracting indices. 446 for (int i = 0; i < ContractDims; i++) { 447 eval_op_indices[i].first = op.indices()[i].first; 448 eval_op_indices[i].second = op.indices()[i].second; 449 } 450 } else { 451 // For RowMajor, we need to reverse the existing dimensions 452 for (int i = 0; i < LDims; i++) { 453 eval_left_dims[i] = m_leftImpl.dimensions()[LDims - i - 1]; 454 } 455 for (int i = 0; i < RDims; i++) { 456 eval_right_dims[i] = m_rightImpl.dimensions()[RDims - i - 1]; 457 } 458 // We need to flip all the pairs of contracting indices as well as 459 // reversing the dimensions. 460 for (int i = 0; i < ContractDims; i++) { 461 eval_op_indices[i].first = LDims - 1 - op.indices()[ContractDims - 1 - i].second; 462 eval_op_indices[i].second = RDims - 1 - op.indices()[ContractDims - 1 - i].first; 463 } 464 } 465 466 // Check for duplicate axes and make sure the first index in eval_op_indices 467 // is increasing. Using O(n^2) sorting is OK since ContractDims is small 468 for (int i = 0; i < ContractDims; i++) { 469 for (int j = i + 1; j < ContractDims; j++) { 470 eigen_assert(eval_op_indices[j].first != eval_op_indices[i].first && 471 eval_op_indices[j].second != eval_op_indices[i].second && 472 "contraction axes should be unique"); 473 if (eval_op_indices[j].first < eval_op_indices[i].first) { 474 numext::swap(eval_op_indices[j], eval_op_indices[i]); 475 } 476 } 477 } 478 479 array<Index, LDims> lhs_strides; 480 lhs_strides[0] = 1; 481 for (int i = 0; i < LDims-1; ++i) { 482 lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i]; 483 } 484 485 array<Index, RDims> rhs_strides; 486 rhs_strides[0] = 1; 487 for (int i = 0; i < RDims-1; ++i) { 488 rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i]; 489 } 490 491 if (m_i_strides.size() > 0) m_i_strides[0] = 1; 492 if (m_j_strides.size() > 0) m_j_strides[0] = 1; 493 if (m_k_strides.size() > 0) m_k_strides[0] = 1; 494 495 m_i_size = 1; 496 m_j_size = 1; 497 m_k_size = 1; 498 499 // To compute the dimension, we simply concatenate the non-contracting 500 // dimensions of the left and then the right tensor. Additionally, we also 501 // compute the strides corresponding to the left non-contracting 502 // dimensions and right non-contracting dimensions. 503 m_lhs_inner_dim_contiguous = true; 504 int dim_idx = 0; 505 Index nocontract_idx = 0; 506 507 for (int i = 0; i < LDims; i++) { 508 // find if we are contracting on index i of left tensor 509 bool contracting = false; 510 for (int j = 0; j < ContractDims; j++) { 511 if (eval_op_indices[j].first == i) { 512 contracting = true; 513 break; 514 } 515 } 516 if (!contracting) { 517 // add dimension size to output dimensions 518 m_dimensions[dim_idx] = eval_left_dims[i]; 519 m_left_nocontract_strides[nocontract_idx] = lhs_strides[i]; 520 if (dim_idx != i) { 521 m_lhs_inner_dim_contiguous = false; 522 } 523 if (nocontract_idx+1 < internal::array_size<left_nocontract_t>::value) { 524 m_i_strides[nocontract_idx+1] = 525 m_i_strides[nocontract_idx] * eval_left_dims[i]; 526 } else { 527 m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[i]; 528 } 529 dim_idx++; 530 nocontract_idx++; 531 } 532 } 533 534 nocontract_idx = 0; 535 for (int i = 0; i < RDims; i++) { 536 bool contracting = false; 537 // find if we are contracting on index i of right tensor 538 for (int j = 0; j < ContractDims; j++) { 539 if (eval_op_indices[j].second == i) { 540 contracting = true; 541 break; 542 } 543 } 544 if (!contracting) { 545 m_dimensions[dim_idx] = eval_right_dims[i]; 546 if (nocontract_idx+1 < internal::array_size<right_nocontract_t>::value) { 547 m_j_strides[nocontract_idx+1] = 548 m_j_strides[nocontract_idx] * eval_right_dims[i]; 549 } else { 550 m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[i]; 551 } 552 m_right_nocontract_strides[nocontract_idx] = rhs_strides[i]; 553 dim_idx++; 554 nocontract_idx++; 555 } 556 } 557 558 // Now compute the strides corresponding to the contracting dimensions. We 559 // assumed above that non-contracting axes are represented in the same order 560 // in the matrix as they are in the tensor. This is not the case for 561 // contracting axes. As the contracting axes must be of the same size in 562 // each tensor, we'll only look at the first tensor here. 563 m_rhs_inner_dim_contiguous = true; 564 m_rhs_inner_dim_reordered = false; 565 for (int i = 0; i < ContractDims; i++) { 566 Index left = eval_op_indices[i].first; 567 Index right = eval_op_indices[i].second; 568 569 Index size = eval_left_dims[left]; 570 eigen_assert(size == eval_right_dims[right] && 571 "Contraction axes must be same size"); 572 573 if (i+1 < static_cast<int>(internal::array_size<contract_t>::value)) { 574 m_k_strides[i+1] = m_k_strides[i] * size; 575 } else { 576 m_k_size = m_k_strides[i] * size; 577 } 578 m_left_contracting_strides[i] = lhs_strides[left]; 579 m_right_contracting_strides[i] = rhs_strides[right]; 580 581 if (i > 0 && right < eval_op_indices[i-1].second) { 582 m_rhs_inner_dim_reordered = true; 583 } 584 if (right != i) { 585 m_rhs_inner_dim_contiguous = false; 586 } 587 } 588 589 // If the layout is RowMajor, we need to reverse the m_dimensions 590 if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) { 591 for (int i = 0, j = NumDims - 1; i < j; i++, j--) { 592 numext::swap(m_dimensions[i], m_dimensions[j]); 593 } 594 } 595 596 // A set of parameters that will allow output kernel to get from output 597 // tensor dimensions (i, j) into the original tensor dimensions. 598 // TODO(ezhulenev): Add parameters required to infer output tensor index for 599 // more complex contractions than 2x2 on internal dimension. 600 m_tensor_contraction_params.swapped_arguments = static_cast<int>(Layout) == RowMajor; 601 } 602 603 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 604 605 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) { 606 m_leftImpl.evalSubExprsIfNeeded(NULL); 607 m_rightImpl.evalSubExprsIfNeeded(NULL); 608 if (data) { 609 evalTo(data); 610 return false; 611 } else { 612 m_result = static_cast<EvaluatorPointerType>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); 613 evalTo(m_result); 614 return true; 615 } 616 } 617 618 #ifdef EIGEN_USE_THREADS 619 template <typename EvalSubExprsCallback> 620 EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync( 621 EvaluatorPointerType dest, EvalSubExprsCallback done) { 622 m_leftImpl.evalSubExprsIfNeededAsync(nullptr, [this, done, dest](bool) { 623 m_rightImpl.evalSubExprsIfNeededAsync(nullptr, [this, done, dest](bool) { 624 if (dest) { 625 evalToAsync(dest, [done]() { done(false); }); 626 } else { 627 m_result = static_cast<EvaluatorPointerType>( 628 m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); 629 evalToAsync(m_result, [done]() { done(true); }); 630 } 631 }); 632 }); 633 } 634 #endif // EIGEN_USE_THREADS 635 636 #ifndef TENSOR_CONTRACTION_DISPATCH 637 #define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ 638 if (this->m_lhs_inner_dim_contiguous) { \ 639 if (this->m_rhs_inner_dim_contiguous) { \ 640 if (this->m_rhs_inner_dim_reordered) { \ 641 METHOD<true, true, true, ALIGNMENT> ARGS; \ 642 } else { \ 643 METHOD<true, true, false, ALIGNMENT> ARGS; \ 644 } \ 645 } else { \ 646 if (this->m_rhs_inner_dim_reordered) { \ 647 METHOD<true, false, true, ALIGNMENT> ARGS; \ 648 } else { \ 649 METHOD<true, false, false, ALIGNMENT> ARGS; \ 650 } \ 651 } \ 652 } else { \ 653 if (this->m_rhs_inner_dim_contiguous) { \ 654 if (this->m_rhs_inner_dim_reordered) { \ 655 METHOD<false, true, true, ALIGNMENT> ARGS; \ 656 } else { \ 657 METHOD<false, true, false, ALIGNMENT> ARGS; \ 658 } \ 659 } else { \ 660 if (this->m_rhs_inner_dim_reordered) { \ 661 METHOD<false, false, true, ALIGNMENT> ARGS; \ 662 } else { \ 663 METHOD<false, false, false, ALIGNMENT> ARGS; \ 664 } \ 665 } \ 666 } 667 #endif 668 669 #ifndef TENSOR_CONTRACTION_ASYNC_DISPATCH 670 #define TENSOR_CONTRACTION_ASYNC_DISPATCH(METHOD, DONE, ALIGNMENT, ARGS, FN) \ 671 if (this->m_lhs_inner_dim_contiguous) { \ 672 if (this->m_rhs_inner_dim_contiguous) { \ 673 if (this->m_rhs_inner_dim_reordered) { \ 674 (new METHOD<DONE, true, true, true, ALIGNMENT> ARGS)->FN; \ 675 } else { \ 676 (new METHOD<DONE, true, true, false, ALIGNMENT> ARGS)->FN; \ 677 } \ 678 } else { \ 679 if (this->m_rhs_inner_dim_reordered) { \ 680 (new METHOD<DONE, true, false, true, ALIGNMENT> ARGS)->FN; \ 681 } else { \ 682 (new METHOD<DONE, true, false, false, ALIGNMENT> ARGS)->FN; \ 683 } \ 684 } \ 685 } else { \ 686 if (this->m_rhs_inner_dim_contiguous) { \ 687 if (this->m_rhs_inner_dim_reordered) { \ 688 (new METHOD<DONE, false, true, true, ALIGNMENT> ARGS)->FN; \ 689 } else { \ 690 (new METHOD<DONE, false, true, false, ALIGNMENT> ARGS)->FN; \ 691 } \ 692 } else { \ 693 if (this->m_rhs_inner_dim_reordered) { \ 694 (new METHOD<DONE, false, false, true, ALIGNMENT> ARGS)->FN; \ 695 } else { \ 696 (new METHOD<DONE, false, false, false, ALIGNMENT> ARGS)->FN; \ 697 } \ 698 } \ 699 } 700 #endif 701 702 EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const { 703 static_cast<const Derived*>(this)->template evalProduct<Unaligned>(buffer); 704 } 705 706 #ifdef EIGEN_USE_THREADS 707 template <typename EvalToCallback> 708 void evalToAsync(Scalar* buffer, EvalToCallback done) const { 709 static_cast<const Derived*>(this) 710 ->template evalProductAsync<EvalToCallback, Unaligned>(buffer, 711 std::move(done)); 712 } 713 #endif // EIGEN_USE_THREADS 714 715 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, 716 bool rhs_inner_dim_reordered, int Alignment> 717 void evalProductSequential(Scalar* buffer) const { 718 if (this->m_j_size == 1) { 719 this->template evalGemv<lhs_inner_dim_contiguous, 720 rhs_inner_dim_contiguous, rhs_inner_dim_reordered, 721 Alignment>(buffer); 722 } else { 723 this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, 724 rhs_inner_dim_reordered, Alignment>(buffer); 725 } 726 } 727 728 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> 729 #if !defined(EIGEN_HIPCC) 730 EIGEN_DEVICE_FUNC 731 #endif 732 void evalGemv(Scalar* buffer) const { 733 const Index rows = m_i_size; 734 const Index cols = m_k_size; 735 736 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; 737 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; 738 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; 739 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; 740 const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size; 741 const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size; 742 const int lhs_alignment = LeftEvaluator::IsAligned ? Aligned : Unaligned; 743 const int rhs_alignment = RightEvaluator::IsAligned ? Aligned : Unaligned; 744 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, 745 LeftEvaluator, left_nocontract_t, 746 contract_t, lhs_packet_size, 747 lhs_inner_dim_contiguous, 748 false, lhs_alignment> LhsMapper; 749 750 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, 751 RightEvaluator, right_nocontract_t, 752 contract_t, rhs_packet_size, 753 rhs_inner_dim_contiguous, 754 rhs_inner_dim_reordered, rhs_alignment> RhsMapper; 755 756 LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides, 757 m_left_contracting_strides, m_k_strides); 758 RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides, 759 m_right_contracting_strides, m_k_strides); 760 761 const Scalar alpha(1); 762 const Index resIncr(1); 763 764 // zero out the result buffer (which must be of size at least rows * sizeof(Scalar) 765 m_device.memset(buffer, 0, rows * sizeof(Scalar)); 766 767 internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run( 768 rows, cols, lhs, rhs, 769 buffer, resIncr, alpha); 770 771 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; 772 m_output_kernel(OutputMapper(buffer, rows), m_tensor_contraction_params, 773 static_cast<Index>(0), static_cast<Index>(0), rows, 774 static_cast<Index>(1)); 775 } 776 777 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> 778 #if !defined(EIGEN_HIPCC) 779 EIGEN_DEVICE_FUNC 780 #endif 781 void evalGemm(Scalar* buffer) const { 782 // columns in left side, rows in right side 783 const Index k = this->m_k_size; 784 this->template evalGemmPartial<lhs_inner_dim_contiguous, 785 rhs_inner_dim_contiguous, 786 rhs_inner_dim_reordered, 787 Alignment, true>(buffer, 0, k, 1); 788 } 789 790 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, 791 bool rhs_inner_dim_reordered, int Alignment> 792 EIGEN_DEVICE_FUNC void evalGemmPartialWithoutOutputKernel( 793 Scalar* buffer, Index k_start, Index k_end, int num_threads) const { 794 evalGemmPartial<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, 795 rhs_inner_dim_reordered, Alignment, 796 /*use_output_kernel*/ false>(buffer, k_start, k_end, 797 num_threads); 798 } 799 800 template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment, bool use_output_kernel> 801 EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar* buffer, Index k_start, Index k_end, int num_threads) const { 802 eigen_assert(k_end >= k_start && k_start >= 0 && k_end <= this->m_k_size); 803 // columns in slice on left side, rows on right side 804 const Index k_slice = k_end - k_start; 805 806 // rows in left side 807 const Index m = this->m_i_size; 808 809 // columns in right side 810 const Index n = this->m_j_size; 811 812 // define data mappers for Lhs and Rhs 813 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar; 814 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar; 815 816 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator; 817 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator; 818 819 const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size; 820 const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size; 821 822 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs, 823 LeftEvaluator, left_nocontract_t, 824 contract_t, lhs_packet_size, 825 lhs_inner_dim_contiguous, 826 false, Unaligned> LhsMapper; 827 828 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs, 829 RightEvaluator, right_nocontract_t, 830 contract_t, rhs_packet_size, 831 rhs_inner_dim_contiguous, 832 rhs_inner_dim_reordered, Unaligned> RhsMapper; 833 834 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper; 835 836 typedef internal::TensorContractionKernel< 837 Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper> 838 TensorContractionKernel; 839 840 // initialize data mappers 841 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides, 842 this->m_left_contracting_strides, this->m_k_strides); 843 844 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides, 845 this->m_right_contracting_strides, this->m_k_strides); 846 847 OutputMapper output(buffer, m); 848 849 // Sizes of the blocks to load in cache. See the Goto paper for details. 850 internal::TensorContractionBlocking<Scalar, LhsScalar, RhsScalar, 851 Index, internal::ShardByCol> 852 blocking(k_slice, m, n, num_threads); 853 const Index kc = blocking.kc(); 854 const Index mc = numext::mini(m, blocking.mc()); 855 const Index nc = numext::mini(n, blocking.nc()); 856 857 typedef typename TensorContractionKernel::LhsBlock LhsBlock; 858 typedef typename TensorContractionKernel::RhsBlock RhsBlock; 859 860 LhsBlock blockA; 861 RhsBlock blockB; 862 863 TensorContractionKernel kernel(m, k_slice, n, mc, kc, nc); 864 865 typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle; 866 const BlockMemHandle packed_mem = 867 kernel.allocate(this->m_device, &blockA, &blockB); 868 869 // If a contraction kernel does not support beta, explicitly initialize 870 // output buffer with zeroes. 871 if (!TensorContractionKernel::HasBeta) { 872 this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); 873 } 874 875 for(Index i2=0; i2<m; i2+=mc) 876 { 877 const Index actual_mc = numext::mini(i2+mc,m)-i2; 878 for (Index k2 = k_start; k2 < k_end; k2 += kc) { 879 // make sure we don't overshoot right edge of left matrix, then pack vertical panel 880 const Index actual_kc = numext::mini(k2 + kc, k_end) - k2; 881 kernel.packLhs(&blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc); 882 883 // If kernel supports beta, there is no need to initialize output 884 // buffer with zeroes. 885 const Scalar alpha = Scalar(1); 886 const Scalar beta = (TensorContractionKernel::HasBeta && k2 == k_start) 887 ? Scalar(0) 888 : Scalar(1); 889 890 // series of horizontal blocks 891 for (Index j2 = 0; j2 < n; j2 += nc) { 892 // make sure we don't overshoot right edge of right matrix, then pack block 893 const Index actual_nc = numext::mini(j2 + nc, n) - j2; 894 kernel.packRhs(&blockB, rhs.getSubMapper(k2, j2), actual_kc, 895 actual_nc); 896 897 // call gebp (matrix kernel) 898 // The parameters here are copied from Eigen's GEMM implementation 899 const OutputMapper output_mapper = output.getSubMapper(i2, j2); 900 kernel.invoke(output_mapper, blockA, blockB, actual_mc, actual_kc, 901 actual_nc, alpha, beta); 902 903 // We are done with this [i2, j2] output block. 904 if (use_output_kernel && k2 + kc >= k_end) { 905 m_output_kernel(output_mapper, m_tensor_contraction_params, i2, j2, 906 actual_mc, actual_nc); 907 } 908 } 909 } 910 } 911 912 kernel.deallocate(this->m_device, packed_mem); 913 } 914 915 EIGEN_STRONG_INLINE void cleanup() { 916 m_leftImpl.cleanup(); 917 m_rightImpl.cleanup(); 918 919 if (m_result != NULL) { 920 m_device.deallocate(m_result); 921 m_result = NULL; 922 } 923 } 924 925 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { 926 return m_result[index]; 927 } 928 929 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool) const { 930 return TensorOpCost(sizeof(CoeffReturnType), 0, 0); 931 } 932 933 template<int LoadMode> 934 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const { 935 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index); 936 } 937 938 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EvaluatorPointerType data() const { return m_result; } 939 940 protected: 941 Dimensions m_dimensions; 942 943 contract_t m_k_strides; 944 contract_t m_left_contracting_strides; 945 contract_t m_right_contracting_strides; 946 947 bool m_lhs_inner_dim_contiguous; 948 bool m_rhs_inner_dim_contiguous; 949 bool m_rhs_inner_dim_reordered; 950 951 left_nocontract_t m_i_strides; 952 right_nocontract_t m_j_strides; 953 left_nocontract_t m_left_nocontract_strides; 954 right_nocontract_t m_right_nocontract_strides; 955 956 Index m_i_size; 957 Index m_j_size; 958 Index m_k_size; 959 960 TensorContractionParams m_tensor_contraction_params; 961 962 TensorEvaluator<EvalLeftArgType, Device> m_leftImpl; 963 TensorEvaluator<EvalRightArgType, Device> m_rightImpl; 964 const Device EIGEN_DEVICE_REF m_device; 965 OutputKernelType m_output_kernel; 966 EvaluatorPointerType m_result; 967 }; 968 969 970 // evaluator for default device 971 template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType, typename Device> 972 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> : 973 public TensorContractionEvaluatorBase< 974 TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> > { 975 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self; 976 typedef TensorContractionEvaluatorBase<Self> Base; 977 978 typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType; 979 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; 980 typedef typename XprType::Index Index; 981 typedef typename XprType::CoeffReturnType CoeffReturnType; 982 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; 983 984 enum { 985 Layout = TensorEvaluator<LeftArgType, Device>::Layout 986 }; 987 988 // Most of the code is assuming that both input tensors are ColMajor. If the 989 // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: 990 // If we want to compute A * B = C, where A is LHS and B is RHS, the code 991 // will pretend B is LHS and A is RHS. 992 typedef typename internal::conditional< 993 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType; 994 typedef typename internal::conditional< 995 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType; 996 997 static const int LDims = 998 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value; 999 static const int RDims = 1000 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value; 1001 static const int ContractDims = internal::array_size<Indices>::value; 1002 1003 typedef array<Index, ContractDims> contract_t; 1004 typedef array<Index, LDims - ContractDims> left_nocontract_t; 1005 typedef array<Index, RDims - ContractDims> right_nocontract_t; 1006 1007 static const int NumDims = LDims + RDims - 2 * ContractDims; 1008 1009 // Could we use NumDimensions here? 1010 typedef DSizes<Index, NumDims> Dimensions; 1011 1012 TensorEvaluator(const XprType& op, const Device& device) : 1013 Base(op, device) { } 1014 1015 template <int Alignment> 1016 void evalProduct(Scalar* buffer) const { 1017 TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Alignment, (buffer)); 1018 } 1019 }; 1020 1021 } // end namespace Eigen 1022 1023 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 1024