52 typename ElementOutput_,
54 typename ElementAccumulator_ = ElementOutput_,
55 typename ElementCompute_ = ElementOutput_,
99 ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(
nullptr), beta_ptr(
nullptr) {
108 ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
160 ComputeFragment converted_accumulator = accumulator_converter(accumulator);
171 intermediate = mul_add_source(beta_, converted_source);
172 intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate);
174 intermediate = max_accumulator(intermediate, threshold_);
179 return destination_converter(intermediate);
192 typename ElementOutput_,
203 static int const kCount = Count;
237 ): alpha(alpha), beta(beta), threshold(threshold), alpha_ptr(
nullptr), beta_ptr(
nullptr) {
246 ): alpha(0), beta(0), threshold(threshold), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
267 alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
268 beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
269 threshold_ = params.threshold;
298 ComputeFragment converted_accumulator = accumulator_converter(accumulator);
309 intermediate = mul_add_source(beta_, converted_source);
310 intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate);
313 intermediate = max_accumulator(intermediate, threshold_);
319 for (
int i = 0; i <
kCount; ++i) {
320 scaled_accumulator[i] =
static_cast<int>(intermediate[i]);
326 return destination_converter(scaled_accumulator);
Fused multiply-add.
Definition: functional.h:92
CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source, ElementCompute uniform=ElementCompute(0)) const
Computes linear scaling: D = alpha * accumulator + beta * source.
Definition: linear_combination_relu.h:150
CUTLASS_HOST_DEVICE Params()
Definition: linear_combination_relu.h:87
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr, ElementCompute threshold=ElementCompute(0))
Definition: linear_combination_relu.h:104
ElementCompute beta
scales source tensor
Definition: linear_combination_relu.h:77
CUTLASS_HOST_DEVICE Params(ElementCompute alpha, ElementCompute beta, ElementCompute threshold=ElementCompute(0))
Definition: linear_combination_relu.h:233
Array< ElementOutput, kCount > FragmentOutput
Definition: linear_combination_relu.h:67
Array< ElementAccumulator, kCount > FragmentAccumulator
Definition: linear_combination_relu.h:68
ElementCompute const * beta_ptr
pointer to source scalar - if not null, loads it from memory
Definition: linear_combination_relu.h:218
Definition: linear_combination_relu.h:58
Definition: functional.h:235
ElementCompute const * beta_ptr
pointer to source scalar - if not null, loads it from memory
Definition: linear_combination_relu.h:80
CUTLASS_HOST_DEVICE LinearCombinationRelu(Params const ¶ms)
Constructs the function object, possibly loading from pointers in host memory.
Definition: linear_combination_relu.h:265
ElementCompute_ ElementCompute
Definition: linear_combination_relu.h:63
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE void set_k_partition(int k_partition)
Functionally required for serial reduction in the epilogue.
Definition: linear_combination_relu.h:280
Boost-like numeric conversion operator for CUTLASS numeric types.
ElementCompute alpha
scales accumulators
Definition: linear_combination_relu.h:214
CUTLASS_HOST_DEVICE LinearCombinationRelu(Params const ¶ms)
Constructs the function object, possibly loading from pointers in host memory.
Definition: linear_combination_relu.h:127
ElementCompute beta
scales source tensor
Definition: linear_combination_relu.h:215
ElementCompute threshold
Relu threshold.
Definition: linear_combination_relu.h:78
Array< ElementOutput, kCount > FragmentOutput
Definition: linear_combination_relu.h:205
CUTLASS_HOST_DEVICE Params(ElementCompute alpha, ElementCompute beta, ElementCompute threshold=ElementCompute(0))
Definition: linear_combination_relu.h:95
Definition: functional.h:64
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
static FloatRoundStyle const kRound
Definition: linear_combination_relu.h:71
Top-level include for all CUTLASS numeric types.
Array< ElementCompute, kCount > ComputeFragment
Definition: linear_combination_relu.h:69
CUTLASS_HOST_DEVICE void set_k_partition(int k_partition)
Functionally required for serial reduction in the epilogue.
Definition: linear_combination_relu.h:142
CUTLASS_HOST_DEVICE Params()
Definition: linear_combination_relu.h:225
Array< ElementCompute, kCount > ComputeFragment
Definition: linear_combination_relu.h:207
ElementOutput_ ElementOutput
Definition: linear_combination_relu.h:199
CUTLASS_HOST_DEVICE bool is_source_needed() const
Returns true if source is needed.
Definition: linear_combination_relu.h:274
FloatRoundStyle
Definition: numeric_conversion.h:43
ElementCompute const * alpha_ptr
pointer to accumulator scalar - if not null, loads it from memory
Definition: linear_combination_relu.h:217
CUTLASS_HOST_DEVICE Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr, ElementCompute threshold=ElementCompute(0))
Definition: linear_combination_relu.h:242
ElementCompute threshold
Relu threshold.
Definition: linear_combination_relu.h:216
Conversion operator for Array.
Definition: numeric_conversion.h:294
ElementCompute alpha
scales accumulators
Definition: linear_combination_relu.h:76
int ElementAccumulator
Definition: linear_combination_relu.h:200
float ElementCompute
Definition: linear_combination_relu.h:201
ElementAccumulator_ ElementAccumulator
Definition: linear_combination_relu.h:62
CUTLASS_HOST_DEVICE bool is_source_needed() const
Returns true if source is needed.
Definition: linear_combination_relu.h:136
CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source, ElementCompute uniform=ElementCompute(0)) const
Computes linear scaling: D = alpha * accumulator + beta * source.
Definition: linear_combination_relu.h:288
static int const kCount
Definition: linear_combination_relu.h:65
Basic include for CUTLASS.
ElementCompute const * alpha_ptr
pointer to accumulator scalar - if not null, loads it from memory
Definition: linear_combination_relu.h:79
Array< ElementAccumulator, kCount > FragmentAccumulator
Definition: linear_combination_relu.h:206
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
Host-constructable parameters structure.
Definition: linear_combination_relu.h:74
ElementOutput_ ElementOutput
Definition: linear_combination_relu.h:61