73 template <
typename Shape>
84 "Mma_HFMA2 requires the M dimension to be divisible by 2." 121 arch::OpMultiplyAdd>;
123 Array<half_t, 2> *ptr_D =
reinterpret_cast<Array<half_t, 2> *
>(&D);
124 Array<half_t, 2>
const *ptr_A =
reinterpret_cast<Array<half_t, 2>
const *
>(&A);
125 Array<half_t, 1>
const *ptr_B =
reinterpret_cast<Array<half_t, 1>
const *
>(&B);
130 for(
auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
133 for(
auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
136 for(
auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
138 Array<half_t, 2> tmp;
139 Array<half_t, 2> *ptr_tmp = &tmp;
140 ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];
144 ptr_A[k*Shape::kM/2 + m],
145 ptr_B[n*Shape::kK + k],
148 ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];
159 template <
typename Shape>
170 "Mma_HFMA2 requires the N dimension to be divisible by 2." 207 arch::OpMultiplyAdd>;
209 Array<half_t, 2> *ptr_D =
reinterpret_cast<Array<half_t, 2> *
>(&D);
210 Array<half_t, 1>
const *ptr_A =
reinterpret_cast<Array<half_t, 1>
const *
>(&A);
211 Array<half_t, 2>
const *ptr_B =
reinterpret_cast<Array<half_t, 2>
const *
>(&B);
216 for(
auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
219 for(
auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
222 for(
auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
224 Array<half_t, 2> tmp;
225 Array<half_t, 2> *ptr_tmp = &tmp;
226 ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
228 Array<half_t, 2> tmp_B;
229 tmp_B[0] = ptr_B->at(2*n*Shape::kK + k);
230 tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k);
234 ptr_A[k*Shape::kM + m],
238 ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
250 template <
typename Shape>
261 "Mma_HFMA2 requires the GEMM M dimension to be divisible by 2." 297 arch::OpMultiplyAdd>;
299 Array<half_t, 2> *ptr_D =
reinterpret_cast<Array<half_t, 2> *
>(&D);
300 Array<half_t, 2>
const *ptr_A =
reinterpret_cast<Array<half_t, 2>
const *
>(&A);
301 Array<half_t, 1>
const *ptr_B =
reinterpret_cast<Array<half_t, 1>
const *
>(&B);
306 for (
int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) {
309 for (
int m = 0; m < Shape::kM / Mma::Shape::kM; ++m) {
312 for (
int n = 0; n < Shape::kN / Mma::Shape::kN; ++n) {
314 Array<half_t, 2> tmp;
315 Array<half_t, 2> *ptr_tmp = &tmp;
317 ptr_tmp[0] = ptr_D[m + n * Shape::kM/2];
321 ptr_A[m + k * Shape::kM/2],
322 ptr_B[k * Shape::kN + n],
325 ptr_D[m + n * Shape::kM/2] = ptr_tmp[0];
336 template <
typename Shape>
347 "Mma_HFMA2 requires the N dimension to be divisible by 2." 384 arch::OpMultiplyAdd>;
386 Array<half_t, 2> *ptr_D =
reinterpret_cast<Array<half_t, 2> *
>(&D);
387 Array<half_t, 1>
const *ptr_A =
reinterpret_cast<Array<half_t, 1>
const *
>(&A);
388 Array<half_t, 2>
const *ptr_B =
reinterpret_cast<Array<half_t, 2>
const *
>(&B);
393 for(
auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
396 for(
auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
399 for(
auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
401 Array<half_t, 2> tmp;
402 Array<half_t, 2> *ptr_tmp = &tmp;
403 ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
407 ptr_A[k*Shape::kM + m],
408 ptr_B[k*Shape::kN/2 + n],
411 ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
423 template <
typename Shape>
434 "Mma_HFMA2 requires the M dimension to be divisible by 2." 471 arch::OpMultiplyAdd>;
473 Array<half_t, 2> *ptr_D =
reinterpret_cast<Array<half_t, 2> *
>(&D);
474 Array<half_t, 2>
const *ptr_A =
reinterpret_cast<Array<half_t, 2>
const *
>(&A);
475 Array<half_t, 1>
const *ptr_B =
reinterpret_cast<Array<half_t, 1>
const *
>(&B);
480 for(
auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
483 for(
auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
486 for(
auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
488 Array<half_t, 2> tmp;
489 Array<half_t, 2> *ptr_tmp = &tmp;
490 ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];
492 Array<half_t, 2> tmp_A;
493 tmp_A[0] = ptr_A->at(2*m*Shape::kK + k);
494 tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k);
499 ptr_B[n*Shape::kK + k],
502 ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];
513 template <
typename Shape>
524 "Mma_HFMA2 requires the N dimension to be divisible by 2." 561 arch::OpMultiplyAdd>;
563 Array<half_t, 2> *ptr_D =
reinterpret_cast<Array<half_t, 2> *
>(&D);
564 Array<half_t, 1>
const *ptr_A =
reinterpret_cast<Array<half_t, 1>
const *
>(&A);
565 Array<half_t, 2>
const *ptr_B =
reinterpret_cast<Array<half_t, 2>
const *
>(&B);
570 for(
auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
573 for(
auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
576 for(
auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
578 Array<half_t, 2> tmp;
579 Array<half_t, 2> *ptr_tmp = &tmp;
580 ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
582 Array<half_t, 2> tmp_B;
583 tmp_B[0] = ptr_B->at(2*n*Shape::kK + k);
584 tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k);
588 ptr_A[m*Shape::kK + k],
592 ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
603 template <
typename Shape>
614 "Mma_HFMA2 requires the M dimension to be divisible by 2." 651 arch::OpMultiplyAdd>;
653 Array<half_t, 2> *ptr_D =
reinterpret_cast<Array<half_t, 2> *
>(&D);
654 Array<half_t, 2>
const *ptr_A =
reinterpret_cast<Array<half_t, 2>
const *
>(&A);
655 Array<half_t, 1>
const *ptr_B =
reinterpret_cast<Array<half_t, 1>
const *
>(&B);
660 for(
auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
663 for(
auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
666 for(
auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
668 Array<half_t, 2> tmp;
669 Array<half_t, 2> *ptr_tmp = &tmp;
670 ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];
672 Array<half_t, 2> tmp_A;
673 tmp_A[0] = ptr_A->at(2*m*Shape::kK + k);
674 tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k);
679 ptr_B[k*Shape::kN + n],
682 ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];
694 template <
typename Shape>
705 "Mma_HFMA2 requires the N dimension to be divisible by 2." 742 arch::OpMultiplyAdd>;
744 Array<half_t, 2> *ptr_D =
reinterpret_cast<Array<half_t, 2> *
>(&D);
745 Array<half_t, 1>
const *ptr_A =
reinterpret_cast<Array<half_t, 1>
const *
>(&A);
746 Array<half_t, 2>
const *ptr_B =
reinterpret_cast<Array<half_t, 2>
const *
>(&B);
751 for(
auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
754 for(
auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
757 for(
auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
759 Array<half_t, 2> tmp;
760 Array<half_t, 2> *ptr_tmp = &tmp;
761 ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
765 ptr_A[m*Shape::kK + k],
766 ptr_B[k*Shape::kN/2 + n],
769 ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
780 template <
typename Shape,
typename LayoutA,
typename LayoutB>
791 "Mma_HFMA2 requires the K dimension to be divisible by 2." 821 Array<half_t, 1> *ptr_D =
reinterpret_cast<Array<half_t, 1> *
>(&D);
822 Array<half_t, 2>
const *ptr_A =
reinterpret_cast<Array<half_t, 2>
const *
>(&A);
823 Array<half_t, 2>
const *ptr_B =
reinterpret_cast<Array<half_t, 2>
const *
>(&B);
835 Array<half_t, 2> tmp_C;
837 Array<half_t, 1> *ptr_tmp_C =
reinterpret_cast<Array<half_t, 1> *
>(&tmp_C);
838 ptr_tmp_C[0] = ptr_D[n*Shape::kM + m];
842 tmp_C =
mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C);
845 Array<half_t, 1> res;
846 Array<half_t, 1> *ptr_res = &res;
849 ptr_D[m*Shape::kN + n] = ptr_res[0];
859 template <
typename Shape,
typename LayoutA,
typename LayoutB>
870 "Mma_HFMA2 requires the K dimension to be divisible by 2." 900 Array<half_t, 1> *ptr_D =
reinterpret_cast<Array<half_t, 1> *
>(&D);
901 Array<half_t, 2>
const *ptr_A =
reinterpret_cast<Array<half_t, 2>
const *
>(&A);
902 Array<half_t, 2>
const *ptr_B =
reinterpret_cast<Array<half_t, 2>
const *
>(&B);
914 Array<half_t, 2> tmp_C;
916 Array<half_t, 1> *ptr_tmp_C =
reinterpret_cast<Array<half_t, 1> *
>(&tmp_C);
917 ptr_tmp_C[0] = ptr_D[n*Shape::kM + m];
922 tmp_C =
mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C);
926 Array<half_t, 1> res;
927 Array<half_t, 1> *ptr_res = &res;
930 ptr_D[n*Shape::kM + m] = ptr_res[0];
943 typename Shape_,
typename LayoutA,
typename LayoutB,
typename LayoutC
997 constexpr bool m_mod2 = !(Shape::kM % 2);
998 constexpr bool n_mod2 = !(Shape::kN % 2);
999 constexpr bool k_mod2 = !(Shape::kK % 2);
1007 constexpr bool use_outer_prod = (c_column_major && m_mod2) || (c_row_major && n_mod2);
1008 constexpr bool use_inner_prod = (a_row_major && b_column_major && k_mod2) || (Shape::kM==1 && Shape::kN==1 && k_mod2);
1009 constexpr bool use_optimized = (use_outer_prod || use_inner_prod);
1032 static bool const kIsConventionalLayout =
1038 static bool const value = kIsConventionalLayout;
1059 arch::OpMultiplyAdd,
1067 using LayoutA = LayoutA_;
1069 using LayoutB = LayoutB_;
1082 arch::OpMultiplyAdd,
Fused multiply-add.
Definition: functional.h:92
Determines whether to enable thread::Gemm<> specializations compatible with SM50. ...
Definition: gemm/thread/mma_sm60.h:1030
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:801
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:271
Definition: aligned_buffer.h:35
Defines a structure containing strides, bounds, and a pointer to tensor data.
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:94
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:528
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:809
Structure to compute the matrix product for HFMA.
Definition: gemm/thread/mma_sm60.h:66
Array< ElementC, Shape::kMN > FragmentC
Definition: gemm/thread/mma_sm60.h:1087
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:441
IEEE half-precision floating-point type.
Definition: half.h:126
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Definition: gemm/thread/mma_sm60.h:1090
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:444
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:438
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:357
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:102
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:723
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:632
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:174
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:712
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:624
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
arch::OpMultiplyAdd Operator
Definition: gemm/thread/mma_sm60.h:1072
static int const kK
Definition: include/cutlass/gemm/gemm.h:60
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:177
Array< ElementB, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:975
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:91
Array< ElementB, Shape::kKN > FragmentB
Definition: gemm/thread/mma_sm60.h:1086
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:365
arch::OpMultiplyAdd Operator
Underlying mathematical operator.
Definition: gemm/thread/mma_sm60.h:969
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:709
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:531
Gemplate that handles all packed matrix layouts.
Definition: gemm/thread/mma_sm50.h:65
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:888
Defines basic thread level reduction with specializations for Array<T, N>.
Array< ElementC, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:978
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:188
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Templates exposing architecture support for warp-level multiply-add operations.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:265
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:88
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:880
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:986
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:795
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:351
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:452
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:621
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<>
Definition: gemm/thread/mma_sm60.h:957
Structure to compute the matrix product.
Definition: gemm/thread/mma.h:66
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:877
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:715
Defines layout functions used by TensorRef and derived classes.
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:534
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:180
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:798
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:618
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:268
Matrix multiply-add operation.
Definition: arch/mma.h:92
Array< ElementA, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:972
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:354
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:542
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:279
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
Array< ElementA, Shape::kMK > FragmentA
Definition: gemm/thread/mma_sm60.h:1085
Structure to compute the thread level reduction.
Definition: reduce.h:43
CUTLASS_HOST_DEVICE Array< T, N > mac(Array< T, N > const &a, Array< T, N > const &b, Array< T, N > const &c)
Definition: simd.h:84
Shape_ Shape
Definition: gemm/thread/mma_sm60.h:1065
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:874
static int const kN
Definition: include/cutlass/gemm/gemm.h:59