47 #include <cuda_runtime.h> 161 template <
typename T> T
from_string(std::string
const &);
258 instruction_shape(instruction_shape), element_accumulator(element_accumulator), opcode_class(opcode_class) {}
289 int threadblock_stages = 0,
292 int minimum_compute_capability = 0,
293 int maximum_compute_capability = 0
295 threadblock_shape(threadblock_shape),
296 threadblock_stages(threadblock_stages),
297 warp_count(warp_count),
298 math_instruction(math_instruction),
299 minimum_compute_capability(minimum_compute_capability),
300 maximum_compute_capability(maximum_compute_capability) { }
319 char const * name =
"unknown",
323 name(name), kind(kind), tile_description(tile_description) { }
351 int log_extent_range = 24,
352 int log_stride_range = 24
356 alignment(alignment),
357 log_extent_range(log_extent_range),
358 log_stride_range(log_stride_range) { }
404 gemm_kind(gemm_kind),
408 element_epilogue(element_epilogue),
409 split_k_mode(split_k_mode),
410 transform_A(transform_A),
411 transform_B(transform_B) {}
425 virtual Status can_implement(
426 void const *configuration,
427 void const *arguments)
const = 0;
429 virtual uint64_t get_host_workspace_size(
430 void const *configuration)
const = 0;
432 virtual uint64_t get_device_workspace_size(
433 void const *configuration)
const = 0;
435 virtual Status initialize(
436 void const *configuration,
437 void *host_workspace,
438 void *device_workspace,
439 cudaStream_t stream =
nullptr)
const = 0;
442 void const *arguments,
443 void *host_workspace,
444 void *device_workspace =
nullptr,
445 cudaStream_t stream =
nullptr)
const = 0;
565 void const *
const *
A;
566 void const *
const *
B;
567 void const *
const *
C;
int64_t lda
Definition: library.h:609
int alignment
Alignment restriction on pointers, strides, and extents.
Definition: library.h:336
void const *const * A
Definition: library.h:565
virtual ~Operation()
Definition: library.h:421
High-level description of an operation.
Definition: library.h:304
char const * to_string(OperationKind type, bool pretty=false)
Converts a NumericType enumerant to a string.
Definition: aligned_buffer.h:35
bool is_complex_type(NumericTypeID type)
Returns true if the numeric type is a complex data type or false if real-valued.
void *const * D
Definition: library.h:568
LayoutTypeID layout
Enumerant identifying the layout function for the tensor.
Definition: library.h:333
GemmKind gemm_kind
Indicates the kind of GEMM performed.
Definition: library.h:367
int64_t ldc
Definition: library.h:587
Arguments for GEMM.
Definition: library.h:477
int batch_count
Definition: library.h:560
ComplexTransform
Enumeraed type describing a transformation on a complex value.
Definition: library.h:111
void const *const * C
Definition: library.h:567
gemm::GemmCoord problem_size
Definition: library.h:583
Configuration for batched GEMM in which multiple matrix products are computed.
Definition: library.h:551
bool is_signed_integer(NumericTypeID type)
Returns true if numeric type is a signed integer.
GemmKind
Enumeration indicating what kind of GEMM operation to perform.
Definition: library.h:149
NumericTypeID get_real_type(NumericTypeID type)
Returns the real-valued type underlying a type (only different from 'type' if complex) ...
OperationKind from_string< OperationKind >(std::string const &str)
Parses a NumericType enumerant from a string.
Definition: include/cutlass/gemm/gemm.h:94
int get_layout_stride_rank(LayoutTypeID layout_id)
Returns the rank of a layout's stride base on the LayoutTypeID.
int64_t ldb
Leading dimension of B matrix.
Definition: library.h:517
int64_t const * ldc
Definition: library.h:557
int64_t batched_stride_B
Definition: library.h:620
Complex valued GEMM in which real and imaginary parts are separated by a stride.
Definition: library.h:581
int log_stride_range
log2() of the maximum value each relevant stride may have
Definition: library.h:342
Defines common types used for all GEMM-like operators.
ComplexTransform transform_A
Transformation on A operand.
Definition: library.h:385
int64_t imag_stride_B
Definition: library.h:615
GemmDescription(GemmKind gemm_kind=GemmKind::kGemm, TensorDescription const &A=TensorDescription(), TensorDescription const &B=TensorDescription(), TensorDescription const &C=TensorDescription(), NumericTypeID element_epilogue=NumericTypeID::kInvalid, SplitKMode split_k_mode=SplitKMode::kNone, ComplexTransform transform_A=ComplexTransform::kNone, ComplexTransform transform_B=ComplexTransform::kNone)
Definition: library.h:394
int sizeof_bits(NumericTypeID type)
Returns the size of a data type in bits.
Base class for all device-wide operations.
Definition: library.h:418
int64_t imag_stride_A
Definition: library.h:590
NumericTypeID from_string< NumericTypeID >(std::string const &str)
Parses a NumericType enumerant from a string.
LayoutTypeID
Layout type identifier.
Definition: library.h:63
OpcodeClassID
Indicates the classificaition of the math instruction.
Definition: library.h:139
int64_t ldc
Definition: library.h:611
std::string lexical_cast(int64_t int_value)
Lexical cast from int64_t to string.
ScalarPointerMode pointer_mode
Enumerant indicating whether alpha/beta point to host or device memory.
Definition: library.h:498
int64_t const * ldb
Definition: library.h:556
gemm::GemmCoord problem_size
Definition: library.h:553
int64_t batched_stride_C
Definition: library.h:621
OperationDescription(char const *name="unknown", OperationKind kind=OperationKind::kInvalid, TileDescription const &tile_description=TileDescription())
Definition: library.h:318
int maximum_compute_capability
Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation.
Definition: library.h:281
Configuration for basic GEMM operations.
Definition: library.h:455
int64_t imag_stride_D
Definition: library.h:617
Definition: library.h:238
void const * B
Pointer to B matrix.
Definition: library.h:483
int64_t imag_stride_A
Definition: library.h:614
TensorDescription A
Describes the A operand.
Definition: library.h:370
Structure describing the tiled structure of a GEMM-like computation.
Definition: library.h:263
int split_k_slices
Number of partitions of K dimension.
Definition: library.h:473
int64_t imag_stride_B
Definition: library.h:591
OpcodeClassID from_string< OpcodeClassID >(std::string const &str)
Converts a OpcodeClassID enumerant from a string.
void const * A
Pointer to A matrix.
Definition: library.h:480
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
int64_t ldd
Leading dimension of D matrix.
Definition: library.h:470
ComplexTransform transform_B
Transformation on B operand.
Definition: library.h:388
int64_t ldb
Definition: library.h:610
bool is_signed_type(NumericTypeID type)
Returns true if numeric type is signed.
NumericTypeID element_epilogue
Describes the data type of the scalars passed to the epilogue.
Definition: library.h:379
int64_t const * lda
Definition: library.h:555
int64_t const * ldd
Definition: library.h:558
int minimum_compute_capability
Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation.
Definition: library.h:278
int64_t ldd
Definition: library.h:588
int64_t batch_stride_C
Stride between instances of the C matrix in memory.
Definition: library.h:532
void const *const * B
Definition: library.h:566
NumericTypeID
Numeric data type.
Definition: library.h:77
int64_t lda
Definition: library.h:585
cutlass::gemm::GemmCoord warp_count
Number of warps in each logical dimension.
Definition: library.h:272
int64_t lda
Leading dimension of A matrix.
Definition: library.h:461
bool is_float_type(NumericTypeID type)
Returns true if numeric type is floating-point type.
TileDescription(cutlass::gemm::GemmCoord threadblock_shape=cutlass::gemm::GemmCoord(), int threadblock_stages=0, cutlass::gemm::GemmCoord warp_count=cutlass::gemm::GemmCoord(), MathInstructionDescription math_instruction=MathInstructionDescription(), int minimum_compute_capability=0, int maximum_compute_capability=0)
Definition: library.h:287
bool cast_from_double(std::vector< uint8_t > &bytes, NumericTypeID type, double src)
Casts from a real value represented as a double to the destination type. Returns true if successful...
NumericTypeID element_accumulator
Describes the data type of the internal accumulator.
Definition: library.h:244
Defines a canonical coordinate for rank=4 tensors offering named indices.
TensorDescription B
Describes the B operand.
Definition: library.h:373
void const * alpha
Definition: library.h:569
void const * beta
Host or device pointer to beta scalar.
Definition: library.h:495
int64_t ldd
Leading dimension of D matrix.
Definition: library.h:523
OpcodeClassID opcode_class
Classification of math instruction.
Definition: library.h:247
gemm::GemmCoord problem_size
GEMM problem size.
Definition: library.h:511
int64_t ldc
Leading dimension of C matrix.
Definition: library.h:467
void * D
Pointer to D matrix.
Definition: library.h:489
bool cast_from_uint64(std::vector< uint8_t > &bytes, NumericTypeID type, uint64_t src)
Casts from an unsigned int64 to the destination type. Returns true if successful. ...
TensorDescription C
Describes the source and destination matrices.
Definition: library.h:376
int64_t ldb
Definition: library.h:586
int64_t imag_stride_C
Definition: library.h:616
Batched complex valued GEMM in which real and imaginary parts are separated by a stride.
Definition: library.h:605
int64_t batched_stride_D
Definition: library.h:622
Configuration for batched GEMM in which multiple matrix products are computed.
Definition: library.h:508
int64_t batch_stride_A
Stride between instances of the A matrix in memory.
Definition: library.h:526
bool is_integer_type(NumericTypeID type)
Returns true if numeric type is integer.
ScalarPointerMode
Enumeration indicating whether scalars are in host or device memory.
Definition: library.h:123
NumericTypeID element
Numeric type of an individual element.
Definition: library.h:330
int batch_count
Number of GEMMs in batch.
Definition: library.h:538
void const * C
Pointer to C matrix.
Definition: library.h:486
T from_string(std::string const &)
Lexical cast from string.
int threadblock_stages
Describes the number of pipeline stages in the threadblock-scoped mainloop.
Definition: library.h:269
Defines a canonical coordinate for rank=2 matrices offering named indices.
int64_t imag_stride_D
Definition: library.h:593
LayoutTypeID from_string< LayoutTypeID >(std::string const &str)
Parses a LayoutType enumerant from a string.
ScalarPointerMode pointer_mode
Definition: library.h:571
MathInstructionDescription(cutlass::gemm::GemmCoord instruction_shape=cutlass::gemm::GemmCoord(), NumericTypeID element_accumulator=NumericTypeID::kInvalid, OpcodeClassID opcode_class=OpcodeClassID::kInvalid)
Definition: library.h:253
int64_t batched_stride_A
Definition: library.h:619
Description of all GEMM computations.
Definition: library.h:364
int64_t lda
Leading dimension of A matrix.
Definition: library.h:514
gemm::GemmCoord problem_size
GEMM problem size.
Definition: library.h:458
SplitKMode split_k_mode
Describes the structure of parallel reductions.
Definition: library.h:382
bool cast_from_int64(std::vector< uint8_t > &bytes, NumericTypeID type, int64_t src)
Casts from a signed int64 to the destination type. Returns true if successful.
int log_extent_range
log2() of the maximum extent of each dimension
Definition: library.h:339
char const * name
Unique identifier describing the operation.
Definition: library.h:307
int64_t batch_stride_B
Stride between instances of the B matrix in memory.
Definition: library.h:529
cutlass::gemm::GemmCoord instruction_shape
Shape of the target math instruction.
Definition: library.h:241
void const * beta
Definition: library.h:570
TileDescription tile_description
Describes the tiled structure of a GEMM-like computation.
Definition: library.h:313
int64_t ldc
Leading dimension of C matrix.
Definition: library.h:520
Structure describing the properties of a tensor.
Definition: library.h:327
int64_t ldb
Leading dimension of B matrix.
Definition: library.h:464
gemm::GemmCoord problem_size
Definition: library.h:607
bool is_unsigned_integer(NumericTypeID type)
returns true if numeric type is an unsigned integer
Arguments for GEMM - used by all the GEMM operations.
Definition: library.h:564
OperationKind
Enumeration indicating the kind of operation.
Definition: library.h:117
void const * alpha
Host or device pointer to alpha scalar.
Definition: library.h:492
OperationKind kind
Kind of operation.
Definition: library.h:310
cutlass::gemm::GemmCoord threadblock_shape
Describes the shape of a threadblock (in elements)
Definition: library.h:266
int64_t ldd
Definition: library.h:612
int64_t imag_stride_C
Definition: library.h:592
MathInstructionDescription math_instruction
Core math instruction.
Definition: library.h:275
Basic include for CUTLASS.
SplitKMode
Describes how reductions are performed across threadblocks.
Definition: library.h:130
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
int64_t batch_stride_D
Stride between instances of the D matrix in memory.
Definition: library.h:535
TensorDescription(NumericTypeID element=NumericTypeID::kInvalid, LayoutTypeID layout=LayoutTypeID::kInvalid, int alignment=1, int log_extent_range=24, int log_stride_range=24)
Definition: library.h:347