32 #if !defined(__clang__) 34 #if (__CUDACC_VER_MAJOR__ >= 9) 35 #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700)) 36 #define CUTLASS_ARCH_WMMA_ENABLED 37 #define CUTLASS_ARCH_WMMA_SM70_ENABLED 41 #if (__CUDACC_VER_MAJOR__ >= 10) 42 #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 720)) 43 #define CUTLASS_ARCH_INTEGER_MATRIX_MULTIPLY_ENABLED 44 #define CUTLASS_ARCH_WMMA_SM72_ENABLED 48 #if (__CUDACC_VER_MAJOR__ >= 10) 49 #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750)) 50 #define CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED 51 #define CUTLASS_ARCH_WMMA_SM75_ENABLED 57 #if defined(CUTLASS_ARCH_WMMA_ENABLED) 74 enum class MemoryKind {
84 static int const kThreadsPerWarp = 32;
85 static int const kQuadsPerWarp = 8;
86 static int const kThreadsPerQuad = 4;
92 template <
typename Type_>
93 struct CutlassToWmmaDataType{
99 struct CutlassToWmmaDataType<
cutlass::half_t> {
106 struct CutlassToWmmaDataType<int8_t> {
107 using Type =
signed char;
112 struct CutlassToWmmaDataType<uint8_t> {
113 using Type =
unsigned char;
118 struct CutlassToWmmaDataType<int32_t> {
122 #if defined(CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED) 126 using Type = nvcuda::wmma::experimental::precision::s4;
132 using Type = nvcuda::wmma::experimental::precision::u4;
138 using Type = nvcuda::wmma::experimental::precision::b1;
145 template <
typename Layout_>
146 struct CutlassToWmmaLayout {
151 struct CutlassToWmmaLayout<
cutlass::layout::RowMajor> {
152 using Layout = nvcuda::wmma::row_major;
153 static nvcuda::wmma::layout_t
const value = nvcuda::wmma::layout_t::mem_row_major;
160 struct CutlassToWmmaLayout<
cutlass::layout::ColumnMajor> {
161 using Layout = nvcuda::wmma::col_major;
162 static nvcuda::wmma::layout_t
const value = nvcuda::wmma::layout_t::mem_col_major;
169 template <
typename Type_>
170 struct WmmaToCutlassDataType{
176 struct WmmaToCutlassDataType<__half> {
194 typename Operator_ = cutlass::arch::OpMultiplyAdd
208 #ifdef CUTLASS_ARCH_WMMA_SM70_ENABLED 212 #ifdef CUTLASS_ARCH_WMMA_SM72_ENABLED 216 #ifdef CUTLASS_ARCH_WMMA_SM75_ENABLED 222 #endif //CUTLASS_ARCH_WMMA_ENABLED integer_subbyte< 4, false > uint4b_t
4-bit Unsigned integer type
Definition: integer_subbyte.h:158
Definition: aligned_buffer.h:35
integer_subbyte< 1, false > uint1b_t
1-bit Unsigned integer type
Definition: integer_subbyte.h:152
IEEE half-precision floating-point type.
Definition: half.h:126
Defines common types used for all GEMM-like operators.
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Templates exposing architecture support for multiply-add operations.
Top-level include for all CUTLASS numeric types.
integer_subbyte< 4, true > int4b_t
4-bit Integer type
Definition: integer_subbyte.h:155