58 namespace threadblock {
65 typename WarpMmaOperator_,
67 typename OutputTileIterator_,
68 typename AccumulatorFragmentIterator_,
69 typename WarpTileIterator_,
70 typename SharedLoadIterator_,
79 AccumulatorFragmentIterator_,
89 AccumulatorFragmentIterator_,
121 using TensorRef =
typename OutputTileIterator::TensorRef;
131 typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>;
142 static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements,
143 "Mismatch between shared load iterator and output tile iterator.");
145 static_assert(OutputTileIterator::kElementsPerAccess,
"OutputTileIterator::kElementsPerAccess must not be zero.");
147 static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess),
165 Base(shared_storage, thread_idx, warp_idx, lane_idx),
166 shared_load_iterator_(shared_storage.reference(), thread_idx) { }
177 typename OutputTileIterator::Fragment source_fragment;
179 if (!output_op.is_source_needed()) {
180 source_iterator.clear_mask();
183 source_fragment.clear();
196 for (
int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
202 source_iterator.load(source_fragment);
211 typename AccumulatorFragmentIterator::Fragment accum_fragment;
213 accum_fragment_iterator.load(accum_fragment);
214 ++accum_fragment_iterator;
226 shared_load_iterator_.load(aligned_accum_fragment[0]);
229 if (kPartitionsK > 1)
232 const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK;
236 shared_load_iterator_.add_tile_offset({tile_row_offset , 0});
237 shared_load_iterator_.load(aligned_accum_fragment[i]);
238 aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
241 shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0});
248 typename OutputTileIterator::Fragment output_fragment;
250 apply_output_operator_(output_fragment, output_op, aligned_accum_fragment[0], source_fragment);
257 destination_iterator.store(output_fragment);
258 ++destination_iterator;
267 void apply_output_operator_(
268 typename OutputTileIterator::Fragment &output_fragment,
271 typename OutputTileIterator::Fragment
const &source_fragment) {
282 int const kOutputOpIterations =
283 OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess;
286 for (
int i = 0; i < kOutputOpIterations; ++i) {
289 output_frag_ptr[i] = output_op(compute_frag_ptr[i], source_frag_ptr[i]);
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
typename Layout::LongIndex LongIndex
Definition: epilogue.h:105
typename Base::WarpCount WarpCount
Number of warps.
Definition: epilogue.h:137
Array< Element, ThreadMap::Iterations::kColumn *ThreadMap::Iterations::kRow *ThreadMap::Iterations::kGroup *ThreadMap::Iterations::kCluster *ThreadMap::kElementsPerAccess > Fragment
Fragment object.
Definition: shared_load_iterator.h:91
Definition: aligned_buffer.h:35
WarpTileIterator warp_tile_iterator_
Stores a warp's fragment of accumulators to SMEM.
Definition: epilogue_base.h:176
Templates implementing how threads are mapped to a given tile.
Shared storage allocation needed by the epilogue.
Definition: epilogue_base.h:97
CUTLASS_DEVICE void operator()(OutputOp const &output_op, OutputTileIterator destination_iterator, AccumulatorTile const &accumulators, OutputTileIterator source_iterator)
Streams the result to global memory.
Definition: epilogue.h:170
OutputTileIterator_ OutputTileIterator
Definition: epilogue.h:96
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Defines common types used for all GEMM-like operators.
CUTLASS_DEVICE Epilogue(typename Base::SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Constructor.
Definition: epilogue.h:159
Shape_ Shape
Definition: epilogue.h:93
typename OutputTileIterator::TensorRef TensorRef
Tensor reference to destination tensor.
Definition: epilogue.h:121
gemm::GemmShape< Shape::kM/WarpMmaOperator::Shape::kM, Shape::kN/WarpMmaOperator::Shape::kN, kPartitionsK > WarpCount
Number of warps.
Definition: epilogue_base.h:92
Definition: functional.h:46
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
static int const kPartitionsK
Definition: epilogue.h:95
OutputOp_ OutputOp
Definition: epilogue.h:100
Definition: tensor_ref.h:146
Padding_ Padding
Definition: epilogue.h:101
Defines a canonical coordinate for rank=4 tensors offering named indices.
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
AccumulatorFragmentIterator_ AccumulatorFragmentIterator
Definition: epilogue.h:97
Top-level include for all CUTLASS numeric types.
typename OutputTileIterator::ConstTensorRef ConstTensorRef
Const tensor reference to source tensor.
Definition: epilogue.h:127
WarpTileIterator_ WarpTileIterator
Definition: epilogue.h:98
SharedLoadIterator_ SharedLoadIterator
Definition: epilogue.h:99
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Epilogue operator without splitk.
Definition: epilogue.h:74
typename WarpTileIterator::Element ElementAccumulator
Accumulator element.
Definition: epilogue.h:111
Defines layout functions used for rank=1 vectors.
Templates implementing storing of tiles from pitch-linear rank=2 tensors.
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Base class for epilogues defining warp-level.
Definition: epilogue_base.h:67
WarpMmaOperator_ WarpMmaOperator
Definition: epilogue.h:94
typename Base::AccumulatorTile AccumulatorTile
The complete warp-level accumulator tile.
Definition: epilogue.h:108
Array< typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess > OutputAccessType
Array type used to output.
Definition: epilogue.h:131
static int const kElementsPerAccess
Output access size.
Definition: epilogue.h:118
typename AccumulatorFragmentIterator::AccumulatorTile AccumulatorTile
The complete warp-level accumulator tile.
Definition: epilogue_base.h:81
typename OutputTileIterator::Element ElementOutput
Output element.
Definition: epilogue.h:115
Basic include for CUTLASS.
typename cutlass::TensorRef< int, cutlass::layout::PackedVectorLayout > SyncTensorRef
Tensor reference to sync tensor.
Definition: epilogue.h:124
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
Array< typename WarpTileIterator::Element, OutputTileIterator::kElementsPerAccess > AccumulatorAccessType
Array type used by output functor.
Definition: epilogue.h:134