82 template <FloatRoundStyle Round>
106 template <
typename T, FloatRoundStyle Round>
132 template <FloatRoundStyle Round>
187 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 188 return half_t(__float2half_rz(flt));
191 unsigned const& s =
reinterpret_cast<unsigned const &
>(flt);
192 uint16_t sign = uint16_t((s >> 16) & 0x8000);
193 int16_t
exp = uint16_t(((s >> 23) & 0xff) - 127);
194 int mantissa = s & 0x7fffff;
197 if ((s & 0x7fffffff) == 0) {
203 if (exp == 128 && mantissa) {
215 exp = uint16_t(exp + uint16_t(15));
216 u = uint16_t(((exp & 0x1f) << 10));
217 u = uint16_t(u | (mantissa >> 13));
220 int rshift = (-14 -
exp);
222 mantissa |= (1 << 23);
223 mantissa = (mantissa >> rshift);
224 u = (uint16_t(mantissa >> 13) & 0x3ff);
235 #endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 262 "Clamp is only needed for integer types");
268 (0x1U << (sizeof_bits<result_type>::value - 1)) - 1;
270 bool is_int_min = !(s > kClamp_min);
271 bool is_int_max = !(s < kClamp_max);
272 return is_int_min ? kClamp_min : (is_int_max ? kClamp_max : convert_op(s));
307 for (
int i = 0; i < N; ++i) {
308 result[i] = convert_(s[i]);
333 Array<half_t, 2> result;
335 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 336 reinterpret_cast<__half2 &
>(result) = __float22half2_rn(reinterpret_cast<float2 const &>(source));
339 result[0] = convert_(source[0]);
340 result[1] = convert_(source[1]);
353 template <FloatRoundStyle Round>
363 Array<float, 2> result;
365 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 366 reinterpret_cast<float2 &
>(result) = __half22float2(reinterpret_cast<__half2 const &>(source));
369 result[0] = convert_(source[0]);
370 result[1] = convert_(source[1]);
403 Array<half_t, 2> *result_ptr =
reinterpret_cast<Array<half_t, 2> *
>(&result);
404 Array<float, 2>
const *source_ptr =
reinterpret_cast<Array<float, 2>
const *
>(&source);
407 for (
int i = 0; i < N / 2; ++i) {
408 result_ptr[i] = convert_vector_(source_ptr[i]);
412 result[N - 1] = convert_element_(source[N - 1]);
444 Array<float, 2> *result_ptr =
reinterpret_cast<Array<float, 2> *
>(&result);
445 Array<half_t, 2>
const *source_ptr =
reinterpret_cast<Array<half_t, 2>
const *
>(&source);
448 for (
int i = 0; i < N / 2; ++i) {
449 result_ptr[i] = convert_vector_(source_ptr[i]);
453 result[N - 1] = convert_element_(source[N - 1]);
468 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2) 481 static result_type convert(
source_type const & source) {
486 result[0] = convert_element_(source[0]);
508 static result_type convert(
source_type const & source) {
513 "cvt.pack.sat.s8.s32.b32 %0, %2, %1, 0;\n" 514 :
"=r"(tmp) :
"r"(source[0]),
"r"(source[1]));
516 uint16_t out = (tmp & 0xffff);
517 return reinterpret_cast<result_type
const &
>(out);
537 static result_type convert(
source_type const & source) {
543 "cvt.pack.sat.s8.s32.b32 r4, %4, %3, 0;" 544 "cvt.pack.sat.s8.s32.b32 %0, %2, %1, r4;" 546 :
"=r"(out) :
"r"(source[0]),
"r"(source[1]),
"r"(source[2]),
"r"(source[3]));
548 return reinterpret_cast<result_type
const &
>(out);
576 Array<int8_t, 4> *result_ptr =
reinterpret_cast<Array<int8_t, 4> *
>(&result);
577 Array<int, 4>
const *source_ptr =
reinterpret_cast<Array<int, 4>
const *
>(&source);
580 for (
int i = 0; i < N / 4; ++i) {
581 result_ptr[i] = convert_vector_(source_ptr[i]);
593 #endif // Conditional guards to enable partial specialization for packed integers 597 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2) 610 static result_type convert(
source_type const & source) {
616 "cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;" 617 "cvt.pack.sat.s4.s32.b32 r4, %6, %5, r4;" 618 "cvt.pack.sat.s4.s32.b32 r4, %4, %3, r4;" 619 "cvt.pack.sat.s4.s32.b32 %0, %2, %1, r4;" 622 :
"r"(source[0]),
"r"(source[1]),
"r"(source[2]),
"r"(source[3]),
623 "r"(source[4]),
"r"(source[5]),
"r"(source[6]),
"r"(source[7]));
625 return reinterpret_cast<result_type
const &
>(out);
653 Array<int4b_t, 8> *result_ptr =
reinterpret_cast<Array<int4b_t, 8> *
>(&result);
654 Array<int, 8>
const *source_ptr =
reinterpret_cast<Array<int, 8>
const *
>(&source);
657 for (
int i = 0; i < N / 8; ++i) {
658 result_ptr[i] = convert_vector_(source_ptr[i]);
670 #endif // Conditional guards to enable partial specialization for packed integers T result_type
Definition: numeric_conversion.h:256
float source_type
Definition: numeric_conversion.h:180
Partial specialization for float <= half_t.
Definition: numeric_conversion.h:133
static CUTLASS_HOST_DEVICE half_t bitcast(uint16_t x)
Constructs from an unsigned short.
Definition: half.h:141
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:120
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:276
Definition: aligned_buffer.h:35
Array< float, 2 > result_type
Definition: numeric_conversion.h:356
Definition: numeric_conversion.h:254
float source_type
Definition: numeric_conversion.h:86
static CUTLASS_HOST_DEVICE result_type convert(source_type const &source)
Definition: numeric_conversion.h:396
static CUTLASS_HOST_DEVICE result_type convert(source_type const &flt)
Round toward zero.
Definition: numeric_conversion.h:185
T result_type
Definition: numeric_conversion.h:109
static CUTLASS_HOST_DEVICE result_type convert(source_type const &source)
Definition: numeric_conversion.h:331
static CUTLASS_HOST_DEVICE result_type convert(source_type const &s)
Definition: numeric_conversion.h:265
Defines a class for using IEEE half-precision floating-point types in host or device code...
T result_type
Definition: numeric_conversion.h:61
IEEE half-precision floating-point type.
Definition: half.h:126
Array< half_t, 2 > source_type
Definition: numeric_conversion.h:357
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:377
Array< float, N > result_type
Definition: numeric_conversion.h:432
CUTLASS_HOST_DEVICE complex< T > exp(complex< T > const &z)
Computes the complex exponential of z.
Definition: complex.h:375
float source_type
Definition: numeric_conversion.h:158
Array< half_t, N > source_type
Definition: numeric_conversion.h:433
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:460
Array< half_t, 2 > result_type
Definition: numeric_conversion.h:326
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:419
static CUTLASS_HOST_DEVICE result_type convert(source_type const &source)
Definition: numeric_conversion.h:437
Array< half_t, N > result_type
Definition: numeric_conversion.h:391
round toward negative infinity
static CUTLASS_HOST_DEVICE result_type convert(source_type const &s)
Definition: numeric_conversion.h:162
static CUTLASS_HOST_DEVICE result_type convert(source_type const &s)
Definition: numeric_conversion.h:114
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:170
add 0.5ulp to integer representation then round toward zero
T source_type
Definition: numeric_conversion.h:110
Partial specialization for Array<float, 2> <= Array<half_t, 2>, round to nearest. ...
Definition: numeric_conversion.h:354
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:239
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Array< S, N > source_type
Definition: numeric_conversion.h:297
Top-level include for all CUTLASS numeric types.
static CUTLASS_HOST_DEVICE result_type convert(source_type const &s)
Definition: numeric_conversion.h:66
S source_type
Definition: numeric_conversion.h:257
Array< float, 2 > source_type
Definition: numeric_conversion.h:327
Definition: numeric_conversion.h:59
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:315
static CUTLASS_HOST_DEVICE result_type convert(source_type const &source)
Definition: numeric_conversion.h:361
FloatRoundStyle
Definition: numeric_conversion.h:43
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:148
int8_t result_type
Definition: numeric_conversion.h:85
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:347
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:72
Array< float, N > source_type
Definition: numeric_conversion.h:392
S source_type
Definition: numeric_conversion.h:62
float result_type
Definition: numeric_conversion.h:135
Conversion operator for Array.
Definition: numeric_conversion.h:294
static CUTLASS_HOST_DEVICE result_type convert(source_type const &s)
Definition: numeric_conversion.h:90
Array< T, N > result_type
Definition: numeric_conversion.h:296
Basic include for CUTLASS.
static CUTLASS_HOST_DEVICE result_type convert(source_type const &s)
Definition: numeric_conversion.h:140
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:98
static CUTLASS_HOST_DEVICE result_type convert(source_type const &s)
Definition: numeric_conversion.h:301