36 #if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1)) 37 #define CUTLASS_ARCH_MMA_SM70_SUPPORTED 40 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) 42 #if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 &&__CUDACC_VER_MINOR__ >= 1)) 43 #define CUTLASS_ARCH_MMA_SM70_ENABLED 62 gemm::GemmShape<8,8,4>,
96 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) 98 unsigned const *A =
reinterpret_cast<unsigned const *
>(&a);
99 unsigned const *B =
reinterpret_cast<unsigned const *
>(&b);
100 unsigned const *C =
reinterpret_cast<unsigned const *
>(&c);
101 unsigned *D =
reinterpret_cast<unsigned *
>(&d);
103 asm volatile(
"mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" 104 :
"=r"(D[0]),
"=r"(D[1]),
"=r"(D[2]),
"=r"(D[3])
105 :
"r"(A[0]),
"r"(A[1]),
"r"(B[0]),
"r"(B[1]),
"r"(C[0]),
"r"(C[1]),
"r"(C[2]),
"r"(C[3])
117 gemm::GemmShape<8, 8, 4>,
151 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) 153 unsigned const *A =
reinterpret_cast<unsigned const *
>(&a);
154 unsigned const *B =
reinterpret_cast<unsigned const *
>(&b);
155 unsigned const *C =
reinterpret_cast<unsigned const *
>(&c);
156 unsigned *D =
reinterpret_cast<unsigned *
>(&d);
158 asm volatile(
"mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" 159 :
"=r"(D[0]),
"=r"(D[1]),
"=r"(D[2]),
"=r"(D[3])
160 :
"r"(A[0]),
"r"(A[1]),
"r"(B[0]),
"r"(B[1]),
"r"(C[0]),
"r"(C[1]),
"r"(C[2]),
"r"(C[3])
172 gemm::GemmShape<8, 8, 4>,
206 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) 208 unsigned const *A =
reinterpret_cast<unsigned const *
>(&a);
209 unsigned const *B =
reinterpret_cast<unsigned const *
>(&b);
210 unsigned const *C =
reinterpret_cast<unsigned const *
>(&c);
211 unsigned *D =
reinterpret_cast<unsigned *
>(&d);
213 asm volatile(
"mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" 214 :
"=r"(D[0]),
"=r"(D[1]),
"=r"(D[2]),
"=r"(D[3])
215 :
"r"(A[0]),
"r"(A[1]),
"r"(B[0]),
"r"(B[1]),
"r"(C[0]),
"r"(C[1]),
"r"(C[2]),
"r"(C[3])
227 gemm::GemmShape<8, 8, 4>,
261 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) 263 unsigned const *A =
reinterpret_cast<unsigned const *
>(&a);
264 unsigned const *B =
reinterpret_cast<unsigned const *
>(&b);
265 unsigned const *C =
reinterpret_cast<unsigned const *
>(&c);
266 unsigned *D =
reinterpret_cast<unsigned *
>(&d);
268 asm volatile(
"mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n" 269 :
"=r"(D[0]),
"=r"(D[1]),
"=r"(D[2]),
"=r"(D[3])
270 :
"r"(A[0]),
"r"(A[1]),
"r"(B[0]),
"r"(B[1]),
"r"(C[0]),
"r"(C[1]),
"r"(C[2]),
"r"(C[3])
288 gemm::GemmShape<8, 8, 4>,
323 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) 325 unsigned const *A =
reinterpret_cast<unsigned const *
>(&a);
326 unsigned const *B =
reinterpret_cast<unsigned const *
>(&b);
327 float const *C =
reinterpret_cast<float const *
>(&c);
328 float *D =
reinterpret_cast<float *
>(&d);
330 asm volatile(
"mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " 331 "{%12,%13,%14,%15,%16,%17,%18,%19};\n" 363 gemm::GemmShape<8, 8, 4>,
398 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) 400 unsigned const *A =
reinterpret_cast<unsigned const *
>(&a);
401 unsigned const *B =
reinterpret_cast<unsigned const *
>(&b);
402 float const *C =
reinterpret_cast<float const *
>(&c);
403 float *D =
reinterpret_cast<float *
>(&d);
405 asm volatile(
"mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " 406 "{%12,%13,%14,%15,%16,%17,%18,%19};\n" 438 gemm::GemmShape<8, 8, 4>,
473 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) 475 unsigned const *A =
reinterpret_cast<unsigned const *
>(&a);
476 unsigned const *B =
reinterpret_cast<unsigned const *
>(&b);
477 float const *C =
reinterpret_cast<float const *
>(&c);
478 float *D =
reinterpret_cast<float *
>(&d);
480 asm volatile(
"mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " 481 "{%12,%13,%14,%15,%16,%17,%18,%19};\n" 513 gemm::GemmShape<8, 8, 4>,
548 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED) 550 unsigned const *A =
reinterpret_cast<unsigned const *
>(&a);
551 unsigned const *B =
reinterpret_cast<unsigned const *
>(&b);
552 float const *C =
reinterpret_cast<float const *
>(&c);
553 float *D =
reinterpret_cast<float *
>(&d);
555 asm volatile(
"mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, " 556 "{%12,%13,%14,%15,%16,%17,%18,%19};\n" 596 gemm::GemmShape<16, 16, 4>,
607 gemm::GemmShape<8, 8, 4>,
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentC Array< half_t, 8 > FragmentC
Definition: mma_sm70.h:84
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentB Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:245
Definition: aligned_buffer.h:35
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentC Array< float, 8 > FragmentC
Definition: mma_sm70.h:535
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentC Array< half_t, 8 > FragmentC
Definition: mma_sm70.h:194
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::ElementC float ElementC
Definition: mma_sm70.h:308
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Definition: mma_sm70.h:199
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::Operator OpMultiplyAdd Operator
Definition: mma_sm70.h:86
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentA Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:131
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::Operator OpMultiplyAdd Operator
Definition: mma_sm70.h:312
IEEE half-precision floating-point type.
Definition: half.h:126
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentC Array< float, 8 > FragmentC
Definition: mma_sm70.h:310
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Multiply-add.
Definition: mma_sm70.h:391
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentB Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:80
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentC Array< float, 8 > FragmentC
Definition: mma_sm70.h:385
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::ElementC float ElementC
Definition: mma_sm70.h:458
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Definition: mma_sm70.h:89
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::ElementC float ElementC
Definition: mma_sm70.h:383
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Multiply-add.
Definition: mma_sm70.h:316
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentA Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:241
Templates exposing architecture support for multiply-add operations.
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentB Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:531
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Definition: mma_sm70.h:254
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::Operator OpMultiplyAdd Operator
Definition: mma_sm70.h:251
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Definition: mma_sm70.h:144
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentA Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:377
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentB Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:306
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::ElementC float ElementC
Definition: mma_sm70.h:533
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentA Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:76
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentA Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:186
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentA Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:302
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::Operator OpMultiplyAdd Operator
Definition: mma_sm70.h:196
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentC Array< half_t, 8 > FragmentC
Definition: mma_sm70.h:249
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentA Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:527
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentC Array< float, 8 > FragmentC
Definition: mma_sm70.h:460
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::Operator OpMultiplyAdd Operator
Definition: mma_sm70.h:141
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentB Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:190
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Multiply-add.
Definition: mma_sm70.h:541
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentC Array< half_t, 8 > FragmentC
Definition: mma_sm70.h:139
Defines layout functions used by TensorRef and derived classes.
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, half_t, layout::RowMajor, OpMultiplyAdd >::FragmentB Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:135
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::Operator OpMultiplyAdd Operator
Definition: mma_sm70.h:387
Matrix multiply-add operation.
Definition: arch/mma.h:92
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::operator() CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Multiply-add.
Definition: mma_sm70.h:466
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::Operator OpMultiplyAdd Operator
Definition: mma_sm70.h:537
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::Operator OpMultiplyAdd Operator
Definition: mma_sm70.h:462
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentA Array< half_t, 4 > FragmentA
Definition: mma_sm70.h:452
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::RowMajor, half_t, layout::ColumnMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentB Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:456
cutlass::arch::Mma< gemm::GemmShape< 8, 8, 4 >, 8, half_t, layout::ColumnMajor, half_t, layout::RowMajor, float, layout::RowMajor, OpMultiplyAdd >::FragmentB Array< half_t, 4 > FragmentB
Definition: mma_sm70.h:381