67 int PartitionGroupSize = 1
87 int PartitionGroupSize
124 "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension.");
126 static_assert(Shape::kRow > 0,
"Shape::kRow must be greater than zero.");
127 static_assert(Shape::kColumn > 0,
"Shape::kColumn must be greater than zero.");
128 static_assert(Policy::WarpShape::kRow > 0,
"Policy::WarpShape::kRow must be greater than zero.");
129 static_assert(Shape::kRow / Policy::WarpShape::kRow > 0,
"Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
133 Shape::kRow / Policy::WarpShape::kRow,
137 static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM),
138 "Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
142 ThreadShape::kRow / Policy::LaneMmaShape::kM,
147 using Fragment = Array<Element, ThreadShape::kCount>;
168 typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
170 MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
176 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *
>(ref.
data()),
177 ref.
stride(0) / Policy::LaneMmaShape::kM);
193 coord.row() * Shape::kRow / Policy::LaneMmaShape::kM,
194 coord.column() * Shape::kColumn});
220 Array<Element, Policy::LaneMmaShape::kM> *dst_ptr =
221 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *
>(&frag);
224 for (
int k = 0; k < Iterations::kColumn; ++k) {
226 for (
int m = 0; m < Iterations::kRow; ++m) {
227 dst_ptr[m + k * Iterations::kRow] =
228 *(ref_.
data() + ref_.
offset({m * Policy::WarpShape::kRow, k}) + pointer_offset / Policy::LaneMmaShape::kM);
235 load_with_pointer_offset(frag, 0);
242 Array<Element, Policy::LaneMmaShape::kM>
const *src_ptr =
243 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *
>(&frag);
246 for (
int k = 0; k < Iterations::kN; ++k) {
248 for (
int m = 0; m < Iterations::kM; ++m) {
249 *(ref_.
data() + ref_.
offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) =
250 src_ptr[m + k * Iterations::kM];
258 store_with_pointer_offset(frag, 0);
290 int PartitionGroupSize
326 static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn),
327 "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension.");
329 static_assert(Shape::kRow > 0,
"Shape::kRow must be greater than zero.");
330 static_assert(Shape::kColumn > 0,
"Shape::kColumn must be greater than zero.");
331 static_assert(Policy::WarpShape::kColumn > 0,
"Policy::WarpShape::kColumn must be greater than zero.");
332 static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0,
"Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
337 Shape::kColumn / Policy::WarpShape::kColumn
340 static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN),
341 "Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
346 ThreadShape::kColumn / Policy::LaneMmaShape::kN
350 using Fragment = Array<Element, ThreadShape::kCount>;
372 typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
374 MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
380 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *
>(ref.
data()),
381 ref.
stride(0) / Policy::LaneMmaShape::kN);
396 coord.row() * Shape::kRow,
397 coord.column() * Shape::kColumn / Policy::LaneMmaShape::kN});
424 Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
425 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *
>(&frag);
428 for (
int k = 0; k < Iterations::kRow; ++k) {
430 for (
int n = 0; n < Iterations::kColumn; ++n) {
431 dst_ptr[n + k * Iterations::kColumn] =
432 *(ref_.
data() + ref_.
offset({k, n * Policy::WarpShape::kColumn}) + pointer_offset / Policy::LaneMmaShape::kN);
440 load_with_pointer_offset(frag, 0);
447 Array<Element, Policy::LaneMmaShape::kN>
const *src_ptr =
448 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *
>(&frag);
451 for (
int k = 0; k < Iterations::kM; ++k) {
453 for (
int n = 0; n < Iterations::kN; ++n) {
454 *(ref_.
data() + ref_.
offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) =
455 src_ptr[n + k * Iterations::kN];
463 store_with_pointer_offset(frag, 0);
528 (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)),
529 "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
531 static_assert(Shape::kRow > 0,
"Shape::kRow must be greater than zero.");
532 static_assert(Shape::kColumn > 0,
"Shape::kColumn must be greater than zero.");
533 static_assert(Policy::WarpShape::kRow > 0,
"Policy::WarpShape::kRow must be greater than zero.");
534 static_assert(Policy::WarpShape::kColumn > 0,
"Policy::WarpShape::kColumn must be greater than zero.");
535 static_assert(Shape::kRow / Policy::WarpShape::kRow > 0,
"Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
536 static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0,
"Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
540 Shape::kRow / Policy::WarpShape::kRow,
541 Shape::kColumn / Policy::WarpShape::kColumn
545 (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)),
546 "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
550 ThreadShape::kRow / Policy::LaneMmaShape::kM,
551 ThreadShape::kColumn / Policy::LaneMmaShape::kN
555 Policy::WarpShape::kRow * Policy::LaneMmaShape::kM,
556 Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN
560 using Fragment = Array<Element, ThreadShape::kCount>;
581 typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
583 MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
584 MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
601 coord.row() * Shape::kRow,
602 coord.column() * Shape::kColumn});
629 Index pointer_offset)
const {
632 for (
int mma_n = 0; mma_n < Iterations::kN; ++mma_n) {
634 for (
int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
636 Array<Element, Policy::LaneMmaShape::kM>
const *src_ptr =
637 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM>
const *
>(
638 ref_.
data() + pointer_offset + ref_.
offset({0, mma_n * Delta::kN + n}));
641 for (
int mma_m = 0; mma_m < Iterations::kM; ++mma_m) {
643 Array<Element, Policy::LaneMmaShape::kM> *dst_ptr =
644 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *
>(&frag) +
645 mma_m + Iterations::kM * (n + mma_n * Policy::LaneMmaShape::kN);
647 *dst_ptr = src_ptr[mma_m * Policy::WarpShape::kM];
656 load_with_pointer_offset(frag, 0);
664 for (
int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
666 for (
int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
668 Array<Element, Policy::LaneMmaShape::kM> *dst_ptr=
669 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *
>(
670 ref_.
data() + pointer_offset + ref_.
offset({0, mma_n * Delta::kColumn + n}));
673 for (
int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
675 Array<Element, Policy::LaneMmaShape::kM>
const *src_ptr =
676 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM>
const *
>(&frag) +
677 mma_m + Iterations::kRow * (n + mma_n * Policy::LaneMmaShape::kN);
679 dst_ptr[mma_m * Policy::WarpShape::kRow] = *src_ptr;
687 store_with_pointer_offset(frag, 0);
740 (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)),
741 "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
743 static_assert(Shape::kRow > 0,
"Shape::kRow must be greater than zero.");
744 static_assert(Shape::kColumn > 0,
"Shape::kColumn must be greater than zero.");
745 static_assert(Policy::WarpShape::kRow > 0,
"Policy::WarpShape::kRow must be greater than zero.");
746 static_assert(Policy::WarpShape::kColumn > 0,
"Policy::WarpShape::kColumn must be greater than zero.");
747 static_assert(Shape::kRow / Policy::WarpShape::kRow > 0,
"Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
748 static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0,
"Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
752 Shape::kRow / Policy::WarpShape::kRow,
753 Shape::kColumn / Policy::WarpShape::kColumn
757 (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)),
758 "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
762 ThreadShape::kRow / Policy::LaneMmaShape::kM,
763 ThreadShape::kColumn / Policy::LaneMmaShape::kN
767 Policy::WarpShape::kRow * Policy::LaneMmaShape::kM,
768 Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN
772 using Fragment = Array<Element, ThreadShape::kCount>;
793 typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
795 MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
796 MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
798 ref_.add_coord_offset(lane_offset);
804 ref_.add_pointer_offset(offset);
812 ref_.add_coord_offset({
813 coord.row() * Shape::kRow,
814 coord.column() * Shape::kColumn});
823 ref_.add_coord_offset({Shape::kRow, 0});
832 ref_.add_coord_offset({-Shape::kRow, 0});
841 Index pointer_offset)
const {
844 for (
int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
846 for (
int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
848 Array<Element, Policy::LaneMmaShape::kN>
const *src_ptr =
849 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN>
const *
>(
850 ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0}));
853 for (
int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
855 Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
856 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *
>(&frag) +
857 mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM);
859 *dst_ptr = src_ptr[mma_n * Policy::WarpShape::kColumn];
868 load_with_pointer_offset(frag, 0);
876 for (
int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
878 for (
int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
880 Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
881 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *
>(
882 ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0}));
885 for (
int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
887 Array<Element, Policy::LaneMmaShape::kN>
const *src_ptr =
888 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN>
const *
>(&frag) +
889 mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM);
891 dst_ptr[mma_n * Policy::WarpShape::kColumn] = *src_ptr;
900 store_with_pointer_offset(frag, 0);
922 int PartitionGroupSize
955 static const int kInterleave = 4;
958 static const int kPartitionsK = PartitionsK;
961 static const int kGroupPerTile = PartitionGroupSize / Shape::kColumn;
968 "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension.");
970 static_assert(Shape::kRow > 0,
"Shape::kRow must be greater than zero.");
971 static_assert(Shape::kColumn > 0,
"Shape::kColumn must be greater than zero.");
972 static_assert(Policy::WarpShape::kRow > 0,
"Policy::WarpShape::kRow must be greater than zero.");
973 static_assert(Shape::kRow / Policy::WarpShape::kRow > 0,
"Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
977 Shape::kRow / Policy::WarpShape::kRow,
981 static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM) && !(ThreadShape::kColumn % Policy::LaneMmaShape::kK),
982 "Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
986 ThreadShape::kRow / Policy::LaneMmaShape::kM,
987 ThreadShape::kColumn / Policy::LaneMmaShape::kK
991 using Fragment = Array<Element, ThreadShape::kCount>;
1013 typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
1015 MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
1021 ref_.
reset(
reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK> *
>(ref.
data()), ref.
stride(0)/Policy::LaneMmaShape::kMK);
1037 coord.row() * Shape::kRow / Policy::LaneMmaShape::kMK,
1038 coord.column() * Shape::kColumn});
1047 add_tile_offset({0, 1});
1049 if (kPartitionsK > 1) {
1052 if (k_group_idx_ == kGroupPerTile) {
1054 add_tile_offset({0, kGroupPerTile * (kPartitionsK-1)});
1074 Array<Element, Policy::LaneMmaShape::kMK > *dst_ptr =
1075 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK> *
>(&frag);
1078 for (
int k = 0; k < Iterations::kColumn; ++k) {
1081 for (
int m = 0; m < Iterations::kRow; ++m) {
1083 dst_ptr[m + k * Iterations::kRow] =
1084 *((ref_.
data() + ref_.
offset({m * Policy::WarpShape::kRow / kInterleave,
1085 k*Policy::LaneMmaShape::kK}) + pointer_offset / Policy::LaneMmaShape::kM));
1093 load_with_pointer_offset(frag, 0);
1100 Array<Element, Policy::LaneMmaShape::kMK>
const *src_ptr =
1101 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK > *
>(&frag);
1104 for (
int k = 0; k < Iterations::kN; ++k) {
1106 for (
int m = 0; m < Iterations::kM; ++m) {
1107 *(ref_.
data() + ref_.
offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) =
1108 src_ptr[m + k * Iterations::kM];
1116 store_with_pointer_offset(frag, 0);
1148 int PartitionGroupSize
1181 static const int kInterleave = 4;
1184 static const int kPartitionsK = PartitionsK;
1187 static const int kGroupPerTile = PartitionGroupSize / Shape::kRow;
1193 static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn),
1194 "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension.");
1196 static_assert(Shape::kRow > 0,
"Shape::kRow must be greater than zero.");
1197 static_assert(Shape::kColumn > 0,
"Shape::kColumn must be greater than zero.");
1198 static_assert(Policy::WarpShape::kColumn > 0,
"Policy::WarpShape::kColumn must be greater than zero.");
1199 static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0,
"Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
1204 Shape::kColumn / Policy::WarpShape::kColumn
1207 static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN) && !(ThreadShape::kRow % Policy::LaneMmaShape::kK),
1208 "Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
1212 ThreadShape::kRow / Policy::LaneMmaShape::kK,
1213 ThreadShape::kColumn / Policy::LaneMmaShape::kN
1242 typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
1244 MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
1252 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kKN> *
>(ref.
data()),
1253 ref.
stride(0) / Policy::LaneMmaShape::kKN);
1268 coord.row() * Shape::kRow,
1269 coord.column() * Shape::kColumn / Policy::LaneMmaShape::kKN});
1278 add_tile_offset({1, 0});
1280 if (kPartitionsK > 1) {
1283 if (k_group_idx_ == kGroupPerTile) {
1285 add_tile_offset({kGroupPerTile * (kPartitionsK-1), 0});
1305 Array<Element, Policy::LaneMmaShape::kKN> *dst_ptr =
1306 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kKN> *
>(&frag);
1309 for (
int k = 0; k < Iterations::kRow; ++k) {
1311 for (
int n = 0; n < Iterations::kColumn; ++n) {
1312 dst_ptr[n + k * Iterations::kColumn] =
1313 *(ref_.
data() + ref_.
offset({k * Policy::LaneMmaShape::kK,
1314 n * Policy::WarpShape::kColumn / kInterleave}) + pointer_offset / Policy::LaneMmaShape::kN);
1322 load_with_pointer_offset(frag, 0);
1329 Array<Element, Policy::LaneMmaShape::kN>
const *src_ptr =
1330 reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *
>(&frag);
1333 for (
int k = 0; k < Iterations::kM; ++k) {
1335 for (
int n = 0; n < Iterations::kN; ++n) {
1336 *(ref_.
data() + ref_.
offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) =
1337 src_ptr[n + k * Iterations::kN];
1345 store_with_pointer_offset(frag, 0);
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:257
Describes the lane policy used by warp-level matrix multiply operators targeting SIMT instructions...
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile.
Definition: mma_simt_tile_iterator.h:991
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:404
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:686
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:1007
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:308
CUTLASS_HOST_DEVICE Element * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:254
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef const &ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:786
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef const &ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:574
CUTLASS_DEVICE void set_kgroup_index(int k_group)
Definition: mma_simt_tile_iterator.h:1356
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Definition: mma_simt_tile_iterator.h:1003
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:730
CUTLASS_DEVICE void set_kgroup_index(int k_group)
Definition: mma_simt_tile_iterator.h:1127
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:521
Operand
GEMM operand enumeration: D = A * B + C.
Definition: include/cutlass/gemm/gemm.h:39
typename TensorRef::Index Index
Index type.
Definition: mma_simt_tile_iterator.h:946
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:105
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:317
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile.
Definition: mma_simt_tile_iterator.h:1217
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1072
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile.
Definition: mma_simt_tile_iterator.h:772
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:201
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:1236
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:445
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:362
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:655
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:386
Element_ Element
Element type.
Definition: mma_simt_tile_iterator.h:302
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:234
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:810
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:782
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:162
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer.
Definition: tensor_ref.h:326
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:497
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1303
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1098
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Definition: mma_simt_tile_iterator.h:69
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:439
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:618
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:210
TensorRef< Element, Layout > TensorRef
TensorRef type for loading element from a tensor.
Definition: mma_simt_tile_iterator.h:724
Element_ Element
Element type.
Definition: mma_simt_tile_iterator.h:1160
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:1232
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:591
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:940
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:219
CUTLASS_HOST_DEVICE Stride stride() const
Returns the layout object's stride vector.
Definition: tensor_ref.h:277
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:114
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:1276
typename Layout::TensorCoord TensorCoord
Coordinate in logical tensor space.
Definition: tensor_ref.h:171
Defines a Shape template for matrix tiles.
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:393
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile.
Definition: mma_simt_tile_iterator.h:560
CUTLASS_DEVICE void set_kgroup_index(int k_group)
Definition: mma_simt_tile_iterator.h:474
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1092
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:183
CUTLASS_HOST_DEVICE void reset(Element *ptr=nullptr)
Updates only the pointer.
Definition: tensor_ref.h:235
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile.
Definition: mma_simt_tile_iterator.h:350
typename TensorRef::Index Index
Index type.
Definition: mma_simt_tile_iterator.h:515
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:1027
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:952
Element_ Element
Element type.
Definition: mma_simt_tile_iterator.h:934
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:518
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:320
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:301
CUTLASS_DEVICE void set_kgroup_index(int k_group)
Definition: mma_simt_tile_iterator.h:269
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:366
CUTLASS_HOST_DEVICE void store(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:462
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:1063
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:240
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:928
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:661
Element_ Element
Element type.
Definition: mma_simt_tile_iterator.h:99
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:509
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory with additional logical offset.
Definition: mma_simt_tile_iterator.h:627
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:1034
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:1045
typename Layout::Index Index
Index type.
Definition: tensor_ref.h:165
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:873
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:1258
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:422
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:1166
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:867
typename TensorRef::Index Index
Index type.
Definition: mma_simt_tile_iterator.h:1172
typename TensorRef::Index Index
Index type.
Definition: mma_simt_tile_iterator.h:111
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:733
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:1175
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:899
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory with additional logical offset.
Definition: mma_simt_tile_iterator.h:839
typename TensorRef::Index Index
Index type.
Definition: mma_simt_tile_iterator.h:314
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:413
Defines layout functions used by TensorRef and derived classes.
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:296
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:821
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:570
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread's part of a tile.
Definition: mma_simt_tile_iterator.h:147
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:190
Definition: layout/matrix.h:343
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1321
CUTLASS_HOST_DEVICE TensorRef & add_pointer_offset(LongIndex offset_)
Adds an offset to each pointer.
Definition: tensor_ref.h:319
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:1294
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:803
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:158
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:830
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:598
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:1178
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1115
typename TensorRef::Index Index
Index type.
Definition: mma_simt_tile_iterator.h:727
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:1154
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:117
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:609
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:93
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:1265
CUTLASS_HOST_DEVICE void store(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1344
Basic include for CUTLASS.
Definition: matrix_coord.h:39
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1327
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:949
Element_ Element
Element type.
Definition: mma_simt_tile_iterator.h:715
Element_ Element
Element type.
Definition: mma_simt_tile_iterator.h:503
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:721
typename Layout::LongIndex LongIndex
Long index used for pointer offsets.
Definition: tensor_ref.h:168
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:709
Definition: layout/matrix.h:237