CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
default_gemm_configuration.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/arch/arch.h"
34 #include "cutlass/arch/mma.h"
35 #include "cutlass/arch/wmma.h"
36 
37 #include "cutlass/gemm/gemm.h"
40 
42 
43 namespace cutlass {
44 namespace gemm {
45 namespace device {
46 
48 
49 template <
50  typename OperatorClass,
51  typename ArchTag,
52  typename ElementA,
53  typename ElementB,
54  typename ElementC,
55  typename ElementAccumulator
56 >
58 
60 
61 template <
62  typename ArchTag,
63  typename ElementA,
64  typename ElementB,
65  typename ElementC,
66  typename ElementAccumulator>
68  arch::OpClassSimt,
69  ArchTag,
70  ElementA,
71  ElementB,
72  ElementC,
73  ElementAccumulator> {
74 
75  static int const kAlignmentA = 1;
76  static int const kAlignmentB = 1;
80  static int const kStages = 2;
81 
83  ElementC,
84  1,
85  ElementAccumulator,
86  ElementAccumulator
87  >;
88 
89  using Operator = arch::OpMultiplyAdd;
90 };
91 
93 
94 template <
95  typename ArchTag,
96  typename ElementC>
97 struct DefaultGemmConfiguration<arch::OpClassSimt, ArchTag, int8_t, int8_t, ElementC, int32_t> {
98 
99  static int const kAlignmentA = 4;
100  static int const kAlignmentB = 4;
104  static int const kStages = 2;
105 
107  ElementC,
108  1,
109  int32_t,
110  float
111  >;
112 
113  using Operator = arch::OpMultiplyAdd;
114 };
115 
117 
118 template <
119  typename ArchTag,
120  typename ElementA,
121  typename ElementB,
122  typename ElementC,
123  typename ElementAccumulator>
125  arch::OpClassWmmaTensorOp,
126  ArchTag,
127  ElementA,
128  ElementB,
129  ElementC,
130  ElementAccumulator> {
131 
132  static int const kAlignmentA = 128 / sizeof_bits<ElementA>::value;
133  static int const kAlignmentB = 128 / sizeof_bits<ElementB>::value;
134 
135  static int const kStages = 2;
136 
138  ElementC,
140  ElementAccumulator,
141  ElementAccumulator
142  >;
143 
144  using Operator = arch::OpMultiplyAdd;
145 };
146 
148 
149 template <
150  typename ElementA,
151  typename ElementB,
152  typename ElementC,
153  typename ElementAccumulator>
155  arch::OpClassTensorOp,
156  arch::Sm70,
157  ElementA,
158  ElementB,
159  ElementC,
160  ElementAccumulator> {
161 
162  static int const kAlignmentA = 128 / sizeof_bits<ElementA>::value;
163  static int const kAlignmentB = 128 / sizeof_bits<ElementB>::value;
164 
168  static int const kStages = 2;
169 
171  ElementC,
173  ElementAccumulator,
174  ElementAccumulator
175  >;
176 
177  using Operator = arch::OpMultiplyAdd;
178 };
179 
181 
182 template <
183  typename ElementA,
184  typename ElementB,
185  typename ElementC,
186  typename ElementAccumulator>
188  arch::OpClassTensorOp,
189  arch::Sm75,
190  ElementA,
191  ElementB,
192  ElementC,
193  ElementAccumulator> {
194 
195  static int const kAlignmentA = 128 / sizeof_bits<ElementA>::value;
196  static int const kAlignmentB = 128 / sizeof_bits<ElementA>::value;
200  static int const kStages = 2;
201 
203  ElementC,
205  ElementAccumulator,
206  ElementAccumulator
207  >;
208 
209  using Operator = typename platform::conditional<
214  arch::OpMultiplyAddSaturate, arch::OpMultiplyAdd>::type;
215 };
216 
218 
219 template <
220  typename ElementC>
222  arch::OpClassTensorOp,
223  arch::Sm75,
224  int8_t,
225  int8_t,
226  ElementC,
227  int32_t> {
228 
229  static int const kAlignmentA = 128 / sizeof_bits<int8_t>::value;
230  static int const kAlignmentB = 128 / sizeof_bits<int8_t>::value;
231 
235  static int const kStages = 2;
236 
238  ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
239 
240  using Operator = arch::OpMultiplyAddSaturate;
241 };
242 
244 
245 template <
246  typename ElementC>
248  arch::OpClassTensorOp,
249  arch::Sm75,
250  int8_t,
251  uint8_t,
252  ElementC,
253  int32_t> {
254 
255  static int const kAlignmentA = 128 / sizeof_bits<int8_t>::value;
256  static int const kAlignmentB = 128 / sizeof_bits<uint8_t>::value;
257 
261  static int const kStages = 2;
262 
264  ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
265 
266  using Operator = arch::OpMultiplyAddSaturate;
267 };
268 
270 
271 template <
272  typename ElementC>
274  arch::OpClassTensorOp,
275  arch::Sm75,
276  uint8_t,
277  int8_t,
278  ElementC,
279  int32_t> {
280 
281  static int const kAlignmentA = 128 / sizeof_bits<uint8_t>::value;
282  static int const kAlignmentB = 128 / sizeof_bits<int8_t>::value;
283 
287  static int const kStages = 2;
288 
290  ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
291 
292  using Operator = arch::OpMultiplyAddSaturate;
293 };
294 
296 
297 template <
298  typename ElementC>
300  arch::OpClassTensorOp,
301  arch::Sm75,
302  uint8_t,
303  uint8_t,
304  ElementC,
305  int32_t> {
306 
307  static int const kAlignmentA = 128 / sizeof_bits<uint8_t>::value;
308  static int const kAlignmentB = 128 / sizeof_bits<uint8_t>::value;
309 
313  static int const kStages = 2;
314 
316  ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
317 
318  using Operator = arch::OpMultiplyAddSaturate;
319 };
320 
322 
323 template <
324  typename ElementC>
326  arch::OpClassTensorOp,
327  arch::Sm75,
328  int4b_t,
329  int4b_t,
330  ElementC,
331  int32_t> {
332 
333  static int const kAlignmentA = 128 / sizeof_bits<int4b_t>::value;
334  static int const kAlignmentB = 128 / sizeof_bits<int4b_t>::value;
335 
339  static int const kStages = 2;
340 
342  ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
343 
344  using Operator = arch::OpMultiplyAddSaturate;
345 };
346 
348 
349 template <
350  typename ElementC>
352  arch::OpClassTensorOp,
353  arch::Sm75,
354  int4b_t,
355  uint4b_t,
356  ElementC,
357  int32_t> {
358 
359  static int const kAlignmentA = 128 / sizeof_bits<int4b_t>::value;
360  static int const kAlignmentB = 128 / sizeof_bits<uint4b_t>::value;
361 
365  static int const kStages = 2;
366 
368  ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
369 
370  using Operator = arch::OpMultiplyAddSaturate;
371 };
372 
374 
375 template <
376  typename ElementC>
378  arch::OpClassTensorOp,
379  arch::Sm75,
380  uint4b_t,
381  int4b_t,
382  ElementC,
383  int32_t> {
384 
385  static int const kAlignmentA = 128 / sizeof_bits<uint4b_t>::value;
386  static int const kAlignmentB = 128 / sizeof_bits<int4b_t>::value;
387 
391  static int const kStages = 2;
392 
394  ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
395 
396  using Operator = arch::OpMultiplyAddSaturate;
397 };
398 
400 
401 template <
402  typename ElementC>
404  arch::OpClassTensorOp,
405  arch::Sm75,
406  uint4b_t,
407  uint4b_t,
408  ElementC,
409  int32_t> {
410 
411  static int const kAlignmentA = 128 / sizeof_bits<uint4b_t>::value;
412  static int const kAlignmentB = 128 / sizeof_bits<uint4b_t>::value;
413 
417  static int const kStages = 2;
418 
420  ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
421 
422  using Operator = arch::OpMultiplyAddSaturate;
423 };
424 
426 } // namespace device
427 } // namespace gemm
428 } // namespace cutlass
429 
Definition: aligned_buffer.h:35
Definition: linear_combination.h:56
std::is_same (false specialization)
Definition: platform.h:394
Definition: linear_combination_clamp.h:58
4-bit signed integer type
Definition: integer_subbyte.h:42
Functor performing linear scaling operations used by epilogues. Values are clamped before converting ...
Definition: arch.h:46
Defines common types used for all GEMM-like operators.
typename platform::conditional< (platform::is_same< ElementA, int8_t >::value||platform::is_same< ElementA, int4b_t >::value||platform::is_same< ElementA, uint8_t >::value||platform::is_same< ElementA, uint4b_t >::value), arch::OpMultiplyAddSaturate, arch::OpMultiplyAdd >::type Operator
Definition: default_gemm_configuration.h:214
Templates exposing architecture support for multiply-add operations.
Definition: arch.h:52
Functor performing linear combination operations used by epilogues.
Defines the size of an element in bits.
Definition: numeric_types.h:42
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
std::conditional (true specialization)
Definition: platform.h:325
Definition: default_gemm_configuration.h:57
Defines tags for architecture-specific configurations.
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
Basic include for CUTLASS.