173 typename ElementAccumulator_ = ElementC_,
175 typename OperatorClass_ = arch::OpClassSimt,
177 typename ArchTag_ = arch::Sm70,
179 typename ThreadblockShape_ =
typename DefaultGemmConfiguration<
180 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
181 ElementAccumulator_>::ThreadblockShape,
183 typename WarpShape_ =
typename DefaultGemmConfiguration<
184 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
185 ElementAccumulator_>::WarpShape,
187 typename InstructionShape_ =
typename DefaultGemmConfiguration<
188 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
189 ElementAccumulator_>::InstructionShape,
191 typename EpilogueOutputOp_ =
typename DefaultGemmConfiguration<
192 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
193 ElementAccumulator_>::EpilogueOutputOp,
195 typename ThreadblockSwizzle_ = threadblock::GemmBatchedIdentityThreadblockSwizzle,
198 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
199 ElementC_, ElementAccumulator_>::kStages,
202 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
203 ElementC_, ElementAccumulator_>::kAlignmentA,
206 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
207 ElementC_, ElementAccumulator_>::kAlignmentB,
209 typename Operator_ =
typename DefaultGemmConfiguration<
210 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
211 ElementAccumulator_>::Operator
305 typename EpilogueOutputOp::Params epilogue_,
308 problem_size(problem_size_),
318 batch_count(batch_count_) { }
334 if (!
TensorRef_aligned(args.ref_A, kAlignmentA) || (args.stride_A % kAlignmentA)) {
338 if (!
TensorRef_aligned(args.ref_B, kAlignmentB) || (args.stride_B % kAlignmentB)) {
342 if (!
TensorRef_aligned(args.ref_C, kAlignmentC) || (args.stride_C % kAlignmentC)) {
346 if (!
TensorRef_aligned(args.ref_D, kAlignmentC) || (args.stride_D % kAlignmentC)) {
369 ThreadblockSwizzle threadblock_swizzle;
374 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK});
380 args.ref_A.non_const_ref(),
382 args.ref_B.non_const_ref(),
384 args.ref_C.non_const_ref(),
398 params_.
ref_A.reset(args.ref_A.non_const_ref().data());
399 params_.
ref_B.reset(args.ref_B.non_const_ref().data());
400 params_.
ref_C.reset(args.ref_C.non_const_ref().data());
401 params_.
ref_D.reset(args.ref_D.data());
409 ThreadblockSwizzle threadblock_swizzle;
417 if (smem_size >= (48 << 10)) {
418 result = cudaFuncSetAttribute(Kernel<GemmKernel>,
419 cudaFuncAttributeMaxDynamicSharedMemorySize,
422 if (result != cudaSuccess) {
426 result = cudaFuncSetAttribute(
428 cudaFuncAttributePreferredSharedMemoryCarveout, 100);
430 if (result != cudaSuccess) {
435 cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
437 result = cudaGetLastError();
450 void *workspace =
nullptr,
451 cudaStream_t stream =
nullptr) {
456 status =
run(stream);
478 typename ElementAccumulator_,
480 typename OperatorClass_,
484 typename ThreadblockShape_,
488 typename InstructionShape_,
490 typename EpilogueOutputOp_,
492 typename ThreadblockSwizzle_,
546 static bool const kSplitKSerial =
false;
611 typename EpilogueOutputOp::Params epilogue_,
614 problem_size(problem_size_),
624 batch_count(batch_count_) { }
639 {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
640 {args.ref_B.data(), args.ref_B.stride(0)},
642 {args.ref_A.data(), args.ref_A.stride(0)},
644 {args.ref_C.data(), args.ref_C.stride(0)},
646 {args.ref_D.data(), args.ref_D.stride(0)},
656 return UnderlyingOperator::can_implement(to_underlying_arguments(args));
662 return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
668 return underlying_operator_.
initialize(to_underlying_arguments(args), workspace);
674 return underlying_operator_.
update(to_underlying_arguments(args), workspace);
680 return underlying_operator_.
run(stream);
691 void *workspace =
nullptr,
692 cudaStream_t stream =
nullptr) {
697 status =
run(stream);
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ElementC ElementC_ ElementC
Definition: device/gemm_batched.h:529
Definition: default_gemm.h:116
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ThreadblockSwizzle ThreadblockSwizzle_ ThreadblockSwizzle
Definition: device/gemm_batched.h:540
static int const kAlignmentB
Definition: device/gemm_batched.h:236
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::ref_A TensorRef< ElementA const, LayoutA > ref_A
Definition: device/gemm_batched.h:580
TensorRef< ElementC, LayoutC > ref_D
Definition: device/gemm_batched.h:280
GemmCoord problem_size
Definition: device/gemm_batched.h:273
Definition: aligned_buffer.h:35
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ElementB ElementB_ ElementB
Definition: device/gemm_batched.h:526
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::operator() Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:689
int64_t stride_D
Definition: device/gemm_batched.h:281
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: kernel/gemm_batched.h:74
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, int64_t stride_A_, TensorRef< ElementB const, LayoutB > ref_B_, int64_t stride_B_, TensorRef< ElementC const, LayoutC > ref_C_, int64_t stride_C_, TensorRef< ElementC, LayoutC > ref_D_, int64_t stride_D_, typename EpilogueOutputOp::Params epilogue_, int batch_count_)
Constructs an Arguments structure.
Definition: device/gemm_batched.h:295
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::LayoutB LayoutB_ LayoutB
Definition: device/gemm_batched.h:527
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::DefaultGemmKernel typename kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, kStages, false, Operator, false >::GemmKernel DefaultGemmKernel
Define the kernel.
Definition: device/gemm_batched.h:262
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::can_implement static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: device/gemm_batched.h:654
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ThreadblockShape ThreadblockShape_ ThreadblockShape
Definition: device/gemm_batched.h:536
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:443
Definition: include/cutlass/gemm/gemm.h:94
Argument structure.
Definition: device/gemm_batched.h:267
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::Operator typename DefaultGemmConfiguration< OperatorClass, ArchTag, ElementB, ElementA, ElementC,ElementAccumulator >::Operator Operator
Definition: device/gemm_batched.h:238
Mma::IteratorB::TensorRef ref_B
Definition: kernel/gemm_batched.h:68
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::stride_C int64_t stride_C
Definition: device/gemm_batched.h:585
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::OperatorClass OperatorClass_ OperatorClass
Definition: device/gemm_batched.h:534
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::LayoutB typename layout::LayoutTranspose< LayoutA >::type LayoutB
Definition: device/gemm_batched.h:220
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ArchTag ArchTag_ ArchTag
Definition: device/gemm_batched.h:535
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::GemmKernel typename UnderlyingOperator::GemmKernel GemmKernel
Definition: device/gemm_batched.h:570
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: device/gemm_batched.h:361
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::stride_D int64_t stride_D
Definition: device/gemm_batched.h:587
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::WarpShape WarpShape_ WarpShape
Definition: device/gemm_batched.h:537
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::epilogue EpilogueOutputOp::Params epilogue
Definition: device/gemm_batched.h:588
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::update Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: device/gemm_batched.h:672
int64_t stride_A
Definition: device/gemm_batched.h:275
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::InstructionShape InstructionShape_ InstructionShape
Definition: device/gemm_batched.h:538
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::initialize Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: device/gemm_batched.h:666
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::run Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:678
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ElementAccumulator ElementAccumulator_ ElementAccumulator
Definition: device/gemm_batched.h:533
Epilogue::OutputTileIterator::TensorRef ref_C
Definition: kernel/gemm_batched.h:71
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::OperatorClass OperatorClass OperatorClass
Definition: device/gemm_batched.h:227
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::stride_B int64_t stride_B
Definition: device/gemm_batched.h:583
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropr...
static int const kAlignmentC
Definition: device/gemm_batched.h:237
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::ref_D TensorRef< ElementC, LayoutC > ref_D
Definition: device/gemm_batched.h:586
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::ThreadblockShape ThreadblockShape ThreadblockShape
Definition: device/gemm_batched.h:229
Shared memory storage structure.
Definition: kernel/gemm_batched.h:124
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: device/gemm_batched.h:291
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: kernel/gemm_batched.h:63
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::InstructionShape InstructionShape InstructionShape
Definition: device/gemm_batched.h:231
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::GemmBatched GemmBatched()
Constructs the GEMM.
Definition: device/gemm_batched.h:634
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::problem_size GemmCoord problem_size
Definition: device/gemm_batched.h:579
static int const kStages
Definition: device/gemm_batched.h:234
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: device/gemm_batched.h:396
static int const kThreadCount
Definition: kernel/gemm_batched.h:58
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: device/gemm_batched.h:332
int64_t stride_C
Definition: device/gemm_batched.h:279
Parameters structure.
Definition: kernel/gemm_batched.h:61
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
operands fail alignment requirements.
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::ElementC ElementC ElementC
Definition: device/gemm_batched.h:222
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::UnderlyingArguments typename UnderlyingOperator::Arguments UnderlyingArguments
Definition: device/gemm_batched.h:569
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::to_underlying_arguments static UnderlyingArguments to_underlying_arguments(Arguments const &args)
Helper to construct a transposed equivalent for the underying GEMM operator.
Definition: device/gemm_batched.h:637
An error within CUTLASS occurred.
TensorRef< ElementB const, LayoutB > ref_B
Definition: device/gemm_batched.h:276
static int const kAlignmentA
Definition: device/gemm_batched.h:235
Template for generic CUTLASS kernel.
Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: device/gemm_batched.h:366
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::ThreadblockSwizzle ThreadblockSwizzle ThreadblockSwizzle
Definition: device/gemm_batched.h:233
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::EpilogueOutputOp EpilogueOutputOp_ EpilogueOutputOp
Definition: device/gemm_batched.h:539
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
GemmBatched()
Constructs the GEMM.
Definition: device/gemm_batched.h:329
Top-level include for all CUTLASS numeric types.
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::EpilogueOutputOp EpilogueOutputOp EpilogueOutputOp
Definition: device/gemm_batched.h:232
int batch_count
Definition: device/gemm_batched.h:283
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::operator() Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:684
Definitions for GEMM structures.
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::ElementA ElementA_ ElementA
Definition: device/gemm_batched.h:523
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::LayoutA typename layout::LayoutTranspose< LayoutB >::type LayoutA
Definition: device/gemm_batched.h:217
cutlass::gemm::GemmCoord problem_size
Definition: kernel/gemm_batched.h:62
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::ElementB ElementA ElementB
Definition: device/gemm_batched.h:219
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::get_workspace_size static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: device/gemm_batched.h:660
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::ElementA ElementB ElementA
Definition: device/gemm_batched.h:216
TensorRef< ElementC const, LayoutC > ref_C
Definition: device/gemm_batched.h:278
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::Arguments CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, int64_t stride_A_, TensorRef< ElementB const, LayoutB > ref_B_, int64_t stride_B_, TensorRef< ElementC const, LayoutC > ref_C_, int64_t stride_C_, TensorRef< ElementC, LayoutC > ref_D_, int64_t stride_D_, typename EpilogueOutputOp::Params epilogue_, int batch_count_)
Constructs an Arguments structure.
Definition: device/gemm_batched.h:601
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::Arguments CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: device/gemm_batched.h:597
bool TensorRef_aligned(TensorRef< Element, Layout > const &ref, int alignment)
Definition: tensor_ref.h:382
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::ref_B TensorRef< ElementB const, LayoutB > ref_B
Definition: device/gemm_batched.h:582
Operation was successful.
Mma::IteratorA::TensorRef ref_A
Definition: kernel/gemm_batched.h:65
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:407
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Defines tags for architecture-specific configurations.
Definition: kernel/gemm_batched.h:49
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::ElementAccumulator ElementAccumulator ElementAccumulator
Definition: device/gemm_batched.h:226
int64_t stride_B
Definition: device/gemm_batched.h:277
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::WarpShape WarpShape WarpShape
Definition: device/gemm_batched.h:230
cutlass::gemm::device::GemmBatched< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA >::ArchTag ArchTag ArchTag
Definition: device/gemm_batched.h:228
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::LayoutA LayoutA_ LayoutA
Definition: device/gemm_batched.h:524
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:448
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::ref_C TensorRef< ElementC const, LayoutC > ref_C
Definition: device/gemm_batched.h:584
EpilogueOutputOp::Params epilogue
Definition: device/gemm_batched.h:282
Basic include for CUTLASS.
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::batch_count int batch_count
Definition: device/gemm_batched.h:589
kernel::GemmBatched< typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue, ThreadblockSwizzle > GemmKernel
Definition: device/gemm_batched.h:264
Definition: device/gemm_batched.h:213
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
LayoutC_ LayoutC
Definition: device/gemm_batched.h:223
TensorRef< ElementA const, LayoutA > ref_A
Definition: device/gemm_batched.h:274
cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments::stride_A int64_t stride_A
Definition: device/gemm_batched.h:581