58 namespace threadblock {
64 template <
typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
65 typename ThreadMap,
typename AccessType>
72 template <
typename Shape_,
typename Element_,
int AdvanceRank,
73 typename ThreadMap_,
typename AccessType_>
75 AdvanceRank, ThreadMap_, AccessType_> {
78 AdvanceRank == 0 || AdvanceRank == 1,
79 "Specialization for pitch-linear iterator may along advance along the " 80 "contiguous(rank=0) or strided(rank=1) dimension.");
83 using Element = Element_;
85 static int const kAdvanceRank = AdvanceRank;
86 using ThreadMap = ThreadMap_;
99 static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
101 static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
102 "Vectors implied by the thread map must be divisible by the access type.");
104 static int const kPredicatesPerByte = 4;
105 static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
107 static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector;
110 static int const kPredicateByteCount =
111 (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte;
112 static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
114 static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
116 static_assert(kPredicateWordCount <= 4,
"Too many predicates.");
119 using Mask = Array<uint32_t, kPredicateWordCount>;
143 Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { }
147 Params(Layout
const &layout) : stride_(layout.stride(0)) {
148 inc_strided_ = (stride_ * ThreadMap::Delta::kStrided) *
160 inc_next_ = inc_advance_ - (ThreadMap::Iterations::kStrided - 1) *
161 ThreadMap::Delta::kStrided * stride_ *
168 using BytePointer =
char *;
176 Params
const ¶ms_;
179 BytePointer pointer_;
182 uint32_t predicates_[kPredicateWordCount];
194 bool is_residue_tile_;
197 int iteration_vector_;
200 int iteration_contiguous_;
203 int iteration_strided_;
208 void compute_predicates_(
212 bool is_steady_state =
false) {
215 for (
int i = 0; i < kPredicateWordCount; ++i) {
219 for (
int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) {
221 int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
223 int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
225 int c = access_residual / kAccessesPerVector;
226 int v = access_residual % kAccessesPerVector;
228 TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements,
229 s * ThreadMap::Delta::kStrided);
231 TensorCoord coord = thread_offset_ + iteration_coord;
235 if (is_steady_state) {
236 if (kAdvanceRank == 0) {
237 guard = (coord.strided() < extent.strided());
239 guard = (coord.contiguous() < extent.contiguous());
242 guard = (coord.strided() < extent.strided() &&
243 coord.contiguous() < extent.contiguous());
246 int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s);
248 int word_idx = pred_idx / kPredicatesPerWord;
249 int residual = pred_idx % kPredicatesPerWord;
250 int byte_idx = residual / kPredicatesPerByte;
251 int bit_idx = residual % kPredicatesPerByte;
253 predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
265 Params
const ¶ms,
275 pointer_(reinterpret_cast<BytePointer>(
278 is_residue_tile_(true) {
283 Index residue_size = (extent_[kAdvanceRank] % Shape::kStrided);
285 residue_size = Shape::kStrided;
288 residue_offset_ =
make_Coord(0, residue_size);
290 extent_.contiguous(),
291 min(threadblock_offset.strided() + residue_offset_.strided(), extent_.strided())
296 Index residue_size = (extent_[kAdvanceRank] % Shape::kContiguous);
298 residue_size = Shape::kContiguous;
300 residue_offset_ =
make_Coord(residue_size, 0);
302 min(extent_.contiguous(), threadblock_offset.contiguous() + residue_offset_.contiguous()),
308 thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id);
311 Layout layout(params_.stride_);
312 add_pointer_offset(layout(thread_offset_));
314 compute_predicates_(residue_extent,
false);
316 set_iteration_index(0);
323 Params
const ¶ms,
337 iteration_vector_ = index % kAccessesPerVector;
338 int residual_access = index / kAccessesPerVector;
340 iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
341 iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
355 if (is_residue_tile_) {
357 thread_offset_ += residue_offset_;
359 Layout layout(params_.stride_);
360 add_pointer_offset(layout(residue_offset_));
362 compute_predicates_(extent_,
true);
365 pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1);
366 pointer_ += Shape::kContiguous * tile_offset.contiguous();
368 pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1);
369 pointer_ += Shape::kStrided * tile_offset.strided();
373 pointer_ += params_.inc_advance_ * tile_offset.strided();
374 pointer_ += Shape::kContiguous * tile_offset.contiguous();
376 pointer_ += params_.inc_advance_ * tile_offset.contiguous();
377 pointer_ += Shape::kStrided * tile_offset.strided();
380 is_residue_tile_ =
false;
396 if (iteration_vector_ < kAccessesPerVector) {
400 iteration_vector_ = 0;
401 ++iteration_contiguous_;
403 if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
409 iteration_contiguous_ = 0;
410 ++iteration_strided_;
412 if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
413 pointer_ += params_.inc_strided_;
419 iteration_strided_ = 0;
422 pointer_ += params_.inc_next_;
427 pointer_ -= params_.inc_advance_;
444 for (
int i = 0; i < kPredicateWordCount; ++i) {
454 for (
int i = 0; i < kPredicateWordCount; ++i) {
455 predicates_[i] = 0xffffffff;
463 for (
int i = 0; i < kPredicateWordCount; ++i) {
464 predicates_[i] = mask[i];
473 for (
int i = 0; i < kPredicateWordCount; ++i) {
474 mask[i] = predicates_[i];
484 iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous);
486 int word_idx = pred_idx / kPredicatesPerWord;
487 int residual = pred_idx % kPredicatesPerWord;
488 int byte_idx = residual / kPredicatesPerByte;
489 int bit_idx = residual % kPredicatesPerByte;
491 bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
508 template <
typename Shape_,
typename Element_,
int AdvanceRank,
509 typename ThreadMap_,
typename AccessType_>
511 AdvanceRank, ThreadMap_, AccessType_> {
514 AdvanceRank == 0 || AdvanceRank == 1,
515 "Specialization for pitch-linear iterator may along advance along the " 516 "contiguous(rank=0) or strided(rank=1) dimension.");
519 using Element = Element_;
521 static int const kAdvanceRank = AdvanceRank;
522 using ThreadMap = ThreadMap_;
540 using Mask =
typename UnderlyingIterator::Mask;
542 static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
550 typename UnderlyingIterator::Params params_;
561 : params_(layout::PitchLinear(layout.stride(0))){};
578 Params
const ¶ms,
587 : iterator_(params.params_, pointer,
588 layout::PitchLinearCoord(extent.row(), extent.column()),
590 layout::PitchLinearCoord(threadblock_offset.row(),
591 threadblock_offset.column())) {}
596 Params
const ¶ms,
611 iterator_.add_pointer_offset(pointer_offset);
618 iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
624 return reinterpret_cast<AccessType *
>(iterator_.get());
671 return iterator_.valid();
684 template <
typename Shape_,
typename Element_,
int AdvanceRank,
685 typename ThreadMap_,
typename AccessType_>
687 AdvanceRank, ThreadMap_, AccessType_> {
690 AdvanceRank == 0 || AdvanceRank == 1,
691 "Specialization for pitch-linear iterator may along advance along the " 692 "contiguous(rank=0) or strided(rank=1) dimension.");
695 using Element = Element_;
697 static int const kAdvanceRank = AdvanceRank;
698 using ThreadMap = ThreadMap_;
715 static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
718 using Mask =
typename UnderlyingIterator::Mask;
726 typename UnderlyingIterator::Params params_;
737 : params_(layout::PitchLinear(layout.stride(0))){};
754 Params
const ¶ms,
763 : iterator_(params.params_, pointer,
764 layout::PitchLinearCoord(extent.column(), extent.row()),
766 layout::PitchLinearCoord(threadblock_offset.column(),
767 threadblock_offset.row())) {}
772 Params
const ¶ms,
787 iterator_.add_pointer_offset(pointer_offset);
794 iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
800 return reinterpret_cast<AccessType *
>(iterator_.get());
847 return iterator_.valid();
862 template <
typename Shape_,
typename Element_,
int AdvanceRank,
863 typename ThreadMap_,
typename AccessType_,
int InterleavedK>
865 layout::ColumnMajorInterleaved<InterleavedK>,
866 AdvanceRank, ThreadMap_, AccessType_> {
869 AdvanceRank == 0 || AdvanceRank == 1,
870 "Specialization for pitch-linear iterator may along advance along the " 871 "contiguous(rank=0) or strided(rank=1) dimension.");
874 using Element = Element_;
875 static int const kInterleavedK = InterleavedK;
877 static int const kAdvanceRank = AdvanceRank;
878 using ThreadMap = ThreadMap_;
893 Shape::kColumn / kInterleavedK>,
897 static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
900 using Mask =
typename UnderlyingIterator::Mask;
908 typename UnderlyingIterator::Params params_;
917 : params_(layout::PitchLinear(layout.stride(0))) {}
934 Params
const ¶ms,
943 : iterator_(params.params_, pointer,
944 layout::PitchLinearCoord(extent.row() * kInterleavedK,
945 extent.column() / kInterleavedK),
947 layout::PitchLinearCoord(
948 threadblock_offset.row() * kInterleavedK,
949 threadblock_offset.column() / kInterleavedK)) {}
954 Params
const ¶ms,
969 iterator_.add_pointer_offset(pointer_offset);
976 iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
982 return reinterpret_cast<AccessType *
>(iterator_.get());
1028 bool valid() {
return iterator_.valid(); }
1041 template <
typename Shape_,
typename Element_,
int AdvanceRank,
1042 typename ThreadMap_,
typename AccessType_,
int InterleavedK>
1044 layout::RowMajorInterleaved<InterleavedK>,
1045 AdvanceRank, ThreadMap_, AccessType_> {
1048 AdvanceRank == 0 || AdvanceRank == 1,
1049 "Specialization for pitch-linear iterator may along advance along the " 1050 "contiguous(rank=0) or strided(rank=1) dimension.");
1053 using Element = Element_;
1054 static int const kInterleavedK = InterleavedK;
1056 static int const kAdvanceRank = AdvanceRank;
1057 using ThreadMap = ThreadMap_;
1072 Shape::kRow / kInterleavedK>,
1077 static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
1080 using Mask =
typename UnderlyingIterator::Mask;
1088 typename UnderlyingIterator::Params params_;
1097 : params_(layout::PitchLinear(layout.stride(0))) {}
1114 Params
const ¶ms,
1123 : iterator_(params.params_, pointer,
1124 layout::PitchLinearCoord(extent.column() * kInterleavedK,
1125 extent.row() / kInterleavedK),
1127 layout::PitchLinearCoord(
1128 threadblock_offset.column() * kInterleavedK,
1129 threadblock_offset.row() / kInterleavedK)) {}
1134 Params
const ¶ms,
1149 iterator_.add_pointer_offset(pointer_offset);
1156 iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
1162 return reinterpret_cast<AccessType *
>(iterator_.get());
1208 bool valid() {
return iterator_.valid(); }
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:355
Definition: aligned_buffer.h:35
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data.
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:249
Mapping function for pitch-linear memory.
Definition: pitch_linear.h:163
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:352
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:154
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:246
Defines a structure containing strides and a pointer to tensor data.
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:59
CUTLASS_HOST_DEVICE half_t & operator++(half_t &lhs)
Definition: half.h:694
int64_t LongIndex
Long index type used for offsets.
Definition: pitch_linear.h:175
Defines a Shape template for matrix tiles.
Defines the size of an element in bits.
Definition: numeric_types.h:42
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
int32_t Index
Index type used for coordinates.
Definition: pitch_linear.h:172
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Defines layout functions used by TensorRef and derived classes.
Defines layout functions used by TensorRef and derived classes for pitch-linear memory.
Definition: layout/matrix.h:343
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:151
Basic include for CUTLASS.
Definition: matrix_coord.h:39
Definition: layout/matrix.h:237