CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 // CUTLASS WMMA does not support clang at present.
32 #if !defined(__clang__)
33 
34 #if (__CUDACC_VER_MAJOR__ >= 9)
35 #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700))
36 #define CUTLASS_ARCH_WMMA_ENABLED
37 #define CUTLASS_ARCH_WMMA_SM70_ENABLED
38 #endif
39 #endif
40 
41 #if (__CUDACC_VER_MAJOR__ >= 10)
42 #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 720))
43 #define CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED
44 #define CUTLASS_ARCH_WMMA_SM72_ENABLED
45 #endif
46 #endif
47 
48 #if (__CUDACC_VER_MAJOR__ >= 10)
49 #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750))
50 #define CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED
51 #define CUTLASS_ARCH_WMMA_SM75_ENABLED
52 #endif
53 #endif
54 
55 #endif //__clang__
56 
57 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
58 
59 #include <mma.h>
60 #include "cutlass/arch/mma.h"
61 #include "cutlass/array.h"
62 #include "cutlass/numeric_types.h"
63 #include "cutlass/gemm/gemm.h"
64 
65 
67 
68 namespace cutlass {
69 namespace arch {
70 
74 enum class MemoryKind {
75  kShared, // Data resides in shared memory
76  kGlobal // Data resides in global memory
77 };
78 
79 
83 struct WarpParams {
84  static int const kThreadsPerWarp = 32;
85  static int const kQuadsPerWarp = 8;
86  static int const kThreadsPerQuad = 4;
87 };
88 
92 template <typename Type_>
93 struct CutlassToWmmaDataType{
94  using Type = Type_;
95 };
96 
98 template<>
99 struct CutlassToWmmaDataType<cutlass::half_t> {
100  using Type = __half;
101 };
102 
103 
105 template<>
106 struct CutlassToWmmaDataType<int8_t> {
107  using Type = signed char;
108 };
109 
111 template<>
112 struct CutlassToWmmaDataType<uint8_t> {
113  using Type = unsigned char;
114 };
115 
117 template<>
118 struct CutlassToWmmaDataType<int32_t> {
119  using Type = int;
120 };
121 
122 #if defined(CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED)
123 template<>
125 struct CutlassToWmmaDataType<cutlass::int4b_t> {
126  using Type = nvcuda::wmma::experimental::precision::s4;
127 };
128 
130 template<>
131 struct CutlassToWmmaDataType<cutlass::uint4b_t> {
132  using Type = nvcuda::wmma::experimental::precision::u4;
133 };
134 
136 template<>
137 struct CutlassToWmmaDataType<cutlass::uint1b_t> {
138  using Type = nvcuda::wmma::experimental::precision::b1;
139 };
140 #endif
141 
145 template <typename Layout_>
146 struct CutlassToWmmaLayout {
147 };
148 
150 template <>
151 struct CutlassToWmmaLayout<cutlass::layout::RowMajor> {
152  using Layout = nvcuda::wmma::row_major;
153  static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_row_major;
154 };
155 
159 template <>
160 struct CutlassToWmmaLayout<cutlass::layout::ColumnMajor> {
161  using Layout = nvcuda::wmma::col_major;
162  static nvcuda::wmma::layout_t const value = nvcuda::wmma::layout_t::mem_col_major;
163 };
165 
169 template <typename Type_>
170 struct WmmaToCutlassDataType{
171  using Type = Type_;
172 };
173 
175 template<>
176 struct WmmaToCutlassDataType<__half> {
177  using Type = cutlass::half_t;
178 };
180 
182 // WMMA template structure defines nvcuda::wmma::fragments and static assertion chaeks
183 // for a specific template paramterized data type (Element[A|B|C]), layout (Layout[A|B|C]),
184 // and native wmma size (Shape)
186 template <
187  typename Shape_,
188  typename ElementA_,
189  typename LayoutA_,
190  typename ElementB_,
191  typename LayoutB_,
192  typename ElementC_,
193  typename LayoutC_,
194  typename Operator_ = cutlass::arch::OpMultiplyAdd
195 >
196 struct Wmma;
198 
199 
200 } // namespace arch
201 } // namespace cutlass
202 
204 
205 //
206 // Specializations for each compute capability
207 //
208 #ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED
209 #include "cutlass/arch/wmma_sm70.h"
210 #endif
211 
212 #ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED
213 #include "cutlass/arch/wmma_sm72.h"
214 #endif
215 
216 #ifdef CUTLASS_ARCH_WMMA_SM75_ENABLED
217 #include "cutlass/arch/wmma_sm75.h"
218 #endif
219 
221 
222 #endif //CUTLASS_ARCH_WMMA_ENABLED
integer_subbyte< 4, false > uint4b_t
4-bit Unsigned integer type
Definition: integer_subbyte.h:158
Definition: aligned_buffer.h:35
Matrix multiply.
integer_subbyte< 1, false > uint1b_t
1-bit Unsigned integer type
Definition: integer_subbyte.h:152
Matrix multiply.
IEEE half-precision floating-point type.
Definition: half.h:126
Defines common types used for all GEMM-like operators.
Matrix multiply.
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Templates exposing architecture support for multiply-add operations.
Top-level include for all CUTLASS numeric types.
integer_subbyte< 4, true > int4b_t
4-bit Integer type
Definition: integer_subbyte.h:155