CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
|
#include <mma_base.h>
Classes | |
class | SharedStorage |
Shared storage object needed by threadblock-scoped GEMM. More... | |
Public Types | |
using | Shape = Shape_ |
Policy describing tuning details. More... | |
using | Policy = Policy_ |
using | Operator = typename Policy::Operator |
Warp-level Mma. More... | |
using | WarpGemm = typename Policy::Operator::Shape |
using | WarpCount = GemmShape< Shape::kM/WarpGemm::kM, Shape::kN/WarpGemm::kN, Shape::kK/WarpGemm::kK > |
Shape describing the number of warps filling the CTA. More... | |
using | TensorRefA = TensorRef< typename Operator::ElementA, typename Operator::LayoutA > |
Tensor reference to the A operand. More... | |
using | TensorRefB = TensorRef< typename Operator::ElementB, typename Operator::LayoutB > |
Tensor reference to the B operand. More... | |
Public Member Functions | |
CUTLASS_DEVICE | MmaBase (SharedStorage &shared_storage, int thread_idx, int warp_idx, int lane_idx) |
Construct from tensor references. More... | |
Static Public Attributes | |
static int const | kWarpGemmIterations |
Number of warp-level GEMM oeprations. More... | |
static int const | kStages = Stages |
Number of stages. More... | |
Protected Attributes | |
Operator::IteratorA | warp_tile_iterator_A_ |
Iterator to load a warp-scoped tile of A operand from shared memory. More... | |
Operator::IteratorB | warp_tile_iterator_B_ |
Iterator to load a warp-scoped tile of B operand from shared memory. More... | |
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
using cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, Stages, Enable >::Operator = typename Policy::Operator |
using cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, Stages, Enable >::Policy = Policy_ |
using cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, Stages, Enable >::Shape = Shape_ |
using cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, Stages, Enable >::TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA> |
using cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, Stages, Enable >::TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB> |
using cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, Stages, Enable >::WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK> |
using cutlass::gemm::threadblock::MmaBase< Shape_, Policy_, Stages, Enable >::WarpGemm = typename Policy::Operator::Shape |
Shape describing the overall GEMM computed from shared memory by each warp.
|
inline |
shared_storage | Shared storage needed for internal use by threadblock-scoped GEMM ID within the threadblock |
thread_idx | ID of warp |
warp_idx | ID of each thread within a warp |
|
static |
|
static |
|
protected |
|
protected |