CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
linear_combination.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/array.h"
34 #include "cutlass/functional.h"
36 
38 
39 namespace cutlass {
40 namespace epilogue {
41 namespace thread {
42 
44 
49 template <
50  typename ElementOutput_,
51  int Count,
52  typename ElementAccumulator_ = ElementOutput_,
53  typename ElementCompute_ = ElementOutput_,
55 >
57 public:
58 
59  using ElementOutput = ElementOutput_;
60  using ElementAccumulator = ElementAccumulator_;
61  using ElementCompute = ElementCompute_;
62 
63  static int const kCount = Count;
64 
65  using FragmentOutput = Array<ElementOutput, kCount>;
66  using FragmentAccumulator = Array<ElementAccumulator, kCount>;
67  using ComputeFragment = Array<ElementCompute, kCount>;
68 
69  static FloatRoundStyle const kRound = Round;
70 
72  struct Params {
73 
78 
79  //
80  // Methods
81  //
82 
84  Params():
85  alpha(ElementCompute(1)),
86  beta(ElementCompute(0)),
87  alpha_ptr(nullptr),
88  beta_ptr(nullptr) { }
89 
92  ElementCompute alpha,
93  ElementCompute beta
94  ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {
95 
96  }
97 
100  ElementCompute const *alpha_ptr,
101  ElementCompute const *beta_ptr
102  ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
103 
104  }
105  };
106 
107 private:
108 
109  //
110  // Data members
111  //
112 
113  ElementCompute alpha_;
114  ElementCompute beta_;
115 
116 public:
117 
120  LinearCombination(Params const &params) {
121 
122  alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
123  beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
124  }
125 
128  bool is_source_needed() const {
129  return beta_ != ElementCompute(0);
130  }
131 
134  void set_k_partition(int k_partition) {
135  if (k_partition) {
136  beta_ = ElementCompute(1);
137  }
138  }
139 
143  FragmentAccumulator const &accumulator,
144  FragmentOutput const &source) const {
145 
146  // Convert source to interal compute numeric type
149 
150  ComputeFragment converted_source = source_converter(source);
151  ComputeFragment converted_accumulator = accumulator_converter(accumulator);
152 
153  // Perform binary operations
154 
155  ComputeFragment intermediate;
156 
157  multiplies<ComputeFragment> mul_add_source;
158  multiply_add<ComputeFragment> mul_add_accumulator;
159 
160  intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
161  intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X
162 
163  // Convert to destination numeric type
165 
166  return destination_converter(intermediate);
167  }
168 };
169 
171 
172 } // namespace thread
173 } // namespace epilogue
174 } // namespace cutlass
Fused multiply-add.
Definition: functional.h:92
Definition: aligned_buffer.h:35
Definition: linear_combination.h:56
static int const kCount
Definition: linear_combination.h:63
ElementCompute alpha
scales accumulators
Definition: linear_combination.h:74
ElementCompute const * alpha_ptr
pointer to accumulator scalar - if not null, loads it from memory
Definition: linear_combination.h:76
Array< ElementAccumulator, kCount > FragmentAccumulator
Definition: linear_combination.h:66
CUTLASS_HOST_DEVICE void set_k_partition(int k_partition)
Functionally required for serial reduction in the epilogue.
Definition: linear_combination.h:134
Array< ElementCompute, kCount > ComputeFragment
Definition: linear_combination.h:67
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
static FloatRoundStyle const kRound
Definition: linear_combination.h:69
ElementAccumulator_ ElementAccumulator
Definition: linear_combination.h:60
CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source) const
Computes linear scaling: D = alpha * accumulator + beta * source.
Definition: linear_combination.h:142
ElementOutput_ ElementOutput
Definition: linear_combination.h:59
Boost-like numeric conversion operator for CUTLASS numeric types.
#define nullptr
nullptr
Definition: platform.h:144
CUTLASS_HOST_DEVICE Params()
Definition: linear_combination.h:84
CUTLASS_HOST_DEVICE LinearCombination(Params const &params)
Constructs the function object, possibly loading from pointers in host memory.
Definition: linear_combination.h:120
Definition: functional.h:64
CUTLASS_HOST_DEVICE Params(ElementCompute alpha, ElementCompute beta)
Definition: linear_combination.h:91
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Array< ElementOutput, kCount > FragmentOutput
Definition: linear_combination.h:65
ElementCompute beta
scales source tensor
Definition: linear_combination.h:75
CUTLASS_HOST_DEVICE Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
Definition: linear_combination.h:99
FloatRoundStyle
Definition: numeric_conversion.h:43
ElementCompute_ ElementCompute
Definition: linear_combination.h:61
Conversion operator for Array.
Definition: numeric_conversion.h:294
Basic include for CUTLASS.
Host-constructable parameters structure.
Definition: linear_combination.h:72
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
CUTLASS_HOST_DEVICE bool is_source_needed() const
Returns true if source is needed.
Definition: linear_combination.h:128
ElementCompute const * beta_ptr
pointer to source scalar - if not null, loads it from memory
Definition: linear_combination.h:77