1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 *reserved. SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 *this list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 *POSSIBILITY OF SUCH DAMAGE. 30 * 31 **************************************************************************************************/ 32 /*! \file 33 \brief Templates implementing loading of tiles from pitch-linear rank=2 34 tensors. 35 36 This iterator uses masks to guard out-of-bounds accesses. The first tile 37 this iterator visits maybe partial, then the remaining tiles are complete. 38 So, we only need to compute the predicates twice, once before the first tile 39 and once for the remaining full tiles which can share the same predicates. 40 41 A precomputed "Params" object minimizes the amount of state that must be 42 stored in registers, and integer addition is used to advance the pointer 43 through memory. 44 */ 45 46 #pragma once 47 48 #include <cutlass/arch/memory.h> 49 #include <cutlass/transform/threadblock/predicated_tile_access_iterator.h> 50 51 //////////////////////////////////////////////////////////////////////////////// 52 53 namespace cutlass { 54 namespace transform { 55 namespace threadblock { 56 57 //////////////////////////////////////////////////////////////////////////////// 58 59 /// PredicatedTileIteratorResidualLast 60 /// 61 /// Satisfies: ForwardTileIteratorConcept | 62 /// ReadableContiguousTileIteratorConcept | 63 /// WriteableContiguousTileIteratorConcept | 64 /// MaskedTileIteratorConcept 65 /// 66 /// Regular tile iterator using a precomputed control structure to minimize 67 /// register liveness and integer arithmetic. 68 /// 69 /// Layout is assumed to be invariant at the time the precomputed "Params" 70 /// object is constructed. 71 /// 72 /// Base pointer and tensor extents may be specified at the time the iterator is 73 /// constructed. Subsequently, they are assumed to be immutable. 74 /// 75 /// Adding a logical coordinate offset may be performed at the time the iterator 76 /// is constructed. Subsequent additions to logical coordinate offset may be 77 /// performed but are relatively expensive. 78 /// 79 /// Visitation order is intended to first visit a "residual" tile that may be 80 /// partially full in both the advance dimension and the steady-state dimension. 81 /// This is assumed to be the last tile in the iteration sequence. Advancing an 82 /// iterator that has just been constructed moves to the first tile that is full 83 /// in the advance dimension and recomputes predicates. Subsequent accesses may 84 /// be performed without updating internal predicates and are efficient in terms 85 /// of live register state and pointer arithmetic instructions. 86 /// 87 /// To be efficient, this assumes the iterator will be dereferenced and advanced 88 /// at least once outside any looping structure to minimize integer arithmetic. 89 /// 90 /// Access out of bounds are safe so long as `clear_mask()` is called prior to 91 /// dereferencing the iterator. 92 /// 93 /// 94 /// Example: 95 /// 96 /// An efficient pipeline structure may be constructed as follows: 97 /// 98 // template <typename Iterator> 99 // __global__ void kernel( 100 // typename Iterator::Params params, 101 // typename Iterator::Element *ptr, 102 // TensorCoord extent) { 103 // 104 // typename Iterator::Fragment fragment; 105 // 106 // TensorCoord threadblock_offset(0, 0); 107 // 108 // Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); 109 // 110 // 111 // fragment = *iter; // load "residue" tile first 112 // ++iter; // advance to first "steady state" tile and update 113 // internal masks 114 // 115 // 116 // #pragma unroll 117 // for (int i = Remaining - 1; i >= 0; --i) { 118 // 119 // f(fragment); 120 // 121 // if (!i) { 122 // iter.clear_mask(); // light-weight operation to clear masks - 123 // subsequent loads become NO-OPs. 124 // } 125 // 126 // fragment = *iter; // load tile during "steady state" phase 127 // ++iter; // advance to next tile - lightweight due to 128 // steady-state masks 129 // } 130 // } 131 // 132 // void host(TensorView<Element, 2, layout::PitchLinear> view) { 133 // 134 // using Iterator = 135 // transform::threadblock::PredicatedTileIteratorResidualLast; 136 // 137 // typename Iterator::Params params(view.layout()); 138 // 139 // kernel<Iterator>(params, view.data()); 140 // } 141 /// 142 /// 143 template < 144 typename Shape, 145 typename Element, 146 typename Layout, 147 int AdvanceRank, 148 typename ThreadMap, 149 int AccessSize = ThreadMap::kElementsPerAccess, 150 bool Gather = false> 151 class PredicatedTileIteratorResidualLast; 152 153 //////////////////////////////////////////////////////////////////////////////// 154 155 /// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. 156 /// 157 /// Satisfies: ForwardTileIteratorConcept | 158 /// ReadableContiguousTileIteratorConcept | 159 /// WriteableContiguousTileIteratorConcept | 160 /// MaskedTileIteratorConcept 161 /// 162 template < 163 typename Shape_, 164 typename Element_, 165 int AdvanceRank, 166 typename ThreadMap_, 167 int AccessSize, 168 bool Gather> 169 class PredicatedTileIteratorResidualLast< 170 Shape_, 171 Element_, 172 layout::PitchLinear, 173 AdvanceRank, 174 ThreadMap_, 175 AccessSize, 176 Gather> { 177 public: 178 static_assert( 179 AdvanceRank == 0 || AdvanceRank == 1, 180 "Specialization for pitch-linear iterator may advance along the " 181 "contiguous(rank=0) or strided(rank=1) dimension."); 182 183 using Shape = Shape_; 184 using Element = Element_; 185 using Layout = layout::PitchLinear; 186 static int const kAdvanceRank = AdvanceRank; 187 using ThreadMap = ThreadMap_; 188 189 using Index = typename Layout::Index; 190 using LongIndex = typename Layout::LongIndex; 191 192 using TensorRef = TensorRef<Element, Layout>; 193 using TensorView = TensorView<Element, Layout>; 194 using TensorCoord = typename Layout::TensorCoord; 195 196 using Pointer = Element*; 197 using NonConstPointer = typename platform::remove_const<Element>::type*; 198 199 /// Type used for internal memory accesses 200 using AccessType = AlignedArray< 201 Element, 202 AccessSize, 203 (AccessSize * sizeof_bits<Element>::value / 8)>; 204 205 /// Underlying iterator to compute the addresses 206 using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< 207 Shape, 208 Element, 209 Layout, 210 kAdvanceRank, 211 ThreadMap, 212 AccessType, 213 Gather>; 214 215 static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; 216 217 /// Fragment object to be loaded or stored 218 using Fragment = cutlass::Array< 219 Element, 220 ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; 221 222 /// Predicate vector stores mask to guard accesses 223 using Mask = typename TileAccessIterator::Mask; 224 225 /// Parameters object is precomputed state and is host-constructible 226 class Params { 227 public: 228 using Base = typename TileAccessIterator::Params::Base; 229 230 friend PredicatedTileIteratorResidualLast; 231 232 private: 233 /// Parameters object 234 typename TileAccessIterator::Params params_; 235 236 public: 237 /// Construct the Params object given a pitch-linear tensor's layout 238 CUTLASS_HOST_DEVICE Params(Layout const & layout)239 Params(Layout const& layout) : params_(layout) {} 240 241 CUTLASS_HOST_DEVICE Params()242 Params() {} 243 244 CUTLASS_HOST_DEVICE Params(Base const & base)245 Params(Base const& base) : params_(base) {} 246 }; 247 248 private: 249 /// Internal pointer type permits fast address arithmetic 250 using BytePointer = char*; 251 252 private: 253 // 254 // Data members 255 // 256 257 /// Data member to the tile access iterator 258 TileAccessIterator address_iterator_; 259 260 public: 261 /// Constructs a TileIterator from its precomputed state, threadblock offset, 262 /// and thread ID 263 CUTLASS_HOST_DEVICE 264 PredicatedTileIteratorResidualLast( 265 /// Precomputed parameters object 266 Params const& params, 267 /// Pointer to start of tensor 268 Pointer pointer, 269 /// Extent of tensor 270 TensorCoord extent, 271 /// ID of each participating thread 272 int thread_id, 273 /// Initial offset of threadblock 274 TensorCoord const& threadblock_offset, 275 /// Gather indices 276 int const* indices = nullptr) 277 : address_iterator_( 278 params.params_, 279 pointer, 280 extent, 281 thread_id, 282 threadblock_offset, 283 indices) {} 284 285 /// Construct a PredicatedTileIteratorResidualLast with zero threadblock 286 /// offset 287 CUTLASS_HOST_DEVICE PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)288 PredicatedTileIteratorResidualLast( 289 Params const& params, ///< Precomputed parameters object 290 Pointer pointer, ///< Pointer to start of tensor 291 TensorCoord extent, ///< Extent of tensor 292 int thread_id ///< ID of each participating thread 293 ) 294 : PredicatedTileIteratorResidualLast( 295 params, 296 pointer, 297 extent, 298 thread_id, 299 make_Coord(0, 0)) {} 300 301 /// Adds a pointer offset in units of Element 302 CUTLASS_HOST_DEVICE add_pointer_offset(LongIndex pointer_offset)303 void add_pointer_offset(LongIndex pointer_offset) { 304 address_iterator_.add_pointer_offset(pointer_offset); 305 } 306 307 /// Advances to the next tile in memory. 308 /// 309 /// The first time this method is called, predicates are updated, and the 310 /// iterator's internal pointer is reverted to the first "steady state" tile. 311 /// Subsequent calls are lightweight and must only update the internal 312 /// pointer. 313 CUTLASS_HOST_DEVICE 314 PredicatedTileIteratorResidualLast& operator++() { 315 if (kAdvanceRank) 316 address_iterator_.add_tile_offset({0, 1}); 317 else 318 address_iterator_.add_tile_offset({1, 0}); 319 320 return *this; 321 } 322 323 /// Advances to the next tile in memory. 324 /// 325 /// The first time this method is called, predicates are updated, and the 326 /// iterator's internal pointer is reverted to the first "steady state" tile. 327 /// Subsequent calls are lightweight and must only update the internal 328 /// pointer. 329 CUTLASS_HOST_DEVICE 330 PredicatedTileIteratorResidualLast operator++(int) { 331 PredicatedTileIteratorResidualLast self(*this); 332 operator++(); 333 return self; 334 } 335 336 /// Clears the predicate set efficiently 337 CUTLASS_HOST_DEVICE 338 void clear_mask(bool enable = true) { 339 address_iterator_.clear_mask(enable); 340 } 341 342 CUTLASS_HOST_DEVICE set_residual_tile(bool enable)343 void set_residual_tile(bool enable) { 344 address_iterator_.set_residual_tile(enable); 345 } 346 347 /// Clears the predicate set efficiently 348 CUTLASS_HOST_DEVICE enable_mask()349 void enable_mask() { 350 address_iterator_.enable_mask(); 351 } 352 353 /// Sets the predicate mask, overriding value stored in predicate iterator 354 CUTLASS_HOST_DEVICE set_mask(Mask const & mask)355 void set_mask(Mask const& mask) { 356 address_iterator_.set_mask(mask); 357 } 358 359 /// Gets the mask 360 CUTLASS_HOST_DEVICE get_mask(Mask & mask)361 void get_mask(Mask& mask) { 362 address_iterator_.get_mask(mask); 363 } 364 365 CUTLASS_DEVICE load_with_pointer_offset(Fragment & frag,Index pointer_offset)366 void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { 367 load_with_byte_offset( 368 frag, pointer_offset * sizeof_bits<Element>::value / 8); 369 } 370 371 CUTLASS_DEVICE load_with_byte_offset(Fragment & frag,LongIndex byte_offset)372 void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { 373 AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); 374 375 CUTLASS_PRAGMA_UNROLL 376 for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { 377 CUTLASS_PRAGMA_UNROLL 378 for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { 379 CUTLASS_PRAGMA_UNROLL 380 for (int v = 0; v < kAccessesPerVector; ++v) { 381 int idx = v + 382 kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); 383 384 address_iterator_.set_iteration_index(idx); 385 char const* byte_ptr = 386 reinterpret_cast<char const*>(address_iterator_.get()) + 387 byte_offset; 388 389 AccessType const* access_ptr = 390 reinterpret_cast<AccessType const*>(byte_ptr); 391 392 cutlass::arch::global_load<AccessType, sizeof(AccessType)>( 393 frag_ptr[idx], access_ptr, address_iterator_.valid()); 394 395 ++address_iterator_; 396 } 397 } 398 } 399 } 400 401 /// Loads a fragment from memory 402 CUTLASS_DEVICE load(Fragment & frag)403 void load(Fragment& frag) { 404 load_with_byte_offset(frag, 0); 405 } 406 407 /// Store a fragment to memory 408 CUTLASS_DEVICE store_with_pointer_offset(Fragment const & frag,Index pointer_offset)409 void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { 410 store_with_byte_offset( 411 frag, pointer_offset * sizeof_bits<Element>::value / 8); 412 } 413 414 /// Store a fragment to memory 415 CUTLASS_DEVICE store_with_byte_offset(Fragment const & frag,LongIndex byte_offset)416 void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { 417 address_iterator_.set_iteration_index(0); 418 AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag); 419 420 CUTLASS_PRAGMA_UNROLL 421 for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { 422 CUTLASS_PRAGMA_UNROLL 423 for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { 424 CUTLASS_PRAGMA_UNROLL 425 for (int v = 0; v < kAccessesPerVector; ++v) { 426 int idx = v + 427 kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); 428 429 char* byte_ptr = 430 reinterpret_cast<char*>(address_iterator_.get()) + byte_offset; 431 AccessType* access_ptr = reinterpret_cast<AccessType*>(byte_ptr); 432 433 if (address_iterator_.valid()) { 434 *access_ptr = frag_ptr[idx]; 435 } 436 ++address_iterator_; 437 } 438 } 439 } 440 } 441 442 /// Store a fragment to memory 443 CUTLASS_DEVICE store(Fragment const & frag)444 void store(Fragment const& frag) { 445 store_with_byte_offset(frag, 0); 446 } 447 }; 448 449 //////////////////////////////////////////////////////////////////////////////// 450 451 /// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. 452 /// 453 /// Satisfies: ForwardTileIteratorConcept | 454 /// ReadableContiguousTileIteratorConcept | 455 /// WriteableContiguousTileIteratorConcept | 456 /// MaskedTileIteratorConcept 457 /// 458 template < 459 typename Shape_, 460 typename Element_, 461 int AdvanceRank, 462 typename ThreadMap_, 463 int AccessSize, 464 bool Gather> 465 class PredicatedTileIteratorResidualLast< 466 Shape_, 467 Element_, 468 layout::ColumnMajor, 469 AdvanceRank, 470 ThreadMap_, 471 AccessSize, 472 Gather> { 473 public: 474 static_assert( 475 AdvanceRank == 0 || AdvanceRank == 1, 476 "Specialization for pitch-linear iterator may along advance along the " 477 "contiguous(rank=0) or strided(rank=1) dimension."); 478 479 using Shape = Shape_; 480 using Element = Element_; 481 using Layout = layout::ColumnMajor; 482 static int const kAdvanceRank = AdvanceRank; 483 using ThreadMap = ThreadMap_; 484 485 using Index = typename Layout::Index; 486 using LongIndex = typename Layout::LongIndex; 487 488 using TensorRef = TensorRef<Element, Layout>; 489 using TensorView = TensorView<Element, Layout>; 490 using TensorCoord = typename Layout::TensorCoord; 491 492 using Pointer = Element*; 493 using NonConstPointer = typename platform::remove_const<Element>::type*; 494 495 using UnderlyingIterator = PredicatedTileIteratorResidualLast< 496 layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, 497 Element, 498 layout::PitchLinear, 499 (kAdvanceRank == 0 ? 0 : 1), 500 ThreadMap, 501 AccessSize, 502 Gather>; 503 504 using AccessType = typename UnderlyingIterator::AccessType; 505 506 /// Fragment object to be loaded or stored 507 using Fragment = cutlass::Array< 508 Element, 509 ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; 510 511 /// Predicate vector stores mask to guard accesses 512 using Mask = typename UnderlyingIterator::Mask; 513 514 /// Parameters object is precomputed state and is host-constructible 515 class Params { 516 private: 517 friend PredicatedTileIteratorResidualLast; 518 519 /// Parameters object 520 typename UnderlyingIterator::Params params_; 521 522 public: 523 CUTLASS_HOST_DEVICE Params()524 Params() {} 525 526 /// Construct the Params object given a pitch-linear tensor's layout 527 CUTLASS_HOST_DEVICE Params(Layout const & layout)528 Params(Layout const& layout) 529 : params_(layout::PitchLinear(layout.stride(0))) {} 530 531 CUTLASS_HOST_DEVICE Params(typename UnderlyingIterator::Params::Base const & base)532 Params(typename UnderlyingIterator::Params::Base const& base) 533 : params_(base) {} 534 }; 535 536 private: 537 // 538 // Data members 539 // 540 541 /// Underlying pitch-linear tile iterator 542 UnderlyingIterator iterator_; 543 544 public: 545 /// Constructs a TileIterator from its precomputed state, threadblock offset, 546 /// and thread ID 547 CUTLASS_HOST_DEVICE 548 PredicatedTileIteratorResidualLast( 549 Params const& params, ///< Precomputed parameters object 550 Pointer pointer, ///< Pointer to start of tensor 551 TensorCoord extent, ///< Extent of tensor 552 int thread_id, ///< ID of each participating thread 553 TensorCoord const& threadblock_offset, ///< Initial offset of threadblock 554 int const* indices = 555 nullptr ///< gather/scatter indices, note no support for 556 ///< gather/scatter at this specialization 557 ) 558 : iterator_( 559 params.params_, 560 pointer, 561 layout::PitchLinearCoord(extent.row(), extent.column()), 562 thread_id, 563 layout::PitchLinearCoord( 564 threadblock_offset.row(), 565 threadblock_offset.column()), 566 indices) {} 567 568 /// Construct a PredicatedTileIteratorResidualLast with zero threadblock 569 /// offset 570 CUTLASS_HOST_DEVICE PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)571 PredicatedTileIteratorResidualLast( 572 Params const& params, ///< Precomputed parameters object 573 Pointer pointer, ///< Pointer to start of tensor 574 TensorCoord extent, ///< Extent of tensor 575 int thread_id ///< ID of each participating thread 576 ) 577 : PredicatedTileIteratorResidualLast( 578 params, 579 pointer, 580 extent, 581 thread_id, 582 make_Coord(0, 0)) {} 583 584 /// Adds a pointer offset in units of Element 585 CUTLASS_HOST_DEVICE add_pointer_offset(LongIndex pointer_offset)586 void add_pointer_offset(LongIndex pointer_offset) { 587 iterator_.add_pointer_offset(pointer_offset); 588 } 589 590 /// Advances to the next tile in memory. 591 /// 592 /// The first time this method is called, predicates are updated, and the 593 /// iterator's internal pointer is reverted to the first "steady state" tile. 594 /// Subsequent calls are lightweight and must only update the internal 595 /// pointer. 596 CUTLASS_HOST_DEVICE 597 PredicatedTileIteratorResidualLast& operator++() { 598 ++iterator_; 599 return *this; 600 } 601 602 /// Advances to the next tile in memory. 603 /// 604 /// The first time this method is called, predicates are updated, and the 605 /// iterator's internal pointer is reverted to the first "steady state" tile. 606 /// Subsequent calls are lightweight and must only update the internal 607 /// pointer. 608 CUTLASS_HOST_DEVICE 609 PredicatedTileIteratorResidualLast operator++(int) { 610 PredicatedTileIteratorResidualLast self(*this); 611 operator++(); 612 return self; 613 } 614 615 /// Clears the predicate set efficiently 616 CUTLASS_HOST_DEVICE 617 void clear_mask(bool enable = true) { 618 iterator_.clear_mask(enable); 619 } 620 621 CUTLASS_HOST_DEVICE set_residual_tile(bool enable)622 void set_residual_tile(bool enable) { 623 iterator_.set_residual_tile(enable); 624 } 625 626 /// Clears the predicate set efficiently 627 CUTLASS_HOST_DEVICE enable_mask()628 void enable_mask() { 629 iterator_.enable_mask(); 630 } 631 632 /// Sets the predicate mask, overriding value stored in predicate iterator 633 CUTLASS_HOST_DEVICE set_mask(Mask const & mask)634 void set_mask(Mask const& mask) { 635 iterator_.set_mask(mask); 636 } 637 638 /// Gets the mask 639 CUTLASS_HOST_DEVICE get_mask(Mask & mask)640 void get_mask(Mask& mask) { 641 iterator_.get_mask(mask); 642 } 643 644 /// Loads a fragment from memory 645 CUTLASS_DEVICE load_with_pointer_offset(Fragment & frag,Index pointer_offset)646 void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { 647 iterator_.load_with_pointer_offset(frag, pointer_offset); 648 } 649 650 /// Loads a fragment from memory 651 CUTLASS_DEVICE load_with_byte_offset(Fragment & frag,LongIndex byte_offset)652 void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { 653 iterator_.load_with_byte_offset(frag, byte_offset); 654 } 655 656 /// Loads a fragment from memory 657 CUTLASS_DEVICE load(Fragment & frag)658 void load(Fragment& frag) { 659 load_with_pointer_offset(frag, 0); 660 } 661 662 /// Store a fragment to memory 663 CUTLASS_DEVICE store_with_pointer_offset(Fragment const & frag,Index pointer_offset)664 void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { 665 iterator_.store_with_pointer_offset(frag, pointer_offset); 666 } 667 668 /// Store a fragment to memory 669 CUTLASS_DEVICE store_with_byte_offset(Fragment const & frag,LongIndex byte_offset)670 void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { 671 iterator_.store_with_byte_offset(frag, byte_offset); 672 } 673 674 /// Store a fragment to memory 675 CUTLASS_DEVICE store(Fragment const & frag)676 void store(Fragment const& frag) { 677 store_with_pointer_offset(frag, 0); 678 } 679 }; 680 681 //////////////////////////////////////////////////////////////////////////////// 682 683 /// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. 684 /// 685 /// Satisfies: ForwardTileIteratorConcept | 686 /// ReadableContiguousTileIteratorConcept | 687 /// WriteableContiguousTileIteratorConcept | 688 /// MaskedTileIteratorConcept 689 /// 690 template < 691 typename Shape_, 692 typename Element_, 693 int AdvanceRank, 694 typename ThreadMap_, 695 int AccessSize, 696 bool Gather> 697 class PredicatedTileIteratorResidualLast< 698 Shape_, 699 Element_, 700 layout::RowMajor, 701 AdvanceRank, 702 ThreadMap_, 703 AccessSize, 704 Gather> { 705 public: 706 static_assert( 707 AdvanceRank == 0 || AdvanceRank == 1, 708 "Specialization for pitch-linear iterator may along advance along the " 709 "contiguous(rank=0) or strided(rank=1) dimension."); 710 711 using Shape = Shape_; 712 using Element = Element_; 713 using Layout = layout::RowMajor; 714 static int const kAdvanceRank = AdvanceRank; 715 using ThreadMap = ThreadMap_; 716 717 using Index = typename Layout::Index; 718 using LongIndex = typename Layout::LongIndex; 719 720 using TensorRef = TensorRef<Element, Layout>; 721 using TensorView = TensorView<Element, Layout>; 722 using TensorCoord = typename Layout::TensorCoord; 723 724 using Pointer = Element*; 725 using NonConstPointer = typename platform::remove_const<Element>::type*; 726 727 using UnderlyingIterator = PredicatedTileIteratorResidualLast< 728 layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, 729 Element, 730 layout::PitchLinear, 731 (kAdvanceRank == 0 ? 1 : 0), 732 ThreadMap, 733 AccessSize, 734 Gather>; 735 736 using AccessType = typename UnderlyingIterator::AccessType; 737 738 /// Fragment object to be loaded or stored 739 using Fragment = cutlass::Array< 740 Element, 741 ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; 742 743 /// Predicate vector stores mask to guard accesses 744 using Mask = typename UnderlyingIterator::Mask; 745 746 /// Parameters object is precomputed state and is host-constructible 747 class Params { 748 private: 749 friend PredicatedTileIteratorResidualLast; 750 751 /// Parameters object 752 typename UnderlyingIterator::Params params_; 753 754 public: 755 CUTLASS_HOST_DEVICE Params()756 Params() {} 757 758 /// Construct the Params object given a pitch-linear tensor's layout 759 CUTLASS_HOST_DEVICE Params(Layout const & layout)760 Params(Layout const& layout) 761 : params_(layout::PitchLinear(layout.stride(0))) {} 762 763 CUTLASS_HOST_DEVICE Params(typename UnderlyingIterator::Params::Base const & base)764 Params(typename UnderlyingIterator::Params::Base const& base) 765 : params_(base) {} 766 }; 767 768 private: 769 // 770 // Data members 771 // 772 773 /// Underlying pitch-linear tile iterator 774 UnderlyingIterator iterator_; 775 776 public: 777 /// Constructs a TileIterator from its precomputed state, threadblock offset, 778 /// and thread ID 779 CUTLASS_HOST_DEVICE 780 PredicatedTileIteratorResidualLast( 781 Params const& params, ///< Precomputed parameters object 782 Pointer pointer, ///< Pointer to start of tensor 783 TensorCoord extent, ///< Extent of tensor 784 int thread_id, ///< ID of each participating thread 785 TensorCoord const& threadblock_offset, ///< Initial offset of threadblock 786 int const* indices = nullptr ///< Gather indices 787 ) 788 : iterator_( 789 params.params_, 790 pointer, 791 layout::PitchLinearCoord(extent.column(), extent.row()), 792 thread_id, 793 layout::PitchLinearCoord( 794 threadblock_offset.column(), 795 threadblock_offset.row()), 796 indices) {} 797 798 /// Construct a PredicatedTileIteratorResidualLast with zero threadblock 799 /// offset 800 CUTLASS_HOST_DEVICE PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)801 PredicatedTileIteratorResidualLast( 802 Params const& params, ///< Precomputed parameters object 803 Pointer pointer, ///< Pointer to start of tensor 804 TensorCoord extent, ///< Extent of tensor 805 int thread_id ///< ID of each participating thread 806 ) 807 : PredicatedTileIteratorResidualLast( 808 params, 809 pointer, 810 extent, 811 thread_id, 812 make_Coord(0, 0)) {} 813 814 /// Adds a pointer offset in units of Element 815 CUTLASS_HOST_DEVICE add_pointer_offset(LongIndex pointer_offset)816 void add_pointer_offset(LongIndex pointer_offset) { 817 iterator_.add_pointer_offset(pointer_offset); 818 } 819 820 /// Advances to the next tile in memory. 821 /// 822 /// The first time this method is called, predicates are updated, and the 823 /// iterator's internal pointer is reverted to the first "steady state" tile. 824 /// Subsequent calls are lightweight and must only update the internal 825 /// pointer. 826 CUTLASS_HOST_DEVICE 827 PredicatedTileIteratorResidualLast& operator++() { 828 ++iterator_; 829 return *this; 830 } 831 832 /// Advances to the next tile in memory. 833 /// 834 /// The first time this method is called, predicates are updated, and the 835 /// iterator's internal pointer is reverted to the first "steady state" tile. 836 /// Subsequent calls are lightweight and must only update the internal 837 /// pointer. 838 CUTLASS_HOST_DEVICE 839 PredicatedTileIteratorResidualLast operator++(int) { 840 PredicatedTileIteratorResidualLast self(*this); 841 operator++(); 842 return self; 843 } 844 845 /// Clears the predicate set efficiently 846 CUTLASS_HOST_DEVICE 847 void clear_mask(bool enable = true) { 848 iterator_.clear_mask(enable); 849 } 850 851 CUTLASS_HOST_DEVICE set_residual_tile(bool enable)852 void set_residual_tile(bool enable) { 853 iterator_.set_residual_tile(enable); 854 } 855 856 /// Clears the predicate set efficiently 857 CUTLASS_HOST_DEVICE enable_mask()858 void enable_mask() { 859 iterator_.enable_mask(); 860 } 861 862 /// Sets the predicate mask, overriding value stored in predicate iterator 863 CUTLASS_HOST_DEVICE set_mask(Mask const & mask)864 void set_mask(Mask const& mask) { 865 iterator_.set_mask(mask); 866 } 867 868 /// Gets the mask 869 CUTLASS_HOST_DEVICE get_mask(Mask & mask)870 void get_mask(Mask& mask) { 871 iterator_.get_mask(mask); 872 } 873 874 /// Loads a fragment from memory 875 CUTLASS_DEVICE load_with_pointer_offset(Fragment & frag,Index pointer_offset)876 void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { 877 iterator_.load_with_pointer_offset(frag, pointer_offset); 878 } 879 880 /// Loads a fragment from memory 881 CUTLASS_DEVICE load_with_byte_offset(Fragment & frag,LongIndex byte_offset)882 void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { 883 iterator_.load_with_byte_offset(frag, byte_offset); 884 } 885 886 /// Loads a fragment from memory 887 CUTLASS_DEVICE load(Fragment & frag)888 void load(Fragment& frag) { 889 load_with_pointer_offset(frag, 0); 890 } 891 892 /// Store a fragment to memory 893 CUTLASS_DEVICE store_with_pointer_offset(Fragment const & frag,Index pointer_offset)894 void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { 895 iterator_.store_with_pointer_offset(frag, pointer_offset); 896 } 897 898 /// Store a fragment to memory 899 CUTLASS_DEVICE store_with_byte_offset(Fragment const & frag,LongIndex byte_offset)900 void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { 901 iterator_.store_with_byte_offset(frag, byte_offset); 902 } 903 904 /// Store a fragment to memory 905 CUTLASS_DEVICE store(Fragment const & frag)906 void store(Fragment const& frag) { 907 store_with_pointer_offset(frag, 0); 908 } 909 }; 910 911 //////////////////////////////////////////////////////////////////////////////// 912 913 /// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data. 914 /// 915 /// Satisfies: ForwardTileIteratorConcept | 916 /// ReadableContiguousTileIteratorConcept | 917 /// WriteableContiguousTileIteratorConcept | 918 /// MaskedTileIteratorConcept 919 /// 920 template < 921 typename Shape_, 922 typename Element_, 923 int AdvanceRank, 924 typename ThreadMap_, 925 int AccessSize> 926 class PredicatedTileIteratorResidualLast< 927 Shape_, 928 Element_, 929 layout::AffineRankN<2>, 930 AdvanceRank, 931 ThreadMap_, 932 AccessSize, 933 false> { 934 public: 935 static_assert( 936 AdvanceRank == 0 || AdvanceRank == 1, 937 "Specialization for pitch-linear iterator may advance along the " 938 "contiguous(rank=0) or strided(rank=1) dimension."); 939 940 using Shape = Shape_; 941 using Element = Element_; 942 using Layout = layout::AffineRankN<2>; 943 static int const kAdvanceRank = AdvanceRank; 944 using ThreadMap = ThreadMap_; 945 946 using Index = typename Layout::Index; 947 using LongIndex = typename Layout::LongIndex; 948 949 using TensorRef = TensorRef<Element, Layout>; 950 using TensorView = TensorView<Element, Layout>; 951 using TensorCoord = typename Layout::TensorCoord; 952 953 using Pointer = Element*; 954 using NonConstPointer = typename platform::remove_const<Element>::type*; 955 956 /// Type used for internal memory accesses 957 using AccessType = AlignedArray< 958 Element, 959 AccessSize, 960 (AccessSize * sizeof_bits<Element>::value / 8)>; 961 962 /// Underlying iterator to compute the addresses 963 using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< 964 Shape, 965 Element, 966 Layout, 967 kAdvanceRank, 968 ThreadMap, 969 AccessType>; 970 971 static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; 972 973 /// Fragment object to be loaded or stored 974 using Fragment = cutlass::Array< 975 Element, 976 ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; 977 978 /// Predicate vector stores mask to guard accesses 979 using Mask = typename TileAccessIterator::Mask; 980 981 /// Parameters object is precomputed state and is host-constructible 982 class Params { 983 public: 984 friend PredicatedTileIteratorResidualLast; 985 986 private: 987 /// Parameters object 988 typename TileAccessIterator::Params params_; 989 990 public: 991 /// Construct the Params object given a pitch-linear tensor's layout 992 CUTLASS_HOST_DEVICE Params(Layout const & layout)993 Params(Layout const& layout) : params_(layout) {} 994 995 CUTLASS_HOST_DEVICE Params()996 Params() {} 997 }; 998 999 private: 1000 /// Internal pointer type permits fast address arithmetic 1001 using BytePointer = char*; 1002 1003 private: 1004 // 1005 // Data members 1006 // 1007 1008 /// Data member to the tile access iterator 1009 TileAccessIterator address_iterator_; 1010 1011 public: 1012 /// Constructs a TileIterator from its precomputed state, threadblock offset, 1013 /// and thread ID 1014 CUTLASS_HOST_DEVICE 1015 PredicatedTileIteratorResidualLast( 1016 /// Precomputed parameters object 1017 Params const& params, 1018 /// Pointer to start of tensor 1019 Pointer pointer, 1020 /// Extent of tensor 1021 TensorCoord extent, 1022 /// ID of each participating thread 1023 int thread_id, 1024 /// Initial offset of threadblock 1025 TensorCoord const& threadblock_offset, 1026 int const* indices = 1027 nullptr ///< gather/scatter indices, note no support for 1028 ///< gather/scatter at this specialization 1029 ) 1030 : address_iterator_( 1031 params.params_, 1032 pointer, 1033 extent, 1034 thread_id, 1035 threadblock_offset) {} 1036 1037 /// Construct a PredicatedTileIteratorResidualLast with zero threadblock 1038 /// offset 1039 CUTLASS_HOST_DEVICE PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)1040 PredicatedTileIteratorResidualLast( 1041 Params const& params, ///< Precomputed parameters object 1042 Pointer pointer, ///< Pointer to start of tensor 1043 TensorCoord extent, ///< Extent of tensor 1044 int thread_id ///< ID of each participating thread 1045 ) 1046 : PredicatedTileIteratorResidualLast( 1047 params, 1048 pointer, 1049 extent, 1050 thread_id, 1051 make_Coord(0, 0)) {} 1052 1053 /// Adds a pointer offset in units of Element 1054 CUTLASS_HOST_DEVICE add_pointer_offset(LongIndex pointer_offset)1055 void add_pointer_offset(LongIndex pointer_offset) { 1056 address_iterator_.add_pointer_offset(pointer_offset); 1057 } 1058 1059 /// Advances to the next tile in memory. 1060 /// 1061 /// The first time this method is called, predicates are updated, and the 1062 /// iterator's internal pointer is reverted to the first "steady state" tile. 1063 /// Subsequent calls are lightweight and must only update the internal 1064 /// pointer. 1065 CUTLASS_HOST_DEVICE 1066 PredicatedTileIteratorResidualLast& operator++() { 1067 if (kAdvanceRank) 1068 address_iterator_.add_tile_offset(make_Coord(0, 1)); 1069 else 1070 address_iterator_.add_tile_offset(make_Coord(1, 0)); 1071 1072 return *this; 1073 } 1074 1075 /// Advances to the next tile in memory. 1076 /// 1077 /// The first time this method is called, predicates are updated, and the 1078 /// iterator's internal pointer is reverted to the first "steady state" tile. 1079 /// Subsequent calls are lightweight and must only update the internal 1080 /// pointer. 1081 CUTLASS_HOST_DEVICE 1082 PredicatedTileIteratorResidualLast operator++(int) { 1083 PredicatedTileIteratorResidualLast self(*this); 1084 operator++(); 1085 return self; 1086 } 1087 1088 /// Clears the predicate set efficiently 1089 CUTLASS_HOST_DEVICE 1090 void clear_mask(bool enable = true) { 1091 address_iterator_.clear_mask(enable); 1092 } 1093 1094 CUTLASS_HOST_DEVICE set_residual_tile(bool enable)1095 void set_residual_tile(bool enable) { 1096 address_iterator_.set_residual_tile(enable); 1097 } 1098 1099 /// Clears the predicate set efficiently 1100 CUTLASS_HOST_DEVICE enable_mask()1101 void enable_mask() { 1102 address_iterator_.enable_mask(); 1103 } 1104 1105 /// Sets the predicate mask, overriding value stored in predicate iterator 1106 CUTLASS_HOST_DEVICE set_mask(Mask const & mask)1107 void set_mask(Mask const& mask) { 1108 address_iterator_.set_mask(mask); 1109 } 1110 1111 /// Gets the mask 1112 CUTLASS_HOST_DEVICE get_mask(Mask & mask)1113 void get_mask(Mask& mask) { 1114 address_iterator_.get_mask(mask); 1115 } 1116 1117 CUTLASS_DEVICE load_with_pointer_offset(Fragment & frag,Index pointer_offset)1118 void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { 1119 load_with_byte_offset( 1120 frag, pointer_offset * sizeof_bits<Element>::value / 8); 1121 } 1122 1123 CUTLASS_DEVICE load_with_byte_offset(Fragment & frag,LongIndex byte_offset)1124 void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { 1125 AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); 1126 1127 CUTLASS_PRAGMA_UNROLL 1128 for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { 1129 CUTLASS_PRAGMA_UNROLL 1130 for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { 1131 CUTLASS_PRAGMA_UNROLL 1132 for (int v = 0; v < kAccessesPerVector; ++v) { 1133 int idx = v + 1134 kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); 1135 1136 address_iterator_.set_iteration_index(idx); 1137 char const* byte_ptr = 1138 reinterpret_cast<char const*>(address_iterator_.get()) + 1139 byte_offset; 1140 1141 AccessType const* access_ptr = 1142 reinterpret_cast<AccessType const*>(byte_ptr); 1143 1144 cutlass::arch::global_load<AccessType, sizeof(AccessType)>( 1145 frag_ptr[idx], access_ptr, address_iterator_.valid()); 1146 1147 ++address_iterator_; 1148 } 1149 } 1150 } 1151 } 1152 1153 /// Loads a fragment from memory 1154 CUTLASS_DEVICE load(Fragment & frag)1155 void load(Fragment& frag) { 1156 load_with_byte_offset(frag, 0); 1157 } 1158 1159 /// Store a fragment to memory 1160 CUTLASS_DEVICE store_with_pointer_offset(Fragment const & frag,Index pointer_offset)1161 void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { 1162 store_with_byte_offset( 1163 frag, pointer_offset * sizeof_bits<Element>::value / 8); 1164 } 1165 1166 /// Store a fragment to memory 1167 CUTLASS_DEVICE store_with_byte_offset(Fragment const & frag,LongIndex byte_offset)1168 void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { 1169 address_iterator_.set_iteration_index(0); 1170 AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag); 1171 1172 CUTLASS_PRAGMA_UNROLL 1173 for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { 1174 CUTLASS_PRAGMA_UNROLL 1175 for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { 1176 CUTLASS_PRAGMA_UNROLL 1177 for (int v = 0; v < kAccessesPerVector; ++v) { 1178 int idx = v + 1179 kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); 1180 1181 char* byte_ptr = 1182 reinterpret_cast<char*>(address_iterator_.get()) + byte_offset; 1183 AccessType* access_ptr = reinterpret_cast<AccessType*>(byte_ptr); 1184 1185 if (address_iterator_.valid()) { 1186 *access_ptr = frag_ptr[idx]; 1187 } 1188 ++address_iterator_; 1189 } 1190 } 1191 } 1192 } 1193 1194 /// Store a fragment to memory 1195 CUTLASS_DEVICE store(Fragment const & frag)1196 void store(Fragment const& frag) { 1197 store_with_byte_offset(frag, 0); 1198 } 1199 }; 1200 1201 //////////////////////////////////////////////////////////////////////////////// 1202 1203 /// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 1204 /// column-major data. 1205 /// 1206 /// Satisfies: ForwardTileIteratorConcept | 1207 /// ReadableContiguousTileIteratorConcept | 1208 /// WriteableContiguousTileIteratorConcept | 1209 /// MaskedTileIteratorConcept 1210 /// 1211 template < 1212 typename Shape_, 1213 typename Element_, 1214 int AdvanceRank, 1215 typename ThreadMap_, 1216 int AccessSize> 1217 class PredicatedTileIteratorResidualLast< 1218 Shape_, 1219 Element_, 1220 layout::AffineRank2ColumnMajor, 1221 AdvanceRank, 1222 ThreadMap_, 1223 AccessSize, 1224 false> { 1225 public: 1226 static_assert( 1227 AdvanceRank == 0 || AdvanceRank == 1, 1228 "Specialization for pitch-linear iterator may along advance along the " 1229 "contiguous(rank=0) or strided(rank=1) dimension."); 1230 1231 using Shape = Shape_; 1232 using Element = Element_; 1233 using Layout = layout::AffineRank2ColumnMajor; 1234 static int const kAdvanceRank = AdvanceRank; 1235 using ThreadMap = ThreadMap_; 1236 1237 using Index = typename Layout::Index; 1238 using LongIndex = typename Layout::LongIndex; 1239 1240 using TensorRef = TensorRef<Element, Layout>; 1241 using TensorView = TensorView<Element, Layout>; 1242 using TensorCoord = typename Layout::TensorCoord; 1243 1244 using Pointer = Element*; 1245 using NonConstPointer = typename platform::remove_const<Element>::type*; 1246 1247 // Map to the underlying AffineRankN<2> layout 1248 using UnderlyingIterator = PredicatedTileIteratorResidualLast< 1249 layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, 1250 Element, 1251 layout::AffineRankN<2>, 1252 (kAdvanceRank == 0 ? 0 : 1), 1253 ThreadMap, 1254 AccessSize>; 1255 1256 using AccessType = typename UnderlyingIterator::AccessType; 1257 1258 /// Fragment object to be loaded or stored 1259 using Fragment = cutlass::Array< 1260 Element, 1261 ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; 1262 1263 /// Predicate vector stores mask to guard accesses 1264 using Mask = typename UnderlyingIterator::Mask; 1265 1266 /// Parameters object is precomputed state and is host-constructible 1267 class Params { 1268 private: 1269 friend PredicatedTileIteratorResidualLast; 1270 1271 /// Parameters object 1272 typename UnderlyingIterator::Params params_; 1273 1274 public: 1275 CUTLASS_HOST_DEVICE Params()1276 Params() {} 1277 1278 /// Construct the Params object given an AffineRankN<2> tensor's layout 1279 CUTLASS_HOST_DEVICE Params(Layout const & layout)1280 Params(Layout const& layout) 1281 : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {} 1282 }; 1283 1284 private: 1285 // 1286 // Data members 1287 // 1288 1289 /// Underlying AffineRankN<2> tile iterator 1290 UnderlyingIterator iterator_; 1291 1292 public: 1293 /// Constructs a TileIterator from its precomputed state, threadblock offset, 1294 /// and thread ID 1295 CUTLASS_HOST_DEVICE 1296 PredicatedTileIteratorResidualLast( 1297 Params const& params, ///< Precomputed parameters object 1298 Pointer pointer, ///< Pointer to start of tensor 1299 TensorCoord extent, ///< Extent of tensor 1300 int thread_id, ///< ID of each participating thread 1301 TensorCoord const& threadblock_offset, ///< Initial offset of threadblock 1302 int const* indices = 1303 nullptr ///< gather/scatter indices, note no support for 1304 ///< gather/scatter at this specialization 1305 ) 1306 : iterator_( 1307 params.params_, 1308 pointer, 1309 layout::PitchLinearCoord(extent.row(), extent.column()), 1310 thread_id, 1311 layout::PitchLinearCoord( 1312 threadblock_offset.row(), 1313 threadblock_offset.column())) {} 1314 1315 /// Construct a PredicatedTileIteratorResidualLast with zero threadblock 1316 /// offset 1317 CUTLASS_HOST_DEVICE PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)1318 PredicatedTileIteratorResidualLast( 1319 Params const& params, ///< Precomputed parameters object 1320 Pointer pointer, ///< Pointer to start of tensor 1321 TensorCoord extent, ///< Extent of tensor 1322 int thread_id ///< ID of each participating thread 1323 ) 1324 : PredicatedTileIteratorResidualLast( 1325 params, 1326 pointer, 1327 extent, 1328 thread_id, 1329 make_Coord(0, 0)) {} 1330 1331 /// Adds a pointer offset in units of Element 1332 CUTLASS_HOST_DEVICE add_pointer_offset(LongIndex pointer_offset)1333 void add_pointer_offset(LongIndex pointer_offset) { 1334 iterator_.add_pointer_offset(pointer_offset); 1335 } 1336 1337 /// Advances to the next tile in memory. 1338 /// 1339 /// The first time this method is called, predicates are updated, and the 1340 /// iterator's internal pointer is reverted to the first "steady state" tile. 1341 /// Subsequent calls are lightweight and must only update the internal 1342 /// pointer. 1343 CUTLASS_HOST_DEVICE 1344 PredicatedTileIteratorResidualLast& operator++() { 1345 ++iterator_; 1346 return *this; 1347 } 1348 1349 /// Advances to the next tile in memory. 1350 /// 1351 /// The first time this method is called, predicates are updated, and the 1352 /// iterator's internal pointer is reverted to the first "steady state" tile. 1353 /// Subsequent calls are lightweight and must only update the internal 1354 /// pointer. 1355 CUTLASS_HOST_DEVICE 1356 PredicatedTileIteratorResidualLast operator++(int) { 1357 PredicatedTileIteratorResidualLast self(*this); 1358 operator++(); 1359 return self; 1360 } 1361 1362 /// Clears the predicate set efficiently 1363 CUTLASS_HOST_DEVICE 1364 void clear_mask(bool enable = true) { 1365 iterator_.clear_mask(enable); 1366 } 1367 1368 CUTLASS_HOST_DEVICE set_residual_tile(bool enable)1369 void set_residual_tile(bool enable) { 1370 iterator_.set_residual_tile(enable); 1371 } 1372 1373 /// Clears the predicate set efficiently 1374 CUTLASS_HOST_DEVICE enable_mask()1375 void enable_mask() { 1376 iterator_.enable_mask(); 1377 } 1378 1379 /// Sets the predicate mask, overriding value stored in predicate iterator 1380 CUTLASS_HOST_DEVICE set_mask(Mask const & mask)1381 void set_mask(Mask const& mask) { 1382 iterator_.set_mask(mask); 1383 } 1384 1385 /// Gets the mask 1386 CUTLASS_HOST_DEVICE get_mask(Mask & mask)1387 void get_mask(Mask& mask) { 1388 iterator_.get_mask(mask); 1389 } 1390 1391 /// Loads a fragment from memory 1392 CUTLASS_DEVICE load_with_pointer_offset(Fragment & frag,Index pointer_offset)1393 void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { 1394 iterator_.load_with_pointer_offset(frag, pointer_offset); 1395 } 1396 1397 /// Loads a fragment from memory 1398 CUTLASS_DEVICE load_with_byte_offset(Fragment & frag,LongIndex byte_offset)1399 void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { 1400 iterator_.load_with_byte_offset(frag, byte_offset); 1401 } 1402 1403 /// Loads a fragment from memory 1404 CUTLASS_DEVICE load(Fragment & frag)1405 void load(Fragment& frag) { 1406 load_with_pointer_offset(frag, 0); 1407 } 1408 1409 /// Store a fragment to memory 1410 CUTLASS_DEVICE store_with_pointer_offset(Fragment const & frag,Index pointer_offset)1411 void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { 1412 iterator_.store_with_pointer_offset(frag, pointer_offset); 1413 } 1414 1415 /// Store a fragment to memory 1416 CUTLASS_DEVICE store_with_byte_offset(Fragment const & frag,LongIndex byte_offset)1417 void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { 1418 iterator_.store_with_byte_offset(frag, byte_offset); 1419 } 1420 1421 /// Store a fragment to memory 1422 CUTLASS_DEVICE store(Fragment const & frag)1423 void store(Fragment const& frag) { 1424 store_with_pointer_offset(frag, 0); 1425 } 1426 }; 1427 1428 //////////////////////////////////////////////////////////////////////////////// 1429 1430 /// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 1431 /// row-major data. 1432 /// 1433 /// Satisfies: ForwardTileIteratorConcept | 1434 /// ReadableContiguousTileIteratorConcept | 1435 /// WriteableContiguousTileIteratorConcept | 1436 /// MaskedTileIteratorConcept 1437 /// 1438 template < 1439 typename Shape_, 1440 typename Element_, 1441 int AdvanceRank, 1442 typename ThreadMap_, 1443 int AccessSize> 1444 class PredicatedTileIteratorResidualLast< 1445 Shape_, 1446 Element_, 1447 layout::AffineRank2RowMajor, 1448 AdvanceRank, 1449 ThreadMap_, 1450 AccessSize, 1451 false> { 1452 public: 1453 static_assert( 1454 AdvanceRank == 0 || AdvanceRank == 1, 1455 "Specialization for pitch-linear iterator may along advance along the " 1456 "contiguous(rank=0) or strided(rank=1) dimension."); 1457 1458 using Shape = Shape_; 1459 using Element = Element_; 1460 using Layout = layout::AffineRank2RowMajor; 1461 static int const kAdvanceRank = AdvanceRank; 1462 using ThreadMap = ThreadMap_; 1463 1464 using Index = typename Layout::Index; 1465 using LongIndex = typename Layout::LongIndex; 1466 1467 using TensorRef = TensorRef<Element, Layout>; 1468 using TensorView = TensorView<Element, Layout>; 1469 using TensorCoord = typename Layout::TensorCoord; 1470 1471 using Pointer = Element*; 1472 using NonConstPointer = typename platform::remove_const<Element>::type*; 1473 1474 // Map to the underlying AffineRankN<2> layout 1475 using UnderlyingIterator = PredicatedTileIteratorResidualLast< 1476 layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, 1477 Element, 1478 layout::AffineRankN<2>, 1479 (kAdvanceRank == 0 ? 1 : 0), 1480 ThreadMap, 1481 AccessSize>; 1482 1483 using AccessType = typename UnderlyingIterator::AccessType; 1484 1485 /// Fragment object to be loaded or stored 1486 using Fragment = cutlass::Array< 1487 Element, 1488 ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; 1489 1490 /// Predicate vector stores mask to guard accesses 1491 using Mask = typename UnderlyingIterator::Mask; 1492 1493 /// Parameters object is precomputed state and is host-constructible 1494 class Params { 1495 private: 1496 friend PredicatedTileIteratorResidualLast; 1497 1498 /// Parameters object 1499 typename UnderlyingIterator::Params params_; 1500 1501 public: 1502 CUTLASS_HOST_DEVICE Params()1503 Params() {} 1504 1505 /// Construct the Params object given an AffineRankN<2> tensor's layout 1506 CUTLASS_HOST_DEVICE Params(Layout const & layout)1507 Params(Layout const& layout) 1508 : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} 1509 }; 1510 1511 private: 1512 // 1513 // Data members 1514 // 1515 1516 /// Underlying AffineRankN<2> tile iterator 1517 UnderlyingIterator iterator_; 1518 1519 public: 1520 /// Constructs a TileIterator from its precomputed state, threadblock offset, 1521 /// and thread ID 1522 CUTLASS_HOST_DEVICE 1523 PredicatedTileIteratorResidualLast( 1524 Params const& params, ///< Precomputed parameters object 1525 Pointer pointer, ///< Pointer to start of tensor 1526 TensorCoord extent, ///< Extent of tensor 1527 int thread_id, ///< ID of each participating thread 1528 TensorCoord const& threadblock_offset, ///< Initial offset of threadblock 1529 int const* indices = 1530 nullptr ///< gather/scatter indices, note no support for 1531 ///< gather/scatter at this specialization 1532 ) 1533 : iterator_( 1534 params.params_, 1535 pointer, 1536 layout::PitchLinearCoord(extent.column(), extent.row()), 1537 thread_id, 1538 layout::PitchLinearCoord( 1539 threadblock_offset.column(), 1540 threadblock_offset.row())) {} 1541 1542 /// Construct a PredicatedTileIteratorResidualLast with zero threadblock 1543 /// offset 1544 CUTLASS_HOST_DEVICE PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)1545 PredicatedTileIteratorResidualLast( 1546 Params const& params, ///< Precomputed parameters object 1547 Pointer pointer, ///< Pointer to start of tensor 1548 TensorCoord extent, ///< Extent of tensor 1549 int thread_id ///< ID of each participating thread 1550 ) 1551 : PredicatedTileIteratorResidualLast( 1552 params, 1553 pointer, 1554 extent, 1555 thread_id, 1556 make_Coord(0, 0)) {} 1557 1558 /// Adds a pointer offset in units of Element 1559 CUTLASS_HOST_DEVICE add_pointer_offset(LongIndex pointer_offset)1560 void add_pointer_offset(LongIndex pointer_offset) { 1561 iterator_.add_pointer_offset(pointer_offset); 1562 } 1563 1564 /// Advances to the next tile in memory. 1565 /// 1566 /// The first time this method is called, predicates are updated, and the 1567 /// iterator's internal pointer is reverted to the first "steady state" tile. 1568 /// Subsequent calls are lightweight and must only update the internal 1569 /// pointer. 1570 CUTLASS_HOST_DEVICE 1571 PredicatedTileIteratorResidualLast& operator++() { 1572 ++iterator_; 1573 return *this; 1574 } 1575 1576 /// Advances to the next tile in memory. 1577 /// 1578 /// The first time this method is called, predicates are updated, and the 1579 /// iterator's internal pointer is reverted to the first "steady state" tile. 1580 /// Subsequent calls are lightweight and must only update the internal 1581 /// pointer. 1582 CUTLASS_HOST_DEVICE 1583 PredicatedTileIteratorResidualLast operator++(int) { 1584 PredicatedTileIteratorResidualLast self(*this); 1585 operator++(); 1586 return self; 1587 } 1588 1589 /// Clears the predicate set efficiently 1590 CUTLASS_HOST_DEVICE 1591 void clear_mask(bool enable = true) { 1592 iterator_.clear_mask(enable); 1593 } 1594 1595 CUTLASS_HOST_DEVICE set_residual_tile(bool enable)1596 void set_residual_tile(bool enable) { 1597 iterator_.set_residual_tile(enable); 1598 } 1599 1600 /// Clears the predicate set efficiently 1601 CUTLASS_HOST_DEVICE enable_mask()1602 void enable_mask() { 1603 iterator_.enable_mask(); 1604 } 1605 1606 /// Sets the predicate mask, overriding value stored in predicate iterator 1607 CUTLASS_HOST_DEVICE set_mask(Mask const & mask)1608 void set_mask(Mask const& mask) { 1609 iterator_.set_mask(mask); 1610 } 1611 1612 /// Gets the mask 1613 CUTLASS_HOST_DEVICE get_mask(Mask & mask)1614 void get_mask(Mask& mask) { 1615 iterator_.get_mask(mask); 1616 } 1617 1618 /// Loads a fragment from memory 1619 CUTLASS_DEVICE load_with_pointer_offset(Fragment & frag,Index pointer_offset)1620 void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { 1621 iterator_.load_with_pointer_offset(frag, pointer_offset); 1622 } 1623 1624 /// Loads a fragment from memory 1625 CUTLASS_DEVICE load_with_byte_offset(Fragment & frag,LongIndex byte_offset)1626 void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { 1627 iterator_.load_with_byte_offset(frag, byte_offset); 1628 } 1629 1630 /// Loads a fragment from memory 1631 CUTLASS_DEVICE load(Fragment & frag)1632 void load(Fragment& frag) { 1633 load_with_pointer_offset(frag, 0); 1634 } 1635 1636 /// Store a fragment to memory 1637 CUTLASS_DEVICE store_with_pointer_offset(Fragment const & frag,Index pointer_offset)1638 void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { 1639 iterator_.store_with_pointer_offset(frag, pointer_offset); 1640 } 1641 1642 /// Store a fragment to memory 1643 CUTLASS_DEVICE store_with_byte_offset(Fragment const & frag,LongIndex byte_offset)1644 void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { 1645 iterator_.store_with_byte_offset(frag, byte_offset); 1646 } 1647 1648 /// Store a fragment to memory 1649 CUTLASS_DEVICE store(Fragment const & frag)1650 void store(Fragment const& frag) { 1651 store_with_pointer_offset(frag, 0); 1652 } 1653 }; 1654 1655 //////////////////////////////////////////////////////////////////////////////// 1656 1657 /// Specialization of PredicatedTileIteratorResidualLast for interleaved data. 1658 /// It is mapped to the congruous layout. 1659 /// 1660 /// Satisfies: ForwardTileIteratorConcept | 1661 /// ReadableContiguousTileIteratorConcept | 1662 /// WriteableContiguousTileIteratorConcept | 1663 /// MaskedTileIteratorConcept 1664 /// 1665 1666 template < 1667 typename Shape_, 1668 typename Element_, 1669 int AdvanceRank, 1670 typename ThreadMap_, 1671 int AccessSize, 1672 int InterleavedK> 1673 class PredicatedTileIteratorResidualLast< 1674 Shape_, 1675 Element_, 1676 layout::ColumnMajorInterleaved<InterleavedK>, 1677 AdvanceRank, 1678 ThreadMap_, 1679 AccessSize, 1680 false> { 1681 public: 1682 static_assert( 1683 AdvanceRank == 0 || AdvanceRank == 1, 1684 "Specialization for pitch-linear iterator may along advance along the " 1685 "contiguous(rank=0) or strided(rank=1) dimension."); 1686 1687 using Shape = Shape_; 1688 using Element = Element_; 1689 static int const kInterleavedK = InterleavedK; 1690 using Layout = layout::ColumnMajorInterleaved<kInterleavedK>; 1691 static int const kAdvanceRank = AdvanceRank; 1692 using ThreadMap = ThreadMap_; 1693 1694 using Index = typename Layout::Index; 1695 using LongIndex = typename Layout::LongIndex; 1696 1697 using TensorRef = TensorRef<Element, Layout>; 1698 using TensorView = TensorView<Element, Layout>; 1699 using TensorCoord = typename Layout::TensorCoord; 1700 1701 using Pointer = Element*; 1702 using NonConstPointer = typename platform::remove_const<Element>::type*; 1703 1704 using UnderlyingIterator = PredicatedTileIteratorResidualLast< 1705 layout::PitchLinearShape< 1706 Shape::kRow * kInterleavedK, 1707 Shape::kColumn / kInterleavedK>, 1708 Element, 1709 layout::PitchLinear, 1710 (kAdvanceRank == 0 ? 0 : 1), 1711 ThreadMap, 1712 AccessSize>; 1713 1714 using AccessType = typename UnderlyingIterator::AccessType; 1715 1716 /// Fragment object to be loaded or stored 1717 using Fragment = cutlass::Array< 1718 Element, 1719 ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; 1720 1721 /// Predicate vector stores mask to guard accesses 1722 using Mask = typename UnderlyingIterator::Mask; 1723 1724 /// Parameters object is precomputed state and is host-constructible 1725 class Params { 1726 private: 1727 friend PredicatedTileIteratorResidualLast; 1728 1729 /// Parameters object 1730 typename UnderlyingIterator::Params params_; 1731 1732 public: 1733 CUTLASS_HOST_DEVICE Params()1734 Params() {} 1735 1736 /// Construct the Params object given a pitch-linear tensor's layout 1737 CUTLASS_HOST_DEVICE Params(Layout const & layout)1738 Params(Layout const& layout) 1739 : params_(layout::PitchLinear(layout.stride(0))) {} 1740 1741 CUTLASS_HOST_DEVICE Params(typename UnderlyingIterator::Params::Base const & base)1742 Params(typename UnderlyingIterator::Params::Base const& base) 1743 : params_(base) {} 1744 }; 1745 1746 private: 1747 // 1748 // Data members 1749 // 1750 1751 /// Underlying pitch-linear tile iterator 1752 UnderlyingIterator iterator_; 1753 1754 public: 1755 /// Constructs a TileIterator from its precomputed state, threadblock offset, 1756 /// and thread ID 1757 CUTLASS_HOST_DEVICE 1758 PredicatedTileIteratorResidualLast( 1759 /// Precomputed parameters object 1760 Params const& params, 1761 /// Pointer to start of tensor 1762 Pointer pointer, 1763 /// Extent of tensor 1764 TensorCoord extent, 1765 /// ID of each participating thread 1766 int thread_id, 1767 /// Initial offset of threadblock 1768 TensorCoord const& threadblock_offset, 1769 int const* indices = 1770 nullptr ///< gather/scatter indices, note no support for 1771 ///< gather/scatter at this specialization 1772 ) 1773 : iterator_( 1774 params.params_, 1775 pointer, 1776 layout::PitchLinearCoord( 1777 extent.row() * kInterleavedK, 1778 extent.column() / kInterleavedK), 1779 thread_id, 1780 layout::PitchLinearCoord( 1781 threadblock_offset.row() * kInterleavedK, 1782 threadblock_offset.column() / kInterleavedK)) {} 1783 1784 /// Construct a PredicatedTileIteratorResidualLast with zero threadblock 1785 /// offset 1786 CUTLASS_HOST_DEVICE PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)1787 PredicatedTileIteratorResidualLast( 1788 Params const& params, ///< Precomputed parameters object 1789 Pointer pointer, ///< Pointer to start of tensor 1790 TensorCoord extent, ///< Extent of tensor 1791 int thread_id ///< ID of each participating thread 1792 ) 1793 : PredicatedTileIteratorResidualLast( 1794 params, 1795 pointer, 1796 extent, 1797 thread_id, 1798 make_Coord(0, 0)) {} 1799 1800 /// Adds a pointer offset in units of Element 1801 CUTLASS_HOST_DEVICE add_pointer_offset(LongIndex pointer_offset)1802 void add_pointer_offset(LongIndex pointer_offset) { 1803 iterator_.add_pointer_offset(pointer_offset); 1804 } 1805 1806 /// Advances to the next tile in memory. 1807 /// 1808 /// The first time this method is called, predicates are updated, and the 1809 /// iterator's internal pointer is reverted to the first "steady state" tile. 1810 /// Subsequent calls are lightweight and must only update the internal 1811 /// pointer. 1812 CUTLASS_HOST_DEVICE 1813 PredicatedTileIteratorResidualLast& operator++() { 1814 ++iterator_; 1815 return *this; 1816 } 1817 1818 /// Advances to the next tile in memory. 1819 /// 1820 /// The first time this method is called, predicates are updated, and the 1821 /// iterator's internal pointer is reverted to the first "steady state" tile. 1822 /// Subsequent calls are lightweight and must only update the internal 1823 /// pointer. 1824 CUTLASS_HOST_DEVICE 1825 PredicatedTileIteratorResidualLast operator++(int) { 1826 PredicatedTileIteratorResidualLast self(*this); 1827 operator++(); 1828 return self; 1829 } 1830 1831 /// Clears the predicate set efficiently 1832 CUTLASS_HOST_DEVICE 1833 void clear_mask(bool enable = true) { 1834 iterator_.clear_mask(enable); 1835 } 1836 1837 CUTLASS_HOST_DEVICE set_residual_tile(bool enable)1838 void set_residual_tile(bool enable) { 1839 iterator_.set_residual_tile(enable); 1840 } 1841 1842 /// Clears the predicate set efficiently 1843 CUTLASS_HOST_DEVICE enable_mask()1844 void enable_mask() { 1845 iterator_.enable_mask(); 1846 } 1847 1848 /// Sets the predicate mask, overriding value stored in predicate iterator 1849 CUTLASS_HOST_DEVICE set_mask(Mask const & mask)1850 void set_mask(Mask const& mask) { 1851 iterator_.set_mask(mask); 1852 } 1853 1854 /// Gets the mask 1855 CUTLASS_HOST_DEVICE get_mask(Mask & mask)1856 void get_mask(Mask& mask) { 1857 iterator_.get_mask(mask); 1858 } 1859 1860 /// Loads a fragment from memory 1861 CUTLASS_DEVICE load_with_pointer_offset(Fragment & frag,Index pointer_offset)1862 void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { 1863 iterator_.load_with_pointer_offset(frag, pointer_offset); 1864 } 1865 1866 /// Loads a fragment from memory 1867 CUTLASS_DEVICE load(Fragment & frag)1868 void load(Fragment& frag) { 1869 load_with_pointer_offset(frag, 0); 1870 } 1871 1872 /// Store a fragment to memory 1873 CUTLASS_DEVICE store_with_pointer_offset(Fragment const & frag,Index pointer_offset)1874 void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { 1875 iterator_.store_with_pointer_offset(frag, pointer_offset); 1876 } 1877 1878 /// Store a fragment to memory 1879 CUTLASS_DEVICE store(Fragment const & frag)1880 void store(Fragment const& frag) { 1881 store_with_pointer_offset(frag, 0); 1882 } 1883 }; 1884 1885 //////////////////////////////////////////////////////////////////////////////// 1886 1887 /// Specialization of PredicatedTileIteratorResidualLast for interleaved-32 1888 /// data. It is mapped to the congruous layout. 1889 /// 1890 /// Satisfies: ForwardTileIteratorConcept | 1891 /// ReadableContiguousTileIteratorConcept | 1892 /// WriteableContiguousTileIteratorConcept | 1893 /// MaskedTileIteratorConcept 1894 /// 1895 template < 1896 typename Shape_, 1897 typename Element_, 1898 int AdvanceRank, 1899 typename ThreadMap_, 1900 int AccessSize, 1901 int InterleavedK> 1902 class PredicatedTileIteratorResidualLast< 1903 Shape_, 1904 Element_, 1905 layout::RowMajorInterleaved<InterleavedK>, 1906 AdvanceRank, 1907 ThreadMap_, 1908 AccessSize, 1909 false> { 1910 public: 1911 static_assert( 1912 AdvanceRank == 0 || AdvanceRank == 1, 1913 "Specialization for pitch-linear iterator may along advance along the " 1914 "contiguous(rank=0) or strided(rank=1) dimension."); 1915 1916 using Shape = Shape_; 1917 using Element = Element_; 1918 static int const kInterleavedK = InterleavedK; 1919 using Layout = layout::RowMajorInterleaved<kInterleavedK>; 1920 static int const kAdvanceRank = AdvanceRank; 1921 using ThreadMap = ThreadMap_; 1922 1923 using Index = typename Layout::Index; 1924 using LongIndex = typename Layout::LongIndex; 1925 1926 using TensorRef = TensorRef<Element, Layout>; 1927 using TensorView = TensorView<Element, Layout>; 1928 using TensorCoord = typename Layout::TensorCoord; 1929 1930 using Pointer = Element*; 1931 using NonConstPointer = typename platform::remove_const<Element>::type*; 1932 1933 using UnderlyingIterator = PredicatedTileIteratorResidualLast< 1934 layout::PitchLinearShape< 1935 Shape::kColumn * kInterleavedK, 1936 Shape::kRow / kInterleavedK>, 1937 Element, 1938 layout::PitchLinear, 1939 (kAdvanceRank == 0 ? 1 : 0), 1940 ThreadMap, 1941 AccessSize>; 1942 1943 using AccessType = typename UnderlyingIterator::AccessType; 1944 1945 /// Fragment object to be loaded or stored 1946 using Fragment = cutlass::Array< 1947 Element, 1948 ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; 1949 1950 /// Predicate vector stores mask to guard accesses 1951 using Mask = typename UnderlyingIterator::Mask; 1952 1953 /// Parameters object is precomputed state and is host-constructible 1954 class Params { 1955 private: 1956 friend PredicatedTileIteratorResidualLast; 1957 1958 /// Parameters object 1959 typename UnderlyingIterator::Params params_; 1960 1961 public: 1962 CUTLASS_HOST_DEVICE Params()1963 Params() {} 1964 1965 /// Construct the Params object given a pitch-linear tensor's layout 1966 CUTLASS_HOST_DEVICE Params(Layout const & layout)1967 Params(Layout const& layout) 1968 : params_(layout::PitchLinear(layout.stride(0))) {} 1969 1970 CUTLASS_HOST_DEVICE Params(typename UnderlyingIterator::Params::Base const & base)1971 Params(typename UnderlyingIterator::Params::Base const& base) 1972 : params_(base) {} 1973 }; 1974 1975 private: 1976 // 1977 // Data members 1978 // 1979 1980 /// Underlying pitch-linear tile iterator 1981 UnderlyingIterator iterator_; 1982 1983 public: 1984 /// Constructs a TileIterator from its precomputed state, threadblock offset, 1985 /// and thread ID 1986 CUTLASS_HOST_DEVICE 1987 PredicatedTileIteratorResidualLast( 1988 /// Precomputed parameters object 1989 Params const& params, 1990 /// Pointer to start of tensor 1991 Pointer pointer, 1992 /// Extent of tensor 1993 TensorCoord extent, 1994 /// ID of each participating thread 1995 int thread_id, 1996 /// Initial offset of threadblock 1997 TensorCoord const& threadblock_offset, 1998 int const* indices = 1999 nullptr ///< gather/scatter indices, note no support for 2000 ///< gather/scatter at this specialization 2001 ) 2002 : iterator_( 2003 params.params_, 2004 pointer, 2005 layout::PitchLinearCoord( 2006 extent.column() * kInterleavedK, 2007 extent.row() / kInterleavedK), 2008 thread_id, 2009 layout::PitchLinearCoord( 2010 threadblock_offset.column() * kInterleavedK, 2011 threadblock_offset.row() / kInterleavedK)) {} 2012 2013 /// Construct a PredicatedTileIteratorResidualLast with zero threadblock 2014 /// offset 2015 CUTLASS_HOST_DEVICE PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)2016 PredicatedTileIteratorResidualLast( 2017 Params const& params, ///< Precomputed parameters object 2018 Pointer pointer, ///< Pointer to start of tensor 2019 TensorCoord extent, ///< Extent of tensor 2020 int thread_id ///< ID of each participating thread 2021 ) 2022 : PredicatedTileIteratorResidualLast( 2023 params, 2024 pointer, 2025 extent, 2026 thread_id, 2027 make_Coord(0, 0)) {} 2028 2029 /// Adds a pointer offset in units of Element 2030 CUTLASS_HOST_DEVICE add_pointer_offset(LongIndex pointer_offset)2031 void add_pointer_offset(LongIndex pointer_offset) { 2032 iterator_.add_pointer_offset(pointer_offset); 2033 } 2034 2035 /// Advances to the next tile in memory. 2036 /// 2037 /// The first time this method is called, predicates are updated, and the 2038 /// iterator's internal pointer is reverted to the first "steady state" tile. 2039 /// Subsequent calls are lightweight and must only update the internal 2040 /// pointer. 2041 CUTLASS_HOST_DEVICE 2042 PredicatedTileIteratorResidualLast& operator++() { 2043 ++iterator_; 2044 return *this; 2045 } 2046 2047 /// Advances to the next tile in memory. 2048 /// 2049 /// The first time this method is called, predicates are updated, and the 2050 /// iterator's internal pointer is reverted to the first "steady state" tile. 2051 /// Subsequent calls are lightweight and must only update the internal 2052 /// pointer. 2053 CUTLASS_HOST_DEVICE 2054 PredicatedTileIteratorResidualLast operator++(int) { 2055 PredicatedTileIteratorResidualLast self(*this); 2056 operator++(); 2057 return self; 2058 } 2059 2060 /// Clears the predicate set efficiently 2061 CUTLASS_HOST_DEVICE 2062 void clear_mask(bool enable = true) { 2063 iterator_.clear_mask(enable); 2064 } 2065 2066 CUTLASS_HOST_DEVICE set_residual_tile(bool enable)2067 void set_residual_tile(bool enable) { 2068 iterator_.set_residual_tile(enable); 2069 } 2070 2071 /// Clears the predicate set efficiently 2072 CUTLASS_HOST_DEVICE enable_mask()2073 void enable_mask() { 2074 iterator_.enable_mask(); 2075 } 2076 2077 /// Sets the predicate mask, overriding value stored in predicate iterator 2078 CUTLASS_HOST_DEVICE set_mask(Mask const & mask)2079 void set_mask(Mask const& mask) { 2080 iterator_.set_mask(mask); 2081 } 2082 2083 /// Gets the mask 2084 CUTLASS_HOST_DEVICE get_mask(Mask & mask)2085 void get_mask(Mask& mask) { 2086 iterator_.get_mask(mask); 2087 } 2088 2089 /// Loads a fragment from memory 2090 CUTLASS_DEVICE load_with_pointer_offset(Fragment & frag,Index pointer_offset)2091 void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { 2092 iterator_.load_with_pointer_offset(frag, pointer_offset); 2093 } 2094 2095 /// Loads a fragment from memory 2096 CUTLASS_DEVICE load(Fragment & frag)2097 void load(Fragment& frag) { 2098 load_with_pointer_offset(frag, 0); 2099 } 2100 2101 /// Store a fragment to memory 2102 CUTLASS_DEVICE store_with_pointer_offset(Fragment const & frag,Index pointer_offset)2103 void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { 2104 iterator_.store_with_pointer_offset(frag, pointer_offset); 2105 } 2106 2107 /// Store a fragment to memory 2108 CUTLASS_DEVICE store(Fragment const & frag)2109 void store(Fragment const& frag) { 2110 store_with_pointer_offset(frag, 0); 2111 } 2112 }; 2113 2114 //////////////////////////////////////////////////////////////////////////////// 2115 2116 } // namespace threadblock 2117 } // namespace transform 2118 } // namespace cutlass 2119 2120 //////////////////////////////////////////////////////////////////////////////// 2121