CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
fast_math.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 
26 #pragma once
27 
28 #include <cstdint>
29 #include "cutlass/cutlass.h"
30 
36 namespace cutlass {
37 
38 /******************************************************************************
39  * Static math utilities
40  ******************************************************************************/
41 
45 template <int N>
46 struct is_pow2 {
47  static bool const value = ((N & (N - 1)) == 0);
48 };
49 
53 template <int N, int CurrentVal = N, int Count = 0>
54 struct log2_down {
56  enum { value = log2_down<N, (CurrentVal >> 1), Count + 1>::value };
57 };
58 
59 // Base case
60 template <int N, int Count>
61 struct log2_down<N, 1, Count> {
62  enum { value = Count };
63 };
64 
68 template <int N, int CurrentVal = N, int Count = 0>
69 struct log2_up {
71  enum { value = log2_up<N, (CurrentVal >> 1), Count + 1>::value };
72 };
73 
74 // Base case
75 template <int N, int Count>
76 struct log2_up<N, 1, Count> {
77  enum { value = ((1 << Count) < N) ? Count + 1 : Count };
78 };
79 
83 template <int N>
84 struct sqrt_est {
85  enum { value = 1 << (log2_up<N>::value / 2) };
86 };
87 
92 template <int Dividend, int Divisor>
93 struct divide_assert {
94  enum { value = Dividend / Divisor };
95 
96  static_assert((Dividend % Divisor == 0), "Not an even multiple");
97 };
98 
99 /******************************************************************************
100  * Rounding
101  ******************************************************************************/
102 
106 template <typename dividend_t, typename divisor_t>
107 CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor) {
108  return ((dividend + divisor - 1) / divisor) * divisor;
109 }
110 
114 template <typename value_t>
115 CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b) {
116  for (;;) {
117  if (a == 0) return b;
118  b %= a;
119  if (b == 0) return a;
120  a %= b;
121  }
122 }
123 
127 template <typename value_t>
128 CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) {
129  value_t temp = gcd(a, b);
130 
131  return temp ? (a / temp * b) : 0;
132 }
133 
139 template <typename value_t>
140 CUTLASS_HOST_DEVICE value_t clz(value_t x) {
141  for (int i = 31; i >= 0; --i) {
142  if ((1 << i) & x) return 31 - i;
143  }
144  return 32;
145 }
146 
147 template <typename value_t>
148 CUTLASS_HOST_DEVICE value_t find_log2(value_t x) {
149  int a = int(31 - clz(x));
150  a += (x & (x - 1)) != 0; // Round up, add 1 if not a power of 2.
151  return a;
152 }
153 
154 
159 void find_divisor(unsigned int& mul, unsigned int& shr, unsigned int denom) {
160  if (denom == 1) {
161  mul = 0;
162  shr = 0;
163  } else {
164  unsigned int p = 31 + find_log2(denom);
165  unsigned m = unsigned(((1ull << p) + unsigned(denom) - 1) / unsigned(denom));
166 
167  mul = m;
168  shr = p - 32;
169  }
170 }
171 
176 void fast_divmod(int& quo, int& rem, int src, int div, unsigned int mul, unsigned int shr) {
177 
178  #if defined(__CUDA_ARCH__)
179  // Use IMUL.HI if div != 1, else simply copy the source.
180  quo = (div != 1) ? __umulhi(src, mul) >> shr : src;
181  #else
182  quo = int((div != 1) ? int(src * mul) >> shr : src);
183  #endif
184 
185  // The remainder.
186  rem = src - (quo * div);
187 
188 }
189 
190 // For long int input
192 void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul, unsigned int shr) {
193 
194  #if defined(__CUDA_ARCH__)
195  // Use IMUL.HI if div != 1, else simply copy the source.
196  quo = (div != 1) ? __umulhi(src, mul) >> shr : src;
197  #else
198  quo = int((div != 1) ? (src * mul) >> shr : src);
199  #endif
200  // The remainder.
201  rem = src - (quo * div);
202 }
203 
204 /******************************************************************************
205  * Min/Max
206  ******************************************************************************/
207 
208 template <int A, int B>
209 struct Min {
210  static int const kValue = (A < B) ? A : B;
211 };
212 
213 template <int A, int B>
214 struct Max {
215  static int const kValue = (A > B) ? A : B;
216 };
217 
219 constexpr int const_min(int a, int b) {
220  return (b < a ? b : a);
221 }
222 
224 constexpr int const_max(int a, int b) {
225  return (b > a ? b : a);
226 }
227 
228 } // namespace cutlass
Definition: aligned_buffer.h:35
#define constexpr
Definition: platform.h:137
CUTLASS_HOST_DEVICE void fast_divmod(int &quo, int &rem, int src, int div, unsigned int mul, unsigned int shr)
Definition: fast_math.h:176
CUTLASS_HOST_DEVICE value_t find_log2(value_t x)
Definition: fast_math.h:148
Definition: fast_math.h:54
Definition: fast_math.h:209
CUTLASS_HOST_DEVICE constexpr int const_max(int a, int b)
Definition: fast_math.h:224
static bool const value
Definition: fast_math.h:47
CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b)
Definition: fast_math.h:128
CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor)
Definition: fast_math.h:107
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
#define static_assert(__e, __m)
Definition: platform.h:153
Definition: fast_math.h:214
CUTLASS_HOST_DEVICE void find_divisor(unsigned int &mul, unsigned int &shr, unsigned int denom)
Definition: fast_math.h:159
CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b)
Definition: fast_math.h:115
Definition: fast_math.h:93
Definition: fast_math.h:69
CUTLASS_HOST_DEVICE value_t clz(value_t x)
Definition: fast_math.h:140
Definition: fast_math.h:46
CUTLASS_HOST_DEVICE constexpr int const_min(int a, int b)
Definition: fast_math.h:219
Basic include for CUTLASS.
Definition: fast_math.h:84