59 cutlass::arch::OpMultiplyAdd
62 #if defined(CUTLASS_ARCH_WMMA_SM70_ENABLED) 65 using LayoutA = LayoutA_;
67 using LayoutB = LayoutB_;
68 using ElementC = ElementC_;
69 using LayoutC = LayoutC_;
70 using Operator = cutlass::arch::OpMultiplyAdd;
77 "Supported list of wmma operator shape for f16 multiplicands are: 16x16x16, 8x328x16, and 32x8x16");
82 "Supported of wmma output data type for f16 multiplicands are: f16 and f32");
85 using FragmentA = nvcuda::wmma::fragment<
86 nvcuda::wmma::matrix_a,
90 typename CutlassToWmmaDataType<ElementA>::Type,
91 typename CutlassToWmmaLayout<LayoutA>::Layout>;
93 using FragmentB = nvcuda::wmma::fragment<
94 nvcuda::wmma::matrix_b,
98 typename CutlassToWmmaDataType<ElementB>::Type,
99 typename CutlassToWmmaLayout<LayoutB>::Layout>;
101 using FragmentC = nvcuda::wmma::fragment<
102 nvcuda::wmma::accumulator,
106 typename CutlassToWmmaDataType<ElementC>::Type>;
114 FragmentC
const &C)
const {
116 nvcuda::wmma::mma_sync(D, A, B, C);
119 static_assert(
false,
"wmma.mma.sync for floating point multiplicands is avialable only for SM70 and beyond");
Definition: aligned_buffer.h:35
IEEE half-precision floating-point type.
Definition: half.h:126
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Defines layout functions used by TensorRef and derived classes.