31 #include "cutlass/shape.h" 34 #include "cutlass/gemm/linear_scaling.h" 54 typename ScalarAlphaBeta_,
56 typename ScalarAccum_,
58 int ReductionSize_ = 1,
60 typename OutputTile_ = Shape<1, 1, 128>,
62 typename SubTile_ = Shape<1, 1, 64>,
64 typename ThreadShape_ = Shape<1, 1, 2>,
66 typename Index_ = int,
68 typename BlockSwizzle_ = DefaultBlockSwizzle,
74 typename Functor_ =
typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >
122 static int const kThreads = SubTile::kW / ThreadShape::kW;
128 static_assert(SubTile::kW % ThreadShape::kW == 0,
"cannot evenly distribute work load among threads");
130 static_assert(kThreads % 32 == 0,
"threads per threadblock is not multiple of 32");
132 static_assert(OutputTile::kW % SubTile::kW == 0,
"cannot evenly distribute work load among iterations");
134 static_assert(ReductionSize * ThreadShape::kW <= maxInReg,
"ReductionSize * ThreadShape::kW should not be bigger than maxInReg");
136 static_assert(ThreadShape::kW <= maxOutReg,
"ThreadShape::kW should not be bigger than maxOutReg");
164 ScalarAlphaBeta alpha_,
165 ScalarAlphaBeta beta_,
166 long long int reduction_stride_,
176 reduction_stride = reduction_stride_;
184 functorParams.initialize(alpha_, beta_);
Coord< 3 > problem_size
The dimension of output tensor.
Definition: batched_reduction_traits.h:140
Definition: aligned_buffer.h:35
Definition: batched_reduction_traits.h:138
BlockSwizzle_ BlockSwizzle
The thread block swizzle.
Definition: batched_reduction_traits.h:113
BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ > This_
Definition: batched_reduction_traits.h:91
static int const kThreads
Definition: batched_reduction_traits.h:122
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
ScalarAccum_ ScalarAccum
The type for accumulation.
Definition: batched_reduction_traits.h:109
Index lda
Definition: batched_reduction_traits.h:150
ScalarAlphaBeta beta
The beta.
Definition: batched_reduction_traits.h:144
Defies functors for mapping blockIdx to partitions of the batched reduction computation.
ThreadShape_ ThreadShape
Definition: batched_reduction_traits.h:99
Index ldd
Definition: batched_reduction_traits.h:158
ScalarD_ ScalarD
The output pointer type.
Definition: batched_reduction_traits.h:105
long long int reduction_stride
stride between two element that will be sumed
Definition: batched_reduction_traits.h:146
SubTile_ SubTile
Definition: batched_reduction_traits.h:97
ScalarC const * d_c
Definition: batched_reduction_traits.h:152
OutputTile_ OutputTile
Definition: batched_reduction_traits.h:95
Index ldc
Definition: batched_reduction_traits.h:154
ScalarAlphaBeta_ ScalarAlphaBeta
The alpha beta type.
Definition: batched_reduction_traits.h:107
ScalarC_ ScalarC
Definition: batched_reduction_traits.h:103
ScalarA const * d_a
Definition: batched_reduction_traits.h:148
Definition: batched_reduction.h:52
static const int ReductionSize
Definition: batched_reduction_traits.h:115
ScalarA_ ScalarA
The input pointer type.
Definition: batched_reduction_traits.h:101
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
ScalarAlphaBeta alpha
The alpha.
Definition: batched_reduction_traits.h:142
Functor_ Functor
Definition: batched_reduction_traits.h:119
static int const maxInReg
Definition: batched_reduction_traits.h:124
ScalarD * d_d
Definition: batched_reduction_traits.h:156
Functor::Params functorParams
The functor params.
Definition: batched_reduction_traits.h:160
static const bool ThreadShapeMultiple2
check if threadShape is multiple of 2.
Definition: batched_reduction_traits.h:117
Index_ Index
The index.
Definition: batched_reduction_traits.h:111
static int const maxOutReg
Definition: batched_reduction_traits.h:126
Implements a software-pipelined efficient batched reduction. D = alpha * Reduction(A) + beta * C...
Basic include for CUTLASS.
Definition: batched_reduction_traits.h:76
CUTLASS_HOST_DEVICE int initialize(Index m_, Index n_, ScalarAlphaBeta alpha_, ScalarAlphaBeta beta_, long long int reduction_stride_, ScalarA const *d_a_, Index lda_, ScalarC const *d_c_, Index ldc_, ScalarD *d_d_, Index ldd_)
Initialize the parameters for 2D output tensor.
Definition: batched_reduction_traits.h:162
cutlass::reduction::BatchedReduction< This_ > KernelClass
The struct that consumes this Traits.
Definition: batched_reduction_traits.h:93