56 typename OperatorShape,
57 typename OperatorElementC,
58 typename OperatorFragmentC,
68 typename OperatorShape_,
69 typename OperatorElementC_,
70 typename OperatorFragmentC_
76 using OperatorShape = OperatorShape_;
77 using OperatorElementC = OperatorElementC_;
78 using OperatorFragmentC = OperatorFragmentC_;
86 Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>;
91 OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>;
96 static int const kIterations = Policy::kIterations;
101 using AccessType = Array<OperatorElementC, Policy::kElementsPerAccess>;
110 AccessType
const *accumulators_;
120 accumulators_(reinterpret_cast<AccessType const *>(&accum)),
142 int index = index_ + index_offset;
144 AccessType *frag_ptr =
reinterpret_cast<AccessType *
>(&frag);
147 for (
int n = 0; n < Policy::OperatorCount::kColumn; ++n) {
149 int accumulator_access_offset =
150 index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess;
152 frag_ptr[n] = accumulators_[accumulator_access_offset];
164 typename OperatorShape_,
166 typename OperatorElementC_,
168 typename OperatorFragmentC_,
172 layout::ColumnMajorInterleaved<InterleavedK>> {
175 using OperatorShape = OperatorShape_;
176 using OperatorElementC = OperatorElementC_;
177 using OperatorFragmentC = OperatorFragmentC_;
178 static int const kInterleavedK = InterleavedK;
185 Array<OperatorElementC,
186 Policy::kElementsPerAccess * InterleavedK / OperatorShape::kN>;
190 Array<OperatorElementC, OperatorFragmentC::kElements *
191 Policy::OperatorCount::kRow *
192 Policy::OperatorCount::kColumn>;
195 static int const kIterations = Policy::kIterations;
200 Array<OperatorElementC, Policy::kElementsPerAccess>;
208 AccessType
const *accumulators_;
217 : accumulators_(reinterpret_cast<AccessType const *>(&accum)),
237 int index = index_ + index_offset;
239 AccessType *frag_ptr =
reinterpret_cast<AccessType *
>(&frag);
242 for (
int n = 0; n < (InterleavedK / OperatorShape::kN); ++n) {
243 int index_m = index % (Policy::OperatorCount::kRow *
244 Policy::kIterationsPerInstruction);
245 int index_n = index / (Policy::OperatorCount::kRow *
246 Policy::kIterationsPerInstruction);
247 int accumulator_access_offset =
248 (index_m / Policy::kIterationsPerInstruction) *
249 (Policy::OperatorCount::kColumn *
250 Policy::kIterationsPerInstruction) +
251 (index_m % Policy::kIterationsPerInstruction) +
252 index_n * (InterleavedK / OperatorShape::kN) *
253 Policy::kIterationsPerInstruction +
254 n * Policy::kIterationsPerInstruction;
256 frag_ptr[n] = accumulators_[accumulator_access_offset];
WarpShape_ WarpShape
Definition: fragment_iterator_tensor_op.h:75
CUTLASS_HOST_DEVICE void load(Fragment &frag, int index_offset=0) const
Loads a fragment from the referenced part of the accumulator tile.
Definition: fragment_iterator_tensor_op.h:140
Definition: aligned_buffer.h:35
Defines basic structures needed for implementing the warp-scoped phase of the epilogue. These quantities assume a 'column-major' arrangement of TensorOp instructions, of which a row-oriented slice is visible per iteration.
AccumulatorTile OutputAccumulatorTile
Definition: fragment_iterator_tensor_op.h:93
CUTLASS_HOST_DEVICE void load(Fragment &frag, int index_offset=0) const
Loads a fragment from the referenced part of the accumulator tile.
Definition: fragment_iterator_tensor_op.h:236
CUTLASS_HOST_DEVICE FragmentIteratorTensorOp & operator--()
Decrements.
Definition: fragment_iterator_tensor_op.h:229
WarpShape_ WarpShape
Definition: fragment_iterator_tensor_op.h:174
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Array< OperatorElementC, OperatorFragmentC::kElements *Policy::OperatorCount::kRow *Policy::OperatorCount::kColumn > AccumulatorTile
This is the complete warp-level accumulator tile.
Definition: fragment_iterator_tensor_op.h:91
Array< OperatorElementC, Policy::OperatorCount::kColumn *Policy::kElementsPerAccess > Fragment
This is the fragment size produced by one access of the iterator.
Definition: fragment_iterator_tensor_op.h:86
Policy details related to the epilogue.
Definition: tensor_op_policy.h:50
CUTLASS_HOST_DEVICE FragmentIteratorTensorOp & operator++()
Increments.
Definition: fragment_iterator_tensor_op.h:222
CUTLASS_HOST_DEVICE FragmentIteratorTensorOp(AccumulatorTile const &accum)
Constructs an iterator.
Definition: fragment_iterator_tensor_op.h:119
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE FragmentIteratorTensorOp(AccumulatorTile const &accum)
Constructs an iterator.
Definition: fragment_iterator_tensor_op.h:216
Definition: fragment_iterator_tensor_op.h:61
Array< OperatorElementC, Policy::kElementsPerAccess *InterleavedK/OperatorShape::kN > Fragment
This is the fragment size produced by one access of the iterator.
Definition: fragment_iterator_tensor_op.h:186
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_HOST_DEVICE FragmentIteratorTensorOp & operator++()
Increments.
Definition: fragment_iterator_tensor_op.h:126
Defines layout functions used by TensorRef and derived classes.
Definition: layout/matrix.h:343
Array< OperatorElementC, OperatorFragmentC::kElements *Policy::OperatorCount::kRow *Policy::OperatorCount::kColumn > AccumulatorTile
This is the complete warp-level accumulator tile.
Definition: fragment_iterator_tensor_op.h:192
CUTLASS_HOST_DEVICE FragmentIteratorTensorOp & operator--()
Decrements.
Definition: fragment_iterator_tensor_op.h:133