57 namespace threadblock {
66 typename WarpMmaOperator_,
70 typename OutputTileIterator_,
72 typename AccumulatorFragmentIterator_,
78 bool IsBetaZero =
false>
104 using TensorRef =
typename OutputTileIterator::TensorRef;
115 OutputTileIterator::kElementsPerAccess>;
119 Array<ElementAccumulator, OutputTileIterator::kElementsPerAccess>;
124 Shape::kN / WarpMmaOperator::Shape::kN, kPartitionsK>;
128 "This must not be zero.");
131 OutputTileIterator::kElementsPerAccess),
160 if (IsBetaZero && output_op.is_source_needed())
163 typename OutputTileIterator::Fragment source_fragment;
166 if (!output_op.is_source_needed()) {
167 source_iterator.clear_mask();
171 source_fragment.clear();
184 for (
int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
190 source_iterator.set_iteration_index(iter);
191 source_iterator.load(source_fragment);
199 typename AccumulatorFragmentIterator::Fragment accum_fragment;
201 accum_fragment_iterator.load(accum_fragment);
202 ++accum_fragment_iterator;
208 typename OutputTileIterator::Fragment output_fragment;
209 apply_output_operator_(output_op, output_fragment, accum_fragment, source_fragment);
215 destination_iterator.set_iteration_index(iter);
216 destination_iterator.store(output_fragment);
217 ++destination_iterator;
224 void apply_output_operator_(
226 typename OutputTileIterator::Fragment &output_fragment,
227 typename AccumulatorFragmentIterator::Fragment
const 228 &aligned_accum_fragment,
229 typename OutputTileIterator::Fragment
const &source_fragment) {
235 &aligned_accum_fragment);
240 int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
241 OutputTileIterator::kElementsPerAccess;
244 for (
int i = 0; i < kOutputOpIterations; ++i) {
246 output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
Shape_ Shape
Definition: interleaved_epilogue.h:81
Definition: aligned_buffer.h:35
typename AccumulatorTile::Element ElementAccumulator
Accumulator element.
Definition: interleaved_epilogue.h:95
Templates implementing how threads are mapped to a given tile.
CUTLASS_DEVICE InterleavedEpilogue(SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Constructor.
Definition: interleaved_epilogue.h:141
typename AccumulatorFragmentIterator::AccumulatorTile AccumulatorTile
The complete warp-level accumulator tile.
Definition: interleaved_epilogue.h:92
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Array< ElementAccumulator, OutputTileIterator::kElementsPerAccess > AccumulatorAccessType
Array type used by output functor.
Definition: interleaved_epilogue.h:119
Epilogue operator without splitk.
Definition: interleaved_epilogue.h:79
Array< typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess > OutputAccessType
Array type used to output.
Definition: interleaved_epilogue.h:115
OutputOp_ OutputOp
Definition: interleaved_epilogue.h:86
Defines common types used for all GEMM-like operators.
typename OutputTileIterator::ConstTensorRef ConstTensorRef
Const tensor reference to source tensor.
Definition: interleaved_epilogue.h:111
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
typename OutputTileIterator::TensorRef TensorRef
Tensor reference to destination tensor.
Definition: interleaved_epilogue.h:104
Shared storage allocation needed by the epilogue.
Definition: interleaved_epilogue.h:135
WarpMmaOperator_ WarpMmaOperator
Definition: interleaved_epilogue.h:82
typename cutlass::TensorRef< int, cutlass::layout::PackedVectorLayout > SyncTensorRef
Tensor reference to sync tensor.
Definition: interleaved_epilogue.h:108
Definition: tensor_ref.h:146
Defines a canonical coordinate for rank=4 tensors offering named indices.
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
OutputTileIterator_ OutputTileIterator
Definition: interleaved_epilogue.h:85
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
CUTLASS_DEVICE void operator()(OutputOp const &output_op, OutputTileIterator destination_iterator, AccumulatorTile const &accumulators, OutputTileIterator source_iterator)
Streams the result to global memory.
Definition: interleaved_epilogue.h:150
static int const kElementsPerAccess
Output access size.
Definition: interleaved_epilogue.h:101
Defines layout functions used for rank=1 vectors.
Templates implementing storing of tiles from pitch-linear rank=2 tensors.
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Definition: layout/matrix.h:343
static int const kPartitionsK
Definition: interleaved_epilogue.h:83
AccumulatorFragmentIterator_ AccumulatorFragmentIterator
Definition: interleaved_epilogue.h:84
typename OutputTileIterator::Element ElementOutput
Output element.
Definition: interleaved_epilogue.h:98
Basic include for CUTLASS.