1 /***************************************************************************************************
2  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
3  *reserved. SPDX-License-Identifier: BSD-3-Clause
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  * 1. Redistributions of source code must retain the above copyright notice,
9  *this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from
17  * this software without specific prior written permission.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29  *POSSIBILITY OF SUCH DAMAGE.
30  *
31  **************************************************************************************************/
32 /*! \file
33     \brief Templates implementing loading of tiles from pitch-linear rank=2
34    tensors.
35 
36     This iterator uses masks to guard out-of-bounds accesses. The first tile
37    this iterator visits maybe partial, then the remaining tiles are complete.
38    So, we only need to compute the predicates twice, once before the first tile
39    and once for the remaining full tiles which can share the same predicates.
40 
41     A precomputed "Params" object minimizes the amount of state that must be
42    stored in registers, and integer addition is used to advance the pointer
43    through memory.
44 */
45 
46 #pragma once
47 
48 #include <cutlass/arch/memory.h>
49 #include <cutlass/transform/threadblock/predicated_tile_access_iterator.h>
50 
51 ////////////////////////////////////////////////////////////////////////////////
52 
53 namespace cutlass {
54 namespace transform {
55 namespace threadblock {
56 
57 ////////////////////////////////////////////////////////////////////////////////
58 
59 /// PredicatedTileIteratorResidualLast
60 ///
61 /// Satisfies: ForwardTileIteratorConcept |
62 ///            ReadableContiguousTileIteratorConcept |
63 ///            WriteableContiguousTileIteratorConcept |
64 ///            MaskedTileIteratorConcept
65 ///
66 /// Regular tile iterator using a precomputed control structure to minimize
67 /// register liveness and integer arithmetic.
68 ///
69 /// Layout is assumed to be invariant at the time the precomputed "Params"
70 /// object is constructed.
71 ///
72 /// Base pointer and tensor extents may be specified at the time the iterator is
73 /// constructed. Subsequently, they are assumed to be immutable.
74 ///
75 /// Adding a logical coordinate offset may be performed at the time the iterator
76 /// is constructed. Subsequent additions to logical coordinate offset may be
77 /// performed but are relatively expensive.
78 ///
79 /// Visitation order is intended to first visit a "residual" tile that may be
80 /// partially full in both the advance dimension and the steady-state dimension.
81 /// This is assumed to be the last tile in the iteration sequence. Advancing an
82 /// iterator that has just been constructed moves to the first tile that is full
83 /// in the advance dimension and recomputes predicates. Subsequent accesses may
84 /// be performed without updating internal predicates and are efficient in terms
85 /// of live register state and pointer arithmetic instructions.
86 ///
87 /// To be efficient, this assumes the iterator will be dereferenced and advanced
88 /// at least once outside any looping structure to minimize integer arithmetic.
89 ///
90 /// Access out of bounds are safe so long as `clear_mask()` is called prior to
91 /// dereferencing the iterator.
92 ///
93 ///
94 /// Example:
95 ///
96 /// An efficient pipeline structure may be constructed as follows:
97 ///
98 // template <typename Iterator>
99 // __global__ void kernel(
100 //   typename Iterator::Params params,
101 //   typename Iterator::Element *ptr,
102 //   TensorCoord extent) {
103 //
104 //   typename Iterator::Fragment fragment;
105 //
106 //   TensorCoord threadblock_offset(0, 0);
107 //
108 //   Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
109 //
110 //
111 //   fragment = *iter;        // load "residue" tile first
112 //   ++iter;                  // advance to first "steady state" tile and update
113 //   internal masks
114 //
115 //
116 //   #pragma unroll
117 //   for (int i = Remaining - 1; i >= 0; --i) {
118 //
119 //     f(fragment);
120 //
121 //     if (!i) {
122 //       iter.clear_mask();   // light-weight operation to clear masks -
123 //       subsequent loads become NO-OPs.
124 //     }
125 //
126 //     fragment = *iter;      // load tile during "steady state" phase
127 //     ++iter;                // advance to next tile - lightweight due to
128 //     steady-state masks
129 //   }
130 // }
131 //
132 // void host(TensorView<Element, 2, layout::PitchLinear> view) {
133 //
134 //   using Iterator =
135 //   transform::threadblock::PredicatedTileIteratorResidualLast;
136 //
137 //   typename Iterator::Params params(view.layout());
138 //
139 //   kernel<Iterator>(params, view.data());
140 // }
141 ///
142 ///
143 template <
144     typename Shape,
145     typename Element,
146     typename Layout,
147     int AdvanceRank,
148     typename ThreadMap,
149     int AccessSize = ThreadMap::kElementsPerAccess,
150     bool Gather = false>
151 class PredicatedTileIteratorResidualLast;
152 
153 ////////////////////////////////////////////////////////////////////////////////
154 
155 /// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data.
156 ///
157 /// Satisfies: ForwardTileIteratorConcept |
158 ///            ReadableContiguousTileIteratorConcept |
159 ///            WriteableContiguousTileIteratorConcept |
160 ///            MaskedTileIteratorConcept
161 ///
162 template <
163     typename Shape_,
164     typename Element_,
165     int AdvanceRank,
166     typename ThreadMap_,
167     int AccessSize,
168     bool Gather>
169 class PredicatedTileIteratorResidualLast<
170     Shape_,
171     Element_,
172     layout::PitchLinear,
173     AdvanceRank,
174     ThreadMap_,
175     AccessSize,
176     Gather> {
177  public:
178   static_assert(
179       AdvanceRank == 0 || AdvanceRank == 1,
180       "Specialization for pitch-linear iterator may advance along the "
181       "contiguous(rank=0) or strided(rank=1) dimension.");
182 
183   using Shape = Shape_;
184   using Element = Element_;
185   using Layout = layout::PitchLinear;
186   static int const kAdvanceRank = AdvanceRank;
187   using ThreadMap = ThreadMap_;
188 
189   using Index = typename Layout::Index;
190   using LongIndex = typename Layout::LongIndex;
191 
192   using TensorRef = TensorRef<Element, Layout>;
193   using TensorView = TensorView<Element, Layout>;
194   using TensorCoord = typename Layout::TensorCoord;
195 
196   using Pointer = Element*;
197   using NonConstPointer = typename platform::remove_const<Element>::type*;
198 
199   /// Type used for internal memory accesses
200   using AccessType = AlignedArray<
201       Element,
202       AccessSize,
203       (AccessSize * sizeof_bits<Element>::value / 8)>;
204 
205   /// Underlying iterator to compute the addresses
206   using TileAccessIterator = PredicatedTileAccessIteratorResidualLast<
207       Shape,
208       Element,
209       Layout,
210       kAdvanceRank,
211       ThreadMap,
212       AccessType,
213       Gather>;
214 
215   static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
216 
217   /// Fragment object to be loaded or stored
218   using Fragment = cutlass::Array<
219       Element,
220       ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
221 
222   /// Predicate vector stores mask to guard accesses
223   using Mask = typename TileAccessIterator::Mask;
224 
225   /// Parameters object is precomputed state and is host-constructible
226   class Params {
227    public:
228     using Base = typename TileAccessIterator::Params::Base;
229 
230     friend PredicatedTileIteratorResidualLast;
231 
232    private:
233     /// Parameters object
234     typename TileAccessIterator::Params params_;
235 
236    public:
237     /// Construct the Params object given a pitch-linear tensor's layout
238     CUTLASS_HOST_DEVICE
Params(Layout const & layout)239     Params(Layout const& layout) : params_(layout) {}
240 
241     CUTLASS_HOST_DEVICE
Params()242     Params() {}
243 
244     CUTLASS_HOST_DEVICE
Params(Base const & base)245     Params(Base const& base) : params_(base) {}
246   };
247 
248  private:
249   /// Internal pointer type permits fast address arithmetic
250   using BytePointer = char*;
251 
252  private:
253   //
254   // Data members
255   //
256 
257   /// Data member to the tile access iterator
258   TileAccessIterator address_iterator_;
259 
260  public:
261   /// Constructs a TileIterator from its precomputed state, threadblock offset,
262   /// and thread ID
263   CUTLASS_HOST_DEVICE
264   PredicatedTileIteratorResidualLast(
265       /// Precomputed parameters object
266       Params const& params,
267       /// Pointer to start of tensor
268       Pointer pointer,
269       /// Extent of tensor
270       TensorCoord extent,
271       /// ID of each participating thread
272       int thread_id,
273       /// Initial offset of threadblock
274       TensorCoord const& threadblock_offset,
275       /// Gather indices
276       int const* indices = nullptr)
277       : address_iterator_(
278             params.params_,
279             pointer,
280             extent,
281             thread_id,
282             threadblock_offset,
283             indices) {}
284 
285   /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
286   /// offset
287   CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)288   PredicatedTileIteratorResidualLast(
289       Params const& params, ///< Precomputed parameters object
290       Pointer pointer, ///< Pointer to start of tensor
291       TensorCoord extent, ///< Extent of tensor
292       int thread_id ///< ID of each participating thread
293       )
294       : PredicatedTileIteratorResidualLast(
295             params,
296             pointer,
297             extent,
298             thread_id,
299             make_Coord(0, 0)) {}
300 
301   /// Adds a pointer offset in units of Element
302   CUTLASS_HOST_DEVICE
add_pointer_offset(LongIndex pointer_offset)303   void add_pointer_offset(LongIndex pointer_offset) {
304     address_iterator_.add_pointer_offset(pointer_offset);
305   }
306 
307   /// Advances to the next tile in memory.
308   ///
309   /// The first time this method is called, predicates are updated, and the
310   /// iterator's internal pointer is reverted to the first "steady state" tile.
311   /// Subsequent calls are lightweight and must only update the internal
312   /// pointer.
313   CUTLASS_HOST_DEVICE
314   PredicatedTileIteratorResidualLast& operator++() {
315     if (kAdvanceRank)
316       address_iterator_.add_tile_offset({0, 1});
317     else
318       address_iterator_.add_tile_offset({1, 0});
319 
320     return *this;
321   }
322 
323   /// Advances to the next tile in memory.
324   ///
325   /// The first time this method is called, predicates are updated, and the
326   /// iterator's internal pointer is reverted to the first "steady state" tile.
327   /// Subsequent calls are lightweight and must only update the internal
328   /// pointer.
329   CUTLASS_HOST_DEVICE
330   PredicatedTileIteratorResidualLast operator++(int) {
331     PredicatedTileIteratorResidualLast self(*this);
332     operator++();
333     return self;
334   }
335 
336   /// Clears the predicate set efficiently
337   CUTLASS_HOST_DEVICE
338   void clear_mask(bool enable = true) {
339     address_iterator_.clear_mask(enable);
340   }
341 
342   CUTLASS_HOST_DEVICE
set_residual_tile(bool enable)343   void set_residual_tile(bool enable) {
344     address_iterator_.set_residual_tile(enable);
345   }
346 
347   /// Clears the predicate set efficiently
348   CUTLASS_HOST_DEVICE
enable_mask()349   void enable_mask() {
350     address_iterator_.enable_mask();
351   }
352 
353   /// Sets the predicate mask, overriding value stored in predicate iterator
354   CUTLASS_HOST_DEVICE
set_mask(Mask const & mask)355   void set_mask(Mask const& mask) {
356     address_iterator_.set_mask(mask);
357   }
358 
359   /// Gets the mask
360   CUTLASS_HOST_DEVICE
get_mask(Mask & mask)361   void get_mask(Mask& mask) {
362     address_iterator_.get_mask(mask);
363   }
364 
365   CUTLASS_DEVICE
load_with_pointer_offset(Fragment & frag,Index pointer_offset)366   void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
367     load_with_byte_offset(
368         frag, pointer_offset * sizeof_bits<Element>::value / 8);
369   }
370 
371   CUTLASS_DEVICE
load_with_byte_offset(Fragment & frag,LongIndex byte_offset)372   void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
373     AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
374 
375     CUTLASS_PRAGMA_UNROLL
376     for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
377       CUTLASS_PRAGMA_UNROLL
378       for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
379         CUTLASS_PRAGMA_UNROLL
380         for (int v = 0; v < kAccessesPerVector; ++v) {
381           int idx = v +
382               kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
383 
384           address_iterator_.set_iteration_index(idx);
385           char const* byte_ptr =
386               reinterpret_cast<char const*>(address_iterator_.get()) +
387               byte_offset;
388 
389           AccessType const* access_ptr =
390               reinterpret_cast<AccessType const*>(byte_ptr);
391 
392           cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
393               frag_ptr[idx], access_ptr, address_iterator_.valid());
394 
395           ++address_iterator_;
396         }
397       }
398     }
399   }
400 
401   /// Loads a fragment from memory
402   CUTLASS_DEVICE
load(Fragment & frag)403   void load(Fragment& frag) {
404     load_with_byte_offset(frag, 0);
405   }
406 
407   /// Store a fragment to memory
408   CUTLASS_DEVICE
store_with_pointer_offset(Fragment const & frag,Index pointer_offset)409   void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
410     store_with_byte_offset(
411         frag, pointer_offset * sizeof_bits<Element>::value / 8);
412   }
413 
414   /// Store a fragment to memory
415   CUTLASS_DEVICE
store_with_byte_offset(Fragment const & frag,LongIndex byte_offset)416   void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
417     address_iterator_.set_iteration_index(0);
418     AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
419 
420     CUTLASS_PRAGMA_UNROLL
421     for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
422       CUTLASS_PRAGMA_UNROLL
423       for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
424         CUTLASS_PRAGMA_UNROLL
425         for (int v = 0; v < kAccessesPerVector; ++v) {
426           int idx = v +
427               kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
428 
429           char* byte_ptr =
430               reinterpret_cast<char*>(address_iterator_.get()) + byte_offset;
431           AccessType* access_ptr = reinterpret_cast<AccessType*>(byte_ptr);
432 
433           if (address_iterator_.valid()) {
434             *access_ptr = frag_ptr[idx];
435           }
436           ++address_iterator_;
437         }
438       }
439     }
440   }
441 
442   /// Store a fragment to memory
443   CUTLASS_DEVICE
store(Fragment const & frag)444   void store(Fragment const& frag) {
445     store_with_byte_offset(frag, 0);
446   }
447 };
448 
449 ////////////////////////////////////////////////////////////////////////////////
450 
451 /// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data.
452 ///
453 /// Satisfies: ForwardTileIteratorConcept |
454 ///            ReadableContiguousTileIteratorConcept |
455 ///            WriteableContiguousTileIteratorConcept |
456 ///            MaskedTileIteratorConcept
457 ///
458 template <
459     typename Shape_,
460     typename Element_,
461     int AdvanceRank,
462     typename ThreadMap_,
463     int AccessSize,
464     bool Gather>
465 class PredicatedTileIteratorResidualLast<
466     Shape_,
467     Element_,
468     layout::ColumnMajor,
469     AdvanceRank,
470     ThreadMap_,
471     AccessSize,
472     Gather> {
473  public:
474   static_assert(
475       AdvanceRank == 0 || AdvanceRank == 1,
476       "Specialization for pitch-linear iterator may along advance along the "
477       "contiguous(rank=0) or strided(rank=1) dimension.");
478 
479   using Shape = Shape_;
480   using Element = Element_;
481   using Layout = layout::ColumnMajor;
482   static int const kAdvanceRank = AdvanceRank;
483   using ThreadMap = ThreadMap_;
484 
485   using Index = typename Layout::Index;
486   using LongIndex = typename Layout::LongIndex;
487 
488   using TensorRef = TensorRef<Element, Layout>;
489   using TensorView = TensorView<Element, Layout>;
490   using TensorCoord = typename Layout::TensorCoord;
491 
492   using Pointer = Element*;
493   using NonConstPointer = typename platform::remove_const<Element>::type*;
494 
495   using UnderlyingIterator = PredicatedTileIteratorResidualLast<
496       layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
497       Element,
498       layout::PitchLinear,
499       (kAdvanceRank == 0 ? 0 : 1),
500       ThreadMap,
501       AccessSize,
502       Gather>;
503 
504   using AccessType = typename UnderlyingIterator::AccessType;
505 
506   /// Fragment object to be loaded or stored
507   using Fragment = cutlass::Array<
508       Element,
509       ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
510 
511   /// Predicate vector stores mask to guard accesses
512   using Mask = typename UnderlyingIterator::Mask;
513 
514   /// Parameters object is precomputed state and is host-constructible
515   class Params {
516    private:
517     friend PredicatedTileIteratorResidualLast;
518 
519     /// Parameters object
520     typename UnderlyingIterator::Params params_;
521 
522    public:
523     CUTLASS_HOST_DEVICE
Params()524     Params() {}
525 
526     /// Construct the Params object given a pitch-linear tensor's layout
527     CUTLASS_HOST_DEVICE
Params(Layout const & layout)528     Params(Layout const& layout)
529         : params_(layout::PitchLinear(layout.stride(0))) {}
530 
531     CUTLASS_HOST_DEVICE
Params(typename UnderlyingIterator::Params::Base const & base)532     Params(typename UnderlyingIterator::Params::Base const& base)
533         : params_(base) {}
534   };
535 
536  private:
537   //
538   // Data members
539   //
540 
541   /// Underlying pitch-linear tile iterator
542   UnderlyingIterator iterator_;
543 
544  public:
545   /// Constructs a TileIterator from its precomputed state, threadblock offset,
546   /// and thread ID
547   CUTLASS_HOST_DEVICE
548   PredicatedTileIteratorResidualLast(
549       Params const& params, ///< Precomputed parameters object
550       Pointer pointer, ///< Pointer to start of tensor
551       TensorCoord extent, ///< Extent of tensor
552       int thread_id, ///< ID of each participating thread
553       TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
554       int const* indices =
555           nullptr ///< gather/scatter indices, note no support for
556                   ///< gather/scatter at this specialization
557       )
558       : iterator_(
559             params.params_,
560             pointer,
561             layout::PitchLinearCoord(extent.row(), extent.column()),
562             thread_id,
563             layout::PitchLinearCoord(
564                 threadblock_offset.row(),
565                 threadblock_offset.column()),
566             indices) {}
567 
568   /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
569   /// offset
570   CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)571   PredicatedTileIteratorResidualLast(
572       Params const& params, ///< Precomputed parameters object
573       Pointer pointer, ///< Pointer to start of tensor
574       TensorCoord extent, ///< Extent of tensor
575       int thread_id ///< ID of each participating thread
576       )
577       : PredicatedTileIteratorResidualLast(
578             params,
579             pointer,
580             extent,
581             thread_id,
582             make_Coord(0, 0)) {}
583 
584   /// Adds a pointer offset in units of Element
585   CUTLASS_HOST_DEVICE
add_pointer_offset(LongIndex pointer_offset)586   void add_pointer_offset(LongIndex pointer_offset) {
587     iterator_.add_pointer_offset(pointer_offset);
588   }
589 
590   /// Advances to the next tile in memory.
591   ///
592   /// The first time this method is called, predicates are updated, and the
593   /// iterator's internal pointer is reverted to the first "steady state" tile.
594   /// Subsequent calls are lightweight and must only update the internal
595   /// pointer.
596   CUTLASS_HOST_DEVICE
597   PredicatedTileIteratorResidualLast& operator++() {
598     ++iterator_;
599     return *this;
600   }
601 
602   /// Advances to the next tile in memory.
603   ///
604   /// The first time this method is called, predicates are updated, and the
605   /// iterator's internal pointer is reverted to the first "steady state" tile.
606   /// Subsequent calls are lightweight and must only update the internal
607   /// pointer.
608   CUTLASS_HOST_DEVICE
609   PredicatedTileIteratorResidualLast operator++(int) {
610     PredicatedTileIteratorResidualLast self(*this);
611     operator++();
612     return self;
613   }
614 
615   /// Clears the predicate set efficiently
616   CUTLASS_HOST_DEVICE
617   void clear_mask(bool enable = true) {
618     iterator_.clear_mask(enable);
619   }
620 
621   CUTLASS_HOST_DEVICE
set_residual_tile(bool enable)622   void set_residual_tile(bool enable) {
623     iterator_.set_residual_tile(enable);
624   }
625 
626   /// Clears the predicate set efficiently
627   CUTLASS_HOST_DEVICE
enable_mask()628   void enable_mask() {
629     iterator_.enable_mask();
630   }
631 
632   /// Sets the predicate mask, overriding value stored in predicate iterator
633   CUTLASS_HOST_DEVICE
set_mask(Mask const & mask)634   void set_mask(Mask const& mask) {
635     iterator_.set_mask(mask);
636   }
637 
638   /// Gets the mask
639   CUTLASS_HOST_DEVICE
get_mask(Mask & mask)640   void get_mask(Mask& mask) {
641     iterator_.get_mask(mask);
642   }
643 
644   /// Loads a fragment from memory
645   CUTLASS_DEVICE
load_with_pointer_offset(Fragment & frag,Index pointer_offset)646   void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
647     iterator_.load_with_pointer_offset(frag, pointer_offset);
648   }
649 
650   /// Loads a fragment from memory
651   CUTLASS_DEVICE
load_with_byte_offset(Fragment & frag,LongIndex byte_offset)652   void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
653     iterator_.load_with_byte_offset(frag, byte_offset);
654   }
655 
656   /// Loads a fragment from memory
657   CUTLASS_DEVICE
load(Fragment & frag)658   void load(Fragment& frag) {
659     load_with_pointer_offset(frag, 0);
660   }
661 
662   /// Store a fragment to memory
663   CUTLASS_DEVICE
store_with_pointer_offset(Fragment const & frag,Index pointer_offset)664   void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
665     iterator_.store_with_pointer_offset(frag, pointer_offset);
666   }
667 
668   /// Store a fragment to memory
669   CUTLASS_DEVICE
store_with_byte_offset(Fragment const & frag,LongIndex byte_offset)670   void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
671     iterator_.store_with_byte_offset(frag, byte_offset);
672   }
673 
674   /// Store a fragment to memory
675   CUTLASS_DEVICE
store(Fragment const & frag)676   void store(Fragment const& frag) {
677     store_with_pointer_offset(frag, 0);
678   }
679 };
680 
681 ////////////////////////////////////////////////////////////////////////////////
682 
683 /// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data.
684 ///
685 /// Satisfies: ForwardTileIteratorConcept |
686 ///            ReadableContiguousTileIteratorConcept |
687 ///            WriteableContiguousTileIteratorConcept |
688 ///            MaskedTileIteratorConcept
689 ///
690 template <
691     typename Shape_,
692     typename Element_,
693     int AdvanceRank,
694     typename ThreadMap_,
695     int AccessSize,
696     bool Gather>
697 class PredicatedTileIteratorResidualLast<
698     Shape_,
699     Element_,
700     layout::RowMajor,
701     AdvanceRank,
702     ThreadMap_,
703     AccessSize,
704     Gather> {
705  public:
706   static_assert(
707       AdvanceRank == 0 || AdvanceRank == 1,
708       "Specialization for pitch-linear iterator may along advance along the "
709       "contiguous(rank=0) or strided(rank=1) dimension.");
710 
711   using Shape = Shape_;
712   using Element = Element_;
713   using Layout = layout::RowMajor;
714   static int const kAdvanceRank = AdvanceRank;
715   using ThreadMap = ThreadMap_;
716 
717   using Index = typename Layout::Index;
718   using LongIndex = typename Layout::LongIndex;
719 
720   using TensorRef = TensorRef<Element, Layout>;
721   using TensorView = TensorView<Element, Layout>;
722   using TensorCoord = typename Layout::TensorCoord;
723 
724   using Pointer = Element*;
725   using NonConstPointer = typename platform::remove_const<Element>::type*;
726 
727   using UnderlyingIterator = PredicatedTileIteratorResidualLast<
728       layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
729       Element,
730       layout::PitchLinear,
731       (kAdvanceRank == 0 ? 1 : 0),
732       ThreadMap,
733       AccessSize,
734       Gather>;
735 
736   using AccessType = typename UnderlyingIterator::AccessType;
737 
738   /// Fragment object to be loaded or stored
739   using Fragment = cutlass::Array<
740       Element,
741       ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
742 
743   /// Predicate vector stores mask to guard accesses
744   using Mask = typename UnderlyingIterator::Mask;
745 
746   /// Parameters object is precomputed state and is host-constructible
747   class Params {
748    private:
749     friend PredicatedTileIteratorResidualLast;
750 
751     /// Parameters object
752     typename UnderlyingIterator::Params params_;
753 
754    public:
755     CUTLASS_HOST_DEVICE
Params()756     Params() {}
757 
758     /// Construct the Params object given a pitch-linear tensor's layout
759     CUTLASS_HOST_DEVICE
Params(Layout const & layout)760     Params(Layout const& layout)
761         : params_(layout::PitchLinear(layout.stride(0))) {}
762 
763     CUTLASS_HOST_DEVICE
Params(typename UnderlyingIterator::Params::Base const & base)764     Params(typename UnderlyingIterator::Params::Base const& base)
765         : params_(base) {}
766   };
767 
768  private:
769   //
770   // Data members
771   //
772 
773   /// Underlying pitch-linear tile iterator
774   UnderlyingIterator iterator_;
775 
776  public:
777   /// Constructs a TileIterator from its precomputed state, threadblock offset,
778   /// and thread ID
779   CUTLASS_HOST_DEVICE
780   PredicatedTileIteratorResidualLast(
781       Params const& params, ///< Precomputed parameters object
782       Pointer pointer, ///< Pointer to start of tensor
783       TensorCoord extent, ///< Extent of tensor
784       int thread_id, ///< ID of each participating thread
785       TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
786       int const* indices = nullptr ///< Gather indices
787       )
788       : iterator_(
789             params.params_,
790             pointer,
791             layout::PitchLinearCoord(extent.column(), extent.row()),
792             thread_id,
793             layout::PitchLinearCoord(
794                 threadblock_offset.column(),
795                 threadblock_offset.row()),
796             indices) {}
797 
798   /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
799   /// offset
800   CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)801   PredicatedTileIteratorResidualLast(
802       Params const& params, ///< Precomputed parameters object
803       Pointer pointer, ///< Pointer to start of tensor
804       TensorCoord extent, ///< Extent of tensor
805       int thread_id ///< ID of each participating thread
806       )
807       : PredicatedTileIteratorResidualLast(
808             params,
809             pointer,
810             extent,
811             thread_id,
812             make_Coord(0, 0)) {}
813 
814   /// Adds a pointer offset in units of Element
815   CUTLASS_HOST_DEVICE
add_pointer_offset(LongIndex pointer_offset)816   void add_pointer_offset(LongIndex pointer_offset) {
817     iterator_.add_pointer_offset(pointer_offset);
818   }
819 
820   /// Advances to the next tile in memory.
821   ///
822   /// The first time this method is called, predicates are updated, and the
823   /// iterator's internal pointer is reverted to the first "steady state" tile.
824   /// Subsequent calls are lightweight and must only update the internal
825   /// pointer.
826   CUTLASS_HOST_DEVICE
827   PredicatedTileIteratorResidualLast& operator++() {
828     ++iterator_;
829     return *this;
830   }
831 
832   /// Advances to the next tile in memory.
833   ///
834   /// The first time this method is called, predicates are updated, and the
835   /// iterator's internal pointer is reverted to the first "steady state" tile.
836   /// Subsequent calls are lightweight and must only update the internal
837   /// pointer.
838   CUTLASS_HOST_DEVICE
839   PredicatedTileIteratorResidualLast operator++(int) {
840     PredicatedTileIteratorResidualLast self(*this);
841     operator++();
842     return self;
843   }
844 
845   /// Clears the predicate set efficiently
846   CUTLASS_HOST_DEVICE
847   void clear_mask(bool enable = true) {
848     iterator_.clear_mask(enable);
849   }
850 
851   CUTLASS_HOST_DEVICE
set_residual_tile(bool enable)852   void set_residual_tile(bool enable) {
853     iterator_.set_residual_tile(enable);
854   }
855 
856   /// Clears the predicate set efficiently
857   CUTLASS_HOST_DEVICE
enable_mask()858   void enable_mask() {
859     iterator_.enable_mask();
860   }
861 
862   /// Sets the predicate mask, overriding value stored in predicate iterator
863   CUTLASS_HOST_DEVICE
set_mask(Mask const & mask)864   void set_mask(Mask const& mask) {
865     iterator_.set_mask(mask);
866   }
867 
868   /// Gets the mask
869   CUTLASS_HOST_DEVICE
get_mask(Mask & mask)870   void get_mask(Mask& mask) {
871     iterator_.get_mask(mask);
872   }
873 
874   /// Loads a fragment from memory
875   CUTLASS_DEVICE
load_with_pointer_offset(Fragment & frag,Index pointer_offset)876   void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
877     iterator_.load_with_pointer_offset(frag, pointer_offset);
878   }
879 
880   /// Loads a fragment from memory
881   CUTLASS_DEVICE
load_with_byte_offset(Fragment & frag,LongIndex byte_offset)882   void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
883     iterator_.load_with_byte_offset(frag, byte_offset);
884   }
885 
886   /// Loads a fragment from memory
887   CUTLASS_DEVICE
load(Fragment & frag)888   void load(Fragment& frag) {
889     load_with_pointer_offset(frag, 0);
890   }
891 
892   /// Store a fragment to memory
893   CUTLASS_DEVICE
store_with_pointer_offset(Fragment const & frag,Index pointer_offset)894   void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
895     iterator_.store_with_pointer_offset(frag, pointer_offset);
896   }
897 
898   /// Store a fragment to memory
899   CUTLASS_DEVICE
store_with_byte_offset(Fragment const & frag,LongIndex byte_offset)900   void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
901     iterator_.store_with_byte_offset(frag, byte_offset);
902   }
903 
904   /// Store a fragment to memory
905   CUTLASS_DEVICE
store(Fragment const & frag)906   void store(Fragment const& frag) {
907     store_with_pointer_offset(frag, 0);
908   }
909 };
910 
911 ////////////////////////////////////////////////////////////////////////////////
912 
913 /// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data.
914 ///
915 /// Satisfies: ForwardTileIteratorConcept |
916 ///            ReadableContiguousTileIteratorConcept |
917 ///            WriteableContiguousTileIteratorConcept |
918 ///            MaskedTileIteratorConcept
919 ///
920 template <
921     typename Shape_,
922     typename Element_,
923     int AdvanceRank,
924     typename ThreadMap_,
925     int AccessSize>
926 class PredicatedTileIteratorResidualLast<
927     Shape_,
928     Element_,
929     layout::AffineRankN<2>,
930     AdvanceRank,
931     ThreadMap_,
932     AccessSize,
933     false> {
934  public:
935   static_assert(
936       AdvanceRank == 0 || AdvanceRank == 1,
937       "Specialization for pitch-linear iterator may advance along the "
938       "contiguous(rank=0) or strided(rank=1) dimension.");
939 
940   using Shape = Shape_;
941   using Element = Element_;
942   using Layout = layout::AffineRankN<2>;
943   static int const kAdvanceRank = AdvanceRank;
944   using ThreadMap = ThreadMap_;
945 
946   using Index = typename Layout::Index;
947   using LongIndex = typename Layout::LongIndex;
948 
949   using TensorRef = TensorRef<Element, Layout>;
950   using TensorView = TensorView<Element, Layout>;
951   using TensorCoord = typename Layout::TensorCoord;
952 
953   using Pointer = Element*;
954   using NonConstPointer = typename platform::remove_const<Element>::type*;
955 
956   /// Type used for internal memory accesses
957   using AccessType = AlignedArray<
958       Element,
959       AccessSize,
960       (AccessSize * sizeof_bits<Element>::value / 8)>;
961 
962   /// Underlying iterator to compute the addresses
963   using TileAccessIterator = PredicatedTileAccessIteratorResidualLast<
964       Shape,
965       Element,
966       Layout,
967       kAdvanceRank,
968       ThreadMap,
969       AccessType>;
970 
971   static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
972 
973   /// Fragment object to be loaded or stored
974   using Fragment = cutlass::Array<
975       Element,
976       ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
977 
978   /// Predicate vector stores mask to guard accesses
979   using Mask = typename TileAccessIterator::Mask;
980 
981   /// Parameters object is precomputed state and is host-constructible
982   class Params {
983    public:
984     friend PredicatedTileIteratorResidualLast;
985 
986    private:
987     /// Parameters object
988     typename TileAccessIterator::Params params_;
989 
990    public:
991     /// Construct the Params object given a pitch-linear tensor's layout
992     CUTLASS_HOST_DEVICE
Params(Layout const & layout)993     Params(Layout const& layout) : params_(layout) {}
994 
995     CUTLASS_HOST_DEVICE
Params()996     Params() {}
997   };
998 
999  private:
1000   /// Internal pointer type permits fast address arithmetic
1001   using BytePointer = char*;
1002 
1003  private:
1004   //
1005   // Data members
1006   //
1007 
1008   /// Data member to the tile access iterator
1009   TileAccessIterator address_iterator_;
1010 
1011  public:
1012   /// Constructs a TileIterator from its precomputed state, threadblock offset,
1013   /// and thread ID
1014   CUTLASS_HOST_DEVICE
1015   PredicatedTileIteratorResidualLast(
1016       /// Precomputed parameters object
1017       Params const& params,
1018       /// Pointer to start of tensor
1019       Pointer pointer,
1020       /// Extent of tensor
1021       TensorCoord extent,
1022       /// ID of each participating thread
1023       int thread_id,
1024       /// Initial offset of threadblock
1025       TensorCoord const& threadblock_offset,
1026       int const* indices =
1027           nullptr ///< gather/scatter indices, note no support for
1028                   ///< gather/scatter at this specialization
1029       )
1030       : address_iterator_(
1031             params.params_,
1032             pointer,
1033             extent,
1034             thread_id,
1035             threadblock_offset) {}
1036 
1037   /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
1038   /// offset
1039   CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)1040   PredicatedTileIteratorResidualLast(
1041       Params const& params, ///< Precomputed parameters object
1042       Pointer pointer, ///< Pointer to start of tensor
1043       TensorCoord extent, ///< Extent of tensor
1044       int thread_id ///< ID of each participating thread
1045       )
1046       : PredicatedTileIteratorResidualLast(
1047             params,
1048             pointer,
1049             extent,
1050             thread_id,
1051             make_Coord(0, 0)) {}
1052 
1053   /// Adds a pointer offset in units of Element
1054   CUTLASS_HOST_DEVICE
add_pointer_offset(LongIndex pointer_offset)1055   void add_pointer_offset(LongIndex pointer_offset) {
1056     address_iterator_.add_pointer_offset(pointer_offset);
1057   }
1058 
1059   /// Advances to the next tile in memory.
1060   ///
1061   /// The first time this method is called, predicates are updated, and the
1062   /// iterator's internal pointer is reverted to the first "steady state" tile.
1063   /// Subsequent calls are lightweight and must only update the internal
1064   /// pointer.
1065   CUTLASS_HOST_DEVICE
1066   PredicatedTileIteratorResidualLast& operator++() {
1067     if (kAdvanceRank)
1068       address_iterator_.add_tile_offset(make_Coord(0, 1));
1069     else
1070       address_iterator_.add_tile_offset(make_Coord(1, 0));
1071 
1072     return *this;
1073   }
1074 
1075   /// Advances to the next tile in memory.
1076   ///
1077   /// The first time this method is called, predicates are updated, and the
1078   /// iterator's internal pointer is reverted to the first "steady state" tile.
1079   /// Subsequent calls are lightweight and must only update the internal
1080   /// pointer.
1081   CUTLASS_HOST_DEVICE
1082   PredicatedTileIteratorResidualLast operator++(int) {
1083     PredicatedTileIteratorResidualLast self(*this);
1084     operator++();
1085     return self;
1086   }
1087 
1088   /// Clears the predicate set efficiently
1089   CUTLASS_HOST_DEVICE
1090   void clear_mask(bool enable = true) {
1091     address_iterator_.clear_mask(enable);
1092   }
1093 
1094   CUTLASS_HOST_DEVICE
set_residual_tile(bool enable)1095   void set_residual_tile(bool enable) {
1096     address_iterator_.set_residual_tile(enable);
1097   }
1098 
1099   /// Clears the predicate set efficiently
1100   CUTLASS_HOST_DEVICE
enable_mask()1101   void enable_mask() {
1102     address_iterator_.enable_mask();
1103   }
1104 
1105   /// Sets the predicate mask, overriding value stored in predicate iterator
1106   CUTLASS_HOST_DEVICE
set_mask(Mask const & mask)1107   void set_mask(Mask const& mask) {
1108     address_iterator_.set_mask(mask);
1109   }
1110 
1111   /// Gets the mask
1112   CUTLASS_HOST_DEVICE
get_mask(Mask & mask)1113   void get_mask(Mask& mask) {
1114     address_iterator_.get_mask(mask);
1115   }
1116 
1117   CUTLASS_DEVICE
load_with_pointer_offset(Fragment & frag,Index pointer_offset)1118   void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
1119     load_with_byte_offset(
1120         frag, pointer_offset * sizeof_bits<Element>::value / 8);
1121   }
1122 
1123   CUTLASS_DEVICE
load_with_byte_offset(Fragment & frag,LongIndex byte_offset)1124   void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
1125     AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
1126 
1127     CUTLASS_PRAGMA_UNROLL
1128     for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
1129       CUTLASS_PRAGMA_UNROLL
1130       for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
1131         CUTLASS_PRAGMA_UNROLL
1132         for (int v = 0; v < kAccessesPerVector; ++v) {
1133           int idx = v +
1134               kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
1135 
1136           address_iterator_.set_iteration_index(idx);
1137           char const* byte_ptr =
1138               reinterpret_cast<char const*>(address_iterator_.get()) +
1139               byte_offset;
1140 
1141           AccessType const* access_ptr =
1142               reinterpret_cast<AccessType const*>(byte_ptr);
1143 
1144           cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
1145               frag_ptr[idx], access_ptr, address_iterator_.valid());
1146 
1147           ++address_iterator_;
1148         }
1149       }
1150     }
1151   }
1152 
1153   /// Loads a fragment from memory
1154   CUTLASS_DEVICE
load(Fragment & frag)1155   void load(Fragment& frag) {
1156     load_with_byte_offset(frag, 0);
1157   }
1158 
1159   /// Store a fragment to memory
1160   CUTLASS_DEVICE
store_with_pointer_offset(Fragment const & frag,Index pointer_offset)1161   void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
1162     store_with_byte_offset(
1163         frag, pointer_offset * sizeof_bits<Element>::value / 8);
1164   }
1165 
1166   /// Store a fragment to memory
1167   CUTLASS_DEVICE
store_with_byte_offset(Fragment const & frag,LongIndex byte_offset)1168   void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
1169     address_iterator_.set_iteration_index(0);
1170     AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
1171 
1172     CUTLASS_PRAGMA_UNROLL
1173     for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
1174       CUTLASS_PRAGMA_UNROLL
1175       for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
1176         CUTLASS_PRAGMA_UNROLL
1177         for (int v = 0; v < kAccessesPerVector; ++v) {
1178           int idx = v +
1179               kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
1180 
1181           char* byte_ptr =
1182               reinterpret_cast<char*>(address_iterator_.get()) + byte_offset;
1183           AccessType* access_ptr = reinterpret_cast<AccessType*>(byte_ptr);
1184 
1185           if (address_iterator_.valid()) {
1186             *access_ptr = frag_ptr[idx];
1187           }
1188           ++address_iterator_;
1189         }
1190       }
1191     }
1192   }
1193 
1194   /// Store a fragment to memory
1195   CUTLASS_DEVICE
store(Fragment const & frag)1196   void store(Fragment const& frag) {
1197     store_with_byte_offset(frag, 0);
1198   }
1199 };
1200 
1201 ////////////////////////////////////////////////////////////////////////////////
1202 
1203 /// Specialization of PredicatedTileIteratorResidualLast for affine rank 2
1204 /// column-major data.
1205 ///
1206 /// Satisfies: ForwardTileIteratorConcept |
1207 ///            ReadableContiguousTileIteratorConcept |
1208 ///            WriteableContiguousTileIteratorConcept |
1209 ///            MaskedTileIteratorConcept
1210 ///
1211 template <
1212     typename Shape_,
1213     typename Element_,
1214     int AdvanceRank,
1215     typename ThreadMap_,
1216     int AccessSize>
1217 class PredicatedTileIteratorResidualLast<
1218     Shape_,
1219     Element_,
1220     layout::AffineRank2ColumnMajor,
1221     AdvanceRank,
1222     ThreadMap_,
1223     AccessSize,
1224     false> {
1225  public:
1226   static_assert(
1227       AdvanceRank == 0 || AdvanceRank == 1,
1228       "Specialization for pitch-linear iterator may along advance along the "
1229       "contiguous(rank=0) or strided(rank=1) dimension.");
1230 
1231   using Shape = Shape_;
1232   using Element = Element_;
1233   using Layout = layout::AffineRank2ColumnMajor;
1234   static int const kAdvanceRank = AdvanceRank;
1235   using ThreadMap = ThreadMap_;
1236 
1237   using Index = typename Layout::Index;
1238   using LongIndex = typename Layout::LongIndex;
1239 
1240   using TensorRef = TensorRef<Element, Layout>;
1241   using TensorView = TensorView<Element, Layout>;
1242   using TensorCoord = typename Layout::TensorCoord;
1243 
1244   using Pointer = Element*;
1245   using NonConstPointer = typename platform::remove_const<Element>::type*;
1246 
1247   // Map to the underlying AffineRankN<2> layout
1248   using UnderlyingIterator = PredicatedTileIteratorResidualLast<
1249       layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
1250       Element,
1251       layout::AffineRankN<2>,
1252       (kAdvanceRank == 0 ? 0 : 1),
1253       ThreadMap,
1254       AccessSize>;
1255 
1256   using AccessType = typename UnderlyingIterator::AccessType;
1257 
1258   /// Fragment object to be loaded or stored
1259   using Fragment = cutlass::Array<
1260       Element,
1261       ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
1262 
1263   /// Predicate vector stores mask to guard accesses
1264   using Mask = typename UnderlyingIterator::Mask;
1265 
1266   /// Parameters object is precomputed state and is host-constructible
1267   class Params {
1268    private:
1269     friend PredicatedTileIteratorResidualLast;
1270 
1271     /// Parameters object
1272     typename UnderlyingIterator::Params params_;
1273 
1274    public:
1275     CUTLASS_HOST_DEVICE
Params()1276     Params() {}
1277 
1278     /// Construct the Params object given an AffineRankN<2> tensor's layout
1279     CUTLASS_HOST_DEVICE
Params(Layout const & layout)1280     Params(Layout const& layout)
1281         : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {}
1282   };
1283 
1284  private:
1285   //
1286   // Data members
1287   //
1288 
1289   /// Underlying AffineRankN<2> tile iterator
1290   UnderlyingIterator iterator_;
1291 
1292  public:
1293   /// Constructs a TileIterator from its precomputed state, threadblock offset,
1294   /// and thread ID
1295   CUTLASS_HOST_DEVICE
1296   PredicatedTileIteratorResidualLast(
1297       Params const& params, ///< Precomputed parameters object
1298       Pointer pointer, ///< Pointer to start of tensor
1299       TensorCoord extent, ///< Extent of tensor
1300       int thread_id, ///< ID of each participating thread
1301       TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
1302       int const* indices =
1303           nullptr ///< gather/scatter indices, note no support for
1304                   ///< gather/scatter at this specialization
1305       )
1306       : iterator_(
1307             params.params_,
1308             pointer,
1309             layout::PitchLinearCoord(extent.row(), extent.column()),
1310             thread_id,
1311             layout::PitchLinearCoord(
1312                 threadblock_offset.row(),
1313                 threadblock_offset.column())) {}
1314 
1315   /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
1316   /// offset
1317   CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)1318   PredicatedTileIteratorResidualLast(
1319       Params const& params, ///< Precomputed parameters object
1320       Pointer pointer, ///< Pointer to start of tensor
1321       TensorCoord extent, ///< Extent of tensor
1322       int thread_id ///< ID of each participating thread
1323       )
1324       : PredicatedTileIteratorResidualLast(
1325             params,
1326             pointer,
1327             extent,
1328             thread_id,
1329             make_Coord(0, 0)) {}
1330 
1331   /// Adds a pointer offset in units of Element
1332   CUTLASS_HOST_DEVICE
add_pointer_offset(LongIndex pointer_offset)1333   void add_pointer_offset(LongIndex pointer_offset) {
1334     iterator_.add_pointer_offset(pointer_offset);
1335   }
1336 
1337   /// Advances to the next tile in memory.
1338   ///
1339   /// The first time this method is called, predicates are updated, and the
1340   /// iterator's internal pointer is reverted to the first "steady state" tile.
1341   /// Subsequent calls are lightweight and must only update the internal
1342   /// pointer.
1343   CUTLASS_HOST_DEVICE
1344   PredicatedTileIteratorResidualLast& operator++() {
1345     ++iterator_;
1346     return *this;
1347   }
1348 
1349   /// Advances to the next tile in memory.
1350   ///
1351   /// The first time this method is called, predicates are updated, and the
1352   /// iterator's internal pointer is reverted to the first "steady state" tile.
1353   /// Subsequent calls are lightweight and must only update the internal
1354   /// pointer.
1355   CUTLASS_HOST_DEVICE
1356   PredicatedTileIteratorResidualLast operator++(int) {
1357     PredicatedTileIteratorResidualLast self(*this);
1358     operator++();
1359     return self;
1360   }
1361 
1362   /// Clears the predicate set efficiently
1363   CUTLASS_HOST_DEVICE
1364   void clear_mask(bool enable = true) {
1365     iterator_.clear_mask(enable);
1366   }
1367 
1368   CUTLASS_HOST_DEVICE
set_residual_tile(bool enable)1369   void set_residual_tile(bool enable) {
1370     iterator_.set_residual_tile(enable);
1371   }
1372 
1373   /// Clears the predicate set efficiently
1374   CUTLASS_HOST_DEVICE
enable_mask()1375   void enable_mask() {
1376     iterator_.enable_mask();
1377   }
1378 
1379   /// Sets the predicate mask, overriding value stored in predicate iterator
1380   CUTLASS_HOST_DEVICE
set_mask(Mask const & mask)1381   void set_mask(Mask const& mask) {
1382     iterator_.set_mask(mask);
1383   }
1384 
1385   /// Gets the mask
1386   CUTLASS_HOST_DEVICE
get_mask(Mask & mask)1387   void get_mask(Mask& mask) {
1388     iterator_.get_mask(mask);
1389   }
1390 
1391   /// Loads a fragment from memory
1392   CUTLASS_DEVICE
load_with_pointer_offset(Fragment & frag,Index pointer_offset)1393   void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
1394     iterator_.load_with_pointer_offset(frag, pointer_offset);
1395   }
1396 
1397   /// Loads a fragment from memory
1398   CUTLASS_DEVICE
load_with_byte_offset(Fragment & frag,LongIndex byte_offset)1399   void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
1400     iterator_.load_with_byte_offset(frag, byte_offset);
1401   }
1402 
1403   /// Loads a fragment from memory
1404   CUTLASS_DEVICE
load(Fragment & frag)1405   void load(Fragment& frag) {
1406     load_with_pointer_offset(frag, 0);
1407   }
1408 
1409   /// Store a fragment to memory
1410   CUTLASS_DEVICE
store_with_pointer_offset(Fragment const & frag,Index pointer_offset)1411   void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
1412     iterator_.store_with_pointer_offset(frag, pointer_offset);
1413   }
1414 
1415   /// Store a fragment to memory
1416   CUTLASS_DEVICE
store_with_byte_offset(Fragment const & frag,LongIndex byte_offset)1417   void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
1418     iterator_.store_with_byte_offset(frag, byte_offset);
1419   }
1420 
1421   /// Store a fragment to memory
1422   CUTLASS_DEVICE
store(Fragment const & frag)1423   void store(Fragment const& frag) {
1424     store_with_pointer_offset(frag, 0);
1425   }
1426 };
1427 
1428 ////////////////////////////////////////////////////////////////////////////////
1429 
1430 /// Specialization of PredicatedTileIteratorResidualLast for affine rank 2
1431 /// row-major data.
1432 ///
1433 /// Satisfies: ForwardTileIteratorConcept |
1434 ///            ReadableContiguousTileIteratorConcept |
1435 ///            WriteableContiguousTileIteratorConcept |
1436 ///            MaskedTileIteratorConcept
1437 ///
1438 template <
1439     typename Shape_,
1440     typename Element_,
1441     int AdvanceRank,
1442     typename ThreadMap_,
1443     int AccessSize>
1444 class PredicatedTileIteratorResidualLast<
1445     Shape_,
1446     Element_,
1447     layout::AffineRank2RowMajor,
1448     AdvanceRank,
1449     ThreadMap_,
1450     AccessSize,
1451     false> {
1452  public:
1453   static_assert(
1454       AdvanceRank == 0 || AdvanceRank == 1,
1455       "Specialization for pitch-linear iterator may along advance along the "
1456       "contiguous(rank=0) or strided(rank=1) dimension.");
1457 
1458   using Shape = Shape_;
1459   using Element = Element_;
1460   using Layout = layout::AffineRank2RowMajor;
1461   static int const kAdvanceRank = AdvanceRank;
1462   using ThreadMap = ThreadMap_;
1463 
1464   using Index = typename Layout::Index;
1465   using LongIndex = typename Layout::LongIndex;
1466 
1467   using TensorRef = TensorRef<Element, Layout>;
1468   using TensorView = TensorView<Element, Layout>;
1469   using TensorCoord = typename Layout::TensorCoord;
1470 
1471   using Pointer = Element*;
1472   using NonConstPointer = typename platform::remove_const<Element>::type*;
1473 
1474   // Map to the underlying AffineRankN<2> layout
1475   using UnderlyingIterator = PredicatedTileIteratorResidualLast<
1476       layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
1477       Element,
1478       layout::AffineRankN<2>,
1479       (kAdvanceRank == 0 ? 1 : 0),
1480       ThreadMap,
1481       AccessSize>;
1482 
1483   using AccessType = typename UnderlyingIterator::AccessType;
1484 
1485   /// Fragment object to be loaded or stored
1486   using Fragment = cutlass::Array<
1487       Element,
1488       ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
1489 
1490   /// Predicate vector stores mask to guard accesses
1491   using Mask = typename UnderlyingIterator::Mask;
1492 
1493   /// Parameters object is precomputed state and is host-constructible
1494   class Params {
1495    private:
1496     friend PredicatedTileIteratorResidualLast;
1497 
1498     /// Parameters object
1499     typename UnderlyingIterator::Params params_;
1500 
1501    public:
1502     CUTLASS_HOST_DEVICE
Params()1503     Params() {}
1504 
1505     /// Construct the Params object given an AffineRankN<2> tensor's layout
1506     CUTLASS_HOST_DEVICE
Params(Layout const & layout)1507     Params(Layout const& layout)
1508         : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {}
1509   };
1510 
1511  private:
1512   //
1513   // Data members
1514   //
1515 
1516   /// Underlying AffineRankN<2> tile iterator
1517   UnderlyingIterator iterator_;
1518 
1519  public:
1520   /// Constructs a TileIterator from its precomputed state, threadblock offset,
1521   /// and thread ID
1522   CUTLASS_HOST_DEVICE
1523   PredicatedTileIteratorResidualLast(
1524       Params const& params, ///< Precomputed parameters object
1525       Pointer pointer, ///< Pointer to start of tensor
1526       TensorCoord extent, ///< Extent of tensor
1527       int thread_id, ///< ID of each participating thread
1528       TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
1529       int const* indices =
1530           nullptr ///< gather/scatter indices, note no support for
1531                   ///< gather/scatter at this specialization
1532       )
1533       : iterator_(
1534             params.params_,
1535             pointer,
1536             layout::PitchLinearCoord(extent.column(), extent.row()),
1537             thread_id,
1538             layout::PitchLinearCoord(
1539                 threadblock_offset.column(),
1540                 threadblock_offset.row())) {}
1541 
1542   /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
1543   /// offset
1544   CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)1545   PredicatedTileIteratorResidualLast(
1546       Params const& params, ///< Precomputed parameters object
1547       Pointer pointer, ///< Pointer to start of tensor
1548       TensorCoord extent, ///< Extent of tensor
1549       int thread_id ///< ID of each participating thread
1550       )
1551       : PredicatedTileIteratorResidualLast(
1552             params,
1553             pointer,
1554             extent,
1555             thread_id,
1556             make_Coord(0, 0)) {}
1557 
1558   /// Adds a pointer offset in units of Element
1559   CUTLASS_HOST_DEVICE
add_pointer_offset(LongIndex pointer_offset)1560   void add_pointer_offset(LongIndex pointer_offset) {
1561     iterator_.add_pointer_offset(pointer_offset);
1562   }
1563 
1564   /// Advances to the next tile in memory.
1565   ///
1566   /// The first time this method is called, predicates are updated, and the
1567   /// iterator's internal pointer is reverted to the first "steady state" tile.
1568   /// Subsequent calls are lightweight and must only update the internal
1569   /// pointer.
1570   CUTLASS_HOST_DEVICE
1571   PredicatedTileIteratorResidualLast& operator++() {
1572     ++iterator_;
1573     return *this;
1574   }
1575 
1576   /// Advances to the next tile in memory.
1577   ///
1578   /// The first time this method is called, predicates are updated, and the
1579   /// iterator's internal pointer is reverted to the first "steady state" tile.
1580   /// Subsequent calls are lightweight and must only update the internal
1581   /// pointer.
1582   CUTLASS_HOST_DEVICE
1583   PredicatedTileIteratorResidualLast operator++(int) {
1584     PredicatedTileIteratorResidualLast self(*this);
1585     operator++();
1586     return self;
1587   }
1588 
1589   /// Clears the predicate set efficiently
1590   CUTLASS_HOST_DEVICE
1591   void clear_mask(bool enable = true) {
1592     iterator_.clear_mask(enable);
1593   }
1594 
1595   CUTLASS_HOST_DEVICE
set_residual_tile(bool enable)1596   void set_residual_tile(bool enable) {
1597     iterator_.set_residual_tile(enable);
1598   }
1599 
1600   /// Clears the predicate set efficiently
1601   CUTLASS_HOST_DEVICE
enable_mask()1602   void enable_mask() {
1603     iterator_.enable_mask();
1604   }
1605 
1606   /// Sets the predicate mask, overriding value stored in predicate iterator
1607   CUTLASS_HOST_DEVICE
set_mask(Mask const & mask)1608   void set_mask(Mask const& mask) {
1609     iterator_.set_mask(mask);
1610   }
1611 
1612   /// Gets the mask
1613   CUTLASS_HOST_DEVICE
get_mask(Mask & mask)1614   void get_mask(Mask& mask) {
1615     iterator_.get_mask(mask);
1616   }
1617 
1618   /// Loads a fragment from memory
1619   CUTLASS_DEVICE
load_with_pointer_offset(Fragment & frag,Index pointer_offset)1620   void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
1621     iterator_.load_with_pointer_offset(frag, pointer_offset);
1622   }
1623 
1624   /// Loads a fragment from memory
1625   CUTLASS_DEVICE
load_with_byte_offset(Fragment & frag,LongIndex byte_offset)1626   void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
1627     iterator_.load_with_byte_offset(frag, byte_offset);
1628   }
1629 
1630   /// Loads a fragment from memory
1631   CUTLASS_DEVICE
load(Fragment & frag)1632   void load(Fragment& frag) {
1633     load_with_pointer_offset(frag, 0);
1634   }
1635 
1636   /// Store a fragment to memory
1637   CUTLASS_DEVICE
store_with_pointer_offset(Fragment const & frag,Index pointer_offset)1638   void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
1639     iterator_.store_with_pointer_offset(frag, pointer_offset);
1640   }
1641 
1642   /// Store a fragment to memory
1643   CUTLASS_DEVICE
store_with_byte_offset(Fragment const & frag,LongIndex byte_offset)1644   void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
1645     iterator_.store_with_byte_offset(frag, byte_offset);
1646   }
1647 
1648   /// Store a fragment to memory
1649   CUTLASS_DEVICE
store(Fragment const & frag)1650   void store(Fragment const& frag) {
1651     store_with_pointer_offset(frag, 0);
1652   }
1653 };
1654 
1655 ////////////////////////////////////////////////////////////////////////////////
1656 
1657 /// Specialization of PredicatedTileIteratorResidualLast for interleaved data.
1658 /// It is mapped to the congruous layout.
1659 ///
1660 /// Satisfies: ForwardTileIteratorConcept |
1661 ///            ReadableContiguousTileIteratorConcept |
1662 ///            WriteableContiguousTileIteratorConcept |
1663 ///            MaskedTileIteratorConcept
1664 ///
1665 
1666 template <
1667     typename Shape_,
1668     typename Element_,
1669     int AdvanceRank,
1670     typename ThreadMap_,
1671     int AccessSize,
1672     int InterleavedK>
1673 class PredicatedTileIteratorResidualLast<
1674     Shape_,
1675     Element_,
1676     layout::ColumnMajorInterleaved<InterleavedK>,
1677     AdvanceRank,
1678     ThreadMap_,
1679     AccessSize,
1680     false> {
1681  public:
1682   static_assert(
1683       AdvanceRank == 0 || AdvanceRank == 1,
1684       "Specialization for pitch-linear iterator may along advance along the "
1685       "contiguous(rank=0) or strided(rank=1) dimension.");
1686 
1687   using Shape = Shape_;
1688   using Element = Element_;
1689   static int const kInterleavedK = InterleavedK;
1690   using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
1691   static int const kAdvanceRank = AdvanceRank;
1692   using ThreadMap = ThreadMap_;
1693 
1694   using Index = typename Layout::Index;
1695   using LongIndex = typename Layout::LongIndex;
1696 
1697   using TensorRef = TensorRef<Element, Layout>;
1698   using TensorView = TensorView<Element, Layout>;
1699   using TensorCoord = typename Layout::TensorCoord;
1700 
1701   using Pointer = Element*;
1702   using NonConstPointer = typename platform::remove_const<Element>::type*;
1703 
1704   using UnderlyingIterator = PredicatedTileIteratorResidualLast<
1705       layout::PitchLinearShape<
1706           Shape::kRow * kInterleavedK,
1707           Shape::kColumn / kInterleavedK>,
1708       Element,
1709       layout::PitchLinear,
1710       (kAdvanceRank == 0 ? 0 : 1),
1711       ThreadMap,
1712       AccessSize>;
1713 
1714   using AccessType = typename UnderlyingIterator::AccessType;
1715 
1716   /// Fragment object to be loaded or stored
1717   using Fragment = cutlass::Array<
1718       Element,
1719       ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
1720 
1721   /// Predicate vector stores mask to guard accesses
1722   using Mask = typename UnderlyingIterator::Mask;
1723 
1724   /// Parameters object is precomputed state and is host-constructible
1725   class Params {
1726    private:
1727     friend PredicatedTileIteratorResidualLast;
1728 
1729     /// Parameters object
1730     typename UnderlyingIterator::Params params_;
1731 
1732    public:
1733     CUTLASS_HOST_DEVICE
Params()1734     Params() {}
1735 
1736     /// Construct the Params object given a pitch-linear tensor's layout
1737     CUTLASS_HOST_DEVICE
Params(Layout const & layout)1738     Params(Layout const& layout)
1739         : params_(layout::PitchLinear(layout.stride(0))) {}
1740 
1741     CUTLASS_HOST_DEVICE
Params(typename UnderlyingIterator::Params::Base const & base)1742     Params(typename UnderlyingIterator::Params::Base const& base)
1743         : params_(base) {}
1744   };
1745 
1746  private:
1747   //
1748   // Data members
1749   //
1750 
1751   /// Underlying pitch-linear tile iterator
1752   UnderlyingIterator iterator_;
1753 
1754  public:
1755   /// Constructs a TileIterator from its precomputed state, threadblock offset,
1756   /// and thread ID
1757   CUTLASS_HOST_DEVICE
1758   PredicatedTileIteratorResidualLast(
1759       /// Precomputed parameters object
1760       Params const& params,
1761       /// Pointer to start of tensor
1762       Pointer pointer,
1763       /// Extent of tensor
1764       TensorCoord extent,
1765       /// ID of each participating thread
1766       int thread_id,
1767       /// Initial offset of threadblock
1768       TensorCoord const& threadblock_offset,
1769       int const* indices =
1770           nullptr ///< gather/scatter indices, note no support for
1771                   ///< gather/scatter at this specialization
1772       )
1773       : iterator_(
1774             params.params_,
1775             pointer,
1776             layout::PitchLinearCoord(
1777                 extent.row() * kInterleavedK,
1778                 extent.column() / kInterleavedK),
1779             thread_id,
1780             layout::PitchLinearCoord(
1781                 threadblock_offset.row() * kInterleavedK,
1782                 threadblock_offset.column() / kInterleavedK)) {}
1783 
1784   /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
1785   /// offset
1786   CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)1787   PredicatedTileIteratorResidualLast(
1788       Params const& params, ///< Precomputed parameters object
1789       Pointer pointer, ///< Pointer to start of tensor
1790       TensorCoord extent, ///< Extent of tensor
1791       int thread_id ///< ID of each participating thread
1792       )
1793       : PredicatedTileIteratorResidualLast(
1794             params,
1795             pointer,
1796             extent,
1797             thread_id,
1798             make_Coord(0, 0)) {}
1799 
1800   /// Adds a pointer offset in units of Element
1801   CUTLASS_HOST_DEVICE
add_pointer_offset(LongIndex pointer_offset)1802   void add_pointer_offset(LongIndex pointer_offset) {
1803     iterator_.add_pointer_offset(pointer_offset);
1804   }
1805 
1806   /// Advances to the next tile in memory.
1807   ///
1808   /// The first time this method is called, predicates are updated, and the
1809   /// iterator's internal pointer is reverted to the first "steady state" tile.
1810   /// Subsequent calls are lightweight and must only update the internal
1811   /// pointer.
1812   CUTLASS_HOST_DEVICE
1813   PredicatedTileIteratorResidualLast& operator++() {
1814     ++iterator_;
1815     return *this;
1816   }
1817 
1818   /// Advances to the next tile in memory.
1819   ///
1820   /// The first time this method is called, predicates are updated, and the
1821   /// iterator's internal pointer is reverted to the first "steady state" tile.
1822   /// Subsequent calls are lightweight and must only update the internal
1823   /// pointer.
1824   CUTLASS_HOST_DEVICE
1825   PredicatedTileIteratorResidualLast operator++(int) {
1826     PredicatedTileIteratorResidualLast self(*this);
1827     operator++();
1828     return self;
1829   }
1830 
1831   /// Clears the predicate set efficiently
1832   CUTLASS_HOST_DEVICE
1833   void clear_mask(bool enable = true) {
1834     iterator_.clear_mask(enable);
1835   }
1836 
1837   CUTLASS_HOST_DEVICE
set_residual_tile(bool enable)1838   void set_residual_tile(bool enable) {
1839     iterator_.set_residual_tile(enable);
1840   }
1841 
1842   /// Clears the predicate set efficiently
1843   CUTLASS_HOST_DEVICE
enable_mask()1844   void enable_mask() {
1845     iterator_.enable_mask();
1846   }
1847 
1848   /// Sets the predicate mask, overriding value stored in predicate iterator
1849   CUTLASS_HOST_DEVICE
set_mask(Mask const & mask)1850   void set_mask(Mask const& mask) {
1851     iterator_.set_mask(mask);
1852   }
1853 
1854   /// Gets the mask
1855   CUTLASS_HOST_DEVICE
get_mask(Mask & mask)1856   void get_mask(Mask& mask) {
1857     iterator_.get_mask(mask);
1858   }
1859 
1860   /// Loads a fragment from memory
1861   CUTLASS_DEVICE
load_with_pointer_offset(Fragment & frag,Index pointer_offset)1862   void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
1863     iterator_.load_with_pointer_offset(frag, pointer_offset);
1864   }
1865 
1866   /// Loads a fragment from memory
1867   CUTLASS_DEVICE
load(Fragment & frag)1868   void load(Fragment& frag) {
1869     load_with_pointer_offset(frag, 0);
1870   }
1871 
1872   /// Store a fragment to memory
1873   CUTLASS_DEVICE
store_with_pointer_offset(Fragment const & frag,Index pointer_offset)1874   void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
1875     iterator_.store_with_pointer_offset(frag, pointer_offset);
1876   }
1877 
1878   /// Store a fragment to memory
1879   CUTLASS_DEVICE
store(Fragment const & frag)1880   void store(Fragment const& frag) {
1881     store_with_pointer_offset(frag, 0);
1882   }
1883 };
1884 
1885 ////////////////////////////////////////////////////////////////////////////////
1886 
1887 /// Specialization of PredicatedTileIteratorResidualLast for interleaved-32
1888 /// data.  It is mapped to the congruous layout.
1889 ///
1890 /// Satisfies: ForwardTileIteratorConcept |
1891 ///            ReadableContiguousTileIteratorConcept |
1892 ///            WriteableContiguousTileIteratorConcept |
1893 ///            MaskedTileIteratorConcept
1894 ///
1895 template <
1896     typename Shape_,
1897     typename Element_,
1898     int AdvanceRank,
1899     typename ThreadMap_,
1900     int AccessSize,
1901     int InterleavedK>
1902 class PredicatedTileIteratorResidualLast<
1903     Shape_,
1904     Element_,
1905     layout::RowMajorInterleaved<InterleavedK>,
1906     AdvanceRank,
1907     ThreadMap_,
1908     AccessSize,
1909     false> {
1910  public:
1911   static_assert(
1912       AdvanceRank == 0 || AdvanceRank == 1,
1913       "Specialization for pitch-linear iterator may along advance along the "
1914       "contiguous(rank=0) or strided(rank=1) dimension.");
1915 
1916   using Shape = Shape_;
1917   using Element = Element_;
1918   static int const kInterleavedK = InterleavedK;
1919   using Layout = layout::RowMajorInterleaved<kInterleavedK>;
1920   static int const kAdvanceRank = AdvanceRank;
1921   using ThreadMap = ThreadMap_;
1922 
1923   using Index = typename Layout::Index;
1924   using LongIndex = typename Layout::LongIndex;
1925 
1926   using TensorRef = TensorRef<Element, Layout>;
1927   using TensorView = TensorView<Element, Layout>;
1928   using TensorCoord = typename Layout::TensorCoord;
1929 
1930   using Pointer = Element*;
1931   using NonConstPointer = typename platform::remove_const<Element>::type*;
1932 
1933   using UnderlyingIterator = PredicatedTileIteratorResidualLast<
1934       layout::PitchLinearShape<
1935           Shape::kColumn * kInterleavedK,
1936           Shape::kRow / kInterleavedK>,
1937       Element,
1938       layout::PitchLinear,
1939       (kAdvanceRank == 0 ? 1 : 0),
1940       ThreadMap,
1941       AccessSize>;
1942 
1943   using AccessType = typename UnderlyingIterator::AccessType;
1944 
1945   /// Fragment object to be loaded or stored
1946   using Fragment = cutlass::Array<
1947       Element,
1948       ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
1949 
1950   /// Predicate vector stores mask to guard accesses
1951   using Mask = typename UnderlyingIterator::Mask;
1952 
1953   /// Parameters object is precomputed state and is host-constructible
1954   class Params {
1955    private:
1956     friend PredicatedTileIteratorResidualLast;
1957 
1958     /// Parameters object
1959     typename UnderlyingIterator::Params params_;
1960 
1961    public:
1962     CUTLASS_HOST_DEVICE
Params()1963     Params() {}
1964 
1965     /// Construct the Params object given a pitch-linear tensor's layout
1966     CUTLASS_HOST_DEVICE
Params(Layout const & layout)1967     Params(Layout const& layout)
1968         : params_(layout::PitchLinear(layout.stride(0))) {}
1969 
1970     CUTLASS_HOST_DEVICE
Params(typename UnderlyingIterator::Params::Base const & base)1971     Params(typename UnderlyingIterator::Params::Base const& base)
1972         : params_(base) {}
1973   };
1974 
1975  private:
1976   //
1977   // Data members
1978   //
1979 
1980   /// Underlying pitch-linear tile iterator
1981   UnderlyingIterator iterator_;
1982 
1983  public:
1984   /// Constructs a TileIterator from its precomputed state, threadblock offset,
1985   /// and thread ID
1986   CUTLASS_HOST_DEVICE
1987   PredicatedTileIteratorResidualLast(
1988       /// Precomputed parameters object
1989       Params const& params,
1990       /// Pointer to start of tensor
1991       Pointer pointer,
1992       /// Extent of tensor
1993       TensorCoord extent,
1994       /// ID of each participating thread
1995       int thread_id,
1996       /// Initial offset of threadblock
1997       TensorCoord const& threadblock_offset,
1998       int const* indices =
1999           nullptr ///< gather/scatter indices, note no support for
2000                   ///< gather/scatter at this specialization
2001       )
2002       : iterator_(
2003             params.params_,
2004             pointer,
2005             layout::PitchLinearCoord(
2006                 extent.column() * kInterleavedK,
2007                 extent.row() / kInterleavedK),
2008             thread_id,
2009             layout::PitchLinearCoord(
2010                 threadblock_offset.column() * kInterleavedK,
2011                 threadblock_offset.row() / kInterleavedK)) {}
2012 
2013   /// Construct a PredicatedTileIteratorResidualLast with zero threadblock
2014   /// offset
2015   CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(Params const & params,Pointer pointer,TensorCoord extent,int thread_id)2016   PredicatedTileIteratorResidualLast(
2017       Params const& params, ///< Precomputed parameters object
2018       Pointer pointer, ///< Pointer to start of tensor
2019       TensorCoord extent, ///< Extent of tensor
2020       int thread_id ///< ID of each participating thread
2021       )
2022       : PredicatedTileIteratorResidualLast(
2023             params,
2024             pointer,
2025             extent,
2026             thread_id,
2027             make_Coord(0, 0)) {}
2028 
2029   /// Adds a pointer offset in units of Element
2030   CUTLASS_HOST_DEVICE
add_pointer_offset(LongIndex pointer_offset)2031   void add_pointer_offset(LongIndex pointer_offset) {
2032     iterator_.add_pointer_offset(pointer_offset);
2033   }
2034 
2035   /// Advances to the next tile in memory.
2036   ///
2037   /// The first time this method is called, predicates are updated, and the
2038   /// iterator's internal pointer is reverted to the first "steady state" tile.
2039   /// Subsequent calls are lightweight and must only update the internal
2040   /// pointer.
2041   CUTLASS_HOST_DEVICE
2042   PredicatedTileIteratorResidualLast& operator++() {
2043     ++iterator_;
2044     return *this;
2045   }
2046 
2047   /// Advances to the next tile in memory.
2048   ///
2049   /// The first time this method is called, predicates are updated, and the
2050   /// iterator's internal pointer is reverted to the first "steady state" tile.
2051   /// Subsequent calls are lightweight and must only update the internal
2052   /// pointer.
2053   CUTLASS_HOST_DEVICE
2054   PredicatedTileIteratorResidualLast operator++(int) {
2055     PredicatedTileIteratorResidualLast self(*this);
2056     operator++();
2057     return self;
2058   }
2059 
2060   /// Clears the predicate set efficiently
2061   CUTLASS_HOST_DEVICE
2062   void clear_mask(bool enable = true) {
2063     iterator_.clear_mask(enable);
2064   }
2065 
2066   CUTLASS_HOST_DEVICE
set_residual_tile(bool enable)2067   void set_residual_tile(bool enable) {
2068     iterator_.set_residual_tile(enable);
2069   }
2070 
2071   /// Clears the predicate set efficiently
2072   CUTLASS_HOST_DEVICE
enable_mask()2073   void enable_mask() {
2074     iterator_.enable_mask();
2075   }
2076 
2077   /// Sets the predicate mask, overriding value stored in predicate iterator
2078   CUTLASS_HOST_DEVICE
set_mask(Mask const & mask)2079   void set_mask(Mask const& mask) {
2080     iterator_.set_mask(mask);
2081   }
2082 
2083   /// Gets the mask
2084   CUTLASS_HOST_DEVICE
get_mask(Mask & mask)2085   void get_mask(Mask& mask) {
2086     iterator_.get_mask(mask);
2087   }
2088 
2089   /// Loads a fragment from memory
2090   CUTLASS_DEVICE
load_with_pointer_offset(Fragment & frag,Index pointer_offset)2091   void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
2092     iterator_.load_with_pointer_offset(frag, pointer_offset);
2093   }
2094 
2095   /// Loads a fragment from memory
2096   CUTLASS_DEVICE
load(Fragment & frag)2097   void load(Fragment& frag) {
2098     load_with_pointer_offset(frag, 0);
2099   }
2100 
2101   /// Store a fragment to memory
2102   CUTLASS_DEVICE
store_with_pointer_offset(Fragment const & frag,Index pointer_offset)2103   void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
2104     iterator_.store_with_pointer_offset(frag, pointer_offset);
2105   }
2106 
2107   /// Store a fragment to memory
2108   CUTLASS_DEVICE
store(Fragment const & frag)2109   void store(Fragment const& frag) {
2110     store_with_pointer_offset(frag, 0);
2111   }
2112 };
2113 
2114 ////////////////////////////////////////////////////////////////////////////////
2115 
2116 } // namespace threadblock
2117 } // namespace transform
2118 } // namespace cutlass
2119 
2120 ////////////////////////////////////////////////////////////////////////////////
2121