1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
3  *reserved. SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice,
9  *this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29  *POSSIBILITY OF SUCH DAMAGE.
30  *
31  **************************************************************************************************/
32 /*! \file
33   \brief Epilogue iterator that supports prefetching
34 
35   Mostly copied from <cutlass/epilogue/threadblock/predicated_tile_iterator.h>
36 */
37 
38 #pragma once
39 
40 #include <cutlass/arch/arch.h>
41 #include <cutlass/arch/memory.h>
42 #include <cutlass/array.h>
43 #include <cutlass/cutlass.h>
44 #include <cutlass/epilogue/threadblock/output_tile_thread_map.h>
45 #include <cutlass/epilogue/threadblock/predicated_tile_iterator_params.h>
46 #include <cutlass/layout/matrix.h>
47 #include <cutlass/layout/tensor.h>
48 #include <cutlass/matrix_shape.h>
49 #include <cutlass/numeric_types.h>
50 #include <cutlass/tensor_ref.h>
51 #include <cutlass/transform/pitch_linear_thread_map.h>
52 
53 ////////////////////////////////////////////////////////////////////////////////
54 
55 namespace cutlass {
56 
57 ////////////////////////////////////////////////////////////////////////////////
58 
59 namespace epilogue {
60 namespace threadblock {
61 
62 ////////////////////////////////////////////////////////////////////////////////
63 
64 /// Tile iterator used to load and store output tile from global memory in
65 /// epilogue.
66 ///
67 /// Satisfies: ReadableTileIterator | PredicatedTileIterator |
68 /// ForwardTileIterator
69 ///
70 template <
71     typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
72     typename Element_, ///< Element data type
73     bool ScatterD = false, ///< Scatter D operand or not
74     bool UseCUDAStore = false>
75 class PredicatedTileIteratorPrefetch {
76  public:
77   using ThreadMap = ThreadMap_;
78   using Shape = typename ThreadMap::Shape;
79 
80   using Element = Element_;
81 
82   using Layout = layout::RowMajor;
83   using TensorRef = TensorRef<Element, Layout>;
84   using ConstTensorRef = typename TensorRef::ConstTensorRef;
85 
86   using Index = typename Layout::Index;
87   using LongIndex = typename Layout::LongIndex;
88   using TensorCoord = MatrixCoord;
89 
90   static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
91   static int const kThreads = ThreadMap::kThreads;
92   static int const kIterations = ThreadMap::Count::kTile;
93 
94   static_assert(
95       ThreadMap::Iterations::kRow > 0,
96       "ThreadMap::Iterations::kRow must be > 0");
97   static_assert(
98       ThreadMap::Iterations::kGroup > 0,
99       "ThreadMap::Iterations::kGroup must be > 0");
100   static_assert(
101       ThreadMap::Iterations::kCluster > 0,
102       "ThreadMap::Iterations::kCluster must be > 0");
103   static_assert(
104       ThreadMap::Iterations::kColumn > 0,
105       "ThreadMap::Iterations::kColumn must be > 0");
106 
107   /// Fragment object
108   using Fragment = Array<
109       Element,
110       ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow *
111           ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster *
112           ThreadMap::kElementsPerAccess>;
113 
114   /// Memory access size
115   using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
116 
117   //
118   // Parameters struct
119   //
120 
121   /// Uses a non-template class
122   struct Params : PredicatedTileIteratorParams {
123     using Base = PredicatedTileIteratorParams;
124 
125     CUTLASS_HOST_DEVICE
ParamsParams126     Params() {}
127 
128     CUTLASS_HOST_DEVICE
ParamsParams129     Params(Layout const& layout)
130         : PredicatedTileIteratorParams(
131               layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
132               make_OutputTileThreadMapDesc<ThreadMap>()) {}
133 
134     CUTLASS_HOST_DEVICE
ParamsParams135     Params(Base const& base) : Base(base) {}
136   };
137 
138   /// Mask object
139   struct Mask {
140     static int const kCount = ThreadMap::Iterations::kColumn;
141 
142     /// Predicate state
143     bool predicates[kCount];
144 
145     //
146     // Mask
147     //
148     CUTLASS_HOST_DEVICE
MaskMask149     Mask() {
150       enable();
151     }
152 
153     ///< Efficiently disables all accesses guarded by mask
clearMask154     CUTLASS_HOST_DEVICE void clear() {
155       CUTLASS_PRAGMA_UNROLL
156       for (int i = 0; i < kCount; ++i) {
157         predicates[i] = false;
158       }
159     }
160 
161     ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask
enableMask162     CUTLASS_DEVICE void enable() {
163       CUTLASS_PRAGMA_UNROLL
164       for (int i = 0; i < kCount; ++i) {
165         predicates[i] = true;
166       }
167     }
168   };
169 
170  private:
171   //
172   // Data members
173   //
174 
175   /// Parameters structure containing reference and precomputed state.
176   PredicatedTileIteratorParams params_;
177 
178   /// Byte-level pointer
179   uint8_t* byte_pointer_;
180 
181   /// Array of boolean values to contain steady-state predicates
182   Mask mask_;
183 
184   /// Extent of the matrix tile in rows
185   Index extent_row_;
186 
187   /// Extent of the matrix tile in rows
188   Index extent_column_;
189 
190   /// A thread's starting row position (assuming steady-state predicates have
191   /// been computed)
192   Index thread_start_row_;
193 
194   /// A thread's starting column
195   Index thread_start_column_;
196 
197   /// Internal state counter
198   int state_[3];
199 
200   /// Scatter indices
201   int const* indices_;
202 
203   //
204   // Static asserts about internal strides
205   //
206 
207   static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
208   static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
209   static_assert(
210       sizeof(PredicatedTileIteratorParams::stride) == 8,
211       "Expected 64b strides");
212 
213  private:
214   //
215   // Methods
216   //
217 
218  public:
219   //
220   // Methods
221   //
222 
223   /// Constructor
224   CUTLASS_DEVICE
225   PredicatedTileIteratorPrefetch(
226       PredicatedTileIteratorParams const& params,
227       Element* pointer,
228       TensorCoord extent,
229       int thread_idx,
230       TensorCoord threadblock_offset = TensorCoord(),
231       int const* indices = nullptr)
params_(params)232       : params_(params), indices_(indices) {
233     TensorCoord thread_offset =
234         ThreadMap::initial_offset(thread_idx) + threadblock_offset;
235 
236     extent_row_ = extent.row();
237     extent_column_ = extent.column();
238 
239     thread_start_row_ = thread_offset.row();
240     thread_start_column_ = thread_offset.column();
241 
242     // Initialize predicates
243     CUTLASS_PRAGMA_UNROLL
244     for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
245       mask_.predicates[c] =
246           ((thread_offset.column() + ThreadMap::Delta::kColumn * c) <
247            extent.column());
248     }
249 
250     // Null pointer performs no accesses
251     if (!pointer) {
252       mask_.clear();
253     }
254 
255     if (ScatterD && !indices) {
256       mask_.clear();
257     }
258 
259     // Initialize pointer
260     byte_pointer_ = reinterpret_cast<uint8_t*>(pointer) +
261         LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
262         LongIndex(thread_offset.column()) * sizeof(AccessType) /
263             kElementsPerAccess;
264 
265     if (ScatterD) {
266       byte_pointer_ = reinterpret_cast<uint8_t*>(pointer) +
267           LongIndex(thread_offset.column()) * sizeof(AccessType) /
268               kElementsPerAccess;
269     }
270 
271     // Initialize internal state counter
272     state_[0] = state_[1] = state_[2] = 0;
273   }
274 
275   /// Adds a pointer offset in units of Element
276   CUTLASS_HOST_DEVICE
add_pointer_offset(LongIndex pointer_offset)277   void add_pointer_offset(LongIndex pointer_offset) {
278     byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
279   }
280 
281   CUTLASS_DEVICE
prefetch_all()282   void prefetch_all() {
283     CUTLASS_PRAGMA_UNROLL
284     for (int iter = 0; iter < kIterations; ++iter) {
285       prefetch();
286       ++(*this);
287     }
288   }
289 
290   CUTLASS_DEVICE
prefetch()291   void prefetch() {
292     uint8_t* byte_pointer = byte_pointer_;
293 
294     CUTLASS_PRAGMA_UNROLL
295     for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
296          ++cluster) {
297       CUTLASS_PRAGMA_UNROLL
298       for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
299         CUTLASS_PRAGMA_UNROLL
300         for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
301           int row_offset = row * ThreadMap::Delta::kRow +
302               group * ThreadMap::Delta::kGroup +
303               cluster * ThreadMap::Delta::kCluster;
304 
305           AccessType* memory_pointer =
306               reinterpret_cast<AccessType*>(byte_pointer);
307 
308           CUTLASS_PRAGMA_UNROLL
309           for (int column = 0; column < ThreadMap::Iterations::kColumn;
310                ++column) {
311             // on windows using unsigned long here gives the error
312             // error: asm operand type size(4) does not match
313             // type/size implied by constraint 'l'
314             uint64_t addr = (uint64_t)((void*)&memory_pointer
315                                            [column * ThreadMap::Delta::kColumn /
316                                             kElementsPerAccess]);
317             asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr));
318           }
319 
320           if (row + 1 < ThreadMap::Iterations::kRow) {
321             if (!ScatterD) {
322               byte_pointer += params_.increment_row;
323             }
324           }
325         }
326 
327         if (group + 1 < ThreadMap::Iterations::kGroup) {
328           byte_pointer += params_.increment_group;
329         }
330       }
331 
332       if (cluster + 1 < ThreadMap::Iterations::kCluster) {
333         byte_pointer += params_.increment_cluster;
334       }
335     }
336   }
337 
338   /// Loads a fragment from memory
339   CUTLASS_DEVICE
load_with_byte_offset(Fragment & frag,int64_t byte_offset)340   void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const {
341     uint8_t* byte_pointer = byte_pointer_;
342     AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
343 
344     CUTLASS_PRAGMA_UNROLL
345     for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
346          ++cluster) {
347       CUTLASS_PRAGMA_UNROLL
348       for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
349         CUTLASS_PRAGMA_UNROLL
350         for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
351           int frag_row_idx =
352               (row +
353                ThreadMap::Iterations::kRow *
354                    (group + ThreadMap::Iterations::kGroup * cluster));
355 
356           int row_offset = row * ThreadMap::Delta::kRow +
357               group * ThreadMap::Delta::kGroup +
358               cluster * ThreadMap::Delta::kCluster;
359 
360           bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
361 
362           AccessType* memory_pointer =
363               reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
364 
365           if (ScatterD && row_guard) {
366             assert(indices_);
367 
368             memory_pointer = reinterpret_cast<AccessType*>(
369                 byte_pointer + byte_offset +
370                 LongIndex(indices_[row_offset + thread_start_row_]) *
371                     LongIndex(params_.stride));
372           }
373 
374           CUTLASS_PRAGMA_UNROLL
375           for (int column = 0; column < ThreadMap::Iterations::kColumn;
376                ++column) {
377             bool guard = row_guard && mask_.predicates[column];
378 
379             cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
380                 frag_ptr
381                     [frag_row_idx * ThreadMap::Iterations::kColumn + column],
382                 (void*)&memory_pointer
383                     [column * ThreadMap::Delta::kColumn / kElementsPerAccess],
384                 guard);
385           }
386 
387           if (row + 1 < ThreadMap::Iterations::kRow) {
388             if (!ScatterD) {
389               byte_pointer += params_.increment_row;
390             }
391           }
392         }
393 
394         if (group + 1 < ThreadMap::Iterations::kGroup) {
395           byte_pointer += params_.increment_group;
396         }
397       }
398 
399       if (cluster + 1 < ThreadMap::Iterations::kCluster) {
400         byte_pointer += params_.increment_cluster;
401       }
402     }
403   }
404 
405   /// Loads a fragment from memory
406   CUTLASS_DEVICE
load(Fragment & frag)407   void load(Fragment& frag) const {
408     load_with_byte_offset(frag, 0);
409   }
410 
411   /// Stores a fragment to memory
412   CUTLASS_DEVICE
store_with_byte_offset(Fragment const & frag,int64_t byte_offset)413   void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const {
414     uint8_t* byte_pointer = byte_pointer_;
415     AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
416 
417     CUTLASS_PRAGMA_UNROLL
418     for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
419          ++cluster) {
420       CUTLASS_PRAGMA_UNROLL
421       for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
422         CUTLASS_PRAGMA_UNROLL
423         for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
424           int frag_row_idx =
425               (row +
426                ThreadMap::Iterations::kRow *
427                    (group + ThreadMap::Iterations::kGroup * cluster));
428 
429           int row_offset = row * ThreadMap::Delta::kRow +
430               group * ThreadMap::Delta::kGroup +
431               cluster * ThreadMap::Delta::kCluster;
432 
433           bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
434 
435           AccessType* memory_pointer =
436               reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
437 
438           if (ScatterD && row_guard) {
439             assert(indices_);
440 
441             memory_pointer = reinterpret_cast<AccessType*>(
442                 byte_pointer + byte_offset +
443                 LongIndex(indices_[row_offset + thread_start_row_]) *
444                     LongIndex(params_.stride));
445           }
446 
447           CUTLASS_PRAGMA_UNROLL
448           for (int column = 0; column < ThreadMap::Iterations::kColumn;
449                ++column) {
450             bool guard = row_guard && mask_.predicates[column];
451 
452             if (UseCUDAStore) {
453               if (guard) {
454                 memory_pointer
455                     [column * ThreadMap::Delta::kColumn / kElementsPerAccess] =
456                         frag_ptr
457                             [frag_row_idx * ThreadMap::Iterations::kColumn +
458                              column];
459               }
460             } else {
461               cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
462                   frag_ptr
463                       [frag_row_idx * ThreadMap::Iterations::kColumn + column],
464                   (void*)&memory_pointer
465                       [column * ThreadMap::Delta::kColumn / kElementsPerAccess],
466                   guard);
467             }
468           }
469 
470           if (row + 1 < ThreadMap::Iterations::kRow) {
471             if (!ScatterD) {
472               byte_pointer += params_.increment_row;
473             }
474           }
475         }
476 
477         if (group + 1 < ThreadMap::Iterations::kGroup) {
478           byte_pointer += params_.increment_group;
479         }
480       }
481 
482       if (cluster + 1 < ThreadMap::Iterations::kCluster) {
483         byte_pointer += params_.increment_cluster;
484       }
485     }
486   }
487 
488   /// Stores a fragment to memory
489   CUTLASS_DEVICE
store(Fragment const & frag)490   void store(Fragment const& frag) const {
491     store_with_byte_offset(frag, 0);
492   }
493 
494   /// Loads a fragment from memory
495   CUTLASS_DEVICE
downsample_load_with_byte_offset(Fragment & frag,int64_t byte_offset,int convolution_P,int convolution_Q,int add_P,int add_Q,int problem_N)496   void downsample_load_with_byte_offset(
497       Fragment& frag,
498       int64_t byte_offset,
499       int convolution_P,
500       int convolution_Q,
501       int add_P,
502       int add_Q,
503       int problem_N) const {
504     uint8_t* byte_pointer = byte_pointer_;
505     AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
506 
507     CUTLASS_PRAGMA_UNROLL
508     for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
509          ++cluster) {
510       CUTLASS_PRAGMA_UNROLL
511       for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
512         CUTLASS_PRAGMA_UNROLL
513         for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
514           int frag_row_idx =
515               (row +
516                ThreadMap::Iterations::kRow *
517                    (group + ThreadMap::Iterations::kGroup * cluster));
518 
519           int row_offset = row * ThreadMap::Delta::kRow +
520               group * ThreadMap::Delta::kGroup +
521               cluster * ThreadMap::Delta::kCluster;
522 
523           bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
524 
525           int output_row = row_offset + thread_start_row_;
526           int output_N = output_row / (convolution_P * convolution_Q);
527           int output_PQ = output_row % (convolution_P * convolution_Q);
528           int output_P = output_PQ / convolution_Q;
529           int output_Q = output_PQ % convolution_Q;
530 
531           int input_row = output_N * 2 * convolution_P * 2 * convolution_Q +
532               (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q;
533 
534           int64_t byte_offset =
535               (input_row - output_row) * problem_N * sizeof(float);
536 
537           AccessType* memory_pointer =
538               reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
539 
540           CUTLASS_PRAGMA_UNROLL
541           for (int column = 0; column < ThreadMap::Iterations::kColumn;
542                ++column) {
543             bool guard = row_guard && mask_.predicates[column];
544 
545             cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
546                 frag_ptr
547                     [frag_row_idx * ThreadMap::Iterations::kColumn + column],
548                 (void*)&memory_pointer
549                     [column * ThreadMap::Delta::kColumn / kElementsPerAccess],
550                 guard);
551           }
552 
553           if (row + 1 < ThreadMap::Iterations::kRow) {
554             byte_pointer += params_.increment_row;
555           }
556         }
557 
558         if (group + 1 < ThreadMap::Iterations::kGroup) {
559           byte_pointer += params_.increment_group;
560         }
561       }
562 
563       if (cluster + 1 < ThreadMap::Iterations::kCluster) {
564         byte_pointer += params_.increment_cluster;
565       }
566     }
567   }
568 
569   /// Loads a fragment from memory
570   CUTLASS_DEVICE
upsample_load_with_byte_offset(Fragment & frag,int64_t byte_offset,int convolution_P,int convolution_Q,int add_P,int add_Q,int problem_N)571   void upsample_load_with_byte_offset(
572       Fragment& frag,
573       int64_t byte_offset,
574       int convolution_P,
575       int convolution_Q,
576       int add_P,
577       int add_Q,
578       int problem_N) const {
579     uint8_t* byte_pointer = byte_pointer_;
580     AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
581 
582     CUTLASS_PRAGMA_UNROLL
583     for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
584          ++cluster) {
585       CUTLASS_PRAGMA_UNROLL
586       for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
587         CUTLASS_PRAGMA_UNROLL
588         for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
589           int frag_row_idx =
590               (row +
591                ThreadMap::Iterations::kRow *
592                    (group + ThreadMap::Iterations::kGroup * cluster));
593 
594           int row_offset = row * ThreadMap::Delta::kRow +
595               group * ThreadMap::Delta::kGroup +
596               cluster * ThreadMap::Delta::kCluster;
597 
598           bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
599 
600           int output_row = row_offset + thread_start_row_;
601           int output_N = output_row / (convolution_P * convolution_Q);
602           int output_PQ = output_row % (convolution_P * convolution_Q);
603           int output_P = output_PQ / convolution_Q;
604           int output_Q = output_PQ % convolution_Q;
605           int row_add_P = add_P;
606           int row_add_Q = add_Q;
607           if (output_P > convolution_P - 2)
608             row_add_P = 0;
609           if (output_Q > convolution_Q - 2)
610             row_add_Q = 0;
611 
612           int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) +
613               ((output_P + row_add_P) / 2) * (convolution_Q / 2) +
614               (output_Q + row_add_Q) / 2;
615 
616           int64_t byte_offset =
617               (input_row - output_row) * problem_N * sizeof(float);
618 
619           AccessType* memory_pointer =
620               reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
621 
622           CUTLASS_PRAGMA_UNROLL
623           for (int column = 0; column < ThreadMap::Iterations::kColumn;
624                ++column) {
625             bool guard = row_guard && mask_.predicates[column];
626 
627             cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
628                 frag_ptr
629                     [frag_row_idx * ThreadMap::Iterations::kColumn + column],
630                 (void*)&memory_pointer
631                     [column * ThreadMap::Delta::kColumn / kElementsPerAccess],
632                 guard);
633           }
634 
635           if (row + 1 < ThreadMap::Iterations::kRow) {
636             byte_pointer += params_.increment_row;
637           }
638         }
639 
640         if (group + 1 < ThreadMap::Iterations::kGroup) {
641           byte_pointer += params_.increment_group;
642         }
643       }
644 
645       if (cluster + 1 < ThreadMap::Iterations::kCluster) {
646         byte_pointer += params_.increment_cluster;
647       }
648     }
649   }
650 
651   CUTLASS_DEVICE
thread_start()652   MatrixCoord thread_start() const {
653     return MatrixCoord(thread_start_row_, thread_start_column_);
654   }
655 
656   /// Need to get the thread start row from the tile iterator
657   CUTLASS_DEVICE
thread_start_row()658   int32_t thread_start_row() const {
659     return thread_start_row_;
660   }
661 
662   /// Need to get the thread start row from the tile iterator
663   CUTLASS_DEVICE
thread_start_column()664   int32_t thread_start_column() const {
665     return thread_start_column_;
666   }
667 
668   /// Extent of the matrix in rows
669   CUTLASS_DEVICE
extent_row()670   Index extent_row() const {
671     return extent_row_;
672   }
673 
674   /// Extent of the matrix in columns
675   CUTLASS_DEVICE
extent_column()676   Index extent_column() const {
677     return extent_column_;
678   }
679 
680   /// Advances to the next position to load or store
681   CUTLASS_HOST_DEVICE
682   PredicatedTileIteratorPrefetch& operator++() {
683     ++state_[0];
684 
685     if (!ScatterD) {
686       byte_pointer_ += params_.advance_row;
687     }
688 
689     thread_start_row_ += ThreadMap::Shape::kRow;
690 
691     if (state_[0] == ThreadMap::Count::kRow) {
692       state_[0] = 0;
693       ++state_[1];
694       byte_pointer_ += params_.advance_group;
695 
696       thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
697           ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
698 
699       if (state_[1] == ThreadMap::Count::kGroup) {
700         state_[1] = 0;
701         ++state_[2];
702         byte_pointer_ += params_.advance_cluster;
703 
704         thread_start_row_ += ThreadMap::Count::kGroup *
705             ThreadMap::Shape::kGroup * ThreadMap::Count::kRow *
706             ThreadMap::Shape::kRow;
707 
708         if (state_[2] == ThreadMap::Count::kCluster) {
709           state_[2] = 0;
710           byte_pointer_ += params_.advance_tile;
711         }
712       }
713     }
714 
715     return *this;
716   }
717 
718   ///< Efficiently disables all accesses guarded by mask
clear_mask()719   CUTLASS_DEVICE void clear_mask() {
720     mask_.clear();
721   }
722 
723   ///< Efficiently enables all accesses guarded by mask
enable_mask()724   CUTLASS_DEVICE void enable_mask() {
725     mask_.enable();
726   }
727 
728   ///< Sets the mask
get_mask(Mask & mask)729   CUTLASS_DEVICE void get_mask(Mask& mask) const {
730     mask = mask_;
731   }
732 
733   ///< Sets the mask
set_mask(Mask const & mask)734   CUTLASS_DEVICE void set_mask(Mask const& mask) {
735     mask_ = mask;
736   }
737 };
738 
739 template <typename IT>
740 struct MakePrefetchableIterator {
741   using Iterator = PredicatedTileIteratorPrefetch<
742       typename IT::ThreadMap,
743       typename IT::Element>;
744 };
745 
746 ///////////////////////////////////////////////////////////////////////////////
747 
748 } // namespace threadblock
749 } // namespace epilogue
750 } // namespace cutlass
751 
752 ////////////////////////////////////////////////////////////////////////////////
753