47 typename ThreadblockSwizzle_
65 typename Mma::IteratorA::TensorRef
ref_A;
68 typename Mma::IteratorB::TensorRef
ref_B;
70 typename Epilogue::OutputTileIterator::Params
params_C;
71 typename Epilogue::OutputTileIterator::TensorRef
ref_C;
73 typename Epilogue::OutputTileIterator::Params
params_D;
74 typename Epilogue::OutputTileIterator::TensorRef
ref_D;
91 typename Mma::IteratorA::TensorRef ref_A_,
93 typename Mma::IteratorB::TensorRef ref_B_,
95 typename Epilogue::OutputTileIterator::TensorRef ref_C_,
97 typename Epilogue::OutputTileIterator::TensorRef ref_D_,
99 typename OutputOp::Params epilogue_,
102 problem_size(problem_size_),
103 grid_tiled_shape(grid_tiled_shape_),
104 params_A(ref_A_.layout()),
107 params_B(ref_B_.layout()),
110 params_C(ref_C_.layout()),
113 params_D(ref_D_.layout()),
117 batch_count(batch_count_),
118 gemm_k_iterations((problem_size.k() +
Mma::Shape::kK - 1) /
Mma::Shape::kK) {
154 for (
int batch_idx = threadblock_swizzle.get_batch_idx();
156 batch_idx += gridDim.z) {
160 threadblock_tile_offset.
m() * Mma::Shape::kM,
166 threadblock_tile_offset.
n() * Mma::Shape::kN
170 int thread_idx = threadIdx.x;
173 typename Mma::IteratorA iterator_A(
180 iterator_A.add_pointer_offset(params.
stride_A * batch_idx);
182 typename Mma::IteratorB iterator_B(
189 iterator_B.add_pointer_offset(params.
stride_B * batch_idx);
197 int warp_idx = threadIdx.x / 32;
198 int lane_idx = threadIdx.x % 32;
200 Mma mma(shared_storage.
main_loop, thread_idx, warp_idx, lane_idx);
202 typename Mma::FragmentC accumulators;
204 accumulators.clear();
208 mma(params.
gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
220 threadblock_tile_offset = threadblock_swizzle.get_tile_offset();
224 threadblock_tile_offset.
m() * Mma::Shape::kM,
225 threadblock_tile_offset.
n() * Mma::Shape::kN
229 typename Epilogue::OutputTileIterator iterator_C(
237 iterator_C.add_pointer_offset(params.
stride_C * batch_idx);
240 typename Epilogue::OutputTileIterator iterator_D(
248 iterator_D.add_pointer_offset(params.
stride_D * batch_idx);
257 epilogue(output_op, iterator_D, accumulators, iterator_C);
CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage)
Executes one GEMM.
Definition: kernel/gemm_batched.h:138
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE Params()
Definition: kernel/gemm_batched.h:85
typename Epilogue::OutputOp OutputOp
Definition: kernel/gemm_batched.h:53
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: kernel/gemm_batched.h:74
Definition: include/cutlass/gemm/gemm.h:94
CUTLASS_HOST_DEVICE Coord< 2 > mn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:171
Defines common types used for all GEMM-like operators.
Mma::IteratorB::TensorRef ref_B
Definition: kernel/gemm_batched.h:68
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
int gemm_k_iterations
Definition: kernel/gemm_batched.h:78
Epilogue::OutputTileIterator::TensorRef ref_C
Definition: kernel/gemm_batched.h:71
CUTLASS_HOST_DEVICE GemmBatched()
Definition: kernel/gemm_batched.h:134
Epilogue_ Epilogue
Definition: kernel/gemm_batched.h:52
Shared memory storage structure.
Definition: kernel/gemm_batched.h:124
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: kernel/gemm_batched.h:63
Mma::SharedStorage main_loop
Definition: kernel/gemm_batched.h:125
static int const kThreadCount
Definition: kernel/gemm_batched.h:58
Parameters structure.
Definition: kernel/gemm_batched.h:61
typename Mma::WarpCount WarpCount
Warp count (concept: GemmShape)
Definition: kernel/gemm_batched.h:57
Epilogue::OutputTileIterator::Params params_D
Definition: kernel/gemm_batched.h:73
Epilogue::OutputTileIterator::Params params_C
Definition: kernel/gemm_batched.h:70
OutputOp::Params epilogue
Definition: kernel/gemm_batched.h:76
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
int64_t stride_C
Definition: kernel/gemm_batched.h:72
CUTLASS_HOST_DEVICE Coord< 2 > mk() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:177
Mma_ Mma
Definition: kernel/gemm_batched.h:51
cutlass::gemm::GemmCoord problem_size
Definition: kernel/gemm_batched.h:62
Mma::IteratorA::Params params_A
Definition: kernel/gemm_batched.h:64
Defines a canonical coordinate for rank=2 matrices offering named indices.
int batch_count
Definition: kernel/gemm_batched.h:77
Mma::IteratorB::Params params_B
Definition: kernel/gemm_batched.h:67
CUTLASS_HOST_DEVICE Coord< 2 > kn() const
Obtains a Coord<2> from GemmCoord.
Definition: include/cutlass/gemm/gemm.h:195
int64_t stride_B
Definition: kernel/gemm_batched.h:69
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Mma::IteratorA::TensorRef ref_A
Definition: kernel/gemm_batched.h:65
CUTLASS_HOST_DEVICE Params(cutlass::gemm::GemmCoord const &problem_size_, cutlass::gemm::GemmCoord const &grid_tiled_shape_, typename Mma::IteratorA::TensorRef ref_A_, int64_t stride_A_, typename Mma::IteratorB::TensorRef ref_B_, int64_t stride_B_, typename Epilogue::OutputTileIterator::TensorRef ref_C_, int64_t stride_C_, typename Epilogue::OutputTileIterator::TensorRef ref_D_, int64_t stride_D_, typename OutputOp::Params epilogue_, int batch_count_)
Definition: kernel/gemm_batched.h:88
int64_t stride_A
Definition: kernel/gemm_batched.h:66
Definition: kernel/gemm_batched.h:49
Epilogue::SharedStorage epilogue
Definition: kernel/gemm_batched.h:126
int64_t stride_D
Definition: kernel/gemm_batched.h:75
ThreadblockSwizzle_ ThreadblockSwizzle
Definition: kernel/gemm_batched.h:54
Basic include for CUTLASS.
Definition: matrix_coord.h:39