49 typename ThreadblockSwizzle_,
69 typename Mma::IteratorA::TensorRef
ref_A;
71 typename Mma::IteratorB::TensorRef
ref_B;
72 typename Epilogue::OutputTileIterator::Params
params_C;
73 typename Epilogue::OutputTileIterator::TensorRef
ref_C;
74 typename Epilogue::OutputTileIterator::Params
params_D;
75 typename Epilogue::OutputTileIterator::TensorRef
ref_D;
92 typename Mma::IteratorA::TensorRef ref_A,
93 typename Mma::IteratorB::TensorRef ref_B,
94 typename Epilogue::OutputTileIterator::TensorRef ref_C,
95 typename Epilogue::OutputTileIterator::TensorRef ref_D,
96 typename OutputOp::Params output_op =
typename OutputOp::Params(),
97 int *semaphore =
nullptr 99 problem_size(problem_size),
100 grid_tiled_shape(grid_tiled_shape),
101 params_A(ref_A.layout()),
103 params_B(ref_B.layout()),
105 params_C(ref_C.layout()),
107 params_D(ref_D.layout()),
109 output_op(output_op),
110 semaphore(semaphore) {
112 int total_gemm_k_iterations = (problem_size.
k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
113 int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.
k() - 1) / grid_tiled_shape.
k();
115 gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
135 typename Mma::IteratorA::TensorRef ref_A,
136 typename Mma::IteratorB::TensorRef ref_B,
137 typename Epilogue::OutputTileIterator::TensorRef ref_C,
138 typename Epilogue::OutputTileIterator::TensorRef ref_D) {
140 static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
141 static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
142 static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
160 if ((problem_size.
m() % kAlignmentA) || (problem_size.
k() % kAlignmentA) ||
161 (problem_size.
n() % kAlignmentB) || (problem_size.
k() % kAlignmentB) ||
162 (problem_size.
m() % kAlignmentC) || (problem_size.
n() % kAlignmentC)) {
188 threadblock_tile_offset.
m() * Mma::Shape::kM,
194 threadblock_tile_offset.
n() * Mma::Shape::kN
198 int problem_size_k =
min(
200 (threadblock_tile_offset.
k() + 1) * params.
gemm_k_size);
203 int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
206 int thread_idx = threadIdx.x;
209 typename Mma::IteratorA iterator_A(
216 typename Mma::IteratorB iterator_B(
223 int warp_idx = threadIdx.x / 32;
224 int lane_idx = threadIdx.x % 32;
231 Mma mma(shared_storage.
main_loop, thread_idx, warp_idx, lane_idx);
233 typename Mma::FragmentC accumulators;
235 accumulators.clear();
237 if (!kSplitKSerial || gemm_k_iterations > 0) {
239 mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
252 threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
256 threadblock_tile_offset.
m() * Mma::Shape::kM,
257 threadblock_tile_offset.
n() * Mma::Shape::kN
260 int block_idx = threadblock_tile_offset.
m() + threadblock_tile_offset.
n() * params.
grid_tiled_shape.
m();
272 output_op.set_k_partition(threadblock_tile_offset.
k());
276 typename Epilogue::OutputTileIterator iterator_C(
285 typename Epilogue::OutputTileIterator iterator_D(
303 if (threadblock_tile_offset.
k()) {
304 iterator_C = iterator_D;
307 semaphore.
wait(threadblock_tile_offset.
k());
313 epilogue(output_op, iterator_D, accumulators, iterator_C);
329 lock = threadblock_tile_offset.
k() + 1;
Epilogue::OutputTileIterator::TensorRef ref_C
Definition: include/cutlass/gemm/kernel/gemm.h:73
Definition: aligned_buffer.h:35
Epilogue::SharedStorage epilogue
Definition: include/cutlass/gemm/kernel/gemm.h:122
Epilogue::OutputTileIterator::Params params_D
Definition: include/cutlass/gemm/kernel/gemm.h:74
Mma::IteratorA::Params params_A
Definition: include/cutlass/gemm/kernel/gemm.h:68
Epilogue_ Epilogue
Definition: include/cutlass/gemm/kernel/gemm.h:55
Mma::IteratorB::Params params_B
Definition: include/cutlass/gemm/kernel/gemm.h:70
CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size, cutlass::gemm::GemmCoord const &grid_tiled_shape, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_C, typename Epilogue::OutputTileIterator::TensorRef ref_D, typename OutputOp::Params output_op=typename OutputOp::Params(), int *semaphore=nullptr)
Definition: include/cutlass/gemm/kernel/gemm.h:89
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_HOST_DEVICE Coord< 2 > mn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:171
Epilogue::OutputTileIterator::Params params_C
Definition: include/cutlass/gemm/kernel/gemm.h:72
static int const kThreadCount
Definition: include/cutlass/gemm/kernel/gemm.h:62
Defines common types used for all GEMM-like operators.
CUTLASS_DEVICE void fetch()
Permit fetching the synchronization mechanism early.
Definition: semaphore.h:68
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: include/cutlass/gemm/kernel/gemm.h:67
int gemm_k_iterations
Definition: include/cutlass/gemm/kernel/gemm.h:78
Mma::IteratorB::TensorRef ref_B
Definition: include/cutlass/gemm/kernel/gemm.h:71
static Status can_implement(cutlass::gemm::GemmCoord const &problem_size, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::OutputTileIterator::TensorRef ref_C, typename Epilogue::OutputTileIterator::TensorRef ref_D)
Determines whether kernel satisfies alignment.
Definition: include/cutlass/gemm/kernel/gemm.h:133
CUTLASS_HOST_DEVICE Gemm()
Definition: include/cutlass/gemm/kernel/gemm.h:130
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
static bool const kSplitKSerial
Definition: include/cutlass/gemm/kernel/gemm.h:58
typename Epilogue::OutputOp OutputOp
Definition: include/cutlass/gemm/kernel/gemm.h:56
Parameters structure.
Definition: include/cutlass/gemm/kernel/gemm.h:65
OutputOp::Params output_op
Definition: include/cutlass/gemm/kernel/gemm.h:76
operands fail alignment requirements.
Shared memory storage structure.
Definition: include/cutlass/gemm/kernel/gemm.h:120
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
int gemm_k_size
Definition: include/cutlass/gemm/kernel/gemm.h:79
int * semaphore
Definition: include/cutlass/gemm/kernel/gemm.h:77
CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage)
Executes one GEMM.
Definition: include/cutlass/gemm/kernel/gemm.h:172
CTA-wide semaphore for inter-CTA synchronization.
Definition: semaphore.h:48
Implementation of a CTA-wide semaphore for inter-CTA synchronization.
Defines a canonical coordinate for rank=2 matrices offering named indices.
CUTLASS_DEVICE void release(int status=0)
Updates the lock with the given result.
Definition: semaphore.h:98
cutlass::gemm::GemmCoord problem_size
Definition: include/cutlass/gemm/kernel/gemm.h:66
ThreadblockSwizzle_ ThreadblockSwizzle
Definition: include/cutlass/gemm/kernel/gemm.h:57
Definition: include/cutlass/gemm/kernel/gemm.h:52
Mma::IteratorA::TensorRef ref_A
Definition: include/cutlass/gemm/kernel/gemm.h:69
bool TensorRef_aligned(TensorRef< Element, Layout > const &ref, int alignment)
Definition: tensor_ref.h:382
CUTLASS_DEVICE void wait(int status=0)
Waits until the semaphore is equal to the given value.
Definition: semaphore.h:81
Operation was successful.
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Mma_ Mma
Definition: include/cutlass/gemm/kernel/gemm.h:54
typename Mma::WarpCount WarpCount
Warp count (concept: GemmShape)
Definition: include/cutlass/gemm/kernel/gemm.h:61
Basic include for CUTLASS.
Definition: matrix_coord.h:39
CUTLASS_HOST_DEVICE Params()
Definition: include/cutlass/gemm/kernel/gemm.h:86
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
Mma::SharedStorage main_loop
Definition: include/cutlass/gemm/kernel/gemm.h:121
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: include/cutlass/gemm/kernel/gemm.h:75