47 typename ThreadblockSwizzle_
67 typename Mma::IteratorA::TensorRef
ref_A;
69 typename Mma::IteratorB::TensorRef
ref_B;
70 typename Epilogue::OutputTileIterator::Params
params_D;
71 typename Epilogue::OutputTileIterator::TensorRef
ref_D;
87 typename Mma::IteratorA::TensorRef ref_A,
88 typename Mma::IteratorB::TensorRef ref_B,
89 typename Epilogue::OutputTileIterator::TensorRef ref_D,
90 typename OutputOp::Params output_op,
91 int64_t splitk_slice_stride
93 problem_size(problem_size),
94 grid_tiled_shape(grid_tiled_shape),
95 params_A(ref_A.layout()),
97 params_B(ref_B.layout()),
99 params_D(ref_D.layout()),
101 output_op(output_op),
102 splitk_slice_stride(splitk_slice_stride) {
104 int full_gemm_k_iterations = problem_size.
k() / Mma::Shape::kK;
105 int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.
k();
107 gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
142 threadblock_tile_offset.
m() * Mma::Shape::kM,
148 threadblock_tile_offset.
n() * Mma::Shape::kN
157 problem_size_k = (threadblock_tile_offset.
k() + 1) * params.
gemm_k_size;
161 int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
164 int thread_idx = threadIdx.x;
167 typename Mma::IteratorA iterator_A(
174 typename Mma::IteratorB iterator_B(
181 int warp_idx = threadIdx.x / 32;
182 int lane_idx = threadIdx.x % 32;
190 Mma mma(shared_storage.
main_loop, thread_idx, warp_idx, lane_idx);
192 typename Mma::FragmentC accumulators;
194 accumulators.clear();
196 mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
208 threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
212 threadblock_tile_offset.
m() * Mma::Shape::kM,
213 threadblock_tile_offset.
n() * Mma::Shape::kN
217 typename Epilogue::OutputTileIterator iterator_D(
235 epilogue(output_op, iterator_D, accumulators, iterator_D);
CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage)
Executes one GEMM.
Definition: kernel/gemm_splitk_parallel.h:126
CUTLASS_HOST_DEVICE GemmSplitKParallel()
Definition: kernel/gemm_splitk_parallel.h:122
Definition: aligned_buffer.h:35
Epilogue_ Epilogue
Definition: kernel/gemm_splitk_parallel.h:52
cutlass::gemm::GemmCoord problem_size
Definition: kernel/gemm_splitk_parallel.h:64
Shared memory storage structure.
Definition: kernel/gemm_splitk_parallel.h:112
Epilogue::SharedStorage epilogue
Definition: kernel/gemm_splitk_parallel.h:114
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_HOST_DEVICE Coord< 2 > mn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:171
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: kernel/gemm_splitk_parallel.h:65
static int const kThreadCount
Definition: kernel/gemm_splitk_parallel.h:58
Mma::SharedStorage main_loop
Definition: kernel/gemm_splitk_parallel.h:113
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
Parameters structure.
Definition: kernel/gemm_splitk_parallel.h:63
typename Mma::WarpCount WarpCount
Warp count (concept: GemmShape)
Definition: kernel/gemm_splitk_parallel.h:57
ThreadblockSwizzle_ ThreadblockSwizzle
Definition: kernel/gemm_splitk_parallel.h:54
CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size, cutlass::gemm::GemmCoord const &grid_tiled_shape, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_D, typename OutputOp::Params output_op, int64_t splitk_slice_stride)
Definition: kernel/gemm_splitk_parallel.h:84
OutputOp::Params output_op
Definition: kernel/gemm_splitk_parallel.h:72
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
Mma::IteratorA::TensorRef ref_A
Definition: kernel/gemm_splitk_parallel.h:67
Mma::IteratorB::TensorRef ref_B
Definition: kernel/gemm_splitk_parallel.h:69
int gemm_k_size
Definition: kernel/gemm_splitk_parallel.h:74
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: kernel/gemm_splitk_parallel.h:71
CUTLASS_HOST_DEVICE Params()
Definition: kernel/gemm_splitk_parallel.h:81
Epilogue::OutputTileIterator::Params params_D
Definition: kernel/gemm_splitk_parallel.h:70
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Mma::IteratorA::Params params_A
Definition: kernel/gemm_splitk_parallel.h:66
static int const kAlignmentK
Definition: kernel/gemm_splitk_parallel.h:60
Defines a canonical coordinate for rank=2 matrices offering named indices.
Definition: kernel/gemm_splitk_parallel.h:49
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Mma_ Mma
Definition: kernel/gemm_splitk_parallel.h:51
Mma::IteratorB::Params params_B
Definition: kernel/gemm_splitk_parallel.h:68
int64_t splitk_slice_stride
Definition: kernel/gemm_splitk_parallel.h:73
Basic include for CUTLASS.
Definition: matrix_coord.h:39
typename Epilogue::OutputOp OutputOp
Definition: kernel/gemm_splitk_parallel.h:53