CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
linear_combination_relu.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include "cutlass/cutlass.h"
33 #include "cutlass/numeric_types.h"
34 #include "cutlass/array.h"
35 #include "cutlass/functional.h"
37 
39 
40 namespace cutlass {
41 namespace epilogue {
42 namespace thread {
43 
45 
51 template <
52  typename ElementOutput_,
53  int Count,
54  typename ElementAccumulator_ = ElementOutput_,
55  typename ElementCompute_ = ElementOutput_,
57 >
59 public:
60 
61  using ElementOutput = ElementOutput_;
62  using ElementAccumulator = ElementAccumulator_;
63  using ElementCompute = ElementCompute_;
64 
65  static int const kCount = Count;
66 
67  using FragmentOutput = Array<ElementOutput, kCount>;
68  using FragmentAccumulator = Array<ElementAccumulator, kCount>;
69  using ComputeFragment = Array<ElementCompute, kCount>;
70 
71  static FloatRoundStyle const kRound = Round;
72 
74  struct Params {
75 
81 
82  //
83  // Methods
84  //
85 
87  Params():
88  alpha(ElementCompute(1)),
89  beta(ElementCompute(0)),
90  threshold(ElementCompute(0)),
91  alpha_ptr(nullptr),
92  beta_ptr(nullptr) { }
93 
96  ElementCompute alpha,
97  ElementCompute beta,
98  ElementCompute threshold = ElementCompute(0)
99  ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
100 
101  }
102 
105  ElementCompute const *alpha_ptr,
106  ElementCompute const *beta_ptr,
107  ElementCompute threshold = ElementCompute(0)
108  ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
109 
110  }
111  };
112 
113 private:
114 
115  //
116  // Data members
117  //
118 
119  ElementCompute alpha_;
120  ElementCompute beta_;
121  ElementCompute threshold_;
122 
123 public:
124 
127  LinearCombinationRelu(Params const &params) {
128 
129  alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
130  beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
131  threshold_ = params.threshold;
132  }
133 
136  bool is_source_needed() const {
137  return beta_ != ElementCompute(0);
138  }
139 
142  void set_k_partition(int k_partition) {
143  if (k_partition) {
144  beta_ = ElementCompute(1);
145  }
146  }
147 
151  FragmentAccumulator const &accumulator,
152  FragmentOutput const &source,
153  ElementCompute uniform = ElementCompute(0)) const {
154 
155  // Convert source to interal compute numeric type
158 
159  ComputeFragment converted_source = source_converter(source);
160  ComputeFragment converted_accumulator = accumulator_converter(accumulator);
161 
162  // Perform binary operations
163 
164  ComputeFragment intermediate;
165 
166  multiplies<ComputeFragment> mul_add_source;
167  multiply_add<ComputeFragment> mul_add_accumulator;
168 
169  maximum<ComputeFragment> max_accumulator;
170 
171  intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
172  intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
173 
174  intermediate = max_accumulator(intermediate, threshold_);
175 
176  // Convert to destination numeric type
178 
179  return destination_converter(intermediate);
180  }
181 };
182 
183 
185 
191 template <
192  typename ElementOutput_,
193  int Count,
194  FloatRoundStyle Round
195 >
196 class LinearCombinationRelu<ElementOutput_, Count, int, float, Round> {
197 public:
198 
199  using ElementOutput = ElementOutput_;
200  using ElementAccumulator = int;
201  using ElementCompute = float;
202 
203  static int const kCount = Count;
204 
205  using FragmentOutput = Array<ElementOutput, kCount>;
206  using FragmentAccumulator = Array<ElementAccumulator, kCount>;
207  using ComputeFragment = Array<ElementCompute, kCount>;
208 
209  static FloatRoundStyle const kRound = Round;
210 
212  struct Params {
213 
219 
220  //
221  // Methods
222  //
223 
225  Params():
226  alpha(ElementCompute(1)),
227  beta(ElementCompute(0)),
228  threshold(ElementCompute(0)),
229  alpha_ptr(nullptr),
230  beta_ptr(nullptr) { }
231 
234  ElementCompute alpha,
235  ElementCompute beta,
236  ElementCompute threshold = ElementCompute(0)
237  ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(nullptr), beta_ptr(nullptr) {
238 
239  }
240 
243  ElementCompute const *alpha_ptr,
244  ElementCompute const *beta_ptr,
245  ElementCompute threshold = ElementCompute(0)
246  ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
247 
248  }
249  };
250 
251 private:
252 
253  //
254  // Data members
255  //
256 
257  ElementCompute alpha_;
258  ElementCompute beta_;
259  ElementCompute threshold_;
260 
261 public:
262 
265  LinearCombinationRelu(Params const &params) {
266 
267  alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
268  beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
269  threshold_ = params.threshold;
270  }
271 
274  bool is_source_needed() const {
275  return beta_ != ElementCompute(0);
276  }
277 
280  void set_k_partition(int k_partition) {
281  if (k_partition) {
282  beta_ = ElementCompute(1);
283  }
284  }
285 
289  FragmentAccumulator const &accumulator,
290  FragmentOutput const &source,
291  ElementCompute uniform = ElementCompute(0)) const {
292 
293  // Convert source to interal compute numeric type
296 
297  ComputeFragment converted_source = source_converter(source);
298  ComputeFragment converted_accumulator = accumulator_converter(accumulator);
299 
300  // Perform binary operations
301 
302  ComputeFragment intermediate;
303 
304  multiplies<ComputeFragment> mul_add_source;
305  multiply_add<ComputeFragment> mul_add_accumulator;
306 
307  maximum<ComputeFragment> max_accumulator;
308 
309  intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
310  intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
311 
312  // Clamp to theshold
313  intermediate = max_accumulator(intermediate, threshold_);
314 
315  // Convert back to accumulator data type
316  FragmentAccumulator scaled_accumulator;
317 
319  for (int i = 0; i < kCount; ++i) {
320  scaled_accumulator[i] = static_cast<int>(intermediate[i]);
321  }
322 
323  // Convert to destination numeric type and pack
325 
326  return destination_converter(scaled_accumulator);
327  }
328 };
329 
331 
332 } // namespace thread
333 } // namespace epilogue
334 } // namespace cutlass
Fused multiply-add.
Definition: functional.h:92
CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source, ElementCompute uniform=ElementCompute(0)) const
Computes linear scaling: D = alpha * accumulator + beta * source.
Definition: linear_combination_relu.h:150
CUTLASS_HOST_DEVICE Params()
Definition: linear_combination_relu.h:87
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr, ElementCompute threshold=ElementCompute(0))
Definition: linear_combination_relu.h:104
ElementCompute beta
scales source tensor
Definition: linear_combination_relu.h:77
CUTLASS_HOST_DEVICE Params(ElementCompute alpha, ElementCompute beta, ElementCompute threshold=ElementCompute(0))
Definition: linear_combination_relu.h:233
Array< ElementOutput, kCount > FragmentOutput
Definition: linear_combination_relu.h:67
Array< ElementAccumulator, kCount > FragmentAccumulator
Definition: linear_combination_relu.h:68
ElementCompute const * beta_ptr
pointer to source scalar - if not null, loads it from memory
Definition: linear_combination_relu.h:218
Definition: linear_combination_relu.h:58
Definition: functional.h:235
ElementCompute const * beta_ptr
pointer to source scalar - if not null, loads it from memory
Definition: linear_combination_relu.h:80
CUTLASS_HOST_DEVICE LinearCombinationRelu(Params const &params)
Constructs the function object, possibly loading from pointers in host memory.
Definition: linear_combination_relu.h:265
ElementCompute_ ElementCompute
Definition: linear_combination_relu.h:63
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE void set_k_partition(int k_partition)
Functionally required for serial reduction in the epilogue.
Definition: linear_combination_relu.h:280
Boost-like numeric conversion operator for CUTLASS numeric types.
#define nullptr
nullptr
Definition: platform.h:144
ElementCompute alpha
scales accumulators
Definition: linear_combination_relu.h:214
CUTLASS_HOST_DEVICE LinearCombinationRelu(Params const &params)
Constructs the function object, possibly loading from pointers in host memory.
Definition: linear_combination_relu.h:127
ElementCompute beta
scales source tensor
Definition: linear_combination_relu.h:215
ElementCompute threshold
Relu threshold.
Definition: linear_combination_relu.h:78
Array< ElementOutput, kCount > FragmentOutput
Definition: linear_combination_relu.h:205
CUTLASS_HOST_DEVICE Params(ElementCompute alpha, ElementCompute beta, ElementCompute threshold=ElementCompute(0))
Definition: linear_combination_relu.h:95
Definition: functional.h:64
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
static FloatRoundStyle const kRound
Definition: linear_combination_relu.h:71
Top-level include for all CUTLASS numeric types.
Array< ElementCompute, kCount > ComputeFragment
Definition: linear_combination_relu.h:69
CUTLASS_HOST_DEVICE void set_k_partition(int k_partition)
Functionally required for serial reduction in the epilogue.
Definition: linear_combination_relu.h:142
Array< ElementCompute, kCount > ComputeFragment
Definition: linear_combination_relu.h:207
CUTLASS_HOST_DEVICE bool is_source_needed() const
Returns true if source is needed.
Definition: linear_combination_relu.h:274
FloatRoundStyle
Definition: numeric_conversion.h:43
ElementCompute const * alpha_ptr
pointer to accumulator scalar - if not null, loads it from memory
Definition: linear_combination_relu.h:217
CUTLASS_HOST_DEVICE Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr, ElementCompute threshold=ElementCompute(0))
Definition: linear_combination_relu.h:242
ElementCompute threshold
Relu threshold.
Definition: linear_combination_relu.h:216
Conversion operator for Array.
Definition: numeric_conversion.h:294
ElementCompute alpha
scales accumulators
Definition: linear_combination_relu.h:76
ElementAccumulator_ ElementAccumulator
Definition: linear_combination_relu.h:62
CUTLASS_HOST_DEVICE bool is_source_needed() const
Returns true if source is needed.
Definition: linear_combination_relu.h:136
CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source, ElementCompute uniform=ElementCompute(0)) const
Computes linear scaling: D = alpha * accumulator + beta * source.
Definition: linear_combination_relu.h:288
static int const kCount
Definition: linear_combination_relu.h:65
Basic include for CUTLASS.
ElementCompute const * alpha_ptr
pointer to accumulator scalar - if not null, loads it from memory
Definition: linear_combination_relu.h:79
Array< ElementAccumulator, kCount > FragmentAccumulator
Definition: linear_combination_relu.h:206
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
Host-constructable parameters structure.
Definition: linear_combination_relu.h:74
ElementOutput_ ElementOutput
Definition: linear_combination_relu.h:61