109 reinterpret_cast<ElementC *>(&D), LayoutC::packed({ Shape::kM, Shape::kN }));
124 arch::OpMultiplyAdd>;
130 for (
int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) {
133 for (
int n = 0; n < Shape::kN; ++n) {
136 for (
int m = 0; m < Shape::kM; ++m) {
139 Array<int8_t, 4>
const *ptr_A =
reinterpret_cast<Array<int8_t, 4>
const *
>(&A);
140 Array<int8_t, 4>
const *ptr_B =
reinterpret_cast<Array<int8_t, 4>
const *
>(&B);
142 Array<int32_t, 1> tmp =
reinterpret_cast<Array<int32_t, 1> &
>(d.
at(mn));
146 ptr_A[m * Shape::kK / Mma::Shape::kK + k],
147 ptr_B[n * Shape::kK / Mma::Shape::kK + k],
150 d.
at(mn) =
reinterpret_cast<int32_t &
>(tmp);
222 reinterpret_cast<ElementC *>(&D), LayoutC::packed({ Shape::kM, Shape::kN }));
237 arch::OpMultiplyAdd>;
240 Array<int8_t, 4>
const *ptr_A =
reinterpret_cast<Array<int8_t, 4>
const *
>(&A);
241 Array<int8_t, 4>
const *ptr_B =
reinterpret_cast<Array<int8_t, 4>
const *
>(&B);
245 for (
int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) {
248 for (
int n = 0; n < Shape::kN; ++n) {
251 for (
int m = 0; m < Shape::kM; ++m) {
254 Array<int32_t, 1> tmp =
reinterpret_cast<Array<int32_t, 1> &
>(d.
at(mn));
258 ptr_A[m + k * Shape::kM],
259 ptr_B[n + k * Shape::kN],
262 d.
at(mn) =
reinterpret_cast<int32_t &
>(tmp);
Definition: aligned_buffer.h:35
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm61.h:102
Array< ElementB, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm61.h:204
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<>
Definition: gemm/thread/mma_sm61.h:64
CUTLASS_HOST_DEVICE void operator()(FragmentC &D, FragmentA const &A, FragmentB const &B, FragmentC const &C)
Computes a matrix product D = A * B + C.
Definition: gemm/thread/mma_sm61.h:215
arch::OpMultiplyAdd Operator
Underlying mathematical operator.
Definition: gemm/thread/mma_sm61.h:85
Array< ElementC, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm61.h:207
Defines common types used for all GEMM-like operators.
arch::OpMultiplyAdd Operator
Underlying mathematical operator.
Definition: gemm/thread/mma_sm61.h:198
int8_t ElementB
Data type of operand B.
Definition: gemm/thread/mma_sm61.h:73
Array< ElementA, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm61.h:88
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
int8_t ElementB
Data type of operand B.
Definition: gemm/thread/mma_sm61.h:186
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Templates exposing architecture support for warp-level multiply-add operations.
Array< ElementA, Shape::kMK > FragmentA
A operand storage.
Definition: gemm/thread/mma_sm61.h:201
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
int8_t ElementA
Data type of operand A.
Definition: gemm/thread/mma_sm61.h:67
int32_t ElementC
Element type of operand C.
Definition: gemm/thread/mma_sm61.h:192
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
LayoutC_ LayoutC
Layout of C matrix (concept: layout::MapFunc)
Definition: gemm/thread/mma_sm61.h:195
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
Structure to compute the matrix product.
Definition: gemm/thread/mma.h:66
LayoutC_ LayoutC
Layout of C matrix (concept: layout::MapFunc)
Definition: gemm/thread/mma_sm61.h:82
Defines layout functions used by TensorRef and derived classes.
int32_t ElementC
Element type of operand C.
Definition: gemm/thread/mma_sm61.h:79
Array< ElementB, Shape::kKN > FragmentB
B operand storage.
Definition: gemm/thread/mma_sm61.h:91
Array< ElementC, Shape::kMN > FragmentC
C operand storage.
Definition: gemm/thread/mma_sm61.h:94
Matrix multiply-add operation.
Definition: arch/mma.h:92
Basic include for CUTLASS.
Definition: matrix_coord.h:39
Shape_ Shape
Size of the Gemm problem - concept: gemm::GemmShape<>
Definition: gemm/thread/mma_sm61.h:177
int8_t ElementA
Data type of operand A.
Definition: gemm/thread/mma_sm61.h:180