CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
mma_sm70.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include <assert.h>
31 
32 #include "mma.h"
33 #include "cutlass/layout/matrix.h"
34 #include "cutlass/numeric_types.h"
35 
36 #if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1))
37 #define CUTLASS_ARCH_MMA_SM70_SUPPORTED
38 #endif
39 
40 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700))
41 
42 #if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 &&__CUDACC_VER_MINOR__ >= 1))
43 #define CUTLASS_ARCH_MMA_SM70_ENABLED
44 #endif
45 
46 #endif
47 
49 
50 namespace cutlass {
51 namespace arch {
52 
54 //
55 // Matrix multiply accumulate 884 - FP16 accumulation
56 //
58 
60 template <>
61 struct Mma<
62  gemm::GemmShape<8,8,4>,
63  8,
64  half_t,
66  half_t,
68  half_t,
70  OpMultiplyAdd> {
71 
73 
74  using ElementA = half_t;
76  using FragmentA = Array<half_t, 4>;
77 
78  using ElementB = half_t;
80  using FragmentB = Array<half_t, 4>;
81 
82  using ElementC = half_t;
84  using FragmentC = Array<half_t, 8>;
85 
86  using Operator = OpMultiplyAdd;
87 
89  void operator()(
90  FragmentC &d,
91  FragmentA const &a,
92  FragmentB const &b,
93  FragmentC const &c
94  ) {
95 
96 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
97 
98  unsigned const *A = reinterpret_cast<unsigned const *>(&a);
99  unsigned const *B = reinterpret_cast<unsigned const *>(&b);
100  unsigned const *C = reinterpret_cast<unsigned const *>(&c);
101  unsigned *D = reinterpret_cast<unsigned *>(&d);
102 
103  asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
104  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
105  : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
106  );
107 
108 #else
109  assert(0);
110 #endif
111  }
112 };
113 
115 template <>
116 struct Mma<
117  gemm::GemmShape<8, 8, 4>,
118  8,
119  half_t,
121  half_t,
123  half_t,
125  OpMultiplyAdd> {
126 
128 
129  using ElementA = half_t;
131  using FragmentA = Array<half_t, 4>;
132 
133  using ElementB = half_t;
135  using FragmentB = Array<half_t, 4>;
136 
137  using ElementC = half_t;
139  using FragmentC = Array<half_t, 8>;
140 
141  using Operator = OpMultiplyAdd;
142 
145  FragmentC &d,
146  FragmentA const &a,
147  FragmentB const &b,
148  FragmentC const &c
149  ) {
150 
151 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
152 
153  unsigned const *A = reinterpret_cast<unsigned const *>(&a);
154  unsigned const *B = reinterpret_cast<unsigned const *>(&b);
155  unsigned const *C = reinterpret_cast<unsigned const *>(&c);
156  unsigned *D = reinterpret_cast<unsigned *>(&d);
157 
158  asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
159  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
160  : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
161  );
162 
163 #else
164  assert(0);
165 #endif
166  }
167 };
168 
170 template <>
171 struct Mma<
172  gemm::GemmShape<8, 8, 4>,
173  8,
174  half_t,
176  half_t,
178  half_t,
180  OpMultiplyAdd> {
181 
183 
184  using ElementA = half_t;
186  using FragmentA = Array<half_t, 4>;
187 
188  using ElementB = half_t;
190  using FragmentB = Array<half_t, 4>;
191 
192  using ElementC = half_t;
194  using FragmentC = Array<half_t, 8>;
195 
196  using Operator = OpMultiplyAdd;
197 
200  FragmentC &d,
201  FragmentA const &a,
202  FragmentB const &b,
203  FragmentC const &c
204  ) {
205 
206 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
207 
208  unsigned const *A = reinterpret_cast<unsigned const *>(&a);
209  unsigned const *B = reinterpret_cast<unsigned const *>(&b);
210  unsigned const *C = reinterpret_cast<unsigned const *>(&c);
211  unsigned *D = reinterpret_cast<unsigned *>(&d);
212 
213  asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
214  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
215  : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
216  );
217 
218 #else
219  assert(0);
220 #endif
221  }
222 };
223 
225 template <>
226 struct Mma<
227  gemm::GemmShape<8, 8, 4>,
228  8,
229  half_t,
231  half_t,
233  half_t,
235  OpMultiplyAdd> {
236 
238 
239  using ElementA = half_t;
241  using FragmentA = Array<half_t, 4>;
242 
243  using ElementB = half_t;
245  using FragmentB = Array<half_t, 4>;
246 
247  using ElementC = half_t;
249  using FragmentC = Array<half_t, 8>;
250 
251  using Operator = OpMultiplyAdd;
252 
255  FragmentC &d,
256  FragmentA const &a,
257  FragmentB const &b,
258  FragmentC const &c
259  ) {
260 
261 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
262 
263  unsigned const *A = reinterpret_cast<unsigned const *>(&a);
264  unsigned const *B = reinterpret_cast<unsigned const *>(&b);
265  unsigned const *C = reinterpret_cast<unsigned const *>(&c);
266  unsigned *D = reinterpret_cast<unsigned *>(&d);
267 
268  asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 {%0,%1,%2,%3}, {%4,%5}, {%6,%7}, {%8,%9,%10,%11};\n"
269  : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
270  : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])
271  );
272 
273 #else
274  assert(0);
275 #endif
276  }
277 };
278 
280 //
281 // Matrix multiply accumulate 884 - FP32 accumulation
282 //
284 
286 template <>
287 struct Mma<
288  gemm::GemmShape<8, 8, 4>,
289  8,
290  half_t,
292  half_t,
294  float,
296  OpMultiplyAdd> {
297 
299 
300  using ElementA = half_t;
302  using FragmentA = Array<half_t, 4>;
303 
304  using ElementB = half_t;
306  using FragmentB = Array<half_t, 4>;
307 
308  using ElementC = float;
310  using FragmentC = Array<float, 8>;
311 
312  using Operator = OpMultiplyAdd;
313 
317  FragmentC &d,
318  FragmentA const &a,
319  FragmentB const &b,
320  FragmentC const &c
321  ) {
322 
323 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
324 
325  unsigned const *A = reinterpret_cast<unsigned const *>(&a);
326  unsigned const *B = reinterpret_cast<unsigned const *>(&b);
327  float const *C = reinterpret_cast<float const *>(&c);
328  float *D = reinterpret_cast<float *>(&d);
329 
330  asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
331  "{%12,%13,%14,%15,%16,%17,%18,%19};\n"
332  : "=f"(D[0]),
333  "=f"(D[1]),
334  "=f"(D[2]),
335  "=f"(D[3]),
336  "=f"(D[4]),
337  "=f"(D[5]),
338  "=f"(D[6]),
339  "=f"(D[7])
340  : "r"(A[0]),
341  "r"(A[1]),
342  "r"(B[0]),
343  "r"(B[1]),
344  "f"(C[0]),
345  "f"(C[1]),
346  "f"(C[2]),
347  "f"(C[3]),
348  "f"(C[4]),
349  "f"(C[5]),
350  "f"(C[6]),
351  "f"(C[7])
352  );
353 
354 #else
355  assert(0);
356 #endif
357  }
358 };
359 
361 template <>
362 struct Mma<
363  gemm::GemmShape<8, 8, 4>,
364  8,
365  half_t,
367  half_t,
369  float,
371  OpMultiplyAdd> {
372 
374 
375  using ElementA = half_t;
377  using FragmentA = Array<half_t, 4>;
378 
379  using ElementB = half_t;
381  using FragmentB = Array<half_t, 4>;
382 
383  using ElementC = float;
385  using FragmentC = Array<float, 8>;
386 
387  using Operator = OpMultiplyAdd;
388 
392  FragmentC &d,
393  FragmentA const &a,
394  FragmentB const &b,
395  FragmentC const &c
396  ) {
397 
398 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
399 
400  unsigned const *A = reinterpret_cast<unsigned const *>(&a);
401  unsigned const *B = reinterpret_cast<unsigned const *>(&b);
402  float const *C = reinterpret_cast<float const *>(&c);
403  float *D = reinterpret_cast<float *>(&d);
404 
405  asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
406  "{%12,%13,%14,%15,%16,%17,%18,%19};\n"
407  : "=f"(D[0]),
408  "=f"(D[1]),
409  "=f"(D[2]),
410  "=f"(D[3]),
411  "=f"(D[4]),
412  "=f"(D[5]),
413  "=f"(D[6]),
414  "=f"(D[7])
415  : "r"(A[0]),
416  "r"(A[1]),
417  "r"(B[0]),
418  "r"(B[1]),
419  "f"(C[0]),
420  "f"(C[1]),
421  "f"(C[2]),
422  "f"(C[3]),
423  "f"(C[4]),
424  "f"(C[5]),
425  "f"(C[6]),
426  "f"(C[7])
427  );
428 
429 #else
430  assert(0);
431 #endif
432  }
433 };
434 
436 template <>
437 struct Mma<
438  gemm::GemmShape<8, 8, 4>,
439  8,
440  half_t,
442  half_t,
444  float,
446  OpMultiplyAdd> {
447 
449 
450  using ElementA = half_t;
452  using FragmentA = Array<half_t, 4>;
453 
454  using ElementB = half_t;
456  using FragmentB = Array<half_t, 4>;
457 
458  using ElementC = float;
460  using FragmentC = Array<float, 8>;
461 
462  using Operator = OpMultiplyAdd;
463 
467  FragmentC &d,
468  FragmentA const &a,
469  FragmentB const &b,
470  FragmentC const &c
471  ) {
472 
473 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
474 
475  unsigned const *A = reinterpret_cast<unsigned const *>(&a);
476  unsigned const *B = reinterpret_cast<unsigned const *>(&b);
477  float const *C = reinterpret_cast<float const *>(&c);
478  float *D = reinterpret_cast<float *>(&d);
479 
480  asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
481  "{%12,%13,%14,%15,%16,%17,%18,%19};\n"
482  : "=f"(D[0]),
483  "=f"(D[1]),
484  "=f"(D[2]),
485  "=f"(D[3]),
486  "=f"(D[4]),
487  "=f"(D[5]),
488  "=f"(D[6]),
489  "=f"(D[7])
490  : "r"(A[0]),
491  "r"(A[1]),
492  "r"(B[0]),
493  "r"(B[1]),
494  "f"(C[0]),
495  "f"(C[1]),
496  "f"(C[2]),
497  "f"(C[3]),
498  "f"(C[4]),
499  "f"(C[5]),
500  "f"(C[6]),
501  "f"(C[7])
502  );
503 
504 #else
505  assert(0);
506 #endif
507  }
508 };
509 
511 template <>
512 struct Mma<
513  gemm::GemmShape<8, 8, 4>,
514  8,
515  half_t,
517  half_t,
519  float,
521  OpMultiplyAdd> {
522 
524 
525  using ElementA = half_t;
527  using FragmentA = Array<half_t, 4>;
528 
529  using ElementB = half_t;
531  using FragmentB = Array<half_t, 4>;
532 
533  using ElementC = float;
535  using FragmentC = Array<float, 8>;
536 
537  using Operator = OpMultiplyAdd;
538 
542  FragmentC &d,
543  FragmentA const &a,
544  FragmentB const &b,
545  FragmentC const &c
546  ) {
547 
548 #if defined(CUTLASS_ARCH_MMA_SM70_ENABLED)
549 
550  unsigned const *A = reinterpret_cast<unsigned const *>(&a);
551  unsigned const *B = reinterpret_cast<unsigned const *>(&b);
552  float const *C = reinterpret_cast<float const *>(&c);
553  float *D = reinterpret_cast<float *>(&d);
554 
555  asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32 {%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, "
556  "{%12,%13,%14,%15,%16,%17,%18,%19};\n"
557  : "=f"(D[0]),
558  "=f"(D[1]),
559  "=f"(D[2]),
560  "=f"(D[3]),
561  "=f"(D[4]),
562  "=f"(D[5]),
563  "=f"(D[6]),
564  "=f"(D[7])
565  : "r"(A[0]),
566  "r"(A[1]),
567  "r"(B[0]),
568  "r"(B[1]),
569  "f"(C[0]),
570  "f"(C[1]),
571  "f"(C[2]),
572  "f"(C[3]),
573  "f"(C[4]),
574  "f"(C[5]),
575  "f"(C[6]),
576  "f"(C[7])
577  );
578 
579 #else
580  assert(0);
581 #endif
582  }
583 };
584 
586 
588 template <
589  typename LayoutA,
590  typename LayoutB,
591  typename ElementC,
592  typename LayoutC,
593  typename Operator
594 >
595 struct Mma<
596  gemm::GemmShape<16, 16, 4>,
597  32,
598  half_t,
599  LayoutA,
600  half_t,
601  LayoutB,
602  ElementC,
603  LayoutC,
604  Operator
605 > :
606  public Mma<
607  gemm::GemmShape<8, 8, 4>,
608  8,
609  half_t,
610  LayoutA,
611  half_t,
612  LayoutB,
613  ElementC,
614  LayoutC,
615  Operator> {
616 
618 };
619 
621 
622 } // namespace arch
623 } // namespace cutlass
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Definition: mma_sm70.h:199
IEEE half-precision floating-point type.
Definition: half.h:126
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Multiply-add.
Definition: mma_sm70.h:391
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Definition: mma_sm70.h:89
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Multiply-add.
Definition: mma_sm70.h:316
Templates exposing architecture support for multiply-add operations.
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Definition: mma_sm70.h:254
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Definition: mma_sm70.h:144
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Multiply-add.
Definition: mma_sm70.h:541
Defines layout functions used by TensorRef and derived classes.
Matrix multiply-add operation.
Definition: arch/mma.h:92
CUTLASS_HOST_DEVICE void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c)
Multiply-add.
Definition: mma_sm70.h:466