40 #if defined(CUTLASS_ARCH_WMMA_ENABLED) 42 #endif //CUTLASS_ARCH_WMMA_ENABLED 48 namespace threadblock {
66 typename ElementAccumulator_,
70 typename OperatorClass_,
74 typename ThreadblockShape_,
78 typename InstructionShape_,
85 bool AccumulatorsInRowMajor =
false 106 typename ElementAccumulator,
110 typename ThreadblockShape,
114 typename InstructionShape,
117 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
118 kAlignmentB, ElementAccumulator, layout::RowMajor,
119 arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape,
120 InstructionShape, 2, Operator, false> {
123 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
125 arch::OpClassSimt, 2, Operator>;
131 ElementA, LayoutA, 1,
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
137 ElementB, LayoutB, 0,
typename MmaCore::IteratorThreadMapB, kAlignmentB>;
141 typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
142 IteratorB,
typename MmaCore::SmemIteratorB, ElementAccumulator,
143 layout::RowMajor,
typename MmaCore::MmaPolicy>;
162 typename ElementAccumulator,
166 typename ThreadblockShape,
170 typename InstructionShape,
174 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
175 kAlignmentB, ElementAccumulator, layout::RowMajor,
176 arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape,
177 InstructionShape, 2, Operator, false> {
180 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
182 arch::OpClassTensorOp, 2, Operator>;
188 ElementA, LayoutA, 1,
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
194 ElementB, LayoutB, 0,
typename MmaCore::IteratorThreadMapB, kAlignmentB>;
198 typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
199 IteratorB,
typename MmaCore::SmemIteratorB, ElementAccumulator,
200 layout::RowMajor,
typename MmaCore::MmaPolicy>;
219 typename ElementAccumulator,
221 typename OperatorClass,
225 typename ThreadblockShape,
229 typename InstructionShape,
234 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
235 kAlignmentB, ElementAccumulator,
236 layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass,
237 ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
241 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
242 ElementB, LayoutB, ElementAccumulator,
247 "Alignment must match thread data map's vector length");
250 "Alignment must match thread data map's vector length");
255 LayoutA, 1,
typename MmaCore::IteratorThreadMapA>;
260 LayoutB, 0,
typename MmaCore::IteratorThreadMapB>;
264 typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
265 IteratorB,
typename MmaCore::SmemIteratorB, ElementAccumulator,
266 layout::ColumnMajorInterleaved<InterleavedK>,
267 typename MmaCore::MmaPolicy>;
283 typename ElementAccumulator,
287 typename ThreadblockShape,
292 struct DefaultMma<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,
293 ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
294 ArchTag, ThreadblockShape, WarpShape,
GemmShape<1, 1, 4>, 2,
298 using ElementB = int8_t;
299 using OperatorClass = arch::OpClassSimt;
306 ThreadblockShape, WarpShape, InstructionShape,
ElementA, LayoutA,
308 OperatorClass, 2, Operator>;
314 ElementA, LayoutA, 1,
typename MmaCore::IteratorThreadMapA, transposeA>;
320 ElementB, LayoutB, 0,
typename MmaCore::IteratorThreadMapB, transposeB>;
324 typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
325 IteratorB,
typename MmaCore::SmemIteratorB, ElementAccumulator,
326 layout::RowMajor,
typename MmaCore::MmaPolicy>;
329 #if defined(CUTLASS_ARCH_WMMA_ENABLED) 345 typename ElementAccumulator,
351 typename ThreadblockShape,
355 typename InstructionShape,
359 kAlignmentB, ElementAccumulator, LayoutC,
360 arch::OpClassWmmaTensorOp, ArchTag, ThreadblockShape, WarpShape,
361 InstructionShape, 2, Operator> {
364 ThreadblockShape, WarpShape, InstructionShape,
ElementA, LayoutA,
365 ElementB, LayoutB, ElementAccumulator, LayoutC,
366 arch::OpClassWmmaTensorOp, 2, Operator>;
372 ElementA, LayoutA, 1,
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
378 ElementB, LayoutB, 0,
typename MmaCore::IteratorThreadMapB, kAlignmentB>;
382 typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
383 IteratorB,
typename MmaCore::SmemIteratorB, ElementAccumulator,
384 LayoutC,
typename MmaCore::MmaPolicy>;
402 typename ElementAccumulator,
408 typename ThreadblockShape,
412 typename InstructionShape,
415 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
416 kAlignmentB, ElementAccumulator, LayoutC,
417 arch::OpClassWmmaTensorOp, ArchTag, ThreadblockShape, WarpShape,
418 InstructionShape, 1, Operator> {
421 ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
422 ElementB, LayoutB, ElementAccumulator, LayoutC,
423 arch::OpClassWmmaTensorOp, 1, Operator>;
429 ElementA, LayoutA, 1,
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
435 ElementB, LayoutB, 0,
typename MmaCore::IteratorThreadMapB, kAlignmentB>;
439 typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
440 IteratorB,
typename MmaCore::SmemIteratorB, ElementAccumulator,
441 LayoutC,
typename MmaCore::MmaPolicy>;
444 #endif //CUTLASS_ARCH_WMMA_ENABLED Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
cutlass::gemm::threadblock::DefaultMma< int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, GemmShape< 1, 1, 4 >, 2, Operator, false >::MmaCore typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, 2, Operator > MmaCore
Definition: default_mma.h:308
Definition: default_mma_core.h:90
cutlass::gemm::threadblock::DefaultMma< int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, GemmShape< 1, 1, 4 >, 2, Operator, false >::ElementA int8_t ElementA
Definition: default_mma.h:297
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_pipelined.h:86
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_singlestage.h:76
cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false >::MmaCore typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, 2, Operator > MmaCore
Definition: default_mma.h:125
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Definition: default_mma.h:87
Defines the size of an element in bits.
Definition: numeric_types.h:42
cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false >::MmaCore typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, Operator > MmaCore
Definition: default_mma.h:182
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Defines tags for architecture-specific configurations.
Definition: layout/matrix.h:343
cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::ColumnMajorInterleaved< InterleavedK >, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, true >::MmaCore typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::ColumnMajorInterleaved< InterleavedK >, OperatorClass, 2, Operator, true > MmaCore
Definition: default_mma.h:244
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
Basic include for CUTLASS.
Templates implementing loading of tiles from pitch-linear rank=2 tensors.