CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
default_mma_core_sm50.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
32 #pragma once
33 
34 #include "cutlass/cutlass.h"
35 #include "cutlass/array.h"
36 
37 #include "cutlass/numeric_types.h"
38 #include "cutlass/matrix_shape.h"
39 
40 #include "cutlass/layout/matrix.h"
43 
46 
48 
49 namespace cutlass {
50 namespace gemm {
51 namespace threadblock {
52 
54 
63 template <
66  typename Shape_,
68  typename WarpShape_,
70  typename ElementA_,
72  typename ElementB_,
74  typename ElementC_,
76  typename LayoutC_,
78  typename Operator_>
79 struct DefaultMmaCore<Shape_, WarpShape_, GemmShape<1, 1, 1>, ElementA_,
80  layout::ColumnMajor, ElementB_, layout::RowMajor,
81  ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_,
82  > {
83  using Shape = Shape_;
84  using WarpShape = WarpShape_;
85  using InstructionShape = InstructionShape_;
86  using ElementA = ElementA_;
88  using ElementB = ElementB_;
90  using ElementC = ElementC_;
91  using LayoutC = LayoutC_;
92  using OperatorClass = arch::OpClassSimt;
93 
95  using WarpCount = GemmShape<
96  Shape::kM / WarpShape::kM,
97  Shape::kN / WarpShape::kN,
98  Shape::kK / WarpShape::kK
99  >;
100 
101  // Divisility requirements
103  !(Shape::kM % WarpShape::kM) &&
104  !(Shape::kN % WarpShape::kN),
105  "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
106  );
107 
109  static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
110 
112  static int const kThreads = WarpCount::kCount * kWarpSize;
113 
114  //
115  // Shared memory layouts
116  //
117 
120 
123 
124  //
125  // Iterators to write to shared memory
126  //
127 
131  kThreads,
132  1
133  >;
134 
138  ElementA,
139  SmemLayoutA,
140  1,
142  >;
143 
147  kThreads,
148  1
149  >;
150 
154  ElementB,
155  SmemLayoutB,
156  0,
158  >;
159 
160  //
161  // Warp-level matrix multiply operator
162  //
163 
164  // Define the warp-level tensor op
166  WarpShape,
167  ElementA,
168  SmemLayoutA,
169  ElementB,
170  SmemLayoutB,
171  ElementC,
172  LayoutC,
176  GemmShape<
179  1>
180  >
181  >
182  >;
183 
185  using MmaPolicy = MmaPolicy<
186  WarpMma,
188  MatrixShape<0, 0>,
189  WarpCount::kK
190  >;
191 };
192 
194 
195 } // namespace threadblock
196 } // namespace gemm
197 } // namespace cutlass
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
Query the number of threads per warp.
Definition: gemm/warp/mma.h:43
Definition: default_mma_core.h:90
Templates implementing how threads are mapped to a given tile.
MmaPolicy< WarpMma, MatrixShape< 0, 0 >, MatrixShape< 0, 0 >, WarpCount::kK > MmaPolicy
Policy used to define MmaPipelined.
Definition: default_mma_core_sm50.h:190
Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
Definition: mma_simt.h:74
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Describes the arrangement and configuration of per-lane operations in warp-level matrix multiply...
Definition: mma_simt_policy.h:46
Defines a Shape template for matrix tiles.
Defines the size of an element in bits.
Definition: numeric_types.h:42
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
Top-level include for all CUTLASS numeric types.
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Definition: regular_tile_iterator.h:50
#define static_assert(__e, __m)
Definition: platform.h:153
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Templates implementing storing of tiles from pitch-linear rank=2 tensors.
Defines layout functions used by TensorRef and derived classes.
cutlass::gemm::warp::MmaSimt< WarpShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, ElementC, LayoutC, warp::MmaSimtPolicy< MatrixShape< 4, 8 >, layout::RowMajorInterleaved< 2 >, GemmShape< 128/sizeof_bits< ElementA >::value, 128/sizeof_bits< ElementB >::value, 1 > > > > WarpMma
Definition: default_mma_core_sm50.h:182
Templates implementing warp-level matrix multiply-accumulate operations.
Basic include for CUTLASS.
Definition: pitch_linear_thread_map.h:59
Definition: layout/matrix.h:237