52 typename ElementOutput_,
54 typename ElementAccumulator_ = ElementOutput_,
55 typename ElementCompute_ = ElementOutput_,
96 ): alpha(alpha), beta(beta), alpha_ptr(
nullptr), beta_ptr(
nullptr) {
104 ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {
154 ComputeFragment converted_accumulator = accumulator_converter(accumulator);
166 intermediate = mul_add_source(beta_, converted_source);
167 intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate);
172 intermediate = max_accumulator(intermediate, -kClamp);
173 intermediate = min_accumulator(intermediate, kClamp -
ElementCompute(1));
178 return destination_converter(intermediate);
186 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 720) && (__CUDACC_VER_MAJOR__ >= 10) && (__CUDACC_VER_MINOR__ >= 2) 194 typename ElementOutput_,
205 static int const kCount = Count;
294 ComputeFragment converted_accumulator = accumulator_converter(accumulator);
303 intermediate = mul_add_source(beta_, converted_source);
304 intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate);
310 for (
int i = 0; i <
kCount; ++i) {
311 scaled_accumulator[i] =
static_cast<int>(intermediate[i]);
317 return destination_converter(scaled_accumulator);
321 #endif // Conditional guards to enable partial specialization for packed integers Fused multiply-add.
Definition: functional.h:92
ElementCompute_ ElementCompute
Definition: linear_combination_clamp.h:63
Definition: aligned_buffer.h:35
ElementCompute beta
scales source tensor
Definition: linear_combination_clamp.h:77
CUTLASS_HOST_DEVICE Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr)
Definition: linear_combination_clamp.h:101
CUTLASS_HOST_DEVICE Params(ElementCompute alpha, ElementCompute beta)
Definition: linear_combination_clamp.h:93
Definition: linear_combination_clamp.h:58
Definition: functional.h:298
Definition: functional.h:235
static int const kCount
Definition: linear_combination_clamp.h:65
CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source, ElementCompute uniform=ElementCompute(0)) const
Computes linear scaling: D = alpha * accumulator + beta * source.
Definition: linear_combination_clamp.h:144
CUTLASS_HOST_DEVICE Params()
Definition: linear_combination_clamp.h:86
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Boost-like numeric conversion operator for CUTLASS numeric types.
Defines the size of an element in bits.
Definition: numeric_types.h:42
CUTLASS_HOST_DEVICE LinearCombinationClamp(Params const ¶ms)
Constructs the function object, possibly loading from pointers in host memory.
Definition: linear_combination_clamp.h:122
Definition: functional.h:64
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Array< ElementCompute, kCount > ComputeFragment
Definition: linear_combination_clamp.h:69
ElementOutput_ ElementOutput
Definition: linear_combination_clamp.h:61
Array< ElementOutput, kCount > FragmentOutput
Definition: linear_combination_clamp.h:67
ElementAccumulator_ ElementAccumulator
Definition: linear_combination_clamp.h:62
ElementCompute const * beta_ptr
pointer to source scalar - if not null, loads it from memory
Definition: linear_combination_clamp.h:79
CUTLASS_HOST_DEVICE void set_k_partition(int k_partition)
Functionally required for serial reduction in the epilogue.
Definition: linear_combination_clamp.h:136
FloatRoundStyle
Definition: numeric_conversion.h:43
Conversion operator for Array.
Definition: numeric_conversion.h:294
Host-constructable parameters structure.
Definition: linear_combination_clamp.h:74
static FloatRoundStyle const kRound
Definition: linear_combination_clamp.h:71
CUTLASS_HOST_DEVICE bool is_source_needed() const
Returns true if source is needed.
Definition: linear_combination_clamp.h:130
Basic include for CUTLASS.
ElementCompute const * alpha_ptr
pointer to accumulator scalar - if not null, loads it from memory
Definition: linear_combination_clamp.h:78
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
ElementCompute alpha
scales accumulators
Definition: linear_combination_clamp.h:76
Array< ElementAccumulator, kCount > FragmentAccumulator
Definition: linear_combination_clamp.h:68