35 #if defined(CUTLASS_ARCH_WMMA_ENABLED) 81 typename Enable =
bool 83 class MmaTensorOpWmma {
89 using ElementA = ElementA_;
92 using LayoutA = LayoutA_;
95 using ElementB = ElementB_;
98 using LayoutB = LayoutB_;
101 using ElementC = ElementC_;
104 using LayoutC = LayoutC_;
107 using Policy = Policy_;
110 using OperatorClass = arch::OpClassTensorOp;
113 static int const kThreadCount = 32;
116 static int const kPartitionsK = PartitionsK_;
119 static int const kPartitionsN = PartitionsN_;
124 using IteratorA = MmaTensorOpWmmaMultiplicandTileIterator<
125 MatrixShape<Shape::kM, Shape::kK>,
Operand::kA, ElementA, LayoutA,
126 Policy::OpDelta::kRow, kThreadCount, Policy>;
129 using FragmentA =
typename IteratorA::Fragment;
132 using IteratorB = MmaTensorOpWmmaMultiplicandTileIterator<
133 MatrixShape<Shape::kK, Shape::kN>,
Operand::kB, ElementB, LayoutB,
134 Policy::OpDelta::kRow, kThreadCount, Policy>;
137 using FragmentB =
typename IteratorB::Fragment;
140 using IteratorC = MmaTensorOpWmmaAccumulatorTileIterator<
141 MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
142 typename Policy::OpDelta, Policy>;
145 using FragmentC =
typename IteratorC::Fragment;
150 !(Shape::kM % Policy::Operator::Shape::kM) &&
151 !(Shape::kN % Policy::Operator::Shape::kN),
152 "Shape of warp-level Wmma must be divisible by operator shape (wmma native size)");
155 using WmmaIterations = MatrixShape<
156 Shape::kM / Policy::Operator::Shape::kM,
157 (Shape::kN / Policy::Operator::Shape::kN / kPartitionsN > 0) ?
158 Shape::kN / Policy::Operator::Shape::kN / kPartitionsN :
165 typename Policy::Operator wmma;
184 int const &partitionN_idx = 0)
const {
187 for (
int n = 0; n < WmmaIterations::kColumn; ++n) {
189 for (
int m = 0; m < WmmaIterations::kRow; ++m) {
192 wmma(D[m * WmmaIterations::kColumn + n], A[m], B[n], C[m * WmmaIterations::kColumn + n]);
205 #endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Definition: aligned_buffer.h:35
Architecture-specific operators on memory added for SM75.
Defines common types used for all GEMM-like operators.
Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for warp-level multiply-add operations.
Defines a Shape template for matrix tiles.
Top-level include for all CUTLASS numeric types.
Matrix multiply for SM75.
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
Basic include for CUTLASS.
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.