46 template<
typename ElementAlphaBeta,
bool BetaIsZero>
50 ElementAlphaBeta
const &
beta;
54 alpha(alpha_), beta(beta_)
57 template<
typename FragmentCD,
typename FragmentAccumulator>
60 FragmentCD
const& fragment_C,
61 FragmentCD& fragment_D)
const 63 using AccType =
typename FragmentAccumulator::value_type;
64 using CDType =
typename FragmentCD::value_type;
66 static_assert(FragmentCD::kElements == FragmentAccumulator::kElements,
67 "Mistmatch in fragment sizes.");
69 for (
int i = 0; i < FragmentCD::kElements; ++i)
73 fragment_D[i] = CDType(accumulators[i] * AccType(alpha));
77 fragment_D[i] = CDType(accumulators[i] * AccType(alpha)
78 + AccType(fragment_C[i]) * AccType(beta));
87 template <
typename GemvKernel,
typename ElementAlphaBeta,
bool BetaIsZero=false>
90 ElementAlphaBeta
alpha,
91 ElementAlphaBeta
beta,
92 typename GemvKernel::IteratorA::TensorRef ref_A,
93 typename GemvKernel::IteratorA::TensorRef::LongIndex lda,
94 typename GemvKernel::IteratorB::TensorRef ref_B,
95 typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,
96 typename GemvKernel::IteratorCD::TensorRef ref_C,
97 typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc,
98 typename GemvKernel::IteratorCD::TensorRef ref_D,
99 typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
101 using ThreadBlockGemv =
typename GemvKernel::ThreadBlockGemv;
102 using ThreadBlockSwizzle =
typename GemvKernel::ThreadBlockSwizzle;
105 ThreadBlockSwizzle swizzler;
109 int const batch_idx = swizzler.get_batch_idx();
112 ref_A.add_pointer_offset(batch_idx*lda);
113 ref_B.add_pointer_offset(batch_idx*ldb);
116 typename GemvKernel::IteratorA::Params params_A(ref_A.layout());
117 typename GemvKernel::IteratorA iterator_A(
120 { 1, problem_size.
k() },
124 typename GemvKernel::IteratorB::Params params_B(ref_B.layout());
125 typename GemvKernel::IteratorB iterator_B(
128 { problem_size.
k(), problem_size.
n() },
130 { 0, tb_offset.
n()*ThreadBlockGemv::Shape::kN });
139 typename ThreadBlockGemv::FragmentC accumulators;
140 accumulators.clear();
143 mma(problem_size.
mnk(), accumulators, iterator_A, iterator_B, accumulators);
148 typename GemvKernel::FragmentCD fragment_CD;
153 tb_offset = swizzler.get_tile_offset();
154 ref_C.add_pointer_offset(batch_idx*ldc);
155 typename GemvKernel::IteratorCD::Params params_C(ref_C.layout());
156 typename GemvKernel::IteratorCD iterator_C(
159 { 1, problem_size.
n() },
161 { 0, tb_offset.
n()*ThreadBlockGemv::Shape::kN });
162 iterator_C.load(fragment_CD);
166 EpilogueScale epilogue_scale(alpha, beta);
167 epilogue_scale(accumulators, fragment_CD, fragment_CD);
170 tb_offset = swizzler.get_tile_offset();
171 ref_D.add_pointer_offset(batch_idx*ldd);
172 typename GemvKernel::IteratorCD::Params params_D(ref_D.layout());
173 typename GemvKernel::IteratorCD iterator_D(
176 { 1, problem_size.
n() },
178 { 0, tb_offset.
n()*ThreadBlockGemv::Shape::kN });
179 iterator_D.store(fragment_CD);
182 template <
typename GemvKernel,
typename ElementAlphaBeta,
bool BetaIsZero>
185 ElementAlphaBeta
alpha,
186 ElementAlphaBeta
beta,
187 typename GemvKernel::IteratorA::TensorRef ref_A,
188 typename GemvKernel::IteratorA::TensorRef::LongIndex lda,
189 typename GemvKernel::IteratorB::TensorRef ref_B,
190 typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,
191 typename GemvKernel::IteratorCD::TensorRef ref_C,
192 typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc,
193 typename GemvKernel::IteratorCD::TensorRef ref_D,
194 typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
196 GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, BetaIsZero>(
197 problem_size,
alpha,
beta, ref_A, lda, ref_B, ldb, ref_C, ldc, ref_D, ldd
201 template <
typename GemvKernel,
typename ElementAlphaBeta>
204 ElementAlphaBeta
alpha,
205 typename GemvKernel::IteratorA::TensorRef ref_A,
206 typename GemvKernel::IteratorA::TensorRef::LongIndex lda,
207 typename GemvKernel::IteratorB::TensorRef ref_B,
208 typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,
209 typename GemvKernel::IteratorCD::TensorRef ref_D,
210 typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
212 GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, true>(
213 problem_size,
alpha, ElementAlphaBeta(0), ref_A, lda, ref_B, ldb, ref_D, ldd, ref_D, ldd
217 template <
typename GemvKernel>
220 typename GemvKernel::IteratorA::TensorRef ref_A,
221 typename GemvKernel::IteratorA::TensorRef::LongIndex lda,
222 typename GemvKernel::IteratorB::TensorRef ref_B,
223 typename GemvKernel::IteratorB::TensorRef::LongIndex ldb,
224 typename GemvKernel::IteratorCD::TensorRef ref_D,
225 typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
227 using ElementAlphaBeta =
typename GemvKernel::IteratorCD::Element;
228 GemvBatchedStridedDevice<GemvKernel, ElementAlphaBeta, true>(
229 problem_size, ElementAlphaBeta(1), ElementAlphaBeta(0), ref_A, lda, ref_B, ldb, ref_D, ldd, ref_D, ldd
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE GemmCoord mnk() const
Obtains a GemmCoord from BatchedGemmCoord.
Definition: include/cutlass/gemm/gemm.h:330
Defines common types used for all GEMM-like operators.
CUTLASS_DEVICE void GemvBatchedStridedDevice(cutlass::gemm::BatchedGemmCoord problem_size, ElementAlphaBeta alpha, ElementAlphaBeta beta, typename GemvKernel::IteratorA::TensorRef ref_A, typename GemvKernel::IteratorA::TensorRef::LongIndex lda, typename GemvKernel::IteratorB::TensorRef ref_B, typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, typename GemvKernel::IteratorCD::TensorRef ref_C, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc, typename GemvKernel::IteratorCD::TensorRef ref_D, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
Definition: gemv_batched_strided.h:88
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
__global__ void GemvBatchedStrided(cutlass::gemm::BatchedGemmCoord problem_size, ElementAlphaBeta alpha, ElementAlphaBeta beta, typename GemvKernel::IteratorA::TensorRef ref_A, typename GemvKernel::IteratorA::TensorRef::LongIndex lda, typename GemvKernel::IteratorB::TensorRef ref_B, typename GemvKernel::IteratorB::TensorRef::LongIndex ldb, typename GemvKernel::IteratorCD::TensorRef ref_C, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldc, typename GemvKernel::IteratorCD::TensorRef ref_D, typename GemvKernel::IteratorCD::TensorRef::LongIndex ldd)
Definition: gemv_batched_strided.h:183
Defines a Shape template for matrix tiles.
Definition: include/cutlass/gemm/gemm.h:260
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:314
Top-level include for all CUTLASS numeric types.
ElementAlphaBeta const & beta
Definition: gemv_batched_strided.h:50
CUTLASS_DEVICE void operator()(FragmentAccumulator &accumulators, FragmentCD const &fragment_C, FragmentCD &fragment_D) const
Definition: gemv_batched_strided.h:59
ElementAlphaBeta const & alpha
Definition: gemv_batched_strided.h:49
Definition: gemv_batched_strided.h:47
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:306
Basic include for CUTLASS.
CUTLASS_DEVICE GemvBatchedStridedEpilogueScaling(ElementAlphaBeta &alpha_, ElementAlphaBeta &beta_)
Definition: gemv_batched_strided.h:53