53 struct Mma<gemm::GemmShape<1, 1, 1>, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> {
60 Array<float, 1>
const &a,
61 Array<float, 1>
const &b,
62 Array<float, 1>
const &c
64 d[0] = a[0] * b[0] + c[0];
79 struct Mma<gemm::GemmShape<1, 1, 1>, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd> {
86 Array<double, 1>
const &a,
87 Array<double, 1>
const &b,
88 Array<double, 1>
const &c
91 d[0] = a[0] * b[0] + c[0];
106 struct Mma<gemm::GemmShape<1, 1, 1>, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd> {
113 Array<int, 1>
const &a,
114 Array<int, 1>
const &b,
115 Array<int, 1>
const &c
118 d[0] = a[0] * b[0] + c[0];
134 gemm::GemmShape<1, 1, 1>,
154 d[0].real() = a[0].real() * b[0].real() + c[0].real();
155 d[0].imag() = a[0].imag() * b[0].real() + c[0].imag();
156 d[0].real() = -a[0].imag() * b[0].imag() + d[0].real();
157 d[0].imag() = a[0].real() * b[0].imag() + d[0].imag();
173 gemm::GemmShape<1, 1, 1>,
189 Array<float, 1>
const &b,
193 d[0].real() = a[0].real() * b[0] + c[0].real();
194 d[0].imag() = a[0].imag() * b[0] + c[0].imag();
210 gemm::GemmShape<1, 1, 1>,
225 Array<float, 1>
const &a,
230 d[0].real() = a[0] * b[0].real() + c[0].real();
231 d[0].imag() = a[0] * b[0].imag() + d[0].imag();
247 gemm::GemmShape<1, 1, 1>,
267 d[0].real() = a[0].real() * b[0].real() + c[0].real();
268 d[0].imag() = a[0].imag() * b[0].real() + c[0].imag();
269 d[0].real() = -a[0].imag() * b[0].imag() + d[0].real();
270 d[0].imag() = a[0].real() * b[0].imag() + d[0].imag();
284 gemm::GemmShape<1, 1, 1>,
300 Array<double, 1>
const &b,
304 d[0].real() = a[0].real() * b[0] + c[0].real();
305 d[0].imag() = a[0].imag() * b[0] + c[0].imag();
319 gemm::GemmShape<1, 1, 1>,
334 Array<double, 1>
const &a,
339 d[0].real() = a[0] * b[0].real() + c[0].real();
340 d[0].imag() = a[0] * b[0].imag() + d[0].imag();
355 struct Mma<gemm::GemmShape<1, 1, 1>, 1,
half_t, LayoutA,
half_t, LayoutB, float, LayoutC, OpMultiplyAdd> {
362 Array<half_t, 1>
const &a,
363 Array<half_t, 1>
const &b,
364 Array<float, 1>
const &c
366 d[0] = float(a[0]) * float(b[0]) + c[0];
cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< int, 1 > &d, Array< int, 1 > const &a, Array< int, 1 > const &b, Array< int, 1 > const &c)
Definition: arch/mma_sm50.h:111
cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, complex< double >, LayoutA, double, LayoutB, complex< double >, LayoutC, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< complex< double >, 1 > &d, Array< complex< double >, 1 > const &a, Array< double, 1 > const &b, Array< complex< double >, 1 > const &c)
Definition: arch/mma_sm50.h:297
Definition: aligned_buffer.h:35
IEEE half-precision floating-point type.
Definition: half.h:126
Defines common types used for all GEMM-like operators.
cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, complex< float >, LayoutA, complex< float >, LayoutB, complex< float >, LayoutC, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< complex< float >, 1 > &d, Array< complex< float >, 1 > const &a, Array< complex< float >, 1 > const &b, Array< complex< float >, 1 > const &c)
Definition: arch/mma_sm50.h:147
cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, complex< double >, LayoutA, complex< double >, LayoutB, complex< double >, LayoutC, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< complex< double >, 1 > &d, Array< complex< double >, 1 > const &a, Array< complex< double >, 1 > const &b, Array< complex< double >, 1 > const &c)
Definition: arch/mma_sm50.h:260
cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< float, 1 > &d, Array< half_t, 1 > const &a, Array< half_t, 1 > const &b, Array< float, 1 > const &c)
Definition: arch/mma_sm50.h:360
Templates exposing architecture support for multiply-add operations.
cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, complex< float >, LayoutA, float, LayoutB, complex< float >, LayoutC, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< complex< float >, 1 > &d, Array< complex< float >, 1 > const &a, Array< float, 1 > const &b, Array< complex< float >, 1 > const &c)
Definition: arch/mma_sm50.h:186
cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< float, 1 > &d, Array< float, 1 > const &a, Array< float, 1 > const &b, Array< float, 1 > const &c)
Definition: arch/mma_sm50.h:58
cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, float, LayoutA, complex< float >, LayoutB, complex< float >, LayoutC, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< complex< float >, 1 > &d, Array< float, 1 > const &a, Array< complex< float >, 1 > const &b, Array< complex< float >, 1 > const &c)
Definition: arch/mma_sm50.h:223
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Defines layout functions used by TensorRef and derived classes.
Matrix multiply-add operation.
Definition: arch/mma.h:92
cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, double, LayoutA, complex< double >, LayoutB, complex< double >, LayoutC, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< complex< double >, 1 > &d, Array< double, 1 > const &a, Array< complex< double >, 1 > const &b, Array< complex< double >, 1 > const &c)
Definition: arch/mma_sm50.h:332
cutlass::arch::Mma< gemm::GemmShape< 1, 1, 1 >, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< double, 1 > &d, Array< double, 1 > const &a, Array< double, 1 > const &b, Array< double, 1 > const &c)
Definition: arch/mma_sm50.h:84