CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
arch/mma_sm50.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/arch/mma.h"
32 #include "cutlass/complex.h"
33 
34 #include "cutlass/layout/matrix.h"
35 #include "cutlass/gemm/gemm.h"
36 
38 
39 namespace cutlass {
40 namespace arch {
41 
43 
45 template <
47  typename LayoutA,
49  typename LayoutB,
51  typename LayoutC
52 >
53 struct Mma<gemm::GemmShape<1, 1, 1>, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> {
54 
56 
58  void operator()(
59  Array<float, 1> &d,
60  Array<float, 1> const &a,
61  Array<float, 1> const &b,
62  Array<float, 1> const &c
63  ) {
64  d[0] = a[0] * b[0] + c[0];
65  }
66 };
67 
69 
71 template <
73  typename LayoutA,
75  typename LayoutB,
77  typename LayoutC
78 >
79 struct Mma<gemm::GemmShape<1, 1, 1>, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd> {
80 
82 
84  void operator()(
85  Array<double, 1> &d,
86  Array<double, 1> const &a,
87  Array<double, 1> const &b,
88  Array<double, 1> const &c
89  ) {
90 
91  d[0] = a[0] * b[0] + c[0];
92  }
93 };
94 
96 
98 template <
100  typename LayoutA,
102  typename LayoutB,
104  typename LayoutC
105 >
106 struct Mma<gemm::GemmShape<1, 1, 1>, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd> {
107 
109 
112  Array<int, 1> &d,
113  Array<int, 1> const &a,
114  Array<int, 1> const &b,
115  Array<int, 1> const &c
116  ) {
117 
118  d[0] = a[0] * b[0] + c[0];
119  }
120 };
121 
123 
125 template <
127  typename LayoutA,
129  typename LayoutB,
131  typename LayoutC
132 >
133 struct Mma<
134  gemm::GemmShape<1, 1, 1>,
135  1,
136  complex<float>,
137  LayoutA,
138  complex<float>,
139  LayoutB,
140  complex<float>,
141  LayoutC,
142  OpMultiplyAdd> {
143 
145 
148  Array<complex<float>, 1> &d,
149  Array<complex<float>, 1> const &a,
150  Array<complex<float>, 1> const &b,
151  Array<complex<float>, 1> const &c
152  ) {
153 
154  d[0].real() = a[0].real() * b[0].real() + c[0].real();
155  d[0].imag() = a[0].imag() * b[0].real() + c[0].imag();
156  d[0].real() = -a[0].imag() * b[0].imag() + d[0].real();
157  d[0].imag() = a[0].real() * b[0].imag() + d[0].imag();
158  }
159 };
160 
162 
164 template <
166  typename LayoutA,
168  typename LayoutB,
170  typename LayoutC
171 >
172 struct Mma<
173  gemm::GemmShape<1, 1, 1>,
174  1,
175  complex<float>,
176  LayoutA,
177  float,
178  LayoutB,
179  complex<float>,
180  LayoutC,
181  OpMultiplyAdd> {
182 
184 
187  Array<complex<float>, 1> &d,
188  Array<complex<float>, 1> const &a,
189  Array<float, 1> const &b,
190  Array<complex<float>, 1> const &c
191  ) {
192 
193  d[0].real() = a[0].real() * b[0] + c[0].real();
194  d[0].imag() = a[0].imag() * b[0] + c[0].imag();
195  }
196 };
197 
199 
201 template <
203  typename LayoutA,
205  typename LayoutB,
207  typename LayoutC
208 >
209 struct Mma<
210  gemm::GemmShape<1, 1, 1>,
211  1,
212  float,
213  LayoutA,
214  complex<float>,
215  LayoutB,
216  complex<float>,
217  LayoutC,
218  OpMultiplyAdd> {
219 
221 
224  Array<complex<float>, 1> &d,
225  Array<float, 1> const &a,
226  Array<complex<float>, 1> const &b,
227  Array<complex<float>, 1> const &c
228  ) {
229 
230  d[0].real() = a[0] * b[0].real() + c[0].real();
231  d[0].imag() = a[0] * b[0].imag() + d[0].imag();
232  }
233 };
234 
236 
238 template <
240  typename LayoutA,
242  typename LayoutB,
244  typename LayoutC
245 >
246 struct Mma<
247  gemm::GemmShape<1, 1, 1>,
248  1,
249  complex<double>,
250  LayoutA,
251  complex<double>,
252  LayoutB,
253  complex<double>,
254  LayoutC,
255  OpMultiplyAdd> {
256 
258 
261  Array<complex<double>, 1> &d,
262  Array<complex<double>, 1> const &a,
263  Array<complex<double>, 1> const &b,
264  Array<complex<double>, 1> const &c
265  ) {
266 
267  d[0].real() = a[0].real() * b[0].real() + c[0].real();
268  d[0].imag() = a[0].imag() * b[0].real() + c[0].imag();
269  d[0].real() = -a[0].imag() * b[0].imag() + d[0].real();
270  d[0].imag() = a[0].real() * b[0].imag() + d[0].imag();
271  }
272 };
273 
275 template <
277  typename LayoutA,
279  typename LayoutB,
281  typename LayoutC
282 >
283 struct Mma<
284  gemm::GemmShape<1, 1, 1>,
285  1,
286  complex<double>,
287  LayoutA,
288  double,
289  LayoutB,
290  complex<double>,
291  LayoutC,
292  OpMultiplyAdd> {
293 
295 
298  Array<complex<double>, 1> &d,
299  Array<complex<double>, 1> const &a,
300  Array<double, 1> const &b,
301  Array<complex<double>, 1> const &c
302  ) {
303 
304  d[0].real() = a[0].real() * b[0] + c[0].real();
305  d[0].imag() = a[0].imag() * b[0] + c[0].imag();
306  }
307 };
308 
310 template <
312  typename LayoutA,
314  typename LayoutB,
316  typename LayoutC
317 >
318 struct Mma<
319  gemm::GemmShape<1, 1, 1>,
320  1,
321  double,
322  LayoutA,
323  complex<double>,
324  LayoutB,
325  complex<double>,
326  LayoutC,
327  OpMultiplyAdd> {
328 
330 
333  Array<complex<double>, 1> &d,
334  Array<double, 1> const &a,
335  Array<complex<double>, 1> const &b,
336  Array<complex<double>, 1> const &c
337  ) {
338 
339  d[0].real() = a[0] * b[0].real() + c[0].real();
340  d[0].imag() = a[0] * b[0].imag() + d[0].imag();
341  }
342 };
343 
345 
347 template <
349  typename LayoutA,
351  typename LayoutB,
353  typename LayoutC
354 >
355 struct Mma<gemm::GemmShape<1, 1, 1>, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd> {
356 
358 
361  Array<float, 1> &d,
362  Array<half_t, 1> const &a,
363  Array<half_t, 1> const &b,
364  Array<float, 1> const &c
365  ) {
366  d[0] = float(a[0]) * float(b[0]) + c[0];
367  }
368 };
369 
371 
372 }
373 }
CUTLASS_HOST_DEVICE void operator()(Array< int, 1 > &d, Array< int, 1 > const &a, Array< int, 1 > const &b, Array< int, 1 > const &c)
Definition: arch/mma_sm50.h:111
CUTLASS_HOST_DEVICE void operator()(Array< complex< double >, 1 > &d, Array< complex< double >, 1 > const &a, Array< double, 1 > const &b, Array< complex< double >, 1 > const &c)
Definition: arch/mma_sm50.h:297
Definition: aligned_buffer.h:35
IEEE half-precision floating-point type.
Definition: half.h:126
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE void operator()(Array< complex< float >, 1 > &d, Array< complex< float >, 1 > const &a, Array< complex< float >, 1 > const &b, Array< complex< float >, 1 > const &c)
Definition: arch/mma_sm50.h:147
CUTLASS_HOST_DEVICE void operator()(Array< complex< double >, 1 > &d, Array< complex< double >, 1 > const &a, Array< complex< double >, 1 > const &b, Array< complex< double >, 1 > const &c)
Definition: arch/mma_sm50.h:260
CUTLASS_HOST_DEVICE void operator()(Array< float, 1 > &d, Array< half_t, 1 > const &a, Array< half_t, 1 > const &b, Array< float, 1 > const &c)
Definition: arch/mma_sm50.h:360
Templates exposing architecture support for multiply-add operations.
CUTLASS_HOST_DEVICE void operator()(Array< complex< float >, 1 > &d, Array< complex< float >, 1 > const &a, Array< float, 1 > const &b, Array< complex< float >, 1 > const &c)
Definition: arch/mma_sm50.h:186
CUTLASS_HOST_DEVICE void operator()(Array< float, 1 > &d, Array< float, 1 > const &a, Array< float, 1 > const &b, Array< float, 1 > const &c)
Definition: arch/mma_sm50.h:58
CUTLASS_HOST_DEVICE void operator()(Array< complex< float >, 1 > &d, Array< float, 1 > const &a, Array< complex< float >, 1 > const &b, Array< complex< float >, 1 > const &c)
Definition: arch/mma_sm50.h:223
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Definition: complex.h:92
Defines layout functions used by TensorRef and derived classes.
Matrix multiply-add operation.
Definition: arch/mma.h:92
CUTLASS_HOST_DEVICE void operator()(Array< complex< double >, 1 > &d, Array< double, 1 > const &a, Array< complex< double >, 1 > const &b, Array< complex< double >, 1 > const &c)
Definition: arch/mma_sm50.h:332
CUTLASS_HOST_DEVICE void operator()(Array< double, 1 > &d, Array< double, 1 > const &a, Array< double, 1 > const &b, Array< double, 1 > const &c)
Definition: arch/mma_sm50.h:84