57 cutlass::arch::OpMultiplyAdd
59 #if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED) 62 using LayoutA = LayoutA_;
64 using LayoutB = LayoutB_;
65 using ElementC = int32_t;
66 using LayoutC = LayoutC_;
67 using Operator = cutlass::arch::OpMultiplyAdd;
72 "Supported list of wmma operator shape for s8 multiplicands is: 8x8x32");
76 using FragmentA = nvcuda::wmma::fragment<
77 nvcuda::wmma::matrix_a,
81 typename CutlassToWmmaDataType<ElementA>::Type,
82 typename CutlassToWmmaLayout<LayoutA>::Layout>;
84 using FragmentB = nvcuda::wmma::fragment<
85 nvcuda::wmma::matrix_b,
89 typename CutlassToWmmaDataType<ElementB>::Type,
90 typename CutlassToWmmaLayout<LayoutB>::Layout>;
92 using FragmentC = nvcuda::wmma::fragment<
93 nvcuda::wmma::accumulator,
97 typename CutlassToWmmaDataType<ElementC>::Type>;
105 FragmentC
const &C)
const {
106 nvcuda::wmma::mma_sync(D, A, B, C);
110 static_assert(
false,
"wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond");
135 cutlass::arch::OpXorPopc
137 #if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED) 138 using Shape = Shape_;
140 using LayoutA = LayoutA_;
142 using LayoutB = LayoutB_;
143 using ElementC = int32_t;
144 using LayoutC = LayoutC_;
145 using Operator = cutlass::arch::OpXorPopc;
150 "Supported list of wmma operator shape for b1 multiplicands is: 8x8x128");
154 using FragmentA = nvcuda::wmma::fragment<
155 nvcuda::wmma::matrix_a,
159 typename CutlassToWmmaDataType<ElementA>::Type,
160 typename CutlassToWmmaLayout<LayoutA>::Layout>;
162 using FragmentB = nvcuda::wmma::fragment<
163 nvcuda::wmma::matrix_b,
167 typename CutlassToWmmaDataType<ElementB>::Type,
168 typename CutlassToWmmaLayout<LayoutB>::Layout>;
170 using FragmentC = nvcuda::wmma::fragment<
171 nvcuda::wmma::accumulator,
175 typename CutlassToWmmaDataType<ElementC>::Type>;
183 FragmentC
const &C)
const {
185 nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
186 nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
190 static_assert(
false,
"wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond");
Definition: aligned_buffer.h:35
integer_subbyte< 1, false > uint1b_t
1-bit Unsigned integer type
Definition: integer_subbyte.h:152
4-bit signed integer type
Definition: integer_subbyte.h:42
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Defines layout functions used by TensorRef and derived classes.
integer_subbyte< 4, true > int4b_t
4-bit Integer type
Definition: integer_subbyte.h:155