CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
default_mma_core_simt.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
32 #pragma once
33 
34 #include "cutlass/cutlass.h"
35 #include "cutlass/array.h"
36 #include "cutlass/fast_math.h"
37 
38 #include "cutlass/numeric_types.h"
39 #include "cutlass/matrix_shape.h"
40 
41 
45 
49 
51 
52 namespace cutlass {
53 namespace gemm {
54 namespace threadblock {
55 
56 namespace detail {
57 
58 // convert a WarpShape which is the whole tile of elements into warp num threads.
59 // The goal is for each thread's tile of elements to be as square as possible
60 // for performance (4x4 will be faster than 2x8).
61 template<typename WarpShape>
63  return (WarpShape::kM > WarpShape::kN) ? 8 : 4;
64 }
65 
67 constexpr int simt_transpose_padding(int threads, int crosswise, int size_in_bits) {
68  return (size_in_bits >= 32 ?
69  threads / crosswise / (size_in_bits / 32) :
70  threads / crosswise * (32 / size_in_bits)
71  );
72 }
73 
74 }
75 
77 
85 template <
88  typename Shape_,
90  typename WarpShape_,
92  typename ElementA_,
94  typename ElementB_,
96  typename ElementC_,
98  typename LayoutC_,
100  typename Operator_>
101 struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<1, 1, 1>, ElementA_,
102  layout::ColumnMajor, ElementB_, layout::RowMajor,
103  ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_
104  > {
105  using Shape = Shape_;
106  using WarpShape = WarpShape_;
108  using ElementA = ElementA_;
110  using ElementB = ElementB_;
112  using ElementC = ElementC_;
113  using LayoutC = LayoutC_;
114  using OperatorClass = arch::OpClassSimt;
115  static int const PartitionsK = Shape::kK / WarpShape::kK;
116 
118  using Operator = Operator_;
119 
121  using WarpCount = GemmShape<
122  Shape::kM / WarpShape::kM,
123  Shape::kN / WarpShape::kN,
124  PartitionsK
125  >;
126 
127  // Divisility requirements
129  !(Shape::kM % WarpShape::kM) &&
130  !(Shape::kN % WarpShape::kN),
131  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
132  );
133 
135  static int const kWarpSize = warp::WarpSize<arch::OpClassSimt>::value;
136 
138  static int const kThreads = WarpCount::kCount * kWarpSize;
139 
140  static int const kElementsPerAccess = 1;
141 
142  //
143  // Shared memory layouts
144  //
145 
148 
149  //
150  // Iterators to write to shared memory
151  //
152 
156  kThreads,
157  kElementsPerAccess
158  >;
159 
163  ElementA,
164  SmemLayoutA,
165  1,
167  >;
168 
172  kThreads,
173  kElementsPerAccess
174  >;
175 
179  ElementB,
180  SmemLayoutB,
181  0,
183  >;
184 
185  //
186  // Warp-level matrix multiply operator
187  //
188 
189  // Define the warp-level op
190  static const int WarpNumThreadsM = detail::simt_get_warp_threads_m<WarpShape>();
191  static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM;
192  static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM;
193  static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN;
194  static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
195  "WarpShape must be divisible by ThreadTile shape.");
196  static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1;
197  static const int numElementsA = 128 / sizeof_bits<ElementA>::value;
198  static const int numElementsB = 128 / sizeof_bits<ElementB>::value;
199  static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM);
200  static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN);
201  // these should have max of thread tile also
203  LaneM,
204  LaneN,
205  1>;
210  >;
211 
213  WarpShape,
214  ElementA,
215  SmemLayoutA,
216  ElementB,
217  SmemLayoutB,
218  ElementC,
219  LayoutC,
220  Policy
221  >;
222 
224  using MmaPolicy = MmaPolicy<
225  MmaWarpSimt,
227  MatrixShape<0, 0>,
228  WarpCount::kK
229  >;
230 };
231 
233 
241 template <
244  typename Shape_,
246  typename WarpShape_,
248  typename ElementA_,
250  typename ElementB_,
252  typename ElementC_,
254  typename LayoutC_,
256  typename Operator_>
257 struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<1, 1, 1>, ElementA_,
258  layout::RowMajor, ElementB_, layout::ColumnMajor,
259  ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_
260  > {
261  using Shape = Shape_;
262  using WarpShape = WarpShape_;
264  using ElementA = ElementA_;
266  using ElementB = ElementB_;
268  using ElementC = ElementC_;
269  using LayoutC = LayoutC_;
270  using OperatorClass = arch::OpClassSimt;
271  static int const PartitionsK = Shape::kK / WarpShape::kK;
272 
274  using Operator = Operator_;
275 
277  using WarpCount = GemmShape<
278  Shape::kM / WarpShape::kM,
279  Shape::kN / WarpShape::kN,
280  PartitionsK
281  >;
282 
283  // Divisility requirements
285  !(Shape::kM % WarpShape::kM) &&
286  !(Shape::kN % WarpShape::kN),
287  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
288  );
289 
291  static int const kWarpSize = warp::WarpSize<arch::OpClassSimt>::value;
292 
294  static int const kThreads = WarpCount::kCount * kWarpSize;
295 
296  static int const kElementsPerAccess = 1;
297 
298  //
299  // Shared memory layouts
300  //
301 
302  using SmemLayoutA = layout::ColumnMajor;
303  using SmemLayoutB = layout::RowMajor;
304 
305  //
306  // Iterators to write to shared memory
307  //
308 
312  kThreads,
313  kElementsPerAccess
314  >;
315 
318 
321  MatrixShape<Shape::kM, Shape::kK>,
322  ElementA,
323  SmemLayoutA,
324  1,
325  SmemThreadMapA // was IteratorThreadMapA
326  >;
327 
331  kThreads,
332  kElementsPerAccess
333  >;
334 
337 
340  MatrixShape<Shape::kK, Shape::kN>,
341  ElementB,
342  SmemLayoutB,
343  0,
344  SmemThreadMapB // was IteratorThreadMapA
345  >;
346 
347  //
348  // Warp-level matrix multiply operator
349  //
350 
351  // Define the warp-level op
352  static const int WarpNumThreadsM = detail::simt_get_warp_threads_m<WarpShape>();
353  static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM;
354  static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM;
355  static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN;
356  static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
357  "WarpShape must be divisible by ThreadTile shape.");
358  static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1;
359  static const int numElementsA = 128 / sizeof_bits<ElementA>::value;
360  static const int numElementsB = 128 / sizeof_bits<ElementB>::value;
361  static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM);
362  static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN);
363 
364  static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementA>::value);
365  static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementB>::value);
366 
367  // these should have max of thread tile also
369  LaneM,
370  LaneN,
371  1>;
376  >;
377 
379  WarpShape,
380  ElementA,
381  SmemLayoutA,
382  ElementB,
383  SmemLayoutB,
384  ElementC,
385  LayoutC,
386  Policy
387  >;
388 
390  using MmaPolicy = MmaPolicy<
391  MmaWarpSimt,
392  MatrixShape<kPaddingN, 0>, // skew for A matrix to avoid SMEM bank conflicts
393  MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts
394  WarpCount::kK
395  >;
396 };
397 
399 
407 template <
410  typename Shape_,
412  typename WarpShape_,
414  typename ElementA_,
416  typename ElementB_,
418  typename ElementC_,
420  typename LayoutC_,
422  typename Operator_>
423 struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<1, 1, 1>, ElementA_,
424  layout::RowMajor, ElementB_, layout::RowMajor, ElementC_,
425  LayoutC_, arch::OpClassSimt, 2, Operator_
426  > {
427  using Shape = Shape_;
428  using WarpShape = WarpShape_;
430  using ElementA = ElementA_;
432  using ElementB = ElementB_;
434  using ElementC = ElementC_;
435  using LayoutC = LayoutC_;
436  using OperatorClass = arch::OpClassSimt;
437  static int const PartitionsK = Shape::kK / WarpShape::kK;
438 
440  using Operator = Operator_;
441 
443  using WarpCount = GemmShape<
444  Shape::kM / WarpShape::kM,
445  Shape::kN / WarpShape::kN,
446  PartitionsK
447  >;
448 
449  // Divisility requirements
451  !(Shape::kM % WarpShape::kM) &&
452  !(Shape::kN % WarpShape::kN),
453  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
454  );
455 
457  static int const kWarpSize = warp::WarpSize<arch::OpClassSimt>::value;
458 
460  static int const kThreads = WarpCount::kCount * kWarpSize;
461 
462  static int const kElementsPerAccess = 1;
463 
464  //
465  // Shared memory layouts
466  //
467 
468  using SmemLayoutA = layout::ColumnMajor;
469  using SmemLayoutB = layout::RowMajor;
470 
471  //
472  // Iterators to write to shared memory
473  //
474 
477  layout::PitchLinearShape<Shape::kK, Shape::kM>,
478  kThreads,
479  kElementsPerAccess
480  >;
481 
484 
487  MatrixShape<Shape::kM, Shape::kK>,
488  ElementA,
489  SmemLayoutA,
490  1,
492  >;
493 
496  layout::PitchLinearShape<Shape::kN, Shape::kK>,
497  kThreads,
498  kElementsPerAccess
499  >;
500 
503  MatrixShape<Shape::kK, Shape::kN>,
504  ElementB,
505  SmemLayoutB,
506  0,
508  >;
509 
510  //
511  // Warp-level matrix multiply operator
512  //
513 
514  // Define the warp-level op
515  static const int WarpNumThreadsM = detail::simt_get_warp_threads_m<WarpShape>();
516  static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM;
517  static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM;
518  static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN;
519  static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
520  "WarpShape must be divisible by ThreadTile shape.");
521  static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1;
522  static const int numElementsA = 128 / sizeof_bits<ElementA>::value;
523  static const int numElementsB = 128 / sizeof_bits<ElementB>::value;
524  static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM);
525  static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN);
526 
527  static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementA>::value);
528 
529  // these should have max of thread tile also
531  LaneM,
532  LaneN,
533  1>;
538  >;
539 
541  WarpShape,
542  ElementA,
543  SmemLayoutA,
544  ElementB,
545  SmemLayoutB,
546  ElementC,
547  LayoutC,
548  Policy
549  >;
550 
552  using MmaPolicy = MmaPolicy<
553  MmaWarpSimt,
554  MatrixShape<kPaddingM, 0>, // skew for A matrix to avoid SMEM bank conflicts
556  WarpCount::kK
557  >;
558 };
559 
561 
569 template <
572  typename Shape_,
574  typename WarpShape_,
576  typename ElementA_,
578  typename ElementB_,
580  typename ElementC_,
582  typename LayoutC_,
584  typename Operator_>
585 struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<1, 1, 1>, ElementA_,
586  layout::ColumnMajor, ElementB_, layout::ColumnMajor,
587  ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_
588  > {
589  using Shape = Shape_;
590  using WarpShape = WarpShape_;
592  using ElementA = ElementA_;
594  using ElementB = ElementB_;
596  using ElementC = ElementC_;
597  using LayoutC = LayoutC_;
598  using OperatorClass = arch::OpClassSimt;
599  static int const PartitionsK = Shape::kK / WarpShape::kK;
600 
602  using Operator = Operator_;
603 
605  using WarpCount = GemmShape<
606  Shape::kM / WarpShape::kM,
607  Shape::kN / WarpShape::kN,
608  PartitionsK
609  >;
610 
611  // Divisility requirements
613  !(Shape::kM % WarpShape::kM) &&
614  !(Shape::kN % WarpShape::kN),
615  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
616  );
617 
619  static int const kWarpSize = warp::WarpSize<arch::OpClassSimt>::value;
620 
622  static int const kThreads = WarpCount::kCount * kWarpSize;
623 
624  static int const kElementsPerAccess = 1;
625 
626  //
627  // Shared memory layouts
628  //
629 
630  using SmemLayoutA = layout::ColumnMajor;
631  using SmemLayoutB = layout::RowMajor;
632 
633  //
634  // Iterators to write to shared memory
635  //
636 
639  layout::PitchLinearShape<Shape::kM, Shape::kK>,
640  kThreads,
641  kElementsPerAccess
642  >;
643 
646  MatrixShape<Shape::kM, Shape::kK>,
647  ElementA,
648  SmemLayoutA,
649  1,
651  >;
652 
655  layout::PitchLinearShape<Shape::kK, Shape::kN>,
656  kThreads,
657  kElementsPerAccess
658  >;
659 
662 
665  MatrixShape<Shape::kK, Shape::kN>,
666  ElementB,
667  SmemLayoutB,
668  0,
670  >;
671 
672  //
673  // Warp-level matrix multiply operator
674  //
675 
676  // Define the warp-level op
677  static const int WarpNumThreadsM = detail::simt_get_warp_threads_m<WarpShape>();
678  static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM;
679  static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM;
680  static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN;
681  static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
682  "WarpShape must be divisible by ThreadTile shape.");
683  static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1;
684  static const int numElementsA = 128 / sizeof_bits<ElementA>::value;
685  static const int numElementsB = 128 / sizeof_bits<ElementB>::value;
686  static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM);
687  static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN);
688 
689  static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementB>::value);
690 
691  // these should have max of thread tile also
693  LaneM,
694  LaneN,
695  1>;
700  >;
701 
703  WarpShape,
704  ElementA,
705  SmemLayoutA,
706  ElementB,
707  SmemLayoutB,
708  ElementC,
709  LayoutC,
710  Policy
711  >;
712 
714  using MmaPolicy = MmaPolicy<
715  MmaWarpSimt,
717  MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts
718  WarpCount::kK
719  >;
720 };
721 
723 
731 template <
734  typename Shape_,
736  typename WarpShape_,
738  typename ElementC_,
740  typename LayoutC_,
742  typename Operator_>
743 struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<1, 1, 4>, int8_t,
744  layout::ColumnMajor, int8_t, layout::RowMajor, ElementC_,
745  LayoutC_, arch::OpClassSimt, 2, Operator_
746  > {
747 
748  using Shape = Shape_;
749  using WarpShape = WarpShape_;
751  using ElementA = int8_t;
753  using ElementB = int8_t;
755  using ElementC = ElementC_;
756  using LayoutC = LayoutC_;
757  using OperatorClass = arch::OpClassSimt;
758  static int const PartitionsK = Shape::kK / WarpShape::kK;
759 
761  using Operator = Operator_;
762 
764  using WarpCount = GemmShape<
765  Shape::kM / WarpShape::kM,
766  Shape::kN / WarpShape::kN,
767  PartitionsK
768  >;
769 
770  // Divisility requirements
772  !(Shape::kM % WarpShape::kM) &&
773  !(Shape::kN % WarpShape::kN),
774  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
775  );
776 
778  static int const kWarpSize = warp::WarpSize<arch::OpClassSimt>::value;
779 
781  static int const kThreads = WarpCount::kCount * kWarpSize;
782 
783  //
784  // Shared memory layouts
785  //
786 
788  using SmemLayoutB = layout::RowMajorInterleaved<4>;
789 
790  //
791  // Iterators to write to shared memory
792  //
793 
796  layout::PitchLinearShape<Shape::kM, Shape::kK>,
797  kThreads,
799  >;
800 
803  MatrixShape<Shape::kM, Shape::kK>,
804  ElementA,
805  SmemLayoutA,
806  1,
808  >;
809 
810 
813  layout::PitchLinearShape<Shape::kN, Shape::kK>,
814  kThreads,
816  >;
817 
820  MatrixShape<Shape::kK, Shape::kN>,
821  ElementB,
822  SmemLayoutB,
823  0,
825  >;
826 
827  //
828  // Warp-level matrix multiply operator
829  //
830 
831  // Define the warp-level op
832  static const int WarpNumThreadsM = detail::simt_get_warp_threads_m<WarpShape>();
833  static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM;
834  static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM;
835  static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN;
836  static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
837  "WarpShape must be divisible by ThreadTile shape.");
838  static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1;
839  static const int numElementsA = 128 / sizeof_bits<ElementA>::value;
840  static const int numElementsB = 128 / sizeof_bits<ElementB>::value;
841  static const int LaneM = cutlass::const_min(4, ThreadTileM);
842  static const int LaneN = cutlass::const_min(4, ThreadTileN);
843  // these should have max of thread tile also
845  LaneM,
846  LaneN,
847  4>;
848 
853  >;
854 
856  WarpShape,
857  ElementA,
858  SmemLayoutA,
859  ElementB,
860  SmemLayoutB,
861  ElementC,
862  LayoutC,
863  Policy,
864  PartitionsK
865  >;
866 
868  using MmaPolicy = MmaPolicy<
869  MmaWarpSimt,
871  MatrixShape<0, 0>,
872  WarpCount::kK
873  >;
874 };
875 
878 //
885 template <
888  typename Shape_,
890  typename WarpShape_,
892  typename ElementC_,
894  typename LayoutC_,
896  typename Operator_>
897 struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<1, 1, 4>, int8_t,
898  layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_,
899  LayoutC_, arch::OpClassSimt, 2, Operator_
900  > {
901 
902  using Shape = Shape_;
903  using WarpShape = WarpShape_;
905  using ElementA = int8_t;
907  using ElementB = int8_t;
909  using ElementC = ElementC_;
910  using LayoutC = LayoutC_;
911  using OperatorClass = arch::OpClassSimt;
912  static int const PartitionsK = Shape::kK / WarpShape::kK;
913 
915  using Operator = Operator_;
916 
918  using WarpCount = GemmShape<
919  Shape::kM / WarpShape::kM,
920  Shape::kN / WarpShape::kN,
921  PartitionsK
922  >;
923 
924  // Divisility requirements
926  !(Shape::kM % WarpShape::kM) &&
927  !(Shape::kN % WarpShape::kN),
928  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
929  );
930 
932  static int const kWarpSize = warp::WarpSize<arch::OpClassSimt>::value;
933 
935  static int const kThreads = WarpCount::kCount * kWarpSize;
936 
937  //
938  // Shared memory layouts
939  //
940 
942  using SmemLayoutB = layout::RowMajorInterleaved<4>;
943 
944  //
945  // Iterators to write to shared memory
946  //
947 
950  layout::PitchLinearShape<Shape::kK, Shape::kM>,
951  kThreads,
953  >;
954 
957 
960  MatrixShape<Shape::kM, Shape::kK>,
961  ElementA,
962  SmemLayoutA,
963  1,
965  >;
966 
967 
970  layout::PitchLinearShape<Shape::kK, Shape::kN>,
971  kThreads,
973  >;
974 
977 
980  MatrixShape<Shape::kK, Shape::kN>,
981  ElementB,
982  SmemLayoutB,
983  0,
985  >;
986 
987  //
988  // Warp-level matrix multiply operator
989  //
990 
991  // Define the warp-level op
992  static const int WarpNumThreadsM = detail::simt_get_warp_threads_m<WarpShape>();
993  static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM;
994  static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM;
995  static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN;
996  static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
997  "WarpShape must be divisible by ThreadTile shape.");
998  static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1;
999  static const int numElementsA = 128 / sizeof_bits<ElementA>::value;
1000  static const int numElementsB = 128 / sizeof_bits<ElementB>::value;
1001  static const int LaneM = cutlass::const_min(4, ThreadTileM);
1002  static const int LaneN = cutlass::const_min(4, ThreadTileN);
1003  // these should have max of thread tile also
1005  LaneM,
1006  LaneN,
1007  4>;
1008 
1012  LaneMmaShape
1013  >;
1014 
1016  WarpShape,
1017  ElementA,
1018  SmemLayoutA,
1019  ElementB,
1020  SmemLayoutB,
1021  ElementC,
1022  LayoutC,
1023  Policy,
1024  PartitionsK
1025  >;
1026 
1027  static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementA>::value);
1028  static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementB>::value);
1029 
1031  using MmaPolicy = MmaPolicy<
1032  MmaWarpSimt,
1035  WarpCount::kK
1036  >;
1037 };
1038 
1041 //
1048 template <
1051  typename Shape_,
1053  typename WarpShape_,
1055  typename ElementC_,
1057  typename LayoutC_,
1059  typename Operator_>
1060 struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<1, 1, 4>, int8_t,
1061  layout::RowMajor, int8_t, layout::RowMajor, ElementC_,
1062  LayoutC_, arch::OpClassSimt, 2, Operator_
1063  > {
1064 
1065  using Shape = Shape_;
1066  using WarpShape = WarpShape_;
1068  using ElementA = int8_t;
1070  using ElementB = int8_t;
1072  using ElementC = ElementC_;
1073  using LayoutC = LayoutC_;
1074  using OperatorClass = arch::OpClassSimt;
1075  static int const PartitionsK = Shape::kK / WarpShape::kK;
1076 
1078  using Operator = Operator_;
1079 
1081  using WarpCount = GemmShape<
1082  Shape::kM / WarpShape::kM,
1083  Shape::kN / WarpShape::kN,
1084  PartitionsK
1085  >;
1086 
1087  // Divisility requirements
1088  static_assert(
1089  !(Shape::kM % WarpShape::kM) &&
1090  !(Shape::kN % WarpShape::kN),
1091  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
1092  );
1093 
1095  static int const kWarpSize = warp::WarpSize<arch::OpClassSimt>::value;
1096 
1098  static int const kThreads = WarpCount::kCount * kWarpSize;
1099 
1100  //
1101  // Shared memory layouts
1102  //
1103 
1105  using SmemLayoutB = layout::RowMajorInterleaved<4>;
1106 
1107  //
1108  // Iterators to write to shared memory
1109  //
1110 
1113  layout::PitchLinearShape<Shape::kK, Shape::kM>,
1114  kThreads,
1116  >;
1117 
1120 
1123  MatrixShape<Shape::kM, Shape::kK>,
1124  ElementA,
1125  SmemLayoutA,
1126  1,
1128  >;
1129 
1132  layout::PitchLinearShape<Shape::kN, Shape::kK>,
1133  kThreads,
1135  >;
1136 
1139  MatrixShape<Shape::kK, Shape::kN>,
1140  ElementB,
1141  SmemLayoutB,
1142  0,
1144  >;
1145 
1146  //
1147  // Warp-level matrix multiply operator
1148  //
1149 
1150  // Define the warp-level op
1151  static const int WarpNumThreadsM = detail::simt_get_warp_threads_m<WarpShape>();
1152  static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM;
1153  static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM;
1154  static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN;
1155  static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
1156  "WarpShape must be divisible by ThreadTile shape.");
1157  static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1;
1158  static const int numElementsA = 128 / sizeof_bits<ElementA>::value;
1159  static const int numElementsB = 128 / sizeof_bits<ElementB>::value;
1160  static const int LaneM = cutlass::const_min(4, ThreadTileM);
1161  static const int LaneN = cutlass::const_min(4, ThreadTileN);
1162  // these should have max of thread tile also
1164  LaneM,
1165  LaneN,
1166  4>;
1167 
1171  LaneMmaShape
1172  >;
1173 
1175  WarpShape,
1176  ElementA,
1177  SmemLayoutA,
1178  ElementB,
1179  SmemLayoutB,
1180  ElementC,
1181  LayoutC,
1182  Policy,
1183  PartitionsK
1184  >;
1185 
1186  static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementA>::value);
1187  static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementB>::value);
1188 
1190  using MmaPolicy = MmaPolicy<
1191  MmaWarpSimt,
1194  WarpCount::kK
1195  >;
1196 };
1197 
1200 //
1207 template <
1210  typename Shape_,
1212  typename WarpShape_,
1214  typename ElementC_,
1216  typename LayoutC_,
1218  typename Operator_>
1219 struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<1, 1, 4>, int8_t,
1220  layout::ColumnMajor, int8_t, layout::ColumnMajor, ElementC_,
1221  LayoutC_, arch::OpClassSimt, 2, Operator_
1222  > {
1223 
1224  using Shape = Shape_;
1225  using WarpShape = WarpShape_;
1227  using ElementA = int8_t;
1229  using ElementB = int8_t;
1231  using ElementC = ElementC_;
1232  using LayoutC = LayoutC_;
1233  using OperatorClass = arch::OpClassSimt;
1234  static int const PartitionsK = Shape::kK / WarpShape::kK;
1235 
1237  using Operator = Operator_;
1238 
1240  using WarpCount = GemmShape<
1241  Shape::kM / WarpShape::kM,
1242  Shape::kN / WarpShape::kN,
1243  PartitionsK
1244  >;
1245 
1246  // Divisility requirements
1247  static_assert(
1248  !(Shape::kM % WarpShape::kM) &&
1249  !(Shape::kN % WarpShape::kN),
1250  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
1251  );
1252 
1254  static int const kWarpSize = warp::WarpSize<arch::OpClassSimt>::value;
1255 
1257  static int const kThreads = WarpCount::kCount * kWarpSize;
1258 
1259  //
1260  // Shared memory layouts
1261  //
1262 
1264  using SmemLayoutB = layout::RowMajorInterleaved<4>;
1265 
1266  //
1267  // Iterators to write to shared memory
1268  //
1269 
1272  layout::PitchLinearShape<Shape::kM, Shape::kK>,
1273  kThreads,
1275  >;
1276 
1279  MatrixShape<Shape::kM, Shape::kK>,
1280  ElementA,
1281  SmemLayoutA,
1282  1,
1284  >;
1285 
1286 
1289  layout::PitchLinearShape<Shape::kK, Shape::kN>,
1290  kThreads,
1292  >;
1293 
1296 
1299  MatrixShape<Shape::kK, Shape::kN>,
1300  ElementB,
1301  SmemLayoutB,
1302  0,
1304  >;
1305 
1306  //
1307  // Warp-level matrix multiply operator
1308  //
1309 
1310  // Define the warp-level op
1311  static const int WarpNumThreadsM = detail::simt_get_warp_threads_m<WarpShape>();
1312  static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM;
1313  static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM;
1314  static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN;
1315  static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN),
1316  "WarpShape must be divisible by ThreadTile shape.");
1317  static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1;
1318  static const int numElementsA = 128 / sizeof_bits<ElementA>::value;
1319  static const int numElementsB = 128 / sizeof_bits<ElementB>::value;
1320  static const int LaneM = cutlass::const_min(4, ThreadTileM);
1321  static const int LaneN = cutlass::const_min(4, ThreadTileN);
1322  // these should have max of thread tile also
1324  LaneM,
1325  LaneN,
1326  4>;
1327 
1331  LaneMmaShape
1332  >;
1333 
1335  WarpShape,
1336  ElementA,
1337  SmemLayoutA,
1338  ElementB,
1339  SmemLayoutB,
1340  ElementC,
1341  LayoutC,
1342  Policy,
1343  PartitionsK
1344  >;
1345 
1346  static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementA>::value);
1347  static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementB>::value);
1348 
1350  using MmaPolicy = MmaPolicy<
1351  MmaWarpSimt,
1354  WarpCount::kK
1355  >;
1356 };
1357 
1358 } // namespace threadblock
1359 } // namespace gemm
1360 } // namespace cutlass
Describes the lane policy used by warp-level matrix multiply operators targeting SIMT instructions...
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Templates implementing loading of tiles from pitch-linear rank=2 tensors.
MmaPolicy< MmaWarpSimt, MatrixShape< kPaddingM, 0 >, MatrixShape< 0, 0 >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_simt.h:1195
Definition: aligned_buffer.h:35
#define constexpr
Definition: platform.h:137
Query the number of threads per warp.
Definition: gemm/warp/mma.h:43
Definition: default_mma_core.h:90
MmaPolicy< MmaWarpSimt, MatrixShape< 0, 0 >, MatrixShape< 0, kPaddingN >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_simt.h:719
Templates implementing how threads are mapped to a given tile.
Definition: pitch_linear_thread_map.h:431
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_simt.h:74
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Describes the arrangement and configuration of per-lane operations in warp-level matrix multiply...
Definition: mma_simt_policy.h:46
MmaPolicy< MmaWarpSimt, MatrixShape< 0, 0 >, MatrixShape< 0, 0 >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_simt.h:873
Defines a Shape template for matrix tiles.
MmaPolicy< MmaWarpSimt, MatrixShape< kPaddingM, 0 >, MatrixShape< 0, kPaddingN >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_simt.h:1036
Defines the size of an element in bits.
Definition: numeric_types.h:42
Thread Mapping a 2D threadtiled mapping as a transposed Pitchlinear2DThreadTile mapping.
Definition: pitch_linear_thread_map.h:713
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Definition: regular_tile_iterator.h:50
#define static_assert(__e, __m)
Definition: platform.h:153
constexpr int simt_transpose_padding(int threads, int crosswise, int size_in_bits)
Computes padding in shared memory to perform efficient transpose without bank conflicts.
Definition: default_mma_core_simt.h:67
MmaPolicy< MmaWarpSimt, MatrixShape< 0, 0 >, MatrixShape< 0, kPaddingN >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_simt.h:1355
Templates implementing loading of tiles from pitch-linear rank=2 tensors.
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
constexpr int simt_get_warp_threads_m()
Definition: default_mma_core_simt.h:62
MmaPolicy< MmaWarpSimt, MatrixShape< 0, 0 >, MatrixShape< 0, 0 >, WarpCount::kK > MmaPolicy
Used for partial specialization.
Definition: default_mma_core_simt.h:229
Math utilities.
MmaPolicy< MmaWarpSimt, MatrixShape< kPaddingN, 0 >, MatrixShape< 0, kPaddingN >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_simt.h:395
Templates implementing warp-level matrix multiply-accumulate operations.
CUTLASS_HOST_DEVICE constexpr int const_min(int a, int b)
Definition: fast_math.h:219
MmaPolicy< MmaWarpSimt, MatrixShape< kPaddingM, 0 >, MatrixShape< 0, 0 >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_simt.h:557
Definition: regular_tile_iterator_pitch_linear_2dthreadtile.h:59
Basic include for CUTLASS.
Definition: pitch_linear_thread_map.h:59
Definition: layout/matrix.h:237