50 typename ElementOutput_,
52 typename ElementAccumulator_ = ElementOutput_,
53 typename ElementCompute_ = ElementOutput_,
94 ): alpha(alpha), beta(beta), alpha_ptr(
nullptr), beta_ptr(
nullptr) {
102 ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
151 ComputeFragment converted_accumulator = accumulator_converter(accumulator);
160 intermediate = mul_add_source(beta_, converted_source);
161 intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate);
166 return destination_converter(intermediate);
Fused multiply-add.
Definition: functional.h:92
Definition: aligned_buffer.h:35
Definition: linear_combination.h:56
static int const kCount
Definition: linear_combination.h:63
ElementCompute alpha
scales accumulators
Definition: linear_combination.h:74
ElementCompute const * alpha_ptr
pointer to accumulator scalar - if not null, loads it from memory
Definition: linear_combination.h:76
Array< ElementAccumulator, kCount > FragmentAccumulator
Definition: linear_combination.h:66
CUTLASS_HOST_DEVICE void set_k_partition(int k_partition)
Functionally required for serial reduction in the epilogue.
Definition: linear_combination.h:134
Array< ElementCompute, kCount > ComputeFragment
Definition: linear_combination.h:67
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
static FloatRoundStyle const kRound
Definition: linear_combination.h:69
ElementAccumulator_ ElementAccumulator
Definition: linear_combination.h:60
CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source) const
Computes linear scaling: D = alpha * accumulator + beta * source.
Definition: linear_combination.h:142
ElementOutput_ ElementOutput
Definition: linear_combination.h:59
Boost-like numeric conversion operator for CUTLASS numeric types.
CUTLASS_HOST_DEVICE Params()
Definition: linear_combination.h:84
CUTLASS_HOST_DEVICE LinearCombination(Params const ¶ms)
Constructs the function object, possibly loading from pointers in host memory.
Definition: linear_combination.h:120
Definition: functional.h:64
CUTLASS_HOST_DEVICE Params(ElementCompute alpha, ElementCompute beta)
Definition: linear_combination.h:91
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Array< ElementOutput, kCount > FragmentOutput
Definition: linear_combination.h:65
ElementCompute beta
scales source tensor
Definition: linear_combination.h:75
CUTLASS_HOST_DEVICE Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
Definition: linear_combination.h:99
FloatRoundStyle
Definition: numeric_conversion.h:43
ElementCompute_ ElementCompute
Definition: linear_combination.h:61
Conversion operator for Array.
Definition: numeric_conversion.h:294
Basic include for CUTLASS.
Host-constructable parameters structure.
Definition: linear_combination.h:72
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
CUTLASS_HOST_DEVICE bool is_source_needed() const
Returns true if source is needed.
Definition: linear_combination.h:128
ElementCompute const * beta_ptr
pointer to source scalar - if not null, loads it from memory
Definition: linear_combination.h:77