61 #if defined(CUTLASS_ARCH_WMMA_ENABLED) 63 #endif //CUTLASS_ARCH_WMMA_ENABLED 92 typename ElementAccumulator,
94 typename OperatorClass,
98 typename ThreadblockShape,
102 typename InstructionShape,
104 typename EpilogueOutputOp,
106 typename ThreadblockSwizzle,
115 bool IsBetaZero =
false>
136 typename ElementAccumulator,
138 typename ThreadblockShape,
142 typename InstructionShape,
144 typename EpilogueOutputOp,
146 typename ThreadblockSwizzle,
153 ElementA, LayoutA, kAlignmentA,
154 ElementB, LayoutB, kAlignmentB,
155 ElementC, layout::RowMajor,
157 arch::OpClassTensorOp,
179 arch::OpClassTensorOp,
188 static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
193 typename Mma::Operator,
196 EpilogueOutputOp::kCount
217 typename ThreadblockShape,
221 typename InstructionShape,
223 typename EpilogueOutputOp,
225 typename ThreadblockSwizzle,
235 struct DefaultGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
236 kAlignmentA, ElementB,
239 int32_t, arch::OpClassTensorOp,
arch::Sm75, ThreadblockShape,
240 WarpShape, InstructionShape, EpilogueOutputOp,
241 ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero> {
246 using ElementAccumulator = int32_t;
250 ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator,
LayoutC,
251 arch::OpClassTensorOp,
arch::Sm75, ThreadblockShape, WarpShape,
252 InstructionShape, 2, Operator,
true>::ThreadblockMma;
254 static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
259 ThreadblockShape,
typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
287 typename ElementAccumulator,
289 typename ThreadblockShape,
293 typename EpilogueOutputOp,
295 typename ThreadblockSwizzle,
302 ElementA, LayoutA, kAlignmentA,
303 ElementB, LayoutB, kAlignmentB,
304 ElementC, layout::RowMajor,
306 arch::OpClassTensorOp,
328 arch::OpClassTensorOp,
337 static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
342 typename Mma::Operator,
345 EpilogueOutputOp::kCount
371 typename ElementAccumulator,
375 typename ThreadblockShape,
379 typename EpilogueOutputOp,
381 typename ThreadblockSwizzle,
423 Operator>::ThreadblockMma;
425 static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount;
426 static_assert(kEpilogueElementsPerAccess == 1,
"simt epilogue must operate on scalars");
431 typename Mma::Operator,
433 kEpilogueElementsPerAccess
461 typename ElementAccumulator,
463 typename ThreadblockShape,
467 typename EpilogueOutputOp,
469 typename ThreadblockSwizzle,
475 struct DefaultGemm<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,
476 ElementC, LayoutC, ElementAccumulator, arch::OpClassSimt,
477 ArchTag, ThreadblockShape, WarpShape,
GemmShape<1, 1, 4>,
478 EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial,
482 using ElementB = int8_t;
484 using OperatorClass = arch::OpClassSimt;
504 static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount;
505 static_assert(kEpilogueElementsPerAccess == 1,
"simt epilogue must operate on scalars");
510 typename Mma::Operator,
512 kEpilogueElementsPerAccess
520 #if defined(CUTLASS_ARCH_WMMA_ENABLED) 541 typename ElementAccumulator,
545 typename ThreadblockShape,
549 typename InstructionShape,
551 typename EpilogueOutputOp,
553 typename ThreadblockSwizzle,
563 ElementB, LayoutB, kAlignmentB,
566 arch::OpClassWmmaTensorOp,
568 ThreadblockShape, WarpShape, InstructionShape,
577 ElementB, LayoutB, kAlignmentB,
578 ElementAccumulator, LayoutC,
579 arch::OpClassWmmaTensorOp,
585 Operator>::ThreadblockMma;
587 static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
592 typename Mma::Operator,
595 EpilogueOutputOp::kCount
602 #endif //CUTLASS_ARCH_WMMA_ENABLED cutlass::gemm::kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, arch::Sm70, ThreadblockShape, WarpShape, GemmShape< 8, 8, 4 >, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator >::Mma typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm70, ThreadblockShape, WarpShape, GemmShape< 8, 8, 4 >, 2, Operator >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:335
Definition: default_gemm.h:116
cutlass::gemm::kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator >::Mma typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, 2, Operator >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:186
Definition: aligned_buffer.h:35
Defines sensible defaults for epilogues for SimtOps.
Definition: default_epilogue_simt.h:70
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Defines common types used for all GEMM-like operators.
cutlass::gemm::kernel::DefaultGemm< ElementA, layout::ColumnMajorInterleaved< InterleavedK >, kAlignmentA, ElementB, layout::RowMajorInterleaved< InterleavedK >, kAlignmentB, ElementC, layout::ColumnMajorInterleaved< InterleavedK >, int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero >::Mma typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, true >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:252
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
cutlass::gemm::kernel::DefaultGemm< int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB, ElementC, LayoutC, ElementAccumulator, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, GemmShape< 1, 1, 4 >, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator, false >::Mma typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, arch::OpClassSimt, arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:502
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Definition: default_mma.h:87
Functor performing linear combination operations used by epilogues.
Defines the size of an element in bits.
Definition: numeric_types.h:42
cutlass::gemm::kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, arch::Sm70, ThreadblockShape, WarpShape, GemmShape< 8, 8, 4 >, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator >::Epilogue typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:346
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Epilogue for threadblock scoped GEMMs using Tensor Ops.
cutlass::gemm::kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, GemmShape< 1, 1, 1 >, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator >::Epilogue typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< ThreadblockShape, typename Mma::Operator, EpilogueOutputOp, kEpilogueElementsPerAccess >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:434
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Defines sensible defaults for epilogues for TensorOps.
Definition: default_epilogue_volta_tensor_op.h:71
cutlass::gemm::kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator >::Epilogue typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:197
cutlass::gemm::kernel::DefaultGemm< ElementA, layout::ColumnMajorInterleaved< InterleavedK >, kAlignmentA, ElementB, layout::RowMajorInterleaved< InterleavedK >, kAlignmentB, ElementC, layout::ColumnMajorInterleaved< InterleavedK >, int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero >::Epilogue typename cutlass::epilogue::threadblock::DefaultInterleavedEpilogueTensorOp< ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, 64/sizeof_bits< ElementC >::value, InterleavedK, IsBetaZero >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:261
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Definition: include/cutlass/gemm/kernel/gemm.h:52
Defines layout functions used by TensorRef and derived classes.
Defines sensible defaults for epilogues for WMMA TensorOps.
Definition: default_epilogue_wmma_tensor_op.h:71
Definition: default_epilogue_tensor_op.h:147
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Definition: layout/matrix.h:343
cutlass::gemm::kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, layout::RowMajor, ElementAccumulator, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, GemmShape< 1, 1, 1 >, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator >::Mma typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, arch::Sm50, ThreadblockShape, WarpShape, GemmShape< 1, 1, 1 >, 2, Operator >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:423
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
cutlass::gemm::kernel::DefaultGemm< int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB, ElementC, LayoutC, ElementAccumulator, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, GemmShape< 1, 1, 4 >, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator, false >::ElementA int8_t ElementA
Definition: default_gemm.h:481
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
Defines sensible defaults for epilogues for TensorOps.
Definition: default_epilogue_tensor_op.h:72
cutlass::gemm::kernel::DefaultGemm< int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB, ElementC, LayoutC, ElementAccumulator, arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape, GemmShape< 1, 1, 4 >, EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial, Operator, false >::Epilogue typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< ThreadblockShape, typename Mma::Operator, EpilogueOutputOp, kEpilogueElementsPerAccess >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:513
Basic include for CUTLASS.
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
Epilogue for threadblock scoped GEMMs using SIMT.
Epilogue for threadblock scoped GEMMs using Tensor Ops on Volta.
Definition: layout/matrix.h:237