CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
linear_combination_clamp.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include "cutlass/cutlass.h"
33 #include "cutlass/numeric_types.h"
34 #include "cutlass/array.h"
35 #include "cutlass/functional.h"
37 
39 
40 namespace cutlass {
41 namespace epilogue {
42 namespace thread {
43 
45 
51 template <
52  typename ElementOutput_,
53  int Count,
54  typename ElementAccumulator_ = ElementOutput_,
55  typename ElementCompute_ = ElementOutput_,
57 >
59 public:
60 
61  using ElementOutput = ElementOutput_;
62  using ElementAccumulator = ElementAccumulator_;
63  using ElementCompute = ElementCompute_;
64 
65  static int const kCount = Count;
66 
67  using FragmentOutput = Array<ElementOutput, kCount>;
68  using FragmentAccumulator = Array<ElementAccumulator, kCount>;
69  using ComputeFragment = Array<ElementCompute, kCount>;
70 
71  static FloatRoundStyle const kRound = Round;
72 
74  struct Params {
75 
80 
81  //
82  // Methods
83  //
84 
86  Params():
87  alpha(ElementCompute(1)),
88  beta(ElementCompute(0)),
89  alpha_ptr(nullptr),
90  beta_ptr(nullptr) { }
91 
94  ElementCompute alpha,
95  ElementCompute beta
96  ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
97 
98  }
99 
102  ElementCompute const *alpha_ptr,
103  ElementCompute const *beta_ptr
104  ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
105 
106  }
107  };
108 
109 private:
110 
111  //
112  // Data members
113  //
114 
115  ElementCompute alpha_;
116  ElementCompute beta_;
117 
118 public:
119 
123 
124  alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
125  beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
126  }
127 
130  bool is_source_needed() const {
131  return beta_ != ElementCompute(0);
132  }
133 
136  void set_k_partition(int k_partition) {
137  if (k_partition) {
138  beta_ = ElementCompute(1);
139  }
140  }
141 
145  FragmentAccumulator const &accumulator,
146  FragmentOutput const &source,
147  ElementCompute uniform = ElementCompute(0)) const {
148 
149  // Convert source to interal compute numeric type
152 
153  ComputeFragment converted_source = source_converter(source);
154  ComputeFragment converted_accumulator = accumulator_converter(accumulator);
155 
156  // Perform binary operations
157 
158  ComputeFragment intermediate;
159 
160  multiplies<ComputeFragment> mul_add_source;
161  multiply_add<ComputeFragment> mul_add_accumulator;
162 
163  minimum<ComputeFragment> min_accumulator;
164  maximum<ComputeFragment> max_accumulator;
165 
166  intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
167  intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
168 
171 
172  intermediate = max_accumulator(intermediate, -kClamp);
173  intermediate = min_accumulator(intermediate, kClamp - ElementCompute(1));
174 
175  // Convert to destination numeric type
177 
178  return destination_converter(intermediate);
179  }
180 
181 };
182 
184 
185 // Conditional guards to enable partial specialization for packed integers
186 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2)
187 
193 template <
194  typename ElementOutput_,
195  int Count,
196  FloatRoundStyle Round
197 >
198 class LinearCombinationClamp<ElementOutput_, Count, int, float, Round> {
199 public:
200 
201  using ElementOutput = ElementOutput_;
202  using ElementAccumulator = int;
203  using ElementCompute = float;
204 
205  static int const kCount = Count;
206 
207  using FragmentOutput = Array<ElementOutput, kCount>;
208  using FragmentAccumulator = Array<ElementAccumulator, kCount>;
209  using ComputeFragment = Array<ElementCompute, kCount>;
210 
211  static FloatRoundStyle const kRound = Round;
212 
214  struct Params {
215 
218  ElementCompute const *alpha_ptr;
219  ElementCompute const *beta_ptr;
220 
221  //
222  // Methods
223  //
224 
226  Params():
227  alpha(ElementCompute(1)),
228  beta(ElementCompute(0)),
229  alpha_ptr(nullptr),
230  beta_ptr(nullptr) { }
231 
233  Params(
234  ElementCompute alpha,
235  ElementCompute beta
236  ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
237 
238  }
239 
241  Params(
242  ElementCompute const *alpha_ptr,
243  ElementCompute const *beta_ptr
244  ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
245 
246  }
247  };
248 
249 private:
250 
251  //
252  // Data members
253  //
254 
255  ElementCompute alpha_;
256  ElementCompute beta_;
257 
258 public:
259 
262  LinearCombinationClamp(Params const &params) {
263 
264  alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
265  beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
266  }
267 
270  bool is_source_needed() const {
271  return beta_ != ElementCompute(0);
272  }
273 
276  void set_k_partition(int k_partition) {
277  if (k_partition) {
278  beta_ = ElementCompute(1);
279  }
280  }
281 
285  FragmentAccumulator const &accumulator,
286  FragmentOutput const &source,
287  ElementCompute uniform = ElementCompute(0)) const {
288 
289  // Convert source to interal compute numeric type
292 
293  ComputeFragment converted_source = source_converter(source);
294  ComputeFragment converted_accumulator = accumulator_converter(accumulator);
295 
296  // Compute linear scaling in floating point
297  ComputeFragment intermediate;
298 
299  multiplies<ComputeFragment> mul_add_source;
300  multiply_add<ComputeFragment> mul_add_accumulator;
301 
302  // Float min-max
303  intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
304  intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
305 
306  // Convert floats back to INT
307  FragmentAccumulator scaled_accumulator;
308 
310  for (int i = 0; i < kCount; ++i) {
311  scaled_accumulator[i] = static_cast<int>(intermediate[i]);
312  }
313 
314  // Convert to destination numeric type
316 
317  return destination_converter(scaled_accumulator);
318  }
319 };
320 
321 #endif // Conditional guards to enable partial specialization for packed integers
322 
324 
325 } // namespace thread
326 } // namespace epilogue
327 } // namespace cutlass
Fused multiply-add.
Definition: functional.h:92
ElementCompute_ ElementCompute
Definition: linear_combination_clamp.h:63
Definition: aligned_buffer.h:35
ElementCompute beta
scales source tensor
Definition: linear_combination_clamp.h:77
CUTLASS_HOST_DEVICE Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
Definition: linear_combination_clamp.h:101
CUTLASS_HOST_DEVICE Params(ElementCompute alpha, ElementCompute beta)
Definition: linear_combination_clamp.h:93
Definition: linear_combination_clamp.h:58
Definition: functional.h:298
Definition: functional.h:235
static int const kCount
Definition: linear_combination_clamp.h:65
CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source, ElementCompute uniform=ElementCompute(0)) const
Computes linear scaling: D = alpha * accumulator + beta * source.
Definition: linear_combination_clamp.h:144
CUTLASS_HOST_DEVICE Params()
Definition: linear_combination_clamp.h:86
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Boost-like numeric conversion operator for CUTLASS numeric types.
Defines the size of an element in bits.
Definition: numeric_types.h:42
#define nullptr
nullptr
Definition: platform.h:144
CUTLASS_HOST_DEVICE LinearCombinationClamp(Params const &params)
Constructs the function object, possibly loading from pointers in host memory.
Definition: linear_combination_clamp.h:122
Definition: functional.h:64
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Array< ElementCompute, kCount > ComputeFragment
Definition: linear_combination_clamp.h:69
ElementOutput_ ElementOutput
Definition: linear_combination_clamp.h:61
Array< ElementOutput, kCount > FragmentOutput
Definition: linear_combination_clamp.h:67
ElementAccumulator_ ElementAccumulator
Definition: linear_combination_clamp.h:62
ElementCompute const * beta_ptr
pointer to source scalar - if not null, loads it from memory
Definition: linear_combination_clamp.h:79
CUTLASS_HOST_DEVICE void set_k_partition(int k_partition)
Functionally required for serial reduction in the epilogue.
Definition: linear_combination_clamp.h:136
FloatRoundStyle
Definition: numeric_conversion.h:43
Conversion operator for Array.
Definition: numeric_conversion.h:294
Host-constructable parameters structure.
Definition: linear_combination_clamp.h:74
static FloatRoundStyle const kRound
Definition: linear_combination_clamp.h:71
CUTLASS_HOST_DEVICE bool is_source_needed() const
Returns true if source is needed.
Definition: linear_combination_clamp.h:130
Basic include for CUTLASS.
ElementCompute const * alpha_ptr
pointer to accumulator scalar - if not null, loads it from memory
Definition: linear_combination_clamp.h:78
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
ElementCompute alpha
scales accumulators
Definition: linear_combination_clamp.h:76
Array< ElementAccumulator, kCount > FragmentAccumulator
Definition: linear_combination_clamp.h:68