48 typename InterleavedTileShape,
77 Policy::kRowsPerIteration,
91 static int const kIterations = Policy::kIterations;
94 static int const kElementsPerAccess = Policy::kElementsPerAccess;
98 static int const kLanesInQuad = 4;
99 static int const kRowsPerQuad = 4;
100 static int const kColumnsPerQuad = 8;
101 static int const kAccessesPerQuad = kColumnsPerQuad / Policy::kElementsPerAccess;
102 static int const kAccessQuadDelta = 16;
108 Policy::kElementsPerAccess>;
134 pointer_(reinterpret_cast<
AccessType *>(ref.data())),
135 layout_(ref.stride()[0] /
Policy::kElementsPerAccess) {
137 int quad_id = lane_id / Detail::kLanesInQuad;
138 int lane_in_quad = (lane_id % Detail::kLanesInQuad);
140 int quad_row_idx = ((quad_id & 4) >> 1) + (quad_id & 1);
141 int quad_col_idx = ((quad_id & 2) >> 1);
143 int row = quad_row_idx * Detail::kRowsPerQuad + lane_in_quad;
144 int column = quad_col_idx * Detail::kColumnsPerQuad;
146 pointer_ += layout_({row, column / kElementsPerAccess});
152 pointer_ += pointer_offset / Policy::kElementsPerAccess;
160 pointer_ += layout_({
161 tile_offset.
row() * Shape::kRow,
162 tile_offset.
column() * Shape::kColumn / Policy::kElementsPerAccess});
170 add_tile_offset(tile_offset);
181 for (
int tile_idx = 0; tile_idx < Policy::TileIterations::kColumn; ++tile_idx) {
184 for (
int access_idx = 0; access_idx < Policy::kAccessesPerInterleavedTile; ++access_idx) {
186 int access_quad = access_idx / 2;
187 int access = access_idx % 2;
189 int ptr_offset = tile_idx * InterleavedTileShape::kN / Policy::kElementsPerAccess +
190 access_quad * Detail::kAccessQuadDelta / Policy::kElementsPerAccess + access;
192 int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx;
194 AccessType access_vector = frag_ptr[frag_idx];
196 pointer_[ptr_offset] = access_vector;
204 store_with_pointer_offset(frag, 0);
214 for (
int tile_idx = 0; tile_idx < Policy::TileIterations::kColumn; ++tile_idx) {
217 for (
int access_idx = 0; access_idx < Policy::kAccessesPerInterleavedTile; ++access_idx) {
219 int access_quad = access_idx / 2;
220 int access = access_idx % 2;
222 int ptr_offset = tile_idx * Detail::kTileDelta + access_quad * Detail::kAccessQuadDelta + access;
223 int frag_idx = tile_idx * Policy::kAccessesPerInterleavedTile + access_idx;
225 frag_ptr[frag_idx] = pointer_[ptr_offset];
233 load_with_pointer_offset(frag, 0);
260 Policy::kRowsPerIteration,
274 static int const kIterations = Policy::kIterations;
277 static int const kElementsPerAccess = Policy::kElementsPerAccess;
281 static int const kLanesInQuad = 4;
282 static int const kRowsPerQuad = 4;
283 static int const kColumnsPerQuad = 8;
284 static int const kAccessesPerQuad = kColumnsPerQuad / Policy::kElementsPerAccess;
285 static int const kAccessQuadDelta = 16;
291 Policy::kElementsPerAccess>;
317 pointer_(reinterpret_cast<
AccessType *>(ref.data())),
318 layout_(ref.stride()[0] /
Policy::kElementsPerAccess) {
320 int quad_id = lane_id / Detail::kLanesInQuad;
321 int lane_in_quad = (lane_id % Detail::kLanesInQuad);
323 int const kQuadRowDelta = 4;
324 int const kQuadColumnDelta = 2 * Policy::MmaIterations::kColumn;
326 int quad_row_offset = ((quad_id & 4) / 2 + (quad_id & 1)) * kQuadRowDelta;
327 int quad_column_offset = (quad_id & 2) / 2 * kQuadColumnDelta;
329 int thread_row_offset = (lane_in_quad & 1);
330 int thread_column_offset = (lane_in_quad & 2) / 2;
332 int row = quad_row_offset + thread_row_offset;
333 int column = quad_column_offset + thread_column_offset;
335 pointer_ += layout_({row, column});
341 pointer_ += pointer_offset / Policy::kElementsPerAccess;
349 pointer_ += layout_({
350 tile_offset.
row() * Shape::kRow,
351 tile_offset.
column() * Shape::kColumn / Policy::kElementsPerAccess});
359 add_tile_offset(tile_offset);
369 int const kAccessesPerRow = Policy::TileIterations::kColumn * Policy::MmaIterations::kColumn * 2;
372 for (
int row_idx = 0; row_idx < Policy::kRowsPerMmaTile; ++row_idx) {
375 for (
int access_idx = 0; access_idx < kAccessesPerRow; ++access_idx) {
377 int frag_idx = row_idx * kAccessesPerRow + access_idx;
379 int ptr_column_offset = (access_idx & 1) * 2 +
380 (access_idx & 2) * Policy::MmaIterations::kColumn * 2 +
381 (access_idx & 4) * Policy::MmaIterations::kColumn * 2;
383 int ptr_row_offset = row_idx * 2;
385 int ptr_offset = layout_({ptr_row_offset, ptr_column_offset});
387 pointer_[ptr_offset] = frag_ptr[frag_idx];
395 store_with_pointer_offset(frag, 0);
410 load_with_pointer_offset(frag, 0);
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment const &frag, Index pointer_offset)
Load.
Definition: tile_iterator_volta_tensor_op.h:400
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
CUTLASS_HOST_DEVICE Index const & column() const
Returns the column of the coordinate.
Definition: matrix_coord.h:85
typename TensorRef::Index Index
Definition: tile_iterator_volta_tensor_op.h:70
Definition: aligned_buffer.h:35
Defines basic structures needed for implementing the warp-scoped phase of the epilogue. These quantities assume a 'column-major' arrangement of TensorOp instructions, of which a row-oriented slice is visible per iteration.
CUTLASS_HOST_DEVICE void load(Fragment const &frag)
Load.
Definition: tile_iterator_volta_tensor_op.h:232
CUTLASS_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset)
Store.
Definition: tile_iterator_volta_tensor_op.h:365
typename TensorRef::LongIndex LongIndex
Definition: tile_iterator_volta_tensor_op.h:71
IEEE half-precision floating-point type.
Definition: half.h:126
CUTLASS_HOST_DEVICE Index const & row() const
Returns the row of the coordinate.
Definition: matrix_coord.h:77
typename Policy::AccessType AccessType
Array type for aligned memory accesses.
Definition: tile_iterator_volta_tensor_op.h:265
CUTLASS_HOST_DEVICE TileIteratorVoltaTensorOp & add_tile_offset(TensorCoord const &tile_offset)
advances in units of whole tiles along the logical coordinate space of the tensor ...
Definition: tile_iterator_volta_tensor_op.h:158
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
WarpShape_ WarpShape
Definition: tile_iterator_volta_tensor_op.h:63
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE TileIteratorVoltaTensorOp & operator+=(TensorCoord const &tile_offset)
Definition: tile_iterator_volta_tensor_op.h:358
CUTLASS_HOST_DEVICE void load(Fragment const &frag)
Load.
Definition: tile_iterator_volta_tensor_op.h:409
CUTLASS_HOST_DEVICE TileIteratorVoltaTensorOp & operator+=(TensorCoord const &tile_offset)
Definition: tile_iterator_volta_tensor_op.h:169
typename Policy::Fragment Fragment
This is the fragment size produced by one access of the iterator.
Definition: tile_iterator_volta_tensor_op.h:268
CUTLASS_HOST_DEVICE TileIteratorVoltaTensorOp & add_pointer_offset(Index pointer_offset)
Adds a pointer offset.
Definition: tile_iterator_volta_tensor_op.h:151
Template for reading and writing tiles of accumulators to shared memory.
Definition: tile_iterator_volta_tensor_op.h:52
CUTLASS_HOST_DEVICE TileIteratorVoltaTensorOp()
Default constructor.
Definition: tile_iterator_volta_tensor_op.h:309
CUTLASS_HOST_DEVICE void store(Fragment const &frag)
Store.
Definition: tile_iterator_volta_tensor_op.h:203
CUTLASS_HOST_DEVICE TileIteratorVoltaTensorOp & add_tile_offset(TensorCoord const &tile_offset)
advances in units of whole tiles along the logical coordinate space of the tensor ...
Definition: tile_iterator_volta_tensor_op.h:347
CUTLASS_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset)
Store.
Definition: tile_iterator_volta_tensor_op.h:176
typename Policy::AccumulatorTile AccumulatorTile
This is the complete warp-level accumulator tile.
Definition: tile_iterator_volta_tensor_op.h:271
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
float Element
Definition: tile_iterator_volta_tensor_op.h:248
WarpShape_ WarpShape
Definition: tile_iterator_volta_tensor_op.h:246
typename Policy::AccessType AccessType
Array type for aligned memory accesses.
Definition: tile_iterator_volta_tensor_op.h:82
typename TensorRef::LongIndex LongIndex
Definition: tile_iterator_volta_tensor_op.h:254
typename Layout::Index Index
Index type.
Definition: tensor_ref.h:165
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_DEVICE TileIteratorVoltaTensorOp(TensorRef const &ref, unsigned lane_id)
Constructor from TensorRef.
Definition: tile_iterator_volta_tensor_op.h:130
typename Policy::Fragment Fragment
This is the fragment size produced by one access of the iterator.
Definition: tile_iterator_volta_tensor_op.h:85
Defines layout functions used by TensorRef and derived classes.
Defines layout functions used by TensorRef and derived classes for pitch-linear memory.
typename TensorRef::Index Index
Definition: tile_iterator_volta_tensor_op.h:253
Policy details related to the epilogue.
Definition: volta_tensor_op_policy.h:52
CUTLASS_HOST_DEVICE void store(Fragment const &frag)
Store.
Definition: tile_iterator_volta_tensor_op.h:394
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment const &frag, Index pointer_offset)
Load.
Definition: tile_iterator_volta_tensor_op.h:209
CUTLASS_HOST_DEVICE TileIteratorVoltaTensorOp & add_pointer_offset(Index pointer_offset)
Adds a pointer offset.
Definition: tile_iterator_volta_tensor_op.h:340
CUTLASS_DEVICE TileIteratorVoltaTensorOp(TensorRef const &ref, unsigned lane_id)
Constructor from TensorRef.
Definition: tile_iterator_volta_tensor_op.h:313
Definition: matrix_coord.h:39
typename Policy::AccumulatorTile AccumulatorTile
This is the complete warp-level accumulator tile.
Definition: tile_iterator_volta_tensor_op.h:88
CUTLASS_HOST_DEVICE TileIteratorVoltaTensorOp()
Default constructor.
Definition: tile_iterator_volta_tensor_op.h:126
typename Layout::LongIndex LongIndex
Long index used for pointer offsets.
Definition: tensor_ref.h:168