CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
default_gemm.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 
36 #pragma once
37 
38 #include "cutlass/cutlass.h"
39 
40 #include "cutlass/layout/matrix.h"
41 #include "cutlass/numeric_types.h"
42 #include "cutlass/arch/wmma.h"
43 
46 
47 #include "cutlass/gemm/gemm.h"
55 
60 
61 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
63 #endif //CUTLASS_ARCH_WMMA_ENABLED
64 
65 
67 
68 namespace cutlass {
69 namespace gemm {
70 namespace kernel {
71 
73 
74 template <
76  typename ElementA_,
78  typename LayoutA_,
80  int kAlignmentA,
82  typename ElementB_,
84  typename LayoutB_,
86  int kAlignmentB,
88  typename ElementC_,
90  typename LayoutC_,
92  typename ElementAccumulator,
94  typename OperatorClass,
96  typename ArchTag,
98  typename ThreadblockShape,
100  typename WarpShape,
102  typename InstructionShape,
104  typename EpilogueOutputOp,
106  typename ThreadblockSwizzle,
108  int Stages,
111  bool SplitKSerial,
113  typename Operator,
115  bool IsBetaZero = false>
116 struct DefaultGemm;
117 
120 template <
122  typename ElementA,
124  typename LayoutA,
126  int kAlignmentA,
128  typename ElementB,
130  typename LayoutB,
132  int kAlignmentB,
134  typename ElementC,
136  typename ElementAccumulator,
138  typename ThreadblockShape,
140  typename WarpShape,
142  typename InstructionShape,
144  typename EpilogueOutputOp,
146  typename ThreadblockSwizzle,
148  bool SplitKSerial,
150  typename Operator
151 >
152 struct DefaultGemm<
153  ElementA, LayoutA, kAlignmentA,
154  ElementB, LayoutB, kAlignmentB,
155  ElementC, layout::RowMajor,
156  ElementAccumulator,
157  arch::OpClassTensorOp,
158  arch::Sm75,
159  ThreadblockShape,
160  WarpShape,
161  InstructionShape,
162  EpilogueOutputOp,
163  ThreadblockSwizzle,
164  2,
165  SplitKSerial,
166  Operator
167 > {
168 
171  ElementA,
172  LayoutA,
173  kAlignmentA,
174  ElementB,
175  LayoutB,
176  kAlignmentB,
177  ElementAccumulator,
179  arch::OpClassTensorOp,
180  arch::Sm75,
181  ThreadblockShape,
182  WarpShape,
183  InstructionShape,
184  2,
185  Operator
186  >::ThreadblockMma;
187 
188  static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
189 
192  ThreadblockShape,
193  typename Mma::Operator,
194  kPartitionsK,
195  EpilogueOutputOp,
196  EpilogueOutputOp::kCount
198 
201 };
202 
205 template <
207  typename ElementA,
209  int kAlignmentA,
211  typename ElementB,
213  int kAlignmentB,
215  typename ElementC,
217  typename ThreadblockShape,
219  typename WarpShape,
221  typename InstructionShape,
223  typename EpilogueOutputOp,
225  typename ThreadblockSwizzle,
227  int InterleavedK,
230  bool SplitKSerial,
232  typename Operator,
234  bool IsBetaZero>
235 struct DefaultGemm<ElementA, layout::ColumnMajorInterleaved<InterleavedK>,
236  kAlignmentA, ElementB,
237  layout::RowMajorInterleaved<InterleavedK>, kAlignmentB,
238  ElementC, layout::ColumnMajorInterleaved<InterleavedK>,
239  int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape,
240  WarpShape, InstructionShape, EpilogueOutputOp,
241  ThreadblockSwizzle, 2, SplitKSerial, Operator, IsBetaZero> {
245 
246  using ElementAccumulator = int32_t;
247 
250  ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC,
251  arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape,
252  InstructionShape, 2, Operator, true>::ThreadblockMma;
253 
254  static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
255 
259  ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp,
260  64 / sizeof_bits<ElementC>::value, InterleavedK,
261  IsBetaZero>::Epilogue;
262 
265 };
266 
268 
269 
271 template <
273  typename ElementA,
275  typename LayoutA,
277  int kAlignmentA,
279  typename ElementB,
281  typename LayoutB,
283  int kAlignmentB,
285  typename ElementC,
287  typename ElementAccumulator,
289  typename ThreadblockShape,
291  typename WarpShape,
293  typename EpilogueOutputOp,
295  typename ThreadblockSwizzle,
297  bool SplitKSerial,
299  typename Operator
300 >
301 struct DefaultGemm<
302  ElementA, LayoutA, kAlignmentA,
303  ElementB, LayoutB, kAlignmentB,
304  ElementC, layout::RowMajor,
305  ElementAccumulator,
306  arch::OpClassTensorOp,
307  arch::Sm70,
308  ThreadblockShape,
309  WarpShape,
310  GemmShape<8, 8, 4>,
311  EpilogueOutputOp,
312  ThreadblockSwizzle,
313  2,
314  SplitKSerial,
315  Operator
316 > {
317 
320  ElementA,
321  LayoutA,
322  kAlignmentA,
323  ElementB,
324  LayoutB,
325  kAlignmentB,
326  ElementAccumulator,
328  arch::OpClassTensorOp,
329  arch::Sm70,
330  ThreadblockShape,
331  WarpShape,
333  2,
334  Operator
335  >::ThreadblockMma;
336 
337  static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
338 
341  ThreadblockShape,
342  typename Mma::Operator,
343  kPartitionsK,
344  EpilogueOutputOp,
345  EpilogueOutputOp::kCount
347 
350 };
351 
353 
355 template <
357  typename ElementA,
359  typename LayoutA,
361  int kAlignmentA,
363  typename ElementB,
365  typename LayoutB,
367  int kAlignmentB,
369  typename ElementC,
371  typename ElementAccumulator,
373  typename ArchTag,
375  typename ThreadblockShape,
377  typename WarpShape,
379  typename EpilogueOutputOp,
381  typename ThreadblockSwizzle,
383  bool SplitKSerial,
385  typename Operator
386  >
387 struct DefaultGemm<
388  ElementA,
389  LayoutA,
390  kAlignmentA,
391  ElementB,
392  LayoutB,
393  kAlignmentB,
394  ElementC,
395  layout::RowMajor,
396  ElementAccumulator,
397  arch::OpClassSimt,
398  ArchTag,
399  ThreadblockShape,
400  WarpShape,
401  GemmShape<1, 1, 1>,
402  EpilogueOutputOp,
403  ThreadblockSwizzle,
404  2,
405  SplitKSerial,
406  Operator> {
409  ElementA,
410  LayoutA,
411  kAlignmentA,
412  ElementB,
413  LayoutB,
414  kAlignmentB,
415  ElementAccumulator,
417  arch::OpClassSimt,
418  arch::Sm50,
419  ThreadblockShape,
420  WarpShape,
422  2,
423  Operator>::ThreadblockMma;
424 
425  static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount;
426  static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars");
427 
430  ThreadblockShape,
431  typename Mma::Operator,
432  EpilogueOutputOp,
433  kEpilogueElementsPerAccess
435 
438 };
439 
441 
444 
445 template <
447  typename LayoutA,
449  int kAlignmentA,
451  typename LayoutB,
453  int kAlignmentB,
455  typename LayoutC,
457  typename ElementC,
459  typename ArchTag,
461  typename ElementAccumulator,
463  typename ThreadblockShape,
465  typename WarpShape,
467  typename EpilogueOutputOp,
469  typename ThreadblockSwizzle,
472  bool SplitKSerial,
474  typename Operator>
475 struct DefaultGemm<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,
476  ElementC, LayoutC, ElementAccumulator, arch::OpClassSimt,
477  ArchTag, ThreadblockShape, WarpShape, GemmShape<1, 1, 4>,
478  EpilogueOutputOp, ThreadblockSwizzle, 2, SplitKSerial,
479  Operator, false> {
480  using InstructionShape = GemmShape<1, 1, 4>;
481  using ElementA = int8_t;
482  using ElementB = int8_t;
483 
484  using OperatorClass = arch::OpClassSimt;
487  LayoutA,
488  kAlignmentA,
489  ElementB,
490  LayoutB,
491  kAlignmentB,
492  ElementAccumulator,
493  LayoutC,
494  arch::OpClassSimt,
495  arch::Sm50,
496  ThreadblockShape,
497  WarpShape,
498  InstructionShape,
499  2,
500  Operator,
501  false
502  >::ThreadblockMma;
503 
504  static int const kEpilogueElementsPerAccess = EpilogueOutputOp::kCount;
505  static_assert(kEpilogueElementsPerAccess == 1, "simt epilogue must operate on scalars");
506 
509  ThreadblockShape,
510  typename Mma::Operator,
511  EpilogueOutputOp,
512  kEpilogueElementsPerAccess
514 
517 };
518 
519 
520 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
521 template <
525  typename ElementA,
527  typename LayoutA,
529  int kAlignmentA,
531  typename ElementB,
533  typename LayoutB,
535  int kAlignmentB,
537  typename ElementC,
539  typename LayoutC,
541  typename ElementAccumulator,
543  typename ArchTag,
545  typename ThreadblockShape,
547  typename WarpShape,
549  typename InstructionShape,
551  typename EpilogueOutputOp,
553  typename ThreadblockSwizzle,
555  int Stages,
558  bool SplitKSerial,
560  typename Operator>
561 struct DefaultGemm<
562  ElementA, LayoutA, kAlignmentA,
563  ElementB, LayoutB, kAlignmentB,
564  ElementC, LayoutC,
565  ElementAccumulator,
566  arch::OpClassWmmaTensorOp,
567  ArchTag,
568  ThreadblockShape, WarpShape, InstructionShape,
569  EpilogueOutputOp,
570  ThreadblockSwizzle,
571  Stages,
572  SplitKSerial,
573  Operator> {
576  ElementA, LayoutA, kAlignmentA,
577  ElementB, LayoutB, kAlignmentB,
578  ElementAccumulator, LayoutC,
579  arch::OpClassWmmaTensorOp,
580  ArchTag,
581  ThreadblockShape,
582  WarpShape,
583  InstructionShape,
584  Stages,
585  Operator>::ThreadblockMma;
586 
587  static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK;
588 
591  ThreadblockShape,
592  typename Mma::Operator,
593  kPartitionsK,
594  EpilogueOutputOp,
595  EpilogueOutputOp::kCount
596  >::Epilogue;
597 
600 };
602 #endif //CUTLASS_ARCH_WMMA_ENABLED
603 
605 
606 } // namespace kernel
607 } // namespace gemm
608 } // namespace cutlass
typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm70, ThreadblockShape, WarpShape, GemmShape< 8, 8, 4 >, 2, Operator >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:335
Definition: default_gemm.h:116
typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, 2, Operator >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:186
Definition: aligned_buffer.h:35
Defines sensible defaults for epilogues for SimtOps.
Definition: default_epilogue_simt.h:70
Definition: arch.h:37
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Definition: arch.h:46
Defines common types used for all GEMM-like operators.
typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, true >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:252
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, arch::OpClassSimt, arch::Sm50, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:502
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Definition: default_mma.h:87
Definition: arch.h:52
Functor performing linear combination operations used by epilogues.
Defines the size of an element in bits.
Definition: numeric_types.h:42
typename cutlass::epilogue::threadblock::DefaultEpilogueVoltaTensorOp< ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:346
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
Epilogue for threadblock scoped GEMMs using Tensor Ops.
typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< ThreadblockShape, typename Mma::Operator, EpilogueOutputOp, kEpilogueElementsPerAccess >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:434
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Defines sensible defaults for epilogues for TensorOps.
Definition: default_epilogue_volta_tensor_op.h:71
typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, EpilogueOutputOp::kCount >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:197
typename cutlass::epilogue::threadblock::DefaultInterleavedEpilogueTensorOp< ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, 64/sizeof_bits< ElementC >::value, InterleavedK, IsBetaZero >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:261
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Epilogue for threadblock scoped GEMMs using Tensor Ops.
Definition: include/cutlass/gemm/kernel/gemm.h:52
Defines layout functions used by TensorRef and derived classes.
Defines sensible defaults for epilogues for WMMA TensorOps.
Definition: default_epilogue_wmma_tensor_op.h:71
Definition: default_epilogue_tensor_op.h:147
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Definition: layout/matrix.h:343
typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, arch::Sm50, ThreadblockShape, WarpShape, GemmShape< 1, 1, 1 >, 2, Operator >::ThreadblockMma Mma
Define the threadblock-scoped matrix multiply-accumulate.
Definition: default_gemm.h:423
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Templates implementing loading of tiles from pitch-linear rank=2 tensors.
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
Defines sensible defaults for epilogues for TensorOps.
Definition: default_epilogue_tensor_op.h:72
typename cutlass::epilogue::threadblock::DefaultEpilogueSimt< ThreadblockShape, typename Mma::Operator, EpilogueOutputOp, kEpilogueElementsPerAccess >::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm.h:513
Basic include for CUTLASS.
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
Epilogue for threadblock scoped GEMMs using SIMT.
Epilogue for threadblock scoped GEMMs using Tensor Ops on Volta.
Definition: layout/matrix.h:237