1 /*************************************************************************************************** 2 * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights 3 *reserved. SPDX-License-Identifier: BSD-3-Clause 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright notice, 9 *this list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright notice, 12 * this list of conditions and the following disclaimer in the documentation 13 * and/or other materials provided with the distribution. 14 * 15 * 3. Neither the name of the copyright holder nor the names of its 16 * contributors may be used to endorse or promote products derived from 17 * this software without specific prior written permission. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 22 *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 23 *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 24 *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 25 *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 26 *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 28 *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 29 *POSSIBILITY OF SUCH DAMAGE. 30 * 31 **************************************************************************************************/ 32 /*! \file 33 \brief Epilogue iterator that supports prefetching 34 35 Mostly copied from <cutlass/epilogue/threadblock/predicated_tile_iterator.h> 36 */ 37 38 #pragma once 39 40 #include <cutlass/arch/arch.h> 41 #include <cutlass/arch/memory.h> 42 #include <cutlass/array.h> 43 #include <cutlass/cutlass.h> 44 #include <cutlass/epilogue/threadblock/output_tile_thread_map.h> 45 #include <cutlass/epilogue/threadblock/predicated_tile_iterator_params.h> 46 #include <cutlass/layout/matrix.h> 47 #include <cutlass/layout/tensor.h> 48 #include <cutlass/matrix_shape.h> 49 #include <cutlass/numeric_types.h> 50 #include <cutlass/tensor_ref.h> 51 #include <cutlass/transform/pitch_linear_thread_map.h> 52 53 //////////////////////////////////////////////////////////////////////////////// 54 55 namespace cutlass { 56 57 //////////////////////////////////////////////////////////////////////////////// 58 59 namespace epilogue { 60 namespace threadblock { 61 62 //////////////////////////////////////////////////////////////////////////////// 63 64 /// Tile iterator used to load and store output tile from global memory in 65 /// epilogue. 66 /// 67 /// Satisfies: ReadableTileIterator | PredicatedTileIterator | 68 /// ForwardTileIterator 69 /// 70 template < 71 typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) 72 typename Element_, ///< Element data type 73 bool ScatterD = false, ///< Scatter D operand or not 74 bool UseCUDAStore = false> 75 class PredicatedTileIteratorPrefetch { 76 public: 77 using ThreadMap = ThreadMap_; 78 using Shape = typename ThreadMap::Shape; 79 80 using Element = Element_; 81 82 using Layout = layout::RowMajor; 83 using TensorRef = TensorRef<Element, Layout>; 84 using ConstTensorRef = typename TensorRef::ConstTensorRef; 85 86 using Index = typename Layout::Index; 87 using LongIndex = typename Layout::LongIndex; 88 using TensorCoord = MatrixCoord; 89 90 static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; 91 static int const kThreads = ThreadMap::kThreads; 92 static int const kIterations = ThreadMap::Count::kTile; 93 94 static_assert( 95 ThreadMap::Iterations::kRow > 0, 96 "ThreadMap::Iterations::kRow must be > 0"); 97 static_assert( 98 ThreadMap::Iterations::kGroup > 0, 99 "ThreadMap::Iterations::kGroup must be > 0"); 100 static_assert( 101 ThreadMap::Iterations::kCluster > 0, 102 "ThreadMap::Iterations::kCluster must be > 0"); 103 static_assert( 104 ThreadMap::Iterations::kColumn > 0, 105 "ThreadMap::Iterations::kColumn must be > 0"); 106 107 /// Fragment object 108 using Fragment = Array< 109 Element, 110 ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * 111 ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * 112 ThreadMap::kElementsPerAccess>; 113 114 /// Memory access size 115 using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>; 116 117 // 118 // Parameters struct 119 // 120 121 /// Uses a non-template class 122 struct Params : PredicatedTileIteratorParams { 123 using Base = PredicatedTileIteratorParams; 124 125 CUTLASS_HOST_DEVICE ParamsParams126 Params() {} 127 128 CUTLASS_HOST_DEVICE ParamsParams129 Params(Layout const& layout) 130 : PredicatedTileIteratorParams( 131 layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, 132 make_OutputTileThreadMapDesc<ThreadMap>()) {} 133 134 CUTLASS_HOST_DEVICE ParamsParams135 Params(Base const& base) : Base(base) {} 136 }; 137 138 /// Mask object 139 struct Mask { 140 static int const kCount = ThreadMap::Iterations::kColumn; 141 142 /// Predicate state 143 bool predicates[kCount]; 144 145 // 146 // Mask 147 // 148 CUTLASS_HOST_DEVICE MaskMask149 Mask() { 150 enable(); 151 } 152 153 ///< Efficiently disables all accesses guarded by mask clearMask154 CUTLASS_HOST_DEVICE void clear() { 155 CUTLASS_PRAGMA_UNROLL 156 for (int i = 0; i < kCount; ++i) { 157 predicates[i] = false; 158 } 159 } 160 161 ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask enableMask162 CUTLASS_DEVICE void enable() { 163 CUTLASS_PRAGMA_UNROLL 164 for (int i = 0; i < kCount; ++i) { 165 predicates[i] = true; 166 } 167 } 168 }; 169 170 private: 171 // 172 // Data members 173 // 174 175 /// Parameters structure containing reference and precomputed state. 176 PredicatedTileIteratorParams params_; 177 178 /// Byte-level pointer 179 uint8_t* byte_pointer_; 180 181 /// Array of boolean values to contain steady-state predicates 182 Mask mask_; 183 184 /// Extent of the matrix tile in rows 185 Index extent_row_; 186 187 /// Extent of the matrix tile in rows 188 Index extent_column_; 189 190 /// A thread's starting row position (assuming steady-state predicates have 191 /// been computed) 192 Index thread_start_row_; 193 194 /// A thread's starting column 195 Index thread_start_column_; 196 197 /// Internal state counter 198 int state_[3]; 199 200 /// Scatter indices 201 int const* indices_; 202 203 // 204 // Static asserts about internal strides 205 // 206 207 static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); 208 static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); 209 static_assert( 210 sizeof(PredicatedTileIteratorParams::stride) == 8, 211 "Expected 64b strides"); 212 213 private: 214 // 215 // Methods 216 // 217 218 public: 219 // 220 // Methods 221 // 222 223 /// Constructor 224 CUTLASS_DEVICE 225 PredicatedTileIteratorPrefetch( 226 PredicatedTileIteratorParams const& params, 227 Element* pointer, 228 TensorCoord extent, 229 int thread_idx, 230 TensorCoord threadblock_offset = TensorCoord(), 231 int const* indices = nullptr) params_(params)232 : params_(params), indices_(indices) { 233 TensorCoord thread_offset = 234 ThreadMap::initial_offset(thread_idx) + threadblock_offset; 235 236 extent_row_ = extent.row(); 237 extent_column_ = extent.column(); 238 239 thread_start_row_ = thread_offset.row(); 240 thread_start_column_ = thread_offset.column(); 241 242 // Initialize predicates 243 CUTLASS_PRAGMA_UNROLL 244 for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { 245 mask_.predicates[c] = 246 ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < 247 extent.column()); 248 } 249 250 // Null pointer performs no accesses 251 if (!pointer) { 252 mask_.clear(); 253 } 254 255 if (ScatterD && !indices) { 256 mask_.clear(); 257 } 258 259 // Initialize pointer 260 byte_pointer_ = reinterpret_cast<uint8_t*>(pointer) + 261 LongIndex(thread_offset.row()) * LongIndex(params_.stride) + 262 LongIndex(thread_offset.column()) * sizeof(AccessType) / 263 kElementsPerAccess; 264 265 if (ScatterD) { 266 byte_pointer_ = reinterpret_cast<uint8_t*>(pointer) + 267 LongIndex(thread_offset.column()) * sizeof(AccessType) / 268 kElementsPerAccess; 269 } 270 271 // Initialize internal state counter 272 state_[0] = state_[1] = state_[2] = 0; 273 } 274 275 /// Adds a pointer offset in units of Element 276 CUTLASS_HOST_DEVICE add_pointer_offset(LongIndex pointer_offset)277 void add_pointer_offset(LongIndex pointer_offset) { 278 byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8; 279 } 280 281 CUTLASS_DEVICE prefetch_all()282 void prefetch_all() { 283 CUTLASS_PRAGMA_UNROLL 284 for (int iter = 0; iter < kIterations; ++iter) { 285 prefetch(); 286 ++(*this); 287 } 288 } 289 290 CUTLASS_DEVICE prefetch()291 void prefetch() { 292 uint8_t* byte_pointer = byte_pointer_; 293 294 CUTLASS_PRAGMA_UNROLL 295 for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; 296 ++cluster) { 297 CUTLASS_PRAGMA_UNROLL 298 for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { 299 CUTLASS_PRAGMA_UNROLL 300 for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { 301 int row_offset = row * ThreadMap::Delta::kRow + 302 group * ThreadMap::Delta::kGroup + 303 cluster * ThreadMap::Delta::kCluster; 304 305 AccessType* memory_pointer = 306 reinterpret_cast<AccessType*>(byte_pointer); 307 308 CUTLASS_PRAGMA_UNROLL 309 for (int column = 0; column < ThreadMap::Iterations::kColumn; 310 ++column) { 311 // on windows using unsigned long here gives the error 312 // error: asm operand type size(4) does not match 313 // type/size implied by constraint 'l' 314 uint64_t addr = (uint64_t)((void*)&memory_pointer 315 [column * ThreadMap::Delta::kColumn / 316 kElementsPerAccess]); 317 asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); 318 } 319 320 if (row + 1 < ThreadMap::Iterations::kRow) { 321 if (!ScatterD) { 322 byte_pointer += params_.increment_row; 323 } 324 } 325 } 326 327 if (group + 1 < ThreadMap::Iterations::kGroup) { 328 byte_pointer += params_.increment_group; 329 } 330 } 331 332 if (cluster + 1 < ThreadMap::Iterations::kCluster) { 333 byte_pointer += params_.increment_cluster; 334 } 335 } 336 } 337 338 /// Loads a fragment from memory 339 CUTLASS_DEVICE load_with_byte_offset(Fragment & frag,int64_t byte_offset)340 void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const { 341 uint8_t* byte_pointer = byte_pointer_; 342 AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); 343 344 CUTLASS_PRAGMA_UNROLL 345 for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; 346 ++cluster) { 347 CUTLASS_PRAGMA_UNROLL 348 for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { 349 CUTLASS_PRAGMA_UNROLL 350 for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { 351 int frag_row_idx = 352 (row + 353 ThreadMap::Iterations::kRow * 354 (group + ThreadMap::Iterations::kGroup * cluster)); 355 356 int row_offset = row * ThreadMap::Delta::kRow + 357 group * ThreadMap::Delta::kGroup + 358 cluster * ThreadMap::Delta::kCluster; 359 360 bool row_guard = ((row_offset + thread_start_row_) < extent_row_); 361 362 AccessType* memory_pointer = 363 reinterpret_cast<AccessType*>(byte_pointer + byte_offset); 364 365 if (ScatterD && row_guard) { 366 assert(indices_); 367 368 memory_pointer = reinterpret_cast<AccessType*>( 369 byte_pointer + byte_offset + 370 LongIndex(indices_[row_offset + thread_start_row_]) * 371 LongIndex(params_.stride)); 372 } 373 374 CUTLASS_PRAGMA_UNROLL 375 for (int column = 0; column < ThreadMap::Iterations::kColumn; 376 ++column) { 377 bool guard = row_guard && mask_.predicates[column]; 378 379 cutlass::arch::global_load<AccessType, sizeof(AccessType)>( 380 frag_ptr 381 [frag_row_idx * ThreadMap::Iterations::kColumn + column], 382 (void*)&memory_pointer 383 [column * ThreadMap::Delta::kColumn / kElementsPerAccess], 384 guard); 385 } 386 387 if (row + 1 < ThreadMap::Iterations::kRow) { 388 if (!ScatterD) { 389 byte_pointer += params_.increment_row; 390 } 391 } 392 } 393 394 if (group + 1 < ThreadMap::Iterations::kGroup) { 395 byte_pointer += params_.increment_group; 396 } 397 } 398 399 if (cluster + 1 < ThreadMap::Iterations::kCluster) { 400 byte_pointer += params_.increment_cluster; 401 } 402 } 403 } 404 405 /// Loads a fragment from memory 406 CUTLASS_DEVICE load(Fragment & frag)407 void load(Fragment& frag) const { 408 load_with_byte_offset(frag, 0); 409 } 410 411 /// Stores a fragment to memory 412 CUTLASS_DEVICE store_with_byte_offset(Fragment const & frag,int64_t byte_offset)413 void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const { 414 uint8_t* byte_pointer = byte_pointer_; 415 AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag); 416 417 CUTLASS_PRAGMA_UNROLL 418 for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; 419 ++cluster) { 420 CUTLASS_PRAGMA_UNROLL 421 for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { 422 CUTLASS_PRAGMA_UNROLL 423 for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { 424 int frag_row_idx = 425 (row + 426 ThreadMap::Iterations::kRow * 427 (group + ThreadMap::Iterations::kGroup * cluster)); 428 429 int row_offset = row * ThreadMap::Delta::kRow + 430 group * ThreadMap::Delta::kGroup + 431 cluster * ThreadMap::Delta::kCluster; 432 433 bool row_guard = ((row_offset + thread_start_row_) < extent_row_); 434 435 AccessType* memory_pointer = 436 reinterpret_cast<AccessType*>(byte_pointer + byte_offset); 437 438 if (ScatterD && row_guard) { 439 assert(indices_); 440 441 memory_pointer = reinterpret_cast<AccessType*>( 442 byte_pointer + byte_offset + 443 LongIndex(indices_[row_offset + thread_start_row_]) * 444 LongIndex(params_.stride)); 445 } 446 447 CUTLASS_PRAGMA_UNROLL 448 for (int column = 0; column < ThreadMap::Iterations::kColumn; 449 ++column) { 450 bool guard = row_guard && mask_.predicates[column]; 451 452 if (UseCUDAStore) { 453 if (guard) { 454 memory_pointer 455 [column * ThreadMap::Delta::kColumn / kElementsPerAccess] = 456 frag_ptr 457 [frag_row_idx * ThreadMap::Iterations::kColumn + 458 column]; 459 } 460 } else { 461 cutlass::arch::global_store<AccessType, sizeof(AccessType)>( 462 frag_ptr 463 [frag_row_idx * ThreadMap::Iterations::kColumn + column], 464 (void*)&memory_pointer 465 [column * ThreadMap::Delta::kColumn / kElementsPerAccess], 466 guard); 467 } 468 } 469 470 if (row + 1 < ThreadMap::Iterations::kRow) { 471 if (!ScatterD) { 472 byte_pointer += params_.increment_row; 473 } 474 } 475 } 476 477 if (group + 1 < ThreadMap::Iterations::kGroup) { 478 byte_pointer += params_.increment_group; 479 } 480 } 481 482 if (cluster + 1 < ThreadMap::Iterations::kCluster) { 483 byte_pointer += params_.increment_cluster; 484 } 485 } 486 } 487 488 /// Stores a fragment to memory 489 CUTLASS_DEVICE store(Fragment const & frag)490 void store(Fragment const& frag) const { 491 store_with_byte_offset(frag, 0); 492 } 493 494 /// Loads a fragment from memory 495 CUTLASS_DEVICE downsample_load_with_byte_offset(Fragment & frag,int64_t byte_offset,int convolution_P,int convolution_Q,int add_P,int add_Q,int problem_N)496 void downsample_load_with_byte_offset( 497 Fragment& frag, 498 int64_t byte_offset, 499 int convolution_P, 500 int convolution_Q, 501 int add_P, 502 int add_Q, 503 int problem_N) const { 504 uint8_t* byte_pointer = byte_pointer_; 505 AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); 506 507 CUTLASS_PRAGMA_UNROLL 508 for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; 509 ++cluster) { 510 CUTLASS_PRAGMA_UNROLL 511 for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { 512 CUTLASS_PRAGMA_UNROLL 513 for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { 514 int frag_row_idx = 515 (row + 516 ThreadMap::Iterations::kRow * 517 (group + ThreadMap::Iterations::kGroup * cluster)); 518 519 int row_offset = row * ThreadMap::Delta::kRow + 520 group * ThreadMap::Delta::kGroup + 521 cluster * ThreadMap::Delta::kCluster; 522 523 bool row_guard = ((row_offset + thread_start_row_) < extent_row_); 524 525 int output_row = row_offset + thread_start_row_; 526 int output_N = output_row / (convolution_P * convolution_Q); 527 int output_PQ = output_row % (convolution_P * convolution_Q); 528 int output_P = output_PQ / convolution_Q; 529 int output_Q = output_PQ % convolution_Q; 530 531 int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + 532 (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; 533 534 int64_t byte_offset = 535 (input_row - output_row) * problem_N * sizeof(float); 536 537 AccessType* memory_pointer = 538 reinterpret_cast<AccessType*>(byte_pointer + byte_offset); 539 540 CUTLASS_PRAGMA_UNROLL 541 for (int column = 0; column < ThreadMap::Iterations::kColumn; 542 ++column) { 543 bool guard = row_guard && mask_.predicates[column]; 544 545 cutlass::arch::global_load<AccessType, sizeof(AccessType)>( 546 frag_ptr 547 [frag_row_idx * ThreadMap::Iterations::kColumn + column], 548 (void*)&memory_pointer 549 [column * ThreadMap::Delta::kColumn / kElementsPerAccess], 550 guard); 551 } 552 553 if (row + 1 < ThreadMap::Iterations::kRow) { 554 byte_pointer += params_.increment_row; 555 } 556 } 557 558 if (group + 1 < ThreadMap::Iterations::kGroup) { 559 byte_pointer += params_.increment_group; 560 } 561 } 562 563 if (cluster + 1 < ThreadMap::Iterations::kCluster) { 564 byte_pointer += params_.increment_cluster; 565 } 566 } 567 } 568 569 /// Loads a fragment from memory 570 CUTLASS_DEVICE upsample_load_with_byte_offset(Fragment & frag,int64_t byte_offset,int convolution_P,int convolution_Q,int add_P,int add_Q,int problem_N)571 void upsample_load_with_byte_offset( 572 Fragment& frag, 573 int64_t byte_offset, 574 int convolution_P, 575 int convolution_Q, 576 int add_P, 577 int add_Q, 578 int problem_N) const { 579 uint8_t* byte_pointer = byte_pointer_; 580 AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag); 581 582 CUTLASS_PRAGMA_UNROLL 583 for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; 584 ++cluster) { 585 CUTLASS_PRAGMA_UNROLL 586 for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { 587 CUTLASS_PRAGMA_UNROLL 588 for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { 589 int frag_row_idx = 590 (row + 591 ThreadMap::Iterations::kRow * 592 (group + ThreadMap::Iterations::kGroup * cluster)); 593 594 int row_offset = row * ThreadMap::Delta::kRow + 595 group * ThreadMap::Delta::kGroup + 596 cluster * ThreadMap::Delta::kCluster; 597 598 bool row_guard = ((row_offset + thread_start_row_) < extent_row_); 599 600 int output_row = row_offset + thread_start_row_; 601 int output_N = output_row / (convolution_P * convolution_Q); 602 int output_PQ = output_row % (convolution_P * convolution_Q); 603 int output_P = output_PQ / convolution_Q; 604 int output_Q = output_PQ % convolution_Q; 605 int row_add_P = add_P; 606 int row_add_Q = add_Q; 607 if (output_P > convolution_P - 2) 608 row_add_P = 0; 609 if (output_Q > convolution_Q - 2) 610 row_add_Q = 0; 611 612 int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + 613 ((output_P + row_add_P) / 2) * (convolution_Q / 2) + 614 (output_Q + row_add_Q) / 2; 615 616 int64_t byte_offset = 617 (input_row - output_row) * problem_N * sizeof(float); 618 619 AccessType* memory_pointer = 620 reinterpret_cast<AccessType*>(byte_pointer + byte_offset); 621 622 CUTLASS_PRAGMA_UNROLL 623 for (int column = 0; column < ThreadMap::Iterations::kColumn; 624 ++column) { 625 bool guard = row_guard && mask_.predicates[column]; 626 627 cutlass::arch::global_load<AccessType, sizeof(AccessType)>( 628 frag_ptr 629 [frag_row_idx * ThreadMap::Iterations::kColumn + column], 630 (void*)&memory_pointer 631 [column * ThreadMap::Delta::kColumn / kElementsPerAccess], 632 guard); 633 } 634 635 if (row + 1 < ThreadMap::Iterations::kRow) { 636 byte_pointer += params_.increment_row; 637 } 638 } 639 640 if (group + 1 < ThreadMap::Iterations::kGroup) { 641 byte_pointer += params_.increment_group; 642 } 643 } 644 645 if (cluster + 1 < ThreadMap::Iterations::kCluster) { 646 byte_pointer += params_.increment_cluster; 647 } 648 } 649 } 650 651 CUTLASS_DEVICE thread_start()652 MatrixCoord thread_start() const { 653 return MatrixCoord(thread_start_row_, thread_start_column_); 654 } 655 656 /// Need to get the thread start row from the tile iterator 657 CUTLASS_DEVICE thread_start_row()658 int32_t thread_start_row() const { 659 return thread_start_row_; 660 } 661 662 /// Need to get the thread start row from the tile iterator 663 CUTLASS_DEVICE thread_start_column()664 int32_t thread_start_column() const { 665 return thread_start_column_; 666 } 667 668 /// Extent of the matrix in rows 669 CUTLASS_DEVICE extent_row()670 Index extent_row() const { 671 return extent_row_; 672 } 673 674 /// Extent of the matrix in columns 675 CUTLASS_DEVICE extent_column()676 Index extent_column() const { 677 return extent_column_; 678 } 679 680 /// Advances to the next position to load or store 681 CUTLASS_HOST_DEVICE 682 PredicatedTileIteratorPrefetch& operator++() { 683 ++state_[0]; 684 685 if (!ScatterD) { 686 byte_pointer_ += params_.advance_row; 687 } 688 689 thread_start_row_ += ThreadMap::Shape::kRow; 690 691 if (state_[0] == ThreadMap::Count::kRow) { 692 state_[0] = 0; 693 ++state_[1]; 694 byte_pointer_ += params_.advance_group; 695 696 thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * 697 ThreadMap::Shape::kRow * ThreadMap::Count::kRow; 698 699 if (state_[1] == ThreadMap::Count::kGroup) { 700 state_[1] = 0; 701 ++state_[2]; 702 byte_pointer_ += params_.advance_cluster; 703 704 thread_start_row_ += ThreadMap::Count::kGroup * 705 ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * 706 ThreadMap::Shape::kRow; 707 708 if (state_[2] == ThreadMap::Count::kCluster) { 709 state_[2] = 0; 710 byte_pointer_ += params_.advance_tile; 711 } 712 } 713 } 714 715 return *this; 716 } 717 718 ///< Efficiently disables all accesses guarded by mask clear_mask()719 CUTLASS_DEVICE void clear_mask() { 720 mask_.clear(); 721 } 722 723 ///< Efficiently enables all accesses guarded by mask enable_mask()724 CUTLASS_DEVICE void enable_mask() { 725 mask_.enable(); 726 } 727 728 ///< Sets the mask get_mask(Mask & mask)729 CUTLASS_DEVICE void get_mask(Mask& mask) const { 730 mask = mask_; 731 } 732 733 ///< Sets the mask set_mask(Mask const & mask)734 CUTLASS_DEVICE void set_mask(Mask const& mask) { 735 mask_ = mask; 736 } 737 }; 738 739 template <typename IT> 740 struct MakePrefetchableIterator { 741 using Iterator = PredicatedTileIteratorPrefetch< 742 typename IT::ThreadMap, 743 typename IT::Element>; 744 }; 745 746 /////////////////////////////////////////////////////////////////////////////// 747 748 } // namespace threadblock 749 } // namespace epilogue 750 } // namespace cutlass 751 752 //////////////////////////////////////////////////////////////////////////////// 753