CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
default_mma.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/arch/arch.h"
34 #include "cutlass/arch/wmma.h"
35 
40 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
42 #endif //CUTLASS_ARCH_WMMA_ENABLED
43 
45 
46 namespace cutlass {
47 namespace gemm {
48 namespace threadblock {
49 
51 
52 template <
54  typename ElementA_,
56  typename LayoutA_,
58  int kAlignmentA,
60  typename ElementB_,
62  typename LayoutB_,
64  int kAlignmentB,
66  typename ElementAccumulator_,
68  typename LayoutC_,
70  typename OperatorClass_,
72  typename ArchTag_,
74  typename ThreadblockShape_,
76  typename WarpShape_,
78  typename InstructionShape_,
80  int Stages,
82  typename Operator,
85  bool AccumulatorsInRowMajor = false
86  >
87 struct DefaultMma;
88 
90 
92 template <
94  typename ElementA,
96  typename LayoutA,
98  int kAlignmentA,
100  typename ElementB,
102  typename LayoutB,
104  int kAlignmentB,
106  typename ElementAccumulator,
108  typename ArchTag,
110  typename ThreadblockShape,
112  typename WarpShape,
114  typename InstructionShape,
116  typename Operator>
117 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
118  kAlignmentB, ElementAccumulator, layout::RowMajor,
119  arch::OpClassSimt, ArchTag, ThreadblockShape, WarpShape,
120  InstructionShape, 2, Operator, false> {
121  // Define the MmaCore components
123  ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
124  ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
125  arch::OpClassSimt, 2, Operator>;
126 
127  // Define iterators over tiles from the A operand
128  using IteratorA =
131  ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>;
132 
133  // Define iterators over tiles from the B operand
134  using IteratorB =
137  ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>;
138 
139  // Define the threadblock-scoped pipelined matrix multiply
141  typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
142  IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
143  layout::RowMajor, typename MmaCore::MmaPolicy>;
144 };
145 
146 
148 template <
150  typename ElementA,
152  typename LayoutA,
154  int kAlignmentA,
156  typename ElementB,
158  typename LayoutB,
160  int kAlignmentB,
162  typename ElementAccumulator,
164  typename ArchTag,
166  typename ThreadblockShape,
168  typename WarpShape,
170  typename InstructionShape,
172  typename Operator
173  >
174 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
175  kAlignmentB, ElementAccumulator, layout::RowMajor,
176  arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape,
177  InstructionShape, 2, Operator, false> {
178  // Define the MmaCore components
180  ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
181  ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
182  arch::OpClassTensorOp, 2, Operator>;
183 
184  // Define iterators over tiles from the A operand
185  using IteratorA =
188  ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>;
189 
190  // Define iterators over tiles from the B operand
191  using IteratorB =
194  ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>;
195 
196  // Define the threadblock-scoped pipelined matrix multiply
198  typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
199  IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
200  layout::RowMajor, typename MmaCore::MmaPolicy>;
201 };
203 
205 template <
207  typename ElementA,
209  typename LayoutA,
211  int kAlignmentA,
213  typename ElementB,
215  typename LayoutB,
217  int kAlignmentB,
219  typename ElementAccumulator,
221  typename OperatorClass,
223  typename ArchTag,
225  typename ThreadblockShape,
227  typename WarpShape,
229  typename InstructionShape,
231  typename Operator,
233  int InterleavedK>
234 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
235  kAlignmentB, ElementAccumulator,
236  layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass,
237  ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
238  Operator, true> {
239  // Define the MmaCore components
241  ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
242  ElementB, LayoutB, ElementAccumulator,
243  layout::ColumnMajorInterleaved<InterleavedK>, OperatorClass, 2, Operator,
244  true>;
245 
246  static_assert(kAlignmentA == 128 / sizeof_bits<ElementA>::value,
247  "Alignment must match thread data map's vector length");
248 
249  static_assert(kAlignmentB ==128 / sizeof_bits<ElementB>::value,
250  "Alignment must match thread data map's vector length");
251 
252  // Define iterators over tiles from the A operand
255  LayoutA, 1, typename MmaCore::IteratorThreadMapA>;
256 
257  // Define iterators over tiles from the B operand
260  LayoutB, 0, typename MmaCore::IteratorThreadMapB>;
261 
262  // Define the threadblock-scoped pipelined matrix multiply
264  typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
265  IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
266  layout::ColumnMajorInterleaved<InterleavedK>,
267  typename MmaCore::MmaPolicy>;
268 };
269 
273 template <
275  typename LayoutA,
277  int kAlignmentA,
279  typename LayoutB,
281  int kAlignmentB,
283  typename ElementAccumulator,
285  typename ArchTag,
287  typename ThreadblockShape,
289  typename Operator,
291  typename WarpShape>
292 struct DefaultMma<int8_t, LayoutA, kAlignmentA, int8_t, LayoutB, kAlignmentB,
293  ElementAccumulator, layout::RowMajor, arch::OpClassSimt,
294  ArchTag, ThreadblockShape, WarpShape, GemmShape<1, 1, 4>, 2,
295  Operator, false> {
296  using InstructionShape = GemmShape<1, 1, 4>;
297  using ElementA = int8_t;
298  using ElementB = int8_t;
299  using OperatorClass = arch::OpClassSimt;
300 
303 
304  // Define the MmaCore components
306  ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
307  ElementB, LayoutB, ElementAccumulator, layout::RowMajor,
308  OperatorClass, 2, Operator>;
309 
310  // Define iterators over tiles from the A operand
311  using IteratorA =
314  ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA>;
315 
316  // Define iterators over tiles from the B operand
317  using IteratorB =
320  ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB>;
321 
322  // Define the threadblock-scoped pipelined matrix multiply
324  typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
325  IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
326  layout::RowMajor, typename MmaCore::MmaPolicy>;
327 };
328 
329 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
330 template <
333  typename ElementA,
335  typename LayoutA,
337  int kAlignmentA,
339  typename ElementB,
341  typename LayoutB,
343  int kAlignmentB,
345  typename ElementAccumulator,
347  typename LayoutC,
349  typename ArchTag,
351  typename ThreadblockShape,
353  typename WarpShape,
355  typename InstructionShape,
357  typename Operator>
358 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
359  kAlignmentB, ElementAccumulator, LayoutC,
360  arch::OpClassWmmaTensorOp, ArchTag, ThreadblockShape, WarpShape,
361  InstructionShape, 2, Operator> {
362  // Define the MmaCore components
364  ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
365  ElementB, LayoutB, ElementAccumulator, LayoutC,
366  arch::OpClassWmmaTensorOp, 2, Operator>;
367 
368  // Define iterators over tiles from the A operand
369  using IteratorA =
372  ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>;
373 
374  // Define iterators over tiles from the B operand
375  using IteratorB =
378  ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>;
379 
380  // Define the threadblock-scoped pipelined matrix multiply
382  typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
383  IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
384  LayoutC, typename MmaCore::MmaPolicy>;
385 };
386 
388 template <
390  typename ElementA,
392  typename LayoutA,
394  int kAlignmentA,
396  typename ElementB,
398  typename LayoutB,
400  int kAlignmentB,
402  typename ElementAccumulator,
404  typename LayoutC,
406  typename ArchTag,
408  typename ThreadblockShape,
410  typename WarpShape,
412  typename InstructionShape,
414  typename Operator>
415 struct DefaultMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
416  kAlignmentB, ElementAccumulator, LayoutC,
417  arch::OpClassWmmaTensorOp, ArchTag, ThreadblockShape, WarpShape,
418  InstructionShape, 1, Operator> {
419  // Define the MmaCore components
421  ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA,
422  ElementB, LayoutB, ElementAccumulator, LayoutC,
423  arch::OpClassWmmaTensorOp, 1, Operator>;
424 
425  // Define iterators over tiles from the A operand
426  using IteratorA =
429  ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, kAlignmentA>;
430 
431  // Define iterators over tiles from the B operand
432  using IteratorB =
435  ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, kAlignmentB>;
436 
437  // Define the threadblock-scoped singlestage matrix multiply
439  typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA,
440  IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
441  LayoutC, typename MmaCore::MmaPolicy>;
442 };
444 #endif //CUTLASS_ARCH_WMMA_ENABLED
445 
446 } // namespace threadblock
447 } // namespace gemm
448 } // namespace cutlass
449 
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, 2, Operator > MmaCore
Definition: default_mma.h:308
std::is_same (false specialization)
Definition: platform.h:394
Definition: default_mma_core.h:90
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_pipelined.h:86
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_singlestage.h:76
typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, 2, Operator > MmaCore
Definition: default_mma.h:125
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Definition: default_mma.h:87
Defines the size of an element in bits.
Definition: numeric_types.h:42
typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, Operator > MmaCore
Definition: default_mma.h:182
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
#define static_assert(__e, __m)
Definition: platform.h:153
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Definition: transform/threadblock/predicated_tile_iterator.h:133
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Defines tags for architecture-specific configurations.
Definition: layout/matrix.h:343
Templates implementing loading of tiles from pitch-linear rank=2 tensors.
typename cutlass::gemm::threadblock::DefaultMmaCore< ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::ColumnMajorInterleaved< InterleavedK >, OperatorClass, 2, Operator, true > MmaCore
Definition: default_mma.h:244
Definition: predicated_tile_iterator_2dthreadtile.h:133
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
Basic include for CUTLASS.
Templates implementing loading of tiles from pitch-linear rank=2 tensors.