69 ComputeType initial_accum) {
72 LayoutA::kRank == 2 &&
73 LayoutB::kRank == 2 &&
74 LayoutC::kRank == 2,
"Tensors must be of rank 2");
78 int const M = problem_size.
m();
79 int const N = problem_size.
n();
80 int const K = problem_size.
k();
83 int const Mblock = 16;
84 int const Nblock = 16;
87 InnerProductOp inner_product_op;
89 for (
int row_block = 0; row_block < M; row_block += Mblock) {
90 for (
int col_block = 0; col_block < N; col_block += Nblock) {
92 ComputeType accum[Mblock][Nblock];
94 for (
int j = 0; j < Nblock; j++) {
95 for (
int i = 0; i < Mblock; i++) {
96 accum[i][j] = initial_accum;
100 for (
int k_block = 0; k_block < K; ++k_block) {
101 for (
int j = 0; j < Nblock; j++) {
102 for (
int i = 0; i < Mblock; i++) {
103 int row = row_block + i;
104 int col = col_block + j;
106 if (row < M && col < N) {
110 accum[i][j] = inner_product_op(ComputeType(a), ComputeType(b), accum[i][j]);
116 for (
int j = 0; j < Nblock; j++) {
117 for (
int i = 0; i < Mblock; i++) {
118 int row = row_block + i;
119 int col = col_block + j;
123 if (row < M && col < N) {
124 tensor_d.
at(coord) = convert_op(
125 alpha * ScalarType(accum[i][j]) +
126 beta * ScalarType(tensor_c.
at(coord)));
146 typename ComputeType,
157 ComputeType initial_accum) {
158 compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
159 ScalarType, ComputeType, InnerProductOp, ConvertOp>(
160 problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c,
174 typename ComputeType,
175 typename InnerProductOp = cutlass::arch::OpMultiplyAdd
182 template <
typename ElementA,
typename LayoutA,
typename ElementB,
183 typename LayoutB,
typename ElementC,
typename LayoutC,
184 typename ScalarType,
typename ComputeType>
185 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
186 ComputeType, arch::OpMultiplyAdd> {
192 ComputeType initial_accum = ComputeType(0)) {
194 LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
195 "Tensors must be of rank 2");
197 compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
199 problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
207 ComputeType initial_accum = ComputeType(0)) {
209 LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
210 "Tensors must be of rank 2");
212 compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
214 problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
221 template <
typename ElementA,
typename LayoutA,
typename ElementB,
222 typename LayoutB,
typename ElementC,
typename LayoutC,
223 typename ScalarType,
typename ComputeType>
224 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
225 ComputeType, arch::OpMultiplyAddSaturate> {
231 ComputeType initial_accum = ComputeType(0)) {
233 LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
234 "Tensors must be of rank 2");
236 compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
237 ScalarType, ComputeType, multiply_add<ComputeType>,
239 problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
247 ComputeType initial_accum = ComputeType(0)) {
249 LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
250 "Tensors must be of rank 2");
252 compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
253 ScalarType, ComputeType, multiply_add<ComputeType>,
255 problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
262 template <
typename ElementA,
typename LayoutA,
typename ElementB,
263 typename LayoutB,
typename ElementC,
typename LayoutC,
264 typename ScalarType,
typename ComputeType>
265 struct Gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType,
266 ComputeType, arch::OpXorPopc> {
272 ComputeType initial_accum = ComputeType(0)) {
274 LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
275 "Tensors must be of rank 2");
277 compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
279 problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum);
287 ComputeType initial_accum = ComputeType(0)) {
289 LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2,
290 "Tensors must be of rank 2");
292 compute_gemm<ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
294 problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum);
309 typename TensorRefCollectionA,
310 typename TensorRefCollectionB,
311 typename TensorRefCollectionC,
313 typename AccumulatorType
319 TensorRefCollectionA
const& tensor_a,
320 TensorRefCollectionB
const& tensor_b,
322 TensorRefCollectionC &tensor_c,
323 AccumulatorType initial_accum) {
325 typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin();
326 typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin();
327 typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin();
331 ++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) {
333 Gemm<
typename TensorRefCollectionA::Element,
334 typename TensorRefCollectionA::Layout,
335 typename TensorRefCollectionB::Element,
336 typename TensorRefCollectionB::Layout,
337 typename TensorRefCollectionC::Element,
338 typename TensorRefCollectionC::Layout,
339 typename TensorRefCollectionC::Element,
340 typename TensorRefCollectionC::Element>
343 gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it,
354 typename TensorRefCollectionA,
355 typename TensorRefCollectionB,
356 typename TensorRefCollectionC,
358 typename AccumulatorType
364 TensorRefCollectionA
const& tensor_a,
365 TensorRefCollectionB
const& tensor_b,
367 TensorRefCollectionC &tensor_c) {
369 BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0));
Fused multiply-add.
Definition: functional.h:92
void BatchedGemm(gemm::GemmCoord problem_size, int batch_count, ScalarType alpha, TensorRefCollectionA const &tensor_a, TensorRefCollectionB const &tensor_b, ScalarType beta, TensorRefCollectionC &tensor_c, AccumulatorType initial_accum)
Computes a batch of GEMMs over a set of matrices of common dimension.
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:315
Definition: aligned_buffer.h:35
Definition: numeric_conversion.h:254
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Definition: include/cutlass/gemm/gemm.h:94
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:177
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
Defines a structure containing strides and a pointer to tensor data.
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
Templates exposing architecture support for multiply-add operations.
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:242
Boost-like numeric conversion operator for CUTLASS numeric types.
Top-level include for all CUTLASS numeric types.
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:202
Definition: numeric_conversion.h:59
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:282
void compute_gemm(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, TensorRef< ElementC, LayoutC > tensor_d, ComputeType initial_accum)
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:61
Fused multiply-add.
Definition: functional.h:101
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:188
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:227
void operator()(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, TensorRef< ElementB, LayoutB > tensor_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, ComputeType initial_accum=ComputeType(0))
Definition: tools/util/include/cutlass/util/reference/host/gemm.h:268
Defines properties of matrices used to denote layout and operands to GEMM kernels.
Definition: matrix_coord.h:39
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...