42 namespace threadblock {
51 typename SmemPaddingA_,
53 typename SmemPaddingB_,
82 typename Enable =
bool>
104 Shape::kN / WarpGemm::kN,
105 Shape::kK / WarpGemm::kK>;
108 static int const kWarpGemmIterations =
109 (WarpGemm::kK / Operator::Policy::MmaShape::kK);
112 static int const kStages = Stages;
133 Shape::kK * kStages +
134 Policy::SmemPaddingA::kColumn>;
138 MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
139 Shape::kN + Policy::SmemPaddingB::kColumn>;
161 return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
167 return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
201 SharedStorage &shared_storage,
209 warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
210 warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {
Policy_ Policy
Definition: mma_base.h:89
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
Architecture-specific operators on memory.
AlignedBuffer< typename Operator::ElementB, ShapeB::kCount > operand_B
Buffer for B operand.
Definition: mma_base.h:150
Operator::IteratorB warp_tile_iterator_B_
Iterator to load a warp-scoped tile of B operand from shared memory.
Definition: mma_base.h:193
typename Policy::Operator::Shape WarpGemm
Definition: mma_base.h:100
Defines common types used for all GEMM-like operators.
Shared storage object needed by threadblock-scoped GEMM.
Definition: mma_base.h:125
Shape_ Shape
Policy describing tuning details.
Definition: mma_base.h:88
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Operator_ Operator
Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) ...
Definition: mma_base.h:58
SmemPaddingA_ SmemPaddingA
Padding used for A operand in shared memory.
Definition: mma_base.h:61
Defines a Shape template for matrix tiles.
static CUTLASS_HOST_DEVICE Operator::LayoutB LayoutB()
Returns a layout object for the B matrix.
Definition: mma_base.h:166
Policy object describing MmaTensorOp.
Definition: mma_base.h:56
Definition: tensor_ref.h:146
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Modifies semantics of cutlass::Array<> to provide guaranteed alignment.
Definition: aligned_buffer.h:45
CUTLASS_HOST_DEVICE TensorRefA operand_A_ref()
Returns a TensorRef to the A operand.
Definition: mma_base.h:172
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
CUTLASS_HOST_DEVICE pointer data()
Definition: aligned_buffer.h:84
CUTLASS_DEVICE MmaBase(SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx)
Construct from tensor references.
Definition: mma_base.h:199
Definition: mma_base.h:83
typename Policy::Operator Operator
Warp-level Mma.
Definition: mma_base.h:96
Operator::IteratorA warp_tile_iterator_A_
Iterator to load a warp-scoped tile of A operand from shared memory.
Definition: mma_base.h:190
AlignedBuffer< typename Operator::ElementA, ShapeA::kCount > operand_A
Buffer for A operand.
Definition: mma_base.h:147
static CUTLASS_DEVICE Operator::LayoutA LayoutA()
Returns a layout object for the A matrix.
Definition: mma_base.h:160
SmemPaddingB_ SmemPaddingB
Padding used for B operand in shared memory.
Definition: mma_base.h:64
static int const kPartitionsK
Number of partitions of K dimension.
Definition: mma_base.h:67
CUTLASS_HOST_DEVICE TensorRefB operand_B_ref()
Returns a TensorRef to the B operand.
Definition: mma_base.h:178
Basic include for CUTLASS.