CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
Public Types | Static Public Attributes | List of all members
cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ > Struct Template Reference

Partial specialization: More...

#include <default_mma_core_simt.h>

Public Types

using Shape = Shape_
 
using WarpShape = WarpShape_
 
using InstructionShape = GemmShape< 1, 1, 4 >
 
using ElementA = int8_t
 
using LayoutA = layout::RowMajor
 
using ElementB = int8_t
 
using LayoutB = layout::ColumnMajor
 
using ElementC = ElementC_
 
using LayoutC = LayoutC_
 
using OperatorClass = arch::OpClassSimt
 
using Operator = Operator_
 Default Operator. More...
 
using WarpCount = GemmShape< Shape::kM/WarpShape::kM, Shape::kN/WarpShape::kN, PartitionsK >
 Number of warps present. More...
 
using SmemLayoutA = layout::ColumnMajorInterleaved< 4 >
 
using SmemLayoutB = layout::RowMajorInterleaved< 4 >
 
using IteratorThreadMapA = transform::PitchLinear2DThreadTileStripminedThreadMap< layout::PitchLinearShape< Shape::kK, Shape::kM >, kThreads, layout::PitchLinearShape< 4, 4 > >
 ThreadMap of iterator A. More...
 
using SmemThreadMapA = transform::TransposePitchLinearThreadMap2DThreadTile< IteratorThreadMapA >
 Transpose the ThreadMap of iterator A. More...
 
using SmemIteratorA = transform::threadblock::RegularTileIterator2dThreadTile< MatrixShape< Shape::kM, Shape::kK >, ElementA, SmemLayoutA, 1, SmemThreadMapA >
 Shared memory iterator to A operand. More...
 
using IteratorThreadMapB = transform::PitchLinear2DThreadTileStripminedThreadMap< layout::PitchLinearShape< Shape::kK, Shape::kN >, kThreads, layout::PitchLinearShape< 4, 4 > >
 Policy of iterator B. More...
 
using SmemThreadMapB = transform::TransposePitchLinearThreadMap2DThreadTile< IteratorThreadMapB >
 Transpose the ThreadMap of iterator A. More...
 
using SmemIteratorB = transform::threadblock::RegularTileIterator2dThreadTile< MatrixShape< Shape::kK, Shape::kN >, ElementB, SmemLayoutB, 0, SmemThreadMapB >
 Shared memory iterator to B operand. More...
 
using LaneMmaShape = cutlass::gemm::GemmShape< LaneM, LaneN, 4 >
 
using Policy = cutlass::gemm::warp::MmaSimtPolicy< cutlass::MatrixShape< WarpNumThreadsM, WarpNumThreadsN >, cutlass::layout::ColumnMajorInterleaved< LaneLayout >, LaneMmaShape >
 
using MmaWarpSimt = cutlass::gemm::warp::MmaSimt< WarpShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, ElementC, LayoutC, Policy, PartitionsK >
 
using MmaPolicy = MmaPolicy< MmaWarpSimt, MatrixShape< kPaddingM, 0 >, MatrixShape< 0, kPaddingN >, WarpCount::kK >
 Policy used to define MmaPipelined. More...
 

Static Public Attributes

static int const PartitionsK = Shape::kK / WarpShape::kK
 
static int const kWarpSize = warp::WarpSize<arch::OpClassSimt>::value
 Number of threads per warp. More...
 
static int const kThreads = WarpCount::kCount * kWarpSize
 Number of threads total. More...
 
static const int WarpNumThreadsM = detail::simt_get_warp_threads_m<WarpShape>()
 
static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM
 
static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM
 
static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN
 
static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1
 
static const int numElementsA = 128 / sizeof_bits<ElementA>::value
 
static const int numElementsB = 128 / sizeof_bits<ElementB>::value
 
static const int LaneM = cutlass::const_min(4, ThreadTileM)
 
static const int LaneN = cutlass::const_min(4, ThreadTileN)
 
static int const kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementA>::value)
 
static int const kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementB>::value)
 

Detailed Description

template<typename Shape_, typename WarpShape_, typename ElementC_, typename LayoutC_, typename Operator_>
struct cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >

A: Row-major B: Column-major Operator: simt class, for dp4a

This uses the default warp-level operator given tile sizes

Member Typedef Documentation

template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::ElementA = int8_t
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::ElementB = int8_t
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::ElementC = ElementC_
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::InstructionShape = GemmShape<1, 1, 4>
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::IteratorThreadMapA = transform::PitchLinear2DThreadTileStripminedThreadMap< layout::PitchLinearShape<Shape::kK, Shape::kM>, kThreads, layout::PitchLinearShape<4, 4> >
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::IteratorThreadMapB = transform::PitchLinear2DThreadTileStripminedThreadMap< layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreads, layout::PitchLinearShape<4, 4> >
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::LaneMmaShape = cutlass::gemm::GemmShape< LaneM, LaneN, 4>
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::LayoutA = layout::RowMajor
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::LayoutB = layout::ColumnMajor
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::LayoutC = LayoutC_
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::MmaPolicy = MmaPolicy< MmaWarpSimt, MatrixShape<kPaddingM, 0>, MatrixShape<0, kPaddingN>, WarpCount::kK >
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::MmaWarpSimt = cutlass::gemm::warp::MmaSimt< WarpShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, ElementC, LayoutC, Policy, PartitionsK >
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::Operator = Operator_
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::OperatorClass = arch::OpClassSimt
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::Policy = cutlass::gemm::warp::MmaSimtPolicy< cutlass::MatrixShape<WarpNumThreadsM, WarpNumThreadsN>, cutlass::layout::ColumnMajorInterleaved<LaneLayout>, LaneMmaShape >
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::Shape = Shape_
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::SmemIteratorA = transform::threadblock::RegularTileIterator2dThreadTile< MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 1, SmemThreadMapA >
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::SmemIteratorB = transform::threadblock::RegularTileIterator2dThreadTile< MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 0, SmemThreadMapB >
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::SmemLayoutA = layout::ColumnMajorInterleaved<4>
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::SmemLayoutB = layout::RowMajorInterleaved<4>
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::SmemThreadMapA = transform::TransposePitchLinearThreadMap2DThreadTile<IteratorThreadMapA>
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::SmemThreadMapB = transform::TransposePitchLinearThreadMap2DThreadTile<IteratorThreadMapB>
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::WarpCount = GemmShape< Shape::kM / WarpShape::kM, Shape::kN / WarpShape::kN, PartitionsK >
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
using cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::WarpShape = WarpShape_

Member Data Documentation

template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
int const cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::kPaddingM = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementA>::value)
static
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
int const cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::kPaddingN = detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits<ElementB>::value)
static
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
int const cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::kThreads = WarpCount::kCount * kWarpSize
static
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
int const cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::kWarpSize = warp::WarpSize<arch::OpClassSimt>::value
static
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
const int cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1
static
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
const int cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::LaneM = cutlass::const_min(4, ThreadTileM)
static
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
const int cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::LaneN = cutlass::const_min(4, ThreadTileN)
static
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
const int cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::numElementsA = 128 / sizeof_bits<ElementA>::value
static
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
const int cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::numElementsB = 128 / sizeof_bits<ElementB>::value
static
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
int const cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::PartitionsK = Shape::kK / WarpShape::kK
static
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
const int cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::ThreadTileM = WarpShape::kM / WarpNumThreadsM
static
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
const int cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::ThreadTileN = WarpShape::kN / WarpNumThreadsN
static
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
const int cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::WarpNumThreadsM = detail::simt_get_warp_threads_m<WarpShape>()
static
template<typename Shape_ , typename WarpShape_ , typename ElementC_ , typename LayoutC_ , typename Operator_ >
const int cutlass::gemm::threadblock::DefaultMmaCore< Shape_, WarpShape_, GemmShape< 1, 1, 4 >, int8_t, layout::RowMajor, int8_t, layout::ColumnMajor, ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_ >::WarpNumThreadsN = kWarpSize / WarpNumThreadsM
static

The documentation for this struct was generated from the following file: