CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
default_mma_core_sm70.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
32 #pragma once
33 
34 #include "cutlass/cutlass.h"
35 #include "cutlass/array.h"
36 
37 #include "cutlass/numeric_types.h"
38 #include "cutlass/matrix_shape.h"
39 
40 
44 
47 
49 
50 namespace cutlass {
51 namespace gemm {
52 namespace threadblock {
53 
55 
63 template <
66  typename Shape_,
68  typename WarpShape_,
70  typename ElementA_,
72  typename ElementB_,
74  typename ElementC_,
76  typename LayoutC_,
78  typename Operator_>
79 struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<8, 8, 4>, ElementA_,
80  layout::ColumnMajor, ElementB_, layout::RowMajor,
81  ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_
82  > {
83  using Shape = Shape_;
84  using WarpShape = WarpShape_;
86  using ElementA = ElementA_;
88  using ElementB = ElementB_;
90  using ElementC = ElementC_;
91  using LayoutC = LayoutC_;
92  using OperatorClass = arch::OpClassTensorOp;
93 
95  using Operator = Operator_;
96 
98  using WarpCount = GemmShape<
99  Shape::kM / WarpShape::kM,
100  Shape::kN / WarpShape::kN,
101  Shape::kK / WarpShape::kK
102  >;
103 
104  // Divisility requirements
106  !(Shape::kM % WarpShape::kM) &&
107  !(Shape::kN % WarpShape::kN),
108  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
109  );
110 
112  static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
113 
115  static int const kThreads = WarpCount::kCount * kWarpSize;
116 
118  static int const kAccessSizeInBits = 128;
119 
120  //
121  // Shared memory layouts
122  //
123 
124  using SmemLayoutA =
127 
128  // Shared memory layout
129  using SmemLayoutB =
132 
133  //
134  // Iterators to write to shared memory
135  //
136 
140  kThreads,
142  kAccessSizeInBits / sizeof_bits<ElementA>::value
143  >;
144 
148  ElementA,
149  SmemLayoutA,
150  1,
152  >;
153 
157  kThreads,
158  layout::PitchLinearShape<8, 4>,
159  kAccessSizeInBits / sizeof_bits<ElementB>::value
160  >;
161 
165  ElementB,
166  SmemLayoutB,
167  0,
169  >;
170 
171  //
172  // Warp-level matrix multiply operator
173  //
174 
175  // Define the warp-level tensor op
179  32,
180  ElementA,
181  LayoutA,
182  ElementB,
183  LayoutB,
184  ElementC,
186  cutlass::arch::OpMultiplyAdd
187  >,
189  >;
190 
192  WarpShape,
193  ElementA,
194  SmemLayoutA,
195  ElementB,
196  SmemLayoutB,
197  ElementC,
198  LayoutC,
199  Policy
200  >;
201 
203  using MmaPolicy = MmaPolicy<
204  MmaTensorOp,
206  MatrixShape<0, 0>,
207  WarpCount::kK
208  >;
209 };
210 
218 template <
221  typename Shape_,
223  typename WarpShape_,
225  typename ElementA_,
227  typename ElementB_,
229  typename ElementC_,
231  typename LayoutC_,
233  typename Operator_>
234 struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<8, 8, 4>, ElementA_,
235  layout::RowMajor, ElementB_, layout::ColumnMajor,
236  ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_
237  > {
238  using Shape = Shape_;
239  using WarpShape = WarpShape_;
241  using ElementA = ElementA_;
243  using ElementB = ElementB_;
245  using ElementC = ElementC_;
246  using LayoutC = LayoutC_;
247  using OperatorClass = arch::OpClassTensorOp;
248 
250  using Operator = Operator_;
251 
253  using WarpCount = GemmShape<
254  Shape::kM / WarpShape::kM,
255  Shape::kN / WarpShape::kN,
256  Shape::kK / WarpShape::kK
257  >;
258 
259  // Divisility requirements
261  !(Shape::kM % WarpShape::kM) &&
262  !(Shape::kN % WarpShape::kN),
263  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
264  );
265 
267  static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
268 
270  static int const kThreads = WarpCount::kCount * kWarpSize;
271 
273  static int const kAccessSizeInBits = 128;
274 
275  //
276  // Shared memory layouts
277  //
278 
281 
282  // Shared memory layout
285 
286  //
287  // Iterators to write to shared memory
288  //
289 
293  kThreads,
295  kAccessSizeInBits / sizeof_bits<ElementA>::value
296  >;
297 
301  ElementA,
302  SmemLayoutA,
303  0,
305  >;
306 
310  kThreads,
311  layout::PitchLinearShape<4, 8>,
312  kAccessSizeInBits / sizeof_bits<ElementB>::value
313  >;
314 
318  ElementB,
319  SmemLayoutB,
320  1,
322  >;
323 
324  //
325  // Warp-level matrix multiply operator
326  //
327 
328  // Define the warp-level tensor op
332  32,
333  ElementA,
334  LayoutA,
335  ElementB,
336  LayoutB,
337  ElementC,
339  cutlass::arch::OpMultiplyAdd
340  >,
342  >;
343 
345  WarpShape,
346  ElementA,
347  SmemLayoutA,
348  ElementB,
349  SmemLayoutB,
350  ElementC,
351  LayoutC,
352  Policy
353  >;
354 
356  using MmaPolicy = MmaPolicy<
357  MmaTensorOp,
359  MatrixShape<0, 0>,
360  WarpCount::kK
361  >;
362 };
363 
365 
373 template <
376  typename Shape_,
378  typename WarpShape_,
380  typename ElementA_,
382  typename ElementB_,
384  typename ElementC_,
386  typename LayoutC_,
388  typename Operator_>
389 struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<8, 8, 4>, ElementA_,
390  layout::RowMajor, ElementB_, layout::RowMajor, ElementC_,
391  LayoutC_, arch::OpClassTensorOp, 2, Operator_
392  > {
393  using Shape = Shape_;
394  using WarpShape = WarpShape_;
396  using ElementA = ElementA_;
398  using ElementB = ElementB_;
400  using ElementC = ElementC_;
401  using LayoutC = LayoutC_;
402  using OperatorClass = arch::OpClassTensorOp;
403 
405  using Operator = Operator_;
406 
408  using WarpCount = GemmShape<
409  Shape::kM / WarpShape::kM,
410  Shape::kN / WarpShape::kN,
411  Shape::kK / WarpShape::kK
412  >;
413 
414  // Divisility requirements
416  !(Shape::kM % WarpShape::kM) &&
417  !(Shape::kN % WarpShape::kN),
418  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
419  );
420 
422  static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
423 
425  static int const kThreads = WarpCount::kCount * kWarpSize;
426 
428  static int const kAccessSizeInBits = 128;
429 
430  //
431  // Shared memory layouts
432  //
433 
436 
437  // Shared memory layout
440 
441  //
442  // Iterators to write to shared memory
443  //
444 
448  kThreads,
450  kAccessSizeInBits / sizeof_bits<ElementA>::value
451  >;
452 
456  ElementA,
457  SmemLayoutA,
458  0,
460  >;
461 
465  kThreads,
467  kAccessSizeInBits / sizeof_bits<ElementB>::value
468  >;
469 
473  ElementB,
474  SmemLayoutB,
475  0,
477  >;
478 
479  //
480  // Warp-level matrix multiply operator
481  //
482 
483  // Define the warp-level tensor op
487  32,
488  ElementA,
489  LayoutA,
490  ElementB,
491  LayoutB,
492  ElementC,
494  cutlass::arch::OpMultiplyAdd
495  >,
497  >;
498 
500  WarpShape,
501  ElementA,
502  SmemLayoutA,
503  ElementB,
504  SmemLayoutB,
505  ElementC,
506  LayoutC,
507  Policy
508  >;
509 
511  using MmaPolicy = MmaPolicy<
512  MmaTensorOp,
514  MatrixShape<0, 0>,
515  WarpCount::kK
516  >;
517 };
518 
520 
528 template <
531  typename Shape_,
533  typename WarpShape_,
535  typename ElementA_,
537  typename ElementB_,
539  typename ElementC_,
541  typename LayoutC_,
543  typename Operator_>
544 struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<8, 8, 4>, ElementA_,
545  layout::ColumnMajor, ElementB_, layout::ColumnMajor,
546  ElementC_, LayoutC_, arch::OpClassTensorOp, 2, Operator_
547  > {
548  using Shape = Shape_;
549  using WarpShape = WarpShape_;
551  using ElementA = ElementA_;
553  using ElementB = ElementB_;
555  using ElementC = ElementC_;
556  using LayoutC = LayoutC_;
557  using OperatorClass = arch::OpClassTensorOp;
558 
560  using Operator = Operator_;
561 
563  using WarpCount = GemmShape<
564  Shape::kM / WarpShape::kM,
565  Shape::kN / WarpShape::kN,
566  Shape::kK / WarpShape::kK
567  >;
568 
569  // Divisility requirements
571  !(Shape::kM % WarpShape::kM) &&
572  !(Shape::kN % WarpShape::kN),
573  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
574  );
575 
577  static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
578 
580  static int const kThreads = WarpCount::kCount * kWarpSize;
581 
583  static int const kAccessSizeInBits = 128;
584 
585  //
586  // Shared memory layouts
587  //
588 
591 
592  // Shared memory layout
595 
596  //
597  // Iterators to write to shared memory
598  //
599 
603  kThreads,
605  kAccessSizeInBits / sizeof_bits<ElementA>::value
606  >;
607 
611  ElementA,
612  SmemLayoutA,
613  1,
615  >;
616 
620  kThreads,
622  kAccessSizeInBits / sizeof_bits<ElementB>::value
623  >;
624 
628  ElementB,
629  SmemLayoutB,
630  1,
632  >;
633 
634  //
635  // Warp-level matrix multiply operator
636  //
637 
638  // Define the warp-level tensor op
642  32,
643  ElementA,
644  LayoutA,
645  ElementB,
646  LayoutB,
647  ElementC,
649  cutlass::arch::OpMultiplyAdd
650  >,
652  >;
653 
655  WarpShape,
656  ElementA,
657  SmemLayoutA,
658  ElementB,
659  SmemLayoutB,
660  ElementC,
661  LayoutC,
662  Policy
663  >;
664 
666  using MmaPolicy = MmaPolicy<
667  MmaTensorOp,
669  MatrixShape<0, 0>,
670  WarpCount::kK
671  >;
672 };
673 
674 } // namespace threadblock
675 } // namespace gemm
676 } // namespace cutlass
Template mapping a row-major view of pitch-linear memory to VoltaTensorOpMultiplicandCongruous.
Definition: tensor_op_multiplicand_sm70.h:630
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
Templates implementing loading of tiles from pitch-linear rank=2 tensors.
Query the number of threads per warp.
Definition: gemm/warp/mma.h:43
Templates implementing warp-level matrix multiply-accumulate operations targeting Tensor Cores...
Definition: default_mma_core.h:90
Templates implementing how threads are mapped to a given tile.
Definition: tensor_op_multiplicand_sm70.h:848
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Defines a Shape template for matrix tiles.
MmaPolicy< MmaTensorOp, MatrixShape< 0, 0 >, MatrixShape< 0, 0 >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_sm70.h:671
Defines the size of an element in bits.
Definition: numeric_types.h:42
MmaPolicy< MmaTensorOp, MatrixShape< 0, 0 >, MatrixShape< 0, 0 >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_sm70.h:208
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Definition: pitch_linear_thread_map.h:205
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_tensor_op_sm70.h:77
MmaPolicy< MmaTensorOp, MatrixShape< 0, 0 >, MatrixShape< 0, 0 >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_sm70.h:361
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Definition: regular_tile_iterator.h:50
#define static_assert(__e, __m)
Definition: platform.h:153
Policy.
Definition: mma_tensor_op_policy.h:48
Definition: tensor_op_multiplicand_sm70.h:943
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
MmaPolicy< MmaTensorOp, MatrixShape< 0, 0 >, MatrixShape< 0, 0 >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_sm70.h:516
Matrix multiply-add operation.
Definition: arch/mma.h:92
Template mapping a column-major view of pitch-linear memory to VoltaTensorOpMultiplicandCongruous.
Definition: tensor_op_multiplicand_sm70.h:191
Basic include for CUTLASS.