47 namespace threadblock {
60 typename SmemIteratorA_,
66 typename SmemIteratorB_,
74 typename Enable =
bool 112 using WarpFragmentA =
typename Operator::FragmentA;
113 using WarpFragmentB =
typename Operator::FragmentB;
128 typename Base::SharedStorage &shared_storage,
133 Base(shared_storage, thread_idx, warp_idx, lane_idx),
134 smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
135 smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
158 int gemm_k_iterations,
179 iterator_A.load(tb_frag_A);
180 iterator_B.load(tb_frag_B);
186 WarpFragmentA warp_frag_A[2];
187 WarpFragmentB warp_frag_B[2];
191 if (gemm_k_iterations <= 1) {
192 iterator_A.clear_mask();
193 iterator_B.clear_mask();
202 for (; gemm_k_iterations > 0; --gemm_k_iterations) {
203 this->smem_iterator_A_.store(tb_frag_A);
204 this->smem_iterator_B_.store(tb_frag_B);
228 warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
237 iterator_A.load(tb_frag_A);
238 iterator_B.load(tb_frag_B);
244 if (gemm_k_iterations <= 2) {
245 iterator_A.clear_mask();
246 iterator_B.clear_mask();
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
IteratorB_ IteratorB
Iterates over tiles of B operand in global memory.
Definition: mma_singlestage.h:84
Definition: aligned_buffer.h:35
ElementC_ ElementC
Data type of accumulator matrix.
Definition: mma_singlestage.h:85
Operator::IteratorB warp_tile_iterator_B_
Iterator to load a warp-scoped tile of B operand from shared memory.
Definition: mma_base.h:193
SmemIteratorB smem_iterator_B_
Iterator to write threadblock-scoped tile of B operand to shared memory.
Definition: mma_singlestage.h:121
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_singlestage.h:76
Defines common types used for all GEMM-like operators.
CUTLASS_DEVICE void operator()(int gemm_k_iterations, FragmentC &accum, IteratorA iterator_A, IteratorB iterator_B, FragmentC const &src_accum)
Perform a threadblock-scoped matrix multiply-accumulate.
Definition: mma_singlestage.h:157
Policy_ Policy
Policy describing tuning details.
Definition: mma_singlestage.h:87
typename IteratorA::Fragment FragmentA
Fragment of operand A loaded from global memory.
Definition: mma_singlestage.h:97
LayoutC_ LayoutC
Layout of accumulator matrix.
Definition: mma_singlestage.h:86
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<>
Definition: mma_singlestage.h:82
CUTLASS_DEVICE MmaSingleStage(typename Base::SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Construct from tensor references.
Definition: mma_singlestage.h:127
Defines a Shape template for matrix tiles.
SmemIteratorA smem_iterator_A_
Iterator to write threadblock-scoped tile of A operand to shared memory.
Definition: mma_singlestage.h:118
static int const kWarpGemmIterations
Number of warp-level GEMM oeprations.
Definition: mma_base.h:108
Template for a double-buffered threadblock-scoped GEMM kernel.
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
static int const kStages
Number of stages.
Definition: mma_base.h:112
Top-level include for all CUTLASS numeric types.
SmemIteratorB_ SmemIteratorB
Definition: mma_singlestage.h:90
typename Policy::Operator::FragmentC FragmentC
Fragment of accumulator tile.
Definition: mma_singlestage.h:103
Definition: mma_base.h:83
typename Policy::Operator Operator
Warp-level Mma.
Definition: mma_singlestage.h:106
Operator::IteratorA warp_tile_iterator_A_
Iterator to load a warp-scoped tile of A operand from shared memory.
Definition: mma_base.h:190
#define CUTLASS_GEMM_LOOP
Definition: cutlass.h:112
SmemIteratorA_ SmemIteratorA
Definition: mma_singlestage.h:89
typename IteratorB::Fragment FragmentB
Fragment of operand B loaded from global memory.
Definition: mma_singlestage.h:100
IteratorA_ IteratorA
Iterates over tiles of A operand in global memory.
Definition: mma_singlestage.h:83
Basic include for CUTLASS.
static int const kN
Definition: include/cutlass/gemm/gemm.h:59