CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
numeric_conversion.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 
34 #include "cutlass/array.h"
35 #include "cutlass/half.h"
36 
37 namespace cutlass {
38 
40 
43 enum class FloatRoundStyle {
50 };
51 
53 
54 template <
55  typename T,
56  typename S,
58 >
60 
61  using result_type = T;
62  using source_type = S;
63  static FloatRoundStyle const round_style = Round;
64 
66  static result_type convert(source_type const & s) {
67 
68  return static_cast<result_type>(s);
69  }
70 
73  return convert(s);
74  }
75 };
76 
78 //
79 // Partial specializations for float => int8_t
80 //
82 template <FloatRoundStyle Round>
83 struct NumericConverter<int8_t, float, Round> {
84 
85  using result_type = int8_t;
86  using source_type = float;
87  static FloatRoundStyle const round_style = Round;
88 
90  static result_type convert(source_type const & s) {
91 
92  result_type result = static_cast<int8_t>(s);
93 
94  return result;
95  }
96 
99  return convert(s);
100  }
101 };
102 
104 
106 template <typename T, FloatRoundStyle Round>
107 struct NumericConverter<T, T, Round> {
108 
109  using result_type = T;
110  using source_type = T;
111  static FloatRoundStyle const round_style = Round;
112 
114  static result_type convert(source_type const & s) {
115 
116  return s;
117  }
118 
121  return convert(s);
122  }
123 };
124 
126 //
127 // Partial specializations for float <=> half_t
128 //
130 
132 template <FloatRoundStyle Round>
133 struct NumericConverter<float, half_t, Round> {
134 
135  using result_type = float;
137  static FloatRoundStyle const round_style = Round;
138 
140  static result_type convert(source_type const & s) {
141 
142  result_type result = static_cast<float>(s);
143 
144  return result;
145  }
146 
149  return convert(s);
150  }
151 };
152 
154 template <>
156 
158  using source_type = float;
160 
162  static result_type convert(source_type const & s) {
163 
164  result_type result = static_cast<half_t>(s);
165 
166  return result;
167  }
168 
171  return convert(s);
172  }
173 };
174 
176 template <>
178 
180  using source_type = float;
182 
185  static result_type convert(source_type const & flt) {
186 
187  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
188  return half_t(__float2half_rz(flt));
189  #else
190  // software implementation rounds toward nearest even
191  unsigned const& s = reinterpret_cast<unsigned const &>(flt);
192  uint16_t sign = uint16_t((s >> 16) & 0x8000);
193  int16_t exp = uint16_t(((s >> 23) & 0xff) - 127);
194  int mantissa = s & 0x7fffff;
195  uint16_t u = 0;
196 
197  if ((s & 0x7fffffff) == 0) {
198  // sign-preserving zero
199  return half_t::bitcast(sign);
200  }
201 
202  if (exp > 15) {
203  if (exp == 128 && mantissa) {
204  // not a number
205  u = 0x7fff;
206  } else {
207  // overflow to infinity
208  u = sign | 0x7c00;
209  }
210  return half_t::bitcast(u);
211  }
212 
213  if (exp >= -14) {
214  // normal fp32 to normal fp16
215  exp = uint16_t(exp + uint16_t(15));
216  u = uint16_t(((exp & 0x1f) << 10));
217  u = uint16_t(u | (mantissa >> 13));
218  } else {
219  // normal single-precision to subnormal half_t-precision representation
220  int rshift = (-14 - exp);
221  if (rshift < 32) {
222  mantissa |= (1 << 23);
223  mantissa = (mantissa >> rshift);
224  u = (uint16_t(mantissa >> 13) & 0x3ff);
225  } else {
226  mantissa = 0;
227  u = 0;
228  }
229  }
230 
231  u |= sign;
232 
233  return half_t::bitcast(u);
234 
235  #endif // defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
236  }
237 
240  return convert(s);
241  }
242 };
243 
245 //
246 // Conversion and Clamp operator for Integers
247 //
249 
250 template <
251  typename T,
252  typename S
253 >
255 
256  using result_type = T;
257  using source_type = S;
258 
262  "Clamp is only needed for integer types");
263 
265  static result_type convert(source_type const & s) {
267  result_type const kClamp_max =
268  (0x1U << (sizeof_bits<result_type>::value - 1)) - 1;
269  result_type const kClamp_min = -kClamp_max - 1;
270  bool is_int_min = !(s > kClamp_min);
271  bool is_int_max = !(s < kClamp_max);
272  return is_int_min ? kClamp_min : (is_int_max ? kClamp_max : convert_op(s));
273  }
274 
277  return convert(s);
278  }
279 };
280 
282 //
283 // Conversion operator for Array
284 //
286 
288 template <
289  typename T,
290  typename S,
291  int N,
293 >
295 
296  using result_type = Array<T, N>;
297  using source_type = Array<S, N>;
298  static FloatRoundStyle const round_style = Round;
299 
301  static result_type convert(source_type const & s) {
302 
303  result_type result;
305 
307  for (int i = 0; i < N; ++i) {
308  result[i] = convert_(s[i]);
309  }
310 
311  return result;
312  }
313 
316  return convert(s);
317  }
318 };
319 
321 
323 template <>
325 
326  using result_type = Array<half_t, 2>;
327  using source_type = Array<float, 2>;
329 
331  static result_type convert(source_type const & source) {
332 
333  Array<half_t, 2> result;
334 
335  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
336  reinterpret_cast<__half2 &>(result) = __float22half2_rn(reinterpret_cast<float2 const &>(source));
337  #else
339  result[0] = convert_(source[0]);
340  result[1] = convert_(source[1]);
341  #endif
342 
343  return result;
344  }
345 
348  return convert(s);
349  }
350 };
351 
353 template <FloatRoundStyle Round>
354 struct NumericArrayConverter<float, half_t, 2, Round> {
355 
356  using result_type = Array<float, 2>;
357  using source_type = Array<half_t, 2>;
359 
361  static result_type convert(source_type const & source) {
362 
363  Array<float, 2> result;
364 
365  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
366  reinterpret_cast<float2 &>(result) = __half22float2(reinterpret_cast<__half2 const &>(source));
367  #else
369  result[0] = convert_(source[0]);
370  result[1] = convert_(source[1]);
371  #endif
372 
373  return result;
374  }
375 
378  return convert(s);
379  }
380 };
381 
383 
385 template <
386  int N,
387  FloatRoundStyle Round
388 >
389 struct NumericArrayConverter<half_t, float, N, Round> {
390 
391  using result_type = Array<half_t, N>;
392  using source_type = Array<float, N>;
393  static FloatRoundStyle const round_style = Round;
394 
396  static result_type convert(source_type const & source) {
397 
400 
401  result_type result;
402 
403  Array<half_t, 2> *result_ptr = reinterpret_cast<Array<half_t, 2> *>(&result);
404  Array<float, 2> const *source_ptr = reinterpret_cast<Array<float, 2> const *>(&source);
405 
407  for (int i = 0; i < N / 2; ++i) {
408  result_ptr[i] = convert_vector_(source_ptr[i]);
409  }
410 
411  if (N % 2) {
412  result[N - 1] = convert_element_(source[N - 1]);
413  }
414 
415  return result;
416  }
417 
420  return convert(s);
421  }
422 };
423 
424 
426 template <
427  int N,
428  FloatRoundStyle Round
429 >
430 struct NumericArrayConverter<float, half_t, N, Round> {
431 
432  using result_type = Array<float, N>;
433  using source_type = Array<half_t, N>;
434  static FloatRoundStyle const round_style = Round;
435 
437  static result_type convert(source_type const & source) {
438 
441 
442  result_type result;
443 
444  Array<float, 2> *result_ptr = reinterpret_cast<Array<float, 2> *>(&result);
445  Array<half_t, 2> const *source_ptr = reinterpret_cast<Array<half_t, 2> const *>(&source);
446 
448  for (int i = 0; i < N / 2; ++i) {
449  result_ptr[i] = convert_vector_(source_ptr[i]);
450  }
451 
452  if (N % 2) {
453  result[N - 1] = convert_element_(source[N - 1]);
454  }
455 
456  return result;
457  }
458 
461  return convert(s);
462  }
463 };
464 
466 
467 // Conditional guards to enable partial specialization for packed integers
468 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)
469 
471 template <
472  FloatRoundStyle Round
473 >
474 struct NumericArrayConverter<int8_t, int, 1, Round> {
475 
476  using result_type = Array<int8_t, 1>;
477  using source_type = Array<int, 1>;
478  static FloatRoundStyle const round_style = Round;
479 
481  static result_type convert(source_type const & source) {
482  NumericConverter<int8_t, int, Round> convert_element_;
483 
484  result_type result;
485 
486  result[0] = convert_element_(source[0]);
487 
488  return result;
489  }
490 
492  result_type operator()(source_type const &s) {
493  return convert(s);
494  }
495 };
496 
498 template <
499  FloatRoundStyle Round
500 >
501 struct NumericArrayConverter<int8_t, int, 2, Round> {
502 
503  using result_type = Array<int8_t, 2>;
504  using source_type = Array<int, 2>;
505  static FloatRoundStyle const round_style = Round;
506 
508  static result_type convert(source_type const & source) {
509 
510  uint32_t tmp;
511 
512  asm volatile(
513  "cvt.pack.sat.s8.s32.b32 %0, %2, %1, 0;\n"
514  : "=r"(tmp) : "r"(source[0]), "r"(source[1]));
515 
516  uint16_t out = (tmp & 0xffff);
517  return reinterpret_cast<result_type const &>(out);
518  }
519 
521  result_type operator()(source_type const &s) {
522  return convert(s);
523  }
524 };
525 
527 template <
528  FloatRoundStyle Round
529 >
530 struct NumericArrayConverter<int8_t, int, 4, Round> {
531 
532  using result_type = Array<int8_t, 4>;
533  using source_type = Array<int, 4>;
534  static FloatRoundStyle const round_style = Round;
535 
537  static result_type convert(source_type const & source) {
538 
539  unsigned out;
540 
541  asm volatile(
542  "{ .reg .u32 r4;"
543  "cvt.pack.sat.s8.s32.b32 r4, %4, %3, 0;"
544  "cvt.pack.sat.s8.s32.b32 %0, %2, %1, r4;"
545  "}"
546  : "=r"(out) : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]));
547 
548  return reinterpret_cast<result_type const &>(out);
549  }
550 
552  result_type operator()(source_type const &s) {
553  return convert(s);
554  }
555 };
556 
558 template <
559  int N,
560  FloatRoundStyle Round
561 >
562 struct NumericArrayConverter<int8_t, int, N, Round> {
563  static_assert(!(N % 4), "N must be multiple of 4.");
564 
565  using result_type = Array<int8_t, N>;
566  using source_type = Array<int, N>;
567  static FloatRoundStyle const round_style = Round;
568 
570  static result_type convert(source_type const & source) {
571 
573 
574  result_type result;
575 
576  Array<int8_t, 4> *result_ptr = reinterpret_cast<Array<int8_t, 4> *>(&result);
577  Array<int, 4> const *source_ptr = reinterpret_cast<Array<int, 4> const *>(&source);
578 
580  for (int i = 0; i < N / 4; ++i) {
581  result_ptr[i] = convert_vector_(source_ptr[i]);
582  }
583 
584  return result;
585  }
586 
588  result_type operator()(source_type const &s) {
589  return convert(s);
590  }
591 };
592 
593 #endif // Conditional guards to enable partial specialization for packed integers
594 
596 
597 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)
598 
600 template <
601  FloatRoundStyle Round
602 >
603 struct NumericArrayConverter<int4b_t, int, 8, Round> {
604 
605  using result_type = Array<int4b_t, 8>;
606  using source_type = Array<int, 8>;
607  static FloatRoundStyle const round_style = Round;
608 
610  static result_type convert(source_type const & source) {
611 
612  unsigned out;
613 
614  asm volatile(
615  "{ .reg .u32 r4;"
616  "cvt.pack.sat.s4.s32.b32 r4, %8, %7, 0;"
617  "cvt.pack.sat.s4.s32.b32 r4, %6, %5, r4;"
618  "cvt.pack.sat.s4.s32.b32 r4, %4, %3, r4;"
619  "cvt.pack.sat.s4.s32.b32 %0, %2, %1, r4;"
620  "}"
621  : "=r"(out)
622  : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]),
623  "r"(source[4]), "r"(source[5]), "r"(source[6]), "r"(source[7]));
624 
625  return reinterpret_cast<result_type const &>(out);
626  }
627 
629  result_type operator()(source_type const &s) {
630  return convert(s);
631  }
632 };
633 
635 template <
636  int N,
637  FloatRoundStyle Round
638 >
639 struct NumericArrayConverter<int4b_t, int, N, Round> {
640  static_assert(!(N % 8), "N must be multiple of 8.");
641 
642  using result_type = Array<int4b_t, N>;
643  using source_type = Array<int, N>;
644  static FloatRoundStyle const round_style = Round;
645 
647  static result_type convert(source_type const & source) {
648 
650 
651  result_type result;
652 
653  Array<int4b_t, 8> *result_ptr = reinterpret_cast<Array<int4b_t, 8> *>(&result);
654  Array<int, 8> const *source_ptr = reinterpret_cast<Array<int, 8> const *>(&source);
655 
657  for (int i = 0; i < N / 8; ++i) {
658  result_ptr[i] = convert_vector_(source_ptr[i]);
659  }
660 
661  return result;
662  }
663 
665  result_type operator()(source_type const &s) {
666  return convert(s);
667  }
668 };
669 
670 #endif // Conditional guards to enable partial specialization for packed integers
671 
673 
674 } // namespace cutlass
T result_type
Definition: numeric_conversion.h:256
Partial specialization for float <= half_t.
Definition: numeric_conversion.h:133
static CUTLASS_HOST_DEVICE half_t bitcast(uint16_t x)
Constructs from an unsigned short.
Definition: half.h:141
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:120
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:276
Definition: aligned_buffer.h:35
Array< float, 2 > result_type
Definition: numeric_conversion.h:356
Definition: numeric_conversion.h:254
float source_type
Definition: numeric_conversion.h:86
static CUTLASS_HOST_DEVICE result_type convert(source_type const &source)
Definition: numeric_conversion.h:396
static CUTLASS_HOST_DEVICE result_type convert(source_type const &flt)
Round toward zero.
Definition: numeric_conversion.h:185
T result_type
Definition: numeric_conversion.h:109
static CUTLASS_HOST_DEVICE result_type convert(source_type const &source)
Definition: numeric_conversion.h:331
std::is_same (false specialization)
Definition: platform.h:394
static CUTLASS_HOST_DEVICE result_type convert(source_type const &s)
Definition: numeric_conversion.h:265
Defines a class for using IEEE half-precision floating-point types in host or device code...
T result_type
Definition: numeric_conversion.h:61
IEEE half-precision floating-point type.
Definition: half.h:126
Array< half_t, 2 > source_type
Definition: numeric_conversion.h:357
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:377
Array< float, N > result_type
Definition: numeric_conversion.h:432
CUTLASS_HOST_DEVICE complex< T > exp(complex< T > const &z)
Computes the complex exponential of z.
Definition: complex.h:375
Array< half_t, N > source_type
Definition: numeric_conversion.h:433
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:460
Array< half_t, 2 > result_type
Definition: numeric_conversion.h:326
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:419
static CUTLASS_HOST_DEVICE result_type convert(source_type const &source)
Definition: numeric_conversion.h:437
Array< half_t, N > result_type
Definition: numeric_conversion.h:391
static CUTLASS_HOST_DEVICE result_type convert(source_type const &s)
Definition: numeric_conversion.h:162
static CUTLASS_HOST_DEVICE result_type convert(source_type const &s)
Definition: numeric_conversion.h:114
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:170
add 0.5ulp to integer representation then round toward zero
T source_type
Definition: numeric_conversion.h:110
Partial specialization for Array<float, 2> <= Array<half_t, 2>, round to nearest. ...
Definition: numeric_conversion.h:354
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:239
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Array< S, N > source_type
Definition: numeric_conversion.h:297
Top-level include for all CUTLASS numeric types.
#define static_assert(__e, __m)
Definition: platform.h:153
static CUTLASS_HOST_DEVICE result_type convert(source_type const &s)
Definition: numeric_conversion.h:66
S source_type
Definition: numeric_conversion.h:257
Definition: numeric_conversion.h:59
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:315
static CUTLASS_HOST_DEVICE result_type convert(source_type const &source)
Definition: numeric_conversion.h:361
FloatRoundStyle
Definition: numeric_conversion.h:43
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:148
int8_t result_type
Definition: numeric_conversion.h:85
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:347
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:72
Array< float, N > source_type
Definition: numeric_conversion.h:392
S source_type
Definition: numeric_conversion.h:62
float result_type
Definition: numeric_conversion.h:135
Conversion operator for Array.
Definition: numeric_conversion.h:294
static CUTLASS_HOST_DEVICE result_type convert(source_type const &s)
Definition: numeric_conversion.h:90
Array< T, N > result_type
Definition: numeric_conversion.h:296
Basic include for CUTLASS.
static CUTLASS_HOST_DEVICE result_type convert(source_type const &s)
Definition: numeric_conversion.h:140
CUTLASS_HOST_DEVICE result_type operator()(source_type const &s)
Definition: numeric_conversion.h:98
static CUTLASS_HOST_DEVICE result_type convert(source_type const &s)
Definition: numeric_conversion.h:301