CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
|
#include <gemm_complex.h>
Classes | |
struct | Arguments |
Argument structure. More... | |
Public Member Functions | |
GemmComplex () | |
Constructs the GEMM. More... | |
Status | initialize (Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr) |
Initializes GEMM state from arguments. More... | |
Status | update (Arguments const &args, void *workspace=nullptr) |
Lightweight update given a subset of arguments. More... | |
Status | run (cudaStream_t stream=nullptr) |
Runs the kernel using initialized state. More... | |
Status | operator() (cudaStream_t stream=nullptr) |
Runs the kernel using initialized state. More... | |
Status | operator() (Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr) |
Runs the kernel using initialized state. More... | |
Static Public Member Functions | |
static Status | can_implement (Arguments const &args) |
Determines whether the GEMM can execute the given problem. More... | |
static size_t | get_workspace_size (Arguments const &args) |
Gets the workspace size. More... | |
Static Public Attributes | |
static int const | kStages = Stages |
static ComplexTransform const | kTransformA = TransformA |
static ComplexTransform const | kTransformB = TransformB |
static bool const | kSplitKSerial = SplitKSerial |
Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may be invoked from host code.
The contributions of this class are:
The intent is to provide a convenient mechanism for interacting with most plausible GEMM configurations for each supported architecture. Consequently, not all parameters are exposed to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy are selected to tradeoff simplicity of the interface with flexibility. We expect most configurations to be specified at this level. Applications with more exotic requirements may construct their kernels of interest using CUTLASS components at the threadblock, warp, and thread levels of abstraction.
CUTLASS exposes computations using the functor design pattern in which objects compose some internal state with an overloaded function call operator. This enables decoupling of initialization from execution, possibly reducing overhead during steady state phases of application execution.
CUTLASS device-level operators expose an Arguments structure encompassing each logical input to the computation. This is distinct from the kernel-level Params structure pattern which contains application-specific precomputed state needed by the device code.
Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN is as follows:
Instantiate the CUTLASS GEMM operator.
cutlass::gemm::device::Gemm< float, cutlass::layout::ColumnMajor, float, cutlass::layout::ColumnMajor, float, cutlass::layout::ColumnMajor > gemm_op;
Launch the GEMM operation on the device
cutlass::Status status = gemm_op({ {m, n, k}, // GemmCoord problem_size, {A, lda}, // TensorRef<float, layout::ColumnMajor> ref_A, {B, ldb}, // TensorRef<float, layout::ColumnMajor> ref_B, {C, ldc}, // TensorRef<float, layout::ColumnMajor> ref_C, {D, ldd}, // TensorRef<float, layout::ColumnMajor> ref_D, {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params });
A simplified view of the template is listed below.
template < / Element type for A matrix operand typename ElementA,
/ Layout type for A matrix operand typename LayoutA,
/ Element type for B matrix operand typename ElementB,
/ Layout type for B matrix operand typename LayoutB,
/ Element type for C and D matrix operands typename ElementC,
/ Layout type for C and D matrix operands typename LayoutC,
/ Element type for internal accumulation typename ElementAccumulator,
/ Operator class tag typename OperatorClass,
/ Tag indicating architecture to tune for typename ArchTag,
/ Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape,
/ Warp-level tile size (concept: GemmShape) typename WarpShape,
/ Warp-level tile size (concept: GemmShape) typename InstructionShape,
/ Epilogue output operator typename EpilogueOutputOp,
/ Threadblock-level swizzling operator typename ThreadblockSwizzle,
/ Number of stages used in the pipelined mainloop int Stages > class Gemm;
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::ArchTag = ArchTag_ |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::ElementA = ElementA_ |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::ElementAccumulator = ElementAccumulator_ |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::ElementB = ElementB_ |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::ElementC = ElementC_ |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::EpilogueOutputOp = EpilogueOutputOp_ |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::GemmKernel = typename kernel::DefaultGemmComplex< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, kStages, kTransformA, kTransformB, kSplitKSerial >::GemmKernel |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::InstructionShape = InstructionShape_ |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::LayoutA = LayoutA_ |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::LayoutB = LayoutB_ |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::LayoutC = LayoutC_ |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::OperatorClass = OperatorClass_ |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::TensorRefA = TensorRef<ElementA const, LayoutA> |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::TensorRefB = TensorRef<ElementB const, LayoutB> |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::TensorRefC = TensorRef<ElementC const, LayoutC> |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::TensorRefD = TensorRef<ElementC, LayoutC> |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::ThreadblockShape = ThreadblockShape_ |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::ThreadblockSwizzle = ThreadblockSwizzle_ |
using cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, LayoutC_, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::WarpShape = WarpShape_ |
|
inline |
|
inlinestatic |
|
inlinestatic |
|
inline |
|
inline |
|
inline |
|
inline |
|
inline |
|
static |
|
static |
|
static |
|
static |