#include <batched_reduction_traits.h>
|
CUTLASS_HOST_DEVICE int | initialize (Index m_, Index n_, ScalarAlphaBeta alpha_, ScalarAlphaBeta beta_, long long int reduction_stride_, ScalarA const *d_a_, Index lda_, ScalarC const *d_c_, Index ldc_, ScalarD *d_d_, Index ldd_) |
| Initialize the parameters for 2D output tensor. More...
|
|
template<typename ScalarA_ , typename ScalarC_ , typename ScalarD_ , typename ScalarAlphaBeta_ , typename ScalarAccum_ , int ReductionSize_ = 1, typename OutputTile_ = Shape<1, 1, 128>, typename SubTile_ = Shape<1, 1, 64>, typename ThreadShape_ = Shape<1, 1, 2>, typename Index_ = int, typename BlockSwizzle_ = DefaultBlockSwizzle, int maxInReg_ = 160, int maxOutReg_ = 64, typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >>
CUTLASS_HOST_DEVICE int cutlass::reduction::BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ >::Params::initialize |
( |
Index |
m_, |
|
|
Index |
n_, |
|
|
ScalarAlphaBeta |
alpha_, |
|
|
ScalarAlphaBeta |
beta_, |
|
|
long long int |
reduction_stride_, |
|
|
ScalarA const * |
d_a_, |
|
|
Index |
lda_, |
|
|
ScalarC const * |
d_c_, |
|
|
Index |
ldc_, |
|
|
ScalarD * |
d_d_, |
|
|
Index |
ldd_ |
|
) |
| |
|
inline |
template<typename ScalarA_ , typename ScalarC_ , typename ScalarD_ , typename ScalarAlphaBeta_ , typename ScalarAccum_ , int ReductionSize_ = 1, typename OutputTile_ = Shape<1, 1, 128>, typename SubTile_ = Shape<1, 1, 64>, typename ThreadShape_ = Shape<1, 1, 2>, typename Index_ = int, typename BlockSwizzle_ = DefaultBlockSwizzle, int maxInReg_ = 160, int maxOutReg_ = 64, typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >>
ScalarAlphaBeta cutlass::reduction::BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ >::Params::alpha |
template<typename ScalarA_ , typename ScalarC_ , typename ScalarD_ , typename ScalarAlphaBeta_ , typename ScalarAccum_ , int ReductionSize_ = 1, typename OutputTile_ = Shape<1, 1, 128>, typename SubTile_ = Shape<1, 1, 64>, typename ThreadShape_ = Shape<1, 1, 2>, typename Index_ = int, typename BlockSwizzle_ = DefaultBlockSwizzle, int maxInReg_ = 160, int maxOutReg_ = 64, typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >>
ScalarAlphaBeta cutlass::reduction::BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ >::Params::beta |
template<typename ScalarA_ , typename ScalarC_ , typename ScalarD_ , typename ScalarAlphaBeta_ , typename ScalarAccum_ , int ReductionSize_ = 1, typename OutputTile_ = Shape<1, 1, 128>, typename SubTile_ = Shape<1, 1, 64>, typename ThreadShape_ = Shape<1, 1, 2>, typename Index_ = int, typename BlockSwizzle_ = DefaultBlockSwizzle, int maxInReg_ = 160, int maxOutReg_ = 64, typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >>
ScalarA const* cutlass::reduction::BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ >::Params::d_a |
template<typename ScalarA_ , typename ScalarC_ , typename ScalarD_ , typename ScalarAlphaBeta_ , typename ScalarAccum_ , int ReductionSize_ = 1, typename OutputTile_ = Shape<1, 1, 128>, typename SubTile_ = Shape<1, 1, 64>, typename ThreadShape_ = Shape<1, 1, 2>, typename Index_ = int, typename BlockSwizzle_ = DefaultBlockSwizzle, int maxInReg_ = 160, int maxOutReg_ = 64, typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >>
ScalarC const* cutlass::reduction::BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ >::Params::d_c |
template<typename ScalarA_ , typename ScalarC_ , typename ScalarD_ , typename ScalarAlphaBeta_ , typename ScalarAccum_ , int ReductionSize_ = 1, typename OutputTile_ = Shape<1, 1, 128>, typename SubTile_ = Shape<1, 1, 64>, typename ThreadShape_ = Shape<1, 1, 2>, typename Index_ = int, typename BlockSwizzle_ = DefaultBlockSwizzle, int maxInReg_ = 160, int maxOutReg_ = 64, typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >>
ScalarD* cutlass::reduction::BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ >::Params::d_d |
template<typename ScalarA_ , typename ScalarC_ , typename ScalarD_ , typename ScalarAlphaBeta_ , typename ScalarAccum_ , int ReductionSize_ = 1, typename OutputTile_ = Shape<1, 1, 128>, typename SubTile_ = Shape<1, 1, 64>, typename ThreadShape_ = Shape<1, 1, 2>, typename Index_ = int, typename BlockSwizzle_ = DefaultBlockSwizzle, int maxInReg_ = 160, int maxOutReg_ = 64, typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >>
Functor::Params cutlass::reduction::BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ >::Params::functorParams |
template<typename ScalarA_ , typename ScalarC_ , typename ScalarD_ , typename ScalarAlphaBeta_ , typename ScalarAccum_ , int ReductionSize_ = 1, typename OutputTile_ = Shape<1, 1, 128>, typename SubTile_ = Shape<1, 1, 64>, typename ThreadShape_ = Shape<1, 1, 2>, typename Index_ = int, typename BlockSwizzle_ = DefaultBlockSwizzle, int maxInReg_ = 160, int maxOutReg_ = 64, typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >>
Index cutlass::reduction::BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ >::Params::lda |
template<typename ScalarA_ , typename ScalarC_ , typename ScalarD_ , typename ScalarAlphaBeta_ , typename ScalarAccum_ , int ReductionSize_ = 1, typename OutputTile_ = Shape<1, 1, 128>, typename SubTile_ = Shape<1, 1, 64>, typename ThreadShape_ = Shape<1, 1, 2>, typename Index_ = int, typename BlockSwizzle_ = DefaultBlockSwizzle, int maxInReg_ = 160, int maxOutReg_ = 64, typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >>
Index cutlass::reduction::BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ >::Params::ldc |
template<typename ScalarA_ , typename ScalarC_ , typename ScalarD_ , typename ScalarAlphaBeta_ , typename ScalarAccum_ , int ReductionSize_ = 1, typename OutputTile_ = Shape<1, 1, 128>, typename SubTile_ = Shape<1, 1, 64>, typename ThreadShape_ = Shape<1, 1, 2>, typename Index_ = int, typename BlockSwizzle_ = DefaultBlockSwizzle, int maxInReg_ = 160, int maxOutReg_ = 64, typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >>
Index cutlass::reduction::BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ >::Params::ldd |
template<typename ScalarA_ , typename ScalarC_ , typename ScalarD_ , typename ScalarAlphaBeta_ , typename ScalarAccum_ , int ReductionSize_ = 1, typename OutputTile_ = Shape<1, 1, 128>, typename SubTile_ = Shape<1, 1, 64>, typename ThreadShape_ = Shape<1, 1, 2>, typename Index_ = int, typename BlockSwizzle_ = DefaultBlockSwizzle, int maxInReg_ = 160, int maxOutReg_ = 64, typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >>
Coord<3> cutlass::reduction::BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ >::Params::problem_size |
template<typename ScalarA_ , typename ScalarC_ , typename ScalarD_ , typename ScalarAlphaBeta_ , typename ScalarAccum_ , int ReductionSize_ = 1, typename OutputTile_ = Shape<1, 1, 128>, typename SubTile_ = Shape<1, 1, 64>, typename ThreadShape_ = Shape<1, 1, 2>, typename Index_ = int, typename BlockSwizzle_ = DefaultBlockSwizzle, int maxInReg_ = 160, int maxOutReg_ = 64, typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >>
long long int cutlass::reduction::BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ >::Params::reduction_stride |
The documentation for this struct was generated from the following file: