72 typename ElementAccumulator_ = ElementC_,
74 typename OperatorClass_ = arch::OpClassSimt,
76 typename ArchTag_ = arch::Sm70,
78 typename ThreadblockShape_ =
typename DefaultGemmConfiguration<
79 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
80 ElementAccumulator_>::ThreadblockShape,
82 typename WarpShape_ =
typename DefaultGemmConfiguration<
83 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
84 ElementAccumulator_>::WarpShape,
86 typename InstructionShape_ =
typename DefaultGemmConfiguration<
87 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
88 ElementAccumulator_>::InstructionShape,
90 typename EpilogueOutputOp_ =
typename DefaultGemmConfiguration<
91 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
92 ElementAccumulator_>::EpilogueOutputOp,
96 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
98 ElementAccumulator_>::EpilogueOutputOp::kCount,
102 ElementAccumulator_,
typename EpilogueOutputOp_::ElementAccumulator,
103 EpilogueOutputOp_::kCount>,
105 typename ThreadblockSwizzle_ =
106 threadblock::GemmSplitKHorizontalThreadblockSwizzle,
109 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
110 ElementC_, ElementAccumulator_>::kStages,
113 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
114 ElementC_, ElementAccumulator_>::kAlignmentA,
117 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
118 ElementC_, ElementAccumulator_>::kAlignmentB,
120 typename Operator_ =
typename DefaultGemmConfiguration<
121 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
122 ElementAccumulator_>::Operator>
211 typename EpilogueOutputOp::Params epilogue_ =
212 typename EpilogueOutputOp::Params(),
213 int split_k_slices = 1,
214 typename ConvertScaledOp::Params convert_ =
215 typename ConvertScaledOp::Params(),
216 typename ReductionOp::Params reduction_ =
217 typename ReductionOp::Params()
219 problem_size(problem_size_),
225 split_k_slices(split_k_slices),
227 reduction(reduction_) { }
233 typename GemmKernel::Params gemm_params_;
255 ThreadblockSwizzle threadblock_swizzle;
259 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
260 args.split_k_slices);
262 return sizeof(ElementAccumulator_) *
size_t(args.problem_size.m()) *
size_t(args.problem_size.n()) * grid_shape.
k();
269 ThreadblockSwizzle threadblock_swizzle;
273 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
274 args.split_k_slices);
282 static_cast<ElementAccumulator_ *>(workspace),
283 args.problem_size.n());
285 int64_t partition_stride = int64_t(args.problem_size.m()) * int64_t(args.problem_size.n());
288 gemm_params_ =
typename GemmKernel::Params{
291 args.ref_A.non_const_ref(),
292 args.ref_B.non_const_ref(),
299 args.problem_size.mn(),
318 gemm_params_.ref_A.reset(args.ref_A.data());
319 gemm_params_.ref_B.reset(args.ref_B.data());
320 gemm_params_.ref_D.reset(workspace);
322 reduction_params_.ref_D.reset(args.ref_D.data());
323 reduction_params_.ref_C.reset(args.ref_C.data());
335 ThreadblockSwizzle threadblock_swizzle;
337 dim3 grid = threadblock_swizzle.get_grid_shape(gemm_params_.grid_tiled_shape);
338 dim3 block(GemmKernel::kThreadCount, 1, 1);
342 int smem_size = int(
sizeof(
typename GemmKernel::SharedStorage));
343 if (smem_size >= (48 << 10)) {
345 result = cudaFuncSetAttribute(
347 cudaFuncAttributeMaxDynamicSharedMemorySize,
350 if (result != cudaSuccess) {
354 result = cudaFuncSetAttribute(
356 cudaFuncAttributePreferredSharedMemoryCarveout, 100);
358 if (result != cudaSuccess) {
363 Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(gemm_params_);
365 result = cudaGetLastError();
367 if (result != cudaSuccess) {
378 Kernel<ReductionKernel><<< grid, block, 0, stream >>>(reduction_params_);
380 result = cudaGetLastError();
382 if (result != cudaSuccess) {
397 void *workspace =
nullptr,
398 cudaStream_t stream =
nullptr) {
403 status =
run(stream);
425 typename ElementAccumulator_,
427 typename OperatorClass_,
431 typename ThreadblockShape_,
435 typename InstructionShape_,
437 typename EpilogueOutputOp_,
439 typename ConvertScaledOp_,
441 typename ReductionOp_,
443 typename ThreadblockSwizzle_,
445 int Stages,
int kAlignmentA,
int kAlignmentB,
449 layout::ColumnMajor, ElementAccumulator_,
450 OperatorClass_, ArchTag_, ThreadblockShape_,
451 WarpShape_, InstructionShape_, EpilogueOutputOp_,
452 ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_,
453 Stages, kAlignmentA, kAlignmentB, Operator_> {
456 using ElementA = ElementA_;
457 using LayoutA = LayoutA_;
458 using ElementB = ElementB_;
459 using LayoutB = LayoutB_;
462 using ElementAccumulator = ElementAccumulator_;
463 using OperatorClass = OperatorClass_;
464 using ArchTag = ArchTag_;
465 using ThreadblockShape = ThreadblockShape_;
466 using WarpShape = WarpShape_;
467 using InstructionShape = InstructionShape_;
468 using ConvertScaledOp = ConvertScaledOp_;
469 using EpilogueOutputOp = EpilogueOutputOp_;
471 using ThreadblockSwizzle = ThreadblockSwizzle_;
473 static int const kStages = Stages;
535 typename EpilogueOutputOp::Params epilogue_ =
536 typename EpilogueOutputOp::Params(),
537 int split_k_slices = 1,
538 typename ConvertScaledOp::Params convert_ =
539 typename ConvertScaledOp::Params(),
540 typename ReductionOp::Params reduction_ =
541 typename ReductionOp::Params()
543 problem_size(problem_size_),
549 split_k_slices(split_k_slices),
551 reduction(reduction_) { }
567 {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
568 {args.ref_B.data(), args.ref_B.stride(0)},
569 {args.ref_A.data(), args.ref_A.stride(0)},
570 {args.ref_C.data(), args.ref_C.stride(0)},
571 {args.ref_D.data(), args.ref_D.stride(0)},
582 return UnderlyingOperator::can_implement(to_underlying_arguments(args));
588 return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
594 return underlying_operator_.
initialize(to_underlying_arguments(args), workspace);
600 return underlying_operator_.
update(to_underlying_arguments(args), workspace);
606 return underlying_operator_.
run(stream);
617 void *workspace =
nullptr,
618 cudaStream_t stream =
nullptr) {
623 status =
run(stream);
Definition: conversion_op.h:53
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::WarpShape WarpShape WarpShape
Definition: device/gemm_splitk_parallel.h:136
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::GemmKernel typename UnderlyingOperator::GemmKernel GemmKernel
Definition: device/gemm_splitk_parallel.h:499
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_D TensorRef< ElementC, LayoutC > ref_D
Definition: device/gemm_splitk_parallel.h:513
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::LayoutB typename layout::LayoutTranspose< LayoutA >::type LayoutB
Definition: device/gemm_splitk_parallel.h:129
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_C TensorRef< ElementC const, LayoutC > ref_C
Definition: device/gemm_splitk_parallel.h:512
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::Operator Operator Operator
Definition: device/gemm_splitk_parallel.h:142
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::ReductionKernel typename UnderlyingOperator::ReductionKernel ReductionKernel
Definition: device/gemm_splitk_parallel.h:500
Definition: aligned_buffer.h:35
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::problem_size GemmCoord problem_size
Definition: device/gemm_splitk_parallel.h:509
Definition: default_gemm_splitk_parallel.h:88
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:395
static int const kStages
Definition: device/gemm_splitk_parallel.h:143
static CUTLASS_HOST_DEVICE dim3 block_shape()
Determines the threadblock shape.
Definition: reduce_split_k.h:138
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::ElementC ElementC_ ElementC
Definition: device/gemm_splitk_parallel.h:460
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_A TensorRef< ElementA const, LayoutA > ref_A
Definition: device/gemm_splitk_parallel.h:510
Kernel performing a reduction over densely packed tensors in global memory.
Definition: include/cutlass/gemm/gemm.h:94
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::epilogue EpilogueOutputOp::Params epilogue
Definition: device/gemm_splitk_parallel.h:514
Functor performing conversion operations used by epilogues.
int split_k_slices
Definition: device/gemm_splitk_parallel.h:191
ReductionOp::Params reduction
Definition: device/gemm_splitk_parallel.h:193
Mixed-precision reduction.
Definition: reduction_operators.h:50
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::run Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:604
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::update Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: device/gemm_splitk_parallel.h:598
Params structure.
Definition: reduce_split_k.h:80
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::InstructionShape InstructionShape InstructionShape
Definition: device/gemm_splitk_parallel.h:137
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: device/gemm_splitk_parallel.h:201
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
ConvertScaledOp::Params convert
Definition: device/gemm_splitk_parallel.h:192
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::GemmSplitKParallel GemmSplitKParallel()
Constructs the GEMM.
Definition: device/gemm_splitk_parallel.h:562
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ElementC ElementC ElementC
Definition: device/gemm_splitk_parallel.h:130
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ThreadblockShape ThreadblockShape ThreadblockShape
Definition: device/gemm_splitk_parallel.h:135
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::ref_B TensorRef< ElementB const, LayoutB > ref_B
Definition: device/gemm_splitk_parallel.h:511
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::EpilogueOutputOp EpilogueOutputOp EpilogueOutputOp
Definition: device/gemm_splitk_parallel.h:139
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::get_workspace_size static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: device/gemm_splitk_parallel.h:586
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ArchTag ArchTag ArchTag
Definition: device/gemm_splitk_parallel.h:134
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::operator() Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:615
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::operator() Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:610
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: device/gemm_splitk_parallel.h:244
TensorRef< ElementC, LayoutC > ref_D
Definition: device/gemm_splitk_parallel.h:189
GemmSplitKParallel()
Constructs the GEMM.
Definition: device/gemm_splitk_parallel.h:241
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
GemmCoord problem_size
Definition: device/gemm_splitk_parallel.h:185
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::ReductionOp ReductionOp_ ReductionOp
Definition: device/gemm_splitk_parallel.h:470
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::GemmKernel typename kernel::DefaultGemmSplitKParallel< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, ConvertScaledOp, ThreadblockSwizzle, kStages, Operator >::GemmKernel GemmKernel
GEMM kernel.
Definition: device/gemm_splitk_parallel.h:165
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::reduction ReductionOp::Params reduction
Definition: device/gemm_splitk_parallel.h:517
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::LayoutA typename layout::LayoutTranspose< LayoutB >::type LayoutA
Definition: device/gemm_splitk_parallel.h:127
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:390
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ReductionOp ReductionOp ReductionOp
Definition: device/gemm_splitk_parallel.h:140
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ElementB ElementA ElementB
Definition: device/gemm_splitk_parallel.h:128
An error within CUTLASS occurred.
Template for generic CUTLASS kernel.
Kernel performing a reduction over densely packed tensors in global memory.
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::UnderlyingArguments typename UnderlyingOperator::Arguments UnderlyingArguments
Definition: device/gemm_splitk_parallel.h:498
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Operator Operator_ Operator
Definition: device/gemm_splitk_parallel.h:472
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::convert ConvertScaledOp::Params convert
Definition: device/gemm_splitk_parallel.h:516
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Definition: reduce_split_k.h:55
static CUTLASS_HOST_DEVICE dim3 grid_shape(cutlass::MatrixCoord problem_size)
Computes the grid size given a chosen threadblock shape.
Definition: reduce_split_k.h:128
Definitions for GEMM structures.
CUTLASS_HOST_DEVICE NonConstTensorRef non_const_ref() const
Definition: tensor_ref.h:229
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: device/gemm_splitk_parallel.h:252
Definition: device/gemm_splitk_parallel.h:123
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
TensorRef< ElementC const, LayoutC > ref_C
Definition: device/gemm_splitk_parallel.h:188
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::initialize Status initialize(Arguments const &args, void *workspace)
Initializes GEMM state from arguments.
Definition: device/gemm_splitk_parallel.h:592
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ConvertScaledOp ConvertScaledOp ConvertScaledOp
Definition: device/gemm_splitk_parallel.h:138
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::can_implement static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: device/gemm_splitk_parallel.h:580
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1, typename ConvertScaledOp::Params convert_=typename ConvertScaledOp::Params(), typename ReductionOp::Params reduction_=typename ReductionOp::Params())
Constructs an Arguments structure.
Definition: device/gemm_splitk_parallel.h:205
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ElementAccumulator ElementAccumulator ElementAccumulator
Definition: device/gemm_splitk_parallel.h:132
The given workspace is null when it is required to be non-null.
Operation was successful.
TensorRef< ElementB const, LayoutB > ref_B
Definition: device/gemm_splitk_parallel.h:187
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Defines tags for architecture-specific configurations.
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ThreadblockSwizzle ThreadblockSwizzle ThreadblockSwizzle
Definition: device/gemm_splitk_parallel.h:141
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::split_k_slices int split_k_slices
Definition: device/gemm_splitk_parallel.h:515
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: device/gemm_splitk_parallel.h:312
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:329
EpilogueOutputOp::Params epilogue
Definition: device/gemm_splitk_parallel.h:190
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropr...
Argument structure.
Definition: device/gemm_splitk_parallel.h:179
LayoutC_ LayoutC
Definition: device/gemm_splitk_parallel.h:131
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::Arguments CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1, typename ConvertScaledOp::Params convert_=typename ConvertScaledOp::Params(), typename ReductionOp::Params reduction_=typename ReductionOp::Params())
Constructs an Arguments structure.
Definition: device/gemm_splitk_parallel.h:529
Status initialize(Arguments const &args, void *workspace)
Initializes GEMM state from arguments.
Definition: device/gemm_splitk_parallel.h:266
TensorRef< ElementA const, LayoutA > ref_A
Definition: device/gemm_splitk_parallel.h:186
Basic include for CUTLASS.
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::Arguments::Arguments CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: device/gemm_splitk_parallel.h:525
cutlass::gemm::device::GemmSplitKParallel< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_, Stages, kAlignmentA, kAlignmentB, Operator_ >::to_underlying_arguments static UnderlyingArguments to_underlying_arguments(Arguments const &args)
Helper to construct a transposed equivalent for the underying GEMM operator.
Definition: device/gemm_splitk_parallel.h:565
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::ElementA ElementB ElementA
Definition: device/gemm_splitk_parallel.h:126
cutlass::gemm::device::GemmSplitKParallel< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >::OperatorClass OperatorClass OperatorClass
Definition: device/gemm_splitk_parallel.h:133