42 namespace threadblock {
63 Shape::kM / Operator::Shape::kM,
64 Shape::kN / Operator::Shape::kN,
69 "Direct epilogue cannot be used with when the threadblock tile is partitioned along the K dimension.");
113 typename OutputOp::Params output_op_,
114 typename ConvertOp::Params convert_op_
116 destination_ref(destination_ref_),
117 source_ref(source_ref_),
118 output_op(output_op_),
119 convert_op(convert_op_) {
128 typename OutputOp::Params output_op_
162 output_op(params.output_op),
163 convert_op(params.convert_op),
179 warp_m * Operator::Shape::kM,
180 warp_n * Operator::Shape::kN
195 MatrixCoord{tb_tile_coord.
m() * Shape::kM, tb_tile_coord.
n() * Shape::kN} + warp_origin_;
199 Operator::Shape::kM / Operator::Policy::Operator::Shape::kM,
200 Operator::Shape::kN / Operator::Policy::Operator::Shape::kN
206 int const kElementsPerAccess = Operator::Policy::Operator::Shape::kN / 4;
207 int const kRowsPerTile = 8;
208 int const kAccumulatorRows = Operator::Policy::Operator::Shape::kM / kRowsPerTile;
211 for (
int mma_n = 0; mma_n < MmaIterations::kN; ++mma_n) {
213 for (
int mma_m = 0; mma_m < MmaIterations::kM; ++mma_m) {
215 int mma_accum_start = kAccumulatorRows * kElementsPerAccess *
216 (mma_m * MmaIterations::kN + mma_n);
219 for (
int row = 0; row < kAccumulatorRows; ++row) {
221 for (
int col = 0; col < kElementsPerAccess; ++col) {
223 int accum_m = mma_m * Operator::Policy::Operator::Shape::kM + row * kRowsPerTile;
224 int accum_n = mma_n * Operator::Policy::Operator::Shape::kN + col;
225 int idx = mma_accum_start + row * kElementsPerAccess + col;
229 MatrixCoord thread_coord = thread_origin + accum_coord;
231 if (thread_coord <
MatrixCoord{problem_size.
m(), problem_size.
n()}) {
233 typename ConvertOp::result_type converted_accum =
convert_op(accumulators[idx]);
235 typename OutputOp::result_type output =
output_op(converted_accum, source_ref_.
at(accum_coord));
237 destination_ref_.
at(accum_coord) = output;
Epilogue operator.
Definition: direct_epilogue_tensor_op.h:55
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
CUTLASS_HOST_DEVICE Params(TensorRef destination_ref_, TensorRef source_ref_, typename OutputOp::Params output_op_)
Constructs a Params object.
Definition: direct_epilogue_tensor_op.h:125
Parameters structure for host-constructible state.
Definition: direct_epilogue_tensor_op.h:92
Definition: aligned_buffer.h:35
CUTLASS_DEVICE void operator()(gemm::GemmCoord problem_size, gemm::GemmCoord tb_tile_coord, FragmentC const &accumulators)
Streams the result to global memory.
Definition: direct_epilogue_tensor_op.h:189
TensorRef destination_ref
Definition: direct_epilogue_tensor_op.h:98
Definition: include/cutlass/gemm/gemm.h:94
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
TensorRef< Element, Layout::kRank, Layout > TensorRef
Reference to source and destination tensors.
Definition: direct_epilogue_tensor_op.h:87
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer.
Definition: tensor_ref.h:326
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
TensorRef source_ref
Definition: direct_epilogue_tensor_op.h:99
CUTLASS_DEVICE DirectEpilogueTensorOp(Params const ¶ms, SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Constructor.
Definition: direct_epilogue_tensor_op.h:155
OutputOp::Params output_op
Definition: direct_epilogue_tensor_op.h:101
ConvertOp::Params convert_op
Definition: direct_epilogue_tensor_op.h:102
CUTLASS_HOST_DEVICE Params(TensorRef destination_ref_, TensorRef source_ref_, typename OutputOp::Params output_op_, typename ConvertOp::Params convert_op_)
Constructs a Params object.
Definition: direct_epilogue_tensor_op.h:110
Shared storage allocation needed by the epilogue.
Definition: direct_epilogue_tensor_op.h:139
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
typename Operator::FragmentC FragmentC
Accumulator tile is really the warp-scoped tile.
Definition: direct_epilogue_tensor_op.h:72
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Operator_ Operator
Definition: direct_epilogue_tensor_op.h:59
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
OutputOp_ OutputOp
Function operator computing final output.
Definition: direct_epilogue_tensor_op.h:81
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
ConvertOp_ ConvertOp
Conversion operator to shared memory.
Definition: direct_epilogue_tensor_op.h:84
Basic include for CUTLASS.
Definition: matrix_coord.h:39
Element_ Element
Data type of output tensor.
Definition: direct_epilogue_tensor_op.h:75
Shape_ Shape
Definition: direct_epilogue_tensor_op.h:58
static int const kN
Definition: include/cutlass/gemm/gemm.h:59