31 #if !defined(__CUDACC_RTC__) 36 #include "cutlass/util/platform.h" 37 #include "cutlass/fragment.h" 44 template <
typename batched_reduction_>
45 __global__
__launch_bounds__(batched_reduction_::Traits::kThreads, 1) void batched_reduction_kernel(typename batched_reduction_::Params params) {
47 batched_reduction_ batched_reduction(params);
48 batched_reduction.run();
51 template <
typename BatchedReductionTraits_>
56 typedef BatchedReductionTraits_
Traits;
58 typedef typename Traits::Params
Params;
68 CUTLASS_DEVICE
void run() {
69 #if (__CUDA_ARCH__ >= 600) 71 typename Traits::BlockSwizzle block_swizzle;
73 block_swizzle.get_threadblock_offset(make_Coord_from_shape<Traits::SubTile>());
75 int subTileSize = gridDim.x * Traits::SubTile::kW;
76 int tileSize =
params.problem_size[1] *
params.problem_size[2];
77 int subTileOffset = threadblock_offset[2] + threadIdx.x * Traits::ThreadShape::kW;
81 typename Traits::ScalarA inRegs[Traits::maxInReg];
82 typename Traits::ScalarAccum AccumRegs[Traits::maxOutReg];
84 for (
int subTile = 0; subTile < tileSize; subTile += subTileSize) {
85 int tileOffset = subTileBase + subTileOffset;
88 for (
int i = 0; i < Traits::ThreadShape::kW; i++)
89 AccumRegs[i] = static_cast<typename Traits::ScalarAccum>(0.0f);
91 typename Traits::ScalarAccum c0[Traits::ThreadShape::kW];
93 for (
int i = 0; i< Traits::ThreadShape::kW; i++)
94 c0[i] = static_cast<typename Traits::ScalarAccum>(
params.d_c[tileOffset + i]);
98 for (
int s = 0; s < Traits::ReductionSize; s++) {
99 int inRegOffset = s * Traits::ThreadShape::kW;
100 int dOffset = (s * tileSize) + tileOffset;
102 for (
int i = 0; i< Traits::ThreadShape::kW; i++) {
103 inRegs[inRegOffset + i] =
params.d_a[dOffset + i];
109 for (
int s = 0; s < Traits::ReductionSize; s++) {
110 int inRegOffset = s * Traits::ThreadShape::kW;
112 for (
int i = 0; i < Traits::ThreadShape::kW; i++) {
115 AccumRegs[i] =
static_cast<typename Traits::ScalarAccum
>(inRegs[inRegOffset + i]) + AccumRegs[i];
119 functor_caller<Traits::ThreadShapeMultiple2>(AccumRegs, c0, AccumRegs);
123 for (
int i = 0; i < Traits::ThreadShape::kW; i++) {
124 params.d_d[tileOffset + i] =
static_cast<typename Traits::ScalarD
>(AccumRegs[i]);
128 subTileBase += subTileSize;
130 #endif //#if (__CUDA_ARCH__ >= 600) 133 template<
bool ThreadShapeMultiple2>
134 CUTLASS_DEVICE
void functor_caller(
typename Traits::ScalarAccum
const *accum,
typename Traits::ScalarAccum
const *old,
typename Traits::ScalarAccum *output) {
135 if (ThreadShapeMultiple2 ==
true) {
137 for (
int i = 0; i < Traits::ThreadShape::kW / 2; i++) {
138 functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 2>(&accum[2 * i], &old[2 * i], &output[2 * i]);
143 for (
int i = 0; i < Traits::ThreadShape::kW; i++) {
144 functor.template evaluate<typename Traits::ScalarAccum, typename Traits::ScalarAccum, 1>(&accum[i], &old[i], &output[i]);
152 #if !defined(__CUDACC_RTC__) 153 static __host__ cudaError_t
launch(Params
const&
params,
155 cudaStream_t stream = cudaStreamDefault) {
157 typename Traits::BlockSwizzle block_swizzle;
158 dim3 grid = block_swizzle.get_grid_layout(params.problem_size,
159 make_Coord_from_shape<typename Traits::OutputTile>());
162 block.x = Traits::kThreads;
163 batched_reduction_kernel<This_><<<grid, block, 0, stream>>>(
params);
164 return cudaGetLastError();
Definition: aligned_buffer.h:35
Params const & params
The params.
Definition: batched_reduction.h:173
__global__ __launch_bounds__(batched_reduction_::Traits::kThreads, 1) void batched_reduction_kernel(typename batched_reduction_
Definition: batched_reduction.h:45
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_DEVICE void run()
Definition: batched_reduction.h:68
BatchedReduction< BatchedReductionTraits_ > This_
This class.
Definition: batched_reduction.h:54
Functor functor
Definition: batched_reduction.h:175
Definition: batched_reduction.h:52
CUTLASS_DEVICE BatchedReduction(Params const ¶ms_)
ctor
Definition: batched_reduction.h:63
Traits::Params Params
Params.
Definition: batched_reduction.h:58
CUTLASS_DEVICE void functor_caller(typename Traits::ScalarAccum const *accum, typename Traits::ScalarAccum const *old, typename Traits::ScalarAccum *output)
Definition: batched_reduction.h:134
Traits::Functor Functor
functor
Definition: batched_reduction.h:60
BatchedReductionTraits_ Traits
The traits.
Definition: batched_reduction.h:56
static __host__ cudaError_t launch(Params const ¶ms, cudaStream_t stream=cudaStreamDefault)
Launch the kernel.
Definition: batched_reduction.h:154