64 typename AccumulatorType,
76 AccumulatorType initial_accum) {
79 LayoutA::kRank == 2 &&
80 LayoutB::kRank == 2 &&
81 LayoutC::kRank == 2,
"Tensors must be of rank 2");
92 (problem_size.
m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
93 (problem_size.
n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn)
106 ><<< grid, block >>>(
131 typename AccumulatorType,
142 AccumulatorType initial_accum) {
144 compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
145 ScalarType, AccumulatorType, InnerProductOp, ConvertOp>(
146 problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
158 typename AccumulatorType,
159 typename InnerProductOp = cutlass::arch::OpMultiplyAdd
166 template <
typename ElementA,
typename LayoutA,
typename ElementB,
167 typename LayoutB,
typename ElementC,
typename LayoutC,
168 typename ScalarType,
typename AccumulatorType>
169 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
170 ScalarType, AccumulatorType, arch::OpMultiplyAdd> {
176 AccumulatorType initial_accum = AccumulatorType(0)) {
179 LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
180 "Tensors must be of rank 2");
182 compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
184 problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
192 AccumulatorType initial_accum = AccumulatorType(0)) {
194 LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
195 "Tensors must be of rank 2");
197 compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
199 problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
206 template <
typename ElementA,
typename LayoutA,
typename ElementB,
207 typename LayoutB,
typename ElementC,
typename LayoutC,
208 typename ScalarType,
typename AccumulatorType>
209 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
210 AccumulatorType, arch::OpMultiplyAddSaturate> {
216 AccumulatorType initial_accum = AccumulatorType(0)) {
218 LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
219 "Tensors must be of rank 2");
221 compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
222 ScalarType, AccumulatorType, multiply_add<AccumulatorType>,
224 problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
232 AccumulatorType initial_accum = AccumulatorType(0)) {
234 LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
235 "Tensors must be of rank 2");
237 compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
238 ScalarType, AccumulatorType, multiply_add<AccumulatorType>,
240 problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
247 template <
typename ElementA,
typename LayoutA,
typename ElementB,
248 typename LayoutB,
typename ElementC,
typename LayoutC,
249 typename ScalarType,
typename AccumulatorType>
250 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
251 AccumulatorType, arch::OpXorPopc> {
257 AccumulatorType initial_accum = AccumulatorType(0)) {
259 LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
260 "Tensors must be of rank 2");
262 compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
264 problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
272 AccumulatorType initial_accum = AccumulatorType(0)) {
274 LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
275 "Tensors must be of rank 2");
277 compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
279 problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
295 typename TensorRefCollectionA,
296 typename TensorRefCollectionB,
297 typename TensorRefCollectionC,
299 typename AccumulatorType,
300 typename InnerProductOp,
307 TensorRefCollectionA
const& tensor_a,
308 TensorRefCollectionB
const& tensor_b,
310 TensorRefCollectionC &tensor_c,
311 AccumulatorType initial_accum) {
314 TensorRefCollectionA::kRank == 2 &&
315 TensorRefCollectionB::kRank == 2 &&
316 TensorRefCollectionC::kRank == 2,
"Tensors must be of rank 2");
326 (problem_size.
m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow),
327 (problem_size.
n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn),
333 TensorRefCollectionA,
334 TensorRefCollectionB,
335 TensorRefCollectionC,
341 ><<< grid, block >>>(
358 typename TensorRefCollectionA,
359 typename TensorRefCollectionB,
360 typename TensorRefCollectionC,
362 typename AccumulatorType
368 TensorRefCollectionA
const& tensor_a,
369 TensorRefCollectionB
const& tensor_b,
371 TensorRefCollectionC &tensor_c) {
373 BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
Fused multiply-add.
Definition: functional.h:92
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:267
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
Definition: numeric_conversion.h:254
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Definition: include/cutlass/gemm/gemm.h:94
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
Defines a structure containing strides and a pointer to tensor data.
__global__ void BatchedGemm(gemm::GemmCoord problem_size, ScalarType alpha, TensorRefCollectionA tensor_collection_a, TensorRefCollectionB tensor_collection_b, ScalarType beta, TensorRefCollectionC tensor_collection_c, AccumulatorType initial_accum)
Definition: tools/util/include/cutlass/util/reference/device/kernel/gemm.h:108
Boost-like numeric conversion operator for CUTLASS numeric types.
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:161
void BatchedGemm(gemm::GemmCoord problem_size, int batch_count, ScalarType alpha, TensorRefCollectionA const &tensor_a, TensorRefCollectionB const &tensor_b, ScalarType beta, TensorRefCollectionC &tensor_c, AccumulatorType initial_accum)
Computes a batch of GEMMs over a set of matrices of common dimension.
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:303
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:212
Top-level include for all CUTLASS numeric types.
Definition: numeric_conversion.h:59
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:227
Fused multiply-add.
Definition: functional.h:101
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
void compute_gemm(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, AccumulatorType initial_accum)
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:68
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:187
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:253
Defines properties of matrices used to denote layout and operands to GEMM kernels.
__global__ void Gemm(gemm::GemmCoord problem_size, ScalarType alpha, TensorRefA tensor_a, TensorRefB tensor_b, ScalarType beta, TensorRefC tensor_c, TensorRefC tensor_d, AccumulatorType initial_accum)
Definition: tools/util/include/cutlass/util/reference/device/kernel/gemm.h:57
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, AccumulatorType initial_accum=AccumulatorType(0))
Definition: tools/util/include/cutlass/util/reference/device/gemm.h:172