CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
mma_sm75.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <assert.h>
32 
33 #include "cutlass/arch/wmma.h"
34 
35 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
36 // CUDA Toolkit includes for nvcuda::wmma needed for binarized matrix multiply.
37 #include <mma.h>
38 #include "cutlass/wmma_array.h"
39 #endif
40 
41 // CUTLASS includes
42 #include "cutlass/arch/mma.h"
43 #include "cutlass/layout/matrix.h"
44 #include "cutlass/numeric_types.h"
45 
47 
48 #if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))
49 
50 #define CUTLASS_ARCH_MMA_SM75_SUPPORTED 1
51 
52 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
53 #define CUTLASS_ARCH_MMA_SM75_ENABLED
54 #endif
55 #endif
56 
58 
59 namespace cutlass {
60 namespace arch {
61 
63 //
64 // Matrix Multiply 1688 - FP16 accumulation
65 //
67 
69 template <>
70 struct Mma<
71  gemm::GemmShape<16, 8, 8>,
72  32,
73  half_t,
75  half_t,
77  half_t,
79  OpMultiplyAdd> {
80 
82 
83  using ElementA = half_t;
85  using FragmentA = Array<half_t, 4>;
86 
87  using ElementB = half_t;
89  using FragmentB = Array<half_t, 2>;
90 
91  using ElementC = half_t;
93  using FragmentC = Array<half_t, 4>;
94 
95  using Operator = OpMultiplyAdd;
96 
98  void operator()(
99  FragmentC &d,
100  FragmentA const &a,
101  FragmentB const &b,
102  FragmentC const &c
103  ) const {
104 
105 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
106 
107  unsigned const *A = reinterpret_cast<unsigned const *>(&a);
108  unsigned const *B = reinterpret_cast<unsigned const *>(&b);
109  unsigned const *C = reinterpret_cast<unsigned const *>(&c);
110  unsigned *D = reinterpret_cast<unsigned *>(&d);
111 
112  asm volatile(
113  "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0,%1}, {%2,%3}, {%4}, {%5,%6};\n"
114  : "=r"(D[0]), "=r"(D[1])
115  : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1]));
116 
117 #else
118  assert(0);
119 #endif
120  }
121 };
122 
124 //
125 // Matrix Multiply 1688 - FP32 accumulation
126 //
128 
130 template <>
131 struct Mma<
132  gemm::GemmShape<16, 8, 8>,
133  32,
134  half_t,
136  half_t,
138  float,
140  OpMultiplyAdd> {
141 
143 
144  using ElementA = half_t;
146  using FragmentA = Array<half_t, 4>;
147 
148  using ElementB = half_t;
150  using FragmentB = Array<half_t, 2>;
151 
152  using ElementC = float;
154  using FragmentC = Array<float, 4>;
155 
156  using Operator = OpMultiplyAdd;
157 
160  void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
161  FragmentC const &c) const {
162 
163 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
164 
165  unsigned const *A = reinterpret_cast<unsigned const *>(&a);
166  unsigned const *B = reinterpret_cast<unsigned const *>(&b);
167  float const *C = reinterpret_cast<float const *>(&c);
168  float *D = reinterpret_cast<float *>(&d);
169 
170  asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
171  : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
172  :
173  "r"(A[0]), "r"(A[1]),
174  "r"(B[0]),
175  "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
176  );
177 
178 #else
179  assert(0);
180 #endif
181  }
182 };
183 
185 //
186 // Integer matrix multiply .8816 (8b)
187 //
189 
191 template <>
192 struct Mma<
193  gemm::GemmShape<8, 8, 16>,
194  32,
195  int8_t,
197  int8_t,
199  int,
201  OpMultiplyAdd> {
202 
204 
205  using ElementA = int8_t;
207  using FragmentA = Array<int8_t, 4>;
208 
209  using ElementB = int8_t;
211  using FragmentB = Array<int8_t, 4>;
212 
213  using ElementC = int;
215  using FragmentC = Array<int, 2>;
216 
217  using Operator = OpMultiplyAdd;
218 
222  FragmentC &d,
223  FragmentA const &a,
224  FragmentB const &b,
225  FragmentC const &c
226  ) const {
227 
228 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
229 
230  unsigned const & A = reinterpret_cast<unsigned const &>(a);
231  unsigned const & B = reinterpret_cast<unsigned const &>(b);
232 
233  int const *C = reinterpret_cast<int const *>(&c);
234  int *D = reinterpret_cast<int *>(&d);
235 
236  asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
237  : "=r"(D[0]), "=r"(D[1])
238  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
239 
240 #else
241  assert(0);
242 #endif
243  }
244 };
245 
247 template <>
248 struct Mma<
249  gemm::GemmShape<8, 8, 16>,
250  32,
251  uint8_t,
253  int8_t,
255  int,
257  OpMultiplyAdd> {
258 
260 
261  using ElementA = uint8_t;
263  using FragmentA = Array<uint8_t, 4>;
264 
265  using ElementB = int8_t;
267  using FragmentB = Array<int8_t, 4>;
268 
269  using ElementC = int;
271  using FragmentC = Array<int, 2>;
272 
273  using Operator = OpMultiplyAdd;
274 
278  FragmentC &d,
279  FragmentA const &a,
280  FragmentB const &b,
281  FragmentC const &c
282  ) const {
283 
284 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
285 
286  unsigned const & A = reinterpret_cast<unsigned const &>(a);
287  unsigned const & B = reinterpret_cast<unsigned const &>(b);
288 
289  int const *C = reinterpret_cast<int const *>(&c);
290  int *D = reinterpret_cast<int *>(&d);
291 
292  asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
293  : "=r"(D[0]), "=r"(D[1])
294  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
295 
296 #else
297  assert(0);
298 #endif
299  }
300 };
301 
303 template <>
304 struct Mma<
305  gemm::GemmShape<8, 8, 16>,
306  32,
307  int8_t,
309  uint8_t,
311  int,
313  OpMultiplyAdd> {
314 
316 
317  using ElementA = int8_t;
319  using FragmentA = Array<int8_t, 4>;
320 
321  using ElementB = uint8_t;
323  using FragmentB = Array<uint8_t, 4>;
324 
325  using ElementC = int;
327  using FragmentC = Array<int, 2>;
328 
329  using Operator = OpMultiplyAdd;
330 
334  FragmentC &d,
335  FragmentA const &a,
336  FragmentB const &b,
337  FragmentC const &c
338  ) const {
339 
340 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
341 
342  unsigned const & A = reinterpret_cast<unsigned const &>(a);
343  unsigned const & B = reinterpret_cast<unsigned const &>(b);
344 
345  int const *C = reinterpret_cast<int const *>(&c);
346  int *D = reinterpret_cast<int *>(&d);
347 
348  asm volatile("mma.sync.aligned.m8n8k16.row.col.s8.u8 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
349  : "=r"(D[0]), "=r"(D[1])
350  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
351 
352 
353 #else
354  assert(0);
355 #endif
356  }
357 };
358 
360 template <>
361 struct Mma<
362  gemm::GemmShape<8, 8, 16>,
363  32,
364  uint8_t,
366  uint8_t,
368  int,
370  OpMultiplyAdd> {
371 
373 
374  using ElementA = uint8_t;
376  using FragmentA = Array<uint8_t, 4>;
377 
378  using ElementB = uint8_t;
380  using FragmentB = Array<uint8_t, 4>;
381 
382  using ElementC = int;
384  using FragmentC = Array<int, 2>;
385 
386  using Operator = OpMultiplyAdd;
387 
391  FragmentC &d,
392  FragmentA const &a,
393  FragmentB const &b,
394  FragmentC const &c
395  ) const {
396 
397 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
398 
399  unsigned const & A = reinterpret_cast<unsigned const &>(a);
400  unsigned const & B = reinterpret_cast<unsigned const &>(b);
401 
402  int const *C = reinterpret_cast<int const *>(&c);
403  int *D = reinterpret_cast<int *>(&d);
404 
405  asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
406  : "=r"(D[0]), "=r"(D[1])
407  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
408 
409 #else
410  assert(0);
411 #endif
412  }
413 };
414 
416 //
417 // Integer matrix multiply (8b) with SATURATE
418 //
420 
422 template <>
423 struct Mma<
424  gemm::GemmShape<8,8,16>,
425  32,
426  int8_t,
428  int8_t,
430  int,
432  OpMultiplyAddSaturate> {
433 
435 
436  using ElementA = int8_t;
438  using FragmentA = Array<int8_t, 4>;
439 
440  using ElementB = int8_t;
442  using FragmentB = Array<int8_t, 4>;
443 
444  using ElementC = int;
446  using FragmentC = Array<int, 2>;
447 
448  using Operator = OpMultiplyAddSaturate;
449 
453  FragmentC &d,
454  FragmentA const &a,
455  FragmentB const &b,
456  FragmentC const &c
457  ) const {
458 
459 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
460 
461  unsigned const & A = reinterpret_cast<unsigned const &>(a);
462  unsigned const & B = reinterpret_cast<unsigned const &>(b);
463 
464  int const *C = reinterpret_cast<int const *>(&c);
465  int *D = reinterpret_cast<int *>(&d);
466 
467  asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
468  : "=r"(D[0]), "=r"(D[1])
469  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
470 
471 #else
472  assert(0);
473 #endif
474  }
475 };
476 
478 template <>
479 struct Mma<
480  gemm::GemmShape<8,8,16>,
481  32,
482  uint8_t,
484  int8_t,
486  int,
488  OpMultiplyAddSaturate> {
489 
491 
492  using ElementA = uint8_t;
494  using FragmentA = Array<uint8_t, 4>;
495 
496  using ElementB = int8_t;
498  using FragmentB = Array<int8_t, 4>;
499 
500  using ElementC = int;
502  using FragmentC = Array<int, 2>;
503 
504  using Operator = OpMultiplyAddSaturate;
505 
509  FragmentC &d,
510  FragmentA const &a,
511  FragmentB const &b,
512  FragmentC const &c
513  ) const {
514 
515 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
516 
517  unsigned const & A = reinterpret_cast<unsigned const &>(a);
518  unsigned const & B = reinterpret_cast<unsigned const &>(b);
519 
520  int const *C = reinterpret_cast<int const *>(&c);
521  int *D = reinterpret_cast<int *>(&d);
522 
523  asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
524  : "=r"(D[0]), "=r"(D[1])
525  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
526 
527 #else
528  assert(0);
529 #endif
530  }
531 };
532 
534 template <>
535 struct Mma<
536  gemm::GemmShape<8,8,16>,
537  32,
538  int8_t,
540  uint8_t,
542  int,
544  OpMultiplyAddSaturate> {
545 
547 
548  using ElementA = int8_t;
550  using FragmentA = Array<int8_t, 4>;
551 
552  using ElementB = uint8_t;
554  using FragmentB = Array<uint8_t, 4>;
555 
556  using ElementC = int;
558  using FragmentC = Array<int, 2>;
559 
560  using Operator = OpMultiplyAddSaturate;
561 
565  FragmentC &d,
566  FragmentA const &a,
567  FragmentB const &b,
568  FragmentC const &c
569  ) const {
570 
571 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
572 
573  unsigned const & A = reinterpret_cast<unsigned const &>(a);
574  unsigned const & B = reinterpret_cast<unsigned const &>(b);
575 
576  int const *C = reinterpret_cast<int const *>(&c);
577  int *D = reinterpret_cast<int *>(&d);
578 
579  asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
580  : "=r"(D[0]), "=r"(D[1])
581  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
582 
583 #else
584  assert(0);
585 #endif
586  }
587 };
588 
590 template <>
591 struct Mma<
592  gemm::GemmShape<8,8,16>,
593  32,
594  uint8_t,
596  uint8_t,
598  int,
600  OpMultiplyAddSaturate> {
601 
603 
604  using ElementA = uint8_t;
606  using FragmentA = Array<uint8_t, 4>;
607 
608  using ElementB = uint8_t;
610  using FragmentB = Array<uint8_t, 4>;
611 
612  using ElementC = int;
614  using FragmentC = Array<int, 2>;
615 
616  using Operator = OpMultiplyAddSaturate;
617 
621  FragmentC &d,
622  FragmentA const &a,
623  FragmentB const &b,
624  FragmentC const &c
625  ) const {
626 
627 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
628 
629  unsigned const & A = reinterpret_cast<unsigned const &>(a);
630  unsigned const & B = reinterpret_cast<unsigned const &>(b);
631 
632  int const *C = reinterpret_cast<int const *>(&c);
633  int *D = reinterpret_cast<int *>(&d);
634 
635  asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
636  : "=r"(D[0]), "=r"(D[1])
637  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
638 
639 #else
640  assert(0);
641 #endif
642  }
643 };
644 
646 //
647 // Integer matrix multiply (4b)
648 //
650 
652 template <>
653 struct Mma<
654  gemm::GemmShape<8,8,32>,
655  32,
656  int4b_t,
658  int4b_t,
660  int,
662  OpMultiplyAdd> {
663 
665 
666  using ElementA = int4b_t;
668  using FragmentA = Array<int4b_t, 8>;
669 
670  using ElementB = int4b_t;
672  using FragmentB = Array<int4b_t, 8>;
673 
674  using ElementC = int;
676  using FragmentC = Array<int, 2>;
677 
678  using Operator = OpMultiplyAdd;
679 
683  FragmentC &d,
684  FragmentA const &a,
685  FragmentB const &b,
686  FragmentC const &c
687  ) const {
688 
689 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
690 
691  unsigned const & A = reinterpret_cast<unsigned const &>(a);
692  unsigned const & B = reinterpret_cast<unsigned const &>(b);
693 
694  int const *C = reinterpret_cast<int const *>(&c);
695  int *D = reinterpret_cast<int *>(&d);
696 
697  asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
698  : "=r"(D[0]), "=r"(D[1])
699  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
700 
701 #else
702  assert(0);
703 #endif
704  }
705 };
706 
708 template <>
709 struct Mma<
710  gemm::GemmShape<8,8,32>,
711  32,
712  uint4b_t,
714  int4b_t,
716  int,
718  OpMultiplyAdd> {
719 
721 
724  using FragmentA = Array<uint4b_t, 8>;
725 
726  using ElementB = int4b_t;
728  using FragmentB = Array<int4b_t, 8>;
729 
730  using ElementC = int;
732  using FragmentC = Array<int, 2>;
733 
734  using Operator = OpMultiplyAdd;
735 
739  FragmentC &d,
740  FragmentA const &a,
741  FragmentB const &b,
742  FragmentC const &c
743  ) const {
744 
745 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
746 
747  unsigned const & A = reinterpret_cast<unsigned const &>(a);
748  unsigned const & B = reinterpret_cast<unsigned const &>(b);
749 
750  int const *C = reinterpret_cast<int const *>(&c);
751  int *D = reinterpret_cast<int *>(&d);
752 
753  asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
754  : "=r"(D[0]), "=r"(D[1])
755  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
756 
757 #else
758  assert(0);
759 #endif
760  }
761 };
762 
764 template <>
765 struct Mma<
766  gemm::GemmShape<8,8,32>,
767  32,
768  int4b_t,
770  uint4b_t,
772  int,
774  OpMultiplyAdd> {
775 
777 
778  using ElementA = int4b_t;
780  using FragmentA = Array<int4b_t, 8>;
781 
784  using FragmentB = Array<uint4b_t, 8>;
785 
786  using ElementC = int;
788  using FragmentC = Array<int, 2>;
789 
790  using Operator = OpMultiplyAdd;
791 
795  FragmentC &d,
796  FragmentA const &a,
797  FragmentB const &b,
798  FragmentC const &c
799  ) const {
800 
801 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
802 
803  unsigned const & A = reinterpret_cast<unsigned const &>(a);
804  unsigned const & B = reinterpret_cast<unsigned const &>(b);
805 
806  int const *C = reinterpret_cast<int const *>(&c);
807  int *D = reinterpret_cast<int *>(&d);
808 
809  asm volatile("_mma.m8n8k32.row.col.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
810  : "=r"(D[0]), "=r"(D[1])
811  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
812 
813 #else
814  assert(0);
815 #endif
816  }
817 };
818 
820 template <>
821 struct Mma<
822  gemm::GemmShape<8,8,32>,
823  32,
824  uint4b_t,
826  uint4b_t,
828  int,
830  OpMultiplyAdd> {
831 
833 
836  using FragmentA = Array<uint4b_t, 8>;
837 
840  using FragmentB = Array<uint4b_t, 8>;
841 
842  using ElementC = int;
844  using FragmentC = Array<int, 2>;
845 
846  using Operator = OpMultiplyAdd;
847 
851  FragmentC &d,
852  FragmentA const &a,
853  FragmentB const &b,
854  FragmentC const &c
855  ) const {
856 
857 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
858 
859  unsigned const & A = reinterpret_cast<unsigned const &>(a);
860  unsigned const & B = reinterpret_cast<unsigned const &>(b);
861 
862  int const *C = reinterpret_cast<int const *>(&c);
863  int *D = reinterpret_cast<int *>(&d);
864 
865  asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
866  : "=r"(D[0]), "=r"(D[1])
867  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
868 
869 #else
870  assert(0);
871 #endif
872  }
873 };
874 
876 //
877 // Integer matrix multiply (4b) - SATURATE
878 //
880 
882 template <>
883 struct Mma<
884  gemm::GemmShape<8,8,32>,
885  32,
886  int4b_t,
888  int4b_t,
890  int,
892  OpMultiplyAddSaturate> {
893 
895 
896  using ElementA = int4b_t;
898  using FragmentA = Array<int4b_t, 8>;
899 
900  using ElementB = int4b_t;
902  using FragmentB = Array<int4b_t, 8>;
903 
904  using ElementC = int;
906  using FragmentC = Array<int, 2>;
907 
908  using Operator = OpMultiplyAddSaturate;
909 
913  FragmentC &d,
914  FragmentA const &a,
915  FragmentB const &b,
916  FragmentC const &c
917  ) const {
918 
919 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
920 
921  unsigned const & A = reinterpret_cast<unsigned const &>(a);
922  unsigned const & B = reinterpret_cast<unsigned const &>(b);
923 
924  int const *C = reinterpret_cast<int const *>(&c);
925  int *D = reinterpret_cast<int *>(&d);
926 
927  asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
928  : "=r"(D[0]), "=r"(D[1])
929  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
930 
931 #else
932  assert(0);
933 #endif
934  }
935 };
936 
938 template <>
939 struct Mma<
940  gemm::GemmShape<8,8,32>,
941  32,
942  uint4b_t,
944  int4b_t,
946  int,
948  OpMultiplyAddSaturate> {
949 
951 
954  using FragmentA = Array<uint4b_t, 8>;
955 
956  using ElementB = int4b_t;
958  using FragmentB = Array<int4b_t, 8>;
959 
960  using ElementC = int;
962  using FragmentC = Array<int, 2>;
963 
964  using Operator = OpMultiplyAddSaturate;
965 
969  FragmentC &d,
970  FragmentA const &a,
971  FragmentB const &b,
972  FragmentC const &c
973  ) const {
974 
975 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
976 
977  unsigned const & A = reinterpret_cast<unsigned const &>(a);
978  unsigned const & B = reinterpret_cast<unsigned const &>(b);
979 
980  int const *C = reinterpret_cast<int const *>(&c);
981  int *D = reinterpret_cast<int *>(&d);
982 
983  asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
984  : "=r"(D[0]), "=r"(D[1])
985  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
986 
987 #else
988  assert(0);
989 #endif
990  }
991 };
992 
994 template <>
995 struct Mma<
996  gemm::GemmShape<8,8,32>,
997  32,
998  int4b_t,
1000  uint4b_t,
1002  int,
1004  OpMultiplyAddSaturate> {
1005 
1007 
1010  using FragmentA = Array<int4b_t, 8>;
1011 
1014  using FragmentB = Array<uint4b_t, 8>;
1015 
1016  using ElementC = int;
1018  using FragmentC = Array<int, 2>;
1019 
1020  using Operator = OpMultiplyAddSaturate;
1021 
1025  FragmentC &d,
1026  FragmentA const &a,
1027  FragmentB const &b,
1028  FragmentC const &c
1029  ) const {
1030 
1031 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
1032 
1033  unsigned const & A = reinterpret_cast<unsigned const &>(a);
1034  unsigned const & B = reinterpret_cast<unsigned const &>(b);
1035 
1036  int const *C = reinterpret_cast<int const *>(&c);
1037  int *D = reinterpret_cast<int *>(&d);
1038 
1039  asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
1040  : "=r"(D[0]), "=r"(D[1])
1041  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
1042 
1043 #else
1044  assert(0);
1045 #endif
1046  }
1047 };
1048 
1050 template <>
1051 struct Mma<
1052  gemm::GemmShape<8,8,32>,
1053  32,
1054  uint4b_t,
1056  uint4b_t,
1058  int,
1060  OpMultiplyAddSaturate> {
1061 
1063 
1066  using FragmentA = Array<uint4b_t, 8>;
1067 
1070  using FragmentB = Array<uint4b_t, 8>;
1071 
1072  using ElementC = int;
1074  using FragmentC = Array<int, 2>;
1075 
1076  using Operator = OpMultiplyAddSaturate;
1077 
1081  FragmentC &d,
1082  FragmentA const &a,
1083  FragmentB const &b,
1084  FragmentC const &c
1085  ) const {
1086 
1087 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
1088 
1089  unsigned const & A = reinterpret_cast<unsigned const &>(a);
1090  unsigned const & B = reinterpret_cast<unsigned const &>(b);
1091 
1092  int const *C = reinterpret_cast<int const *>(&c);
1093  int *D = reinterpret_cast<int *>(&d);
1094 
1095  asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n"
1096  : "=r"(D[0]), "=r"(D[1])
1097  : "r"(A), "r"(B), "r"(C[0]), "r"(C[1]));
1098 
1099 #else
1100  assert(0);
1101 #endif
1102  }
1103 };
1104 
1106 //
1107 // b1 ^ b1 + s32 => s32
1108 //
1110 
1112 template <>
1113 struct Mma<
1114  gemm::GemmShape<8,8,128>,
1115  32,
1116  uint1b_t,
1118  uint1b_t,
1120  int,
1122  OpXorPopc> {
1123 
1125 
1128  using FragmentA = Array<uint1b_t, 32>;
1129 
1132  using FragmentB = Array<uint1b_t, 32>;
1133 
1134  using ElementC = int;
1136  using FragmentC = Array<int, 2>;
1137 
1138  using Operator = OpXorPopc;
1139 
1143  FragmentC &d,
1144  FragmentA const &a,
1145  FragmentB const &b,
1146  FragmentC const &c
1147  ) const {
1148 
1149 #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED)
1150 
1151 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
1152  using WmmaFragmentA = nvcuda::wmma::fragment<
1153  nvcuda::wmma::matrix_a,
1154  Shape::kM,
1155  Shape::kN,
1156  Shape::kK,
1157  nvcuda::wmma::experimental::precision::b1,
1158  nvcuda::wmma::row_major>;
1159 
1160  using WmmaFragmentB = nvcuda::wmma::fragment<
1161  nvcuda::wmma::matrix_b,
1162  Shape::kM,
1163  Shape::kN,
1164  Shape::kK,
1165  nvcuda::wmma::experimental::precision::b1,
1166  nvcuda::wmma::col_major>;
1167 
1168  using WmmaFragmentC = nvcuda::wmma::fragment<
1169  nvcuda::wmma::accumulator,
1170  Shape::kM,
1171  Shape::kN,
1172  Shape::kK,
1173  int>;
1174 
1175  WmmaFragmentA const & A = reinterpret_cast<WmmaFragmentA const &>(a);
1176  WmmaFragmentB const & B = reinterpret_cast<WmmaFragmentB const &>(b);
1177 
1178  WmmaFragmentC const & C = reinterpret_cast<WmmaFragmentC const &>(c);
1179  WmmaFragmentC & D = reinterpret_cast<WmmaFragmentC &>(d);
1180 
1181  nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
1182  nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
1183 #else
1184 
1185  assert(0); // WMMA must be supported to issue binary matrix multiply-accumulate instructions.
1186 
1187 #endif // defined(CUTLASS_ARCH_WMMA_ENABLED)
1188 
1189 #else
1190  assert(0);
1191 #endif
1192 
1193  }
1194 };
1195 
1197 
1198 } // namespace arch
1199 } // namespace cutlass
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:794
integer_subbyte< 4, false > uint4b_t
4-bit Unsigned integer type
Definition: integer_subbyte.h:158
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:277
integer_subbyte< 1, false > uint1b_t
1-bit Unsigned integer type
Definition: integer_subbyte.h:152
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:968
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:1024
4-bit signed integer type
Definition: integer_subbyte.h:42
IEEE half-precision floating-point type.
Definition: half.h:126
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:564
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:1142
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:912
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:333
Templates exposing architecture support for multiply-add operations.
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:508
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:682
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:221
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:1080
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:620
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:738
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:160
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:850
Defines layout functions used by TensorRef and derived classes.
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:452
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Computes multiply-add.
Definition: mma_sm75.h:390
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c) const
Definition: mma_sm75.h:98
Matrix multiply-add operation.
Definition: arch/mma.h:92
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
integer_subbyte< 4, true > int4b_t
4-bit Integer type
Definition: integer_subbyte.h:155