52 typename AccumulatorType,
54 typename InnerProductOp,
65 AccumulatorType initial_accum) {
69 (threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow,
70 (threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn
83 > gemm(initial_accum);
91 gemm.epilogue(problem_size, alpha, beta, tensor_c, tensor_d, output_coord);
99 typename TensorRefCollectionA,
100 typename TensorRefCollectionB,
101 typename TensorRefCollectionC,
103 typename AccumulatorType,
105 typename InnerProductOp,
111 TensorRefCollectionA tensor_collection_a,
112 TensorRefCollectionB tensor_collection_b,
114 TensorRefCollectionC tensor_collection_c,
115 AccumulatorType initial_accum) {
118 int batch_id = blockIdx.z;
121 typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id);
122 typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id);
123 typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id);
127 (threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kColumn,
128 (threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kRow
133 typename TensorRefCollectionA::TensorRef,
134 typename TensorRefCollectionB::TensorRef,
135 typename TensorRefCollectionC::TensorRef,
141 > gemm(initial_accum);
149 gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord);
Thread-level blocked general matrix product.
Definition: tools/util/include/cutlass/util/reference/device/thread/gemm.h:57
Definition: aligned_buffer.h:35
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Definition: include/cutlass/gemm/gemm.h:94
Defines common types used for all GEMM-like operators.
Defines a structure containing strides and a pointer to tensor data.
__global__ void BatchedGemm(gemm::GemmCoord problem_size, ScalarType alpha, TensorRefCollectionA tensor_collection_a, TensorRefCollectionB tensor_collection_b, ScalarType beta, TensorRefCollectionC tensor_collection_c, AccumulatorType initial_accum)
Definition: tools/util/include/cutlass/util/reference/device/kernel/gemm.h:108
Defines properties of matrices used to denote layout and operands to GEMM kernels.
__global__ void Gemm(gemm::GemmCoord problem_size, ScalarType alpha, TensorRefA tensor_a, TensorRefB tensor_b, ScalarType beta, TensorRefC tensor_c, TensorRefC tensor_d, AccumulatorType initial_accum)
Definition: tools/util/include/cutlass/util/reference/device/kernel/gemm.h:57
Definition: matrix_coord.h:39