76 bool AccumulatorsInRowMajor =
false,
80 typename Enable =
bool 143 typename Policy::Operator::Shape,
typename Policy::OpDelta>;
151 !(Shape::kM % Policy::Operator::Shape::kM) &&
152 !(Shape::kN % Policy::Operator::Shape::kN),
153 "Shape of warp-level Mma must be divisible by operator shape.");
157 Shape::kM / Policy::Operator::Shape::kM,
158 (Shape::kN / Policy::Operator::Shape::kN / kPartitionsN > 0) ?
159 Shape::kN / Policy::Operator::Shape::kN / kPartitionsN :
166 typename Policy::Operator
mma;
185 int const &partitionN_idx = 0)
const {
187 using MmaOperandA =
typename Policy::Operator::FragmentA;
188 using MmaOperandB =
typename Policy::Operator::FragmentB;
189 using MmaOperandC =
typename Policy::Operator::FragmentC;
193 MmaOperandA
const *ptr_A =
reinterpret_cast<MmaOperandA
const *
>(&A);
194 MmaOperandB
const *ptr_B =
reinterpret_cast<MmaOperandB
const *
>(&B);
195 MmaOperandC *ptr_D =
reinterpret_cast<MmaOperandC *
>(&D);
198 const int n_off = partitionN_idx * FragmentB::kElements / MmaOperandB::kElements /
kPartitionsN;
201 for (
int n = 0; n < MmaIterations::kColumn; ++n) {
204 for (
int m = 0; m < MmaIterations::kRow; ++m) {
206 int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
208 if (AccumulatorsInRowMajor) {
210 ptr_D[n + m_serpentine * MmaIterations::kColumn],
213 ptr_D[n + m_serpentine * MmaIterations::kColumn]);
216 ptr_D[m_serpentine + (n + n_off) * MmaIterations::kRow],
219 ptr_D[m_serpentine + (n + n_off) * MmaIterations::kRow]);
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
typename IteratorA::Fragment FragmentA
Storage for A tile.
Definition: mma_tensor_op.h:129
Definition: aligned_buffer.h:35
LayoutB_ LayoutB
Layout of multiplicand B.
Definition: mma_tensor_op.h:97
CUTLASS_DEVICE MmaTensorOp()
Ctor.
Definition: mma_tensor_op.h:176
Architecture-specific operators on memory added for SM75.
Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
Defines common types used for all GEMM-like operators.
static int const kThreadCount
Number of threads participating in warp-level matrix product.
Definition: mma_tensor_op.h:112
static int const kPartitionsN
PartitionsN indicating how many PartitionsN for multiplicand B.
Definition: mma_tensor_op.h:118
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_tensor_op.h:82
LayoutA_ LayoutA
Layout of multiplicand A.
Definition: mma_tensor_op.h:91
typename IteratorB::Fragment FragmentB
Storage for B tile.
Definition: mma_tensor_op.h:138
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for warp-level multiply-add operations.
Defines a Shape template for matrix tiles.
typename IteratorC::Fragment FragmentC
Storage for C tile.
Definition: mma_tensor_op.h:146
Definition: mma_tensor_op_tile_iterator.h:1794
CUTLASS_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C, int const &partitionN_idx=0) const
Performs a warp-level matrix multiply-accumulate operation.
Definition: mma_tensor_op.h:180
Policy::Operator mma
Underlying matrix multiply operator (concept: arch::Mma)
Definition: mma_tensor_op.h:166
ElementC_ ElementC
Data type of accumulator matrix C.
Definition: mma_tensor_op.h:100
Top-level include for all CUTLASS numeric types.
Definition: mma_tensor_op_tile_iterator.h:75
LayoutC_ LayoutC
Layout of accumulator matrix C.
Definition: mma_tensor_op.h:103
Policy_ Policy
Shape of the warp in units of thread (concept: MmaLanePolicySimt)
Definition: mma_tensor_op.h:106
Shape_ Shape
Shape of warp-level matrix operation (concept: GemmShape)
Definition: mma_tensor_op.h:85
ElementB_ ElementB
Data type of multiplicand B.
Definition: mma_tensor_op.h:94
static int const kPartitionsK
Number of partitions along K dimension.
Definition: mma_tensor_op.h:115
arch::OpClassTensorOp OperatorClass
Indicates class of matrix operator.
Definition: mma_tensor_op.h:109
ElementA_ ElementA
Data type of multiplicand A.
Definition: mma_tensor_op.h:88
Matrix multiply for SM75.
Basic include for CUTLASS.
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.