46 namespace threadblock {
59 typename SmemIteratorA_,
65 typename SmemIteratorB_,
73 typename TransformA_ = NumericArrayConverter<
74 typename SmemIteratorA_::Element,
75 typename IteratorA_::Element,
76 IteratorA_::Fragment::kElements>,
79 typename TransformB_ = NumericArrayConverter<
80 typename SmemIteratorB_::Element,
81 typename IteratorB_::Element,
82 IteratorB_::Fragment::kElements>,
84 typename Enable =
bool 126 using WarpFragmentA =
typename Operator::FragmentA;
127 using WarpFragmentB =
typename Operator::FragmentB;
142 typename Base::SharedStorage &shared_storage,
147 Base(shared_storage, thread_idx, warp_idx, lane_idx),
148 smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
149 smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
171 int gemm_k_iterations,
193 iterator_A.load(tb_frag_A);
194 iterator_B.load(tb_frag_B);
199 this->smem_iterator_A_.store(transform_A(tb_frag_A));
200 this->smem_iterator_B_.store(transform_B(tb_frag_B));
208 WarpFragmentA warp_frag_A[2];
209 WarpFragmentB warp_frag_B[2];
222 int smem_write_stage_idx = 1;
225 if (gemm_k_iterations <= 1) {
226 iterator_A.clear_mask();
227 iterator_B.clear_mask();
239 for (; gemm_k_iterations > 0; --gemm_k_iterations) {
250 if (warp_mma_k == Base::kWarpGemmIterations - 1) {
253 this->smem_iterator_A_.store(transform_A(tb_frag_A));
255 this->smem_iterator_B_.store(transform_B(tb_frag_B));
263 if (smem_write_stage_idx == 1) {
269 {0, -
Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
275 smem_write_stage_idx ^= 1;
287 if (warp_mma_k == 0) {
289 iterator_A.load(tb_frag_A);
290 iterator_B.load(tb_frag_B);
296 if (gemm_k_iterations <= 2) {
297 iterator_A.clear_mask();
298 iterator_B.clear_mask();
302 warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
LayoutC_ LayoutC
Layout of accumulator matrix.
Definition: mma_pipelined.h:96
TransformB_ TransformB
Definition: mma_pipelined.h:103
Definition: aligned_buffer.h:35
Policy_ Policy
Policy describing tuning details.
Definition: mma_pipelined.h:97
Operator::IteratorB warp_tile_iterator_B_
Iterator to load a warp-scoped tile of B operand from shared memory.
Definition: mma_base.h:193
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_pipelined.h:86
IteratorB_ IteratorB
Iterates over tiles of B operand in global memory.
Definition: mma_pipelined.h:94
Defines common types used for all GEMM-like operators.
CUTLASS_DEVICE void operator()(int gemm_k_iterations, FragmentC &accum, IteratorA iterator_A, IteratorB iterator_B, FragmentC const &src_accum, TransformA transform_A=TransformA(), TransformB transform_B=TransformB())
Perform a threadblock-scoped matrix multiply-accumulate.
Definition: mma_pipelined.h:170
IteratorA_ IteratorA
Iterates over tiles of A operand in global memory.
Definition: mma_pipelined.h:93
typename IteratorB::Fragment FragmentB
Fragment of operand B loaded from global memory.
Definition: mma_pipelined.h:113
SmemIteratorA_ SmemIteratorA
Definition: mma_pipelined.h:99
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Boost-like numeric conversion operator for CUTLASS numeric types.
Defines a Shape template for matrix tiles.
static int const kWarpGemmIterations
Number of warp-level GEMM oeprations.
Definition: mma_base.h:108
Template for a double-buffered threadblock-scoped GEMM kernel.
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<>
Definition: mma_pipelined.h:92
static int const kStages
Number of stages.
Definition: mma_base.h:112
typename IteratorA::Fragment FragmentA
Fragment of operand A loaded from global memory.
Definition: mma_pipelined.h:110
Top-level include for all CUTLASS numeric types.
Definition: mma_base.h:83
typename Policy::Operator::FragmentC FragmentC
Fragment of accumulator tile.
Definition: mma_pipelined.h:116
Operator::IteratorA warp_tile_iterator_A_
Iterator to load a warp-scoped tile of A operand from shared memory.
Definition: mma_base.h:190
#define CUTLASS_GEMM_LOOP
Definition: cutlass.h:112
ElementC_ ElementC
Data type of accumulator matrix.
Definition: mma_pipelined.h:95
SmemIteratorA smem_iterator_A_
Iterator to write threadblock-scoped tile of A operand to shared memory.
Definition: mma_pipelined.h:132
SmemIteratorB_ SmemIteratorB
Definition: mma_pipelined.h:100
SmemIteratorB smem_iterator_B_
Iterator to write threadblock-scoped tile of B operand to shared memory.
Definition: mma_pipelined.h:135
CUTLASS_DEVICE MmaPipelined(typename Base::SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Construct from tensor references.
Definition: mma_pipelined.h:141
Basic include for CUTLASS.
TransformA_ TransformA
Definition: mma_pipelined.h:102
typename Policy::Operator Operator
Warp-level Mma.
Definition: mma_pipelined.h:119
static int const kN
Definition: include/cutlass/gemm/gemm.h:59