31 #include <cuda_fp16.h> 45 template <
typename LayoutA,
typename LayoutB,
typename LayoutC>
47 gemm::GemmShape<2,1,1>,
62 Array<half_t, 2>
const &a,
63 Array<half_t, 1>
const &b,
64 Array<half_t, 2>
const &c
67 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) 69 __half2
const & A =
reinterpret_cast<__half2
const &
>(a);
70 __half2 B = __half2half2(reinterpret_cast<__half const &>(b));
71 __half2
const & C =
reinterpret_cast<__half2
const &
>(c);
73 __half2 D = __hfma2(A, B, C);
75 d =
reinterpret_cast<Array<half_t, 2> &
>(D);
79 for (
int i = 0; i < 2; ++i) {
80 d[i] = a[i] * b[0] + c[i];
89 template <
typename LayoutA,
typename LayoutB>
91 gemm::GemmShape<1,2,1>,
106 Array<half_t, 1>
const &a,
107 Array<half_t, 2>
const &b,
108 Array<half_t, 2>
const &c
111 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) 113 __half2
const & A = __half2half2(reinterpret_cast<__half const &>(a));
114 __half2 B =
reinterpret_cast<__half2
const &
>(b);
115 __half2
const & C =
reinterpret_cast<__half2
const &
>(c);
117 __half2 D = __hfma2(A, B, C);
119 d =
reinterpret_cast<Array<half_t, 2> &
>(D);
123 for (
int i = 0; i < 2; ++i) {
124 d[i] = a[0] * b[i] + c[i];
135 gemm::GemmShape<2, 2, 1>,
150 Array<half_t, 2>
const &a,
151 Array<half_t, 2>
const &b,
152 Array<half_t, 4>
const &c
155 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) 157 __half2
const & A =
reinterpret_cast<__half2
const &
>(a);
158 __half2 Blo = __low2half2(reinterpret_cast<__half2 const &>(b));
159 __half2 Bhi = __high2half2(reinterpret_cast<__half2 const &>(b));
161 __half2
const *C =
reinterpret_cast<__half2
const *
>(&c);
163 __half2 Dlo = __hfma2(A, Blo, C[0]);
164 __half2 Dhi = __hfma2(A, Bhi, C[1]);
166 Array<half_t, 2> * D =
reinterpret_cast<Array<half_t, 2> *
>(&d);
168 D[0] =
reinterpret_cast<Array<half_t, 2>
const &
>(Dlo);
169 D[1] =
reinterpret_cast<Array<half_t, 2>
const &
>(Dhi);
173 for (
int j = 0; j < 2; ++j) {
175 for (
int i = 0; i < 2; ++i) {
176 d[i + 2 * j] = a[i] * b[j] + c[i + 2 * j];
188 gemm::GemmShape<2, 2, 1>,
203 Array<half_t, 2>
const &a,
204 Array<half_t, 2>
const &b,
205 Array<half_t, 4>
const &c
208 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) 210 __half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a));
211 __half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a));
212 __half2
const & B =
reinterpret_cast<__half2
const &
>(b);
214 __half2
const *C =
reinterpret_cast<__half2
const *
>(&c);
216 __half2 Dlo = __hfma2(Alo, B, C[0]);
217 __half2 Dhi = __hfma2(Ahi, B, C[0]);
219 Array<half_t, 2> * D =
reinterpret_cast<Array<half_t, 2> *
>(&d);
221 D[0] =
reinterpret_cast<Array<half_t, 2> &
>(Dlo);
222 D[1] =
reinterpret_cast<Array<half_t, 2> &
>(Dhi);
225 for (
int i = 0; i < 2; ++i) {
227 for (
int j = 0; j < 2; ++j) {
228 d[i * 2 + j] = a[i] * b[j] + c[i * 2 + j];
cutlass::arch::Mma< gemm::GemmShape< 1, 2, 1 >, 1, half_t, LayoutA, half_t, LayoutB, half_t, layout::RowMajor, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< half_t, 2 > &d, Array< half_t, 1 > const &a, Array< half_t, 2 > const &b, Array< half_t, 2 > const &c)
Definition: arch/mma_sm60.h:104
Definition: aligned_buffer.h:35
IEEE half-precision floating-point type.
Definition: half.h:126
cutlass::arch::Mma< gemm::GemmShape< 2, 1, 1 >, 1, half_t, LayoutA, half_t, LayoutB, half_t, LayoutC, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< half_t, 2 > &d, Array< half_t, 2 > const &a, Array< half_t, 1 > const &b, Array< half_t, 2 > const &c)
Definition: arch/mma_sm60.h:60
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for multiply-add operations.
cutlass::arch::Mma< gemm::GemmShape< 2, 2, 1 >, 1, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::ColumnMajor, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< half_t, 4 > &d, Array< half_t, 2 > const &a, Array< half_t, 2 > const &b, Array< half_t, 4 > const &c)
Definition: arch/mma_sm60.h:148
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Defines layout functions used by TensorRef and derived classes.
Matrix multiply-add operation.
Definition: arch/mma.h:92
cutlass::arch::Mma< gemm::GemmShape< 2, 2, 1 >, 1, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(Array< half_t, 4 > &d, Array< half_t, 2 > const &a, Array< half_t, 2 > const &b, Array< half_t, 4 > const &c)
Definition: arch/mma_sm60.h:201