63 typename ConvertOp = NumericConverter<ElementC, ScalarType>,
64 typename InnerProductOp = multiply_add<ComputeType>
75 ComputeType initial_accum) {
78 LayoutA::kRank == 2 &&
79 LayoutB::kRank == 2 &&
80 LayoutC::kRank == 2,
"Tensors must be of rank 2");
83 int const M = problem_size.
m();
84 int const N = problem_size.
n();
85 int const K = problem_size.
k();
88 int const Mblock = 16;
89 int const Nblock = 16;
92 InnerProductOp inner_product_op;
94 for (
int row_block = 0; row_block < M; row_block += Mblock) {
95 for (
int col_block = 0; col_block < N; col_block += Nblock) {
97 ComputeType accum[Mblock][Nblock];
99 for (
int j = 0; j < Nblock; j++) {
100 for (
int i = 0; i < Mblock; i++) {
101 accum[i][j] = initial_accum;
105 for (
int k_block = 0; k_block < K; ++k_block) {
106 for (
int j = 0; j < Nblock; j++) {
107 for (
int i = 0; i < Mblock; i++) {
108 int row = row_block + i;
109 int col = col_block + j;
111 if (row < M && col < N) {
115 ComputeType a_ik = ComputeType(a);
116 ComputeType b_kj = ComputeType(b);
126 accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
132 for (
int j = 0; j < Nblock; j++) {
133 for (
int i = 0; i < Mblock; i++) {
134 int row = row_block + i;
135 int col = col_block + j;
139 if (row < M && col < N) {
141 tensor_c.
at(coord) = convert_op(
142 alpha * ScalarType(accum[i][j]) +
143 beta * ScalarType(tensor_c.
at(coord)));
176 GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, ScalarType(0));
Definition: aligned_buffer.h:35
ComplexTransform
Enumeraed type describing a transformation on a complex value.
Definition: complex.h:43
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Definition: include/cutlass/gemm/gemm.h:94
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
Defines a structure containing strides and a pointer to tensor data.
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
void GemmComplex(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, ComplexTransform transform_a, TensorRef< ElementB, LayoutB > tensor_b, ComplexTransform transform_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, ComputeType initial_accum)
Definition: tools/util/include/cutlass/util/reference/host/gemm_complex.h:66
Boost-like numeric conversion operator for CUTLASS numeric types.
CUTLASS_HOST_DEVICE complex< T > conj(complex< T > const &z)
Returns the complex conjugate.
Definition: complex.h:356
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Defines properties of matrices used to denote layout and operands to GEMM kernels.
Definition: matrix_coord.h:39
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...