CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm/thread/mma_sm60.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/tensor_ref.h"
33 #include "cutlass/layout/matrix.h"
34 #include "cutlass/gemm/gemm.h"
36 #include "cutlass/functional.h"
38 
40 
41 namespace cutlass {
42 namespace gemm {
43 namespace thread {
44 
46 
47 namespace detail {
48 
50 template <
52  typename Shape,
53 
55  typename LayoutA,
56 
58  typename LayoutB,
59 
61  typename LayoutC,
62 
64  bool
65 >
66 struct Mma_HFMA2;
67 
68 
70 // Specialization for NNN //
72 
73 template <typename Shape>
74 struct Mma_HFMA2 <
75  Shape,
76  layout::ColumnMajor,
79  true
80  > {
81 
83  !(Shape::kM % 2),
84  "Mma_HFMA2 requires the M dimension to be divisible by 2."
85  );
86 
88  using FragmentA = Array<half_t, Shape::kMK>;
89 
91  using FragmentB = Array<half_t, Shape::kKN>;
92 
94  using FragmentC = Array<half_t, Shape::kMN>;
95 
96  //
97  // Methods
98  //
99 
103  FragmentC & D,
104  FragmentA const & A,
105  FragmentB const & B,
106  FragmentC const & C) {
107 
109  D = C;
110 
112  using Mma = arch::Mma<
114  1,
115  half_t,
117  half_t,
118  layout::ColumnMajor,
119  half_t,
120  layout::ColumnMajor,
121  arch::OpMultiplyAdd>;
122 
123  Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
124  Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);
125  Array<half_t, 1> const *ptr_B = reinterpret_cast<Array<half_t, 1> const *>(&B);
126 
127  Mma mma;
128 
130  for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
131 
133  for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
134 
136  for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
137 
138  Array<half_t, 2> tmp;
139  Array<half_t, 2> *ptr_tmp = &tmp;
140  ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];
141 
142  mma(
143  tmp,
144  ptr_A[k*Shape::kM/2 + m],
145  ptr_B[n*Shape::kK + k],
146  tmp);
147 
148  ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];
149  }
150  }
151  }
152  }
153 };
154 
156 // Specialization for NNT //
158 
159 template <typename Shape>
160 struct Mma_HFMA2<
161  Shape,
162  layout::ColumnMajor,
165  true
166  > {
167 
169  !(Shape::kN % 2),
170  "Mma_HFMA2 requires the N dimension to be divisible by 2."
171  );
172 
174  using FragmentA = Array<half_t, Shape::kMK>;
175 
177  using FragmentB = Array<half_t, Shape::kKN>;
178 
180  using FragmentC = Array<half_t, Shape::kMN>;
181 
182  //
183  // Methods
184  //
185 
189  FragmentC & D,
190  FragmentA const & A,
191  FragmentB const & B,
192  FragmentC const & C) {
193 
195  D = C;
196 
198  using Mma = arch::Mma<
200  1,
201  half_t,
203  half_t,
204  layout::ColumnMajor,
205  half_t,
207  arch::OpMultiplyAdd>;
208 
209  Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
210  Array<half_t, 1> const *ptr_A = reinterpret_cast<Array<half_t, 1> const *>(&A);
211  Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);
212 
213  Mma mma;
214 
216  for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
217 
219  for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
220 
222  for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
223 
224  Array<half_t, 2> tmp;
225  Array<half_t, 2> *ptr_tmp = &tmp;
226  ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
227 
228  Array<half_t, 2> tmp_B;
229  tmp_B[0] = ptr_B->at(2*n*Shape::kK + k);
230  tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k);
231 
232  mma(
233  tmp,
234  ptr_A[k*Shape::kM + m],
235  tmp_B,
236  tmp);
237 
238  ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
239  }
240  }
241  }
242  }
243 };
244 
245 
247 // Specialization for NTN //
249 
250 template <typename Shape>
251 struct Mma_HFMA2 <
252  Shape,
253  layout::ColumnMajor,
256  true
257  > {
258 
260  !(Shape::kM % 2),
261  "Mma_HFMA2 requires the GEMM M dimension to be divisible by 2."
262  );
263 
265  using FragmentA = Array<half_t, Shape::kMK>;
266 
268  using FragmentB = Array<half_t, Shape::kKN>;
269 
271  using FragmentC = Array<half_t, Shape::kMN>;
272 
273  //
274  // Methods
275  //
276 
280  FragmentC & D,
281  FragmentA const & A,
282  FragmentB const & B,
283  FragmentC const & C) {
284 
286  D = C;
287 
288  using Mma = arch::Mma<
290  1,
291  half_t,
293  half_t,
295  half_t,
296  layout::ColumnMajor,
297  arch::OpMultiplyAdd>;
298 
299  Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
300  Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);
301  Array<half_t, 1> const *ptr_B = reinterpret_cast<Array<half_t, 1> const *>(&B);
302 
303  Mma mma;
304 
306  for (int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) {
307 
309  for (int m = 0; m < Shape::kM / Mma::Shape::kM; ++m) {
310 
312  for (int n = 0; n < Shape::kN / Mma::Shape::kN; ++n) {
313 
314  Array<half_t, 2> tmp;
315  Array<half_t, 2> *ptr_tmp = &tmp;
316 
317  ptr_tmp[0] = ptr_D[m + n * Shape::kM/2];
318 
319  mma(
320  tmp,
321  ptr_A[m + k * Shape::kM/2],
322  ptr_B[k * Shape::kN + n],
323  tmp);
324 
325  ptr_D[m + n * Shape::kM/2] = ptr_tmp[0];
326  }
327  }
328  }
329  }
330 };
331 
333 // Specialization for NTT //
335 
336 template <typename Shape>
337 struct Mma_HFMA2<
338  Shape,
339  layout::ColumnMajor,
342  true
343  > {
344 
346  !(Shape::kN % 2),
347  "Mma_HFMA2 requires the N dimension to be divisible by 2."
348  );
349 
351  using FragmentA = Array<half_t, Shape::kMK>;
352 
354  using FragmentB = Array<half_t, Shape::kKN>;
355 
357  using FragmentC = Array<half_t, Shape::kMN>;
358 
359  //
360  // Methods
361  //
362 
366  FragmentC & D,
367  FragmentA const & A,
368  FragmentB const & B,
369  FragmentC const & C) {
370 
372  D = C;
373 
375  using Mma = arch::Mma<
377  1,
378  half_t,
380  half_t,
382  half_t,
383  layout::RowMajor,
384  arch::OpMultiplyAdd>;
385 
386  Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
387  Array<half_t, 1> const *ptr_A = reinterpret_cast<Array<half_t, 1> const *>(&A);
388  Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);
389 
390  Mma mma;
391 
393  for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
394 
396  for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
397 
399  for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
400 
401  Array<half_t, 2> tmp;
402  Array<half_t, 2> *ptr_tmp = &tmp;
403  ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
404 
405  mma(
406  tmp,
407  ptr_A[k*Shape::kM + m],
408  ptr_B[k*Shape::kN/2 + n],
409  tmp);
410 
411  ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
412  }
413  }
414  }
415  }
416 };
417 
418 
420 // Specialization for TNN //
422 
423 template <typename Shape>
424 struct Mma_HFMA2 <
425  Shape,
426  layout::RowMajor,
429  true
430  > {
431 
433  !(Shape::kM % 2),
434  "Mma_HFMA2 requires the M dimension to be divisible by 2."
435  );
436 
438  using FragmentA = Array<half_t, Shape::kMK>;
439 
441  using FragmentB = Array<half_t, Shape::kKN>;
442 
444  using FragmentC = Array<half_t, Shape::kMN>;
445 
446  //
447  // Methods
448  //
449 
453  FragmentC & D,
454  FragmentA const & A,
455  FragmentB const & B,
456  FragmentC const & C) {
457 
459  D = C;
460 
462  using Mma = arch::Mma<
464  1,
465  half_t,
467  half_t,
469  half_t,
470  layout::ColumnMajor,
471  arch::OpMultiplyAdd>;
472 
473  Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
474  Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);
475  Array<half_t, 1> const *ptr_B = reinterpret_cast<Array<half_t, 1> const *>(&B);
476 
477  Mma mma;
478 
480  for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
481 
483  for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
484 
486  for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
487 
488  Array<half_t, 2> tmp;
489  Array<half_t, 2> *ptr_tmp = &tmp;
490  ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];
491 
492  Array<half_t, 2> tmp_A;
493  tmp_A[0] = ptr_A->at(2*m*Shape::kK + k);
494  tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k);
495 
496  mma(
497  tmp,
498  tmp_A,
499  ptr_B[n*Shape::kK + k],
500  tmp);
501 
502  ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];
503  }
504  }
505  }
506  }
507 };
508 
510 // Specialization for TNT //
512 
513 template <typename Shape>
514 struct Mma_HFMA2 <
515  Shape,
516  layout::RowMajor,
519  true
520  > {
521 
523  !(Shape::kN % 2),
524  "Mma_HFMA2 requires the N dimension to be divisible by 2."
525  );
526 
528  using FragmentA = Array<half_t, Shape::kMK>;
529 
531  using FragmentB = Array<half_t, Shape::kKN>;
532 
534  using FragmentC = Array<half_t, Shape::kMN>;
535 
536  //
537  // Methods
538  //
539 
543  FragmentC & D,
544  FragmentA const & A,
545  FragmentB const & B,
546  FragmentC const & C) {
547 
549  D = C;
550 
552  using Mma = arch::Mma<
554  1,
555  half_t,
557  half_t,
559  half_t,
560  layout::RowMajor,
561  arch::OpMultiplyAdd>;
562 
563  Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
564  Array<half_t, 1> const *ptr_A = reinterpret_cast<Array<half_t, 1> const *>(&A);
565  Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);
566 
567  Mma mma;
568 
570  for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
571 
573  for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
574 
576  for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
577 
578  Array<half_t, 2> tmp;
579  Array<half_t, 2> *ptr_tmp = &tmp;
580  ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
581 
582  Array<half_t, 2> tmp_B;
583  tmp_B[0] = ptr_B->at(2*n*Shape::kK + k);
584  tmp_B[1] = ptr_B->at((2*n+1)*Shape::kK + k);
585 
586  mma(
587  tmp,
588  ptr_A[m*Shape::kK + k],
589  tmp_B,
590  tmp);
591 
592  ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
593  }
594  }
595  }
596  }
597 };
598 
600 // Specialization for TTN //
602 
603 template <typename Shape>
604 struct Mma_HFMA2 <
605  Shape,
606  layout::RowMajor,
609  true
610  > {
611 
613  !(Shape::kM % 2),
614  "Mma_HFMA2 requires the M dimension to be divisible by 2."
615  );
616 
618  using FragmentA = Array<half_t, Shape::kMK>;
619 
621  using FragmentB = Array<half_t, Shape::kKN>;
622 
624  using FragmentC = Array<half_t, Shape::kMN>;
625 
626  //
627  // Methods
628  //
629 
633  FragmentC & D,
634  FragmentA const & A,
635  FragmentB const & B,
636  FragmentC const & C) {
637 
639  D = C;
640 
642  using Mma = arch::Mma<
644  1,
645  half_t,
647  half_t,
648  layout::RowMajor,
649  half_t,
651  arch::OpMultiplyAdd>;
652 
653  Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
654  Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);
655  Array<half_t, 1> const *ptr_B = reinterpret_cast<Array<half_t, 1> const *>(&B);
656 
657  Mma mma;
658 
660  for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
661 
663  for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
664 
666  for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
667 
668  Array<half_t, 2> tmp;
669  Array<half_t, 2> *ptr_tmp = &tmp;
670  ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m];
671 
672  Array<half_t, 2> tmp_A;
673  tmp_A[0] = ptr_A->at(2*m*Shape::kK + k);
674  tmp_A[1] = ptr_A->at((2*m+1)*Shape::kK + k);
675 
676  mma(
677  tmp,
678  tmp_A,
679  ptr_B[k*Shape::kN + n],
680  tmp);
681 
682  ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0];
683  }
684  }
685  }
686  }
687 };
688 
689 
691 // Specialization for TTT //
693 
694 template <typename Shape>
695 struct Mma_HFMA2<
696  Shape,
697  layout::RowMajor,
700  true
701  > {
702 
704  !(Shape::kN % 2),
705  "Mma_HFMA2 requires the N dimension to be divisible by 2."
706  );
707 
709  using FragmentA = Array<half_t, Shape::kMK>;
710 
712  using FragmentB = Array<half_t, Shape::kKN>;
713 
715  using FragmentC = Array<half_t, Shape::kMN>;
716 
717  //
718  // Methods
719  //
720 
724  FragmentC & D,
725  FragmentA const & A,
726  FragmentB const & B,
727  FragmentC const & C) {
728 
730  D = C;
731 
733  using Mma = arch::Mma<
735  1,
736  half_t,
738  half_t,
739  layout::RowMajor,
740  half_t,
741  layout::RowMajor,
742  arch::OpMultiplyAdd>;
743 
744  Array<half_t, 2> *ptr_D = reinterpret_cast<Array<half_t, 2> *>(&D);
745  Array<half_t, 1> const *ptr_A = reinterpret_cast<Array<half_t, 1> const *>(&A);
746  Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);
747 
748  Mma mma;
749 
751  for(auto k=0; k < Shape::kK / Mma::Shape::kK; k++){
752 
754  for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){
755 
757  for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){
758 
759  Array<half_t, 2> tmp;
760  Array<half_t, 2> *ptr_tmp = &tmp;
761  ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n];
762 
763  mma(
764  tmp,
765  ptr_A[m*Shape::kK + k],
766  ptr_B[k*Shape::kN/2 + n],
767  tmp);
768 
769  ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0];
770  }
771  }
772  }
773  }
774 };
775 
777 // Specialization for TNT + Inner Product or 1x1x2K + LayoutC = T //
779 
780 template <typename Shape, typename LayoutA, typename LayoutB>
781 struct Mma_HFMA2<
782  Shape,
783  LayoutA,
784  LayoutB,
785  layout::RowMajor,
786  false
787  > {
788 
790  !(Shape::kK % 2),
791  "Mma_HFMA2 requires the K dimension to be divisible by 2."
792  );
793 
795  using FragmentA = Array<half_t, Shape::kMK>;
796 
798  using FragmentB = Array<half_t, Shape::kKN>;
799 
801  using FragmentC = Array<half_t, Shape::kMN>;
802 
803  //
804  // Methods
805  //
806 
810  FragmentC & D,
811  FragmentA const & A,
812  FragmentB const & B,
813  FragmentC const & C) {
814 
816  D = C;
817 
820 
821  Array<half_t, 1> *ptr_D = reinterpret_cast<Array<half_t, 1> *>(&D);
822  Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);
823  Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);
824 
825  // Inner product is calculated using MACs, followed by final reduction
827  cutlass::reduction::thread::Reduce< plus<half_t>, Array<half_t, 2> > reduce;
828 
830  for(auto n=0; n < Shape::kN / GemmShape::kN; n++){
831 
833  for(auto m=0; m < Shape::kM / GemmShape::kM; m++){
834 
835  Array<half_t, 2> tmp_C;
836  tmp_C.clear();
837  Array<half_t, 1> *ptr_tmp_C = reinterpret_cast<Array<half_t, 1> *>(&tmp_C);
838  ptr_tmp_C[0] = ptr_D[n*Shape::kM + m];
839 
841  for(auto k=0; k < Shape::kK / GemmShape::kK; k++){
842  tmp_C = mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C);
843  }
844 
845  Array<half_t, 1> res;
846  Array<half_t, 1> *ptr_res = &res;
847  res = reduce(tmp_C);
848 
849  ptr_D[m*Shape::kN + n] = ptr_res[0];
850  }
851  }
852  }
853 };
854 
856 // Specialization for TNN + Inner Product or 1x1x2K + LayoutC = N //
858 
859 template <typename Shape, typename LayoutA, typename LayoutB>
860 struct Mma_HFMA2<
861  Shape,
862  LayoutA,
863  LayoutB,
864  layout::ColumnMajor,
865  false
866  > {
867 
869  !(Shape::kK % 2),
870  "Mma_HFMA2 requires the K dimension to be divisible by 2."
871  );
872 
874  using FragmentA = Array<half_t, Shape::kMK>;
875 
877  using FragmentB = Array<half_t, Shape::kKN>;
878 
880  using FragmentC = Array<half_t, Shape::kMN>;
881 
882  //
883  // Methods
884  //
885 
889  FragmentC & D,
890  FragmentA const & A,
891  FragmentB const & B,
892  FragmentC const & C) {
893 
895  D = C;
896 
899 
900  Array<half_t, 1> *ptr_D = reinterpret_cast<Array<half_t, 1> *>(&D);
901  Array<half_t, 2> const *ptr_A = reinterpret_cast<Array<half_t, 2> const *>(&A);
902  Array<half_t, 2> const *ptr_B = reinterpret_cast<Array<half_t, 2> const *>(&B);
903 
904  // Inner product is calculated using MACs, followed by final reduction
906  cutlass::reduction::thread::Reduce< plus<half_t>, Array<half_t, 2> > reduce;
907 
909  for(auto n=0; n < Shape::kN / GemmShape::kN; n++){
910 
912  for(auto m=0; m < Shape::kM / GemmShape::kM; m++){
913 
914  Array<half_t, 2> tmp_C;
915  tmp_C.clear();
916  Array<half_t, 1> *ptr_tmp_C = reinterpret_cast<Array<half_t, 1> *>(&tmp_C);
917  ptr_tmp_C[0] = ptr_D[n*Shape::kM + m];
918 
920  for(auto k=0; k < Shape::kK / GemmShape::kK; k++){
921 
922  tmp_C = mac(ptr_A[m*Shape::kK/2 + k], ptr_B[n*Shape::kK/2 + k], tmp_C);
923 
924  }
925 
926  Array<half_t, 1> res;
927  Array<half_t, 1> *ptr_res = &res;
928  res = reduce(tmp_C);
929 
930  ptr_D[n*Shape::kM + m] = ptr_res[0];
931  }
932  }
933  }
934 };
935 
936 } // namespace detail
937 
939 
941 template <
943  typename Shape_, typename LayoutA, typename LayoutB, typename LayoutC
944 >
945 struct Mma<
946  Shape_,
947  half_t,
948  LayoutA,
949  half_t,
950  LayoutB,
951  half_t,
952  LayoutC,
953  arch::OpMultiplyAdd
954  > {
955 
957  using Shape = Shape_;
958 
960  using ElementA = half_t;
961 
963  using ElementB = half_t;
964 
966  using ElementC = half_t;
967 
969  using Operator = arch::OpMultiplyAdd;
970 
972  using FragmentA = Array<ElementA, Shape::kMK>;
973 
975  using FragmentB = Array<ElementB, Shape::kKN>;
976 
978  using FragmentC = Array<ElementC, Shape::kMN>;
979 
980  //
981  // Methods
982  //
983 
987  FragmentC & D,
988  FragmentA const & A,
989  FragmentB const & B,
990  FragmentC const & C) {
991 
996 
997  constexpr bool m_mod2 = !(Shape::kM % 2);
998  constexpr bool n_mod2 = !(Shape::kN % 2);
999  constexpr bool k_mod2 = !(Shape::kK % 2);
1000 
1001  // HFMA based MMA optimizations are of 2 types :
1002  // 1. Inner product
1003  // 2. Outer product
1004  // It is chosen based on LayoutC (for outer product gemm) or
1005  // Using LayoutA and LayoutB or shape=1x1x2K (for inner product gemms)
1006  // If all fails, we choose the generic MMA
1007  constexpr bool use_outer_prod = (c_column_major && m_mod2) || (c_row_major && n_mod2);
1008  constexpr bool use_inner_prod = (a_row_major && b_column_major && k_mod2) || (Shape::kM==1 && Shape::kN==1 && k_mod2);
1009  constexpr bool use_optimized = (use_outer_prod || use_inner_prod);
1010 
1011  typename platform::conditional< use_optimized,
1014  >::type mma;
1015 
1016  mma(D, A, B, C);
1017 
1018  }
1019 };
1020 
1022 
1023 namespace detail {
1024 
1026  template <
1027  typename LayoutA,
1029  typename LayoutB>
1031 
1032  static bool const kIsConventionalLayout =
1037 
1038  static bool const value = kIsConventionalLayout;
1039  };
1040 };
1041 
1043 
1045 template <
1047  typename Shape_,
1048  typename LayoutA_,
1049  typename LayoutB_
1050 >
1051 struct Mma<
1052  Shape_,
1053  half_t,
1054  LayoutA_,
1055  half_t,
1056  LayoutB_,
1057  half_t,
1058  layout::RowMajor,
1059  arch::OpMultiplyAdd,
1060  typename platform::enable_if<detail::EnableMma_Crow_SM60<
1061  LayoutA_,
1062  LayoutB_
1063  >::value>::type>{
1064 
1065  using Shape = Shape_;
1066  using ElementA = half_t;
1067  using LayoutA = LayoutA_;
1068  using ElementB = half_t;
1069  using LayoutB = LayoutB_;
1070  using ElementC = half_t;
1072  using Operator = arch::OpMultiplyAdd;
1073 
1074  using TransposeMma = Mma<
1076  half_t,
1078  half_t,
1080  half_t,
1082  arch::OpMultiplyAdd,
1083  bool>;
1084 
1085  using FragmentA = Array<ElementA, Shape::kMK>;
1086  using FragmentB = Array<ElementB, Shape::kKN>;
1087  using FragmentC = Array<ElementC, Shape::kMN>;
1088 
1091  FragmentC & D,
1092  FragmentA const & A,
1093  FragmentB const & B,
1094  FragmentC const & C) {
1095 
1096  TransposeMma mma;
1097 
1098  mma(D, B, A, C);
1099  }
1100 };
1101 
1103 
1104 } // namespace thread
1105 } // namespace gemm
1106 } // namespace cutlass
1107 
Fused multiply-add.
Definition: functional.h:92
Determines whether to enable thread::Gemm<> specializations compatible with SM50. ...
Definition: gemm/thread/mma_sm60.h:1030
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:801
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:271
Definition: aligned_buffer.h:35
#define constexpr
Definition: platform.h:137
Defines a structure containing strides, bounds, and a pointer to tensor data.
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:94
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:528
std::is_same (false specialization)
Definition: platform.h:394
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:809
Structure to compute the matrix product for HFMA.
Definition: gemm/thread/mma_sm60.h:66
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:441
IEEE half-precision floating-point type.
Definition: half.h:126
Defines common types used for all GEMM-like operators.
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:444
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:438
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:357
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:102
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:723
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:632
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:174
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:712
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:624
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
static int const kK
Definition: include/cutlass/gemm/gemm.h:60
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:177
Array< ElementB, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:975
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:91
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:365
arch::OpMultiplyAdd Operator
Underlying mathematical operator.
Definition: gemm/thread/mma_sm60.h:969
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:709
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:531
Gemplate that handles all packed matrix layouts.
Definition: gemm/thread/mma_sm50.h:65
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:888
Defines basic thread level reduction with specializations for Array<T, N>.
Array< ElementC, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:978
std::enable_if (true specialization)
Definition: platform.h:315
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:188
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Templates exposing architecture support for warp-level multiply-add operations.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
std::conditional (true specialization)
Definition: platform.h:325
#define static_assert(__e, __m)
Definition: platform.h:153
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:265
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:88
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:880
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:986
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:795
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:351
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:452
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:621
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<>
Definition: gemm/thread/mma_sm60.h:957
Structure to compute the matrix product.
Definition: gemm/thread/mma.h:66
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:877
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:715
Defines layout functions used by TensorRef and derived classes.
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:534
Array< half_t, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm60.h:180
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:798
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:618
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:268
Matrix multiply-add operation.
Definition: arch/mma.h:92
Array< ElementA, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:972
Array< half_t, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm60.h:354
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:542
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm60.h:279
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
Structure to compute the thread level reduction.
Definition: reduce.h:43
CUTLASS_HOST_DEVICE Array< T, N > mac(Array< T, N > const &a, Array< T, N > const &b, Array< T, N > const &c)
Definition: simd.h:84
Array< half_t, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm60.h:874
static int const kN
Definition: include/cutlass/gemm/gemm.h:59