32 #ifndef CUTLASS_ENABLE_F16C 33 #define CUTLASS_ENABLE_F16C 0 36 #if defined(__CUDACC_RTC__) 44 # define FP_INFINITE 1 50 # define FP_SUBNORMAL 3 58 #undef CUTLASS_ENABLE_F16C 59 #define CUTLASS_ENABLE_F16C 0 69 #include <cuda_fp16.h> 76 #if !defined(__CUDA_ARCH__) && (CUTLASS_ENABLE_F16C) 79 #include <immintrin.h> 81 #define F16C_ROUND_NEAREST 0 83 #if !defined(__CUDA_ARCH__) 84 extern __inline
float _cvtsh_ss (
unsigned short __S) {
86 std::memcpy(&packed, &__S,
sizeof(__S));
88 __m128 result = _mm_cvtph_ps(packed);
91 std::memcpy(&flt, &result,
sizeof(flt));
96 __inline
unsigned short _cvtss_sh (
float __F,
const int) {
98 std::memcpy(&packed, &__F,
sizeof(__F));
100 __m128i result = _mm_cvtps_ph(packed, F16C_ROUND_NEAREST);
103 std::memcpy(&u, &result,
sizeof(u));
112 #include <x86intrin.h> 113 #define F16C_ROUND_NEAREST (_MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC) 116 #endif // !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C 148 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) 150 __device__ __noinline__
155 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 156 return half_t(__float2half_rn(flt));
157 #elif !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C 158 unsigned short u = _cvtss_sh(flt, F16C_ROUND_NEAREST);
162 unsigned const& s =
reinterpret_cast<unsigned const &
>(flt);
163 uint16_t sign = uint16_t((s >> 16) & 0x8000);
164 int16_t
exp = uint16_t(((s >> 23) & 0xff) - 127);
168 if ((s & 0x7fffffff) == 0) {
174 if (exp == 128 && mantissa) {
188 exp = uint16_t(exp + uint16_t(15));
189 u = uint16_t(((exp & 0x1f) << 10));
190 u = uint16_t(u | (mantissa >> 13));
193 int rshift = (-14 -
exp);
195 mantissa |= (1 << 23);
197 sticky_bit = ((mantissa & ((1 << rshift) - 1)) != 0);
199 mantissa = (mantissa >> rshift);
200 u = (uint16_t(mantissa >> 13) & 0x3ff);
208 int round_bit = ((mantissa >> 12) & 1);
209 sticky_bit |= ((mantissa & ((1 << 12) - 1)) != 0);
211 if ((round_bit && sticky_bit) || (round_bit && (u & 1))) {
224 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 225 return half_t(__int2half_rn(n));
234 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 235 return half_t(__uint2half_rn(n));
242 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) 244 __device__ __noinline__
249 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 250 return __half2float(x.
to_half());
251 #elif !defined(__CUDA_ARCH__) && CUTLASS_ENABLE_F16C 256 int sign = ((h >> 15) & 1);
257 int exp = ((h >> 10) & 0x1f);
261 if (exp > 0 && exp < 31) {
264 f = (sign << 31) | (exp << 23) | (mantissa << 13);
265 }
else if (exp == 0) {
269 while ((mantissa & (1 << 10)) == 0) {
274 f = (sign << 31) | (exp << 23) | (mantissa << 13);
279 }
else if (exp == 31) {
283 f = (0xff << 23) | (sign << 31);
286 return reinterpret_cast<float const&
>(f);
300 explicit half_t(half
const & x): storage(reinterpret_cast<uint16_t const &>(x)) {
331 storage =
reinterpret_cast<uint16_t
const &
>(x);
337 operator float()
const {
343 operator double()
const {
349 explicit operator int()
const {
355 operator bool()
const {
356 return (
convert(*
this) != 0.0f);
362 return reinterpret_cast<half
const &
>(
storage);
380 return ((storage & 0x8000) != 0);
386 return int((storage >> 10) & 0x1f);
398 return int(storage & 0x3ff);
406 return ((h.
raw() & 0x8000) != 0);
465 #if defined(__CUDACC_RTC__) 475 uint16_t a_mag = (
reinterpret_cast<uint16_t
const &
>(a) & 0x7fff);
476 uint16_t b_sign = (
reinterpret_cast<uint16_t
const &
>(b) & 0x8000);
477 uint16_t result = (a_mag | b_sign);
494 #if !defined(__CUDACC_RTC__) 498 static bool const is_specialized =
true;
499 static bool const is_signed =
true;
500 static bool const is_integer =
false;
501 static bool const is_exact =
false;
502 static bool const has_infinity =
true;
503 static bool const has_quiet_NaN =
true;
504 static bool const has_signaling_NaN =
false;
505 static std::float_denorm_style
const has_denorm = std::denorm_present;
506 static bool const has_denorm_loss =
true;
508 static bool const is_iec559 =
true;
509 static bool const is_bounded =
true;
510 static bool const is_modulo =
false;
511 static int const digits = 10;
556 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 559 return float(lhs) == float(rhs);
565 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 568 return float(lhs) != float(rhs);
574 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 577 return float(lhs) < float(rhs);
583 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 586 return float(lhs) <= float(rhs);
592 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 595 return float(lhs) > float(rhs);
601 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 604 return float(lhs) >= float(rhs);
610 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 613 return half_t(
float(lhs) +
float(rhs));
619 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 622 return half_t(-
float(lhs));
628 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 631 return half_t(
float(lhs) -
float(rhs));
637 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 640 return half_t(
float(lhs) *
float(rhs));
646 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 649 return half_t(
float(lhs) /
float(rhs));
655 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 658 lhs =
half_t(
float(lhs) +
float(rhs));
665 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 668 lhs =
half_t(
float(lhs) -
float(rhs));
675 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 678 lhs =
half_t(
float(lhs) *
float(rhs));
685 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 688 lhs =
half_t(
float(lhs) /
float(rhs));
695 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 707 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 720 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 733 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) static cutlass::half_t max()
Maximum finite value.
Definition: half.h:520
static CUTLASS_HOST_DEVICE half_t bitcast(uint16_t x)
Constructs from an unsigned short.
Definition: half.h:141
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE half_t(int x)
Integer conversion - round to nearest even.
Definition: half.h:318
CUTLASS_HOST_DEVICE T abs(complex< T > const &z)
Returns the magnitude of the complex number.
Definition: complex.h:313
static cutlass::half_t signaling_NaN()
Returns smallest finite value.
Definition: half.h:535
static cutlass::half_t infinity()
Returns smallest finite value.
Definition: half.h:529
CUTLASS_HOST_DEVICE half_t()
Default constructor.
Definition: half.h:296
static CUTLASS_HOST_DEVICE half_t convert(float const &flt)
FP32 -> FP16 conversion - rounds to nearest even.
Definition: half.h:154
CUTLASS_HOST_DEVICE half_t & operator/=(half_t &lhs, half_t const &rhs)
Definition: half.h:684
uint16_t storage
Storage type.
Definition: half.h:133
IEEE half-precision floating-point type.
Definition: half.h:126
static CUTLASS_HOST_DEVICE half_t convert(int const &n)
FP32 -> FP16 conversion - rounds to nearest even.
Definition: half.h:223
CUTLASS_HOST_DEVICE bool isnormal(cutlass::half_t const &h)
Definition: half.h:436
CUTLASS_HOST_DEVICE bool operator<=(half_t const &lhs, half_t const &rhs)
Definition: half.h:582
static cutlass::half_t denorm_min()
Returns smallest finite value.
Definition: half.h:538
CUTLASS_HOST_DEVICE complex< T > exp(complex< T > const &z)
Computes the complex exponential of z.
Definition: complex.h:375
static CUTLASS_HOST_DEVICE float convert(half_t const &x)
Converts a half-precision value stored as a uint16_t to a float.
Definition: half.h:248
CUTLASS_HOST_DEVICE half_t & operator+=(half_t &lhs, half_t const &rhs)
Definition: half.h:654
CUTLASS_HOST_DEVICE half_t(unsigned x)
Integer conversion - round toward zero.
Definition: half.h:324
CUTLASS_HOST_DEVICE half_t operator+(half_t const &lhs, half_t const &rhs)
Definition: half.h:609
CUTLASS_HOST_DEVICE half_t & operator-=(half_t &lhs, half_t const &rhs)
Definition: half.h:664
static cutlass::half_t round_error()
Returns smallest finite value.
Definition: half.h:526
CUTLASS_HOST_DEVICE half_t & operator++(half_t &lhs)
Definition: half.h:694
CUTLASS_HOST_DEVICE bool signbit() const
Returns the sign bit.
Definition: half.h:379
CUTLASS_HOST_DEVICE cutlass::half_t sqrt(cutlass::half_t const &h)
Definition: half.h:464
CUTLASS_HOST_DEVICE int fpclassify(cutlass::half_t const &h)
Definition: half.h:441
CUTLASS_HOST_DEVICE bool operator!=(half_t const &lhs, half_t const &rhs)
Definition: half.h:564
CUTLASS_HOST_DEVICE half to_half() const
Bitcasts to CUDA's half type.
Definition: half.h:361
CUTLASS_HOST_DEVICE half_t & operator=(half const &x)
Assignment.
Definition: half.h:330
CUTLASS_HOST_DEVICE half_t(half const &x)
Reinterpret cast from CUDA's half type.
Definition: half.h:300
CUTLASS_HOST_DEVICE uint16_t raw() const
Accesses raw internal state.
Definition: half.h:373
CUTLASS_HOST_DEVICE bool isinf(cutlass::half_t const &h)
Definition: half.h:431
CUTLASS_HOST_DEVICE half_t copysign(half_t const &a, half_t const &b)
Definition: half.h:473
CUTLASS_HOST_DEVICE half_t & operator--(half_t &lhs)
Definition: half.h:706
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE cutlass::half_t nanh(const char *)
Definition: half.h:425
CUTLASS_HOST_DEVICE half_t(float x)
Floating point conversion.
Definition: half.h:306
CUTLASS_HOST_DEVICE bool operator>(half_t const &lhs, half_t const &rhs)
Definition: half.h:591
CUTLASS_HOST_DEVICE half_t operator-(half_t const &lhs)
Definition: half.h:618
CUTLASS_HOST_DEVICE bool isfinite(cutlass::half_t const &h)
Definition: half.h:420
CUTLASS_HOST_DEVICE half_t & operator*=(half_t &lhs, half_t const &rhs)
Definition: half.h:674
static cutlass::half_t min()
Least positive value.
Definition: half.h:514
CUTLASS_HOST_DEVICE int mantissa() const
Returns the mantissa.
Definition: half.h:397
static cutlass::half_t lowest()
Minimum finite value.
Definition: half.h:517
static cutlass::half_t epsilon()
Returns smallest finite value.
Definition: half.h:523
CUTLASS_HOST_DEVICE bool operator==(half_t const &lhs, half_t const &rhs)
Definition: half.h:555
CUTLASS_HOST_DEVICE Coord< Rank, Index > operator/(Index s, Coord< Rank, Index > coord)
Scalar division.
Definition: coord.h:360
CUTLASS_HOST_DEVICE bool operator>=(half_t const &lhs, half_t const &rhs)
Definition: half.h:600
CUTLASS_HOST_DEVICE half_t(double x)
Floating point conversion.
Definition: half.h:312
static CUTLASS_HOST_DEVICE half_t convert(unsigned const &n)
FP32 -> FP16 conversion - rounds to nearest even.
Definition: half.h:233
CUTLASS_HOST_DEVICE bool operator<(half_t const &lhs, half_t const &rhs)
Definition: half.h:573
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE complex< T > sqrt(complex< T > const &z)
Computes the square root of complex number z.
Definition: complex.h:393
CUTLASS_HOST_DEVICE uint16_t & raw()
Accesses raw internal state.
Definition: half.h:367
static cutlass::half_t quiet_NaN()
Returns smallest finite value.
Definition: half.h:532
CUTLASS_HOST_DEVICE int exponent() const
Returns the unbiased exponent.
Definition: half.h:391
CUTLASS_HOST_DEVICE int exponent_biased() const
Returns the biased exponent.
Definition: half.h:385
CUTLASS_HOST_DEVICE half_t operator*(half_t const &lhs, half_t const &rhs)
Definition: half.h:636
CUTLASS_HOST_DEVICE bool isnan(cutlass::half_t const &h)
Definition: half.h:415