52 typename ReductionOp_,
53 int PartitionsPerStage = 4
102 size_t partition_stride_,
106 typename OutputOp::Params output_ =
typename OutputOp::Params(),
107 typename ReductionOp::Params reduction_ =
typename ReductionOp::Params()
109 problem_size(problem_size_),
110 partitions(partitions_),
111 partition_stride(sizeof(
FragmentWorkspace) * partition_stride_ / kElementsPerAccess),
112 workspace(workspace_),
113 destination(destination_),
116 reduction(reduction_) {
132 (problem_size.
column() + Shape::kColumn - 1) / Shape::kColumn,
133 (problem_size.
row() + Shape::kRow -1) / Shape::kRow);
139 return dim3(Shape::kColumn / kElementsPerAccess, Shape::kRow);
148 int(blockIdx.y) * Shape::kRow + threadIdx.y,
149 int(blockIdx.x) * Shape::kColumn + threadIdx.x * kElementsPerAccess
170 char const *workspace_ptr =
171 reinterpret_cast<char const *
>(
200 accumulator = reduction_op(accumulator, workspace_frag[i]);
216 if (output_op.is_source_needed()) {
224 typename OutputOp::FragmentOutput output_frag = output_op(accumulator, source_frag);
233 *dest_ptr =
reinterpret_cast<FragmentOutput const &
>(output_frag);
typename ReductionOp::Element ElementWorkspace
Definition: reduce_split_k.h:64
typename ReductionOp::ElementAccumulator ElementAccumulator
Definition: reduce_split_k.h:65
CUTLASS_HOST_DEVICE Index const & column() const
Returns the column of the coordinate.
Definition: matrix_coord.h:85
OutputTensorRef source
Definition: reduce_split_k.h:87
Definition: aligned_buffer.h:35
OutputTensorRef destination
Definition: reduce_split_k.h:86
Defines a structure containing strides, bounds, and a pointer to tensor data.
size_t partition_stride
Definition: reduce_split_k.h:84
CUTLASS_HOST_DEVICE Element * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:254
static CUTLASS_HOST_DEVICE dim3 block_shape()
Determines the threadblock shape.
Definition: reduce_split_k.h:138
Aligned array type.
Definition: array.h:511
typename OutputOp::ElementOutput ElementOutput
Definition: reduce_split_k.h:66
CUTLASS_HOST_DEVICE Index const & row() const
Returns the row of the coordinate.
Definition: matrix_coord.h:77
int partitions
Definition: reduce_split_k.h:83
CUTLASS_HOST_DEVICE Params(MatrixCoord problem_size_, int partitions_, size_t partition_stride_, WorkspaceTensorRef workspace_, OutputTensorRef destination_, OutputTensorRef source_, typename OutputOp::Params output_=typename OutputOp::Params(), typename ReductionOp::Params reduction_=typename ReductionOp::Params())
Definition: reduce_split_k.h:99
CUTLASS_HOST_DEVICE Params()
Definition: reduce_split_k.h:96
Params structure.
Definition: reduce_split_k.h:80
ReductionOp_ ReductionOp
Definition: reduce_split_k.h:59
CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &storage)
Perform a reduction.
Definition: reduce_split_k.h:144
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Shape_ Shape
Definition: reduce_split_k.h:58
Boost-like numeric conversion operator for CUTLASS numeric types.
Defines a Shape template for matrix tiles.
WorkspaceTensorRef workspace
Definition: reduce_split_k.h:85
ReductionOp::Params reduction
Definition: reduce_split_k.h:89
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:301
Definition: reduce_split_k.h:55
#define CUTLASS_PRAGMA_NO_UNROLL
Definition: cutlass.h:111
static int const kPartitionsPerStage
Definition: reduce_split_k.h:62
static CUTLASS_HOST_DEVICE dim3 grid_shape(cutlass::MatrixCoord problem_size)
Computes the grid size given a chosen threadblock shape.
Definition: reduce_split_k.h:128
static int const kElementsPerAccess
Definition: reduce_split_k.h:61
Defines layout functions used by TensorRef and derived classes.
OutputOp_ OutputOp
Definition: reduce_split_k.h:60
MatrixCoord problem_size
Definition: reduce_split_k.h:82
Array< ElementAccumulator, kElementsPerAccess > FragmentAccumulator
Definition: reduce_split_k.h:72
Basic include for CUTLASS.
Definition: matrix_coord.h:39
Definition: reduce_split_k.h:121
OutputOp::Params output
Definition: reduce_split_k.h:88
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...