58 namespace threadblock {
64 template <
typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
65 typename ThreadMap,
typename AccessType>
72 template <
typename Shape_,
typename Element_,
int AdvanceRank,
73 typename ThreadMap_,
typename AccessType_>
75 AdvanceRank, ThreadMap_, AccessType_> {
78 AdvanceRank == 0 || AdvanceRank == 1,
79 "Specialization for pitch-linear iterator may along advance along the " 80 "contiguous(rank=0) or strided(rank=1) dimension.");
83 using Element = Element_;
85 static int const kAdvanceRank = AdvanceRank;
86 using ThreadMap = ThreadMap_;
99 static int const kPredicatesPerByte = 4;
100 static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
103 static int const kPredicateByteCount = (ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kStrided + kPredicatesPerByte - 1) / kPredicatesPerByte;
104 static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
106 static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
108 static_assert(kPredicateWordCount <= 4,
"Too many predicates.");
111 using Mask = Array<uint32_t, kPredicateWordCount>;
135 Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { }
139 Params(Layout
const &layout) : stride_(layout.stride(0)) {
142 (stride_ * ThreadMap::Delta::kStrided) *
int(
sizeof(Element));
146 inc_advance_ = Shape::kStrided * stride_ * int(
sizeof(Element));
149 inc_advance_ = Shape::kContiguous * int(
sizeof(Element));
152 inc_next_ = inc_advance_ - (ThreadMap::Iterations::kStrided - 1) *
153 ThreadMap::Delta::kStrided * stride_ *
154 int(
sizeof(Element));
160 using BytePointer =
char *;
168 Params
const ¶ms_;
171 BytePointer pointer_;
174 uint32_t predicates_[kPredicateWordCount];
183 int residue_tile_idx_;
186 bool is_residue_tile_;
189 int iteration_contiguous_;
192 int iteration_strided_;
195 int iteration_thread_;
200 void compute_predicates_(
202 bool is_steady_state =
false) {
205 for (
int i = 0; i < kPredicateWordCount; ++i) {
210 for (
int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
212 for (
int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
214 for (
int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++) {
216 TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous,
217 ts + s * ThreadMap::Delta::kStrided);
219 TensorCoord coord = thread_offset_ + iteration_coord;
223 if (is_steady_state) {
224 if (kAdvanceRank == 0) {
225 guard = (coord.strided() < extent_.strided());
227 guard = (coord.contiguous() < extent_.contiguous());
230 guard = (coord.strided() < extent_.strided() &&
231 coord.contiguous() < extent_.contiguous());
234 int pred_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
235 int word_idx = pred_idx / kPredicatesPerWord;
236 int residual = pred_idx % kPredicatesPerWord;
237 int byte_idx = residual / kPredicatesPerByte;
238 int bit_idx = residual % kPredicatesPerByte;
240 predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
254 Params
const ¶ms,
264 pointer_(reinterpret_cast<BytePointer>(
267 is_residue_tile_(true) {
273 (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) /
275 residue_offset =
make_Coord(0, residue_tile_idx_ * Shape::kStrided);
278 (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) /
280 residue_offset =
make_Coord(residue_tile_idx_ * Shape::kContiguous, 0);
284 thread_offset_ = threadblock_offset + residue_offset +
285 ThreadMap::initial_offset(thread_id);
288 Layout layout(params_.stride_);
289 add_pointer_offset(layout(thread_offset_));
291 compute_predicates_(
false);
293 set_iteration_index(0);
300 Params
const ¶ms,
314 int residual = index % (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided);
315 iteration_strided_ = index / (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided);
317 iteration_contiguous_ = residual / ThreadMap::ThreadAccessShape::kStrided;
318 iteration_thread_ = residual % ThreadMap::ThreadAccessShape::kStrided;
325 pointer_ += int(
sizeof(Element)) * pointer_offset;
332 if (is_residue_tile_) {
335 residue_offset =
TensorCoord(0, residue_tile_idx_ * Shape::kStrided);
337 residue_offset =
TensorCoord(residue_tile_idx_ * Shape::kContiguous, 0);
340 thread_offset_ -= residue_offset;
342 Layout layout(params_.stride_);
343 add_pointer_offset(-layout(residue_offset));
345 compute_predicates_(
true);
348 pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1);
349 pointer_ += Shape::kContiguous * tile_offset.contiguous();
351 pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1);
352 pointer_ += Shape::kStrided * tile_offset.strided();
356 pointer_ += params_.inc_advance_ * tile_offset.strided();
357 pointer_ += Shape::kContiguous * tile_offset.contiguous();
359 pointer_ += params_.inc_advance_ * tile_offset.contiguous();
360 pointer_ += Shape::kStrided * tile_offset.strided();
363 is_residue_tile_ =
false;
370 pointer_ + (iteration_thread_ * params_.stride_ + iteration_contiguous_ * ThreadMap::Delta::kContiguous) *
int(
sizeof(Element)));
381 if (iteration_thread_ < ThreadMap::ThreadAccessShape::kStrided)
384 iteration_thread_ = 0;
386 ++iteration_contiguous_;
388 if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
393 iteration_contiguous_ = 0;
394 ++iteration_strided_;
396 if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
397 pointer_ += params_.inc_strided_;
403 iteration_strided_ = 0;
406 pointer_ += params_.inc_next_;
411 pointer_ -= params_.inc_advance_;
428 for (
int i = 0; i < kPredicateWordCount; ++i) {
438 for (
int i = 0; i < kPredicateWordCount; ++i) {
439 predicates_[i] = 0xffffffff;
447 for (
int i = 0; i < kPredicateWordCount; ++i) {
448 predicates_[i] = mask[i];
457 for (
int i = 0; i < kPredicateWordCount; ++i) {
458 mask[i] = predicates_[i];
468 iteration_contiguous_ * ThreadMap::ThreadAccessShape::kStrided +
469 iteration_strided_ * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
471 int word_idx = pred_idx / kPredicatesPerWord;
472 int residual = pred_idx % kPredicatesPerWord;
473 int byte_idx = residual / kPredicatesPerByte;
474 int bit_idx = residual % kPredicatesPerByte;
476 bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
491 template <
typename Shape_,
typename Element_,
int AdvanceRank,
492 typename ThreadMap_,
typename AccessType_>
494 AdvanceRank, ThreadMap_, AccessType_> {
497 AdvanceRank == 0 || AdvanceRank == 1,
498 "Specialization for pitch-linear iterator may along advance along the " 499 "contiguous(rank=0) or strided(rank=1) dimension.");
502 using Element = Element_;
504 static int const kAdvanceRank = AdvanceRank;
505 using ThreadMap = ThreadMap_;
523 using Mask =
typename UnderlyingIterator::Mask;
531 typename UnderlyingIterator::Params params_;
542 : params_(layout::PitchLinear(layout.stride(0))){};
559 Params
const ¶ms,
568 : iterator_(params.params_, pointer,
569 layout::PitchLinearCoord(extent.row(), extent.column()),
571 layout::PitchLinearCoord(threadblock_offset.row(),
572 threadblock_offset.column())) {}
577 Params
const ¶ms,
592 iterator_.add_pointer_offset(pointer_offset);
599 iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
605 return reinterpret_cast<AccessType *
>(iterator_.get());
652 return iterator_.valid();
665 template <
typename Shape_,
typename Element_,
int AdvanceRank,
666 typename ThreadMap_,
typename AccessType_>
668 AdvanceRank, ThreadMap_, AccessType_> {
671 AdvanceRank == 0 || AdvanceRank == 1,
672 "Specialization for pitch-linear iterator may along advance along the " 673 "contiguous(rank=0) or strided(rank=1) dimension.");
676 using Element = Element_;
678 static int const kAdvanceRank = AdvanceRank;
679 using ThreadMap = ThreadMap_;
697 using Mask =
typename UnderlyingIterator::Mask;
705 typename UnderlyingIterator::Params params_;
716 : params_(layout::PitchLinear(layout.stride(0))){};
733 Params
const ¶ms,
742 : iterator_(params.params_, pointer,
743 layout::PitchLinearCoord(extent.column(), extent.row()),
745 layout::PitchLinearCoord(threadblock_offset.column(),
746 threadblock_offset.row())) {}
751 Params
const ¶ms,
766 iterator_.add_pointer_offset(pointer_offset);
773 iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
779 return reinterpret_cast<AccessType *
>(iterator_.get());
826 return iterator_.valid();
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
Definition: aligned_buffer.h:35
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data.
Mapping function for pitch-linear memory.
Definition: pitch_linear.h:163
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:154
Defines a structure containing strides and a pointer to tensor data.
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:59
CUTLASS_HOST_DEVICE half_t & operator++(half_t &lhs)
Definition: half.h:694
int64_t LongIndex
Long index type used for offsets.
Definition: pitch_linear.h:175
Defines a Shape template for matrix tiles.
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
int32_t Index
Index type used for coordinates.
Definition: pitch_linear.h:172
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Defines layout functions used by TensorRef and derived classes.
Defines layout functions used by TensorRef and derived classes for pitch-linear memory.
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:151
Basic include for CUTLASS.
Definition: matrix_coord.h:39