75 typename Enable =
bool 114 "Shape must be a multiple of InterleavedTileShape.");
124 Policy::Operator::Shape::kM,
125 Policy::Operator::Shape::kK
127 Policy::OpDelta::kRow,
141 Policy::Operator::Shape::kK,
142 Policy::Operator::Shape::kN
144 Policy::OpDelta::kRow,
156 typename Policy::Operator::Shape,
157 typename Policy::OpDelta
166 !(Shape::kM % Policy::Operator::Shape::kM) &&
167 !(Shape::kN % Policy::Operator::Shape::kN),
168 "Shape of warp-level Mma must be divisible by operator shape.");
172 InterleavedTileShape::kM / Policy::Operator::Shape::kM,
186 typename Policy::Operator
mma;
205 int const &partitionN_idx = 0) {
207 using MmaOperandA =
typename Policy::Operator::FragmentA;
208 using MmaOperandB =
typename Policy::Operator::FragmentB;
209 using MmaOperandC =
typename Policy::Operator::FragmentC;
213 MmaOperandA
const *ptr_A =
reinterpret_cast<MmaOperandA
const *
>(&A);
214 MmaOperandB
const *ptr_B =
reinterpret_cast<MmaOperandB
const *
>(&B);
215 MmaOperandC *ptr_D =
reinterpret_cast<MmaOperandC *
>(&D);
227 int op_col = inner_col + MmaIterations::kColumn * outer_col;
230 int inner_row_serp = inner_row;
231 int outer_row_serp = outer_row;
233 inner_row_serp = MmaIterations::kRow - inner_row - 1;
234 outer_row_serp = TileIterations::kRow - outer_row - 1;
236 int op_row = inner_row_serp + MmaIterations::kRow * outer_row_serp;
237 int op_idx = inner_row_serp + MmaIterations::kRow *
238 (inner_col + MmaIterations::kColumn *
239 (outer_row_serp + TileIterations::kRow * outer_col));
Policy_ Policy
Shape of the warp in units of thread (concept: MmaLanePolicySimt)
Definition: mma_tensor_op_sm70.h:101
typename IteratorB::Fragment FragmentB
Storage for B tile.
Definition: mma_tensor_op_sm70.h:149
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
LayoutB_ LayoutB
Layout of multiplicand B.
Definition: mma_tensor_op_sm70.h:92
Definition: mma_tensor_op_tile_iterator_sm70.h:70
Definition: aligned_buffer.h:35
static int const kColumn
columns of a matrix
Definition: matrix_shape.h:44
Defines common types used for all GEMM-like operators.
Shape_ Shape
Shape of warp-level matrix operation (concept: GemmShape)
Definition: mma_tensor_op_sm70.h:80
arch::OpClassTensorOp OperatorClass
Indicates class of matrix operator.
Definition: mma_tensor_op_sm70.h:104
LayoutA_ LayoutA
Layout of multiplicand A.
Definition: mma_tensor_op_sm70.h:86
Array< Element, Shape::kCount/kThreads > Fragment
Fragment object holding a thread's part of a tile.
Definition: mma_tensor_op_tile_iterator_sm70.h:1213
ElementA_ ElementA
Data type of multiplicand A.
Definition: mma_tensor_op_sm70.h:83
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for warp-level multiply-add operations.
Templates exposing architecture support for multiply-add operations.
Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
Defines a Shape template for matrix tiles.
ElementB_ ElementB
Data type of multiplicand B.
Definition: mma_tensor_op_sm70.h:89
CUTLASS_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C, int const &partitionN_idx=0)
Performs a warp-level matrix multiply-accumulate operation.
Definition: mma_tensor_op_sm70.h:200
Policy::Operator mma
Underlying matrix multiply operator (concept: arch::Mma)
Definition: mma_tensor_op_sm70.h:186
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_tensor_op_sm70.h:77
static int const kRow
rows of a matrix
Definition: matrix_shape.h:43
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
typename IteratorC::Fragment FragmentC
Storage for C tile.
Definition: mma_tensor_op_sm70.h:161
LayoutC_ LayoutC
Layout of accumulator matrix C.
Definition: mma_tensor_op_sm70.h:98
static int const kThreadCount
Number of threads participating in warp-level matrix product.
Definition: mma_tensor_op_sm70.h:107
CUTLASS_DEVICE MmaVoltaTensorOp()
Ctor.
Definition: mma_tensor_op_sm70.h:196
Definition: mma_tensor_op_tile_iterator_sm70.h:1135
ElementC_ ElementC
Data type of accumulator matrix C.
Definition: mma_tensor_op_sm70.h:95
typename IteratorA::Fragment FragmentA
Storage for A tile.
Definition: mma_tensor_op_sm70.h:132
Basic include for CUTLASS.
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.
static int const kN
Definition: include/cutlass/gemm/gemm.h:59