52 typename AccumulatorType,
74 AccumulatorType
accum[OutputTile::kColumn][OutputTile::kRow];
82 Gemm(AccumulatorType initial_accum = AccumulatorType(0)) {
85 for (
int i = 0; i < OutputTile::kColumn; ++i) {
89 for (
int j = 0; j < OutputTile::kColumn; ++j) {
95 for (
int j = 0; j < OutputTile::kColumn; ++j) {
97 for (
int i = 0; i < OutputTile::kRow; ++i) {
98 accum[j][i] = initial_accum;
111 InnerProductOp inner_product_op;
115 for (
int k = 0; k < problem_size.
k(); ++k) {
119 for (
int i = 0; i < OutputTile::kColumn; ++i) {
120 if (output_coord.row() + i < problem_size.
m()) {
121 A_tile[i] = tensor_a.at(
make_Coord(output_coord.row() + i, k));
127 for (
int j = 0; j < OutputTile::kRow; ++j) {
128 if (output_coord.column() + j < problem_size.
n()) {
129 B_tile[j] = tensor_b.at(
make_Coord(k, output_coord.column() + j));
135 for (
int j = 0; j < OutputTile::kRow; ++j) {
137 for (
int i = 0; i < OutputTile::kColumn; ++i) {
138 accum[j][i] = inner_product_op(A_tile[i], B_tile[j], accum[j][i]);
156 ConvertOp convert_op;
159 for (
int j = 0; j < OutputTile::kRow; ++j) {
160 for (
int i = 0; i < OutputTile::kColumn; ++i) {
162 if (coord.
row() < problem_size.
m() && coord.
column() < problem_size.
n()) {
164 tensor_d.at(coord) = convert_op(
165 alpha * ScalarType(accum[j][i]) +
166 beta * ScalarType(tensor_c.at(coord))
Fused multiply-add.
Definition: functional.h:92
CUTLASS_HOST_DEVICE Index const & column() const
Returns the column of the coordinate.
Definition: matrix_coord.h:85
Thread-level blocked general matrix product.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:57
Definition: aligned_buffer.h:35
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
typename TensorRefA::Element ElementA
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:59
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
ElementB B_tile[OutputTile::kRow]
Tile for B operand.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:71
Definition: include/cutlass/gemm/gemm.h:94
AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow]
Tile for Accumulator.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:74
typename TensorRefB::Element ElementB
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:60
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & row() const
Returns the row of the coordinate.
Definition: matrix_coord.h:77
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
Defines a structure containing strides and a pointer to tensor data.
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
CUTLASS_HOST_DEVICE Gemm(AccumulatorType initial_accum=AccumulatorType(0))
Constructor.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:82
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
ElementA A_tile[OutputTile::kColumn]
Tile for A operand.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:68
typename TensorRefC::Element ElementC
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:61
CUTLASS_HOST_DEVICE Gemm & multiply_add(gemm::GemmCoord problem_size, TensorRefA tensor_a, TensorRefB tensor_b, MatrixCoord output_coord=MatrixCoord())
Computes a matrix product.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:105
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
#define CUTLASS_PRAGMA_NO_UNROLL
Definition: cutlass.h:111
CUTLASS_HOST_DEVICE Gemm & epilogue(gemm::GemmCoord problem_size, ScalarType alpha, ScalarType beta, TensorRefC tensor_c, TensorRefC tensor_d, MatrixCoord output_coord=MatrixCoord())
Performs linear scaling of matrix product and updates output tensor.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:148
Definition: numeric_conversion.h:59
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Defines properties of matrices used to denote layout and operands to GEMM kernels.
Definition: matrix_coord.h:39