CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
batched_reduction_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2 * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without modification, are permitted
5 * provided that the following conditions are met:
6 * * Redistributions of source code must retain the above copyright notice, this list of
7 * conditions and the following disclaimer.
8 * * Redistributions in binary form must reproduce the above copyright notice, this list of
9 * conditions and the following disclaimer in the documentation and/or other materials
10 * provided with the distribution.
11 * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12 * to endorse or promote products derived from this software without specific prior written
13 * permission.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21 * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 *
24 **************************************************************************************************/
29 #pragma once
30 #include "cutlass/cutlass.h"
31 #include "cutlass/shape.h"
34 #include "cutlass/gemm/linear_scaling.h"
35 
36 namespace cutlass {
37 namespace reduction {
38 
39 /*
40 OutputTile defines the work load per thread block
41 Subtile defines the work load per thread block per iteration
42 OutputTile / Subtile = number of iterations within a kernel
43 ThreadShape defines the work load per thread
44 Subtile / ThreadShape = number of threads per thread block
45 */
46 template <
48  typename ScalarA_,
50  typename ScalarC_,
52  typename ScalarD_,
54  typename ScalarAlphaBeta_,
56  typename ScalarAccum_,
58  int ReductionSize_ = 1,
60  typename OutputTile_ = Shape<1, 1, 128>,
62  typename SubTile_ = Shape<1, 1, 64>,
64  typename ThreadShape_ = Shape<1, 1, 2>,
66  typename Index_ = int,
68  typename BlockSwizzle_ = DefaultBlockSwizzle,
70  int maxInReg_ = 160,
72  int maxOutReg_ = 64,
74  typename Functor_ = typename cutlass::gemm::LinearScaling<ScalarAlphaBeta_, typename cutlass::gemm::FragmentMultiplyAdd<ScalarAlphaBeta_, ScalarAccum_, (ThreadShape_::kW % 2 == 0)> >
75 >
78  typedef BatchedReductionTraits<ScalarA_,
79  ScalarC_,
80  ScalarD_,
81  ScalarAlphaBeta_,
82  ScalarAccum_,
83  ReductionSize_,
84  OutputTile_,
85  SubTile_,
86  ThreadShape_,
87  Index_,
88  BlockSwizzle_,
89  maxInReg_,
90  maxOutReg_,
91  Functor_> This_;
95  typedef OutputTile_ OutputTile;
97  typedef SubTile_ SubTile;
99  typedef ThreadShape_ ThreadShape;
101  typedef ScalarA_ ScalarA;
103  typedef ScalarC_ ScalarC;
105  typedef ScalarD_ ScalarD;
107  typedef ScalarAlphaBeta_ ScalarAlphaBeta;
109  typedef ScalarAccum_ ScalarAccum;
111  typedef Index_ Index;
113  typedef BlockSwizzle_ BlockSwizzle;
115  static const int ReductionSize = ReductionSize_;
117  static const bool ThreadShapeMultiple2 = (ThreadShape::kW % 2 == 0);
119  typedef Functor_ Functor;
122  static int const kThreads = SubTile::kW / ThreadShape::kW;
123  //
124  static int const maxInReg = maxInReg_;
125  //
126  static int const maxOutReg = maxOutReg_;
127  //
128  static_assert(SubTile::kW % ThreadShape::kW == 0, "cannot evenly distribute work load among threads");
129  //
130  static_assert(kThreads % 32 == 0, "threads per threadblock is not multiple of 32");
131  //
132  static_assert(OutputTile::kW % SubTile::kW == 0, "cannot evenly distribute work load among iterations");
133  //
134  static_assert(ReductionSize * ThreadShape::kW <= maxInReg, "ReductionSize * ThreadShape::kW should not be bigger than maxInReg");
135  //
136  static_assert(ThreadShape::kW <= maxOutReg, "ThreadShape::kW should not be bigger than maxOutReg");
137 
138  struct Params {
142  ScalarAlphaBeta alpha;
144  ScalarAlphaBeta beta;
146  long long int reduction_stride;
147  //
148  ScalarA const *d_a;
149  //
150  Index lda;
151  //
152  ScalarC const *d_c;
153  //
154  Index ldc;
155  //
156  ScalarD *d_d;
157  //
158  Index ldd;
160  typename Functor::Params functorParams;
163  Index n_,
164  ScalarAlphaBeta alpha_,
165  ScalarAlphaBeta beta_,
166  long long int reduction_stride_,
167  ScalarA const *d_a_,
168  Index lda_,
169  ScalarC const *d_c_,
170  Index ldc_,
171  ScalarD *d_d_,
172  Index ldd_){
173  problem_size = make_Coord(1, n_, m_);
174  alpha = alpha_;
175  beta = beta_;
176  reduction_stride = reduction_stride_;
177  d_a = d_a_;
178  lda = lda_;
179  d_c = d_c_;
180  d_d = d_d_;
181  ldc = ldc_;
182  ldd = ldd_;
183 
184  functorParams.initialize(alpha_, beta_);
185 
186  return 0;
187  }
188  };
189 
190 };
191 } // namespace reduction
192 } // namespace cutlass
Coord< 3 > problem_size
The dimension of output tensor.
Definition: batched_reduction_traits.h:140
Definition: aligned_buffer.h:35
Definition: batched_reduction_traits.h:138
BlockSwizzle_ BlockSwizzle
The thread block swizzle.
Definition: batched_reduction_traits.h:113
BatchedReductionTraits< ScalarA_, ScalarC_, ScalarD_, ScalarAlphaBeta_, ScalarAccum_, ReductionSize_, OutputTile_, SubTile_, ThreadShape_, Index_, BlockSwizzle_, maxInReg_, maxOutReg_, Functor_ > This_
Definition: batched_reduction_traits.h:91
static int const kThreads
Definition: batched_reduction_traits.h:122
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
ScalarAccum_ ScalarAccum
The type for accumulation.
Definition: batched_reduction_traits.h:109
Index lda
Definition: batched_reduction_traits.h:150
ScalarAlphaBeta beta
The beta.
Definition: batched_reduction_traits.h:144
Defies functors for mapping blockIdx to partitions of the batched reduction computation.
ThreadShape_ ThreadShape
Definition: batched_reduction_traits.h:99
Index ldd
Definition: batched_reduction_traits.h:158
ScalarD_ ScalarD
The output pointer type.
Definition: batched_reduction_traits.h:105
long long int reduction_stride
stride between two element that will be sumed
Definition: batched_reduction_traits.h:146
SubTile_ SubTile
Definition: batched_reduction_traits.h:97
ScalarC const * d_c
Definition: batched_reduction_traits.h:152
OutputTile_ OutputTile
Definition: batched_reduction_traits.h:95
Index ldc
Definition: batched_reduction_traits.h:154
ScalarAlphaBeta_ ScalarAlphaBeta
The alpha beta type.
Definition: batched_reduction_traits.h:107
ScalarC_ ScalarC
Definition: batched_reduction_traits.h:103
ScalarA const * d_a
Definition: batched_reduction_traits.h:148
Definition: batched_reduction.h:52
static const int ReductionSize
Definition: batched_reduction_traits.h:115
ScalarA_ ScalarA
The input pointer type.
Definition: batched_reduction_traits.h:101
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
#define static_assert(__e, __m)
Definition: platform.h:153
ScalarAlphaBeta alpha
The alpha.
Definition: batched_reduction_traits.h:142
Functor_ Functor
Definition: batched_reduction_traits.h:119
static int const maxInReg
Definition: batched_reduction_traits.h:124
ScalarD * d_d
Definition: batched_reduction_traits.h:156
Functor::Params functorParams
The functor params.
Definition: batched_reduction_traits.h:160
static const bool ThreadShapeMultiple2
check if threadShape is multiple of 2.
Definition: batched_reduction_traits.h:117
Index_ Index
The index.
Definition: batched_reduction_traits.h:111
static int const maxOutReg
Definition: batched_reduction_traits.h:126
Implements a software-pipelined efficient batched reduction. D = alpha * Reduction(A) + beta * C...
Basic include for CUTLASS.
Definition: batched_reduction_traits.h:76
CUTLASS_HOST_DEVICE int initialize(Index m_, Index n_, ScalarAlphaBeta alpha_, ScalarAlphaBeta beta_, long long int reduction_stride_, ScalarA const *d_a_, Index lda_, ScalarC const *d_c_, Index ldc_, ScalarD *d_d_, Index ldd_)
Initialize the parameters for 2D output tensor.
Definition: batched_reduction_traits.h:162
cutlass::reduction::BatchedReduction< This_ > KernelClass
The struct that consumes this Traits.
Definition: batched_reduction_traits.h:93