CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
default_mma_core_sm75.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
32 #pragma once
33 
34 #include "cutlass/cutlass.h"
35 #include "cutlass/array.h"
37 
38 #include "cutlass/numeric_types.h"
39 #include "cutlass/matrix_shape.h"
40 
44 
47 
49 
50 namespace cutlass {
51 namespace gemm {
52 namespace threadblock {
53 
55 
63 template <
66  typename Shape_,
68  typename WarpShape_,
70  typename InstructionShape_,
72  typename ElementA_,
74  typename ElementB_,
76  typename ElementC_,
78  typename LayoutC_,
80  typename Operator_>
81 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
82  layout::ColumnMajor, ElementB_, layout::RowMajor,
83  ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_
84  > {
85  using Shape = Shape_;
86  using WarpShape = WarpShape_;
87  using InstructionShape = InstructionShape_;
88  using ElementA = ElementA_;
90  using ElementB = ElementB_;
92  using ElementC = ElementC_;
93  using LayoutC = LayoutC_;
94  using OperatorClass = arch::OpClassTensorOp;
95 
97  using WarpCount = GemmShape<
98  Shape::kM / WarpShape::kM,
99  Shape::kN / WarpShape::kN,
100  Shape::kK / WarpShape::kK
101  >;
102 
103  // Divisility requirements
105  !(Shape::kM % WarpShape::kM) &&
106  !(Shape::kN % WarpShape::kN),
107  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
108  );
109 
111  static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
112 
114  static int const kThreads = WarpCount::kCount * kWarpSize;
115 
117  static int const kAccessSizeInBits = 128;
118 
120  using Operator = Operator_;
121 
122  //
123  // Shared memory layouts
124  //
125 
126  using SmemLayoutA =
129 
130  // Shared memory layout
133 
134  //
135  // Iterators to write to shared memory
136  //
137 
141  kThreads,
143  kAccessSizeInBits / sizeof_bits<ElementA>::value
144  >;
145 
149  ElementA,
150  SmemLayoutA,
151  1,
153  >;
154 
158  kThreads,
159  layout::PitchLinearShape<8, 4>,
160  kAccessSizeInBits / sizeof_bits<ElementB>::value
161  >;
162 
166  ElementB,
167  SmemLayoutB,
168  0,
170  >;
171 
172  //
173  // Warp-level matrix multiply operator
174  //
175 
176  // Define the warp-level tensor op
178  WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
179  ElementC, LayoutC, Operator, WarpCount::kK>::Type;
180 
182  using MmaPolicy = MmaPolicy<
183  MmaTensorOp,
185  MatrixShape<0, 0>,
186  WarpCount::kK
187  >;
188 };
189 
191 
199 template <
202  typename Shape_,
204  typename WarpShape_,
206  typename InstructionShape_,
208  typename ElementA_,
210  typename ElementB_,
212  typename ElementC_,
214  typename LayoutC_,
216  typename Operator_>
217 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
218  layout::RowMajor, ElementB_, layout::ColumnMajor,
219  ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_
220  > {
221  using Shape = Shape_;
222  using WarpShape = WarpShape_;
223  using InstructionShape = InstructionShape_;
224  using ElementA = ElementA_;
226  using ElementB = ElementB_;
228  using ElementC = ElementC_;
229  using LayoutC = LayoutC_;
230  using OperatorClass = arch::OpClassTensorOp;
231 
233  using WarpCount = GemmShape<
234  Shape::kM / WarpShape::kM,
235  Shape::kN / WarpShape::kN,
236  Shape::kK / WarpShape::kK
237  >;
238 
239  // Divisility requirements
241  !(Shape::kM % WarpShape::kM) &&
242  !(Shape::kN % WarpShape::kN),
243  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
244  );
245 
247  static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
248 
250  static int const kThreads = WarpCount::kCount * kWarpSize;
251 
253  static int const kAccessSizeInBits = 128;
254 
256  using Operator = Operator_;
257 
258  // Warp thread arrangement
259  static int const kWarpThreadArrangementContiguousA =
260  Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
261 
262  static int const kWarpThreadArrangementStridedA =
263  kWarpSize / kWarpThreadArrangementContiguousA;
264 
265  static int const kWarpThreadArrangementContiguousB =
266  Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
267 
268  static int const kWarpThreadArrangementStridedB =
269  kWarpSize / kWarpThreadArrangementContiguousB;
270 
271  //
272  // Shared memory layouts
273  //
274 
277 
278  // Shared memory layout
281 
282  //
283  // Iterators to write to shared memory
284  //
285 
289  layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
290  kWarpThreadArrangementStridedA>,
291  kAccessSizeInBits / sizeof_bits<ElementA>::value>;
292 
296  ElementA,
297  SmemLayoutA,
298  0,
300  >;
301 
305  layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
306  kWarpThreadArrangementStridedB>,
307  kAccessSizeInBits / sizeof_bits<ElementB>::value>;
308 
312  ElementB,
313  SmemLayoutB,
314  1,
316  >;
317 
318  //
319  // Warp-level matrix multiply operator
320  //
321 
322  // Define the warp-level tensor op
324  WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
325  ElementC, LayoutC, Operator, WarpCount::kK>::Type;
326 
328  using MmaPolicy = MmaPolicy<
329  MmaTensorOp,
331  MatrixShape<0, 0>,
332  WarpCount::kK
333  >;
334 };
335 
337 
345 template <
348  typename Shape_,
350  typename WarpShape_,
352  typename InstructionShape_,
354  typename ElementA_,
356  typename ElementB_,
358  typename ElementC_,
360  typename LayoutC_,
362  typename Operator_>
363 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
364  layout::RowMajor, ElementB_, layout::RowMajor, ElementC_,
365  LayoutC_, arch::OpClassTensorOp, 2, Operator_
366  > {
367  using Shape = Shape_;
368  using WarpShape = WarpShape_;
369  using InstructionShape = InstructionShape_;
370  using ElementA = ElementA_;
372  using ElementB = ElementB_;
374  using ElementC = ElementC_;
375  using LayoutC = LayoutC_;
376  using OperatorClass = arch::OpClassTensorOp;
377 
379  using WarpCount = GemmShape<
380  Shape::kM / WarpShape::kM,
381  Shape::kN / WarpShape::kN,
382  Shape::kK / WarpShape::kK
383  >;
384 
385  // Divisility requirements
387  !(Shape::kM % WarpShape::kM) &&
388  !(Shape::kN % WarpShape::kN),
389  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
390  );
391 
393  static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
394 
396  static int const kThreads = WarpCount::kCount * kWarpSize;
397 
399  static int const kAccessSizeInBits = 128;
400 
402  using Operator = Operator_;
403 
404  // Warp thread arrangement
405  static int const kWarpThreadArrangementContiguousA =
406  Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
407 
408  static int const kWarpThreadArrangementStridedA =
409  kWarpSize / kWarpThreadArrangementContiguousA;
410 
411  //
412  // Shared memory layouts
413  //
414 
417 
418  // Shared memory layout
421 
422  //
423  // Iterators to write to shared memory
424  //
425 
429  layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
430  kWarpThreadArrangementStridedA>,
431  kAccessSizeInBits / sizeof_bits<ElementA>::value>;
432 
436  ElementA,
437  SmemLayoutA,
438  0,
440  >;
441 
445  kThreads,
447  kAccessSizeInBits / sizeof_bits<ElementB>::value
448  >;
449 
453  ElementB,
454  SmemLayoutB,
455  0,
457  >;
458 
459  //
460  // Warp-level matrix multiply operator
461  //
462 
463  // Define the warp-level tensor op
465  WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
466  ElementC, LayoutC, Operator, WarpCount::kK>::Type;
467 
469  using MmaPolicy = MmaPolicy<
470  MmaTensorOp,
472  MatrixShape<0, 0>,
473  WarpCount::kK
474  >;
475 };
476 
478 
486 template <
489  typename Shape_,
491  typename WarpShape_,
493  typename InstructionShape_,
495  typename ElementA_,
497  typename ElementB_,
499  typename ElementC_,
501  typename LayoutC_,
503  typename Operator_>
504 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
505  layout::ColumnMajor, ElementB_, layout::ColumnMajor,
506  ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_
507  > {
508  using Shape = Shape_;
509  using WarpShape = WarpShape_;
510  using InstructionShape = InstructionShape_;
511  using ElementA = ElementA_;
513  using ElementB = ElementB_;
515  using ElementC = ElementC_;
516  using LayoutC = LayoutC_;
517  using OperatorClass = arch::OpClassTensorOp;
518 
520  using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
521  Shape::kN / WarpShape::kN,
522  Shape::kK / WarpShape::kK>;
523 
524  // Divisility requirements
526  !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
527  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
528 
530  static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
531 
533  static int const kThreads = WarpCount::kCount * kWarpSize;
534 
536  static int const kAccessSizeInBits = 128;
537 
539  using Operator = Operator_;
540 
541  // Warp thread arrangement
542  static int const kWarpThreadArrangementContiguousB =
543  Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
544 
545  static int const kWarpThreadArrangementStridedB =
546  kWarpSize / kWarpThreadArrangementContiguousB;
547 
548  //
549  // Shared memory layouts
550  //
551 
554 
555  // Shared memory layout
558 
559  //
560  // Iterators to write to shared memory
561  //
562 
567  kAccessSizeInBits / sizeof_bits<ElementA>::value>;
568 
573 
577  layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
578  kWarpThreadArrangementStridedB>,
579  kAccessSizeInBits / sizeof_bits<ElementB>::value>;
580 
585 
586  //
587  // Warp-level matrix multiply operator
588  //
589 
590  // Define the warp-level tensor op
592  WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
593  ElementC, LayoutC, Operator, WarpCount::kK>::Type;
594 
597  MatrixShape<0, 0>, WarpCount::kK>;
598 };
599 
608 template <
611  typename Shape_,
613  typename WarpShape_,
615  typename InstructionShape_,
617  typename ElementA_,
619  typename ElementB_,
621  typename ElementC_,
623  typename LayoutC_,
625  typename Operator_,
628  bool AccumulatorsInRowMajor,
630  int InterleavedK>
631 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
632  layout::ColumnMajorInterleaved<InterleavedK>, ElementB_,
633  layout::RowMajorInterleaved<InterleavedK>, ElementC_,
634  LayoutC_, arch::OpClassTensorOp, 2, Operator_,
635  AccumulatorsInRowMajor> {
636  using Shape = Shape_;
637  using WarpShape = WarpShape_;
638  using InstructionShape = InstructionShape_;
639  using ElementA = ElementA_;
641  using ElementB = ElementB_;
643  using ElementC = ElementC_;
644  using LayoutC = LayoutC_;
645  using OperatorClass = arch::OpClassTensorOp;
646  static int const kInterleavedK = InterleavedK;
647 
649  using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
650  Shape::kN / WarpShape::kN,
651  Shape::kK / WarpShape::kK>;
652 
653  // Divisility requirements
655  !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
656  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
657 
659  static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
660 
662  static int const kThreads = WarpCount::kCount * kWarpSize;
663 
665  static int const kAccessSizeInBits = 128;
666 
668  using Operator = Operator_;
669 
670  // Warp thread arrangement
671  static int const kElementsPerAccess =
672  kAccessSizeInBits / sizeof_bits<ElementA>::value;
673 
674  static int const kWarpThreadArrangementContiguous =
675  kInterleavedK / kElementsPerAccess;
676 
677  static int const kWarpThreadArrangementStrided =
678  kWarpSize / kWarpThreadArrangementContiguous;
679 
680  //
681  // Shared memory layouts
682  //
683 
686 
687  // Shared memory layout
690 
691  //
692  // Iterators to write to shared memory
693  //
694 
697  layout::PitchLinearShape<Shape::kM * kInterleavedK,
698  Shape::kK / kInterleavedK>,
699  kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>;
700 
704  layout::PitchLinearShape<kWarpThreadArrangementContiguous,
705  kWarpThreadArrangementStrided>>;
706 
711 
714  layout::PitchLinearShape<Shape::kN * kInterleavedK,
715  Shape::kK / kInterleavedK>,
716  kThreads, layout::PitchLinearShape<32, 1>, kElementsPerAccess>;
717 
721  layout::PitchLinearShape<kWarpThreadArrangementContiguous,
722  kWarpThreadArrangementStrided>>;
723 
728 
729  //
730  // Warp-level matrix multiply operator
731  //
732 
733  // Define the warp-level tensor op
735  WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
736  ElementC, LayoutC, Operator, WarpCount::kK, AccumulatorsInRowMajor>::Type;
737 
740  MatrixShape<0, 0>, WarpCount::kK>;
741 };
742 
744 
745 } // namespace threadblock
746 } // namespace gemm
747 } // namespace cutlass
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Templates implementing storing of tiles from pitch-linear rank=2 tensors.
Definition: aligned_buffer.h:35
Default warp-level GEMM operators selected by data type, size, and layouts of operands.
Query the number of threads per warp.
Definition: gemm/warp/mma.h:43
Definition: tensor_op_multiplicand_sm75.h:734
Definition: default_mma_core.h:90
typename cutlass::gemm::warp::DefaultMmaTensorOp< WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, ElementC, LayoutC, Operator, WarpCount::kK >::Type MmaTensorOp
Definition: default_mma_core_sm75.h:325
Templates implementing how threads are mapped to a given tile.
Partial specialization for m-by-n-by-kgroup.
Definition: default_mma_tensor_op.h:67
Definition: tensor_op_multiplicand_sm75.h:422
Definition: tensor_op_multiplicand_sm75.h:835
typename cutlass::gemm::warp::DefaultMmaTensorOp< WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, ElementC, LayoutC, Operator, WarpCount::kK >::Type MmaTensorOp
Definition: default_mma_core_sm75.h:593
C++ features that may be otherwise unimplemented for CUDA device functions.
MmaPolicy< MmaTensorOp, MatrixShape< 0, 0 >, MatrixShape< 0, 0 >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_sm75.h:187
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
MmaPolicy< MmaTensorOp, MatrixShape< 0, 0 >, MatrixShape< 0, 0 >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_sm75.h:597
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Definition: pitch_linear_thread_map.h:333
Defines a Shape template for matrix tiles.
Defines the size of an element in bits.
Definition: numeric_types.h:42
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Definition: pitch_linear_thread_map.h:205
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Definition: regular_tile_iterator.h:50
#define static_assert(__e, __m)
Definition: platform.h:153
MmaPolicy< MmaTensorOp, MatrixShape< 0, 0 >, MatrixShape< 0, 0 >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_sm75.h:474
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
typename cutlass::gemm::warp::DefaultMmaTensorOp< WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, ElementC, LayoutC, Operator, WarpCount::kK >::Type MmaTensorOp
Definition: default_mma_core_sm75.h:179
typename cutlass::gemm::warp::DefaultMmaTensorOp< WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, ElementC, LayoutC, Operator, WarpCount::kK >::Type MmaTensorOp
Definition: default_mma_core_sm75.h:466
Definition: layout/matrix.h:343
typename cutlass::gemm::warp::DefaultMmaTensorOp< WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, ElementC, LayoutC, Operator, WarpCount::kK, AccumulatorsInRowMajor >::Type MmaTensorOp
Definition: default_mma_core_sm75.h:736
MmaPolicy< MmaTensorOp, MatrixShape< 0, 0 >, MatrixShape< 0, 0 >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_sm75.h:333
Basic include for CUTLASS.
Definition: layout/matrix.h:237
Definition: tensor_op_multiplicand_sm75.h:527