31 #include <cublas_v2.h> 32 #include <cuda_fp16.h> 47 static inline device_type
to_device(host_type x) {
return x; }
52 static cudaDataType_t
const cublas_type = CUDA_R_8I;
58 static inline int to_print(int8_t x) {
return (
int)x; }
59 static inline device_type
to_device(host_type x) {
return x; }
64 static cudaDataType_t
const cublas_type = CUDA_R_8I;
70 static inline uint32_t
to_print(uint8_t x) {
return (uint32_t)x; }
71 static inline device_type
to_device(host_type x) {
return x; }
76 static cudaDataType_t
const cublas_type = CUDA_R_32I;
82 static inline int to_print(
int x) {
return x; }
83 static inline device_type
to_device(host_type x) {
return x; }
88 static cudaDataType_t
const cublas_type = CUDA_R_32I;
94 static inline uint32_t
to_print(uint32_t x) {
return x; }
95 static inline device_type
to_device(host_type x) {
return x; }
100 static cudaDataType_t
const cublas_type = CUDA_R_8I;
106 static inline int64_t
to_print(int64_t x) {
return x; }
107 static inline device_type
to_device(host_type x) {
return x; }
112 static cudaDataType_t
const cublas_type = CUDA_R_8I;
118 static inline uint64_t
to_print(uint64_t x) {
return x; }
119 static inline device_type
to_device(host_type x) {
return x; }
124 static cudaDataType_t
const cublas_type = CUDA_R_16F;
133 static inline device_type
to_device(
half_t x) {
return reinterpret_cast<device_type
const &
>(x); }
138 static cudaDataType_t
const cublas_type = CUDA_R_32F;
144 static inline float to_print(
float x) {
return x; }
145 static inline device_type
to_device(host_type x) {
return x; }
150 static cudaDataType_t
const cublas_type = CUDA_R_64F;
156 static inline double to_print(
double x) {
return x; }
157 static inline device_type
to_device(host_type x) {
return x; }
168 static cudaDataType_t
const cublas_type = CUDA_C_16F;
178 static cudaDataType_t
const cublas_type = CUDA_C_16F;
196 static cudaDataType_t
const cublas_type = CUDA_C_32F;
215 static cudaDataType_t
const cublas_type = CUDA_C_64F;
int64_t integer_type
Definition: type_traits.h:103
complex< half_t > host_type
Definition: type_traits.h:169
static CUTLASS_HOST_DEVICE half_t bitcast(uint16_t x)
Constructs from an unsigned short.
Definition: half.h:141
static complex< float > to_print(complex< float > x)
Definition: type_traits.h:209
Definition: aligned_buffer.h:35
static double remove_negative_zero(double x)
Definition: type_traits.h:155
T host_type
Definition: type_traits.h:43
int16_t integer_type
Definition: type_traits.h:127
uint64_t unsigned_type
Definition: type_traits.h:154
float device_type
Definition: type_traits.h:140
CUTLASS_HOST_DEVICE float const & imag(cuFloatComplex const &z)
Returns the imaginary part of the complex number.
Definition: complex.h:72
static int8_t remove_negative_zero(int8_t x)
Definition: type_traits.h:57
static float remove_negative_zero(float x)
Definition: type_traits.h:143
static device_type to_device(host_type x)
Definition: type_traits.h:95
uint32_t integer_type
Definition: type_traits.h:91
int16_t integer_type
Definition: type_traits.h:181
T device_type
Definition: type_traits.h:44
int32_t integer_type
Definition: type_traits.h:79
uint32_t unsigned_type
Definition: type_traits.h:92
int device_type
Definition: type_traits.h:78
static int32_t remove_negative_zero(int32_t x)
Definition: type_traits.h:81
static uint32_t remove_negative_zero(uint32_t x)
Definition: type_traits.h:93
static device_type to_device(host_type x)
Definition: type_traits.h:119
uint64_t unsigned_type
Definition: type_traits.h:104
IEEE half-precision floating-point type.
Definition: half.h:126
complex< double > device_type
Definition: type_traits.h:217
int16_t integer_type
Definition: type_traits.h:171
uint64_t integer_type
Definition: type_traits.h:115
static complex< double > to_print(complex< double > x)
Definition: type_traits.h:226
CUTLASS_HOST_DEVICE float const & real(cuFloatComplex const &z)
Returns the real part of the complex number.
Definition: complex.h:56
static half_t remove_negative_zero(half_t x)
Definition: type_traits.h:129
complex< half > device_type
Definition: type_traits.h:180
uint64_t host_type
Definition: type_traits.h:113
static device_type to_device(host_type x)
Definition: type_traits.h:145
int64_t host_type
Definition: type_traits.h:101
static half_t to_print(half_t x)
Definition: type_traits.h:132
uint8_t integer_type
Definition: type_traits.h:67
static T remove_negative_zero(T x)
Definition: type_traits.h:45
static device_type to_device(host_type x)
Definition: type_traits.h:157
complex< half_t > host_type
Definition: type_traits.h:179
static float to_print(float x)
Definition: type_traits.h:144
int64_t device_type
Definition: type_traits.h:102
static complex< double > remove_negative_zero(complex< double > x)
Definition: type_traits.h:220
static uint8_t remove_negative_zero(uint8_t x)
Definition: type_traits.h:69
double device_type
Definition: type_traits.h:152
static int64_t remove_negative_zero(int64_t x)
Definition: type_traits.h:105
static uint64_t to_print(uint64_t x)
Definition: type_traits.h:118
int64_t integer_type
Definition: type_traits.h:199
static device_type to_device(half_t x)
Definition: type_traits.h:133
complex< double > host_type
Definition: type_traits.h:216
static uint64_t remove_negative_zero(uint64_t x)
Definition: type_traits.h:117
uint32_t unsigned_type
Definition: type_traits.h:142
static uint32_t to_print(uint32_t x)
Definition: type_traits.h:94
int32_t integer_type
Definition: type_traits.h:141
static device_type to_device(host_type x)
Definition: type_traits.h:83
static device_type to_device(host_type x)
Definition: type_traits.h:107
double host_type
Definition: type_traits.h:151
int8_t device_type
Definition: type_traits.h:54
uint32_t unsigned_type
Definition: type_traits.h:80
uint64_t real
Definition: type_traits.h:219
static device_type to_device(complex< float > x)
Definition: type_traits.h:210
unsigned device_type
Definition: type_traits.h:90
uint64_t device_type
Definition: type_traits.h:114
Top-level include for all CUTLASS numeric types.
static int64_t to_print(int64_t x)
Definition: type_traits.h:106
uint8_t unsigned_type
Definition: type_traits.h:56
uint8_t unsigned_type
Definition: type_traits.h:68
static device_type to_device(host_type x)
Definition: type_traits.h:47
int8_t integer_type
Definition: type_traits.h:55
static int to_print(int x)
Definition: type_traits.h:82
int64_t real
Definition: type_traits.h:218
uint16_t unsigned_type
Definition: type_traits.h:128
static uint32_t to_print(uint8_t x)
Definition: type_traits.h:70
static double to_print(double x)
Definition: type_traits.h:156
static device_type to_device(complex< half_t > x)
Definition: type_traits.h:190
unsigned host_type
Definition: type_traits.h:89
half_t host_type
Definition: type_traits.h:125
uint64_t unsigned_type
Definition: type_traits.h:116
uint8_t host_type
Definition: type_traits.h:65
int64_t integer_type
Definition: type_traits.h:153
uint16_t unsigned_type
Definition: type_traits.h:172
static complex< float > remove_negative_zero(complex< float > x)
Definition: type_traits.h:202
float host_type
Definition: type_traits.h:139
static device_type to_device(host_type x)
Definition: type_traits.h:59
complex< float > device_type
Definition: type_traits.h:198
complex< float > host_type
Definition: type_traits.h:197
uint64_t unsigned_type
Definition: type_traits.h:200
static device_type to_device(complex< double > x)
Definition: type_traits.h:227
int host_type
Definition: type_traits.h:77
static device_type to_device(host_type x)
Definition: type_traits.h:71
static complex< half_t > to_print(complex< half_t > x)
Definition: type_traits.h:189
uint16_t unsigned_type
Definition: type_traits.h:182
half_t device_type
Definition: type_traits.h:126
static int to_print(int8_t x)
Definition: type_traits.h:58
static device_type to_device(complex< half > x)
Definition: type_traits.h:173
complex< half > device_type
Definition: type_traits.h:170
uint8_t device_type
Definition: type_traits.h:66
int8_t host_type
Definition: type_traits.h:53
static complex< half_t > remove_negative_zero(complex< half_t > x)
Definition: type_traits.h:183
CUTLASS_HOST_DEVICE uint16_t & raw()
Accesses raw internal state.
Definition: half.h:367
static T to_print(T x)
Definition: type_traits.h:46
Definition: type_traits.h:42