39 #if defined(CUTLASS_ARCH_WMMA_ENABLED) 55 namespace threadblock {
73 typename InstructionShape_,
86 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
87 layout::ColumnMajor, ElementB_, layout::RowMajor,
88 ElementC_, LayoutC_, arch::OpClassWmmaTensorOp, Stages,
91 using WarpShape = WarpShape_;
92 using InstructionShape = InstructionShape_;
93 using ElementA = ElementA_;
94 using LayoutA = layout::ColumnMajor;
95 using ElementB = ElementB_;
96 using LayoutB = layout::RowMajor;
97 using ElementC = ElementC_;
98 using LayoutC = LayoutC_;
99 using OperatorClass = arch::OpClassWmmaTensorOp;
102 using WarpCount = GemmShape<
103 Shape::kM / WarpShape::kM,
104 Shape::kN / WarpShape::kN,
105 Shape::kK / WarpShape::kK
110 !(Shape::kM % WarpShape::kM) &&
111 !(Shape::kN % WarpShape::kN),
112 "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." 119 static int const kThreads = WarpCount::kCount * kWarpSize;
122 static int const kAccessSizeInBits = 128;
125 using Operator = Operator_;
131 using SmemLayoutA = LayoutA;
132 using SmemLayoutB = LayoutB;
143 using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
144 layout::PitchLinearShape<Shape::kM, Shape::kK>,
150 using SmemIteratorA = transform::threadblock::RegularTileIterator<
151 MatrixShape<Shape::kM, Shape::kK>,
159 using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
160 layout::PitchLinearShape<Shape::kN, Shape::kK>,
166 using SmemIteratorB = transform::threadblock::RegularTileIterator<
167 MatrixShape<Shape::kK, Shape::kN>,
193 using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma<
205 using MmaPolicy = MmaPolicy<
207 MatrixShape<kPaddingA, 0>,
208 MatrixShape<0, kPaddingB>,
230 typename InstructionShape_,
243 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
244 layout::RowMajor, ElementB_, layout::ColumnMajor,
245 ElementC_, LayoutC_, arch::OpClassWmmaTensorOp, Stages,
247 using Shape = Shape_;
248 using WarpShape = WarpShape_;
249 using InstructionShape = InstructionShape_;
250 using ElementA = ElementA_;
251 using LayoutA = layout::RowMajor;
252 using ElementB = ElementB_;
253 using LayoutB = layout::ColumnMajor;
254 using ElementC = ElementC_;
255 using LayoutC = LayoutC_;
256 using OperatorClass = arch::OpClassWmmaTensorOp;
259 using WarpCount = GemmShape<
260 Shape::kM / WarpShape::kM,
261 Shape::kN / WarpShape::kN,
262 Shape::kK / WarpShape::kK
267 !(Shape::kM % WarpShape::kM) &&
268 !(Shape::kN % WarpShape::kN),
269 "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." 276 static int const kThreads = WarpCount::kCount * kWarpSize;
280 static int const kAccessSizeInBits = 128;
283 using Operator = Operator_;
286 static int const kWarpThreadArrangementContiguousA =
289 static int const kWarpThreadArrangementStridedA =
290 kWarpSize / kWarpThreadArrangementContiguousA;
292 static int const kWarpThreadArrangementContiguousB =
295 static int const kWarpThreadArrangementStridedB =
296 kWarpSize / kWarpThreadArrangementContiguousB;
303 using SmemLayoutA = LayoutA;
304 using SmemLayoutB = LayoutB;
313 using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
314 layout::PitchLinearShape<Shape::kK, Shape::kM>,
320 using SmemIteratorA = transform::threadblock::RegularTileIterator<
321 MatrixShape<Shape::kM, Shape::kK>,
329 using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
330 layout::PitchLinearShape<Shape::kK, Shape::kN>,
336 using SmemIteratorB = transform::threadblock::RegularTileIterator<
337 MatrixShape<Shape::kK, Shape::kN>,
363 using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma<
375 using MmaPolicy = MmaPolicy<
377 MatrixShape<0, kPaddingA>,
378 MatrixShape<kPaddingB, 0>,
401 typename InstructionShape_,
414 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
415 layout::RowMajor, ElementB_, layout::RowMajor, ElementC_,
416 LayoutC_, arch::OpClassWmmaTensorOp, Stages, Operator_> {
417 using Shape = Shape_;
418 using WarpShape = WarpShape_;
419 using InstructionShape = InstructionShape_;
420 using ElementA = ElementA_;
421 using LayoutA = layout::RowMajor;
422 using ElementB = ElementB_;
423 using LayoutB = layout::RowMajor;
424 using ElementC = ElementC_;
425 using LayoutC = LayoutC_;
426 using OperatorClass = arch::OpClassWmmaTensorOp;
429 using WarpCount = GemmShape<
430 Shape::kM / WarpShape::kM,
431 Shape::kN / WarpShape::kN,
432 Shape::kK / WarpShape::kK
437 !(Shape::kM % WarpShape::kM) &&
438 !(Shape::kN % WarpShape::kN),
439 "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." 446 static int const kThreads = WarpCount::kCount * kWarpSize;
449 static int const kAccessSizeInBits = 128;
452 using Operator = Operator_;
455 static int const kWarpThreadArrangementContiguousA =
458 static int const kWarpThreadArrangementStridedA =
459 kWarpSize / kWarpThreadArrangementContiguousA;
466 using SmemLayoutA = LayoutA;
467 using SmemLayoutB = LayoutB;
478 using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
479 layout::PitchLinearShape<Shape::kK, Shape::kM>,
486 using SmemIteratorA = transform::threadblock::RegularTileIterator<
487 MatrixShape<Shape::kM, Shape::kK>,
495 using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
496 layout::PitchLinearShape<Shape::kN, Shape::kK>,
502 using SmemIteratorB = transform::threadblock::RegularTileIterator<
503 MatrixShape<Shape::kK, Shape::kN>,
529 using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma<
541 using MmaPolicy = MmaPolicy<
543 MatrixShape<0, kPaddingA>,
544 MatrixShape<0, kPaddingB>,
565 typename InstructionShape_,
578 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
579 layout::ColumnMajor, ElementB_, layout::ColumnMajor,
580 ElementC_, LayoutC_, arch::OpClassWmmaTensorOp, Stages,
582 using Shape = Shape_;
583 using WarpShape = WarpShape_;
584 using InstructionShape = InstructionShape_;
585 using ElementA = ElementA_;
586 using LayoutA = layout::ColumnMajor;
587 using ElementB = ElementB_;
588 using LayoutB = layout::ColumnMajor;
589 using ElementC = ElementC_;
590 using LayoutC = LayoutC_;
591 using OperatorClass = arch::OpClassWmmaTensorOp;
595 GemmShape<Shape::kM / WarpShape::kM, Shape::kN / WarpShape::kN,
596 Shape::kK / WarpShape::kK>;
600 !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
601 "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
607 static int const kThreads = WarpCount::kCount * kWarpSize;
610 static int const kAccessSizeInBits = 128;
613 using Operator = Operator_;
616 static int const kWarpThreadArrangementContiguousB =
619 static int const kWarpThreadArrangementStridedB =
620 kWarpSize / kWarpThreadArrangementContiguousB;
627 using SmemLayoutA = LayoutA;
628 using SmemLayoutB = LayoutB;
639 using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
640 layout::PitchLinearShape<Shape::kM, Shape::kK>,
646 using SmemIteratorA = transform::threadblock::RegularTileIterator<
647 MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 1,
651 using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
652 layout::PitchLinearShape<Shape::kK, Shape::kN>,
658 using SmemIteratorB = transform::threadblock::RegularTileIterator<
659 MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 0,
681 using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma<
693 using MmaPolicy = MmaPolicy<
695 MatrixShape<kPaddingA, 0>,
696 MatrixShape<kPaddingB, 0>,
705 #endif // defined(CUTLASS_ARCH_WMMA_ENABLED) Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Templates implementing loading of tiles from pitch-linear rank=2 tensors.
Definition: aligned_buffer.h:35
static int const value
Definition: numeric_types.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Defines a Shape template for matrix tiles.
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
static int const value
Definition: gemm/warp/mma.h:44
Top-level include for all CUTLASS numeric types.
Policy.
Definition: mma_tensor_op_policy.h:48
Templates implementing warp-level matrix multiply-accumulate operations targeting Tensor Cores...
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
Basic include for CUTLASS.
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.