53 namespace threadblock {
68 using Shape =
typename ThreadMap::Shape;
81 static int const kThreads = ThreadMap::kThreads;
84 static_assert( ThreadMap::Iterations::kRow > 0,
"ThreadMap::Iterations::kRow must be > 0");
85 static_assert( ThreadMap::Iterations::kGroup > 0,
"ThreadMap::Iterations::kGroup must be > 0");
86 static_assert( ThreadMap::Iterations::kCluster > 0,
"ThreadMap::Iterations::kCluster must be > 0");
87 static_assert( ThreadMap::Iterations::kColumn > 0,
"ThreadMap::Iterations::kColumn must be > 0");
92 ThreadMap::Iterations::kColumn *
93 ThreadMap::Iterations::kRow *
94 ThreadMap::Iterations::kGroup *
95 ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
130 increment_row = stride * ThreadMap::Delta::kRow;
132 increment_group = stride * ThreadMap::Delta::kGroup
133 - stride * ThreadMap::Delta::kRow * (ThreadMap::Iterations::kRow - 1);
135 increment_cluster = stride * ThreadMap::Delta::kCluster
136 - stride * ThreadMap::Delta::kGroup * (ThreadMap::Iterations::kGroup - 1)
137 - stride * ThreadMap::Delta::kRow * (ThreadMap::Iterations::kRow - 1);
139 advance_row = stride * ThreadMap::Shape::kRow;
141 advance_group = stride * (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
145 ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;;
149 ThreadMap::Shape::kGroup *
150 ThreadMap::Shape::kRow *
151 ThreadMap::Shape::kCluster *
152 ThreadMap::Shape::kTile;
172 static int const kCount = ThreadMap::Iterations::kColumn;
175 bool predicates[kCount];
188 for (
int i = 0; i < kCount; ++i) {
189 predicates[i] =
false;
196 for (
int i = 0; i < kCount; ++i) {
197 predicates[i] =
true;
212 uint8_t *byte_pointer_;
221 Index thread_start_row_;
249 TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
251 extent_row_ = extent.
row();
252 thread_start_row_ = thread_offset.
row();
256 for (
int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
259 + ThreadMap::Delta::kColumn * c) < extent.
column());
263 byte_pointer_ =
reinterpret_cast<uint8_t *
>(pointer) +
268 state_[0] = state_[1] = state_[2] = 0;
281 uint8_t *byte_pointer = byte_pointer_;
285 for (
int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
288 for (
int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
291 for (
int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
294 (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
296 int row_offset = row * ThreadMap::Delta::kRow
297 + group * ThreadMap::Delta::kGroup
298 + cluster * ThreadMap::Delta::kCluster;
300 bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
305 for (
int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
307 bool guard = row_guard && mask_.
predicates[column];
310 frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] =
315 if (row + 1 < ThreadMap::Iterations::kRow) {
320 if (group + 1 < ThreadMap::Iterations::kGroup) {
325 if (cluster + 1 < ThreadMap::Iterations::kCluster) {
334 uint8_t *byte_pointer = byte_pointer_;
338 for (
int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
341 for (
int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
344 for (
int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
347 (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
349 int row_offset = row * ThreadMap::Delta::kRow
350 + group * ThreadMap::Delta::kGroup
351 + cluster * ThreadMap::Delta::kCluster;
353 bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
358 for (
int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
360 bool guard = row_guard && mask_.
predicates[column];
365 frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];
369 if (row + 1 < ThreadMap::Iterations::kRow) {
374 if (group + 1 < ThreadMap::Iterations::kGroup) {
379 if (cluster + 1 < ThreadMap::Iterations::kCluster) {
391 thread_start_row_ += ThreadMap::Shape::kRow;
393 if (state_[0] == ThreadMap::Count::kRow) {
399 thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
400 ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
402 if (state_[1] == ThreadMap::Count::kGroup) {
408 thread_start_row_ += ThreadMap::Count::kGroup *
409 ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
411 if (state_[2] == ThreadMap::Count::kCluster) {
456 using Element = Element_;
466 static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
467 static int const kThreads = ThreadMap::kThreads;
468 static int const kIterations = ThreadMap::Iterations::kCount;
471 using Fragment = Array<Element, ThreadMap::kElementsPerAccess>;
503 stride_ - ThreadMap::Iterations::kContiguous * kElementsPerAccess *
523 static int const kCount = (ThreadMap::Iterations::kContiguous < 8)
525 : ThreadMap::Iterations::kContiguous;
528 bool predicates[kCount];
541 for (
int i = 0; i < kCount; ++i) {
542 predicates[i] =
false;
549 for (
int i = 0; i < kCount; ++i) {
550 predicates[i] =
true;
565 uint8_t *byte_pointer_;
575 Index thread_start_col_;
578 int iteration_contiguous_;
580 int iteration_strided_;
604 TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) +
606 threadblock_offset.
strided() / InterleavedK);
608 extent_col_ = extent.
strided() / InterleavedK;
609 thread_start_col_ = thread_offset.
strided();
613 for (
int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
615 ((thread_offset.
contiguous() + ThreadMap::Delta::kContiguous * c) <
620 byte_pointer_ =
reinterpret_cast<uint8_t *
>(pointer) +
625 iteration_contiguous_ = iteration_strided_ = 0;
637 uint8_t *byte_pointer = byte_pointer_;
641 int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided;
643 bool col_guard = ((thread_start_col_ + col_offset) < extent_col_);
645 bool guard = col_guard && mask_.
predicates[iteration_contiguous_];
648 *frag_ptr = *memory_pointer;
655 uint8_t *byte_pointer = byte_pointer_;
659 int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided;
661 bool col_guard = ((thread_start_col_ + col_offset) < extent_col_);
663 bool guard = col_guard && mask_.
predicates[iteration_contiguous_];
666 *memory_pointer = *frag_ptr;
673 iteration_contiguous_ = iteration % ThreadMap::Iterations::kContiguous;
674 iteration_strided_ = iteration / ThreadMap::Iterations::kContiguous;
681 ++iteration_contiguous_;
684 if (iteration_contiguous_ == ThreadMap::Iterations::kContiguous) {
686 iteration_contiguous_ = 0;
687 ++iteration_strided_;
690 if (iteration_strided_ == ThreadMap::Iterations::kStrided) {
691 iteration_strided_ = 0;
bool predicates[kCount]
Predicate state.
Definition: epilogue/threadblock/predicated_tile_iterator.h:175
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
static int const kElementsPerAccess
Definition: epilogue/threadblock/predicated_tile_iterator.h:80
CUTLASS_DEVICE void enable()
Definition: epilogue/threadblock/predicated_tile_iterator.h:194
CUTLASS_HOST_DEVICE Index const & column() const
Returns the column of the coordinate.
Definition: matrix_coord.h:85
Index advance_row
amount to add to move to the next 'row' position
Definition: epilogue/threadblock/predicated_tile_iterator.h:116
CUTLASS_DEVICE void load(Fragment &frag)
Loads a fragment from memory.
Definition: epilogue/threadblock/predicated_tile_iterator.h:279
Element_ Element
Definition: epilogue/threadblock/predicated_tile_iterator.h:70
Definition: aligned_buffer.h:35
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data.
AlignedArray< Element, ThreadMap::kElementsPerAccess > AccessType
Memory access size.
Definition: epilogue/threadblock/predicated_tile_iterator.h:98
CUTLASS_HOST_DEVICE Status initialize(Index stride_)
Definition: epilogue/threadblock/predicated_tile_iterator.h:496
CUTLASS_HOST_DEVICE void clear()
CUTLASS_HOST_DEVICE enables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:539
Templates implementing how threads are mapped to a given tile.
CUTLASS_DEVICE void get_mask(Mask &mask)
Sets the mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:432
Array< Element, ThreadMap::kElementsPerAccess > Fragment
Fragment object.
Definition: epilogue/threadblock/predicated_tile_iterator.h:471
ThreadMap_ ThreadMap
Definition: epilogue/threadblock/predicated_tile_iterator.h:454
Aligned array type.
Definition: array.h:511
Mask object.
Definition: epilogue/threadblock/predicated_tile_iterator.h:170
CUTLASS_HOST_DEVICE Index const & row() const
Returns the row of the coordinate.
Definition: matrix_coord.h:77
bool predicates[kCount]
Predicate state.
Definition: epilogue/threadblock/predicated_tile_iterator.h:528
typename TensorRef::ConstTensorRef ConstTensorRef
Definition: epilogue/threadblock/predicated_tile_iterator.h:74
CUTLASS_HOST_DEVICE Mask()
Efficiently disables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:534
CUTLASS_HOST_DEVICE Stride stride() const
Returns the stride of the layout.
Definition: layout/matrix.h:418
CUTLASS_HOST_DEVICE Stride stride() const
Returns the stride of the layout.
Definition: layout/matrix.h:112
Array< Element, ThreadMap::Iterations::kColumn *ThreadMap::Iterations::kRow *ThreadMap::Iterations::kGroup *ThreadMap::Iterations::kCluster *ThreadMap::kElementsPerAccess > Fragment
Fragment object.
Definition: epilogue/threadblock/predicated_tile_iterator.h:95
Definition: epilogue/threadblock/predicated_tile_iterator.h:480
CUTLASS_DEVICE void store(Fragment const &frag)
Stores a fragment to memory.
Definition: epilogue/threadblock/predicated_tile_iterator.h:333
typename Layout::LongIndex LongIndex
Definition: epilogue/threadblock/predicated_tile_iterator.h:77
TensorRef< typename platform::remove_const< Element >::type const, Layout > ConstTensorRef
TensorRef to constant data.
Definition: tensor_ref.h:179
CUTLASS_DEVICE void load(Fragment &frag)
Loads a fragment from memory.
Definition: epilogue/threadblock/predicated_tile_iterator.h:636
CUTLASS_DEVICE InterleavedPredicatedTileIterator(Params const ¶ms, Element *pointer, TensorCoord extent, int thread_idx, TensorCoord threadblock_offset)
Constructor.
Definition: epilogue/threadblock/predicated_tile_iterator.h:596
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
typename Layout::Index Index
Definition: epilogue/threadblock/predicated_tile_iterator.h:462
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:59
Index advance_cluster
amount to add to move to the next 'cluster' position
Definition: epilogue/threadblock/predicated_tile_iterator.h:118
Defines a Shape template for matrix tiles.
CUTLASS_DEVICE void store(Fragment const &frag)
Stores a fragment to memory.
Definition: epilogue/threadblock/predicated_tile_iterator.h:654
Defines the size of an element in bits.
Definition: numeric_types.h:42
Index advance_row
amount to add to move to the next 'row' position
Definition: epilogue/threadblock/predicated_tile_iterator.h:488
CUTLASS_DEVICE void clear_mask()
Efficiently enables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:699
ThreadMap_ ThreadMap
Definition: epilogue/threadblock/predicated_tile_iterator.h:67
Mask object.
Definition: epilogue/threadblock/predicated_tile_iterator.h:522
CUTLASS_HOST_DEVICE Params()
Definition: epilogue/threadblock/predicated_tile_iterator.h:510
CUTLASS_DEVICE PredicatedTileIterator(Params const ¶ms, Element *pointer, TensorCoord extent, int thread_idx, TensorCoord threadblock_offset=TensorCoord())
Constructor.
Definition: epilogue/threadblock/predicated_tile_iterator.h:240
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE Index const & contiguous() const
Returns the contiguous dimension.
Definition: pitch_linear.h:89
Definition: epilogue/threadblock/predicated_tile_iterator.h:452
typename TensorRef::ConstTensorRef ConstTensorRef
Definition: epilogue/threadblock/predicated_tile_iterator.h:460
CUTLASS_HOST_DEVICE void set_iteration_index(int iteration)
Overrides the internal iteration index.
Definition: epilogue/threadblock/predicated_tile_iterator.h:672
Index stride
stride in bytes between rows
Definition: epilogue/threadblock/predicated_tile_iterator.h:110
CUTLASS_HOST_DEVICE PredicatedTileIterator & operator++()
Advances to the next position to load or store.
Definition: epilogue/threadblock/predicated_tile_iterator.h:387
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Definition: epilogue/threadblock/predicated_tile_iterator.h:163
Index stride
stride in bytes between columns
Definition: epilogue/threadblock/predicated_tile_iterator.h:486
Index advance_column
amount to add to move to the next 'column' position
Definition: epilogue/threadblock/predicated_tile_iterator.h:489
Definition: epilogue/threadblock/predicated_tile_iterator.h:104
static int const kIterations
Definition: epilogue/threadblock/predicated_tile_iterator.h:82
Index advance_tile
amount to add to move to the next 'tile'
Definition: epilogue/threadblock/predicated_tile_iterator.h:119
Metaprogram for determining the mapping of output elements to threads for epilogue tiles...
CUTLASS_HOST_DEVICE InterleavedPredicatedTileIterator & operator++()
Advances to the next position to load or store.
Definition: epilogue/threadblock/predicated_tile_iterator.h:679
CUTLASS_DEVICE void clear_mask()
Efficiently enables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:422
Index increment_group
increment quantity (in bytes) to advance when moving to the next group
Definition: epilogue/threadblock/predicated_tile_iterator.h:113
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
typename Layout::Index Index
Definition: epilogue/threadblock/predicated_tile_iterator.h:76
Definition: epilogue/threadblock/predicated_tile_iterator.h:65
typename Layout::LongIndex LongIndex
Definition: epilogue/threadblock/predicated_tile_iterator.h:463
CUTLASS_DEVICE void set_mask(Mask const &mask)
Definition: epilogue/threadblock/predicated_tile_iterator.h:714
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Definition: epilogue/threadblock/predicated_tile_iterator.h:515
CUTLASS_DEVICE void set_mask(Mask const &mask)
Definition: epilogue/threadblock/predicated_tile_iterator.h:437
CUTLASS_DEVICE void enable()
Definition: epilogue/threadblock/predicated_tile_iterator.h:547
Index advance_group
amount to add to move to the next 'group' position
Definition: epilogue/threadblock/predicated_tile_iterator.h:117
Defines layout functions used by TensorRef and derived classes.
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: epilogue/threadblock/predicated_tile_iterator.h:630
CUTLASS_HOST_DEVICE Status initialize(Index stride_)
Definition: epilogue/threadblock/predicated_tile_iterator.h:126
Operation was successful.
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: epilogue/threadblock/predicated_tile_iterator.h:273
typename ThreadMap::Shape Shape
Definition: epilogue/threadblock/predicated_tile_iterator.h:68
Definition: layout/matrix.h:343
MatrixCoord TensorCoord
Definition: epilogue/threadblock/predicated_tile_iterator.h:78
CUTLASS_HOST_DEVICE Params()
Definition: epilogue/threadblock/predicated_tile_iterator.h:158
CUTLASS_DEVICE void enable_mask()
Sets the mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:427
Index increment_row
increment quantity (in bytes) to advance when moving between rows
Definition: epilogue/threadblock/predicated_tile_iterator.h:112
Index increment_cluster
increment quantity (in bytes) to advance when moving to the next cluster
Definition: epilogue/threadblock/predicated_tile_iterator.h:114
CUTLASS_DEVICE void get_mask(Mask &mask)
Sets the mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:709
Basic include for CUTLASS.
Definition: matrix_coord.h:39
CUTLASS_HOST_DEVICE void clear()
CUTLASS_HOST_DEVICE enables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:186
CUTLASS_HOST_DEVICE Index const & strided() const
Returns the column of the coordinate.
Definition: pitch_linear.h:97
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
CUTLASS_DEVICE void enable_mask()
Sets the mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:704
CUTLASS_HOST_DEVICE Mask()
Efficiently disables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:181
static int const kThreads
Definition: epilogue/threadblock/predicated_tile_iterator.h:81