60 typename RealElementA,
64 typename RealElementB,
68 typename RealElementC,
78 typename Enable =
bool 89 typename RealElementA,
93 typename RealElementB,
97 typename RealElementC,
156 static int const kThreadCount = 32;
167 Policy::OpDelta::kRow,
182 Policy::OpDelta::kColumn,
192 !(Shape::kM % Policy::Operator::Shape::kM) &&
193 !(Shape::kN % Policy::Operator::Shape::kN),
194 "Shape of warp-level Mma must be divisible by operator shape.");
198 Shape::kM / Policy::Operator::Shape::kM,
199 Shape::kN / Policy::Operator::Shape::kN
207 typename Policy::Operator::Shape,
208 typename Policy::OpDelta>;
217 FragmentC::kElements == 2 * MmaIterations::kCount * Policy::Operator::FragmentC::kElements,
218 "Unexpected planar complex fragment length.");
227 typename Policy::Operator mma;
248 using MmaOperandA =
typename Policy::Operator::FragmentA;
249 using MmaOperandB =
typename Policy::Operator::FragmentB;
250 using MmaOperandC =
typename Policy::Operator::FragmentC;
253 "This implementation only supports math instructions in which exactly one element is needed for the A operand." 254 "We can geneneralize later.");
257 "This implementation only supports math instructions in which exactly one element is needed for the A operand." 258 "We can geneneralize later.");
263 for (
int m = 0; m < MmaIterations::kRow; ++m) {
267 for (
int n = 0; n < MmaIterations::kColumn; ++n) {
270 MmaOperandA operand_A;
271 MmaOperandB operand_B;
273 operand_A[0] = A[m].real();
274 operand_B[0] = B[n].real();
277 MmaOperandC *accum =
reinterpret_cast<MmaOperandC *
>(&D) +
278 (m + n * MmaIterations::kRow);
280 mma(*accum, operand_A, operand_B, *accum);
285 for (
int n = MmaIterations::kColumn - 1; n >= 0; --n) {
288 MmaOperandA operand_A;
289 MmaOperandB operand_B;
291 operand_A[0] = A[m].real();
295 MmaOperandC *accum =
reinterpret_cast<MmaOperandC *
>(&D) +
296 (m + n * MmaIterations::kRow) + MmaIterations::kCount;
298 mma(*accum, operand_A, operand_B, *accum);
303 for (
int n = 0; n < MmaIterations::kColumn; ++n) {
306 MmaOperandA operand_A;
307 MmaOperandB operand_B;
314 MmaOperandC *accum =
reinterpret_cast<MmaOperandC *
>(&D) +
315 (m + n * MmaIterations::kRow);
317 mma(*accum, operand_A, operand_B, *accum);
322 for (
int n = MmaIterations::kColumn - 1; n >= 0; --n) {
325 MmaOperandA operand_A;
326 MmaOperandB operand_B;
329 operand_B[0] = B[n].real();
332 MmaOperandC *accum =
reinterpret_cast<MmaOperandC *
>(&D) +
333 (m + n * MmaIterations::kRow) + MmaIterations::kCount;
335 mma(*accum, operand_A, operand_B, *accum);
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
ComplexTransform
Enumeraed type describing a transformation on a complex value.
Definition: complex.h:43
Architecture-specific operators on memory added for SM75.
Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::FragmentA typename IteratorA::Fragment FragmentA
Storage for A tile.
Definition: mma_complex_tensor_op.h:173
Defines common types used for all GEMM-like operators.
cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::LayoutB LayoutB_ LayoutB
Layout of multiplicand B.
Definition: mma_complex_tensor_op.h:135
cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::FragmentB typename IteratorB::Fragment FragmentB
Storage for B tile.
Definition: mma_complex_tensor_op.h:188
Definition: mma_complex_tensor_op.h:80
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for warp-level multiply-add operations.
Defines a Shape template for matrix tiles.
Definition: mma_tensor_op_tile_iterator.h:1794
cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::FragmentC typename IteratorC::Fragment FragmentC
Definition: mma_complex_tensor_op.h:214
cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::Shape Shape_ Shape
Shape of warp-level matrix operation (concept: GemmShape)
Definition: mma_complex_tensor_op.h:123
cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::MmaComplexTensorOp CUTLASS_DEVICE MmaComplexTensorOp()
Ctor.
Definition: mma_complex_tensor_op.h:237
Top-level include for all CUTLASS numeric types.
Definition: mma_tensor_op_tile_iterator.h:75
cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::OperatorClass arch::OpClassTensorOp OperatorClass
Indicates class of matrix operator.
Definition: mma_complex_tensor_op.h:153
Matrix multiply for SM75.
cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::Policy Policy_ Policy
Shape of the warp in units of thread (concept: MmaLanePolicySimt)
Definition: mma_complex_tensor_op.h:144
Templates implementing warp-level matrix multiply-accumulate operations targeting Tensor Cores...
Basic include for CUTLASS.
cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::operator() CUTLASS_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C) const
Performs a warp-level matrix multiply-accumulate operation.
Definition: mma_complex_tensor_op.h:241
cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::LayoutC LayoutC_ LayoutC
Layout of accumulator matrix C.
Definition: mma_complex_tensor_op.h:141
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.
cutlass::gemm::warp::MmaComplexTensorOp< Shape_, complex< RealElementA >, LayoutA_, complex< RealElementB >, LayoutB_, complex< RealElementC >, LayoutC_, Policy_, TransformA, TransformB, Enable >::LayoutA LayoutA_ LayoutA
Layout of multiplicand A.
Definition: mma_complex_tensor_op.h:129