CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
library.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
40 #pragma once
41 
43 
44 #include <vector>
45 #include <string>
46 #include <cstdint>
47 #include <cuda_runtime.h>
48 
49 #include "cutlass/cutlass.h"
50 #include "cutlass/matrix_coord.h"
51 #include "cutlass/tensor_coord.h"
52 #include "cutlass/layout/tensor.h"
53 
54 #include "cutlass/gemm/gemm.h"
56 
57 namespace cutlass {
58 namespace library {
59 
61 
63 enum class LayoutTypeID {
64  kUnknown,
66  kRowMajor,
73  kInvalid
74 };
75 
77 enum class NumericTypeID {
78  kUnknown,
79  kVoid,
80  kB1,
81  kU4,
82  kU8,
83  kU16,
84  kU32,
85  kU64,
86  kS4,
87  kS8,
88  kS16,
89  kS32,
90  kS64,
91  kF16,
92  kF32,
93  kF64,
94  kCF16,
95  kCF32,
96  kCF64,
97  kCS4,
98  kCS8,
99  kCS16,
100  kCS32,
101  kCS64,
102  kCU4,
103  kCU8,
104  kCU16,
105  kCU32,
106  kCU64,
107  kInvalid
108 };
109 
111 enum class ComplexTransform {
112  kNone,
113  kConjugate
114 };
115 
117 enum class OperationKind {
118  kGemm,
119  kInvalid
120 };
121 
123 enum class ScalarPointerMode {
124  kHost,
125  kDevice,
126  kInvalid
127 };
128 
130 enum class SplitKMode {
131  kNone,
132  kSerial,
133  kParallel,
135  kInvalid
136 };
137 
139 enum class OpcodeClassID {
140  kSimt,
141  kTensorOp,
143  kInvalid
144 };
145 
147 
149 enum class GemmKind {
150  kGemm,
151  kBatched,
152  kArray,
155  kInvalid
156 };
157 
159 
161 template <typename T> T from_string(std::string const &);
162 
164 char const *to_string(OperationKind type, bool pretty = false);
165 
167 template <> OperationKind from_string<OperationKind>(std::string const &str);
168 
170 char const *to_string(NumericTypeID type, bool pretty = false);
171 
173 template <> NumericTypeID from_string<NumericTypeID>(std::string const &str);
174 
176 int sizeof_bits(NumericTypeID type);
177 
179 bool is_complex_type(NumericTypeID type);
180 
183 
185 bool is_integer_type(NumericTypeID type);
186 
188 bool is_signed_type(NumericTypeID type);
189 
192 
195 
197 bool is_float_type(NumericTypeID type);
198 
200 char const *to_string(Status status, bool pretty = false);
201 
203 char const *to_string(LayoutTypeID layout, bool pretty = false);
204 
206 template <> LayoutTypeID from_string<LayoutTypeID>(std::string const &str);
207 
209 int get_layout_stride_rank(LayoutTypeID layout_id);
210 
212 char const *to_string(OpcodeClassID type, bool pretty = false);
213 
215 template <>
216 OpcodeClassID from_string<OpcodeClassID>(std::string const &str);
217 
219 std::string lexical_cast(int64_t int_value);
220 
222 bool lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type, std::string const &str);
223 
225 std::string lexical_cast(std::vector<uint8_t> &bytes, NumericTypeID type);
226 
228 bool cast_from_int64(std::vector<uint8_t> &bytes, NumericTypeID type, int64_t src);
229 
231 bool cast_from_uint64(std::vector<uint8_t> &bytes, NumericTypeID type, uint64_t src);
232 
234 bool cast_from_double(std::vector<uint8_t> &bytes, NumericTypeID type, double src);
235 
237 
239 
242 
245 
248 
249  //
250  // Methods
251  //
252 
255  NumericTypeID element_accumulator = NumericTypeID::kInvalid,
257  ):
258  instruction_shape(instruction_shape), element_accumulator(element_accumulator), opcode_class(opcode_class) {}
259 
260 };
261 
264 
267 
270 
273 
276 
279 
282 
283  //
284  // Methods
285  //
286 
289  int threadblock_stages = 0,
292  int minimum_compute_capability = 0,
293  int maximum_compute_capability = 0
294  ):
295  threadblock_shape(threadblock_shape),
296  threadblock_stages(threadblock_stages),
297  warp_count(warp_count),
298  math_instruction(math_instruction),
299  minimum_compute_capability(minimum_compute_capability),
300  maximum_compute_capability(maximum_compute_capability) { }
301 };
302 
305 
307  char const * name;
308 
311 
314 
315  //
316  // Methods
317  //
319  char const * name = "unknown",
321  TileDescription const & tile_description = TileDescription()
322  ):
323  name(name), kind(kind), tile_description(tile_description) { }
324 };
325 
328 
331 
334 
337 
340 
343 
344  //
345  // Methods
346  //
350  int alignment = 1,
351  int log_extent_range = 24,
352  int log_stride_range = 24
353  ):
354  element(element),
355  layout(layout),
356  alignment(alignment),
357  log_extent_range(log_extent_range),
358  log_stride_range(log_stride_range) { }
359 };
360 
362 
365 
368 
371 
374 
377 
380 
383 
386 
389 
390  //
391  // Methods
392  //
393 
395  GemmKind gemm_kind = GemmKind::kGemm,
399  NumericTypeID element_epilogue = NumericTypeID::kInvalid,
400  SplitKMode split_k_mode = SplitKMode::kNone,
403  ):
404  gemm_kind(gemm_kind),
405  A(A),
406  B(B),
407  C(C),
408  element_epilogue(element_epilogue),
409  split_k_mode(split_k_mode),
410  transform_A(transform_A),
411  transform_B(transform_B) {}
412 };
413 
416 
418 class Operation {
419 public:
420 
421  virtual ~Operation() { }
422 
423  virtual OperationDescription const & description() const = 0;
424 
425  virtual Status can_implement(
426  void const *configuration,
427  void const *arguments) const = 0;
428 
429  virtual uint64_t get_host_workspace_size(
430  void const *configuration) const = 0;
431 
432  virtual uint64_t get_device_workspace_size(
433  void const *configuration) const = 0;
434 
435  virtual Status initialize(
436  void const *configuration,
437  void *host_workspace,
438  void *device_workspace,
439  cudaStream_t stream = nullptr) const = 0;
440 
441  virtual Status run(
442  void const *arguments,
443  void *host_workspace,
444  void *device_workspace = nullptr,
445  cudaStream_t stream = nullptr) const = 0;
446 };
447 
449 
451 //
452 // OperationKind: Gemm
453 // GemmKind: Gemm
454 //
456 
459 
461  int64_t lda;
462 
464  int64_t ldb;
465 
467  int64_t ldc;
468 
470  int64_t ldd;
471 
474 };
475 
478 
480  void const *A;
481 
483  void const *B;
484 
486  void const *C;
487 
489  void *D;
490 
492  void const *alpha;
493 
495  void const *beta;
496 
499 };
500 
502 
504 //
505 // OperationKind: Gemm
506 // GemmKind: Batched
507 
509 
512 
514  int64_t lda;
515 
517  int64_t ldb;
518 
520  int64_t ldc;
521 
523  int64_t ldd;
524 
526  int64_t batch_stride_A;
527 
529  int64_t batch_stride_B;
530 
532  int64_t batch_stride_C;
533 
535  int64_t batch_stride_D;
536 
539 };
540 
543 
545 
547 //
548 // OperationKind: Gemm
549 // GemmKind: Array
550 
552 
554 
555  int64_t const *lda;
556  int64_t const *ldb;
557  int64_t const *ldc;
558  int64_t const *ldd;
559 
561 };
562 
565  void const * const *A;
566  void const * const *B;
567  void const * const *C;
568  void * const *D;
569  void const *alpha;
570  void const *beta;
572 };
573 
575 
577 //
578 // OperationKind: Gemm
579 // GemmKind: Planar complex
580 
582 
584 
585  int64_t lda;
586  int64_t ldb;
587  int64_t ldc;
588  int64_t ldd;
589 
590  int64_t imag_stride_A;
591  int64_t imag_stride_B;
592  int64_t imag_stride_C;
593  int64_t imag_stride_D;
594 };
595 
597 
599 
601 //
602 // OperationKind: Gemm
603 // GemmKind: Planar complex batched
604 //
606 
608 
609  int64_t lda;
610  int64_t ldb;
611  int64_t ldc;
612  int64_t ldd;
613 
614  int64_t imag_stride_A;
615  int64_t imag_stride_B;
616  int64_t imag_stride_C;
617  int64_t imag_stride_D;
618 
623 };
624 
626 
628 
629 } // namespace library
630 } // namespace cutlass
631 
int alignment
Alignment restriction on pointers, strides, and extents.
Definition: library.h:336
void const *const * A
Definition: library.h:565
virtual ~Operation()
Definition: library.h:421
High-level description of an operation.
Definition: library.h:304
char const * to_string(OperationKind type, bool pretty=false)
Converts a NumericType enumerant to a string.
Definition: aligned_buffer.h:35
bool is_complex_type(NumericTypeID type)
Returns true if the numeric type is a complex data type or false if real-valued.
void *const * D
Definition: library.h:568
LayoutTypeID layout
Enumerant identifying the layout function for the tensor.
Definition: library.h:333
GemmKind gemm_kind
Indicates the kind of GEMM performed.
Definition: library.h:367
int64_t ldc
Definition: library.h:587
Arguments for GEMM.
Definition: library.h:477
int batch_count
Definition: library.h:560
ComplexTransform
Enumeraed type describing a transformation on a complex value.
Definition: library.h:111
void const *const * C
Definition: library.h:567
gemm::GemmCoord problem_size
Definition: library.h:583
Configuration for batched GEMM in which multiple matrix products are computed.
Definition: library.h:551
bool is_signed_integer(NumericTypeID type)
Returns true if numeric type is a signed integer.
GemmKind
Enumeration indicating what kind of GEMM operation to perform.
Definition: library.h:149
NumericTypeID get_real_type(NumericTypeID type)
Returns the real-valued type underlying a type (only different from &#39;type&#39; if complex) ...
OperationKind from_string< OperationKind >(std::string const &str)
Parses a NumericType enumerant from a string.
Definition: include/cutlass/gemm/gemm.h:94
int get_layout_stride_rank(LayoutTypeID layout_id)
Returns the rank of a layout&#39;s stride base on the LayoutTypeID.
int64_t ldb
Leading dimension of B matrix.
Definition: library.h:517
int64_t const * ldc
Definition: library.h:557
int64_t batched_stride_B
Definition: library.h:620
Complex valued GEMM in which real and imaginary parts are separated by a stride.
Definition: library.h:581
int log_stride_range
log2() of the maximum value each relevant stride may have
Definition: library.h:342
Defines common types used for all GEMM-like operators.
ComplexTransform transform_A
Transformation on A operand.
Definition: library.h:385
GemmDescription(GemmKind gemm_kind=GemmKind::kGemm, TensorDescription const &A=TensorDescription(), TensorDescription const &B=TensorDescription(), TensorDescription const &C=TensorDescription(), NumericTypeID element_epilogue=NumericTypeID::kInvalid, SplitKMode split_k_mode=SplitKMode::kNone, ComplexTransform transform_A=ComplexTransform::kNone, ComplexTransform transform_B=ComplexTransform::kNone)
Definition: library.h:394
int sizeof_bits(NumericTypeID type)
Returns the size of a data type in bits.
Base class for all device-wide operations.
Definition: library.h:418
int64_t imag_stride_A
Definition: library.h:590
NumericTypeID from_string< NumericTypeID >(std::string const &str)
Parses a NumericType enumerant from a string.
LayoutTypeID
Layout type identifier.
Definition: library.h:63
OpcodeClassID
Indicates the classificaition of the math instruction.
Definition: library.h:139
std::string lexical_cast(int64_t int_value)
Lexical cast from int64_t to string.
ScalarPointerMode pointer_mode
Enumerant indicating whether alpha/beta point to host or device memory.
Definition: library.h:498
int64_t const * ldb
Definition: library.h:556
gemm::GemmCoord problem_size
Definition: library.h:553
int64_t batched_stride_C
Definition: library.h:621
OperationDescription(char const *name="unknown", OperationKind kind=OperationKind::kInvalid, TileDescription const &tile_description=TileDescription())
Definition: library.h:318
int maximum_compute_capability
Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation.
Definition: library.h:281
Configuration for basic GEMM operations.
Definition: library.h:455
void const * B
Pointer to B matrix.
Definition: library.h:483
TensorDescription A
Describes the A operand.
Definition: library.h:370
Structure describing the tiled structure of a GEMM-like computation.
Definition: library.h:263
int split_k_slices
Number of partitions of K dimension.
Definition: library.h:473
int64_t imag_stride_B
Definition: library.h:591
OpcodeClassID from_string< OpcodeClassID >(std::string const &str)
Converts a OpcodeClassID enumerant from a string.
void const * A
Pointer to A matrix.
Definition: library.h:480
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
int64_t ldd
Leading dimension of D matrix.
Definition: library.h:470
ComplexTransform transform_B
Transformation on B operand.
Definition: library.h:388
bool is_signed_type(NumericTypeID type)
Returns true if numeric type is signed.
NumericTypeID element_epilogue
Describes the data type of the scalars passed to the epilogue.
Definition: library.h:379
int64_t const * lda
Definition: library.h:555
int64_t const * ldd
Definition: library.h:558
int minimum_compute_capability
Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation.
Definition: library.h:278
int64_t ldd
Definition: library.h:588
int64_t batch_stride_C
Stride between instances of the C matrix in memory.
Definition: library.h:532
void const *const * B
Definition: library.h:566
NumericTypeID
Numeric data type.
Definition: library.h:77
int64_t lda
Definition: library.h:585
cutlass::gemm::GemmCoord warp_count
Number of warps in each logical dimension.
Definition: library.h:272
int64_t lda
Leading dimension of A matrix.
Definition: library.h:461
bool is_float_type(NumericTypeID type)
Returns true if numeric type is floating-point type.
TileDescription(cutlass::gemm::GemmCoord threadblock_shape=cutlass::gemm::GemmCoord(), int threadblock_stages=0, cutlass::gemm::GemmCoord warp_count=cutlass::gemm::GemmCoord(), MathInstructionDescription math_instruction=MathInstructionDescription(), int minimum_compute_capability=0, int maximum_compute_capability=0)
Definition: library.h:287
bool cast_from_double(std::vector< uint8_t > &bytes, NumericTypeID type, double src)
Casts from a real value represented as a double to the destination type. Returns true if successful...
NumericTypeID element_accumulator
Describes the data type of the internal accumulator.
Definition: library.h:244
Defines a canonical coordinate for rank=4 tensors offering named indices.
TensorDescription B
Describes the B operand.
Definition: library.h:373
void const * alpha
Definition: library.h:569
void const * beta
Host or device pointer to beta scalar.
Definition: library.h:495
int64_t ldd
Leading dimension of D matrix.
Definition: library.h:523
OpcodeClassID opcode_class
Classification of math instruction.
Definition: library.h:247
gemm::GemmCoord problem_size
GEMM problem size.
Definition: library.h:511
int64_t ldc
Leading dimension of C matrix.
Definition: library.h:467
void * D
Pointer to D matrix.
Definition: library.h:489
bool cast_from_uint64(std::vector< uint8_t > &bytes, NumericTypeID type, uint64_t src)
Casts from an unsigned int64 to the destination type. Returns true if successful. ...
TensorDescription C
Describes the source and destination matrices.
Definition: library.h:376
int64_t ldb
Definition: library.h:586
Batched complex valued GEMM in which real and imaginary parts are separated by a stride.
Definition: library.h:605
int64_t batched_stride_D
Definition: library.h:622
Configuration for batched GEMM in which multiple matrix products are computed.
Definition: library.h:508
int64_t batch_stride_A
Stride between instances of the A matrix in memory.
Definition: library.h:526
bool is_integer_type(NumericTypeID type)
Returns true if numeric type is integer.
ScalarPointerMode
Enumeration indicating whether scalars are in host or device memory.
Definition: library.h:123
NumericTypeID element
Numeric type of an individual element.
Definition: library.h:330
int batch_count
Number of GEMMs in batch.
Definition: library.h:538
void const * C
Pointer to C matrix.
Definition: library.h:486
T from_string(std::string const &)
Lexical cast from string.
int threadblock_stages
Describes the number of pipeline stages in the threadblock-scoped mainloop.
Definition: library.h:269
Defines a canonical coordinate for rank=2 matrices offering named indices.
int64_t imag_stride_D
Definition: library.h:593
LayoutTypeID from_string< LayoutTypeID >(std::string const &str)
Parses a LayoutType enumerant from a string.
ScalarPointerMode pointer_mode
Definition: library.h:571
MathInstructionDescription(cutlass::gemm::GemmCoord instruction_shape=cutlass::gemm::GemmCoord(), NumericTypeID element_accumulator=NumericTypeID::kInvalid, OpcodeClassID opcode_class=OpcodeClassID::kInvalid)
Definition: library.h:253
int64_t batched_stride_A
Definition: library.h:619
Description of all GEMM computations.
Definition: library.h:364
int64_t lda
Leading dimension of A matrix.
Definition: library.h:514
gemm::GemmCoord problem_size
GEMM problem size.
Definition: library.h:458
SplitKMode split_k_mode
Describes the structure of parallel reductions.
Definition: library.h:382
bool cast_from_int64(std::vector< uint8_t > &bytes, NumericTypeID type, int64_t src)
Casts from a signed int64 to the destination type. Returns true if successful.
int log_extent_range
log2() of the maximum extent of each dimension
Definition: library.h:339
char const * name
Unique identifier describing the operation.
Definition: library.h:307
int64_t batch_stride_B
Stride between instances of the B matrix in memory.
Definition: library.h:529
cutlass::gemm::GemmCoord instruction_shape
Shape of the target math instruction.
Definition: library.h:241
void const * beta
Definition: library.h:570
TileDescription tile_description
Describes the tiled structure of a GEMM-like computation.
Definition: library.h:313
int64_t ldc
Leading dimension of C matrix.
Definition: library.h:520
Structure describing the properties of a tensor.
Definition: library.h:327
int64_t ldb
Leading dimension of B matrix.
Definition: library.h:464
gemm::GemmCoord problem_size
Definition: library.h:607
bool is_unsigned_integer(NumericTypeID type)
returns true if numeric type is an unsigned integer
Arguments for GEMM - used by all the GEMM operations.
Definition: library.h:564
OperationKind
Enumeration indicating the kind of operation.
Definition: library.h:117
void const * alpha
Host or device pointer to alpha scalar.
Definition: library.h:492
OperationKind kind
Kind of operation.
Definition: library.h:310
cutlass::gemm::GemmCoord threadblock_shape
Describes the shape of a threadblock (in elements)
Definition: library.h:266
int64_t imag_stride_C
Definition: library.h:592
MathInstructionDescription math_instruction
Core math instruction.
Definition: library.h:275
Basic include for CUTLASS.
SplitKMode
Describes how reductions are performed across threadblocks.
Definition: library.h:130
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
int64_t batch_stride_D
Stride between instances of the D matrix in memory.
Definition: library.h:535
TensorDescription(NumericTypeID element=NumericTypeID::kInvalid, LayoutTypeID layout=LayoutTypeID::kInvalid, int alignment=1, int log_extent_range=24, int log_stride_range=24)
Definition: library.h:347