CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
gemv_batched_strided.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 
26 #pragma once
27 
28 #include "cutlass/cutlass.h"
29 
30 #include "cutlass/aligned_buffer.h"
31 #include "cutlass/array.h"
32 
33 #include "cutlass/numeric_types.h"
34 #include "cutlass/matrix_shape.h"
35 
36 #include "cutlass/gemm/gemm.h"
37 
39 
40 namespace cutlass {
41 namespace gemm {
42 namespace kernel {
43 
44 namespace detail
45 {
46  template<typename ElementAlphaBeta, bool BetaIsZero>
48  {
49  ElementAlphaBeta const & alpha;
50  ElementAlphaBeta const & beta;
51 
52  CUTLASS_DEVICE
53  GemvBatchedStridedEpilogueScaling(ElementAlphaBeta& alpha_, ElementAlphaBeta& beta_) :
54  alpha(alpha_), beta(beta_)
55  { }
56 
57  template<typename FragmentCD, typename FragmentAccumulator>
58  CUTLASS_DEVICE
59  void operator()(FragmentAccumulator& accumulators,
60  FragmentCD const& fragment_C,
61  FragmentCD& fragment_D) const
62  {
63  using AccType = typename FragmentAccumulator::value_type;
64  using CDType = typename FragmentCD::value_type;
65 
66  static_assert(FragmentCD::kElements == FragmentAccumulator::kElements,
67  "Mistmatch in fragment sizes.");
68 
69  for (int i = 0; i < FragmentCD::kElements; ++i)
70  {
71  if (BetaIsZero)
72  {
73  fragment_D[i] = CDType(accumulators[i] * AccType(alpha));
74  }
75  else
76  {
77  fragment_D[i] = CDType(accumulators[i] * AccType(alpha)
78  + AccType(fragment_C[i]) * AccType(beta));
79  }
80  }
81  }
82  };
83 }
84 
86 
87 template <typename GemvKernel, typename ElementAlphaBeta, bool BetaIsZero=false>
88 CUTLASS_DEVICE void GemvBatchedStridedDevice(
90  ElementAlphaBeta alpha,
91  ElementAlphaBeta beta,
92  typename GemvKernel::IteratorA::TensorRef ref_A,
93  typename GemvKernel::IteratorA::TensorRef::LongIndex lda,
94  typename GemvKernel::IteratorB::TensorRef ref_B,
95  typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,
96  typename GemvKernel::IteratorCD::TensorRef ref_C,
97  typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc,
98  typename GemvKernel::IteratorCD::TensorRef ref_D,
99  typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
100 {
101  using ThreadBlockGemv = typename GemvKernel::ThreadBlockGemv;
102  using ThreadBlockSwizzle = typename GemvKernel::ThreadBlockSwizzle;
104 
105  ThreadBlockSwizzle swizzler;
106 
107  // Compute initial location in logical coordinates
108  BatchedGemmCoord tb_offset = swizzler.get_tile_offset();
109  int const batch_idx = swizzler.get_batch_idx();
110 
111  // Offset to the batch
112  ref_A.add_pointer_offset(batch_idx*lda);
113  ref_B.add_pointer_offset(batch_idx*ldb);
114 
115  // Construct iterators to A and B operands
116  typename GemvKernel::IteratorA::Params params_A(ref_A.layout());
117  typename GemvKernel::IteratorA iterator_A(
118  params_A,
119  ref_A.data(),
120  { 1, problem_size.k() },
121  0,
122  { 0, 0 });
123 
124  typename GemvKernel::IteratorB::Params params_B(ref_B.layout());
125  typename GemvKernel::IteratorB iterator_B(
126  params_B,
127  ref_B.data(),
128  { problem_size.k(), problem_size.n() },
129  threadIdx.x,
130  { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN });
131 
132  //
133  // Main loop
134  //
135 
136  // Construct thread-scoped matrix multiply
137  ThreadBlockGemv mma;
138 
139  typename ThreadBlockGemv::FragmentC accumulators;
140  accumulators.clear();
141 
142  // Compute threadblock-scoped gemv
143  mma(problem_size.mnk(), accumulators, iterator_A, iterator_B, accumulators);
144 
145  //
146  // Epilogue (TODO: Epiloge as template argument)
147  //
148  typename GemvKernel::FragmentCD fragment_CD;
149 
150  // Load C (skip if beta is zero)
151  if (!BetaIsZero)
152  {
153  tb_offset = swizzler.get_tile_offset();
154  ref_C.add_pointer_offset(batch_idx*ldc);
155  typename GemvKernel::IteratorCD::Params params_C(ref_C.layout());
156  typename GemvKernel::IteratorCD iterator_C(
157  params_C,
158  ref_C.data(),
159  { 1, problem_size.n() },
160  threadIdx.x,
161  { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN });
162  iterator_C.load(fragment_CD);
163  }
164 
165  // Apply alpha/beta scaling
166  EpilogueScale epilogue_scale(alpha, beta);
167  epilogue_scale(accumulators, fragment_CD, fragment_CD);
168 
169  // Store D
170  tb_offset = swizzler.get_tile_offset();
171  ref_D.add_pointer_offset(batch_idx*ldd);
172  typename GemvKernel::IteratorCD::Params params_D(ref_D.layout());
173  typename GemvKernel::IteratorCD iterator_D(
174  params_D,
175  ref_D.data(),
176  { 1, problem_size.n() },
177  threadIdx.x,
178  { 0, tb_offset.n()*ThreadBlockGemv::Shape::kN });
179  iterator_D.store(fragment_CD);
180 }
181 
182 template <typename GemvKernel, typename ElementAlphaBeta, bool BetaIsZero>
183 __global__ void GemvBatchedStrided(
184  cutlass::gemm::BatchedGemmCoord problem_size,
185  ElementAlphaBeta alpha,
186  ElementAlphaBeta beta,
187  typename GemvKernel::IteratorA::TensorRef ref_A,
188  typename GemvKernel::IteratorA::TensorRef::LongIndex lda,
189  typename GemvKernel::IteratorB::TensorRef ref_B,
190  typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,
191  typename GemvKernel::IteratorCD::TensorRef ref_C,
192  typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc,
193  typename GemvKernel::IteratorCD::TensorRef ref_D,
194  typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
195 {
196  GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, BetaIsZero>(
197  problem_size, alpha, beta, ref_A, lda, ref_B, ldb, ref_C, ldc, ref_D, ldd
198  );
199 }
200 
201 template <typename GemvKernel, typename ElementAlphaBeta>
202 __global__ void GemvBatchedStrided(
203  cutlass::gemm::BatchedGemmCoord problem_size,
204  ElementAlphaBeta alpha,
205  typename GemvKernel::IteratorA::TensorRef ref_A,
206  typename GemvKernel::IteratorA::TensorRef::LongIndex lda,
207  typename GemvKernel::IteratorB::TensorRef ref_B,
208  typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,
209  typename GemvKernel::IteratorCD::TensorRef ref_D,
210  typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
211 {
212  GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, true>(
213  problem_size, alpha, ElementAlphaBeta(0), ref_A, lda, ref_B, ldb, ref_D, ldd, ref_D, ldd
214  );
215 }
216 
217 template <typename GemvKernel>
218 __global__ void GemvBatchedStrided(
219  cutlass::gemm::BatchedGemmCoord problem_size,
220  typename GemvKernel::IteratorA::TensorRef ref_A,
221  typename GemvKernel::IteratorA::TensorRef::LongIndex lda,
222  typename GemvKernel::IteratorB::TensorRef ref_B,
223  typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,
224  typename GemvKernel::IteratorCD::TensorRef ref_D,
225  typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
226 {
227  using ElementAlphaBeta = typename GemvKernel::IteratorCD::Element;
228  GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, true>(
229  problem_size, ElementAlphaBeta(1), ElementAlphaBeta(0), ref_A, lda, ref_B, ldb, ref_D, ldd, ref_D, ldd
230  );
231 }
232 
233 
235 
236 } // namespace kernel
237 } // namespace gemm
238 } // namespace cutlass
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE GemmCoord mnk() const
Obtains a GemmCoord from BatchedGemmCoord.
Definition: include/cutlass/gemm/gemm.h:330
Defines common types used for all GEMM-like operators.
CUTLASS_DEVICE void GemvBatchedStridedDevice(cutlass::gemm::BatchedGemmCoord problem_size, ElementAlphaBeta alpha, ElementAlphaBeta beta, typename GemvKernel::IteratorA::TensorRef ref_A, typename GemvKernel::IteratorA::TensorRef::LongIndex lda, typename GemvKernel::IteratorB::TensorRef ref_B, typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, typename GemvKernel::IteratorCD::TensorRef ref_C, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc, typename GemvKernel::IteratorCD::TensorRef ref_D, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
Definition: gemv_batched_strided.h:88
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
__global__ void GemvBatchedStrided(cutlass::gemm::BatchedGemmCoord problem_size, ElementAlphaBeta alpha, ElementAlphaBeta beta, typename GemvKernel::IteratorA::TensorRef ref_A, typename GemvKernel::IteratorA::TensorRef::LongIndex lda, typename GemvKernel::IteratorB::TensorRef ref_B, typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, typename GemvKernel::IteratorCD::TensorRef ref_C, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc, typename GemvKernel::IteratorCD::TensorRef ref_D, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
Definition: gemv_batched_strided.h:183
Defines a Shape template for matrix tiles.
Definition: include/cutlass/gemm/gemm.h:260
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:314
Top-level include for all CUTLASS numeric types.
#define static_assert(__e, __m)
Definition: platform.h:153
ElementAlphaBeta const & beta
Definition: gemv_batched_strided.h:50
CUTLASS_DEVICE void operator()(FragmentAccumulator &accumulators, FragmentCD const &fragment_C, FragmentCD &fragment_D) const
Definition: gemv_batched_strided.h:59
ElementAlphaBeta const & alpha
Definition: gemv_batched_strided.h:49
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:306
Basic include for CUTLASS.
CUTLASS_DEVICE GemvBatchedStridedEpilogueScaling(ElementAlphaBeta &alpha_, ElementAlphaBeta &beta_)
Definition: gemv_batched_strided.h:53