39 #include "cutlass/gemm/kernel/default_gemm_complex.h" 173 typename ElementAccumulator_ = ElementC_,
175 typename OperatorClass_ = arch::OpClassSimt,
177 typename ArchTag_ = arch::Sm70,
179 typename ThreadblockShape_ =
typename DefaultGemmConfiguration<
180 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
181 ElementAccumulator_>::ThreadblockShape,
183 typename WarpShape_ =
typename DefaultGemmConfiguration<
184 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
185 ElementAccumulator_>::WarpShape,
187 typename InstructionShape_ =
typename DefaultGemmConfiguration<
188 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
189 ElementAccumulator_>::InstructionShape,
191 typename EpilogueOutputOp_ =
typename DefaultGemmConfiguration<
192 OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
193 ElementAccumulator_>::EpilogueOutputOp,
195 typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle,
198 DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
199 ElementC_, ElementAccumulator_>::kStages,
205 bool SplitKSerial =
false 234 using GemmKernel =
typename kernel::DefaultGemmComplex<
288 typename EpilogueOutputOp::Params epilogue_ =
289 typename EpilogueOutputOp::Params(),
290 int split_k_slices = 1
292 problem_size(problem_size_),
298 split_k_slices(split_k_slices) {
306 typename GemmKernel::Params params_;
316 if (!kSplitKSerial && args.split_k_slices > 1) {
326 if (kSplitKSerial && args.split_k_slices > 1) {
329 ThreadblockSwizzle threadblock_swizzle;
333 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
334 args.split_k_slices);
336 return sizeof(int) *
size_t(tiled_shape.
m()) *
size_t(tiled_shape.
n());
346 ThreadblockSwizzle threadblock_swizzle;
350 {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
351 args.split_k_slices);
354 if (args.split_k_slices > 1) {
361 cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
363 if (result != cudaSuccess) {
370 if (args.split_k_slices > 1) {
376 params_ =
typename GemmKernel::Params{
379 args.ref_A.non_const_ref(),
380 args.ref_B.non_const_ref(),
381 args.ref_C.non_const_ref(),
384 static_cast<int *
>(workspace)
393 if (kSplitKSerial && args.split_k_slices > 1) {
399 params_.ref_A.reset(args.ref_A.non_const_ref().data());
400 params_.ref_B.reset(args.ref_B.non_const_ref().data());
401 params_.ref_C.reset(args.ref_C.non_const_ref().data());
402 params_.ref_D.reset(args.ref_D.data());
403 params_.semaphore =
static_cast<int *
>(workspace);
411 ThreadblockSwizzle threadblock_swizzle;
413 dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
414 dim3 block(GemmKernel::kThreadCount, 1, 1);
418 int smem_size = int(
sizeof(
typename GemmKernel::SharedStorage));
419 if (smem_size >= (48 << 10)) {
420 result = cudaFuncSetAttribute(Kernel<GemmKernel>,
421 cudaFuncAttributeMaxDynamicSharedMemorySize,
424 if (result != cudaSuccess) {
428 result = cudaFuncSetAttribute(
430 cudaFuncAttributePreferredSharedMemoryCarveout, 100);
432 if (result != cudaSuccess) {
437 cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
439 result = cudaGetLastError();
452 void *workspace =
nullptr,
453 cudaStream_t stream =
nullptr) {
458 status =
run(stream);
480 typename ElementAccumulator_,
482 typename OperatorClass_,
486 typename ThreadblockShape_,
490 typename InstructionShape_,
492 typename EpilogueOutputOp_,
494 typename ThreadblockSwizzle_,
526 using ElementA = ElementA_;
527 using LayoutA = LayoutA_;
529 using ElementB = ElementB_;
530 using LayoutB = LayoutB_;
532 using ElementC = ElementC_;
536 using ElementAccumulator = ElementAccumulator_;
537 using OperatorClass = OperatorClass_;
538 using ArchTag = ArchTag_;
539 using ThreadblockShape = ThreadblockShape_;
540 using WarpShape = WarpShape_;
541 using InstructionShape = InstructionShape_;
542 using EpilogueOutputOp = EpilogueOutputOp_;
543 using ThreadblockSwizzle = ThreadblockSwizzle_;
544 static int const kStages = Stages;
545 static bool const kSplitKSerial = SplitKSerial;
602 typename EpilogueOutputOp::Params epilogue_ =
603 typename EpilogueOutputOp::Params(),
604 int split_k_slices = 1
606 problem_size(problem_size_),
612 split_k_slices(split_k_slices) { }
627 {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
628 {args.ref_B.data(), args.ref_B.stride(0)},
629 {args.ref_A.data(), args.ref_A.stride(0)},
630 {args.ref_C.data(), args.ref_C.stride(0)},
631 {args.ref_D.data(), args.ref_D.stride(0)},
640 return UnderlyingOperator::can_implement(to_underlying_arguments(args));
646 return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
652 return underlying_operator_.
initialize(to_underlying_arguments(args), workspace);
658 return underlying_operator_.
update(to_underlying_arguments(args), workspace);
664 return underlying_operator_.
run(stream);
675 void *workspace =
nullptr,
676 cudaStream_t stream =
nullptr) {
681 status =
run(stream);
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:450
static ComplexTransform const kTransformA
Definition: include/cutlass/gemm/device/gemm_complex.h:229
TensorRef< ElementA const, LayoutA > ref_A
Definition: include/cutlass/gemm/device/gemm_complex.h:263
Definition: include/cutlass/gemm/device/gemm_complex.h:207
Definition: aligned_buffer.h:35
ComplexTransform
Enumeraed type describing a transformation on a complex value.
Definition: complex.h:43
cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::ElementB ElementA ElementB
Definition: include/cutlass/gemm/device/gemm_complex.h:213
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::ref_C TensorRef< ElementC const, LayoutC > ref_C
Definition: include/cutlass/gemm/device/gemm_complex.h:581
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:409
Specified problem size is not supported by operator.
GemmCoord problem_size
Definition: include/cutlass/gemm/device/gemm_complex.h:262
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: include/cutlass/gemm/device/gemm_complex.h:314
cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::EpilogueOutputOp EpilogueOutputOp EpilogueOutputOp
Definition: include/cutlass/gemm/device/gemm_complex.h:226
static int const kStages
Definition: include/cutlass/gemm/device/gemm_complex.h:228
cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::ArchTag ArchTag ArchTag
Definition: include/cutlass/gemm/device/gemm_complex.h:222
TensorRef< ElementB const, LayoutB > ref_B
Definition: include/cutlass/gemm/device/gemm_complex.h:264
Definition: include/cutlass/gemm/gemm.h:94
EpilogueOutputOp::Params epilogue
Definition: include/cutlass/gemm/device/gemm_complex.h:267
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::operator() Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:673
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: include/cutlass/gemm/device/gemm_complex.h:343
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::split_k_slices int split_k_slices
Definition: include/cutlass/gemm/device/gemm_complex.h:584
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::to_underlying_arguments static UnderlyingArguments to_underlying_arguments(Arguments const &args)
Helper to construct a transposed equivalent for the underying GEMM operator.
Definition: include/cutlass/gemm/device/gemm_complex.h:625
cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::LayoutA typename layout::LayoutTranspose< LayoutB >::type LayoutA
Definition: include/cutlass/gemm/device/gemm_complex.h:211
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
Argument structure.
Definition: include/cutlass/gemm/device/gemm_complex.h:256
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: include/cutlass/gemm/device/gemm_complex.h:324
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::initialize Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: include/cutlass/gemm/device/gemm_complex.h:650
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::run Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:662
LayoutC_ LayoutC
Definition: include/cutlass/gemm/device/gemm_complex.h:217
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: include/cutlass/gemm/device/gemm_complex.h:276
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::ElementC ElementC ElementC
Definition: include/cutlass/gemm/device/gemm_complex.h:216
cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::ElementAccumulator ElementAccumulator ElementAccumulator
Definition: include/cutlass/gemm/device/gemm_complex.h:220
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::get_workspace_size static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: include/cutlass/gemm/device/gemm_complex.h:644
An error within CUTLASS occurred.
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:445
Template for generic CUTLASS kernel.
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::GemmComplex GemmComplex()
Constructs the GEMM.
Definition: include/cutlass/gemm/device/gemm_complex.h:622
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::operator() Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:668
cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::ThreadblockSwizzle ThreadblockSwizzle ThreadblockSwizzle
Definition: include/cutlass/gemm/device/gemm_complex.h:227
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
GemmComplex()
Constructs the GEMM.
Definition: include/cutlass/gemm/device/gemm_complex.h:311
Definitions for GEMM structures.
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::GemmKernel typename UnderlyingOperator::GemmKernel GemmKernel
Definition: include/cutlass/gemm/device/gemm_complex.h:569
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::problem_size GemmCoord problem_size
Definition: include/cutlass/gemm/device/gemm_complex.h:578
cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::LayoutB typename layout::LayoutTranspose< LayoutA >::type LayoutB
Definition: include/cutlass/gemm/device/gemm_complex.h:214
TensorRef< ElementC, LayoutC > ref_D
Definition: include/cutlass/gemm/device/gemm_complex.h:266
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
TensorRef< ElementC const, LayoutC > ref_C
Definition: include/cutlass/gemm/device/gemm_complex.h:265
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::UnderlyingArguments typename UnderlyingOperator::Arguments UnderlyingArguments
Definition: include/cutlass/gemm/device/gemm_complex.h:568
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::ref_B TensorRef< ElementB const, LayoutB > ref_B
Definition: include/cutlass/gemm/device/gemm_complex.h:580
cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::WarpShape WarpShape WarpShape
Definition: include/cutlass/gemm/device/gemm_complex.h:224
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: include/cutlass/gemm/device/gemm_complex.h:391
static bool const kSplitKSerial
Definition: include/cutlass/gemm/device/gemm_complex.h:231
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::update Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: include/cutlass/gemm/device/gemm_complex.h:656
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::Arguments CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: include/cutlass/gemm/device/gemm_complex.h:592
The given workspace is null when it is required to be non-null.
Operation was successful.
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::can_implement static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: include/cutlass/gemm/device/gemm_complex.h:638
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Defines tags for architecture-specific configurations.
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::ref_D TensorRef< ElementC, LayoutC > ref_D
Definition: include/cutlass/gemm/device/gemm_complex.h:582
int split_k_slices
Definition: include/cutlass/gemm/device/gemm_complex.h:268
cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::InstructionShape InstructionShape InstructionShape
Definition: include/cutlass/gemm/device/gemm_complex.h:225
cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::OperatorClass OperatorClass OperatorClass
Definition: include/cutlass/gemm/device/gemm_complex.h:221
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)
Constructs an Arguments structure.
Definition: include/cutlass/gemm/device/gemm_complex.h:282
cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::GemmKernel typename kernel::DefaultGemmComplex< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, kStages, kTransformA, kTransformB, kSplitKSerial >::GemmKernel GemmKernel
Define the kernel.
Definition: include/cutlass/gemm/device/gemm_complex.h:253
cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::ThreadblockShape ThreadblockShape ThreadblockShape
Definition: include/cutlass/gemm/device/gemm_complex.h:223
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::epilogue EpilogueOutputOp::Params epilogue
Definition: include/cutlass/gemm/device/gemm_complex.h:583
cutlass::gemm::device::GemmComplex< ElementB, typename layout::LayoutTranspose< LayoutB >::type, ElementA, typename layout::LayoutTranspose< LayoutA >::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, TransformA, TransformB, SplitKSerial >::ElementA ElementB ElementA
Definition: include/cutlass/gemm/device/gemm_complex.h:210
static ComplexTransform const kTransformB
Definition: include/cutlass/gemm/device/gemm_complex.h:230
Basic include for CUTLASS.
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::Arguments CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)
Constructs an Arguments structure.
Definition: include/cutlass/gemm/device/gemm_complex.h:596
cutlass::gemm::device::GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >::Arguments::ref_A TensorRef< ElementA const, LayoutA > ref_A
Definition: include/cutlass/gemm/device/gemm_complex.h:579