173 typename ElementAccumulator_ = ElementC_,
175 typename OperatorClass_ = arch::OpClassSimt,
177 typename ArchTag_ = arch::Sm70,
179 typename ThreadblockShape_ =
typename DefaultGemmConfiguration<
180 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
181 ElementAccumulator_>::ThreadblockShape,
183 typename WarpShape_ =
typename DefaultGemmConfiguration<
184 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
185 ElementAccumulator_>::WarpShape,
187 typename InstructionShape_ =
typename DefaultGemmConfiguration<
188 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
189 ElementAccumulator_>::InstructionShape,
191 typename EpilogueOutputOp_ =
typename DefaultGemmConfiguration<
192 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
193 ElementAccumulator_>::EpilogueOutputOp,
195 typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle,
198 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
199 ElementC_, ElementAccumulator_>::kStages,
202 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
203 ElementC_, ElementAccumulator_>::kAlignmentA,
206 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
207 ElementC_, ElementAccumulator_>::kAlignmentB,
209 bool SplitKSerial =
false,
211 typename Operator_ =
typename DefaultGemmConfiguration<
212 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
213 ElementAccumulator_>::Operator,
215 bool IsBetaZero =
false>
302 typename EpilogueOutputOp::Params epilogue_ =
303 typename EpilogueOutputOp::Params(),
304 int split_k_slices = 1
306 problem_size(problem_size_),
312 split_k_slices(split_k_slices) {
320 typename GemmKernel::Params params_;
330 if (!kSplitKSerial && args.split_k_slices > 1) {
334 Status status = GemmKernel::can_implement(
336 args.ref_A.non_const_ref(),
337 args.ref_B.non_const_ref(),
338 args.ref_C.non_const_ref(),
352 if (kSplitKSerial && args.split_k_slices > 1) {
355 ThreadblockSwizzle threadblock_swizzle;
359 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
360 args.split_k_slices);
362 return sizeof(int) *
size_t(tiled_shape.
m()) *
size_t(tiled_shape.
n());
372 ThreadblockSwizzle threadblock_swizzle;
376 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
377 args.split_k_slices);
380 if (args.split_k_slices > 1) {
387 cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
389 if (result != cudaSuccess) {
396 if (args.split_k_slices > 1) {
402 params_ =
typename GemmKernel::Params{
405 args.ref_A.non_const_ref(),
406 args.ref_B.non_const_ref(),
407 args.ref_C.non_const_ref(),
410 static_cast<int *
>(workspace)
419 if (kSplitKSerial && args.split_k_slices > 1) {
425 params_.ref_A.reset(args.ref_A.non_const_ref().data());
426 params_.ref_B.reset(args.ref_B.non_const_ref().data());
427 params_.ref_C.reset(args.ref_C.non_const_ref().data());
428 params_.ref_D.reset(args.ref_D.data());
429 params_.semaphore =
static_cast<int *
>(workspace);
437 ThreadblockSwizzle threadblock_swizzle;
439 dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
440 dim3 block(GemmKernel::kThreadCount, 1, 1);
444 int smem_size = int(
sizeof(
typename GemmKernel::SharedStorage));
445 if (smem_size >= (48 << 10)) {
446 result = cudaFuncSetAttribute(Kernel<GemmKernel>,
447 cudaFuncAttributeMaxDynamicSharedMemorySize,
450 if (result != cudaSuccess) {
454 result = cudaFuncSetAttribute(
456 cudaFuncAttributePreferredSharedMemoryCarveout, 100);
458 if (result != cudaSuccess) {
463 cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
465 result = cudaGetLastError();
478 void *workspace =
nullptr,
479 cudaStream_t stream =
nullptr) {
484 status =
run(stream);
506 typename ElementAccumulator_,
508 typename OperatorClass_,
512 typename ThreadblockShape_,
516 typename InstructionShape_,
518 typename EpilogueOutputOp_,
520 typename ThreadblockSwizzle_,
533 class Gemm<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
535 ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_,
536 WarpShape_, InstructionShape_, EpilogueOutputOp_,
537 ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial,
538 Operator_, IsBetaZero> {
541 using ElementA = ElementA_;
542 using LayoutA = LayoutA_;
544 using ElementB = ElementB_;
545 using LayoutB = LayoutB_;
547 using ElementC = ElementC_;
551 using ElementAccumulator = ElementAccumulator_;
552 using OperatorClass = OperatorClass_;
553 using ArchTag = ArchTag_;
554 using ThreadblockShape = ThreadblockShape_;
555 using WarpShape = WarpShape_;
556 using InstructionShape = InstructionShape_;
557 using EpilogueOutputOp = EpilogueOutputOp_;
558 using ThreadblockSwizzle = ThreadblockSwizzle_;
559 using Operator = Operator_;
560 static int const kStages = Stages;
561 static int const kAlignmentA = AlignmentA;
562 static int const kAlignmentB = AlignmentB;
563 static bool const kSplitKSerial = SplitKSerial;
564 static bool const kIsBetaZero = IsBetaZero;
591 static int const kAlignmentC = UnderlyingOperator::kAlignmentC;
624 typename EpilogueOutputOp::Params epilogue_ =
625 typename EpilogueOutputOp::Params(),
626 int split_k_slices = 1
628 problem_size(problem_size_),
634 split_k_slices(split_k_slices) { }
649 {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
650 {args.ref_B.data(), args.ref_B.stride(0)},
651 {args.ref_A.data(), args.ref_A.stride(0)},
652 {args.ref_C.data(), args.ref_C.stride(0)},
653 {args.ref_D.data(), args.ref_D.stride(0)},
662 return UnderlyingOperator::can_implement(to_underlying_arguments(args));
668 return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
674 return underlying_operator_.
initialize(to_underlying_arguments(args), workspace);
680 return underlying_operator_.
update(to_underlying_arguments(args), workspace);
686 return underlying_operator_.
run(stream);
697 void *workspace =
nullptr,
698 cudaStream_t stream =
nullptr) {
703 status =
run(stream);
Definition: default_gemm.h:116
static int const kStages
Definition: include/cutlass/gemm/device/gemm.h:238
GemmCoord problem_size
Definition: include/cutlass/gemm/device/gemm.h:276
Definition: aligned_buffer.h:35
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::split_k_slices int split_k_slices
Definition: include/cutlass/gemm/device/gemm.h:606
Specified problem size is not supported by operator.
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::ElementA ElementB ElementA
Definition: include/cutlass/gemm/device/gemm.h:219
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::ref_A TensorRef< ElementA const, LayoutA > ref_A
Definition: include/cutlass/gemm/device/gemm.h:601
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: include/cutlass/gemm/device/gemm.h:350
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::ThreadblockSwizzle ThreadblockSwizzle ThreadblockSwizzle
Definition: include/cutlass/gemm/device/gemm.h:236
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::Arguments CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: include/cutlass/gemm/device/gemm.h:614
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: include/cutlass/gemm/device/gemm.h:328
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: include/cutlass/gemm/device/gemm.h:290
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)
Constructs an Arguments structure.
Definition: include/cutlass/gemm/device/gemm.h:296
Definition: include/cutlass/gemm/gemm.h:94
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::get_workspace_size static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: include/cutlass/gemm/device/gemm.h:666
Definition: include/cutlass/gemm/device/gemm.h:216
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
static bool const kSplitKSerial
Definition: include/cutlass/gemm/device/gemm.h:242
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::update Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: include/cutlass/gemm/device/gemm.h:678
EpilogueOutputOp::Params epilogue
Definition: include/cutlass/gemm/device/gemm.h:281
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::epilogue EpilogueOutputOp::Params epilogue
Definition: include/cutlass/gemm/device/gemm.h:605
TensorRef< ElementA const, LayoutA > ref_A
Definition: include/cutlass/gemm/device/gemm.h:277
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Gemm Gemm()
Constructs the GEMM.
Definition: include/cutlass/gemm/device/gemm.h:644
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::can_implement static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: include/cutlass/gemm/device/gemm.h:660
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropr...
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::operator() Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:690
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::InstructionShape InstructionShape InstructionShape
Definition: include/cutlass/gemm/device/gemm.h:234
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::Operator Operator Operator
Definition: include/cutlass/gemm/device/gemm.h:237
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: include/cutlass/gemm/device/gemm.h:417
int split_k_slices
Definition: include/cutlass/gemm/device/gemm.h:282
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::ElementC ElementC ElementC
Definition: include/cutlass/gemm/device/gemm.h:225
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::run Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:684
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:471
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::OperatorClass OperatorClass OperatorClass
Definition: include/cutlass/gemm/device/gemm.h:230
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::UnderlyingArguments typename UnderlyingOperator::Arguments UnderlyingArguments
Definition: include/cutlass/gemm/device/gemm.h:589
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::ElementAccumulator ElementAccumulator ElementAccumulator
Definition: include/cutlass/gemm/device/gemm.h:229
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::GemmKernel typename kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, kStages, kSplitKSerial, Operator, kIsBetaZero >::GemmKernel GemmKernel
Define the kernel.
Definition: include/cutlass/gemm/device/gemm.h:267
static int const kAlignmentB
Definition: include/cutlass/gemm/device/gemm.h:240
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
static int const kAlignmentA
Definition: include/cutlass/gemm/device/gemm.h:239
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::ThreadblockShape ThreadblockShape ThreadblockShape
Definition: include/cutlass/gemm/device/gemm.h:232
TensorRef< ElementC const, LayoutC > ref_C
Definition: include/cutlass/gemm/device/gemm.h:279
TensorRef< ElementB const, LayoutB > ref_B
Definition: include/cutlass/gemm/device/gemm.h:278
An error within CUTLASS occurred.
Template for generic CUTLASS kernel.
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::problem_size GemmCoord problem_size
Definition: include/cutlass/gemm/device/gemm.h:600
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:476
Gemm()
Constructs the GEMM.
Definition: include/cutlass/gemm/device/gemm.h:325
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::Arguments CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)
Constructs an Arguments structure.
Definition: include/cutlass/gemm/device/gemm.h:618
LayoutC_ LayoutC
Definition: include/cutlass/gemm/device/gemm.h:226
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::LayoutB typename layout::LayoutTranspose< LayoutA >::type LayoutB
Definition: include/cutlass/gemm/device/gemm.h:223
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::ref_C TensorRef< ElementC const, LayoutC > ref_C
Definition: include/cutlass/gemm/device/gemm.h:603
Argument structure.
Definition: include/cutlass/gemm/device/gemm.h:270
Definitions for GEMM structures.
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::initialize Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: include/cutlass/gemm/device/gemm.h:672
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:435
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::EpilogueOutputOp EpilogueOutputOp EpilogueOutputOp
Definition: include/cutlass/gemm/device/gemm.h:235
Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: include/cutlass/gemm/device/gemm.h:369
The given workspace is null when it is required to be non-null.
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::WarpShape WarpShape WarpShape
Definition: include/cutlass/gemm/device/gemm.h:233
Operation was successful.
static int const kAlignmentC
Definition: include/cutlass/gemm/device/gemm.h:241
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::to_underlying_arguments static UnderlyingArguments to_underlying_arguments(Arguments const &args)
Helper to construct a transposed equivalent for the underying GEMM operator.
Definition: include/cutlass/gemm/device/gemm.h:647
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::ElementB ElementA ElementB
Definition: include/cutlass/gemm/device/gemm.h:222
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Defines tags for architecture-specific configurations.
TensorRef< ElementC, LayoutC > ref_D
Definition: include/cutlass/gemm/device/gemm.h:280
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::ref_D TensorRef< ElementC, LayoutC > ref_D
Definition: include/cutlass/gemm/device/gemm.h:604
static bool const kIsBetaZero
Definition: include/cutlass/gemm/device/gemm.h:243
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::GemmKernel typename UnderlyingOperator::GemmKernel GemmKernel
Definition: include/cutlass/gemm/device/gemm.h:590
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::ArchTag ArchTag ArchTag
Definition: include/cutlass/gemm/device/gemm.h:231
Basic include for CUTLASS.
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::operator() Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:695
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
cutlass::gemm::device::Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >::Arguments::ref_B TensorRef< ElementB const, LayoutB > ref_B
Definition: include/cutlass/gemm/device/gemm.h:602
cutlass::gemm::device::Gemm< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, SplitKSerial, Operator, kIsBetaZero >::LayoutA typename layout::LayoutTranspose< LayoutB >::type LayoutA
Definition: include/cutlass/gemm/device/gemm.h:220