CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
default_mma_core_wmma.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
32 #pragma once
33 
34 #include "cutlass/cutlass.h"
35 #include "cutlass/array.h"
36 #include "cutlass/fast_math.h"
37 #include "cutlass/arch/wmma.h"
38 
39 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
40 
41 #include "cutlass/numeric_types.h"
42 #include "cutlass/matrix_shape.h"
43 
45 
47 
50 
52 
53 namespace cutlass {
54 namespace gemm {
55 namespace threadblock {
56 
58 
66 template <
69  typename Shape_,
71  typename WarpShape_,
73  typename InstructionShape_,
75  typename ElementA_,
77  typename ElementB_,
79  typename ElementC_,
81  typename LayoutC_,
83  typename Operator_,
85  int Stages>
86 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
87  layout::ColumnMajor, ElementB_, layout::RowMajor,
88  ElementC_, LayoutC_, arch::OpClassWmmaTensorOp, Stages,
89  Operator_> {
90  using Shape = Shape_;
91  using WarpShape = WarpShape_;
92  using InstructionShape = InstructionShape_;
93  using ElementA = ElementA_;
94  using LayoutA = layout::ColumnMajor;
95  using ElementB = ElementB_;
96  using LayoutB = layout::RowMajor;
97  using ElementC = ElementC_;
98  using LayoutC = LayoutC_;
99  using OperatorClass = arch::OpClassWmmaTensorOp;
100 
102  using WarpCount = GemmShape<
103  Shape::kM / WarpShape::kM,
104  Shape::kN / WarpShape::kN,
105  Shape::kK / WarpShape::kK
106  >;
107 
108  // Divisility requirements
110  !(Shape::kM % WarpShape::kM) &&
111  !(Shape::kN % WarpShape::kN),
112  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
113  );
114 
116  static int const kWarpSize = warp::WarpSize<arch::OpClassWmmaTensorOp>::value;
117 
119  static int const kThreads = WarpCount::kCount * kWarpSize;
120 
122  static int const kAccessSizeInBits = 128;
123 
125  using Operator = Operator_;
126 
127  //
128  // Shared memory layouts
129  //
130  // NOTE: shared memory layout for wmma is same as the operands' layout in the global memory
131  using SmemLayoutA = LayoutA;
132  using SmemLayoutB = LayoutB;
133 
134  // Pad shared memory to avoid bank conflicts
135  static int const kPaddingA = 128 / sizeof_bits<ElementA>::value;
136  static int const kPaddingB = 128 / sizeof_bits<ElementB>::value;
137 
138  //
139  // Iterators to write to shared memory
140  //
141 
143  using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
144  layout::PitchLinearShape<Shape::kM, Shape::kK>,
145  kThreads,
146  kAccessSizeInBits / sizeof_bits<ElementB>::value
147  >;
148 
150  using SmemIteratorA = transform::threadblock::RegularTileIterator<
151  MatrixShape<Shape::kM, Shape::kK>,
152  ElementA,
153  SmemLayoutA,
154  1,
155  IteratorThreadMapA
156  >;
157 
159  using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
160  layout::PitchLinearShape<Shape::kN, Shape::kK>,
161  kThreads,
162  kAccessSizeInBits / sizeof_bits<ElementB>::value
163  >;
164 
166  using SmemIteratorB = transform::threadblock::RegularTileIterator<
167  MatrixShape<Shape::kK, Shape::kN>,
168  ElementB,
169  SmemLayoutB,
170  0,
171  IteratorThreadMapB
172  >;
173 
174  //
175  // Warp-level matrix multiply operator
176  //
177 
178  // Define the warp-level tensor op
180  cutlass::arch::Wmma<
181  InstructionShape,
182  ElementA,
183  LayoutA,
184  ElementB,
185  LayoutB,
186  ElementC,
187  LayoutC,
188  Operator
189  >,
191  >;
192 
193  using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma<
194  WarpShape,
195  ElementA,
196  SmemLayoutA,
197  ElementB,
198  SmemLayoutB,
199  ElementC,
200  LayoutC,
201  Policy
202  >;
203 
205  using MmaPolicy = MmaPolicy<
206  MmaTensorOp,
207  MatrixShape<kPaddingA, 0>,
208  MatrixShape<0, kPaddingB>,
209  WarpCount::kK
210  >;
211 };
212 
214 
222 template <
225  typename Shape_,
227  typename WarpShape_,
230  typename InstructionShape_,
232  typename ElementA_,
234  typename ElementB_,
236  typename ElementC_,
238  typename LayoutC_,
240  typename Operator_,
242  int Stages>
243 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
244  layout::RowMajor, ElementB_, layout::ColumnMajor,
245  ElementC_, LayoutC_, arch::OpClassWmmaTensorOp, Stages,
246  Operator_> {
247  using Shape = Shape_;
248  using WarpShape = WarpShape_;
249  using InstructionShape = InstructionShape_;
250  using ElementA = ElementA_;
251  using LayoutA = layout::RowMajor;
252  using ElementB = ElementB_;
253  using LayoutB = layout::ColumnMajor;
254  using ElementC = ElementC_;
255  using LayoutC = LayoutC_;
256  using OperatorClass = arch::OpClassWmmaTensorOp;
257 
259  using WarpCount = GemmShape<
260  Shape::kM / WarpShape::kM,
261  Shape::kN / WarpShape::kN,
262  Shape::kK / WarpShape::kK
263  >;
264 
265  // Divisility requirements
267  !(Shape::kM % WarpShape::kM) &&
268  !(Shape::kN % WarpShape::kN),
269  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
270  );
271 
273  static int const kWarpSize = warp::WarpSize<arch::OpClassWmmaTensorOp>::value;
274 
276  static int const kThreads = WarpCount::kCount * kWarpSize;
277 
278 
280  static int const kAccessSizeInBits = 128;
281 
283  using Operator = Operator_;
284 
285  // Warp thread arrangement
286  static int const kWarpThreadArrangementContiguousA =
287  Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
288 
289  static int const kWarpThreadArrangementStridedA =
290  kWarpSize / kWarpThreadArrangementContiguousA;
291 
292  static int const kWarpThreadArrangementContiguousB =
293  Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
294 
295  static int const kWarpThreadArrangementStridedB =
296  kWarpSize / kWarpThreadArrangementContiguousB;
297 
298  //
299  // Shared memory layouts
300  //
301 
302  // shared memory layout for wmma is same as the operands' layout in global memory
303  using SmemLayoutA = LayoutA;
304  using SmemLayoutB = LayoutB;
305 
306  // Pad shared memory to avoid bank conflicts
307  static int const kPaddingA = 128 / sizeof_bits<ElementA>::value;
308  static int const kPaddingB = 128 / sizeof_bits<ElementB>::value;
309 
310  //
311  // Iterators to write to shared memory
312  //
313  using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
314  layout::PitchLinearShape<Shape::kK, Shape::kM>,
315  kThreads,
316  kAccessSizeInBits / sizeof_bits<ElementA>::value
317  >;
318 
320  using SmemIteratorA = transform::threadblock::RegularTileIterator<
321  MatrixShape<Shape::kM, Shape::kK>,
322  ElementA,
323  SmemLayoutA,
324  1,
325  IteratorThreadMapA
326  >;
327 
329  using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
330  layout::PitchLinearShape<Shape::kK, Shape::kN>,
331  kThreads,
332  kAccessSizeInBits / sizeof_bits<ElementB>::value
333  >;
334 
336  using SmemIteratorB = transform::threadblock::RegularTileIterator<
337  MatrixShape<Shape::kK, Shape::kN>,
338  ElementB,
339  SmemLayoutB,
340  0,
341  IteratorThreadMapB // SmemThreadMapB
342  >;
343 
344  //
345  // Warp-level matrix multiply operator
346  //
347 
348  // Define the warp-level tensor op
350  cutlass::arch::Wmma<
351  InstructionShape,
352  ElementA,
353  LayoutA,
354  ElementB,
355  LayoutB,
356  ElementC,
357  LayoutC,
358  Operator
359  >,
361  >;
362 
363  using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma<
364  WarpShape,
365  ElementA,
366  SmemLayoutA,
367  ElementB,
368  SmemLayoutB,
369  ElementC,
370  LayoutC,
371  Policy
372  >;
373 
375  using MmaPolicy = MmaPolicy<
376  MmaTensorOp,
377  MatrixShape<0, kPaddingA>,
378  MatrixShape<kPaddingB, 0>,
379  WarpCount::kK
380  >;
381 };
382 
384 
386 
394 template <
397  typename Shape_,
399  typename WarpShape_,
401  typename InstructionShape_,
403  typename ElementA_,
405  typename ElementB_,
407  typename ElementC_,
409  typename LayoutC_,
411  typename Operator_,
413  int Stages>
414 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
415  layout::RowMajor, ElementB_, layout::RowMajor, ElementC_,
416  LayoutC_, arch::OpClassWmmaTensorOp, Stages, Operator_> {
417  using Shape = Shape_;
418  using WarpShape = WarpShape_;
419  using InstructionShape = InstructionShape_;
420  using ElementA = ElementA_;
421  using LayoutA = layout::RowMajor;
422  using ElementB = ElementB_;
423  using LayoutB = layout::RowMajor;
424  using ElementC = ElementC_;
425  using LayoutC = LayoutC_;
426  using OperatorClass = arch::OpClassWmmaTensorOp;
427 
429  using WarpCount = GemmShape<
430  Shape::kM / WarpShape::kM,
431  Shape::kN / WarpShape::kN,
432  Shape::kK / WarpShape::kK
433  >;
434 
435  // Divisility requirements
437  !(Shape::kM % WarpShape::kM) &&
438  !(Shape::kN % WarpShape::kN),
439  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
440  );
441 
443  static int const kWarpSize = warp::WarpSize<arch::OpClassWmmaTensorOp>::value;
444 
446  static int const kThreads = WarpCount::kCount * kWarpSize;
447 
449  static int const kAccessSizeInBits = 128;
450 
452  using Operator = Operator_;
453 
454  // Warp thread arrangement
455  static int const kWarpThreadArrangementContiguousA =
456  Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
457 
458  static int const kWarpThreadArrangementStridedA =
459  kWarpSize / kWarpThreadArrangementContiguousA;
460 
461  //
462  // Shared memory layouts
463  //
464 
465  // shared memory layout for wmma is same as the operands' layout in global memory
466  using SmemLayoutA = LayoutA;
467  using SmemLayoutB = LayoutB;
468 
469  // Pad shared memory to avoid bank conflicts
470  static int const kPaddingA = 128 / sizeof_bits<ElementA>::value;
471  static int const kPaddingB = 128 / sizeof_bits<ElementB>::value;
472 
473  //
474  // Iterators to write to shared memory
475  //
476 
478  using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
479  layout::PitchLinearShape<Shape::kK, Shape::kM>,
480  kThreads,
481  kAccessSizeInBits / sizeof_bits<ElementA>::value
482  >;
483 
484 
486  using SmemIteratorA = transform::threadblock::RegularTileIterator<
487  MatrixShape<Shape::kM, Shape::kK>,
488  ElementA,
489  SmemLayoutA,
490  1,
491  IteratorThreadMapA
492  >;
493 
495  using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
496  layout::PitchLinearShape<Shape::kN, Shape::kK>,
497  kThreads,
498  kAccessSizeInBits / sizeof_bits<ElementB>::value
499  >;
500 
502  using SmemIteratorB = transform::threadblock::RegularTileIterator<
503  MatrixShape<Shape::kK, Shape::kN>,
504  ElementB,
505  SmemLayoutB,
506  0,
507  IteratorThreadMapB
508  >;
509 
510  //
511  // Warp-level matrix multiply operator
512  //
513 
514  // Define the warp-level tensor op
516  cutlass::arch::Wmma<
517  InstructionShape,
518  ElementA,
519  LayoutA,
520  ElementB,
521  LayoutB,
522  ElementC,
523  LayoutC,
524  Operator
525  >,
527  >;
528 
529  using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma<
530  WarpShape,
531  ElementA,
532  SmemLayoutA,
533  ElementB,
534  SmemLayoutB,
535  ElementC,
536  LayoutC,
537  Policy
538  >;
539 
541  using MmaPolicy = MmaPolicy<
542  MmaTensorOp,
543  MatrixShape<0, kPaddingA>,
544  MatrixShape<0, kPaddingB>,
545  WarpCount::kK
546  >;
547 };
548 
550 
558 template <
561  typename Shape_,
563  typename WarpShape_,
565  typename InstructionShape_,
567  typename ElementA_,
569  typename ElementB_,
571  typename ElementC_,
573  typename LayoutC_,
575  typename Operator_,
577  int Stages>
578 struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
579  layout::ColumnMajor, ElementB_, layout::ColumnMajor,
580  ElementC_, LayoutC_, arch::OpClassWmmaTensorOp, Stages,
581  Operator_> {
582  using Shape = Shape_;
583  using WarpShape = WarpShape_;
584  using InstructionShape = InstructionShape_;
585  using ElementA = ElementA_;
586  using LayoutA = layout::ColumnMajor;
587  using ElementB = ElementB_;
588  using LayoutB = layout::ColumnMajor;
589  using ElementC = ElementC_;
590  using LayoutC = LayoutC_;
591  using OperatorClass = arch::OpClassWmmaTensorOp;
592 
594  using WarpCount =
595  GemmShape<Shape::kM / WarpShape::kM, Shape::kN / WarpShape::kN,
596  Shape::kK / WarpShape::kK>;
597 
598  // Divisility requirements
600  !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
601  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
602 
604  static int const kWarpSize = warp::WarpSize<arch::OpClassWmmaTensorOp>::value;
605 
607  static int const kThreads = WarpCount::kCount * kWarpSize;
608 
610  static int const kAccessSizeInBits = 128;
611 
613  using Operator = Operator_;
614 
615  // Warp thread arrangement
616  static int const kWarpThreadArrangementContiguousB =
617  Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
618 
619  static int const kWarpThreadArrangementStridedB =
620  kWarpSize / kWarpThreadArrangementContiguousB;
621 
622  //
623  // Shared memory layouts
624  //
625 
626  // shared memory layout for wmma is same as the operands' layout in global memory
627  using SmemLayoutA = LayoutA;
628  using SmemLayoutB = LayoutB;
629 
630  // Pad shared memory to avoid bank conflicts
631  static int const kPaddingA = 128 / sizeof_bits<ElementA>::value;
632  static int const kPaddingB = 128 / sizeof_bits<ElementB>::value;
633 
634  //
635  // Iterators to write to shared memory
636  //
637 
639  using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap<
640  layout::PitchLinearShape<Shape::kM, Shape::kK>,
641  kThreads,
642  kAccessSizeInBits / sizeof_bits<ElementA>::value
643  >;
644 
646  using SmemIteratorA = transform::threadblock::RegularTileIterator<
647  MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 1,
648  IteratorThreadMapA>;
649 
651  using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap<
652  layout::PitchLinearShape<Shape::kK, Shape::kN>,
653  kThreads,
654  kAccessSizeInBits / sizeof_bits<ElementB>::value
655  >;
656 
658  using SmemIteratorB = transform::threadblock::RegularTileIterator<
659  MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 0,
660  IteratorThreadMapB>;
661 
662  //
663  // Warp-level matrix multiply operator
664  //
665 
666  // Define the warp-level tensor op
668  cutlass::arch::Wmma<
669  InstructionShape,
670  ElementA,
671  LayoutA,
672  ElementB,
673  LayoutB,
674  ElementC,
675  LayoutC,
676  Operator
677  >,
679  >;
680 
681  using MmaTensorOp = cutlass::gemm::warp::MmaTensorOpWmma<
682  WarpShape,
683  ElementA,
684  SmemLayoutA,
685  ElementB,
686  SmemLayoutB,
687  ElementC,
688  LayoutC,
689  Policy
690  >;
691 
693  using MmaPolicy = MmaPolicy<
694  MmaTensorOp,
695  MatrixShape<kPaddingA, 0>,
696  MatrixShape<kPaddingB, 0>,
697  WarpCount::kK
698  >;
699 };
700 
701 } // namespace threadblock
702 } // namespace gemm
703 } // namespace cutlass
704 
705 #endif // defined(CUTLASS_ARCH_WMMA_ENABLED)
706 
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Templates implementing loading of tiles from pitch-linear rank=2 tensors.
Definition: aligned_buffer.h:35
static int const value
Definition: numeric_types.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Defines a Shape template for matrix tiles.
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
static int const value
Definition: gemm/warp/mma.h:44
Top-level include for all CUTLASS numeric types.
#define static_assert(__e, __m)
Definition: platform.h:153
Policy.
Definition: mma_tensor_op_policy.h:48
Math utilities.
Templates implementing warp-level matrix multiply-accumulate operations targeting Tensor Cores...
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
Basic include for CUTLASS.
Policy describing implementation details of warp-level GEMM targeting Tensor Cores.