57 typename InterleavedTileShape,
92 static int const kIterations = Policy::kIterations;
113 accumulators_(reinterpret_cast<
AccessType const *>(&accum)),
138 static int const kAccessesPerMma = Policy::kElementsPerMma / Policy::kElementsPerAccess;
141 for (
int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) {
143 int tile_access_idx =
144 (tile_n * Policy::TileIterations::kRow + (index_ & 2) / 2) * Policy::MmaIterations::kCount * kAccessesPerMma;
147 for (
int mma_n = 0; mma_n < Policy::MmaIterations::kColumn * kAccessesPerMma; ++mma_n) {
149 int mma_access_idx = ((mma_n & 1) * 2 + (index_ & 1)) * kAccessesPerMma + (mma_n & 2) / 2;
151 frag_ptr[tile_n * Policy::MmaIterations::kColumn * kAccessesPerMma +
152 mma_n] = accumulators_[tile_access_idx + mma_access_idx];
169 using ElementC = float;
185 static int const kIterations = Policy::kIterations;
206 accumulators_(reinterpret_cast<
AccessType const *>(&accum)),
230 int const kRegsPerMmaRow = 2;
233 for (
int reg_row = 0; reg_row < Policy::kRowsPerMmaTile; ++reg_row) {
236 for (
int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) {
239 for (
int mma_n = 0; mma_n < Policy::MmaIterations::kColumn * 2; ++mma_n) {
241 int mma_idx = (index_ & 1) + (index_ & 2) * Policy::MmaIterations::kCount / 2 +
242 (tile_n * Policy::TileIterations::kRow) * Policy::MmaIterations::kCount + (mma_n & 1) * 2;
244 int reg_offset = reg_row * kRegsPerMmaRow + (mma_n & 2) * 2;
245 int reg_idx = mma_idx * Policy::kElementsPerMma + reg_offset;
247 *frag_ptr = accumulators_[reg_idx / Policy::kElementsPerAccess];
typename Policy::AccumulatorTile AccumulatorTile
This is the complete warp-level accumulator tile.
Definition: fragment_iterator_volta_tensor_op.h:182
CUTLASS_HOST_DEVICE FragmentIteratorVoltaTensorOp & operator++()
Increments.
Definition: fragment_iterator_volta_tensor_op.h:120
Definition: aligned_buffer.h:35
typename Policy::AccessType AccessType
Array type for aligned memory accesses.
Definition: fragment_iterator_volta_tensor_op.h:176
CUTLASS_HOST_DEVICE FragmentIteratorVoltaTensorOp & operator--()
Decrements.
Definition: fragment_iterator_volta_tensor_op.h:219
IEEE half-precision floating-point type.
Definition: half.h:126
typename Policy::AccumulatorTile AccumulatorTile
This is the complete warp-level accumulator tile.
Definition: fragment_iterator_volta_tensor_op.h:87
Defines common types used for all GEMM-like operators.
AccumulatorTile OutputAccumulatorTile
Definition: fragment_iterator_volta_tensor_op.h:89
CUTLASS_HOST_DEVICE FragmentIteratorVoltaTensorOp & operator++()
Increments.
Definition: fragment_iterator_volta_tensor_op.h:212
WarpShape_ WarpShape
Definition: fragment_iterator_volta_tensor_op.h:167
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE FragmentIteratorVoltaTensorOp(AccumulatorTile const &accum)
Constructs an iterator.
Definition: fragment_iterator_volta_tensor_op.h:205
typename Policy::AccessType AccessType
Array type for aligned memory accesses.
Definition: fragment_iterator_volta_tensor_op.h:81
CUTLASS_HOST_DEVICE FragmentIteratorVoltaTensorOp & operator--()
Decrements.
Definition: fragment_iterator_volta_tensor_op.h:127
CUTLASS_HOST_DEVICE void load(Fragment &frag, int index_offset=0) const
Loads a fragment from the referenced part of the accumulator tile.
Definition: fragment_iterator_volta_tensor_op.h:134
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Definition: fragment_iterator_volta_tensor_op.h:61
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_HOST_DEVICE void load(Fragment &frag, int index_offset=0) const
Loads a fragment from the referenced part of the accumulator tile.
Definition: fragment_iterator_volta_tensor_op.h:226
Defines layout functions used by TensorRef and derived classes.
Policy details related to the epilogue.
Definition: volta_tensor_op_policy.h:52
CUTLASS_HOST_DEVICE FragmentIteratorVoltaTensorOp(AccumulatorTile const &accum)
Constructs an iterator.
Definition: fragment_iterator_volta_tensor_op.h:112
WarpShape_ WarpShape
Definition: fragment_iterator_volta_tensor_op.h:72
typename Policy::Fragment Fragment
This is the fragment size produced by one access of the iterator.
Definition: fragment_iterator_volta_tensor_op.h:84
Defines basic structures needed for implementing the warp-scoped phase of the epilogue. These quantities assume a 'column-major' arrangement of TensorOp instructions, of which a row-oriented slice is visible per iteration.
typename Policy::Fragment Fragment
This is the fragment size produced by one access of the iterator.
Definition: fragment_iterator_volta_tensor_op.h:179