CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
type_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <cublas_v2.h>
32 #include <cuda_fp16.h>
33 #include <stdint.h>
34 
35 #include "cutlass/numeric_types.h"
36 #include "cutlass/complex.h"
37 
38 namespace cutlass {
39 struct half_t;
40 
41 template <typename T>
42 struct TypeTraits {
43  typedef T host_type;
44  typedef T device_type;
45  static inline T remove_negative_zero(T x) { return x; }
46  static inline T to_print(T x) { return x; }
47  static inline device_type to_device(host_type x) { return x; }
48 };
49 
50 template <>
51 struct TypeTraits<int8_t> {
52  static cudaDataType_t const cublas_type = CUDA_R_8I;
53  typedef int8_t host_type;
54  typedef int8_t device_type;
55  typedef int8_t integer_type;
56  typedef uint8_t unsigned_type;
57  static inline int8_t remove_negative_zero(int8_t x) { return x; }
58  static inline int to_print(int8_t x) { return (int)x; }
59  static inline device_type to_device(host_type x) { return x; }
60 };
61 
62 template <>
63 struct TypeTraits<uint8_t> {
64  static cudaDataType_t const cublas_type = CUDA_R_8I;
65  typedef uint8_t host_type;
66  typedef uint8_t device_type;
67  typedef uint8_t integer_type;
68  typedef uint8_t unsigned_type;
69  static inline uint8_t remove_negative_zero(uint8_t x) { return x; }
70  static inline uint32_t to_print(uint8_t x) { return (uint32_t)x; }
71  static inline device_type to_device(host_type x) { return x; }
72 };
73 
74 template <>
75 struct TypeTraits<int> {
76  static cudaDataType_t const cublas_type = CUDA_R_32I;
77  typedef int host_type;
78  typedef int device_type;
79  typedef int32_t integer_type;
80  typedef uint32_t unsigned_type;
81  static inline int32_t remove_negative_zero(int32_t x) { return x; }
82  static inline int to_print(int x) { return x; }
83  static inline device_type to_device(host_type x) { return x; }
84 };
85 
86 template <>
87 struct TypeTraits<unsigned> {
88  static cudaDataType_t const cublas_type = CUDA_R_32I;
89  typedef unsigned host_type;
90  typedef unsigned device_type;
91  typedef uint32_t integer_type;
92  typedef uint32_t unsigned_type;
93  static inline uint32_t remove_negative_zero(uint32_t x) { return x; }
94  static inline uint32_t to_print(uint32_t x) { return x; }
95  static inline device_type to_device(host_type x) { return x; }
96 };
97 
98 template <>
99 struct TypeTraits<int64_t> {
100  static cudaDataType_t const cublas_type = CUDA_R_8I;
101  typedef int64_t host_type;
102  typedef int64_t device_type;
103  typedef int64_t integer_type;
104  typedef uint64_t unsigned_type;
105  static inline int64_t remove_negative_zero(int64_t x) { return x; }
106  static inline int64_t to_print(int64_t x) { return x; }
107  static inline device_type to_device(host_type x) { return x; }
108 };
109 
110 template <>
111 struct TypeTraits<uint64_t> {
112  static cudaDataType_t const cublas_type = CUDA_R_8I;
113  typedef uint64_t host_type;
114  typedef uint64_t device_type;
115  typedef uint64_t integer_type;
116  typedef uint64_t unsigned_type;
117  static inline uint64_t remove_negative_zero(uint64_t x) { return x; }
118  static inline uint64_t to_print(uint64_t x) { return x; }
119  static inline device_type to_device(host_type x) { return x; }
120 };
121 
122 template <>
124  static cudaDataType_t const cublas_type = CUDA_R_16F;
125  typedef half_t host_type;
127  typedef int16_t integer_type;
128  typedef uint16_t unsigned_type;
129  static inline half_t remove_negative_zero(half_t x) {
130  return (x.raw() == 0x8000 ? half_t::bitcast(0) : x);
131  }
132  static inline half_t to_print(half_t x) { return x; }
133  static inline device_type to_device(half_t x) { return reinterpret_cast<device_type const &>(x); }
134 };
135 
136 template <>
137 struct TypeTraits<float> {
138  static cudaDataType_t const cublas_type = CUDA_R_32F;
139  typedef float host_type;
140  typedef float device_type;
141  typedef int32_t integer_type;
142  typedef uint32_t unsigned_type;
143  static inline float remove_negative_zero(float x) { return x == -0.f ? 0.f : x; }
144  static inline float to_print(float x) { return x; }
145  static inline device_type to_device(host_type x) { return x; }
146 };
147 
148 template <>
149 struct TypeTraits<double> {
150  static cudaDataType_t const cublas_type = CUDA_R_64F;
151  typedef double host_type;
152  typedef double device_type;
153  typedef int64_t integer_type;
154  typedef uint64_t unsigned_type;
155  static inline double remove_negative_zero(double x) { return x == -0.0 ? 0.0 : x; }
156  static inline double to_print(double x) { return x; }
157  static inline device_type to_device(host_type x) { return x; }
158 };
159 
161 //
162 // Complex types
163 //
165 
166 template <>
167 struct TypeTraits<complex<half> > {
168  static cudaDataType_t const cublas_type = CUDA_C_16F;
171  typedef int16_t integer_type;
172  typedef uint16_t unsigned_type;
173  static inline device_type to_device(complex<half> x) { return reinterpret_cast<device_type const &>(x); }
174 };
175 
176 template <>
178  static cudaDataType_t const cublas_type = CUDA_C_16F;
181  typedef int16_t integer_type;
182  typedef uint16_t unsigned_type;
184  return complex<half_t>(
185  real(x) == -0_hf ? 0_hf : real(x),
186  imag(x) == -0_hf ? 0_hf : imag(x)
187  );
188  }
189  static inline complex<half_t> to_print(complex<half_t> x) { return x; }
190  static inline device_type to_device(complex<half_t> x) { return reinterpret_cast<device_type const &>(x); }
191 };
192 
193 template <>
194 struct TypeTraits<complex<float> > {
195 
196  static cudaDataType_t const cublas_type = CUDA_C_32F;
199  typedef int64_t integer_type;
200  typedef uint64_t unsigned_type;
201 
203  return complex<float>(
204  real(x) == -0.f ? 0.f : real(x),
205  imag(x) == -0.f ? 0.f : imag(x)
206  );
207  }
208 
209  static inline complex<float> to_print(complex<float> x) { return x; }
210  static inline device_type to_device(complex<float> x) { return reinterpret_cast<device_type const &>(x); }
211 };
212 
213 template <>
214 struct TypeTraits<complex<double> > {
215  static cudaDataType_t const cublas_type = CUDA_C_64F;
218  struct integer_type { int64_t real, imag; };
219  struct unsigned_type { uint64_t real, imag; };
221  return complex<double>(
222  real(x) == -0.0 ? 0.0 : real(x),
223  imag(x) == -0.0 ? 0.0 : imag(x)
224  );
225  }
226  static inline complex<double> to_print(complex<double> x) { return x; }
227  static inline device_type to_device(complex<double> x) { return reinterpret_cast<device_type const &>(x); }
228 };
229 
231 
232 } // namespace cutlass
int64_t integer_type
Definition: type_traits.h:103
complex< half_t > host_type
Definition: type_traits.h:169
static CUTLASS_HOST_DEVICE half_t bitcast(uint16_t x)
Constructs from an unsigned short.
Definition: half.h:141
static complex< float > to_print(complex< float > x)
Definition: type_traits.h:209
Definition: aligned_buffer.h:35
static double remove_negative_zero(double x)
Definition: type_traits.h:155
T host_type
Definition: type_traits.h:43
int16_t integer_type
Definition: type_traits.h:127
uint64_t unsigned_type
Definition: type_traits.h:154
float device_type
Definition: type_traits.h:140
CUTLASS_HOST_DEVICE float const & imag(cuFloatComplex const &z)
Returns the imaginary part of the complex number.
Definition: complex.h:72
static int8_t remove_negative_zero(int8_t x)
Definition: type_traits.h:57
static float remove_negative_zero(float x)
Definition: type_traits.h:143
static device_type to_device(host_type x)
Definition: type_traits.h:95
uint32_t integer_type
Definition: type_traits.h:91
int16_t integer_type
Definition: type_traits.h:181
T device_type
Definition: type_traits.h:44
int32_t integer_type
Definition: type_traits.h:79
uint32_t unsigned_type
Definition: type_traits.h:92
int device_type
Definition: type_traits.h:78
static int32_t remove_negative_zero(int32_t x)
Definition: type_traits.h:81
static uint32_t remove_negative_zero(uint32_t x)
Definition: type_traits.h:93
static device_type to_device(host_type x)
Definition: type_traits.h:119
uint64_t unsigned_type
Definition: type_traits.h:104
IEEE half-precision floating-point type.
Definition: half.h:126
complex< double > device_type
Definition: type_traits.h:217
int16_t integer_type
Definition: type_traits.h:171
uint64_t integer_type
Definition: type_traits.h:115
static complex< double > to_print(complex< double > x)
Definition: type_traits.h:226
CUTLASS_HOST_DEVICE float const & real(cuFloatComplex const &z)
Returns the real part of the complex number.
Definition: complex.h:56
static half_t remove_negative_zero(half_t x)
Definition: type_traits.h:129
complex< half > device_type
Definition: type_traits.h:180
uint64_t host_type
Definition: type_traits.h:113
static device_type to_device(host_type x)
Definition: type_traits.h:145
int64_t host_type
Definition: type_traits.h:101
static half_t to_print(half_t x)
Definition: type_traits.h:132
uint8_t integer_type
Definition: type_traits.h:67
static T remove_negative_zero(T x)
Definition: type_traits.h:45
static device_type to_device(host_type x)
Definition: type_traits.h:157
complex< half_t > host_type
Definition: type_traits.h:179
static float to_print(float x)
Definition: type_traits.h:144
int64_t device_type
Definition: type_traits.h:102
static complex< double > remove_negative_zero(complex< double > x)
Definition: type_traits.h:220
static uint8_t remove_negative_zero(uint8_t x)
Definition: type_traits.h:69
double device_type
Definition: type_traits.h:152
static int64_t remove_negative_zero(int64_t x)
Definition: type_traits.h:105
static uint64_t to_print(uint64_t x)
Definition: type_traits.h:118
int64_t integer_type
Definition: type_traits.h:199
static device_type to_device(half_t x)
Definition: type_traits.h:133
complex< double > host_type
Definition: type_traits.h:216
static uint64_t remove_negative_zero(uint64_t x)
Definition: type_traits.h:117
uint32_t unsigned_type
Definition: type_traits.h:142
static uint32_t to_print(uint32_t x)
Definition: type_traits.h:94
int32_t integer_type
Definition: type_traits.h:141
static device_type to_device(host_type x)
Definition: type_traits.h:83
static device_type to_device(host_type x)
Definition: type_traits.h:107
double host_type
Definition: type_traits.h:151
int8_t device_type
Definition: type_traits.h:54
uint32_t unsigned_type
Definition: type_traits.h:80
uint64_t real
Definition: type_traits.h:219
static device_type to_device(complex< float > x)
Definition: type_traits.h:210
unsigned device_type
Definition: type_traits.h:90
uint64_t device_type
Definition: type_traits.h:114
Top-level include for all CUTLASS numeric types.
static int64_t to_print(int64_t x)
Definition: type_traits.h:106
uint8_t unsigned_type
Definition: type_traits.h:56
uint8_t unsigned_type
Definition: type_traits.h:68
static device_type to_device(host_type x)
Definition: type_traits.h:47
int8_t integer_type
Definition: type_traits.h:55
static int to_print(int x)
Definition: type_traits.h:82
int64_t real
Definition: type_traits.h:218
uint16_t unsigned_type
Definition: type_traits.h:128
static uint32_t to_print(uint8_t x)
Definition: type_traits.h:70
static double to_print(double x)
Definition: type_traits.h:156
static device_type to_device(complex< half_t > x)
Definition: type_traits.h:190
unsigned host_type
Definition: type_traits.h:89
half_t host_type
Definition: type_traits.h:125
uint64_t unsigned_type
Definition: type_traits.h:116
uint8_t host_type
Definition: type_traits.h:65
int64_t integer_type
Definition: type_traits.h:153
uint16_t unsigned_type
Definition: type_traits.h:172
Definition: complex.h:92
static complex< float > remove_negative_zero(complex< float > x)
Definition: type_traits.h:202
float host_type
Definition: type_traits.h:139
static device_type to_device(host_type x)
Definition: type_traits.h:59
complex< float > device_type
Definition: type_traits.h:198
complex< float > host_type
Definition: type_traits.h:197
uint64_t unsigned_type
Definition: type_traits.h:200
static device_type to_device(complex< double > x)
Definition: type_traits.h:227
int host_type
Definition: type_traits.h:77
static device_type to_device(host_type x)
Definition: type_traits.h:71
static complex< half_t > to_print(complex< half_t > x)
Definition: type_traits.h:189
uint16_t unsigned_type
Definition: type_traits.h:182
half_t device_type
Definition: type_traits.h:126
static int to_print(int8_t x)
Definition: type_traits.h:58
static device_type to_device(complex< half > x)
Definition: type_traits.h:173
complex< half > device_type
Definition: type_traits.h:170
uint8_t device_type
Definition: type_traits.h:66
int8_t host_type
Definition: type_traits.h:53
static complex< half_t > remove_negative_zero(complex< half_t > x)
Definition: type_traits.h:183
CUTLASS_HOST_DEVICE uint16_t & raw()
Accesses raw internal state.
Definition: half.h:367
static T to_print(T x)
Definition: type_traits.h:46
Definition: type_traits.h:42