CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
arch/mma_sm60.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <cuda_fp16.h>
32 
33 #include "cutlass/arch/mma.h"
34 
35 #include "cutlass/layout/matrix.h"
36 
38 
39 namespace cutlass {
40 namespace arch {
41 
43 
45 template <typename LayoutA, typename LayoutB, typename LayoutC>
46 struct Mma<
47  gemm::GemmShape<2,1,1>,
48  1,
49  half_t,
50  LayoutA,
51  half_t,
52  LayoutB,
53  half_t,
54  LayoutC,
55  OpMultiplyAdd> {
56 
58 
60  void operator()(
61  Array<half_t, 2> &d,
62  Array<half_t, 2> const &a,
63  Array<half_t, 1> const &b,
64  Array<half_t, 2> const &c
65  ) {
66 
67 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
68 
69  __half2 const & A = reinterpret_cast<__half2 const &>(a);
70  __half2 B = __half2half2(reinterpret_cast<__half const &>(b));
71  __half2 const & C = reinterpret_cast<__half2 const &>(c);
72 
73  __half2 D = __hfma2(A, B, C);
74 
75  d = reinterpret_cast<Array<half_t, 2> &>(D);
76 
77 #else
79  for (int i = 0; i < 2; ++i) {
80  d[i] = a[i] * b[0] + c[i];
81  }
82 #endif
83  }
84 };
85 
87 
89 template <typename LayoutA, typename LayoutB>
90 struct Mma<
91  gemm::GemmShape<1,2,1>,
92  1,
93  half_t,
94  LayoutA,
95  half_t,
96  LayoutB,
97  half_t,
99  OpMultiplyAdd> {
100 
102 
105  Array<half_t, 2> &d,
106  Array<half_t, 1> const &a,
107  Array<half_t, 2> const &b,
108  Array<half_t, 2> const &c
109  ) {
110 
111 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
112 
113  __half2 const & A = __half2half2(reinterpret_cast<__half const &>(a));
114  __half2 B = reinterpret_cast<__half2 const &>(b);
115  __half2 const & C = reinterpret_cast<__half2 const &>(c);
116 
117  __half2 D = __hfma2(A, B, C);
118 
119  d = reinterpret_cast<Array<half_t, 2> &>(D);
120 
121 #else
123  for (int i = 0; i < 2; ++i) {
124  d[i] = a[0] * b[i] + c[i];
125  }
126 #endif
127  }
128 };
129 
131 
133 template <>
134 struct Mma <
135  gemm::GemmShape<2, 2, 1>,
136  1,
137  half_t,
139  half_t,
141  half_t,
143  OpMultiplyAdd> {
144 
146 
149  Array<half_t, 4> &d,
150  Array<half_t, 2> const &a,
151  Array<half_t, 2> const &b,
152  Array<half_t, 4> const &c
153  ) {
154 
155 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
156 
157  __half2 const & A = reinterpret_cast<__half2 const &>(a);
158  __half2 Blo = __low2half2(reinterpret_cast<__half2 const &>(b));
159  __half2 Bhi = __high2half2(reinterpret_cast<__half2 const &>(b));
160 
161  __half2 const *C = reinterpret_cast<__half2 const *>(&c);
162 
163  __half2 Dlo = __hfma2(A, Blo, C[0]);
164  __half2 Dhi = __hfma2(A, Bhi, C[1]);
165 
166  Array<half_t, 2> * D = reinterpret_cast<Array<half_t, 2> *>(&d);
167 
168  D[0] = reinterpret_cast<Array<half_t, 2> const &>(Dlo);
169  D[1] = reinterpret_cast<Array<half_t, 2> const &>(Dhi);
170 
171 #else
173  for (int j = 0; j < 2; ++j) {
175  for (int i = 0; i < 2; ++i) {
176  d[i + 2 * j] = a[i] * b[j] + c[i + 2 * j];
177  }
178  }
179 #endif
180  }
181 };
182 
184 
186 template <>
187 struct Mma<
188  gemm::GemmShape<2, 2, 1>,
189  1,
190  half_t,
192  half_t,
194  half_t,
196  OpMultiplyAdd> {
197 
199 
202  Array<half_t, 4> &d,
203  Array<half_t, 2> const &a,
204  Array<half_t, 2> const &b,
205  Array<half_t, 4> const &c
206  ) {
207 
208 #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600))
209 
210  __half2 Alo = __low2half2(reinterpret_cast<__half2 const &>(a));
211  __half2 Ahi = __high2half2(reinterpret_cast<__half2 const &>(a));
212  __half2 const & B = reinterpret_cast<__half2 const &>(b);
213 
214  __half2 const *C = reinterpret_cast<__half2 const *>(&c);
215 
216  __half2 Dlo = __hfma2(Alo, B, C[0]);
217  __half2 Dhi = __hfma2(Ahi, B, C[0]);
218 
219  Array<half_t, 2> * D = reinterpret_cast<Array<half_t, 2> *>(&d);
220 
221  D[0] = reinterpret_cast<Array<half_t, 2> &>(Dlo);
222  D[1] = reinterpret_cast<Array<half_t, 2> &>(Dhi);
223 #else
225  for (int i = 0; i < 2; ++i) {
227  for (int j = 0; j < 2; ++j) {
228  d[i * 2 + j] = a[i] * b[j] + c[i * 2 + j];
229  }
230  }
231 #endif
232  }
233 };
234 
236 
237 }
238 }
239 
CUTLASS_HOST_DEVICE void operator()(Array< half_t, 2 > &d, Array< half_t, 1 > const &a, Array< half_t, 2 > const &b, Array< half_t, 2 > const &c)
Definition: arch/mma_sm60.h:104
Definition: aligned_buffer.h:35
IEEE half-precision floating-point type.
Definition: half.h:126
CUTLASS_HOST_DEVICE void operator()(Array< half_t, 2 > &d, Array< half_t, 2 > const &a, Array< half_t, 1 > const &b, Array< half_t, 2 > const &c)
Definition: arch/mma_sm60.h:60
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Templates exposing architecture support for multiply-add operations.
CUTLASS_HOST_DEVICE void operator()(Array< half_t, 4 > &d, Array< half_t, 2 > const &a, Array< half_t, 2 > const &b, Array< half_t, 4 > const &c)
Definition: arch/mma_sm60.h:148
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Defines layout functions used by TensorRef and derived classes.
Matrix multiply-add operation.
Definition: arch/mma.h:92
CUTLASS_HOST_DEVICE void operator()(Array< half_t, 4 > &d, Array< half_t, 2 > const &a, Array< half_t, 2 > const &b, Array< half_t, 4 > const &c)
Definition: arch/mma_sm60.h:201